From 21a04332ec8e229fa66bcb70e7e6faa3df02ea0d Mon Sep 17 00:00:00 2001 From: song Date: Mon, 29 Dec 2025 17:46:52 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=20golangci-lint=20?= =?UTF-8?q?=E6=A3=80=E6=9F=A5=E9=94=99=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - SA1029: 创建 ctxkey 包定义类型安全的 context key - ST1005: 错误字符串首字母改小写 - errcheck: 显式忽略 bytes.Buffer.Write 返回值 - 修复单元测试中 GatewayService 缺少 cfg 字段的问题 --- backend/internal/pkg/antigravity/client.go | 12 +-- .../internal/pkg/antigravity/gemini_types.go | 34 ++++----- .../pkg/antigravity/stream_transformer.go | 74 +++++++++---------- backend/internal/pkg/ctxkey/ctxkey.go | 10 +++ .../internal/server/middleware/middleware.go | 9 +-- .../service/antigravity_oauth_service.go | 4 +- .../service/gateway_multiplatform_test.go | 22 ++++++ backend/internal/service/gateway_service.go | 6 +- .../service/gemini_messages_compat_service.go | 3 +- 9 files changed, 101 insertions(+), 73 deletions(-) create mode 100644 backend/internal/pkg/ctxkey/ctxkey.go diff --git a/backend/internal/pkg/antigravity/client.go b/backend/internal/pkg/antigravity/client.go index 4f14b0e6..e5d5b905 100644 --- a/backend/internal/pkg/antigravity/client.go +++ b/backend/internal/pkg/antigravity/client.go @@ -114,7 +114,7 @@ func (c *Client) ExchangeCode(ctx context.Context, code, codeVerifier string) (* resp, err := c.httpClient.Do(req) if err != nil { - return nil, fmt.Errorf("Token 交换请求失败: %w", err) + return nil, fmt.Errorf("token 交换请求失败: %w", err) } defer func() { _ = resp.Body.Close() }() @@ -124,12 +124,12 @@ func (c *Client) ExchangeCode(ctx context.Context, code, codeVerifier string) (* } if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("Token 交换失败 (HTTP %d): %s", resp.StatusCode, string(bodyBytes)) + return nil, fmt.Errorf("token 交换失败 (HTTP %d): %s", resp.StatusCode, string(bodyBytes)) } var tokenResp TokenResponse if err := json.Unmarshal(bodyBytes, &tokenResp); err != nil { - return nil, fmt.Errorf("Token 解析失败: %w", err) + return nil, fmt.Errorf("token 解析失败: %w", err) } return &tokenResp, nil @@ -151,7 +151,7 @@ func (c *Client) RefreshToken(ctx context.Context, refreshToken string) (*TokenR resp, err := c.httpClient.Do(req) if err != nil { - return nil, fmt.Errorf("Token 刷新请求失败: %w", err) + return nil, fmt.Errorf("token 刷新请求失败: %w", err) } defer func() { _ = resp.Body.Close() }() @@ -161,12 +161,12 @@ func (c *Client) RefreshToken(ctx context.Context, refreshToken string) (*TokenR } if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("Token 刷新失败 (HTTP %d): %s", resp.StatusCode, string(bodyBytes)) + return nil, fmt.Errorf("token 刷新失败 (HTTP %d): %s", resp.StatusCode, string(bodyBytes)) } var tokenResp TokenResponse if err := json.Unmarshal(bodyBytes, &tokenResp); err != nil { - return nil, fmt.Errorf("Token 解析失败: %w", err) + return nil, fmt.Errorf("token 解析失败: %w", err) } return &tokenResp, nil diff --git a/backend/internal/pkg/antigravity/gemini_types.go b/backend/internal/pkg/antigravity/gemini_types.go index 95b9faec..2800e0ee 100644 --- a/backend/internal/pkg/antigravity/gemini_types.go +++ b/backend/internal/pkg/antigravity/gemini_types.go @@ -14,13 +14,13 @@ type V1InternalRequest struct { // GeminiRequest Gemini 请求内容 type GeminiRequest struct { - Contents []GeminiContent `json:"contents"` - SystemInstruction *GeminiContent `json:"systemInstruction,omitempty"` - GenerationConfig *GeminiGenerationConfig `json:"generationConfig,omitempty"` - Tools []GeminiToolDeclaration `json:"tools,omitempty"` - ToolConfig *GeminiToolConfig `json:"toolConfig,omitempty"` - SafetySettings []GeminiSafetySetting `json:"safetySettings,omitempty"` - SessionID string `json:"sessionId,omitempty"` + Contents []GeminiContent `json:"contents"` + SystemInstruction *GeminiContent `json:"systemInstruction,omitempty"` + GenerationConfig *GeminiGenerationConfig `json:"generationConfig,omitempty"` + Tools []GeminiToolDeclaration `json:"tools,omitempty"` + ToolConfig *GeminiToolConfig `json:"toolConfig,omitempty"` + SafetySettings []GeminiSafetySetting `json:"safetySettings,omitempty"` + SessionID string `json:"sessionId,omitempty"` } // GeminiContent Gemini 内容 @@ -31,10 +31,10 @@ type GeminiContent struct { // GeminiPart Gemini 内容部分 type GeminiPart struct { - Text string `json:"text,omitempty"` - Thought bool `json:"thought,omitempty"` - ThoughtSignature string `json:"thoughtSignature,omitempty"` - InlineData *GeminiInlineData `json:"inlineData,omitempty"` + Text string `json:"text,omitempty"` + Thought bool `json:"thought,omitempty"` + ThoughtSignature string `json:"thoughtSignature,omitempty"` + InlineData *GeminiInlineData `json:"inlineData,omitempty"` FunctionCall *GeminiFunctionCall `json:"functionCall,omitempty"` FunctionResponse *GeminiFunctionResponse `json:"functionResponse,omitempty"` } @@ -61,12 +61,12 @@ type GeminiFunctionResponse struct { // GeminiGenerationConfig Gemini 生成配置 type GeminiGenerationConfig struct { - MaxOutputTokens int `json:"maxOutputTokens,omitempty"` - Temperature *float64 `json:"temperature,omitempty"` - TopP *float64 `json:"topP,omitempty"` - TopK *int `json:"topK,omitempty"` - ThinkingConfig *GeminiThinkingConfig `json:"thinkingConfig,omitempty"` - StopSequences []string `json:"stopSequences,omitempty"` + MaxOutputTokens int `json:"maxOutputTokens,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"topP,omitempty"` + TopK *int `json:"topK,omitempty"` + ThinkingConfig *GeminiThinkingConfig `json:"thinkingConfig,omitempty"` + StopSequences []string `json:"stopSequences,omitempty"` } // GeminiThinkingConfig Gemini thinking 配置 diff --git a/backend/internal/pkg/antigravity/stream_transformer.go b/backend/internal/pkg/antigravity/stream_transformer.go index a0611e9a..20c8444a 100644 --- a/backend/internal/pkg/antigravity/stream_transformer.go +++ b/backend/internal/pkg/antigravity/stream_transformer.go @@ -72,7 +72,7 @@ func (p *StreamingProcessor) ProcessLine(line string) []byte { // 发送 message_start if !p.messageStartSent { - result.Write(p.emitMessageStart(&v1Resp)) + _, _ = result.Write(p.emitMessageStart(&v1Resp)) } // 更新 usage @@ -84,7 +84,7 @@ func (p *StreamingProcessor) ProcessLine(line string) []byte { // 处理 parts if len(geminiResp.Candidates) > 0 && geminiResp.Candidates[0].Content != nil { for _, part := range geminiResp.Candidates[0].Content.Parts { - result.Write(p.processPart(&part)) + _, _ = result.Write(p.processPart(&part)) } } @@ -92,7 +92,7 @@ func (p *StreamingProcessor) ProcessLine(line string) []byte { if len(geminiResp.Candidates) > 0 { finishReason := geminiResp.Candidates[0].FinishReason if finishReason != "" { - result.Write(p.emitFinish(finishReason)) + _, _ = result.Write(p.emitFinish(finishReason)) } } @@ -104,7 +104,7 @@ func (p *StreamingProcessor) Finish() ([]byte, *ClaudeUsage) { var result bytes.Buffer if !p.messageStopSent { - result.Write(p.emitFinish("")) + _, _ = result.Write(p.emitFinish("")) } usage := &ClaudeUsage{ @@ -164,21 +164,21 @@ func (p *StreamingProcessor) processPart(part *GeminiPart) []byte { if part.FunctionCall != nil { // 先处理 trailingSignature if p.trailingSignature != "" { - result.Write(p.endBlock()) - result.Write(p.emitEmptyThinkingWithSignature(p.trailingSignature)) + _, _ = result.Write(p.endBlock()) + _, _ = result.Write(p.emitEmptyThinkingWithSignature(p.trailingSignature)) p.trailingSignature = "" } - result.Write(p.processFunctionCall(part.FunctionCall, signature)) + _, _ = result.Write(p.processFunctionCall(part.FunctionCall, signature)) return result.Bytes() } // 2. Text 处理 if part.Text != "" || part.Thought { if part.Thought { - result.Write(p.processThinking(part.Text, signature)) + _, _ = result.Write(p.processThinking(part.Text, signature)) } else { - result.Write(p.processText(part.Text, signature)) + _, _ = result.Write(p.processText(part.Text, signature)) } } @@ -186,7 +186,7 @@ func (p *StreamingProcessor) processPart(part *GeminiPart) []byte { if part.InlineData != nil && part.InlineData.Data != "" { markdownImg := fmt.Sprintf("![image](data:%s;base64,%s)", part.InlineData.MimeType, part.InlineData.Data) - result.Write(p.processText(markdownImg, "")) + _, _ = result.Write(p.processText(markdownImg, "")) } return result.Bytes() @@ -198,21 +198,21 @@ func (p *StreamingProcessor) processThinking(text, signature string) []byte { // 处理之前的 trailingSignature if p.trailingSignature != "" { - result.Write(p.endBlock()) - result.Write(p.emitEmptyThinkingWithSignature(p.trailingSignature)) + _, _ = result.Write(p.endBlock()) + _, _ = result.Write(p.emitEmptyThinkingWithSignature(p.trailingSignature)) p.trailingSignature = "" } // 开始或继续 thinking 块 if p.blockType != BlockTypeThinking { - result.Write(p.startBlock(BlockTypeThinking, map[string]interface{}{ + _, _ = result.Write(p.startBlock(BlockTypeThinking, map[string]interface{}{ "type": "thinking", "thinking": "", })) } if text != "" { - result.Write(p.emitDelta("thinking_delta", map[string]interface{}{ + _, _ = result.Write(p.emitDelta("thinking_delta", map[string]interface{}{ "thinking": text, })) } @@ -239,34 +239,34 @@ func (p *StreamingProcessor) processText(text, signature string) []byte { // 处理之前的 trailingSignature if p.trailingSignature != "" { - result.Write(p.endBlock()) - result.Write(p.emitEmptyThinkingWithSignature(p.trailingSignature)) + _, _ = result.Write(p.endBlock()) + _, _ = result.Write(p.emitEmptyThinkingWithSignature(p.trailingSignature)) p.trailingSignature = "" } // 非空 text 带签名 - 特殊处理 if signature != "" { - result.Write(p.startBlock(BlockTypeText, map[string]interface{}{ + _, _ = result.Write(p.startBlock(BlockTypeText, map[string]interface{}{ "type": "text", "text": "", })) - result.Write(p.emitDelta("text_delta", map[string]interface{}{ + _, _ = result.Write(p.emitDelta("text_delta", map[string]interface{}{ "text": text, })) - result.Write(p.endBlock()) - result.Write(p.emitEmptyThinkingWithSignature(signature)) + _, _ = result.Write(p.endBlock()) + _, _ = result.Write(p.emitEmptyThinkingWithSignature(signature)) return result.Bytes() } // 普通 text (无签名) if p.blockType != BlockTypeText { - result.Write(p.startBlock(BlockTypeText, map[string]interface{}{ + _, _ = result.Write(p.startBlock(BlockTypeText, map[string]interface{}{ "type": "text", "text": "", })) } - result.Write(p.emitDelta("text_delta", map[string]interface{}{ + _, _ = result.Write(p.emitDelta("text_delta", map[string]interface{}{ "text": text, })) @@ -295,17 +295,17 @@ func (p *StreamingProcessor) processFunctionCall(fc *GeminiFunctionCall, signatu toolUse["signature"] = signature } - result.Write(p.startBlock(BlockTypeFunction, toolUse)) + _, _ = result.Write(p.startBlock(BlockTypeFunction, toolUse)) // 发送 input_json_delta if fc.Args != nil { argsJSON, _ := json.Marshal(fc.Args) - result.Write(p.emitDelta("input_json_delta", map[string]interface{}{ + _, _ = result.Write(p.emitDelta("input_json_delta", map[string]interface{}{ "partial_json": string(argsJSON), })) } - result.Write(p.endBlock()) + _, _ = result.Write(p.endBlock()) return result.Bytes() } @@ -315,7 +315,7 @@ func (p *StreamingProcessor) startBlock(blockType BlockType, contentBlock map[st var result bytes.Buffer if p.blockType != BlockTypeNone { - result.Write(p.endBlock()) + _, _ = result.Write(p.endBlock()) } event := map[string]interface{}{ @@ -324,7 +324,7 @@ func (p *StreamingProcessor) startBlock(blockType BlockType, contentBlock map[st "content_block": contentBlock, } - result.Write(p.formatSSE("content_block_start", event)) + _, _ = result.Write(p.formatSSE("content_block_start", event)) p.blockType = blockType return result.Bytes() @@ -340,7 +340,7 @@ func (p *StreamingProcessor) endBlock() []byte { // Thinking 块结束时发送暂存的签名 if p.blockType == BlockTypeThinking && p.pendingSignature != "" { - result.Write(p.emitDelta("signature_delta", map[string]interface{}{ + _, _ = result.Write(p.emitDelta("signature_delta", map[string]interface{}{ "signature": p.pendingSignature, })) p.pendingSignature = "" @@ -351,7 +351,7 @@ func (p *StreamingProcessor) endBlock() []byte { "index": p.blockIndex, } - result.Write(p.formatSSE("content_block_stop", event)) + _, _ = result.Write(p.formatSSE("content_block_stop", event)) p.blockIndex++ p.blockType = BlockTypeNone @@ -381,17 +381,17 @@ func (p *StreamingProcessor) emitDelta(deltaType string, deltaContent map[string func (p *StreamingProcessor) emitEmptyThinkingWithSignature(signature string) []byte { var result bytes.Buffer - result.Write(p.startBlock(BlockTypeThinking, map[string]interface{}{ + _, _ = result.Write(p.startBlock(BlockTypeThinking, map[string]interface{}{ "type": "thinking", "thinking": "", })) - result.Write(p.emitDelta("thinking_delta", map[string]interface{}{ + _, _ = result.Write(p.emitDelta("thinking_delta", map[string]interface{}{ "thinking": "", })) - result.Write(p.emitDelta("signature_delta", map[string]interface{}{ + _, _ = result.Write(p.emitDelta("signature_delta", map[string]interface{}{ "signature": signature, })) - result.Write(p.endBlock()) + _, _ = result.Write(p.endBlock()) return result.Bytes() } @@ -401,11 +401,11 @@ func (p *StreamingProcessor) emitFinish(finishReason string) []byte { var result bytes.Buffer // 关闭最后一个块 - result.Write(p.endBlock()) + _, _ = result.Write(p.endBlock()) // 处理 trailingSignature if p.trailingSignature != "" { - result.Write(p.emitEmptyThinkingWithSignature(p.trailingSignature)) + _, _ = result.Write(p.emitEmptyThinkingWithSignature(p.trailingSignature)) p.trailingSignature = "" } @@ -431,13 +431,13 @@ func (p *StreamingProcessor) emitFinish(finishReason string) []byte { "usage": usage, } - result.Write(p.formatSSE("message_delta", deltaEvent)) + _, _ = result.Write(p.formatSSE("message_delta", deltaEvent)) if !p.messageStopSent { stopEvent := map[string]interface{}{ "type": "message_stop", } - result.Write(p.formatSSE("message_stop", stopEvent)) + _, _ = result.Write(p.formatSSE("message_stop", stopEvent)) p.messageStopSent = true } diff --git a/backend/internal/pkg/ctxkey/ctxkey.go b/backend/internal/pkg/ctxkey/ctxkey.go new file mode 100644 index 00000000..8920ea69 --- /dev/null +++ b/backend/internal/pkg/ctxkey/ctxkey.go @@ -0,0 +1,10 @@ +// Package ctxkey 定义用于 context.Value 的类型安全 key +package ctxkey + +// Key 定义 context key 的类型,避免使用内置 string 类型(staticcheck SA1029) +type Key string + +const ( + // ForcePlatform 强制平台(用于 /antigravity 路由),由 middleware.ForcePlatform 设置 + ForcePlatform Key = "ctx_force_platform" +) diff --git a/backend/internal/server/middleware/middleware.go b/backend/internal/server/middleware/middleware.go index 45643164..75b9f68e 100644 --- a/backend/internal/server/middleware/middleware.go +++ b/backend/internal/server/middleware/middleware.go @@ -3,6 +3,7 @@ package middleware import ( "context" + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" "github.com/gin-gonic/gin" ) @@ -22,16 +23,12 @@ const ( ContextKeyForcePlatform ContextKey = "force_platform" ) -// ctxKeyForcePlatformStr 用于 request.Context 的字符串 key(供 Service 读取) -// 注意:service 包中也需要使用相同的字符串 "ctx_force_platform" -const ctxKeyForcePlatformStr = "ctx_force_platform" - // ForcePlatform 返回设置强制平台的中间件 // 同时设置 request.Context(供 Service 使用)和 gin.Context(供 Handler 快速检查) func ForcePlatform(platform string) gin.HandlerFunc { return func(c *gin.Context) { - // 设置到 request.Context,使用字符串 key 供 Service 层读取 - ctx := context.WithValue(c.Request.Context(), ctxKeyForcePlatformStr, platform) + // 设置到 request.Context,使用 ctxkey.ForcePlatform 供 Service 层读取 + ctx := context.WithValue(c.Request.Context(), ctxkey.ForcePlatform, platform) c.Request = c.Request.WithContext(ctx) // 同时设置到 gin.Context,供 Handler 快速检查 c.Set(string(ContextKeyForcePlatform), platform) diff --git a/backend/internal/service/antigravity_oauth_service.go b/backend/internal/service/antigravity_oauth_service.go index 57565631..fc6cc74d 100644 --- a/backend/internal/service/antigravity_oauth_service.go +++ b/backend/internal/service/antigravity_oauth_service.go @@ -116,7 +116,7 @@ func (s *AntigravityOAuthService) ExchangeCode(ctx context.Context, input *Antig // 交换 token tokenResp, err := client.ExchangeCode(ctx, input.Code, session.CodeVerifier) if err != nil { - return nil, fmt.Errorf("Token 交换失败: %w", err) + return nil, fmt.Errorf("token 交换失败: %w", err) } // 删除 session @@ -184,7 +184,7 @@ func (s *AntigravityOAuthService) RefreshToken(ctx context.Context, refreshToken lastErr = err } - return nil, fmt.Errorf("Token 刷新失败 (重试后): %w", lastErr) + return nil, fmt.Errorf("token 刷新失败 (重试后): %w", lastErr) } func isNonRetryableAntigravityOAuthError(err error) bool { diff --git a/backend/internal/service/gateway_multiplatform_test.go b/backend/internal/service/gateway_multiplatform_test.go index 4dfef5f4..b54d7b4a 100644 --- a/backend/internal/service/gateway_multiplatform_test.go +++ b/backend/internal/service/gateway_multiplatform_test.go @@ -8,10 +8,16 @@ import ( "testing" "time" + "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/stretchr/testify/require" ) +// testConfig 返回一个用于测试的默认配置 +func testConfig() *config.Config { + return &config.Config{RunMode: config.RunModeStandard} +} + // mockAccountRepoForPlatform 单平台测试用的 mock type mockAccountRepoForPlatform struct { accounts []Account @@ -177,6 +183,7 @@ func TestGatewayService_SelectAccountForModelWithPlatform_Anthropic(t *testing.T svc := &GatewayService{ accountRepo: repo, cache: cache, + cfg: testConfig(), } acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic) @@ -206,6 +213,7 @@ func TestGatewayService_SelectAccountForModelWithPlatform_Antigravity(t *testing svc := &GatewayService{ accountRepo: repo, cache: cache, + cfg: testConfig(), } acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAntigravity) @@ -236,6 +244,7 @@ func TestGatewayService_SelectAccountForModelWithPlatform_PriorityAndLastUsed(t svc := &GatewayService{ accountRepo: repo, cache: cache, + cfg: testConfig(), } acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic) @@ -258,6 +267,7 @@ func TestGatewayService_SelectAccountForModelWithPlatform_NoAvailableAccounts(t svc := &GatewayService{ accountRepo: repo, cache: cache, + cfg: testConfig(), } acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic) @@ -286,6 +296,7 @@ func TestGatewayService_SelectAccountForModelWithPlatform_AllExcluded(t *testing svc := &GatewayService{ accountRepo: repo, cache: cache, + cfg: testConfig(), } excludedIDs := map[int64]struct{}{1: {}, 2: {}} @@ -361,6 +372,7 @@ func TestGatewayService_SelectAccountForModelWithPlatform_Schedulability(t *test svc := &GatewayService{ accountRepo: repo, cache: cache, + cfg: testConfig(), } acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic) @@ -394,6 +406,7 @@ func TestGatewayService_SelectAccountForModelWithPlatform_StickySession(t *testi svc := &GatewayService{ accountRepo: repo, cache: cache, + cfg: testConfig(), } acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "session-123", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic) @@ -421,6 +434,7 @@ func TestGatewayService_SelectAccountForModelWithPlatform_StickySession(t *testi svc := &GatewayService{ accountRepo: repo, cache: cache, + cfg: testConfig(), } // 请求 anthropic 平台,但粘性会话绑定的是 antigravity 账户 @@ -450,6 +464,7 @@ func TestGatewayService_SelectAccountForModelWithPlatform_StickySession(t *testi svc := &GatewayService{ accountRepo: repo, cache: cache, + cfg: testConfig(), } excludedIDs := map[int64]struct{}{1: {}} @@ -478,6 +493,7 @@ func TestGatewayService_SelectAccountForModelWithPlatform_StickySession(t *testi svc := &GatewayService{ accountRepo: repo, cache: cache, + cfg: testConfig(), } acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "session-123", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic) @@ -569,6 +585,7 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) { svc := &GatewayService{ accountRepo: repo, cache: cache, + cfg: testConfig(), } acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic) @@ -594,6 +611,7 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) { svc := &GatewayService{ accountRepo: repo, cache: cache, + cfg: testConfig(), } acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic) @@ -622,6 +640,7 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) { svc := &GatewayService{ accountRepo: repo, cache: cache, + cfg: testConfig(), } acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "session-123", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic) @@ -649,6 +668,7 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) { svc := &GatewayService{ accountRepo: repo, cache: cache, + cfg: testConfig(), } acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "session-123", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic) @@ -673,6 +693,7 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) { svc := &GatewayService{ accountRepo: repo, cache: cache, + cfg: testConfig(), } acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic) @@ -698,6 +719,7 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) { svc := &GatewayService{ accountRepo: repo, cache: cache, + cfg: testConfig(), } acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic) diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 08e3c1d1..e88e757a 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -18,6 +18,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/claude" + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" "github.com/tidwall/gjson" "github.com/tidwall/sjson" @@ -30,9 +31,6 @@ const ( stickySessionTTL = time.Hour // 粘性会话TTL ) -// ctxKeyForcePlatform 用于从 context 读取强制平台(由 middleware.ForcePlatform 设置) -// 必须与 middleware.ctxKeyForcePlatformStr 使用相同的字符串值 -const ctxKeyForcePlatform = "ctx_force_platform" // sseDataRe matches SSE data lines with optional whitespace after colon. // Some upstream APIs return non-standard "data:" without space (should be "data: "). @@ -300,7 +298,7 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) { // 优先检查 context 中的强制平台(/antigravity 路由) var platform string - forcePlatform, hasForcePlatform := ctx.Value(ctxKeyForcePlatform).(string) + forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string) if hasForcePlatform && forcePlatform != "" { platform = forcePlatform } else if groupID != nil { diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go index 025ca888..c7374ad6 100644 --- a/backend/internal/service/gemini_messages_compat_service.go +++ b/backend/internal/service/gemini_messages_compat_service.go @@ -18,6 +18,7 @@ import ( "strings" "time" + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" "github.com/Wei-Shaw/sub2api/internal/pkg/googleapi" @@ -74,7 +75,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context, func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) { // 优先检查 context 中的强制平台(/antigravity 路由) var platform string - forcePlatform, hasForcePlatform := ctx.Value(ctxKeyForcePlatform).(string) + forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string) if hasForcePlatform && forcePlatform != "" { platform = forcePlatform } else if groupID != nil {