diff --git a/backend/cmd/server/wire.go b/backend/cmd/server/wire.go index fffcd5f9..8596b8ba 100644 --- a/backend/cmd/server/wire.go +++ b/backend/cmd/server/wire.go @@ -12,7 +12,6 @@ import ( "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/handler" - "github.com/Wei-Shaw/sub2api/internal/infrastructure" "github.com/Wei-Shaw/sub2api/internal/repository" "github.com/Wei-Shaw/sub2api/internal/server" "github.com/Wei-Shaw/sub2api/internal/server/middleware" @@ -31,7 +30,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { wire.Build( // Infrastructure layer ProviderSets config.ProviderSet, - infrastructure.ProviderSet, // Business layer ProviderSets repository.ProviderSet, diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index e3498680..1adabefe 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -12,7 +12,6 @@ import ( "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/handler" "github.com/Wei-Shaw/sub2api/internal/handler/admin" - "github.com/Wei-Shaw/sub2api/internal/infrastructure" "github.com/Wei-Shaw/sub2api/internal/repository" "github.com/Wei-Shaw/sub2api/internal/server" "github.com/Wei-Shaw/sub2api/internal/server/middleware" @@ -35,18 +34,18 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { if err != nil { return nil, err } - client, err := infrastructure.ProvideEnt(configConfig) + client, err := repository.ProvideEnt(configConfig) if err != nil { return nil, err } - db, err := infrastructure.ProvideSQLDB(client) + db, err := repository.ProvideSQLDB(client) if err != nil { return nil, err } userRepository := repository.NewUserRepository(client, db) settingRepository := repository.NewSettingRepository(client) settingService := service.NewSettingService(settingRepository, configConfig) - redisClient := infrastructure.ProvideRedis(configConfig) + redisClient := repository.ProvideRedis(configConfig) emailCache := repository.NewEmailCache(redisClient) emailService := service.NewEmailService(settingRepository, emailCache) turnstileVerifier := repository.NewTurnstileVerifier() diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 8c154a9d..7927fec5 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -121,6 +121,17 @@ type GatewayConfig struct { // 应大于最长 LLM 请求时间,防止请求完成前槽位过期 ConcurrencySlotTTLMinutes int `mapstructure:"concurrency_slot_ttl_minutes"` + // 是否记录上游错误响应体摘要(避免输出请求内容) + LogUpstreamErrorBody bool `mapstructure:"log_upstream_error_body"` + // 上游错误响应体记录最大字节数(超过会截断) + LogUpstreamErrorBodyMaxBytes int `mapstructure:"log_upstream_error_body_max_bytes"` + + // API-key 账号在客户端未提供 anthropic-beta 时,是否按需自动补齐(默认关闭以保持兼容) + InjectBetaForApiKey bool `mapstructure:"inject_beta_for_apikey"` + + // 是否允许对部分 400 错误触发 failover(默认关闭以避免改变语义) + FailoverOn400 bool `mapstructure:"failover_on_400"` + // Scheduling: 账号调度相关配置 Scheduling GatewaySchedulingConfig `mapstructure:"scheduling"` } @@ -334,6 +345,10 @@ func setDefaults() { // Gateway viper.SetDefault("gateway.response_header_timeout", 300) // 300秒(5分钟)等待上游响应头,LLM高负载时可能排队较久 + viper.SetDefault("gateway.log_upstream_error_body", false) + viper.SetDefault("gateway.log_upstream_error_body_max_bytes", 2048) + viper.SetDefault("gateway.inject_beta_for_apikey", false) + viper.SetDefault("gateway.failover_on_400", false) viper.SetDefault("gateway.max_body_size", int64(100*1024*1024)) viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy) // HTTP 上游连接池配置(针对 5000+ 并发用户优化) diff --git a/backend/internal/infrastructure/wire.go b/backend/internal/infrastructure/wire.go deleted file mode 100644 index 1e64640c..00000000 --- a/backend/internal/infrastructure/wire.go +++ /dev/null @@ -1,79 +0,0 @@ -package infrastructure - -import ( - "database/sql" - "errors" - - "github.com/Wei-Shaw/sub2api/ent" - "github.com/Wei-Shaw/sub2api/internal/config" - - "github.com/google/wire" - "github.com/redis/go-redis/v9" - - entsql "entgo.io/ent/dialect/sql" -) - -// ProviderSet 是基础设施层的 Wire 依赖提供者集合。 -// -// Wire 是 Google 开发的编译时依赖注入工具。ProviderSet 将相关的依赖提供函数 -// 组织在一起,便于在应用程序启动时自动组装依赖关系。 -// -// 包含的提供者: -// - ProvideEnt: 提供 Ent ORM 客户端 -// - ProvideSQLDB: 提供底层 SQL 数据库连接 -// - ProvideRedis: 提供 Redis 客户端 -var ProviderSet = wire.NewSet( - ProvideEnt, - ProvideSQLDB, - ProvideRedis, -) - -// ProvideEnt 为依赖注入提供 Ent 客户端。 -// -// 该函数是 InitEnt 的包装器,符合 Wire 的依赖提供函数签名要求。 -// Wire 会在编译时分析依赖关系,自动生成初始化代码。 -// -// 依赖:config.Config -// 提供:*ent.Client -func ProvideEnt(cfg *config.Config) (*ent.Client, error) { - client, _, err := InitEnt(cfg) - return client, err -} - -// ProvideSQLDB 从 Ent 客户端提取底层的 *sql.DB 连接。 -// -// 某些 Repository 需要直接执行原生 SQL(如复杂的批量更新、聚合查询), -// 此时需要访问底层的 sql.DB 而不是通过 Ent ORM。 -// -// 设计说明: -// - Ent 底层使用 sql.DB,通过 Driver 接口可以访问 -// - 这种设计允许在同一事务中混用 Ent 和原生 SQL -// -// 依赖:*ent.Client -// 提供:*sql.DB -func ProvideSQLDB(client *ent.Client) (*sql.DB, error) { - if client == nil { - return nil, errors.New("nil ent client") - } - // 从 Ent 客户端获取底层驱动 - drv, ok := client.Driver().(*entsql.Driver) - if !ok { - return nil, errors.New("ent driver does not expose *sql.DB") - } - // 返回驱动持有的 sql.DB 实例 - return drv.DB(), nil -} - -// ProvideRedis 为依赖注入提供 Redis 客户端。 -// -// Redis 用于: -// - 分布式锁(如并发控制) -// - 缓存(如用户会话、API 响应缓存) -// - 速率限制 -// - 实时统计数据 -// -// 依赖:config.Config -// 提供:*redis.Client -func ProvideRedis(cfg *config.Config) *redis.Client { - return InitRedis(cfg) -} diff --git a/backend/internal/pkg/antigravity/claude_types.go b/backend/internal/pkg/antigravity/claude_types.go index 01b805cd..34e6b1f4 100644 --- a/backend/internal/pkg/antigravity/claude_types.go +++ b/backend/internal/pkg/antigravity/claude_types.go @@ -54,6 +54,9 @@ type CustomToolSpec struct { InputSchema map[string]any `json:"input_schema"` } +// ClaudeCustomToolSpec 兼容旧命名(MCP custom 工具规格) +type ClaudeCustomToolSpec = CustomToolSpec + // SystemBlock system prompt 数组形式的元素 type SystemBlock struct { Type string `json:"type"` diff --git a/backend/internal/pkg/antigravity/request_transformer.go b/backend/internal/pkg/antigravity/request_transformer.go index e0b5b886..83b87a32 100644 --- a/backend/internal/pkg/antigravity/request_transformer.go +++ b/backend/internal/pkg/antigravity/request_transformer.go @@ -14,13 +14,16 @@ func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel st // 用于存储 tool_use id -> name 映射 toolIDToName := make(map[string]string) - // 检测是否启用 thinking - isThinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled" - // 只有 Gemini 模型支持 dummy thought workaround // Claude 模型通过 Vertex/Google API 需要有效的 thought signatures allowDummyThought := strings.HasPrefix(mappedModel, "gemini-") + // 检测是否启用 thinking + requestedThinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled" + // 为避免 Claude 模型的 thought signature/消息块约束导致 400(上游要求 thinking 块开头等), + // 非 Gemini 模型默认不启用 thinking(除非未来支持完整签名链路)。 + isThinkingEnabled := requestedThinkingEnabled && allowDummyThought + // 1. 构建 contents contents, err := buildContents(claudeReq.Messages, toolIDToName, isThinkingEnabled, allowDummyThought) if err != nil { @@ -31,7 +34,15 @@ func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel st systemInstruction := buildSystemInstruction(claudeReq.System, claudeReq.Model) // 3. 构建 generationConfig - generationConfig := buildGenerationConfig(claudeReq) + reqForGen := claudeReq + if requestedThinkingEnabled && !allowDummyThought { + log.Printf("[Warning] Disabling thinking for non-Gemini model in antigravity transform: model=%s", mappedModel) + // shallow copy to avoid mutating caller's request + clone := *claudeReq + clone.Thinking = nil + reqForGen = &clone + } + generationConfig := buildGenerationConfig(reqForGen) // 4. 构建 tools tools := buildTools(claudeReq.Tools) @@ -148,8 +159,9 @@ func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isT if !hasThoughtPart && len(parts) > 0 { // 在开头添加 dummy thinking block parts = append([]GeminiPart{{ - Text: "Thinking...", - Thought: true, + Text: "Thinking...", + Thought: true, + ThoughtSignature: dummyThoughtSignature, }}, parts...) } } @@ -171,6 +183,34 @@ func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isT // 参考: https://ai.google.dev/gemini-api/docs/thought-signatures const dummyThoughtSignature = "skip_thought_signature_validator" +// isValidThoughtSignature 验证 thought signature 是否有效 +// Claude API 要求 signature 必须是 base64 编码的字符串,长度至少 32 字节 +func isValidThoughtSignature(signature string) bool { + // 空字符串无效 + if signature == "" { + return false + } + + // signature 应该是 base64 编码,长度至少 40 个字符(约 30 字节) + // 参考 Claude API 文档和实际观察到的有效 signature + if len(signature) < 40 { + log.Printf("[Debug] Signature too short: len=%d", len(signature)) + return false + } + + // 检查是否是有效的 base64 字符 + // base64 字符集: A-Z, a-z, 0-9, +, /, = + for i, c := range signature { + if (c < 'A' || c > 'Z') && (c < 'a' || c > 'z') && + (c < '0' || c > '9') && c != '+' && c != '/' && c != '=' { + log.Printf("[Debug] Invalid base64 character at position %d: %c (code=%d)", i, c, c) + return false + } + } + + return true +} + // buildParts 构建消息的 parts // allowDummyThought: 只有 Gemini 模型支持 dummy thought signature func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDummyThought bool) ([]GeminiPart, error) { @@ -199,22 +239,30 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu } case "thinking": - part := GeminiPart{ - Text: block.Thinking, - Thought: true, - } - // 保留原有 signature(Claude 模型需要有效的 signature) - if block.Signature != "" { - part.ThoughtSignature = block.Signature - } else if !allowDummyThought { - // Claude 模型需要有效 signature,跳过无 signature 的 thinking block - log.Printf("Warning: skipping thinking block without signature for Claude model") + if allowDummyThought { + // Gemini 模型可以使用 dummy signature + parts = append(parts, GeminiPart{ + Text: block.Thinking, + Thought: true, + ThoughtSignature: dummyThoughtSignature, + }) continue - } else { - // Gemini 模型使用 dummy signature - part.ThoughtSignature = dummyThoughtSignature } - parts = append(parts, part) + + // Claude 模型:仅在提供有效 signature 时保留 thinking block;否则跳过以避免上游校验失败。 + signature := strings.TrimSpace(block.Signature) + if signature == "" || signature == dummyThoughtSignature { + log.Printf("[Warning] Skipping thinking block for Claude model (missing or dummy signature)") + continue + } + if !isValidThoughtSignature(signature) { + log.Printf("[Debug] Thinking signature may be invalid (passing through anyway): len=%d", len(signature)) + } + parts = append(parts, GeminiPart{ + Text: block.Thinking, + Thought: true, + ThoughtSignature: signature, + }) case "image": if block.Source != nil && block.Source.Type == "base64" { @@ -239,10 +287,9 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu ID: block.ID, }, } - // 保留原有 signature,或对 Gemini 模型使用 dummy signature - if block.Signature != "" { - part.ThoughtSignature = block.Signature - } else if allowDummyThought { + // 只有 Gemini 模型使用 dummy signature + // Claude 模型不设置 signature(避免验证问题) + if allowDummyThought { part.ThoughtSignature = dummyThoughtSignature } parts = append(parts, part) @@ -386,9 +433,9 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration { // 普通工具 var funcDecls []GeminiFunctionDecl - for _, tool := range tools { + for i, tool := range tools { // 跳过无效工具名称 - if tool.Name == "" { + if strings.TrimSpace(tool.Name) == "" { log.Printf("Warning: skipping tool with empty name") continue } @@ -397,10 +444,18 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration { var inputSchema map[string]any // 检查是否为 custom 类型工具 (MCP) - if tool.Type == "custom" && tool.Custom != nil { - // Custom 格式: 从 custom 字段获取 description 和 input_schema + if tool.Type == "custom" { + if tool.Custom == nil || tool.Custom.InputSchema == nil { + log.Printf("[Warning] Skipping invalid custom tool '%s': missing custom spec or input_schema", tool.Name) + continue + } description = tool.Custom.Description inputSchema = tool.Custom.InputSchema + + // 调试日志:记录 custom 工具的 schema + if schemaJSON, err := json.Marshal(inputSchema); err == nil { + log.Printf("[Debug] Tool[%d] '%s' (custom) original schema: %s", i, tool.Name, string(schemaJSON)) + } } else { // 标准格式: 从顶层字段获取 description = tool.Description @@ -409,7 +464,6 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration { // 清理 JSON Schema params := cleanJSONSchema(inputSchema) - // 为 nil schema 提供默认值 if params == nil { params = map[string]any{ @@ -418,6 +472,11 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration { } } + // 调试日志:记录清理后的 schema + if paramsJSON, err := json.Marshal(params); err == nil { + log.Printf("[Debug] Tool[%d] '%s' cleaned schema: %s", i, tool.Name, string(paramsJSON)) + } + funcDecls = append(funcDecls, GeminiFunctionDecl{ Name: tool.Name, Description: description, @@ -479,31 +538,64 @@ func cleanJSONSchema(schema map[string]any) map[string]any { } // excludedSchemaKeys 不支持的 schema 字段 +// 基于 Claude API (Vertex AI) 的实际支持情况 +// 支持: type, description, enum, properties, required, additionalProperties, items +// 不支持: minItems, maxItems, minLength, maxLength, pattern, minimum, maximum 等验证字段 var excludedSchemaKeys = map[string]bool{ - "$schema": true, - "$id": true, - "$ref": true, - "additionalProperties": true, - "minLength": true, - "maxLength": true, - "minItems": true, - "maxItems": true, - "uniqueItems": true, - "minimum": true, - "maximum": true, - "exclusiveMinimum": true, - "exclusiveMaximum": true, - "pattern": true, - "format": true, - "default": true, - "strict": true, - "const": true, - "examples": true, - "deprecated": true, - "readOnly": true, - "writeOnly": true, - "contentMediaType": true, - "contentEncoding": true, + // 元 schema 字段 + "$schema": true, + "$id": true, + "$ref": true, + + // 字符串验证(Gemini 不支持) + "minLength": true, + "maxLength": true, + "pattern": true, + + // 数字验证(Claude API 通过 Vertex AI 不支持这些字段) + "minimum": true, + "maximum": true, + "exclusiveMinimum": true, + "exclusiveMaximum": true, + "multipleOf": true, + + // 数组验证(Claude API 通过 Vertex AI 不支持这些字段) + "uniqueItems": true, + "minItems": true, + "maxItems": true, + + // 组合 schema(Gemini 不支持) + "oneOf": true, + "anyOf": true, + "allOf": true, + "not": true, + "if": true, + "then": true, + "else": true, + "$defs": true, + "definitions": true, + + // 对象验证(仅保留 properties/required/additionalProperties) + "minProperties": true, + "maxProperties": true, + "patternProperties": true, + "propertyNames": true, + "dependencies": true, + "dependentSchemas": true, + "dependentRequired": true, + + // 其他不支持的字段 + "default": true, + "const": true, + "examples": true, + "deprecated": true, + "readOnly": true, + "writeOnly": true, + "contentMediaType": true, + "contentEncoding": true, + + // Claude 特有字段 + "strict": true, } // cleanSchemaValue 递归清理 schema 值 @@ -523,6 +615,31 @@ func cleanSchemaValue(value any) any { continue } + // 特殊处理 format 字段:只保留 Gemini 支持的 format 值 + if k == "format" { + if formatStr, ok := val.(string); ok { + // Gemini 只支持 date-time, date, time + if formatStr == "date-time" || formatStr == "date" || formatStr == "time" { + result[k] = val + } + // 其他 format 值直接跳过 + } + continue + } + + // 特殊处理 additionalProperties:Claude API 只支持布尔值,不支持 schema 对象 + if k == "additionalProperties" { + if boolVal, ok := val.(bool); ok { + result[k] = boolVal + log.Printf("[Debug] additionalProperties is bool: %v", boolVal) + } else { + // 如果是 schema 对象,转换为 false(更安全的默认值) + result[k] = false + log.Printf("[Debug] additionalProperties is not bool (type: %T), converting to false", val) + } + continue + } + // 递归清理所有值 result[k] = cleanSchemaValue(val) } diff --git a/backend/internal/pkg/antigravity/request_transformer_test.go b/backend/internal/pkg/antigravity/request_transformer_test.go new file mode 100644 index 00000000..56eebad0 --- /dev/null +++ b/backend/internal/pkg/antigravity/request_transformer_test.go @@ -0,0 +1,179 @@ +package antigravity + +import ( + "encoding/json" + "testing" +) + +// TestBuildParts_ThinkingBlockWithoutSignature 测试thinking block无signature时的处理 +func TestBuildParts_ThinkingBlockWithoutSignature(t *testing.T) { + tests := []struct { + name string + content string + allowDummyThought bool + expectedParts int + description string + }{ + { + name: "Claude model - skip thinking block without signature", + content: `[ + {"type": "text", "text": "Hello"}, + {"type": "thinking", "thinking": "Let me think...", "signature": ""}, + {"type": "text", "text": "World"} + ]`, + allowDummyThought: false, + expectedParts: 2, // 只有两个text block + description: "Claude模型应该跳过无signature的thinking block", + }, + { + name: "Claude model - keep thinking block with signature", + content: `[ + {"type": "text", "text": "Hello"}, + {"type": "thinking", "thinking": "Let me think...", "signature": "valid_sig"}, + {"type": "text", "text": "World"} + ]`, + allowDummyThought: false, + expectedParts: 3, // 三个block都保留 + description: "Claude模型应该保留有signature的thinking block", + }, + { + name: "Gemini model - use dummy signature", + content: `[ + {"type": "text", "text": "Hello"}, + {"type": "thinking", "thinking": "Let me think...", "signature": ""}, + {"type": "text", "text": "World"} + ]`, + allowDummyThought: true, + expectedParts: 3, // 三个block都保留,thinking使用dummy signature + description: "Gemini模型应该为无signature的thinking block使用dummy signature", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + toolIDToName := make(map[string]string) + parts, err := buildParts(json.RawMessage(tt.content), toolIDToName, tt.allowDummyThought) + + if err != nil { + t.Fatalf("buildParts() error = %v", err) + } + + if len(parts) != tt.expectedParts { + t.Errorf("%s: got %d parts, want %d parts", tt.description, len(parts), tt.expectedParts) + } + }) + } +} + +// TestBuildTools_CustomTypeTools 测试custom类型工具转换 +func TestBuildTools_CustomTypeTools(t *testing.T) { + tests := []struct { + name string + tools []ClaudeTool + expectedLen int + description string + }{ + { + name: "Standard tool format", + tools: []ClaudeTool{ + { + Name: "get_weather", + Description: "Get weather information", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "location": map[string]any{"type": "string"}, + }, + }, + }, + }, + expectedLen: 1, + description: "标准工具格式应该正常转换", + }, + { + name: "Custom type tool (MCP format)", + tools: []ClaudeTool{ + { + Type: "custom", + Name: "mcp_tool", + Custom: &ClaudeCustomToolSpec{ + Description: "MCP tool description", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "param": map[string]any{"type": "string"}, + }, + }, + }, + }, + }, + expectedLen: 1, + description: "Custom类型工具应该从Custom字段读取description和input_schema", + }, + { + name: "Mixed standard and custom tools", + tools: []ClaudeTool{ + { + Name: "standard_tool", + Description: "Standard tool", + InputSchema: map[string]any{"type": "object"}, + }, + { + Type: "custom", + Name: "custom_tool", + Custom: &ClaudeCustomToolSpec{ + Description: "Custom tool", + InputSchema: map[string]any{"type": "object"}, + }, + }, + }, + expectedLen: 1, // 返回一个GeminiToolDeclaration,包含2个function declarations + description: "混合标准和custom工具应该都能正确转换", + }, + { + name: "Invalid custom tool - nil Custom field", + tools: []ClaudeTool{ + { + Type: "custom", + Name: "invalid_custom", + // Custom 为 nil + }, + }, + expectedLen: 0, // 应该被跳过 + description: "Custom字段为nil的custom工具应该被跳过", + }, + { + name: "Invalid custom tool - nil InputSchema", + tools: []ClaudeTool{ + { + Type: "custom", + Name: "invalid_custom", + Custom: &ClaudeCustomToolSpec{ + Description: "Invalid", + // InputSchema 为 nil + }, + }, + }, + expectedLen: 0, // 应该被跳过 + description: "InputSchema为nil的custom工具应该被跳过", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := buildTools(tt.tools) + + if len(result) != tt.expectedLen { + t.Errorf("%s: got %d tool declarations, want %d", tt.description, len(result), tt.expectedLen) + } + + // 验证function declarations存在 + if len(result) > 0 && result[0].FunctionDeclarations != nil { + if len(result[0].FunctionDeclarations) != len(tt.tools) { + t.Errorf("%s: got %d function declarations, want %d", + tt.description, len(result[0].FunctionDeclarations), len(tt.tools)) + } + } + }) + } +} diff --git a/backend/internal/pkg/claude/constants.go b/backend/internal/pkg/claude/constants.go index 97ad6c83..0db3ed4a 100644 --- a/backend/internal/pkg/claude/constants.go +++ b/backend/internal/pkg/claude/constants.go @@ -16,6 +16,12 @@ const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleav // HaikuBetaHeader Haiku 模型使用的 anthropic-beta header(不需要 claude-code beta) const HaikuBetaHeader = BetaOAuth + "," + BetaInterleavedThinking +// ApiKeyBetaHeader API-key 账号建议使用的 anthropic-beta header(不包含 oauth) +const ApiKeyBetaHeader = BetaClaudeCode + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming + +// ApiKeyHaikuBetaHeader Haiku 模型在 API-key 账号下使用的 anthropic-beta header(不包含 oauth / claude-code) +const ApiKeyHaikuBetaHeader = BetaInterleavedThinking + // Claude Code 客户端默认请求头 var DefaultHeaders = map[string]string{ "User-Agent": "claude-cli/2.0.62 (external, cli)", diff --git a/backend/internal/infrastructure/errors/errors.go b/backend/internal/pkg/errors/errors.go similarity index 100% rename from backend/internal/infrastructure/errors/errors.go rename to backend/internal/pkg/errors/errors.go diff --git a/backend/internal/infrastructure/errors/errors_test.go b/backend/internal/pkg/errors/errors_test.go similarity index 100% rename from backend/internal/infrastructure/errors/errors_test.go rename to backend/internal/pkg/errors/errors_test.go diff --git a/backend/internal/infrastructure/errors/http.go b/backend/internal/pkg/errors/http.go similarity index 100% rename from backend/internal/infrastructure/errors/http.go rename to backend/internal/pkg/errors/http.go diff --git a/backend/internal/infrastructure/errors/types.go b/backend/internal/pkg/errors/types.go similarity index 100% rename from backend/internal/infrastructure/errors/types.go rename to backend/internal/pkg/errors/types.go diff --git a/backend/internal/pkg/response/response.go b/backend/internal/pkg/response/response.go index e26d2531..87dc4264 100644 --- a/backend/internal/pkg/response/response.go +++ b/backend/internal/pkg/response/response.go @@ -4,7 +4,7 @@ import ( "math" "net/http" - infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/gin-gonic/gin" ) diff --git a/backend/internal/pkg/response/response_test.go b/backend/internal/pkg/response/response_test.go index 13b184af..ef31ca3c 100644 --- a/backend/internal/pkg/response/response_test.go +++ b/backend/internal/pkg/response/response_test.go @@ -9,7 +9,7 @@ import ( "net/http/httptest" "testing" - infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" + errors2 "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" ) @@ -82,7 +82,7 @@ func TestErrorFrom(t *testing.T) { }, { name: "application_error", - err: infraerrors.Forbidden("FORBIDDEN", "no access").WithMetadata(map[string]string{"scope": "admin"}), + err: errors2.Forbidden("FORBIDDEN", "no access").WithMetadata(map[string]string{"scope": "admin"}), wantWritten: true, wantHTTPCode: http.StatusForbidden, wantBody: Response{ @@ -94,7 +94,7 @@ func TestErrorFrom(t *testing.T) { }, { name: "bad_request_error", - err: infraerrors.BadRequest("INVALID_REQUEST", "invalid request"), + err: errors2.BadRequest("INVALID_REQUEST", "invalid request"), wantWritten: true, wantHTTPCode: http.StatusBadRequest, wantBody: Response{ @@ -105,7 +105,7 @@ func TestErrorFrom(t *testing.T) { }, { name: "unauthorized_error", - err: infraerrors.Unauthorized("UNAUTHORIZED", "unauthorized"), + err: errors2.Unauthorized("UNAUTHORIZED", "unauthorized"), wantWritten: true, wantHTTPCode: http.StatusUnauthorized, wantBody: Response{ @@ -116,7 +116,7 @@ func TestErrorFrom(t *testing.T) { }, { name: "not_found_error", - err: infraerrors.NotFound("NOT_FOUND", "not found"), + err: errors2.NotFound("NOT_FOUND", "not found"), wantWritten: true, wantHTTPCode: http.StatusNotFound, wantBody: Response{ @@ -127,7 +127,7 @@ func TestErrorFrom(t *testing.T) { }, { name: "conflict_error", - err: infraerrors.Conflict("CONFLICT", "conflict"), + err: errors2.Conflict("CONFLICT", "conflict"), wantWritten: true, wantHTTPCode: http.StatusConflict, wantBody: Response{ @@ -143,7 +143,7 @@ func TestErrorFrom(t *testing.T) { wantHTTPCode: http.StatusInternalServerError, wantBody: Response{ Code: http.StatusInternalServerError, - Message: infraerrors.UnknownMessage, + Message: errors2.UnknownMessage, }, }, } diff --git a/backend/internal/infrastructure/db_pool.go b/backend/internal/repository/db_pool.go similarity index 97% rename from backend/internal/infrastructure/db_pool.go rename to backend/internal/repository/db_pool.go index 612155bf..d7116ab1 100644 --- a/backend/internal/infrastructure/db_pool.go +++ b/backend/internal/repository/db_pool.go @@ -1,4 +1,4 @@ -package infrastructure +package repository import ( "database/sql" diff --git a/backend/internal/infrastructure/db_pool_test.go b/backend/internal/repository/db_pool_test.go similarity index 98% rename from backend/internal/infrastructure/db_pool_test.go rename to backend/internal/repository/db_pool_test.go index 0f0e9716..3868106a 100644 --- a/backend/internal/infrastructure/db_pool_test.go +++ b/backend/internal/repository/db_pool_test.go @@ -1,4 +1,4 @@ -package infrastructure +package repository import ( "database/sql" diff --git a/backend/internal/infrastructure/ent.go b/backend/internal/repository/ent.go similarity index 99% rename from backend/internal/infrastructure/ent.go rename to backend/internal/repository/ent.go index b1ab9a55..d457ba72 100644 --- a/backend/internal/infrastructure/ent.go +++ b/backend/internal/repository/ent.go @@ -1,6 +1,6 @@ // Package infrastructure 提供应用程序的基础设施层组件。 // 包括数据库连接初始化、ORM 客户端管理、Redis 连接、数据库迁移等核心功能。 -package infrastructure +package repository import ( "context" diff --git a/backend/internal/repository/error_translate.go b/backend/internal/repository/error_translate.go index 192f9261..b8065ffe 100644 --- a/backend/internal/repository/error_translate.go +++ b/backend/internal/repository/error_translate.go @@ -7,7 +7,7 @@ import ( "strings" dbent "github.com/Wei-Shaw/sub2api/ent" - infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/lib/pq" ) diff --git a/backend/internal/repository/integration_harness_test.go b/backend/internal/repository/integration_harness_test.go index 6ef447e1..fb9c26c4 100644 --- a/backend/internal/repository/integration_harness_test.go +++ b/backend/internal/repository/integration_harness_test.go @@ -17,7 +17,6 @@ import ( dbent "github.com/Wei-Shaw/sub2api/ent" _ "github.com/Wei-Shaw/sub2api/ent/runtime" - "github.com/Wei-Shaw/sub2api/internal/infrastructure" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" @@ -97,7 +96,7 @@ func TestMain(m *testing.M) { log.Printf("failed to open sql db: %v", err) os.Exit(1) } - if err := infrastructure.ApplyMigrations(ctx, integrationDB); err != nil { + if err := ApplyMigrations(ctx, integrationDB); err != nil { log.Printf("failed to apply db migrations: %v", err) os.Exit(1) } diff --git a/backend/internal/infrastructure/migrations_runner.go b/backend/internal/repository/migrations_runner.go similarity index 99% rename from backend/internal/infrastructure/migrations_runner.go rename to backend/internal/repository/migrations_runner.go index 8477c031..e556b9ce 100644 --- a/backend/internal/infrastructure/migrations_runner.go +++ b/backend/internal/repository/migrations_runner.go @@ -1,4 +1,4 @@ -package infrastructure +package repository import ( "context" diff --git a/backend/internal/repository/migrations_schema_integration_test.go b/backend/internal/repository/migrations_schema_integration_test.go index 49d96445..4c7848b2 100644 --- a/backend/internal/repository/migrations_schema_integration_test.go +++ b/backend/internal/repository/migrations_schema_integration_test.go @@ -7,7 +7,6 @@ import ( "database/sql" "testing" - "github.com/Wei-Shaw/sub2api/internal/infrastructure" "github.com/stretchr/testify/require" ) @@ -15,7 +14,7 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) { tx := testTx(t) // Re-apply migrations to verify idempotency (no errors, no duplicate rows). - require.NoError(t, infrastructure.ApplyMigrations(context.Background(), integrationDB)) + require.NoError(t, ApplyMigrations(context.Background(), integrationDB)) // schema_migrations should have at least the current migration set. var applied int diff --git a/backend/internal/infrastructure/redis.go b/backend/internal/repository/redis.go similarity index 98% rename from backend/internal/infrastructure/redis.go rename to backend/internal/repository/redis.go index 9f4c8770..f3606ad9 100644 --- a/backend/internal/infrastructure/redis.go +++ b/backend/internal/repository/redis.go @@ -1,4 +1,4 @@ -package infrastructure +package repository import ( "time" diff --git a/backend/internal/infrastructure/redis_test.go b/backend/internal/repository/redis_test.go similarity index 97% rename from backend/internal/infrastructure/redis_test.go rename to backend/internal/repository/redis_test.go index 5e38e826..756a63dc 100644 --- a/backend/internal/infrastructure/redis_test.go +++ b/backend/internal/repository/redis_test.go @@ -1,4 +1,4 @@ -package infrastructure +package repository import ( "testing" diff --git a/backend/internal/repository/user_subscription_repo.go b/backend/internal/repository/user_subscription_repo.go index 2b308674..cd3b9db6 100644 --- a/backend/internal/repository/user_subscription_repo.go +++ b/backend/internal/repository/user_subscription_repo.go @@ -291,13 +291,11 @@ func (r *userSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id i return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil) } -// IncrementUsage 原子性地累加用量并校验限额。 -// 使用单条 SQL 语句同时检查 Group 的限额,如果任一限额即将超出则拒绝更新。 -// 当更新失败时,会执行额外查询确定具体超出的限额类型。 +// IncrementUsage 原子性地累加订阅用量。 +// 限额检查已在请求前由 BillingCacheService.CheckBillingEligibility 完成, +// 此处仅负责记录实际消费,确保消费数据的完整性。 func (r *userSubscriptionRepository) IncrementUsage(ctx context.Context, id int64, costUSD float64) error { - // 使用 JOIN 的原子更新:只有当所有限额条件满足时才执行累加 - // NULL 限额表示无限制 - const atomicUpdateSQL = ` + const updateSQL = ` UPDATE user_subscriptions us SET daily_usage_usd = us.daily_usage_usd + $1, @@ -309,13 +307,10 @@ func (r *userSubscriptionRepository) IncrementUsage(ctx context.Context, id int6 AND us.deleted_at IS NULL AND us.group_id = g.id AND g.deleted_at IS NULL - AND (g.daily_limit_usd IS NULL OR us.daily_usage_usd + $1 <= g.daily_limit_usd) - AND (g.weekly_limit_usd IS NULL OR us.weekly_usage_usd + $1 <= g.weekly_limit_usd) - AND (g.monthly_limit_usd IS NULL OR us.monthly_usage_usd + $1 <= g.monthly_limit_usd) ` client := clientFromContext(ctx, r.client) - result, err := client.ExecContext(ctx, atomicUpdateSQL, costUSD, id) + result, err := client.ExecContext(ctx, updateSQL, costUSD, id) if err != nil { return err } @@ -326,64 +321,11 @@ func (r *userSubscriptionRepository) IncrementUsage(ctx context.Context, id int6 } if affected > 0 { - return nil // 更新成功 + return nil } - // affected == 0:可能是订阅不存在、分组已删除、或限额超出 - // 执行额外查询确定具体原因 - return r.checkIncrementFailureReason(ctx, id, costUSD) -} - -// checkIncrementFailureReason 查询更新失败的具体原因 -func (r *userSubscriptionRepository) checkIncrementFailureReason(ctx context.Context, id int64, costUSD float64) error { - const checkSQL = ` - SELECT - CASE WHEN us.deleted_at IS NOT NULL THEN 'subscription_deleted' - WHEN g.id IS NULL THEN 'subscription_not_found' - WHEN g.deleted_at IS NOT NULL THEN 'group_deleted' - WHEN g.daily_limit_usd IS NOT NULL AND us.daily_usage_usd + $1 > g.daily_limit_usd THEN 'daily_exceeded' - WHEN g.weekly_limit_usd IS NOT NULL AND us.weekly_usage_usd + $1 > g.weekly_limit_usd THEN 'weekly_exceeded' - WHEN g.monthly_limit_usd IS NOT NULL AND us.monthly_usage_usd + $1 > g.monthly_limit_usd THEN 'monthly_exceeded' - ELSE 'unknown' - END AS reason - FROM user_subscriptions us - LEFT JOIN groups g ON us.group_id = g.id - WHERE us.id = $2 - ` - - client := clientFromContext(ctx, r.client) - rows, err := client.QueryContext(ctx, checkSQL, costUSD, id) - if err != nil { - return err - } - defer func() { _ = rows.Close() }() - - if !rows.Next() { - return service.ErrSubscriptionNotFound - } - - var reason string - if err := rows.Scan(&reason); err != nil { - return err - } - - if err := rows.Err(); err != nil { - return err - } - - switch reason { - case "subscription_not_found", "subscription_deleted", "group_deleted": - return service.ErrSubscriptionNotFound - case "daily_exceeded": - return service.ErrDailyLimitExceeded - case "weekly_exceeded": - return service.ErrWeeklyLimitExceeded - case "monthly_exceeded": - return service.ErrMonthlyLimitExceeded - default: - // unknown 情况理论上不应发生,但作为兜底返回 - return service.ErrSubscriptionNotFound - } + // affected == 0:订阅不存在或已删除 + return service.ErrSubscriptionNotFound } func (r *userSubscriptionRepository) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) { diff --git a/backend/internal/repository/user_subscription_repo_integration_test.go b/backend/internal/repository/user_subscription_repo_integration_test.go index 3a6c6434..2099e5d8 100644 --- a/backend/internal/repository/user_subscription_repo_integration_test.go +++ b/backend/internal/repository/user_subscription_repo_integration_test.go @@ -633,112 +633,7 @@ func (s *UserSubscriptionRepoSuite) TestActiveExpiredBoundaries_UsageAndReset_Ba s.Require().Equal(service.SubscriptionStatusExpired, updated.Status, "expected status expired") } -// --- 限额检查与软删除过滤测试 --- - -func (s *UserSubscriptionRepoSuite) mustCreateGroupWithLimits(name string, daily, weekly, monthly *float64) *service.Group { - s.T().Helper() - - create := s.client.Group.Create(). - SetName(name). - SetStatus(service.StatusActive). - SetSubscriptionType(service.SubscriptionTypeSubscription) - - if daily != nil { - create.SetDailyLimitUsd(*daily) - } - if weekly != nil { - create.SetWeeklyLimitUsd(*weekly) - } - if monthly != nil { - create.SetMonthlyLimitUsd(*monthly) - } - - g, err := create.Save(s.ctx) - s.Require().NoError(err, "create group with limits") - return groupEntityToService(g) -} - -func (s *UserSubscriptionRepoSuite) TestIncrementUsage_DailyLimitExceeded() { - user := s.mustCreateUser("dailylimit@test.com", service.RoleUser) - dailyLimit := 10.0 - group := s.mustCreateGroupWithLimits("g-dailylimit", &dailyLimit, nil, nil) - sub := s.mustCreateSubscription(user.ID, group.ID, nil) - - // 先增加 9.0,应该成功 - err := s.repo.IncrementUsage(s.ctx, sub.ID, 9.0) - s.Require().NoError(err, "first increment should succeed") - - // 再增加 2.0,会超过 10.0 限额,应该失败 - err = s.repo.IncrementUsage(s.ctx, sub.ID, 2.0) - s.Require().Error(err, "should fail when daily limit exceeded") - s.Require().ErrorIs(err, service.ErrDailyLimitExceeded) - - // 验证用量没有变化 - got, err := s.repo.GetByID(s.ctx, sub.ID) - s.Require().NoError(err) - s.Require().InDelta(9.0, got.DailyUsageUSD, 1e-6, "usage should not change after failed increment") -} - -func (s *UserSubscriptionRepoSuite) TestIncrementUsage_WeeklyLimitExceeded() { - user := s.mustCreateUser("weeklylimit@test.com", service.RoleUser) - weeklyLimit := 50.0 - group := s.mustCreateGroupWithLimits("g-weeklylimit", nil, &weeklyLimit, nil) - sub := s.mustCreateSubscription(user.ID, group.ID, nil) - - // 增加 45.0,应该成功 - err := s.repo.IncrementUsage(s.ctx, sub.ID, 45.0) - s.Require().NoError(err, "first increment should succeed") - - // 再增加 10.0,会超过 50.0 限额,应该失败 - err = s.repo.IncrementUsage(s.ctx, sub.ID, 10.0) - s.Require().Error(err, "should fail when weekly limit exceeded") - s.Require().ErrorIs(err, service.ErrWeeklyLimitExceeded) -} - -func (s *UserSubscriptionRepoSuite) TestIncrementUsage_MonthlyLimitExceeded() { - user := s.mustCreateUser("monthlylimit@test.com", service.RoleUser) - monthlyLimit := 100.0 - group := s.mustCreateGroupWithLimits("g-monthlylimit", nil, nil, &monthlyLimit) - sub := s.mustCreateSubscription(user.ID, group.ID, nil) - - // 增加 90.0,应该成功 - err := s.repo.IncrementUsage(s.ctx, sub.ID, 90.0) - s.Require().NoError(err, "first increment should succeed") - - // 再增加 20.0,会超过 100.0 限额,应该失败 - err = s.repo.IncrementUsage(s.ctx, sub.ID, 20.0) - s.Require().Error(err, "should fail when monthly limit exceeded") - s.Require().ErrorIs(err, service.ErrMonthlyLimitExceeded) -} - -func (s *UserSubscriptionRepoSuite) TestIncrementUsage_NoLimits() { - user := s.mustCreateUser("nolimits@test.com", service.RoleUser) - group := s.mustCreateGroupWithLimits("g-nolimits", nil, nil, nil) // 无限额 - sub := s.mustCreateSubscription(user.ID, group.ID, nil) - - // 应该可以增加任意金额 - err := s.repo.IncrementUsage(s.ctx, sub.ID, 1000000.0) - s.Require().NoError(err, "should succeed without limits") - - got, err := s.repo.GetByID(s.ctx, sub.ID) - s.Require().NoError(err) - s.Require().InDelta(1000000.0, got.DailyUsageUSD, 1e-6) -} - -func (s *UserSubscriptionRepoSuite) TestIncrementUsage_AtExactLimit() { - user := s.mustCreateUser("exactlimit@test.com", service.RoleUser) - dailyLimit := 10.0 - group := s.mustCreateGroupWithLimits("g-exactlimit", &dailyLimit, nil, nil) - sub := s.mustCreateSubscription(user.ID, group.ID, nil) - - // 正好达到限额应该成功 - err := s.repo.IncrementUsage(s.ctx, sub.ID, 10.0) - s.Require().NoError(err, "should succeed at exact limit") - - got, err := s.repo.GetByID(s.ctx, sub.ID) - s.Require().NoError(err) - s.Require().InDelta(10.0, got.DailyUsageUSD, 1e-6) -} +// --- 软删除过滤测试 --- func (s *UserSubscriptionRepoSuite) TestIncrementUsage_SoftDeletedGroup() { user := s.mustCreateUser("softdeleted@test.com", service.RoleUser) @@ -779,7 +674,7 @@ func (s *UserSubscriptionRepoSuite) TestUpdate_NilInput() { func (s *UserSubscriptionRepoSuite) TestIncrementUsage_Concurrent() { user := s.mustCreateUser("concurrent@test.com", service.RoleUser) - group := s.mustCreateGroupWithLimits("g-concurrent", nil, nil, nil) // 无限额 + group := s.mustCreateGroup("g-concurrent") sub := s.mustCreateSubscription(user.ID, group.ID, nil) const numGoroutines = 10 @@ -808,34 +703,6 @@ func (s *UserSubscriptionRepoSuite) TestIncrementUsage_Concurrent() { s.Require().InDelta(expectedUsage, got.MonthlyUsageUSD, 1e-6, "monthly usage should be correctly accumulated") } -func (s *UserSubscriptionRepoSuite) TestIncrementUsage_ConcurrentWithLimit() { - user := s.mustCreateUser("concurrentlimit@test.com", service.RoleUser) - dailyLimit := 5.0 - group := s.mustCreateGroupWithLimits("g-concurrentlimit", &dailyLimit, nil, nil) - sub := s.mustCreateSubscription(user.ID, group.ID, nil) - - // 注意:事务内的操作是串行的,所以这里改为顺序执行以验证限额逻辑 - // 尝试增加 10 次,每次 1.0,但限额只有 5.0 - const numAttempts = 10 - const incrementPerAttempt = 1.0 - - successCount := 0 - for i := 0; i < numAttempts; i++ { - err := s.repo.IncrementUsage(s.ctx, sub.ID, incrementPerAttempt) - if err == nil { - successCount++ - } - } - - // 验证:应该有 5 次成功(不超过限额),5 次失败(超出限额) - s.Require().Equal(5, successCount, "exactly 5 increments should succeed (limit=5, increment=1)") - - // 验证最终用量等于限额 - got, err := s.repo.GetByID(s.ctx, sub.ID) - s.Require().NoError(err) - s.Require().InDelta(dailyLimit, got.DailyUsageUSD, 1e-6, "daily usage should equal limit") -} - func (s *UserSubscriptionRepoSuite) TestTxContext_RollbackIsolation() { baseClient := testEntClient(s.T()) tx, err := baseClient.Tx(context.Background()) diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go index f1a8d4cf..0d579b23 100644 --- a/backend/internal/repository/wire.go +++ b/backend/internal/repository/wire.go @@ -1,6 +1,11 @@ package repository import ( + "database/sql" + "errors" + + entsql "entgo.io/ent/dialect/sql" + "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/google/wire" @@ -54,4 +59,58 @@ var ProviderSet = wire.NewSet( NewOpenAIOAuthClient, NewGeminiOAuthClient, NewGeminiCliCodeAssistClient, + + ProvideEnt, + ProvideSQLDB, + ProvideRedis, ) + +// ProvideEnt 为依赖注入提供 Ent 客户端。 +// +// 该函数是 InitEnt 的包装器,符合 Wire 的依赖提供函数签名要求。 +// Wire 会在编译时分析依赖关系,自动生成初始化代码。 +// +// 依赖:config.Config +// 提供:*ent.Client +func ProvideEnt(cfg *config.Config) (*ent.Client, error) { + client, _, err := InitEnt(cfg) + return client, err +} + +// ProvideSQLDB 从 Ent 客户端提取底层的 *sql.DB 连接。 +// +// 某些 Repository 需要直接执行原生 SQL(如复杂的批量更新、聚合查询), +// 此时需要访问底层的 sql.DB 而不是通过 Ent ORM。 +// +// 设计说明: +// - Ent 底层使用 sql.DB,通过 Driver 接口可以访问 +// - 这种设计允许在同一事务中混用 Ent 和原生 SQL +// +// 依赖:*ent.Client +// 提供:*sql.DB +func ProvideSQLDB(client *ent.Client) (*sql.DB, error) { + if client == nil { + return nil, errors.New("nil ent client") + } + // 从 Ent 客户端获取底层驱动 + drv, ok := client.Driver().(*entsql.Driver) + if !ok { + return nil, errors.New("ent driver does not expose *sql.DB") + } + // 返回驱动持有的 sql.DB 实例 + return drv.DB(), nil +} + +// ProvideRedis 为依赖注入提供 Redis 客户端。 +// +// Redis 用于: +// - 分布式锁(如并发控制) +// - 缓存(如用户会话、API 响应缓存) +// - 速率限制 +// - 实时统计数据 +// +// 依赖:config.Config +// 提供:*redis.Client +func ProvideRedis(cfg *config.Config) *redis.Client { + return InitRedis(cfg) +} diff --git a/backend/internal/server/middleware/recovery.go b/backend/internal/server/middleware/recovery.go index 04ea6f9d..f05154d3 100644 --- a/backend/internal/server/middleware/recovery.go +++ b/backend/internal/server/middleware/recovery.go @@ -7,7 +7,7 @@ import ( "os" "strings" - infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/gin-gonic/gin" ) diff --git a/backend/internal/server/middleware/recovery_test.go b/backend/internal/server/middleware/recovery_test.go index 5edb6da0..439f44cb 100644 --- a/backend/internal/server/middleware/recovery_test.go +++ b/backend/internal/server/middleware/recovery_test.go @@ -8,7 +8,7 @@ import ( "net/http/httptest" "testing" - infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" diff --git a/backend/internal/service/account_service.go b/backend/internal/service/account_service.go index 05895c8b..3c5841bd 100644 --- a/backend/internal/service/account_service.go +++ b/backend/internal/service/account_service.go @@ -5,7 +5,7 @@ import ( "fmt" "time" - infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" ) diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index 4be09810..feeb19a0 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -488,6 +488,11 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn subscriptionType = SubscriptionTypeStandard } + // 限额字段:0 和 nil 都表示"无限制" + dailyLimit := normalizeLimit(input.DailyLimitUSD) + weeklyLimit := normalizeLimit(input.WeeklyLimitUSD) + monthlyLimit := normalizeLimit(input.MonthlyLimitUSD) + group := &Group{ Name: input.Name, Description: input.Description, @@ -496,9 +501,9 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn IsExclusive: input.IsExclusive, Status: StatusActive, SubscriptionType: subscriptionType, - DailyLimitUSD: input.DailyLimitUSD, - WeeklyLimitUSD: input.WeeklyLimitUSD, - MonthlyLimitUSD: input.MonthlyLimitUSD, + DailyLimitUSD: dailyLimit, + WeeklyLimitUSD: weeklyLimit, + MonthlyLimitUSD: monthlyLimit, } if err := s.groupRepo.Create(ctx, group); err != nil { return nil, err @@ -506,6 +511,14 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn return group, nil } +// normalizeLimit 将 0 或负数转换为 nil(表示无限制) +func normalizeLimit(limit *float64) *float64 { + if limit == nil || *limit <= 0 { + return nil + } + return limit +} + func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) { group, err := s.groupRepo.GetByID(ctx, id) if err != nil { @@ -535,15 +548,15 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd if input.SubscriptionType != "" { group.SubscriptionType = input.SubscriptionType } - // 限额字段支持设置为nil(清除限额)或具体值 + // 限额字段:0 和 nil 都表示"无限制",正数表示具体限额 if input.DailyLimitUSD != nil { - group.DailyLimitUSD = input.DailyLimitUSD + group.DailyLimitUSD = normalizeLimit(input.DailyLimitUSD) } if input.WeeklyLimitUSD != nil { - group.WeeklyLimitUSD = input.WeeklyLimitUSD + group.WeeklyLimitUSD = normalizeLimit(input.WeeklyLimitUSD) } if input.MonthlyLimitUSD != nil { - group.MonthlyLimitUSD = input.MonthlyLimitUSD + group.MonthlyLimitUSD = normalizeLimit(input.MonthlyLimitUSD) } if err := s.groupRepo.Update(ctx, group); err != nil { diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index ae2976f8..5b3bf565 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -358,6 +358,15 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, return nil, fmt.Errorf("transform request: %w", err) } + // 调试:记录转换后的请求体(仅记录前 2000 字符) + if bodyJSON, err := json.Marshal(geminiBody); err == nil { + truncated := string(bodyJSON) + if len(truncated) > 2000 { + truncated = truncated[:2000] + "..." + } + log.Printf("[Debug] Transformed Gemini request: %s", truncated) + } + // 构建上游 action action := "generateContent" if claudeReq.Stream { diff --git a/backend/internal/service/antigravity_token_refresher.go b/backend/internal/service/antigravity_token_refresher.go index b4739025..9dd4463f 100644 --- a/backend/internal/service/antigravity_token_refresher.go +++ b/backend/internal/service/antigravity_token_refresher.go @@ -2,6 +2,7 @@ package service import ( "context" + "fmt" "time" ) @@ -28,7 +29,7 @@ func (r *AntigravityTokenRefresher) CanRefresh(account *Account) bool { } // NeedsRefresh 检查账户是否需要刷新 -// Antigravity 使用固定的10分钟刷新窗口,忽略全局配置 +// Antigravity 使用固定的15分钟刷新窗口,忽略全局配置 func (r *AntigravityTokenRefresher) NeedsRefresh(account *Account, _ time.Duration) bool { if !r.CanRefresh(account) { return false @@ -37,7 +38,13 @@ func (r *AntigravityTokenRefresher) NeedsRefresh(account *Account, _ time.Durati if expiresAt == nil { return false } - return time.Until(*expiresAt) < antigravityRefreshWindow + timeUntilExpiry := time.Until(*expiresAt) + needsRefresh := timeUntilExpiry < antigravityRefreshWindow + if needsRefresh { + fmt.Printf("[AntigravityTokenRefresher] Account %d needs refresh: expires_at=%s, time_until_expiry=%v, window=%v\n", + account.ID, expiresAt.Format("2006-01-02 15:04:05"), timeUntilExpiry, antigravityRefreshWindow) + } + return needsRefresh } // Refresh 执行 token 刷新 diff --git a/backend/internal/service/api_key_service.go b/backend/internal/service/api_key_service.go index facf997e..f22c383a 100644 --- a/backend/internal/service/api_key_service.go +++ b/backend/internal/service/api_key_service.go @@ -8,7 +8,7 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/config" - infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" ) diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go index 54bbfa5c..69765520 100644 --- a/backend/internal/service/auth_service.go +++ b/backend/internal/service/auth_service.go @@ -8,7 +8,7 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/config" - infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/golang-jwt/jwt/v5" "golang.org/x/crypto/bcrypt" diff --git a/backend/internal/service/billing_cache_service.go b/backend/internal/service/billing_cache_service.go index 58ed555a..9cdeed7b 100644 --- a/backend/internal/service/billing_cache_service.go +++ b/backend/internal/service/billing_cache_service.go @@ -9,7 +9,7 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/config" - infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" ) // 错误定义 diff --git a/backend/internal/service/email_service.go b/backend/internal/service/email_service.go index 7b4db611..6537b01e 100644 --- a/backend/internal/service/email_service.go +++ b/backend/internal/service/email_service.go @@ -10,7 +10,7 @@ import ( "strconv" "time" - infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" ) var ( diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index af9342b1..cb60131b 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -20,6 +20,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + "github.com/tidwall/gjson" "github.com/tidwall/sjson" "github.com/gin-gonic/gin" @@ -1061,6 +1062,30 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A // 处理错误响应(不可重试的错误) if resp.StatusCode >= 400 { + // 可选:对部分 400 触发 failover(默认关闭以保持语义) + if resp.StatusCode == 400 && s.cfg != nil && s.cfg.Gateway.FailoverOn400 { + respBody, readErr := io.ReadAll(resp.Body) + if readErr != nil { + // ReadAll failed, fall back to normal error handling without consuming the stream + return s.handleErrorResponse(ctx, resp, c, account) + } + _ = resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + + if s.shouldFailoverOn400(respBody) { + if s.cfg.Gateway.LogUpstreamErrorBody { + log.Printf( + "Account %d: 400 error, attempting failover: %s", + account.ID, + truncateForLog(respBody, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes), + ) + } else { + log.Printf("Account %d: 400 error, attempting failover", account.ID) + } + s.handleFailoverSideEffects(ctx, resp, account) + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} + } + } return s.handleErrorResponse(ctx, resp, c, account) } @@ -1163,6 +1188,13 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex // 处理anthropic-beta header(OAuth账号需要特殊处理) if tokenType == "oauth" { req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta"))) + } else if s.cfg != nil && s.cfg.Gateway.InjectBetaForApiKey && req.Header.Get("anthropic-beta") == "" { + // API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭) + if requestNeedsBetaFeatures(body) { + if beta := defaultApiKeyBetaHeader(body); beta != "" { + req.Header.Set("anthropic-beta", beta) + } + } } return req, nil @@ -1215,6 +1247,83 @@ func (s *GatewayService) getBetaHeader(modelID string, clientBetaHeader string) return claude.DefaultBetaHeader } +func requestNeedsBetaFeatures(body []byte) bool { + tools := gjson.GetBytes(body, "tools") + if tools.Exists() && tools.IsArray() && len(tools.Array()) > 0 { + return true + } + if strings.EqualFold(gjson.GetBytes(body, "thinking.type").String(), "enabled") { + return true + } + return false +} + +func defaultApiKeyBetaHeader(body []byte) string { + modelID := gjson.GetBytes(body, "model").String() + if strings.Contains(strings.ToLower(modelID), "haiku") { + return claude.ApiKeyHaikuBetaHeader + } + return claude.ApiKeyBetaHeader +} + +func truncateForLog(b []byte, maxBytes int) string { + if maxBytes <= 0 { + maxBytes = 2048 + } + if len(b) > maxBytes { + b = b[:maxBytes] + } + s := string(b) + // 保持一行,避免污染日志格式 + s = strings.ReplaceAll(s, "\n", "\\n") + s = strings.ReplaceAll(s, "\r", "\\r") + return s +} + +func (s *GatewayService) shouldFailoverOn400(respBody []byte) bool { + // 只对“可能是兼容性差异导致”的 400 允许切换,避免无意义重试。 + // 默认保守:无法识别则不切换。 + msg := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(respBody))) + if msg == "" { + return false + } + + // 缺少/错误的 beta header:换账号/链路可能成功(尤其是混合调度时)。 + // 更精确匹配 beta 相关的兼容性问题,避免误触发切换。 + if strings.Contains(msg, "anthropic-beta") || + strings.Contains(msg, "beta feature") || + strings.Contains(msg, "requires beta") { + return true + } + + // thinking/tool streaming 等兼容性约束(常见于中间转换链路) + if strings.Contains(msg, "thinking") || strings.Contains(msg, "thought_signature") || strings.Contains(msg, "signature") { + return true + } + if strings.Contains(msg, "tool_use") || strings.Contains(msg, "tool_result") || strings.Contains(msg, "tools") { + return true + } + + return false +} + +func extractUpstreamErrorMessage(body []byte) string { + // Claude 风格:{"type":"error","error":{"type":"...","message":"..."}} + if m := gjson.GetBytes(body, "error.message").String(); strings.TrimSpace(m) != "" { + inner := strings.TrimSpace(m) + // 有些上游会把完整 JSON 作为字符串塞进 message + if strings.HasPrefix(inner, "{") { + if innerMsg := gjson.Get(inner, "error.message").String(); strings.TrimSpace(innerMsg) != "" { + return innerMsg + } + } + return m + } + + // 兜底:尝试顶层 message + return gjson.GetBytes(body, "message").String() +} + func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) { body, _ := io.ReadAll(resp.Body) @@ -1227,6 +1336,16 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res switch resp.StatusCode { case 400: + // 仅记录上游错误摘要(避免输出请求内容);需要时可通过配置打开 + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + log.Printf( + "Upstream 400 error (account=%d platform=%s type=%s): %s", + account.ID, + account.Platform, + account.Type, + truncateForLog(body, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes), + ) + } c.Data(http.StatusBadRequest, "application/json", body) return nil, fmt.Errorf("upstream error: %d", resp.StatusCode) case 401: @@ -1706,6 +1825,18 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, // 标记账号状态(429/529等) s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + // 记录上游错误摘要便于排障(不回显请求内容) + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + log.Printf( + "count_tokens upstream error %d (account=%d platform=%s type=%s): %s", + resp.StatusCode, + account.ID, + account.Platform, + account.Type, + truncateForLog(respBody, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes), + ) + } + // 返回简化的错误响应 errMsg := "Upstream request failed" switch resp.StatusCode { @@ -1786,6 +1917,13 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con // OAuth 账号:处理 anthropic-beta header if tokenType == "oauth" { req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta"))) + } else if s.cfg != nil && s.cfg.Gateway.InjectBetaForApiKey && req.Header.Get("anthropic-beta") == "" { + // API-key:与 messages 同步的按需 beta 注入(默认关闭) + if requestNeedsBetaFeatures(body) { + if beta := defaultApiKeyBetaHeader(body); beta != "" { + req.Header.Set("anthropic-beta", beta) + } + } } return req, nil diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go index a0bf1b6a..b1877800 100644 --- a/backend/internal/service/gemini_messages_compat_service.go +++ b/backend/internal/service/gemini_messages_compat_service.go @@ -2278,11 +2278,13 @@ func convertClaudeToolsToGeminiTools(tools any) []any { "properties": map[string]any{}, } } + // 清理 JSON Schema + cleanedParams := cleanToolSchema(params) funcDecls = append(funcDecls, map[string]any{ "name": name, "description": desc, - "parameters": params, + "parameters": cleanedParams, }) } @@ -2296,6 +2298,41 @@ func convertClaudeToolsToGeminiTools(tools any) []any { } } +// cleanToolSchema 清理工具的 JSON Schema,移除 Gemini 不支持的字段 +func cleanToolSchema(schema any) any { + if schema == nil { + return nil + } + + switch v := schema.(type) { + case map[string]any: + cleaned := make(map[string]any) + for key, value := range v { + // 跳过不支持的字段 + if key == "$schema" || key == "$id" || key == "$ref" || + key == "additionalProperties" || key == "minLength" || + key == "maxLength" || key == "minItems" || key == "maxItems" { + continue + } + // 递归清理嵌套对象 + cleaned[key] = cleanToolSchema(value) + } + // 规范化 type 字段为大写 + if typeVal, ok := cleaned["type"].(string); ok { + cleaned["type"] = strings.ToUpper(typeVal) + } + return cleaned + case []any: + cleaned := make([]any, len(v)) + for i, item := range v { + cleaned[i] = cleanToolSchema(item) + } + return cleaned + default: + return v + } +} + func convertClaudeGenerationConfig(req map[string]any) map[string]any { out := make(map[string]any) if mt, ok := asInt(req["max_tokens"]); ok && mt > 0 { diff --git a/backend/internal/service/gemini_messages_compat_service_test.go b/backend/internal/service/gemini_messages_compat_service_test.go new file mode 100644 index 00000000..d49f2eb3 --- /dev/null +++ b/backend/internal/service/gemini_messages_compat_service_test.go @@ -0,0 +1,128 @@ +package service + +import ( + "testing" +) + +// TestConvertClaudeToolsToGeminiTools_CustomType 测试custom类型工具转换 +func TestConvertClaudeToolsToGeminiTools_CustomType(t *testing.T) { + tests := []struct { + name string + tools any + expectedLen int + description string + }{ + { + name: "Standard tools", + tools: []any{ + map[string]any{ + "name": "get_weather", + "description": "Get weather info", + "input_schema": map[string]any{"type": "object"}, + }, + }, + expectedLen: 1, + description: "标准工具格式应该正常转换", + }, + { + name: "Custom type tool (MCP format)", + tools: []any{ + map[string]any{ + "type": "custom", + "name": "mcp_tool", + "custom": map[string]any{ + "description": "MCP tool description", + "input_schema": map[string]any{"type": "object"}, + }, + }, + }, + expectedLen: 1, + description: "Custom类型工具应该从custom字段读取", + }, + { + name: "Mixed standard and custom tools", + tools: []any{ + map[string]any{ + "name": "standard_tool", + "description": "Standard", + "input_schema": map[string]any{"type": "object"}, + }, + map[string]any{ + "type": "custom", + "name": "custom_tool", + "custom": map[string]any{ + "description": "Custom", + "input_schema": map[string]any{"type": "object"}, + }, + }, + }, + expectedLen: 1, + description: "混合工具应该都能正确转换", + }, + { + name: "Custom tool without custom field", + tools: []any{ + map[string]any{ + "type": "custom", + "name": "invalid_custom", + // 缺少 custom 字段 + }, + }, + expectedLen: 0, // 应该被跳过 + description: "缺少custom字段的custom工具应该被跳过", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := convertClaudeToolsToGeminiTools(tt.tools) + + if tt.expectedLen == 0 { + if result != nil { + t.Errorf("%s: expected nil result, got %v", tt.description, result) + } + return + } + + if result == nil { + t.Fatalf("%s: expected non-nil result", tt.description) + } + + if len(result) != 1 { + t.Errorf("%s: expected 1 tool declaration, got %d", tt.description, len(result)) + return + } + + toolDecl, ok := result[0].(map[string]any) + if !ok { + t.Fatalf("%s: result[0] is not map[string]any", tt.description) + } + + funcDecls, ok := toolDecl["functionDeclarations"].([]any) + if !ok { + t.Fatalf("%s: functionDeclarations is not []any", tt.description) + } + + toolsArr, _ := tt.tools.([]any) + expectedFuncCount := 0 + for _, tool := range toolsArr { + toolMap, _ := tool.(map[string]any) + if toolMap["name"] != "" { + // 检查是否为有效的custom工具 + if toolMap["type"] == "custom" { + if toolMap["custom"] != nil { + expectedFuncCount++ + } + } else { + expectedFuncCount++ + } + } + } + + if len(funcDecls) != expectedFuncCount { + t.Errorf("%s: expected %d function declarations, got %d", + tt.description, expectedFuncCount, len(funcDecls)) + } + }) + } +} diff --git a/backend/internal/service/gemini_oauth_service.go b/backend/internal/service/gemini_oauth_service.go index e4bda5f8..221bd0f2 100644 --- a/backend/internal/service/gemini_oauth_service.go +++ b/backend/internal/service/gemini_oauth_service.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "net/http" + "regexp" "strconv" "strings" "time" @@ -163,6 +164,45 @@ type GeminiTokenInfo struct { Scope string `json:"scope,omitempty"` ProjectID string `json:"project_id,omitempty"` OAuthType string `json:"oauth_type,omitempty"` // "code_assist" 或 "ai_studio" + TierID string `json:"tier_id,omitempty"` // Gemini Code Assist tier: LEGACY/PRO/ULTRA +} + +// validateTierID validates tier_id format and length +func validateTierID(tierID string) error { + if tierID == "" { + return nil // Empty is allowed + } + if len(tierID) > 64 { + return fmt.Errorf("tier_id exceeds maximum length of 64 characters") + } + // Allow alphanumeric, underscore, hyphen, and slash (for tier paths) + if !regexp.MustCompile(`^[a-zA-Z0-9_/-]+$`).MatchString(tierID) { + return fmt.Errorf("tier_id contains invalid characters") + } + return nil +} + +// extractTierIDFromAllowedTiers extracts tierID from LoadCodeAssist response +// Prioritizes IsDefault tier, falls back to first non-empty tier +func extractTierIDFromAllowedTiers(allowedTiers []geminicli.AllowedTier) string { + tierID := "LEGACY" + // First pass: look for default tier + for _, tier := range allowedTiers { + if tier.IsDefault && strings.TrimSpace(tier.ID) != "" { + tierID = strings.TrimSpace(tier.ID) + break + } + } + // Second pass: if still LEGACY, take first non-empty tier + if tierID == "LEGACY" { + for _, tier := range allowedTiers { + if strings.TrimSpace(tier.ID) != "" { + tierID = strings.TrimSpace(tier.ID) + break + } + } + } + return tierID } func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExchangeCodeInput) (*GeminiTokenInfo, error) { @@ -223,13 +263,14 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - 300 projectID := sessionProjectID + var tierID string // 对于 code_assist 模式,project_id 是必需的 // 对于 ai_studio 模式,project_id 是可选的(不影响使用 AI Studio API) if oauthType == "code_assist" { if projectID == "" { var err error - projectID, err = s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL) + projectID, tierID, err = s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL) if err != nil { // 记录警告但不阻断流程,允许后续补充 project_id fmt.Printf("[GeminiOAuth] Warning: Failed to fetch project_id during token exchange: %v\n", err) @@ -248,6 +289,7 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch ExpiresAt: expiresAt, Scope: tokenResp.Scope, ProjectID: projectID, + TierID: tierID, OAuthType: oauthType, }, nil } @@ -357,7 +399,7 @@ func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *A // For Code Assist, project_id is required. Auto-detect if missing. // For AI Studio OAuth, project_id is optional and should not block refresh. if oauthType == "code_assist" && strings.TrimSpace(tokenInfo.ProjectID) == "" { - projectID, err := s.fetchProjectID(ctx, tokenInfo.AccessToken, proxyURL) + projectID, tierID, err := s.fetchProjectID(ctx, tokenInfo.AccessToken, proxyURL) if err != nil { return nil, fmt.Errorf("failed to auto-detect project_id: %w", err) } @@ -366,6 +408,7 @@ func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *A return nil, fmt.Errorf("failed to auto-detect project_id: empty result") } tokenInfo.ProjectID = projectID + tokenInfo.TierID = tierID } return tokenInfo, nil @@ -388,6 +431,13 @@ func (s *GeminiOAuthService) BuildAccountCredentials(tokenInfo *GeminiTokenInfo) if tokenInfo.ProjectID != "" { creds["project_id"] = tokenInfo.ProjectID } + if tokenInfo.TierID != "" { + // Validate tier_id before storing + if err := validateTierID(tokenInfo.TierID); err == nil { + creds["tier_id"] = tokenInfo.TierID + } + // Silently skip invalid tier_id (don't block account creation) + } if tokenInfo.OAuthType != "" { creds["oauth_type"] = tokenInfo.OAuthType } @@ -398,34 +448,26 @@ func (s *GeminiOAuthService) Stop() { s.sessionStore.Stop() } -func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, proxyURL string) (string, error) { +func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, proxyURL string) (string, string, error) { if s.codeAssist == nil { - return "", errors.New("code assist client not configured") + return "", "", errors.New("code assist client not configured") } loadResp, loadErr := s.codeAssist.LoadCodeAssist(ctx, accessToken, proxyURL, nil) + + // Extract tierID from response (works whether CloudAICompanionProject is set or not) + tierID := "LEGACY" + if loadResp != nil { + tierID = extractTierIDFromAllowedTiers(loadResp.AllowedTiers) + } + + // If LoadCodeAssist returned a project, use it if loadErr == nil && loadResp != nil && strings.TrimSpace(loadResp.CloudAICompanionProject) != "" { - return strings.TrimSpace(loadResp.CloudAICompanionProject), nil + return strings.TrimSpace(loadResp.CloudAICompanionProject), tierID, nil } // Pick tier from allowedTiers; if no default tier is marked, pick the first non-empty tier ID. - tierID := "LEGACY" - if loadResp != nil { - for _, tier := range loadResp.AllowedTiers { - if tier.IsDefault && strings.TrimSpace(tier.ID) != "" { - tierID = strings.TrimSpace(tier.ID) - break - } - } - if strings.TrimSpace(tierID) == "" || tierID == "LEGACY" { - for _, tier := range loadResp.AllowedTiers { - if strings.TrimSpace(tier.ID) != "" { - tierID = strings.TrimSpace(tier.ID) - break - } - } - } - } + // (tierID already extracted above, reuse it) req := &geminicli.OnboardUserRequest{ TierID: tierID, @@ -443,39 +485,39 @@ func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, pr // If Code Assist onboarding fails (e.g. INVALID_ARGUMENT), fallback to Cloud Resource Manager projects. fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL) if fbErr == nil && strings.TrimSpace(fallback) != "" { - return strings.TrimSpace(fallback), nil + return strings.TrimSpace(fallback), tierID, nil } - return "", err + return "", "", err } if resp.Done { if resp.Response != nil && resp.Response.CloudAICompanionProject != nil { switch v := resp.Response.CloudAICompanionProject.(type) { case string: - return strings.TrimSpace(v), nil + return strings.TrimSpace(v), tierID, nil case map[string]any: if id, ok := v["id"].(string); ok { - return strings.TrimSpace(id), nil + return strings.TrimSpace(id), tierID, nil } } } fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL) if fbErr == nil && strings.TrimSpace(fallback) != "" { - return strings.TrimSpace(fallback), nil + return strings.TrimSpace(fallback), tierID, nil } - return "", errors.New("onboardUser completed but no project_id returned") + return "", "", errors.New("onboardUser completed but no project_id returned") } time.Sleep(2 * time.Second) } fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL) if fbErr == nil && strings.TrimSpace(fallback) != "" { - return strings.TrimSpace(fallback), nil + return strings.TrimSpace(fallback), tierID, nil } if loadErr != nil { - return "", fmt.Errorf("loadCodeAssist failed (%v) and onboardUser timeout after %d attempts", loadErr, maxAttempts) + return "", "", fmt.Errorf("loadCodeAssist failed (%v) and onboardUser timeout after %d attempts", loadErr, maxAttempts) } - return "", fmt.Errorf("onboardUser timeout after %d attempts", maxAttempts) + return "", "", fmt.Errorf("onboardUser timeout after %d attempts", maxAttempts) } type googleCloudProject struct { diff --git a/backend/internal/service/gemini_token_provider.go b/backend/internal/service/gemini_token_provider.go index 2195ec55..5f369de5 100644 --- a/backend/internal/service/gemini_token_provider.go +++ b/backend/internal/service/gemini_token_provider.go @@ -112,7 +112,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou } } - detected, err := p.geminiOAuthService.fetchProjectID(ctx, accessToken, proxyURL) + detected, tierID, err := p.geminiOAuthService.fetchProjectID(ctx, accessToken, proxyURL) if err != nil { log.Printf("[GeminiTokenProvider] Auto-detect project_id failed: %v, fallback to AI Studio API mode", err) return accessToken, nil @@ -123,6 +123,9 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou account.Credentials = make(map[string]any) } account.Credentials["project_id"] = detected + if tierID != "" { + account.Credentials["tier_id"] = tierID + } _ = p.accountRepo.Update(ctx, account) } } diff --git a/backend/internal/service/group_service.go b/backend/internal/service/group_service.go index 886c0a3a..403636e8 100644 --- a/backend/internal/service/group_service.go +++ b/backend/internal/service/group_service.go @@ -4,7 +4,7 @@ import ( "context" "fmt" - infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" ) diff --git a/backend/internal/service/proxy_service.go b/backend/internal/service/proxy_service.go index c074b13d..044f9ffc 100644 --- a/backend/internal/service/proxy_service.go +++ b/backend/internal/service/proxy_service.go @@ -4,7 +4,7 @@ import ( "context" "fmt" - infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" ) diff --git a/backend/internal/service/redeem_service.go b/backend/internal/service/redeem_service.go index 7b0b80f5..b6324235 100644 --- a/backend/internal/service/redeem_service.go +++ b/backend/internal/service/redeem_service.go @@ -10,7 +10,7 @@ import ( "time" dbent "github.com/Wei-Shaw/sub2api/ent" - infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" ) diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index 0ffe991d..b5786ece 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -9,7 +9,7 @@ import ( "strconv" "github.com/Wei-Shaw/sub2api/internal/config" - infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" ) var ( diff --git a/backend/internal/service/subscription_service.go b/backend/internal/service/subscription_service.go index 09554c0f..d960c86f 100644 --- a/backend/internal/service/subscription_service.go +++ b/backend/internal/service/subscription_service.go @@ -6,7 +6,7 @@ import ( "log" "time" - infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" ) @@ -490,6 +490,7 @@ func (s *SubscriptionService) CheckAndResetWindows(ctx context.Context, sub *Use } // CheckUsageLimits 检查使用限额(返回错误如果超限) +// 用于中间件的快速预检查,additionalCost 通常为 0 func (s *SubscriptionService) CheckUsageLimits(ctx context.Context, sub *UserSubscription, group *Group, additionalCost float64) error { if !sub.CheckDailyLimit(group, additionalCost) { return ErrDailyLimitExceeded diff --git a/backend/internal/service/turnstile_service.go b/backend/internal/service/turnstile_service.go index cfb87c57..4afcc335 100644 --- a/backend/internal/service/turnstile_service.go +++ b/backend/internal/service/turnstile_service.go @@ -5,7 +5,7 @@ import ( "fmt" "log" - infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" ) var ( diff --git a/backend/internal/service/usage_service.go b/backend/internal/service/usage_service.go index f653ddfe..e1e97671 100644 --- a/backend/internal/service/usage_service.go +++ b/backend/internal/service/usage_service.go @@ -5,7 +5,7 @@ import ( "fmt" "time" - infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" ) diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go index c17588c6..44a94d32 100644 --- a/backend/internal/service/user_service.go +++ b/backend/internal/service/user_service.go @@ -4,7 +4,7 @@ import ( "context" "fmt" - infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" ) diff --git a/backend/internal/setup/setup.go b/backend/internal/setup/setup.go index 5565ab91..230d016f 100644 --- a/backend/internal/setup/setup.go +++ b/backend/internal/setup/setup.go @@ -11,7 +11,7 @@ import ( "strconv" "time" - "github.com/Wei-Shaw/sub2api/internal/infrastructure" + "github.com/Wei-Shaw/sub2api/internal/repository" "github.com/Wei-Shaw/sub2api/internal/service" _ "github.com/lib/pq" @@ -262,7 +262,7 @@ func initializeDatabase(cfg *SetupConfig) error { migrationCtx, cancel := context.WithTimeout(context.Background(), 60*time.Second) defer cancel() - return infrastructure.ApplyMigrations(migrationCtx, db) + return repository.ApplyMigrations(migrationCtx, db) } func createAdminUser(cfg *SetupConfig) error { diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index 5bd85d7d..5478d151 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -122,6 +122,21 @@ pricing: # Hash check interval in minutes hash_check_interval_minutes: 10 +# ============================================================================= +# Gateway (Optional) +# ============================================================================= +gateway: + # Wait time (in seconds) for upstream response headers (streaming body not affected) + response_header_timeout: 300 + # Log upstream error response body summary (safe/truncated; does not log request content) + log_upstream_error_body: false + # Max bytes to log from upstream error body + log_upstream_error_body_max_bytes: 2048 + # Auto inject anthropic-beta for API-key accounts when needed (default off) + inject_beta_for_apikey: false + # Allow failover on selected 400 errors (default off) + failover_on_400: false + # ============================================================================= # Gemini OAuth (Required for Gemini accounts) # ============================================================================= diff --git a/frontend/package-lock.json b/frontend/package-lock.json index 6563ee0c..1770a985 100644 --- a/frontend/package-lock.json +++ b/frontend/package-lock.json @@ -952,6 +952,7 @@ "integrity": "sha512-N2clP5pJhB2YnZJ3PIHFk5RkygRX5WO/5f0WC08tp0wd+sv0rsJk3MqWn3CbNmT2J505a5336jaQj4ph1AdMug==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "undici-types": "~6.21.0" } @@ -1367,6 +1368,7 @@ } ], "license": "MIT", + "peer": true, "dependencies": { "baseline-browser-mapping": "^2.9.0", "caniuse-lite": "^1.0.30001759", @@ -1443,6 +1445,7 @@ "resolved": "https://registry.npmmirror.com/chart.js/-/chart.js-4.5.1.tgz", "integrity": "sha512-GIjfiT9dbmHRiYi6Nl2yFCq7kkwdkp1W/lp2J99rX0yo9tgJGn3lKQATztIjb5tVtevcBtIdICNWqlq5+E8/Pw==", "license": "MIT", + "peer": true, "dependencies": { "@kurkle/color": "^0.3.0" }, @@ -2040,6 +2043,7 @@ "integrity": "sha512-/imKNG4EbWNrVjoNC/1H5/9GFy+tqjGBHCaSsN+P2RnPqjsLmv6UD3Ej+Kj8nBWaRAwyk7kK5ZUc+OEatnTR3A==", "dev": true, "license": "MIT", + "peer": true, "bin": { "jiti": "bin/jiti.js" } @@ -2348,6 +2352,7 @@ } ], "license": "MIT", + "peer": true, "dependencies": { "nanoid": "^3.3.11", "picocolors": "^1.1.1", @@ -2821,6 +2826,7 @@ "integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==", "dev": true, "license": "MIT", + "peer": true, "engines": { "node": ">=12" }, @@ -2854,6 +2860,7 @@ "integrity": "sha512-hjcS1mhfuyi4WW8IWtjP7brDrG2cuDZukyrYrSauoXGNgx0S7zceP07adYkJycEr56BOUTNPzbInooiN3fn1qw==", "devOptional": true, "license": "Apache-2.0", + "peer": true, "bin": { "tsc": "bin/tsc", "tsserver": "bin/tsserver" @@ -2926,6 +2933,7 @@ "integrity": "sha512-o5a9xKjbtuhY6Bi5S3+HvbRERmouabWbyUcpXXUA1u+GNUKoROi9byOJ8M0nHbHYHkYICiMlqxkg1KkYmm25Sw==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "esbuild": "^0.21.3", "postcss": "^8.4.43", @@ -3097,6 +3105,7 @@ "resolved": "https://registry.npmmirror.com/vue/-/vue-3.5.25.tgz", "integrity": "sha512-YLVdgv2K13WJ6n+kD5owehKtEXwdwXuj2TTyJMsO7pSeKw2bfRNZGjhB7YzrpbMYj5b5QsUebHpOqR3R3ziy/g==", "license": "MIT", + "peer": true, "dependencies": { "@vue/compiler-dom": "3.5.25", "@vue/compiler-sfc": "3.5.25", @@ -3190,6 +3199,7 @@ "integrity": "sha512-P7OP77b2h/Pmk+lZdJ0YWs+5tJ6J2+uOQPo7tlBnY44QqQSPYvS0qVT4wqDJgwrZaLe47etJLLQRFia71GYITw==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "@volar/typescript": "2.4.15", "@vue/language-core": "2.2.12" diff --git a/frontend/src/components/account/AccountStatusIndicator.vue b/frontend/src/components/account/AccountStatusIndicator.vue index c1ca08fa..914678a5 100644 --- a/frontend/src/components/account/AccountStatusIndicator.vue +++ b/frontend/src/components/account/AccountStatusIndicator.vue @@ -83,6 +83,14 @@ > + + + + {{ tierDisplay }} + @@ -140,4 +148,23 @@ const statusText = computed(() => { return props.account.status }) +// Computed: tier display +const tierDisplay = computed(() => { + const credentials = props.account.credentials as Record | undefined + const tierId = credentials?.tier_id + if (!tierId || tierId === 'unknown') return null + + const tierMap: Record = { + 'free': 'Free', + 'payg': 'Pay-as-you-go', + 'pay-as-you-go': 'Pay-as-you-go', + 'enterprise': 'Enterprise', + 'LEGACY': 'Legacy', + 'PRO': 'Pro', + 'ULTRA': 'Ultra' + } + + return tierMap[tierId] || tierId +}) +