fix: 修复 golangci-lint 检查错误
- SA1029: 创建 ctxkey 包定义类型安全的 context key - ST1005: 错误字符串首字母改小写 - errcheck: 显式忽略 bytes.Buffer.Write 返回值 - 修复单元测试中 GatewayService 缺少 cfg 字段的问题
This commit is contained in:
@@ -114,7 +114,7 @@ func (c *Client) ExchangeCode(ctx context.Context, code, codeVerifier string) (*
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Token 交换请求失败: %w", err)
|
||||
return nil, fmt.Errorf("token 交换请求失败: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
@@ -124,12 +124,12 @@ func (c *Client) ExchangeCode(ctx context.Context, code, codeVerifier string) (*
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("Token 交换失败 (HTTP %d): %s", resp.StatusCode, string(bodyBytes))
|
||||
return nil, fmt.Errorf("token 交换失败 (HTTP %d): %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
var tokenResp TokenResponse
|
||||
if err := json.Unmarshal(bodyBytes, &tokenResp); err != nil {
|
||||
return nil, fmt.Errorf("Token 解析失败: %w", err)
|
||||
return nil, fmt.Errorf("token 解析失败: %w", err)
|
||||
}
|
||||
|
||||
return &tokenResp, nil
|
||||
@@ -151,7 +151,7 @@ func (c *Client) RefreshToken(ctx context.Context, refreshToken string) (*TokenR
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Token 刷新请求失败: %w", err)
|
||||
return nil, fmt.Errorf("token 刷新请求失败: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
@@ -161,12 +161,12 @@ func (c *Client) RefreshToken(ctx context.Context, refreshToken string) (*TokenR
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("Token 刷新失败 (HTTP %d): %s", resp.StatusCode, string(bodyBytes))
|
||||
return nil, fmt.Errorf("token 刷新失败 (HTTP %d): %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
var tokenResp TokenResponse
|
||||
if err := json.Unmarshal(bodyBytes, &tokenResp); err != nil {
|
||||
return nil, fmt.Errorf("Token 解析失败: %w", err)
|
||||
return nil, fmt.Errorf("token 解析失败: %w", err)
|
||||
}
|
||||
|
||||
return &tokenResp, nil
|
||||
|
||||
@@ -14,13 +14,13 @@ type V1InternalRequest struct {
|
||||
|
||||
// GeminiRequest Gemini 请求内容
|
||||
type GeminiRequest struct {
|
||||
Contents []GeminiContent `json:"contents"`
|
||||
SystemInstruction *GeminiContent `json:"systemInstruction,omitempty"`
|
||||
GenerationConfig *GeminiGenerationConfig `json:"generationConfig,omitempty"`
|
||||
Tools []GeminiToolDeclaration `json:"tools,omitempty"`
|
||||
ToolConfig *GeminiToolConfig `json:"toolConfig,omitempty"`
|
||||
SafetySettings []GeminiSafetySetting `json:"safetySettings,omitempty"`
|
||||
SessionID string `json:"sessionId,omitempty"`
|
||||
Contents []GeminiContent `json:"contents"`
|
||||
SystemInstruction *GeminiContent `json:"systemInstruction,omitempty"`
|
||||
GenerationConfig *GeminiGenerationConfig `json:"generationConfig,omitempty"`
|
||||
Tools []GeminiToolDeclaration `json:"tools,omitempty"`
|
||||
ToolConfig *GeminiToolConfig `json:"toolConfig,omitempty"`
|
||||
SafetySettings []GeminiSafetySetting `json:"safetySettings,omitempty"`
|
||||
SessionID string `json:"sessionId,omitempty"`
|
||||
}
|
||||
|
||||
// GeminiContent Gemini 内容
|
||||
@@ -31,10 +31,10 @@ type GeminiContent struct {
|
||||
|
||||
// GeminiPart Gemini 内容部分
|
||||
type GeminiPart struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
Thought bool `json:"thought,omitempty"`
|
||||
ThoughtSignature string `json:"thoughtSignature,omitempty"`
|
||||
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
|
||||
Text string `json:"text,omitempty"`
|
||||
Thought bool `json:"thought,omitempty"`
|
||||
ThoughtSignature string `json:"thoughtSignature,omitempty"`
|
||||
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
|
||||
FunctionCall *GeminiFunctionCall `json:"functionCall,omitempty"`
|
||||
FunctionResponse *GeminiFunctionResponse `json:"functionResponse,omitempty"`
|
||||
}
|
||||
@@ -61,12 +61,12 @@ type GeminiFunctionResponse struct {
|
||||
|
||||
// GeminiGenerationConfig Gemini 生成配置
|
||||
type GeminiGenerationConfig struct {
|
||||
MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"topP,omitempty"`
|
||||
TopK *int `json:"topK,omitempty"`
|
||||
ThinkingConfig *GeminiThinkingConfig `json:"thinkingConfig,omitempty"`
|
||||
StopSequences []string `json:"stopSequences,omitempty"`
|
||||
MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"topP,omitempty"`
|
||||
TopK *int `json:"topK,omitempty"`
|
||||
ThinkingConfig *GeminiThinkingConfig `json:"thinkingConfig,omitempty"`
|
||||
StopSequences []string `json:"stopSequences,omitempty"`
|
||||
}
|
||||
|
||||
// GeminiThinkingConfig Gemini thinking 配置
|
||||
|
||||
@@ -72,7 +72,7 @@ func (p *StreamingProcessor) ProcessLine(line string) []byte {
|
||||
|
||||
// 发送 message_start
|
||||
if !p.messageStartSent {
|
||||
result.Write(p.emitMessageStart(&v1Resp))
|
||||
_, _ = result.Write(p.emitMessageStart(&v1Resp))
|
||||
}
|
||||
|
||||
// 更新 usage
|
||||
@@ -84,7 +84,7 @@ func (p *StreamingProcessor) ProcessLine(line string) []byte {
|
||||
// 处理 parts
|
||||
if len(geminiResp.Candidates) > 0 && geminiResp.Candidates[0].Content != nil {
|
||||
for _, part := range geminiResp.Candidates[0].Content.Parts {
|
||||
result.Write(p.processPart(&part))
|
||||
_, _ = result.Write(p.processPart(&part))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -92,7 +92,7 @@ func (p *StreamingProcessor) ProcessLine(line string) []byte {
|
||||
if len(geminiResp.Candidates) > 0 {
|
||||
finishReason := geminiResp.Candidates[0].FinishReason
|
||||
if finishReason != "" {
|
||||
result.Write(p.emitFinish(finishReason))
|
||||
_, _ = result.Write(p.emitFinish(finishReason))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -104,7 +104,7 @@ func (p *StreamingProcessor) Finish() ([]byte, *ClaudeUsage) {
|
||||
var result bytes.Buffer
|
||||
|
||||
if !p.messageStopSent {
|
||||
result.Write(p.emitFinish(""))
|
||||
_, _ = result.Write(p.emitFinish(""))
|
||||
}
|
||||
|
||||
usage := &ClaudeUsage{
|
||||
@@ -164,21 +164,21 @@ func (p *StreamingProcessor) processPart(part *GeminiPart) []byte {
|
||||
if part.FunctionCall != nil {
|
||||
// 先处理 trailingSignature
|
||||
if p.trailingSignature != "" {
|
||||
result.Write(p.endBlock())
|
||||
result.Write(p.emitEmptyThinkingWithSignature(p.trailingSignature))
|
||||
_, _ = result.Write(p.endBlock())
|
||||
_, _ = result.Write(p.emitEmptyThinkingWithSignature(p.trailingSignature))
|
||||
p.trailingSignature = ""
|
||||
}
|
||||
|
||||
result.Write(p.processFunctionCall(part.FunctionCall, signature))
|
||||
_, _ = result.Write(p.processFunctionCall(part.FunctionCall, signature))
|
||||
return result.Bytes()
|
||||
}
|
||||
|
||||
// 2. Text 处理
|
||||
if part.Text != "" || part.Thought {
|
||||
if part.Thought {
|
||||
result.Write(p.processThinking(part.Text, signature))
|
||||
_, _ = result.Write(p.processThinking(part.Text, signature))
|
||||
} else {
|
||||
result.Write(p.processText(part.Text, signature))
|
||||
_, _ = result.Write(p.processText(part.Text, signature))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -186,7 +186,7 @@ func (p *StreamingProcessor) processPart(part *GeminiPart) []byte {
|
||||
if part.InlineData != nil && part.InlineData.Data != "" {
|
||||
markdownImg := fmt.Sprintf("",
|
||||
part.InlineData.MimeType, part.InlineData.Data)
|
||||
result.Write(p.processText(markdownImg, ""))
|
||||
_, _ = result.Write(p.processText(markdownImg, ""))
|
||||
}
|
||||
|
||||
return result.Bytes()
|
||||
@@ -198,21 +198,21 @@ func (p *StreamingProcessor) processThinking(text, signature string) []byte {
|
||||
|
||||
// 处理之前的 trailingSignature
|
||||
if p.trailingSignature != "" {
|
||||
result.Write(p.endBlock())
|
||||
result.Write(p.emitEmptyThinkingWithSignature(p.trailingSignature))
|
||||
_, _ = result.Write(p.endBlock())
|
||||
_, _ = result.Write(p.emitEmptyThinkingWithSignature(p.trailingSignature))
|
||||
p.trailingSignature = ""
|
||||
}
|
||||
|
||||
// 开始或继续 thinking 块
|
||||
if p.blockType != BlockTypeThinking {
|
||||
result.Write(p.startBlock(BlockTypeThinking, map[string]interface{}{
|
||||
_, _ = result.Write(p.startBlock(BlockTypeThinking, map[string]interface{}{
|
||||
"type": "thinking",
|
||||
"thinking": "",
|
||||
}))
|
||||
}
|
||||
|
||||
if text != "" {
|
||||
result.Write(p.emitDelta("thinking_delta", map[string]interface{}{
|
||||
_, _ = result.Write(p.emitDelta("thinking_delta", map[string]interface{}{
|
||||
"thinking": text,
|
||||
}))
|
||||
}
|
||||
@@ -239,34 +239,34 @@ func (p *StreamingProcessor) processText(text, signature string) []byte {
|
||||
|
||||
// 处理之前的 trailingSignature
|
||||
if p.trailingSignature != "" {
|
||||
result.Write(p.endBlock())
|
||||
result.Write(p.emitEmptyThinkingWithSignature(p.trailingSignature))
|
||||
_, _ = result.Write(p.endBlock())
|
||||
_, _ = result.Write(p.emitEmptyThinkingWithSignature(p.trailingSignature))
|
||||
p.trailingSignature = ""
|
||||
}
|
||||
|
||||
// 非空 text 带签名 - 特殊处理
|
||||
if signature != "" {
|
||||
result.Write(p.startBlock(BlockTypeText, map[string]interface{}{
|
||||
_, _ = result.Write(p.startBlock(BlockTypeText, map[string]interface{}{
|
||||
"type": "text",
|
||||
"text": "",
|
||||
}))
|
||||
result.Write(p.emitDelta("text_delta", map[string]interface{}{
|
||||
_, _ = result.Write(p.emitDelta("text_delta", map[string]interface{}{
|
||||
"text": text,
|
||||
}))
|
||||
result.Write(p.endBlock())
|
||||
result.Write(p.emitEmptyThinkingWithSignature(signature))
|
||||
_, _ = result.Write(p.endBlock())
|
||||
_, _ = result.Write(p.emitEmptyThinkingWithSignature(signature))
|
||||
return result.Bytes()
|
||||
}
|
||||
|
||||
// 普通 text (无签名)
|
||||
if p.blockType != BlockTypeText {
|
||||
result.Write(p.startBlock(BlockTypeText, map[string]interface{}{
|
||||
_, _ = result.Write(p.startBlock(BlockTypeText, map[string]interface{}{
|
||||
"type": "text",
|
||||
"text": "",
|
||||
}))
|
||||
}
|
||||
|
||||
result.Write(p.emitDelta("text_delta", map[string]interface{}{
|
||||
_, _ = result.Write(p.emitDelta("text_delta", map[string]interface{}{
|
||||
"text": text,
|
||||
}))
|
||||
|
||||
@@ -295,17 +295,17 @@ func (p *StreamingProcessor) processFunctionCall(fc *GeminiFunctionCall, signatu
|
||||
toolUse["signature"] = signature
|
||||
}
|
||||
|
||||
result.Write(p.startBlock(BlockTypeFunction, toolUse))
|
||||
_, _ = result.Write(p.startBlock(BlockTypeFunction, toolUse))
|
||||
|
||||
// 发送 input_json_delta
|
||||
if fc.Args != nil {
|
||||
argsJSON, _ := json.Marshal(fc.Args)
|
||||
result.Write(p.emitDelta("input_json_delta", map[string]interface{}{
|
||||
_, _ = result.Write(p.emitDelta("input_json_delta", map[string]interface{}{
|
||||
"partial_json": string(argsJSON),
|
||||
}))
|
||||
}
|
||||
|
||||
result.Write(p.endBlock())
|
||||
_, _ = result.Write(p.endBlock())
|
||||
|
||||
return result.Bytes()
|
||||
}
|
||||
@@ -315,7 +315,7 @@ func (p *StreamingProcessor) startBlock(blockType BlockType, contentBlock map[st
|
||||
var result bytes.Buffer
|
||||
|
||||
if p.blockType != BlockTypeNone {
|
||||
result.Write(p.endBlock())
|
||||
_, _ = result.Write(p.endBlock())
|
||||
}
|
||||
|
||||
event := map[string]interface{}{
|
||||
@@ -324,7 +324,7 @@ func (p *StreamingProcessor) startBlock(blockType BlockType, contentBlock map[st
|
||||
"content_block": contentBlock,
|
||||
}
|
||||
|
||||
result.Write(p.formatSSE("content_block_start", event))
|
||||
_, _ = result.Write(p.formatSSE("content_block_start", event))
|
||||
p.blockType = blockType
|
||||
|
||||
return result.Bytes()
|
||||
@@ -340,7 +340,7 @@ func (p *StreamingProcessor) endBlock() []byte {
|
||||
|
||||
// Thinking 块结束时发送暂存的签名
|
||||
if p.blockType == BlockTypeThinking && p.pendingSignature != "" {
|
||||
result.Write(p.emitDelta("signature_delta", map[string]interface{}{
|
||||
_, _ = result.Write(p.emitDelta("signature_delta", map[string]interface{}{
|
||||
"signature": p.pendingSignature,
|
||||
}))
|
||||
p.pendingSignature = ""
|
||||
@@ -351,7 +351,7 @@ func (p *StreamingProcessor) endBlock() []byte {
|
||||
"index": p.blockIndex,
|
||||
}
|
||||
|
||||
result.Write(p.formatSSE("content_block_stop", event))
|
||||
_, _ = result.Write(p.formatSSE("content_block_stop", event))
|
||||
|
||||
p.blockIndex++
|
||||
p.blockType = BlockTypeNone
|
||||
@@ -381,17 +381,17 @@ func (p *StreamingProcessor) emitDelta(deltaType string, deltaContent map[string
|
||||
func (p *StreamingProcessor) emitEmptyThinkingWithSignature(signature string) []byte {
|
||||
var result bytes.Buffer
|
||||
|
||||
result.Write(p.startBlock(BlockTypeThinking, map[string]interface{}{
|
||||
_, _ = result.Write(p.startBlock(BlockTypeThinking, map[string]interface{}{
|
||||
"type": "thinking",
|
||||
"thinking": "",
|
||||
}))
|
||||
result.Write(p.emitDelta("thinking_delta", map[string]interface{}{
|
||||
_, _ = result.Write(p.emitDelta("thinking_delta", map[string]interface{}{
|
||||
"thinking": "",
|
||||
}))
|
||||
result.Write(p.emitDelta("signature_delta", map[string]interface{}{
|
||||
_, _ = result.Write(p.emitDelta("signature_delta", map[string]interface{}{
|
||||
"signature": signature,
|
||||
}))
|
||||
result.Write(p.endBlock())
|
||||
_, _ = result.Write(p.endBlock())
|
||||
|
||||
return result.Bytes()
|
||||
}
|
||||
@@ -401,11 +401,11 @@ func (p *StreamingProcessor) emitFinish(finishReason string) []byte {
|
||||
var result bytes.Buffer
|
||||
|
||||
// 关闭最后一个块
|
||||
result.Write(p.endBlock())
|
||||
_, _ = result.Write(p.endBlock())
|
||||
|
||||
// 处理 trailingSignature
|
||||
if p.trailingSignature != "" {
|
||||
result.Write(p.emitEmptyThinkingWithSignature(p.trailingSignature))
|
||||
_, _ = result.Write(p.emitEmptyThinkingWithSignature(p.trailingSignature))
|
||||
p.trailingSignature = ""
|
||||
}
|
||||
|
||||
@@ -431,13 +431,13 @@ func (p *StreamingProcessor) emitFinish(finishReason string) []byte {
|
||||
"usage": usage,
|
||||
}
|
||||
|
||||
result.Write(p.formatSSE("message_delta", deltaEvent))
|
||||
_, _ = result.Write(p.formatSSE("message_delta", deltaEvent))
|
||||
|
||||
if !p.messageStopSent {
|
||||
stopEvent := map[string]interface{}{
|
||||
"type": "message_stop",
|
||||
}
|
||||
result.Write(p.formatSSE("message_stop", stopEvent))
|
||||
_, _ = result.Write(p.formatSSE("message_stop", stopEvent))
|
||||
p.messageStopSent = true
|
||||
}
|
||||
|
||||
|
||||
10
backend/internal/pkg/ctxkey/ctxkey.go
Normal file
10
backend/internal/pkg/ctxkey/ctxkey.go
Normal file
@@ -0,0 +1,10 @@
|
||||
// Package ctxkey 定义用于 context.Value 的类型安全 key
|
||||
package ctxkey
|
||||
|
||||
// Key 定义 context key 的类型,避免使用内置 string 类型(staticcheck SA1029)
|
||||
type Key string
|
||||
|
||||
const (
|
||||
// ForcePlatform 强制平台(用于 /antigravity 路由),由 middleware.ForcePlatform 设置
|
||||
ForcePlatform Key = "ctx_force_platform"
|
||||
)
|
||||
@@ -3,6 +3,7 @@ package middleware
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
@@ -22,16 +23,12 @@ const (
|
||||
ContextKeyForcePlatform ContextKey = "force_platform"
|
||||
)
|
||||
|
||||
// ctxKeyForcePlatformStr 用于 request.Context 的字符串 key(供 Service 读取)
|
||||
// 注意:service 包中也需要使用相同的字符串 "ctx_force_platform"
|
||||
const ctxKeyForcePlatformStr = "ctx_force_platform"
|
||||
|
||||
// ForcePlatform 返回设置强制平台的中间件
|
||||
// 同时设置 request.Context(供 Service 使用)和 gin.Context(供 Handler 快速检查)
|
||||
func ForcePlatform(platform string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 设置到 request.Context,使用字符串 key 供 Service 层读取
|
||||
ctx := context.WithValue(c.Request.Context(), ctxKeyForcePlatformStr, platform)
|
||||
// 设置到 request.Context,使用 ctxkey.ForcePlatform 供 Service 层读取
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.ForcePlatform, platform)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
// 同时设置到 gin.Context,供 Handler 快速检查
|
||||
c.Set(string(ContextKeyForcePlatform), platform)
|
||||
|
||||
@@ -116,7 +116,7 @@ func (s *AntigravityOAuthService) ExchangeCode(ctx context.Context, input *Antig
|
||||
// 交换 token
|
||||
tokenResp, err := client.ExchangeCode(ctx, input.Code, session.CodeVerifier)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Token 交换失败: %w", err)
|
||||
return nil, fmt.Errorf("token 交换失败: %w", err)
|
||||
}
|
||||
|
||||
// 删除 session
|
||||
@@ -184,7 +184,7 @@ func (s *AntigravityOAuthService) RefreshToken(ctx context.Context, refreshToken
|
||||
lastErr = err
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("Token 刷新失败 (重试后): %w", lastErr)
|
||||
return nil, fmt.Errorf("token 刷新失败 (重试后): %w", lastErr)
|
||||
}
|
||||
|
||||
func isNonRetryableAntigravityOAuthError(err error) bool {
|
||||
|
||||
@@ -8,10 +8,16 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// testConfig 返回一个用于测试的默认配置
|
||||
func testConfig() *config.Config {
|
||||
return &config.Config{RunMode: config.RunModeStandard}
|
||||
}
|
||||
|
||||
// mockAccountRepoForPlatform 单平台测试用的 mock
|
||||
type mockAccountRepoForPlatform struct {
|
||||
accounts []Account
|
||||
@@ -177,6 +183,7 @@ func TestGatewayService_SelectAccountForModelWithPlatform_Anthropic(t *testing.T
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
cfg: testConfig(),
|
||||
}
|
||||
|
||||
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
|
||||
@@ -206,6 +213,7 @@ func TestGatewayService_SelectAccountForModelWithPlatform_Antigravity(t *testing
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
cfg: testConfig(),
|
||||
}
|
||||
|
||||
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAntigravity)
|
||||
@@ -236,6 +244,7 @@ func TestGatewayService_SelectAccountForModelWithPlatform_PriorityAndLastUsed(t
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
cfg: testConfig(),
|
||||
}
|
||||
|
||||
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
|
||||
@@ -258,6 +267,7 @@ func TestGatewayService_SelectAccountForModelWithPlatform_NoAvailableAccounts(t
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
cfg: testConfig(),
|
||||
}
|
||||
|
||||
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
|
||||
@@ -286,6 +296,7 @@ func TestGatewayService_SelectAccountForModelWithPlatform_AllExcluded(t *testing
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
cfg: testConfig(),
|
||||
}
|
||||
|
||||
excludedIDs := map[int64]struct{}{1: {}, 2: {}}
|
||||
@@ -361,6 +372,7 @@ func TestGatewayService_SelectAccountForModelWithPlatform_Schedulability(t *test
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
cfg: testConfig(),
|
||||
}
|
||||
|
||||
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
|
||||
@@ -394,6 +406,7 @@ func TestGatewayService_SelectAccountForModelWithPlatform_StickySession(t *testi
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
cfg: testConfig(),
|
||||
}
|
||||
|
||||
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "session-123", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
|
||||
@@ -421,6 +434,7 @@ func TestGatewayService_SelectAccountForModelWithPlatform_StickySession(t *testi
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
cfg: testConfig(),
|
||||
}
|
||||
|
||||
// 请求 anthropic 平台,但粘性会话绑定的是 antigravity 账户
|
||||
@@ -450,6 +464,7 @@ func TestGatewayService_SelectAccountForModelWithPlatform_StickySession(t *testi
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
cfg: testConfig(),
|
||||
}
|
||||
|
||||
excludedIDs := map[int64]struct{}{1: {}}
|
||||
@@ -478,6 +493,7 @@ func TestGatewayService_SelectAccountForModelWithPlatform_StickySession(t *testi
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
cfg: testConfig(),
|
||||
}
|
||||
|
||||
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "session-123", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
|
||||
@@ -569,6 +585,7 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
cfg: testConfig(),
|
||||
}
|
||||
|
||||
acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
|
||||
@@ -594,6 +611,7 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
cfg: testConfig(),
|
||||
}
|
||||
|
||||
acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
|
||||
@@ -622,6 +640,7 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
cfg: testConfig(),
|
||||
}
|
||||
|
||||
acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "session-123", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
|
||||
@@ -649,6 +668,7 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
cfg: testConfig(),
|
||||
}
|
||||
|
||||
acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "session-123", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
|
||||
@@ -673,6 +693,7 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
cfg: testConfig(),
|
||||
}
|
||||
|
||||
acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
|
||||
@@ -698,6 +719,7 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
cfg: testConfig(),
|
||||
}
|
||||
|
||||
acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, PlatformAnthropic)
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
|
||||
@@ -30,9 +31,6 @@ const (
|
||||
stickySessionTTL = time.Hour // 粘性会话TTL
|
||||
)
|
||||
|
||||
// ctxKeyForcePlatform 用于从 context 读取强制平台(由 middleware.ForcePlatform 设置)
|
||||
// 必须与 middleware.ctxKeyForcePlatformStr 使用相同的字符串值
|
||||
const ctxKeyForcePlatform = "ctx_force_platform"
|
||||
|
||||
// sseDataRe matches SSE data lines with optional whitespace after colon.
|
||||
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
|
||||
@@ -300,7 +298,7 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int
|
||||
func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
|
||||
// 优先检查 context 中的强制平台(/antigravity 路由)
|
||||
var platform string
|
||||
forcePlatform, hasForcePlatform := ctx.Value(ctxKeyForcePlatform).(string)
|
||||
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
|
||||
if hasForcePlatform && forcePlatform != "" {
|
||||
platform = forcePlatform
|
||||
} else if groupID != nil {
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
|
||||
|
||||
@@ -74,7 +75,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context,
|
||||
func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
|
||||
// 优先检查 context 中的强制平台(/antigravity 路由)
|
||||
var platform string
|
||||
forcePlatform, hasForcePlatform := ctx.Value(ctxKeyForcePlatform).(string)
|
||||
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
|
||||
if hasForcePlatform && forcePlatform != "" {
|
||||
platform = forcePlatform
|
||||
} else if groupID != nil {
|
||||
|
||||
Reference in New Issue
Block a user