diff --git a/README.md b/README.md index 41a5aca1..99753e45 100644 --- a/README.md +++ b/README.md @@ -49,9 +49,13 @@ Sub2API is an AI API gateway platform designed to distribute and manage API quot - + + + + +
pinccpincc PinCC is the official relay service built on Sub2API, offering stable access to Claude Code, Codex, Gemini and other popular models — ready to use, no deployment or maintenance required.
PackyCodeThanks to PackyCode for sponsoring this project! PackyCode is a reliable and efficient API relay service provider, offering relay services for Claude Code, Codex, Gemini, and more. PackyCode provides special discounts for our software users: register using this link and enter the "sub2api" promo code during first recharge to get 10% off.
## Ecosystem diff --git a/README_CN.md b/README_CN.md index 3380cce7..8b6feaba 100644 --- a/README_CN.md +++ b/README_CN.md @@ -48,9 +48,13 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅的 - + + + + +
pinccpincc PinCC 是基于 Sub2API 搭建的官方中转服务,提供 Claude Code、Codex、Gemini 等主流模型的稳定中转,开箱即用,免去自建部署与运维烦恼。
PackyCode感谢 PackyCode 赞助了本项目!PackyCode 是一家稳定、高效的API中转服务商,提供 Claude Code、Codex、Gemini 等多种中转服务。PackyCode 为本软件的用户提供了特别优惠,使用此链接注册并在充值时填写"sub2api"优惠码,首次充值可以享受9折优惠!
## 生态项目 diff --git a/README_JA.md b/README_JA.md index c60b1a8e..1266bd84 100644 --- a/README_JA.md +++ b/README_JA.md @@ -49,9 +49,13 @@ Sub2API は、AI 製品のサブスクリプションから API クォータを - + + + + +
pinccpincc PinCC は Sub2API 上に構築された公式リレーサービスで、Claude Code、Codex、Gemini などの人気モデルへの安定したアクセスを提供します。デプロイやメンテナンスは不要で、すぐにご利用いただけます。
PackyCodePackyCode のご支援に感謝します!PackyCode は Claude Code、Codex、Gemini などのリレーサービスを提供する信頼性の高い API 中継プラットフォームです。本ソフト利用者向けに特別割引があります:このリンクで登録し、チャージ時に「sub2api」クーポンを入力すると 10% オフになります。
## エコシステム diff --git a/assets/partners/logos/packycode.png b/assets/partners/logos/packycode.png new file mode 100644 index 00000000..4fc7eecc Binary files /dev/null and b/assets/partners/logos/packycode.png differ diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION index 23175873..9e3db2aa 100644 --- a/backend/cmd/server/VERSION +++ b/backend/cmd/server/VERSION @@ -1 +1 @@ -0.1.105 +0.1.106 diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 300cda00..ce898a4a 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -137,7 +137,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db) schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig) antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oauthRefreshAPI, tempUnschedCache) - antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService) + internal500CounterCache := repository.NewInternal500CounterCache(redisClient) + antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache) tlsFingerprintProfileRepository := repository.NewTLSFingerprintProfileRepository(client) tlsFingerprintProfileCache := repository.NewTLSFingerprintProfileCache(redisClient) tlsFingerprintProfileService := service.NewTLSFingerprintProfileService(tlsFingerprintProfileRepository, tlsFingerprintProfileCache) diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index d1cb76db..3ee5d6cd 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -1281,8 +1281,8 @@ func setDefaults() { viper.SetDefault("rate_limit.oauth_401_cooldown_minutes", 10) // Pricing - 从 model-price-repo 同步模型定价和上下文窗口数据(固定到 commit,避免分支漂移) - viper.SetDefault("pricing.remote_url", "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/c7947e9871687e664180bc971d4837f1fc2784a9/model_prices_and_context_window.json") - viper.SetDefault("pricing.hash_url", "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/c7947e9871687e664180bc971d4837f1fc2784a9/model_prices_and_context_window.sha256") + viper.SetDefault("pricing.remote_url", "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/main/model_prices_and_context_window.json") + viper.SetDefault("pricing.hash_url", "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/main/model_prices_and_context_window.sha256") viper.SetDefault("pricing.data_dir", "./data") viper.SetDefault("pricing.fallback_file", "./resources/model-pricing/model_prices_and_context_window.json") viper.SetDefault("pricing.update_interval_hours", 24) diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go index e39b36d3..0b5448af 100644 --- a/backend/internal/handler/dto/mappers.go +++ b/backend/internal/handler/dto/mappers.go @@ -268,6 +268,14 @@ func AccountFromServiceShallow(a *service.Account) *Account { target := a.GetCacheTTLOverrideTarget() out.CacheTTLOverrideTarget = &target } + // 自定义 Base URL 中继转发 + if a.IsCustomBaseURLEnabled() { + enabled := true + out.CustomBaseURLEnabled = &enabled + if customURL := a.GetCustomBaseURL(); customURL != "" { + out.CustomBaseURL = &customURL + } + } } // 提取账号配额限制(apikey / bedrock 类型有效) diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go index aa419d6b..8af6990e 100644 --- a/backend/internal/handler/dto/types.go +++ b/backend/internal/handler/dto/types.go @@ -198,6 +198,10 @@ type Account struct { CacheTTLOverrideEnabled *bool `json:"cache_ttl_override_enabled,omitempty"` CacheTTLOverrideTarget *string `json:"cache_ttl_override_target,omitempty"` + // 自定义 Base URL 中继转发(仅 Anthropic OAuth/SetupToken 账号有效) + CustomBaseURLEnabled *bool `json:"custom_base_url_enabled,omitempty"` + CustomBaseURL *string `json:"custom_base_url,omitempty"` + // API Key 账号配额限制 QuotaLimit *float64 `json:"quota_limit,omitempty"` QuotaUsed *float64 `json:"quota_used,omitempty"` diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index 3ce6e5d6..ae70cee4 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -541,6 +541,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { return } reqModel := modelResult.String() + routingModel := service.NormalizeOpenAICompatRequestedModel(reqModel) reqStream := gjson.GetBytes(body, "stream").Bool() reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream)) @@ -606,7 +607,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { apiKey.GroupID, "", // no previous_response_id sessionHash, - reqModel, + routingModel, failedAccountIDs, service.OpenAIUpstreamTransportAny, ) @@ -621,7 +622,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { if apiKey.Group != nil { defaultModel = apiKey.Group.DefaultMappedModel } - if defaultModel != "" && defaultModel != reqModel { + if defaultModel != "" && defaultModel != routingModel { reqLog.Info("openai_messages.fallback_to_default_model", zap.String("default_mapped_model", defaultModel), ) diff --git a/backend/internal/pkg/oauth/oauth.go b/backend/internal/pkg/oauth/oauth.go index cfc91bee..c5ef3c6e 100644 --- a/backend/internal/pkg/oauth/oauth.go +++ b/backend/internal/pkg/oauth/oauth.go @@ -24,20 +24,18 @@ const ( RedirectURI = "https://platform.claude.com/oauth/code/callback" // Scopes - Browser URL (includes org:create_api_key for user authorization) - ScopeOAuth = "org:create_api_key user:profile user:inference user:sessions:claude_code user:mcp_servers" + ScopeOAuth = "org:create_api_key user:profile user:inference user:sessions:claude_code user:mcp_servers user:file_upload" // Scopes - Internal API call (org:create_api_key not supported in API) - ScopeAPI = "user:profile user:inference user:sessions:claude_code user:mcp_servers" + ScopeAPI = "user:profile user:inference user:sessions:claude_code user:mcp_servers user:file_upload" // Scopes - Setup token (inference only) ScopeInference = "user:inference" - // Code Verifier character set (RFC 7636 compliant) - codeVerifierCharset = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~" - // Session TTL SessionTTL = 30 * time.Minute ) // OAuthSession stores OAuth flow state + type OAuthSession struct { State string `json:"state"` CodeVerifier string `json:"code_verifier"` @@ -147,30 +145,14 @@ func GenerateSessionID() (string, error) { return hex.EncodeToString(bytes), nil } -// GenerateCodeVerifier generates a PKCE code verifier using character set method +// GenerateCodeVerifier generates a PKCE code verifier (RFC 7636). +// Uses 32 random bytes → base64url-no-pad, producing a 43-char verifier. func GenerateCodeVerifier() (string, error) { - const targetLen = 32 - charsetLen := len(codeVerifierCharset) - limit := 256 - (256 % charsetLen) - - result := make([]byte, 0, targetLen) - randBuf := make([]byte, targetLen*2) - - for len(result) < targetLen { - if _, err := rand.Read(randBuf); err != nil { - return "", err - } - for _, b := range randBuf { - if int(b) < limit { - result = append(result, codeVerifierCharset[int(b)%charsetLen]) - if len(result) >= targetLen { - break - } - } - } + bytes, err := GenerateRandomBytes(32) + if err != nil { + return "", err } - - return base64URLEncode(result), nil + return base64URLEncode(bytes), nil } // GenerateCodeChallenge generates a PKCE code challenge using S256 method diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go index 859eefd5..667193a6 100644 --- a/backend/internal/repository/api_key_repo.go +++ b/backend/internal/repository/api_key_repo.go @@ -3,6 +3,7 @@ package repository import ( "context" "database/sql" + "fmt" "time" dbent "github.com/Wei-Shaw/sub2api/ent" @@ -257,9 +258,12 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) erro } func (r *apiKeyRepository) Delete(ctx context.Context, id int64) error { + // 存在唯一键约束 生成tombstone key 用来释放原key,长度远小于 128,满足 schema 限制 + tombstoneKey := fmt.Sprintf("__deleted__%d__%d", id, time.Now().UnixNano()) // 显式软删除:避免依赖 Hook 行为,确保 deleted_at 一定被设置。 affected, err := r.client.APIKey.Update(). Where(apikey.IDEQ(id), apikey.DeletedAtIsNil()). + SetKey(tombstoneKey). SetDeletedAt(time.Now()). Save(ctx) if err != nil { diff --git a/backend/internal/repository/api_key_repo_integration_test.go b/backend/internal/repository/api_key_repo_integration_test.go index a8989ff2..7d5c1826 100644 --- a/backend/internal/repository/api_key_repo_integration_test.go +++ b/backend/internal/repository/api_key_repo_integration_test.go @@ -151,6 +151,31 @@ func (s *APIKeyRepoSuite) TestDelete() { s.Require().Error(err, "expected error after delete") } +func (s *APIKeyRepoSuite) TestCreate_AfterSoftDelete_AllowsSameKey() { + user := s.mustCreateUser("recreate-after-soft-delete@test.com") + const reusedKey = "sk-reuse-after-soft-delete" + + first := &service.APIKey{ + UserID: user.ID, + Key: reusedKey, + Name: "First Key", + Status: service.StatusActive, + } + s.Require().NoError(s.repo.Create(s.ctx, first), "create first key") + + s.Require().NoError(s.repo.Delete(s.ctx, first.ID), "soft delete first key") + + second := &service.APIKey{ + UserID: user.ID, + Key: reusedKey, + Name: "Second Key", + Status: service.StatusActive, + } + s.Require().NoError(s.repo.Create(s.ctx, second), "create second key with same key") + s.Require().NotZero(second.ID) + s.Require().NotEqual(first.ID, second.ID, "recreated key should be a new row") +} + // --- ListByUserID / CountByUserID --- func (s *APIKeyRepoSuite) TestListByUserID() { diff --git a/backend/internal/repository/internal500_counter_cache.go b/backend/internal/repository/internal500_counter_cache.go new file mode 100644 index 00000000..13b0faa8 --- /dev/null +++ b/backend/internal/repository/internal500_counter_cache.go @@ -0,0 +1,55 @@ +package repository + +import ( + "context" + "fmt" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/redis/go-redis/v9" +) + +const ( + internal500CounterPrefix = "internal500_count:account:" + internal500CounterTTLSeconds = 86400 // 24 小时兜底 +) + +// internal500CounterIncrScript 使用 Lua 脚本原子性地增加计数并返回当前值 +// 如果 key 不存在,则创建并设置过期时间 +var internal500CounterIncrScript = redis.NewScript(` + local key = KEYS[1] + local ttl = tonumber(ARGV[1]) + + local count = redis.call('INCR', key) + if count == 1 then + redis.call('EXPIRE', key, ttl) + end + + return count +`) + +type internal500CounterCache struct { + rdb *redis.Client +} + +// NewInternal500CounterCache 创建 INTERNAL 500 连续失败计数器缓存实例 +func NewInternal500CounterCache(rdb *redis.Client) service.Internal500CounterCache { + return &internal500CounterCache{rdb: rdb} +} + +// IncrementInternal500Count 原子递增计数并返回当前值 +func (c *internal500CounterCache) IncrementInternal500Count(ctx context.Context, accountID int64) (int64, error) { + key := fmt.Sprintf("%s%d", internal500CounterPrefix, accountID) + + result, err := internal500CounterIncrScript.Run(ctx, c.rdb, []string{key}, internal500CounterTTLSeconds).Int64() + if err != nil { + return 0, fmt.Errorf("increment internal500 count: %w", err) + } + + return result, nil +} + +// ResetInternal500Count 清零计数器(成功响应时调用) +func (c *internal500CounterCache) ResetInternal500Count(ctx context.Context, accountID int64) error { + key := fmt.Sprintf("%s%d", internal500CounterPrefix, accountID) + return c.rdb.Del(ctx, key).Err() +} diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go index f65f9beb..49d47bf6 100644 --- a/backend/internal/repository/wire.go +++ b/backend/internal/repository/wire.go @@ -81,6 +81,7 @@ var ProviderSet = wire.NewSet( NewAPIKeyCache, NewTempUnschedCache, NewTimeoutCounterCache, + NewInternal500CounterCache, ProvideConcurrencyCache, ProvideSessionLimitCache, NewRPMCache, diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go index 741e33e8..a1449ffd 100644 --- a/backend/internal/service/account.go +++ b/backend/internal/service/account.go @@ -1229,6 +1229,28 @@ func (a *Account) IsSessionIDMaskingEnabled() bool { return false } +// IsCustomBaseURLEnabled 检查是否启用自定义 base URL 中继转发 +// 仅适用于 Anthropic OAuth/SetupToken 类型账号 +func (a *Account) IsCustomBaseURLEnabled() bool { + if !a.IsAnthropicOAuthOrSetupToken() { + return false + } + if a.Extra == nil { + return false + } + if v, ok := a.Extra["custom_base_url_enabled"]; ok { + if enabled, ok := v.(bool); ok { + return enabled + } + } + return false +} + +// GetCustomBaseURL 返回自定义中继服务的 base URL +func (a *Account) GetCustomBaseURL() string { + return a.GetExtraString("custom_base_url") +} + // IsCacheTTLOverrideEnabled 检查是否启用缓存 TTL 强制替换 // 仅适用于 Anthropic OAuth/SetupToken 类型账号 // 启用后将所有 cache creation tokens 归入指定的 TTL 类型(5m 或 1h) diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index fa175d5d..88c064f3 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -1866,6 +1866,18 @@ func (s *adminServiceImpl) ClearAccountError(ctx context.Context, id int64) (*Ac if err := s.accountRepo.ClearError(ctx, id); err != nil { return nil, err } + if err := s.accountRepo.ClearRateLimit(ctx, id); err != nil { + return nil, err + } + if err := s.accountRepo.ClearAntigravityQuotaScopes(ctx, id); err != nil { + return nil, err + } + if err := s.accountRepo.ClearModelRateLimits(ctx, id); err != nil { + return nil, err + } + if err := s.accountRepo.ClearTempUnschedulable(ctx, id); err != nil { + return nil, err + } return s.accountRepo.GetByID(ctx, id) } diff --git a/backend/internal/service/admin_service_clear_error_test.go b/backend/internal/service/admin_service_clear_error_test.go new file mode 100644 index 00000000..f039612c --- /dev/null +++ b/backend/internal/service/admin_service_clear_error_test.go @@ -0,0 +1,86 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +type accountRepoStubForClearAccountError struct { + mockAccountRepoForGemini + account *Account + clearErrorCalls int + clearRateLimitCalls int + clearAntigravityCalls int + clearModelRateLimitCalls int + clearTempUnschedCalls int +} + +func (r *accountRepoStubForClearAccountError) GetByID(ctx context.Context, id int64) (*Account, error) { + return r.account, nil +} + +func (r *accountRepoStubForClearAccountError) ClearError(ctx context.Context, id int64) error { + r.clearErrorCalls++ + r.account.Status = StatusActive + r.account.ErrorMessage = "" + return nil +} + +func (r *accountRepoStubForClearAccountError) ClearRateLimit(ctx context.Context, id int64) error { + r.clearRateLimitCalls++ + r.account.RateLimitedAt = nil + r.account.RateLimitResetAt = nil + return nil +} + +func (r *accountRepoStubForClearAccountError) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error { + r.clearAntigravityCalls++ + return nil +} + +func (r *accountRepoStubForClearAccountError) ClearModelRateLimits(ctx context.Context, id int64) error { + r.clearModelRateLimitCalls++ + return nil +} + +func (r *accountRepoStubForClearAccountError) ClearTempUnschedulable(ctx context.Context, id int64) error { + r.clearTempUnschedCalls++ + r.account.TempUnschedulableUntil = nil + r.account.TempUnschedulableReason = "" + return nil +} + +func TestAdminService_ClearAccountError_AlsoClearsRecoverableRuntimeState(t *testing.T) { + until := time.Now().Add(10 * time.Minute) + resetAt := time.Now().Add(5 * time.Minute) + repo := &accountRepoStubForClearAccountError{ + account: &Account{ + ID: 31, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Status: StatusError, + ErrorMessage: "refresh failed", + RateLimitResetAt: &resetAt, + TempUnschedulableUntil: &until, + TempUnschedulableReason: "missing refresh token", + }, + } + svc := &adminServiceImpl{accountRepo: repo} + + updated, err := svc.ClearAccountError(context.Background(), 31) + require.NoError(t, err) + require.NotNil(t, updated) + require.Equal(t, 1, repo.clearErrorCalls) + require.Equal(t, 1, repo.clearRateLimitCalls) + require.Equal(t, 1, repo.clearAntigravityCalls) + require.Equal(t, 1, repo.clearModelRateLimitCalls) + require.Equal(t, 1, repo.clearTempUnschedCalls) + require.Nil(t, updated.RateLimitResetAt) + require.Nil(t, updated.TempUnschedulableUntil) + require.Empty(t, updated.TempUnschedulableReason) +} diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index aa5d948c..a76e59fb 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -614,6 +614,7 @@ func (s *AntigravityGatewayService) antigravityRetryLoop(p antigravityRetryLoopP urlFallbackLoop: for urlIdx, baseURL := range availableURLs { usedBaseURL = baseURL + allAttemptsInternal500 := true // 追踪本轮所有 attempt 是否全部命中 INTERNAL 500 for attempt := 1; attempt <= antigravityMaxRetries; attempt++ { select { case <-p.ctx.Done(): @@ -766,10 +767,19 @@ urlFallbackLoop: logger.LegacyPrintf("service.antigravity_gateway", "%s status=context_canceled_during_backoff", p.prefix) return nil, p.ctx.Err() } + // 追踪 INTERNAL 500:非匹配的 attempt 清除标记 + if !isAntigravityInternalServerError(resp.StatusCode, respBody) { + allAttemptsInternal500 = false + } continue } } + // INTERNAL 500 渐进惩罚:3 次重试全部命中特定 500 时递增计数器并惩罚 + if allAttemptsInternal500 && isAntigravityInternalServerError(resp.StatusCode, respBody) { + s.handleInternal500RetryExhausted(p.ctx, p.prefix, p.account) + } + // 其他 4xx 错误或重试用尽,直接返回 resp = &http.Response{ StatusCode: resp.StatusCode, @@ -788,6 +798,11 @@ urlFallbackLoop: antigravity.DefaultURLAvailability.MarkSuccess(usedBaseURL) } + // 成功响应时清零 INTERNAL 500 连续失败计数器(覆盖所有成功路径,含 smart retry) + if resp != nil && resp.StatusCode < 400 { + s.resetInternal500Counter(p.ctx, p.prefix, p.account.ID) + } + return &antigravityRetryLoopResult{resp: resp}, nil } @@ -862,6 +877,7 @@ type AntigravityGatewayService struct { settingService *SettingService cache GatewayCache // 用于模型级限流时清除粘性会话绑定 schedulerSnapshot *SchedulerSnapshotService + internal500Cache Internal500CounterCache // INTERNAL 500 渐进惩罚计数器 } func NewAntigravityGatewayService( @@ -872,6 +888,7 @@ func NewAntigravityGatewayService( rateLimitService *RateLimitService, httpUpstream HTTPUpstream, settingService *SettingService, + internal500Cache Internal500CounterCache, ) *AntigravityGatewayService { return &AntigravityGatewayService{ accountRepo: accountRepo, @@ -881,6 +898,7 @@ func NewAntigravityGatewayService( settingService: settingService, cache: cache, schedulerSnapshot: schedulerSnapshot, + internal500Cache: internal500Cache, } } diff --git a/backend/internal/service/antigravity_internal500_penalty.go b/backend/internal/service/antigravity_internal500_penalty.go new file mode 100644 index 00000000..747a4d4e --- /dev/null +++ b/backend/internal/service/antigravity_internal500_penalty.go @@ -0,0 +1,97 @@ +package service + +import ( + "context" + "fmt" + "log/slog" + "net/http" + "time" + + "github.com/tidwall/gjson" +) + +// INTERNAL 500 渐进惩罚:连续多轮全部返回特定 500 错误时的惩罚时长 +const ( + internal500PenaltyTier1Duration = 30 * time.Minute // 第 1 轮:临时不可调度 30 分钟 + internal500PenaltyTier2Duration = 2 * time.Hour // 第 2 轮:临时不可调度 2 小时 + internal500PenaltyTier3Threshold = 3 // 第 3+ 轮:永久禁用 +) + +// isAntigravityInternalServerError 检测特定的 INTERNAL 500 错误 +// 必须同时匹配 error.code==500, error.message=="Internal error encountered.", error.status=="INTERNAL" +func isAntigravityInternalServerError(statusCode int, body []byte) bool { + if statusCode != http.StatusInternalServerError { + return false + } + return gjson.GetBytes(body, "error.code").Int() == 500 && + gjson.GetBytes(body, "error.message").String() == "Internal error encountered." && + gjson.GetBytes(body, "error.status").String() == "INTERNAL" +} + +// applyInternal500Penalty 根据连续 INTERNAL 500 轮次数应用渐进惩罚 +// count=1: temp_unschedulable 10 分钟 +// count=2: temp_unschedulable 10 小时 +// count>=3: SetError 永久禁用 +func (s *AntigravityGatewayService) applyInternal500Penalty( + ctx context.Context, prefix string, account *Account, count int64, +) { + switch { + case count >= int64(internal500PenaltyTier3Threshold): + reason := fmt.Sprintf("INTERNAL 500 consecutive failures: %d rounds", count) + if err := s.accountRepo.SetError(ctx, account.ID, reason); err != nil { + slog.Error("internal500_set_error_failed", "account_id", account.ID, "error", err) + return + } + slog.Warn("internal500_account_disabled", + "account_id", account.ID, "account_name", account.Name, "consecutive_count", count) + case count == 2: + until := time.Now().Add(internal500PenaltyTier2Duration) + reason := fmt.Sprintf("INTERNAL 500 x%d (temp unsched %v)", count, internal500PenaltyTier2Duration) + if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil { + slog.Error("internal500_temp_unsched_failed", "account_id", account.ID, "error", err) + return + } + slog.Warn("internal500_temp_unschedulable", + "account_id", account.ID, "account_name", account.Name, + "duration", internal500PenaltyTier2Duration, "consecutive_count", count) + case count == 1: + until := time.Now().Add(internal500PenaltyTier1Duration) + reason := fmt.Sprintf("INTERNAL 500 x%d (temp unsched %v)", count, internal500PenaltyTier1Duration) + if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil { + slog.Error("internal500_temp_unsched_failed", "account_id", account.ID, "error", err) + return + } + slog.Info("internal500_temp_unschedulable", + "account_id", account.ID, "account_name", account.Name, + "duration", internal500PenaltyTier1Duration, "consecutive_count", count) + } +} + +// handleInternal500RetryExhausted 处理 INTERNAL 500 重试耗尽:递增计数器并应用惩罚 +func (s *AntigravityGatewayService) handleInternal500RetryExhausted( + ctx context.Context, prefix string, account *Account, +) { + if s.internal500Cache == nil { + return + } + count, err := s.internal500Cache.IncrementInternal500Count(ctx, account.ID) + if err != nil { + slog.Error("internal500_counter_increment_failed", + "prefix", prefix, "account_id", account.ID, "error", err) + return + } + s.applyInternal500Penalty(ctx, prefix, account, count) +} + +// resetInternal500Counter 成功响应时清零 INTERNAL 500 计数器 +func (s *AntigravityGatewayService) resetInternal500Counter( + ctx context.Context, prefix string, accountID int64, +) { + if s.internal500Cache == nil { + return + } + if err := s.internal500Cache.ResetInternal500Count(ctx, accountID); err != nil { + slog.Error("internal500_counter_reset_failed", + "prefix", prefix, "account_id", accountID, "error", err) + } +} diff --git a/backend/internal/service/antigravity_internal500_penalty_test.go b/backend/internal/service/antigravity_internal500_penalty_test.go new file mode 100644 index 00000000..03831839 --- /dev/null +++ b/backend/internal/service/antigravity_internal500_penalty_test.go @@ -0,0 +1,321 @@ +//go:build unit + +package service + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +// --- mock: Internal500CounterCache --- + +type mockInternal500Cache struct { + incrementCount int64 + incrementErr error + resetErr error + + incrementCalls []int64 // 记录 IncrementInternal500Count 被调用时的 accountID + resetCalls []int64 // 记录 ResetInternal500Count 被调用时的 accountID +} + +func (m *mockInternal500Cache) IncrementInternal500Count(_ context.Context, accountID int64) (int64, error) { + m.incrementCalls = append(m.incrementCalls, accountID) + return m.incrementCount, m.incrementErr +} + +func (m *mockInternal500Cache) ResetInternal500Count(_ context.Context, accountID int64) error { + m.resetCalls = append(m.resetCalls, accountID) + return m.resetErr +} + +// --- mock: 专用于 internal500 惩罚测试的 AccountRepository --- + +type internal500AccountRepoStub struct { + AccountRepository // 嵌入接口,未实现的方法会 panic(不应被调用) + + tempUnschedCalls []tempUnschedCall + setErrorCalls []setErrorCall +} + +type tempUnschedCall struct { + accountID int64 + until time.Time + reason string +} + +type setErrorCall struct { + accountID int64 + reason string +} + +func (r *internal500AccountRepoStub) SetTempUnschedulable(_ context.Context, id int64, until time.Time, reason string) error { + r.tempUnschedCalls = append(r.tempUnschedCalls, tempUnschedCall{accountID: id, until: until, reason: reason}) + return nil +} + +func (r *internal500AccountRepoStub) SetError(_ context.Context, id int64, errorMsg string) error { + r.setErrorCalls = append(r.setErrorCalls, setErrorCall{accountID: id, reason: errorMsg}) + return nil +} + +// ============================================================================= +// TestIsAntigravityInternalServerError +// ============================================================================= + +func TestIsAntigravityInternalServerError(t *testing.T) { + t.Run("匹配完整的 INTERNAL 500 body", func(t *testing.T) { + body := []byte(`{"error":{"code":500,"message":"Internal error encountered.","status":"INTERNAL"}}`) + require.True(t, isAntigravityInternalServerError(500, body)) + }) + + t.Run("statusCode 不是 500", func(t *testing.T) { + body := []byte(`{"error":{"code":500,"message":"Internal error encountered.","status":"INTERNAL"}}`) + require.False(t, isAntigravityInternalServerError(429, body)) + require.False(t, isAntigravityInternalServerError(503, body)) + require.False(t, isAntigravityInternalServerError(200, body)) + }) + + t.Run("body 中 message 不匹配", func(t *testing.T) { + body := []byte(`{"error":{"code":500,"message":"Some other error","status":"INTERNAL"}}`) + require.False(t, isAntigravityInternalServerError(500, body)) + }) + + t.Run("body 中 status 不匹配", func(t *testing.T) { + body := []byte(`{"error":{"code":500,"message":"Internal error encountered.","status":"UNAVAILABLE"}}`) + require.False(t, isAntigravityInternalServerError(500, body)) + }) + + t.Run("body 中 code 不匹配", func(t *testing.T) { + body := []byte(`{"error":{"code":503,"message":"Internal error encountered.","status":"INTERNAL"}}`) + require.False(t, isAntigravityInternalServerError(500, body)) + }) + + t.Run("空 body", func(t *testing.T) { + require.False(t, isAntigravityInternalServerError(500, []byte{})) + require.False(t, isAntigravityInternalServerError(500, nil)) + }) + + t.Run("其他 500 错误格式(纯文本)", func(t *testing.T) { + body := []byte(`Internal Server Error`) + require.False(t, isAntigravityInternalServerError(500, body)) + }) + + t.Run("其他 500 错误格式(不同 JSON 结构)", func(t *testing.T) { + body := []byte(`{"message":"Internal Server Error","statusCode":500}`) + require.False(t, isAntigravityInternalServerError(500, body)) + }) +} + +// ============================================================================= +// TestApplyInternal500Penalty +// ============================================================================= + +func TestApplyInternal500Penalty(t *testing.T) { + t.Run("count=1 → SetTempUnschedulable 10 分钟", func(t *testing.T) { + repo := &internal500AccountRepoStub{} + svc := &AntigravityGatewayService{accountRepo: repo} + account := &Account{ID: 1, Name: "acc-1"} + + before := time.Now() + svc.applyInternal500Penalty(context.Background(), "[test]", account, 1) + after := time.Now() + + require.Len(t, repo.tempUnschedCalls, 1) + require.Empty(t, repo.setErrorCalls) + + call := repo.tempUnschedCalls[0] + require.Equal(t, int64(1), call.accountID) + require.Contains(t, call.reason, "INTERNAL 500") + // until 应在 [before+10m, after+10m] 范围内 + require.True(t, call.until.After(before.Add(internal500PenaltyTier1Duration).Add(-time.Second))) + require.True(t, call.until.Before(after.Add(internal500PenaltyTier1Duration).Add(time.Second))) + }) + + t.Run("count=2 → SetTempUnschedulable 10 小时", func(t *testing.T) { + repo := &internal500AccountRepoStub{} + svc := &AntigravityGatewayService{accountRepo: repo} + account := &Account{ID: 2, Name: "acc-2"} + + before := time.Now() + svc.applyInternal500Penalty(context.Background(), "[test]", account, 2) + after := time.Now() + + require.Len(t, repo.tempUnschedCalls, 1) + require.Empty(t, repo.setErrorCalls) + + call := repo.tempUnschedCalls[0] + require.Equal(t, int64(2), call.accountID) + require.Contains(t, call.reason, "INTERNAL 500") + require.True(t, call.until.After(before.Add(internal500PenaltyTier2Duration).Add(-time.Second))) + require.True(t, call.until.Before(after.Add(internal500PenaltyTier2Duration).Add(time.Second))) + }) + + t.Run("count=3 → SetError 永久禁用", func(t *testing.T) { + repo := &internal500AccountRepoStub{} + svc := &AntigravityGatewayService{accountRepo: repo} + account := &Account{ID: 3, Name: "acc-3"} + + svc.applyInternal500Penalty(context.Background(), "[test]", account, 3) + + require.Empty(t, repo.tempUnschedCalls) + require.Len(t, repo.setErrorCalls, 1) + + call := repo.setErrorCalls[0] + require.Equal(t, int64(3), call.accountID) + require.Contains(t, call.reason, "INTERNAL 500 consecutive failures: 3") + }) + + t.Run("count=5 → SetError 永久禁用(>=3 都走永久禁用)", func(t *testing.T) { + repo := &internal500AccountRepoStub{} + svc := &AntigravityGatewayService{accountRepo: repo} + account := &Account{ID: 5, Name: "acc-5"} + + svc.applyInternal500Penalty(context.Background(), "[test]", account, 5) + + require.Empty(t, repo.tempUnschedCalls) + require.Len(t, repo.setErrorCalls, 1) + + call := repo.setErrorCalls[0] + require.Equal(t, int64(5), call.accountID) + require.Contains(t, call.reason, "INTERNAL 500 consecutive failures: 5") + }) + + t.Run("count=0 → 不调用任何方法", func(t *testing.T) { + repo := &internal500AccountRepoStub{} + svc := &AntigravityGatewayService{accountRepo: repo} + account := &Account{ID: 10, Name: "acc-10"} + + svc.applyInternal500Penalty(context.Background(), "[test]", account, 0) + + require.Empty(t, repo.tempUnschedCalls) + require.Empty(t, repo.setErrorCalls) + }) +} + +// ============================================================================= +// TestHandleInternal500RetryExhausted +// ============================================================================= + +func TestHandleInternal500RetryExhausted(t *testing.T) { + t.Run("internal500Cache 为 nil → 不 panic,不调用任何方法", func(t *testing.T) { + repo := &internal500AccountRepoStub{} + svc := &AntigravityGatewayService{ + accountRepo: repo, + internal500Cache: nil, + } + account := &Account{ID: 1, Name: "acc-1"} + + // 不应 panic + require.NotPanics(t, func() { + svc.handleInternal500RetryExhausted(context.Background(), "[test]", account) + }) + require.Empty(t, repo.tempUnschedCalls) + require.Empty(t, repo.setErrorCalls) + }) + + t.Run("IncrementInternal500Count 返回 error → 不调用惩罚方法", func(t *testing.T) { + repo := &internal500AccountRepoStub{} + cache := &mockInternal500Cache{ + incrementErr: errors.New("redis connection error"), + } + svc := &AntigravityGatewayService{ + accountRepo: repo, + internal500Cache: cache, + } + account := &Account{ID: 2, Name: "acc-2"} + + svc.handleInternal500RetryExhausted(context.Background(), "[test]", account) + + require.Len(t, cache.incrementCalls, 1) + require.Equal(t, int64(2), cache.incrementCalls[0]) + require.Empty(t, repo.tempUnschedCalls) + require.Empty(t, repo.setErrorCalls) + }) + + t.Run("IncrementInternal500Count 返回 count=1 → 触发 tier1 惩罚", func(t *testing.T) { + repo := &internal500AccountRepoStub{} + cache := &mockInternal500Cache{ + incrementCount: 1, + } + svc := &AntigravityGatewayService{ + accountRepo: repo, + internal500Cache: cache, + } + account := &Account{ID: 3, Name: "acc-3"} + + svc.handleInternal500RetryExhausted(context.Background(), "[test]", account) + + require.Len(t, cache.incrementCalls, 1) + require.Equal(t, int64(3), cache.incrementCalls[0]) + // tier1: SetTempUnschedulable + require.Len(t, repo.tempUnschedCalls, 1) + require.Equal(t, int64(3), repo.tempUnschedCalls[0].accountID) + require.Empty(t, repo.setErrorCalls) + }) + + t.Run("IncrementInternal500Count 返回 count=3 → 触发 tier3 永久禁用", func(t *testing.T) { + repo := &internal500AccountRepoStub{} + cache := &mockInternal500Cache{ + incrementCount: 3, + } + svc := &AntigravityGatewayService{ + accountRepo: repo, + internal500Cache: cache, + } + account := &Account{ID: 4, Name: "acc-4"} + + svc.handleInternal500RetryExhausted(context.Background(), "[test]", account) + + require.Len(t, cache.incrementCalls, 1) + require.Empty(t, repo.tempUnschedCalls) + require.Len(t, repo.setErrorCalls, 1) + require.Equal(t, int64(4), repo.setErrorCalls[0].accountID) + }) +} + +// ============================================================================= +// TestResetInternal500Counter +// ============================================================================= + +func TestResetInternal500Counter(t *testing.T) { + t.Run("internal500Cache 为 nil → 不 panic", func(t *testing.T) { + svc := &AntigravityGatewayService{ + internal500Cache: nil, + } + + require.NotPanics(t, func() { + svc.resetInternal500Counter(context.Background(), "[test]", 1) + }) + }) + + t.Run("ResetInternal500Count 返回 error → 不 panic(仅日志)", func(t *testing.T) { + cache := &mockInternal500Cache{ + resetErr: errors.New("redis timeout"), + } + svc := &AntigravityGatewayService{ + internal500Cache: cache, + } + + require.NotPanics(t, func() { + svc.resetInternal500Counter(context.Background(), "[test]", 42) + }) + require.Len(t, cache.resetCalls, 1) + require.Equal(t, int64(42), cache.resetCalls[0]) + }) + + t.Run("正常调用 → 调用 ResetInternal500Count", func(t *testing.T) { + cache := &mockInternal500Cache{} + svc := &AntigravityGatewayService{ + internal500Cache: cache, + } + + svc.resetInternal500Counter(context.Background(), "[test]", 99) + + require.Len(t, cache.resetCalls, 1) + require.Equal(t, int64(99), cache.resetCalls[0]) + }) +} diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 5b7a97b0..b54f463b 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -12,6 +12,7 @@ import ( "log/slog" mathrand "math/rand" "net/http" + "net/url" "os" "path/filepath" "regexp" @@ -368,6 +369,8 @@ var allowedHeaders = map[string]bool{ "user-agent": true, "content-type": true, "accept-encoding": true, + "x-claude-code-session-id": true, + "x-client-request-id": true, } // GatewayCache 定义网关服务的缓存操作接口。 @@ -4150,10 +4153,12 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A return nil, err } - // 获取代理URL + // 获取代理URL(自定义 base URL 模式下,proxy 通过 buildCustomRelayURL 作为查询参数传递) proxyURL := "" if account.ProxyID != nil && account.Proxy != nil { - proxyURL = account.Proxy.URL() + if !account.IsCustomBaseURLEnabled() || account.GetCustomBaseURL() == "" { + proxyURL = account.Proxy.URL() + } } // 解析 TLS 指纹 profile(同一请求生命周期内不变,避免重试循环中重复解析) @@ -5628,6 +5633,16 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex } targetURL = validatedURL + "/v1/messages?beta=true" } + } else if account.IsCustomBaseURLEnabled() { + customURL := account.GetCustomBaseURL() + if customURL == "" { + return nil, fmt.Errorf("custom_base_url is enabled but not configured for account %d", account.ID) + } + validatedURL, err := s.validateUpstreamBaseURL(customURL) + if err != nil { + return nil, err + } + targetURL = s.buildCustomRelayURL(validatedURL, "/v1/messages", account) } clientHeaders := http.Header{} @@ -5743,6 +5758,15 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex } } + // 同步 X-Claude-Code-Session-Id 头:取 body 中已处理的 metadata.user_id 的 session_id 覆盖 + if sessionHeader := getHeaderRaw(req.Header, "X-Claude-Code-Session-Id"); sessionHeader != "" { + if uid := gjson.GetBytes(body, "metadata.user_id").String(); uid != "" { + if parsed := ParseMetadataUserID(uid); parsed != nil { + setHeaderRaw(req.Header, "X-Claude-Code-Session-Id", parsed.SessionID) + } + } + } + // === DEBUG: 打印上游转发请求(headers + body 摘要),与 CLIENT_ORIGINAL 对比 === s.debugLogGatewaySnapshot("UPSTREAM_FORWARD", req.Header, body, map[string]string{ "url": req.URL.String(), @@ -8063,10 +8087,12 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, return err } - // 获取代理URL + // 获取代理URL(自定义 base URL 模式下,proxy 通过 buildCustomRelayURL 作为查询参数传递) proxyURL := "" if account.ProxyID != nil && account.Proxy != nil { - proxyURL = account.Proxy.URL() + if !account.IsCustomBaseURLEnabled() || account.GetCustomBaseURL() == "" { + proxyURL = account.Proxy.URL() + } } // 发送请求 @@ -8345,6 +8371,16 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con } targetURL = validatedURL + "/v1/messages/count_tokens?beta=true" } + } else if account.IsCustomBaseURLEnabled() { + customURL := account.GetCustomBaseURL() + if customURL == "" { + return nil, fmt.Errorf("custom_base_url is enabled but not configured for account %d", account.ID) + } + validatedURL, err := s.validateUpstreamBaseURL(customURL) + if err != nil { + return nil, err + } + targetURL = s.buildCustomRelayURL(validatedURL, "/v1/messages/count_tokens", account) } clientHeaders := http.Header{} @@ -8450,6 +8486,15 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con } } + // 同步 X-Claude-Code-Session-Id 头:取 body 中已处理的 metadata.user_id 的 session_id 覆盖 + if sessionHeader := getHeaderRaw(req.Header, "X-Claude-Code-Session-Id"); sessionHeader != "" { + if uid := gjson.GetBytes(body, "metadata.user_id").String(); uid != "" { + if parsed := ParseMetadataUserID(uid); parsed != nil { + setHeaderRaw(req.Header, "X-Claude-Code-Session-Id", parsed.SessionID) + } + } + } + if c != nil && tokenType == "oauth" { c.Set(claudeMimicDebugInfoKey, buildClaudeMimicDebugLine(req, body, account, tokenType, mimicClaudeCode)) } @@ -8471,6 +8516,19 @@ func (s *GatewayService) countTokensError(c *gin.Context, status int, errType, m }) } +// buildCustomRelayURL 构建自定义中继转发 URL +// 在 path 后附加 beta=true 和可选的 proxy 查询参数 +func (s *GatewayService) buildCustomRelayURL(baseURL, path string, account *Account) string { + u := strings.TrimRight(baseURL, "/") + path + "?beta=true" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL := account.Proxy.URL() + if proxyURL != "" { + u += "&proxy=" + url.QueryEscape(proxyURL) + } + } + return u +} + func (s *GatewayService) validateUpstreamBaseURL(raw string) (string, error) { if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled { normalized, err := urlvalidator.ValidateURLFormat(raw, s.cfg.Security.URLAllowlist.AllowInsecureHTTP) diff --git a/backend/internal/service/header_util.go b/backend/internal/service/header_util.go index 6acfee5a..1091070d 100644 --- a/backend/internal/service/header_util.go +++ b/backend/internal/service/header_util.go @@ -36,6 +36,11 @@ var headerWireCasing = map[string]string{ "sec-fetch-mode": "sec-fetch-mode", "accept-encoding": "accept-encoding", "authorization": "authorization", + + // Claude Code 2.1.87+ 新增 header + "x-claude-code-session-id": "X-Claude-Code-Session-Id", + "x-client-request-id": "x-client-request-id", + "content-length": "content-length", } // headerWireOrder 定义真实 Claude CLI 发送 header 的顺序(基于抓包)。 @@ -55,11 +60,14 @@ var headerWireOrder = []string{ "authorization", "x-app", "User-Agent", + "X-Claude-Code-Session-Id", "content-type", "anthropic-beta", + "x-client-request-id", "accept-language", "sec-fetch-mode", "accept-encoding", + "content-length", "x-stainless-helper-method", } diff --git a/backend/internal/service/internal500_counter.go b/backend/internal/service/internal500_counter.go new file mode 100644 index 00000000..0f0bc50c --- /dev/null +++ b/backend/internal/service/internal500_counter.go @@ -0,0 +1,11 @@ +package service + +import "context" + +// Internal500CounterCache 追踪 Antigravity 账号连续 INTERNAL 500 失败轮数 +type Internal500CounterCache interface { + // IncrementInternal500Count 原子递增计数并返回当前值 + IncrementInternal500Count(ctx context.Context, accountID int64) (int64, error) + // ResetInternal500Count 清零计数器(成功响应时调用) + ResetInternal500Count(ctx context.Context, accountID int64) error +} diff --git a/backend/internal/service/openai_compat_model.go b/backend/internal/service/openai_compat_model.go new file mode 100644 index 00000000..5f140d01 --- /dev/null +++ b/backend/internal/service/openai_compat_model.go @@ -0,0 +1,103 @@ +package service + +import ( + "strings" + + "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" +) + +func NormalizeOpenAICompatRequestedModel(model string) string { + trimmed := strings.TrimSpace(model) + if trimmed == "" { + return "" + } + + normalized, _, ok := splitOpenAICompatReasoningModel(trimmed) + if !ok || normalized == "" { + return trimmed + } + return normalized +} + +func applyOpenAICompatModelNormalization(req *apicompat.AnthropicRequest) { + if req == nil { + return + } + + originalModel := strings.TrimSpace(req.Model) + if originalModel == "" { + return + } + + normalizedModel, derivedEffort, hasReasoningSuffix := splitOpenAICompatReasoningModel(originalModel) + if hasReasoningSuffix && normalizedModel != "" { + req.Model = normalizedModel + } + + if req.OutputConfig != nil && strings.TrimSpace(req.OutputConfig.Effort) != "" { + return + } + + claudeEffort := openAIReasoningEffortToClaudeOutputEffort(derivedEffort) + if claudeEffort == "" { + return + } + + if req.OutputConfig == nil { + req.OutputConfig = &apicompat.AnthropicOutputConfig{} + } + req.OutputConfig.Effort = claudeEffort +} + +func splitOpenAICompatReasoningModel(model string) (normalizedModel string, reasoningEffort string, ok bool) { + trimmed := strings.TrimSpace(model) + if trimmed == "" { + return "", "", false + } + + modelID := trimmed + if strings.Contains(modelID, "/") { + parts := strings.Split(modelID, "/") + modelID = parts[len(parts)-1] + } + modelID = strings.TrimSpace(modelID) + if !strings.HasPrefix(strings.ToLower(modelID), "gpt-") { + return trimmed, "", false + } + + parts := strings.FieldsFunc(strings.ToLower(modelID), func(r rune) bool { + switch r { + case '-', '_', ' ': + return true + default: + return false + } + }) + if len(parts) == 0 { + return trimmed, "", false + } + + last := strings.NewReplacer("-", "", "_", "", " ", "").Replace(parts[len(parts)-1]) + switch last { + case "none", "minimal": + case "low", "medium", "high": + reasoningEffort = last + case "xhigh", "extrahigh": + reasoningEffort = "xhigh" + default: + return trimmed, "", false + } + + return normalizeCodexModel(modelID), reasoningEffort, true +} + +func openAIReasoningEffortToClaudeOutputEffort(effort string) string { + switch strings.TrimSpace(effort) { + case "low", "medium", "high": + return effort + case "xhigh": + return "max" + default: + return "" + } +} diff --git a/backend/internal/service/openai_compat_model_test.go b/backend/internal/service/openai_compat_model_test.go new file mode 100644 index 00000000..32c646d4 --- /dev/null +++ b/backend/internal/service/openai_compat_model_test.go @@ -0,0 +1,129 @@ +package service + +import ( + "bytes" + "context" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +func TestNormalizeOpenAICompatRequestedModel(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + want string + }{ + {name: "gpt reasoning alias strips xhigh", input: "gpt-5.4-xhigh", want: "gpt-5.4"}, + {name: "gpt reasoning alias strips none", input: "gpt-5.4-none", want: "gpt-5.4"}, + {name: "codex max model stays intact", input: "gpt-5.1-codex-max", want: "gpt-5.1-codex-max"}, + {name: "non openai model unchanged", input: "claude-opus-4-6", want: "claude-opus-4-6"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.want, NormalizeOpenAICompatRequestedModel(tt.input)) + }) + } +} + +func TestApplyOpenAICompatModelNormalization(t *testing.T) { + t.Parallel() + + t.Run("derives xhigh from model suffix when output config missing", func(t *testing.T) { + req := &apicompat.AnthropicRequest{Model: "gpt-5.4-xhigh"} + + applyOpenAICompatModelNormalization(req) + + require.Equal(t, "gpt-5.4", req.Model) + require.NotNil(t, req.OutputConfig) + require.Equal(t, "max", req.OutputConfig.Effort) + }) + + t.Run("explicit output config wins over model suffix", func(t *testing.T) { + req := &apicompat.AnthropicRequest{ + Model: "gpt-5.4-xhigh", + OutputConfig: &apicompat.AnthropicOutputConfig{Effort: "low"}, + } + + applyOpenAICompatModelNormalization(req) + + require.Equal(t, "gpt-5.4", req.Model) + require.NotNil(t, req.OutputConfig) + require.Equal(t, "low", req.OutputConfig.Effort) + }) + + t.Run("non openai model is untouched", func(t *testing.T) { + req := &apicompat.AnthropicRequest{Model: "claude-opus-4-6"} + + applyOpenAICompatModelNormalization(req) + + require.Equal(t, "claude-opus-4-6", req.Model) + require.Nil(t, req.OutputConfig) + }) +} + +func TestForwardAsAnthropic_NormalizesRoutingAndEffortForGpt54XHigh(t *testing.T) { + t.Parallel() + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + body := []byte(`{"model":"gpt-5.4-xhigh","max_tokens":16,"messages":[{"role":"user","content":"hello"}],"stream":false}`) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + upstreamBody := strings.Join([]string{ + `data: {"type":"response.completed","response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":5,"output_tokens":2,"total_tokens":7}}}`, + "", + "data: [DONE]", + "", + }, "\n") + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_compat"}}, + Body: io.NopCloser(strings.NewReader(upstreamBody)), + }} + + svc := &OpenAIGatewayService{httpUpstream: upstream} + account := &Account{ + ID: 1, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token", + "chatgpt_account_id": "chatgpt-acc", + "model_mapping": map[string]any{ + "gpt-5.4": "gpt-5.4", + }, + }, + } + + result, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "", "gpt-5.1") + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "gpt-5.4-xhigh", result.Model) + require.Equal(t, "gpt-5.4", result.UpstreamModel) + require.Equal(t, "gpt-5.4", result.BillingModel) + require.NotNil(t, result.ReasoningEffort) + require.Equal(t, "xhigh", *result.ReasoningEffort) + + require.Equal(t, "gpt-5.4", gjson.GetBytes(upstream.lastBody, "model").String()) + require.Equal(t, "xhigh", gjson.GetBytes(upstream.lastBody, "reasoning.effort").String()) + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, "gpt-5.4-xhigh", gjson.GetBytes(rec.Body.Bytes(), "model").String()) + require.Equal(t, "ok", gjson.GetBytes(rec.Body.Bytes(), "content.0.text").String()) + t.Logf("upstream body: %s", string(upstream.lastBody)) + t.Logf("response body: %s", rec.Body.String()) +} diff --git a/backend/internal/service/openai_gateway_messages.go b/backend/internal/service/openai_gateway_messages.go index e9548b79..02efc23b 100644 --- a/backend/internal/service/openai_gateway_messages.go +++ b/backend/internal/service/openai_gateway_messages.go @@ -40,6 +40,7 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic( return nil, fmt.Errorf("parse anthropic request: %w", err) } originalModel := anthropicReq.Model + applyOpenAICompatModelNormalization(&anthropicReq) clientStream := anthropicReq.Stream // client's original stream preference // 2. Convert Anthropic → Responses diff --git a/backend/internal/service/openai_gateway_record_usage_test.go b/backend/internal/service/openai_gateway_record_usage_test.go index 5aa4db8a..7a636afa 100644 --- a/backend/internal/service/openai_gateway_record_usage_test.go +++ b/backend/internal/service/openai_gateway_record_usage_test.go @@ -895,14 +895,16 @@ func TestOpenAIGatewayServiceRecordUsage_UsesRequestedModelAndUpstreamModelMetad require.Equal(t, 1, userRepo.deductCalls) } -func TestOpenAIGatewayServiceRecordUsage_BillsMappedRequestsUsingUpstreamModelFallback(t *testing.T) { +func TestOpenAIGatewayServiceRecordUsage_BillsMappedRequestsUsingRequestedModel(t *testing.T) { usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} userRepo := &openAIRecordUsageUserRepoStub{} subRepo := &openAIRecordUsageSubRepoStub{} svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil) usage := OpenAIUsage{InputTokens: 20, OutputTokens: 10} - expectedCost, err := svc.billingService.CalculateCost("gpt-5.1-codex", UsageTokens{ + // Billing should use the requested model ("gpt-5.1"), not the upstream mapped model ("gpt-5.1-codex"). + // This ensures pricing is always based on the model the user requested. + expectedCost, err := svc.billingService.CalculateCost("gpt-5.1", UsageTokens{ InputTokens: 20, OutputTokens: 10, }, 1.1) diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index c7a74aed..e85f0705 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -4153,9 +4153,6 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec } billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel) - if result.BillingModel != "" { - billingModel = strings.TrimSpace(result.BillingModel) - } serviceTier := "" if result.ServiceTier != nil { serviceTier = strings.TrimSpace(*result.ServiceTier) diff --git a/backend/internal/service/openai_oauth_service.go b/backend/internal/service/openai_oauth_service.go index 0a1266d9..0f004b01 100644 --- a/backend/internal/service/openai_oauth_service.go +++ b/backend/internal/service/openai_oauth_service.go @@ -502,6 +502,25 @@ func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *A refreshToken := account.GetCredential("refresh_token") if refreshToken == "" { + accessToken := account.GetCredential("access_token") + if accessToken != "" { + tokenInfo := &OpenAITokenInfo{ + AccessToken: accessToken, + RefreshToken: "", + IDToken: account.GetCredential("id_token"), + ClientID: account.GetCredential("client_id"), + Email: account.GetCredential("email"), + ChatGPTAccountID: account.GetCredential("chatgpt_account_id"), + ChatGPTUserID: account.GetCredential("chatgpt_user_id"), + OrganizationID: account.GetCredential("organization_id"), + PlanType: account.GetCredential("plan_type"), + } + if expiresAt := account.GetCredentialAsTime("expires_at"); expiresAt != nil { + tokenInfo.ExpiresAt = expiresAt.Unix() + tokenInfo.ExpiresIn = int64(time.Until(*expiresAt).Seconds()) + } + return tokenInfo, nil + } return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_NO_REFRESH_TOKEN", "no refresh token available") } diff --git a/backend/internal/service/openai_oauth_service_refresh_test.go b/backend/internal/service/openai_oauth_service_refresh_test.go new file mode 100644 index 00000000..a31eb8cb --- /dev/null +++ b/backend/internal/service/openai_oauth_service_refresh_test.go @@ -0,0 +1,54 @@ +package service + +import ( + "context" + "errors" + "sync/atomic" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/openai" + "github.com/stretchr/testify/require" +) + +type openaiOAuthClientRefreshStub struct { + refreshCalls int32 +} + +func (s *openaiOAuthClientRefreshStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL, clientID string) (*openai.TokenResponse, error) { + return nil, errors.New("not implemented") +} + +func (s *openaiOAuthClientRefreshStub) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) { + atomic.AddInt32(&s.refreshCalls, 1) + return nil, errors.New("not implemented") +} + +func (s *openaiOAuthClientRefreshStub) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) { + atomic.AddInt32(&s.refreshCalls, 1) + return nil, errors.New("not implemented") +} + +func TestOpenAIOAuthService_RefreshAccountToken_NoRefreshTokenUsesExistingAccessToken(t *testing.T) { + client := &openaiOAuthClientRefreshStub{} + svc := NewOpenAIOAuthService(nil, client) + + expiresAt := time.Now().Add(30 * time.Minute).UTC().Format(time.RFC3339) + account := &Account{ + ID: 77, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "existing-access-token", + "expires_at": expiresAt, + "client_id": "client-id-1", + }, + } + + info, err := svc.RefreshAccountToken(context.Background(), account) + require.NoError(t, err) + require.NotNil(t, info) + require.Equal(t, "existing-access-token", info.AccessToken) + require.Equal(t, "client-id-1", info.ClientID) + require.Zero(t, atomic.LoadInt32(&client.refreshCalls), "existing access token should be reused without calling refresh") +} diff --git a/backend/internal/service/pricing_service.go b/backend/internal/service/pricing_service.go index 10440c60..5623d4b7 100644 --- a/backend/internal/service/pricing_service.go +++ b/backend/internal/service/pricing_service.go @@ -189,10 +189,38 @@ func (s *PricingService) checkAndUpdatePricing() error { return s.downloadPricingData() } - // 检查文件是否过期 + // 先加载本地文件(确保服务可用),再检查是否需要更新 + if err := s.loadPricingData(pricingFile); err != nil { + logger.LegacyPrintf("service.pricing", "[Pricing] Failed to load local file, downloading: %v", err) + return s.downloadPricingData() + } + + // 如果配置了哈希URL,通过远程哈希检查是否有更新 + if s.cfg.Pricing.HashURL != "" { + remoteHash, err := s.fetchRemoteHash() + if err != nil { + logger.LegacyPrintf("service.pricing", "[Pricing] Failed to fetch remote hash on startup: %v", err) + return nil // 已加载本地文件,哈希获取失败不影响启动 + } + + s.mu.RLock() + localHash := s.localHash + s.mu.RUnlock() + + if localHash == "" || remoteHash != localHash { + logger.LegacyPrintf("service.pricing", "[Pricing] Remote hash differs on startup (local=%s remote=%s), downloading...", + localHash[:min(8, len(localHash))], remoteHash[:min(8, len(remoteHash))]) + if err := s.downloadPricingData(); err != nil { + logger.LegacyPrintf("service.pricing", "[Pricing] Download failed, using existing file: %v", err) + } + } + return nil + } + + // 没有哈希URL时,基于文件年龄检查 info, err := os.Stat(pricingFile) if err != nil { - return s.downloadPricingData() + return nil // 已加载本地文件 } fileAge := time.Since(info.ModTime()) @@ -205,21 +233,11 @@ func (s *PricingService) checkAndUpdatePricing() error { } } - // 加载本地文件 - return s.loadPricingData(pricingFile) + return nil } // syncWithRemote 与远程同步(基于哈希校验) func (s *PricingService) syncWithRemote() error { - pricingFile := s.getPricingFilePath() - - // 计算本地文件哈希 - localHash, err := s.computeFileHash(pricingFile) - if err != nil { - logger.LegacyPrintf("service.pricing", "[Pricing] Failed to compute local hash: %v", err) - return s.downloadPricingData() - } - // 如果配置了哈希URL,从远程获取哈希进行比对 if s.cfg.Pricing.HashURL != "" { remoteHash, err := s.fetchRemoteHash() @@ -228,8 +246,13 @@ func (s *PricingService) syncWithRemote() error { return nil // 哈希获取失败不影响正常使用 } - if remoteHash != localHash { - logger.LegacyPrintf("service.pricing", "%s", "[Pricing] Remote hash differs, downloading new version...") + s.mu.RLock() + localHash := s.localHash + s.mu.RUnlock() + + if localHash == "" || remoteHash != localHash { + logger.LegacyPrintf("service.pricing", "[Pricing] Remote hash differs (local=%s remote=%s), downloading new version...", + localHash[:min(8, len(localHash))], remoteHash[:min(8, len(remoteHash))]) return s.downloadPricingData() } logger.LegacyPrintf("service.pricing", "%s", "[Pricing] Hash check passed, no update needed") @@ -237,6 +260,7 @@ func (s *PricingService) syncWithRemote() error { } // 没有哈希URL时,基于时间检查 + pricingFile := s.getPricingFilePath() info, err := os.Stat(pricingFile) if err != nil { return s.downloadPricingData() @@ -264,11 +288,12 @@ func (s *PricingService) downloadPricingData() error { ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() - var expectedHash string + // 获取远程哈希(用于同步锚点,不作为完整性校验) + var remoteHash string if strings.TrimSpace(s.cfg.Pricing.HashURL) != "" { - expectedHash, err = s.fetchRemoteHash() + remoteHash, err = s.fetchRemoteHash() if err != nil { - return fmt.Errorf("fetch remote hash: %w", err) + logger.LegacyPrintf("service.pricing", "[Pricing] Failed to fetch remote hash (continuing): %v", err) } } @@ -277,11 +302,13 @@ func (s *PricingService) downloadPricingData() error { return fmt.Errorf("download failed: %w", err) } - if expectedHash != "" { - actualHash := sha256.Sum256(body) - if !strings.EqualFold(expectedHash, hex.EncodeToString(actualHash[:])) { - return fmt.Errorf("pricing hash mismatch") - } + // 哈希校验:不匹配时仅告警,不阻止更新 + // 远程哈希文件可能与数据文件不同步(如维护者更新了数据但未更新哈希文件) + dataHash := sha256.Sum256(body) + dataHashStr := hex.EncodeToString(dataHash[:]) + if remoteHash != "" && !strings.EqualFold(remoteHash, dataHashStr) { + logger.LegacyPrintf("service.pricing", "[Pricing] Hash mismatch warning: remote=%s data=%s (hash file may be out of sync)", + remoteHash[:min(8, len(remoteHash))], dataHashStr[:8]) } // 解析JSON数据(使用灵活的解析方式) @@ -296,11 +323,14 @@ func (s *PricingService) downloadPricingData() error { logger.LegacyPrintf("service.pricing", "[Pricing] Failed to save file: %v", err) } - // 保存哈希 - hash := sha256.Sum256(body) - hashStr := hex.EncodeToString(hash[:]) + // 使用远程哈希作为同步锚点,防止重复下载 + // 当远程哈希不可用时,回退到数据本身的哈希 + syncHash := dataHashStr + if remoteHash != "" { + syncHash = remoteHash + } hashFile := s.getHashFilePath() - if err := os.WriteFile(hashFile, []byte(hashStr+"\n"), 0644); err != nil { + if err := os.WriteFile(hashFile, []byte(syncHash+"\n"), 0644); err != nil { logger.LegacyPrintf("service.pricing", "[Pricing] Failed to save hash: %v", err) } @@ -308,7 +338,7 @@ func (s *PricingService) downloadPricingData() error { s.mu.Lock() s.pricingData = data s.lastUpdated = time.Now() - s.localHash = hashStr + s.localHash = syncHash s.mu.Unlock() logger.LegacyPrintf("service.pricing", "[Pricing] Downloaded %d models successfully", len(data)) @@ -486,16 +516,6 @@ func (s *PricingService) validatePricingURL(raw string) (string, error) { return normalized, nil } -// computeFileHash 计算文件哈希 -func (s *PricingService) computeFileHash(filePath string) (string, error) { - data, err := os.ReadFile(filePath) - if err != nil { - return "", err - } - hash := sha256.Sum256(data) - return hex.EncodeToString(hash[:]), nil -} - // GetModelPricing 获取模型价格(带模糊匹配) func (s *PricingService) GetModelPricing(modelName string) *LiteLLMModelPricing { s.mu.RLock() diff --git a/backend/internal/service/token_refresh_service.go b/backend/internal/service/token_refresh_service.go index eb3e5592..d39095ea 100644 --- a/backend/internal/service/token_refresh_service.go +++ b/backend/internal/service/token_refresh_service.go @@ -32,8 +32,9 @@ type TokenRefreshService struct { privacyClientFactory PrivacyClientFactory proxyRepo ProxyRepository - stopCh chan struct{} - wg sync.WaitGroup + stopCh chan struct{} + stopOnce sync.Once + wg sync.WaitGroup } // NewTokenRefreshService 创建token刷新服务 @@ -130,7 +131,9 @@ func (s *TokenRefreshService) Start() { // Stop 停止刷新服务(可安全多次调用) func (s *TokenRefreshService) Stop() { - close(s.stopCh) + s.stopOnce.Do(func() { + close(s.stopCh) + }) s.wg.Wait() slog.Info("token_refresh.service_stopped") } @@ -430,6 +433,7 @@ func isNonRetryableRefreshError(err error) bool { "unauthorized_client", // 客户端未授权 "access_denied", // 访问被拒绝 "missing_project_id", // 缺少 project_id + "no refresh token available", } for _, needle := range nonRetryable { if strings.Contains(msg, needle) { diff --git a/backend/internal/service/token_refresh_service_test.go b/backend/internal/service/token_refresh_service_test.go index 60ba4a96..2179a85e 100644 --- a/backend/internal/service/token_refresh_service_test.go +++ b/backend/internal/service/token_refresh_service_test.go @@ -19,6 +19,7 @@ type tokenRefreshAccountRepo struct { updateCredentialsCalls int setErrorCalls int clearTempCalls int + setTempUnschedCalls int lastAccount *Account updateErr error } @@ -58,6 +59,11 @@ func (r *tokenRefreshAccountRepo) ClearTempUnschedulable(ctx context.Context, id return nil } +func (r *tokenRefreshAccountRepo) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error { + r.setTempUnschedCalls++ + return nil +} + type tokenCacheInvalidatorStub struct { calls int err error @@ -490,6 +496,31 @@ func TestTokenRefreshService_RefreshWithRetry_NonRetryableErrorAllPlatforms(t *t } } +func TestTokenRefreshService_RefreshWithRetry_NoRefreshTokenDoesNotTempUnschedule(t *testing.T) { + repo := &tokenRefreshAccountRepo{} + cfg := &config.Config{ + TokenRefresh: config.TokenRefreshConfig{ + MaxRetries: 2, + RetryBackoffSeconds: 0, + }, + } + service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, cfg, nil) + account := &Account{ + ID: 18, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + } + refresher := &tokenRefresherStub{ + err: errors.New("no refresh token available"), + } + + err := service.refreshWithRetry(context.Background(), account, refresher, refresher, time.Hour) + require.Error(t, err) + require.Equal(t, 0, repo.updateCalls) + require.Equal(t, 0, repo.setTempUnschedCalls, "missing refresh token should not mark the account temp unschedulable") + require.Equal(t, 1, repo.setErrorCalls, "missing refresh token should be treated as a non-retryable credential state") +} + // TestIsNonRetryableRefreshError 测试不可重试错误判断 func TestIsNonRetryableRefreshError(t *testing.T) { tests := []struct { @@ -503,6 +534,7 @@ func TestIsNonRetryableRefreshError(t *testing.T) { {name: "invalid_client", err: errors.New("invalid_client"), expected: true}, {name: "unauthorized_client", err: errors.New("unauthorized_client"), expected: true}, {name: "access_denied", err: errors.New("access_denied"), expected: true}, + {name: "no_refresh_token", err: errors.New("no refresh token available"), expected: true}, {name: "invalid_grant_with_desc", err: errors.New("Error: invalid_grant - token revoked"), expected: true}, {name: "case_insensitive", err: errors.New("INVALID_GRANT"), expected: true}, } diff --git a/backend/internal/service/usage_log_helpers.go b/backend/internal/service/usage_log_helpers.go index 57c51540..a7bcae99 100644 --- a/backend/internal/service/usage_log_helpers.go +++ b/backend/internal/service/usage_log_helpers.go @@ -21,8 +21,8 @@ func optionalNonEqualStringPtr(value, compare string) *string { } func forwardResultBillingModel(requestedModel, upstreamModel string) string { - if trimmedUpstream := strings.TrimSpace(upstreamModel); trimmedUpstream != "" { - return trimmedUpstream + if trimmed := strings.TrimSpace(requestedModel); trimmed != "" { + return trimmed } - return strings.TrimSpace(requestedModel) + return strings.TrimSpace(upstreamModel) } diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index 2058ced1..8f60acd5 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -865,10 +865,10 @@ rate_limit: pricing: # URL to fetch model pricing data (default: pinned model-price-repo commit) # 获取模型定价数据的 URL(默认:固定 commit 的 model-price-repo) - remote_url: "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/c7947e9871687e664180bc971d4837f1fc2784a9/model_prices_and_context_window.json" + remote_url: "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/refs/heads/main//model_prices_and_context_window.json" # Hash verification URL (optional) # 哈希校验 URL(可选) - hash_url: "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/c7947e9871687e664180bc971d4837f1fc2784a9/model_prices_and_context_window.sha256" + hash_url: "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/refs/heads/main//model_prices_and_context_window.sha256" # Local data directory for caching # 本地数据缓存目录 data_dir: "./data" diff --git a/frontend/src/components/account/CreateAccountModal.vue b/frontend/src/components/account/CreateAccountModal.vue index 806b57db..7ffa453f 100644 --- a/frontend/src/components/account/CreateAccountModal.vue +++ b/frontend/src/components/account/CreateAccountModal.vue @@ -2245,6 +2245,41 @@

+ + +
+
+
+ +

+ {{ t('admin.accounts.quotaControl.customBaseUrl.hint') }} +

+
+ +
+
+ +
+
@@ -3095,6 +3130,8 @@ const tlsFingerprintProfiles = ref<{ id: number; name: string }[]>([]) const sessionIdMaskingEnabled = ref(false) const cacheTTLOverrideEnabled = ref(false) const cacheTTLOverrideTarget = ref('5m') +const customBaseUrlEnabled = ref(false) +const customBaseUrl = ref('') // Gemini tier selection (used as fallback when auto-detection is unavailable/fails) const geminiTierGoogleOne = ref<'google_one_free' | 'google_ai_pro' | 'google_ai_ultra'>('google_one_free') @@ -3765,6 +3802,8 @@ const resetForm = () => { sessionIdMaskingEnabled.value = false cacheTTLOverrideEnabled.value = false cacheTTLOverrideTarget.value = '5m' + customBaseUrlEnabled.value = false + customBaseUrl.value = '' allowOverages.value = false antigravityAccountType.value = 'oauth' upstreamBaseUrl.value = '' @@ -4856,6 +4895,12 @@ const handleAnthropicExchange = async (authCode: string) => { extra.cache_ttl_override_target = cacheTTLOverrideTarget.value } + // Add custom base URL settings + if (customBaseUrlEnabled.value && customBaseUrl.value.trim()) { + extra.custom_base_url_enabled = true + extra.custom_base_url = customBaseUrl.value.trim() + } + const credentials: Record = { ...tokenInfo } applyInterceptWarmup(credentials, interceptWarmupRequests.value, 'create') await createAccountAndFinish(form.platform, addMethod.value as AccountType, credentials, extra) @@ -4974,6 +5019,12 @@ const handleCookieAuth = async (sessionKey: string) => { extra.cache_ttl_override_target = cacheTTLOverrideTarget.value } + // Add custom base URL settings + if (customBaseUrlEnabled.value && customBaseUrl.value.trim()) { + extra.custom_base_url_enabled = true + extra.custom_base_url = customBaseUrl.value.trim() + } + const accountName = keys.length > 1 ? `${form.name} #${i + 1}` : form.name const credentials: Record = { ...tokenInfo } diff --git a/frontend/src/components/account/EditAccountModal.vue b/frontend/src/components/account/EditAccountModal.vue index da6c9715..607e7a69 100644 --- a/frontend/src/components/account/EditAccountModal.vue +++ b/frontend/src/components/account/EditAccountModal.vue @@ -1580,6 +1580,41 @@

+ + +
+
+
+ +

+ {{ t('admin.accounts.quotaControl.customBaseUrl.hint') }} +

+
+ +
+
+ +
+
@@ -1854,6 +1889,8 @@ const tlsFingerprintProfiles = ref<{ id: number; name: string }[]>([]) const sessionIdMaskingEnabled = ref(false) const cacheTTLOverrideEnabled = ref(false) const cacheTTLOverrideTarget = ref('5m') +const customBaseUrlEnabled = ref(false) +const customBaseUrl = ref('') // OpenAI 自动透传开关(OAuth/API Key) const openaiPassthroughEnabled = ref(false) @@ -2482,6 +2519,8 @@ function loadQuotaControlSettings(account: Account) { sessionIdMaskingEnabled.value = false cacheTTLOverrideEnabled.value = false cacheTTLOverrideTarget.value = '5m' + customBaseUrlEnabled.value = false + customBaseUrl.value = '' // Only applies to Anthropic OAuth/SetupToken accounts if (account.platform !== 'anthropic' || (account.type !== 'oauth' && account.type !== 'setup-token')) { @@ -2528,6 +2567,12 @@ function loadQuotaControlSettings(account: Account) { cacheTTLOverrideEnabled.value = true cacheTTLOverrideTarget.value = account.cache_ttl_override_target || '5m' } + + // Load custom base URL setting + if (account.custom_base_url_enabled === true) { + customBaseUrlEnabled.value = true + customBaseUrl.value = account.custom_base_url || '' + } } function formatTempUnschedKeywords(value: unknown) { @@ -2980,6 +3025,15 @@ const handleSubmit = async () => { delete newExtra.cache_ttl_override_target } + // Custom base URL relay setting + if (customBaseUrlEnabled.value && customBaseUrl.value.trim()) { + newExtra.custom_base_url_enabled = true + newExtra.custom_base_url = customBaseUrl.value.trim() + } else { + delete newExtra.custom_base_url_enabled + delete newExtra.custom_base_url + } + updatePayload.extra = newExtra } diff --git a/frontend/src/components/charts/TokenUsageTrend.vue b/frontend/src/components/charts/TokenUsageTrend.vue index a255fb03..4cd126b9 100644 --- a/frontend/src/components/charts/TokenUsageTrend.vue +++ b/frontend/src/components/charts/TokenUsageTrend.vue @@ -64,7 +64,8 @@ const chartColors = computed(() => ({ input: '#3b82f6', output: '#10b981', cacheCreation: '#f59e0b', - cacheRead: '#06b6d4' + cacheRead: '#06b6d4', + cacheHitRate: '#8b5cf6' })) const chartData = computed(() => { @@ -104,6 +105,19 @@ const chartData = computed(() => { backgroundColor: `${chartColors.value.cacheRead}20`, fill: true, tension: 0.3 + }, + { + label: 'Cache Hit Rate', + data: props.trendData.map((d) => { + const total = d.cache_read_tokens + d.cache_creation_tokens + return total > 0 ? (d.cache_read_tokens / total) * 100 : 0 + }), + borderColor: chartColors.value.cacheHitRate, + backgroundColor: `${chartColors.value.cacheHitRate}20`, + borderDash: [5, 5], + fill: false, + tension: 0.3, + yAxisID: 'yPercent' } ] } @@ -132,6 +146,9 @@ const lineOptions = computed(() => ({ tooltip: { callbacks: { label: (context: any) => { + if (context.dataset.yAxisID === 'yPercent') { + return `${context.dataset.label}: ${context.raw.toFixed(1)}%` + } return `${context.dataset.label}: ${formatTokens(context.raw)}` }, footer: (tooltipItems: any) => { @@ -168,6 +185,21 @@ const lineOptions = computed(() => ({ }, callback: (value: string | number) => formatTokens(Number(value)) } + }, + yPercent: { + position: 'right' as const, + min: 0, + max: 100, + grid: { + drawOnChartArea: false + }, + ticks: { + color: chartColors.value.cacheHitRate, + font: { + size: 10 + }, + callback: (value: string | number) => `${value}%` + } } } })) diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index 07a0e634..f5267d6a 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -2318,6 +2318,11 @@ export default { target: 'Target TTL', targetHint: 'Select the TTL tier for billing' }, + customBaseUrl: { + label: 'Custom Relay URL', + hint: 'Forward requests to a custom relay service. Proxy URL will be passed as a query parameter.', + urlHint: 'Relay service URL (e.g., https://relay.example.com)', + }, clientAffinity: { label: 'Client Affinity Scheduling', hint: 'When enabled, new sessions prefer accounts previously used by this client to reduce account switching' @@ -4378,6 +4383,7 @@ export default { provider: 'Type', active: 'Active', endpoint: 'Endpoint', + bucket: 'Bucket', storagePath: 'Storage Path', capacityUsage: 'Capacity / Used', capacityUnlimited: 'Unlimited', diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index a6b6e8b5..9581206e 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -2462,6 +2462,11 @@ export default { target: '目标 TTL', targetHint: '选择计费使用的 TTL 类型' }, + customBaseUrl: { + label: '自定义转发地址', + hint: '启用后将请求转发到自定义中继服务,代理地址将作为 URL 参数传递给中继服务', + urlHint: '中继服务地址(如 https://relay.example.com)', + }, clientAffinity: { label: '客户端亲和调度', hint: '启用后,新会话会优先调度到该客户端之前使用过的账号,避免频繁切换账号' @@ -4542,6 +4547,7 @@ export default { provider: '存储类型', active: '生效状态', endpoint: '端点', + bucket: '存储桶', storagePath: '存储路径', capacityUsage: '容量 / 已用', capacityUnlimited: '无限制', diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index 8ab48216..f9425ad0 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -734,6 +734,10 @@ export interface Account { cache_ttl_override_enabled?: boolean | null cache_ttl_override_target?: string | null + // 自定义 Base URL 中继转发(仅 Anthropic OAuth/SetupToken 账号有效) + custom_base_url_enabled?: boolean | null + custom_base_url?: string | null + // 客户端亲和调度(仅 Anthropic/Antigravity 平台有效) // 启用后新会话会优先调度到客户端之前使用过的账号 client_affinity_enabled?: boolean | null diff --git a/frontend/src/views/admin/ops/components/OpsSystemLogTable.vue b/frontend/src/views/admin/ops/components/OpsSystemLogTable.vue index d2aeb3ca..bfc9397d 100644 --- a/frontend/src/views/admin/ops/components/OpsSystemLogTable.vue +++ b/frontend/src/views/admin/ops/components/OpsSystemLogTable.vue @@ -2,6 +2,7 @@ import { computed, onMounted, reactive, ref, watch } from 'vue' import { opsAPI, type OpsRuntimeLogConfig, type OpsSystemLog, type OpsSystemLogSinkHealth } from '@/api/admin/ops' import Pagination from '@/components/common/Pagination.vue' +import Select from '@/components/common/Select.vue' import { useAppStore } from '@/stores' const appStore = useAppStore() @@ -56,6 +57,37 @@ const filters = reactive({ q: '' }) +const runtimeLevelOptions = [ + { value: 'debug', label: 'debug' }, + { value: 'info', label: 'info' }, + { value: 'warn', label: 'warn' }, + { value: 'error', label: 'error' } +] + +const stacktraceLevelOptions = [ + { value: 'none', label: 'none' }, + { value: 'error', label: 'error' }, + { value: 'fatal', label: 'fatal' } +] + +const timeRangeOptions = [ + { value: '5m', label: '5m' }, + { value: '30m', label: '30m' }, + { value: '1h', label: '1h' }, + { value: '6h', label: '6h' }, + { value: '24h', label: '24h' }, + { value: '7d', label: '7d' }, + { value: '30d', label: '30d' } +] + +const filterLevelOptions = [ + { value: '', label: '全部' }, + { value: 'debug', label: 'debug' }, + { value: 'info', label: 'info' }, + { value: 'warn', label: 'warn' }, + { value: 'error', label: 'error' } +] + const levelBadgeClass = (level: string) => { const v = String(level || '').toLowerCase() if (v === 'error' || v === 'fatal') return 'bg-red-100 text-red-700 dark:bg-red-900/30 dark:text-red-300' @@ -347,20 +379,11 @@ onMounted(async () => {