fix(merge): 解决与 main 分支的配置冲突

- 合并 main 分支的上游错误日志配置
- 保留调度配置
- 合并 beta header 和 failover 配置
This commit is contained in:
IanShaw027
2026-01-01 04:33:12 +08:00
53 changed files with 951 additions and 417 deletions

View File

@@ -12,7 +12,6 @@ import (
"github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler" "github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/infrastructure"
"github.com/Wei-Shaw/sub2api/internal/repository" "github.com/Wei-Shaw/sub2api/internal/repository"
"github.com/Wei-Shaw/sub2api/internal/server" "github.com/Wei-Shaw/sub2api/internal/server"
"github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/server/middleware"
@@ -31,7 +30,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
wire.Build( wire.Build(
// Infrastructure layer ProviderSets // Infrastructure layer ProviderSets
config.ProviderSet, config.ProviderSet,
infrastructure.ProviderSet,
// Business layer ProviderSets // Business layer ProviderSets
repository.ProviderSet, repository.ProviderSet,

View File

@@ -12,7 +12,6 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler" "github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/handler/admin" "github.com/Wei-Shaw/sub2api/internal/handler/admin"
"github.com/Wei-Shaw/sub2api/internal/infrastructure"
"github.com/Wei-Shaw/sub2api/internal/repository" "github.com/Wei-Shaw/sub2api/internal/repository"
"github.com/Wei-Shaw/sub2api/internal/server" "github.com/Wei-Shaw/sub2api/internal/server"
"github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/server/middleware"
@@ -35,18 +34,18 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
client, err := infrastructure.ProvideEnt(configConfig) client, err := repository.ProvideEnt(configConfig)
if err != nil { if err != nil {
return nil, err return nil, err
} }
db, err := infrastructure.ProvideSQLDB(client) db, err := repository.ProvideSQLDB(client)
if err != nil { if err != nil {
return nil, err return nil, err
} }
userRepository := repository.NewUserRepository(client, db) userRepository := repository.NewUserRepository(client, db)
settingRepository := repository.NewSettingRepository(client) settingRepository := repository.NewSettingRepository(client)
settingService := service.NewSettingService(settingRepository, configConfig) settingService := service.NewSettingService(settingRepository, configConfig)
redisClient := infrastructure.ProvideRedis(configConfig) redisClient := repository.ProvideRedis(configConfig)
emailCache := repository.NewEmailCache(redisClient) emailCache := repository.NewEmailCache(redisClient)
emailService := service.NewEmailService(settingRepository, emailCache) emailService := service.NewEmailService(settingRepository, emailCache)
turnstileVerifier := repository.NewTurnstileVerifier() turnstileVerifier := repository.NewTurnstileVerifier()

View File

@@ -121,6 +121,17 @@ type GatewayConfig struct {
// 应大于最长 LLM 请求时间,防止请求完成前槽位过期 // 应大于最长 LLM 请求时间,防止请求完成前槽位过期
ConcurrencySlotTTLMinutes int `mapstructure:"concurrency_slot_ttl_minutes"` ConcurrencySlotTTLMinutes int `mapstructure:"concurrency_slot_ttl_minutes"`
// 是否记录上游错误响应体摘要(避免输出请求内容)
LogUpstreamErrorBody bool `mapstructure:"log_upstream_error_body"`
// 上游错误响应体记录最大字节数(超过会截断)
LogUpstreamErrorBodyMaxBytes int `mapstructure:"log_upstream_error_body_max_bytes"`
// API-key 账号在客户端未提供 anthropic-beta 时,是否按需自动补齐(默认关闭以保持兼容)
InjectBetaForApiKey bool `mapstructure:"inject_beta_for_apikey"`
// 是否允许对部分 400 错误触发 failover默认关闭以避免改变语义
FailoverOn400 bool `mapstructure:"failover_on_400"`
// Scheduling: 账号调度相关配置 // Scheduling: 账号调度相关配置
Scheduling GatewaySchedulingConfig `mapstructure:"scheduling"` Scheduling GatewaySchedulingConfig `mapstructure:"scheduling"`
} }
@@ -334,6 +345,10 @@ func setDefaults() {
// Gateway // Gateway
viper.SetDefault("gateway.response_header_timeout", 300) // 300秒(5分钟)等待上游响应头LLM高负载时可能排队较久 viper.SetDefault("gateway.response_header_timeout", 300) // 300秒(5分钟)等待上游响应头LLM高负载时可能排队较久
viper.SetDefault("gateway.log_upstream_error_body", false)
viper.SetDefault("gateway.log_upstream_error_body_max_bytes", 2048)
viper.SetDefault("gateway.inject_beta_for_apikey", false)
viper.SetDefault("gateway.failover_on_400", false)
viper.SetDefault("gateway.max_body_size", int64(100*1024*1024)) viper.SetDefault("gateway.max_body_size", int64(100*1024*1024))
viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy) viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy)
// HTTP 上游连接池配置(针对 5000+ 并发用户优化) // HTTP 上游连接池配置(针对 5000+ 并发用户优化)

View File

@@ -1,79 +0,0 @@
package infrastructure
import (
"database/sql"
"errors"
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/google/wire"
"github.com/redis/go-redis/v9"
entsql "entgo.io/ent/dialect/sql"
)
// ProviderSet 是基础设施层的 Wire 依赖提供者集合。
//
// Wire 是 Google 开发的编译时依赖注入工具。ProviderSet 将相关的依赖提供函数
// 组织在一起,便于在应用程序启动时自动组装依赖关系。
//
// 包含的提供者:
// - ProvideEnt: 提供 Ent ORM 客户端
// - ProvideSQLDB: 提供底层 SQL 数据库连接
// - ProvideRedis: 提供 Redis 客户端
var ProviderSet = wire.NewSet(
ProvideEnt,
ProvideSQLDB,
ProvideRedis,
)
// ProvideEnt 为依赖注入提供 Ent 客户端。
//
// 该函数是 InitEnt 的包装器,符合 Wire 的依赖提供函数签名要求。
// Wire 会在编译时分析依赖关系,自动生成初始化代码。
//
// 依赖config.Config
// 提供:*ent.Client
func ProvideEnt(cfg *config.Config) (*ent.Client, error) {
client, _, err := InitEnt(cfg)
return client, err
}
// ProvideSQLDB 从 Ent 客户端提取底层的 *sql.DB 连接。
//
// 某些 Repository 需要直接执行原生 SQL如复杂的批量更新、聚合查询
// 此时需要访问底层的 sql.DB 而不是通过 Ent ORM。
//
// 设计说明:
// - Ent 底层使用 sql.DB通过 Driver 接口可以访问
// - 这种设计允许在同一事务中混用 Ent 和原生 SQL
//
// 依赖:*ent.Client
// 提供:*sql.DB
func ProvideSQLDB(client *ent.Client) (*sql.DB, error) {
if client == nil {
return nil, errors.New("nil ent client")
}
// 从 Ent 客户端获取底层驱动
drv, ok := client.Driver().(*entsql.Driver)
if !ok {
return nil, errors.New("ent driver does not expose *sql.DB")
}
// 返回驱动持有的 sql.DB 实例
return drv.DB(), nil
}
// ProvideRedis 为依赖注入提供 Redis 客户端。
//
// Redis 用于:
// - 分布式锁(如并发控制)
// - 缓存如用户会话、API 响应缓存)
// - 速率限制
// - 实时统计数据
//
// 依赖config.Config
// 提供:*redis.Client
func ProvideRedis(cfg *config.Config) *redis.Client {
return InitRedis(cfg)
}

View File

@@ -54,6 +54,9 @@ type CustomToolSpec struct {
InputSchema map[string]any `json:"input_schema"` InputSchema map[string]any `json:"input_schema"`
} }
// ClaudeCustomToolSpec 兼容旧命名MCP custom 工具规格)
type ClaudeCustomToolSpec = CustomToolSpec
// SystemBlock system prompt 数组形式的元素 // SystemBlock system prompt 数组形式的元素
type SystemBlock struct { type SystemBlock struct {
Type string `json:"type"` Type string `json:"type"`

View File

@@ -14,13 +14,16 @@ func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel st
// 用于存储 tool_use id -> name 映射 // 用于存储 tool_use id -> name 映射
toolIDToName := make(map[string]string) toolIDToName := make(map[string]string)
// 检测是否启用 thinking
isThinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled"
// 只有 Gemini 模型支持 dummy thought workaround // 只有 Gemini 模型支持 dummy thought workaround
// Claude 模型通过 Vertex/Google API 需要有效的 thought signatures // Claude 模型通过 Vertex/Google API 需要有效的 thought signatures
allowDummyThought := strings.HasPrefix(mappedModel, "gemini-") allowDummyThought := strings.HasPrefix(mappedModel, "gemini-")
// 检测是否启用 thinking
requestedThinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled"
// 为避免 Claude 模型的 thought signature/消息块约束导致 400上游要求 thinking 块开头等),
// 非 Gemini 模型默认不启用 thinking除非未来支持完整签名链路
isThinkingEnabled := requestedThinkingEnabled && allowDummyThought
// 1. 构建 contents // 1. 构建 contents
contents, err := buildContents(claudeReq.Messages, toolIDToName, isThinkingEnabled, allowDummyThought) contents, err := buildContents(claudeReq.Messages, toolIDToName, isThinkingEnabled, allowDummyThought)
if err != nil { if err != nil {
@@ -31,7 +34,15 @@ func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel st
systemInstruction := buildSystemInstruction(claudeReq.System, claudeReq.Model) systemInstruction := buildSystemInstruction(claudeReq.System, claudeReq.Model)
// 3. 构建 generationConfig // 3. 构建 generationConfig
generationConfig := buildGenerationConfig(claudeReq) reqForGen := claudeReq
if requestedThinkingEnabled && !allowDummyThought {
log.Printf("[Warning] Disabling thinking for non-Gemini model in antigravity transform: model=%s", mappedModel)
// shallow copy to avoid mutating caller's request
clone := *claudeReq
clone.Thinking = nil
reqForGen = &clone
}
generationConfig := buildGenerationConfig(reqForGen)
// 4. 构建 tools // 4. 构建 tools
tools := buildTools(claudeReq.Tools) tools := buildTools(claudeReq.Tools)
@@ -148,8 +159,9 @@ func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isT
if !hasThoughtPart && len(parts) > 0 { if !hasThoughtPart && len(parts) > 0 {
// 在开头添加 dummy thinking block // 在开头添加 dummy thinking block
parts = append([]GeminiPart{{ parts = append([]GeminiPart{{
Text: "Thinking...", Text: "Thinking...",
Thought: true, Thought: true,
ThoughtSignature: dummyThoughtSignature,
}}, parts...) }}, parts...)
} }
} }
@@ -171,6 +183,34 @@ func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isT
// 参考: https://ai.google.dev/gemini-api/docs/thought-signatures // 参考: https://ai.google.dev/gemini-api/docs/thought-signatures
const dummyThoughtSignature = "skip_thought_signature_validator" const dummyThoughtSignature = "skip_thought_signature_validator"
// isValidThoughtSignature 验证 thought signature 是否有效
// Claude API 要求 signature 必须是 base64 编码的字符串,长度至少 32 字节
func isValidThoughtSignature(signature string) bool {
// 空字符串无效
if signature == "" {
return false
}
// signature 应该是 base64 编码,长度至少 40 个字符(约 30 字节)
// 参考 Claude API 文档和实际观察到的有效 signature
if len(signature) < 40 {
log.Printf("[Debug] Signature too short: len=%d", len(signature))
return false
}
// 检查是否是有效的 base64 字符
// base64 字符集: A-Z, a-z, 0-9, +, /, =
for i, c := range signature {
if (c < 'A' || c > 'Z') && (c < 'a' || c > 'z') &&
(c < '0' || c > '9') && c != '+' && c != '/' && c != '=' {
log.Printf("[Debug] Invalid base64 character at position %d: %c (code=%d)", i, c, c)
return false
}
}
return true
}
// buildParts 构建消息的 parts // buildParts 构建消息的 parts
// allowDummyThought: 只有 Gemini 模型支持 dummy thought signature // allowDummyThought: 只有 Gemini 模型支持 dummy thought signature
func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDummyThought bool) ([]GeminiPart, error) { func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDummyThought bool) ([]GeminiPart, error) {
@@ -199,22 +239,30 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
} }
case "thinking": case "thinking":
part := GeminiPart{ if allowDummyThought {
Text: block.Thinking, // Gemini 模型可以使用 dummy signature
Thought: true, parts = append(parts, GeminiPart{
} Text: block.Thinking,
// 保留原有 signatureClaude 模型需要有效的 signature Thought: true,
if block.Signature != "" { ThoughtSignature: dummyThoughtSignature,
part.ThoughtSignature = block.Signature })
} else if !allowDummyThought {
// Claude 模型需要有效 signature跳过无 signature 的 thinking block
log.Printf("Warning: skipping thinking block without signature for Claude model")
continue continue
} else {
// Gemini 模型使用 dummy signature
part.ThoughtSignature = dummyThoughtSignature
} }
parts = append(parts, part)
// Claude 模型:仅在提供有效 signature 时保留 thinking block否则跳过以避免上游校验失败。
signature := strings.TrimSpace(block.Signature)
if signature == "" || signature == dummyThoughtSignature {
log.Printf("[Warning] Skipping thinking block for Claude model (missing or dummy signature)")
continue
}
if !isValidThoughtSignature(signature) {
log.Printf("[Debug] Thinking signature may be invalid (passing through anyway): len=%d", len(signature))
}
parts = append(parts, GeminiPart{
Text: block.Thinking,
Thought: true,
ThoughtSignature: signature,
})
case "image": case "image":
if block.Source != nil && block.Source.Type == "base64" { if block.Source != nil && block.Source.Type == "base64" {
@@ -239,10 +287,9 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
ID: block.ID, ID: block.ID,
}, },
} }
// 保留原有 signature或对 Gemini 模型使用 dummy signature // 只有 Gemini 模型使用 dummy signature
if block.Signature != "" { // Claude 模型不设置 signature避免验证问题
part.ThoughtSignature = block.Signature if allowDummyThought {
} else if allowDummyThought {
part.ThoughtSignature = dummyThoughtSignature part.ThoughtSignature = dummyThoughtSignature
} }
parts = append(parts, part) parts = append(parts, part)
@@ -386,9 +433,9 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
// 普通工具 // 普通工具
var funcDecls []GeminiFunctionDecl var funcDecls []GeminiFunctionDecl
for _, tool := range tools { for i, tool := range tools {
// 跳过无效工具名称 // 跳过无效工具名称
if tool.Name == "" { if strings.TrimSpace(tool.Name) == "" {
log.Printf("Warning: skipping tool with empty name") log.Printf("Warning: skipping tool with empty name")
continue continue
} }
@@ -397,10 +444,18 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
var inputSchema map[string]any var inputSchema map[string]any
// 检查是否为 custom 类型工具 (MCP) // 检查是否为 custom 类型工具 (MCP)
if tool.Type == "custom" && tool.Custom != nil { if tool.Type == "custom" {
// Custom 格式: 从 custom 字段获取 description 和 input_schema if tool.Custom == nil || tool.Custom.InputSchema == nil {
log.Printf("[Warning] Skipping invalid custom tool '%s': missing custom spec or input_schema", tool.Name)
continue
}
description = tool.Custom.Description description = tool.Custom.Description
inputSchema = tool.Custom.InputSchema inputSchema = tool.Custom.InputSchema
// 调试日志:记录 custom 工具的 schema
if schemaJSON, err := json.Marshal(inputSchema); err == nil {
log.Printf("[Debug] Tool[%d] '%s' (custom) original schema: %s", i, tool.Name, string(schemaJSON))
}
} else { } else {
// 标准格式: 从顶层字段获取 // 标准格式: 从顶层字段获取
description = tool.Description description = tool.Description
@@ -409,7 +464,6 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
// 清理 JSON Schema // 清理 JSON Schema
params := cleanJSONSchema(inputSchema) params := cleanJSONSchema(inputSchema)
// 为 nil schema 提供默认值 // 为 nil schema 提供默认值
if params == nil { if params == nil {
params = map[string]any{ params = map[string]any{
@@ -418,6 +472,11 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
} }
} }
// 调试日志:记录清理后的 schema
if paramsJSON, err := json.Marshal(params); err == nil {
log.Printf("[Debug] Tool[%d] '%s' cleaned schema: %s", i, tool.Name, string(paramsJSON))
}
funcDecls = append(funcDecls, GeminiFunctionDecl{ funcDecls = append(funcDecls, GeminiFunctionDecl{
Name: tool.Name, Name: tool.Name,
Description: description, Description: description,
@@ -479,31 +538,64 @@ func cleanJSONSchema(schema map[string]any) map[string]any {
} }
// excludedSchemaKeys 不支持的 schema 字段 // excludedSchemaKeys 不支持的 schema 字段
// 基于 Claude API (Vertex AI) 的实际支持情况
// 支持: type, description, enum, properties, required, additionalProperties, items
// 不支持: minItems, maxItems, minLength, maxLength, pattern, minimum, maximum 等验证字段
var excludedSchemaKeys = map[string]bool{ var excludedSchemaKeys = map[string]bool{
"$schema": true, // 元 schema 字段
"$id": true, "$schema": true,
"$ref": true, "$id": true,
"additionalProperties": true, "$ref": true,
"minLength": true,
"maxLength": true, // 字符串验证Gemini 不支持)
"minItems": true, "minLength": true,
"maxItems": true, "maxLength": true,
"uniqueItems": true, "pattern": true,
"minimum": true,
"maximum": true, // 数字验证Claude API 通过 Vertex AI 不支持这些字段)
"exclusiveMinimum": true, "minimum": true,
"exclusiveMaximum": true, "maximum": true,
"pattern": true, "exclusiveMinimum": true,
"format": true, "exclusiveMaximum": true,
"default": true, "multipleOf": true,
"strict": true,
"const": true, // 数组验证Claude API 通过 Vertex AI 不支持这些字段)
"examples": true, "uniqueItems": true,
"deprecated": true, "minItems": true,
"readOnly": true, "maxItems": true,
"writeOnly": true,
"contentMediaType": true, // 组合 schemaGemini 不支持)
"contentEncoding": true, "oneOf": true,
"anyOf": true,
"allOf": true,
"not": true,
"if": true,
"then": true,
"else": true,
"$defs": true,
"definitions": true,
// 对象验证(仅保留 properties/required/additionalProperties
"minProperties": true,
"maxProperties": true,
"patternProperties": true,
"propertyNames": true,
"dependencies": true,
"dependentSchemas": true,
"dependentRequired": true,
// 其他不支持的字段
"default": true,
"const": true,
"examples": true,
"deprecated": true,
"readOnly": true,
"writeOnly": true,
"contentMediaType": true,
"contentEncoding": true,
// Claude 特有字段
"strict": true,
} }
// cleanSchemaValue 递归清理 schema 值 // cleanSchemaValue 递归清理 schema 值
@@ -523,6 +615,31 @@ func cleanSchemaValue(value any) any {
continue continue
} }
// 特殊处理 format 字段:只保留 Gemini 支持的 format 值
if k == "format" {
if formatStr, ok := val.(string); ok {
// Gemini 只支持 date-time, date, time
if formatStr == "date-time" || formatStr == "date" || formatStr == "time" {
result[k] = val
}
// 其他 format 值直接跳过
}
continue
}
// 特殊处理 additionalPropertiesClaude API 只支持布尔值,不支持 schema 对象
if k == "additionalProperties" {
if boolVal, ok := val.(bool); ok {
result[k] = boolVal
log.Printf("[Debug] additionalProperties is bool: %v", boolVal)
} else {
// 如果是 schema 对象,转换为 false更安全的默认值
result[k] = false
log.Printf("[Debug] additionalProperties is not bool (type: %T), converting to false", val)
}
continue
}
// 递归清理所有值 // 递归清理所有值
result[k] = cleanSchemaValue(val) result[k] = cleanSchemaValue(val)
} }

View File

@@ -0,0 +1,179 @@
package antigravity
import (
"encoding/json"
"testing"
)
// TestBuildParts_ThinkingBlockWithoutSignature 测试thinking block无signature时的处理
func TestBuildParts_ThinkingBlockWithoutSignature(t *testing.T) {
tests := []struct {
name string
content string
allowDummyThought bool
expectedParts int
description string
}{
{
name: "Claude model - skip thinking block without signature",
content: `[
{"type": "text", "text": "Hello"},
{"type": "thinking", "thinking": "Let me think...", "signature": ""},
{"type": "text", "text": "World"}
]`,
allowDummyThought: false,
expectedParts: 2, // 只有两个text block
description: "Claude模型应该跳过无signature的thinking block",
},
{
name: "Claude model - keep thinking block with signature",
content: `[
{"type": "text", "text": "Hello"},
{"type": "thinking", "thinking": "Let me think...", "signature": "valid_sig"},
{"type": "text", "text": "World"}
]`,
allowDummyThought: false,
expectedParts: 3, // 三个block都保留
description: "Claude模型应该保留有signature的thinking block",
},
{
name: "Gemini model - use dummy signature",
content: `[
{"type": "text", "text": "Hello"},
{"type": "thinking", "thinking": "Let me think...", "signature": ""},
{"type": "text", "text": "World"}
]`,
allowDummyThought: true,
expectedParts: 3, // 三个block都保留thinking使用dummy signature
description: "Gemini模型应该为无signature的thinking block使用dummy signature",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
toolIDToName := make(map[string]string)
parts, err := buildParts(json.RawMessage(tt.content), toolIDToName, tt.allowDummyThought)
if err != nil {
t.Fatalf("buildParts() error = %v", err)
}
if len(parts) != tt.expectedParts {
t.Errorf("%s: got %d parts, want %d parts", tt.description, len(parts), tt.expectedParts)
}
})
}
}
// TestBuildTools_CustomTypeTools 测试custom类型工具转换
func TestBuildTools_CustomTypeTools(t *testing.T) {
tests := []struct {
name string
tools []ClaudeTool
expectedLen int
description string
}{
{
name: "Standard tool format",
tools: []ClaudeTool{
{
Name: "get_weather",
Description: "Get weather information",
InputSchema: map[string]any{
"type": "object",
"properties": map[string]any{
"location": map[string]any{"type": "string"},
},
},
},
},
expectedLen: 1,
description: "标准工具格式应该正常转换",
},
{
name: "Custom type tool (MCP format)",
tools: []ClaudeTool{
{
Type: "custom",
Name: "mcp_tool",
Custom: &ClaudeCustomToolSpec{
Description: "MCP tool description",
InputSchema: map[string]any{
"type": "object",
"properties": map[string]any{
"param": map[string]any{"type": "string"},
},
},
},
},
},
expectedLen: 1,
description: "Custom类型工具应该从Custom字段读取description和input_schema",
},
{
name: "Mixed standard and custom tools",
tools: []ClaudeTool{
{
Name: "standard_tool",
Description: "Standard tool",
InputSchema: map[string]any{"type": "object"},
},
{
Type: "custom",
Name: "custom_tool",
Custom: &ClaudeCustomToolSpec{
Description: "Custom tool",
InputSchema: map[string]any{"type": "object"},
},
},
},
expectedLen: 1, // 返回一个GeminiToolDeclaration包含2个function declarations
description: "混合标准和custom工具应该都能正确转换",
},
{
name: "Invalid custom tool - nil Custom field",
tools: []ClaudeTool{
{
Type: "custom",
Name: "invalid_custom",
// Custom 为 nil
},
},
expectedLen: 0, // 应该被跳过
description: "Custom字段为nil的custom工具应该被跳过",
},
{
name: "Invalid custom tool - nil InputSchema",
tools: []ClaudeTool{
{
Type: "custom",
Name: "invalid_custom",
Custom: &ClaudeCustomToolSpec{
Description: "Invalid",
// InputSchema 为 nil
},
},
},
expectedLen: 0, // 应该被跳过
description: "InputSchema为nil的custom工具应该被跳过",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := buildTools(tt.tools)
if len(result) != tt.expectedLen {
t.Errorf("%s: got %d tool declarations, want %d", tt.description, len(result), tt.expectedLen)
}
// 验证function declarations存在
if len(result) > 0 && result[0].FunctionDeclarations != nil {
if len(result[0].FunctionDeclarations) != len(tt.tools) {
t.Errorf("%s: got %d function declarations, want %d",
tt.description, len(result[0].FunctionDeclarations), len(tt.tools))
}
}
})
}
}

View File

@@ -16,6 +16,12 @@ const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleav
// HaikuBetaHeader Haiku 模型使用的 anthropic-beta header不需要 claude-code beta // HaikuBetaHeader Haiku 模型使用的 anthropic-beta header不需要 claude-code beta
const HaikuBetaHeader = BetaOAuth + "," + BetaInterleavedThinking const HaikuBetaHeader = BetaOAuth + "," + BetaInterleavedThinking
// ApiKeyBetaHeader API-key 账号建议使用的 anthropic-beta header不包含 oauth
const ApiKeyBetaHeader = BetaClaudeCode + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
// ApiKeyHaikuBetaHeader Haiku 模型在 API-key 账号下使用的 anthropic-beta header不包含 oauth / claude-code
const ApiKeyHaikuBetaHeader = BetaInterleavedThinking
// Claude Code 客户端默认请求头 // Claude Code 客户端默认请求头
var DefaultHeaders = map[string]string{ var DefaultHeaders = map[string]string{
"User-Agent": "claude-cli/2.0.62 (external, cli)", "User-Agent": "claude-cli/2.0.62 (external, cli)",

View File

@@ -4,7 +4,7 @@ import (
"math" "math"
"net/http" "net/http"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )

View File

@@ -9,7 +9,7 @@ import (
"net/http/httptest" "net/http/httptest"
"testing" "testing"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" errors2 "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@@ -82,7 +82,7 @@ func TestErrorFrom(t *testing.T) {
}, },
{ {
name: "application_error", name: "application_error",
err: infraerrors.Forbidden("FORBIDDEN", "no access").WithMetadata(map[string]string{"scope": "admin"}), err: errors2.Forbidden("FORBIDDEN", "no access").WithMetadata(map[string]string{"scope": "admin"}),
wantWritten: true, wantWritten: true,
wantHTTPCode: http.StatusForbidden, wantHTTPCode: http.StatusForbidden,
wantBody: Response{ wantBody: Response{
@@ -94,7 +94,7 @@ func TestErrorFrom(t *testing.T) {
}, },
{ {
name: "bad_request_error", name: "bad_request_error",
err: infraerrors.BadRequest("INVALID_REQUEST", "invalid request"), err: errors2.BadRequest("INVALID_REQUEST", "invalid request"),
wantWritten: true, wantWritten: true,
wantHTTPCode: http.StatusBadRequest, wantHTTPCode: http.StatusBadRequest,
wantBody: Response{ wantBody: Response{
@@ -105,7 +105,7 @@ func TestErrorFrom(t *testing.T) {
}, },
{ {
name: "unauthorized_error", name: "unauthorized_error",
err: infraerrors.Unauthorized("UNAUTHORIZED", "unauthorized"), err: errors2.Unauthorized("UNAUTHORIZED", "unauthorized"),
wantWritten: true, wantWritten: true,
wantHTTPCode: http.StatusUnauthorized, wantHTTPCode: http.StatusUnauthorized,
wantBody: Response{ wantBody: Response{
@@ -116,7 +116,7 @@ func TestErrorFrom(t *testing.T) {
}, },
{ {
name: "not_found_error", name: "not_found_error",
err: infraerrors.NotFound("NOT_FOUND", "not found"), err: errors2.NotFound("NOT_FOUND", "not found"),
wantWritten: true, wantWritten: true,
wantHTTPCode: http.StatusNotFound, wantHTTPCode: http.StatusNotFound,
wantBody: Response{ wantBody: Response{
@@ -127,7 +127,7 @@ func TestErrorFrom(t *testing.T) {
}, },
{ {
name: "conflict_error", name: "conflict_error",
err: infraerrors.Conflict("CONFLICT", "conflict"), err: errors2.Conflict("CONFLICT", "conflict"),
wantWritten: true, wantWritten: true,
wantHTTPCode: http.StatusConflict, wantHTTPCode: http.StatusConflict,
wantBody: Response{ wantBody: Response{
@@ -143,7 +143,7 @@ func TestErrorFrom(t *testing.T) {
wantHTTPCode: http.StatusInternalServerError, wantHTTPCode: http.StatusInternalServerError,
wantBody: Response{ wantBody: Response{
Code: http.StatusInternalServerError, Code: http.StatusInternalServerError,
Message: infraerrors.UnknownMessage, Message: errors2.UnknownMessage,
}, },
}, },
} }

View File

@@ -1,4 +1,4 @@
package infrastructure package repository
import ( import (
"database/sql" "database/sql"

View File

@@ -1,4 +1,4 @@
package infrastructure package repository
import ( import (
"database/sql" "database/sql"

View File

@@ -1,6 +1,6 @@
// Package infrastructure 提供应用程序的基础设施层组件。 // Package infrastructure 提供应用程序的基础设施层组件。
// 包括数据库连接初始化、ORM 客户端管理、Redis 连接、数据库迁移等核心功能。 // 包括数据库连接初始化、ORM 客户端管理、Redis 连接、数据库迁移等核心功能。
package infrastructure package repository
import ( import (
"context" "context"

View File

@@ -7,7 +7,7 @@ import (
"strings" "strings"
dbent "github.com/Wei-Shaw/sub2api/ent" dbent "github.com/Wei-Shaw/sub2api/ent"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/lib/pq" "github.com/lib/pq"
) )

View File

@@ -17,7 +17,6 @@ import (
dbent "github.com/Wei-Shaw/sub2api/ent" dbent "github.com/Wei-Shaw/sub2api/ent"
_ "github.com/Wei-Shaw/sub2api/ent/runtime" _ "github.com/Wei-Shaw/sub2api/ent/runtime"
"github.com/Wei-Shaw/sub2api/internal/infrastructure"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
@@ -97,7 +96,7 @@ func TestMain(m *testing.M) {
log.Printf("failed to open sql db: %v", err) log.Printf("failed to open sql db: %v", err)
os.Exit(1) os.Exit(1)
} }
if err := infrastructure.ApplyMigrations(ctx, integrationDB); err != nil { if err := ApplyMigrations(ctx, integrationDB); err != nil {
log.Printf("failed to apply db migrations: %v", err) log.Printf("failed to apply db migrations: %v", err)
os.Exit(1) os.Exit(1)
} }

View File

@@ -1,4 +1,4 @@
package infrastructure package repository
import ( import (
"context" "context"

View File

@@ -7,7 +7,6 @@ import (
"database/sql" "database/sql"
"testing" "testing"
"github.com/Wei-Shaw/sub2api/internal/infrastructure"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@@ -15,7 +14,7 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
tx := testTx(t) tx := testTx(t)
// Re-apply migrations to verify idempotency (no errors, no duplicate rows). // Re-apply migrations to verify idempotency (no errors, no duplicate rows).
require.NoError(t, infrastructure.ApplyMigrations(context.Background(), integrationDB)) require.NoError(t, ApplyMigrations(context.Background(), integrationDB))
// schema_migrations should have at least the current migration set. // schema_migrations should have at least the current migration set.
var applied int var applied int

View File

@@ -1,4 +1,4 @@
package infrastructure package repository
import ( import (
"time" "time"

View File

@@ -1,4 +1,4 @@
package infrastructure package repository
import ( import (
"testing" "testing"

View File

@@ -291,13 +291,11 @@ func (r *userSubscriptionRepository) ResetMonthlyUsage(ctx context.Context, id i
return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil) return translatePersistenceError(err, service.ErrSubscriptionNotFound, nil)
} }
// IncrementUsage 原子性地累加用量并校验限额 // IncrementUsage 原子性地累加订阅用量。
// 使用单条 SQL 语句同时检查 Group 的限额,如果任一限额即将超出则拒绝更新。 // 限额检查已在请求前由 BillingCacheService.CheckBillingEligibility 完成,
// 当更新失败时,会执行额外查询确定具体超出的限额类型 // 此处仅负责记录实际消费,确保消费数据的完整性
func (r *userSubscriptionRepository) IncrementUsage(ctx context.Context, id int64, costUSD float64) error { func (r *userSubscriptionRepository) IncrementUsage(ctx context.Context, id int64, costUSD float64) error {
// 使用 JOIN 的原子更新:只有当所有限额条件满足时才执行累加 const updateSQL = `
// NULL 限额表示无限制
const atomicUpdateSQL = `
UPDATE user_subscriptions us UPDATE user_subscriptions us
SET SET
daily_usage_usd = us.daily_usage_usd + $1, daily_usage_usd = us.daily_usage_usd + $1,
@@ -309,13 +307,10 @@ func (r *userSubscriptionRepository) IncrementUsage(ctx context.Context, id int6
AND us.deleted_at IS NULL AND us.deleted_at IS NULL
AND us.group_id = g.id AND us.group_id = g.id
AND g.deleted_at IS NULL AND g.deleted_at IS NULL
AND (g.daily_limit_usd IS NULL OR us.daily_usage_usd + $1 <= g.daily_limit_usd)
AND (g.weekly_limit_usd IS NULL OR us.weekly_usage_usd + $1 <= g.weekly_limit_usd)
AND (g.monthly_limit_usd IS NULL OR us.monthly_usage_usd + $1 <= g.monthly_limit_usd)
` `
client := clientFromContext(ctx, r.client) client := clientFromContext(ctx, r.client)
result, err := client.ExecContext(ctx, atomicUpdateSQL, costUSD, id) result, err := client.ExecContext(ctx, updateSQL, costUSD, id)
if err != nil { if err != nil {
return err return err
} }
@@ -326,64 +321,11 @@ func (r *userSubscriptionRepository) IncrementUsage(ctx context.Context, id int6
} }
if affected > 0 { if affected > 0 {
return nil // 更新成功 return nil
} }
// affected == 0可能是订阅不存在、分组已删除、或限额超出 // affected == 0订阅不存在或已删除
// 执行额外查询确定具体原因 return service.ErrSubscriptionNotFound
return r.checkIncrementFailureReason(ctx, id, costUSD)
}
// checkIncrementFailureReason 查询更新失败的具体原因
func (r *userSubscriptionRepository) checkIncrementFailureReason(ctx context.Context, id int64, costUSD float64) error {
const checkSQL = `
SELECT
CASE WHEN us.deleted_at IS NOT NULL THEN 'subscription_deleted'
WHEN g.id IS NULL THEN 'subscription_not_found'
WHEN g.deleted_at IS NOT NULL THEN 'group_deleted'
WHEN g.daily_limit_usd IS NOT NULL AND us.daily_usage_usd + $1 > g.daily_limit_usd THEN 'daily_exceeded'
WHEN g.weekly_limit_usd IS NOT NULL AND us.weekly_usage_usd + $1 > g.weekly_limit_usd THEN 'weekly_exceeded'
WHEN g.monthly_limit_usd IS NOT NULL AND us.monthly_usage_usd + $1 > g.monthly_limit_usd THEN 'monthly_exceeded'
ELSE 'unknown'
END AS reason
FROM user_subscriptions us
LEFT JOIN groups g ON us.group_id = g.id
WHERE us.id = $2
`
client := clientFromContext(ctx, r.client)
rows, err := client.QueryContext(ctx, checkSQL, costUSD, id)
if err != nil {
return err
}
defer func() { _ = rows.Close() }()
if !rows.Next() {
return service.ErrSubscriptionNotFound
}
var reason string
if err := rows.Scan(&reason); err != nil {
return err
}
if err := rows.Err(); err != nil {
return err
}
switch reason {
case "subscription_not_found", "subscription_deleted", "group_deleted":
return service.ErrSubscriptionNotFound
case "daily_exceeded":
return service.ErrDailyLimitExceeded
case "weekly_exceeded":
return service.ErrWeeklyLimitExceeded
case "monthly_exceeded":
return service.ErrMonthlyLimitExceeded
default:
// unknown 情况理论上不应发生,但作为兜底返回
return service.ErrSubscriptionNotFound
}
} }
func (r *userSubscriptionRepository) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) { func (r *userSubscriptionRepository) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) {

View File

@@ -633,112 +633,7 @@ func (s *UserSubscriptionRepoSuite) TestActiveExpiredBoundaries_UsageAndReset_Ba
s.Require().Equal(service.SubscriptionStatusExpired, updated.Status, "expected status expired") s.Require().Equal(service.SubscriptionStatusExpired, updated.Status, "expected status expired")
} }
// --- 限额检查与软删除过滤测试 --- // --- 软删除过滤测试 ---
func (s *UserSubscriptionRepoSuite) mustCreateGroupWithLimits(name string, daily, weekly, monthly *float64) *service.Group {
s.T().Helper()
create := s.client.Group.Create().
SetName(name).
SetStatus(service.StatusActive).
SetSubscriptionType(service.SubscriptionTypeSubscription)
if daily != nil {
create.SetDailyLimitUsd(*daily)
}
if weekly != nil {
create.SetWeeklyLimitUsd(*weekly)
}
if monthly != nil {
create.SetMonthlyLimitUsd(*monthly)
}
g, err := create.Save(s.ctx)
s.Require().NoError(err, "create group with limits")
return groupEntityToService(g)
}
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_DailyLimitExceeded() {
user := s.mustCreateUser("dailylimit@test.com", service.RoleUser)
dailyLimit := 10.0
group := s.mustCreateGroupWithLimits("g-dailylimit", &dailyLimit, nil, nil)
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
// 先增加 9.0,应该成功
err := s.repo.IncrementUsage(s.ctx, sub.ID, 9.0)
s.Require().NoError(err, "first increment should succeed")
// 再增加 2.0,会超过 10.0 限额,应该失败
err = s.repo.IncrementUsage(s.ctx, sub.ID, 2.0)
s.Require().Error(err, "should fail when daily limit exceeded")
s.Require().ErrorIs(err, service.ErrDailyLimitExceeded)
// 验证用量没有变化
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err)
s.Require().InDelta(9.0, got.DailyUsageUSD, 1e-6, "usage should not change after failed increment")
}
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_WeeklyLimitExceeded() {
user := s.mustCreateUser("weeklylimit@test.com", service.RoleUser)
weeklyLimit := 50.0
group := s.mustCreateGroupWithLimits("g-weeklylimit", nil, &weeklyLimit, nil)
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
// 增加 45.0,应该成功
err := s.repo.IncrementUsage(s.ctx, sub.ID, 45.0)
s.Require().NoError(err, "first increment should succeed")
// 再增加 10.0,会超过 50.0 限额,应该失败
err = s.repo.IncrementUsage(s.ctx, sub.ID, 10.0)
s.Require().Error(err, "should fail when weekly limit exceeded")
s.Require().ErrorIs(err, service.ErrWeeklyLimitExceeded)
}
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_MonthlyLimitExceeded() {
user := s.mustCreateUser("monthlylimit@test.com", service.RoleUser)
monthlyLimit := 100.0
group := s.mustCreateGroupWithLimits("g-monthlylimit", nil, nil, &monthlyLimit)
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
// 增加 90.0,应该成功
err := s.repo.IncrementUsage(s.ctx, sub.ID, 90.0)
s.Require().NoError(err, "first increment should succeed")
// 再增加 20.0,会超过 100.0 限额,应该失败
err = s.repo.IncrementUsage(s.ctx, sub.ID, 20.0)
s.Require().Error(err, "should fail when monthly limit exceeded")
s.Require().ErrorIs(err, service.ErrMonthlyLimitExceeded)
}
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_NoLimits() {
user := s.mustCreateUser("nolimits@test.com", service.RoleUser)
group := s.mustCreateGroupWithLimits("g-nolimits", nil, nil, nil) // 无限额
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
// 应该可以增加任意金额
err := s.repo.IncrementUsage(s.ctx, sub.ID, 1000000.0)
s.Require().NoError(err, "should succeed without limits")
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err)
s.Require().InDelta(1000000.0, got.DailyUsageUSD, 1e-6)
}
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_AtExactLimit() {
user := s.mustCreateUser("exactlimit@test.com", service.RoleUser)
dailyLimit := 10.0
group := s.mustCreateGroupWithLimits("g-exactlimit", &dailyLimit, nil, nil)
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
// 正好达到限额应该成功
err := s.repo.IncrementUsage(s.ctx, sub.ID, 10.0)
s.Require().NoError(err, "should succeed at exact limit")
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err)
s.Require().InDelta(10.0, got.DailyUsageUSD, 1e-6)
}
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_SoftDeletedGroup() { func (s *UserSubscriptionRepoSuite) TestIncrementUsage_SoftDeletedGroup() {
user := s.mustCreateUser("softdeleted@test.com", service.RoleUser) user := s.mustCreateUser("softdeleted@test.com", service.RoleUser)
@@ -779,7 +674,7 @@ func (s *UserSubscriptionRepoSuite) TestUpdate_NilInput() {
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_Concurrent() { func (s *UserSubscriptionRepoSuite) TestIncrementUsage_Concurrent() {
user := s.mustCreateUser("concurrent@test.com", service.RoleUser) user := s.mustCreateUser("concurrent@test.com", service.RoleUser)
group := s.mustCreateGroupWithLimits("g-concurrent", nil, nil, nil) // 无限额 group := s.mustCreateGroup("g-concurrent")
sub := s.mustCreateSubscription(user.ID, group.ID, nil) sub := s.mustCreateSubscription(user.ID, group.ID, nil)
const numGoroutines = 10 const numGoroutines = 10
@@ -808,34 +703,6 @@ func (s *UserSubscriptionRepoSuite) TestIncrementUsage_Concurrent() {
s.Require().InDelta(expectedUsage, got.MonthlyUsageUSD, 1e-6, "monthly usage should be correctly accumulated") s.Require().InDelta(expectedUsage, got.MonthlyUsageUSD, 1e-6, "monthly usage should be correctly accumulated")
} }
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_ConcurrentWithLimit() {
user := s.mustCreateUser("concurrentlimit@test.com", service.RoleUser)
dailyLimit := 5.0
group := s.mustCreateGroupWithLimits("g-concurrentlimit", &dailyLimit, nil, nil)
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
// 注意:事务内的操作是串行的,所以这里改为顺序执行以验证限额逻辑
// 尝试增加 10 次,每次 1.0,但限额只有 5.0
const numAttempts = 10
const incrementPerAttempt = 1.0
successCount := 0
for i := 0; i < numAttempts; i++ {
err := s.repo.IncrementUsage(s.ctx, sub.ID, incrementPerAttempt)
if err == nil {
successCount++
}
}
// 验证:应该有 5 次成功不超过限额5 次失败(超出限额)
s.Require().Equal(5, successCount, "exactly 5 increments should succeed (limit=5, increment=1)")
// 验证最终用量等于限额
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err)
s.Require().InDelta(dailyLimit, got.DailyUsageUSD, 1e-6, "daily usage should equal limit")
}
func (s *UserSubscriptionRepoSuite) TestTxContext_RollbackIsolation() { func (s *UserSubscriptionRepoSuite) TestTxContext_RollbackIsolation() {
baseClient := testEntClient(s.T()) baseClient := testEntClient(s.T())
tx, err := baseClient.Tx(context.Background()) tx, err := baseClient.Tx(context.Background())

View File

@@ -1,6 +1,11 @@
package repository package repository
import ( import (
"database/sql"
"errors"
entsql "entgo.io/ent/dialect/sql"
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/google/wire" "github.com/google/wire"
@@ -54,4 +59,58 @@ var ProviderSet = wire.NewSet(
NewOpenAIOAuthClient, NewOpenAIOAuthClient,
NewGeminiOAuthClient, NewGeminiOAuthClient,
NewGeminiCliCodeAssistClient, NewGeminiCliCodeAssistClient,
ProvideEnt,
ProvideSQLDB,
ProvideRedis,
) )
// ProvideEnt 为依赖注入提供 Ent 客户端。
//
// 该函数是 InitEnt 的包装器,符合 Wire 的依赖提供函数签名要求。
// Wire 会在编译时分析依赖关系,自动生成初始化代码。
//
// 依赖config.Config
// 提供:*ent.Client
func ProvideEnt(cfg *config.Config) (*ent.Client, error) {
client, _, err := InitEnt(cfg)
return client, err
}
// ProvideSQLDB 从 Ent 客户端提取底层的 *sql.DB 连接。
//
// 某些 Repository 需要直接执行原生 SQL如复杂的批量更新、聚合查询
// 此时需要访问底层的 sql.DB 而不是通过 Ent ORM。
//
// 设计说明:
// - Ent 底层使用 sql.DB通过 Driver 接口可以访问
// - 这种设计允许在同一事务中混用 Ent 和原生 SQL
//
// 依赖:*ent.Client
// 提供:*sql.DB
func ProvideSQLDB(client *ent.Client) (*sql.DB, error) {
if client == nil {
return nil, errors.New("nil ent client")
}
// 从 Ent 客户端获取底层驱动
drv, ok := client.Driver().(*entsql.Driver)
if !ok {
return nil, errors.New("ent driver does not expose *sql.DB")
}
// 返回驱动持有的 sql.DB 实例
return drv.DB(), nil
}
// ProvideRedis 为依赖注入提供 Redis 客户端。
//
// Redis 用于:
// - 分布式锁(如并发控制)
// - 缓存如用户会话、API 响应缓存)
// - 速率限制
// - 实时统计数据
//
// 依赖config.Config
// 提供:*redis.Client
func ProvideRedis(cfg *config.Config) *redis.Client {
return InitRedis(cfg)
}

View File

@@ -7,7 +7,7 @@ import (
"os" "os"
"strings" "strings"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )

View File

@@ -8,7 +8,7 @@ import (
"net/http/httptest" "net/http/httptest"
"testing" "testing"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"

View File

@@ -5,7 +5,7 @@ import (
"fmt" "fmt"
"time" "time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
) )

View File

@@ -488,6 +488,11 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
subscriptionType = SubscriptionTypeStandard subscriptionType = SubscriptionTypeStandard
} }
// 限额字段0 和 nil 都表示"无限制"
dailyLimit := normalizeLimit(input.DailyLimitUSD)
weeklyLimit := normalizeLimit(input.WeeklyLimitUSD)
monthlyLimit := normalizeLimit(input.MonthlyLimitUSD)
group := &Group{ group := &Group{
Name: input.Name, Name: input.Name,
Description: input.Description, Description: input.Description,
@@ -496,9 +501,9 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
IsExclusive: input.IsExclusive, IsExclusive: input.IsExclusive,
Status: StatusActive, Status: StatusActive,
SubscriptionType: subscriptionType, SubscriptionType: subscriptionType,
DailyLimitUSD: input.DailyLimitUSD, DailyLimitUSD: dailyLimit,
WeeklyLimitUSD: input.WeeklyLimitUSD, WeeklyLimitUSD: weeklyLimit,
MonthlyLimitUSD: input.MonthlyLimitUSD, MonthlyLimitUSD: monthlyLimit,
} }
if err := s.groupRepo.Create(ctx, group); err != nil { if err := s.groupRepo.Create(ctx, group); err != nil {
return nil, err return nil, err
@@ -506,6 +511,14 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
return group, nil return group, nil
} }
// normalizeLimit 将 0 或负数转换为 nil表示无限制
func normalizeLimit(limit *float64) *float64 {
if limit == nil || *limit <= 0 {
return nil
}
return limit
}
func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) { func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) {
group, err := s.groupRepo.GetByID(ctx, id) group, err := s.groupRepo.GetByID(ctx, id)
if err != nil { if err != nil {
@@ -535,15 +548,15 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
if input.SubscriptionType != "" { if input.SubscriptionType != "" {
group.SubscriptionType = input.SubscriptionType group.SubscriptionType = input.SubscriptionType
} }
// 限额字段支持设置为nil清除限额或具体值 // 限额字段0 和 nil 都表示"无限制",正数表示具体限额
if input.DailyLimitUSD != nil { if input.DailyLimitUSD != nil {
group.DailyLimitUSD = input.DailyLimitUSD group.DailyLimitUSD = normalizeLimit(input.DailyLimitUSD)
} }
if input.WeeklyLimitUSD != nil { if input.WeeklyLimitUSD != nil {
group.WeeklyLimitUSD = input.WeeklyLimitUSD group.WeeklyLimitUSD = normalizeLimit(input.WeeklyLimitUSD)
} }
if input.MonthlyLimitUSD != nil { if input.MonthlyLimitUSD != nil {
group.MonthlyLimitUSD = input.MonthlyLimitUSD group.MonthlyLimitUSD = normalizeLimit(input.MonthlyLimitUSD)
} }
if err := s.groupRepo.Update(ctx, group); err != nil { if err := s.groupRepo.Update(ctx, group); err != nil {

View File

@@ -358,6 +358,15 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
return nil, fmt.Errorf("transform request: %w", err) return nil, fmt.Errorf("transform request: %w", err)
} }
// 调试:记录转换后的请求体(仅记录前 2000 字符)
if bodyJSON, err := json.Marshal(geminiBody); err == nil {
truncated := string(bodyJSON)
if len(truncated) > 2000 {
truncated = truncated[:2000] + "..."
}
log.Printf("[Debug] Transformed Gemini request: %s", truncated)
}
// 构建上游 action // 构建上游 action
action := "generateContent" action := "generateContent"
if claudeReq.Stream { if claudeReq.Stream {

View File

@@ -2,6 +2,7 @@ package service
import ( import (
"context" "context"
"fmt"
"time" "time"
) )
@@ -28,7 +29,7 @@ func (r *AntigravityTokenRefresher) CanRefresh(account *Account) bool {
} }
// NeedsRefresh 检查账户是否需要刷新 // NeedsRefresh 检查账户是否需要刷新
// Antigravity 使用固定的10分钟刷新窗口,忽略全局配置 // Antigravity 使用固定的15分钟刷新窗口,忽略全局配置
func (r *AntigravityTokenRefresher) NeedsRefresh(account *Account, _ time.Duration) bool { func (r *AntigravityTokenRefresher) NeedsRefresh(account *Account, _ time.Duration) bool {
if !r.CanRefresh(account) { if !r.CanRefresh(account) {
return false return false
@@ -37,7 +38,13 @@ func (r *AntigravityTokenRefresher) NeedsRefresh(account *Account, _ time.Durati
if expiresAt == nil { if expiresAt == nil {
return false return false
} }
return time.Until(*expiresAt) < antigravityRefreshWindow timeUntilExpiry := time.Until(*expiresAt)
needsRefresh := timeUntilExpiry < antigravityRefreshWindow
if needsRefresh {
fmt.Printf("[AntigravityTokenRefresher] Account %d needs refresh: expires_at=%s, time_until_expiry=%v, window=%v\n",
account.ID, expiresAt.Format("2006-01-02 15:04:05"), timeUntilExpiry, antigravityRefreshWindow)
}
return needsRefresh
} }
// Refresh 执行 token 刷新 // Refresh 执行 token 刷新

View File

@@ -8,7 +8,7 @@ import (
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
) )

View File

@@ -8,7 +8,7 @@ import (
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/golang-jwt/jwt/v5" "github.com/golang-jwt/jwt/v5"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"

View File

@@ -9,7 +9,7 @@ import (
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
) )
// 错误定义 // 错误定义

View File

@@ -10,7 +10,7 @@ import (
"strconv" "strconv"
"time" "time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
) )
var ( var (

View File

@@ -20,6 +20,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson" "github.com/tidwall/sjson"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -1061,6 +1062,30 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// 处理错误响应(不可重试的错误) // 处理错误响应(不可重试的错误)
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
// 可选:对部分 400 触发 failover默认关闭以保持语义
if resp.StatusCode == 400 && s.cfg != nil && s.cfg.Gateway.FailoverOn400 {
respBody, readErr := io.ReadAll(resp.Body)
if readErr != nil {
// ReadAll failed, fall back to normal error handling without consuming the stream
return s.handleErrorResponse(ctx, resp, c, account)
}
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody))
if s.shouldFailoverOn400(respBody) {
if s.cfg.Gateway.LogUpstreamErrorBody {
log.Printf(
"Account %d: 400 error, attempting failover: %s",
account.ID,
truncateForLog(respBody, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes),
)
} else {
log.Printf("Account %d: 400 error, attempting failover", account.ID)
}
s.handleFailoverSideEffects(ctx, resp, account)
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
}
return s.handleErrorResponse(ctx, resp, c, account) return s.handleErrorResponse(ctx, resp, c, account)
} }
@@ -1163,6 +1188,13 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
// 处理anthropic-beta headerOAuth账号需要特殊处理 // 处理anthropic-beta headerOAuth账号需要特殊处理
if tokenType == "oauth" { if tokenType == "oauth" {
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta"))) req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForApiKey && req.Header.Get("anthropic-beta") == "" {
// API-key仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭)
if requestNeedsBetaFeatures(body) {
if beta := defaultApiKeyBetaHeader(body); beta != "" {
req.Header.Set("anthropic-beta", beta)
}
}
} }
return req, nil return req, nil
@@ -1215,6 +1247,83 @@ func (s *GatewayService) getBetaHeader(modelID string, clientBetaHeader string)
return claude.DefaultBetaHeader return claude.DefaultBetaHeader
} }
func requestNeedsBetaFeatures(body []byte) bool {
tools := gjson.GetBytes(body, "tools")
if tools.Exists() && tools.IsArray() && len(tools.Array()) > 0 {
return true
}
if strings.EqualFold(gjson.GetBytes(body, "thinking.type").String(), "enabled") {
return true
}
return false
}
func defaultApiKeyBetaHeader(body []byte) string {
modelID := gjson.GetBytes(body, "model").String()
if strings.Contains(strings.ToLower(modelID), "haiku") {
return claude.ApiKeyHaikuBetaHeader
}
return claude.ApiKeyBetaHeader
}
func truncateForLog(b []byte, maxBytes int) string {
if maxBytes <= 0 {
maxBytes = 2048
}
if len(b) > maxBytes {
b = b[:maxBytes]
}
s := string(b)
// 保持一行,避免污染日志格式
s = strings.ReplaceAll(s, "\n", "\\n")
s = strings.ReplaceAll(s, "\r", "\\r")
return s
}
func (s *GatewayService) shouldFailoverOn400(respBody []byte) bool {
// 只对“可能是兼容性差异导致”的 400 允许切换,避免无意义重试。
// 默认保守:无法识别则不切换。
msg := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(respBody)))
if msg == "" {
return false
}
// 缺少/错误的 beta header换账号/链路可能成功(尤其是混合调度时)。
// 更精确匹配 beta 相关的兼容性问题,避免误触发切换。
if strings.Contains(msg, "anthropic-beta") ||
strings.Contains(msg, "beta feature") ||
strings.Contains(msg, "requires beta") {
return true
}
// thinking/tool streaming 等兼容性约束(常见于中间转换链路)
if strings.Contains(msg, "thinking") || strings.Contains(msg, "thought_signature") || strings.Contains(msg, "signature") {
return true
}
if strings.Contains(msg, "tool_use") || strings.Contains(msg, "tool_result") || strings.Contains(msg, "tools") {
return true
}
return false
}
func extractUpstreamErrorMessage(body []byte) string {
// Claude 风格:{"type":"error","error":{"type":"...","message":"..."}}
if m := gjson.GetBytes(body, "error.message").String(); strings.TrimSpace(m) != "" {
inner := strings.TrimSpace(m)
// 有些上游会把完整 JSON 作为字符串塞进 message
if strings.HasPrefix(inner, "{") {
if innerMsg := gjson.Get(inner, "error.message").String(); strings.TrimSpace(innerMsg) != "" {
return innerMsg
}
}
return m
}
// 兜底:尝试顶层 message
return gjson.GetBytes(body, "message").String()
}
func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) { func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) {
body, _ := io.ReadAll(resp.Body) body, _ := io.ReadAll(resp.Body)
@@ -1227,6 +1336,16 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res
switch resp.StatusCode { switch resp.StatusCode {
case 400: case 400:
// 仅记录上游错误摘要(避免输出请求内容);需要时可通过配置打开
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
log.Printf(
"Upstream 400 error (account=%d platform=%s type=%s): %s",
account.ID,
account.Platform,
account.Type,
truncateForLog(body, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes),
)
}
c.Data(http.StatusBadRequest, "application/json", body) c.Data(http.StatusBadRequest, "application/json", body)
return nil, fmt.Errorf("upstream error: %d", resp.StatusCode) return nil, fmt.Errorf("upstream error: %d", resp.StatusCode)
case 401: case 401:
@@ -1706,6 +1825,18 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
// 标记账号状态429/529等 // 标记账号状态429/529等
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
// 记录上游错误摘要便于排障(不回显请求内容)
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
log.Printf(
"count_tokens upstream error %d (account=%d platform=%s type=%s): %s",
resp.StatusCode,
account.ID,
account.Platform,
account.Type,
truncateForLog(respBody, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes),
)
}
// 返回简化的错误响应 // 返回简化的错误响应
errMsg := "Upstream request failed" errMsg := "Upstream request failed"
switch resp.StatusCode { switch resp.StatusCode {
@@ -1786,6 +1917,13 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
// OAuth 账号:处理 anthropic-beta header // OAuth 账号:处理 anthropic-beta header
if tokenType == "oauth" { if tokenType == "oauth" {
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta"))) req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForApiKey && req.Header.Get("anthropic-beta") == "" {
// API-key与 messages 同步的按需 beta 注入(默认关闭)
if requestNeedsBetaFeatures(body) {
if beta := defaultApiKeyBetaHeader(body); beta != "" {
req.Header.Set("anthropic-beta", beta)
}
}
} }
return req, nil return req, nil

View File

@@ -2278,11 +2278,13 @@ func convertClaudeToolsToGeminiTools(tools any) []any {
"properties": map[string]any{}, "properties": map[string]any{},
} }
} }
// 清理 JSON Schema
cleanedParams := cleanToolSchema(params)
funcDecls = append(funcDecls, map[string]any{ funcDecls = append(funcDecls, map[string]any{
"name": name, "name": name,
"description": desc, "description": desc,
"parameters": params, "parameters": cleanedParams,
}) })
} }
@@ -2296,6 +2298,41 @@ func convertClaudeToolsToGeminiTools(tools any) []any {
} }
} }
// cleanToolSchema 清理工具的 JSON Schema移除 Gemini 不支持的字段
func cleanToolSchema(schema any) any {
if schema == nil {
return nil
}
switch v := schema.(type) {
case map[string]any:
cleaned := make(map[string]any)
for key, value := range v {
// 跳过不支持的字段
if key == "$schema" || key == "$id" || key == "$ref" ||
key == "additionalProperties" || key == "minLength" ||
key == "maxLength" || key == "minItems" || key == "maxItems" {
continue
}
// 递归清理嵌套对象
cleaned[key] = cleanToolSchema(value)
}
// 规范化 type 字段为大写
if typeVal, ok := cleaned["type"].(string); ok {
cleaned["type"] = strings.ToUpper(typeVal)
}
return cleaned
case []any:
cleaned := make([]any, len(v))
for i, item := range v {
cleaned[i] = cleanToolSchema(item)
}
return cleaned
default:
return v
}
}
func convertClaudeGenerationConfig(req map[string]any) map[string]any { func convertClaudeGenerationConfig(req map[string]any) map[string]any {
out := make(map[string]any) out := make(map[string]any)
if mt, ok := asInt(req["max_tokens"]); ok && mt > 0 { if mt, ok := asInt(req["max_tokens"]); ok && mt > 0 {

View File

@@ -0,0 +1,128 @@
package service
import (
"testing"
)
// TestConvertClaudeToolsToGeminiTools_CustomType 测试custom类型工具转换
func TestConvertClaudeToolsToGeminiTools_CustomType(t *testing.T) {
tests := []struct {
name string
tools any
expectedLen int
description string
}{
{
name: "Standard tools",
tools: []any{
map[string]any{
"name": "get_weather",
"description": "Get weather info",
"input_schema": map[string]any{"type": "object"},
},
},
expectedLen: 1,
description: "标准工具格式应该正常转换",
},
{
name: "Custom type tool (MCP format)",
tools: []any{
map[string]any{
"type": "custom",
"name": "mcp_tool",
"custom": map[string]any{
"description": "MCP tool description",
"input_schema": map[string]any{"type": "object"},
},
},
},
expectedLen: 1,
description: "Custom类型工具应该从custom字段读取",
},
{
name: "Mixed standard and custom tools",
tools: []any{
map[string]any{
"name": "standard_tool",
"description": "Standard",
"input_schema": map[string]any{"type": "object"},
},
map[string]any{
"type": "custom",
"name": "custom_tool",
"custom": map[string]any{
"description": "Custom",
"input_schema": map[string]any{"type": "object"},
},
},
},
expectedLen: 1,
description: "混合工具应该都能正确转换",
},
{
name: "Custom tool without custom field",
tools: []any{
map[string]any{
"type": "custom",
"name": "invalid_custom",
// 缺少 custom 字段
},
},
expectedLen: 0, // 应该被跳过
description: "缺少custom字段的custom工具应该被跳过",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := convertClaudeToolsToGeminiTools(tt.tools)
if tt.expectedLen == 0 {
if result != nil {
t.Errorf("%s: expected nil result, got %v", tt.description, result)
}
return
}
if result == nil {
t.Fatalf("%s: expected non-nil result", tt.description)
}
if len(result) != 1 {
t.Errorf("%s: expected 1 tool declaration, got %d", tt.description, len(result))
return
}
toolDecl, ok := result[0].(map[string]any)
if !ok {
t.Fatalf("%s: result[0] is not map[string]any", tt.description)
}
funcDecls, ok := toolDecl["functionDeclarations"].([]any)
if !ok {
t.Fatalf("%s: functionDeclarations is not []any", tt.description)
}
toolsArr, _ := tt.tools.([]any)
expectedFuncCount := 0
for _, tool := range toolsArr {
toolMap, _ := tool.(map[string]any)
if toolMap["name"] != "" {
// 检查是否为有效的custom工具
if toolMap["type"] == "custom" {
if toolMap["custom"] != nil {
expectedFuncCount++
}
} else {
expectedFuncCount++
}
}
}
if len(funcDecls) != expectedFuncCount {
t.Errorf("%s: expected %d function declarations, got %d",
tt.description, expectedFuncCount, len(funcDecls))
}
})
}
}

View File

@@ -7,6 +7,7 @@ import (
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"regexp"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@@ -163,6 +164,45 @@ type GeminiTokenInfo struct {
Scope string `json:"scope,omitempty"` Scope string `json:"scope,omitempty"`
ProjectID string `json:"project_id,omitempty"` ProjectID string `json:"project_id,omitempty"`
OAuthType string `json:"oauth_type,omitempty"` // "code_assist" 或 "ai_studio" OAuthType string `json:"oauth_type,omitempty"` // "code_assist" 或 "ai_studio"
TierID string `json:"tier_id,omitempty"` // Gemini Code Assist tier: LEGACY/PRO/ULTRA
}
// validateTierID validates tier_id format and length
func validateTierID(tierID string) error {
if tierID == "" {
return nil // Empty is allowed
}
if len(tierID) > 64 {
return fmt.Errorf("tier_id exceeds maximum length of 64 characters")
}
// Allow alphanumeric, underscore, hyphen, and slash (for tier paths)
if !regexp.MustCompile(`^[a-zA-Z0-9_/-]+$`).MatchString(tierID) {
return fmt.Errorf("tier_id contains invalid characters")
}
return nil
}
// extractTierIDFromAllowedTiers extracts tierID from LoadCodeAssist response
// Prioritizes IsDefault tier, falls back to first non-empty tier
func extractTierIDFromAllowedTiers(allowedTiers []geminicli.AllowedTier) string {
tierID := "LEGACY"
// First pass: look for default tier
for _, tier := range allowedTiers {
if tier.IsDefault && strings.TrimSpace(tier.ID) != "" {
tierID = strings.TrimSpace(tier.ID)
break
}
}
// Second pass: if still LEGACY, take first non-empty tier
if tierID == "LEGACY" {
for _, tier := range allowedTiers {
if strings.TrimSpace(tier.ID) != "" {
tierID = strings.TrimSpace(tier.ID)
break
}
}
}
return tierID
} }
func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExchangeCodeInput) (*GeminiTokenInfo, error) { func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExchangeCodeInput) (*GeminiTokenInfo, error) {
@@ -223,13 +263,14 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - 300 expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - 300
projectID := sessionProjectID projectID := sessionProjectID
var tierID string
// 对于 code_assist 模式project_id 是必需的 // 对于 code_assist 模式project_id 是必需的
// 对于 ai_studio 模式project_id 是可选的(不影响使用 AI Studio API // 对于 ai_studio 模式project_id 是可选的(不影响使用 AI Studio API
if oauthType == "code_assist" { if oauthType == "code_assist" {
if projectID == "" { if projectID == "" {
var err error var err error
projectID, err = s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL) projectID, tierID, err = s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL)
if err != nil { if err != nil {
// 记录警告但不阻断流程,允许后续补充 project_id // 记录警告但不阻断流程,允许后续补充 project_id
fmt.Printf("[GeminiOAuth] Warning: Failed to fetch project_id during token exchange: %v\n", err) fmt.Printf("[GeminiOAuth] Warning: Failed to fetch project_id during token exchange: %v\n", err)
@@ -248,6 +289,7 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
ExpiresAt: expiresAt, ExpiresAt: expiresAt,
Scope: tokenResp.Scope, Scope: tokenResp.Scope,
ProjectID: projectID, ProjectID: projectID,
TierID: tierID,
OAuthType: oauthType, OAuthType: oauthType,
}, nil }, nil
} }
@@ -357,7 +399,7 @@ func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *A
// For Code Assist, project_id is required. Auto-detect if missing. // For Code Assist, project_id is required. Auto-detect if missing.
// For AI Studio OAuth, project_id is optional and should not block refresh. // For AI Studio OAuth, project_id is optional and should not block refresh.
if oauthType == "code_assist" && strings.TrimSpace(tokenInfo.ProjectID) == "" { if oauthType == "code_assist" && strings.TrimSpace(tokenInfo.ProjectID) == "" {
projectID, err := s.fetchProjectID(ctx, tokenInfo.AccessToken, proxyURL) projectID, tierID, err := s.fetchProjectID(ctx, tokenInfo.AccessToken, proxyURL)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to auto-detect project_id: %w", err) return nil, fmt.Errorf("failed to auto-detect project_id: %w", err)
} }
@@ -366,6 +408,7 @@ func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *A
return nil, fmt.Errorf("failed to auto-detect project_id: empty result") return nil, fmt.Errorf("failed to auto-detect project_id: empty result")
} }
tokenInfo.ProjectID = projectID tokenInfo.ProjectID = projectID
tokenInfo.TierID = tierID
} }
return tokenInfo, nil return tokenInfo, nil
@@ -388,6 +431,13 @@ func (s *GeminiOAuthService) BuildAccountCredentials(tokenInfo *GeminiTokenInfo)
if tokenInfo.ProjectID != "" { if tokenInfo.ProjectID != "" {
creds["project_id"] = tokenInfo.ProjectID creds["project_id"] = tokenInfo.ProjectID
} }
if tokenInfo.TierID != "" {
// Validate tier_id before storing
if err := validateTierID(tokenInfo.TierID); err == nil {
creds["tier_id"] = tokenInfo.TierID
}
// Silently skip invalid tier_id (don't block account creation)
}
if tokenInfo.OAuthType != "" { if tokenInfo.OAuthType != "" {
creds["oauth_type"] = tokenInfo.OAuthType creds["oauth_type"] = tokenInfo.OAuthType
} }
@@ -398,34 +448,26 @@ func (s *GeminiOAuthService) Stop() {
s.sessionStore.Stop() s.sessionStore.Stop()
} }
func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, proxyURL string) (string, error) { func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, proxyURL string) (string, string, error) {
if s.codeAssist == nil { if s.codeAssist == nil {
return "", errors.New("code assist client not configured") return "", "", errors.New("code assist client not configured")
} }
loadResp, loadErr := s.codeAssist.LoadCodeAssist(ctx, accessToken, proxyURL, nil) loadResp, loadErr := s.codeAssist.LoadCodeAssist(ctx, accessToken, proxyURL, nil)
// Extract tierID from response (works whether CloudAICompanionProject is set or not)
tierID := "LEGACY"
if loadResp != nil {
tierID = extractTierIDFromAllowedTiers(loadResp.AllowedTiers)
}
// If LoadCodeAssist returned a project, use it
if loadErr == nil && loadResp != nil && strings.TrimSpace(loadResp.CloudAICompanionProject) != "" { if loadErr == nil && loadResp != nil && strings.TrimSpace(loadResp.CloudAICompanionProject) != "" {
return strings.TrimSpace(loadResp.CloudAICompanionProject), nil return strings.TrimSpace(loadResp.CloudAICompanionProject), tierID, nil
} }
// Pick tier from allowedTiers; if no default tier is marked, pick the first non-empty tier ID. // Pick tier from allowedTiers; if no default tier is marked, pick the first non-empty tier ID.
tierID := "LEGACY" // (tierID already extracted above, reuse it)
if loadResp != nil {
for _, tier := range loadResp.AllowedTiers {
if tier.IsDefault && strings.TrimSpace(tier.ID) != "" {
tierID = strings.TrimSpace(tier.ID)
break
}
}
if strings.TrimSpace(tierID) == "" || tierID == "LEGACY" {
for _, tier := range loadResp.AllowedTiers {
if strings.TrimSpace(tier.ID) != "" {
tierID = strings.TrimSpace(tier.ID)
break
}
}
}
}
req := &geminicli.OnboardUserRequest{ req := &geminicli.OnboardUserRequest{
TierID: tierID, TierID: tierID,
@@ -443,39 +485,39 @@ func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, pr
// If Code Assist onboarding fails (e.g. INVALID_ARGUMENT), fallback to Cloud Resource Manager projects. // If Code Assist onboarding fails (e.g. INVALID_ARGUMENT), fallback to Cloud Resource Manager projects.
fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL) fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL)
if fbErr == nil && strings.TrimSpace(fallback) != "" { if fbErr == nil && strings.TrimSpace(fallback) != "" {
return strings.TrimSpace(fallback), nil return strings.TrimSpace(fallback), tierID, nil
} }
return "", err return "", "", err
} }
if resp.Done { if resp.Done {
if resp.Response != nil && resp.Response.CloudAICompanionProject != nil { if resp.Response != nil && resp.Response.CloudAICompanionProject != nil {
switch v := resp.Response.CloudAICompanionProject.(type) { switch v := resp.Response.CloudAICompanionProject.(type) {
case string: case string:
return strings.TrimSpace(v), nil return strings.TrimSpace(v), tierID, nil
case map[string]any: case map[string]any:
if id, ok := v["id"].(string); ok { if id, ok := v["id"].(string); ok {
return strings.TrimSpace(id), nil return strings.TrimSpace(id), tierID, nil
} }
} }
} }
fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL) fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL)
if fbErr == nil && strings.TrimSpace(fallback) != "" { if fbErr == nil && strings.TrimSpace(fallback) != "" {
return strings.TrimSpace(fallback), nil return strings.TrimSpace(fallback), tierID, nil
} }
return "", errors.New("onboardUser completed but no project_id returned") return "", "", errors.New("onboardUser completed but no project_id returned")
} }
time.Sleep(2 * time.Second) time.Sleep(2 * time.Second)
} }
fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL) fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL)
if fbErr == nil && strings.TrimSpace(fallback) != "" { if fbErr == nil && strings.TrimSpace(fallback) != "" {
return strings.TrimSpace(fallback), nil return strings.TrimSpace(fallback), tierID, nil
} }
if loadErr != nil { if loadErr != nil {
return "", fmt.Errorf("loadCodeAssist failed (%v) and onboardUser timeout after %d attempts", loadErr, maxAttempts) return "", "", fmt.Errorf("loadCodeAssist failed (%v) and onboardUser timeout after %d attempts", loadErr, maxAttempts)
} }
return "", fmt.Errorf("onboardUser timeout after %d attempts", maxAttempts) return "", "", fmt.Errorf("onboardUser timeout after %d attempts", maxAttempts)
} }
type googleCloudProject struct { type googleCloudProject struct {

View File

@@ -112,7 +112,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
} }
} }
detected, err := p.geminiOAuthService.fetchProjectID(ctx, accessToken, proxyURL) detected, tierID, err := p.geminiOAuthService.fetchProjectID(ctx, accessToken, proxyURL)
if err != nil { if err != nil {
log.Printf("[GeminiTokenProvider] Auto-detect project_id failed: %v, fallback to AI Studio API mode", err) log.Printf("[GeminiTokenProvider] Auto-detect project_id failed: %v, fallback to AI Studio API mode", err)
return accessToken, nil return accessToken, nil
@@ -123,6 +123,9 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
account.Credentials = make(map[string]any) account.Credentials = make(map[string]any)
} }
account.Credentials["project_id"] = detected account.Credentials["project_id"] = detected
if tierID != "" {
account.Credentials["tier_id"] = tierID
}
_ = p.accountRepo.Update(ctx, account) _ = p.accountRepo.Update(ctx, account)
} }
} }

View File

@@ -4,7 +4,7 @@ import (
"context" "context"
"fmt" "fmt"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
) )

View File

@@ -4,7 +4,7 @@ import (
"context" "context"
"fmt" "fmt"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
) )

View File

@@ -10,7 +10,7 @@ import (
"time" "time"
dbent "github.com/Wei-Shaw/sub2api/ent" dbent "github.com/Wei-Shaw/sub2api/ent"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
) )

View File

@@ -9,7 +9,7 @@ import (
"strconv" "strconv"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
) )
var ( var (

View File

@@ -6,7 +6,7 @@ import (
"log" "log"
"time" "time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
) )
@@ -490,6 +490,7 @@ func (s *SubscriptionService) CheckAndResetWindows(ctx context.Context, sub *Use
} }
// CheckUsageLimits 检查使用限额(返回错误如果超限) // CheckUsageLimits 检查使用限额(返回错误如果超限)
// 用于中间件的快速预检查additionalCost 通常为 0
func (s *SubscriptionService) CheckUsageLimits(ctx context.Context, sub *UserSubscription, group *Group, additionalCost float64) error { func (s *SubscriptionService) CheckUsageLimits(ctx context.Context, sub *UserSubscription, group *Group, additionalCost float64) error {
if !sub.CheckDailyLimit(group, additionalCost) { if !sub.CheckDailyLimit(group, additionalCost) {
return ErrDailyLimitExceeded return ErrDailyLimitExceeded

View File

@@ -5,7 +5,7 @@ import (
"fmt" "fmt"
"log" "log"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
) )
var ( var (

View File

@@ -5,7 +5,7 @@ import (
"fmt" "fmt"
"time" "time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
) )

View File

@@ -4,7 +4,7 @@ import (
"context" "context"
"fmt" "fmt"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
) )

View File

@@ -11,7 +11,7 @@ import (
"strconv" "strconv"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/infrastructure" "github.com/Wei-Shaw/sub2api/internal/repository"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
_ "github.com/lib/pq" _ "github.com/lib/pq"
@@ -262,7 +262,7 @@ func initializeDatabase(cfg *SetupConfig) error {
migrationCtx, cancel := context.WithTimeout(context.Background(), 60*time.Second) migrationCtx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
defer cancel() defer cancel()
return infrastructure.ApplyMigrations(migrationCtx, db) return repository.ApplyMigrations(migrationCtx, db)
} }
func createAdminUser(cfg *SetupConfig) error { func createAdminUser(cfg *SetupConfig) error {

View File

@@ -122,6 +122,21 @@ pricing:
# Hash check interval in minutes # Hash check interval in minutes
hash_check_interval_minutes: 10 hash_check_interval_minutes: 10
# =============================================================================
# Gateway (Optional)
# =============================================================================
gateway:
# Wait time (in seconds) for upstream response headers (streaming body not affected)
response_header_timeout: 300
# Log upstream error response body summary (safe/truncated; does not log request content)
log_upstream_error_body: false
# Max bytes to log from upstream error body
log_upstream_error_body_max_bytes: 2048
# Auto inject anthropic-beta for API-key accounts when needed (default off)
inject_beta_for_apikey: false
# Allow failover on selected 400 errors (default off)
failover_on_400: false
# ============================================================================= # =============================================================================
# Gemini OAuth (Required for Gemini accounts) # Gemini OAuth (Required for Gemini accounts)
# ============================================================================= # =============================================================================

View File

@@ -952,6 +952,7 @@
"integrity": "sha512-N2clP5pJhB2YnZJ3PIHFk5RkygRX5WO/5f0WC08tp0wd+sv0rsJk3MqWn3CbNmT2J505a5336jaQj4ph1AdMug==", "integrity": "sha512-N2clP5pJhB2YnZJ3PIHFk5RkygRX5WO/5f0WC08tp0wd+sv0rsJk3MqWn3CbNmT2J505a5336jaQj4ph1AdMug==",
"dev": true, "dev": true,
"license": "MIT", "license": "MIT",
"peer": true,
"dependencies": { "dependencies": {
"undici-types": "~6.21.0" "undici-types": "~6.21.0"
} }
@@ -1367,6 +1368,7 @@
} }
], ],
"license": "MIT", "license": "MIT",
"peer": true,
"dependencies": { "dependencies": {
"baseline-browser-mapping": "^2.9.0", "baseline-browser-mapping": "^2.9.0",
"caniuse-lite": "^1.0.30001759", "caniuse-lite": "^1.0.30001759",
@@ -1443,6 +1445,7 @@
"resolved": "https://registry.npmmirror.com/chart.js/-/chart.js-4.5.1.tgz", "resolved": "https://registry.npmmirror.com/chart.js/-/chart.js-4.5.1.tgz",
"integrity": "sha512-GIjfiT9dbmHRiYi6Nl2yFCq7kkwdkp1W/lp2J99rX0yo9tgJGn3lKQATztIjb5tVtevcBtIdICNWqlq5+E8/Pw==", "integrity": "sha512-GIjfiT9dbmHRiYi6Nl2yFCq7kkwdkp1W/lp2J99rX0yo9tgJGn3lKQATztIjb5tVtevcBtIdICNWqlq5+E8/Pw==",
"license": "MIT", "license": "MIT",
"peer": true,
"dependencies": { "dependencies": {
"@kurkle/color": "^0.3.0" "@kurkle/color": "^0.3.0"
}, },
@@ -2040,6 +2043,7 @@
"integrity": "sha512-/imKNG4EbWNrVjoNC/1H5/9GFy+tqjGBHCaSsN+P2RnPqjsLmv6UD3Ej+Kj8nBWaRAwyk7kK5ZUc+OEatnTR3A==", "integrity": "sha512-/imKNG4EbWNrVjoNC/1H5/9GFy+tqjGBHCaSsN+P2RnPqjsLmv6UD3Ej+Kj8nBWaRAwyk7kK5ZUc+OEatnTR3A==",
"dev": true, "dev": true,
"license": "MIT", "license": "MIT",
"peer": true,
"bin": { "bin": {
"jiti": "bin/jiti.js" "jiti": "bin/jiti.js"
} }
@@ -2348,6 +2352,7 @@
} }
], ],
"license": "MIT", "license": "MIT",
"peer": true,
"dependencies": { "dependencies": {
"nanoid": "^3.3.11", "nanoid": "^3.3.11",
"picocolors": "^1.1.1", "picocolors": "^1.1.1",
@@ -2821,6 +2826,7 @@
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==", "integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
"dev": true, "dev": true,
"license": "MIT", "license": "MIT",
"peer": true,
"engines": { "engines": {
"node": ">=12" "node": ">=12"
}, },
@@ -2854,6 +2860,7 @@
"integrity": "sha512-hjcS1mhfuyi4WW8IWtjP7brDrG2cuDZukyrYrSauoXGNgx0S7zceP07adYkJycEr56BOUTNPzbInooiN3fn1qw==", "integrity": "sha512-hjcS1mhfuyi4WW8IWtjP7brDrG2cuDZukyrYrSauoXGNgx0S7zceP07adYkJycEr56BOUTNPzbInooiN3fn1qw==",
"devOptional": true, "devOptional": true,
"license": "Apache-2.0", "license": "Apache-2.0",
"peer": true,
"bin": { "bin": {
"tsc": "bin/tsc", "tsc": "bin/tsc",
"tsserver": "bin/tsserver" "tsserver": "bin/tsserver"
@@ -2926,6 +2933,7 @@
"integrity": "sha512-o5a9xKjbtuhY6Bi5S3+HvbRERmouabWbyUcpXXUA1u+GNUKoROi9byOJ8M0nHbHYHkYICiMlqxkg1KkYmm25Sw==", "integrity": "sha512-o5a9xKjbtuhY6Bi5S3+HvbRERmouabWbyUcpXXUA1u+GNUKoROi9byOJ8M0nHbHYHkYICiMlqxkg1KkYmm25Sw==",
"dev": true, "dev": true,
"license": "MIT", "license": "MIT",
"peer": true,
"dependencies": { "dependencies": {
"esbuild": "^0.21.3", "esbuild": "^0.21.3",
"postcss": "^8.4.43", "postcss": "^8.4.43",
@@ -3097,6 +3105,7 @@
"resolved": "https://registry.npmmirror.com/vue/-/vue-3.5.25.tgz", "resolved": "https://registry.npmmirror.com/vue/-/vue-3.5.25.tgz",
"integrity": "sha512-YLVdgv2K13WJ6n+kD5owehKtEXwdwXuj2TTyJMsO7pSeKw2bfRNZGjhB7YzrpbMYj5b5QsUebHpOqR3R3ziy/g==", "integrity": "sha512-YLVdgv2K13WJ6n+kD5owehKtEXwdwXuj2TTyJMsO7pSeKw2bfRNZGjhB7YzrpbMYj5b5QsUebHpOqR3R3ziy/g==",
"license": "MIT", "license": "MIT",
"peer": true,
"dependencies": { "dependencies": {
"@vue/compiler-dom": "3.5.25", "@vue/compiler-dom": "3.5.25",
"@vue/compiler-sfc": "3.5.25", "@vue/compiler-sfc": "3.5.25",
@@ -3190,6 +3199,7 @@
"integrity": "sha512-P7OP77b2h/Pmk+lZdJ0YWs+5tJ6J2+uOQPo7tlBnY44QqQSPYvS0qVT4wqDJgwrZaLe47etJLLQRFia71GYITw==", "integrity": "sha512-P7OP77b2h/Pmk+lZdJ0YWs+5tJ6J2+uOQPo7tlBnY44QqQSPYvS0qVT4wqDJgwrZaLe47etJLLQRFia71GYITw==",
"dev": true, "dev": true,
"license": "MIT", "license": "MIT",
"peer": true,
"dependencies": { "dependencies": {
"@volar/typescript": "2.4.15", "@volar/typescript": "2.4.15",
"@vue/language-core": "2.2.12" "@vue/language-core": "2.2.12"

View File

@@ -83,6 +83,14 @@
></div> ></div>
</div> </div>
</div> </div>
<!-- Tier Indicator -->
<span
v-if="tierDisplay"
class="inline-flex items-center rounded bg-blue-100 px-1.5 py-0.5 text-xs font-medium text-blue-700 dark:bg-blue-900/30 dark:text-blue-400"
>
{{ tierDisplay }}
</span>
</div> </div>
</template> </template>
@@ -140,4 +148,23 @@ const statusText = computed(() => {
return props.account.status return props.account.status
}) })
// Computed: tier display
const tierDisplay = computed(() => {
const credentials = props.account.credentials as Record<string, any> | undefined
const tierId = credentials?.tier_id
if (!tierId || tierId === 'unknown') return null
const tierMap: Record<string, string> = {
'free': 'Free',
'payg': 'Pay-as-you-go',
'pay-as-you-go': 'Pay-as-you-go',
'enterprise': 'Enterprise',
'LEGACY': 'Legacy',
'PRO': 'Pro',
'ULTRA': 'Ultra'
}
return tierMap[tierId] || tierId
})
</script> </script>