Merge branch 'main' into feature/antigravity-user-agent-configurable

This commit is contained in:
Wesley Liddick
2026-02-24 14:01:43 +08:00
committed by GitHub
463 changed files with 64603 additions and 3674 deletions

View File

@@ -204,9 +204,14 @@ func shouldFallbackToNextURL(err error, statusCode int) bool {
// ExchangeCode 用 authorization code 交换 token
func (c *Client) ExchangeCode(ctx context.Context, code, codeVerifier string) (*TokenResponse, error) {
clientSecret, err := getClientSecret()
if err != nil {
return nil, err
}
params := url.Values{}
params.Set("client_id", ClientID)
params.Set("client_secret", ClientSecret)
params.Set("client_secret", clientSecret)
params.Set("code", code)
params.Set("redirect_uri", RedirectURI)
params.Set("grant_type", "authorization_code")
@@ -243,9 +248,14 @@ func (c *Client) ExchangeCode(ctx context.Context, code, codeVerifier string) (*
// RefreshToken 刷新 access_token
func (c *Client) RefreshToken(ctx context.Context, refreshToken string) (*TokenResponse, error) {
clientSecret, err := getClientSecret()
if err != nil {
return nil, err
}
params := url.Values{}
params.Set("client_id", ClientID)
params.Set("client_secret", ClientSecret)
params.Set("client_secret", clientSecret)
params.Set("refresh_token", refreshToken)
params.Set("grant_type", "refresh_token")

File diff suppressed because it is too large Load Diff

View File

@@ -6,11 +6,14 @@ import (
"encoding/base64"
"encoding/hex"
"fmt"
"net/http"
"net/url"
"os"
"strings"
"sync"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
const (
@@ -21,7 +24,11 @@ const (
// Antigravity OAuth 客户端凭证
ClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
ClientSecret = "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"
ClientSecret = ""
// AntigravityOAuthClientSecretEnv 是 Antigravity OAuth client_secret 的环境变量名。
// 出于安全原因,该值不得硬编码入库。
AntigravityOAuthClientSecretEnv = "ANTIGRAVITY_OAUTH_CLIENT_SECRET"
// 固定的 redirect_uri用户需手动复制 code
RedirectURI = "http://localhost:8085/callback"
@@ -57,6 +64,17 @@ func init() {
// GetUserAgent 返回当前配置的 User-Agent
func GetUserAgent() string {
return fmt.Sprintf("antigravity/%s windows/amd64", defaultUserAgentVersion)
func getClientSecret() (string, error) {
if v := strings.TrimSpace(ClientSecret); v != "" {
return v, nil
}
if v, ok := os.LookupEnv(AntigravityOAuthClientSecretEnv); ok {
if vv := strings.TrimSpace(v); vv != "" {
return vv, nil
}
}
return "", infraerrors.Newf(http.StatusBadRequest, "ANTIGRAVITY_OAUTH_CLIENT_SECRET_MISSING", "missing antigravity oauth client_secret; set %s", AntigravityOAuthClientSecretEnv)
}
// BaseURLs 定义 Antigravity API 端点(与 Antigravity-Manager 保持一致)

View File

@@ -0,0 +1,704 @@
//go:build unit
package antigravity
import (
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"net/url"
"strings"
"testing"
"time"
)
// ---------------------------------------------------------------------------
// getClientSecret
// ---------------------------------------------------------------------------
func TestGetClientSecret_环境变量设置(t *testing.T) {
t.Setenv(AntigravityOAuthClientSecretEnv, "my-secret-value")
secret, err := getClientSecret()
if err != nil {
t.Fatalf("获取 client_secret 失败: %v", err)
}
if secret != "my-secret-value" {
t.Errorf("client_secret 不匹配: got %s, want my-secret-value", secret)
}
}
func TestGetClientSecret_环境变量为空(t *testing.T) {
t.Setenv(AntigravityOAuthClientSecretEnv, "")
_, err := getClientSecret()
if err == nil {
t.Fatal("环境变量为空时应返回错误")
}
if !strings.Contains(err.Error(), AntigravityOAuthClientSecretEnv) {
t.Errorf("错误信息应包含环境变量名: got %s", err.Error())
}
}
func TestGetClientSecret_环境变量未设置(t *testing.T) {
// t.Setenv 会在测试结束时恢复,但我们需要确保它不存在
// 注意:如果 ClientSecret 常量非空,这个测试会直接返回常量值
// 当前代码中 ClientSecret = "",所以会走环境变量逻辑
// 明确设置再取消,确保环境变量不存在
t.Setenv(AntigravityOAuthClientSecretEnv, "")
_, err := getClientSecret()
if err == nil {
t.Fatal("环境变量未设置时应返回错误")
}
}
func TestGetClientSecret_环境变量含空格(t *testing.T) {
t.Setenv(AntigravityOAuthClientSecretEnv, " ")
_, err := getClientSecret()
if err == nil {
t.Fatal("环境变量仅含空格时应返回错误")
}
}
func TestGetClientSecret_环境变量有前后空格(t *testing.T) {
t.Setenv(AntigravityOAuthClientSecretEnv, " valid-secret ")
secret, err := getClientSecret()
if err != nil {
t.Fatalf("获取 client_secret 失败: %v", err)
}
if secret != "valid-secret" {
t.Errorf("应去除前后空格: got %q, want %q", secret, "valid-secret")
}
}
// ---------------------------------------------------------------------------
// ForwardBaseURLs
// ---------------------------------------------------------------------------
func TestForwardBaseURLs_Daily优先(t *testing.T) {
urls := ForwardBaseURLs()
if len(urls) == 0 {
t.Fatal("ForwardBaseURLs 返回空列表")
}
// daily URL 应排在第一位
if urls[0] != antigravityDailyBaseURL {
t.Errorf("第一个 URL 应为 daily: got %s, want %s", urls[0], antigravityDailyBaseURL)
}
// 应包含所有 URL
if len(urls) != len(BaseURLs) {
t.Errorf("URL 数量不匹配: got %d, want %d", len(urls), len(BaseURLs))
}
// 验证 prod URL 也在列表中
found := false
for _, u := range urls {
if u == antigravityProdBaseURL {
found = true
break
}
}
if !found {
t.Error("ForwardBaseURLs 中缺少 prod URL")
}
}
func TestForwardBaseURLs_不修改原切片(t *testing.T) {
originalFirst := BaseURLs[0]
_ = ForwardBaseURLs()
// 确保原始 BaseURLs 未被修改
if BaseURLs[0] != originalFirst {
t.Errorf("ForwardBaseURLs 不应修改原始 BaseURLs: got %s, want %s", BaseURLs[0], originalFirst)
}
}
// ---------------------------------------------------------------------------
// URLAvailability
// ---------------------------------------------------------------------------
func TestNewURLAvailability(t *testing.T) {
ua := NewURLAvailability(5 * time.Minute)
if ua == nil {
t.Fatal("NewURLAvailability 返回 nil")
}
if ua.ttl != 5*time.Minute {
t.Errorf("TTL 不匹配: got %v, want 5m", ua.ttl)
}
if ua.unavailable == nil {
t.Error("unavailable map 不应为 nil")
}
}
func TestURLAvailability_MarkUnavailable(t *testing.T) {
ua := NewURLAvailability(5 * time.Minute)
testURL := "https://example.com"
ua.MarkUnavailable(testURL)
if ua.IsAvailable(testURL) {
t.Error("标记为不可用后 IsAvailable 应返回 false")
}
}
func TestURLAvailability_MarkSuccess(t *testing.T) {
ua := NewURLAvailability(5 * time.Minute)
testURL := "https://example.com"
// 先标记为不可用
ua.MarkUnavailable(testURL)
if ua.IsAvailable(testURL) {
t.Error("标记为不可用后应不可用")
}
// 标记成功后应恢复可用
ua.MarkSuccess(testURL)
if !ua.IsAvailable(testURL) {
t.Error("MarkSuccess 后应恢复可用")
}
// 验证 lastSuccess 被设置
ua.mu.RLock()
if ua.lastSuccess != testURL {
t.Errorf("lastSuccess 不匹配: got %s, want %s", ua.lastSuccess, testURL)
}
ua.mu.RUnlock()
}
func TestURLAvailability_IsAvailable_TTL过期(t *testing.T) {
// 使用极短的 TTL
ua := NewURLAvailability(1 * time.Millisecond)
testURL := "https://example.com"
ua.MarkUnavailable(testURL)
// 等待 TTL 过期
time.Sleep(5 * time.Millisecond)
if !ua.IsAvailable(testURL) {
t.Error("TTL 过期后 URL 应恢复可用")
}
}
func TestURLAvailability_IsAvailable_未标记的URL(t *testing.T) {
ua := NewURLAvailability(5 * time.Minute)
if !ua.IsAvailable("https://never-marked.com") {
t.Error("未标记的 URL 应默认可用")
}
}
func TestURLAvailability_GetAvailableURLs(t *testing.T) {
ua := NewURLAvailability(10 * time.Minute)
// 默认所有 URL 都可用
urls := ua.GetAvailableURLs()
if len(urls) != len(BaseURLs) {
t.Errorf("可用 URL 数量不匹配: got %d, want %d", len(urls), len(BaseURLs))
}
}
func TestURLAvailability_GetAvailableURLs_标记一个不可用(t *testing.T) {
ua := NewURLAvailability(10 * time.Minute)
if len(BaseURLs) < 2 {
t.Skip("BaseURLs 少于 2 个,跳过此测试")
}
ua.MarkUnavailable(BaseURLs[0])
urls := ua.GetAvailableURLs()
// 标记的 URL 不应出现在可用列表中
for _, u := range urls {
if u == BaseURLs[0] {
t.Errorf("被标记不可用的 URL 不应出现在可用列表中: %s", BaseURLs[0])
}
}
}
func TestURLAvailability_GetAvailableURLsWithBase(t *testing.T) {
ua := NewURLAvailability(10 * time.Minute)
customURLs := []string{"https://a.com", "https://b.com", "https://c.com"}
urls := ua.GetAvailableURLsWithBase(customURLs)
if len(urls) != 3 {
t.Errorf("可用 URL 数量不匹配: got %d, want 3", len(urls))
}
}
func TestURLAvailability_GetAvailableURLsWithBase_LastSuccess优先(t *testing.T) {
ua := NewURLAvailability(10 * time.Minute)
customURLs := []string{"https://a.com", "https://b.com", "https://c.com"}
ua.MarkSuccess("https://c.com")
urls := ua.GetAvailableURLsWithBase(customURLs)
if len(urls) != 3 {
t.Fatalf("可用 URL 数量不匹配: got %d, want 3", len(urls))
}
// c.com 应排在第一位
if urls[0] != "https://c.com" {
t.Errorf("lastSuccess 应排在第一位: got %s, want https://c.com", urls[0])
}
// 其余按原始顺序
if urls[1] != "https://a.com" {
t.Errorf("第二个应为 a.com: got %s", urls[1])
}
if urls[2] != "https://b.com" {
t.Errorf("第三个应为 b.com: got %s", urls[2])
}
}
func TestURLAvailability_GetAvailableURLsWithBase_LastSuccess不可用(t *testing.T) {
ua := NewURLAvailability(10 * time.Minute)
customURLs := []string{"https://a.com", "https://b.com"}
ua.MarkSuccess("https://b.com")
ua.MarkUnavailable("https://b.com")
urls := ua.GetAvailableURLsWithBase(customURLs)
// b.com 被标记不可用,不应出现
if len(urls) != 1 {
t.Fatalf("可用 URL 数量不匹配: got %d, want 1", len(urls))
}
if urls[0] != "https://a.com" {
t.Errorf("仅 a.com 应可用: got %s", urls[0])
}
}
func TestURLAvailability_GetAvailableURLsWithBase_LastSuccess不在列表中(t *testing.T) {
ua := NewURLAvailability(10 * time.Minute)
customURLs := []string{"https://a.com", "https://b.com"}
ua.MarkSuccess("https://not-in-list.com")
urls := ua.GetAvailableURLsWithBase(customURLs)
// lastSuccess 不在自定义列表中,不应被添加
if len(urls) != 2 {
t.Fatalf("可用 URL 数量不匹配: got %d, want 2", len(urls))
}
}
// ---------------------------------------------------------------------------
// SessionStore
// ---------------------------------------------------------------------------
func TestNewSessionStore(t *testing.T) {
store := NewSessionStore()
defer store.Stop()
if store == nil {
t.Fatal("NewSessionStore 返回 nil")
}
if store.sessions == nil {
t.Error("sessions map 不应为 nil")
}
}
func TestSessionStore_SetAndGet(t *testing.T) {
store := NewSessionStore()
defer store.Stop()
session := &OAuthSession{
State: "test-state",
CodeVerifier: "test-verifier",
ProxyURL: "http://proxy.example.com",
CreatedAt: time.Now(),
}
store.Set("session-1", session)
got, ok := store.Get("session-1")
if !ok {
t.Fatal("Get 应返回 true")
}
if got.State != "test-state" {
t.Errorf("State 不匹配: got %s", got.State)
}
if got.CodeVerifier != "test-verifier" {
t.Errorf("CodeVerifier 不匹配: got %s", got.CodeVerifier)
}
if got.ProxyURL != "http://proxy.example.com" {
t.Errorf("ProxyURL 不匹配: got %s", got.ProxyURL)
}
}
func TestSessionStore_Get_不存在(t *testing.T) {
store := NewSessionStore()
defer store.Stop()
_, ok := store.Get("nonexistent")
if ok {
t.Error("不存在的 session 应返回 false")
}
}
func TestSessionStore_Get_过期(t *testing.T) {
store := NewSessionStore()
defer store.Stop()
session := &OAuthSession{
State: "expired-state",
CreatedAt: time.Now().Add(-SessionTTL - time.Minute), // 已过期
}
store.Set("expired-session", session)
_, ok := store.Get("expired-session")
if ok {
t.Error("过期的 session 应返回 false")
}
}
func TestSessionStore_Delete(t *testing.T) {
store := NewSessionStore()
defer store.Stop()
session := &OAuthSession{
State: "to-delete",
CreatedAt: time.Now(),
}
store.Set("del-session", session)
store.Delete("del-session")
_, ok := store.Get("del-session")
if ok {
t.Error("删除后 Get 应返回 false")
}
}
func TestSessionStore_Delete_不存在(t *testing.T) {
store := NewSessionStore()
defer store.Stop()
// 删除不存在的 session 不应 panic
store.Delete("nonexistent")
}
func TestSessionStore_Stop(t *testing.T) {
store := NewSessionStore()
store.Stop()
// 多次 Stop 不应 panic
store.Stop()
}
func TestSessionStore_多个Session(t *testing.T) {
store := NewSessionStore()
defer store.Stop()
for i := 0; i < 10; i++ {
session := &OAuthSession{
State: "state-" + string(rune('0'+i)),
CreatedAt: time.Now(),
}
store.Set("session-"+string(rune('0'+i)), session)
}
// 验证都能取到
for i := 0; i < 10; i++ {
_, ok := store.Get("session-" + string(rune('0'+i)))
if !ok {
t.Errorf("session-%d 应存在", i)
}
}
}
// ---------------------------------------------------------------------------
// GenerateRandomBytes
// ---------------------------------------------------------------------------
func TestGenerateRandomBytes_长度正确(t *testing.T) {
sizes := []int{0, 1, 16, 32, 64, 128}
for _, size := range sizes {
b, err := GenerateRandomBytes(size)
if err != nil {
t.Fatalf("GenerateRandomBytes(%d) 失败: %v", size, err)
}
if len(b) != size {
t.Errorf("长度不匹配: got %d, want %d", len(b), size)
}
}
}
func TestGenerateRandomBytes_不同调用产生不同结果(t *testing.T) {
b1, err := GenerateRandomBytes(32)
if err != nil {
t.Fatalf("第一次调用失败: %v", err)
}
b2, err := GenerateRandomBytes(32)
if err != nil {
t.Fatalf("第二次调用失败: %v", err)
}
// 两次生成的随机字节应该不同(概率上几乎不可能相同)
if string(b1) == string(b2) {
t.Error("两次生成的随机字节相同,概率极低,可能有问题")
}
}
// ---------------------------------------------------------------------------
// GenerateState
// ---------------------------------------------------------------------------
func TestGenerateState_返回值格式(t *testing.T) {
state, err := GenerateState()
if err != nil {
t.Fatalf("GenerateState 失败: %v", err)
}
if state == "" {
t.Error("GenerateState 返回空字符串")
}
// base64url 编码不应包含 +, /, =
if strings.ContainsAny(state, "+/=") {
t.Errorf("GenerateState 返回值包含非 base64url 字符: %s", state)
}
// 32 字节的 base64url 编码长度应为 43去掉了尾部 = 填充)
if len(state) != 43 {
t.Errorf("GenerateState 返回值长度不匹配: got %d, want 43", len(state))
}
}
func TestGenerateState_唯一性(t *testing.T) {
s1, _ := GenerateState()
s2, _ := GenerateState()
if s1 == s2 {
t.Error("两次 GenerateState 结果相同")
}
}
// ---------------------------------------------------------------------------
// GenerateSessionID
// ---------------------------------------------------------------------------
func TestGenerateSessionID_返回值格式(t *testing.T) {
id, err := GenerateSessionID()
if err != nil {
t.Fatalf("GenerateSessionID 失败: %v", err)
}
if id == "" {
t.Error("GenerateSessionID 返回空字符串")
}
// 16 字节的 hex 编码长度应为 32
if len(id) != 32 {
t.Errorf("GenerateSessionID 返回值长度不匹配: got %d, want 32", len(id))
}
// 验证是合法的 hex 字符串
if _, err := hex.DecodeString(id); err != nil {
t.Errorf("GenerateSessionID 返回值不是合法的 hex 字符串: %s, err: %v", id, err)
}
}
func TestGenerateSessionID_唯一性(t *testing.T) {
id1, _ := GenerateSessionID()
id2, _ := GenerateSessionID()
if id1 == id2 {
t.Error("两次 GenerateSessionID 结果相同")
}
}
// ---------------------------------------------------------------------------
// GenerateCodeVerifier
// ---------------------------------------------------------------------------
func TestGenerateCodeVerifier_返回值格式(t *testing.T) {
verifier, err := GenerateCodeVerifier()
if err != nil {
t.Fatalf("GenerateCodeVerifier 失败: %v", err)
}
if verifier == "" {
t.Error("GenerateCodeVerifier 返回空字符串")
}
// base64url 编码不应包含 +, /, =
if strings.ContainsAny(verifier, "+/=") {
t.Errorf("GenerateCodeVerifier 返回值包含非 base64url 字符: %s", verifier)
}
// 32 字节的 base64url 编码长度应为 43
if len(verifier) != 43 {
t.Errorf("GenerateCodeVerifier 返回值长度不匹配: got %d, want 43", len(verifier))
}
}
func TestGenerateCodeVerifier_唯一性(t *testing.T) {
v1, _ := GenerateCodeVerifier()
v2, _ := GenerateCodeVerifier()
if v1 == v2 {
t.Error("两次 GenerateCodeVerifier 结果相同")
}
}
// ---------------------------------------------------------------------------
// GenerateCodeChallenge
// ---------------------------------------------------------------------------
func TestGenerateCodeChallenge_SHA256_Base64URL(t *testing.T) {
verifier := "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
challenge := GenerateCodeChallenge(verifier)
// 手动计算预期值
hash := sha256.Sum256([]byte(verifier))
expected := strings.TrimRight(base64.URLEncoding.EncodeToString(hash[:]), "=")
if challenge != expected {
t.Errorf("CodeChallenge 不匹配: got %s, want %s", challenge, expected)
}
}
func TestGenerateCodeChallenge_不含填充字符(t *testing.T) {
challenge := GenerateCodeChallenge("test-verifier")
if strings.Contains(challenge, "=") {
t.Errorf("CodeChallenge 不应包含 = 填充字符: %s", challenge)
}
}
func TestGenerateCodeChallenge_不含非URL安全字符(t *testing.T) {
challenge := GenerateCodeChallenge("another-verifier")
if strings.ContainsAny(challenge, "+/") {
t.Errorf("CodeChallenge 不应包含 + 或 / 字符: %s", challenge)
}
}
func TestGenerateCodeChallenge_相同输入相同输出(t *testing.T) {
c1 := GenerateCodeChallenge("same-verifier")
c2 := GenerateCodeChallenge("same-verifier")
if c1 != c2 {
t.Errorf("相同输入应产生相同输出: got %s and %s", c1, c2)
}
}
func TestGenerateCodeChallenge_不同输入不同输出(t *testing.T) {
c1 := GenerateCodeChallenge("verifier-1")
c2 := GenerateCodeChallenge("verifier-2")
if c1 == c2 {
t.Error("不同输入应产生不同输出")
}
}
// ---------------------------------------------------------------------------
// BuildAuthorizationURL
// ---------------------------------------------------------------------------
func TestBuildAuthorizationURL_参数验证(t *testing.T) {
state := "test-state-123"
codeChallenge := "test-challenge-abc"
authURL := BuildAuthorizationURL(state, codeChallenge)
// 验证以 AuthorizeURL 开头
if !strings.HasPrefix(authURL, AuthorizeURL+"?") {
t.Errorf("URL 应以 %s? 开头: got %s", AuthorizeURL, authURL)
}
// 解析 URL 并验证参数
parsed, err := url.Parse(authURL)
if err != nil {
t.Fatalf("解析 URL 失败: %v", err)
}
params := parsed.Query()
expectedParams := map[string]string{
"client_id": ClientID,
"redirect_uri": RedirectURI,
"response_type": "code",
"scope": Scopes,
"state": state,
"code_challenge": codeChallenge,
"code_challenge_method": "S256",
"access_type": "offline",
"prompt": "consent",
"include_granted_scopes": "true",
}
for key, want := range expectedParams {
got := params.Get(key)
if got != want {
t.Errorf("参数 %s 不匹配: got %q, want %q", key, got, want)
}
}
}
func TestBuildAuthorizationURL_参数数量(t *testing.T) {
authURL := BuildAuthorizationURL("s", "c")
parsed, err := url.Parse(authURL)
if err != nil {
t.Fatalf("解析 URL 失败: %v", err)
}
params := parsed.Query()
// 应包含 10 个参数
expectedCount := 10
if len(params) != expectedCount {
t.Errorf("参数数量不匹配: got %d, want %d", len(params), expectedCount)
}
}
func TestBuildAuthorizationURL_特殊字符编码(t *testing.T) {
state := "state+with/special=chars"
codeChallenge := "challenge+value"
authURL := BuildAuthorizationURL(state, codeChallenge)
parsed, err := url.Parse(authURL)
if err != nil {
t.Fatalf("解析 URL 失败: %v", err)
}
// 解析后应正确还原特殊字符
if got := parsed.Query().Get("state"); got != state {
t.Errorf("state 参数编码/解码不匹配: got %q, want %q", got, state)
}
}
// ---------------------------------------------------------------------------
// 常量值验证
// ---------------------------------------------------------------------------
func TestConstants_值正确(t *testing.T) {
if AuthorizeURL != "https://accounts.google.com/o/oauth2/v2/auth" {
t.Errorf("AuthorizeURL 不匹配: got %s", AuthorizeURL)
}
if TokenURL != "https://oauth2.googleapis.com/token" {
t.Errorf("TokenURL 不匹配: got %s", TokenURL)
}
if UserInfoURL != "https://www.googleapis.com/oauth2/v2/userinfo" {
t.Errorf("UserInfoURL 不匹配: got %s", UserInfoURL)
}
if ClientID != "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com" {
t.Errorf("ClientID 不匹配: got %s", ClientID)
}
if ClientSecret != "" {
t.Error("ClientSecret 应为空字符串")
}
if RedirectURI != "http://localhost:8085/callback" {
t.Errorf("RedirectURI 不匹配: got %s", RedirectURI)
}
if UserAgent != "antigravity/1.15.8 windows/amd64" {
t.Errorf("UserAgent 不匹配: got %s", UserAgent)
}
if SessionTTL != 30*time.Minute {
t.Errorf("SessionTTL 不匹配: got %v", SessionTTL)
}
if URLAvailabilityTTL != 5*time.Minute {
t.Errorf("URLAvailabilityTTL 不匹配: got %v", URLAvailabilityTTL)
}
}
func TestScopes_包含必要范围(t *testing.T) {
expectedScopes := []string{
"https://www.googleapis.com/auth/cloud-platform",
"https://www.googleapis.com/auth/userinfo.email",
"https://www.googleapis.com/auth/userinfo.profile",
"https://www.googleapis.com/auth/cclog",
"https://www.googleapis.com/auth/experimentsandconfigs",
}
for _, scope := range expectedScopes {
if !strings.Contains(Scopes, scope) {
t.Errorf("Scopes 缺少 %s", scope)
}
}
}

View File

@@ -1,10 +1,13 @@
package antigravity
import (
"crypto/rand"
"encoding/json"
"fmt"
"log"
"strings"
"sync/atomic"
"time"
)
// TransformGeminiToClaude 将 Gemini 响应转换为 Claude 格式(非流式)
@@ -341,12 +344,30 @@ func buildGroundingText(grounding *GeminiGroundingMetadata) string {
return builder.String()
}
// generateRandomID 生成随机 ID
// fallbackCounter 降级伪随机 ID 的全局计数器,混入 seed 避免高并发下 UnixNano 相同导致碰撞。
var fallbackCounter uint64
// generateRandomID 生成密码学安全的随机 ID
func generateRandomID() string {
const chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
result := make([]byte, 12)
for i := range result {
result[i] = chars[i%len(chars)]
id := make([]byte, 12)
randBytes := make([]byte, 12)
if _, err := rand.Read(randBytes); err != nil {
// 避免在请求路径里 panic极端情况下熵源不可用时降级为伪随机。
// 这里主要用于生成响应/工具调用的临时 ID安全要求不高但需尽量避免碰撞。
cnt := atomic.AddUint64(&fallbackCounter, 1)
seed := uint64(time.Now().UnixNano()) ^ cnt
seed ^= uint64(len(err.Error())) << 32
for i := range id {
seed ^= seed << 13
seed ^= seed >> 7
seed ^= seed << 17
id[i] = chars[int(seed)%len(chars)]
}
return string(id)
}
return string(result)
for i, b := range randBytes {
id[i] = chars[int(b)%len(chars)]
}
return string(id)
}

View File

@@ -0,0 +1,109 @@
//go:build unit
package antigravity
import (
"sync"
"sync/atomic"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// --- Task 7: 验证 generateRandomID 和降级碰撞防护 ---
func TestGenerateRandomID_Uniqueness(t *testing.T) {
seen := make(map[string]struct{}, 100)
for i := 0; i < 100; i++ {
id := generateRandomID()
require.Len(t, id, 12, "ID 长度应为 12")
_, dup := seen[id]
require.False(t, dup, "第 %d 次调用生成了重复 ID: %s", i, id)
seen[id] = struct{}{}
}
}
func TestFallbackCounter_Increments(t *testing.T) {
// 验证 fallbackCounter 的原子递增行为确保降级分支不会生成相同 seed
before := atomic.LoadUint64(&fallbackCounter)
cnt1 := atomic.AddUint64(&fallbackCounter, 1)
cnt2 := atomic.AddUint64(&fallbackCounter, 1)
require.Equal(t, before+1, cnt1, "第一次递增应为 before+1")
require.Equal(t, before+2, cnt2, "第二次递增应为 before+2")
require.NotEqual(t, cnt1, cnt2, "连续两次递增的计数器值应不同")
}
func TestFallbackCounter_ConcurrentIncrements(t *testing.T) {
// 验证并发递增的原子性 — 每次递增都应产生唯一值
const goroutines = 50
results := make([]uint64, goroutines)
var wg sync.WaitGroup
wg.Add(goroutines)
for i := 0; i < goroutines; i++ {
go func(idx int) {
defer wg.Done()
results[idx] = atomic.AddUint64(&fallbackCounter, 1)
}(i)
}
wg.Wait()
// 所有结果应唯一
seen := make(map[uint64]bool, goroutines)
for _, v := range results {
assert.False(t, seen[v], "并发递增产生了重复值: %d", v)
seen[v] = true
}
}
func TestGenerateRandomID_Charset(t *testing.T) {
const validChars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
validSet := make(map[byte]struct{}, len(validChars))
for i := 0; i < len(validChars); i++ {
validSet[validChars[i]] = struct{}{}
}
for i := 0; i < 50; i++ {
id := generateRandomID()
for j := 0; j < len(id); j++ {
_, ok := validSet[id[j]]
require.True(t, ok, "ID 包含非法字符: %c (ID=%s)", id[j], id)
}
}
}
func TestGenerateRandomID_Length(t *testing.T) {
for i := 0; i < 100; i++ {
id := generateRandomID()
assert.Len(t, id, 12, "每次生成的 ID 长度应为 12")
}
}
func TestGenerateRandomID_ConcurrentUniqueness(t *testing.T) {
// 验证并发调用不会产生重复 ID
const goroutines = 100
results := make([]string, goroutines)
var wg sync.WaitGroup
wg.Add(goroutines)
for i := 0; i < goroutines; i++ {
go func(idx int) {
defer wg.Done()
results[idx] = generateRandomID()
}(i)
}
wg.Wait()
seen := make(map[string]bool, goroutines)
for _, id := range results {
assert.False(t, seen[id], "并发调用产生了重复 ID: %s", id)
seen[id] = true
}
}
func BenchmarkGenerateRandomID(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = generateRandomID()
}
}

View File

@@ -8,9 +8,21 @@ const (
// ForcePlatform 强制平台(用于 /antigravity 路由),由 middleware.ForcePlatform 设置
ForcePlatform Key = "ctx_force_platform"
// RequestID 为服务端生成/透传的请求 ID。
RequestID Key = "ctx_request_id"
// ClientRequestID 客户端请求的唯一标识,用于追踪请求全生命周期(用于 Ops 监控与排障)。
ClientRequestID Key = "ctx_client_request_id"
// Model 请求模型标识(用于统一请求链路日志字段)。
Model Key = "ctx_model"
// Platform 当前请求最终命中的平台(用于统一请求链路日志字段)。
Platform Key = "ctx_platform"
// AccountID 当前请求最终命中的账号 ID用于统一请求链路日志字段
AccountID Key = "ctx_account_id"
// RetryCount 表示当前请求在网关层的重试次数(用于 Ops 记录与排障)。
RetryCount Key = "ctx_retry_count"
@@ -32,4 +44,12 @@ const (
// SingleAccountRetry 标识当前请求处于单账号 503 退避重试模式。
// 在此模式下Service 层的模型限流预检查将等待限流过期而非直接切换账号。
SingleAccountRetry Key = "ctx_single_account_retry"
// PrefetchedStickyAccountID 标识上游(通常 handler预取到的 sticky session 账号 ID。
// Service 层可复用该值,避免同请求链路重复读取 Redis。
PrefetchedStickyAccountID Key = "ctx_prefetched_sticky_account_id"
// PrefetchedStickyGroupID 标识上游预取 sticky session 时所使用的分组 ID。
// Service 层仅在分组匹配时复用 PrefetchedStickyAccountID避免分组切换重试误用旧 sticky。
PrefetchedStickyGroupID Key = "ctx_prefetched_sticky_group_id"
)

View File

@@ -21,6 +21,7 @@ func DefaultModels() []Model {
{Name: "models/gemini-2.5-pro", SupportedGenerationMethods: methods},
{Name: "models/gemini-3-flash-preview", SupportedGenerationMethods: methods},
{Name: "models/gemini-3-pro-preview", SupportedGenerationMethods: methods},
{Name: "models/gemini-3.1-pro-preview", SupportedGenerationMethods: methods},
}
}

View File

@@ -38,8 +38,13 @@ const (
// GeminiCLIOAuthClientID/Secret are the public OAuth client credentials used by Google Gemini CLI.
// They enable the "login without creating your own OAuth client" experience, but Google may
// restrict which scopes are allowed for this client.
GeminiCLIOAuthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
GeminiCLIOAuthClientSecret = "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"
GeminiCLIOAuthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
// GeminiCLIOAuthClientSecret is intentionally not embedded in this repository.
// If you rely on the built-in Gemini CLI OAuth client, you MUST provide its client_secret via config/env.
GeminiCLIOAuthClientSecret = ""
// GeminiCLIOAuthClientSecretEnv is the environment variable name for the built-in client secret.
GeminiCLIOAuthClientSecretEnv = "GEMINI_CLI_OAUTH_CLIENT_SECRET"
SessionTTL = 30 * time.Minute

View File

@@ -16,6 +16,7 @@ var DefaultModels = []Model{
{ID: "gemini-2.5-pro", Type: "model", DisplayName: "Gemini 2.5 Pro", CreatedAt: ""},
{ID: "gemini-3-flash-preview", Type: "model", DisplayName: "Gemini 3 Flash Preview", CreatedAt: ""},
{ID: "gemini-3-pro-preview", Type: "model", DisplayName: "Gemini 3 Pro Preview", CreatedAt: ""},
{ID: "gemini-3.1-pro-preview", Type: "model", DisplayName: "Gemini 3.1 Pro Preview", CreatedAt: ""},
}
// DefaultTestModel is the default model to preselect in test flows.

View File

@@ -6,10 +6,14 @@ import (
"encoding/base64"
"encoding/hex"
"fmt"
"net/http"
"net/url"
"os"
"strings"
"sync"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
type OAuthConfig struct {
@@ -164,15 +168,24 @@ func EffectiveOAuthConfig(cfg OAuthConfig, oauthType string) (OAuthConfig, error
}
// Fall back to built-in Gemini CLI OAuth client when not configured.
// SECURITY: This repo does not embed the built-in client secret; it must be provided via env.
if effective.ClientID == "" && effective.ClientSecret == "" {
secret := strings.TrimSpace(GeminiCLIOAuthClientSecret)
if secret == "" {
if v, ok := os.LookupEnv(GeminiCLIOAuthClientSecretEnv); ok {
secret = strings.TrimSpace(v)
}
}
if secret == "" {
return OAuthConfig{}, infraerrors.Newf(http.StatusBadRequest, "GEMINI_CLI_OAUTH_CLIENT_SECRET_MISSING", "built-in Gemini CLI OAuth client_secret is not configured; set %s or provide a custom OAuth client", GeminiCLIOAuthClientSecretEnv)
}
effective.ClientID = GeminiCLIOAuthClientID
effective.ClientSecret = GeminiCLIOAuthClientSecret
effective.ClientSecret = secret
} else if effective.ClientID == "" || effective.ClientSecret == "" {
return OAuthConfig{}, fmt.Errorf("OAuth client not configured: please set both client_id and client_secret (or leave both empty to use the built-in Gemini CLI client)")
return OAuthConfig{}, infraerrors.New(http.StatusBadRequest, "GEMINI_OAUTH_CLIENT_NOT_CONFIGURED", "OAuth client not configured: please set both client_id and client_secret (or leave both empty to use the built-in Gemini CLI client)")
}
isBuiltinClient := effective.ClientID == GeminiCLIOAuthClientID &&
effective.ClientSecret == GeminiCLIOAuthClientSecret
isBuiltinClient := effective.ClientID == GeminiCLIOAuthClientID
if effective.Scopes == "" {
// Use different default scopes based on OAuth type

View File

@@ -1,11 +1,439 @@
package geminicli
import (
"encoding/hex"
"strings"
"sync"
"testing"
"time"
)
// ---------------------------------------------------------------------------
// SessionStore 测试
// ---------------------------------------------------------------------------
func TestSessionStore_SetAndGet(t *testing.T) {
store := NewSessionStore()
defer store.Stop()
session := &OAuthSession{
State: "test-state",
OAuthType: "code_assist",
CreatedAt: time.Now(),
}
store.Set("sid-1", session)
got, ok := store.Get("sid-1")
if !ok {
t.Fatal("期望 Get 返回 ok=true实际返回 false")
}
if got.State != "test-state" {
t.Errorf("期望 State=%q实际=%q", "test-state", got.State)
}
}
func TestSessionStore_GetNotFound(t *testing.T) {
store := NewSessionStore()
defer store.Stop()
_, ok := store.Get("不存在的ID")
if ok {
t.Error("期望不存在的 sessionID 返回 ok=false")
}
}
func TestSessionStore_GetExpired(t *testing.T) {
store := NewSessionStore()
defer store.Stop()
// 创建一个已过期的 sessionCreatedAt 设置为 SessionTTL+1 分钟之前)
session := &OAuthSession{
State: "expired-state",
OAuthType: "code_assist",
CreatedAt: time.Now().Add(-(SessionTTL + 1*time.Minute)),
}
store.Set("expired-sid", session)
_, ok := store.Get("expired-sid")
if ok {
t.Error("期望过期的 session 返回 ok=false")
}
}
func TestSessionStore_Delete(t *testing.T) {
store := NewSessionStore()
defer store.Stop()
session := &OAuthSession{
State: "to-delete",
OAuthType: "code_assist",
CreatedAt: time.Now(),
}
store.Set("del-sid", session)
// 先确认存在
if _, ok := store.Get("del-sid"); !ok {
t.Fatal("删除前 session 应该存在")
}
store.Delete("del-sid")
if _, ok := store.Get("del-sid"); ok {
t.Error("删除后 session 不应该存在")
}
}
func TestSessionStore_Stop_Idempotent(t *testing.T) {
store := NewSessionStore()
// 多次调用 Stop 不应 panic
store.Stop()
store.Stop()
store.Stop()
}
func TestSessionStore_ConcurrentAccess(t *testing.T) {
store := NewSessionStore()
defer store.Stop()
const goroutines = 50
var wg sync.WaitGroup
wg.Add(goroutines * 3)
// 并发写入
for i := 0; i < goroutines; i++ {
go func(idx int) {
defer wg.Done()
sid := "concurrent-" + string(rune('A'+idx%26))
store.Set(sid, &OAuthSession{
State: sid,
OAuthType: "code_assist",
CreatedAt: time.Now(),
})
}(i)
}
// 并发读取
for i := 0; i < goroutines; i++ {
go func(idx int) {
defer wg.Done()
sid := "concurrent-" + string(rune('A'+idx%26))
store.Get(sid) // 可能找到也可能没找到,关键是不 panic
}(i)
}
// 并发删除
for i := 0; i < goroutines; i++ {
go func(idx int) {
defer wg.Done()
sid := "concurrent-" + string(rune('A'+idx%26))
store.Delete(sid)
}(i)
}
wg.Wait()
}
// ---------------------------------------------------------------------------
// GenerateRandomBytes 测试
// ---------------------------------------------------------------------------
func TestGenerateRandomBytes(t *testing.T) {
tests := []int{0, 1, 16, 32, 64}
for _, n := range tests {
b, err := GenerateRandomBytes(n)
if err != nil {
t.Errorf("GenerateRandomBytes(%d) 出错: %v", n, err)
continue
}
if len(b) != n {
t.Errorf("GenerateRandomBytes(%d) 返回长度=%d期望=%d", n, len(b), n)
}
}
}
func TestGenerateRandomBytes_Uniqueness(t *testing.T) {
// 两次调用应该返回不同的结果极小概率相同32字节足够
a, _ := GenerateRandomBytes(32)
b, _ := GenerateRandomBytes(32)
if string(a) == string(b) {
t.Error("两次 GenerateRandomBytes(32) 返回了相同结果,随机性可能有问题")
}
}
// ---------------------------------------------------------------------------
// GenerateState 测试
// ---------------------------------------------------------------------------
func TestGenerateState(t *testing.T) {
state, err := GenerateState()
if err != nil {
t.Fatalf("GenerateState() 出错: %v", err)
}
if state == "" {
t.Error("GenerateState() 返回空字符串")
}
// base64url 编码不应包含 padding '='
if strings.Contains(state, "=") {
t.Errorf("GenerateState() 结果包含 '=' padding: %s", state)
}
// base64url 不应包含 '+' 或 '/'
if strings.ContainsAny(state, "+/") {
t.Errorf("GenerateState() 结果包含非 base64url 字符: %s", state)
}
}
// ---------------------------------------------------------------------------
// GenerateSessionID 测试
// ---------------------------------------------------------------------------
func TestGenerateSessionID(t *testing.T) {
sid, err := GenerateSessionID()
if err != nil {
t.Fatalf("GenerateSessionID() 出错: %v", err)
}
// 16 字节 -> 32 个 hex 字符
if len(sid) != 32 {
t.Errorf("GenerateSessionID() 长度=%d期望=32", len(sid))
}
// 必须是合法的 hex 字符串
if _, err := hex.DecodeString(sid); err != nil {
t.Errorf("GenerateSessionID() 不是合法的 hex 字符串: %s, err=%v", sid, err)
}
}
func TestGenerateSessionID_Uniqueness(t *testing.T) {
a, _ := GenerateSessionID()
b, _ := GenerateSessionID()
if a == b {
t.Error("两次 GenerateSessionID() 返回了相同结果")
}
}
// ---------------------------------------------------------------------------
// GenerateCodeVerifier 测试
// ---------------------------------------------------------------------------
func TestGenerateCodeVerifier(t *testing.T) {
verifier, err := GenerateCodeVerifier()
if err != nil {
t.Fatalf("GenerateCodeVerifier() 出错: %v", err)
}
if verifier == "" {
t.Error("GenerateCodeVerifier() 返回空字符串")
}
// RFC 7636 要求 code_verifier 至少 43 个字符
if len(verifier) < 43 {
t.Errorf("GenerateCodeVerifier() 长度=%dRFC 7636 要求至少 43 字符", len(verifier))
}
// base64url 编码不应包含 padding 和非 URL 安全字符
if strings.Contains(verifier, "=") {
t.Errorf("GenerateCodeVerifier() 包含 '=' padding: %s", verifier)
}
if strings.ContainsAny(verifier, "+/") {
t.Errorf("GenerateCodeVerifier() 包含非 base64url 字符: %s", verifier)
}
}
// ---------------------------------------------------------------------------
// GenerateCodeChallenge 测试
// ---------------------------------------------------------------------------
func TestGenerateCodeChallenge(t *testing.T) {
// 使用已知输入验证输出
// RFC 7636 附录 B 示例: verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
// 预期 challenge = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"
verifier := "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
expected := "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"
challenge := GenerateCodeChallenge(verifier)
if challenge != expected {
t.Errorf("GenerateCodeChallenge(%q) = %q期望 %q", verifier, challenge, expected)
}
}
func TestGenerateCodeChallenge_NoPadding(t *testing.T) {
challenge := GenerateCodeChallenge("test-verifier-string")
if strings.Contains(challenge, "=") {
t.Errorf("GenerateCodeChallenge() 结果包含 '=' padding: %s", challenge)
}
}
// ---------------------------------------------------------------------------
// base64URLEncode 测试
// ---------------------------------------------------------------------------
func TestBase64URLEncode(t *testing.T) {
tests := []struct {
name string
input []byte
}{
{"空字节", []byte{}},
{"单字节", []byte{0xff}},
{"多字节", []byte{0x01, 0x02, 0x03, 0x04, 0x05}},
{"全零", []byte{0x00, 0x00, 0x00}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := base64URLEncode(tt.input)
// 不应包含 '=' padding
if strings.Contains(result, "=") {
t.Errorf("base64URLEncode(%v) 包含 '=' padding: %s", tt.input, result)
}
// 不应包含标准 base64 的 '+' 或 '/'
if strings.ContainsAny(result, "+/") {
t.Errorf("base64URLEncode(%v) 包含非 URL 安全字符: %s", tt.input, result)
}
})
}
}
// ---------------------------------------------------------------------------
// hasRestrictedScope 测试
// ---------------------------------------------------------------------------
func TestHasRestrictedScope(t *testing.T) {
tests := []struct {
scope string
expected bool
}{
// 受限 scope
{"https://www.googleapis.com/auth/generative-language", true},
{"https://www.googleapis.com/auth/generative-language.retriever", true},
{"https://www.googleapis.com/auth/generative-language.tuning", true},
{"https://www.googleapis.com/auth/drive", true},
{"https://www.googleapis.com/auth/drive.readonly", true},
{"https://www.googleapis.com/auth/drive.file", true},
// 非受限 scope
{"https://www.googleapis.com/auth/cloud-platform", false},
{"https://www.googleapis.com/auth/userinfo.email", false},
{"https://www.googleapis.com/auth/userinfo.profile", false},
// 边界情况
{"", false},
{"random-scope", false},
}
for _, tt := range tests {
t.Run(tt.scope, func(t *testing.T) {
got := hasRestrictedScope(tt.scope)
if got != tt.expected {
t.Errorf("hasRestrictedScope(%q) = %v期望 %v", tt.scope, got, tt.expected)
}
})
}
}
// ---------------------------------------------------------------------------
// BuildAuthorizationURL 测试
// ---------------------------------------------------------------------------
func TestBuildAuthorizationURL(t *testing.T) {
t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-secret")
authURL, err := BuildAuthorizationURL(
OAuthConfig{},
"test-state",
"test-challenge",
"https://example.com/callback",
"",
"code_assist",
)
if err != nil {
t.Fatalf("BuildAuthorizationURL() 出错: %v", err)
}
// 检查返回的 URL 包含期望的参数
checks := []string{
"response_type=code",
"client_id=" + GeminiCLIOAuthClientID,
"redirect_uri=",
"state=test-state",
"code_challenge=test-challenge",
"code_challenge_method=S256",
"access_type=offline",
"prompt=consent",
"include_granted_scopes=true",
}
for _, check := range checks {
if !strings.Contains(authURL, check) {
t.Errorf("BuildAuthorizationURL() URL 缺少参数 %q\nURL: %s", check, authURL)
}
}
// 不应包含 project_id因为传的是空字符串
if strings.Contains(authURL, "project_id=") {
t.Errorf("BuildAuthorizationURL() 空 projectID 时不应包含 project_id 参数")
}
// URL 应该以正确的授权端点开头
if !strings.HasPrefix(authURL, AuthorizeURL+"?") {
t.Errorf("BuildAuthorizationURL() URL 应以 %s? 开头,实际: %s", AuthorizeURL, authURL)
}
}
func TestBuildAuthorizationURL_EmptyRedirectURI(t *testing.T) {
t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-secret")
_, err := BuildAuthorizationURL(
OAuthConfig{},
"test-state",
"test-challenge",
"", // 空 redirectURI
"",
"code_assist",
)
if err == nil {
t.Error("BuildAuthorizationURL() 空 redirectURI 应该报错")
}
if !strings.Contains(err.Error(), "redirect_uri") {
t.Errorf("错误消息应包含 'redirect_uri',实际: %v", err)
}
}
func TestBuildAuthorizationURL_WithProjectID(t *testing.T) {
t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-secret")
authURL, err := BuildAuthorizationURL(
OAuthConfig{},
"test-state",
"test-challenge",
"https://example.com/callback",
"my-project-123",
"code_assist",
)
if err != nil {
t.Fatalf("BuildAuthorizationURL() 出错: %v", err)
}
if !strings.Contains(authURL, "project_id=my-project-123") {
t.Errorf("BuildAuthorizationURL() 带 projectID 时应包含 project_id 参数\nURL: %s", authURL)
}
}
func TestBuildAuthorizationURL_OAuthConfigError(t *testing.T) {
// 不设置环境变量,也不提供 client 凭据EffectiveOAuthConfig 应该报错
t.Setenv(GeminiCLIOAuthClientSecretEnv, "")
_, err := BuildAuthorizationURL(
OAuthConfig{},
"test-state",
"test-challenge",
"https://example.com/callback",
"",
"code_assist",
)
if err == nil {
t.Error("当 EffectiveOAuthConfig 失败时BuildAuthorizationURL 应该返回错误")
}
}
// ---------------------------------------------------------------------------
// EffectiveOAuthConfig 测试 - 原有测试
// ---------------------------------------------------------------------------
func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) {
// 内置的 Gemini CLI client secret 不嵌入在此仓库中。
// 测试通过环境变量设置一个假的 secret 来模拟运维配置。
t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-built-in-secret")
tests := []struct {
name string
input OAuthConfig
@@ -15,7 +443,7 @@ func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) {
wantErr bool
}{
{
name: "Google One with built-in client (empty config)",
name: "Google One 使用内置客户端(空配置)",
input: OAuthConfig{},
oauthType: "google_one",
wantClientID: GeminiCLIOAuthClientID,
@@ -23,18 +451,18 @@ func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) {
wantErr: false,
},
{
name: "Google One always uses built-in client (even if custom credentials passed)",
name: "Google One 使用自定义客户端(传入自定义凭据时使用自定义)",
input: OAuthConfig{
ClientID: "custom-client-id",
ClientSecret: "custom-client-secret",
},
oauthType: "google_one",
wantClientID: "custom-client-id",
wantScopes: DefaultCodeAssistScopes, // Uses code assist scopes even with custom client
wantScopes: DefaultCodeAssistScopes,
wantErr: false,
},
{
name: "Google One with built-in client and custom scopes (should filter restricted scopes)",
name: "Google One 内置客户端 + 自定义 scopes应过滤受限 scopes",
input: OAuthConfig{
Scopes: "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/generative-language.retriever https://www.googleapis.com/auth/drive.readonly",
},
@@ -44,7 +472,7 @@ func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) {
wantErr: false,
},
{
name: "Google One with built-in client and only restricted scopes (should fallback to default)",
name: "Google One 内置客户端 + 仅受限 scopes应回退到默认",
input: OAuthConfig{
Scopes: "https://www.googleapis.com/auth/generative-language.retriever https://www.googleapis.com/auth/drive.readonly",
},
@@ -54,7 +482,7 @@ func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) {
wantErr: false,
},
{
name: "Code Assist with built-in client",
name: "Code Assist 使用内置客户端",
input: OAuthConfig{},
oauthType: "code_assist",
wantClientID: GeminiCLIOAuthClientID,
@@ -84,7 +512,9 @@ func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) {
}
func TestEffectiveOAuthConfig_ScopeFiltering(t *testing.T) {
// Test that Google One with built-in client filters out restricted scopes
t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-built-in-secret")
// 测试 Google One + 内置客户端过滤受限 scopes
cfg, err := EffectiveOAuthConfig(OAuthConfig{
Scopes: "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/userinfo.email https://www.googleapis.com/auth/generative-language.retriever https://www.googleapis.com/auth/drive.readonly https://www.googleapis.com/auth/userinfo.profile",
}, "google_one")
@@ -93,21 +523,240 @@ func TestEffectiveOAuthConfig_ScopeFiltering(t *testing.T) {
t.Fatalf("EffectiveOAuthConfig() error = %v", err)
}
// Should only contain cloud-platform, userinfo.email, and userinfo.profile
// Should NOT contain generative-language or drive scopes
// 应仅包含 cloud-platformuserinfo.email userinfo.profile
// 不应包含 generative-language drive scopes
if strings.Contains(cfg.Scopes, "generative-language") {
t.Errorf("Scopes should not contain generative-language when using built-in client, got: %v", cfg.Scopes)
t.Errorf("使用内置客户端时 Scopes 不应包含 generative-language实际: %v", cfg.Scopes)
}
if strings.Contains(cfg.Scopes, "drive") {
t.Errorf("Scopes should not contain drive when using built-in client, got: %v", cfg.Scopes)
t.Errorf("使用内置客户端时 Scopes 不应包含 drive实际: %v", cfg.Scopes)
}
if !strings.Contains(cfg.Scopes, "cloud-platform") {
t.Errorf("Scopes should contain cloud-platform, got: %v", cfg.Scopes)
t.Errorf("Scopes 应包含 cloud-platform,实际: %v", cfg.Scopes)
}
if !strings.Contains(cfg.Scopes, "userinfo.email") {
t.Errorf("Scopes should contain userinfo.email, got: %v", cfg.Scopes)
t.Errorf("Scopes 应包含 userinfo.email,实际: %v", cfg.Scopes)
}
if !strings.Contains(cfg.Scopes, "userinfo.profile") {
t.Errorf("Scopes should contain userinfo.profile, got: %v", cfg.Scopes)
t.Errorf("Scopes 应包含 userinfo.profile,实际: %v", cfg.Scopes)
}
}
// ---------------------------------------------------------------------------
// EffectiveOAuthConfig 测试 - 新增分支覆盖
// ---------------------------------------------------------------------------
func TestEffectiveOAuthConfig_OnlyClientID_NoSecret(t *testing.T) {
// 只提供 clientID 不提供 secret 应报错
_, err := EffectiveOAuthConfig(OAuthConfig{
ClientID: "some-client-id",
}, "code_assist")
if err == nil {
t.Error("只提供 ClientID 不提供 ClientSecret 应该报错")
}
if !strings.Contains(err.Error(), "client_id") || !strings.Contains(err.Error(), "client_secret") {
t.Errorf("错误消息应提及 client_id 和 client_secret实际: %v", err)
}
}
func TestEffectiveOAuthConfig_OnlyClientSecret_NoID(t *testing.T) {
// 只提供 secret 不提供 clientID 应报错
_, err := EffectiveOAuthConfig(OAuthConfig{
ClientSecret: "some-client-secret",
}, "code_assist")
if err == nil {
t.Error("只提供 ClientSecret 不提供 ClientID 应该报错")
}
if !strings.Contains(err.Error(), "client_id") || !strings.Contains(err.Error(), "client_secret") {
t.Errorf("错误消息应提及 client_id 和 client_secret实际: %v", err)
}
}
func TestEffectiveOAuthConfig_AIStudio_DefaultScopes_BuiltinClient(t *testing.T) {
t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-built-in-secret")
// ai_studio 类型使用内置客户端scopes 为空 -> 应使用 DefaultCodeAssistScopes因为内置客户端不能请求 generative-language scope
cfg, err := EffectiveOAuthConfig(OAuthConfig{}, "ai_studio")
if err != nil {
t.Fatalf("EffectiveOAuthConfig() error = %v", err)
}
if cfg.Scopes != DefaultCodeAssistScopes {
t.Errorf("ai_studio + 内置客户端应使用 DefaultCodeAssistScopes实际: %q", cfg.Scopes)
}
}
func TestEffectiveOAuthConfig_AIStudio_DefaultScopes_CustomClient(t *testing.T) {
// ai_studio 类型使用自定义客户端scopes 为空 -> 应使用 DefaultAIStudioScopes
cfg, err := EffectiveOAuthConfig(OAuthConfig{
ClientID: "custom-id",
ClientSecret: "custom-secret",
}, "ai_studio")
if err != nil {
t.Fatalf("EffectiveOAuthConfig() error = %v", err)
}
if cfg.Scopes != DefaultAIStudioScopes {
t.Errorf("ai_studio + 自定义客户端应使用 DefaultAIStudioScopes实际: %q", cfg.Scopes)
}
}
func TestEffectiveOAuthConfig_AIStudio_ScopeNormalization(t *testing.T) {
// ai_studio 类型,旧的 generative-language scope 应被归一化为 generative-language.retriever
cfg, err := EffectiveOAuthConfig(OAuthConfig{
ClientID: "custom-id",
ClientSecret: "custom-secret",
Scopes: "https://www.googleapis.com/auth/generative-language https://www.googleapis.com/auth/cloud-platform",
}, "ai_studio")
if err != nil {
t.Fatalf("EffectiveOAuthConfig() error = %v", err)
}
if strings.Contains(cfg.Scopes, "auth/generative-language ") || strings.HasSuffix(cfg.Scopes, "auth/generative-language") {
// 确保不包含未归一化的旧 scope仅 generative-language 而非 generative-language.retriever
parts := strings.Fields(cfg.Scopes)
for _, p := range parts {
if p == "https://www.googleapis.com/auth/generative-language" {
t.Errorf("ai_studio 应将 generative-language 归一化为 generative-language.retriever实际 scopes: %q", cfg.Scopes)
}
}
}
if !strings.Contains(cfg.Scopes, "generative-language.retriever") {
t.Errorf("ai_studio 归一化后应包含 generative-language.retriever实际: %q", cfg.Scopes)
}
}
func TestEffectiveOAuthConfig_CommaSeparatedScopes(t *testing.T) {
t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-built-in-secret")
// 逗号分隔的 scopes 应被归一化为空格分隔
cfg, err := EffectiveOAuthConfig(OAuthConfig{
ClientID: "custom-id",
ClientSecret: "custom-secret",
Scopes: "https://www.googleapis.com/auth/cloud-platform,https://www.googleapis.com/auth/userinfo.email",
}, "code_assist")
if err != nil {
t.Fatalf("EffectiveOAuthConfig() error = %v", err)
}
// 应该用空格分隔,而非逗号
if strings.Contains(cfg.Scopes, ",") {
t.Errorf("逗号分隔的 scopes 应被归一化为空格分隔,实际: %q", cfg.Scopes)
}
if !strings.Contains(cfg.Scopes, "cloud-platform") {
t.Errorf("归一化后应包含 cloud-platform实际: %q", cfg.Scopes)
}
if !strings.Contains(cfg.Scopes, "userinfo.email") {
t.Errorf("归一化后应包含 userinfo.email实际: %q", cfg.Scopes)
}
}
func TestEffectiveOAuthConfig_MixedCommaAndSpaceScopes(t *testing.T) {
// 混合逗号和空格分隔的 scopes
cfg, err := EffectiveOAuthConfig(OAuthConfig{
ClientID: "custom-id",
ClientSecret: "custom-secret",
Scopes: "https://www.googleapis.com/auth/cloud-platform, https://www.googleapis.com/auth/userinfo.email https://www.googleapis.com/auth/userinfo.profile",
}, "code_assist")
if err != nil {
t.Fatalf("EffectiveOAuthConfig() error = %v", err)
}
parts := strings.Fields(cfg.Scopes)
if len(parts) != 3 {
t.Errorf("归一化后应有 3 个 scope实际: %dscopes: %q", len(parts), cfg.Scopes)
}
}
func TestEffectiveOAuthConfig_WhitespaceTriming(t *testing.T) {
// 输入中的前后空白应被清理
cfg, err := EffectiveOAuthConfig(OAuthConfig{
ClientID: " custom-id ",
ClientSecret: " custom-secret ",
Scopes: " https://www.googleapis.com/auth/cloud-platform ",
}, "code_assist")
if err != nil {
t.Fatalf("EffectiveOAuthConfig() error = %v", err)
}
if cfg.ClientID != "custom-id" {
t.Errorf("ClientID 应去除前后空白,实际: %q", cfg.ClientID)
}
if cfg.ClientSecret != "custom-secret" {
t.Errorf("ClientSecret 应去除前后空白,实际: %q", cfg.ClientSecret)
}
if cfg.Scopes != "https://www.googleapis.com/auth/cloud-platform" {
t.Errorf("Scopes 应去除前后空白,实际: %q", cfg.Scopes)
}
}
func TestEffectiveOAuthConfig_NoEnvSecret(t *testing.T) {
// 不设置环境变量且不提供凭据,应该报错
t.Setenv(GeminiCLIOAuthClientSecretEnv, "")
_, err := EffectiveOAuthConfig(OAuthConfig{}, "code_assist")
if err == nil {
t.Error("没有内置 secret 且未提供凭据时应该报错")
}
if !strings.Contains(err.Error(), GeminiCLIOAuthClientSecretEnv) {
t.Errorf("错误消息应提及环境变量 %s实际: %v", GeminiCLIOAuthClientSecretEnv, err)
}
}
func TestEffectiveOAuthConfig_AIStudio_BuiltinClient_CustomScopes(t *testing.T) {
t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-built-in-secret")
// ai_studio + 内置客户端 + 自定义 scopes -> 应过滤受限 scopes
cfg, err := EffectiveOAuthConfig(OAuthConfig{
Scopes: "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/generative-language.retriever",
}, "ai_studio")
if err != nil {
t.Fatalf("EffectiveOAuthConfig() error = %v", err)
}
// 内置客户端应过滤 generative-language.retriever
if strings.Contains(cfg.Scopes, "generative-language") {
t.Errorf("ai_studio + 内置客户端应过滤受限 scopes实际: %q", cfg.Scopes)
}
if !strings.Contains(cfg.Scopes, "cloud-platform") {
t.Errorf("应保留 cloud-platform scope实际: %q", cfg.Scopes)
}
}
func TestEffectiveOAuthConfig_UnknownOAuthType_DefaultScopes(t *testing.T) {
t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-built-in-secret")
// 未知的 oauthType 应回退到默认的 code_assist scopes
cfg, err := EffectiveOAuthConfig(OAuthConfig{}, "unknown_type")
if err != nil {
t.Fatalf("EffectiveOAuthConfig() error = %v", err)
}
if cfg.Scopes != DefaultCodeAssistScopes {
t.Errorf("未知 oauthType 应使用 DefaultCodeAssistScopes实际: %q", cfg.Scopes)
}
}
func TestEffectiveOAuthConfig_EmptyOAuthType_DefaultScopes(t *testing.T) {
t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-built-in-secret")
// 空的 oauthType 应走 default 分支,使用 DefaultCodeAssistScopes
cfg, err := EffectiveOAuthConfig(OAuthConfig{}, "")
if err != nil {
t.Fatalf("EffectiveOAuthConfig() error = %v", err)
}
if cfg.Scopes != DefaultCodeAssistScopes {
t.Errorf("空 oauthType 应使用 DefaultCodeAssistScopes实际: %q", cfg.Scopes)
}
}
func TestEffectiveOAuthConfig_CustomClient_NoScopeFiltering(t *testing.T) {
// 自定义客户端 + google_one + 包含受限 scopes -> 不应被过滤(因为不是内置客户端)
cfg, err := EffectiveOAuthConfig(OAuthConfig{
ClientID: "custom-id",
ClientSecret: "custom-secret",
Scopes: "https://www.googleapis.com/auth/generative-language.retriever https://www.googleapis.com/auth/drive.readonly",
}, "google_one")
if err != nil {
t.Fatalf("EffectiveOAuthConfig() error = %v", err)
}
// 自定义客户端不应过滤任何 scope
if !strings.Contains(cfg.Scopes, "generative-language.retriever") {
t.Errorf("自定义客户端不应过滤 generative-language.retriever实际: %q", cfg.Scopes)
}
if !strings.Contains(cfg.Scopes, "drive.readonly") {
t.Errorf("自定义客户端不应过滤 drive.readonly实际: %q", cfg.Scopes)
}
}

View File

@@ -44,6 +44,16 @@ func GetClientIP(c *gin.Context) string {
return normalizeIP(c.ClientIP())
}
// GetTrustedClientIP 从 Gin 的可信代理解析链提取客户端 IP。
// 该方法依赖 gin.Engine.SetTrustedProxies 配置,不会优先直接信任原始转发头值。
// 适用于 ACL / 风控等安全敏感场景。
func GetTrustedClientIP(c *gin.Context) string {
if c == nil {
return ""
}
return normalizeIP(c.ClientIP())
}
// normalizeIP 规范化 IP 地址,去除端口号和空格。
func normalizeIP(ip string) string {
ip = strings.TrimSpace(ip)
@@ -54,29 +64,34 @@ func normalizeIP(ip string) string {
return ip
}
// isPrivateIP 检查 IP 是否为私有地址。
func isPrivateIP(ipStr string) bool {
ip := net.ParseIP(ipStr)
if ip == nil {
return false
}
// privateNets 预编译私有 IP CIDR 块,避免每次调用 isPrivateIP 时重复解析
var privateNets []*net.IPNet
// 私有 IP 范围
privateBlocks := []string{
func init() {
for _, cidr := range []string{
"10.0.0.0/8",
"172.16.0.0/12",
"192.168.0.0/16",
"127.0.0.0/8",
"::1/128",
"fc00::/7",
}
for _, block := range privateBlocks {
_, cidr, err := net.ParseCIDR(block)
} {
_, block, err := net.ParseCIDR(cidr)
if err != nil {
continue
panic("invalid CIDR: " + cidr)
}
if cidr.Contains(ip) {
privateNets = append(privateNets, block)
}
}
// isPrivateIP 检查 IP 是否为私有地址。
func isPrivateIP(ipStr string) bool {
ip := net.ParseIP(ipStr)
if ip == nil {
return false
}
for _, block := range privateNets {
if block.Contains(ip) {
return true
}
}

View File

@@ -0,0 +1,75 @@
//go:build unit
package ip
import (
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestIsPrivateIP(t *testing.T) {
tests := []struct {
name string
ip string
expected bool
}{
// 私有 IPv4
{"10.x 私有地址", "10.0.0.1", true},
{"10.x 私有地址段末", "10.255.255.255", true},
{"172.16.x 私有地址", "172.16.0.1", true},
{"172.31.x 私有地址", "172.31.255.255", true},
{"192.168.x 私有地址", "192.168.1.1", true},
{"127.0.0.1 本地回环", "127.0.0.1", true},
{"127.x 回环段", "127.255.255.255", true},
// 公网 IPv4
{"8.8.8.8 公网 DNS", "8.8.8.8", false},
{"1.1.1.1 公网", "1.1.1.1", false},
{"172.15.255.255 非私有", "172.15.255.255", false},
{"172.32.0.0 非私有", "172.32.0.0", false},
{"11.0.0.1 公网", "11.0.0.1", false},
// IPv6
{"::1 IPv6 回环", "::1", true},
{"fc00:: IPv6 私有", "fc00::1", true},
{"fd00:: IPv6 私有", "fd00::1", true},
{"2001:db8::1 IPv6 公网", "2001:db8::1", false},
// 无效输入
{"空字符串", "", false},
{"非法字符串", "not-an-ip", false},
{"不完整 IP", "192.168", false},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := isPrivateIP(tc.ip)
require.Equal(t, tc.expected, got, "isPrivateIP(%q)", tc.ip)
})
}
}
func TestGetTrustedClientIPUsesGinClientIP(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
require.NoError(t, r.SetTrustedProxies(nil))
r.GET("/t", func(c *gin.Context) {
c.String(200, GetTrustedClientIP(c))
})
w := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/t", nil)
req.RemoteAddr = "9.9.9.9:12345"
req.Header.Set("X-Forwarded-For", "1.2.3.4")
req.Header.Set("X-Real-IP", "1.2.3.4")
req.Header.Set("CF-Connecting-IP", "1.2.3.4")
r.ServeHTTP(w, req)
require.Equal(t, 200, w.Code)
require.Equal(t, "9.9.9.9", w.Body.String())
}

View File

@@ -0,0 +1,31 @@
package logger
import "github.com/Wei-Shaw/sub2api/internal/config"
func OptionsFromConfig(cfg config.LogConfig) InitOptions {
return InitOptions{
Level: cfg.Level,
Format: cfg.Format,
ServiceName: cfg.ServiceName,
Environment: cfg.Environment,
Caller: cfg.Caller,
StacktraceLevel: cfg.StacktraceLevel,
Output: OutputOptions{
ToStdout: cfg.Output.ToStdout,
ToFile: cfg.Output.ToFile,
FilePath: cfg.Output.FilePath,
},
Rotation: RotationOptions{
MaxSizeMB: cfg.Rotation.MaxSizeMB,
MaxBackups: cfg.Rotation.MaxBackups,
MaxAgeDays: cfg.Rotation.MaxAgeDays,
Compress: cfg.Rotation.Compress,
LocalTime: cfg.Rotation.LocalTime,
},
Sampling: SamplingOptions{
Enabled: cfg.Sampling.Enabled,
Initial: cfg.Sampling.Initial,
Thereafter: cfg.Sampling.Thereafter,
},
}
}

View File

@@ -0,0 +1,519 @@
package logger
import (
"context"
"fmt"
"io"
"log"
"log/slog"
"os"
"path/filepath"
"strings"
"sync"
"time"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"gopkg.in/natefinch/lumberjack.v2"
)
type Level = zapcore.Level
const (
LevelDebug = zapcore.DebugLevel
LevelInfo = zapcore.InfoLevel
LevelWarn = zapcore.WarnLevel
LevelError = zapcore.ErrorLevel
LevelFatal = zapcore.FatalLevel
)
type Sink interface {
WriteLogEvent(event *LogEvent)
}
type LogEvent struct {
Time time.Time
Level string
Component string
Message string
LoggerName string
Fields map[string]any
}
var (
mu sync.RWMutex
global *zap.Logger
sugar *zap.SugaredLogger
atomicLevel zap.AtomicLevel
initOptions InitOptions
currentSink Sink
stdLogUndo func()
bootstrapOnce sync.Once
)
func InitBootstrap() {
bootstrapOnce.Do(func() {
if err := Init(bootstrapOptions()); err != nil {
_, _ = fmt.Fprintf(os.Stderr, "logger bootstrap init failed: %v\n", err)
}
})
}
func Init(options InitOptions) error {
mu.Lock()
defer mu.Unlock()
return initLocked(options)
}
func initLocked(options InitOptions) error {
normalized := options.normalized()
zl, al, err := buildLogger(normalized)
if err != nil {
return err
}
prev := global
global = zl
sugar = zl.Sugar()
atomicLevel = al
initOptions = normalized
bridgeSlogLocked()
bridgeStdLogLocked()
if prev != nil {
_ = prev.Sync()
}
return nil
}
func Reconfigure(mutator func(*InitOptions) error) error {
mu.Lock()
defer mu.Unlock()
next := initOptions
if mutator != nil {
if err := mutator(&next); err != nil {
return err
}
}
return initLocked(next)
}
func SetLevel(level string) error {
lv, ok := parseLevel(level)
if !ok {
return fmt.Errorf("invalid log level: %s", level)
}
mu.Lock()
defer mu.Unlock()
atomicLevel.SetLevel(lv)
initOptions.Level = strings.ToLower(strings.TrimSpace(level))
return nil
}
func CurrentLevel() string {
mu.RLock()
defer mu.RUnlock()
if global == nil {
return "info"
}
return atomicLevel.Level().String()
}
func SetSink(sink Sink) {
mu.Lock()
defer mu.Unlock()
currentSink = sink
}
// WriteSinkEvent 直接写入日志 sink不经过全局日志级别门控。
// 用于需要“可观测性入库”与“业务输出级别”解耦的场景(例如 ops 系统日志索引)。
func WriteSinkEvent(level, component, message string, fields map[string]any) {
mu.RLock()
sink := currentSink
mu.RUnlock()
if sink == nil {
return
}
level = strings.ToLower(strings.TrimSpace(level))
if level == "" {
level = "info"
}
component = strings.TrimSpace(component)
message = strings.TrimSpace(message)
if message == "" {
return
}
eventFields := make(map[string]any, len(fields)+1)
for k, v := range fields {
eventFields[k] = v
}
if component != "" {
if _, ok := eventFields["component"]; !ok {
eventFields["component"] = component
}
}
sink.WriteLogEvent(&LogEvent{
Time: time.Now(),
Level: level,
Component: component,
Message: message,
LoggerName: component,
Fields: eventFields,
})
}
func L() *zap.Logger {
mu.RLock()
defer mu.RUnlock()
if global != nil {
return global
}
return zap.NewNop()
}
func S() *zap.SugaredLogger {
mu.RLock()
defer mu.RUnlock()
if sugar != nil {
return sugar
}
return zap.NewNop().Sugar()
}
func With(fields ...zap.Field) *zap.Logger {
return L().With(fields...)
}
func Sync() {
mu.RLock()
l := global
mu.RUnlock()
if l != nil {
_ = l.Sync()
}
}
func bridgeStdLogLocked() {
if stdLogUndo != nil {
stdLogUndo()
stdLogUndo = nil
}
prevFlags := log.Flags()
prevPrefix := log.Prefix()
prevWriter := log.Writer()
log.SetFlags(0)
log.SetPrefix("")
log.SetOutput(newStdLogBridge(global.Named("stdlog")))
stdLogUndo = func() {
log.SetOutput(prevWriter)
log.SetFlags(prevFlags)
log.SetPrefix(prevPrefix)
}
}
func bridgeSlogLocked() {
slog.SetDefault(slog.New(newSlogZapHandler(global.Named("slog"))))
}
func buildLogger(options InitOptions) (*zap.Logger, zap.AtomicLevel, error) {
level, _ := parseLevel(options.Level)
atomic := zap.NewAtomicLevelAt(level)
encoderCfg := zapcore.EncoderConfig{
TimeKey: "time",
LevelKey: "level",
NameKey: "logger",
CallerKey: "caller",
MessageKey: "msg",
StacktraceKey: "stacktrace",
LineEnding: zapcore.DefaultLineEnding,
EncodeLevel: zapcore.CapitalLevelEncoder,
EncodeTime: zapcore.ISO8601TimeEncoder,
EncodeDuration: zapcore.MillisDurationEncoder,
EncodeCaller: zapcore.ShortCallerEncoder,
}
var enc zapcore.Encoder
if options.Format == "console" {
enc = zapcore.NewConsoleEncoder(encoderCfg)
} else {
enc = zapcore.NewJSONEncoder(encoderCfg)
}
sinkCore := newSinkCore()
cores := make([]zapcore.Core, 0, 3)
if options.Output.ToStdout {
infoPriority := zap.LevelEnablerFunc(func(lvl zapcore.Level) bool {
return lvl >= atomic.Level() && lvl < zapcore.WarnLevel
})
errPriority := zap.LevelEnablerFunc(func(lvl zapcore.Level) bool {
return lvl >= atomic.Level() && lvl >= zapcore.WarnLevel
})
cores = append(cores, zapcore.NewCore(enc, zapcore.Lock(os.Stdout), infoPriority))
cores = append(cores, zapcore.NewCore(enc, zapcore.Lock(os.Stderr), errPriority))
}
if options.Output.ToFile {
fileCore, filePath, fileErr := buildFileCore(enc, atomic, options)
if fileErr != nil {
_, _ = fmt.Fprintf(os.Stderr, "time=%s level=WARN msg=\"日志文件输出初始化失败,降级为仅标准输出\" path=%s err=%v\n",
time.Now().Format(time.RFC3339Nano),
filePath,
fileErr,
)
} else {
cores = append(cores, fileCore)
}
}
if len(cores) == 0 {
cores = append(cores, zapcore.NewCore(enc, zapcore.Lock(os.Stdout), atomic))
}
core := zapcore.NewTee(cores...)
if options.Sampling.Enabled {
core = zapcore.NewSamplerWithOptions(core, samplingTick(), options.Sampling.Initial, options.Sampling.Thereafter)
}
core = sinkCore.Wrap(core)
stacktraceLevel, _ := parseStacktraceLevel(options.StacktraceLevel)
zapOpts := make([]zap.Option, 0, 5)
if options.Caller {
zapOpts = append(zapOpts, zap.AddCaller())
}
if stacktraceLevel <= zapcore.FatalLevel {
zapOpts = append(zapOpts, zap.AddStacktrace(stacktraceLevel))
}
logger := zap.New(core, zapOpts...).With(
zap.String("service", options.ServiceName),
zap.String("env", options.Environment),
)
return logger, atomic, nil
}
func buildFileCore(enc zapcore.Encoder, atomic zap.AtomicLevel, options InitOptions) (zapcore.Core, string, error) {
filePath := options.Output.FilePath
if strings.TrimSpace(filePath) == "" {
filePath = resolveLogFilePath("")
}
dir := filepath.Dir(filePath)
if err := os.MkdirAll(dir, 0o755); err != nil {
return nil, filePath, err
}
lj := &lumberjack.Logger{
Filename: filePath,
MaxSize: options.Rotation.MaxSizeMB,
MaxBackups: options.Rotation.MaxBackups,
MaxAge: options.Rotation.MaxAgeDays,
Compress: options.Rotation.Compress,
LocalTime: options.Rotation.LocalTime,
}
return zapcore.NewCore(enc, zapcore.AddSync(lj), atomic), filePath, nil
}
type sinkCore struct {
core zapcore.Core
fields []zapcore.Field
}
func newSinkCore() *sinkCore {
return &sinkCore{}
}
func (s *sinkCore) Wrap(core zapcore.Core) zapcore.Core {
cp := *s
cp.core = core
return &cp
}
func (s *sinkCore) Enabled(level zapcore.Level) bool {
return s.core.Enabled(level)
}
func (s *sinkCore) With(fields []zapcore.Field) zapcore.Core {
nextFields := append([]zapcore.Field{}, s.fields...)
nextFields = append(nextFields, fields...)
return &sinkCore{
core: s.core.With(fields),
fields: nextFields,
}
}
func (s *sinkCore) Check(entry zapcore.Entry, ce *zapcore.CheckedEntry) *zapcore.CheckedEntry {
// Delegate to inner core (tee) so each sub-core's level enabler is respected.
// Then add ourselves for sink forwarding only.
ce = s.core.Check(entry, ce)
if ce != nil {
ce = ce.AddCore(entry, s)
}
return ce
}
func (s *sinkCore) Write(entry zapcore.Entry, fields []zapcore.Field) error {
// Only handle sink forwarding — the inner cores write via their own
// Write methods (added to CheckedEntry by s.core.Check above).
mu.RLock()
sink := currentSink
mu.RUnlock()
if sink == nil {
return nil
}
enc := zapcore.NewMapObjectEncoder()
for _, f := range s.fields {
f.AddTo(enc)
}
for _, f := range fields {
f.AddTo(enc)
}
event := &LogEvent{
Time: entry.Time,
Level: strings.ToLower(entry.Level.String()),
Component: entry.LoggerName,
Message: entry.Message,
LoggerName: entry.LoggerName,
Fields: enc.Fields,
}
sink.WriteLogEvent(event)
return nil
}
func (s *sinkCore) Sync() error {
return s.core.Sync()
}
type stdLogBridge struct {
logger *zap.Logger
}
func newStdLogBridge(l *zap.Logger) io.Writer {
if l == nil {
l = zap.NewNop()
}
return &stdLogBridge{logger: l}
}
func (b *stdLogBridge) Write(p []byte) (int, error) {
msg := normalizeStdLogMessage(string(p))
if msg == "" {
return len(p), nil
}
level := inferStdLogLevel(msg)
entry := b.logger.WithOptions(zap.AddCallerSkip(4))
switch level {
case LevelDebug:
entry.Debug(msg, zap.Bool("legacy_stdlog", true))
case LevelWarn:
entry.Warn(msg, zap.Bool("legacy_stdlog", true))
case LevelError, LevelFatal:
entry.Error(msg, zap.Bool("legacy_stdlog", true))
default:
entry.Info(msg, zap.Bool("legacy_stdlog", true))
}
return len(p), nil
}
func normalizeStdLogMessage(raw string) string {
msg := strings.TrimSpace(strings.ReplaceAll(raw, "\n", " "))
if msg == "" {
return ""
}
return strings.Join(strings.Fields(msg), " ")
}
func inferStdLogLevel(msg string) Level {
lower := strings.ToLower(strings.TrimSpace(msg))
if lower == "" {
return LevelInfo
}
if strings.HasPrefix(lower, "[debug]") || strings.HasPrefix(lower, "debug:") {
return LevelDebug
}
if strings.HasPrefix(lower, "[warn]") || strings.HasPrefix(lower, "[warning]") || strings.HasPrefix(lower, "warn:") || strings.HasPrefix(lower, "warning:") {
return LevelWarn
}
if strings.HasPrefix(lower, "[error]") || strings.HasPrefix(lower, "error:") || strings.HasPrefix(lower, "fatal:") || strings.HasPrefix(lower, "panic:") {
return LevelError
}
if strings.Contains(lower, " failed") || strings.Contains(lower, "error") || strings.Contains(lower, "panic") || strings.Contains(lower, "fatal") {
return LevelError
}
if strings.Contains(lower, "warning") || strings.Contains(lower, "warn") || strings.Contains(lower, " retry") || strings.Contains(lower, " queue full") || strings.Contains(lower, "fallback") {
return LevelWarn
}
return LevelInfo
}
// LegacyPrintf 用于平滑迁移历史的 printf 风格日志到结构化 logger。
func LegacyPrintf(component, format string, args ...any) {
msg := normalizeStdLogMessage(fmt.Sprintf(format, args...))
if msg == "" {
return
}
mu.RLock()
initialized := global != nil
mu.RUnlock()
if !initialized {
// 在日志系统未初始化前,回退到标准库 log避免测试/工具链丢日志。
log.Print(msg)
return
}
l := L()
if component != "" {
l = l.With(zap.String("component", component))
}
l = l.WithOptions(zap.AddCallerSkip(1))
switch inferStdLogLevel(msg) {
case LevelDebug:
l.Debug(msg, zap.Bool("legacy_printf", true))
case LevelWarn:
l.Warn(msg, zap.Bool("legacy_printf", true))
case LevelError, LevelFatal:
l.Error(msg, zap.Bool("legacy_printf", true))
default:
l.Info(msg, zap.Bool("legacy_printf", true))
}
}
type contextKey string
const loggerContextKey contextKey = "ctx_logger"
func IntoContext(ctx context.Context, l *zap.Logger) context.Context {
if ctx == nil {
ctx = context.Background()
}
if l == nil {
l = L()
}
return context.WithValue(ctx, loggerContextKey, l)
}
func FromContext(ctx context.Context) *zap.Logger {
if ctx == nil {
return L()
}
if l, ok := ctx.Value(loggerContextKey).(*zap.Logger); ok && l != nil {
return l
}
return L()
}

View File

@@ -0,0 +1,192 @@
package logger
import (
"encoding/json"
"io"
"os"
"path/filepath"
"strings"
"testing"
)
func TestInit_DualOutput(t *testing.T) {
tmpDir := t.TempDir()
logPath := filepath.Join(tmpDir, "logs", "sub2api.log")
origStdout := os.Stdout
origStderr := os.Stderr
stdoutR, stdoutW, err := os.Pipe()
if err != nil {
t.Fatalf("create stdout pipe: %v", err)
}
stderrR, stderrW, err := os.Pipe()
if err != nil {
t.Fatalf("create stderr pipe: %v", err)
}
os.Stdout = stdoutW
os.Stderr = stderrW
t.Cleanup(func() {
os.Stdout = origStdout
os.Stderr = origStderr
_ = stdoutR.Close()
_ = stderrR.Close()
_ = stdoutW.Close()
_ = stderrW.Close()
})
err = Init(InitOptions{
Level: "debug",
Format: "json",
ServiceName: "sub2api",
Environment: "test",
Output: OutputOptions{
ToStdout: true,
ToFile: true,
FilePath: logPath,
},
Rotation: RotationOptions{
MaxSizeMB: 10,
MaxBackups: 2,
MaxAgeDays: 1,
},
Sampling: SamplingOptions{Enabled: false},
})
if err != nil {
t.Fatalf("Init() error: %v", err)
}
L().Info("dual-output-info")
L().Warn("dual-output-warn")
Sync()
_ = stdoutW.Close()
_ = stderrW.Close()
stdoutBytes, _ := io.ReadAll(stdoutR)
stderrBytes, _ := io.ReadAll(stderrR)
stdoutText := string(stdoutBytes)
stderrText := string(stderrBytes)
if !strings.Contains(stdoutText, "dual-output-info") {
t.Fatalf("stdout missing info log: %s", stdoutText)
}
if !strings.Contains(stderrText, "dual-output-warn") {
t.Fatalf("stderr missing warn log: %s", stderrText)
}
fileBytes, err := os.ReadFile(logPath)
if err != nil {
t.Fatalf("read log file: %v", err)
}
fileText := string(fileBytes)
if !strings.Contains(fileText, "dual-output-info") || !strings.Contains(fileText, "dual-output-warn") {
t.Fatalf("file missing logs: %s", fileText)
}
}
func TestInit_FileOutputFailureDowngrade(t *testing.T) {
origStdout := os.Stdout
origStderr := os.Stderr
_, stdoutW, err := os.Pipe()
if err != nil {
t.Fatalf("create stdout pipe: %v", err)
}
stderrR, stderrW, err := os.Pipe()
if err != nil {
t.Fatalf("create stderr pipe: %v", err)
}
os.Stdout = stdoutW
os.Stderr = stderrW
t.Cleanup(func() {
os.Stdout = origStdout
os.Stderr = origStderr
_ = stdoutW.Close()
_ = stderrR.Close()
_ = stderrW.Close()
})
err = Init(InitOptions{
Level: "info",
Format: "json",
Output: OutputOptions{
ToStdout: true,
ToFile: true,
FilePath: filepath.Join(os.DevNull, "logs", "sub2api.log"),
},
Rotation: RotationOptions{
MaxSizeMB: 10,
MaxBackups: 1,
MaxAgeDays: 1,
},
})
if err != nil {
t.Fatalf("Init() should downgrade instead of failing, got: %v", err)
}
_ = stderrW.Close()
stderrBytes, _ := io.ReadAll(stderrR)
if !strings.Contains(string(stderrBytes), "日志文件输出初始化失败") {
t.Fatalf("stderr should contain fallback warning, got: %s", string(stderrBytes))
}
}
func TestInit_CallerShouldPointToCallsite(t *testing.T) {
origStdout := os.Stdout
origStderr := os.Stderr
stdoutR, stdoutW, err := os.Pipe()
if err != nil {
t.Fatalf("create stdout pipe: %v", err)
}
_, stderrW, err := os.Pipe()
if err != nil {
t.Fatalf("create stderr pipe: %v", err)
}
os.Stdout = stdoutW
os.Stderr = stderrW
t.Cleanup(func() {
os.Stdout = origStdout
os.Stderr = origStderr
_ = stdoutR.Close()
_ = stdoutW.Close()
_ = stderrW.Close()
})
if err := Init(InitOptions{
Level: "info",
Format: "json",
ServiceName: "sub2api",
Environment: "test",
Caller: true,
Output: OutputOptions{
ToStdout: true,
ToFile: false,
},
Sampling: SamplingOptions{Enabled: false},
}); err != nil {
t.Fatalf("Init() error: %v", err)
}
L().Info("caller-check")
Sync()
_ = stdoutW.Close()
logBytes, _ := io.ReadAll(stdoutR)
var line string
for _, item := range strings.Split(string(logBytes), "\n") {
if strings.Contains(item, "caller-check") {
line = item
break
}
}
if line == "" {
t.Fatalf("log output missing caller-check: %s", string(logBytes))
}
var payload map[string]any
if err := json.Unmarshal([]byte(line), &payload); err != nil {
t.Fatalf("parse log json failed: %v, line=%s", err, line)
}
caller, _ := payload["caller"].(string)
if !strings.Contains(caller, "logger_test.go:") {
t.Fatalf("caller should point to this test file, got: %s", caller)
}
}

View File

@@ -0,0 +1,161 @@
package logger
import (
"os"
"path/filepath"
"strings"
"time"
)
const (
// DefaultContainerLogPath 为容器内默认日志文件路径。
DefaultContainerLogPath = "/app/data/logs/sub2api.log"
defaultLogFilename = "sub2api.log"
)
type InitOptions struct {
Level string
Format string
ServiceName string
Environment string
Caller bool
StacktraceLevel string
Output OutputOptions
Rotation RotationOptions
Sampling SamplingOptions
}
type OutputOptions struct {
ToStdout bool
ToFile bool
FilePath string
}
type RotationOptions struct {
MaxSizeMB int
MaxBackups int
MaxAgeDays int
Compress bool
LocalTime bool
}
type SamplingOptions struct {
Enabled bool
Initial int
Thereafter int
}
func (o InitOptions) normalized() InitOptions {
out := o
out.Level = strings.ToLower(strings.TrimSpace(out.Level))
if out.Level == "" {
out.Level = "info"
}
out.Format = strings.ToLower(strings.TrimSpace(out.Format))
if out.Format == "" {
out.Format = "console"
}
out.ServiceName = strings.TrimSpace(out.ServiceName)
if out.ServiceName == "" {
out.ServiceName = "sub2api"
}
out.Environment = strings.TrimSpace(out.Environment)
if out.Environment == "" {
out.Environment = "production"
}
out.StacktraceLevel = strings.ToLower(strings.TrimSpace(out.StacktraceLevel))
if out.StacktraceLevel == "" {
out.StacktraceLevel = "error"
}
if !out.Output.ToStdout && !out.Output.ToFile {
out.Output.ToStdout = true
}
out.Output.FilePath = resolveLogFilePath(out.Output.FilePath)
if out.Rotation.MaxSizeMB <= 0 {
out.Rotation.MaxSizeMB = 100
}
if out.Rotation.MaxBackups < 0 {
out.Rotation.MaxBackups = 10
}
if out.Rotation.MaxAgeDays < 0 {
out.Rotation.MaxAgeDays = 7
}
if out.Sampling.Enabled {
if out.Sampling.Initial <= 0 {
out.Sampling.Initial = 100
}
if out.Sampling.Thereafter <= 0 {
out.Sampling.Thereafter = 100
}
}
return out
}
func resolveLogFilePath(explicit string) string {
explicit = strings.TrimSpace(explicit)
if explicit != "" {
return explicit
}
dataDir := strings.TrimSpace(os.Getenv("DATA_DIR"))
if dataDir != "" {
return filepath.Join(dataDir, "logs", defaultLogFilename)
}
return DefaultContainerLogPath
}
func bootstrapOptions() InitOptions {
return InitOptions{
Level: "info",
Format: "console",
ServiceName: "sub2api",
Environment: "bootstrap",
Output: OutputOptions{
ToStdout: true,
ToFile: false,
},
Rotation: RotationOptions{
MaxSizeMB: 100,
MaxBackups: 10,
MaxAgeDays: 7,
Compress: true,
LocalTime: true,
},
Sampling: SamplingOptions{
Enabled: false,
Initial: 100,
Thereafter: 100,
},
}
}
func parseLevel(level string) (Level, bool) {
switch strings.ToLower(strings.TrimSpace(level)) {
case "debug":
return LevelDebug, true
case "info":
return LevelInfo, true
case "warn":
return LevelWarn, true
case "error":
return LevelError, true
default:
return LevelInfo, false
}
}
func parseStacktraceLevel(level string) (Level, bool) {
switch strings.ToLower(strings.TrimSpace(level)) {
case "none":
return LevelFatal + 1, true
case "error":
return LevelError, true
case "fatal":
return LevelFatal, true
default:
return LevelError, false
}
}
func samplingTick() time.Duration {
return time.Second
}

View File

@@ -0,0 +1,102 @@
package logger
import (
"os"
"path/filepath"
"testing"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)
func TestResolveLogFilePath_Default(t *testing.T) {
t.Setenv("DATA_DIR", "")
got := resolveLogFilePath("")
if got != DefaultContainerLogPath {
t.Fatalf("resolveLogFilePath() = %q, want %q", got, DefaultContainerLogPath)
}
}
func TestResolveLogFilePath_WithDataDir(t *testing.T) {
t.Setenv("DATA_DIR", "/tmp/sub2api-data")
got := resolveLogFilePath("")
want := filepath.Join("/tmp/sub2api-data", "logs", "sub2api.log")
if got != want {
t.Fatalf("resolveLogFilePath() = %q, want %q", got, want)
}
}
func TestResolveLogFilePath_ExplicitPath(t *testing.T) {
t.Setenv("DATA_DIR", "/tmp/ignore")
got := resolveLogFilePath("/var/log/custom.log")
if got != "/var/log/custom.log" {
t.Fatalf("resolveLogFilePath() = %q, want explicit path", got)
}
}
func TestNormalizedOptions_InvalidFallback(t *testing.T) {
t.Setenv("DATA_DIR", "")
opts := InitOptions{
Level: "TRACE",
Format: "TEXT",
ServiceName: "",
Environment: "",
StacktraceLevel: "panic",
Output: OutputOptions{
ToStdout: false,
ToFile: false,
},
Rotation: RotationOptions{
MaxSizeMB: 0,
MaxBackups: -1,
MaxAgeDays: -1,
},
Sampling: SamplingOptions{
Enabled: true,
Initial: 0,
Thereafter: 0,
},
}
out := opts.normalized()
if out.Level != "trace" {
// normalized 仅做 trim/lower不做校验校验在 config 层。
t.Fatalf("normalized level should preserve value for upstream validation, got %q", out.Level)
}
if !out.Output.ToStdout {
t.Fatalf("normalized output should fallback to stdout")
}
if out.Output.FilePath != DefaultContainerLogPath {
t.Fatalf("normalized file path = %q", out.Output.FilePath)
}
if out.Rotation.MaxSizeMB != 100 {
t.Fatalf("normalized max_size_mb = %d", out.Rotation.MaxSizeMB)
}
if out.Rotation.MaxBackups != 10 {
t.Fatalf("normalized max_backups = %d", out.Rotation.MaxBackups)
}
if out.Rotation.MaxAgeDays != 7 {
t.Fatalf("normalized max_age_days = %d", out.Rotation.MaxAgeDays)
}
if out.Sampling.Initial != 100 || out.Sampling.Thereafter != 100 {
t.Fatalf("normalized sampling defaults invalid: %+v", out.Sampling)
}
}
func TestBuildFileCore_InvalidPathFallback(t *testing.T) {
t.Setenv("DATA_DIR", "")
opts := bootstrapOptions()
opts.Output.ToFile = true
opts.Output.FilePath = filepath.Join(os.DevNull, "logs", "sub2api.log")
encoderCfg := zapcore.EncoderConfig{
TimeKey: "time",
LevelKey: "level",
MessageKey: "msg",
EncodeTime: zapcore.ISO8601TimeEncoder,
EncodeLevel: zapcore.CapitalLevelEncoder,
}
encoder := zapcore.NewJSONEncoder(encoderCfg)
_, _, err := buildFileCore(encoder, zap.NewAtomicLevel(), opts)
if err == nil {
t.Fatalf("buildFileCore() expected error for invalid path")
}
}

View File

@@ -0,0 +1,132 @@
package logger
import (
"context"
"log/slog"
"strings"
"time"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)
type slogZapHandler struct {
logger *zap.Logger
attrs []slog.Attr
groups []string
}
func newSlogZapHandler(logger *zap.Logger) slog.Handler {
if logger == nil {
logger = zap.NewNop()
}
return &slogZapHandler{
logger: logger,
attrs: make([]slog.Attr, 0, 8),
groups: make([]string, 0, 4),
}
}
func (h *slogZapHandler) Enabled(_ context.Context, level slog.Level) bool {
switch {
case level >= slog.LevelError:
return h.logger.Core().Enabled(LevelError)
case level >= slog.LevelWarn:
return h.logger.Core().Enabled(LevelWarn)
case level <= slog.LevelDebug:
return h.logger.Core().Enabled(LevelDebug)
default:
return h.logger.Core().Enabled(LevelInfo)
}
}
func (h *slogZapHandler) Handle(_ context.Context, record slog.Record) error {
fields := make([]zap.Field, 0, len(h.attrs)+record.NumAttrs()+3)
fields = append(fields, slogAttrsToZapFields(h.groups, h.attrs)...)
record.Attrs(func(attr slog.Attr) bool {
fields = append(fields, slogAttrToZapField(h.groups, attr))
return true
})
entry := h.logger.With(fields...)
switch {
case record.Level >= slog.LevelError:
entry.Error(record.Message)
case record.Level >= slog.LevelWarn:
entry.Warn(record.Message)
case record.Level <= slog.LevelDebug:
entry.Debug(record.Message)
default:
entry.Info(record.Message)
}
return nil
}
func (h *slogZapHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
next := *h
next.attrs = append(append([]slog.Attr{}, h.attrs...), attrs...)
return &next
}
func (h *slogZapHandler) WithGroup(name string) slog.Handler {
name = strings.TrimSpace(name)
if name == "" {
return h
}
next := *h
next.groups = append(append([]string{}, h.groups...), name)
return &next
}
func slogAttrsToZapFields(groups []string, attrs []slog.Attr) []zap.Field {
fields := make([]zap.Field, 0, len(attrs))
for _, attr := range attrs {
fields = append(fields, slogAttrToZapField(groups, attr))
}
return fields
}
func slogAttrToZapField(groups []string, attr slog.Attr) zap.Field {
if len(groups) > 0 {
attr.Key = strings.Join(append(append([]string{}, groups...), attr.Key), ".")
}
value := attr.Value.Resolve()
switch value.Kind() {
case slog.KindBool:
return zap.Bool(attr.Key, value.Bool())
case slog.KindInt64:
return zap.Int64(attr.Key, value.Int64())
case slog.KindUint64:
return zap.Uint64(attr.Key, value.Uint64())
case slog.KindFloat64:
return zap.Float64(attr.Key, value.Float64())
case slog.KindDuration:
return zap.Duration(attr.Key, value.Duration())
case slog.KindTime:
return zap.Time(attr.Key, value.Time())
case slog.KindString:
return zap.String(attr.Key, value.String())
case slog.KindGroup:
groupFields := make([]zap.Field, 0, len(value.Group()))
for _, nested := range value.Group() {
groupFields = append(groupFields, slogAttrToZapField(nil, nested))
}
return zap.Object(attr.Key, zapObjectFields(groupFields))
case slog.KindAny:
if t, ok := value.Any().(time.Time); ok {
return zap.Time(attr.Key, t)
}
return zap.Any(attr.Key, value.Any())
default:
return zap.String(attr.Key, value.String())
}
}
type zapObjectFields []zap.Field
func (z zapObjectFields) MarshalLogObject(enc zapcore.ObjectEncoder) error {
for _, field := range z {
field.AddTo(enc)
}
return nil
}

View File

@@ -0,0 +1,88 @@
package logger
import (
"context"
"log/slog"
"testing"
"time"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)
type captureState struct {
writes []capturedWrite
}
type capturedWrite struct {
fields []zapcore.Field
}
type captureCore struct {
state *captureState
withFields []zapcore.Field
}
func newCaptureCore() *captureCore {
return &captureCore{state: &captureState{}}
}
func (c *captureCore) Enabled(zapcore.Level) bool {
return true
}
func (c *captureCore) With(fields []zapcore.Field) zapcore.Core {
nextFields := make([]zapcore.Field, 0, len(c.withFields)+len(fields))
nextFields = append(nextFields, c.withFields...)
nextFields = append(nextFields, fields...)
return &captureCore{
state: c.state,
withFields: nextFields,
}
}
func (c *captureCore) Check(entry zapcore.Entry, ce *zapcore.CheckedEntry) *zapcore.CheckedEntry {
return ce.AddCore(entry, c)
}
func (c *captureCore) Write(entry zapcore.Entry, fields []zapcore.Field) error {
allFields := make([]zapcore.Field, 0, len(c.withFields)+len(fields))
allFields = append(allFields, c.withFields...)
allFields = append(allFields, fields...)
c.state.writes = append(c.state.writes, capturedWrite{
fields: allFields,
})
return nil
}
func (c *captureCore) Sync() error {
return nil
}
func TestSlogZapHandler_Handle_DoesNotAppendTimeField(t *testing.T) {
core := newCaptureCore()
handler := newSlogZapHandler(zap.New(core))
record := slog.NewRecord(time.Date(2026, 1, 1, 12, 0, 0, 0, time.UTC), slog.LevelInfo, "hello", 0)
record.AddAttrs(slog.String("component", "http.access"))
if err := handler.Handle(context.Background(), record); err != nil {
t.Fatalf("handle slog record: %v", err)
}
if len(core.state.writes) != 1 {
t.Fatalf("write calls = %d, want 1", len(core.state.writes))
}
var hasComponent bool
for _, field := range core.state.writes[0].fields {
if field.Key == "time" {
t.Fatalf("unexpected duplicate time field in slog adapter output")
}
if field.Key == "component" {
hasComponent = true
}
}
if !hasComponent {
t.Fatalf("component field should be preserved")
}
}

View File

@@ -0,0 +1,165 @@
package logger
import (
"io"
"log"
"os"
"strings"
"testing"
)
func TestInferStdLogLevel(t *testing.T) {
cases := []struct {
msg string
want Level
}{
{msg: "Warning: queue full", want: LevelWarn},
{msg: "Forward request failed: timeout", want: LevelError},
{msg: "[ERROR] upstream unavailable", want: LevelError},
{msg: "service started", want: LevelInfo},
{msg: "debug: cache miss", want: LevelDebug},
}
for _, tc := range cases {
got := inferStdLogLevel(tc.msg)
if got != tc.want {
t.Fatalf("inferStdLogLevel(%q)=%v want=%v", tc.msg, got, tc.want)
}
}
}
func TestNormalizeStdLogMessage(t *testing.T) {
raw := " [TokenRefresh] cycle complete \n total=1 failed=0 \n"
got := normalizeStdLogMessage(raw)
want := "[TokenRefresh] cycle complete total=1 failed=0"
if got != want {
t.Fatalf("normalizeStdLogMessage()=%q want=%q", got, want)
}
}
func TestStdLogBridgeRoutesLevels(t *testing.T) {
origStdout := os.Stdout
origStderr := os.Stderr
stdoutR, stdoutW, err := os.Pipe()
if err != nil {
t.Fatalf("create stdout pipe: %v", err)
}
stderrR, stderrW, err := os.Pipe()
if err != nil {
t.Fatalf("create stderr pipe: %v", err)
}
os.Stdout = stdoutW
os.Stderr = stderrW
t.Cleanup(func() {
os.Stdout = origStdout
os.Stderr = origStderr
_ = stdoutR.Close()
_ = stdoutW.Close()
_ = stderrR.Close()
_ = stderrW.Close()
})
if err := Init(InitOptions{
Level: "debug",
Format: "json",
ServiceName: "sub2api",
Environment: "test",
Output: OutputOptions{
ToStdout: true,
ToFile: false,
},
Sampling: SamplingOptions{Enabled: false},
}); err != nil {
t.Fatalf("Init() error: %v", err)
}
log.Printf("service started")
log.Printf("Warning: queue full")
log.Printf("Forward request failed: timeout")
Sync()
_ = stdoutW.Close()
_ = stderrW.Close()
stdoutBytes, _ := io.ReadAll(stdoutR)
stderrBytes, _ := io.ReadAll(stderrR)
stdoutText := string(stdoutBytes)
stderrText := string(stderrBytes)
if !strings.Contains(stdoutText, "service started") {
t.Fatalf("stdout missing info log: %s", stdoutText)
}
if !strings.Contains(stderrText, "Warning: queue full") {
t.Fatalf("stderr missing warn log: %s", stderrText)
}
if !strings.Contains(stderrText, "Forward request failed: timeout") {
t.Fatalf("stderr missing error log: %s", stderrText)
}
if !strings.Contains(stderrText, "\"legacy_stdlog\":true") {
t.Fatalf("stderr missing legacy_stdlog marker: %s", stderrText)
}
}
func TestLegacyPrintfRoutesLevels(t *testing.T) {
origStdout := os.Stdout
origStderr := os.Stderr
stdoutR, stdoutW, err := os.Pipe()
if err != nil {
t.Fatalf("create stdout pipe: %v", err)
}
stderrR, stderrW, err := os.Pipe()
if err != nil {
t.Fatalf("create stderr pipe: %v", err)
}
os.Stdout = stdoutW
os.Stderr = stderrW
t.Cleanup(func() {
os.Stdout = origStdout
os.Stderr = origStderr
_ = stdoutR.Close()
_ = stdoutW.Close()
_ = stderrR.Close()
_ = stderrW.Close()
})
if err := Init(InitOptions{
Level: "debug",
Format: "json",
ServiceName: "sub2api",
Environment: "test",
Output: OutputOptions{
ToStdout: true,
ToFile: false,
},
Sampling: SamplingOptions{Enabled: false},
}); err != nil {
t.Fatalf("Init() error: %v", err)
}
LegacyPrintf("service.test", "request started")
LegacyPrintf("service.test", "Warning: queue full")
LegacyPrintf("service.test", "forward failed: timeout")
Sync()
_ = stdoutW.Close()
_ = stderrW.Close()
stdoutBytes, _ := io.ReadAll(stdoutR)
stderrBytes, _ := io.ReadAll(stderrR)
stdoutText := string(stdoutBytes)
stderrText := string(stderrBytes)
if !strings.Contains(stdoutText, "request started") {
t.Fatalf("stdout missing info log: %s", stdoutText)
}
if !strings.Contains(stderrText, "Warning: queue full") {
t.Fatalf("stderr missing warn log: %s", stderrText)
}
if !strings.Contains(stderrText, "forward failed: timeout") {
t.Fatalf("stderr missing error log: %s", stderrText)
}
if !strings.Contains(stderrText, "\"legacy_printf\":true") {
t.Fatalf("stderr missing legacy_printf marker: %s", stderrText)
}
if !strings.Contains(stderrText, "\"component\":\"service.test\"") {
t.Fatalf("stderr missing component field: %s", stderrText)
}
}

View File

@@ -50,6 +50,7 @@ type OAuthSession struct {
type SessionStore struct {
mu sync.RWMutex
sessions map[string]*OAuthSession
stopOnce sync.Once
stopCh chan struct{}
}
@@ -65,7 +66,9 @@ func NewSessionStore() *SessionStore {
// Stop stops the cleanup goroutine
func (s *SessionStore) Stop() {
close(s.stopCh)
s.stopOnce.Do(func() {
close(s.stopCh)
})
}
// Set stores a session

View File

@@ -0,0 +1,43 @@
package oauth
import (
"sync"
"testing"
"time"
)
func TestSessionStore_Stop_Idempotent(t *testing.T) {
store := NewSessionStore()
store.Stop()
store.Stop()
select {
case <-store.stopCh:
// ok
case <-time.After(time.Second):
t.Fatal("stopCh 未关闭")
}
}
func TestSessionStore_Stop_Concurrent(t *testing.T) {
store := NewSessionStore()
var wg sync.WaitGroup
for range 50 {
wg.Add(1)
go func() {
defer wg.Done()
store.Stop()
}()
}
wg.Wait()
select {
case <-store.stopCh:
// ok
case <-time.After(time.Second):
t.Fatal("stopCh 未关闭")
}
}

View File

@@ -15,8 +15,8 @@ type Model struct {
// DefaultModels OpenAI models list
var DefaultModels = []Model{
{ID: "gpt-5.3", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3"},
{ID: "gpt-5.3-codex", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3 Codex"},
{ID: "gpt-5.3-codex-spark", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3 Codex Spark"},
{ID: "gpt-5.2", Object: "model", Created: 1733875200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2"},
{ID: "gpt-5.2-codex", Object: "model", Created: 1733011200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2 Codex"},
{ID: "gpt-5.1-codex-max", Object: "model", Created: 1730419200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1 Codex Max"},

View File

@@ -17,6 +17,8 @@ import (
const (
// OAuth Client ID for OpenAI (Codex CLI official)
ClientID = "app_EMoamEEZ73f0CkXaXp7hrann"
// OAuth Client ID for Sora mobile flow (aligned with sora2api)
SoraClientID = "app_LlGpXReQgckcGGUo2JrYvtJK"
// OAuth endpoints
AuthorizeURL = "https://auth.openai.com/oauth/authorize"
@@ -47,6 +49,7 @@ type OAuthSession struct {
type SessionStore struct {
mu sync.RWMutex
sessions map[string]*OAuthSession
stopOnce sync.Once
stopCh chan struct{}
}
@@ -92,7 +95,9 @@ func (s *SessionStore) Delete(sessionID string) {
// Stop stops the cleanup goroutine
func (s *SessionStore) Stop() {
close(s.stopCh)
s.stopOnce.Do(func() {
close(s.stopCh)
})
}
// cleanup removes expired sessions periodically

View File

@@ -0,0 +1,43 @@
package openai
import (
"sync"
"testing"
"time"
)
func TestSessionStore_Stop_Idempotent(t *testing.T) {
store := NewSessionStore()
store.Stop()
store.Stop()
select {
case <-store.stopCh:
// ok
case <-time.After(time.Second):
t.Fatal("stopCh 未关闭")
}
}
func TestSessionStore_Stop_Concurrent(t *testing.T) {
store := NewSessionStore()
var wg sync.WaitGroup
for range 50 {
wg.Add(1)
go func() {
defer wg.Done()
store.Stop()
}()
}
wg.Wait()
select {
case <-store.stopCh:
// ok
case <-time.After(time.Second):
t.Fatal("stopCh 未关闭")
}
}

View File

@@ -1,5 +1,7 @@
package openai
import "strings"
// CodexCLIUserAgentPrefixes matches Codex CLI User-Agent patterns
// Examples: "codex_vscode/1.0.0", "codex_cli_rs/0.1.2"
var CodexCLIUserAgentPrefixes = []string{
@@ -7,10 +9,67 @@ var CodexCLIUserAgentPrefixes = []string{
"codex_cli_rs/",
}
// CodexOfficialClientUserAgentPrefixes matches Codex 官方客户端家族 User-Agent 前缀。
// 该列表仅用于 OpenAI OAuth `codex_cli_only` 访问限制判定。
var CodexOfficialClientUserAgentPrefixes = []string{
"codex_cli_rs/",
"codex_vscode/",
"codex_app/",
"codex_chatgpt_desktop/",
"codex_atlas/",
"codex_exec/",
"codex_sdk_ts/",
"codex ",
}
// CodexOfficialClientOriginatorPrefixes matches Codex 官方客户端家族 originator 前缀。
// 说明OpenAI 官方 Codex 客户端并不只使用固定的 codex_app 标识。
// 例如 codex_cli_rs、codex_vscode、codex_chatgpt_desktop、codex_atlas、codex_exec、codex_sdk_ts 等。
var CodexOfficialClientOriginatorPrefixes = []string{
"codex_",
"codex ",
}
// IsCodexCLIRequest checks if the User-Agent indicates a Codex CLI request
func IsCodexCLIRequest(userAgent string) bool {
for _, prefix := range CodexCLIUserAgentPrefixes {
if len(userAgent) >= len(prefix) && userAgent[:len(prefix)] == prefix {
ua := normalizeCodexClientHeader(userAgent)
if ua == "" {
return false
}
return matchCodexClientHeaderPrefixes(ua, CodexCLIUserAgentPrefixes)
}
// IsCodexOfficialClientRequest checks if the User-Agent indicates a Codex 官方客户端请求。
// 与 IsCodexCLIRequest 解耦,避免影响历史兼容逻辑。
func IsCodexOfficialClientRequest(userAgent string) bool {
ua := normalizeCodexClientHeader(userAgent)
if ua == "" {
return false
}
return matchCodexClientHeaderPrefixes(ua, CodexOfficialClientUserAgentPrefixes)
}
// IsCodexOfficialClientOriginator checks if originator indicates a Codex 官方客户端请求。
func IsCodexOfficialClientOriginator(originator string) bool {
v := normalizeCodexClientHeader(originator)
if v == "" {
return false
}
return matchCodexClientHeaderPrefixes(v, CodexOfficialClientOriginatorPrefixes)
}
func normalizeCodexClientHeader(value string) string {
return strings.ToLower(strings.TrimSpace(value))
}
func matchCodexClientHeaderPrefixes(value string, prefixes []string) bool {
for _, prefix := range prefixes {
normalizedPrefix := normalizeCodexClientHeader(prefix)
if normalizedPrefix == "" {
continue
}
// 优先前缀匹配;若 UA/Originator 被网关拼接为复合字符串时,退化为包含匹配。
if strings.HasPrefix(value, normalizedPrefix) || strings.Contains(value, normalizedPrefix) {
return true
}
}

View File

@@ -0,0 +1,87 @@
package openai
import "testing"
func TestIsCodexCLIRequest(t *testing.T) {
tests := []struct {
name string
ua string
want bool
}{
{name: "codex_cli_rs 前缀", ua: "codex_cli_rs/0.1.0", want: true},
{name: "codex_vscode 前缀", ua: "codex_vscode/1.2.3", want: true},
{name: "大小写混合", ua: "Codex_CLI_Rs/0.1.0", want: true},
{name: "复合 UA 包含 codex", ua: "Mozilla/5.0 codex_cli_rs/0.1.0", want: true},
{name: "空白包裹", ua: " codex_vscode/1.2.3 ", want: true},
{name: "非 codex", ua: "curl/8.0.1", want: false},
{name: "空字符串", ua: "", want: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := IsCodexCLIRequest(tt.ua)
if got != tt.want {
t.Fatalf("IsCodexCLIRequest(%q) = %v, want %v", tt.ua, got, tt.want)
}
})
}
}
func TestIsCodexOfficialClientRequest(t *testing.T) {
tests := []struct {
name string
ua string
want bool
}{
{name: "codex_cli_rs 前缀", ua: "codex_cli_rs/0.98.0", want: true},
{name: "codex_vscode 前缀", ua: "codex_vscode/1.0.0", want: true},
{name: "codex_app 前缀", ua: "codex_app/0.1.0", want: true},
{name: "codex_chatgpt_desktop 前缀", ua: "codex_chatgpt_desktop/1.0.0", want: true},
{name: "codex_atlas 前缀", ua: "codex_atlas/1.0.0", want: true},
{name: "codex_exec 前缀", ua: "codex_exec/0.1.0", want: true},
{name: "codex_sdk_ts 前缀", ua: "codex_sdk_ts/0.1.0", want: true},
{name: "Codex 桌面 UA", ua: "Codex Desktop/1.2.3", want: true},
{name: "复合 UA 包含 codex_app", ua: "Mozilla/5.0 codex_app/0.1.0", want: true},
{name: "大小写混合", ua: "Codex_VSCode/1.2.3", want: true},
{name: "非 codex", ua: "curl/8.0.1", want: false},
{name: "空字符串", ua: "", want: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := IsCodexOfficialClientRequest(tt.ua)
if got != tt.want {
t.Fatalf("IsCodexOfficialClientRequest(%q) = %v, want %v", tt.ua, got, tt.want)
}
})
}
}
func TestIsCodexOfficialClientOriginator(t *testing.T) {
tests := []struct {
name string
originator string
want bool
}{
{name: "codex_cli_rs", originator: "codex_cli_rs", want: true},
{name: "codex_vscode", originator: "codex_vscode", want: true},
{name: "codex_app", originator: "codex_app", want: true},
{name: "codex_chatgpt_desktop", originator: "codex_chatgpt_desktop", want: true},
{name: "codex_atlas", originator: "codex_atlas", want: true},
{name: "codex_exec", originator: "codex_exec", want: true},
{name: "codex_sdk_ts", originator: "codex_sdk_ts", want: true},
{name: "Codex 前缀", originator: "Codex Desktop", want: true},
{name: "空白包裹", originator: " codex_vscode ", want: true},
{name: "非 codex", originator: "my_client", want: false},
{name: "空字符串", originator: "", want: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := IsCodexOfficialClientOriginator(tt.originator)
if got != tt.want {
t.Fatalf("IsCodexOfficialClientOriginator(%q) = %v, want %v", tt.originator, got, tt.want)
}
})
}
}

View File

@@ -7,6 +7,7 @@ import (
"net/http"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/util/logredact"
"github.com/gin-gonic/gin"
)
@@ -78,7 +79,7 @@ func ErrorFrom(c *gin.Context, err error) bool {
// Log internal errors with full details for debugging
if statusCode >= 500 && c.Request != nil {
log.Printf("[ERROR] %s %s\n Error: %s", c.Request.Method, c.Request.URL.Path, err.Error())
log.Printf("[ERROR] %s %s\n Error: %s", c.Request.Method, c.Request.URL.Path, logredact.RedactText(err.Error()))
}
ErrorWithDetails(c, statusCode, status.Message, status.Reason, status.Metadata)

View File

@@ -14,6 +14,44 @@ import (
"github.com/stretchr/testify/require"
)
// ---------- 辅助函数 ----------
// parseResponseBody 从 httptest.ResponseRecorder 中解析 JSON 响应体
func parseResponseBody(t *testing.T, w *httptest.ResponseRecorder) Response {
t.Helper()
var got Response
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &got))
return got
}
// parsePaginatedBody 从响应体中解析分页数据Data 字段是 PaginatedData
func parsePaginatedBody(t *testing.T, w *httptest.ResponseRecorder) (Response, PaginatedData) {
t.Helper()
// 先用 raw json 解析,因为 Data 是 any 类型
var raw struct {
Code int `json:"code"`
Message string `json:"message"`
Reason string `json:"reason,omitempty"`
Data json.RawMessage `json:"data,omitempty"`
}
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &raw))
var pd PaginatedData
require.NoError(t, json.Unmarshal(raw.Data, &pd))
return Response{Code: raw.Code, Message: raw.Message, Reason: raw.Reason}, pd
}
// newContextWithQuery 创建一个带有 URL query 参数的 gin.Context 用于测试 ParsePagination
func newContextWithQuery(query string) (*httptest.ResponseRecorder, *gin.Context) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/?"+query, nil)
return w, c
}
// ---------- 现有测试 ----------
func TestErrorWithDetails(t *testing.T) {
gin.SetMode(gin.TestMode)
@@ -169,3 +207,582 @@ func TestErrorFrom(t *testing.T) {
})
}
}
// ---------- 新增测试 ----------
func TestSuccess(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
data any
wantCode int
wantBody Response
}{
{
name: "返回字符串数据",
data: "hello",
wantCode: http.StatusOK,
wantBody: Response{Code: 0, Message: "success", Data: "hello"},
},
{
name: "返回nil数据",
data: nil,
wantCode: http.StatusOK,
wantBody: Response{Code: 0, Message: "success"},
},
{
name: "返回map数据",
data: map[string]string{"key": "value"},
wantCode: http.StatusOK,
wantBody: Response{Code: 0, Message: "success"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
Success(c, tt.data)
require.Equal(t, tt.wantCode, w.Code)
// 只验证 code 和 messagedata 字段类型在 JSON 反序列化时会变成 map/slice
got := parseResponseBody(t, w)
require.Equal(t, 0, got.Code)
require.Equal(t, "success", got.Message)
if tt.data == nil {
require.Nil(t, got.Data)
} else {
require.NotNil(t, got.Data)
}
})
}
}
func TestCreated(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
data any
wantCode int
}{
{
name: "创建成功_返回数据",
data: map[string]int{"id": 42},
wantCode: http.StatusCreated,
},
{
name: "创建成功_nil数据",
data: nil,
wantCode: http.StatusCreated,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
Created(c, tt.data)
require.Equal(t, tt.wantCode, w.Code)
got := parseResponseBody(t, w)
require.Equal(t, 0, got.Code)
require.Equal(t, "success", got.Message)
})
}
}
func TestError(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
statusCode int
message string
}{
{
name: "400错误",
statusCode: http.StatusBadRequest,
message: "bad request",
},
{
name: "500错误",
statusCode: http.StatusInternalServerError,
message: "internal error",
},
{
name: "自定义状态码",
statusCode: 418,
message: "I'm a teapot",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
Error(c, tt.statusCode, tt.message)
require.Equal(t, tt.statusCode, w.Code)
got := parseResponseBody(t, w)
require.Equal(t, tt.statusCode, got.Code)
require.Equal(t, tt.message, got.Message)
require.Empty(t, got.Reason)
require.Nil(t, got.Metadata)
require.Nil(t, got.Data)
})
}
}
func TestBadRequest(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
BadRequest(c, "参数无效")
require.Equal(t, http.StatusBadRequest, w.Code)
got := parseResponseBody(t, w)
require.Equal(t, http.StatusBadRequest, got.Code)
require.Equal(t, "参数无效", got.Message)
}
func TestUnauthorized(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
Unauthorized(c, "未登录")
require.Equal(t, http.StatusUnauthorized, w.Code)
got := parseResponseBody(t, w)
require.Equal(t, http.StatusUnauthorized, got.Code)
require.Equal(t, "未登录", got.Message)
}
func TestForbidden(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
Forbidden(c, "无权限")
require.Equal(t, http.StatusForbidden, w.Code)
got := parseResponseBody(t, w)
require.Equal(t, http.StatusForbidden, got.Code)
require.Equal(t, "无权限", got.Message)
}
func TestNotFound(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
NotFound(c, "资源不存在")
require.Equal(t, http.StatusNotFound, w.Code)
got := parseResponseBody(t, w)
require.Equal(t, http.StatusNotFound, got.Code)
require.Equal(t, "资源不存在", got.Message)
}
func TestInternalError(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
InternalError(c, "服务器内部错误")
require.Equal(t, http.StatusInternalServerError, w.Code)
got := parseResponseBody(t, w)
require.Equal(t, http.StatusInternalServerError, got.Code)
require.Equal(t, "服务器内部错误", got.Message)
}
func TestPaginated(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
items any
total int64
page int
pageSize int
wantPages int
wantTotal int64
wantPage int
wantPageSize int
}{
{
name: "标准分页_多页",
items: []string{"a", "b"},
total: 25,
page: 1,
pageSize: 10,
wantPages: 3,
wantTotal: 25,
wantPage: 1,
wantPageSize: 10,
},
{
name: "总数刚好整除",
items: []string{"a"},
total: 20,
page: 2,
pageSize: 10,
wantPages: 2,
wantTotal: 20,
wantPage: 2,
wantPageSize: 10,
},
{
name: "总数为0_pages至少为1",
items: []string{},
total: 0,
page: 1,
pageSize: 10,
wantPages: 1,
wantTotal: 0,
wantPage: 1,
wantPageSize: 10,
},
{
name: "单页数据",
items: []int{1, 2, 3},
total: 3,
page: 1,
pageSize: 20,
wantPages: 1,
wantTotal: 3,
wantPage: 1,
wantPageSize: 20,
},
{
name: "总数为1",
items: []string{"only"},
total: 1,
page: 1,
pageSize: 10,
wantPages: 1,
wantTotal: 1,
wantPage: 1,
wantPageSize: 10,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
Paginated(c, tt.items, tt.total, tt.page, tt.pageSize)
require.Equal(t, http.StatusOK, w.Code)
resp, pd := parsePaginatedBody(t, w)
require.Equal(t, 0, resp.Code)
require.Equal(t, "success", resp.Message)
require.Equal(t, tt.wantTotal, pd.Total)
require.Equal(t, tt.wantPage, pd.Page)
require.Equal(t, tt.wantPageSize, pd.PageSize)
require.Equal(t, tt.wantPages, pd.Pages)
})
}
}
func TestPaginatedWithResult(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
items any
pagination *PaginationResult
wantTotal int64
wantPage int
wantPageSize int
wantPages int
}{
{
name: "正常分页结果",
items: []string{"a", "b"},
pagination: &PaginationResult{
Total: 50,
Page: 3,
PageSize: 10,
Pages: 5,
},
wantTotal: 50,
wantPage: 3,
wantPageSize: 10,
wantPages: 5,
},
{
name: "pagination为nil_使用默认值",
items: []string{},
pagination: nil,
wantTotal: 0,
wantPage: 1,
wantPageSize: 20,
wantPages: 1,
},
{
name: "单页结果",
items: []int{1},
pagination: &PaginationResult{
Total: 1,
Page: 1,
PageSize: 20,
Pages: 1,
},
wantTotal: 1,
wantPage: 1,
wantPageSize: 20,
wantPages: 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
PaginatedWithResult(c, tt.items, tt.pagination)
require.Equal(t, http.StatusOK, w.Code)
resp, pd := parsePaginatedBody(t, w)
require.Equal(t, 0, resp.Code)
require.Equal(t, "success", resp.Message)
require.Equal(t, tt.wantTotal, pd.Total)
require.Equal(t, tt.wantPage, pd.Page)
require.Equal(t, tt.wantPageSize, pd.PageSize)
require.Equal(t, tt.wantPages, pd.Pages)
})
}
}
func TestParsePagination(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
query string
wantPage int
wantPageSize int
}{
{
name: "无参数_使用默认值",
query: "",
wantPage: 1,
wantPageSize: 20,
},
{
name: "仅指定page",
query: "page=3",
wantPage: 3,
wantPageSize: 20,
},
{
name: "仅指定page_size",
query: "page_size=50",
wantPage: 1,
wantPageSize: 50,
},
{
name: "同时指定page和page_size",
query: "page=2&page_size=30",
wantPage: 2,
wantPageSize: 30,
},
{
name: "使用limit代替page_size",
query: "limit=15",
wantPage: 1,
wantPageSize: 15,
},
{
name: "page_size优先于limit",
query: "page_size=25&limit=50",
wantPage: 1,
wantPageSize: 25,
},
{
name: "page为0_使用默认值",
query: "page=0",
wantPage: 1,
wantPageSize: 20,
},
{
name: "page_size超过1000_使用默认值",
query: "page_size=1001",
wantPage: 1,
wantPageSize: 20,
},
{
name: "page_size恰好1000_有效",
query: "page_size=1000",
wantPage: 1,
wantPageSize: 1000,
},
{
name: "page为非数字_使用默认值",
query: "page=abc",
wantPage: 1,
wantPageSize: 20,
},
{
name: "page_size为非数字_使用默认值",
query: "page_size=xyz",
wantPage: 1,
wantPageSize: 20,
},
{
name: "limit为非数字_使用默认值",
query: "limit=abc",
wantPage: 1,
wantPageSize: 20,
},
{
name: "page_size为0_使用默认值",
query: "page_size=0",
wantPage: 1,
wantPageSize: 20,
},
{
name: "limit为0_使用默认值",
query: "limit=0",
wantPage: 1,
wantPageSize: 20,
},
{
name: "大页码",
query: "page=999&page_size=100",
wantPage: 999,
wantPageSize: 100,
},
{
name: "page_size为1_最小有效值",
query: "page_size=1",
wantPage: 1,
wantPageSize: 1,
},
{
name: "混合数字和字母的page",
query: "page=12a",
wantPage: 1,
wantPageSize: 20,
},
{
name: "limit超过1000_使用默认值",
query: "limit=2000",
wantPage: 1,
wantPageSize: 20,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, c := newContextWithQuery(tt.query)
page, pageSize := ParsePagination(c)
require.Equal(t, tt.wantPage, page, "page 不符合预期")
require.Equal(t, tt.wantPageSize, pageSize, "pageSize 不符合预期")
})
}
}
func Test_parseInt(t *testing.T) {
tests := []struct {
name string
input string
wantVal int
wantErr bool
}{
{
name: "正常数字",
input: "123",
wantVal: 123,
wantErr: false,
},
{
name: "零",
input: "0",
wantVal: 0,
wantErr: false,
},
{
name: "单个数字",
input: "5",
wantVal: 5,
wantErr: false,
},
{
name: "大数字",
input: "99999",
wantVal: 99999,
wantErr: false,
},
{
name: "包含字母_返回0",
input: "abc",
wantVal: 0,
wantErr: false,
},
{
name: "数字开头接字母_返回0",
input: "12a",
wantVal: 0,
wantErr: false,
},
{
name: "包含负号_返回0",
input: "-1",
wantVal: 0,
wantErr: false,
},
{
name: "包含小数点_返回0",
input: "1.5",
wantVal: 0,
wantErr: false,
},
{
name: "包含空格_返回0",
input: "1 2",
wantVal: 0,
wantErr: false,
},
{
name: "空字符串",
input: "",
wantVal: 0,
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
val, err := parseInt(tt.input)
if tt.wantErr {
require.Error(t, err)
} else {
require.NoError(t, err)
}
require.Equal(t, tt.wantVal, val)
})
}
}

View File

@@ -286,7 +286,7 @@ func (d *SOCKS5ProxyDialer) DialTLSContext(ctx context.Context, network, addr st
return nil, fmt.Errorf("apply TLS preset: %w", err)
}
if err := tlsConn.Handshake(); err != nil {
if err := tlsConn.HandshakeContext(ctx); err != nil {
slog.Debug("tls_fingerprint_socks5_handshake_failed", "error", err)
_ = conn.Close()
return nil, fmt.Errorf("TLS handshake failed: %w", err)

View File

@@ -1,3 +1,5 @@
//go:build unit
// Package tlsfingerprint provides TLS fingerprint simulation for HTTP clients.
//
// Unit tests for TLS fingerprint dialer.
@@ -9,26 +11,161 @@
package tlsfingerprint
import (
"context"
"encoding/json"
"io"
"net/http"
"net/url"
"os"
"strings"
"testing"
"time"
)
// FingerprintResponse represents the response from tls.peet.ws/api/all.
type FingerprintResponse struct {
IP string `json:"ip"`
TLS TLSInfo `json:"tls"`
HTTP2 any `json:"http2"`
// TestDialerBasicConnection tests that the dialer can establish TLS connections.
func TestDialerBasicConnection(t *testing.T) {
skipNetworkTest(t)
// Create a dialer with default profile
profile := &Profile{
Name: "Test Profile",
EnableGREASE: false,
}
dialer := NewDialer(profile, nil)
// Create HTTP client with custom TLS dialer
client := &http.Client{
Transport: &http.Transport{
DialTLSContext: dialer.DialTLSContext,
},
Timeout: 30 * time.Second,
}
// Make a request to a known HTTPS endpoint
resp, err := client.Get("https://www.google.com")
if err != nil {
t.Fatalf("failed to connect: %v", err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
t.Errorf("expected status 200, got %d", resp.StatusCode)
}
}
// TLSInfo contains TLS fingerprint details.
type TLSInfo struct {
JA3 string `json:"ja3"`
JA3Hash string `json:"ja3_hash"`
JA4 string `json:"ja4"`
PeetPrint string `json:"peetprint"`
PeetPrintHash string `json:"peetprint_hash"`
ClientRandom string `json:"client_random"`
SessionID string `json:"session_id"`
// TestJA3Fingerprint verifies the JA3/JA4 fingerprint matches expected value.
// This test uses tls.peet.ws to verify the fingerprint.
// Expected JA3 hash: 1a28e69016765d92e3b381168d68922c (Claude CLI / Node.js 20.x)
// Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4 (d=domain) or t13i5911h1_... (i=IP)
func TestJA3Fingerprint(t *testing.T) {
skipNetworkTest(t)
profile := &Profile{
Name: "Claude CLI Test",
EnableGREASE: false,
}
dialer := NewDialer(profile, nil)
client := &http.Client{
Transport: &http.Transport{
DialTLSContext: dialer.DialTLSContext,
},
Timeout: 30 * time.Second,
}
// Use tls.peet.ws fingerprint detection API
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "GET", "https://tls.peet.ws/api/all", nil)
if err != nil {
t.Fatalf("failed to create request: %v", err)
}
req.Header.Set("User-Agent", "Claude Code/2.0.0 Node.js/20.0.0")
resp, err := client.Do(req)
if err != nil {
t.Fatalf("failed to get fingerprint: %v", err)
}
defer func() { _ = resp.Body.Close() }()
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("failed to read response: %v", err)
}
var fpResp FingerprintResponse
if err := json.Unmarshal(body, &fpResp); err != nil {
t.Logf("Response body: %s", string(body))
t.Fatalf("failed to parse fingerprint response: %v", err)
}
// Log all fingerprint information
t.Logf("JA3: %s", fpResp.TLS.JA3)
t.Logf("JA3 Hash: %s", fpResp.TLS.JA3Hash)
t.Logf("JA4: %s", fpResp.TLS.JA4)
t.Logf("PeetPrint: %s", fpResp.TLS.PeetPrint)
t.Logf("PeetPrint Hash: %s", fpResp.TLS.PeetPrintHash)
// Verify JA3 hash matches expected value
expectedJA3Hash := "1a28e69016765d92e3b381168d68922c"
if fpResp.TLS.JA3Hash == expectedJA3Hash {
t.Logf("✓ JA3 hash matches expected value: %s", expectedJA3Hash)
} else {
t.Errorf("✗ JA3 hash mismatch: got %s, expected %s", fpResp.TLS.JA3Hash, expectedJA3Hash)
}
// Verify JA4 fingerprint
// JA4 format: t[version][sni][cipher_count][ext_count][alpn]_[cipher_hash]_[ext_hash]
// Expected: t13d5910h1 (d=domain) or t13i5910h1 (i=IP)
// The suffix _a33745022dd6_1f22a2ca17c4 should match
expectedJA4Suffix := "_a33745022dd6_1f22a2ca17c4"
if strings.HasSuffix(fpResp.TLS.JA4, expectedJA4Suffix) {
t.Logf("✓ JA4 suffix matches expected value: %s", expectedJA4Suffix)
} else {
t.Errorf("✗ JA4 suffix mismatch: got %s, expected suffix %s", fpResp.TLS.JA4, expectedJA4Suffix)
}
// Verify JA4 prefix (t13d5911h1 or t13i5911h1)
// d = domain (SNI present), i = IP (no SNI)
// Since we connect to tls.peet.ws (domain), we expect 'd'
expectedJA4Prefix := "t13d5911h1"
if strings.HasPrefix(fpResp.TLS.JA4, expectedJA4Prefix) {
t.Logf("✓ JA4 prefix matches: %s (t13=TLS1.3, d=domain, 59=ciphers, 11=extensions, h1=HTTP/1.1)", expectedJA4Prefix)
} else {
// Also accept 'i' variant for IP connections
altPrefix := "t13i5911h1"
if strings.HasPrefix(fpResp.TLS.JA4, altPrefix) {
t.Logf("✓ JA4 prefix matches (IP variant): %s", altPrefix)
} else {
t.Errorf("✗ JA4 prefix mismatch: got %s, expected %s or %s", fpResp.TLS.JA4, expectedJA4Prefix, altPrefix)
}
}
// Verify JA3 contains expected cipher suites (TLS 1.3 ciphers at the beginning)
if strings.Contains(fpResp.TLS.JA3, "4866-4867-4865") {
t.Logf("✓ JA3 contains expected TLS 1.3 cipher suites")
} else {
t.Logf("Warning: JA3 does not contain expected TLS 1.3 cipher suites")
}
// Verify extension list (should be 11 extensions including SNI)
// Expected: 0-11-10-35-16-22-23-13-43-45-51
expectedExtensions := "0-11-10-35-16-22-23-13-43-45-51"
if strings.Contains(fpResp.TLS.JA3, expectedExtensions) {
t.Logf("✓ JA3 contains expected extension list: %s", expectedExtensions)
} else {
t.Logf("Warning: JA3 extension list may differ")
}
}
func skipNetworkTest(t *testing.T) {
if testing.Short() {
t.Skip("跳过网络测试short 模式)")
}
if os.Getenv("TLSFINGERPRINT_NETWORK_TESTS") != "1" {
t.Skip("跳过网络测试(需要设置 TLSFINGERPRINT_NETWORK_TESTS=1")
}
}
// TestDialerWithProfile tests that different profiles produce different fingerprints.
@@ -158,3 +295,137 @@ func mustParseURL(rawURL string) *url.URL {
}
return u
}
// TestProfileExpectation defines expected fingerprint values for a profile.
type TestProfileExpectation struct {
Profile *Profile
ExpectedJA3 string // Expected JA3 hash (empty = don't check)
ExpectedJA4 string // Expected full JA4 (empty = don't check)
JA4CipherHash string // Expected JA4 cipher hash - the stable middle part (empty = don't check)
}
// TestAllProfiles tests multiple TLS fingerprint profiles against tls.peet.ws.
// Run with: go test -v -run TestAllProfiles ./internal/pkg/tlsfingerprint/...
func TestAllProfiles(t *testing.T) {
skipNetworkTest(t)
// Define all profiles to test with their expected fingerprints
// These profiles are from config.yaml gateway.tls_fingerprint.profiles
profiles := []TestProfileExpectation{
{
// Linux x64 Node.js v22.17.1
// Expected JA3 Hash: 1a28e69016765d92e3b381168d68922c
// Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4
Profile: &Profile{
Name: "linux_x64_node_v22171",
EnableGREASE: false,
CipherSuites: []uint16{4866, 4867, 4865, 49199, 49195, 49200, 49196, 158, 49191, 103, 49192, 107, 163, 159, 52393, 52392, 52394, 49327, 49325, 49315, 49311, 49245, 49249, 49239, 49235, 162, 49326, 49324, 49314, 49310, 49244, 49248, 49238, 49234, 49188, 106, 49187, 64, 49162, 49172, 57, 56, 49161, 49171, 51, 50, 157, 49313, 49309, 49233, 156, 49312, 49308, 49232, 61, 60, 53, 47, 255},
Curves: []uint16{29, 23, 30, 25, 24, 256, 257, 258, 259, 260},
PointFormats: []uint8{0, 1, 2},
},
JA4CipherHash: "a33745022dd6", // stable part
},
{
// MacOS arm64 Node.js v22.18.0
// Expected JA3 Hash: 70cb5ca646080902703ffda87036a5ea
// Expected JA4: t13d5912h1_a33745022dd6_dbd39dd1d406
Profile: &Profile{
Name: "macos_arm64_node_v22180",
EnableGREASE: false,
CipherSuites: []uint16{4866, 4867, 4865, 49199, 49195, 49200, 49196, 158, 49191, 103, 49192, 107, 163, 159, 52393, 52392, 52394, 49327, 49325, 49315, 49311, 49245, 49249, 49239, 49235, 162, 49326, 49324, 49314, 49310, 49244, 49248, 49238, 49234, 49188, 106, 49187, 64, 49162, 49172, 57, 56, 49161, 49171, 51, 50, 157, 49313, 49309, 49233, 156, 49312, 49308, 49232, 61, 60, 53, 47, 255},
Curves: []uint16{29, 23, 30, 25, 24, 256, 257, 258, 259, 260},
PointFormats: []uint8{0, 1, 2},
},
JA4CipherHash: "a33745022dd6", // stable part (same cipher suites)
},
}
for _, tc := range profiles {
tc := tc // capture range variable
t.Run(tc.Profile.Name, func(t *testing.T) {
fp := fetchFingerprint(t, tc.Profile)
if fp == nil {
return // fetchFingerprint already called t.Fatal
}
t.Logf("Profile: %s", tc.Profile.Name)
t.Logf(" JA3: %s", fp.JA3)
t.Logf(" JA3 Hash: %s", fp.JA3Hash)
t.Logf(" JA4: %s", fp.JA4)
t.Logf(" PeetPrint: %s", fp.PeetPrint)
t.Logf(" PeetPrintHash: %s", fp.PeetPrintHash)
// Verify expectations
if tc.ExpectedJA3 != "" {
if fp.JA3Hash == tc.ExpectedJA3 {
t.Logf(" ✓ JA3 hash matches: %s", tc.ExpectedJA3)
} else {
t.Errorf(" ✗ JA3 hash mismatch: got %s, expected %s", fp.JA3Hash, tc.ExpectedJA3)
}
}
if tc.ExpectedJA4 != "" {
if fp.JA4 == tc.ExpectedJA4 {
t.Logf(" ✓ JA4 matches: %s", tc.ExpectedJA4)
} else {
t.Errorf(" ✗ JA4 mismatch: got %s, expected %s", fp.JA4, tc.ExpectedJA4)
}
}
// Check JA4 cipher hash (stable middle part)
// JA4 format: prefix_cipherHash_extHash
if tc.JA4CipherHash != "" {
if strings.Contains(fp.JA4, "_"+tc.JA4CipherHash+"_") {
t.Logf(" ✓ JA4 cipher hash matches: %s", tc.JA4CipherHash)
} else {
t.Errorf(" ✗ JA4 cipher hash mismatch: got %s, expected cipher hash %s", fp.JA4, tc.JA4CipherHash)
}
}
})
}
}
// fetchFingerprint makes a request to tls.peet.ws and returns the TLS fingerprint info.
func fetchFingerprint(t *testing.T, profile *Profile) *TLSInfo {
t.Helper()
dialer := NewDialer(profile, nil)
client := &http.Client{
Transport: &http.Transport{
DialTLSContext: dialer.DialTLSContext,
},
Timeout: 30 * time.Second,
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "GET", "https://tls.peet.ws/api/all", nil)
if err != nil {
t.Fatalf("failed to create request: %v", err)
return nil
}
req.Header.Set("User-Agent", "Claude Code/2.0.0 Node.js/20.0.0")
resp, err := client.Do(req)
if err != nil {
t.Fatalf("failed to get fingerprint: %v", err)
return nil
}
defer func() { _ = resp.Body.Close() }()
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("failed to read response: %v", err)
return nil
}
var fpResp FingerprintResponse
if err := json.Unmarshal(body, &fpResp); err != nil {
t.Logf("Response body: %s", string(body))
t.Fatalf("failed to parse fingerprint response: %v", err)
return nil
}
return &fpResp.TLS
}

View File

@@ -0,0 +1,20 @@
package tlsfingerprint
// FingerprintResponse represents the response from tls.peet.ws/api/all.
// 共享测试类型,供 unit 和 integration 测试文件使用。
type FingerprintResponse struct {
IP string `json:"ip"`
TLS TLSInfo `json:"tls"`
HTTP2 any `json:"http2"`
}
// TLSInfo contains TLS fingerprint details.
type TLSInfo struct {
JA3 string `json:"ja3"`
JA3Hash string `json:"ja3_hash"`
JA4 string `json:"ja4"`
PeetPrint string `json:"peetprint"`
PeetPrintHash string `json:"peetprint_hash"`
ClientRandom string `json:"client_random"`
SessionID string `json:"session_id"`
}