fix: 修复 golangci-lint 检查错误

- SA1029: 创建 ctxkey 包定义类型安全的 context key
- ST1005: 错误字符串首字母改小写
- errcheck: 显式忽略 bytes.Buffer.Write 返回值
- 修复单元测试中 GatewayService 缺少 cfg 字段的问题
This commit is contained in:
song
2025-12-29 17:46:52 +08:00
parent eec8b4c91e
commit 21a04332ec
9 changed files with 101 additions and 73 deletions

View File

@@ -114,7 +114,7 @@ func (c *Client) ExchangeCode(ctx context.Context, code, codeVerifier string) (*
resp, err := c.httpClient.Do(req) resp, err := c.httpClient.Do(req)
if err != nil { if err != nil {
return nil, fmt.Errorf("Token 交换请求失败: %w", err) return nil, fmt.Errorf("token 交换请求失败: %w", err)
} }
defer func() { _ = resp.Body.Close() }() defer func() { _ = resp.Body.Close() }()
@@ -124,12 +124,12 @@ func (c *Client) ExchangeCode(ctx context.Context, code, codeVerifier string) (*
} }
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("Token 交换失败 (HTTP %d): %s", resp.StatusCode, string(bodyBytes)) return nil, fmt.Errorf("token 交换失败 (HTTP %d): %s", resp.StatusCode, string(bodyBytes))
} }
var tokenResp TokenResponse var tokenResp TokenResponse
if err := json.Unmarshal(bodyBytes, &tokenResp); err != nil { if err := json.Unmarshal(bodyBytes, &tokenResp); err != nil {
return nil, fmt.Errorf("Token 解析失败: %w", err) return nil, fmt.Errorf("token 解析失败: %w", err)
} }
return &tokenResp, nil return &tokenResp, nil
@@ -151,7 +151,7 @@ func (c *Client) RefreshToken(ctx context.Context, refreshToken string) (*TokenR
resp, err := c.httpClient.Do(req) resp, err := c.httpClient.Do(req)
if err != nil { if err != nil {
return nil, fmt.Errorf("Token 刷新请求失败: %w", err) return nil, fmt.Errorf("token 刷新请求失败: %w", err)
} }
defer func() { _ = resp.Body.Close() }() defer func() { _ = resp.Body.Close() }()
@@ -161,12 +161,12 @@ func (c *Client) RefreshToken(ctx context.Context, refreshToken string) (*TokenR
} }
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("Token 刷新失败 (HTTP %d): %s", resp.StatusCode, string(bodyBytes)) return nil, fmt.Errorf("token 刷新失败 (HTTP %d): %s", resp.StatusCode, string(bodyBytes))
} }
var tokenResp TokenResponse var tokenResp TokenResponse
if err := json.Unmarshal(bodyBytes, &tokenResp); err != nil { if err := json.Unmarshal(bodyBytes, &tokenResp); err != nil {
return nil, fmt.Errorf("Token 解析失败: %w", err) return nil, fmt.Errorf("token 解析失败: %w", err)
} }
return &tokenResp, nil return &tokenResp, nil

View File

@@ -14,13 +14,13 @@ type V1InternalRequest struct {
// GeminiRequest Gemini 请求内容 // GeminiRequest Gemini 请求内容
type GeminiRequest struct { type GeminiRequest struct {
Contents []GeminiContent `json:"contents"` Contents []GeminiContent `json:"contents"`
SystemInstruction *GeminiContent `json:"systemInstruction,omitempty"` SystemInstruction *GeminiContent `json:"systemInstruction,omitempty"`
GenerationConfig *GeminiGenerationConfig `json:"generationConfig,omitempty"` GenerationConfig *GeminiGenerationConfig `json:"generationConfig,omitempty"`
Tools []GeminiToolDeclaration `json:"tools,omitempty"` Tools []GeminiToolDeclaration `json:"tools,omitempty"`
ToolConfig *GeminiToolConfig `json:"toolConfig,omitempty"` ToolConfig *GeminiToolConfig `json:"toolConfig,omitempty"`
SafetySettings []GeminiSafetySetting `json:"safetySettings,omitempty"` SafetySettings []GeminiSafetySetting `json:"safetySettings,omitempty"`
SessionID string `json:"sessionId,omitempty"` SessionID string `json:"sessionId,omitempty"`
} }
// GeminiContent Gemini 内容 // GeminiContent Gemini 内容
@@ -31,10 +31,10 @@ type GeminiContent struct {
// GeminiPart Gemini 内容部分 // GeminiPart Gemini 内容部分
type GeminiPart struct { type GeminiPart struct {
Text string `json:"text,omitempty"` Text string `json:"text,omitempty"`
Thought bool `json:"thought,omitempty"` Thought bool `json:"thought,omitempty"`
ThoughtSignature string `json:"thoughtSignature,omitempty"` ThoughtSignature string `json:"thoughtSignature,omitempty"`
InlineData *GeminiInlineData `json:"inlineData,omitempty"` InlineData *GeminiInlineData `json:"inlineData,omitempty"`
FunctionCall *GeminiFunctionCall `json:"functionCall,omitempty"` FunctionCall *GeminiFunctionCall `json:"functionCall,omitempty"`
FunctionResponse *GeminiFunctionResponse `json:"functionResponse,omitempty"` FunctionResponse *GeminiFunctionResponse `json:"functionResponse,omitempty"`
} }
@@ -61,12 +61,12 @@ type GeminiFunctionResponse struct {
// GeminiGenerationConfig Gemini 生成配置 // GeminiGenerationConfig Gemini 生成配置
type GeminiGenerationConfig struct { type GeminiGenerationConfig struct {
MaxOutputTokens int `json:"maxOutputTokens,omitempty"` MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
Temperature *float64 `json:"temperature,omitempty"` Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"topP,omitempty"` TopP *float64 `json:"topP,omitempty"`
TopK *int `json:"topK,omitempty"` TopK *int `json:"topK,omitempty"`
ThinkingConfig *GeminiThinkingConfig `json:"thinkingConfig,omitempty"` ThinkingConfig *GeminiThinkingConfig `json:"thinkingConfig,omitempty"`
StopSequences []string `json:"stopSequences,omitempty"` StopSequences []string `json:"stopSequences,omitempty"`
} }
// GeminiThinkingConfig Gemini thinking 配置 // GeminiThinkingConfig Gemini thinking 配置

View File

@@ -72,7 +72,7 @@ func (p *StreamingProcessor) ProcessLine(line string) []byte {
// 发送 message_start // 发送 message_start
if !p.messageStartSent { if !p.messageStartSent {
result.Write(p.emitMessageStart(&v1Resp)) _, _ = result.Write(p.emitMessageStart(&v1Resp))
} }
// 更新 usage // 更新 usage
@@ -84,7 +84,7 @@ func (p *StreamingProcessor) ProcessLine(line string) []byte {
// 处理 parts // 处理 parts
if len(geminiResp.Candidates) > 0 && geminiResp.Candidates[0].Content != nil { if len(geminiResp.Candidates) > 0 && geminiResp.Candidates[0].Content != nil {
for _, part := range geminiResp.Candidates[0].Content.Parts { for _, part := range geminiResp.Candidates[0].Content.Parts {
result.Write(p.processPart(&part)) _, _ = result.Write(p.processPart(&part))
} }
} }
@@ -92,7 +92,7 @@ func (p *StreamingProcessor) ProcessLine(line string) []byte {
if len(geminiResp.Candidates) > 0 { if len(geminiResp.Candidates) > 0 {
finishReason := geminiResp.Candidates[0].FinishReason finishReason := geminiResp.Candidates[0].FinishReason
if finishReason != "" { if finishReason != "" {
result.Write(p.emitFinish(finishReason)) _, _ = result.Write(p.emitFinish(finishReason))
} }
} }
@@ -104,7 +104,7 @@ func (p *StreamingProcessor) Finish() ([]byte, *ClaudeUsage) {
var result bytes.Buffer var result bytes.Buffer
if !p.messageStopSent { if !p.messageStopSent {
result.Write(p.emitFinish("")) _, _ = result.Write(p.emitFinish(""))
} }
usage := &ClaudeUsage{ usage := &ClaudeUsage{
@@ -164,21 +164,21 @@ func (p *StreamingProcessor) processPart(part *GeminiPart) []byte {
if part.FunctionCall != nil { if part.FunctionCall != nil {
// 先处理 trailingSignature // 先处理 trailingSignature
if p.trailingSignature != "" { if p.trailingSignature != "" {
result.Write(p.endBlock()) _, _ = result.Write(p.endBlock())
result.Write(p.emitEmptyThinkingWithSignature(p.trailingSignature)) _, _ = result.Write(p.emitEmptyThinkingWithSignature(p.trailingSignature))
p.trailingSignature = "" p.trailingSignature = ""
} }
result.Write(p.processFunctionCall(part.FunctionCall, signature)) _, _ = result.Write(p.processFunctionCall(part.FunctionCall, signature))
return result.Bytes() return result.Bytes()
} }
// 2. Text 处理 // 2. Text 处理
if part.Text != "" || part.Thought { if part.Text != "" || part.Thought {
if part.Thought { if part.Thought {
result.Write(p.processThinking(part.Text, signature)) _, _ = result.Write(p.processThinking(part.Text, signature))
} else { } else {
result.Write(p.processText(part.Text, signature)) _, _ = result.Write(p.processText(part.Text, signature))
} }
} }
@@ -186,7 +186,7 @@ func (p *StreamingProcessor) processPart(part *GeminiPart) []byte {
if part.InlineData != nil && part.InlineData.Data != "" { if part.InlineData != nil && part.InlineData.Data != "" {
markdownImg := fmt.Sprintf("![image](data:%s;base64,%s)", markdownImg := fmt.Sprintf("![image](data:%s;base64,%s)",
part.InlineData.MimeType, part.InlineData.Data) part.InlineData.MimeType, part.InlineData.Data)
result.Write(p.processText(markdownImg, "")) _, _ = result.Write(p.processText(markdownImg, ""))
} }
return result.Bytes() return result.Bytes()
@@ -198,21 +198,21 @@ func (p *StreamingProcessor) processThinking(text, signature string) []byte {
// 处理之前的 trailingSignature // 处理之前的 trailingSignature
if p.trailingSignature != "" { if p.trailingSignature != "" {
result.Write(p.endBlock()) _, _ = result.Write(p.endBlock())
result.Write(p.emitEmptyThinkingWithSignature(p.trailingSignature)) _, _ = result.Write(p.emitEmptyThinkingWithSignature(p.trailingSignature))
p.trailingSignature = "" p.trailingSignature = ""
} }
// 开始或继续 thinking 块 // 开始或继续 thinking 块
if p.blockType != BlockTypeThinking { if p.blockType != BlockTypeThinking {
result.Write(p.startBlock(BlockTypeThinking, map[string]interface{}{ _, _ = result.Write(p.startBlock(BlockTypeThinking, map[string]interface{}{
"type": "thinking", "type": "thinking",
"thinking": "", "thinking": "",
})) }))
} }
if text != "" { if text != "" {
result.Write(p.emitDelta("thinking_delta", map[string]interface{}{ _, _ = result.Write(p.emitDelta("thinking_delta", map[string]interface{}{
"thinking": text, "thinking": text,
})) }))
} }
@@ -239,34 +239,34 @@ func (p *StreamingProcessor) processText(text, signature string) []byte {
// 处理之前的 trailingSignature // 处理之前的 trailingSignature
if p.trailingSignature != "" { if p.trailingSignature != "" {
result.Write(p.endBlock()) _, _ = result.Write(p.endBlock())
result.Write(p.emitEmptyThinkingWithSignature(p.trailingSignature)) _, _ = result.Write(p.emitEmptyThinkingWithSignature(p.trailingSignature))
p.trailingSignature = "" p.trailingSignature = ""
} }
// 非空 text 带签名 - 特殊处理 // 非空 text 带签名 - 特殊处理
if signature != "" { if signature != "" {
result.Write(p.startBlock(BlockTypeText, map[string]interface{}{ _, _ = result.Write(p.startBlock(BlockTypeText, map[string]interface{}{
"type": "text", "type": "text",
"text": "", "text": "",
})) }))
result.Write(p.emitDelta("text_delta", map[string]interface{}{ _, _ = result.Write(p.emitDelta("text_delta", map[string]interface{}{
"text": text, "text": text,
})) }))
result.Write(p.endBlock()) _, _ = result.Write(p.endBlock())
result.Write(p.emitEmptyThinkingWithSignature(signature)) _, _ = result.Write(p.emitEmptyThinkingWithSignature(signature))
return result.Bytes() return result.Bytes()
} }
// 普通 text (无签名) // 普通 text (无签名)
if p.blockType != BlockTypeText { if p.blockType != BlockTypeText {
result.Write(p.startBlock(BlockTypeText, map[string]interface{}{ _, _ = result.Write(p.startBlock(BlockTypeText, map[string]interface{}{
"type": "text", "type": "text",
"text": "", "text": "",
})) }))
} }
result.Write(p.emitDelta("text_delta", map[string]interface{}{ _, _ = result.Write(p.emitDelta("text_delta", map[string]interface{}{
"text": text, "text": text,
})) }))
@@ -295,17 +295,17 @@ func (p *StreamingProcessor) processFunctionCall(fc *GeminiFunctionCall, signatu
toolUse["signature"] = signature toolUse["signature"] = signature
} }
result.Write(p.startBlock(BlockTypeFunction, toolUse)) _, _ = result.Write(p.startBlock(BlockTypeFunction, toolUse))
// 发送 input_json_delta // 发送 input_json_delta
if fc.Args != nil { if fc.Args != nil {
argsJSON, _ := json.Marshal(fc.Args) argsJSON, _ := json.Marshal(fc.Args)
result.Write(p.emitDelta("input_json_delta", map[string]interface{}{ _, _ = result.Write(p.emitDelta("input_json_delta", map[string]interface{}{
"partial_json": string(argsJSON), "partial_json": string(argsJSON),
})) }))
} }
result.Write(p.endBlock()) _, _ = result.Write(p.endBlock())
return result.Bytes() return result.Bytes()
} }
@@ -315,7 +315,7 @@ func (p *StreamingProcessor) startBlock(blockType BlockType, contentBlock map[st
var result bytes.Buffer var result bytes.Buffer
if p.blockType != BlockTypeNone { if p.blockType != BlockTypeNone {
result.Write(p.endBlock()) _, _ = result.Write(p.endBlock())
} }
event := map[string]interface{}{ event := map[string]interface{}{
@@ -324,7 +324,7 @@ func (p *StreamingProcessor) startBlock(blockType BlockType, contentBlock map[st
"content_block": contentBlock, "content_block": contentBlock,
} }
result.Write(p.formatSSE("content_block_start", event)) _, _ = result.Write(p.formatSSE("content_block_start", event))
p.blockType = blockType p.blockType = blockType
return result.Bytes() return result.Bytes()
@@ -340,7 +340,7 @@ func (p *StreamingProcessor) endBlock() []byte {
// Thinking 块结束时发送暂存的签名 // Thinking 块结束时发送暂存的签名
if p.blockType == BlockTypeThinking && p.pendingSignature != "" { if p.blockType == BlockTypeThinking && p.pendingSignature != "" {
result.Write(p.emitDelta("signature_delta", map[string]interface{}{ _, _ = result.Write(p.emitDelta("signature_delta", map[string]interface{}{
"signature": p.pendingSignature, "signature": p.pendingSignature,
})) }))
p.pendingSignature = "" p.pendingSignature = ""
@@ -351,7 +351,7 @@ func (p *StreamingProcessor) endBlock() []byte {
"index": p.blockIndex, "index": p.blockIndex,
} }
result.Write(p.formatSSE("content_block_stop", event)) _, _ = result.Write(p.formatSSE("content_block_stop", event))
p.blockIndex++ p.blockIndex++
p.blockType = BlockTypeNone p.blockType = BlockTypeNone
@@ -381,17 +381,17 @@ func (p *StreamingProcessor) emitDelta(deltaType string, deltaContent map[string
func (p *StreamingProcessor) emitEmptyThinkingWithSignature(signature string) []byte { func (p *StreamingProcessor) emitEmptyThinkingWithSignature(signature string) []byte {
var result bytes.Buffer var result bytes.Buffer
result.Write(p.startBlock(BlockTypeThinking, map[string]interface{}{ _, _ = result.Write(p.startBlock(BlockTypeThinking, map[string]interface{}{
"type": "thinking", "type": "thinking",
"thinking": "", "thinking": "",
})) }))
result.Write(p.emitDelta("thinking_delta", map[string]interface{}{ _, _ = result.Write(p.emitDelta("thinking_delta", map[string]interface{}{
"thinking": "", "thinking": "",
})) }))
result.Write(p.emitDelta("signature_delta", map[string]interface{}{ _, _ = result.Write(p.emitDelta("signature_delta", map[string]interface{}{
"signature": signature, "signature": signature,
})) }))
result.Write(p.endBlock()) _, _ = result.Write(p.endBlock())
return result.Bytes() return result.Bytes()
} }
@@ -401,11 +401,11 @@ func (p *StreamingProcessor) emitFinish(finishReason string) []byte {
var result bytes.Buffer var result bytes.Buffer
// 关闭最后一个块 // 关闭最后一个块
result.Write(p.endBlock()) _, _ = result.Write(p.endBlock())
// 处理 trailingSignature // 处理 trailingSignature
if p.trailingSignature != "" { if p.trailingSignature != "" {
result.Write(p.emitEmptyThinkingWithSignature(p.trailingSignature)) _, _ = result.Write(p.emitEmptyThinkingWithSignature(p.trailingSignature))
p.trailingSignature = "" p.trailingSignature = ""
} }
@@ -431,13 +431,13 @@ func (p *StreamingProcessor) emitFinish(finishReason string) []byte {
"usage": usage, "usage": usage,
} }
result.Write(p.formatSSE("message_delta", deltaEvent)) _, _ = result.Write(p.formatSSE("message_delta", deltaEvent))
if !p.messageStopSent { if !p.messageStopSent {
stopEvent := map[string]interface{}{ stopEvent := map[string]interface{}{
"type": "message_stop", "type": "message_stop",
} }
result.Write(p.formatSSE("message_stop", stopEvent)) _, _ = result.Write(p.formatSSE("message_stop", stopEvent))
p.messageStopSent = true p.messageStopSent = true
} }

View File

@@ -0,0 +1,10 @@
// Package ctxkey 定义用于 context.Value 的类型安全 key
package ctxkey
// Key 定义 context key 的类型,避免使用内置 string 类型staticcheck SA1029
type Key string
const (
// ForcePlatform 强制平台(用于 /antigravity 路由),由 middleware.ForcePlatform 设置
ForcePlatform Key = "ctx_force_platform"
)

View File

@@ -3,6 +3,7 @@ package middleware
import ( import (
"context" "context"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@@ -22,16 +23,12 @@ const (
ContextKeyForcePlatform ContextKey = "force_platform" ContextKeyForcePlatform ContextKey = "force_platform"
) )
// ctxKeyForcePlatformStr 用于 request.Context 的字符串 key供 Service 读取)
// 注意service 包中也需要使用相同的字符串 "ctx_force_platform"
const ctxKeyForcePlatformStr = "ctx_force_platform"
// ForcePlatform 返回设置强制平台的中间件 // ForcePlatform 返回设置强制平台的中间件
// 同时设置 request.Context供 Service 使用)和 gin.Context供 Handler 快速检查) // 同时设置 request.Context供 Service 使用)和 gin.Context供 Handler 快速检查)
func ForcePlatform(platform string) gin.HandlerFunc { func ForcePlatform(platform string) gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
// 设置到 request.Context使用字符串 key 供 Service 层读取 // 设置到 request.Context使用 ctxkey.ForcePlatform 供 Service 层读取
ctx := context.WithValue(c.Request.Context(), ctxKeyForcePlatformStr, platform) ctx := context.WithValue(c.Request.Context(), ctxkey.ForcePlatform, platform)
c.Request = c.Request.WithContext(ctx) c.Request = c.Request.WithContext(ctx)
// 同时设置到 gin.Context供 Handler 快速检查 // 同时设置到 gin.Context供 Handler 快速检查
c.Set(string(ContextKeyForcePlatform), platform) c.Set(string(ContextKeyForcePlatform), platform)

View File

@@ -116,7 +116,7 @@ func (s *AntigravityOAuthService) ExchangeCode(ctx context.Context, input *Antig
// 交换 token // 交换 token
tokenResp, err := client.ExchangeCode(ctx, input.Code, session.CodeVerifier) tokenResp, err := client.ExchangeCode(ctx, input.Code, session.CodeVerifier)
if err != nil { if err != nil {
return nil, fmt.Errorf("Token 交换失败: %w", err) return nil, fmt.Errorf("token 交换失败: %w", err)
} }
// 删除 session // 删除 session
@@ -184,7 +184,7 @@ func (s *AntigravityOAuthService) RefreshToken(ctx context.Context, refreshToken
lastErr = err lastErr = err
} }
return nil, fmt.Errorf("Token 刷新失败 (重试后): %w", lastErr) return nil, fmt.Errorf("token 刷新失败 (重试后): %w", lastErr)
} }
func isNonRetryableAntigravityOAuthError(err error) bool { func isNonRetryableAntigravityOAuthError(err error) bool {

View File

@@ -8,10 +8,16 @@ import (
"testing" "testing"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
// testConfig 返回一个用于测试的默认配置
func testConfig() *config.Config {
return &config.Config{RunMode: config.RunModeStandard}
}
// mockAccountRepoForPlatform 单平台测试用的 mock // mockAccountRepoForPlatform 单平台测试用的 mock
type mockAccountRepoForPlatform struct { type mockAccountRepoForPlatform struct {
accounts []Account accounts []Account
@@ -177,6 +183,7 @@ func TestGatewayService_SelectAccountForModelWithPlatform_Anthropic(t *testing.T
svc := &GatewayService{ svc := &GatewayService{
accountRepo: repo, accountRepo: repo,
cache: cache, cache: cache,
cfg: testConfig(),
} }
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic) acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
@@ -206,6 +213,7 @@ func TestGatewayService_SelectAccountForModelWithPlatform_Antigravity(t *testing
svc := &GatewayService{ svc := &GatewayService{
accountRepo: repo, accountRepo: repo,
cache: cache, cache: cache,
cfg: testConfig(),
} }
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAntigravity) acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAntigravity)
@@ -236,6 +244,7 @@ func TestGatewayService_SelectAccountForModelWithPlatform_PriorityAndLastUsed(t
svc := &GatewayService{ svc := &GatewayService{
accountRepo: repo, accountRepo: repo,
cache: cache, cache: cache,
cfg: testConfig(),
} }
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic) acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
@@ -258,6 +267,7 @@ func TestGatewayService_SelectAccountForModelWithPlatform_NoAvailableAccounts(t
svc := &GatewayService{ svc := &GatewayService{
accountRepo: repo, accountRepo: repo,
cache: cache, cache: cache,
cfg: testConfig(),
} }
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic) acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
@@ -286,6 +296,7 @@ func TestGatewayService_SelectAccountForModelWithPlatform_AllExcluded(t *testing
svc := &GatewayService{ svc := &GatewayService{
accountRepo: repo, accountRepo: repo,
cache: cache, cache: cache,
cfg: testConfig(),
} }
excludedIDs := map[int64]struct{}{1: {}, 2: {}} excludedIDs := map[int64]struct{}{1: {}, 2: {}}
@@ -361,6 +372,7 @@ func TestGatewayService_SelectAccountForModelWithPlatform_Schedulability(t *test
svc := &GatewayService{ svc := &GatewayService{
accountRepo: repo, accountRepo: repo,
cache: cache, cache: cache,
cfg: testConfig(),
} }
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic) acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
@@ -394,6 +406,7 @@ func TestGatewayService_SelectAccountForModelWithPlatform_StickySession(t *testi
svc := &GatewayService{ svc := &GatewayService{
accountRepo: repo, accountRepo: repo,
cache: cache, cache: cache,
cfg: testConfig(),
} }
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "session-123", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic) acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "session-123", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
@@ -421,6 +434,7 @@ func TestGatewayService_SelectAccountForModelWithPlatform_StickySession(t *testi
svc := &GatewayService{ svc := &GatewayService{
accountRepo: repo, accountRepo: repo,
cache: cache, cache: cache,
cfg: testConfig(),
} }
// 请求 anthropic 平台,但粘性会话绑定的是 antigravity 账户 // 请求 anthropic 平台,但粘性会话绑定的是 antigravity 账户
@@ -450,6 +464,7 @@ func TestGatewayService_SelectAccountForModelWithPlatform_StickySession(t *testi
svc := &GatewayService{ svc := &GatewayService{
accountRepo: repo, accountRepo: repo,
cache: cache, cache: cache,
cfg: testConfig(),
} }
excludedIDs := map[int64]struct{}{1: {}} excludedIDs := map[int64]struct{}{1: {}}
@@ -478,6 +493,7 @@ func TestGatewayService_SelectAccountForModelWithPlatform_StickySession(t *testi
svc := &GatewayService{ svc := &GatewayService{
accountRepo: repo, accountRepo: repo,
cache: cache, cache: cache,
cfg: testConfig(),
} }
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "session-123", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic) acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "session-123", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
@@ -569,6 +585,7 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
svc := &GatewayService{ svc := &GatewayService{
accountRepo: repo, accountRepo: repo,
cache: cache, cache: cache,
cfg: testConfig(),
} }
acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic) acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
@@ -594,6 +611,7 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
svc := &GatewayService{ svc := &GatewayService{
accountRepo: repo, accountRepo: repo,
cache: cache, cache: cache,
cfg: testConfig(),
} }
acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic) acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
@@ -622,6 +640,7 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
svc := &GatewayService{ svc := &GatewayService{
accountRepo: repo, accountRepo: repo,
cache: cache, cache: cache,
cfg: testConfig(),
} }
acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "session-123", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic) acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "session-123", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
@@ -649,6 +668,7 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
svc := &GatewayService{ svc := &GatewayService{
accountRepo: repo, accountRepo: repo,
cache: cache, cache: cache,
cfg: testConfig(),
} }
acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "session-123", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic) acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "session-123", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
@@ -673,6 +693,7 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
svc := &GatewayService{ svc := &GatewayService{
accountRepo: repo, accountRepo: repo,
cache: cache, cache: cache,
cfg: testConfig(),
} }
acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic) acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
@@ -698,6 +719,7 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
svc := &GatewayService{ svc := &GatewayService{
accountRepo: repo, accountRepo: repo,
cache: cache, cache: cache,
cfg: testConfig(),
} }
acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic) acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)

View File

@@ -18,6 +18,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
"github.com/tidwall/sjson" "github.com/tidwall/sjson"
@@ -30,9 +31,6 @@ const (
stickySessionTTL = time.Hour // 粘性会话TTL stickySessionTTL = time.Hour // 粘性会话TTL
) )
// ctxKeyForcePlatform 用于从 context 读取强制平台(由 middleware.ForcePlatform 设置)
// 必须与 middleware.ctxKeyForcePlatformStr 使用相同的字符串值
const ctxKeyForcePlatform = "ctx_force_platform"
// sseDataRe matches SSE data lines with optional whitespace after colon. // sseDataRe matches SSE data lines with optional whitespace after colon.
// Some upstream APIs return non-standard "data:" without space (should be "data: "). // Some upstream APIs return non-standard "data:" without space (should be "data: ").
@@ -300,7 +298,7 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int
func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) { func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
// 优先检查 context 中的强制平台(/antigravity 路由) // 优先检查 context 中的强制平台(/antigravity 路由)
var platform string var platform string
forcePlatform, hasForcePlatform := ctx.Value(ctxKeyForcePlatform).(string) forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
if hasForcePlatform && forcePlatform != "" { if hasForcePlatform && forcePlatform != "" {
platform = forcePlatform platform = forcePlatform
} else if groupID != nil { } else if groupID != nil {

View File

@@ -18,6 +18,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi" "github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
@@ -74,7 +75,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context,
func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) { func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
// 优先检查 context 中的强制平台(/antigravity 路由) // 优先检查 context 中的强制平台(/antigravity 路由)
var platform string var platform string
forcePlatform, hasForcePlatform := ctx.Value(ctxKeyForcePlatform).(string) forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
if hasForcePlatform && forcePlatform != "" { if hasForcePlatform && forcePlatform != "" {
platform = forcePlatform platform = forcePlatform
} else if groupID != nil { } else if groupID != nil {