Merge branch 'main' into feature/antigravity-user-agent-configurable
This commit is contained in:
@@ -204,9 +204,14 @@ func shouldFallbackToNextURL(err error, statusCode int) bool {
|
||||
|
||||
// ExchangeCode 用 authorization code 交换 token
|
||||
func (c *Client) ExchangeCode(ctx context.Context, code, codeVerifier string) (*TokenResponse, error) {
|
||||
clientSecret, err := getClientSecret()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
params := url.Values{}
|
||||
params.Set("client_id", ClientID)
|
||||
params.Set("client_secret", ClientSecret)
|
||||
params.Set("client_secret", clientSecret)
|
||||
params.Set("code", code)
|
||||
params.Set("redirect_uri", RedirectURI)
|
||||
params.Set("grant_type", "authorization_code")
|
||||
@@ -243,9 +248,14 @@ func (c *Client) ExchangeCode(ctx context.Context, code, codeVerifier string) (*
|
||||
|
||||
// RefreshToken 刷新 access_token
|
||||
func (c *Client) RefreshToken(ctx context.Context, refreshToken string) (*TokenResponse, error) {
|
||||
clientSecret, err := getClientSecret()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
params := url.Values{}
|
||||
params.Set("client_id", ClientID)
|
||||
params.Set("client_secret", ClientSecret)
|
||||
params.Set("client_secret", clientSecret)
|
||||
params.Set("refresh_token", refreshToken)
|
||||
params.Set("grant_type", "refresh_token")
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -6,11 +6,14 @@ import (
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -21,7 +24,11 @@ const (
|
||||
|
||||
// Antigravity OAuth 客户端凭证
|
||||
ClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
|
||||
ClientSecret = "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"
|
||||
ClientSecret = ""
|
||||
|
||||
// AntigravityOAuthClientSecretEnv 是 Antigravity OAuth client_secret 的环境变量名。
|
||||
// 出于安全原因,该值不得硬编码入库。
|
||||
AntigravityOAuthClientSecretEnv = "ANTIGRAVITY_OAUTH_CLIENT_SECRET"
|
||||
|
||||
// 固定的 redirect_uri(用户需手动复制 code)
|
||||
RedirectURI = "http://localhost:8085/callback"
|
||||
@@ -57,6 +64,17 @@ func init() {
|
||||
// GetUserAgent 返回当前配置的 User-Agent
|
||||
func GetUserAgent() string {
|
||||
return fmt.Sprintf("antigravity/%s windows/amd64", defaultUserAgentVersion)
|
||||
|
||||
func getClientSecret() (string, error) {
|
||||
if v := strings.TrimSpace(ClientSecret); v != "" {
|
||||
return v, nil
|
||||
}
|
||||
if v, ok := os.LookupEnv(AntigravityOAuthClientSecretEnv); ok {
|
||||
if vv := strings.TrimSpace(v); vv != "" {
|
||||
return vv, nil
|
||||
}
|
||||
}
|
||||
return "", infraerrors.Newf(http.StatusBadRequest, "ANTIGRAVITY_OAUTH_CLIENT_SECRET_MISSING", "missing antigravity oauth client_secret; set %s", AntigravityOAuthClientSecretEnv)
|
||||
}
|
||||
|
||||
// BaseURLs 定义 Antigravity API 端点(与 Antigravity-Manager 保持一致)
|
||||
|
||||
704
backend/internal/pkg/antigravity/oauth_test.go
Normal file
704
backend/internal/pkg/antigravity/oauth_test.go
Normal file
@@ -0,0 +1,704 @@
|
||||
//go:build unit
|
||||
|
||||
package antigravity
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// getClientSecret
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestGetClientSecret_环境变量设置(t *testing.T) {
|
||||
t.Setenv(AntigravityOAuthClientSecretEnv, "my-secret-value")
|
||||
|
||||
secret, err := getClientSecret()
|
||||
if err != nil {
|
||||
t.Fatalf("获取 client_secret 失败: %v", err)
|
||||
}
|
||||
if secret != "my-secret-value" {
|
||||
t.Errorf("client_secret 不匹配: got %s, want my-secret-value", secret)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetClientSecret_环境变量为空(t *testing.T) {
|
||||
t.Setenv(AntigravityOAuthClientSecretEnv, "")
|
||||
|
||||
_, err := getClientSecret()
|
||||
if err == nil {
|
||||
t.Fatal("环境变量为空时应返回错误")
|
||||
}
|
||||
if !strings.Contains(err.Error(), AntigravityOAuthClientSecretEnv) {
|
||||
t.Errorf("错误信息应包含环境变量名: got %s", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetClientSecret_环境变量未设置(t *testing.T) {
|
||||
// t.Setenv 会在测试结束时恢复,但我们需要确保它不存在
|
||||
// 注意:如果 ClientSecret 常量非空,这个测试会直接返回常量值
|
||||
// 当前代码中 ClientSecret = "",所以会走环境变量逻辑
|
||||
|
||||
// 明确设置再取消,确保环境变量不存在
|
||||
t.Setenv(AntigravityOAuthClientSecretEnv, "")
|
||||
|
||||
_, err := getClientSecret()
|
||||
if err == nil {
|
||||
t.Fatal("环境变量未设置时应返回错误")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetClientSecret_环境变量含空格(t *testing.T) {
|
||||
t.Setenv(AntigravityOAuthClientSecretEnv, " ")
|
||||
|
||||
_, err := getClientSecret()
|
||||
if err == nil {
|
||||
t.Fatal("环境变量仅含空格时应返回错误")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetClientSecret_环境变量有前后空格(t *testing.T) {
|
||||
t.Setenv(AntigravityOAuthClientSecretEnv, " valid-secret ")
|
||||
|
||||
secret, err := getClientSecret()
|
||||
if err != nil {
|
||||
t.Fatalf("获取 client_secret 失败: %v", err)
|
||||
}
|
||||
if secret != "valid-secret" {
|
||||
t.Errorf("应去除前后空格: got %q, want %q", secret, "valid-secret")
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ForwardBaseURLs
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestForwardBaseURLs_Daily优先(t *testing.T) {
|
||||
urls := ForwardBaseURLs()
|
||||
if len(urls) == 0 {
|
||||
t.Fatal("ForwardBaseURLs 返回空列表")
|
||||
}
|
||||
|
||||
// daily URL 应排在第一位
|
||||
if urls[0] != antigravityDailyBaseURL {
|
||||
t.Errorf("第一个 URL 应为 daily: got %s, want %s", urls[0], antigravityDailyBaseURL)
|
||||
}
|
||||
|
||||
// 应包含所有 URL
|
||||
if len(urls) != len(BaseURLs) {
|
||||
t.Errorf("URL 数量不匹配: got %d, want %d", len(urls), len(BaseURLs))
|
||||
}
|
||||
|
||||
// 验证 prod URL 也在列表中
|
||||
found := false
|
||||
for _, u := range urls {
|
||||
if u == antigravityProdBaseURL {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("ForwardBaseURLs 中缺少 prod URL")
|
||||
}
|
||||
}
|
||||
|
||||
func TestForwardBaseURLs_不修改原切片(t *testing.T) {
|
||||
originalFirst := BaseURLs[0]
|
||||
_ = ForwardBaseURLs()
|
||||
// 确保原始 BaseURLs 未被修改
|
||||
if BaseURLs[0] != originalFirst {
|
||||
t.Errorf("ForwardBaseURLs 不应修改原始 BaseURLs: got %s, want %s", BaseURLs[0], originalFirst)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// URLAvailability
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestNewURLAvailability(t *testing.T) {
|
||||
ua := NewURLAvailability(5 * time.Minute)
|
||||
if ua == nil {
|
||||
t.Fatal("NewURLAvailability 返回 nil")
|
||||
}
|
||||
if ua.ttl != 5*time.Minute {
|
||||
t.Errorf("TTL 不匹配: got %v, want 5m", ua.ttl)
|
||||
}
|
||||
if ua.unavailable == nil {
|
||||
t.Error("unavailable map 不应为 nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestURLAvailability_MarkUnavailable(t *testing.T) {
|
||||
ua := NewURLAvailability(5 * time.Minute)
|
||||
testURL := "https://example.com"
|
||||
|
||||
ua.MarkUnavailable(testURL)
|
||||
|
||||
if ua.IsAvailable(testURL) {
|
||||
t.Error("标记为不可用后 IsAvailable 应返回 false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestURLAvailability_MarkSuccess(t *testing.T) {
|
||||
ua := NewURLAvailability(5 * time.Minute)
|
||||
testURL := "https://example.com"
|
||||
|
||||
// 先标记为不可用
|
||||
ua.MarkUnavailable(testURL)
|
||||
if ua.IsAvailable(testURL) {
|
||||
t.Error("标记为不可用后应不可用")
|
||||
}
|
||||
|
||||
// 标记成功后应恢复可用
|
||||
ua.MarkSuccess(testURL)
|
||||
if !ua.IsAvailable(testURL) {
|
||||
t.Error("MarkSuccess 后应恢复可用")
|
||||
}
|
||||
|
||||
// 验证 lastSuccess 被设置
|
||||
ua.mu.RLock()
|
||||
if ua.lastSuccess != testURL {
|
||||
t.Errorf("lastSuccess 不匹配: got %s, want %s", ua.lastSuccess, testURL)
|
||||
}
|
||||
ua.mu.RUnlock()
|
||||
}
|
||||
|
||||
func TestURLAvailability_IsAvailable_TTL过期(t *testing.T) {
|
||||
// 使用极短的 TTL
|
||||
ua := NewURLAvailability(1 * time.Millisecond)
|
||||
testURL := "https://example.com"
|
||||
|
||||
ua.MarkUnavailable(testURL)
|
||||
// 等待 TTL 过期
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
|
||||
if !ua.IsAvailable(testURL) {
|
||||
t.Error("TTL 过期后 URL 应恢复可用")
|
||||
}
|
||||
}
|
||||
|
||||
func TestURLAvailability_IsAvailable_未标记的URL(t *testing.T) {
|
||||
ua := NewURLAvailability(5 * time.Minute)
|
||||
if !ua.IsAvailable("https://never-marked.com") {
|
||||
t.Error("未标记的 URL 应默认可用")
|
||||
}
|
||||
}
|
||||
|
||||
func TestURLAvailability_GetAvailableURLs(t *testing.T) {
|
||||
ua := NewURLAvailability(10 * time.Minute)
|
||||
|
||||
// 默认所有 URL 都可用
|
||||
urls := ua.GetAvailableURLs()
|
||||
if len(urls) != len(BaseURLs) {
|
||||
t.Errorf("可用 URL 数量不匹配: got %d, want %d", len(urls), len(BaseURLs))
|
||||
}
|
||||
}
|
||||
|
||||
func TestURLAvailability_GetAvailableURLs_标记一个不可用(t *testing.T) {
|
||||
ua := NewURLAvailability(10 * time.Minute)
|
||||
|
||||
if len(BaseURLs) < 2 {
|
||||
t.Skip("BaseURLs 少于 2 个,跳过此测试")
|
||||
}
|
||||
|
||||
ua.MarkUnavailable(BaseURLs[0])
|
||||
urls := ua.GetAvailableURLs()
|
||||
|
||||
// 标记的 URL 不应出现在可用列表中
|
||||
for _, u := range urls {
|
||||
if u == BaseURLs[0] {
|
||||
t.Errorf("被标记不可用的 URL 不应出现在可用列表中: %s", BaseURLs[0])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestURLAvailability_GetAvailableURLsWithBase(t *testing.T) {
|
||||
ua := NewURLAvailability(10 * time.Minute)
|
||||
customURLs := []string{"https://a.com", "https://b.com", "https://c.com"}
|
||||
|
||||
urls := ua.GetAvailableURLsWithBase(customURLs)
|
||||
if len(urls) != 3 {
|
||||
t.Errorf("可用 URL 数量不匹配: got %d, want 3", len(urls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestURLAvailability_GetAvailableURLsWithBase_LastSuccess优先(t *testing.T) {
|
||||
ua := NewURLAvailability(10 * time.Minute)
|
||||
customURLs := []string{"https://a.com", "https://b.com", "https://c.com"}
|
||||
|
||||
ua.MarkSuccess("https://c.com")
|
||||
|
||||
urls := ua.GetAvailableURLsWithBase(customURLs)
|
||||
if len(urls) != 3 {
|
||||
t.Fatalf("可用 URL 数量不匹配: got %d, want 3", len(urls))
|
||||
}
|
||||
// c.com 应排在第一位
|
||||
if urls[0] != "https://c.com" {
|
||||
t.Errorf("lastSuccess 应排在第一位: got %s, want https://c.com", urls[0])
|
||||
}
|
||||
// 其余按原始顺序
|
||||
if urls[1] != "https://a.com" {
|
||||
t.Errorf("第二个应为 a.com: got %s", urls[1])
|
||||
}
|
||||
if urls[2] != "https://b.com" {
|
||||
t.Errorf("第三个应为 b.com: got %s", urls[2])
|
||||
}
|
||||
}
|
||||
|
||||
func TestURLAvailability_GetAvailableURLsWithBase_LastSuccess不可用(t *testing.T) {
|
||||
ua := NewURLAvailability(10 * time.Minute)
|
||||
customURLs := []string{"https://a.com", "https://b.com"}
|
||||
|
||||
ua.MarkSuccess("https://b.com")
|
||||
ua.MarkUnavailable("https://b.com")
|
||||
|
||||
urls := ua.GetAvailableURLsWithBase(customURLs)
|
||||
// b.com 被标记不可用,不应出现
|
||||
if len(urls) != 1 {
|
||||
t.Fatalf("可用 URL 数量不匹配: got %d, want 1", len(urls))
|
||||
}
|
||||
if urls[0] != "https://a.com" {
|
||||
t.Errorf("仅 a.com 应可用: got %s", urls[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestURLAvailability_GetAvailableURLsWithBase_LastSuccess不在列表中(t *testing.T) {
|
||||
ua := NewURLAvailability(10 * time.Minute)
|
||||
customURLs := []string{"https://a.com", "https://b.com"}
|
||||
|
||||
ua.MarkSuccess("https://not-in-list.com")
|
||||
|
||||
urls := ua.GetAvailableURLsWithBase(customURLs)
|
||||
// lastSuccess 不在自定义列表中,不应被添加
|
||||
if len(urls) != 2 {
|
||||
t.Fatalf("可用 URL 数量不匹配: got %d, want 2", len(urls))
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// SessionStore
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestNewSessionStore(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
defer store.Stop()
|
||||
|
||||
if store == nil {
|
||||
t.Fatal("NewSessionStore 返回 nil")
|
||||
}
|
||||
if store.sessions == nil {
|
||||
t.Error("sessions map 不应为 nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionStore_SetAndGet(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
defer store.Stop()
|
||||
|
||||
session := &OAuthSession{
|
||||
State: "test-state",
|
||||
CodeVerifier: "test-verifier",
|
||||
ProxyURL: "http://proxy.example.com",
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
store.Set("session-1", session)
|
||||
|
||||
got, ok := store.Get("session-1")
|
||||
if !ok {
|
||||
t.Fatal("Get 应返回 true")
|
||||
}
|
||||
if got.State != "test-state" {
|
||||
t.Errorf("State 不匹配: got %s", got.State)
|
||||
}
|
||||
if got.CodeVerifier != "test-verifier" {
|
||||
t.Errorf("CodeVerifier 不匹配: got %s", got.CodeVerifier)
|
||||
}
|
||||
if got.ProxyURL != "http://proxy.example.com" {
|
||||
t.Errorf("ProxyURL 不匹配: got %s", got.ProxyURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionStore_Get_不存在(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
defer store.Stop()
|
||||
|
||||
_, ok := store.Get("nonexistent")
|
||||
if ok {
|
||||
t.Error("不存在的 session 应返回 false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionStore_Get_过期(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
defer store.Stop()
|
||||
|
||||
session := &OAuthSession{
|
||||
State: "expired-state",
|
||||
CreatedAt: time.Now().Add(-SessionTTL - time.Minute), // 已过期
|
||||
}
|
||||
|
||||
store.Set("expired-session", session)
|
||||
|
||||
_, ok := store.Get("expired-session")
|
||||
if ok {
|
||||
t.Error("过期的 session 应返回 false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionStore_Delete(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
defer store.Stop()
|
||||
|
||||
session := &OAuthSession{
|
||||
State: "to-delete",
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
store.Set("del-session", session)
|
||||
store.Delete("del-session")
|
||||
|
||||
_, ok := store.Get("del-session")
|
||||
if ok {
|
||||
t.Error("删除后 Get 应返回 false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionStore_Delete_不存在(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
defer store.Stop()
|
||||
|
||||
// 删除不存在的 session 不应 panic
|
||||
store.Delete("nonexistent")
|
||||
}
|
||||
|
||||
func TestSessionStore_Stop(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
store.Stop()
|
||||
|
||||
// 多次 Stop 不应 panic
|
||||
store.Stop()
|
||||
}
|
||||
|
||||
func TestSessionStore_多个Session(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
defer store.Stop()
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
session := &OAuthSession{
|
||||
State: "state-" + string(rune('0'+i)),
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
store.Set("session-"+string(rune('0'+i)), session)
|
||||
}
|
||||
|
||||
// 验证都能取到
|
||||
for i := 0; i < 10; i++ {
|
||||
_, ok := store.Get("session-" + string(rune('0'+i)))
|
||||
if !ok {
|
||||
t.Errorf("session-%d 应存在", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// GenerateRandomBytes
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestGenerateRandomBytes_长度正确(t *testing.T) {
|
||||
sizes := []int{0, 1, 16, 32, 64, 128}
|
||||
for _, size := range sizes {
|
||||
b, err := GenerateRandomBytes(size)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateRandomBytes(%d) 失败: %v", size, err)
|
||||
}
|
||||
if len(b) != size {
|
||||
t.Errorf("长度不匹配: got %d, want %d", len(b), size)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateRandomBytes_不同调用产生不同结果(t *testing.T) {
|
||||
b1, err := GenerateRandomBytes(32)
|
||||
if err != nil {
|
||||
t.Fatalf("第一次调用失败: %v", err)
|
||||
}
|
||||
b2, err := GenerateRandomBytes(32)
|
||||
if err != nil {
|
||||
t.Fatalf("第二次调用失败: %v", err)
|
||||
}
|
||||
// 两次生成的随机字节应该不同(概率上几乎不可能相同)
|
||||
if string(b1) == string(b2) {
|
||||
t.Error("两次生成的随机字节相同,概率极低,可能有问题")
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// GenerateState
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestGenerateState_返回值格式(t *testing.T) {
|
||||
state, err := GenerateState()
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateState 失败: %v", err)
|
||||
}
|
||||
if state == "" {
|
||||
t.Error("GenerateState 返回空字符串")
|
||||
}
|
||||
// base64url 编码不应包含 +, /, =
|
||||
if strings.ContainsAny(state, "+/=") {
|
||||
t.Errorf("GenerateState 返回值包含非 base64url 字符: %s", state)
|
||||
}
|
||||
// 32 字节的 base64url 编码长度应为 43(去掉了尾部 = 填充)
|
||||
if len(state) != 43 {
|
||||
t.Errorf("GenerateState 返回值长度不匹配: got %d, want 43", len(state))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateState_唯一性(t *testing.T) {
|
||||
s1, _ := GenerateState()
|
||||
s2, _ := GenerateState()
|
||||
if s1 == s2 {
|
||||
t.Error("两次 GenerateState 结果相同")
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// GenerateSessionID
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestGenerateSessionID_返回值格式(t *testing.T) {
|
||||
id, err := GenerateSessionID()
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateSessionID 失败: %v", err)
|
||||
}
|
||||
if id == "" {
|
||||
t.Error("GenerateSessionID 返回空字符串")
|
||||
}
|
||||
// 16 字节的 hex 编码长度应为 32
|
||||
if len(id) != 32 {
|
||||
t.Errorf("GenerateSessionID 返回值长度不匹配: got %d, want 32", len(id))
|
||||
}
|
||||
// 验证是合法的 hex 字符串
|
||||
if _, err := hex.DecodeString(id); err != nil {
|
||||
t.Errorf("GenerateSessionID 返回值不是合法的 hex 字符串: %s, err: %v", id, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateSessionID_唯一性(t *testing.T) {
|
||||
id1, _ := GenerateSessionID()
|
||||
id2, _ := GenerateSessionID()
|
||||
if id1 == id2 {
|
||||
t.Error("两次 GenerateSessionID 结果相同")
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// GenerateCodeVerifier
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestGenerateCodeVerifier_返回值格式(t *testing.T) {
|
||||
verifier, err := GenerateCodeVerifier()
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateCodeVerifier 失败: %v", err)
|
||||
}
|
||||
if verifier == "" {
|
||||
t.Error("GenerateCodeVerifier 返回空字符串")
|
||||
}
|
||||
// base64url 编码不应包含 +, /, =
|
||||
if strings.ContainsAny(verifier, "+/=") {
|
||||
t.Errorf("GenerateCodeVerifier 返回值包含非 base64url 字符: %s", verifier)
|
||||
}
|
||||
// 32 字节的 base64url 编码长度应为 43
|
||||
if len(verifier) != 43 {
|
||||
t.Errorf("GenerateCodeVerifier 返回值长度不匹配: got %d, want 43", len(verifier))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateCodeVerifier_唯一性(t *testing.T) {
|
||||
v1, _ := GenerateCodeVerifier()
|
||||
v2, _ := GenerateCodeVerifier()
|
||||
if v1 == v2 {
|
||||
t.Error("两次 GenerateCodeVerifier 结果相同")
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// GenerateCodeChallenge
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestGenerateCodeChallenge_SHA256_Base64URL(t *testing.T) {
|
||||
verifier := "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
|
||||
|
||||
challenge := GenerateCodeChallenge(verifier)
|
||||
|
||||
// 手动计算预期值
|
||||
hash := sha256.Sum256([]byte(verifier))
|
||||
expected := strings.TrimRight(base64.URLEncoding.EncodeToString(hash[:]), "=")
|
||||
|
||||
if challenge != expected {
|
||||
t.Errorf("CodeChallenge 不匹配: got %s, want %s", challenge, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateCodeChallenge_不含填充字符(t *testing.T) {
|
||||
challenge := GenerateCodeChallenge("test-verifier")
|
||||
if strings.Contains(challenge, "=") {
|
||||
t.Errorf("CodeChallenge 不应包含 = 填充字符: %s", challenge)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateCodeChallenge_不含非URL安全字符(t *testing.T) {
|
||||
challenge := GenerateCodeChallenge("another-verifier")
|
||||
if strings.ContainsAny(challenge, "+/") {
|
||||
t.Errorf("CodeChallenge 不应包含 + 或 / 字符: %s", challenge)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateCodeChallenge_相同输入相同输出(t *testing.T) {
|
||||
c1 := GenerateCodeChallenge("same-verifier")
|
||||
c2 := GenerateCodeChallenge("same-verifier")
|
||||
if c1 != c2 {
|
||||
t.Errorf("相同输入应产生相同输出: got %s and %s", c1, c2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateCodeChallenge_不同输入不同输出(t *testing.T) {
|
||||
c1 := GenerateCodeChallenge("verifier-1")
|
||||
c2 := GenerateCodeChallenge("verifier-2")
|
||||
if c1 == c2 {
|
||||
t.Error("不同输入应产生不同输出")
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// BuildAuthorizationURL
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestBuildAuthorizationURL_参数验证(t *testing.T) {
|
||||
state := "test-state-123"
|
||||
codeChallenge := "test-challenge-abc"
|
||||
|
||||
authURL := BuildAuthorizationURL(state, codeChallenge)
|
||||
|
||||
// 验证以 AuthorizeURL 开头
|
||||
if !strings.HasPrefix(authURL, AuthorizeURL+"?") {
|
||||
t.Errorf("URL 应以 %s? 开头: got %s", AuthorizeURL, authURL)
|
||||
}
|
||||
|
||||
// 解析 URL 并验证参数
|
||||
parsed, err := url.Parse(authURL)
|
||||
if err != nil {
|
||||
t.Fatalf("解析 URL 失败: %v", err)
|
||||
}
|
||||
|
||||
params := parsed.Query()
|
||||
|
||||
expectedParams := map[string]string{
|
||||
"client_id": ClientID,
|
||||
"redirect_uri": RedirectURI,
|
||||
"response_type": "code",
|
||||
"scope": Scopes,
|
||||
"state": state,
|
||||
"code_challenge": codeChallenge,
|
||||
"code_challenge_method": "S256",
|
||||
"access_type": "offline",
|
||||
"prompt": "consent",
|
||||
"include_granted_scopes": "true",
|
||||
}
|
||||
|
||||
for key, want := range expectedParams {
|
||||
got := params.Get(key)
|
||||
if got != want {
|
||||
t.Errorf("参数 %s 不匹配: got %q, want %q", key, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAuthorizationURL_参数数量(t *testing.T) {
|
||||
authURL := BuildAuthorizationURL("s", "c")
|
||||
parsed, err := url.Parse(authURL)
|
||||
if err != nil {
|
||||
t.Fatalf("解析 URL 失败: %v", err)
|
||||
}
|
||||
|
||||
params := parsed.Query()
|
||||
// 应包含 10 个参数
|
||||
expectedCount := 10
|
||||
if len(params) != expectedCount {
|
||||
t.Errorf("参数数量不匹配: got %d, want %d", len(params), expectedCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAuthorizationURL_特殊字符编码(t *testing.T) {
|
||||
state := "state+with/special=chars"
|
||||
codeChallenge := "challenge+value"
|
||||
|
||||
authURL := BuildAuthorizationURL(state, codeChallenge)
|
||||
|
||||
parsed, err := url.Parse(authURL)
|
||||
if err != nil {
|
||||
t.Fatalf("解析 URL 失败: %v", err)
|
||||
}
|
||||
|
||||
// 解析后应正确还原特殊字符
|
||||
if got := parsed.Query().Get("state"); got != state {
|
||||
t.Errorf("state 参数编码/解码不匹配: got %q, want %q", got, state)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// 常量值验证
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestConstants_值正确(t *testing.T) {
|
||||
if AuthorizeURL != "https://accounts.google.com/o/oauth2/v2/auth" {
|
||||
t.Errorf("AuthorizeURL 不匹配: got %s", AuthorizeURL)
|
||||
}
|
||||
if TokenURL != "https://oauth2.googleapis.com/token" {
|
||||
t.Errorf("TokenURL 不匹配: got %s", TokenURL)
|
||||
}
|
||||
if UserInfoURL != "https://www.googleapis.com/oauth2/v2/userinfo" {
|
||||
t.Errorf("UserInfoURL 不匹配: got %s", UserInfoURL)
|
||||
}
|
||||
if ClientID != "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com" {
|
||||
t.Errorf("ClientID 不匹配: got %s", ClientID)
|
||||
}
|
||||
if ClientSecret != "" {
|
||||
t.Error("ClientSecret 应为空字符串")
|
||||
}
|
||||
if RedirectURI != "http://localhost:8085/callback" {
|
||||
t.Errorf("RedirectURI 不匹配: got %s", RedirectURI)
|
||||
}
|
||||
if UserAgent != "antigravity/1.15.8 windows/amd64" {
|
||||
t.Errorf("UserAgent 不匹配: got %s", UserAgent)
|
||||
}
|
||||
if SessionTTL != 30*time.Minute {
|
||||
t.Errorf("SessionTTL 不匹配: got %v", SessionTTL)
|
||||
}
|
||||
if URLAvailabilityTTL != 5*time.Minute {
|
||||
t.Errorf("URLAvailabilityTTL 不匹配: got %v", URLAvailabilityTTL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScopes_包含必要范围(t *testing.T) {
|
||||
expectedScopes := []string{
|
||||
"https://www.googleapis.com/auth/cloud-platform",
|
||||
"https://www.googleapis.com/auth/userinfo.email",
|
||||
"https://www.googleapis.com/auth/userinfo.profile",
|
||||
"https://www.googleapis.com/auth/cclog",
|
||||
"https://www.googleapis.com/auth/experimentsandconfigs",
|
||||
}
|
||||
|
||||
for _, scope := range expectedScopes {
|
||||
if !strings.Contains(Scopes, scope) {
|
||||
t.Errorf("Scopes 缺少 %s", scope)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,10 +1,13 @@
|
||||
package antigravity
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TransformGeminiToClaude 将 Gemini 响应转换为 Claude 格式(非流式)
|
||||
@@ -341,12 +344,30 @@ func buildGroundingText(grounding *GeminiGroundingMetadata) string {
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
// generateRandomID 生成随机 ID
|
||||
// fallbackCounter 降级伪随机 ID 的全局计数器,混入 seed 避免高并发下 UnixNano 相同导致碰撞。
|
||||
var fallbackCounter uint64
|
||||
|
||||
// generateRandomID 生成密码学安全的随机 ID
|
||||
func generateRandomID() string {
|
||||
const chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
result := make([]byte, 12)
|
||||
for i := range result {
|
||||
result[i] = chars[i%len(chars)]
|
||||
id := make([]byte, 12)
|
||||
randBytes := make([]byte, 12)
|
||||
if _, err := rand.Read(randBytes); err != nil {
|
||||
// 避免在请求路径里 panic:极端情况下熵源不可用时降级为伪随机。
|
||||
// 这里主要用于生成响应/工具调用的临时 ID,安全要求不高但需尽量避免碰撞。
|
||||
cnt := atomic.AddUint64(&fallbackCounter, 1)
|
||||
seed := uint64(time.Now().UnixNano()) ^ cnt
|
||||
seed ^= uint64(len(err.Error())) << 32
|
||||
for i := range id {
|
||||
seed ^= seed << 13
|
||||
seed ^= seed >> 7
|
||||
seed ^= seed << 17
|
||||
id[i] = chars[int(seed)%len(chars)]
|
||||
}
|
||||
return string(id)
|
||||
}
|
||||
return string(result)
|
||||
for i, b := range randBytes {
|
||||
id[i] = chars[int(b)%len(chars)]
|
||||
}
|
||||
return string(id)
|
||||
}
|
||||
|
||||
109
backend/internal/pkg/antigravity/response_transformer_test.go
Normal file
109
backend/internal/pkg/antigravity/response_transformer_test.go
Normal file
@@ -0,0 +1,109 @@
|
||||
//go:build unit
|
||||
|
||||
package antigravity
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// --- Task 7: 验证 generateRandomID 和降级碰撞防护 ---
|
||||
|
||||
func TestGenerateRandomID_Uniqueness(t *testing.T) {
|
||||
seen := make(map[string]struct{}, 100)
|
||||
for i := 0; i < 100; i++ {
|
||||
id := generateRandomID()
|
||||
require.Len(t, id, 12, "ID 长度应为 12")
|
||||
_, dup := seen[id]
|
||||
require.False(t, dup, "第 %d 次调用生成了重复 ID: %s", i, id)
|
||||
seen[id] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFallbackCounter_Increments(t *testing.T) {
|
||||
// 验证 fallbackCounter 的原子递增行为确保降级分支不会生成相同 seed
|
||||
before := atomic.LoadUint64(&fallbackCounter)
|
||||
cnt1 := atomic.AddUint64(&fallbackCounter, 1)
|
||||
cnt2 := atomic.AddUint64(&fallbackCounter, 1)
|
||||
require.Equal(t, before+1, cnt1, "第一次递增应为 before+1")
|
||||
require.Equal(t, before+2, cnt2, "第二次递增应为 before+2")
|
||||
require.NotEqual(t, cnt1, cnt2, "连续两次递增的计数器值应不同")
|
||||
}
|
||||
|
||||
func TestFallbackCounter_ConcurrentIncrements(t *testing.T) {
|
||||
// 验证并发递增的原子性 — 每次递增都应产生唯一值
|
||||
const goroutines = 50
|
||||
results := make([]uint64, goroutines)
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(goroutines)
|
||||
|
||||
for i := 0; i < goroutines; i++ {
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
results[idx] = atomic.AddUint64(&fallbackCounter, 1)
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
// 所有结果应唯一
|
||||
seen := make(map[uint64]bool, goroutines)
|
||||
for _, v := range results {
|
||||
assert.False(t, seen[v], "并发递增产生了重复值: %d", v)
|
||||
seen[v] = true
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateRandomID_Charset(t *testing.T) {
|
||||
const validChars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
validSet := make(map[byte]struct{}, len(validChars))
|
||||
for i := 0; i < len(validChars); i++ {
|
||||
validSet[validChars[i]] = struct{}{}
|
||||
}
|
||||
|
||||
for i := 0; i < 50; i++ {
|
||||
id := generateRandomID()
|
||||
for j := 0; j < len(id); j++ {
|
||||
_, ok := validSet[id[j]]
|
||||
require.True(t, ok, "ID 包含非法字符: %c (ID=%s)", id[j], id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateRandomID_Length(t *testing.T) {
|
||||
for i := 0; i < 100; i++ {
|
||||
id := generateRandomID()
|
||||
assert.Len(t, id, 12, "每次生成的 ID 长度应为 12")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateRandomID_ConcurrentUniqueness(t *testing.T) {
|
||||
// 验证并发调用不会产生重复 ID
|
||||
const goroutines = 100
|
||||
results := make([]string, goroutines)
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(goroutines)
|
||||
|
||||
for i := 0; i < goroutines; i++ {
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
results[idx] = generateRandomID()
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
seen := make(map[string]bool, goroutines)
|
||||
for _, id := range results {
|
||||
assert.False(t, seen[id], "并发调用产生了重复 ID: %s", id)
|
||||
seen[id] = true
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkGenerateRandomID(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = generateRandomID()
|
||||
}
|
||||
}
|
||||
@@ -8,9 +8,21 @@ const (
|
||||
// ForcePlatform 强制平台(用于 /antigravity 路由),由 middleware.ForcePlatform 设置
|
||||
ForcePlatform Key = "ctx_force_platform"
|
||||
|
||||
// RequestID 为服务端生成/透传的请求 ID。
|
||||
RequestID Key = "ctx_request_id"
|
||||
|
||||
// ClientRequestID 客户端请求的唯一标识,用于追踪请求全生命周期(用于 Ops 监控与排障)。
|
||||
ClientRequestID Key = "ctx_client_request_id"
|
||||
|
||||
// Model 请求模型标识(用于统一请求链路日志字段)。
|
||||
Model Key = "ctx_model"
|
||||
|
||||
// Platform 当前请求最终命中的平台(用于统一请求链路日志字段)。
|
||||
Platform Key = "ctx_platform"
|
||||
|
||||
// AccountID 当前请求最终命中的账号 ID(用于统一请求链路日志字段)。
|
||||
AccountID Key = "ctx_account_id"
|
||||
|
||||
// RetryCount 表示当前请求在网关层的重试次数(用于 Ops 记录与排障)。
|
||||
RetryCount Key = "ctx_retry_count"
|
||||
|
||||
@@ -32,4 +44,12 @@ const (
|
||||
// SingleAccountRetry 标识当前请求处于单账号 503 退避重试模式。
|
||||
// 在此模式下,Service 层的模型限流预检查将等待限流过期而非直接切换账号。
|
||||
SingleAccountRetry Key = "ctx_single_account_retry"
|
||||
|
||||
// PrefetchedStickyAccountID 标识上游(通常 handler)预取到的 sticky session 账号 ID。
|
||||
// Service 层可复用该值,避免同请求链路重复读取 Redis。
|
||||
PrefetchedStickyAccountID Key = "ctx_prefetched_sticky_account_id"
|
||||
|
||||
// PrefetchedStickyGroupID 标识上游预取 sticky session 时所使用的分组 ID。
|
||||
// Service 层仅在分组匹配时复用 PrefetchedStickyAccountID,避免分组切换重试误用旧 sticky。
|
||||
PrefetchedStickyGroupID Key = "ctx_prefetched_sticky_group_id"
|
||||
)
|
||||
|
||||
@@ -21,6 +21,7 @@ func DefaultModels() []Model {
|
||||
{Name: "models/gemini-2.5-pro", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-3-flash-preview", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-3-pro-preview", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-3.1-pro-preview", SupportedGenerationMethods: methods},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -38,8 +38,13 @@ const (
|
||||
// GeminiCLIOAuthClientID/Secret are the public OAuth client credentials used by Google Gemini CLI.
|
||||
// They enable the "login without creating your own OAuth client" experience, but Google may
|
||||
// restrict which scopes are allowed for this client.
|
||||
GeminiCLIOAuthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
|
||||
GeminiCLIOAuthClientSecret = "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"
|
||||
GeminiCLIOAuthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
|
||||
// GeminiCLIOAuthClientSecret is intentionally not embedded in this repository.
|
||||
// If you rely on the built-in Gemini CLI OAuth client, you MUST provide its client_secret via config/env.
|
||||
GeminiCLIOAuthClientSecret = ""
|
||||
|
||||
// GeminiCLIOAuthClientSecretEnv is the environment variable name for the built-in client secret.
|
||||
GeminiCLIOAuthClientSecretEnv = "GEMINI_CLI_OAUTH_CLIENT_SECRET"
|
||||
|
||||
SessionTTL = 30 * time.Minute
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ var DefaultModels = []Model{
|
||||
{ID: "gemini-2.5-pro", Type: "model", DisplayName: "Gemini 2.5 Pro", CreatedAt: ""},
|
||||
{ID: "gemini-3-flash-preview", Type: "model", DisplayName: "Gemini 3 Flash Preview", CreatedAt: ""},
|
||||
{ID: "gemini-3-pro-preview", Type: "model", DisplayName: "Gemini 3 Pro Preview", CreatedAt: ""},
|
||||
{ID: "gemini-3.1-pro-preview", Type: "model", DisplayName: "Gemini 3.1 Pro Preview", CreatedAt: ""},
|
||||
}
|
||||
|
||||
// DefaultTestModel is the default model to preselect in test flows.
|
||||
|
||||
@@ -6,10 +6,14 @@ import (
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
)
|
||||
|
||||
type OAuthConfig struct {
|
||||
@@ -164,15 +168,24 @@ func EffectiveOAuthConfig(cfg OAuthConfig, oauthType string) (OAuthConfig, error
|
||||
}
|
||||
|
||||
// Fall back to built-in Gemini CLI OAuth client when not configured.
|
||||
// SECURITY: This repo does not embed the built-in client secret; it must be provided via env.
|
||||
if effective.ClientID == "" && effective.ClientSecret == "" {
|
||||
secret := strings.TrimSpace(GeminiCLIOAuthClientSecret)
|
||||
if secret == "" {
|
||||
if v, ok := os.LookupEnv(GeminiCLIOAuthClientSecretEnv); ok {
|
||||
secret = strings.TrimSpace(v)
|
||||
}
|
||||
}
|
||||
if secret == "" {
|
||||
return OAuthConfig{}, infraerrors.Newf(http.StatusBadRequest, "GEMINI_CLI_OAUTH_CLIENT_SECRET_MISSING", "built-in Gemini CLI OAuth client_secret is not configured; set %s or provide a custom OAuth client", GeminiCLIOAuthClientSecretEnv)
|
||||
}
|
||||
effective.ClientID = GeminiCLIOAuthClientID
|
||||
effective.ClientSecret = GeminiCLIOAuthClientSecret
|
||||
effective.ClientSecret = secret
|
||||
} else if effective.ClientID == "" || effective.ClientSecret == "" {
|
||||
return OAuthConfig{}, fmt.Errorf("OAuth client not configured: please set both client_id and client_secret (or leave both empty to use the built-in Gemini CLI client)")
|
||||
return OAuthConfig{}, infraerrors.New(http.StatusBadRequest, "GEMINI_OAUTH_CLIENT_NOT_CONFIGURED", "OAuth client not configured: please set both client_id and client_secret (or leave both empty to use the built-in Gemini CLI client)")
|
||||
}
|
||||
|
||||
isBuiltinClient := effective.ClientID == GeminiCLIOAuthClientID &&
|
||||
effective.ClientSecret == GeminiCLIOAuthClientSecret
|
||||
isBuiltinClient := effective.ClientID == GeminiCLIOAuthClientID
|
||||
|
||||
if effective.Scopes == "" {
|
||||
// Use different default scopes based on OAuth type
|
||||
|
||||
@@ -1,11 +1,439 @@
|
||||
package geminicli
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// SessionStore 测试
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestSessionStore_SetAndGet(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
defer store.Stop()
|
||||
|
||||
session := &OAuthSession{
|
||||
State: "test-state",
|
||||
OAuthType: "code_assist",
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
store.Set("sid-1", session)
|
||||
|
||||
got, ok := store.Get("sid-1")
|
||||
if !ok {
|
||||
t.Fatal("期望 Get 返回 ok=true,实际返回 false")
|
||||
}
|
||||
if got.State != "test-state" {
|
||||
t.Errorf("期望 State=%q,实际=%q", "test-state", got.State)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionStore_GetNotFound(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
defer store.Stop()
|
||||
|
||||
_, ok := store.Get("不存在的ID")
|
||||
if ok {
|
||||
t.Error("期望不存在的 sessionID 返回 ok=false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionStore_GetExpired(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
defer store.Stop()
|
||||
|
||||
// 创建一个已过期的 session(CreatedAt 设置为 SessionTTL+1 分钟之前)
|
||||
session := &OAuthSession{
|
||||
State: "expired-state",
|
||||
OAuthType: "code_assist",
|
||||
CreatedAt: time.Now().Add(-(SessionTTL + 1*time.Minute)),
|
||||
}
|
||||
store.Set("expired-sid", session)
|
||||
|
||||
_, ok := store.Get("expired-sid")
|
||||
if ok {
|
||||
t.Error("期望过期的 session 返回 ok=false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionStore_Delete(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
defer store.Stop()
|
||||
|
||||
session := &OAuthSession{
|
||||
State: "to-delete",
|
||||
OAuthType: "code_assist",
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
store.Set("del-sid", session)
|
||||
|
||||
// 先确认存在
|
||||
if _, ok := store.Get("del-sid"); !ok {
|
||||
t.Fatal("删除前 session 应该存在")
|
||||
}
|
||||
|
||||
store.Delete("del-sid")
|
||||
|
||||
if _, ok := store.Get("del-sid"); ok {
|
||||
t.Error("删除后 session 不应该存在")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionStore_Stop_Idempotent(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
|
||||
// 多次调用 Stop 不应 panic
|
||||
store.Stop()
|
||||
store.Stop()
|
||||
store.Stop()
|
||||
}
|
||||
|
||||
func TestSessionStore_ConcurrentAccess(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
defer store.Stop()
|
||||
|
||||
const goroutines = 50
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(goroutines * 3)
|
||||
|
||||
// 并发写入
|
||||
for i := 0; i < goroutines; i++ {
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
sid := "concurrent-" + string(rune('A'+idx%26))
|
||||
store.Set(sid, &OAuthSession{
|
||||
State: sid,
|
||||
OAuthType: "code_assist",
|
||||
CreatedAt: time.Now(),
|
||||
})
|
||||
}(i)
|
||||
}
|
||||
|
||||
// 并发读取
|
||||
for i := 0; i < goroutines; i++ {
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
sid := "concurrent-" + string(rune('A'+idx%26))
|
||||
store.Get(sid) // 可能找到也可能没找到,关键是不 panic
|
||||
}(i)
|
||||
}
|
||||
|
||||
// 并发删除
|
||||
for i := 0; i < goroutines; i++ {
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
sid := "concurrent-" + string(rune('A'+idx%26))
|
||||
store.Delete(sid)
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// GenerateRandomBytes 测试
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestGenerateRandomBytes(t *testing.T) {
|
||||
tests := []int{0, 1, 16, 32, 64}
|
||||
for _, n := range tests {
|
||||
b, err := GenerateRandomBytes(n)
|
||||
if err != nil {
|
||||
t.Errorf("GenerateRandomBytes(%d) 出错: %v", n, err)
|
||||
continue
|
||||
}
|
||||
if len(b) != n {
|
||||
t.Errorf("GenerateRandomBytes(%d) 返回长度=%d,期望=%d", n, len(b), n)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateRandomBytes_Uniqueness(t *testing.T) {
|
||||
// 两次调用应该返回不同的结果(极小概率相同,32字节足够)
|
||||
a, _ := GenerateRandomBytes(32)
|
||||
b, _ := GenerateRandomBytes(32)
|
||||
if string(a) == string(b) {
|
||||
t.Error("两次 GenerateRandomBytes(32) 返回了相同结果,随机性可能有问题")
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// GenerateState 测试
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestGenerateState(t *testing.T) {
|
||||
state, err := GenerateState()
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateState() 出错: %v", err)
|
||||
}
|
||||
if state == "" {
|
||||
t.Error("GenerateState() 返回空字符串")
|
||||
}
|
||||
// base64url 编码不应包含 padding '='
|
||||
if strings.Contains(state, "=") {
|
||||
t.Errorf("GenerateState() 结果包含 '=' padding: %s", state)
|
||||
}
|
||||
// base64url 不应包含 '+' 或 '/'
|
||||
if strings.ContainsAny(state, "+/") {
|
||||
t.Errorf("GenerateState() 结果包含非 base64url 字符: %s", state)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// GenerateSessionID 测试
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestGenerateSessionID(t *testing.T) {
|
||||
sid, err := GenerateSessionID()
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateSessionID() 出错: %v", err)
|
||||
}
|
||||
// 16 字节 -> 32 个 hex 字符
|
||||
if len(sid) != 32 {
|
||||
t.Errorf("GenerateSessionID() 长度=%d,期望=32", len(sid))
|
||||
}
|
||||
// 必须是合法的 hex 字符串
|
||||
if _, err := hex.DecodeString(sid); err != nil {
|
||||
t.Errorf("GenerateSessionID() 不是合法的 hex 字符串: %s, err=%v", sid, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateSessionID_Uniqueness(t *testing.T) {
|
||||
a, _ := GenerateSessionID()
|
||||
b, _ := GenerateSessionID()
|
||||
if a == b {
|
||||
t.Error("两次 GenerateSessionID() 返回了相同结果")
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// GenerateCodeVerifier 测试
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestGenerateCodeVerifier(t *testing.T) {
|
||||
verifier, err := GenerateCodeVerifier()
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateCodeVerifier() 出错: %v", err)
|
||||
}
|
||||
if verifier == "" {
|
||||
t.Error("GenerateCodeVerifier() 返回空字符串")
|
||||
}
|
||||
// RFC 7636 要求 code_verifier 至少 43 个字符
|
||||
if len(verifier) < 43 {
|
||||
t.Errorf("GenerateCodeVerifier() 长度=%d,RFC 7636 要求至少 43 字符", len(verifier))
|
||||
}
|
||||
// base64url 编码不应包含 padding 和非 URL 安全字符
|
||||
if strings.Contains(verifier, "=") {
|
||||
t.Errorf("GenerateCodeVerifier() 包含 '=' padding: %s", verifier)
|
||||
}
|
||||
if strings.ContainsAny(verifier, "+/") {
|
||||
t.Errorf("GenerateCodeVerifier() 包含非 base64url 字符: %s", verifier)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// GenerateCodeChallenge 测试
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestGenerateCodeChallenge(t *testing.T) {
|
||||
// 使用已知输入验证输出
|
||||
// RFC 7636 附录 B 示例: verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
|
||||
// 预期 challenge = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"
|
||||
verifier := "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
|
||||
expected := "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"
|
||||
|
||||
challenge := GenerateCodeChallenge(verifier)
|
||||
if challenge != expected {
|
||||
t.Errorf("GenerateCodeChallenge(%q) = %q,期望 %q", verifier, challenge, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateCodeChallenge_NoPadding(t *testing.T) {
|
||||
challenge := GenerateCodeChallenge("test-verifier-string")
|
||||
if strings.Contains(challenge, "=") {
|
||||
t.Errorf("GenerateCodeChallenge() 结果包含 '=' padding: %s", challenge)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// base64URLEncode 测试
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestBase64URLEncode(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input []byte
|
||||
}{
|
||||
{"空字节", []byte{}},
|
||||
{"单字节", []byte{0xff}},
|
||||
{"多字节", []byte{0x01, 0x02, 0x03, 0x04, 0x05}},
|
||||
{"全零", []byte{0x00, 0x00, 0x00}},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := base64URLEncode(tt.input)
|
||||
// 不应包含 '=' padding
|
||||
if strings.Contains(result, "=") {
|
||||
t.Errorf("base64URLEncode(%v) 包含 '=' padding: %s", tt.input, result)
|
||||
}
|
||||
// 不应包含标准 base64 的 '+' 或 '/'
|
||||
if strings.ContainsAny(result, "+/") {
|
||||
t.Errorf("base64URLEncode(%v) 包含非 URL 安全字符: %s", tt.input, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// hasRestrictedScope 测试
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestHasRestrictedScope(t *testing.T) {
|
||||
tests := []struct {
|
||||
scope string
|
||||
expected bool
|
||||
}{
|
||||
// 受限 scope
|
||||
{"https://www.googleapis.com/auth/generative-language", true},
|
||||
{"https://www.googleapis.com/auth/generative-language.retriever", true},
|
||||
{"https://www.googleapis.com/auth/generative-language.tuning", true},
|
||||
{"https://www.googleapis.com/auth/drive", true},
|
||||
{"https://www.googleapis.com/auth/drive.readonly", true},
|
||||
{"https://www.googleapis.com/auth/drive.file", true},
|
||||
// 非受限 scope
|
||||
{"https://www.googleapis.com/auth/cloud-platform", false},
|
||||
{"https://www.googleapis.com/auth/userinfo.email", false},
|
||||
{"https://www.googleapis.com/auth/userinfo.profile", false},
|
||||
// 边界情况
|
||||
{"", false},
|
||||
{"random-scope", false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.scope, func(t *testing.T) {
|
||||
got := hasRestrictedScope(tt.scope)
|
||||
if got != tt.expected {
|
||||
t.Errorf("hasRestrictedScope(%q) = %v,期望 %v", tt.scope, got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// BuildAuthorizationURL 测试
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestBuildAuthorizationURL(t *testing.T) {
|
||||
t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-secret")
|
||||
|
||||
authURL, err := BuildAuthorizationURL(
|
||||
OAuthConfig{},
|
||||
"test-state",
|
||||
"test-challenge",
|
||||
"https://example.com/callback",
|
||||
"",
|
||||
"code_assist",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("BuildAuthorizationURL() 出错: %v", err)
|
||||
}
|
||||
|
||||
// 检查返回的 URL 包含期望的参数
|
||||
checks := []string{
|
||||
"response_type=code",
|
||||
"client_id=" + GeminiCLIOAuthClientID,
|
||||
"redirect_uri=",
|
||||
"state=test-state",
|
||||
"code_challenge=test-challenge",
|
||||
"code_challenge_method=S256",
|
||||
"access_type=offline",
|
||||
"prompt=consent",
|
||||
"include_granted_scopes=true",
|
||||
}
|
||||
for _, check := range checks {
|
||||
if !strings.Contains(authURL, check) {
|
||||
t.Errorf("BuildAuthorizationURL() URL 缺少参数 %q\nURL: %s", check, authURL)
|
||||
}
|
||||
}
|
||||
|
||||
// 不应包含 project_id(因为传的是空字符串)
|
||||
if strings.Contains(authURL, "project_id=") {
|
||||
t.Errorf("BuildAuthorizationURL() 空 projectID 时不应包含 project_id 参数")
|
||||
}
|
||||
|
||||
// URL 应该以正确的授权端点开头
|
||||
if !strings.HasPrefix(authURL, AuthorizeURL+"?") {
|
||||
t.Errorf("BuildAuthorizationURL() URL 应以 %s? 开头,实际: %s", AuthorizeURL, authURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAuthorizationURL_EmptyRedirectURI(t *testing.T) {
|
||||
t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-secret")
|
||||
|
||||
_, err := BuildAuthorizationURL(
|
||||
OAuthConfig{},
|
||||
"test-state",
|
||||
"test-challenge",
|
||||
"", // 空 redirectURI
|
||||
"",
|
||||
"code_assist",
|
||||
)
|
||||
if err == nil {
|
||||
t.Error("BuildAuthorizationURL() 空 redirectURI 应该报错")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "redirect_uri") {
|
||||
t.Errorf("错误消息应包含 'redirect_uri',实际: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAuthorizationURL_WithProjectID(t *testing.T) {
|
||||
t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-secret")
|
||||
|
||||
authURL, err := BuildAuthorizationURL(
|
||||
OAuthConfig{},
|
||||
"test-state",
|
||||
"test-challenge",
|
||||
"https://example.com/callback",
|
||||
"my-project-123",
|
||||
"code_assist",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("BuildAuthorizationURL() 出错: %v", err)
|
||||
}
|
||||
if !strings.Contains(authURL, "project_id=my-project-123") {
|
||||
t.Errorf("BuildAuthorizationURL() 带 projectID 时应包含 project_id 参数\nURL: %s", authURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAuthorizationURL_OAuthConfigError(t *testing.T) {
|
||||
// 不设置环境变量,也不提供 client 凭据,EffectiveOAuthConfig 应该报错
|
||||
t.Setenv(GeminiCLIOAuthClientSecretEnv, "")
|
||||
|
||||
_, err := BuildAuthorizationURL(
|
||||
OAuthConfig{},
|
||||
"test-state",
|
||||
"test-challenge",
|
||||
"https://example.com/callback",
|
||||
"",
|
||||
"code_assist",
|
||||
)
|
||||
if err == nil {
|
||||
t.Error("当 EffectiveOAuthConfig 失败时,BuildAuthorizationURL 应该返回错误")
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// EffectiveOAuthConfig 测试 - 原有测试
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) {
|
||||
// 内置的 Gemini CLI client secret 不嵌入在此仓库中。
|
||||
// 测试通过环境变量设置一个假的 secret 来模拟运维配置。
|
||||
t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-built-in-secret")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input OAuthConfig
|
||||
@@ -15,7 +443,7 @@ func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) {
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "Google One with built-in client (empty config)",
|
||||
name: "Google One 使用内置客户端(空配置)",
|
||||
input: OAuthConfig{},
|
||||
oauthType: "google_one",
|
||||
wantClientID: GeminiCLIOAuthClientID,
|
||||
@@ -23,18 +451,18 @@ func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) {
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Google One always uses built-in client (even if custom credentials passed)",
|
||||
name: "Google One 使用自定义客户端(传入自定义凭据时使用自定义)",
|
||||
input: OAuthConfig{
|
||||
ClientID: "custom-client-id",
|
||||
ClientSecret: "custom-client-secret",
|
||||
},
|
||||
oauthType: "google_one",
|
||||
wantClientID: "custom-client-id",
|
||||
wantScopes: DefaultCodeAssistScopes, // Uses code assist scopes even with custom client
|
||||
wantScopes: DefaultCodeAssistScopes,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Google One with built-in client and custom scopes (should filter restricted scopes)",
|
||||
name: "Google One 内置客户端 + 自定义 scopes(应过滤受限 scopes)",
|
||||
input: OAuthConfig{
|
||||
Scopes: "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/generative-language.retriever https://www.googleapis.com/auth/drive.readonly",
|
||||
},
|
||||
@@ -44,7 +472,7 @@ func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) {
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Google One with built-in client and only restricted scopes (should fallback to default)",
|
||||
name: "Google One 内置客户端 + 仅受限 scopes(应回退到默认)",
|
||||
input: OAuthConfig{
|
||||
Scopes: "https://www.googleapis.com/auth/generative-language.retriever https://www.googleapis.com/auth/drive.readonly",
|
||||
},
|
||||
@@ -54,7 +482,7 @@ func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) {
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Code Assist with built-in client",
|
||||
name: "Code Assist 使用内置客户端",
|
||||
input: OAuthConfig{},
|
||||
oauthType: "code_assist",
|
||||
wantClientID: GeminiCLIOAuthClientID,
|
||||
@@ -84,7 +512,9 @@ func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestEffectiveOAuthConfig_ScopeFiltering(t *testing.T) {
|
||||
// Test that Google One with built-in client filters out restricted scopes
|
||||
t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-built-in-secret")
|
||||
|
||||
// 测试 Google One + 内置客户端过滤受限 scopes
|
||||
cfg, err := EffectiveOAuthConfig(OAuthConfig{
|
||||
Scopes: "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/userinfo.email https://www.googleapis.com/auth/generative-language.retriever https://www.googleapis.com/auth/drive.readonly https://www.googleapis.com/auth/userinfo.profile",
|
||||
}, "google_one")
|
||||
@@ -93,21 +523,240 @@ func TestEffectiveOAuthConfig_ScopeFiltering(t *testing.T) {
|
||||
t.Fatalf("EffectiveOAuthConfig() error = %v", err)
|
||||
}
|
||||
|
||||
// Should only contain cloud-platform, userinfo.email, and userinfo.profile
|
||||
// Should NOT contain generative-language or drive scopes
|
||||
// 应仅包含 cloud-platform、userinfo.email 和 userinfo.profile
|
||||
// 不应包含 generative-language 或 drive scopes
|
||||
if strings.Contains(cfg.Scopes, "generative-language") {
|
||||
t.Errorf("Scopes should not contain generative-language when using built-in client, got: %v", cfg.Scopes)
|
||||
t.Errorf("使用内置客户端时 Scopes 不应包含 generative-language,实际: %v", cfg.Scopes)
|
||||
}
|
||||
if strings.Contains(cfg.Scopes, "drive") {
|
||||
t.Errorf("Scopes should not contain drive when using built-in client, got: %v", cfg.Scopes)
|
||||
t.Errorf("使用内置客户端时 Scopes 不应包含 drive,实际: %v", cfg.Scopes)
|
||||
}
|
||||
if !strings.Contains(cfg.Scopes, "cloud-platform") {
|
||||
t.Errorf("Scopes should contain cloud-platform, got: %v", cfg.Scopes)
|
||||
t.Errorf("Scopes 应包含 cloud-platform,实际: %v", cfg.Scopes)
|
||||
}
|
||||
if !strings.Contains(cfg.Scopes, "userinfo.email") {
|
||||
t.Errorf("Scopes should contain userinfo.email, got: %v", cfg.Scopes)
|
||||
t.Errorf("Scopes 应包含 userinfo.email,实际: %v", cfg.Scopes)
|
||||
}
|
||||
if !strings.Contains(cfg.Scopes, "userinfo.profile") {
|
||||
t.Errorf("Scopes should contain userinfo.profile, got: %v", cfg.Scopes)
|
||||
t.Errorf("Scopes 应包含 userinfo.profile,实际: %v", cfg.Scopes)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// EffectiveOAuthConfig 测试 - 新增分支覆盖
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestEffectiveOAuthConfig_OnlyClientID_NoSecret(t *testing.T) {
|
||||
// 只提供 clientID 不提供 secret 应报错
|
||||
_, err := EffectiveOAuthConfig(OAuthConfig{
|
||||
ClientID: "some-client-id",
|
||||
}, "code_assist")
|
||||
if err == nil {
|
||||
t.Error("只提供 ClientID 不提供 ClientSecret 应该报错")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "client_id") || !strings.Contains(err.Error(), "client_secret") {
|
||||
t.Errorf("错误消息应提及 client_id 和 client_secret,实际: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEffectiveOAuthConfig_OnlyClientSecret_NoID(t *testing.T) {
|
||||
// 只提供 secret 不提供 clientID 应报错
|
||||
_, err := EffectiveOAuthConfig(OAuthConfig{
|
||||
ClientSecret: "some-client-secret",
|
||||
}, "code_assist")
|
||||
if err == nil {
|
||||
t.Error("只提供 ClientSecret 不提供 ClientID 应该报错")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "client_id") || !strings.Contains(err.Error(), "client_secret") {
|
||||
t.Errorf("错误消息应提及 client_id 和 client_secret,实际: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEffectiveOAuthConfig_AIStudio_DefaultScopes_BuiltinClient(t *testing.T) {
|
||||
t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-built-in-secret")
|
||||
|
||||
// ai_studio 类型,使用内置客户端,scopes 为空 -> 应使用 DefaultCodeAssistScopes(因为内置客户端不能请求 generative-language scope)
|
||||
cfg, err := EffectiveOAuthConfig(OAuthConfig{}, "ai_studio")
|
||||
if err != nil {
|
||||
t.Fatalf("EffectiveOAuthConfig() error = %v", err)
|
||||
}
|
||||
if cfg.Scopes != DefaultCodeAssistScopes {
|
||||
t.Errorf("ai_studio + 内置客户端应使用 DefaultCodeAssistScopes,实际: %q", cfg.Scopes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEffectiveOAuthConfig_AIStudio_DefaultScopes_CustomClient(t *testing.T) {
|
||||
// ai_studio 类型,使用自定义客户端,scopes 为空 -> 应使用 DefaultAIStudioScopes
|
||||
cfg, err := EffectiveOAuthConfig(OAuthConfig{
|
||||
ClientID: "custom-id",
|
||||
ClientSecret: "custom-secret",
|
||||
}, "ai_studio")
|
||||
if err != nil {
|
||||
t.Fatalf("EffectiveOAuthConfig() error = %v", err)
|
||||
}
|
||||
if cfg.Scopes != DefaultAIStudioScopes {
|
||||
t.Errorf("ai_studio + 自定义客户端应使用 DefaultAIStudioScopes,实际: %q", cfg.Scopes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEffectiveOAuthConfig_AIStudio_ScopeNormalization(t *testing.T) {
|
||||
// ai_studio 类型,旧的 generative-language scope 应被归一化为 generative-language.retriever
|
||||
cfg, err := EffectiveOAuthConfig(OAuthConfig{
|
||||
ClientID: "custom-id",
|
||||
ClientSecret: "custom-secret",
|
||||
Scopes: "https://www.googleapis.com/auth/generative-language https://www.googleapis.com/auth/cloud-platform",
|
||||
}, "ai_studio")
|
||||
if err != nil {
|
||||
t.Fatalf("EffectiveOAuthConfig() error = %v", err)
|
||||
}
|
||||
if strings.Contains(cfg.Scopes, "auth/generative-language ") || strings.HasSuffix(cfg.Scopes, "auth/generative-language") {
|
||||
// 确保不包含未归一化的旧 scope(仅 generative-language 而非 generative-language.retriever)
|
||||
parts := strings.Fields(cfg.Scopes)
|
||||
for _, p := range parts {
|
||||
if p == "https://www.googleapis.com/auth/generative-language" {
|
||||
t.Errorf("ai_studio 应将 generative-language 归一化为 generative-language.retriever,实际 scopes: %q", cfg.Scopes)
|
||||
}
|
||||
}
|
||||
}
|
||||
if !strings.Contains(cfg.Scopes, "generative-language.retriever") {
|
||||
t.Errorf("ai_studio 归一化后应包含 generative-language.retriever,实际: %q", cfg.Scopes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEffectiveOAuthConfig_CommaSeparatedScopes(t *testing.T) {
|
||||
t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-built-in-secret")
|
||||
|
||||
// 逗号分隔的 scopes 应被归一化为空格分隔
|
||||
cfg, err := EffectiveOAuthConfig(OAuthConfig{
|
||||
ClientID: "custom-id",
|
||||
ClientSecret: "custom-secret",
|
||||
Scopes: "https://www.googleapis.com/auth/cloud-platform,https://www.googleapis.com/auth/userinfo.email",
|
||||
}, "code_assist")
|
||||
if err != nil {
|
||||
t.Fatalf("EffectiveOAuthConfig() error = %v", err)
|
||||
}
|
||||
// 应该用空格分隔,而非逗号
|
||||
if strings.Contains(cfg.Scopes, ",") {
|
||||
t.Errorf("逗号分隔的 scopes 应被归一化为空格分隔,实际: %q", cfg.Scopes)
|
||||
}
|
||||
if !strings.Contains(cfg.Scopes, "cloud-platform") {
|
||||
t.Errorf("归一化后应包含 cloud-platform,实际: %q", cfg.Scopes)
|
||||
}
|
||||
if !strings.Contains(cfg.Scopes, "userinfo.email") {
|
||||
t.Errorf("归一化后应包含 userinfo.email,实际: %q", cfg.Scopes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEffectiveOAuthConfig_MixedCommaAndSpaceScopes(t *testing.T) {
|
||||
// 混合逗号和空格分隔的 scopes
|
||||
cfg, err := EffectiveOAuthConfig(OAuthConfig{
|
||||
ClientID: "custom-id",
|
||||
ClientSecret: "custom-secret",
|
||||
Scopes: "https://www.googleapis.com/auth/cloud-platform, https://www.googleapis.com/auth/userinfo.email https://www.googleapis.com/auth/userinfo.profile",
|
||||
}, "code_assist")
|
||||
if err != nil {
|
||||
t.Fatalf("EffectiveOAuthConfig() error = %v", err)
|
||||
}
|
||||
parts := strings.Fields(cfg.Scopes)
|
||||
if len(parts) != 3 {
|
||||
t.Errorf("归一化后应有 3 个 scope,实际: %d,scopes: %q", len(parts), cfg.Scopes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEffectiveOAuthConfig_WhitespaceTriming(t *testing.T) {
|
||||
// 输入中的前后空白应被清理
|
||||
cfg, err := EffectiveOAuthConfig(OAuthConfig{
|
||||
ClientID: " custom-id ",
|
||||
ClientSecret: " custom-secret ",
|
||||
Scopes: " https://www.googleapis.com/auth/cloud-platform ",
|
||||
}, "code_assist")
|
||||
if err != nil {
|
||||
t.Fatalf("EffectiveOAuthConfig() error = %v", err)
|
||||
}
|
||||
if cfg.ClientID != "custom-id" {
|
||||
t.Errorf("ClientID 应去除前后空白,实际: %q", cfg.ClientID)
|
||||
}
|
||||
if cfg.ClientSecret != "custom-secret" {
|
||||
t.Errorf("ClientSecret 应去除前后空白,实际: %q", cfg.ClientSecret)
|
||||
}
|
||||
if cfg.Scopes != "https://www.googleapis.com/auth/cloud-platform" {
|
||||
t.Errorf("Scopes 应去除前后空白,实际: %q", cfg.Scopes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEffectiveOAuthConfig_NoEnvSecret(t *testing.T) {
|
||||
// 不设置环境变量且不提供凭据,应该报错
|
||||
t.Setenv(GeminiCLIOAuthClientSecretEnv, "")
|
||||
|
||||
_, err := EffectiveOAuthConfig(OAuthConfig{}, "code_assist")
|
||||
if err == nil {
|
||||
t.Error("没有内置 secret 且未提供凭据时应该报错")
|
||||
}
|
||||
if !strings.Contains(err.Error(), GeminiCLIOAuthClientSecretEnv) {
|
||||
t.Errorf("错误消息应提及环境变量 %s,实际: %v", GeminiCLIOAuthClientSecretEnv, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEffectiveOAuthConfig_AIStudio_BuiltinClient_CustomScopes(t *testing.T) {
|
||||
t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-built-in-secret")
|
||||
|
||||
// ai_studio + 内置客户端 + 自定义 scopes -> 应过滤受限 scopes
|
||||
cfg, err := EffectiveOAuthConfig(OAuthConfig{
|
||||
Scopes: "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/generative-language.retriever",
|
||||
}, "ai_studio")
|
||||
if err != nil {
|
||||
t.Fatalf("EffectiveOAuthConfig() error = %v", err)
|
||||
}
|
||||
// 内置客户端应过滤 generative-language.retriever
|
||||
if strings.Contains(cfg.Scopes, "generative-language") {
|
||||
t.Errorf("ai_studio + 内置客户端应过滤受限 scopes,实际: %q", cfg.Scopes)
|
||||
}
|
||||
if !strings.Contains(cfg.Scopes, "cloud-platform") {
|
||||
t.Errorf("应保留 cloud-platform scope,实际: %q", cfg.Scopes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEffectiveOAuthConfig_UnknownOAuthType_DefaultScopes(t *testing.T) {
|
||||
t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-built-in-secret")
|
||||
|
||||
// 未知的 oauthType 应回退到默认的 code_assist scopes
|
||||
cfg, err := EffectiveOAuthConfig(OAuthConfig{}, "unknown_type")
|
||||
if err != nil {
|
||||
t.Fatalf("EffectiveOAuthConfig() error = %v", err)
|
||||
}
|
||||
if cfg.Scopes != DefaultCodeAssistScopes {
|
||||
t.Errorf("未知 oauthType 应使用 DefaultCodeAssistScopes,实际: %q", cfg.Scopes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEffectiveOAuthConfig_EmptyOAuthType_DefaultScopes(t *testing.T) {
|
||||
t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-built-in-secret")
|
||||
|
||||
// 空的 oauthType 应走 default 分支,使用 DefaultCodeAssistScopes
|
||||
cfg, err := EffectiveOAuthConfig(OAuthConfig{}, "")
|
||||
if err != nil {
|
||||
t.Fatalf("EffectiveOAuthConfig() error = %v", err)
|
||||
}
|
||||
if cfg.Scopes != DefaultCodeAssistScopes {
|
||||
t.Errorf("空 oauthType 应使用 DefaultCodeAssistScopes,实际: %q", cfg.Scopes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEffectiveOAuthConfig_CustomClient_NoScopeFiltering(t *testing.T) {
|
||||
// 自定义客户端 + google_one + 包含受限 scopes -> 不应被过滤(因为不是内置客户端)
|
||||
cfg, err := EffectiveOAuthConfig(OAuthConfig{
|
||||
ClientID: "custom-id",
|
||||
ClientSecret: "custom-secret",
|
||||
Scopes: "https://www.googleapis.com/auth/generative-language.retriever https://www.googleapis.com/auth/drive.readonly",
|
||||
}, "google_one")
|
||||
if err != nil {
|
||||
t.Fatalf("EffectiveOAuthConfig() error = %v", err)
|
||||
}
|
||||
// 自定义客户端不应过滤任何 scope
|
||||
if !strings.Contains(cfg.Scopes, "generative-language.retriever") {
|
||||
t.Errorf("自定义客户端不应过滤 generative-language.retriever,实际: %q", cfg.Scopes)
|
||||
}
|
||||
if !strings.Contains(cfg.Scopes, "drive.readonly") {
|
||||
t.Errorf("自定义客户端不应过滤 drive.readonly,实际: %q", cfg.Scopes)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -44,6 +44,16 @@ func GetClientIP(c *gin.Context) string {
|
||||
return normalizeIP(c.ClientIP())
|
||||
}
|
||||
|
||||
// GetTrustedClientIP 从 Gin 的可信代理解析链提取客户端 IP。
|
||||
// 该方法依赖 gin.Engine.SetTrustedProxies 配置,不会优先直接信任原始转发头值。
|
||||
// 适用于 ACL / 风控等安全敏感场景。
|
||||
func GetTrustedClientIP(c *gin.Context) string {
|
||||
if c == nil {
|
||||
return ""
|
||||
}
|
||||
return normalizeIP(c.ClientIP())
|
||||
}
|
||||
|
||||
// normalizeIP 规范化 IP 地址,去除端口号和空格。
|
||||
func normalizeIP(ip string) string {
|
||||
ip = strings.TrimSpace(ip)
|
||||
@@ -54,29 +64,34 @@ func normalizeIP(ip string) string {
|
||||
return ip
|
||||
}
|
||||
|
||||
// isPrivateIP 检查 IP 是否为私有地址。
|
||||
func isPrivateIP(ipStr string) bool {
|
||||
ip := net.ParseIP(ipStr)
|
||||
if ip == nil {
|
||||
return false
|
||||
}
|
||||
// privateNets 预编译私有 IP CIDR 块,避免每次调用 isPrivateIP 时重复解析
|
||||
var privateNets []*net.IPNet
|
||||
|
||||
// 私有 IP 范围
|
||||
privateBlocks := []string{
|
||||
func init() {
|
||||
for _, cidr := range []string{
|
||||
"10.0.0.0/8",
|
||||
"172.16.0.0/12",
|
||||
"192.168.0.0/16",
|
||||
"127.0.0.0/8",
|
||||
"::1/128",
|
||||
"fc00::/7",
|
||||
}
|
||||
|
||||
for _, block := range privateBlocks {
|
||||
_, cidr, err := net.ParseCIDR(block)
|
||||
} {
|
||||
_, block, err := net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
continue
|
||||
panic("invalid CIDR: " + cidr)
|
||||
}
|
||||
if cidr.Contains(ip) {
|
||||
privateNets = append(privateNets, block)
|
||||
}
|
||||
}
|
||||
|
||||
// isPrivateIP 检查 IP 是否为私有地址。
|
||||
func isPrivateIP(ipStr string) bool {
|
||||
ip := net.ParseIP(ipStr)
|
||||
if ip == nil {
|
||||
return false
|
||||
}
|
||||
for _, block := range privateNets {
|
||||
if block.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
75
backend/internal/pkg/ip/ip_test.go
Normal file
75
backend/internal/pkg/ip/ip_test.go
Normal file
@@ -0,0 +1,75 @@
|
||||
//go:build unit
|
||||
|
||||
package ip
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestIsPrivateIP(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ip string
|
||||
expected bool
|
||||
}{
|
||||
// 私有 IPv4
|
||||
{"10.x 私有地址", "10.0.0.1", true},
|
||||
{"10.x 私有地址段末", "10.255.255.255", true},
|
||||
{"172.16.x 私有地址", "172.16.0.1", true},
|
||||
{"172.31.x 私有地址", "172.31.255.255", true},
|
||||
{"192.168.x 私有地址", "192.168.1.1", true},
|
||||
{"127.0.0.1 本地回环", "127.0.0.1", true},
|
||||
{"127.x 回环段", "127.255.255.255", true},
|
||||
|
||||
// 公网 IPv4
|
||||
{"8.8.8.8 公网 DNS", "8.8.8.8", false},
|
||||
{"1.1.1.1 公网", "1.1.1.1", false},
|
||||
{"172.15.255.255 非私有", "172.15.255.255", false},
|
||||
{"172.32.0.0 非私有", "172.32.0.0", false},
|
||||
{"11.0.0.1 公网", "11.0.0.1", false},
|
||||
|
||||
// IPv6
|
||||
{"::1 IPv6 回环", "::1", true},
|
||||
{"fc00:: IPv6 私有", "fc00::1", true},
|
||||
{"fd00:: IPv6 私有", "fd00::1", true},
|
||||
{"2001:db8::1 IPv6 公网", "2001:db8::1", false},
|
||||
|
||||
// 无效输入
|
||||
{"空字符串", "", false},
|
||||
{"非法字符串", "not-an-ip", false},
|
||||
{"不完整 IP", "192.168", false},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := isPrivateIP(tc.ip)
|
||||
require.Equal(t, tc.expected, got, "isPrivateIP(%q)", tc.ip)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetTrustedClientIPUsesGinClientIP(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
require.NoError(t, r.SetTrustedProxies(nil))
|
||||
|
||||
r.GET("/t", func(c *gin.Context) {
|
||||
c.String(200, GetTrustedClientIP(c))
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("GET", "/t", nil)
|
||||
req.RemoteAddr = "9.9.9.9:12345"
|
||||
req.Header.Set("X-Forwarded-For", "1.2.3.4")
|
||||
req.Header.Set("X-Real-IP", "1.2.3.4")
|
||||
req.Header.Set("CF-Connecting-IP", "1.2.3.4")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, 200, w.Code)
|
||||
require.Equal(t, "9.9.9.9", w.Body.String())
|
||||
}
|
||||
31
backend/internal/pkg/logger/config_adapter.go
Normal file
31
backend/internal/pkg/logger/config_adapter.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package logger
|
||||
|
||||
import "github.com/Wei-Shaw/sub2api/internal/config"
|
||||
|
||||
func OptionsFromConfig(cfg config.LogConfig) InitOptions {
|
||||
return InitOptions{
|
||||
Level: cfg.Level,
|
||||
Format: cfg.Format,
|
||||
ServiceName: cfg.ServiceName,
|
||||
Environment: cfg.Environment,
|
||||
Caller: cfg.Caller,
|
||||
StacktraceLevel: cfg.StacktraceLevel,
|
||||
Output: OutputOptions{
|
||||
ToStdout: cfg.Output.ToStdout,
|
||||
ToFile: cfg.Output.ToFile,
|
||||
FilePath: cfg.Output.FilePath,
|
||||
},
|
||||
Rotation: RotationOptions{
|
||||
MaxSizeMB: cfg.Rotation.MaxSizeMB,
|
||||
MaxBackups: cfg.Rotation.MaxBackups,
|
||||
MaxAgeDays: cfg.Rotation.MaxAgeDays,
|
||||
Compress: cfg.Rotation.Compress,
|
||||
LocalTime: cfg.Rotation.LocalTime,
|
||||
},
|
||||
Sampling: SamplingOptions{
|
||||
Enabled: cfg.Sampling.Enabled,
|
||||
Initial: cfg.Sampling.Initial,
|
||||
Thereafter: cfg.Sampling.Thereafter,
|
||||
},
|
||||
}
|
||||
}
|
||||
519
backend/internal/pkg/logger/logger.go
Normal file
519
backend/internal/pkg/logger/logger.go
Normal file
@@ -0,0 +1,519 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
"gopkg.in/natefinch/lumberjack.v2"
|
||||
)
|
||||
|
||||
type Level = zapcore.Level
|
||||
|
||||
const (
|
||||
LevelDebug = zapcore.DebugLevel
|
||||
LevelInfo = zapcore.InfoLevel
|
||||
LevelWarn = zapcore.WarnLevel
|
||||
LevelError = zapcore.ErrorLevel
|
||||
LevelFatal = zapcore.FatalLevel
|
||||
)
|
||||
|
||||
type Sink interface {
|
||||
WriteLogEvent(event *LogEvent)
|
||||
}
|
||||
|
||||
type LogEvent struct {
|
||||
Time time.Time
|
||||
Level string
|
||||
Component string
|
||||
Message string
|
||||
LoggerName string
|
||||
Fields map[string]any
|
||||
}
|
||||
|
||||
var (
|
||||
mu sync.RWMutex
|
||||
global *zap.Logger
|
||||
sugar *zap.SugaredLogger
|
||||
atomicLevel zap.AtomicLevel
|
||||
initOptions InitOptions
|
||||
currentSink Sink
|
||||
stdLogUndo func()
|
||||
bootstrapOnce sync.Once
|
||||
)
|
||||
|
||||
func InitBootstrap() {
|
||||
bootstrapOnce.Do(func() {
|
||||
if err := Init(bootstrapOptions()); err != nil {
|
||||
_, _ = fmt.Fprintf(os.Stderr, "logger bootstrap init failed: %v\n", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func Init(options InitOptions) error {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
return initLocked(options)
|
||||
}
|
||||
|
||||
func initLocked(options InitOptions) error {
|
||||
normalized := options.normalized()
|
||||
zl, al, err := buildLogger(normalized)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
prev := global
|
||||
global = zl
|
||||
sugar = zl.Sugar()
|
||||
atomicLevel = al
|
||||
initOptions = normalized
|
||||
|
||||
bridgeSlogLocked()
|
||||
bridgeStdLogLocked()
|
||||
|
||||
if prev != nil {
|
||||
_ = prev.Sync()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func Reconfigure(mutator func(*InitOptions) error) error {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
next := initOptions
|
||||
if mutator != nil {
|
||||
if err := mutator(&next); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return initLocked(next)
|
||||
}
|
||||
|
||||
func SetLevel(level string) error {
|
||||
lv, ok := parseLevel(level)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid log level: %s", level)
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
atomicLevel.SetLevel(lv)
|
||||
initOptions.Level = strings.ToLower(strings.TrimSpace(level))
|
||||
return nil
|
||||
}
|
||||
|
||||
func CurrentLevel() string {
|
||||
mu.RLock()
|
||||
defer mu.RUnlock()
|
||||
if global == nil {
|
||||
return "info"
|
||||
}
|
||||
return atomicLevel.Level().String()
|
||||
}
|
||||
|
||||
func SetSink(sink Sink) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
currentSink = sink
|
||||
}
|
||||
|
||||
// WriteSinkEvent 直接写入日志 sink,不经过全局日志级别门控。
|
||||
// 用于需要“可观测性入库”与“业务输出级别”解耦的场景(例如 ops 系统日志索引)。
|
||||
func WriteSinkEvent(level, component, message string, fields map[string]any) {
|
||||
mu.RLock()
|
||||
sink := currentSink
|
||||
mu.RUnlock()
|
||||
if sink == nil {
|
||||
return
|
||||
}
|
||||
|
||||
level = strings.ToLower(strings.TrimSpace(level))
|
||||
if level == "" {
|
||||
level = "info"
|
||||
}
|
||||
component = strings.TrimSpace(component)
|
||||
message = strings.TrimSpace(message)
|
||||
if message == "" {
|
||||
return
|
||||
}
|
||||
|
||||
eventFields := make(map[string]any, len(fields)+1)
|
||||
for k, v := range fields {
|
||||
eventFields[k] = v
|
||||
}
|
||||
if component != "" {
|
||||
if _, ok := eventFields["component"]; !ok {
|
||||
eventFields["component"] = component
|
||||
}
|
||||
}
|
||||
|
||||
sink.WriteLogEvent(&LogEvent{
|
||||
Time: time.Now(),
|
||||
Level: level,
|
||||
Component: component,
|
||||
Message: message,
|
||||
LoggerName: component,
|
||||
Fields: eventFields,
|
||||
})
|
||||
}
|
||||
|
||||
func L() *zap.Logger {
|
||||
mu.RLock()
|
||||
defer mu.RUnlock()
|
||||
if global != nil {
|
||||
return global
|
||||
}
|
||||
return zap.NewNop()
|
||||
}
|
||||
|
||||
func S() *zap.SugaredLogger {
|
||||
mu.RLock()
|
||||
defer mu.RUnlock()
|
||||
if sugar != nil {
|
||||
return sugar
|
||||
}
|
||||
return zap.NewNop().Sugar()
|
||||
}
|
||||
|
||||
func With(fields ...zap.Field) *zap.Logger {
|
||||
return L().With(fields...)
|
||||
}
|
||||
|
||||
func Sync() {
|
||||
mu.RLock()
|
||||
l := global
|
||||
mu.RUnlock()
|
||||
if l != nil {
|
||||
_ = l.Sync()
|
||||
}
|
||||
}
|
||||
|
||||
func bridgeStdLogLocked() {
|
||||
if stdLogUndo != nil {
|
||||
stdLogUndo()
|
||||
stdLogUndo = nil
|
||||
}
|
||||
|
||||
prevFlags := log.Flags()
|
||||
prevPrefix := log.Prefix()
|
||||
prevWriter := log.Writer()
|
||||
|
||||
log.SetFlags(0)
|
||||
log.SetPrefix("")
|
||||
log.SetOutput(newStdLogBridge(global.Named("stdlog")))
|
||||
|
||||
stdLogUndo = func() {
|
||||
log.SetOutput(prevWriter)
|
||||
log.SetFlags(prevFlags)
|
||||
log.SetPrefix(prevPrefix)
|
||||
}
|
||||
}
|
||||
|
||||
func bridgeSlogLocked() {
|
||||
slog.SetDefault(slog.New(newSlogZapHandler(global.Named("slog"))))
|
||||
}
|
||||
|
||||
func buildLogger(options InitOptions) (*zap.Logger, zap.AtomicLevel, error) {
|
||||
level, _ := parseLevel(options.Level)
|
||||
atomic := zap.NewAtomicLevelAt(level)
|
||||
|
||||
encoderCfg := zapcore.EncoderConfig{
|
||||
TimeKey: "time",
|
||||
LevelKey: "level",
|
||||
NameKey: "logger",
|
||||
CallerKey: "caller",
|
||||
MessageKey: "msg",
|
||||
StacktraceKey: "stacktrace",
|
||||
LineEnding: zapcore.DefaultLineEnding,
|
||||
EncodeLevel: zapcore.CapitalLevelEncoder,
|
||||
EncodeTime: zapcore.ISO8601TimeEncoder,
|
||||
EncodeDuration: zapcore.MillisDurationEncoder,
|
||||
EncodeCaller: zapcore.ShortCallerEncoder,
|
||||
}
|
||||
|
||||
var enc zapcore.Encoder
|
||||
if options.Format == "console" {
|
||||
enc = zapcore.NewConsoleEncoder(encoderCfg)
|
||||
} else {
|
||||
enc = zapcore.NewJSONEncoder(encoderCfg)
|
||||
}
|
||||
|
||||
sinkCore := newSinkCore()
|
||||
cores := make([]zapcore.Core, 0, 3)
|
||||
|
||||
if options.Output.ToStdout {
|
||||
infoPriority := zap.LevelEnablerFunc(func(lvl zapcore.Level) bool {
|
||||
return lvl >= atomic.Level() && lvl < zapcore.WarnLevel
|
||||
})
|
||||
errPriority := zap.LevelEnablerFunc(func(lvl zapcore.Level) bool {
|
||||
return lvl >= atomic.Level() && lvl >= zapcore.WarnLevel
|
||||
})
|
||||
cores = append(cores, zapcore.NewCore(enc, zapcore.Lock(os.Stdout), infoPriority))
|
||||
cores = append(cores, zapcore.NewCore(enc, zapcore.Lock(os.Stderr), errPriority))
|
||||
}
|
||||
|
||||
if options.Output.ToFile {
|
||||
fileCore, filePath, fileErr := buildFileCore(enc, atomic, options)
|
||||
if fileErr != nil {
|
||||
_, _ = fmt.Fprintf(os.Stderr, "time=%s level=WARN msg=\"日志文件输出初始化失败,降级为仅标准输出\" path=%s err=%v\n",
|
||||
time.Now().Format(time.RFC3339Nano),
|
||||
filePath,
|
||||
fileErr,
|
||||
)
|
||||
} else {
|
||||
cores = append(cores, fileCore)
|
||||
}
|
||||
}
|
||||
|
||||
if len(cores) == 0 {
|
||||
cores = append(cores, zapcore.NewCore(enc, zapcore.Lock(os.Stdout), atomic))
|
||||
}
|
||||
|
||||
core := zapcore.NewTee(cores...)
|
||||
if options.Sampling.Enabled {
|
||||
core = zapcore.NewSamplerWithOptions(core, samplingTick(), options.Sampling.Initial, options.Sampling.Thereafter)
|
||||
}
|
||||
core = sinkCore.Wrap(core)
|
||||
|
||||
stacktraceLevel, _ := parseStacktraceLevel(options.StacktraceLevel)
|
||||
zapOpts := make([]zap.Option, 0, 5)
|
||||
if options.Caller {
|
||||
zapOpts = append(zapOpts, zap.AddCaller())
|
||||
}
|
||||
if stacktraceLevel <= zapcore.FatalLevel {
|
||||
zapOpts = append(zapOpts, zap.AddStacktrace(stacktraceLevel))
|
||||
}
|
||||
|
||||
logger := zap.New(core, zapOpts...).With(
|
||||
zap.String("service", options.ServiceName),
|
||||
zap.String("env", options.Environment),
|
||||
)
|
||||
return logger, atomic, nil
|
||||
}
|
||||
|
||||
func buildFileCore(enc zapcore.Encoder, atomic zap.AtomicLevel, options InitOptions) (zapcore.Core, string, error) {
|
||||
filePath := options.Output.FilePath
|
||||
if strings.TrimSpace(filePath) == "" {
|
||||
filePath = resolveLogFilePath("")
|
||||
}
|
||||
|
||||
dir := filepath.Dir(filePath)
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return nil, filePath, err
|
||||
}
|
||||
lj := &lumberjack.Logger{
|
||||
Filename: filePath,
|
||||
MaxSize: options.Rotation.MaxSizeMB,
|
||||
MaxBackups: options.Rotation.MaxBackups,
|
||||
MaxAge: options.Rotation.MaxAgeDays,
|
||||
Compress: options.Rotation.Compress,
|
||||
LocalTime: options.Rotation.LocalTime,
|
||||
}
|
||||
return zapcore.NewCore(enc, zapcore.AddSync(lj), atomic), filePath, nil
|
||||
}
|
||||
|
||||
type sinkCore struct {
|
||||
core zapcore.Core
|
||||
fields []zapcore.Field
|
||||
}
|
||||
|
||||
func newSinkCore() *sinkCore {
|
||||
return &sinkCore{}
|
||||
}
|
||||
|
||||
func (s *sinkCore) Wrap(core zapcore.Core) zapcore.Core {
|
||||
cp := *s
|
||||
cp.core = core
|
||||
return &cp
|
||||
}
|
||||
|
||||
func (s *sinkCore) Enabled(level zapcore.Level) bool {
|
||||
return s.core.Enabled(level)
|
||||
}
|
||||
|
||||
func (s *sinkCore) With(fields []zapcore.Field) zapcore.Core {
|
||||
nextFields := append([]zapcore.Field{}, s.fields...)
|
||||
nextFields = append(nextFields, fields...)
|
||||
return &sinkCore{
|
||||
core: s.core.With(fields),
|
||||
fields: nextFields,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *sinkCore) Check(entry zapcore.Entry, ce *zapcore.CheckedEntry) *zapcore.CheckedEntry {
|
||||
// Delegate to inner core (tee) so each sub-core's level enabler is respected.
|
||||
// Then add ourselves for sink forwarding only.
|
||||
ce = s.core.Check(entry, ce)
|
||||
if ce != nil {
|
||||
ce = ce.AddCore(entry, s)
|
||||
}
|
||||
return ce
|
||||
}
|
||||
|
||||
func (s *sinkCore) Write(entry zapcore.Entry, fields []zapcore.Field) error {
|
||||
// Only handle sink forwarding — the inner cores write via their own
|
||||
// Write methods (added to CheckedEntry by s.core.Check above).
|
||||
mu.RLock()
|
||||
sink := currentSink
|
||||
mu.RUnlock()
|
||||
if sink == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
enc := zapcore.NewMapObjectEncoder()
|
||||
for _, f := range s.fields {
|
||||
f.AddTo(enc)
|
||||
}
|
||||
for _, f := range fields {
|
||||
f.AddTo(enc)
|
||||
}
|
||||
|
||||
event := &LogEvent{
|
||||
Time: entry.Time,
|
||||
Level: strings.ToLower(entry.Level.String()),
|
||||
Component: entry.LoggerName,
|
||||
Message: entry.Message,
|
||||
LoggerName: entry.LoggerName,
|
||||
Fields: enc.Fields,
|
||||
}
|
||||
sink.WriteLogEvent(event)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *sinkCore) Sync() error {
|
||||
return s.core.Sync()
|
||||
}
|
||||
|
||||
type stdLogBridge struct {
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
func newStdLogBridge(l *zap.Logger) io.Writer {
|
||||
if l == nil {
|
||||
l = zap.NewNop()
|
||||
}
|
||||
return &stdLogBridge{logger: l}
|
||||
}
|
||||
|
||||
func (b *stdLogBridge) Write(p []byte) (int, error) {
|
||||
msg := normalizeStdLogMessage(string(p))
|
||||
if msg == "" {
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
level := inferStdLogLevel(msg)
|
||||
entry := b.logger.WithOptions(zap.AddCallerSkip(4))
|
||||
|
||||
switch level {
|
||||
case LevelDebug:
|
||||
entry.Debug(msg, zap.Bool("legacy_stdlog", true))
|
||||
case LevelWarn:
|
||||
entry.Warn(msg, zap.Bool("legacy_stdlog", true))
|
||||
case LevelError, LevelFatal:
|
||||
entry.Error(msg, zap.Bool("legacy_stdlog", true))
|
||||
default:
|
||||
entry.Info(msg, zap.Bool("legacy_stdlog", true))
|
||||
}
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func normalizeStdLogMessage(raw string) string {
|
||||
msg := strings.TrimSpace(strings.ReplaceAll(raw, "\n", " "))
|
||||
if msg == "" {
|
||||
return ""
|
||||
}
|
||||
return strings.Join(strings.Fields(msg), " ")
|
||||
}
|
||||
|
||||
func inferStdLogLevel(msg string) Level {
|
||||
lower := strings.ToLower(strings.TrimSpace(msg))
|
||||
if lower == "" {
|
||||
return LevelInfo
|
||||
}
|
||||
|
||||
if strings.HasPrefix(lower, "[debug]") || strings.HasPrefix(lower, "debug:") {
|
||||
return LevelDebug
|
||||
}
|
||||
if strings.HasPrefix(lower, "[warn]") || strings.HasPrefix(lower, "[warning]") || strings.HasPrefix(lower, "warn:") || strings.HasPrefix(lower, "warning:") {
|
||||
return LevelWarn
|
||||
}
|
||||
if strings.HasPrefix(lower, "[error]") || strings.HasPrefix(lower, "error:") || strings.HasPrefix(lower, "fatal:") || strings.HasPrefix(lower, "panic:") {
|
||||
return LevelError
|
||||
}
|
||||
|
||||
if strings.Contains(lower, " failed") || strings.Contains(lower, "error") || strings.Contains(lower, "panic") || strings.Contains(lower, "fatal") {
|
||||
return LevelError
|
||||
}
|
||||
if strings.Contains(lower, "warning") || strings.Contains(lower, "warn") || strings.Contains(lower, " retry") || strings.Contains(lower, " queue full") || strings.Contains(lower, "fallback") {
|
||||
return LevelWarn
|
||||
}
|
||||
return LevelInfo
|
||||
}
|
||||
|
||||
// LegacyPrintf 用于平滑迁移历史的 printf 风格日志到结构化 logger。
|
||||
func LegacyPrintf(component, format string, args ...any) {
|
||||
msg := normalizeStdLogMessage(fmt.Sprintf(format, args...))
|
||||
if msg == "" {
|
||||
return
|
||||
}
|
||||
|
||||
mu.RLock()
|
||||
initialized := global != nil
|
||||
mu.RUnlock()
|
||||
if !initialized {
|
||||
// 在日志系统未初始化前,回退到标准库 log,避免测试/工具链丢日志。
|
||||
log.Print(msg)
|
||||
return
|
||||
}
|
||||
|
||||
l := L()
|
||||
if component != "" {
|
||||
l = l.With(zap.String("component", component))
|
||||
}
|
||||
l = l.WithOptions(zap.AddCallerSkip(1))
|
||||
|
||||
switch inferStdLogLevel(msg) {
|
||||
case LevelDebug:
|
||||
l.Debug(msg, zap.Bool("legacy_printf", true))
|
||||
case LevelWarn:
|
||||
l.Warn(msg, zap.Bool("legacy_printf", true))
|
||||
case LevelError, LevelFatal:
|
||||
l.Error(msg, zap.Bool("legacy_printf", true))
|
||||
default:
|
||||
l.Info(msg, zap.Bool("legacy_printf", true))
|
||||
}
|
||||
}
|
||||
|
||||
type contextKey string
|
||||
|
||||
const loggerContextKey contextKey = "ctx_logger"
|
||||
|
||||
func IntoContext(ctx context.Context, l *zap.Logger) context.Context {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
if l == nil {
|
||||
l = L()
|
||||
}
|
||||
return context.WithValue(ctx, loggerContextKey, l)
|
||||
}
|
||||
|
||||
func FromContext(ctx context.Context) *zap.Logger {
|
||||
if ctx == nil {
|
||||
return L()
|
||||
}
|
||||
if l, ok := ctx.Value(loggerContextKey).(*zap.Logger); ok && l != nil {
|
||||
return l
|
||||
}
|
||||
return L()
|
||||
}
|
||||
192
backend/internal/pkg/logger/logger_test.go
Normal file
192
backend/internal/pkg/logger/logger_test.go
Normal file
@@ -0,0 +1,192 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestInit_DualOutput(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
logPath := filepath.Join(tmpDir, "logs", "sub2api.log")
|
||||
|
||||
origStdout := os.Stdout
|
||||
origStderr := os.Stderr
|
||||
stdoutR, stdoutW, err := os.Pipe()
|
||||
if err != nil {
|
||||
t.Fatalf("create stdout pipe: %v", err)
|
||||
}
|
||||
stderrR, stderrW, err := os.Pipe()
|
||||
if err != nil {
|
||||
t.Fatalf("create stderr pipe: %v", err)
|
||||
}
|
||||
os.Stdout = stdoutW
|
||||
os.Stderr = stderrW
|
||||
t.Cleanup(func() {
|
||||
os.Stdout = origStdout
|
||||
os.Stderr = origStderr
|
||||
_ = stdoutR.Close()
|
||||
_ = stderrR.Close()
|
||||
_ = stdoutW.Close()
|
||||
_ = stderrW.Close()
|
||||
})
|
||||
|
||||
err = Init(InitOptions{
|
||||
Level: "debug",
|
||||
Format: "json",
|
||||
ServiceName: "sub2api",
|
||||
Environment: "test",
|
||||
Output: OutputOptions{
|
||||
ToStdout: true,
|
||||
ToFile: true,
|
||||
FilePath: logPath,
|
||||
},
|
||||
Rotation: RotationOptions{
|
||||
MaxSizeMB: 10,
|
||||
MaxBackups: 2,
|
||||
MaxAgeDays: 1,
|
||||
},
|
||||
Sampling: SamplingOptions{Enabled: false},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Init() error: %v", err)
|
||||
}
|
||||
|
||||
L().Info("dual-output-info")
|
||||
L().Warn("dual-output-warn")
|
||||
Sync()
|
||||
|
||||
_ = stdoutW.Close()
|
||||
_ = stderrW.Close()
|
||||
stdoutBytes, _ := io.ReadAll(stdoutR)
|
||||
stderrBytes, _ := io.ReadAll(stderrR)
|
||||
stdoutText := string(stdoutBytes)
|
||||
stderrText := string(stderrBytes)
|
||||
|
||||
if !strings.Contains(stdoutText, "dual-output-info") {
|
||||
t.Fatalf("stdout missing info log: %s", stdoutText)
|
||||
}
|
||||
if !strings.Contains(stderrText, "dual-output-warn") {
|
||||
t.Fatalf("stderr missing warn log: %s", stderrText)
|
||||
}
|
||||
|
||||
fileBytes, err := os.ReadFile(logPath)
|
||||
if err != nil {
|
||||
t.Fatalf("read log file: %v", err)
|
||||
}
|
||||
fileText := string(fileBytes)
|
||||
if !strings.Contains(fileText, "dual-output-info") || !strings.Contains(fileText, "dual-output-warn") {
|
||||
t.Fatalf("file missing logs: %s", fileText)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInit_FileOutputFailureDowngrade(t *testing.T) {
|
||||
origStdout := os.Stdout
|
||||
origStderr := os.Stderr
|
||||
_, stdoutW, err := os.Pipe()
|
||||
if err != nil {
|
||||
t.Fatalf("create stdout pipe: %v", err)
|
||||
}
|
||||
stderrR, stderrW, err := os.Pipe()
|
||||
if err != nil {
|
||||
t.Fatalf("create stderr pipe: %v", err)
|
||||
}
|
||||
os.Stdout = stdoutW
|
||||
os.Stderr = stderrW
|
||||
t.Cleanup(func() {
|
||||
os.Stdout = origStdout
|
||||
os.Stderr = origStderr
|
||||
_ = stdoutW.Close()
|
||||
_ = stderrR.Close()
|
||||
_ = stderrW.Close()
|
||||
})
|
||||
|
||||
err = Init(InitOptions{
|
||||
Level: "info",
|
||||
Format: "json",
|
||||
Output: OutputOptions{
|
||||
ToStdout: true,
|
||||
ToFile: true,
|
||||
FilePath: filepath.Join(os.DevNull, "logs", "sub2api.log"),
|
||||
},
|
||||
Rotation: RotationOptions{
|
||||
MaxSizeMB: 10,
|
||||
MaxBackups: 1,
|
||||
MaxAgeDays: 1,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Init() should downgrade instead of failing, got: %v", err)
|
||||
}
|
||||
|
||||
_ = stderrW.Close()
|
||||
stderrBytes, _ := io.ReadAll(stderrR)
|
||||
if !strings.Contains(string(stderrBytes), "日志文件输出初始化失败") {
|
||||
t.Fatalf("stderr should contain fallback warning, got: %s", string(stderrBytes))
|
||||
}
|
||||
}
|
||||
|
||||
func TestInit_CallerShouldPointToCallsite(t *testing.T) {
|
||||
origStdout := os.Stdout
|
||||
origStderr := os.Stderr
|
||||
stdoutR, stdoutW, err := os.Pipe()
|
||||
if err != nil {
|
||||
t.Fatalf("create stdout pipe: %v", err)
|
||||
}
|
||||
_, stderrW, err := os.Pipe()
|
||||
if err != nil {
|
||||
t.Fatalf("create stderr pipe: %v", err)
|
||||
}
|
||||
os.Stdout = stdoutW
|
||||
os.Stderr = stderrW
|
||||
t.Cleanup(func() {
|
||||
os.Stdout = origStdout
|
||||
os.Stderr = origStderr
|
||||
_ = stdoutR.Close()
|
||||
_ = stdoutW.Close()
|
||||
_ = stderrW.Close()
|
||||
})
|
||||
|
||||
if err := Init(InitOptions{
|
||||
Level: "info",
|
||||
Format: "json",
|
||||
ServiceName: "sub2api",
|
||||
Environment: "test",
|
||||
Caller: true,
|
||||
Output: OutputOptions{
|
||||
ToStdout: true,
|
||||
ToFile: false,
|
||||
},
|
||||
Sampling: SamplingOptions{Enabled: false},
|
||||
}); err != nil {
|
||||
t.Fatalf("Init() error: %v", err)
|
||||
}
|
||||
|
||||
L().Info("caller-check")
|
||||
Sync()
|
||||
_ = stdoutW.Close()
|
||||
logBytes, _ := io.ReadAll(stdoutR)
|
||||
|
||||
var line string
|
||||
for _, item := range strings.Split(string(logBytes), "\n") {
|
||||
if strings.Contains(item, "caller-check") {
|
||||
line = item
|
||||
break
|
||||
}
|
||||
}
|
||||
if line == "" {
|
||||
t.Fatalf("log output missing caller-check: %s", string(logBytes))
|
||||
}
|
||||
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal([]byte(line), &payload); err != nil {
|
||||
t.Fatalf("parse log json failed: %v, line=%s", err, line)
|
||||
}
|
||||
caller, _ := payload["caller"].(string)
|
||||
if !strings.Contains(caller, "logger_test.go:") {
|
||||
t.Fatalf("caller should point to this test file, got: %s", caller)
|
||||
}
|
||||
}
|
||||
161
backend/internal/pkg/logger/options.go
Normal file
161
backend/internal/pkg/logger/options.go
Normal file
@@ -0,0 +1,161 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultContainerLogPath 为容器内默认日志文件路径。
|
||||
DefaultContainerLogPath = "/app/data/logs/sub2api.log"
|
||||
defaultLogFilename = "sub2api.log"
|
||||
)
|
||||
|
||||
type InitOptions struct {
|
||||
Level string
|
||||
Format string
|
||||
ServiceName string
|
||||
Environment string
|
||||
Caller bool
|
||||
StacktraceLevel string
|
||||
Output OutputOptions
|
||||
Rotation RotationOptions
|
||||
Sampling SamplingOptions
|
||||
}
|
||||
|
||||
type OutputOptions struct {
|
||||
ToStdout bool
|
||||
ToFile bool
|
||||
FilePath string
|
||||
}
|
||||
|
||||
type RotationOptions struct {
|
||||
MaxSizeMB int
|
||||
MaxBackups int
|
||||
MaxAgeDays int
|
||||
Compress bool
|
||||
LocalTime bool
|
||||
}
|
||||
|
||||
type SamplingOptions struct {
|
||||
Enabled bool
|
||||
Initial int
|
||||
Thereafter int
|
||||
}
|
||||
|
||||
func (o InitOptions) normalized() InitOptions {
|
||||
out := o
|
||||
out.Level = strings.ToLower(strings.TrimSpace(out.Level))
|
||||
if out.Level == "" {
|
||||
out.Level = "info"
|
||||
}
|
||||
out.Format = strings.ToLower(strings.TrimSpace(out.Format))
|
||||
if out.Format == "" {
|
||||
out.Format = "console"
|
||||
}
|
||||
out.ServiceName = strings.TrimSpace(out.ServiceName)
|
||||
if out.ServiceName == "" {
|
||||
out.ServiceName = "sub2api"
|
||||
}
|
||||
out.Environment = strings.TrimSpace(out.Environment)
|
||||
if out.Environment == "" {
|
||||
out.Environment = "production"
|
||||
}
|
||||
out.StacktraceLevel = strings.ToLower(strings.TrimSpace(out.StacktraceLevel))
|
||||
if out.StacktraceLevel == "" {
|
||||
out.StacktraceLevel = "error"
|
||||
}
|
||||
if !out.Output.ToStdout && !out.Output.ToFile {
|
||||
out.Output.ToStdout = true
|
||||
}
|
||||
out.Output.FilePath = resolveLogFilePath(out.Output.FilePath)
|
||||
if out.Rotation.MaxSizeMB <= 0 {
|
||||
out.Rotation.MaxSizeMB = 100
|
||||
}
|
||||
if out.Rotation.MaxBackups < 0 {
|
||||
out.Rotation.MaxBackups = 10
|
||||
}
|
||||
if out.Rotation.MaxAgeDays < 0 {
|
||||
out.Rotation.MaxAgeDays = 7
|
||||
}
|
||||
if out.Sampling.Enabled {
|
||||
if out.Sampling.Initial <= 0 {
|
||||
out.Sampling.Initial = 100
|
||||
}
|
||||
if out.Sampling.Thereafter <= 0 {
|
||||
out.Sampling.Thereafter = 100
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func resolveLogFilePath(explicit string) string {
|
||||
explicit = strings.TrimSpace(explicit)
|
||||
if explicit != "" {
|
||||
return explicit
|
||||
}
|
||||
dataDir := strings.TrimSpace(os.Getenv("DATA_DIR"))
|
||||
if dataDir != "" {
|
||||
return filepath.Join(dataDir, "logs", defaultLogFilename)
|
||||
}
|
||||
return DefaultContainerLogPath
|
||||
}
|
||||
|
||||
func bootstrapOptions() InitOptions {
|
||||
return InitOptions{
|
||||
Level: "info",
|
||||
Format: "console",
|
||||
ServiceName: "sub2api",
|
||||
Environment: "bootstrap",
|
||||
Output: OutputOptions{
|
||||
ToStdout: true,
|
||||
ToFile: false,
|
||||
},
|
||||
Rotation: RotationOptions{
|
||||
MaxSizeMB: 100,
|
||||
MaxBackups: 10,
|
||||
MaxAgeDays: 7,
|
||||
Compress: true,
|
||||
LocalTime: true,
|
||||
},
|
||||
Sampling: SamplingOptions{
|
||||
Enabled: false,
|
||||
Initial: 100,
|
||||
Thereafter: 100,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func parseLevel(level string) (Level, bool) {
|
||||
switch strings.ToLower(strings.TrimSpace(level)) {
|
||||
case "debug":
|
||||
return LevelDebug, true
|
||||
case "info":
|
||||
return LevelInfo, true
|
||||
case "warn":
|
||||
return LevelWarn, true
|
||||
case "error":
|
||||
return LevelError, true
|
||||
default:
|
||||
return LevelInfo, false
|
||||
}
|
||||
}
|
||||
|
||||
func parseStacktraceLevel(level string) (Level, bool) {
|
||||
switch strings.ToLower(strings.TrimSpace(level)) {
|
||||
case "none":
|
||||
return LevelFatal + 1, true
|
||||
case "error":
|
||||
return LevelError, true
|
||||
case "fatal":
|
||||
return LevelFatal, true
|
||||
default:
|
||||
return LevelError, false
|
||||
}
|
||||
}
|
||||
|
||||
func samplingTick() time.Duration {
|
||||
return time.Second
|
||||
}
|
||||
102
backend/internal/pkg/logger/options_test.go
Normal file
102
backend/internal/pkg/logger/options_test.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
)
|
||||
|
||||
func TestResolveLogFilePath_Default(t *testing.T) {
|
||||
t.Setenv("DATA_DIR", "")
|
||||
got := resolveLogFilePath("")
|
||||
if got != DefaultContainerLogPath {
|
||||
t.Fatalf("resolveLogFilePath() = %q, want %q", got, DefaultContainerLogPath)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveLogFilePath_WithDataDir(t *testing.T) {
|
||||
t.Setenv("DATA_DIR", "/tmp/sub2api-data")
|
||||
got := resolveLogFilePath("")
|
||||
want := filepath.Join("/tmp/sub2api-data", "logs", "sub2api.log")
|
||||
if got != want {
|
||||
t.Fatalf("resolveLogFilePath() = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveLogFilePath_ExplicitPath(t *testing.T) {
|
||||
t.Setenv("DATA_DIR", "/tmp/ignore")
|
||||
got := resolveLogFilePath("/var/log/custom.log")
|
||||
if got != "/var/log/custom.log" {
|
||||
t.Fatalf("resolveLogFilePath() = %q, want explicit path", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizedOptions_InvalidFallback(t *testing.T) {
|
||||
t.Setenv("DATA_DIR", "")
|
||||
opts := InitOptions{
|
||||
Level: "TRACE",
|
||||
Format: "TEXT",
|
||||
ServiceName: "",
|
||||
Environment: "",
|
||||
StacktraceLevel: "panic",
|
||||
Output: OutputOptions{
|
||||
ToStdout: false,
|
||||
ToFile: false,
|
||||
},
|
||||
Rotation: RotationOptions{
|
||||
MaxSizeMB: 0,
|
||||
MaxBackups: -1,
|
||||
MaxAgeDays: -1,
|
||||
},
|
||||
Sampling: SamplingOptions{
|
||||
Enabled: true,
|
||||
Initial: 0,
|
||||
Thereafter: 0,
|
||||
},
|
||||
}
|
||||
out := opts.normalized()
|
||||
if out.Level != "trace" {
|
||||
// normalized 仅做 trim/lower,不做校验;校验在 config 层。
|
||||
t.Fatalf("normalized level should preserve value for upstream validation, got %q", out.Level)
|
||||
}
|
||||
if !out.Output.ToStdout {
|
||||
t.Fatalf("normalized output should fallback to stdout")
|
||||
}
|
||||
if out.Output.FilePath != DefaultContainerLogPath {
|
||||
t.Fatalf("normalized file path = %q", out.Output.FilePath)
|
||||
}
|
||||
if out.Rotation.MaxSizeMB != 100 {
|
||||
t.Fatalf("normalized max_size_mb = %d", out.Rotation.MaxSizeMB)
|
||||
}
|
||||
if out.Rotation.MaxBackups != 10 {
|
||||
t.Fatalf("normalized max_backups = %d", out.Rotation.MaxBackups)
|
||||
}
|
||||
if out.Rotation.MaxAgeDays != 7 {
|
||||
t.Fatalf("normalized max_age_days = %d", out.Rotation.MaxAgeDays)
|
||||
}
|
||||
if out.Sampling.Initial != 100 || out.Sampling.Thereafter != 100 {
|
||||
t.Fatalf("normalized sampling defaults invalid: %+v", out.Sampling)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildFileCore_InvalidPathFallback(t *testing.T) {
|
||||
t.Setenv("DATA_DIR", "")
|
||||
opts := bootstrapOptions()
|
||||
opts.Output.ToFile = true
|
||||
opts.Output.FilePath = filepath.Join(os.DevNull, "logs", "sub2api.log")
|
||||
encoderCfg := zapcore.EncoderConfig{
|
||||
TimeKey: "time",
|
||||
LevelKey: "level",
|
||||
MessageKey: "msg",
|
||||
EncodeTime: zapcore.ISO8601TimeEncoder,
|
||||
EncodeLevel: zapcore.CapitalLevelEncoder,
|
||||
}
|
||||
encoder := zapcore.NewJSONEncoder(encoderCfg)
|
||||
_, _, err := buildFileCore(encoder, zap.NewAtomicLevel(), opts)
|
||||
if err == nil {
|
||||
t.Fatalf("buildFileCore() expected error for invalid path")
|
||||
}
|
||||
}
|
||||
132
backend/internal/pkg/logger/slog_handler.go
Normal file
132
backend/internal/pkg/logger/slog_handler.go
Normal file
@@ -0,0 +1,132 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
)
|
||||
|
||||
type slogZapHandler struct {
|
||||
logger *zap.Logger
|
||||
attrs []slog.Attr
|
||||
groups []string
|
||||
}
|
||||
|
||||
func newSlogZapHandler(logger *zap.Logger) slog.Handler {
|
||||
if logger == nil {
|
||||
logger = zap.NewNop()
|
||||
}
|
||||
return &slogZapHandler{
|
||||
logger: logger,
|
||||
attrs: make([]slog.Attr, 0, 8),
|
||||
groups: make([]string, 0, 4),
|
||||
}
|
||||
}
|
||||
|
||||
func (h *slogZapHandler) Enabled(_ context.Context, level slog.Level) bool {
|
||||
switch {
|
||||
case level >= slog.LevelError:
|
||||
return h.logger.Core().Enabled(LevelError)
|
||||
case level >= slog.LevelWarn:
|
||||
return h.logger.Core().Enabled(LevelWarn)
|
||||
case level <= slog.LevelDebug:
|
||||
return h.logger.Core().Enabled(LevelDebug)
|
||||
default:
|
||||
return h.logger.Core().Enabled(LevelInfo)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *slogZapHandler) Handle(_ context.Context, record slog.Record) error {
|
||||
fields := make([]zap.Field, 0, len(h.attrs)+record.NumAttrs()+3)
|
||||
fields = append(fields, slogAttrsToZapFields(h.groups, h.attrs)...)
|
||||
record.Attrs(func(attr slog.Attr) bool {
|
||||
fields = append(fields, slogAttrToZapField(h.groups, attr))
|
||||
return true
|
||||
})
|
||||
|
||||
entry := h.logger.With(fields...)
|
||||
switch {
|
||||
case record.Level >= slog.LevelError:
|
||||
entry.Error(record.Message)
|
||||
case record.Level >= slog.LevelWarn:
|
||||
entry.Warn(record.Message)
|
||||
case record.Level <= slog.LevelDebug:
|
||||
entry.Debug(record.Message)
|
||||
default:
|
||||
entry.Info(record.Message)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *slogZapHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
|
||||
next := *h
|
||||
next.attrs = append(append([]slog.Attr{}, h.attrs...), attrs...)
|
||||
return &next
|
||||
}
|
||||
|
||||
func (h *slogZapHandler) WithGroup(name string) slog.Handler {
|
||||
name = strings.TrimSpace(name)
|
||||
if name == "" {
|
||||
return h
|
||||
}
|
||||
next := *h
|
||||
next.groups = append(append([]string{}, h.groups...), name)
|
||||
return &next
|
||||
}
|
||||
|
||||
func slogAttrsToZapFields(groups []string, attrs []slog.Attr) []zap.Field {
|
||||
fields := make([]zap.Field, 0, len(attrs))
|
||||
for _, attr := range attrs {
|
||||
fields = append(fields, slogAttrToZapField(groups, attr))
|
||||
}
|
||||
return fields
|
||||
}
|
||||
|
||||
func slogAttrToZapField(groups []string, attr slog.Attr) zap.Field {
|
||||
if len(groups) > 0 {
|
||||
attr.Key = strings.Join(append(append([]string{}, groups...), attr.Key), ".")
|
||||
}
|
||||
value := attr.Value.Resolve()
|
||||
switch value.Kind() {
|
||||
case slog.KindBool:
|
||||
return zap.Bool(attr.Key, value.Bool())
|
||||
case slog.KindInt64:
|
||||
return zap.Int64(attr.Key, value.Int64())
|
||||
case slog.KindUint64:
|
||||
return zap.Uint64(attr.Key, value.Uint64())
|
||||
case slog.KindFloat64:
|
||||
return zap.Float64(attr.Key, value.Float64())
|
||||
case slog.KindDuration:
|
||||
return zap.Duration(attr.Key, value.Duration())
|
||||
case slog.KindTime:
|
||||
return zap.Time(attr.Key, value.Time())
|
||||
case slog.KindString:
|
||||
return zap.String(attr.Key, value.String())
|
||||
case slog.KindGroup:
|
||||
groupFields := make([]zap.Field, 0, len(value.Group()))
|
||||
for _, nested := range value.Group() {
|
||||
groupFields = append(groupFields, slogAttrToZapField(nil, nested))
|
||||
}
|
||||
return zap.Object(attr.Key, zapObjectFields(groupFields))
|
||||
case slog.KindAny:
|
||||
if t, ok := value.Any().(time.Time); ok {
|
||||
return zap.Time(attr.Key, t)
|
||||
}
|
||||
return zap.Any(attr.Key, value.Any())
|
||||
default:
|
||||
return zap.String(attr.Key, value.String())
|
||||
}
|
||||
}
|
||||
|
||||
type zapObjectFields []zap.Field
|
||||
|
||||
func (z zapObjectFields) MarshalLogObject(enc zapcore.ObjectEncoder) error {
|
||||
for _, field := range z {
|
||||
field.AddTo(enc)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
88
backend/internal/pkg/logger/slog_handler_test.go
Normal file
88
backend/internal/pkg/logger/slog_handler_test.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
)
|
||||
|
||||
type captureState struct {
|
||||
writes []capturedWrite
|
||||
}
|
||||
|
||||
type capturedWrite struct {
|
||||
fields []zapcore.Field
|
||||
}
|
||||
|
||||
type captureCore struct {
|
||||
state *captureState
|
||||
withFields []zapcore.Field
|
||||
}
|
||||
|
||||
func newCaptureCore() *captureCore {
|
||||
return &captureCore{state: &captureState{}}
|
||||
}
|
||||
|
||||
func (c *captureCore) Enabled(zapcore.Level) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *captureCore) With(fields []zapcore.Field) zapcore.Core {
|
||||
nextFields := make([]zapcore.Field, 0, len(c.withFields)+len(fields))
|
||||
nextFields = append(nextFields, c.withFields...)
|
||||
nextFields = append(nextFields, fields...)
|
||||
return &captureCore{
|
||||
state: c.state,
|
||||
withFields: nextFields,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *captureCore) Check(entry zapcore.Entry, ce *zapcore.CheckedEntry) *zapcore.CheckedEntry {
|
||||
return ce.AddCore(entry, c)
|
||||
}
|
||||
|
||||
func (c *captureCore) Write(entry zapcore.Entry, fields []zapcore.Field) error {
|
||||
allFields := make([]zapcore.Field, 0, len(c.withFields)+len(fields))
|
||||
allFields = append(allFields, c.withFields...)
|
||||
allFields = append(allFields, fields...)
|
||||
c.state.writes = append(c.state.writes, capturedWrite{
|
||||
fields: allFields,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *captureCore) Sync() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestSlogZapHandler_Handle_DoesNotAppendTimeField(t *testing.T) {
|
||||
core := newCaptureCore()
|
||||
handler := newSlogZapHandler(zap.New(core))
|
||||
|
||||
record := slog.NewRecord(time.Date(2026, 1, 1, 12, 0, 0, 0, time.UTC), slog.LevelInfo, "hello", 0)
|
||||
record.AddAttrs(slog.String("component", "http.access"))
|
||||
|
||||
if err := handler.Handle(context.Background(), record); err != nil {
|
||||
t.Fatalf("handle slog record: %v", err)
|
||||
}
|
||||
if len(core.state.writes) != 1 {
|
||||
t.Fatalf("write calls = %d, want 1", len(core.state.writes))
|
||||
}
|
||||
|
||||
var hasComponent bool
|
||||
for _, field := range core.state.writes[0].fields {
|
||||
if field.Key == "time" {
|
||||
t.Fatalf("unexpected duplicate time field in slog adapter output")
|
||||
}
|
||||
if field.Key == "component" {
|
||||
hasComponent = true
|
||||
}
|
||||
}
|
||||
if !hasComponent {
|
||||
t.Fatalf("component field should be preserved")
|
||||
}
|
||||
}
|
||||
165
backend/internal/pkg/logger/stdlog_bridge_test.go
Normal file
165
backend/internal/pkg/logger/stdlog_bridge_test.go
Normal file
@@ -0,0 +1,165 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestInferStdLogLevel(t *testing.T) {
|
||||
cases := []struct {
|
||||
msg string
|
||||
want Level
|
||||
}{
|
||||
{msg: "Warning: queue full", want: LevelWarn},
|
||||
{msg: "Forward request failed: timeout", want: LevelError},
|
||||
{msg: "[ERROR] upstream unavailable", want: LevelError},
|
||||
{msg: "service started", want: LevelInfo},
|
||||
{msg: "debug: cache miss", want: LevelDebug},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
got := inferStdLogLevel(tc.msg)
|
||||
if got != tc.want {
|
||||
t.Fatalf("inferStdLogLevel(%q)=%v want=%v", tc.msg, got, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeStdLogMessage(t *testing.T) {
|
||||
raw := " [TokenRefresh] cycle complete \n total=1 failed=0 \n"
|
||||
got := normalizeStdLogMessage(raw)
|
||||
want := "[TokenRefresh] cycle complete total=1 failed=0"
|
||||
if got != want {
|
||||
t.Fatalf("normalizeStdLogMessage()=%q want=%q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStdLogBridgeRoutesLevels(t *testing.T) {
|
||||
origStdout := os.Stdout
|
||||
origStderr := os.Stderr
|
||||
stdoutR, stdoutW, err := os.Pipe()
|
||||
if err != nil {
|
||||
t.Fatalf("create stdout pipe: %v", err)
|
||||
}
|
||||
stderrR, stderrW, err := os.Pipe()
|
||||
if err != nil {
|
||||
t.Fatalf("create stderr pipe: %v", err)
|
||||
}
|
||||
os.Stdout = stdoutW
|
||||
os.Stderr = stderrW
|
||||
t.Cleanup(func() {
|
||||
os.Stdout = origStdout
|
||||
os.Stderr = origStderr
|
||||
_ = stdoutR.Close()
|
||||
_ = stdoutW.Close()
|
||||
_ = stderrR.Close()
|
||||
_ = stderrW.Close()
|
||||
})
|
||||
|
||||
if err := Init(InitOptions{
|
||||
Level: "debug",
|
||||
Format: "json",
|
||||
ServiceName: "sub2api",
|
||||
Environment: "test",
|
||||
Output: OutputOptions{
|
||||
ToStdout: true,
|
||||
ToFile: false,
|
||||
},
|
||||
Sampling: SamplingOptions{Enabled: false},
|
||||
}); err != nil {
|
||||
t.Fatalf("Init() error: %v", err)
|
||||
}
|
||||
|
||||
log.Printf("service started")
|
||||
log.Printf("Warning: queue full")
|
||||
log.Printf("Forward request failed: timeout")
|
||||
Sync()
|
||||
|
||||
_ = stdoutW.Close()
|
||||
_ = stderrW.Close()
|
||||
stdoutBytes, _ := io.ReadAll(stdoutR)
|
||||
stderrBytes, _ := io.ReadAll(stderrR)
|
||||
stdoutText := string(stdoutBytes)
|
||||
stderrText := string(stderrBytes)
|
||||
|
||||
if !strings.Contains(stdoutText, "service started") {
|
||||
t.Fatalf("stdout missing info log: %s", stdoutText)
|
||||
}
|
||||
if !strings.Contains(stderrText, "Warning: queue full") {
|
||||
t.Fatalf("stderr missing warn log: %s", stderrText)
|
||||
}
|
||||
if !strings.Contains(stderrText, "Forward request failed: timeout") {
|
||||
t.Fatalf("stderr missing error log: %s", stderrText)
|
||||
}
|
||||
if !strings.Contains(stderrText, "\"legacy_stdlog\":true") {
|
||||
t.Fatalf("stderr missing legacy_stdlog marker: %s", stderrText)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLegacyPrintfRoutesLevels(t *testing.T) {
|
||||
origStdout := os.Stdout
|
||||
origStderr := os.Stderr
|
||||
stdoutR, stdoutW, err := os.Pipe()
|
||||
if err != nil {
|
||||
t.Fatalf("create stdout pipe: %v", err)
|
||||
}
|
||||
stderrR, stderrW, err := os.Pipe()
|
||||
if err != nil {
|
||||
t.Fatalf("create stderr pipe: %v", err)
|
||||
}
|
||||
os.Stdout = stdoutW
|
||||
os.Stderr = stderrW
|
||||
t.Cleanup(func() {
|
||||
os.Stdout = origStdout
|
||||
os.Stderr = origStderr
|
||||
_ = stdoutR.Close()
|
||||
_ = stdoutW.Close()
|
||||
_ = stderrR.Close()
|
||||
_ = stderrW.Close()
|
||||
})
|
||||
|
||||
if err := Init(InitOptions{
|
||||
Level: "debug",
|
||||
Format: "json",
|
||||
ServiceName: "sub2api",
|
||||
Environment: "test",
|
||||
Output: OutputOptions{
|
||||
ToStdout: true,
|
||||
ToFile: false,
|
||||
},
|
||||
Sampling: SamplingOptions{Enabled: false},
|
||||
}); err != nil {
|
||||
t.Fatalf("Init() error: %v", err)
|
||||
}
|
||||
|
||||
LegacyPrintf("service.test", "request started")
|
||||
LegacyPrintf("service.test", "Warning: queue full")
|
||||
LegacyPrintf("service.test", "forward failed: timeout")
|
||||
Sync()
|
||||
|
||||
_ = stdoutW.Close()
|
||||
_ = stderrW.Close()
|
||||
stdoutBytes, _ := io.ReadAll(stdoutR)
|
||||
stderrBytes, _ := io.ReadAll(stderrR)
|
||||
stdoutText := string(stdoutBytes)
|
||||
stderrText := string(stderrBytes)
|
||||
|
||||
if !strings.Contains(stdoutText, "request started") {
|
||||
t.Fatalf("stdout missing info log: %s", stdoutText)
|
||||
}
|
||||
if !strings.Contains(stderrText, "Warning: queue full") {
|
||||
t.Fatalf("stderr missing warn log: %s", stderrText)
|
||||
}
|
||||
if !strings.Contains(stderrText, "forward failed: timeout") {
|
||||
t.Fatalf("stderr missing error log: %s", stderrText)
|
||||
}
|
||||
if !strings.Contains(stderrText, "\"legacy_printf\":true") {
|
||||
t.Fatalf("stderr missing legacy_printf marker: %s", stderrText)
|
||||
}
|
||||
if !strings.Contains(stderrText, "\"component\":\"service.test\"") {
|
||||
t.Fatalf("stderr missing component field: %s", stderrText)
|
||||
}
|
||||
}
|
||||
@@ -50,6 +50,7 @@ type OAuthSession struct {
|
||||
type SessionStore struct {
|
||||
mu sync.RWMutex
|
||||
sessions map[string]*OAuthSession
|
||||
stopOnce sync.Once
|
||||
stopCh chan struct{}
|
||||
}
|
||||
|
||||
@@ -65,7 +66,9 @@ func NewSessionStore() *SessionStore {
|
||||
|
||||
// Stop stops the cleanup goroutine
|
||||
func (s *SessionStore) Stop() {
|
||||
close(s.stopCh)
|
||||
s.stopOnce.Do(func() {
|
||||
close(s.stopCh)
|
||||
})
|
||||
}
|
||||
|
||||
// Set stores a session
|
||||
|
||||
43
backend/internal/pkg/oauth/oauth_test.go
Normal file
43
backend/internal/pkg/oauth/oauth_test.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package oauth
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestSessionStore_Stop_Idempotent(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
|
||||
store.Stop()
|
||||
store.Stop()
|
||||
|
||||
select {
|
||||
case <-store.stopCh:
|
||||
// ok
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("stopCh 未关闭")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionStore_Stop_Concurrent(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for range 50 {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
store.Stop()
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
select {
|
||||
case <-store.stopCh:
|
||||
// ok
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("stopCh 未关闭")
|
||||
}
|
||||
}
|
||||
@@ -15,8 +15,8 @@ type Model struct {
|
||||
|
||||
// DefaultModels OpenAI models list
|
||||
var DefaultModels = []Model{
|
||||
{ID: "gpt-5.3", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3"},
|
||||
{ID: "gpt-5.3-codex", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3 Codex"},
|
||||
{ID: "gpt-5.3-codex-spark", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3 Codex Spark"},
|
||||
{ID: "gpt-5.2", Object: "model", Created: 1733875200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2"},
|
||||
{ID: "gpt-5.2-codex", Object: "model", Created: 1733011200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2 Codex"},
|
||||
{ID: "gpt-5.1-codex-max", Object: "model", Created: 1730419200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1 Codex Max"},
|
||||
|
||||
@@ -17,6 +17,8 @@ import (
|
||||
const (
|
||||
// OAuth Client ID for OpenAI (Codex CLI official)
|
||||
ClientID = "app_EMoamEEZ73f0CkXaXp7hrann"
|
||||
// OAuth Client ID for Sora mobile flow (aligned with sora2api)
|
||||
SoraClientID = "app_LlGpXReQgckcGGUo2JrYvtJK"
|
||||
|
||||
// OAuth endpoints
|
||||
AuthorizeURL = "https://auth.openai.com/oauth/authorize"
|
||||
@@ -47,6 +49,7 @@ type OAuthSession struct {
|
||||
type SessionStore struct {
|
||||
mu sync.RWMutex
|
||||
sessions map[string]*OAuthSession
|
||||
stopOnce sync.Once
|
||||
stopCh chan struct{}
|
||||
}
|
||||
|
||||
@@ -92,7 +95,9 @@ func (s *SessionStore) Delete(sessionID string) {
|
||||
|
||||
// Stop stops the cleanup goroutine
|
||||
func (s *SessionStore) Stop() {
|
||||
close(s.stopCh)
|
||||
s.stopOnce.Do(func() {
|
||||
close(s.stopCh)
|
||||
})
|
||||
}
|
||||
|
||||
// cleanup removes expired sessions periodically
|
||||
|
||||
43
backend/internal/pkg/openai/oauth_test.go
Normal file
43
backend/internal/pkg/openai/oauth_test.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestSessionStore_Stop_Idempotent(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
|
||||
store.Stop()
|
||||
store.Stop()
|
||||
|
||||
select {
|
||||
case <-store.stopCh:
|
||||
// ok
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("stopCh 未关闭")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionStore_Stop_Concurrent(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for range 50 {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
store.Stop()
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
select {
|
||||
case <-store.stopCh:
|
||||
// ok
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("stopCh 未关闭")
|
||||
}
|
||||
}
|
||||
@@ -1,5 +1,7 @@
|
||||
package openai
|
||||
|
||||
import "strings"
|
||||
|
||||
// CodexCLIUserAgentPrefixes matches Codex CLI User-Agent patterns
|
||||
// Examples: "codex_vscode/1.0.0", "codex_cli_rs/0.1.2"
|
||||
var CodexCLIUserAgentPrefixes = []string{
|
||||
@@ -7,10 +9,67 @@ var CodexCLIUserAgentPrefixes = []string{
|
||||
"codex_cli_rs/",
|
||||
}
|
||||
|
||||
// CodexOfficialClientUserAgentPrefixes matches Codex 官方客户端家族 User-Agent 前缀。
|
||||
// 该列表仅用于 OpenAI OAuth `codex_cli_only` 访问限制判定。
|
||||
var CodexOfficialClientUserAgentPrefixes = []string{
|
||||
"codex_cli_rs/",
|
||||
"codex_vscode/",
|
||||
"codex_app/",
|
||||
"codex_chatgpt_desktop/",
|
||||
"codex_atlas/",
|
||||
"codex_exec/",
|
||||
"codex_sdk_ts/",
|
||||
"codex ",
|
||||
}
|
||||
|
||||
// CodexOfficialClientOriginatorPrefixes matches Codex 官方客户端家族 originator 前缀。
|
||||
// 说明:OpenAI 官方 Codex 客户端并不只使用固定的 codex_app 标识。
|
||||
// 例如 codex_cli_rs、codex_vscode、codex_chatgpt_desktop、codex_atlas、codex_exec、codex_sdk_ts 等。
|
||||
var CodexOfficialClientOriginatorPrefixes = []string{
|
||||
"codex_",
|
||||
"codex ",
|
||||
}
|
||||
|
||||
// IsCodexCLIRequest checks if the User-Agent indicates a Codex CLI request
|
||||
func IsCodexCLIRequest(userAgent string) bool {
|
||||
for _, prefix := range CodexCLIUserAgentPrefixes {
|
||||
if len(userAgent) >= len(prefix) && userAgent[:len(prefix)] == prefix {
|
||||
ua := normalizeCodexClientHeader(userAgent)
|
||||
if ua == "" {
|
||||
return false
|
||||
}
|
||||
return matchCodexClientHeaderPrefixes(ua, CodexCLIUserAgentPrefixes)
|
||||
}
|
||||
|
||||
// IsCodexOfficialClientRequest checks if the User-Agent indicates a Codex 官方客户端请求。
|
||||
// 与 IsCodexCLIRequest 解耦,避免影响历史兼容逻辑。
|
||||
func IsCodexOfficialClientRequest(userAgent string) bool {
|
||||
ua := normalizeCodexClientHeader(userAgent)
|
||||
if ua == "" {
|
||||
return false
|
||||
}
|
||||
return matchCodexClientHeaderPrefixes(ua, CodexOfficialClientUserAgentPrefixes)
|
||||
}
|
||||
|
||||
// IsCodexOfficialClientOriginator checks if originator indicates a Codex 官方客户端请求。
|
||||
func IsCodexOfficialClientOriginator(originator string) bool {
|
||||
v := normalizeCodexClientHeader(originator)
|
||||
if v == "" {
|
||||
return false
|
||||
}
|
||||
return matchCodexClientHeaderPrefixes(v, CodexOfficialClientOriginatorPrefixes)
|
||||
}
|
||||
|
||||
func normalizeCodexClientHeader(value string) string {
|
||||
return strings.ToLower(strings.TrimSpace(value))
|
||||
}
|
||||
|
||||
func matchCodexClientHeaderPrefixes(value string, prefixes []string) bool {
|
||||
for _, prefix := range prefixes {
|
||||
normalizedPrefix := normalizeCodexClientHeader(prefix)
|
||||
if normalizedPrefix == "" {
|
||||
continue
|
||||
}
|
||||
// 优先前缀匹配;若 UA/Originator 被网关拼接为复合字符串时,退化为包含匹配。
|
||||
if strings.HasPrefix(value, normalizedPrefix) || strings.Contains(value, normalizedPrefix) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
87
backend/internal/pkg/openai/request_test.go
Normal file
87
backend/internal/pkg/openai/request_test.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package openai
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestIsCodexCLIRequest(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ua string
|
||||
want bool
|
||||
}{
|
||||
{name: "codex_cli_rs 前缀", ua: "codex_cli_rs/0.1.0", want: true},
|
||||
{name: "codex_vscode 前缀", ua: "codex_vscode/1.2.3", want: true},
|
||||
{name: "大小写混合", ua: "Codex_CLI_Rs/0.1.0", want: true},
|
||||
{name: "复合 UA 包含 codex", ua: "Mozilla/5.0 codex_cli_rs/0.1.0", want: true},
|
||||
{name: "空白包裹", ua: " codex_vscode/1.2.3 ", want: true},
|
||||
{name: "非 codex", ua: "curl/8.0.1", want: false},
|
||||
{name: "空字符串", ua: "", want: false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := IsCodexCLIRequest(tt.ua)
|
||||
if got != tt.want {
|
||||
t.Fatalf("IsCodexCLIRequest(%q) = %v, want %v", tt.ua, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsCodexOfficialClientRequest(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ua string
|
||||
want bool
|
||||
}{
|
||||
{name: "codex_cli_rs 前缀", ua: "codex_cli_rs/0.98.0", want: true},
|
||||
{name: "codex_vscode 前缀", ua: "codex_vscode/1.0.0", want: true},
|
||||
{name: "codex_app 前缀", ua: "codex_app/0.1.0", want: true},
|
||||
{name: "codex_chatgpt_desktop 前缀", ua: "codex_chatgpt_desktop/1.0.0", want: true},
|
||||
{name: "codex_atlas 前缀", ua: "codex_atlas/1.0.0", want: true},
|
||||
{name: "codex_exec 前缀", ua: "codex_exec/0.1.0", want: true},
|
||||
{name: "codex_sdk_ts 前缀", ua: "codex_sdk_ts/0.1.0", want: true},
|
||||
{name: "Codex 桌面 UA", ua: "Codex Desktop/1.2.3", want: true},
|
||||
{name: "复合 UA 包含 codex_app", ua: "Mozilla/5.0 codex_app/0.1.0", want: true},
|
||||
{name: "大小写混合", ua: "Codex_VSCode/1.2.3", want: true},
|
||||
{name: "非 codex", ua: "curl/8.0.1", want: false},
|
||||
{name: "空字符串", ua: "", want: false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := IsCodexOfficialClientRequest(tt.ua)
|
||||
if got != tt.want {
|
||||
t.Fatalf("IsCodexOfficialClientRequest(%q) = %v, want %v", tt.ua, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsCodexOfficialClientOriginator(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
originator string
|
||||
want bool
|
||||
}{
|
||||
{name: "codex_cli_rs", originator: "codex_cli_rs", want: true},
|
||||
{name: "codex_vscode", originator: "codex_vscode", want: true},
|
||||
{name: "codex_app", originator: "codex_app", want: true},
|
||||
{name: "codex_chatgpt_desktop", originator: "codex_chatgpt_desktop", want: true},
|
||||
{name: "codex_atlas", originator: "codex_atlas", want: true},
|
||||
{name: "codex_exec", originator: "codex_exec", want: true},
|
||||
{name: "codex_sdk_ts", originator: "codex_sdk_ts", want: true},
|
||||
{name: "Codex 前缀", originator: "Codex Desktop", want: true},
|
||||
{name: "空白包裹", originator: " codex_vscode ", want: true},
|
||||
{name: "非 codex", originator: "my_client", want: false},
|
||||
{name: "空字符串", originator: "", want: false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := IsCodexOfficialClientOriginator(tt.originator)
|
||||
if got != tt.want {
|
||||
t.Fatalf("IsCodexOfficialClientOriginator(%q) = %v, want %v", tt.originator, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"net/http"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/logredact"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
@@ -78,7 +79,7 @@ func ErrorFrom(c *gin.Context, err error) bool {
|
||||
|
||||
// Log internal errors with full details for debugging
|
||||
if statusCode >= 500 && c.Request != nil {
|
||||
log.Printf("[ERROR] %s %s\n Error: %s", c.Request.Method, c.Request.URL.Path, err.Error())
|
||||
log.Printf("[ERROR] %s %s\n Error: %s", c.Request.Method, c.Request.URL.Path, logredact.RedactText(err.Error()))
|
||||
}
|
||||
|
||||
ErrorWithDetails(c, statusCode, status.Message, status.Reason, status.Metadata)
|
||||
|
||||
@@ -14,6 +14,44 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// ---------- 辅助函数 ----------
|
||||
|
||||
// parseResponseBody 从 httptest.ResponseRecorder 中解析 JSON 响应体
|
||||
func parseResponseBody(t *testing.T, w *httptest.ResponseRecorder) Response {
|
||||
t.Helper()
|
||||
var got Response
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &got))
|
||||
return got
|
||||
}
|
||||
|
||||
// parsePaginatedBody 从响应体中解析分页数据(Data 字段是 PaginatedData)
|
||||
func parsePaginatedBody(t *testing.T, w *httptest.ResponseRecorder) (Response, PaginatedData) {
|
||||
t.Helper()
|
||||
// 先用 raw json 解析,因为 Data 是 any 类型
|
||||
var raw struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Reason string `json:"reason,omitempty"`
|
||||
Data json.RawMessage `json:"data,omitempty"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &raw))
|
||||
|
||||
var pd PaginatedData
|
||||
require.NoError(t, json.Unmarshal(raw.Data, &pd))
|
||||
|
||||
return Response{Code: raw.Code, Message: raw.Message, Reason: raw.Reason}, pd
|
||||
}
|
||||
|
||||
// newContextWithQuery 创建一个带有 URL query 参数的 gin.Context 用于测试 ParsePagination
|
||||
func newContextWithQuery(query string) (*httptest.ResponseRecorder, *gin.Context) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/?"+query, nil)
|
||||
return w, c
|
||||
}
|
||||
|
||||
// ---------- 现有测试 ----------
|
||||
|
||||
func TestErrorWithDetails(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
@@ -169,3 +207,582 @@ func TestErrorFrom(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- 新增测试 ----------
|
||||
|
||||
func TestSuccess(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
data any
|
||||
wantCode int
|
||||
wantBody Response
|
||||
}{
|
||||
{
|
||||
name: "返回字符串数据",
|
||||
data: "hello",
|
||||
wantCode: http.StatusOK,
|
||||
wantBody: Response{Code: 0, Message: "success", Data: "hello"},
|
||||
},
|
||||
{
|
||||
name: "返回nil数据",
|
||||
data: nil,
|
||||
wantCode: http.StatusOK,
|
||||
wantBody: Response{Code: 0, Message: "success"},
|
||||
},
|
||||
{
|
||||
name: "返回map数据",
|
||||
data: map[string]string{"key": "value"},
|
||||
wantCode: http.StatusOK,
|
||||
wantBody: Response{Code: 0, Message: "success"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
Success(c, tt.data)
|
||||
|
||||
require.Equal(t, tt.wantCode, w.Code)
|
||||
|
||||
// 只验证 code 和 message,data 字段类型在 JSON 反序列化时会变成 map/slice
|
||||
got := parseResponseBody(t, w)
|
||||
require.Equal(t, 0, got.Code)
|
||||
require.Equal(t, "success", got.Message)
|
||||
|
||||
if tt.data == nil {
|
||||
require.Nil(t, got.Data)
|
||||
} else {
|
||||
require.NotNil(t, got.Data)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreated(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
data any
|
||||
wantCode int
|
||||
}{
|
||||
{
|
||||
name: "创建成功_返回数据",
|
||||
data: map[string]int{"id": 42},
|
||||
wantCode: http.StatusCreated,
|
||||
},
|
||||
{
|
||||
name: "创建成功_nil数据",
|
||||
data: nil,
|
||||
wantCode: http.StatusCreated,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
Created(c, tt.data)
|
||||
|
||||
require.Equal(t, tt.wantCode, w.Code)
|
||||
|
||||
got := parseResponseBody(t, w)
|
||||
require.Equal(t, 0, got.Code)
|
||||
require.Equal(t, "success", got.Message)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestError(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
statusCode int
|
||||
message string
|
||||
}{
|
||||
{
|
||||
name: "400错误",
|
||||
statusCode: http.StatusBadRequest,
|
||||
message: "bad request",
|
||||
},
|
||||
{
|
||||
name: "500错误",
|
||||
statusCode: http.StatusInternalServerError,
|
||||
message: "internal error",
|
||||
},
|
||||
{
|
||||
name: "自定义状态码",
|
||||
statusCode: 418,
|
||||
message: "I'm a teapot",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
Error(c, tt.statusCode, tt.message)
|
||||
|
||||
require.Equal(t, tt.statusCode, w.Code)
|
||||
|
||||
got := parseResponseBody(t, w)
|
||||
require.Equal(t, tt.statusCode, got.Code)
|
||||
require.Equal(t, tt.message, got.Message)
|
||||
require.Empty(t, got.Reason)
|
||||
require.Nil(t, got.Metadata)
|
||||
require.Nil(t, got.Data)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBadRequest(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
BadRequest(c, "参数无效")
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, w.Code)
|
||||
got := parseResponseBody(t, w)
|
||||
require.Equal(t, http.StatusBadRequest, got.Code)
|
||||
require.Equal(t, "参数无效", got.Message)
|
||||
}
|
||||
|
||||
func TestUnauthorized(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
Unauthorized(c, "未登录")
|
||||
|
||||
require.Equal(t, http.StatusUnauthorized, w.Code)
|
||||
got := parseResponseBody(t, w)
|
||||
require.Equal(t, http.StatusUnauthorized, got.Code)
|
||||
require.Equal(t, "未登录", got.Message)
|
||||
}
|
||||
|
||||
func TestForbidden(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
Forbidden(c, "无权限")
|
||||
|
||||
require.Equal(t, http.StatusForbidden, w.Code)
|
||||
got := parseResponseBody(t, w)
|
||||
require.Equal(t, http.StatusForbidden, got.Code)
|
||||
require.Equal(t, "无权限", got.Message)
|
||||
}
|
||||
|
||||
func TestNotFound(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
NotFound(c, "资源不存在")
|
||||
|
||||
require.Equal(t, http.StatusNotFound, w.Code)
|
||||
got := parseResponseBody(t, w)
|
||||
require.Equal(t, http.StatusNotFound, got.Code)
|
||||
require.Equal(t, "资源不存在", got.Message)
|
||||
}
|
||||
|
||||
func TestInternalError(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
InternalError(c, "服务器内部错误")
|
||||
|
||||
require.Equal(t, http.StatusInternalServerError, w.Code)
|
||||
got := parseResponseBody(t, w)
|
||||
require.Equal(t, http.StatusInternalServerError, got.Code)
|
||||
require.Equal(t, "服务器内部错误", got.Message)
|
||||
}
|
||||
|
||||
func TestPaginated(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
items any
|
||||
total int64
|
||||
page int
|
||||
pageSize int
|
||||
wantPages int
|
||||
wantTotal int64
|
||||
wantPage int
|
||||
wantPageSize int
|
||||
}{
|
||||
{
|
||||
name: "标准分页_多页",
|
||||
items: []string{"a", "b"},
|
||||
total: 25,
|
||||
page: 1,
|
||||
pageSize: 10,
|
||||
wantPages: 3,
|
||||
wantTotal: 25,
|
||||
wantPage: 1,
|
||||
wantPageSize: 10,
|
||||
},
|
||||
{
|
||||
name: "总数刚好整除",
|
||||
items: []string{"a"},
|
||||
total: 20,
|
||||
page: 2,
|
||||
pageSize: 10,
|
||||
wantPages: 2,
|
||||
wantTotal: 20,
|
||||
wantPage: 2,
|
||||
wantPageSize: 10,
|
||||
},
|
||||
{
|
||||
name: "总数为0_pages至少为1",
|
||||
items: []string{},
|
||||
total: 0,
|
||||
page: 1,
|
||||
pageSize: 10,
|
||||
wantPages: 1,
|
||||
wantTotal: 0,
|
||||
wantPage: 1,
|
||||
wantPageSize: 10,
|
||||
},
|
||||
{
|
||||
name: "单页数据",
|
||||
items: []int{1, 2, 3},
|
||||
total: 3,
|
||||
page: 1,
|
||||
pageSize: 20,
|
||||
wantPages: 1,
|
||||
wantTotal: 3,
|
||||
wantPage: 1,
|
||||
wantPageSize: 20,
|
||||
},
|
||||
{
|
||||
name: "总数为1",
|
||||
items: []string{"only"},
|
||||
total: 1,
|
||||
page: 1,
|
||||
pageSize: 10,
|
||||
wantPages: 1,
|
||||
wantTotal: 1,
|
||||
wantPage: 1,
|
||||
wantPageSize: 10,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
Paginated(c, tt.items, tt.total, tt.page, tt.pageSize)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
resp, pd := parsePaginatedBody(t, w)
|
||||
require.Equal(t, 0, resp.Code)
|
||||
require.Equal(t, "success", resp.Message)
|
||||
require.Equal(t, tt.wantTotal, pd.Total)
|
||||
require.Equal(t, tt.wantPage, pd.Page)
|
||||
require.Equal(t, tt.wantPageSize, pd.PageSize)
|
||||
require.Equal(t, tt.wantPages, pd.Pages)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPaginatedWithResult(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
items any
|
||||
pagination *PaginationResult
|
||||
wantTotal int64
|
||||
wantPage int
|
||||
wantPageSize int
|
||||
wantPages int
|
||||
}{
|
||||
{
|
||||
name: "正常分页结果",
|
||||
items: []string{"a", "b"},
|
||||
pagination: &PaginationResult{
|
||||
Total: 50,
|
||||
Page: 3,
|
||||
PageSize: 10,
|
||||
Pages: 5,
|
||||
},
|
||||
wantTotal: 50,
|
||||
wantPage: 3,
|
||||
wantPageSize: 10,
|
||||
wantPages: 5,
|
||||
},
|
||||
{
|
||||
name: "pagination为nil_使用默认值",
|
||||
items: []string{},
|
||||
pagination: nil,
|
||||
wantTotal: 0,
|
||||
wantPage: 1,
|
||||
wantPageSize: 20,
|
||||
wantPages: 1,
|
||||
},
|
||||
{
|
||||
name: "单页结果",
|
||||
items: []int{1},
|
||||
pagination: &PaginationResult{
|
||||
Total: 1,
|
||||
Page: 1,
|
||||
PageSize: 20,
|
||||
Pages: 1,
|
||||
},
|
||||
wantTotal: 1,
|
||||
wantPage: 1,
|
||||
wantPageSize: 20,
|
||||
wantPages: 1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
PaginatedWithResult(c, tt.items, tt.pagination)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
resp, pd := parsePaginatedBody(t, w)
|
||||
require.Equal(t, 0, resp.Code)
|
||||
require.Equal(t, "success", resp.Message)
|
||||
require.Equal(t, tt.wantTotal, pd.Total)
|
||||
require.Equal(t, tt.wantPage, pd.Page)
|
||||
require.Equal(t, tt.wantPageSize, pd.PageSize)
|
||||
require.Equal(t, tt.wantPages, pd.Pages)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParsePagination(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
query string
|
||||
wantPage int
|
||||
wantPageSize int
|
||||
}{
|
||||
{
|
||||
name: "无参数_使用默认值",
|
||||
query: "",
|
||||
wantPage: 1,
|
||||
wantPageSize: 20,
|
||||
},
|
||||
{
|
||||
name: "仅指定page",
|
||||
query: "page=3",
|
||||
wantPage: 3,
|
||||
wantPageSize: 20,
|
||||
},
|
||||
{
|
||||
name: "仅指定page_size",
|
||||
query: "page_size=50",
|
||||
wantPage: 1,
|
||||
wantPageSize: 50,
|
||||
},
|
||||
{
|
||||
name: "同时指定page和page_size",
|
||||
query: "page=2&page_size=30",
|
||||
wantPage: 2,
|
||||
wantPageSize: 30,
|
||||
},
|
||||
{
|
||||
name: "使用limit代替page_size",
|
||||
query: "limit=15",
|
||||
wantPage: 1,
|
||||
wantPageSize: 15,
|
||||
},
|
||||
{
|
||||
name: "page_size优先于limit",
|
||||
query: "page_size=25&limit=50",
|
||||
wantPage: 1,
|
||||
wantPageSize: 25,
|
||||
},
|
||||
{
|
||||
name: "page为0_使用默认值",
|
||||
query: "page=0",
|
||||
wantPage: 1,
|
||||
wantPageSize: 20,
|
||||
},
|
||||
{
|
||||
name: "page_size超过1000_使用默认值",
|
||||
query: "page_size=1001",
|
||||
wantPage: 1,
|
||||
wantPageSize: 20,
|
||||
},
|
||||
{
|
||||
name: "page_size恰好1000_有效",
|
||||
query: "page_size=1000",
|
||||
wantPage: 1,
|
||||
wantPageSize: 1000,
|
||||
},
|
||||
{
|
||||
name: "page为非数字_使用默认值",
|
||||
query: "page=abc",
|
||||
wantPage: 1,
|
||||
wantPageSize: 20,
|
||||
},
|
||||
{
|
||||
name: "page_size为非数字_使用默认值",
|
||||
query: "page_size=xyz",
|
||||
wantPage: 1,
|
||||
wantPageSize: 20,
|
||||
},
|
||||
{
|
||||
name: "limit为非数字_使用默认值",
|
||||
query: "limit=abc",
|
||||
wantPage: 1,
|
||||
wantPageSize: 20,
|
||||
},
|
||||
{
|
||||
name: "page_size为0_使用默认值",
|
||||
query: "page_size=0",
|
||||
wantPage: 1,
|
||||
wantPageSize: 20,
|
||||
},
|
||||
{
|
||||
name: "limit为0_使用默认值",
|
||||
query: "limit=0",
|
||||
wantPage: 1,
|
||||
wantPageSize: 20,
|
||||
},
|
||||
{
|
||||
name: "大页码",
|
||||
query: "page=999&page_size=100",
|
||||
wantPage: 999,
|
||||
wantPageSize: 100,
|
||||
},
|
||||
{
|
||||
name: "page_size为1_最小有效值",
|
||||
query: "page_size=1",
|
||||
wantPage: 1,
|
||||
wantPageSize: 1,
|
||||
},
|
||||
{
|
||||
name: "混合数字和字母的page",
|
||||
query: "page=12a",
|
||||
wantPage: 1,
|
||||
wantPageSize: 20,
|
||||
},
|
||||
{
|
||||
name: "limit超过1000_使用默认值",
|
||||
query: "limit=2000",
|
||||
wantPage: 1,
|
||||
wantPageSize: 20,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, c := newContextWithQuery(tt.query)
|
||||
|
||||
page, pageSize := ParsePagination(c)
|
||||
|
||||
require.Equal(t, tt.wantPage, page, "page 不符合预期")
|
||||
require.Equal(t, tt.wantPageSize, pageSize, "pageSize 不符合预期")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_parseInt(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantVal int
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "正常数字",
|
||||
input: "123",
|
||||
wantVal: 123,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "零",
|
||||
input: "0",
|
||||
wantVal: 0,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "单个数字",
|
||||
input: "5",
|
||||
wantVal: 5,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "大数字",
|
||||
input: "99999",
|
||||
wantVal: 99999,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "包含字母_返回0",
|
||||
input: "abc",
|
||||
wantVal: 0,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "数字开头接字母_返回0",
|
||||
input: "12a",
|
||||
wantVal: 0,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "包含负号_返回0",
|
||||
input: "-1",
|
||||
wantVal: 0,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "包含小数点_返回0",
|
||||
input: "1.5",
|
||||
wantVal: 0,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "包含空格_返回0",
|
||||
input: "1 2",
|
||||
wantVal: 0,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "空字符串",
|
||||
input: "",
|
||||
wantVal: 0,
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
val, err := parseInt(tt.input)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
require.Equal(t, tt.wantVal, val)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -286,7 +286,7 @@ func (d *SOCKS5ProxyDialer) DialTLSContext(ctx context.Context, network, addr st
|
||||
return nil, fmt.Errorf("apply TLS preset: %w", err)
|
||||
}
|
||||
|
||||
if err := tlsConn.Handshake(); err != nil {
|
||||
if err := tlsConn.HandshakeContext(ctx); err != nil {
|
||||
slog.Debug("tls_fingerprint_socks5_handshake_failed", "error", err)
|
||||
_ = conn.Close()
|
||||
return nil, fmt.Errorf("TLS handshake failed: %w", err)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build unit
|
||||
|
||||
// Package tlsfingerprint provides TLS fingerprint simulation for HTTP clients.
|
||||
//
|
||||
// Unit tests for TLS fingerprint dialer.
|
||||
@@ -9,26 +11,161 @@
|
||||
package tlsfingerprint
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// FingerprintResponse represents the response from tls.peet.ws/api/all.
|
||||
type FingerprintResponse struct {
|
||||
IP string `json:"ip"`
|
||||
TLS TLSInfo `json:"tls"`
|
||||
HTTP2 any `json:"http2"`
|
||||
// TestDialerBasicConnection tests that the dialer can establish TLS connections.
|
||||
func TestDialerBasicConnection(t *testing.T) {
|
||||
skipNetworkTest(t)
|
||||
|
||||
// Create a dialer with default profile
|
||||
profile := &Profile{
|
||||
Name: "Test Profile",
|
||||
EnableGREASE: false,
|
||||
}
|
||||
dialer := NewDialer(profile, nil)
|
||||
|
||||
// Create HTTP client with custom TLS dialer
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
DialTLSContext: dialer.DialTLSContext,
|
||||
},
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
// Make a request to a known HTTPS endpoint
|
||||
resp, err := client.Get("https://www.google.com")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect: %v", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
// TLSInfo contains TLS fingerprint details.
|
||||
type TLSInfo struct {
|
||||
JA3 string `json:"ja3"`
|
||||
JA3Hash string `json:"ja3_hash"`
|
||||
JA4 string `json:"ja4"`
|
||||
PeetPrint string `json:"peetprint"`
|
||||
PeetPrintHash string `json:"peetprint_hash"`
|
||||
ClientRandom string `json:"client_random"`
|
||||
SessionID string `json:"session_id"`
|
||||
// TestJA3Fingerprint verifies the JA3/JA4 fingerprint matches expected value.
|
||||
// This test uses tls.peet.ws to verify the fingerprint.
|
||||
// Expected JA3 hash: 1a28e69016765d92e3b381168d68922c (Claude CLI / Node.js 20.x)
|
||||
// Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4 (d=domain) or t13i5911h1_... (i=IP)
|
||||
func TestJA3Fingerprint(t *testing.T) {
|
||||
skipNetworkTest(t)
|
||||
|
||||
profile := &Profile{
|
||||
Name: "Claude CLI Test",
|
||||
EnableGREASE: false,
|
||||
}
|
||||
dialer := NewDialer(profile, nil)
|
||||
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
DialTLSContext: dialer.DialTLSContext,
|
||||
},
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
// Use tls.peet.ws fingerprint detection API
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", "https://tls.peet.ws/api/all", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create request: %v", err)
|
||||
}
|
||||
req.Header.Set("User-Agent", "Claude Code/2.0.0 Node.js/20.0.0")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get fingerprint: %v", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read response: %v", err)
|
||||
}
|
||||
|
||||
var fpResp FingerprintResponse
|
||||
if err := json.Unmarshal(body, &fpResp); err != nil {
|
||||
t.Logf("Response body: %s", string(body))
|
||||
t.Fatalf("failed to parse fingerprint response: %v", err)
|
||||
}
|
||||
|
||||
// Log all fingerprint information
|
||||
t.Logf("JA3: %s", fpResp.TLS.JA3)
|
||||
t.Logf("JA3 Hash: %s", fpResp.TLS.JA3Hash)
|
||||
t.Logf("JA4: %s", fpResp.TLS.JA4)
|
||||
t.Logf("PeetPrint: %s", fpResp.TLS.PeetPrint)
|
||||
t.Logf("PeetPrint Hash: %s", fpResp.TLS.PeetPrintHash)
|
||||
|
||||
// Verify JA3 hash matches expected value
|
||||
expectedJA3Hash := "1a28e69016765d92e3b381168d68922c"
|
||||
if fpResp.TLS.JA3Hash == expectedJA3Hash {
|
||||
t.Logf("✓ JA3 hash matches expected value: %s", expectedJA3Hash)
|
||||
} else {
|
||||
t.Errorf("✗ JA3 hash mismatch: got %s, expected %s", fpResp.TLS.JA3Hash, expectedJA3Hash)
|
||||
}
|
||||
|
||||
// Verify JA4 fingerprint
|
||||
// JA4 format: t[version][sni][cipher_count][ext_count][alpn]_[cipher_hash]_[ext_hash]
|
||||
// Expected: t13d5910h1 (d=domain) or t13i5910h1 (i=IP)
|
||||
// The suffix _a33745022dd6_1f22a2ca17c4 should match
|
||||
expectedJA4Suffix := "_a33745022dd6_1f22a2ca17c4"
|
||||
if strings.HasSuffix(fpResp.TLS.JA4, expectedJA4Suffix) {
|
||||
t.Logf("✓ JA4 suffix matches expected value: %s", expectedJA4Suffix)
|
||||
} else {
|
||||
t.Errorf("✗ JA4 suffix mismatch: got %s, expected suffix %s", fpResp.TLS.JA4, expectedJA4Suffix)
|
||||
}
|
||||
|
||||
// Verify JA4 prefix (t13d5911h1 or t13i5911h1)
|
||||
// d = domain (SNI present), i = IP (no SNI)
|
||||
// Since we connect to tls.peet.ws (domain), we expect 'd'
|
||||
expectedJA4Prefix := "t13d5911h1"
|
||||
if strings.HasPrefix(fpResp.TLS.JA4, expectedJA4Prefix) {
|
||||
t.Logf("✓ JA4 prefix matches: %s (t13=TLS1.3, d=domain, 59=ciphers, 11=extensions, h1=HTTP/1.1)", expectedJA4Prefix)
|
||||
} else {
|
||||
// Also accept 'i' variant for IP connections
|
||||
altPrefix := "t13i5911h1"
|
||||
if strings.HasPrefix(fpResp.TLS.JA4, altPrefix) {
|
||||
t.Logf("✓ JA4 prefix matches (IP variant): %s", altPrefix)
|
||||
} else {
|
||||
t.Errorf("✗ JA4 prefix mismatch: got %s, expected %s or %s", fpResp.TLS.JA4, expectedJA4Prefix, altPrefix)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify JA3 contains expected cipher suites (TLS 1.3 ciphers at the beginning)
|
||||
if strings.Contains(fpResp.TLS.JA3, "4866-4867-4865") {
|
||||
t.Logf("✓ JA3 contains expected TLS 1.3 cipher suites")
|
||||
} else {
|
||||
t.Logf("Warning: JA3 does not contain expected TLS 1.3 cipher suites")
|
||||
}
|
||||
|
||||
// Verify extension list (should be 11 extensions including SNI)
|
||||
// Expected: 0-11-10-35-16-22-23-13-43-45-51
|
||||
expectedExtensions := "0-11-10-35-16-22-23-13-43-45-51"
|
||||
if strings.Contains(fpResp.TLS.JA3, expectedExtensions) {
|
||||
t.Logf("✓ JA3 contains expected extension list: %s", expectedExtensions)
|
||||
} else {
|
||||
t.Logf("Warning: JA3 extension list may differ")
|
||||
}
|
||||
}
|
||||
|
||||
func skipNetworkTest(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("跳过网络测试(short 模式)")
|
||||
}
|
||||
if os.Getenv("TLSFINGERPRINT_NETWORK_TESTS") != "1" {
|
||||
t.Skip("跳过网络测试(需要设置 TLSFINGERPRINT_NETWORK_TESTS=1)")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDialerWithProfile tests that different profiles produce different fingerprints.
|
||||
@@ -158,3 +295,137 @@ func mustParseURL(rawURL string) *url.URL {
|
||||
}
|
||||
return u
|
||||
}
|
||||
|
||||
// TestProfileExpectation defines expected fingerprint values for a profile.
|
||||
type TestProfileExpectation struct {
|
||||
Profile *Profile
|
||||
ExpectedJA3 string // Expected JA3 hash (empty = don't check)
|
||||
ExpectedJA4 string // Expected full JA4 (empty = don't check)
|
||||
JA4CipherHash string // Expected JA4 cipher hash - the stable middle part (empty = don't check)
|
||||
}
|
||||
|
||||
// TestAllProfiles tests multiple TLS fingerprint profiles against tls.peet.ws.
|
||||
// Run with: go test -v -run TestAllProfiles ./internal/pkg/tlsfingerprint/...
|
||||
func TestAllProfiles(t *testing.T) {
|
||||
skipNetworkTest(t)
|
||||
|
||||
// Define all profiles to test with their expected fingerprints
|
||||
// These profiles are from config.yaml gateway.tls_fingerprint.profiles
|
||||
profiles := []TestProfileExpectation{
|
||||
{
|
||||
// Linux x64 Node.js v22.17.1
|
||||
// Expected JA3 Hash: 1a28e69016765d92e3b381168d68922c
|
||||
// Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4
|
||||
Profile: &Profile{
|
||||
Name: "linux_x64_node_v22171",
|
||||
EnableGREASE: false,
|
||||
CipherSuites: []uint16{4866, 4867, 4865, 49199, 49195, 49200, 49196, 158, 49191, 103, 49192, 107, 163, 159, 52393, 52392, 52394, 49327, 49325, 49315, 49311, 49245, 49249, 49239, 49235, 162, 49326, 49324, 49314, 49310, 49244, 49248, 49238, 49234, 49188, 106, 49187, 64, 49162, 49172, 57, 56, 49161, 49171, 51, 50, 157, 49313, 49309, 49233, 156, 49312, 49308, 49232, 61, 60, 53, 47, 255},
|
||||
Curves: []uint16{29, 23, 30, 25, 24, 256, 257, 258, 259, 260},
|
||||
PointFormats: []uint8{0, 1, 2},
|
||||
},
|
||||
JA4CipherHash: "a33745022dd6", // stable part
|
||||
},
|
||||
{
|
||||
// MacOS arm64 Node.js v22.18.0
|
||||
// Expected JA3 Hash: 70cb5ca646080902703ffda87036a5ea
|
||||
// Expected JA4: t13d5912h1_a33745022dd6_dbd39dd1d406
|
||||
Profile: &Profile{
|
||||
Name: "macos_arm64_node_v22180",
|
||||
EnableGREASE: false,
|
||||
CipherSuites: []uint16{4866, 4867, 4865, 49199, 49195, 49200, 49196, 158, 49191, 103, 49192, 107, 163, 159, 52393, 52392, 52394, 49327, 49325, 49315, 49311, 49245, 49249, 49239, 49235, 162, 49326, 49324, 49314, 49310, 49244, 49248, 49238, 49234, 49188, 106, 49187, 64, 49162, 49172, 57, 56, 49161, 49171, 51, 50, 157, 49313, 49309, 49233, 156, 49312, 49308, 49232, 61, 60, 53, 47, 255},
|
||||
Curves: []uint16{29, 23, 30, 25, 24, 256, 257, 258, 259, 260},
|
||||
PointFormats: []uint8{0, 1, 2},
|
||||
},
|
||||
JA4CipherHash: "a33745022dd6", // stable part (same cipher suites)
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range profiles {
|
||||
tc := tc // capture range variable
|
||||
t.Run(tc.Profile.Name, func(t *testing.T) {
|
||||
fp := fetchFingerprint(t, tc.Profile)
|
||||
if fp == nil {
|
||||
return // fetchFingerprint already called t.Fatal
|
||||
}
|
||||
|
||||
t.Logf("Profile: %s", tc.Profile.Name)
|
||||
t.Logf(" JA3: %s", fp.JA3)
|
||||
t.Logf(" JA3 Hash: %s", fp.JA3Hash)
|
||||
t.Logf(" JA4: %s", fp.JA4)
|
||||
t.Logf(" PeetPrint: %s", fp.PeetPrint)
|
||||
t.Logf(" PeetPrintHash: %s", fp.PeetPrintHash)
|
||||
|
||||
// Verify expectations
|
||||
if tc.ExpectedJA3 != "" {
|
||||
if fp.JA3Hash == tc.ExpectedJA3 {
|
||||
t.Logf(" ✓ JA3 hash matches: %s", tc.ExpectedJA3)
|
||||
} else {
|
||||
t.Errorf(" ✗ JA3 hash mismatch: got %s, expected %s", fp.JA3Hash, tc.ExpectedJA3)
|
||||
}
|
||||
}
|
||||
|
||||
if tc.ExpectedJA4 != "" {
|
||||
if fp.JA4 == tc.ExpectedJA4 {
|
||||
t.Logf(" ✓ JA4 matches: %s", tc.ExpectedJA4)
|
||||
} else {
|
||||
t.Errorf(" ✗ JA4 mismatch: got %s, expected %s", fp.JA4, tc.ExpectedJA4)
|
||||
}
|
||||
}
|
||||
|
||||
// Check JA4 cipher hash (stable middle part)
|
||||
// JA4 format: prefix_cipherHash_extHash
|
||||
if tc.JA4CipherHash != "" {
|
||||
if strings.Contains(fp.JA4, "_"+tc.JA4CipherHash+"_") {
|
||||
t.Logf(" ✓ JA4 cipher hash matches: %s", tc.JA4CipherHash)
|
||||
} else {
|
||||
t.Errorf(" ✗ JA4 cipher hash mismatch: got %s, expected cipher hash %s", fp.JA4, tc.JA4CipherHash)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// fetchFingerprint makes a request to tls.peet.ws and returns the TLS fingerprint info.
|
||||
func fetchFingerprint(t *testing.T, profile *Profile) *TLSInfo {
|
||||
t.Helper()
|
||||
|
||||
dialer := NewDialer(profile, nil)
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
DialTLSContext: dialer.DialTLSContext,
|
||||
},
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", "https://tls.peet.ws/api/all", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create request: %v", err)
|
||||
return nil
|
||||
}
|
||||
req.Header.Set("User-Agent", "Claude Code/2.0.0 Node.js/20.0.0")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get fingerprint: %v", err)
|
||||
return nil
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read response: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
var fpResp FingerprintResponse
|
||||
if err := json.Unmarshal(body, &fpResp); err != nil {
|
||||
t.Logf("Response body: %s", string(body))
|
||||
t.Fatalf("failed to parse fingerprint response: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
return &fpResp.TLS
|
||||
}
|
||||
|
||||
20
backend/internal/pkg/tlsfingerprint/test_types_test.go
Normal file
20
backend/internal/pkg/tlsfingerprint/test_types_test.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package tlsfingerprint
|
||||
|
||||
// FingerprintResponse represents the response from tls.peet.ws/api/all.
|
||||
// 共享测试类型,供 unit 和 integration 测试文件使用。
|
||||
type FingerprintResponse struct {
|
||||
IP string `json:"ip"`
|
||||
TLS TLSInfo `json:"tls"`
|
||||
HTTP2 any `json:"http2"`
|
||||
}
|
||||
|
||||
// TLSInfo contains TLS fingerprint details.
|
||||
type TLSInfo struct {
|
||||
JA3 string `json:"ja3"`
|
||||
JA3Hash string `json:"ja3_hash"`
|
||||
JA4 string `json:"ja4"`
|
||||
PeetPrint string `json:"peetprint"`
|
||||
PeetPrintHash string `json:"peetprint_hash"`
|
||||
ClientRandom string `json:"client_random"`
|
||||
SessionID string `json:"session_id"`
|
||||
}
|
||||
Reference in New Issue
Block a user