diff --git a/backend/cmd/jwtgen/main.go b/backend/cmd/jwtgen/main.go index c461198b..139a3a39 100644 --- a/backend/cmd/jwtgen/main.go +++ b/backend/cmd/jwtgen/main.go @@ -33,7 +33,7 @@ func main() { }() userRepo := repository.NewUserRepository(client, sqlDB) - authService := service.NewAuthService(userRepo, cfg, nil, nil, nil, nil, nil) + authService := service.NewAuthService(userRepo, nil, cfg, nil, nil, nil, nil, nil) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 9ccbddc2..694d05a7 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -43,6 +43,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { return nil, err } userRepository := repository.NewUserRepository(client, db) + redeemCodeRepository := repository.NewRedeemCodeRepository(client) settingRepository := repository.NewSettingRepository(client) settingService := service.NewSettingService(settingRepository, configConfig) redisClient := repository.ProvideRedis(configConfig) @@ -61,24 +62,23 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, apiKeyCache, configConfig) apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService) promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator) - authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService) + authService := service.NewAuthService(userRepository, redeemCodeRepository, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService) userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator) + subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService) + redeemCache := repository.NewRedeemCache(redisClient) + redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator) secretEncryptor, err := repository.NewAESEncryptor(configConfig) if err != nil { return nil, err } totpCache := repository.NewTotpCache(redisClient) totpService := service.NewTotpService(userRepository, secretEncryptor, totpCache, settingService, emailService, emailQueueService) - authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService, totpService) + authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService, redeemService, totpService) userHandler := handler.NewUserHandler(userService) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) usageLogRepository := repository.NewUsageLogRepository(client, db) usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator) usageHandler := handler.NewUsageHandler(usageService, apiKeyService) - redeemCodeRepository := repository.NewRedeemCodeRepository(client) - subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService) - redeemCache := repository.NewRedeemCache(redisClient) - redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator) redeemHandler := handler.NewRedeemHandler(redeemService) subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService) announcementRepository := repository.NewAnnouncementRepository(client) diff --git a/backend/internal/domain/constants.go b/backend/internal/domain/constants.go index 586e6309..35a6a5b7 100644 --- a/backend/internal/domain/constants.go +++ b/backend/internal/domain/constants.go @@ -37,6 +37,7 @@ const ( RedeemTypeBalance = "balance" RedeemTypeConcurrency = "concurrency" RedeemTypeSubscription = "subscription" + RedeemTypeInvitation = "invitation" ) // PromoCode status constants diff --git a/backend/internal/handler/admin/group_handler.go b/backend/internal/handler/admin/group_handler.go index 9192fe45..d10d678b 100644 --- a/backend/internal/handler/admin/group_handler.go +++ b/backend/internal/handler/admin/group_handler.go @@ -47,6 +47,8 @@ type CreateGroupRequest struct { MCPXMLInject *bool `json:"mcp_xml_inject"` // 支持的模型系列(仅 antigravity 平台使用) SupportedModelScopes []string `json:"supported_model_scopes"` + // 从指定分组复制账号(创建后自动绑定) + CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"` } // UpdateGroupRequest represents update group request @@ -74,6 +76,8 @@ type UpdateGroupRequest struct { MCPXMLInject *bool `json:"mcp_xml_inject"` // 支持的模型系列(仅 antigravity 平台使用) SupportedModelScopes *[]string `json:"supported_model_scopes"` + // 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号) + CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"` } // List handles listing all groups with pagination @@ -182,6 +186,7 @@ func (h *GroupHandler) Create(c *gin.Context) { ModelRoutingEnabled: req.ModelRoutingEnabled, MCPXMLInject: req.MCPXMLInject, SupportedModelScopes: req.SupportedModelScopes, + CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs, }) if err != nil { response.ErrorFrom(c, err) @@ -227,6 +232,7 @@ func (h *GroupHandler) Update(c *gin.Context) { ModelRoutingEnabled: req.ModelRoutingEnabled, MCPXMLInject: req.MCPXMLInject, SupportedModelScopes: req.SupportedModelScopes, + CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs, }) if err != nil { response.ErrorFrom(c, err) diff --git a/backend/internal/handler/admin/redeem_handler.go b/backend/internal/handler/admin/redeem_handler.go index f1b68334..e229385f 100644 --- a/backend/internal/handler/admin/redeem_handler.go +++ b/backend/internal/handler/admin/redeem_handler.go @@ -29,7 +29,7 @@ func NewRedeemHandler(adminService service.AdminService) *RedeemHandler { // GenerateRedeemCodesRequest represents generate redeem codes request type GenerateRedeemCodesRequest struct { Count int `json:"count" binding:"required,min=1,max=100"` - Type string `json:"type" binding:"required,oneof=balance concurrency subscription"` + Type string `json:"type" binding:"required,oneof=balance concurrency subscription invitation"` Value float64 `json:"value" binding:"min=0"` GroupID *int64 `json:"group_id"` // 订阅类型必填 ValidityDays int `json:"validity_days" binding:"omitempty,max=36500"` // 订阅类型使用,默认30天,最大100年 diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index cdad3659..1e723ee5 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -49,6 +49,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { EmailVerifyEnabled: settings.EmailVerifyEnabled, PromoCodeEnabled: settings.PromoCodeEnabled, PasswordResetEnabled: settings.PasswordResetEnabled, + InvitationCodeEnabled: settings.InvitationCodeEnabled, TotpEnabled: settings.TotpEnabled, TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(), SMTPHost: settings.SMTPHost, @@ -94,11 +95,12 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { // UpdateSettingsRequest 更新设置请求 type UpdateSettingsRequest struct { // 注册设置 - RegistrationEnabled bool `json:"registration_enabled"` - EmailVerifyEnabled bool `json:"email_verify_enabled"` - PromoCodeEnabled bool `json:"promo_code_enabled"` - PasswordResetEnabled bool `json:"password_reset_enabled"` - TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证 + RegistrationEnabled bool `json:"registration_enabled"` + EmailVerifyEnabled bool `json:"email_verify_enabled"` + PromoCodeEnabled bool `json:"promo_code_enabled"` + PasswordResetEnabled bool `json:"password_reset_enabled"` + InvitationCodeEnabled bool `json:"invitation_code_enabled"` + TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证 // 邮件服务设置 SMTPHost string `json:"smtp_host"` @@ -291,6 +293,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { EmailVerifyEnabled: req.EmailVerifyEnabled, PromoCodeEnabled: req.PromoCodeEnabled, PasswordResetEnabled: req.PasswordResetEnabled, + InvitationCodeEnabled: req.InvitationCodeEnabled, TotpEnabled: req.TotpEnabled, SMTPHost: req.SMTPHost, SMTPPort: req.SMTPPort, @@ -370,6 +373,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled, PromoCodeEnabled: updatedSettings.PromoCodeEnabled, PasswordResetEnabled: updatedSettings.PasswordResetEnabled, + InvitationCodeEnabled: updatedSettings.InvitationCodeEnabled, TotpEnabled: updatedSettings.TotpEnabled, TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(), SMTPHost: updatedSettings.SMTPHost, diff --git a/backend/internal/handler/auth_handler.go b/backend/internal/handler/auth_handler.go index 3522407d..0b8d5e15 100644 --- a/backend/internal/handler/auth_handler.go +++ b/backend/internal/handler/auth_handler.go @@ -15,23 +15,25 @@ import ( // AuthHandler handles authentication-related requests type AuthHandler struct { - cfg *config.Config - authService *service.AuthService - userService *service.UserService - settingSvc *service.SettingService - promoService *service.PromoService - totpService *service.TotpService + cfg *config.Config + authService *service.AuthService + userService *service.UserService + settingSvc *service.SettingService + promoService *service.PromoService + redeemService *service.RedeemService + totpService *service.TotpService } // NewAuthHandler creates a new AuthHandler -func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService, settingService *service.SettingService, promoService *service.PromoService, totpService *service.TotpService) *AuthHandler { +func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService, settingService *service.SettingService, promoService *service.PromoService, redeemService *service.RedeemService, totpService *service.TotpService) *AuthHandler { return &AuthHandler{ - cfg: cfg, - authService: authService, - userService: userService, - settingSvc: settingService, - promoService: promoService, - totpService: totpService, + cfg: cfg, + authService: authService, + userService: userService, + settingSvc: settingService, + promoService: promoService, + redeemService: redeemService, + totpService: totpService, } } @@ -41,7 +43,8 @@ type RegisterRequest struct { Password string `json:"password" binding:"required,min=6"` VerifyCode string `json:"verify_code"` TurnstileToken string `json:"turnstile_token"` - PromoCode string `json:"promo_code"` // 注册优惠码 + PromoCode string `json:"promo_code"` // 注册优惠码 + InvitationCode string `json:"invitation_code"` // 邀请码 } // SendVerifyCodeRequest 发送验证码请求 @@ -87,7 +90,7 @@ func (h *AuthHandler) Register(c *gin.Context) { } } - token, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode, req.PromoCode) + token, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode, req.PromoCode, req.InvitationCode) if err != nil { response.ErrorFrom(c, err) return @@ -346,6 +349,66 @@ func (h *AuthHandler) ValidatePromoCode(c *gin.Context) { }) } +// ValidateInvitationCodeRequest 验证邀请码请求 +type ValidateInvitationCodeRequest struct { + Code string `json:"code" binding:"required"` +} + +// ValidateInvitationCodeResponse 验证邀请码响应 +type ValidateInvitationCodeResponse struct { + Valid bool `json:"valid"` + ErrorCode string `json:"error_code,omitempty"` +} + +// ValidateInvitationCode 验证邀请码(公开接口,注册前调用) +// POST /api/v1/auth/validate-invitation-code +func (h *AuthHandler) ValidateInvitationCode(c *gin.Context) { + // 检查邀请码功能是否启用 + if h.settingSvc == nil || !h.settingSvc.IsInvitationCodeEnabled(c.Request.Context()) { + response.Success(c, ValidateInvitationCodeResponse{ + Valid: false, + ErrorCode: "INVITATION_CODE_DISABLED", + }) + return + } + + var req ValidateInvitationCodeRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + // 验证邀请码 + redeemCode, err := h.redeemService.GetByCode(c.Request.Context(), req.Code) + if err != nil { + response.Success(c, ValidateInvitationCodeResponse{ + Valid: false, + ErrorCode: "INVITATION_CODE_NOT_FOUND", + }) + return + } + + // 检查类型和状态 + if redeemCode.Type != service.RedeemTypeInvitation { + response.Success(c, ValidateInvitationCodeResponse{ + Valid: false, + ErrorCode: "INVITATION_CODE_INVALID", + }) + return + } + + if redeemCode.Status != service.StatusUnused { + response.Success(c, ValidateInvitationCodeResponse{ + Valid: false, + ErrorCode: "INVITATION_CODE_USED", + }) + return + } + + response.Success(c, ValidateInvitationCodeResponse{ + Valid: true, + }) +} // ForgotPasswordRequest 忘记密码请求 type ForgotPasswordRequest struct { Email string `json:"email" binding:"required,email"` diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go index 101c44c9..c22d8100 100644 --- a/backend/internal/handler/dto/mappers.go +++ b/backend/internal/handler/dto/mappers.go @@ -381,6 +381,7 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog { AccountID: l.AccountID, RequestID: l.RequestID, Model: l.Model, + ReasoningEffort: l.ReasoningEffort, GroupID: l.GroupID, SubscriptionID: l.SubscriptionID, InputTokens: l.InputTokens, diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go index 152da756..be94bc16 100644 --- a/backend/internal/handler/dto/settings.go +++ b/backend/internal/handler/dto/settings.go @@ -6,6 +6,7 @@ type SystemSettings struct { EmailVerifyEnabled bool `json:"email_verify_enabled"` PromoCodeEnabled bool `json:"promo_code_enabled"` PasswordResetEnabled bool `json:"password_reset_enabled"` + InvitationCodeEnabled bool `json:"invitation_code_enabled"` TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证 TotpEncryptionKeyConfigured bool `json:"totp_encryption_key_configured"` // TOTP 加密密钥是否已配置 @@ -63,6 +64,7 @@ type PublicSettings struct { EmailVerifyEnabled bool `json:"email_verify_enabled"` PromoCodeEnabled bool `json:"promo_code_enabled"` PasswordResetEnabled bool `json:"password_reset_enabled"` + InvitationCodeEnabled bool `json:"invitation_code_enabled"` TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证 TurnstileEnabled bool `json:"turnstile_enabled"` TurnstileSiteKey string `json:"turnstile_site_key"` diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go index ab0b86fe..ecfbb7c2 100644 --- a/backend/internal/handler/dto/types.go +++ b/backend/internal/handler/dto/types.go @@ -237,6 +237,9 @@ type UsageLog struct { AccountID int64 `json:"account_id"` RequestID string `json:"request_id"` Model string `json:"model"` + // ReasoningEffort is the request's reasoning effort level (OpenAI Responses API). + // nil means not provided / not applicable. + ReasoningEffort *string `json:"reasoning_effort,omitempty"` GroupID *int64 `json:"group_id"` SubscriptionID *int64 `json:"subscription_id"` diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 86564db3..bef2e5e9 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -596,7 +596,6 @@ func cloneAPIKeyWithGroup(apiKey *service.APIKey, group *service.Group) *service cloned.Group = group return &cloned } - // Usage handles getting account balance and usage statistics for CC Switch integration // GET /v1/usage func (h *GatewayHandler) Usage(c *gin.Context) { @@ -849,6 +848,9 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { return } + // 检查是否为 Claude Code 客户端,设置到 context 中 + SetClaudeCodeClientContext(c, body) + setOpsRequestContext(c, "", false, body) parsedReq, err := service.ParseGatewayRequest(body) diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index 53431dc3..cfb59c04 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -371,18 +371,21 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { userAgent := c.GetHeader("User-Agent") clientIP := ip.GetClientIP(c) - // 6) record usage async + // 6) record usage async (Gemini 使用长上下文双倍计费) go func(result *service.ForwardResult, usedAccount *service.Account, ua, ip string) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ - Result: result, - APIKey: apiKey, - User: apiKey.User, - Account: usedAccount, - Subscription: subscription, - UserAgent: ua, - IPAddress: ip, + + if err := h.gatewayService.RecordUsageWithLongContext(ctx, &service.RecordUsageLongContextInput{ + Result: result, + APIKey: apiKey, + User: apiKey.User, + Account: usedAccount, + Subscription: subscription, + UserAgent: ua, + IPAddress: ip, + LongContextThreshold: 200000, // Gemini 200K 阈值 + LongContextMultiplier: 2.0, // 超出部分双倍计费 }); err != nil { log.Printf("Record usage failed: %v", err) } diff --git a/backend/internal/handler/setting_handler.go b/backend/internal/handler/setting_handler.go index 9fd27dc3..2029f116 100644 --- a/backend/internal/handler/setting_handler.go +++ b/backend/internal/handler/setting_handler.go @@ -36,6 +36,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) { EmailVerifyEnabled: settings.EmailVerifyEnabled, PromoCodeEnabled: settings.PromoCodeEnabled, PasswordResetEnabled: settings.PasswordResetEnabled, + InvitationCodeEnabled: settings.InvitationCodeEnabled, TotpEnabled: settings.TotpEnabled, TurnstileEnabled: settings.TurnstileEnabled, TurnstileSiteKey: settings.TurnstileSiteKey, diff --git a/backend/internal/pkg/antigravity/oauth.go b/backend/internal/pkg/antigravity/oauth.go index d1712c98..59ded99e 100644 --- a/backend/internal/pkg/antigravity/oauth.go +++ b/backend/internal/pkg/antigravity/oauth.go @@ -81,7 +81,6 @@ func ForwardBaseURLs() []string { } return reordered } - // URLAvailability 管理 URL 可用性状态(带 TTL 自动恢复和动态优先级) type URLAvailability struct { mu sync.RWMutex diff --git a/backend/internal/pkg/claude/constants.go b/backend/internal/pkg/claude/constants.go index d1a56a84..8b3441dc 100644 --- a/backend/internal/pkg/claude/constants.go +++ b/backend/internal/pkg/claude/constants.go @@ -9,11 +9,26 @@ const ( BetaClaudeCode = "claude-code-20250219" BetaInterleavedThinking = "interleaved-thinking-2025-05-14" BetaFineGrainedToolStreaming = "fine-grained-tool-streaming-2025-05-14" + BetaTokenCounting = "token-counting-2024-11-01" ) // DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming +// MessageBetaHeaderNoTools /v1/messages 在无工具时的 beta header +// +// NOTE: Claude Code OAuth credentials are scoped to Claude Code. When we "mimic" +// Claude Code for non-Claude-Code clients, we must include the claude-code beta +// even if the request doesn't use tools, otherwise upstream may reject the +// request as a non-Claude-Code API request. +const MessageBetaHeaderNoTools = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + +// MessageBetaHeaderWithTools /v1/messages 在有工具时的 beta header +const MessageBetaHeaderWithTools = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + +// CountTokensBetaHeader count_tokens 请求使用的 anthropic-beta header +const CountTokensBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaTokenCounting + // HaikuBetaHeader Haiku 模型使用的 anthropic-beta header(不需要 claude-code beta) const HaikuBetaHeader = BetaOAuth + "," + BetaInterleavedThinking @@ -25,15 +40,17 @@ const APIKeyHaikuBetaHeader = BetaInterleavedThinking // DefaultHeaders 是 Claude Code 客户端默认请求头。 var DefaultHeaders = map[string]string{ - "User-Agent": "claude-cli/2.0.62 (external, cli)", + // Keep these in sync with recent Claude CLI traffic to reduce the chance + // that Claude Code-scoped OAuth credentials are rejected as "non-CLI" usage. + "User-Agent": "claude-cli/2.1.22 (external, cli)", "X-Stainless-Lang": "js", - "X-Stainless-Package-Version": "0.52.0", + "X-Stainless-Package-Version": "0.70.0", "X-Stainless-OS": "Linux", - "X-Stainless-Arch": "x64", + "X-Stainless-Arch": "arm64", "X-Stainless-Runtime": "node", - "X-Stainless-Runtime-Version": "v22.14.0", + "X-Stainless-Runtime-Version": "v24.13.0", "X-Stainless-Retry-Count": "0", - "X-Stainless-Timeout": "60", + "X-Stainless-Timeout": "600", "X-App": "cli", "Anthropic-Dangerous-Direct-Browser-Access": "true", } @@ -79,3 +96,39 @@ func DefaultModelIDs() []string { // DefaultTestModel 测试时使用的默认模型 const DefaultTestModel = "claude-sonnet-4-5-20250929" + +// ModelIDOverrides Claude OAuth 请求需要的模型 ID 映射 +var ModelIDOverrides = map[string]string{ + "claude-sonnet-4-5": "claude-sonnet-4-5-20250929", + "claude-opus-4-5": "claude-opus-4-5-20251101", + "claude-haiku-4-5": "claude-haiku-4-5-20251001", +} + +// ModelIDReverseOverrides 用于将上游模型 ID 还原为短名 +var ModelIDReverseOverrides = map[string]string{ + "claude-sonnet-4-5-20250929": "claude-sonnet-4-5", + "claude-opus-4-5-20251101": "claude-opus-4-5", + "claude-haiku-4-5-20251001": "claude-haiku-4-5", +} + +// NormalizeModelID 根据 Claude OAuth 规则映射模型 +func NormalizeModelID(id string) string { + if id == "" { + return id + } + if mapped, ok := ModelIDOverrides[id]; ok { + return mapped + } + return id +} + +// DenormalizeModelID 将上游模型 ID 转换为短名 +func DenormalizeModelID(id string) string { + if id == "" { + return id + } + if mapped, ok := ModelIDReverseOverrides[id]; ok { + return mapped + } + return id +} diff --git a/backend/internal/repository/group_repo.go b/backend/internal/repository/group_repo.go index 53624635..d8cec491 100644 --- a/backend/internal/repository/group_repo.go +++ b/backend/internal/repository/group_repo.go @@ -439,3 +439,61 @@ func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int6 return counts, nil } + +// GetAccountIDsByGroupIDs 获取多个分组的所有账号 ID(去重) +func (r *groupRepository) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error) { + if len(groupIDs) == 0 { + return nil, nil + } + + rows, err := r.sql.QueryContext( + ctx, + "SELECT DISTINCT account_id FROM account_groups WHERE group_id = ANY($1) ORDER BY account_id", + pq.Array(groupIDs), + ) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + var accountIDs []int64 + for rows.Next() { + var accountID int64 + if err := rows.Scan(&accountID); err != nil { + return nil, err + } + accountIDs = append(accountIDs, accountID) + } + if err := rows.Err(); err != nil { + return nil, err + } + + return accountIDs, nil +} + +// BindAccountsToGroup 将多个账号绑定到指定分组(批量插入,忽略已存在的绑定) +func (r *groupRepository) BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error { + if len(accountIDs) == 0 { + return nil + } + + // 使用 INSERT ... ON CONFLICT DO NOTHING 忽略已存在的绑定 + _, err := r.sql.ExecContext( + ctx, + `INSERT INTO account_groups (account_id, group_id, priority, created_at) + SELECT unnest($1::bigint[]), $2, 50, NOW() + ON CONFLICT (account_id, group_id) DO NOTHING`, + pq.Array(accountIDs), + groupID, + ) + if err != nil { + return err + } + + // 发送调度器事件 + if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &groupID, nil); err != nil { + log.Printf("[SchedulerOutbox] enqueue bind accounts to group failed: group=%d err=%v", groupID, err) + } + + return nil +} diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go index 963db7ba..dc8f1460 100644 --- a/backend/internal/repository/usage_log_repo.go +++ b/backend/internal/repository/usage_log_repo.go @@ -22,7 +22,7 @@ import ( "github.com/lib/pq" ) -const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, created_at" +const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, reasoning_effort, created_at" type usageLogRepository struct { client *dbent.Client @@ -111,21 +111,22 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) duration_ms, first_token_ms, user_agent, - ip_address, - image_count, - image_size, - created_at - ) VALUES ( - $1, $2, $3, $4, $5, - $6, $7, - $8, $9, $10, $11, - $12, $13, - $14, $15, $16, $17, $18, $19, - $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30 - ) - ON CONFLICT (request_id, api_key_id) DO NOTHING - RETURNING id, created_at - ` + ip_address, + image_count, + image_size, + reasoning_effort, + created_at + ) VALUES ( + $1, $2, $3, $4, $5, + $6, $7, + $8, $9, $10, $11, + $12, $13, + $14, $15, $16, $17, $18, $19, + $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31 + ) + ON CONFLICT (request_id, api_key_id) DO NOTHING + RETURNING id, created_at + ` groupID := nullInt64(log.GroupID) subscriptionID := nullInt64(log.SubscriptionID) @@ -134,6 +135,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) userAgent := nullString(log.UserAgent) ipAddress := nullString(log.IPAddress) imageSize := nullString(log.ImageSize) + reasoningEffort := nullString(log.ReasoningEffort) var requestIDArg any if requestID != "" { @@ -170,6 +172,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) ipAddress, log.ImageCount, imageSize, + reasoningEffort, createdAt, } if err := scanSingleRow(ctx, sqlq, query, args, &log.ID, &log.CreatedAt); err != nil { @@ -2090,6 +2093,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e ipAddress sql.NullString imageCount int imageSize sql.NullString + reasoningEffort sql.NullString createdAt time.Time ) @@ -2124,6 +2128,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e &ipAddress, &imageCount, &imageSize, + &reasoningEffort, &createdAt, ); err != nil { return nil, err @@ -2183,6 +2188,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e if imageSize.Valid { log.ImageSize = &imageSize.String } + if reasoningEffort.Valid { + log.ReasoningEffort = &reasoningEffort.String + } return log, nil } diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index 22e6213e..d6b2f1b8 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -488,6 +488,7 @@ func TestAPIContracts(t *testing.T) { "fallback_model_openai": "gpt-4o", "enable_identity_patch": true, "identity_patch_prompt": "", + "invitation_code_enabled": false, "home_content": "", "hide_ccs_import_button": false, "purchase_subscription_enabled": false, @@ -599,8 +600,8 @@ func newContractDeps(t *testing.T) *contractDeps { settingRepo := newStubSettingRepo() settingService := service.NewSettingService(settingRepo, cfg) - adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil) - authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, nil) +adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil) +authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) usageHandler := handler.NewUsageHandler(usageService, apiKeyService) adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil) @@ -880,6 +881,14 @@ func (stubGroupRepo) DeleteAccountGroupsByGroupID(ctx context.Context, groupID i return 0, errors.New("not implemented") } +func (stubGroupRepo) BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error { + return errors.New("not implemented") +} + +func (stubGroupRepo) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error) { + return nil, errors.New("not implemented") +} + type stubAccountRepo struct { bulkUpdateIDs []int64 } diff --git a/backend/internal/server/routes/auth.go b/backend/internal/server/routes/auth.go index 33a88e82..24f6d549 100644 --- a/backend/internal/server/routes/auth.go +++ b/backend/internal/server/routes/auth.go @@ -32,6 +32,10 @@ func RegisterAuthRoutes( auth.POST("/validate-promo-code", rateLimiter.LimitWithOptions("validate-promo", 10, time.Minute, middleware.RateLimitOptions{ FailureMode: middleware.RateLimitFailClose, }), h.Auth.ValidatePromoCode) + // 邀请码验证接口添加速率限制:每分钟最多 10 次(Redis 故障时 fail-close) + auth.POST("/validate-invitation-code", rateLimiter.LimitWithOptions("validate-invitation", 10, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), h.Auth.ValidateInvitationCode) // 忘记密码接口添加速率限制:每分钟最多 5 次(Redis 故障时 fail-close) auth.POST("/forgot-password", rateLimiter.LimitWithOptions("forgot-password", 5, time.Minute, middleware.RateLimitOptions{ FailureMode: middleware.RateLimitFailClose, diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go index 182e0161..7b958838 100644 --- a/backend/internal/service/account.go +++ b/backend/internal/service/account.go @@ -410,6 +410,22 @@ func (a *Account) GetExtraString(key string) string { return "" } +func (a *Account) GetClaudeUserID() string { + if v := strings.TrimSpace(a.GetExtraString("claude_user_id")); v != "" { + return v + } + if v := strings.TrimSpace(a.GetExtraString("anthropic_user_id")); v != "" { + return v + } + if v := strings.TrimSpace(a.GetCredential("claude_user_id")); v != "" { + return v + } + if v := strings.TrimSpace(a.GetCredential("anthropic_user_id")); v != "" { + return v + } + return "" +} + func (a *Account) IsCustomErrorCodesEnabled() bool { if a.Type != AccountTypeAPIKey || a.Credentials == nil { return false diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go index 46376c69..3290fe52 100644 --- a/backend/internal/service/account_test_service.go +++ b/backend/internal/service/account_test_service.go @@ -123,7 +123,7 @@ func createTestPayload(modelID string) (map[string]any, error) { "system": []map[string]any{ { "type": "text", - "text": "You are Claude Code, Anthropic's official CLI for Claude.", + "text": claudeCodeSystemPrompt, "cache_control": map[string]string{ "type": "ephemeral", }, diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index 1449070e..52a10476 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -115,6 +115,8 @@ type CreateGroupInput struct { MCPXMLInject *bool // 支持的模型系列(仅 antigravity 平台使用) SupportedModelScopes []string + // 从指定分组复制账号(创建分组后在同一事务内绑定) + CopyAccountsFromGroupIDs []int64 } type UpdateGroupInput struct { @@ -142,6 +144,8 @@ type UpdateGroupInput struct { MCPXMLInject *bool // 支持的模型系列(仅 antigravity 平台使用) SupportedModelScopes *[]string + // 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号) + CopyAccountsFromGroupIDs []int64 } type CreateAccountInput struct { @@ -598,6 +602,38 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn mcpXMLInject = *input.MCPXMLInject } + // 如果指定了复制账号的源分组,先获取账号 ID 列表 + var accountIDsToCopy []int64 + if len(input.CopyAccountsFromGroupIDs) > 0 { + // 去重源分组 IDs + seen := make(map[int64]struct{}) + uniqueSourceGroupIDs := make([]int64, 0, len(input.CopyAccountsFromGroupIDs)) + for _, srcGroupID := range input.CopyAccountsFromGroupIDs { + if _, exists := seen[srcGroupID]; !exists { + seen[srcGroupID] = struct{}{} + uniqueSourceGroupIDs = append(uniqueSourceGroupIDs, srcGroupID) + } + } + + // 校验源分组的平台是否与新分组一致 + for _, srcGroupID := range uniqueSourceGroupIDs { + srcGroup, err := s.groupRepo.GetByIDLite(ctx, srcGroupID) + if err != nil { + return nil, fmt.Errorf("source group %d not found: %w", srcGroupID, err) + } + if srcGroup.Platform != platform { + return nil, fmt.Errorf("source group %d platform mismatch: expected %s, got %s", srcGroupID, platform, srcGroup.Platform) + } + } + + // 获取所有源分组的账号(去重) + var err error + accountIDsToCopy, err = s.groupRepo.GetAccountIDsByGroupIDs(ctx, uniqueSourceGroupIDs) + if err != nil { + return nil, fmt.Errorf("failed to get accounts from source groups: %w", err) + } + } + group := &Group{ Name: input.Name, Description: input.Description, @@ -622,6 +658,15 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn if err := s.groupRepo.Create(ctx, group); err != nil { return nil, err } + + // 如果有需要复制的账号,绑定到新分组 + if len(accountIDsToCopy) > 0 { + if err := s.groupRepo.BindAccountsToGroup(ctx, group.ID, accountIDsToCopy); err != nil { + return nil, fmt.Errorf("failed to bind accounts to new group: %w", err) + } + group.AccountCount = int64(len(accountIDsToCopy)) + } + return group, nil } @@ -810,6 +855,54 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd if err := s.groupRepo.Update(ctx, group); err != nil { return nil, err } + + // 如果指定了复制账号的源分组,同步绑定(替换当前分组的账号) + if len(input.CopyAccountsFromGroupIDs) > 0 { + // 去重源分组 IDs + seen := make(map[int64]struct{}) + uniqueSourceGroupIDs := make([]int64, 0, len(input.CopyAccountsFromGroupIDs)) + for _, srcGroupID := range input.CopyAccountsFromGroupIDs { + // 校验:源分组不能是自身 + if srcGroupID == id { + return nil, fmt.Errorf("cannot copy accounts from self") + } + // 去重 + if _, exists := seen[srcGroupID]; !exists { + seen[srcGroupID] = struct{}{} + uniqueSourceGroupIDs = append(uniqueSourceGroupIDs, srcGroupID) + } + } + + // 校验源分组的平台是否与当前分组一致 + for _, srcGroupID := range uniqueSourceGroupIDs { + srcGroup, err := s.groupRepo.GetByIDLite(ctx, srcGroupID) + if err != nil { + return nil, fmt.Errorf("source group %d not found: %w", srcGroupID, err) + } + if srcGroup.Platform != group.Platform { + return nil, fmt.Errorf("source group %d platform mismatch: expected %s, got %s", srcGroupID, group.Platform, srcGroup.Platform) + } + } + + // 获取所有源分组的账号(去重) + accountIDsToCopy, err := s.groupRepo.GetAccountIDsByGroupIDs(ctx, uniqueSourceGroupIDs) + if err != nil { + return nil, fmt.Errorf("failed to get accounts from source groups: %w", err) + } + + // 先清空当前分组的所有账号绑定 + if _, err := s.groupRepo.DeleteAccountGroupsByGroupID(ctx, id); err != nil { + return nil, fmt.Errorf("failed to clear existing account bindings: %w", err) + } + + // 再绑定源分组的账号 + if len(accountIDsToCopy) > 0 { + if err := s.groupRepo.BindAccountsToGroup(ctx, id, accountIDsToCopy); err != nil { + return nil, fmt.Errorf("failed to bind accounts to group: %w", err) + } + } + } + if s.authCacheInvalidator != nil { s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id) } diff --git a/backend/internal/service/admin_service_delete_test.go b/backend/internal/service/admin_service_delete_test.go index 6472ccbb..923d33ab 100644 --- a/backend/internal/service/admin_service_delete_test.go +++ b/backend/internal/service/admin_service_delete_test.go @@ -164,6 +164,14 @@ func (s *groupRepoStub) DeleteAccountGroupsByGroupID(ctx context.Context, groupI panic("unexpected DeleteAccountGroupsByGroupID call") } +func (s *groupRepoStub) BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error { + panic("unexpected BindAccountsToGroup call") +} + +func (s *groupRepoStub) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error) { + panic("unexpected GetAccountIDsByGroupIDs call") +} + type proxyRepoStub struct { deleteErr error countErr error diff --git a/backend/internal/service/admin_service_group_test.go b/backend/internal/service/admin_service_group_test.go index 1454dccd..9b8c0107 100644 --- a/backend/internal/service/admin_service_group_test.go +++ b/backend/internal/service/admin_service_group_test.go @@ -108,6 +108,14 @@ func (s *groupRepoStubForAdmin) DeleteAccountGroupsByGroupID(_ context.Context, panic("unexpected DeleteAccountGroupsByGroupID call") } +func (s *groupRepoStubForAdmin) BindAccountsToGroup(_ context.Context, _ int64, _ []int64) error { + panic("unexpected BindAccountsToGroup call") +} + +func (s *groupRepoStubForAdmin) GetAccountIDsByGroupIDs(_ context.Context, _ []int64) ([]int64, error) { + panic("unexpected GetAccountIDsByGroupIDs call") +} + // TestAdminService_CreateGroup_WithImagePricing 测试创建分组时 ImagePrice 字段正确传递 func TestAdminService_CreateGroup_WithImagePricing(t *testing.T) { repo := &groupRepoStubForAdmin{} @@ -379,6 +387,14 @@ func (s *groupRepoStubForFallbackCycle) DeleteAccountGroupsByGroupID(_ context.C panic("unexpected DeleteAccountGroupsByGroupID call") } +func (s *groupRepoStubForFallbackCycle) BindAccountsToGroup(_ context.Context, _ int64, _ []int64) error { + panic("unexpected BindAccountsToGroup call") +} + +func (s *groupRepoStubForFallbackCycle) GetAccountIDsByGroupIDs(_ context.Context, _ []int64) ([]int64, error) { + panic("unexpected GetAccountIDsByGroupIDs call") +} + type groupRepoStubForInvalidRequestFallback struct { groups map[int64]*Group created *Group @@ -748,4 +764,4 @@ func TestAdminService_UpdateGroup_InvalidRequestFallbackAllowsAntigravity(t *tes require.NotNil(t, group) require.NotNil(t, repo.updated) require.Equal(t, fallbackID, *repo.updated.FallbackGroupIDOnInvalidRequest) -} + } diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index 2af9efdb..a3db1b09 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -302,13 +302,11 @@ func logPrefix(sessionID, accountName string) string { } // Antigravity 直接支持的模型(精确匹配透传) +// 注意:gemini-2.5 系列已移除,统一映射到 gemini-3 系列 var antigravitySupportedModels = map[string]bool{ "claude-opus-4-5-thinking": true, "claude-sonnet-4-5": true, "claude-sonnet-4-5-thinking": true, - "gemini-2.5-flash": true, - "gemini-2.5-flash-lite": true, - "gemini-2.5-flash-thinking": true, "gemini-3-flash": true, "gemini-3-pro-low": true, "gemini-3-pro-high": true, @@ -317,23 +315,32 @@ var antigravitySupportedModels = map[string]bool{ // Antigravity 前缀映射表(按前缀长度降序排列,确保最长匹配优先) // 用于处理模型版本号变化(如 -20251111, -thinking, -preview 等后缀) +// gemini-2.5 系列统一映射到 gemini-3 系列(Antigravity 上游不再支持 2.5) var antigravityPrefixMapping = []struct { prefix string target string }{ - // 长前缀优先 - {"gemini-2.5-flash-image", "gemini-3-pro-image"}, // gemini-2.5-flash-image → 3-pro-image - {"gemini-3-pro-image", "gemini-3-pro-image"}, // gemini-3-pro-image-preview 等 - {"gemini-3-flash", "gemini-3-flash"}, // gemini-3-flash-preview 等 → gemini-3-flash - {"claude-3-5-sonnet", "claude-sonnet-4-5"}, // 旧版 claude-3-5-sonnet-xxx - {"claude-sonnet-4-5", "claude-sonnet-4-5"}, // claude-sonnet-4-5-xxx - {"claude-haiku-4-5", "claude-sonnet-4-5"}, // claude-haiku-4-5-xxx → sonnet + // gemini-2.5 → gemini-3 映射(长前缀优先) + {"gemini-2.5-flash-thinking", "gemini-3-flash"}, // gemini-2.5-flash-thinking → gemini-3-flash + {"gemini-2.5-flash-image", "gemini-3-pro-image"}, // gemini-2.5-flash-image → gemini-3-pro-image + {"gemini-2.5-flash-lite", "gemini-3-flash"}, // gemini-2.5-flash-lite → gemini-3-flash + {"gemini-2.5-flash", "gemini-3-flash"}, // gemini-2.5-flash → gemini-3-flash + {"gemini-2.5-pro-preview", "gemini-3-pro-high"}, // gemini-2.5-pro-preview → gemini-3-pro-high + {"gemini-2.5-pro-exp", "gemini-3-pro-high"}, // gemini-2.5-pro-exp → gemini-3-pro-high + {"gemini-2.5-pro", "gemini-3-pro-high"}, // gemini-2.5-pro → gemini-3-pro-high + // gemini-3 前缀映射 + {"gemini-3-pro-image", "gemini-3-pro-image"}, // gemini-3-pro-image-preview 等 + {"gemini-3-flash", "gemini-3-flash"}, // gemini-3-flash-preview 等 → gemini-3-flash + {"gemini-3-pro", "gemini-3-pro-high"}, // gemini-3-pro, gemini-3-pro-preview 等 + // Claude 映射 + {"claude-3-5-sonnet", "claude-sonnet-4-5"}, // 旧版 claude-3-5-sonnet-xxx + {"claude-sonnet-4-5", "claude-sonnet-4-5"}, // claude-sonnet-4-5-xxx + {"claude-haiku-4-5", "claude-sonnet-4-5"}, // claude-haiku-4-5-xxx → sonnet {"claude-opus-4-5", "claude-opus-4-5-thinking"}, {"claude-3-haiku", "claude-sonnet-4-5"}, // 旧版 claude-3-haiku-xxx → sonnet {"claude-sonnet-4", "claude-sonnet-4-5"}, {"claude-haiku-4", "claude-sonnet-4-5"}, // → sonnet {"claude-opus-4", "claude-opus-4-5-thinking"}, - {"gemini-3-pro", "gemini-3-pro-high"}, // gemini-3-pro, gemini-3-pro-preview 等 } // AntigravityGatewayService 处理 Antigravity 平台的 API 转发 diff --git a/backend/internal/service/antigravity_gateway_service_test.go b/backend/internal/service/antigravity_gateway_service_test.go index ffdcdc73..32a591ef 100644 --- a/backend/internal/service/antigravity_gateway_service_test.go +++ b/backend/internal/service/antigravity_gateway_service_test.go @@ -103,6 +103,10 @@ func (s *httpUpstreamStub) Do(_ *http.Request, _ string, _ int64, _ int) (*http. return s.resp, s.err } +func (s *httpUpstreamStub) DoWithTLS(_ *http.Request, _ string, _ int64, _ int, _ bool) (*http.Response, error) { + return s.resp, s.err +} + func TestAntigravityGatewayService_Forward_PromptTooLong(t *testing.T) { gin.SetMode(gin.TestMode) writer := httptest.NewRecorder() diff --git a/backend/internal/service/antigravity_model_mapping_test.go b/backend/internal/service/antigravity_model_mapping_test.go index 179a3520..e269103a 100644 --- a/backend/internal/service/antigravity_model_mapping_test.go +++ b/backend/internal/service/antigravity_model_mapping_test.go @@ -134,18 +134,18 @@ func TestAntigravityGatewayService_GetMappedModel(t *testing.T) { expected: "claude-sonnet-4-5", }, - // 3. Gemini 透传 + // 3. Gemini 2.5 → 3 映射 { - name: "Gemini透传 - gemini-2.5-flash", + name: "Gemini映射 - gemini-2.5-flash → gemini-3-flash", requestedModel: "gemini-2.5-flash", accountMapping: nil, - expected: "gemini-2.5-flash", + expected: "gemini-3-flash", }, { - name: "Gemini透传 - gemini-2.5-pro", + name: "Gemini映射 - gemini-2.5-pro → gemini-3-pro-high", requestedModel: "gemini-2.5-pro", accountMapping: nil, - expected: "gemini-2.5-pro", + expected: "gemini-3-pro-high", }, { name: "Gemini透传 - gemini-future-model", diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go index f51fae24..25604d2c 100644 --- a/backend/internal/service/auth_service.go +++ b/backend/internal/service/auth_service.go @@ -19,17 +19,19 @@ import ( ) var ( - ErrInvalidCredentials = infraerrors.Unauthorized("INVALID_CREDENTIALS", "invalid email or password") - ErrUserNotActive = infraerrors.Forbidden("USER_NOT_ACTIVE", "user is not active") - ErrEmailExists = infraerrors.Conflict("EMAIL_EXISTS", "email already exists") - ErrEmailReserved = infraerrors.BadRequest("EMAIL_RESERVED", "email is reserved") - ErrInvalidToken = infraerrors.Unauthorized("INVALID_TOKEN", "invalid token") - ErrTokenExpired = infraerrors.Unauthorized("TOKEN_EXPIRED", "token has expired") - ErrTokenTooLarge = infraerrors.BadRequest("TOKEN_TOO_LARGE", "token too large") - ErrTokenRevoked = infraerrors.Unauthorized("TOKEN_REVOKED", "token has been revoked") - ErrEmailVerifyRequired = infraerrors.BadRequest("EMAIL_VERIFY_REQUIRED", "email verification is required") - ErrRegDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled") - ErrServiceUnavailable = infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable") + ErrInvalidCredentials = infraerrors.Unauthorized("INVALID_CREDENTIALS", "invalid email or password") + ErrUserNotActive = infraerrors.Forbidden("USER_NOT_ACTIVE", "user is not active") + ErrEmailExists = infraerrors.Conflict("EMAIL_EXISTS", "email already exists") + ErrEmailReserved = infraerrors.BadRequest("EMAIL_RESERVED", "email is reserved") + ErrInvalidToken = infraerrors.Unauthorized("INVALID_TOKEN", "invalid token") + ErrTokenExpired = infraerrors.Unauthorized("TOKEN_EXPIRED", "token has expired") + ErrTokenTooLarge = infraerrors.BadRequest("TOKEN_TOO_LARGE", "token too large") + ErrTokenRevoked = infraerrors.Unauthorized("TOKEN_REVOKED", "token has been revoked") + ErrEmailVerifyRequired = infraerrors.BadRequest("EMAIL_VERIFY_REQUIRED", "email verification is required") + ErrRegDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled") + ErrServiceUnavailable = infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable") + ErrInvitationCodeRequired = infraerrors.BadRequest("INVITATION_CODE_REQUIRED", "invitation code is required") + ErrInvitationCodeInvalid = infraerrors.BadRequest("INVITATION_CODE_INVALID", "invalid or used invitation code") ) // maxTokenLength 限制 token 大小,避免超长 header 触发解析时的异常内存分配。 @@ -47,6 +49,7 @@ type JWTClaims struct { // AuthService 认证服务 type AuthService struct { userRepo UserRepository + redeemRepo RedeemCodeRepository cfg *config.Config settingService *SettingService emailService *EmailService @@ -58,6 +61,7 @@ type AuthService struct { // NewAuthService 创建认证服务实例 func NewAuthService( userRepo UserRepository, + redeemRepo RedeemCodeRepository, cfg *config.Config, settingService *SettingService, emailService *EmailService, @@ -67,6 +71,7 @@ func NewAuthService( ) *AuthService { return &AuthService{ userRepo: userRepo, + redeemRepo: redeemRepo, cfg: cfg, settingService: settingService, emailService: emailService, @@ -78,11 +83,11 @@ func NewAuthService( // Register 用户注册,返回token和用户 func (s *AuthService) Register(ctx context.Context, email, password string) (string, *User, error) { - return s.RegisterWithVerification(ctx, email, password, "", "") + return s.RegisterWithVerification(ctx, email, password, "", "", "") } -// RegisterWithVerification 用户注册(支持邮件验证和优惠码),返回token和用户 -func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode, promoCode string) (string, *User, error) { +// RegisterWithVerification 用户注册(支持邮件验证、优惠码和邀请码),返回token和用户 +func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode, promoCode, invitationCode string) (string, *User, error) { // 检查是否开放注册(默认关闭:settingService 未配置时不允许注册) if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) { return "", nil, ErrRegDisabled @@ -93,6 +98,26 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw return "", nil, ErrEmailReserved } + // 检查是否需要邀请码 + var invitationRedeemCode *RedeemCode + if s.settingService != nil && s.settingService.IsInvitationCodeEnabled(ctx) { + if invitationCode == "" { + return "", nil, ErrInvitationCodeRequired + } + // 验证邀请码 + redeemCode, err := s.redeemRepo.GetByCode(ctx, invitationCode) + if err != nil { + log.Printf("[Auth] Invalid invitation code: %s, error: %v", invitationCode, err) + return "", nil, ErrInvitationCodeInvalid + } + // 检查类型和状态 + if redeemCode.Type != RedeemTypeInvitation || redeemCode.Status != StatusUnused { + log.Printf("[Auth] Invitation code invalid: type=%s, status=%s", redeemCode.Type, redeemCode.Status) + return "", nil, ErrInvitationCodeInvalid + } + invitationRedeemCode = redeemCode + } + // 检查是否需要邮件验证 if s.settingService != nil && s.settingService.IsEmailVerifyEnabled(ctx) { // 如果邮件验证已开启但邮件服务未配置,拒绝注册 @@ -153,6 +178,13 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw return "", nil, ErrServiceUnavailable } + // 标记邀请码为已使用(如果使用了邀请码) + if invitationRedeemCode != nil { + if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil { + // 邀请码标记失败不影响注册,只记录日志 + log.Printf("[Auth] Failed to mark invitation code as used for user %d: %v", user.ID, err) + } + } // 应用优惠码(如果提供且功能已启用) if promoCode != "" && s.promoService != nil && s.settingService != nil && s.settingService.IsPromoCodeEnabled(ctx) { if err := s.promoService.ApplyPromoCode(ctx, user.ID, promoCode); err != nil { diff --git a/backend/internal/service/auth_service_register_test.go b/backend/internal/service/auth_service_register_test.go index e31ca561..aa3c769e 100644 --- a/backend/internal/service/auth_service_register_test.go +++ b/backend/internal/service/auth_service_register_test.go @@ -115,6 +115,7 @@ func newAuthService(repo *userRepoStub, settings map[string]string, emailCache E return NewAuthService( repo, + nil, // redeemRepo cfg, settingService, emailService, @@ -152,7 +153,7 @@ func TestAuthService_Register_EmailVerifyEnabledButServiceNotConfigured(t *testi }, nil) // 应返回服务不可用错误,而不是允许绕过验证 - _, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "any-code", "") + _, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "any-code", "", "") require.ErrorIs(t, err, ErrServiceUnavailable) } @@ -164,7 +165,7 @@ func TestAuthService_Register_EmailVerifyRequired(t *testing.T) { SettingKeyEmailVerifyEnabled: "true", }, cache) - _, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "", "") + _, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "", "", "") require.ErrorIs(t, err, ErrEmailVerifyRequired) } @@ -178,7 +179,7 @@ func TestAuthService_Register_EmailVerifyInvalid(t *testing.T) { SettingKeyEmailVerifyEnabled: "true", }, cache) - _, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "wrong", "") + _, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "wrong", "", "") require.ErrorIs(t, err, ErrInvalidVerifyCode) require.ErrorContains(t, err, "verify code") } diff --git a/backend/internal/service/billing_service.go b/backend/internal/service/billing_service.go index f2afc343..db5a9708 100644 --- a/backend/internal/service/billing_service.go +++ b/backend/internal/service/billing_service.go @@ -241,6 +241,76 @@ func (s *BillingService) CalculateCostWithConfig(model string, tokens UsageToken return s.CalculateCost(model, tokens, multiplier) } +// CalculateCostWithLongContext 计算费用,支持长上下文双倍计费 +// threshold: 阈值(如 200000),超过此值的部分按 extraMultiplier 倍计费 +// extraMultiplier: 超出部分的倍率(如 2.0 表示双倍) +// +// 示例:缓存 210k + 输入 10k = 220k,阈值 200k,倍率 2.0 +// 拆分为:范围内 (200k, 0) + 范围外 (10k, 10k) +// 范围内正常计费,范围外 × 2 计费 +func (s *BillingService) CalculateCostWithLongContext(model string, tokens UsageTokens, rateMultiplier float64, threshold int, extraMultiplier float64) (*CostBreakdown, error) { + // 未启用长上下文计费,直接走正常计费 + if threshold <= 0 || extraMultiplier <= 1 { + return s.CalculateCost(model, tokens, rateMultiplier) + } + + // 计算总输入 token(缓存读取 + 新输入) + total := tokens.CacheReadTokens + tokens.InputTokens + if total <= threshold { + return s.CalculateCost(model, tokens, rateMultiplier) + } + + // 拆分成范围内和范围外 + var inRangeCacheTokens, inRangeInputTokens int + var outRangeCacheTokens, outRangeInputTokens int + + if tokens.CacheReadTokens >= threshold { + // 缓存已超过阈值:范围内只有缓存,范围外是超出的缓存+全部输入 + inRangeCacheTokens = threshold + inRangeInputTokens = 0 + outRangeCacheTokens = tokens.CacheReadTokens - threshold + outRangeInputTokens = tokens.InputTokens + } else { + // 缓存未超过阈值:范围内是全部缓存+部分输入,范围外是剩余输入 + inRangeCacheTokens = tokens.CacheReadTokens + inRangeInputTokens = threshold - tokens.CacheReadTokens + outRangeCacheTokens = 0 + outRangeInputTokens = tokens.InputTokens - inRangeInputTokens + } + + // 范围内部分:正常计费 + inRangeTokens := UsageTokens{ + InputTokens: inRangeInputTokens, + OutputTokens: tokens.OutputTokens, // 输出只算一次 + CacheCreationTokens: tokens.CacheCreationTokens, + CacheReadTokens: inRangeCacheTokens, + } + inRangeCost, err := s.CalculateCost(model, inRangeTokens, rateMultiplier) + if err != nil { + return nil, err + } + + // 范围外部分:× extraMultiplier 计费 + outRangeTokens := UsageTokens{ + InputTokens: outRangeInputTokens, + CacheReadTokens: outRangeCacheTokens, + } + outRangeCost, err := s.CalculateCost(model, outRangeTokens, rateMultiplier*extraMultiplier) + if err != nil { + return inRangeCost, nil // 出错时返回范围内成本 + } + + // 合并成本 + return &CostBreakdown{ + InputCost: inRangeCost.InputCost + outRangeCost.InputCost, + OutputCost: inRangeCost.OutputCost, + CacheCreationCost: inRangeCost.CacheCreationCost, + CacheReadCost: inRangeCost.CacheReadCost + outRangeCost.CacheReadCost, + TotalCost: inRangeCost.TotalCost + outRangeCost.TotalCost, + ActualCost: inRangeCost.ActualCost + outRangeCost.ActualCost, + }, nil +} + // ListSupportedModels 列出所有支持的模型(现在总是返回true,因为有模糊匹配) func (s *BillingService) ListSupportedModels() []string { models := make([]string, 0) diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go index 218b7aae..0295c23b 100644 --- a/backend/internal/service/domain_constants.go +++ b/backend/internal/service/domain_constants.go @@ -39,6 +39,7 @@ const ( RedeemTypeBalance = domain.RedeemTypeBalance RedeemTypeConcurrency = domain.RedeemTypeConcurrency RedeemTypeSubscription = domain.RedeemTypeSubscription + RedeemTypeInvitation = domain.RedeemTypeInvitation ) // PromoCode status constants @@ -72,10 +73,11 @@ const LinuxDoConnectSyntheticEmailDomain = "@linuxdo-connect.invalid" // Setting keys const ( // 注册设置 - SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册 - SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证 - SettingKeyPromoCodeEnabled = "promo_code_enabled" // 是否启用优惠码功能 - SettingKeyPasswordResetEnabled = "password_reset_enabled" // 是否启用忘记密码功能(需要先开启邮件验证) + SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册 + SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证 + SettingKeyPromoCodeEnabled = "promo_code_enabled" // 是否启用优惠码功能 + SettingKeyPasswordResetEnabled = "password_reset_enabled" // 是否启用忘记密码功能(需要先开启邮件验证) + SettingKeyInvitationCodeEnabled = "invitation_code_enabled" // 是否启用邀请码注册 // 邮件服务设置 SettingKeySMTPHost = "smtp_host" // SMTP服务器地址 diff --git a/backend/internal/service/gateway_beta_test.go b/backend/internal/service/gateway_beta_test.go new file mode 100644 index 00000000..dd58c183 --- /dev/null +++ b/backend/internal/service/gateway_beta_test.go @@ -0,0 +1,23 @@ +package service + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestMergeAnthropicBeta(t *testing.T) { + got := mergeAnthropicBeta( + []string{"oauth-2025-04-20", "interleaved-thinking-2025-05-14"}, + "foo, oauth-2025-04-20,bar, foo", + ) + require.Equal(t, "oauth-2025-04-20,interleaved-thinking-2025-05-14,foo,bar", got) +} + +func TestMergeAnthropicBeta_EmptyIncoming(t *testing.T) { + got := mergeAnthropicBeta( + []string{"oauth-2025-04-20", "interleaved-thinking-2025-05-14"}, + "", + ) + require.Equal(t, "oauth-2025-04-20,interleaved-thinking-2025-05-14", got) +} diff --git a/backend/internal/service/gateway_multiplatform_test.go b/backend/internal/service/gateway_multiplatform_test.go index 26eb24e4..4bfa23d1 100644 --- a/backend/internal/service/gateway_multiplatform_test.go +++ b/backend/internal/service/gateway_multiplatform_test.go @@ -266,6 +266,14 @@ func (m *mockGroupRepoForGateway) DeleteAccountGroupsByGroupID(ctx context.Conte return 0, nil } +func (m *mockGroupRepoForGateway) BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error { + return nil +} + +func (m *mockGroupRepoForGateway) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error) { + return nil, nil +} + func ptr[T any](v T) *T { return &v } diff --git a/backend/internal/service/gateway_oauth_metadata_test.go b/backend/internal/service/gateway_oauth_metadata_test.go new file mode 100644 index 00000000..ed6f1887 --- /dev/null +++ b/backend/internal/service/gateway_oauth_metadata_test.go @@ -0,0 +1,62 @@ +package service + +import ( + "regexp" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestBuildOAuthMetadataUserID_FallbackWithoutAccountUUID(t *testing.T) { + svc := &GatewayService{} + + parsed := &ParsedRequest{ + Model: "claude-sonnet-4-5", + Stream: true, + MetadataUserID: "", + System: nil, + Messages: nil, + } + + account := &Account{ + ID: 123, + Type: AccountTypeOAuth, + Extra: map[string]any{}, // intentionally missing account_uuid / claude_user_id + } + + fp := &Fingerprint{ClientID: "deadbeef"} // should be used as user id in legacy format + + got := svc.buildOAuthMetadataUserID(parsed, account, fp) + require.NotEmpty(t, got) + + // Legacy format: user_{client}_account__session_{uuid} + re := regexp.MustCompile(`^user_[a-zA-Z0-9]+_account__session_[a-f0-9-]{36}$`) + require.True(t, re.MatchString(got), "unexpected user_id format: %s", got) +} + +func TestBuildOAuthMetadataUserID_UsesAccountUUIDWhenPresent(t *testing.T) { + svc := &GatewayService{} + + parsed := &ParsedRequest{ + Model: "claude-sonnet-4-5", + Stream: true, + MetadataUserID: "", + } + + account := &Account{ + ID: 123, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "account_uuid": "acc-uuid", + "claude_user_id": "clientid123", + "anthropic_user_id": "", + }, + } + + got := svc.buildOAuthMetadataUserID(parsed, account, nil) + require.NotEmpty(t, got) + + // New format: user_{client}_account_{account_uuid}_session_{uuid} + re := regexp.MustCompile(`^user_clientid123_account_acc-uuid_session_[a-f0-9-]{36}$`) + require.True(t, re.MatchString(got), "unexpected user_id format: %s", got) +} diff --git a/backend/internal/service/gateway_prompt_test.go b/backend/internal/service/gateway_prompt_test.go index b056f8fa..52c75d1d 100644 --- a/backend/internal/service/gateway_prompt_test.go +++ b/backend/internal/service/gateway_prompt_test.go @@ -2,6 +2,7 @@ package service import ( "encoding/json" + "strings" "testing" "github.com/stretchr/testify/require" @@ -134,6 +135,8 @@ func TestSystemIncludesClaudeCodePrompt(t *testing.T) { } func TestInjectClaudeCodePrompt(t *testing.T) { + claudePrefix := strings.TrimSpace(claudeCodeSystemPrompt) + tests := []struct { name string body string @@ -162,7 +165,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) { system: "Custom prompt", wantSystemLen: 2, wantFirstText: claudeCodeSystemPrompt, - wantSecondText: "Custom prompt", + wantSecondText: claudePrefix + "\n\nCustom prompt", }, { name: "string system equals Claude Code prompt", @@ -178,7 +181,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) { // Claude Code + Custom = 2 wantSystemLen: 2, wantFirstText: claudeCodeSystemPrompt, - wantSecondText: "Custom", + wantSecondText: claudePrefix + "\n\nCustom", }, { name: "array system with existing Claude Code prompt (should dedupe)", @@ -190,7 +193,7 @@ func TestInjectClaudeCodePrompt(t *testing.T) { // Claude Code at start + Other = 2 (deduped) wantSystemLen: 2, wantFirstText: claudeCodeSystemPrompt, - wantSecondText: "Other", + wantSecondText: claudePrefix + "\n\nOther", }, { name: "empty array", diff --git a/backend/internal/service/gateway_sanitize_test.go b/backend/internal/service/gateway_sanitize_test.go new file mode 100644 index 00000000..8fa971ca --- /dev/null +++ b/backend/internal/service/gateway_sanitize_test.go @@ -0,0 +1,21 @@ +package service + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestSanitizeOpenCodeText_RewritesCanonicalSentence(t *testing.T) { + in := "You are OpenCode, the best coding agent on the planet." + got := sanitizeSystemText(in) + require.Equal(t, strings.TrimSpace(claudeCodeSystemPrompt), got) +} + +func TestSanitizeToolDescription_DoesNotRewriteKeywords(t *testing.T) { + in := "OpenCode and opencode are mentioned." + got := sanitizeToolDescription(in) + // We no longer rewrite tool descriptions; only redact obvious path leaks. + require.Equal(t, in, got) +} diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 30078e3c..065f3cba 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -20,12 +20,14 @@ import ( "strings" "sync/atomic" "time" + "unicode" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator" + "github.com/google/uuid" "github.com/tidwall/gjson" "github.com/tidwall/sjson" @@ -37,8 +39,15 @@ const ( claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true" stickySessionTTL = time.Hour // 粘性会话TTL defaultMaxLineSize = 40 * 1024 * 1024 - claudeCodeSystemPrompt = "You are Claude Code, Anthropic's official CLI for Claude." - maxCacheControlBlocks = 4 // Anthropic API 允许的最大 cache_control 块数量 + // Canonical Claude Code banner. Keep it EXACT (no trailing whitespace/newlines) + // to match real Claude CLI traffic as closely as possible. When we need a visual + // separator between system blocks, we add "\n\n" at concatenation time. + claudeCodeSystemPrompt = "You are Claude Code, Anthropic's official CLI for Claude." + maxCacheControlBlocks = 4 // Anthropic API 允许的最大 cache_control 块数量 +) + +const ( + claudeMimicDebugInfoKey = "claude_mimic_debug_info" ) func (s *GatewayService) debugModelRoutingEnabled() bool { @@ -46,6 +55,11 @@ func (s *GatewayService) debugModelRoutingEnabled() bool { return v == "1" || v == "true" || v == "yes" || v == "on" } +func (s *GatewayService) debugClaudeMimicEnabled() bool { + v := strings.ToLower(strings.TrimSpace(os.Getenv("SUB2API_DEBUG_CLAUDE_MIMIC"))) + return v == "1" || v == "true" || v == "yes" || v == "on" +} + func shortSessionHash(sessionHash string) string { if sessionHash == "" { return "" @@ -65,12 +79,178 @@ func normalizeClaudeModelForAnthropic(requestedModel string) string { return requestedModel } +func redactAuthHeaderValue(v string) string { + v = strings.TrimSpace(v) + if v == "" { + return "" + } + // Keep scheme for debugging, redact secret. + if strings.HasPrefix(strings.ToLower(v), "bearer ") { + return "Bearer [redacted]" + } + return "[redacted]" +} + +func safeHeaderValueForLog(key string, v string) string { + key = strings.ToLower(strings.TrimSpace(key)) + switch key { + case "authorization", "x-api-key": + return redactAuthHeaderValue(v) + default: + return strings.TrimSpace(v) + } +} + +func extractSystemPreviewFromBody(body []byte) string { + if len(body) == 0 { + return "" + } + sys := gjson.GetBytes(body, "system") + if !sys.Exists() { + return "" + } + + switch { + case sys.IsArray(): + for _, item := range sys.Array() { + if !item.IsObject() { + continue + } + if strings.EqualFold(item.Get("type").String(), "text") { + if t := item.Get("text").String(); strings.TrimSpace(t) != "" { + return t + } + } + } + return "" + case sys.Type == gjson.String: + return sys.String() + default: + return "" + } +} + +func buildClaudeMimicDebugLine(req *http.Request, body []byte, account *Account, tokenType string, mimicClaudeCode bool) string { + if req == nil { + return "" + } + + // Only log a minimal fingerprint to avoid leaking user content. + interesting := []string{ + "user-agent", + "x-app", + "anthropic-dangerous-direct-browser-access", + "anthropic-version", + "anthropic-beta", + "x-stainless-lang", + "x-stainless-package-version", + "x-stainless-os", + "x-stainless-arch", + "x-stainless-runtime", + "x-stainless-runtime-version", + "x-stainless-retry-count", + "x-stainless-timeout", + "authorization", + "x-api-key", + "content-type", + "accept", + "x-stainless-helper-method", + } + + h := make([]string, 0, len(interesting)) + for _, k := range interesting { + if v := req.Header.Get(k); v != "" { + h = append(h, fmt.Sprintf("%s=%q", k, safeHeaderValueForLog(k, v))) + } + } + + metaUserID := strings.TrimSpace(gjson.GetBytes(body, "metadata.user_id").String()) + sysPreview := strings.TrimSpace(extractSystemPreviewFromBody(body)) + + // Truncate preview to keep logs sane. + if len(sysPreview) > 300 { + sysPreview = sysPreview[:300] + "..." + } + sysPreview = strings.ReplaceAll(sysPreview, "\n", "\\n") + sysPreview = strings.ReplaceAll(sysPreview, "\r", "\\r") + + aid := int64(0) + aname := "" + if account != nil { + aid = account.ID + aname = account.Name + } + + return fmt.Sprintf( + "url=%s account=%d(%s) tokenType=%s mimic=%t meta.user_id=%q system.preview=%q headers={%s}", + req.URL.String(), + aid, + aname, + tokenType, + mimicClaudeCode, + metaUserID, + sysPreview, + strings.Join(h, " "), + ) +} + +func logClaudeMimicDebug(req *http.Request, body []byte, account *Account, tokenType string, mimicClaudeCode bool) { + line := buildClaudeMimicDebugLine(req, body, account, tokenType, mimicClaudeCode) + if line == "" { + return + } + log.Printf("[ClaudeMimicDebug] %s", line) +} + +func isClaudeCodeCredentialScopeError(msg string) bool { + m := strings.ToLower(strings.TrimSpace(msg)) + if m == "" { + return false + } + return strings.Contains(m, "only authorized for use with claude code") && + strings.Contains(m, "cannot be used for other api requests") +} + // sseDataRe matches SSE data lines with optional whitespace after colon. // Some upstream APIs return non-standard "data:" without space (should be "data: "). var ( sseDataRe = regexp.MustCompile(`^data:\s*`) sessionIDRegex = regexp.MustCompile(`session_([a-f0-9-]{36})`) claudeCliUserAgentRe = regexp.MustCompile(`^claude-cli/\d+\.\d+\.\d+`) + toolPrefixRe = regexp.MustCompile(`(?i)^(?:oc_|mcp_)`) + toolNameBoundaryRe = regexp.MustCompile(`[^a-zA-Z0-9]+`) + toolNameCamelRe = regexp.MustCompile(`([a-z0-9])([A-Z])`) + toolNameFieldRe = regexp.MustCompile(`"name"\s*:\s*"([^"]+)"`) + modelFieldRe = regexp.MustCompile(`"model"\s*:\s*"([^"]+)"`) + toolDescAbsPathRe = regexp.MustCompile(`/\/?(?:home|Users|tmp|var|opt|usr|etc)\/[^\s,\)"'\]]+`) + toolDescWinPathRe = regexp.MustCompile(`(?i)[A-Z]:\\[^\s,\)"'\]]+`) + + claudeToolNameOverrides = map[string]string{ + "bash": "Bash", + "read": "Read", + "edit": "Edit", + "write": "Write", + "task": "Task", + "glob": "Glob", + "grep": "Grep", + "webfetch": "WebFetch", + "websearch": "WebSearch", + "todowrite": "TodoWrite", + "question": "AskUserQuestion", + } + openCodeToolOverrides = map[string]string{ + "Bash": "bash", + "Read": "read", + "Edit": "edit", + "Write": "write", + "Task": "task", + "Glob": "glob", + "Grep": "grep", + "WebFetch": "webfetch", + "WebSearch": "websearch", + "TodoWrite": "todowrite", + "AskUserQuestion": "question", + } // claudeCodePromptPrefixes 用于检测 Claude Code 系统提示词的前缀列表 // 支持多种变体:标准版、Agent SDK 版、Explore Agent 版、Compact 版等 @@ -436,6 +616,394 @@ func (s *GatewayService) replaceModelInBody(body []byte, newModel string) []byte return newBody } +type claudeOAuthNormalizeOptions struct { + injectMetadata bool + metadataUserID string + stripSystemCacheControl bool +} + +func stripToolPrefix(value string) string { + if value == "" { + return value + } + return toolPrefixRe.ReplaceAllString(value, "") +} + +func toPascalCase(value string) string { + if value == "" { + return value + } + normalized := toolNameBoundaryRe.ReplaceAllString(value, " ") + tokens := make([]string, 0) + for _, token := range strings.Fields(normalized) { + expanded := toolNameCamelRe.ReplaceAllString(token, "$1 $2") + parts := strings.Fields(expanded) + if len(parts) > 0 { + tokens = append(tokens, parts...) + } + } + if len(tokens) == 0 { + return value + } + var builder strings.Builder + for _, token := range tokens { + lower := strings.ToLower(token) + if lower == "" { + continue + } + runes := []rune(lower) + runes[0] = unicode.ToUpper(runes[0]) + _, _ = builder.WriteString(string(runes)) + } + return builder.String() +} + +func toSnakeCase(value string) string { + if value == "" { + return value + } + output := toolNameCamelRe.ReplaceAllString(value, "$1_$2") + output = toolNameBoundaryRe.ReplaceAllString(output, "_") + output = strings.Trim(output, "_") + return strings.ToLower(output) +} + +func normalizeToolNameForClaude(name string, cache map[string]string) string { + if name == "" { + return name + } + stripped := stripToolPrefix(name) + mapped, ok := claudeToolNameOverrides[strings.ToLower(stripped)] + if !ok { + mapped = toPascalCase(stripped) + } + if mapped != "" && cache != nil && mapped != stripped { + cache[mapped] = stripped + } + if mapped == "" { + return stripped + } + return mapped +} + +func normalizeToolNameForOpenCode(name string, cache map[string]string) string { + if name == "" { + return name + } + stripped := stripToolPrefix(name) + if cache != nil { + if mapped, ok := cache[stripped]; ok { + return mapped + } + } + if mapped, ok := openCodeToolOverrides[stripped]; ok { + return mapped + } + return toSnakeCase(stripped) +} + +func normalizeParamNameForOpenCode(name string, cache map[string]string) string { + if name == "" { + return name + } + if cache != nil { + if mapped, ok := cache[name]; ok { + return mapped + } + } + return name +} + +// sanitizeSystemText rewrites only the fixed OpenCode identity sentence (if present). +// We intentionally avoid broad keyword replacement in system prompts to prevent +// accidentally changing user-provided instructions. +func sanitizeSystemText(text string) string { + if text == "" { + return text + } + // Some clients include a fixed OpenCode identity sentence. Anthropic may treat + // this as a non-Claude-Code fingerprint, so rewrite it to the canonical + // Claude Code banner before generic "OpenCode"/"opencode" replacements. + text = strings.ReplaceAll( + text, + "You are OpenCode, the best coding agent on the planet.", + strings.TrimSpace(claudeCodeSystemPrompt), + ) + return text +} + +func sanitizeToolDescription(description string) string { + if description == "" { + return description + } + description = toolDescAbsPathRe.ReplaceAllString(description, "[path]") + description = toolDescWinPathRe.ReplaceAllString(description, "[path]") + // Intentionally do NOT rewrite tool descriptions (OpenCode/Claude strings). + // Tool names/skill names may rely on exact wording, and rewriting can be misleading. + return description +} + +func normalizeToolInputSchema(inputSchema any, cache map[string]string) { + schema, ok := inputSchema.(map[string]any) + if !ok { + return + } + properties, ok := schema["properties"].(map[string]any) + if !ok { + return + } + + newProperties := make(map[string]any, len(properties)) + for key, value := range properties { + snakeKey := toSnakeCase(key) + newProperties[snakeKey] = value + if snakeKey != key && cache != nil { + cache[snakeKey] = key + } + } + schema["properties"] = newProperties + + if required, ok := schema["required"].([]any); ok { + newRequired := make([]any, 0, len(required)) + for _, item := range required { + name, ok := item.(string) + if !ok { + newRequired = append(newRequired, item) + continue + } + snakeName := toSnakeCase(name) + newRequired = append(newRequired, snakeName) + if snakeName != name && cache != nil { + cache[snakeName] = name + } + } + schema["required"] = newRequired + } +} + +func stripCacheControlFromSystemBlocks(system any) bool { + blocks, ok := system.([]any) + if !ok { + return false + } + changed := false + for _, item := range blocks { + block, ok := item.(map[string]any) + if !ok { + continue + } + if _, exists := block["cache_control"]; !exists { + continue + } + delete(block, "cache_control") + changed = true + } + return changed +} + +func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAuthNormalizeOptions) ([]byte, string, map[string]string) { + if len(body) == 0 { + return body, modelID, nil + } + var req map[string]any + if err := json.Unmarshal(body, &req); err != nil { + return body, modelID, nil + } + + toolNameMap := make(map[string]string) + + if system, ok := req["system"]; ok { + switch v := system.(type) { + case string: + sanitized := sanitizeSystemText(v) + if sanitized != v { + req["system"] = sanitized + } + case []any: + for _, item := range v { + block, ok := item.(map[string]any) + if !ok { + continue + } + if blockType, _ := block["type"].(string); blockType != "text" { + continue + } + text, ok := block["text"].(string) + if !ok || text == "" { + continue + } + sanitized := sanitizeSystemText(text) + if sanitized != text { + block["text"] = sanitized + } + } + } + } + + if rawModel, ok := req["model"].(string); ok { + normalized := claude.NormalizeModelID(rawModel) + if normalized != rawModel { + req["model"] = normalized + modelID = normalized + } + } + + if rawTools, exists := req["tools"]; exists { + switch tools := rawTools.(type) { + case []any: + for idx, tool := range tools { + toolMap, ok := tool.(map[string]any) + if !ok { + continue + } + if name, ok := toolMap["name"].(string); ok { + normalized := normalizeToolNameForClaude(name, toolNameMap) + if normalized != "" && normalized != name { + toolMap["name"] = normalized + } + } + if desc, ok := toolMap["description"].(string); ok { + sanitized := sanitizeToolDescription(desc) + if sanitized != desc { + toolMap["description"] = sanitized + } + } + if schema, ok := toolMap["input_schema"]; ok { + normalizeToolInputSchema(schema, toolNameMap) + } + tools[idx] = toolMap + } + req["tools"] = tools + case map[string]any: + normalizedTools := make(map[string]any, len(tools)) + for name, value := range tools { + normalized := normalizeToolNameForClaude(name, toolNameMap) + if normalized == "" { + normalized = name + } + if toolMap, ok := value.(map[string]any); ok { + toolMap["name"] = normalized + if desc, ok := toolMap["description"].(string); ok { + sanitized := sanitizeToolDescription(desc) + if sanitized != desc { + toolMap["description"] = sanitized + } + } + if schema, ok := toolMap["input_schema"]; ok { + normalizeToolInputSchema(schema, toolNameMap) + } + normalizedTools[normalized] = toolMap + continue + } + normalizedTools[normalized] = value + } + req["tools"] = normalizedTools + } + } else { + req["tools"] = []any{} + } + + if messages, ok := req["messages"].([]any); ok { + for _, msg := range messages { + msgMap, ok := msg.(map[string]any) + if !ok { + continue + } + content, ok := msgMap["content"].([]any) + if !ok { + continue + } + for _, block := range content { + blockMap, ok := block.(map[string]any) + if !ok { + continue + } + if blockType, _ := blockMap["type"].(string); blockType != "tool_use" { + continue + } + if name, ok := blockMap["name"].(string); ok { + normalized := normalizeToolNameForClaude(name, toolNameMap) + if normalized != "" && normalized != name { + blockMap["name"] = normalized + } + } + } + } + } + + if opts.stripSystemCacheControl { + if system, ok := req["system"]; ok { + _ = stripCacheControlFromSystemBlocks(system) + } + } + + if opts.injectMetadata && opts.metadataUserID != "" { + metadata, ok := req["metadata"].(map[string]any) + if !ok { + metadata = map[string]any{} + req["metadata"] = metadata + } + if existing, ok := metadata["user_id"].(string); !ok || existing == "" { + metadata["user_id"] = opts.metadataUserID + } + } + + delete(req, "temperature") + delete(req, "tool_choice") + + newBody, err := json.Marshal(req) + if err != nil { + return body, modelID, toolNameMap + } + return newBody, modelID, toolNameMap +} + +func (s *GatewayService) buildOAuthMetadataUserID(parsed *ParsedRequest, account *Account, fp *Fingerprint) string { + if parsed == nil || account == nil { + return "" + } + if parsed.MetadataUserID != "" { + return "" + } + + userID := strings.TrimSpace(account.GetClaudeUserID()) + if userID == "" && fp != nil { + userID = fp.ClientID + } + if userID == "" { + // Fall back to a random, well-formed client id so we can still satisfy + // Claude Code OAuth requirements when account metadata is incomplete. + userID = generateClientID() + } + + sessionHash := s.GenerateSessionHash(parsed) + sessionID := uuid.NewString() + if sessionHash != "" { + seed := fmt.Sprintf("%d::%s", account.ID, sessionHash) + sessionID = generateSessionUUID(seed) + } + + // Prefer the newer format that includes account_uuid (if present), + // otherwise fall back to the legacy Claude Code format. + accountUUID := strings.TrimSpace(account.GetExtraString("account_uuid")) + if accountUUID != "" { + return fmt.Sprintf("user_%s_account_%s_session_%s", userID, accountUUID, sessionID) + } + return fmt.Sprintf("user_%s_account__session_%s", userID, sessionID) +} + +func generateSessionUUID(seed string) string { + if seed == "" { + return uuid.NewString() + } + hash := sha256.Sum256([]byte(seed)) + bytes := hash[:16] + bytes[6] = (bytes[6] & 0x0f) | 0x40 + bytes[8] = (bytes[8] & 0x3f) | 0x80 + return fmt.Sprintf("%x-%x-%x-%x-%x", + bytes[0:4], bytes[4:6], bytes[6:8], bytes[8:10], bytes[10:16]) +} + // SelectAccount 选择账号(粘性会话+优先级) func (s *GatewayService) SelectAccount(ctx context.Context, groupID *int64, sessionHash string) (*Account, error) { return s.SelectAccountForModel(ctx, groupID, sessionHash, "") @@ -2060,6 +2628,16 @@ func isClaudeCodeClient(userAgent string, metadataUserID string) bool { return claudeCliUserAgentRe.MatchString(userAgent) } +func isClaudeCodeRequest(ctx context.Context, c *gin.Context, parsed *ParsedRequest) bool { + if IsClaudeCodeClient(ctx) { + return true + } + if parsed == nil || c == nil { + return false + } + return isClaudeCodeClient(c.GetHeader("User-Agent"), parsed.MetadataUserID) +} + // systemIncludesClaudeCodePrompt 检查 system 中是否已包含 Claude Code 提示词 // 使用前缀匹配支持多种变体(标准版、Agent SDK 版等) func systemIncludesClaudeCodePrompt(system any) bool { @@ -2096,6 +2674,10 @@ func injectClaudeCodePrompt(body []byte, system any) []byte { "text": claudeCodeSystemPrompt, "cache_control": map[string]string{"type": "ephemeral"}, } + // Opencode plugin applies an extra safeguard: it not only prepends the Claude Code + // banner, it also prefixes the next system instruction with the same banner plus + // a blank line. This helps when upstream concatenates system instructions. + claudeCodePrefix := strings.TrimSpace(claudeCodeSystemPrompt) var newSystem []any @@ -2103,19 +2685,36 @@ func injectClaudeCodePrompt(body []byte, system any) []byte { case nil: newSystem = []any{claudeCodeBlock} case string: - if v == "" || v == claudeCodeSystemPrompt { + // Be tolerant of older/newer clients that may differ only by trailing whitespace/newlines. + if strings.TrimSpace(v) == "" || strings.TrimSpace(v) == strings.TrimSpace(claudeCodeSystemPrompt) { newSystem = []any{claudeCodeBlock} } else { - newSystem = []any{claudeCodeBlock, map[string]any{"type": "text", "text": v}} + // Mirror opencode behavior: keep the banner as a separate system entry, + // but also prefix the next system text with the banner. + merged := v + if !strings.HasPrefix(v, claudeCodePrefix) { + merged = claudeCodePrefix + "\n\n" + v + } + newSystem = []any{claudeCodeBlock, map[string]any{"type": "text", "text": merged}} } case []any: newSystem = make([]any, 0, len(v)+1) newSystem = append(newSystem, claudeCodeBlock) + prefixedNext := false for _, item := range v { if m, ok := item.(map[string]any); ok { - if text, ok := m["text"].(string); ok && text == claudeCodeSystemPrompt { + if text, ok := m["text"].(string); ok && strings.TrimSpace(text) == strings.TrimSpace(claudeCodeSystemPrompt) { continue } + // Prefix the first subsequent text system block once. + if !prefixedNext { + if blockType, _ := m["type"].(string); blockType == "text" { + if text, ok := m["text"].(string); ok && strings.TrimSpace(text) != "" && !strings.HasPrefix(text, claudeCodePrefix) { + m["text"] = claudeCodePrefix + "\n\n" + text + prefixedNext = true + } + } + } } newSystem = append(newSystem, item) } @@ -2319,21 +2918,38 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A body := parsed.Body reqModel := parsed.Model reqStream := parsed.Stream + originalModel := reqModel + var toolNameMap map[string]string - // 智能注入 Claude Code 系统提示词(仅 OAuth/SetupToken 账号需要) - // 条件:1) OAuth/SetupToken 账号 2) 不是 Claude Code 客户端 3) 不是 Haiku 模型 4) system 中还没有 Claude Code 提示词 - if account.IsOAuth() && - !isClaudeCodeClient(c.GetHeader("User-Agent"), parsed.MetadataUserID) && - !strings.Contains(strings.ToLower(reqModel), "haiku") && - !systemIncludesClaudeCodePrompt(parsed.System) { - body = injectClaudeCodePrompt(body, parsed.System) + isClaudeCode := isClaudeCodeRequest(ctx, c, parsed) + shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode + + if shouldMimicClaudeCode { + // 智能注入 Claude Code 系统提示词(仅 OAuth/SetupToken 账号需要) + // 条件:1) OAuth/SetupToken 账号 2) 不是 Claude Code 客户端 3) 不是 Haiku 模型 4) system 中还没有 Claude Code 提示词 + if !strings.Contains(strings.ToLower(reqModel), "haiku") && + !systemIncludesClaudeCodePrompt(parsed.System) { + body = injectClaudeCodePrompt(body, parsed.System) + } + + normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: true} + if s.identityService != nil { + fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header) + if err == nil && fp != nil { + if metadataUserID := s.buildOAuthMetadataUserID(parsed, account, fp); metadataUserID != "" { + normalizeOpts.injectMetadata = true + normalizeOpts.metadataUserID = metadataUserID + } + } + } + + body, reqModel, toolNameMap = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts) } // 强制执行 cache_control 块数量限制(最多 4 个) body = enforceCacheControlLimit(body) // 应用模型映射(APIKey 明确映射优先,其次使用 Anthropic 前缀映射) - originalModel := reqModel mappedModel := reqModel mappingSource := "" if account.Type == AccountTypeAPIKey { @@ -2377,10 +2993,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A retryStart := time.Now() for attempt := 1; attempt <= maxRetryAttempts; attempt++ { // 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取) - upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType, reqModel) // Capture upstream request body for ops retry of this attempt. c.Set(OpsUpstreamRequestBodyKey, string(body)) - + upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode) if err != nil { return nil, err } @@ -2458,7 +3073,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A // also downgrade tool_use/tool_result blocks to text. filteredBody := FilterThinkingBlocksForRetry(body) - retryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, filteredBody, token, tokenType, reqModel) + retryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, filteredBody, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode) if buildErr == nil { retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) if retryErr == nil { @@ -2490,7 +3105,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A if looksLikeToolSignatureError(msg2) && time.Since(retryStart) < maxRetryElapsed { log.Printf("Account %d: signature retry still failing and looks tool-related, retrying with tool blocks downgraded", account.ID) filteredBody2 := FilterSignatureSensitiveBlocksForRetry(body) - retryReq2, buildErr2 := s.buildUpstreamRequest(ctx, c, account, filteredBody2, token, tokenType, reqModel) + retryReq2, buildErr2 := s.buildUpstreamRequest(ctx, c, account, filteredBody2, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode) if buildErr2 == nil { retryResp2, retryErr2 := s.httpUpstream.DoWithTLS(retryReq2, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) if retryErr2 == nil { @@ -2715,7 +3330,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A var firstTokenMs *int var clientDisconnect bool if reqStream { - streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, reqModel) + streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, reqModel, toolNameMap, shouldMimicClaudeCode) if err != nil { if err.Error() == "have error in stream" { return nil, &UpstreamFailoverError{ @@ -2728,7 +3343,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A firstTokenMs = streamResult.firstTokenMs clientDisconnect = streamResult.clientDisconnect } else { - usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, reqModel) + usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, reqModel, toolNameMap, shouldMimicClaudeCode) if err != nil { return nil, err } @@ -2745,7 +3360,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A }, nil } -func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) { +func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string, reqStream bool, mimicClaudeCode bool) (*http.Request, error) { // 确定目标URL targetURL := claudeAPIURL if account.Type == AccountTypeAPIKey { @@ -2759,11 +3374,16 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex } } + clientHeaders := http.Header{} + if c != nil && c.Request != nil { + clientHeaders = c.Request.Header + } + // OAuth账号:应用统一指纹 var fingerprint *Fingerprint if account.IsOAuth() && s.identityService != nil { // 1. 获取或创建指纹(包含随机生成的ClientID) - fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header) + fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, clientHeaders) if err != nil { log.Printf("Warning: failed to get fingerprint for account %d: %v", account.ID, err) // 失败时降级为透传原始headers @@ -2794,7 +3414,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex } // 白名单透传headers - for key, values := range c.Request.Header { + for key, values := range clientHeaders { lowerKey := strings.ToLower(key) if allowedHeaders[lowerKey] { for _, v := range values { @@ -2815,10 +3435,30 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex if req.Header.Get("anthropic-version") == "" { req.Header.Set("anthropic-version", "2023-06-01") } - - // 处理anthropic-beta header(OAuth账号需要特殊处理) if tokenType == "oauth" { - req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta"))) + applyClaudeOAuthHeaderDefaults(req, reqStream) + } + + // 处理 anthropic-beta header(OAuth 账号需要包含 oauth beta) + if tokenType == "oauth" { + if mimicClaudeCode { + // 非 Claude Code 客户端:按 opencode 的策略处理: + // - 强制 Claude Code 指纹相关请求头(尤其是 user-agent/x-stainless/x-app) + // - 保留 incoming beta 的同时,确保 OAuth 所需 beta 存在 + applyClaudeCodeMimicHeaders(req, reqStream) + + incomingBeta := req.Header.Get("anthropic-beta") + // Match real Claude CLI traffic (per mitmproxy reports): + // messages requests typically use only oauth + interleaved-thinking. + // Also drop claude-code beta if a downstream client added it. + requiredBetas := []string{claude.BetaOAuth, claude.BetaInterleavedThinking} + drop := map[string]struct{}{claude.BetaClaudeCode: {}} + req.Header.Set("anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, drop)) + } else { + // Claude Code 客户端:尽量透传原始 header,仅补齐 oauth beta + clientBetaHeader := req.Header.Get("anthropic-beta") + req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, clientBetaHeader)) + } } else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" { // API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭) if requestNeedsBetaFeatures(body) { @@ -2828,6 +3468,15 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex } } + // Always capture a compact fingerprint line for later error diagnostics. + // We only print it when needed (or when the explicit debug flag is enabled). + if c != nil && tokenType == "oauth" { + c.Set(claudeMimicDebugInfoKey, buildClaudeMimicDebugLine(req, body, account, tokenType, mimicClaudeCode)) + } + if s.debugClaudeMimicEnabled() { + logClaudeMimicDebug(req, body, account, tokenType, mimicClaudeCode) + } + return req, nil } @@ -2897,6 +3546,93 @@ func defaultAPIKeyBetaHeader(body []byte) string { return claude.APIKeyBetaHeader } +func applyClaudeOAuthHeaderDefaults(req *http.Request, isStream bool) { + if req == nil { + return + } + if req.Header.Get("accept") == "" { + req.Header.Set("accept", "application/json") + } + for key, value := range claude.DefaultHeaders { + if value == "" { + continue + } + if req.Header.Get(key) == "" { + req.Header.Set(key, value) + } + } + if isStream && req.Header.Get("x-stainless-helper-method") == "" { + req.Header.Set("x-stainless-helper-method", "stream") + } +} + +func mergeAnthropicBeta(required []string, incoming string) string { + seen := make(map[string]struct{}, len(required)+8) + out := make([]string, 0, len(required)+8) + + add := func(v string) { + v = strings.TrimSpace(v) + if v == "" { + return + } + if _, ok := seen[v]; ok { + return + } + seen[v] = struct{}{} + out = append(out, v) + } + + for _, r := range required { + add(r) + } + for _, p := range strings.Split(incoming, ",") { + add(p) + } + return strings.Join(out, ",") +} + +func mergeAnthropicBetaDropping(required []string, incoming string, drop map[string]struct{}) string { + merged := mergeAnthropicBeta(required, incoming) + if merged == "" || len(drop) == 0 { + return merged + } + out := make([]string, 0, 8) + for _, p := range strings.Split(merged, ",") { + p = strings.TrimSpace(p) + if p == "" { + continue + } + if _, ok := drop[p]; ok { + continue + } + out = append(out, p) + } + return strings.Join(out, ",") +} + +// applyClaudeCodeMimicHeaders forces "Claude Code-like" request headers. +// This mirrors opencode-anthropic-auth behavior: do not trust downstream +// headers when using Claude Code-scoped OAuth credentials. +func applyClaudeCodeMimicHeaders(req *http.Request, isStream bool) { + if req == nil { + return + } + // Start with the standard defaults (fill missing). + applyClaudeOAuthHeaderDefaults(req, isStream) + // Then force key headers to match Claude Code fingerprint regardless of what the client sent. + for key, value := range claude.DefaultHeaders { + if value == "" { + continue + } + req.Header.Set(key, value) + } + // Real Claude CLI uses Accept: application/json (even for streaming). + req.Header.Set("accept", "application/json") + if isStream { + req.Header.Set("x-stainless-helper-method", "stream") + } +} + func truncateForLog(b []byte, maxBytes int) string { if maxBytes <= 0 { maxBytes = 2048 @@ -3000,6 +3736,20 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body)) upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + // Print a compact upstream request fingerprint when we hit the Claude Code OAuth + // credential scope error. This avoids requiring env-var tweaks in a fixed deploy. + if isClaudeCodeCredentialScopeError(upstreamMsg) && c != nil { + if v, ok := c.Get(claudeMimicDebugInfoKey); ok { + if line, ok := v.(string); ok && strings.TrimSpace(line) != "" { + log.Printf("[ClaudeMimicDebugOnError] status=%d request_id=%s %s", + resp.StatusCode, + resp.Header.Get("x-request-id"), + line, + ) + } + } + } + // Enrich Ops error logs with upstream status + message, and optionally a truncated body snippet. upstreamDetail := "" if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { @@ -3129,6 +3879,19 @@ func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *ht upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + + if isClaudeCodeCredentialScopeError(upstreamMsg) && c != nil { + if v, ok := c.Get(claudeMimicDebugInfoKey); ok { + if line, ok := v.(string); ok && strings.TrimSpace(line) != "" { + log.Printf("[ClaudeMimicDebugOnError] status=%d request_id=%s %s", + resp.StatusCode, + resp.Header.Get("x-request-id"), + line, + ) + } + } + } + upstreamDetail := "" if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes @@ -3181,7 +3944,7 @@ type streamingResult struct { clientDisconnect bool // 客户端是否在流式传输过程中断开 } -func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string) (*streamingResult, error) { +func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string, toolNameMap map[string]string, mimicClaudeCode bool) (*streamingResult, error) { // 更新5h窗口状态 s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header) @@ -3276,6 +4039,171 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http needModelReplace := originalModel != mappedModel clientDisconnected := false // 客户端断开标志,断开后继续读取上游以获取完整usage + pendingEventLines := make([]string, 0, 4) + var toolInputBuffers map[int]string + if mimicClaudeCode { + toolInputBuffers = make(map[int]string) + } + + transformToolInputJSON := func(raw string) string { + if !mimicClaudeCode { + return raw + } + raw = strings.TrimSpace(raw) + if raw == "" { + return raw + } + + var parsed any + if err := json.Unmarshal([]byte(raw), &parsed); err != nil { + return replaceToolNamesInText(raw, toolNameMap) + } + + rewritten, changed := rewriteParamKeysInValue(parsed, toolNameMap) + if changed { + if bytes, err := json.Marshal(rewritten); err == nil { + return string(bytes) + } + } + return raw + } + + processSSEEvent := func(lines []string) ([]string, string, error) { + if len(lines) == 0 { + return nil, "", nil + } + + eventName := "" + dataLine := "" + for _, line := range lines { + trimmed := strings.TrimSpace(line) + if strings.HasPrefix(trimmed, "event:") { + eventName = strings.TrimSpace(strings.TrimPrefix(trimmed, "event:")) + continue + } + if dataLine == "" && sseDataRe.MatchString(trimmed) { + dataLine = sseDataRe.ReplaceAllString(trimmed, "") + } + } + + if eventName == "error" { + return nil, dataLine, errors.New("have error in stream") + } + + if dataLine == "" { + return []string{strings.Join(lines, "\n") + "\n\n"}, "", nil + } + + if dataLine == "[DONE]" { + block := "" + if eventName != "" { + block = "event: " + eventName + "\n" + } + block += "data: " + dataLine + "\n\n" + return []string{block}, dataLine, nil + } + + var event map[string]any + if err := json.Unmarshal([]byte(dataLine), &event); err != nil { + replaced := dataLine + if mimicClaudeCode { + replaced = replaceToolNamesInText(dataLine, toolNameMap) + } + block := "" + if eventName != "" { + block = "event: " + eventName + "\n" + } + block += "data: " + replaced + "\n\n" + return []string{block}, replaced, nil + } + + eventType, _ := event["type"].(string) + if eventName == "" { + eventName = eventType + } + + if needModelReplace { + if msg, ok := event["message"].(map[string]any); ok { + if model, ok := msg["model"].(string); ok && model == mappedModel { + msg["model"] = originalModel + } + } + } + + if mimicClaudeCode && eventType == "content_block_delta" { + if delta, ok := event["delta"].(map[string]any); ok { + if deltaType, _ := delta["type"].(string); deltaType == "input_json_delta" { + if indexVal, ok := event["index"].(float64); ok { + index := int(indexVal) + if partial, ok := delta["partial_json"].(string); ok { + toolInputBuffers[index] += partial + } + } + return nil, dataLine, nil + } + } + } + + if mimicClaudeCode && eventType == "content_block_stop" { + if indexVal, ok := event["index"].(float64); ok { + index := int(indexVal) + if buffered := toolInputBuffers[index]; buffered != "" { + delete(toolInputBuffers, index) + + transformed := transformToolInputJSON(buffered) + synthetic := map[string]any{ + "type": "content_block_delta", + "index": index, + "delta": map[string]any{ + "type": "input_json_delta", + "partial_json": transformed, + }, + } + + synthBytes, synthErr := json.Marshal(synthetic) + if synthErr == nil { + synthBlock := "event: content_block_delta\n" + "data: " + string(synthBytes) + "\n\n" + + rewriteToolNamesInValue(event, toolNameMap) + stopBytes, stopErr := json.Marshal(event) + if stopErr == nil { + stopBlock := "" + if eventName != "" { + stopBlock = "event: " + eventName + "\n" + } + stopBlock += "data: " + string(stopBytes) + "\n\n" + return []string{synthBlock, stopBlock}, string(stopBytes), nil + } + } + } + } + } + + if mimicClaudeCode { + rewriteToolNamesInValue(event, toolNameMap) + } + newData, err := json.Marshal(event) + if err != nil { + replaced := dataLine + if mimicClaudeCode { + replaced = replaceToolNamesInText(dataLine, toolNameMap) + } + block := "" + if eventName != "" { + block = "event: " + eventName + "\n" + } + block += "data: " + replaced + "\n\n" + return []string{block}, replaced, nil + } + + block := "" + if eventName != "" { + block = "event: " + eventName + "\n" + } + block += "data: " + string(newData) + "\n\n" + return []string{block}, string(newData), nil + } + for { select { case ev, ok := <-events: @@ -3304,42 +4232,43 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", ev.err) } line := ev.line - if line == "event: error" { - // 上游返回错误事件,如果客户端已断开仍返回已收集的 usage - if clientDisconnected { - return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil + trimmed := strings.TrimSpace(line) + + if trimmed == "" { + if len(pendingEventLines) == 0 { + continue } - return nil, errors.New("have error in stream") + + outputBlocks, data, err := processSSEEvent(pendingEventLines) + pendingEventLines = pendingEventLines[:0] + if err != nil { + if clientDisconnected { + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil + } + return nil, err + } + + for _, block := range outputBlocks { + if !clientDisconnected { + if _, werr := fmt.Fprint(w, block); werr != nil { + clientDisconnected = true + log.Printf("Client disconnected during streaming, continuing to drain upstream for billing") + break + } + flusher.Flush() + } + if data != "" { + if firstTokenMs == nil && data != "[DONE]" { + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + s.parseSSEUsage(data, usage) + } + } + continue } - // Extract data from SSE line (supports both "data: " and "data:" formats) - var data string - if sseDataRe.MatchString(line) { - data = sseDataRe.ReplaceAllString(line, "") - // 如果有模型映射,替换响应中的model字段 - if needModelReplace { - line = s.replaceModelInSSELine(line, mappedModel, originalModel) - } - } - - // 写入客户端(统一处理 data 行和非 data 行) - if !clientDisconnected { - if _, err := fmt.Fprintf(w, "%s\n", line); err != nil { - clientDisconnected = true - log.Printf("Client disconnected during streaming, continuing to drain upstream for billing") - } else { - flusher.Flush() - } - } - - // 无论客户端是否断开,都解析 usage(仅对 data 行) - if data != "" { - if firstTokenMs == nil && data != "[DONE]" { - ms := int(time.Since(startTime).Milliseconds()) - firstTokenMs = &ms - } - s.parseSSEUsage(data, usage) - } + pendingEventLines = append(pendingEventLines, line) case <-intervalCh: lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt)) @@ -3363,43 +4292,124 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http } -// replaceModelInSSELine 替换SSE数据行中的model字段 -func (s *GatewayService) replaceModelInSSELine(line, fromModel, toModel string) string { - if !sseDataRe.MatchString(line) { - return line +func rewriteParamKeysInValue(value any, cache map[string]string) (any, bool) { + switch v := value.(type) { + case map[string]any: + changed := false + rewritten := make(map[string]any, len(v)) + for key, item := range v { + newKey := normalizeParamNameForOpenCode(key, cache) + newItem, childChanged := rewriteParamKeysInValue(item, cache) + if childChanged { + changed = true + } + if newKey != key { + changed = true + } + rewritten[newKey] = newItem + } + if !changed { + return value, false + } + return rewritten, true + case []any: + changed := false + rewritten := make([]any, len(v)) + for idx, item := range v { + newItem, childChanged := rewriteParamKeysInValue(item, cache) + if childChanged { + changed = true + } + rewritten[idx] = newItem + } + if !changed { + return value, false + } + return rewritten, true + default: + return value, false } - data := sseDataRe.ReplaceAllString(line, "") - if data == "" || data == "[DONE]" { - return line +} + +func rewriteToolNamesInValue(value any, toolNameMap map[string]string) bool { + switch v := value.(type) { + case map[string]any: + changed := false + if blockType, _ := v["type"].(string); blockType == "tool_use" { + if name, ok := v["name"].(string); ok { + mapped := normalizeToolNameForOpenCode(name, toolNameMap) + if mapped != name { + v["name"] = mapped + changed = true + } + } + if input, ok := v["input"].(map[string]any); ok { + rewrittenInput, inputChanged := rewriteParamKeysInValue(input, toolNameMap) + if inputChanged { + if m, ok := rewrittenInput.(map[string]any); ok { + v["input"] = m + changed = true + } + } + } + } + for _, item := range v { + if rewriteToolNamesInValue(item, toolNameMap) { + changed = true + } + } + return changed + case []any: + changed := false + for _, item := range v { + if rewriteToolNamesInValue(item, toolNameMap) { + changed = true + } + } + return changed + default: + return false + } +} + +func replaceToolNamesInText(text string, toolNameMap map[string]string) string { + if text == "" { + return text + } + output := toolNameFieldRe.ReplaceAllStringFunc(text, func(match string) string { + submatches := toolNameFieldRe.FindStringSubmatch(match) + if len(submatches) < 2 { + return match + } + name := submatches[1] + mapped := normalizeToolNameForOpenCode(name, toolNameMap) + if mapped == name { + return match + } + return strings.Replace(match, name, mapped, 1) + }) + output = modelFieldRe.ReplaceAllStringFunc(output, func(match string) string { + submatches := modelFieldRe.FindStringSubmatch(match) + if len(submatches) < 2 { + return match + } + model := submatches[1] + mapped := claude.DenormalizeModelID(model) + if mapped == model { + return match + } + return strings.Replace(match, model, mapped, 1) + }) + + for mapped, original := range toolNameMap { + if mapped == "" || original == "" || mapped == original { + continue + } + output = strings.ReplaceAll(output, "\""+mapped+"\":", "\""+original+"\":") + output = strings.ReplaceAll(output, "\\\""+mapped+"\\\":", "\\\""+original+"\\\":") } - var event map[string]any - if err := json.Unmarshal([]byte(data), &event); err != nil { - return line - } - - // 只替换 message_start 事件中的 message.model - if event["type"] != "message_start" { - return line - } - - msg, ok := event["message"].(map[string]any) - if !ok { - return line - } - - model, ok := msg["model"].(string) - if !ok || model != fromModel { - return line - } - - msg["model"] = toModel - newData, err := json.Marshal(event) - if err != nil { - return line - } - - return "data: " + string(newData) + return output } func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) { @@ -3445,7 +4455,7 @@ func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) { } } -func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*ClaudeUsage, error) { +func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string, toolNameMap map[string]string, mimicClaudeCode bool) (*ClaudeUsage, error) { // 更新5h窗口状态 s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header) @@ -3466,6 +4476,9 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h if originalModel != mappedModel { body = s.replaceModelInResponseBody(body, mappedModel, originalModel) } + if mimicClaudeCode { + body = s.replaceToolNamesInResponseBody(body, toolNameMap) + } responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders) @@ -3503,6 +4516,28 @@ func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toMo return newBody } +func (s *GatewayService) replaceToolNamesInResponseBody(body []byte, toolNameMap map[string]string) []byte { + if len(body) == 0 { + return body + } + var resp map[string]any + if err := json.Unmarshal(body, &resp); err != nil { + replaced := replaceToolNamesInText(string(body), toolNameMap) + if replaced == string(body) { + return body + } + return []byte(replaced) + } + if !rewriteToolNamesInValue(resp, toolNameMap) { + return body + } + newBody, err := json.Marshal(resp) + if err != nil { + return body + } + return newBody +} + // RecordUsageInput 记录使用量的输入参数 type RecordUsageInput struct { Result *ForwardResult @@ -3657,6 +4692,162 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu return nil } +// RecordUsageLongContextInput 记录使用量的输入参数(支持长上下文双倍计费) +type RecordUsageLongContextInput struct { + Result *ForwardResult + APIKey *APIKey + User *User + Account *Account + Subscription *UserSubscription // 可选:订阅信息 + UserAgent string // 请求的 User-Agent + IPAddress string // 请求的客户端 IP 地址 + LongContextThreshold int // 长上下文阈值(如 200000) + LongContextMultiplier float64 // 超出阈值部分的倍率(如 2.0) +} + +// RecordUsageWithLongContext 记录使用量并扣费,支持长上下文双倍计费(用于 Gemini) +func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *RecordUsageLongContextInput) error { + result := input.Result + apiKey := input.APIKey + user := input.User + account := input.Account + subscription := input.Subscription + + // 获取费率倍数 + multiplier := s.cfg.Default.RateMultiplier + if apiKey.GroupID != nil && apiKey.Group != nil { + multiplier = apiKey.Group.RateMultiplier + } + + var cost *CostBreakdown + + // 根据请求类型选择计费方式 + if result.ImageCount > 0 { + // 图片生成计费 + var groupConfig *ImagePriceConfig + if apiKey.Group != nil { + groupConfig = &ImagePriceConfig{ + Price1K: apiKey.Group.ImagePrice1K, + Price2K: apiKey.Group.ImagePrice2K, + Price4K: apiKey.Group.ImagePrice4K, + } + } + cost = s.billingService.CalculateImageCost(result.Model, result.ImageSize, result.ImageCount, groupConfig, multiplier) + } else { + // Token 计费(使用长上下文计费方法) + tokens := UsageTokens{ + InputTokens: result.Usage.InputTokens, + OutputTokens: result.Usage.OutputTokens, + CacheCreationTokens: result.Usage.CacheCreationInputTokens, + CacheReadTokens: result.Usage.CacheReadInputTokens, + } + var err error + cost, err = s.billingService.CalculateCostWithLongContext(result.Model, tokens, multiplier, input.LongContextThreshold, input.LongContextMultiplier) + if err != nil { + log.Printf("Calculate cost failed: %v", err) + cost = &CostBreakdown{ActualCost: 0} + } + } + + // 判断计费方式:订阅模式 vs 余额模式 + isSubscriptionBilling := subscription != nil && apiKey.Group != nil && apiKey.Group.IsSubscriptionType() + billingType := BillingTypeBalance + if isSubscriptionBilling { + billingType = BillingTypeSubscription + } + + // 创建使用日志 + durationMs := int(result.Duration.Milliseconds()) + var imageSize *string + if result.ImageSize != "" { + imageSize = &result.ImageSize + } + accountRateMultiplier := account.BillingRateMultiplier() + usageLog := &UsageLog{ + UserID: user.ID, + APIKeyID: apiKey.ID, + AccountID: account.ID, + RequestID: result.RequestID, + Model: result.Model, + InputTokens: result.Usage.InputTokens, + OutputTokens: result.Usage.OutputTokens, + CacheCreationTokens: result.Usage.CacheCreationInputTokens, + CacheReadTokens: result.Usage.CacheReadInputTokens, + InputCost: cost.InputCost, + OutputCost: cost.OutputCost, + CacheCreationCost: cost.CacheCreationCost, + CacheReadCost: cost.CacheReadCost, + TotalCost: cost.TotalCost, + ActualCost: cost.ActualCost, + RateMultiplier: multiplier, + AccountRateMultiplier: &accountRateMultiplier, + BillingType: billingType, + Stream: result.Stream, + DurationMs: &durationMs, + FirstTokenMs: result.FirstTokenMs, + ImageCount: result.ImageCount, + ImageSize: imageSize, + CreatedAt: time.Now(), + } + + // 添加 UserAgent + if input.UserAgent != "" { + usageLog.UserAgent = &input.UserAgent + } + + // 添加 IPAddress + if input.IPAddress != "" { + usageLog.IPAddress = &input.IPAddress + } + + // 添加分组和订阅关联 + if apiKey.GroupID != nil { + usageLog.GroupID = apiKey.GroupID + } + if subscription != nil { + usageLog.SubscriptionID = &subscription.ID + } + + inserted, err := s.usageLogRepo.Create(ctx, usageLog) + if err != nil { + log.Printf("Create usage log failed: %v", err) + } + + if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { + log.Printf("[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens()) + s.deferredService.ScheduleLastUsedUpdate(account.ID) + return nil + } + + shouldBill := inserted || err != nil + + // 根据计费类型执行扣费 + if isSubscriptionBilling { + // 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率) + if shouldBill && cost.TotalCost > 0 { + if err := s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost); err != nil { + log.Printf("Increment subscription usage failed: %v", err) + } + // 异步更新订阅缓存 + s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost) + } + } else { + // 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用) + if shouldBill && cost.ActualCost > 0 { + if err := s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost); err != nil { + log.Printf("Deduct balance failed: %v", err) + } + // 异步更新余额缓存 + s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost) + } + } + + // Schedule batch update for account last_used_at + s.deferredService.ScheduleLastUsedUpdate(account.ID) + + return nil +} + // ForwardCountTokens 转发 count_tokens 请求到上游 API // 特点:不记录使用量、仅支持非流式响应 func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) error { @@ -3668,6 +4859,14 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, body := parsed.Body reqModel := parsed.Model + isClaudeCode := isClaudeCodeRequest(ctx, c, parsed) + shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode + + if shouldMimicClaudeCode { + normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: true} + body, reqModel, _ = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts) + } + // Antigravity 账户不支持 count_tokens 转发,直接返回空值 if account.Platform == PlatformAntigravity { c.JSON(http.StatusOK, gin.H{"input_tokens": 0}) @@ -3706,7 +4905,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, } // 构建上游请求 - upstreamReq, err := s.buildCountTokensRequest(ctx, c, account, body, token, tokenType, reqModel) + upstreamReq, err := s.buildCountTokensRequest(ctx, c, account, body, token, tokenType, reqModel, shouldMimicClaudeCode) if err != nil { s.countTokensError(c, http.StatusInternalServerError, "api_error", "Failed to build request") return err @@ -3739,7 +4938,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, log.Printf("Account %d: detected thinking block signature error on count_tokens, retrying with filtered thinking blocks", account.ID) filteredBody := FilterThinkingBlocksForRetry(body) - retryReq, buildErr := s.buildCountTokensRequest(ctx, c, account, filteredBody, token, tokenType, reqModel) + retryReq, buildErr := s.buildCountTokensRequest(ctx, c, account, filteredBody, token, tokenType, reqModel, shouldMimicClaudeCode) if buildErr == nil { retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) if retryErr == nil { @@ -3804,7 +5003,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, } // buildCountTokensRequest 构建 count_tokens 上游请求 -func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) { +func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string, mimicClaudeCode bool) (*http.Request, error) { // 确定目标 URL targetURL := claudeAPICountTokensURL if account.Type == AccountTypeAPIKey { @@ -3818,10 +5017,15 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con } } + clientHeaders := http.Header{} + if c != nil && c.Request != nil { + clientHeaders = c.Request.Header + } + // OAuth 账号:应用统一指纹和重写 userID // 如果启用了会话ID伪装,会在重写后替换 session 部分为固定值 if account.IsOAuth() && s.identityService != nil { - fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header) + fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, clientHeaders) if err == nil { accountUUID := account.GetExtraString("account_uuid") if accountUUID != "" && fp.ClientID != "" { @@ -3845,7 +5049,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con } // 白名单透传 headers - for key, values := range c.Request.Header { + for key, values := range clientHeaders { lowerKey := strings.ToLower(key) if allowedHeaders[lowerKey] { for _, v := range values { @@ -3856,7 +5060,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con // OAuth 账号:应用指纹到请求头 if account.IsOAuth() && s.identityService != nil { - fp, _ := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header) + fp, _ := s.identityService.GetOrCreateFingerprint(ctx, account.ID, clientHeaders) if fp != nil { s.identityService.ApplyFingerprint(req, fp) } @@ -3869,10 +5073,30 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con if req.Header.Get("anthropic-version") == "" { req.Header.Set("anthropic-version", "2023-06-01") } + if tokenType == "oauth" { + applyClaudeOAuthHeaderDefaults(req, false) + } // OAuth 账号:处理 anthropic-beta header if tokenType == "oauth" { - req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta"))) + if mimicClaudeCode { + applyClaudeCodeMimicHeaders(req, false) + + incomingBeta := req.Header.Get("anthropic-beta") + requiredBetas := []string{claude.BetaClaudeCode, claude.BetaOAuth, claude.BetaInterleavedThinking, claude.BetaTokenCounting} + req.Header.Set("anthropic-beta", mergeAnthropicBeta(requiredBetas, incomingBeta)) + } else { + clientBetaHeader := req.Header.Get("anthropic-beta") + if clientBetaHeader == "" { + req.Header.Set("anthropic-beta", claude.CountTokensBetaHeader) + } else { + beta := s.getBetaHeader(modelID, clientBetaHeader) + if !strings.Contains(beta, claude.BetaTokenCounting) { + beta = beta + "," + claude.BetaTokenCounting + } + req.Header.Set("anthropic-beta", beta) + } + } } else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" { // API-key:与 messages 同步的按需 beta 注入(默认关闭) if requestNeedsBetaFeatures(body) { @@ -3882,6 +5106,13 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con } } + if c != nil && tokenType == "oauth" { + c.Set(claudeMimicDebugInfoKey, buildClaudeMimicDebugLine(req, body, account, tokenType, mimicClaudeCode)) + } + if s.debugClaudeMimicEnabled() { + logClaudeMimicDebug(req, body, account, tokenType, mimicClaudeCode) + } + return req, nil } diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go index 2e04c73c..bd322991 100644 --- a/backend/internal/service/gemini_messages_compat_service.go +++ b/backend/internal/service/gemini_messages_compat_service.go @@ -36,6 +36,11 @@ const ( geminiRetryMaxDelay = 16 * time.Second ) +// Gemini tool calling now requires `thoughtSignature` in parts that include `functionCall`. +// Many clients don't send it; we inject a known dummy signature to satisfy the validator. +// Ref: https://ai.google.dev/gemini-api/docs/thought-signatures +const geminiDummyThoughtSignature = "skip_thought_signature_validator" + type GeminiMessagesCompatService struct { accountRepo AccountRepository groupRepo GroupRepository @@ -528,6 +533,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex if err != nil { return nil, s.writeClaudeError(c, http.StatusBadRequest, "invalid_request_error", err.Error()) } + geminiReq = ensureGeminiFunctionCallThoughtSignatures(geminiReq) originalClaudeBody := body proxyURL := "" @@ -983,6 +989,10 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. return nil, s.writeGoogleError(c, http.StatusNotFound, "Unsupported action: "+action) } + // Some Gemini upstreams validate tool call parts strictly; ensure any `functionCall` part includes a + // `thoughtSignature` to avoid frequent INVALID_ARGUMENT 400s. + body = ensureGeminiFunctionCallThoughtSignatures(body) + mappedModel := originalModel if account.Type == AccountTypeAPIKey { mappedModel = account.GetMappedModel(originalModel) @@ -2662,6 +2672,58 @@ func nextGeminiDailyResetUnix() *int64 { return &ts } +func ensureGeminiFunctionCallThoughtSignatures(body []byte) []byte { + // Fast path: only run when functionCall is present. + if !bytes.Contains(body, []byte(`"functionCall"`)) { + return body + } + + var payload map[string]any + if err := json.Unmarshal(body, &payload); err != nil { + return body + } + + contentsAny, ok := payload["contents"].([]any) + if !ok || len(contentsAny) == 0 { + return body + } + + modified := false + for _, c := range contentsAny { + cm, ok := c.(map[string]any) + if !ok { + continue + } + partsAny, ok := cm["parts"].([]any) + if !ok || len(partsAny) == 0 { + continue + } + for _, p := range partsAny { + pm, ok := p.(map[string]any) + if !ok || pm == nil { + continue + } + if fc, ok := pm["functionCall"].(map[string]any); !ok || fc == nil { + continue + } + ts, _ := pm["thoughtSignature"].(string) + if strings.TrimSpace(ts) == "" { + pm["thoughtSignature"] = geminiDummyThoughtSignature + modified = true + } + } + } + + if !modified { + return body + } + b, err := json.Marshal(payload) + if err != nil { + return body + } + return b +} + func extractGeminiFinishReason(geminiResp map[string]any) string { if candidates, ok := geminiResp["candidates"].([]any); ok && len(candidates) > 0 { if cand, ok := candidates[0].(map[string]any); ok { @@ -2861,7 +2923,13 @@ func convertClaudeMessagesToGeminiContents(messages any, toolUseIDToName map[str if strings.TrimSpace(id) != "" && strings.TrimSpace(name) != "" { toolUseIDToName[id] = name } + signature, _ := bm["signature"].(string) + signature = strings.TrimSpace(signature) + if signature == "" { + signature = geminiDummyThoughtSignature + } parts = append(parts, map[string]any{ + "thoughtSignature": signature, "functionCall": map[string]any{ "name": name, "args": bm["input"], diff --git a/backend/internal/service/gemini_messages_compat_service_test.go b/backend/internal/service/gemini_messages_compat_service_test.go index d49f2eb3..f31b40ec 100644 --- a/backend/internal/service/gemini_messages_compat_service_test.go +++ b/backend/internal/service/gemini_messages_compat_service_test.go @@ -1,6 +1,8 @@ package service import ( + "encoding/json" + "strings" "testing" ) @@ -126,3 +128,78 @@ func TestConvertClaudeToolsToGeminiTools_CustomType(t *testing.T) { }) } } + +func TestConvertClaudeMessagesToGeminiGenerateContent_AddsThoughtSignatureForToolUse(t *testing.T) { + claudeReq := map[string]any{ + "model": "claude-haiku-4-5-20251001", + "max_tokens": 10, + "messages": []any{ + map[string]any{ + "role": "user", + "content": []any{ + map[string]any{"type": "text", "text": "hi"}, + }, + }, + map[string]any{ + "role": "assistant", + "content": []any{ + map[string]any{"type": "text", "text": "ok"}, + map[string]any{ + "type": "tool_use", + "id": "toolu_123", + "name": "default_api:write_file", + "input": map[string]any{"path": "a.txt", "content": "x"}, + // no signature on purpose + }, + }, + }, + }, + "tools": []any{ + map[string]any{ + "name": "default_api:write_file", + "description": "write file", + "input_schema": map[string]any{ + "type": "object", + "properties": map[string]any{"path": map[string]any{"type": "string"}}, + }, + }, + }, + } + b, _ := json.Marshal(claudeReq) + + out, err := convertClaudeMessagesToGeminiGenerateContent(b) + if err != nil { + t.Fatalf("convert failed: %v", err) + } + s := string(out) + if !strings.Contains(s, "\"functionCall\"") { + t.Fatalf("expected functionCall in output, got: %s", s) + } + if !strings.Contains(s, "\"thoughtSignature\":\""+geminiDummyThoughtSignature+"\"") { + t.Fatalf("expected injected thoughtSignature %q, got: %s", geminiDummyThoughtSignature, s) + } +} + +func TestEnsureGeminiFunctionCallThoughtSignatures_InsertsWhenMissing(t *testing.T) { + geminiReq := map[string]any{ + "contents": []any{ + map[string]any{ + "role": "user", + "parts": []any{ + map[string]any{ + "functionCall": map[string]any{ + "name": "default_api:write_file", + "args": map[string]any{"path": "a.txt"}, + }, + }, + }, + }, + }, + } + b, _ := json.Marshal(geminiReq) + out := ensureGeminiFunctionCallThoughtSignatures(b) + s := string(out) + if !strings.Contains(s, "\"thoughtSignature\":\""+geminiDummyThoughtSignature+"\"") { + t.Fatalf("expected injected thoughtSignature %q, got: %s", geminiDummyThoughtSignature, s) + } +} diff --git a/backend/internal/service/gemini_multiplatform_test.go b/backend/internal/service/gemini_multiplatform_test.go index c63a020c..e7ed80fd 100644 --- a/backend/internal/service/gemini_multiplatform_test.go +++ b/backend/internal/service/gemini_multiplatform_test.go @@ -218,6 +218,14 @@ func (m *mockGroupRepoForGemini) DeleteAccountGroupsByGroupID(ctx context.Contex return 0, nil } +func (m *mockGroupRepoForGemini) BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error { + return nil +} + +func (m *mockGroupRepoForGemini) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error) { + return nil, nil +} + var _ GroupRepository = (*mockGroupRepoForGemini)(nil) // mockGatewayCacheForGemini Gemini 测试用的 cache mock diff --git a/backend/internal/service/group_service.go b/backend/internal/service/group_service.go index 324f347b..a2bf2073 100644 --- a/backend/internal/service/group_service.go +++ b/backend/internal/service/group_service.go @@ -29,6 +29,10 @@ type GroupRepository interface { ExistsByName(ctx context.Context, name string) (bool, error) GetAccountCount(ctx context.Context, groupID int64) (int64, error) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) + // GetAccountIDsByGroupIDs 获取多个分组的所有账号 ID(去重) + GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error) + // BindAccountsToGroup 将多个账号绑定到指定分组 + BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error } // CreateGroupRequest 创建分组请求 diff --git a/backend/internal/service/identity_service.go b/backend/internal/service/identity_service.go index e2e723b0..a620ac4d 100644 --- a/backend/internal/service/identity_service.go +++ b/backend/internal/service/identity_service.go @@ -26,13 +26,13 @@ var ( // 默认指纹值(当客户端未提供时使用) var defaultFingerprint = Fingerprint{ - UserAgent: "claude-cli/2.0.62 (external, cli)", + UserAgent: "claude-cli/2.1.22 (external, cli)", StainlessLang: "js", - StainlessPackageVersion: "0.52.0", + StainlessPackageVersion: "0.70.0", StainlessOS: "Linux", - StainlessArch: "x64", + StainlessArch: "arm64", StainlessRuntime: "node", - StainlessRuntimeVersion: "v22.14.0", + StainlessRuntimeVersion: "v24.13.0", } // Fingerprint represents account fingerprint data @@ -327,7 +327,7 @@ func generateUUIDFromSeed(seed string) string { } // parseUserAgentVersion 解析user-agent版本号 -// 例如:claude-cli/2.0.62 -> (2, 0, 62) +// 例如:claude-cli/2.1.2 -> (2, 1, 2) func parseUserAgentVersion(ua string) (major, minor, patch int, ok bool) { // 匹配 xxx/x.y.z 格式 matches := userAgentVersionRegex.FindStringSubmatch(ua) diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 289a13af..04ea2930 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -156,12 +156,15 @@ type OpenAIUsage struct { // OpenAIForwardResult represents the result of forwarding type OpenAIForwardResult struct { - RequestID string - Usage OpenAIUsage - Model string - Stream bool - Duration time.Duration - FirstTokenMs *int + RequestID string + Usage OpenAIUsage + Model string + // ReasoningEffort is extracted from request body (reasoning.effort) or derived from model suffix. + // Stored for usage records display; nil means not provided / not applicable. + ReasoningEffort *string + Stream bool + Duration time.Duration + FirstTokenMs *int } // OpenAIGatewayService handles OpenAI API gateway operations @@ -958,13 +961,16 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco } } + reasoningEffort := extractOpenAIReasoningEffort(reqBody, originalModel) + return &OpenAIForwardResult{ - RequestID: resp.Header.Get("x-request-id"), - Usage: *usage, - Model: originalModel, - Stream: reqStream, - Duration: time.Since(startTime), - FirstTokenMs: firstTokenMs, + RequestID: resp.Header.Get("x-request-id"), + Usage: *usage, + Model: originalModel, + ReasoningEffort: reasoningEffort, + Stream: reqStream, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, }, nil } @@ -1260,15 +1266,29 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp // 记录上次收到上游数据的时间,用于控制 keepalive 发送频率 lastDataAt := time.Now() - // 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端) + // 仅发送一次错误事件,避免多次写入导致协议混乱。 + // 注意:OpenAI `/v1/responses` streaming 事件必须符合 OpenAI Responses schema; + // 否则下游 SDK(例如 OpenCode)会因为类型校验失败而报错。 errorEventSent := false + clientDisconnected := false // 客户端断开后继续 drain 上游以收集 usage sendErrorEvent := func(reason string) { - if errorEventSent { + if errorEventSent || clientDisconnected { return } errorEventSent = true - _, _ = fmt.Fprintf(w, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason) - flusher.Flush() + payload := map[string]any{ + "type": "error", + "sequence_number": 0, + "error": map[string]any{ + "type": "upstream_error", + "message": reason, + "code": reason, + }, + } + if b, err := json.Marshal(payload); err == nil { + _, _ = fmt.Fprintf(w, "data: %s\n\n", b) + flusher.Flush() + } } needModelReplace := originalModel != mappedModel @@ -1280,6 +1300,17 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil } if ev.err != nil { + // 客户端断开/取消请求时,上游读取往往会返回 context canceled。 + // /v1/responses 的 SSE 事件必须符合 OpenAI 协议;这里不注入自定义 error event,避免下游 SDK 解析失败。 + if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) { + log.Printf("Context canceled during streaming, returning collected usage") + return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil + } + // 客户端已断开时,上游出错仅影响体验,不影响计费;返回已收集 usage + if clientDisconnected { + log.Printf("Upstream read error after client disconnect: %v, returning collected usage", ev.err) + return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil + } if errors.Is(ev.err, bufio.ErrTooLong) { log.Printf("SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err) sendErrorEvent("response_too_large") @@ -1303,15 +1334,19 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp // Correct Codex tool calls if needed (apply_patch -> edit, etc.) if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEData(data); corrected { + data = correctedData line = "data: " + correctedData } - // Forward line - if _, err := fmt.Fprintf(w, "%s\n", line); err != nil { - sendErrorEvent("write_failed") - return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err + // 写入客户端(客户端断开后继续 drain 上游) + if !clientDisconnected { + if _, err := fmt.Fprintf(w, "%s\n", line); err != nil { + clientDisconnected = true + log.Printf("Client disconnected during streaming, continuing to drain upstream for billing") + } else { + flusher.Flush() + } } - flusher.Flush() // Record first token time if firstTokenMs == nil && data != "" && data != "[DONE]" { @@ -1321,11 +1356,14 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp s.parseSSEUsage(data, usage) } else { // Forward non-data lines as-is - if _, err := fmt.Fprintf(w, "%s\n", line); err != nil { - sendErrorEvent("write_failed") - return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err + if !clientDisconnected { + if _, err := fmt.Fprintf(w, "%s\n", line); err != nil { + clientDisconnected = true + log.Printf("Client disconnected during streaming, continuing to drain upstream for billing") + } else { + flusher.Flush() + } } - flusher.Flush() } case <-intervalCh: @@ -1333,6 +1371,10 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp if time.Since(lastRead) < streamInterval { continue } + if clientDisconnected { + log.Printf("Upstream timeout after client disconnect, returning collected usage") + return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil + } log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval) // 处理流超时,可能标记账户为临时不可调度或错误状态 if s.rateLimitService != nil { @@ -1342,11 +1384,16 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout") case <-keepaliveCh: + if clientDisconnected { + continue + } if time.Since(lastDataAt) < keepaliveInterval { continue } if _, err := fmt.Fprint(w, ":\n\n"); err != nil { - return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err + clientDisconnected = true + log.Printf("Client disconnected during streaming, continuing to drain upstream for billing") + continue } flusher.Flush() } @@ -1687,6 +1734,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec AccountID: account.ID, RequestID: result.RequestID, Model: result.Model, + ReasoningEffort: result.ReasoningEffort, InputTokens: actualInputTokens, OutputTokens: result.Usage.OutputTokens, CacheCreationTokens: result.Usage.CacheCreationInputTokens, @@ -1881,3 +1929,86 @@ func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, acc _ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates) }() } + +func getOpenAIReasoningEffortFromReqBody(reqBody map[string]any) (value string, present bool) { + if reqBody == nil { + return "", false + } + + // Primary: reasoning.effort + if reasoning, ok := reqBody["reasoning"].(map[string]any); ok { + if effort, ok := reasoning["effort"].(string); ok { + return normalizeOpenAIReasoningEffort(effort), true + } + } + + // Fallback: some clients may use a flat field. + if effort, ok := reqBody["reasoning_effort"].(string); ok { + return normalizeOpenAIReasoningEffort(effort), true + } + + return "", false +} + +func deriveOpenAIReasoningEffortFromModel(model string) string { + if strings.TrimSpace(model) == "" { + return "" + } + + modelID := strings.TrimSpace(model) + if strings.Contains(modelID, "/") { + parts := strings.Split(modelID, "/") + modelID = parts[len(parts)-1] + } + + parts := strings.FieldsFunc(strings.ToLower(modelID), func(r rune) bool { + switch r { + case '-', '_', ' ': + return true + default: + return false + } + }) + if len(parts) == 0 { + return "" + } + + return normalizeOpenAIReasoningEffort(parts[len(parts)-1]) +} + +func extractOpenAIReasoningEffort(reqBody map[string]any, requestedModel string) *string { + if value, present := getOpenAIReasoningEffortFromReqBody(reqBody); present { + if value == "" { + return nil + } + return &value + } + + value := deriveOpenAIReasoningEffortFromModel(requestedModel) + if value == "" { + return nil + } + return &value +} + +func normalizeOpenAIReasoningEffort(raw string) string { + value := strings.ToLower(strings.TrimSpace(raw)) + if value == "" { + return "" + } + + // Normalize separators for "x-high"/"x_high" variants. + value = strings.NewReplacer("-", "", "_", "", " ", "").Replace(value) + + switch value { + case "none", "minimal": + return "" + case "low", "medium", "high": + return value + case "xhigh", "extrahigh": + return "xhigh" + default: + // Only store known effort levels for now to keep UI consistent. + return "" + } +} diff --git a/backend/internal/service/openai_gateway_service_test.go b/backend/internal/service/openai_gateway_service_test.go index 1912e244..ae69a986 100644 --- a/backend/internal/service/openai_gateway_service_test.go +++ b/backend/internal/service/openai_gateway_service_test.go @@ -59,6 +59,25 @@ type stubConcurrencyCache struct { skipDefaultLoad bool } +type cancelReadCloser struct{} + +func (c cancelReadCloser) Read(p []byte) (int, error) { return 0, context.Canceled } +func (c cancelReadCloser) Close() error { return nil } + +type failingGinWriter struct { + gin.ResponseWriter + failAfter int + writes int +} + +func (w *failingGinWriter) Write(p []byte) (int, error) { + if w.writes >= w.failAfter { + return 0, errors.New("write failed") + } + w.writes++ + return w.ResponseWriter.Write(p) +} + func (c stubConcurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) { if c.acquireResults != nil { if result, ok := c.acquireResults[accountID]; ok { @@ -814,8 +833,85 @@ func TestOpenAIStreamingTimeout(t *testing.T) { if err == nil || !strings.Contains(err.Error(), "stream data interval timeout") { t.Fatalf("expected stream timeout error, got %v", err) } - if !strings.Contains(rec.Body.String(), "stream_timeout") { - t.Fatalf("expected stream_timeout SSE error, got %q", rec.Body.String()) + if !strings.Contains(rec.Body.String(), "\"type\":\"error\"") || !strings.Contains(rec.Body.String(), "stream_timeout") { + t.Fatalf("expected OpenAI-compatible error SSE event, got %q", rec.Body.String()) + } +} + +func TestOpenAIStreamingContextCanceledDoesNotInjectErrorEvent(t *testing.T) { + gin.SetMode(gin.TestMode) + cfg := &config.Config{ + Gateway: config.GatewayConfig{ + StreamDataIntervalTimeout: 0, + StreamKeepaliveInterval: 0, + MaxLineSize: defaultMaxLineSize, + }, + } + svc := &OpenAIGatewayService{cfg: cfg} + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + c.Request = httptest.NewRequest(http.MethodPost, "/", nil).WithContext(ctx) + + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: cancelReadCloser{}, + Header: http.Header{}, + } + + _, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model") + if err != nil { + t.Fatalf("expected nil error, got %v", err) + } + if strings.Contains(rec.Body.String(), "event: error") || strings.Contains(rec.Body.String(), "stream_read_error") { + t.Fatalf("expected no injected SSE error event, got %q", rec.Body.String()) + } +} + +func TestOpenAIStreamingClientDisconnectDrainsUpstreamUsage(t *testing.T) { + gin.SetMode(gin.TestMode) + cfg := &config.Config{ + Gateway: config.GatewayConfig{ + StreamDataIntervalTimeout: 0, + StreamKeepaliveInterval: 0, + MaxLineSize: defaultMaxLineSize, + }, + } + svc := &OpenAIGatewayService{cfg: cfg} + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + c.Writer = &failingGinWriter{ResponseWriter: c.Writer, failAfter: 0} + + pr, pw := io.Pipe() + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: pr, + Header: http.Header{}, + } + + go func() { + defer func() { _ = pw.Close() }() + _, _ = pw.Write([]byte("data: {\"type\":\"response.in_progress\",\"response\":{}}\n\n")) + _, _ = pw.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"usage\":{\"input_tokens\":3,\"output_tokens\":5,\"input_tokens_details\":{\"cached_tokens\":1}}}}\n\n")) + }() + + result, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model") + _ = pr.Close() + if err != nil { + t.Fatalf("expected nil error, got %v", err) + } + if result == nil || result.usage == nil { + t.Fatalf("expected usage result") + } + if result.usage.InputTokens != 3 || result.usage.OutputTokens != 5 || result.usage.CacheReadInputTokens != 1 { + t.Fatalf("unexpected usage: %+v", *result.usage) + } + if strings.Contains(rec.Body.String(), "event: error") || strings.Contains(rec.Body.String(), "write_failed") { + t.Fatalf("expected no injected SSE error event, got %q", rec.Body.String()) } } @@ -854,8 +950,8 @@ func TestOpenAIStreamingTooLong(t *testing.T) { if !errors.Is(err, bufio.ErrTooLong) { t.Fatalf("expected ErrTooLong, got %v", err) } - if !strings.Contains(rec.Body.String(), "response_too_large") { - t.Fatalf("expected response_too_large SSE error, got %q", rec.Body.String()) + if !strings.Contains(rec.Body.String(), "\"type\":\"error\"") || !strings.Contains(rec.Body.String(), "response_too_large") { + t.Fatalf("expected OpenAI-compatible error SSE event, got %q", rec.Body.String()) } } diff --git a/backend/internal/service/redeem_service.go b/backend/internal/service/redeem_service.go index ff52dc47..adcafb3f 100644 --- a/backend/internal/service/redeem_service.go +++ b/backend/internal/service/redeem_service.go @@ -126,7 +126,8 @@ func (s *RedeemService) GenerateCodes(ctx context.Context, req GenerateCodesRequ return nil, errors.New("count must be greater than 0") } - if req.Value <= 0 { + // 邀请码类型不需要数值,其他类型需要 + if req.Type != RedeemTypeInvitation && req.Value <= 0 { return nil, errors.New("value must be greater than 0") } @@ -139,6 +140,12 @@ func (s *RedeemService) GenerateCodes(ctx context.Context, req GenerateCodesRequ codeType = RedeemTypeBalance } + // 邀请码类型的 value 设为 0 + value := req.Value + if codeType == RedeemTypeInvitation { + value = 0 + } + codes := make([]RedeemCode, 0, req.Count) for i := 0; i < req.Count; i++ { code, err := s.GenerateRandomCode() @@ -149,7 +156,7 @@ func (s *RedeemService) GenerateCodes(ctx context.Context, req GenerateCodesRequ codes = append(codes, RedeemCode{ Code: code, Type: codeType, - Value: req.Value, + Value: value, Status: StatusUnused, }) } diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index 60ae9543..e40d0d0d 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -62,6 +62,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings SettingKeyEmailVerifyEnabled, SettingKeyPromoCodeEnabled, SettingKeyPasswordResetEnabled, + SettingKeyInvitationCodeEnabled, SettingKeyTotpEnabled, SettingKeyTurnstileEnabled, SettingKeyTurnstileSiteKey, @@ -99,6 +100,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings EmailVerifyEnabled: emailVerifyEnabled, PromoCodeEnabled: settings[SettingKeyPromoCodeEnabled] != "false", // 默认启用 PasswordResetEnabled: passwordResetEnabled, + InvitationCodeEnabled: settings[SettingKeyInvitationCodeEnabled] == "true", TotpEnabled: settings[SettingKeyTotpEnabled] == "true", TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true", TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey], @@ -141,6 +143,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any EmailVerifyEnabled bool `json:"email_verify_enabled"` PromoCodeEnabled bool `json:"promo_code_enabled"` PasswordResetEnabled bool `json:"password_reset_enabled"` + InvitationCodeEnabled bool `json:"invitation_code_enabled"` TotpEnabled bool `json:"totp_enabled"` TurnstileEnabled bool `json:"turnstile_enabled"` TurnstileSiteKey string `json:"turnstile_site_key,omitempty"` @@ -161,6 +164,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any EmailVerifyEnabled: settings.EmailVerifyEnabled, PromoCodeEnabled: settings.PromoCodeEnabled, PasswordResetEnabled: settings.PasswordResetEnabled, + InvitationCodeEnabled: settings.InvitationCodeEnabled, TotpEnabled: settings.TotpEnabled, TurnstileEnabled: settings.TurnstileEnabled, TurnstileSiteKey: settings.TurnstileSiteKey, @@ -188,6 +192,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet updates[SettingKeyEmailVerifyEnabled] = strconv.FormatBool(settings.EmailVerifyEnabled) updates[SettingKeyPromoCodeEnabled] = strconv.FormatBool(settings.PromoCodeEnabled) updates[SettingKeyPasswordResetEnabled] = strconv.FormatBool(settings.PasswordResetEnabled) + updates[SettingKeyInvitationCodeEnabled] = strconv.FormatBool(settings.InvitationCodeEnabled) updates[SettingKeyTotpEnabled] = strconv.FormatBool(settings.TotpEnabled) // 邮件服务设置(只有非空才更新密码) @@ -286,6 +291,14 @@ func (s *SettingService) IsPromoCodeEnabled(ctx context.Context) bool { return value != "false" } +// IsInvitationCodeEnabled 检查是否启用邀请码注册功能 +func (s *SettingService) IsInvitationCodeEnabled(ctx context.Context) bool { + value, err := s.settingRepo.GetValue(ctx, SettingKeyInvitationCodeEnabled) + if err != nil { + return false // 默认关闭 + } + return value == "true" +} // IsPasswordResetEnabled 检查是否启用密码重置功能 // 要求:必须同时开启邮件验证 func (s *SettingService) IsPasswordResetEnabled(ctx context.Context) bool { @@ -401,6 +414,7 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin EmailVerifyEnabled: emailVerifyEnabled, PromoCodeEnabled: settings[SettingKeyPromoCodeEnabled] != "false", // 默认启用 PasswordResetEnabled: emailVerifyEnabled && settings[SettingKeyPasswordResetEnabled] == "true", + InvitationCodeEnabled: settings[SettingKeyInvitationCodeEnabled] == "true", TotpEnabled: settings[SettingKeyTotpEnabled] == "true", SMTPHost: settings[SettingKeySMTPHost], SMTPUsername: settings[SettingKeySMTPUsername], diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go index 358911dc..0c7bab67 100644 --- a/backend/internal/service/settings_view.go +++ b/backend/internal/service/settings_view.go @@ -1,11 +1,12 @@ package service type SystemSettings struct { - RegistrationEnabled bool - EmailVerifyEnabled bool - PromoCodeEnabled bool - PasswordResetEnabled bool - TotpEnabled bool // TOTP 双因素认证 + RegistrationEnabled bool + EmailVerifyEnabled bool + PromoCodeEnabled bool + PasswordResetEnabled bool + InvitationCodeEnabled bool + TotpEnabled bool // TOTP 双因素认证 SMTPHost string SMTPPort int @@ -61,21 +62,22 @@ type SystemSettings struct { } type PublicSettings struct { - RegistrationEnabled bool - EmailVerifyEnabled bool - PromoCodeEnabled bool - PasswordResetEnabled bool - TotpEnabled bool // TOTP 双因素认证 - TurnstileEnabled bool - TurnstileSiteKey string - SiteName string - SiteLogo string - SiteSubtitle string - APIBaseURL string - ContactInfo string - DocURL string - HomeContent string - HideCcsImportButton bool + RegistrationEnabled bool + EmailVerifyEnabled bool + PromoCodeEnabled bool + PasswordResetEnabled bool + InvitationCodeEnabled bool + TotpEnabled bool // TOTP 双因素认证 + TurnstileEnabled bool + TurnstileSiteKey string + SiteName string + SiteLogo string + SiteSubtitle string + APIBaseURL string + ContactInfo string + DocURL string + HomeContent string + HideCcsImportButton bool PurchaseSubscriptionEnabled bool PurchaseSubscriptionURL string diff --git a/backend/internal/service/usage_log.go b/backend/internal/service/usage_log.go index 3b0e934f..a9721d7f 100644 --- a/backend/internal/service/usage_log.go +++ b/backend/internal/service/usage_log.go @@ -14,6 +14,9 @@ type UsageLog struct { AccountID int64 RequestID string Model string + // ReasoningEffort is the request's reasoning effort level (OpenAI Responses API), + // e.g. "low" / "medium" / "high" / "xhigh". Nil means not provided / not applicable. + ReasoningEffort *string GroupID *int64 SubscriptionID *int64 diff --git a/backend/migrations/046_add_usage_log_reasoning_effort.sql b/backend/migrations/046_add_usage_log_reasoning_effort.sql new file mode 100644 index 00000000..f6572d1d --- /dev/null +++ b/backend/migrations/046_add_usage_log_reasoning_effort.sql @@ -0,0 +1,4 @@ +-- Add reasoning_effort field to usage_logs for OpenAI/Codex requests. +-- This stores the request's reasoning effort level (e.g. low/medium/high/xhigh). +ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS reasoning_effort VARCHAR(20); + diff --git a/backend/migrations/046_add_group_supported_model_scopes.sql b/backend/migrations/046b_add_group_supported_model_scopes.sql similarity index 100% rename from backend/migrations/046_add_group_supported_model_scopes.sql rename to backend/migrations/046b_add_group_supported_model_scopes.sql diff --git a/docs/rename_local_migrations_20260202.sql b/docs/rename_local_migrations_20260202.sql index a5ba2ef1..911ed17d 100644 --- a/docs/rename_local_migrations_20260202.sql +++ b/docs/rename_local_migrations_20260202.sql @@ -24,4 +24,11 @@ WHERE filename = '044_add_group_mcp_xml_inject.sql' SELECT 1 FROM schema_migrations WHERE filename = '044b_add_group_mcp_xml_inject.sql' ); +UPDATE schema_migrations +SET filename = '046b_add_group_supported_model_scopes.sql' +WHERE filename = '046_add_group_supported_model_scopes.sql' + AND NOT EXISTS ( + SELECT 1 FROM schema_migrations WHERE filename = '046b_add_group_supported_model_scopes.sql' + ); + COMMIT; diff --git a/frontend/src/api/admin/settings.ts b/frontend/src/api/admin/settings.ts index a0595e4f..3dc76fe7 100644 --- a/frontend/src/api/admin/settings.ts +++ b/frontend/src/api/admin/settings.ts @@ -14,6 +14,7 @@ export interface SystemSettings { email_verify_enabled: boolean promo_code_enabled: boolean password_reset_enabled: boolean + invitation_code_enabled: boolean totp_enabled: boolean // TOTP 双因素认证 totp_encryption_key_configured: boolean // TOTP 加密密钥是否已配置 // Default settings @@ -72,6 +73,7 @@ export interface UpdateSettingsRequest { email_verify_enabled?: boolean promo_code_enabled?: boolean password_reset_enabled?: boolean + invitation_code_enabled?: boolean totp_enabled?: boolean // TOTP 双因素认证 default_balance?: number default_concurrency?: number diff --git a/frontend/src/api/auth.ts b/frontend/src/api/auth.ts index bbd5ed74..40c9c5a4 100644 --- a/frontend/src/api/auth.ts +++ b/frontend/src/api/auth.ts @@ -164,6 +164,24 @@ export async function validatePromoCode(code: string): Promise { + const { data } = await apiClient.post('/auth/validate-invitation-code', { code }) + return data +} + /** * Forgot password request */ @@ -229,6 +247,7 @@ export const authAPI = { getPublicSettings, sendVerifyCode, validatePromoCode, + validateInvitationCode, forgotPassword, resetPassword } diff --git a/frontend/src/components/admin/usage/UsageTable.vue b/frontend/src/components/admin/usage/UsageTable.vue index f6d1b1be..fbb1942a 100644 --- a/frontend/src/components/admin/usage/UsageTable.vue +++ b/frontend/src/components/admin/usage/UsageTable.vue @@ -21,6 +21,12 @@ {{ value }} + +