diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 2629199f..20e33317 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -1,248 +1,248 @@ -// Code generated by Wire. DO NOT EDIT. - -//go:generate go run -mod=mod github.com/google/wire/cmd/wire -//go:build !wireinject -// +build !wireinject - -package main - -import ( - "context" - "github.com/Wei-Shaw/sub2api/ent" - "github.com/Wei-Shaw/sub2api/internal/config" - "github.com/Wei-Shaw/sub2api/internal/handler" - "github.com/Wei-Shaw/sub2api/internal/handler/admin" - "github.com/Wei-Shaw/sub2api/internal/repository" - "github.com/Wei-Shaw/sub2api/internal/server" - "github.com/Wei-Shaw/sub2api/internal/server/middleware" - "github.com/Wei-Shaw/sub2api/internal/service" - "github.com/redis/go-redis/v9" - "log" - "net/http" - "time" -) - -import ( - _ "embed" - _ "github.com/Wei-Shaw/sub2api/ent/runtime" -) - -// Injectors from wire.go: - -func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { - configConfig, err := config.ProvideConfig() - if err != nil { - return nil, err - } - client, err := repository.ProvideEnt(configConfig) - if err != nil { - return nil, err - } - db, err := repository.ProvideSQLDB(client) - if err != nil { - return nil, err - } - userRepository := repository.NewUserRepository(client, db) - settingRepository := repository.NewSettingRepository(client) - settingService := service.NewSettingService(settingRepository, configConfig) - redisClient := repository.ProvideRedis(configConfig) - emailCache := repository.NewEmailCache(redisClient) - emailService := service.NewEmailService(settingRepository, emailCache) - turnstileVerifier := repository.NewTurnstileVerifier() - turnstileService := service.NewTurnstileService(settingService, turnstileVerifier) - emailQueueService := service.ProvideEmailQueueService(emailService) - authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService) - userService := service.NewUserService(userRepository) - authHandler := handler.NewAuthHandler(configConfig, authService, userService) - userHandler := handler.NewUserHandler(userService) - apiKeyRepository := repository.NewApiKeyRepository(client) - groupRepository := repository.NewGroupRepository(client, db) - userSubscriptionRepository := repository.NewUserSubscriptionRepository(client) - apiKeyCache := repository.NewApiKeyCache(redisClient) - apiKeyService := service.NewApiKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, apiKeyCache, configConfig) - apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) - usageLogRepository := repository.NewUsageLogRepository(client, db) - usageService := service.NewUsageService(usageLogRepository, userRepository) - usageHandler := handler.NewUsageHandler(usageService, apiKeyService) - redeemCodeRepository := repository.NewRedeemCodeRepository(client) - billingCache := repository.NewBillingCache(redisClient) - billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository, configConfig) - subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService) - redeemCache := repository.NewRedeemCache(redisClient) - redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client) - redeemHandler := handler.NewRedeemHandler(redeemService) - subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService) - dashboardService := service.NewDashboardService(usageLogRepository) - dashboardHandler := admin.NewDashboardHandler(dashboardService) - accountRepository := repository.NewAccountRepository(client, db) - proxyRepository := repository.NewProxyRepository(client, db) - proxyExitInfoProber := repository.NewProxyExitInfoProber() - adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber) - adminUserHandler := admin.NewUserHandler(adminService) - groupHandler := admin.NewGroupHandler(adminService) - claudeOAuthClient := repository.NewClaudeOAuthClient() - oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient) - openAIOAuthClient := repository.NewOpenAIOAuthClient() - openAIOAuthService := service.NewOpenAIOAuthService(proxyRepository, openAIOAuthClient) - geminiOAuthClient := repository.NewGeminiOAuthClient(configConfig) - geminiCliCodeAssistClient := repository.NewGeminiCliCodeAssistClient() - geminiOAuthService := service.NewGeminiOAuthService(proxyRepository, geminiOAuthClient, geminiCliCodeAssistClient, configConfig) - geminiQuotaService := service.NewGeminiQuotaService(configConfig, settingRepository) - rateLimitService := service.NewRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService) - claudeUsageFetcher := repository.NewClaudeUsageFetcher() - antigravityQuotaFetcher := service.NewAntigravityQuotaFetcher(proxyRepository) - usageCache := service.NewUsageCache() - accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache) - geminiTokenCache := repository.NewGeminiTokenCache(redisClient) - geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService) - gatewayCache := repository.NewGatewayCache(redisClient) - antigravityOAuthService := service.NewAntigravityOAuthService(proxyRepository) - antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService) - httpUpstream := repository.NewHTTPUpstream(configConfig) - antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, antigravityTokenProvider, rateLimitService, httpUpstream) - accountTestService := service.NewAccountTestService(accountRepository, oAuthService, openAIOAuthService, geminiTokenProvider, antigravityGatewayService, httpUpstream) - concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig) - concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig) - crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService) - accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService) - oAuthHandler := admin.NewOAuthHandler(oAuthService) - openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService) - geminiOAuthHandler := admin.NewGeminiOAuthHandler(geminiOAuthService) - antigravityOAuthHandler := admin.NewAntigravityOAuthHandler(antigravityOAuthService) - proxyHandler := admin.NewProxyHandler(adminService) - adminRedeemHandler := admin.NewRedeemHandler(adminService) - settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService) - updateCache := repository.NewUpdateCache(redisClient) - gitHubReleaseClient := repository.NewGitHubReleaseClient() - serviceBuildInfo := provideServiceBuildInfo(buildInfo) - updateService := service.ProvideUpdateService(updateCache, gitHubReleaseClient, serviceBuildInfo) - systemHandler := handler.ProvideSystemHandler(updateService) - adminSubscriptionHandler := admin.NewSubscriptionHandler(subscriptionService) - adminUsageHandler := admin.NewUsageHandler(usageService, apiKeyService, adminService) - userAttributeDefinitionRepository := repository.NewUserAttributeDefinitionRepository(client) - userAttributeValueRepository := repository.NewUserAttributeValueRepository(client) - userAttributeService := service.NewUserAttributeService(userAttributeDefinitionRepository, userAttributeValueRepository) - userAttributeHandler := admin.NewUserAttributeHandler(userAttributeService) - adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, settingHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler) - pricingRemoteClient := repository.NewPricingRemoteClient() - pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient) - if err != nil { - return nil, err - } - billingService := service.NewBillingService(configConfig, pricingService) - identityCache := repository.NewIdentityCache(redisClient) - identityService := service.NewIdentityService(identityCache) - timingWheelService := service.ProvideTimingWheelService() - deferredService := service.ProvideDeferredService(accountRepository, timingWheelService) - gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService) - geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService) - gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService) - openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService) - openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService) - handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo) - handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler) - jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService) - adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService) - apiKeyAuthMiddleware := middleware.NewApiKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig) - engine := server.ProvideRouter(configConfig, handlers, jwtAuthMiddleware, adminAuthMiddleware, apiKeyAuthMiddleware, apiKeyService, subscriptionService) - httpServer := server.ProvideHTTPServer(configConfig, engine) - tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, configConfig) - v := provideCleanup(client, redisClient, tokenRefreshService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService) - application := &Application{ - Server: httpServer, - Cleanup: v, - } - return application, nil -} - -// wire.go: - -type Application struct { - Server *http.Server - Cleanup func() -} - -func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo { - return service.BuildInfo{ - Version: buildInfo.Version, - BuildType: buildInfo.BuildType, - } -} - -func provideCleanup( - entClient *ent.Client, - rdb *redis.Client, - tokenRefresh *service.TokenRefreshService, - pricing *service.PricingService, - emailQueue *service.EmailQueueService, - billingCache *service.BillingCacheService, - oauth *service.OAuthService, - openaiOAuth *service.OpenAIOAuthService, - geminiOAuth *service.GeminiOAuthService, - antigravityOAuth *service.AntigravityOAuthService, -) func() { - return func() { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - - cleanupSteps := []struct { - name string - fn func() error - }{ - {"TokenRefreshService", func() error { - tokenRefresh.Stop() - return nil - }}, - {"PricingService", func() error { - pricing.Stop() - return nil - }}, - {"EmailQueueService", func() error { - emailQueue.Stop() - return nil - }}, - {"BillingCacheService", func() error { - billingCache.Stop() - return nil - }}, - {"OAuthService", func() error { - oauth.Stop() - return nil - }}, - {"OpenAIOAuthService", func() error { - openaiOAuth.Stop() - return nil - }}, - {"GeminiOAuthService", func() error { - geminiOAuth.Stop() - return nil - }}, - {"AntigravityOAuthService", func() error { - antigravityOAuth.Stop() - return nil - }}, - {"Redis", func() error { - return rdb.Close() - }}, - {"Ent", func() error { - return entClient.Close() - }}, - } - - for _, step := range cleanupSteps { - if err := step.fn(); err != nil { - log.Printf("[Cleanup] %s failed: %v", step.name, err) - - } else { - log.Printf("[Cleanup] %s succeeded", step.name) - } - } - - select { - case <-ctx.Done(): - log.Printf("[Cleanup] Warning: cleanup timed out after 10 seconds") - default: - log.Printf("[Cleanup] All cleanup steps completed") - } - } -} +// Code generated by Wire. DO NOT EDIT. + +//go:generate go run -mod=mod github.com/google/wire/cmd/wire +//go:build !wireinject +// +build !wireinject + +package main + +import ( + "context" + "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/handler" + "github.com/Wei-Shaw/sub2api/internal/handler/admin" + "github.com/Wei-Shaw/sub2api/internal/repository" + "github.com/Wei-Shaw/sub2api/internal/server" + "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/redis/go-redis/v9" + "log" + "net/http" + "time" +) + +import ( + _ "embed" + _ "github.com/Wei-Shaw/sub2api/ent/runtime" +) + +// Injectors from wire.go: + +func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { + configConfig, err := config.ProvideConfig() + if err != nil { + return nil, err + } + client, err := repository.ProvideEnt(configConfig) + if err != nil { + return nil, err + } + db, err := repository.ProvideSQLDB(client) + if err != nil { + return nil, err + } + userRepository := repository.NewUserRepository(client, db) + settingRepository := repository.NewSettingRepository(client) + settingService := service.NewSettingService(settingRepository, configConfig) + redisClient := repository.ProvideRedis(configConfig) + emailCache := repository.NewEmailCache(redisClient) + emailService := service.NewEmailService(settingRepository, emailCache) + turnstileVerifier := repository.NewTurnstileVerifier() + turnstileService := service.NewTurnstileService(settingService, turnstileVerifier) + emailQueueService := service.ProvideEmailQueueService(emailService) + authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService) + userService := service.NewUserService(userRepository) + authHandler := handler.NewAuthHandler(configConfig, authService, userService) + userHandler := handler.NewUserHandler(userService) + apiKeyRepository := repository.NewApiKeyRepository(client) + groupRepository := repository.NewGroupRepository(client, db) + userSubscriptionRepository := repository.NewUserSubscriptionRepository(client) + apiKeyCache := repository.NewApiKeyCache(redisClient) + apiKeyService := service.NewApiKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, apiKeyCache, configConfig) + apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) + usageLogRepository := repository.NewUsageLogRepository(client, db) + usageService := service.NewUsageService(usageLogRepository, userRepository) + usageHandler := handler.NewUsageHandler(usageService, apiKeyService) + redeemCodeRepository := repository.NewRedeemCodeRepository(client) + billingCache := repository.NewBillingCache(redisClient) + billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository, configConfig) + subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService) + redeemCache := repository.NewRedeemCache(redisClient) + redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client) + redeemHandler := handler.NewRedeemHandler(redeemService) + subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService) + dashboardService := service.NewDashboardService(usageLogRepository) + dashboardHandler := admin.NewDashboardHandler(dashboardService) + accountRepository := repository.NewAccountRepository(client, db) + proxyRepository := repository.NewProxyRepository(client, db) + proxyExitInfoProber := repository.NewProxyExitInfoProber() + adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber) + adminUserHandler := admin.NewUserHandler(adminService) + groupHandler := admin.NewGroupHandler(adminService) + claudeOAuthClient := repository.NewClaudeOAuthClient() + oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient) + openAIOAuthClient := repository.NewOpenAIOAuthClient() + openAIOAuthService := service.NewOpenAIOAuthService(proxyRepository, openAIOAuthClient) + geminiOAuthClient := repository.NewGeminiOAuthClient(configConfig) + geminiCliCodeAssistClient := repository.NewGeminiCliCodeAssistClient() + geminiOAuthService := service.NewGeminiOAuthService(proxyRepository, geminiOAuthClient, geminiCliCodeAssistClient, configConfig) + geminiQuotaService := service.NewGeminiQuotaService(configConfig, settingRepository) + rateLimitService := service.NewRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService) + claudeUsageFetcher := repository.NewClaudeUsageFetcher() + antigravityQuotaFetcher := service.NewAntigravityQuotaFetcher(proxyRepository) + usageCache := service.NewUsageCache() + accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache) + geminiTokenCache := repository.NewGeminiTokenCache(redisClient) + geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService) + gatewayCache := repository.NewGatewayCache(redisClient) + antigravityOAuthService := service.NewAntigravityOAuthService(proxyRepository) + antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService) + httpUpstream := repository.NewHTTPUpstream(configConfig) + antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, antigravityTokenProvider, rateLimitService, httpUpstream) + accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream) + concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig) + concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig) + crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService) + accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService) + oAuthHandler := admin.NewOAuthHandler(oAuthService) + openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService) + geminiOAuthHandler := admin.NewGeminiOAuthHandler(geminiOAuthService) + antigravityOAuthHandler := admin.NewAntigravityOAuthHandler(antigravityOAuthService) + proxyHandler := admin.NewProxyHandler(adminService) + adminRedeemHandler := admin.NewRedeemHandler(adminService) + settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService) + updateCache := repository.NewUpdateCache(redisClient) + gitHubReleaseClient := repository.NewGitHubReleaseClient() + serviceBuildInfo := provideServiceBuildInfo(buildInfo) + updateService := service.ProvideUpdateService(updateCache, gitHubReleaseClient, serviceBuildInfo) + systemHandler := handler.ProvideSystemHandler(updateService) + adminSubscriptionHandler := admin.NewSubscriptionHandler(subscriptionService) + adminUsageHandler := admin.NewUsageHandler(usageService, apiKeyService, adminService) + userAttributeDefinitionRepository := repository.NewUserAttributeDefinitionRepository(client) + userAttributeValueRepository := repository.NewUserAttributeValueRepository(client) + userAttributeService := service.NewUserAttributeService(userAttributeDefinitionRepository, userAttributeValueRepository) + userAttributeHandler := admin.NewUserAttributeHandler(userAttributeService) + adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, settingHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler) + pricingRemoteClient := repository.NewPricingRemoteClient() + pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient) + if err != nil { + return nil, err + } + billingService := service.NewBillingService(configConfig, pricingService) + identityCache := repository.NewIdentityCache(redisClient) + identityService := service.NewIdentityService(identityCache) + timingWheelService := service.ProvideTimingWheelService() + deferredService := service.ProvideDeferredService(accountRepository, timingWheelService) + gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService) + geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService) + gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService) + openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService) + openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService) + handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo) + handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler) + jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService) + adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService) + apiKeyAuthMiddleware := middleware.NewApiKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig) + engine := server.ProvideRouter(configConfig, handlers, jwtAuthMiddleware, adminAuthMiddleware, apiKeyAuthMiddleware, apiKeyService, subscriptionService) + httpServer := server.ProvideHTTPServer(configConfig, engine) + tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, configConfig) + v := provideCleanup(client, redisClient, tokenRefreshService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService) + application := &Application{ + Server: httpServer, + Cleanup: v, + } + return application, nil +} + +// wire.go: + +type Application struct { + Server *http.Server + Cleanup func() +} + +func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo { + return service.BuildInfo{ + Version: buildInfo.Version, + BuildType: buildInfo.BuildType, + } +} + +func provideCleanup( + entClient *ent.Client, + rdb *redis.Client, + tokenRefresh *service.TokenRefreshService, + pricing *service.PricingService, + emailQueue *service.EmailQueueService, + billingCache *service.BillingCacheService, + oauth *service.OAuthService, + openaiOAuth *service.OpenAIOAuthService, + geminiOAuth *service.GeminiOAuthService, + antigravityOAuth *service.AntigravityOAuthService, +) func() { + return func() { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + cleanupSteps := []struct { + name string + fn func() error + }{ + {"TokenRefreshService", func() error { + tokenRefresh.Stop() + return nil + }}, + {"PricingService", func() error { + pricing.Stop() + return nil + }}, + {"EmailQueueService", func() error { + emailQueue.Stop() + return nil + }}, + {"BillingCacheService", func() error { + billingCache.Stop() + return nil + }}, + {"OAuthService", func() error { + oauth.Stop() + return nil + }}, + {"OpenAIOAuthService", func() error { + openaiOAuth.Stop() + return nil + }}, + {"GeminiOAuthService", func() error { + geminiOAuth.Stop() + return nil + }}, + {"AntigravityOAuthService", func() error { + antigravityOAuth.Stop() + return nil + }}, + {"Redis", func() error { + return rdb.Close() + }}, + {"Ent", func() error { + return entClient.Close() + }}, + } + + for _, step := range cleanupSteps { + if err := step.fn(); err != nil { + log.Printf("[Cleanup] %s failed: %v", step.name, err) + + } else { + log.Printf("[Cleanup] %s succeeded", step.name) + } + } + + select { + case <-ctx.Done(): + log.Printf("[Cleanup] Warning: cleanup timed out after 10 seconds") + default: + log.Printf("[Cleanup] All cleanup steps completed") + } + } +} diff --git a/backend/internal/handler/admin/proxy_handler.go b/backend/internal/handler/admin/proxy_handler.go index e207839d..99557f9a 100644 --- a/backend/internal/handler/admin/proxy_handler.go +++ b/backend/internal/handler/admin/proxy_handler.go @@ -1,323 +1,323 @@ -package admin - -import ( - "strconv" - "strings" - - "github.com/Wei-Shaw/sub2api/internal/handler/dto" - "github.com/Wei-Shaw/sub2api/internal/pkg/response" - "github.com/Wei-Shaw/sub2api/internal/service" - - "github.com/gin-gonic/gin" -) - -// ProxyHandler handles admin proxy management -type ProxyHandler struct { - adminService service.AdminService -} - -// NewProxyHandler creates a new admin proxy handler -func NewProxyHandler(adminService service.AdminService) *ProxyHandler { - return &ProxyHandler{ - adminService: adminService, - } -} - -// CreateProxyRequest represents create proxy request -type CreateProxyRequest struct { - Name string `json:"name" binding:"required"` - Protocol string `json:"protocol" binding:"required,oneof=http https socks5"` - Host string `json:"host" binding:"required"` - Port int `json:"port" binding:"required,min=1,max=65535"` - Username string `json:"username"` - Password string `json:"password"` -} - -// UpdateProxyRequest represents update proxy request -type UpdateProxyRequest struct { - Name string `json:"name"` - Protocol string `json:"protocol" binding:"omitempty,oneof=http https socks5"` - Host string `json:"host"` - Port int `json:"port" binding:"omitempty,min=1,max=65535"` - Username string `json:"username"` - Password string `json:"password"` - Status string `json:"status" binding:"omitempty,oneof=active inactive"` -} - -// List handles listing all proxies with pagination -// GET /api/v1/admin/proxies -func (h *ProxyHandler) List(c *gin.Context) { - page, pageSize := response.ParsePagination(c) - protocol := c.Query("protocol") - status := c.Query("status") - search := c.Query("search") - - proxies, total, err := h.adminService.ListProxies(c.Request.Context(), page, pageSize, protocol, status, search) - if err != nil { - response.ErrorFrom(c, err) - return - } - - out := make([]dto.Proxy, 0, len(proxies)) - for i := range proxies { - out = append(out, *dto.ProxyFromService(&proxies[i])) - } - response.Paginated(c, out, total, page, pageSize) -} - -// GetAll handles getting all active proxies without pagination -// GET /api/v1/admin/proxies/all -// Optional query param: with_count=true to include account count per proxy -func (h *ProxyHandler) GetAll(c *gin.Context) { - withCount := c.Query("with_count") == "true" - - if withCount { - proxies, err := h.adminService.GetAllProxiesWithAccountCount(c.Request.Context()) - if err != nil { - response.ErrorFrom(c, err) - return - } - out := make([]dto.ProxyWithAccountCount, 0, len(proxies)) - for i := range proxies { - out = append(out, *dto.ProxyWithAccountCountFromService(&proxies[i])) - } - response.Success(c, out) - return - } - - proxies, err := h.adminService.GetAllProxies(c.Request.Context()) - if err != nil { - response.ErrorFrom(c, err) - return - } - - out := make([]dto.Proxy, 0, len(proxies)) - for i := range proxies { - out = append(out, *dto.ProxyFromService(&proxies[i])) - } - response.Success(c, out) -} - -// GetByID handles getting a proxy by ID -// GET /api/v1/admin/proxies/:id -func (h *ProxyHandler) GetByID(c *gin.Context) { - proxyID, err := strconv.ParseInt(c.Param("id"), 10, 64) - if err != nil { - response.BadRequest(c, "Invalid proxy ID") - return - } - - proxy, err := h.adminService.GetProxy(c.Request.Context(), proxyID) - if err != nil { - response.ErrorFrom(c, err) - return - } - - response.Success(c, dto.ProxyFromService(proxy)) -} - -// Create handles creating a new proxy -// POST /api/v1/admin/proxies -func (h *ProxyHandler) Create(c *gin.Context) { - var req CreateProxyRequest - if err := c.ShouldBindJSON(&req); err != nil { - response.BadRequest(c, "Invalid request: "+err.Error()) - return - } - - proxy, err := h.adminService.CreateProxy(c.Request.Context(), &service.CreateProxyInput{ - Name: strings.TrimSpace(req.Name), - Protocol: strings.TrimSpace(req.Protocol), - Host: strings.TrimSpace(req.Host), - Port: req.Port, - Username: strings.TrimSpace(req.Username), - Password: strings.TrimSpace(req.Password), - }) - if err != nil { - response.ErrorFrom(c, err) - return - } - - response.Success(c, dto.ProxyFromService(proxy)) -} - -// Update handles updating a proxy -// PUT /api/v1/admin/proxies/:id -func (h *ProxyHandler) Update(c *gin.Context) { - proxyID, err := strconv.ParseInt(c.Param("id"), 10, 64) - if err != nil { - response.BadRequest(c, "Invalid proxy ID") - return - } - - var req UpdateProxyRequest - if err := c.ShouldBindJSON(&req); err != nil { - response.BadRequest(c, "Invalid request: "+err.Error()) - return - } - - proxy, err := h.adminService.UpdateProxy(c.Request.Context(), proxyID, &service.UpdateProxyInput{ - Name: strings.TrimSpace(req.Name), - Protocol: strings.TrimSpace(req.Protocol), - Host: strings.TrimSpace(req.Host), - Port: req.Port, - Username: strings.TrimSpace(req.Username), - Password: strings.TrimSpace(req.Password), - Status: strings.TrimSpace(req.Status), - }) - if err != nil { - response.ErrorFrom(c, err) - return - } - - response.Success(c, dto.ProxyFromService(proxy)) -} - -// Delete handles deleting a proxy -// DELETE /api/v1/admin/proxies/:id -func (h *ProxyHandler) Delete(c *gin.Context) { - proxyID, err := strconv.ParseInt(c.Param("id"), 10, 64) - if err != nil { - response.BadRequest(c, "Invalid proxy ID") - return - } - - err = h.adminService.DeleteProxy(c.Request.Context(), proxyID) - if err != nil { - response.ErrorFrom(c, err) - return - } - - response.Success(c, gin.H{"message": "Proxy deleted successfully"}) -} - -// Test handles testing proxy connectivity -// POST /api/v1/admin/proxies/:id/test -func (h *ProxyHandler) Test(c *gin.Context) { - proxyID, err := strconv.ParseInt(c.Param("id"), 10, 64) - if err != nil { - response.BadRequest(c, "Invalid proxy ID") - return - } - - result, err := h.adminService.TestProxy(c.Request.Context(), proxyID) - if err != nil { - response.ErrorFrom(c, err) - return - } - - response.Success(c, result) -} - -// GetStats handles getting proxy statistics -// GET /api/v1/admin/proxies/:id/stats -func (h *ProxyHandler) GetStats(c *gin.Context) { - proxyID, err := strconv.ParseInt(c.Param("id"), 10, 64) - if err != nil { - response.BadRequest(c, "Invalid proxy ID") - return - } - - // Return mock data for now - _ = proxyID - response.Success(c, gin.H{ - "total_accounts": 0, - "active_accounts": 0, - "total_requests": 0, - "success_rate": 100.0, - "average_latency": 0, - }) -} - -// GetProxyAccounts handles getting accounts using a proxy -// GET /api/v1/admin/proxies/:id/accounts -func (h *ProxyHandler) GetProxyAccounts(c *gin.Context) { - proxyID, err := strconv.ParseInt(c.Param("id"), 10, 64) - if err != nil { - response.BadRequest(c, "Invalid proxy ID") - return - } - - page, pageSize := response.ParsePagination(c) - - accounts, total, err := h.adminService.GetProxyAccounts(c.Request.Context(), proxyID, page, pageSize) - if err != nil { - response.ErrorFrom(c, err) - return - } - - out := make([]dto.Account, 0, len(accounts)) - for i := range accounts { - out = append(out, *dto.AccountFromService(&accounts[i])) - } - response.Paginated(c, out, total, page, pageSize) -} - -// BatchCreateProxyItem represents a single proxy in batch create request -type BatchCreateProxyItem struct { - Protocol string `json:"protocol" binding:"required,oneof=http https socks5"` - Host string `json:"host" binding:"required"` - Port int `json:"port" binding:"required,min=1,max=65535"` - Username string `json:"username"` - Password string `json:"password"` -} - -// BatchCreateRequest represents batch create proxies request -type BatchCreateRequest struct { - Proxies []BatchCreateProxyItem `json:"proxies" binding:"required,min=1"` -} - -// BatchCreate handles batch creating proxies -// POST /api/v1/admin/proxies/batch -func (h *ProxyHandler) BatchCreate(c *gin.Context) { - var req BatchCreateRequest - if err := c.ShouldBindJSON(&req); err != nil { - response.BadRequest(c, "Invalid request: "+err.Error()) - return - } - - created := 0 - skipped := 0 - - for _, item := range req.Proxies { - // Trim all string fields - host := strings.TrimSpace(item.Host) - protocol := strings.TrimSpace(item.Protocol) - username := strings.TrimSpace(item.Username) - password := strings.TrimSpace(item.Password) - - // Check for duplicates (same host, port, username, password) - exists, err := h.adminService.CheckProxyExists(c.Request.Context(), host, item.Port, username, password) - if err != nil { - response.ErrorFrom(c, err) - return - } - - if exists { - skipped++ - continue - } - - // Create proxy with default name - _, err = h.adminService.CreateProxy(c.Request.Context(), &service.CreateProxyInput{ - Name: "default", - Protocol: protocol, - Host: host, - Port: item.Port, - Username: username, - Password: password, - }) - if err != nil { - // If creation fails due to duplicate, count as skipped - skipped++ - continue - } - - created++ - } - - response.Success(c, gin.H{ - "created": created, - "skipped": skipped, - }) -} +package admin + +import ( + "strconv" + "strings" + + "github.com/Wei-Shaw/sub2api/internal/handler/dto" + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/service" + + "github.com/gin-gonic/gin" +) + +// ProxyHandler handles admin proxy management +type ProxyHandler struct { + adminService service.AdminService +} + +// NewProxyHandler creates a new admin proxy handler +func NewProxyHandler(adminService service.AdminService) *ProxyHandler { + return &ProxyHandler{ + adminService: adminService, + } +} + +// CreateProxyRequest represents create proxy request +type CreateProxyRequest struct { + Name string `json:"name" binding:"required"` + Protocol string `json:"protocol" binding:"required,oneof=http https socks5 socks5h"` + Host string `json:"host" binding:"required"` + Port int `json:"port" binding:"required,min=1,max=65535"` + Username string `json:"username"` + Password string `json:"password"` +} + +// UpdateProxyRequest represents update proxy request +type UpdateProxyRequest struct { + Name string `json:"name"` + Protocol string `json:"protocol" binding:"omitempty,oneof=http https socks5 socks5h"` + Host string `json:"host"` + Port int `json:"port" binding:"omitempty,min=1,max=65535"` + Username string `json:"username"` + Password string `json:"password"` + Status string `json:"status" binding:"omitempty,oneof=active inactive"` +} + +// List handles listing all proxies with pagination +// GET /api/v1/admin/proxies +func (h *ProxyHandler) List(c *gin.Context) { + page, pageSize := response.ParsePagination(c) + protocol := c.Query("protocol") + status := c.Query("status") + search := c.Query("search") + + proxies, total, err := h.adminService.ListProxies(c.Request.Context(), page, pageSize, protocol, status, search) + if err != nil { + response.ErrorFrom(c, err) + return + } + + out := make([]dto.Proxy, 0, len(proxies)) + for i := range proxies { + out = append(out, *dto.ProxyFromService(&proxies[i])) + } + response.Paginated(c, out, total, page, pageSize) +} + +// GetAll handles getting all active proxies without pagination +// GET /api/v1/admin/proxies/all +// Optional query param: with_count=true to include account count per proxy +func (h *ProxyHandler) GetAll(c *gin.Context) { + withCount := c.Query("with_count") == "true" + + if withCount { + proxies, err := h.adminService.GetAllProxiesWithAccountCount(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + out := make([]dto.ProxyWithAccountCount, 0, len(proxies)) + for i := range proxies { + out = append(out, *dto.ProxyWithAccountCountFromService(&proxies[i])) + } + response.Success(c, out) + return + } + + proxies, err := h.adminService.GetAllProxies(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + + out := make([]dto.Proxy, 0, len(proxies)) + for i := range proxies { + out = append(out, *dto.ProxyFromService(&proxies[i])) + } + response.Success(c, out) +} + +// GetByID handles getting a proxy by ID +// GET /api/v1/admin/proxies/:id +func (h *ProxyHandler) GetByID(c *gin.Context) { + proxyID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "Invalid proxy ID") + return + } + + proxy, err := h.adminService.GetProxy(c.Request.Context(), proxyID) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, dto.ProxyFromService(proxy)) +} + +// Create handles creating a new proxy +// POST /api/v1/admin/proxies +func (h *ProxyHandler) Create(c *gin.Context) { + var req CreateProxyRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + proxy, err := h.adminService.CreateProxy(c.Request.Context(), &service.CreateProxyInput{ + Name: strings.TrimSpace(req.Name), + Protocol: strings.TrimSpace(req.Protocol), + Host: strings.TrimSpace(req.Host), + Port: req.Port, + Username: strings.TrimSpace(req.Username), + Password: strings.TrimSpace(req.Password), + }) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, dto.ProxyFromService(proxy)) +} + +// Update handles updating a proxy +// PUT /api/v1/admin/proxies/:id +func (h *ProxyHandler) Update(c *gin.Context) { + proxyID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "Invalid proxy ID") + return + } + + var req UpdateProxyRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + proxy, err := h.adminService.UpdateProxy(c.Request.Context(), proxyID, &service.UpdateProxyInput{ + Name: strings.TrimSpace(req.Name), + Protocol: strings.TrimSpace(req.Protocol), + Host: strings.TrimSpace(req.Host), + Port: req.Port, + Username: strings.TrimSpace(req.Username), + Password: strings.TrimSpace(req.Password), + Status: strings.TrimSpace(req.Status), + }) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, dto.ProxyFromService(proxy)) +} + +// Delete handles deleting a proxy +// DELETE /api/v1/admin/proxies/:id +func (h *ProxyHandler) Delete(c *gin.Context) { + proxyID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "Invalid proxy ID") + return + } + + err = h.adminService.DeleteProxy(c.Request.Context(), proxyID) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, gin.H{"message": "Proxy deleted successfully"}) +} + +// Test handles testing proxy connectivity +// POST /api/v1/admin/proxies/:id/test +func (h *ProxyHandler) Test(c *gin.Context) { + proxyID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "Invalid proxy ID") + return + } + + result, err := h.adminService.TestProxy(c.Request.Context(), proxyID) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, result) +} + +// GetStats handles getting proxy statistics +// GET /api/v1/admin/proxies/:id/stats +func (h *ProxyHandler) GetStats(c *gin.Context) { + proxyID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "Invalid proxy ID") + return + } + + // Return mock data for now + _ = proxyID + response.Success(c, gin.H{ + "total_accounts": 0, + "active_accounts": 0, + "total_requests": 0, + "success_rate": 100.0, + "average_latency": 0, + }) +} + +// GetProxyAccounts handles getting accounts using a proxy +// GET /api/v1/admin/proxies/:id/accounts +func (h *ProxyHandler) GetProxyAccounts(c *gin.Context) { + proxyID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "Invalid proxy ID") + return + } + + page, pageSize := response.ParsePagination(c) + + accounts, total, err := h.adminService.GetProxyAccounts(c.Request.Context(), proxyID, page, pageSize) + if err != nil { + response.ErrorFrom(c, err) + return + } + + out := make([]dto.Account, 0, len(accounts)) + for i := range accounts { + out = append(out, *dto.AccountFromService(&accounts[i])) + } + response.Paginated(c, out, total, page, pageSize) +} + +// BatchCreateProxyItem represents a single proxy in batch create request +type BatchCreateProxyItem struct { + Protocol string `json:"protocol" binding:"required,oneof=http https socks5 socks5h"` + Host string `json:"host" binding:"required"` + Port int `json:"port" binding:"required,min=1,max=65535"` + Username string `json:"username"` + Password string `json:"password"` +} + +// BatchCreateRequest represents batch create proxies request +type BatchCreateRequest struct { + Proxies []BatchCreateProxyItem `json:"proxies" binding:"required,min=1"` +} + +// BatchCreate handles batch creating proxies +// POST /api/v1/admin/proxies/batch +func (h *ProxyHandler) BatchCreate(c *gin.Context) { + var req BatchCreateRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + created := 0 + skipped := 0 + + for _, item := range req.Proxies { + // Trim all string fields + host := strings.TrimSpace(item.Host) + protocol := strings.TrimSpace(item.Protocol) + username := strings.TrimSpace(item.Username) + password := strings.TrimSpace(item.Password) + + // Check for duplicates (same host, port, username, password) + exists, err := h.adminService.CheckProxyExists(c.Request.Context(), host, item.Port, username, password) + if err != nil { + response.ErrorFrom(c, err) + return + } + + if exists { + skipped++ + continue + } + + // Create proxy with default name + _, err = h.adminService.CreateProxy(c.Request.Context(), &service.CreateProxyInput{ + Name: "default", + Protocol: protocol, + Host: host, + Port: item.Port, + Username: username, + Password: password, + }) + if err != nil { + // If creation fails due to duplicate, count as skipped + skipped++ + continue + } + + created++ + } + + response.Success(c, gin.H{ + "created": created, + "skipped": skipped, + }) +} diff --git a/backend/internal/pkg/httpclient/pool.go b/backend/internal/pkg/httpclient/pool.go index a16c921a..8a81c09a 100644 --- a/backend/internal/pkg/httpclient/pool.go +++ b/backend/internal/pkg/httpclient/pool.go @@ -1,157 +1,138 @@ -// Package httpclient 提供共享 HTTP 客户端池 -// -// 性能优化说明: -// 原实现在多个服务中重复创建 http.Client: -// 1. proxy_probe_service.go: 每次探测创建新客户端 -// 2. pricing_service.go: 每次请求创建新客户端 -// 3. turnstile_service.go: 每次验证创建新客户端 -// 4. github_release_service.go: 每次请求创建新客户端 -// 5. claude_usage_service.go: 每次请求创建新客户端 -// -// 新实现使用统一的客户端池: -// 1. 相同配置复用同一 http.Client 实例 -// 2. 复用 Transport 连接池,减少 TCP/TLS 握手开销 -// 3. 支持 HTTP/HTTPS/SOCKS5 代理 -// 4. 支持严格代理模式(代理失败则返回错误) -package httpclient - -import ( - "context" - "crypto/tls" - "fmt" - "net" - "net/http" - "net/url" - "strings" - "sync" - "time" - - "golang.org/x/net/proxy" -) - -// Transport 连接池默认配置 -const ( - defaultMaxIdleConns = 100 // 最大空闲连接数 - defaultMaxIdleConnsPerHost = 10 // 每个主机最大空闲连接数 - defaultIdleConnTimeout = 90 * time.Second // 空闲连接超时时间 -) - -// Options 定义共享 HTTP 客户端的构建参数 -type Options struct { - ProxyURL string // 代理 URL(支持 http/https/socks5) - Timeout time.Duration // 请求总超时时间 - ResponseHeaderTimeout time.Duration // 等待响应头超时时间 - InsecureSkipVerify bool // 是否跳过 TLS 证书验证 - ProxyStrict bool // 严格代理模式:代理失败时返回错误而非回退 - - // 可选的连接池参数(不设置则使用默认值) - MaxIdleConns int // 最大空闲连接总数(默认 100) - MaxIdleConnsPerHost int // 每主机最大空闲连接(默认 10) - MaxConnsPerHost int // 每主机最大连接数(默认 0 无限制) -} - -// sharedClients 存储按配置参数缓存的 http.Client 实例 -var sharedClients sync.Map - -// GetClient 返回共享的 HTTP 客户端实例 -// 性能优化:相同配置复用同一客户端,避免重复创建 Transport -func GetClient(opts Options) (*http.Client, error) { - key := buildClientKey(opts) - if cached, ok := sharedClients.Load(key); ok { - if client, ok := cached.(*http.Client); ok { - return client, nil - } - } - - client, err := buildClient(opts) - if err != nil { - if opts.ProxyStrict { - return nil, err - } - fallback := opts - fallback.ProxyURL = "" - client, _ = buildClient(fallback) - } - - actual, _ := sharedClients.LoadOrStore(key, client) - if c, ok := actual.(*http.Client); ok { - return c, nil - } - return client, nil -} - -func buildClient(opts Options) (*http.Client, error) { - transport, err := buildTransport(opts) - if err != nil { - return nil, err - } - - return &http.Client{ - Transport: transport, - Timeout: opts.Timeout, - }, nil -} - -func buildTransport(opts Options) (*http.Transport, error) { - // 使用自定义值或默认值 - maxIdleConns := opts.MaxIdleConns - if maxIdleConns <= 0 { - maxIdleConns = defaultMaxIdleConns - } - maxIdleConnsPerHost := opts.MaxIdleConnsPerHost - if maxIdleConnsPerHost <= 0 { - maxIdleConnsPerHost = defaultMaxIdleConnsPerHost - } - - transport := &http.Transport{ - MaxIdleConns: maxIdleConns, - MaxIdleConnsPerHost: maxIdleConnsPerHost, - MaxConnsPerHost: opts.MaxConnsPerHost, // 0 表示无限制 - IdleConnTimeout: defaultIdleConnTimeout, - ResponseHeaderTimeout: opts.ResponseHeaderTimeout, - } - - if opts.InsecureSkipVerify { - transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} - } - - proxyURL := strings.TrimSpace(opts.ProxyURL) - if proxyURL == "" { - return transport, nil - } - - parsed, err := url.Parse(proxyURL) - if err != nil { - return nil, err - } - - switch strings.ToLower(parsed.Scheme) { - case "http", "https": - transport.Proxy = http.ProxyURL(parsed) - case "socks5", "socks5h": - dialer, err := proxy.FromURL(parsed, proxy.Direct) - if err != nil { - return nil, err - } - transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { - return dialer.Dial(network, addr) - } - default: - return nil, fmt.Errorf("unsupported proxy protocol: %s", parsed.Scheme) - } - - return transport, nil -} - -func buildClientKey(opts Options) string { - return fmt.Sprintf("%s|%s|%s|%t|%t|%d|%d|%d", - strings.TrimSpace(opts.ProxyURL), - opts.Timeout.String(), - opts.ResponseHeaderTimeout.String(), - opts.InsecureSkipVerify, - opts.ProxyStrict, - opts.MaxIdleConns, - opts.MaxIdleConnsPerHost, - opts.MaxConnsPerHost, - ) -} +// Package httpclient 提供共享 HTTP 客户端池 +// +// 性能优化说明: +// 原实现在多个服务中重复创建 http.Client: +// 1. proxy_probe_service.go: 每次探测创建新客户端 +// 2. pricing_service.go: 每次请求创建新客户端 +// 3. turnstile_service.go: 每次验证创建新客户端 +// 4. github_release_service.go: 每次请求创建新客户端 +// 5. claude_usage_service.go: 每次请求创建新客户端 +// +// 新实现使用统一的客户端池: +// 1. 相同配置复用同一 http.Client 实例 +// 2. 复用 Transport 连接池,减少 TCP/TLS 握手开销 +// 3. 支持 HTTP/HTTPS/SOCKS5/SOCKS5H 代理 +// 4. 代理配置失败时直接返回错误,不会回退到直连(避免 IP 关联风险) +package httpclient + +import ( + "crypto/tls" + "fmt" + "net/http" + "net/url" + "strings" + "sync" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil" +) + +// Transport 连接池默认配置 +const ( + defaultMaxIdleConns = 100 // 最大空闲连接数 + defaultMaxIdleConnsPerHost = 10 // 每个主机最大空闲连接数 + defaultIdleConnTimeout = 90 * time.Second // 空闲连接超时时间 +) + +// Options 定义共享 HTTP 客户端的构建参数 +type Options struct { + ProxyURL string // 代理 URL(支持 http/https/socks5/socks5h) + Timeout time.Duration // 请求总超时时间 + ResponseHeaderTimeout time.Duration // 等待响应头超时时间 + InsecureSkipVerify bool // 是否跳过 TLS 证书验证 + + // 可选的连接池参数(不设置则使用默认值) + MaxIdleConns int // 最大空闲连接总数(默认 100) + MaxIdleConnsPerHost int // 每主机最大空闲连接(默认 10) + MaxConnsPerHost int // 每主机最大连接数(默认 0 无限制) +} + +// sharedClients 存储按配置参数缓存的 http.Client 实例 +var sharedClients sync.Map + +// GetClient 返回共享的 HTTP 客户端实例 +// 性能优化:相同配置复用同一客户端,避免重复创建 Transport +// 安全说明:代理配置失败时直接返回错误,不会回退到直连,避免 IP 关联风险 +func GetClient(opts Options) (*http.Client, error) { + key := buildClientKey(opts) + if cached, ok := sharedClients.Load(key); ok { + if client, ok := cached.(*http.Client); ok { + return client, nil + } + } + + client, err := buildClient(opts) + if err != nil { + return nil, err + } + + actual, _ := sharedClients.LoadOrStore(key, client) + if c, ok := actual.(*http.Client); ok { + return c, nil + } + return client, nil +} + +func buildClient(opts Options) (*http.Client, error) { + transport, err := buildTransport(opts) + if err != nil { + return nil, err + } + + return &http.Client{ + Transport: transport, + Timeout: opts.Timeout, + }, nil +} + +func buildTransport(opts Options) (*http.Transport, error) { + // 使用自定义值或默认值 + maxIdleConns := opts.MaxIdleConns + if maxIdleConns <= 0 { + maxIdleConns = defaultMaxIdleConns + } + maxIdleConnsPerHost := opts.MaxIdleConnsPerHost + if maxIdleConnsPerHost <= 0 { + maxIdleConnsPerHost = defaultMaxIdleConnsPerHost + } + + transport := &http.Transport{ + MaxIdleConns: maxIdleConns, + MaxIdleConnsPerHost: maxIdleConnsPerHost, + MaxConnsPerHost: opts.MaxConnsPerHost, // 0 表示无限制 + IdleConnTimeout: defaultIdleConnTimeout, + ResponseHeaderTimeout: opts.ResponseHeaderTimeout, + } + + if opts.InsecureSkipVerify { + transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} + } + + proxyURL := strings.TrimSpace(opts.ProxyURL) + if proxyURL == "" { + return transport, nil + } + + parsed, err := url.Parse(proxyURL) + if err != nil { + return nil, err + } + + if err := proxyutil.ConfigureTransportProxy(transport, parsed); err != nil { + return nil, err + } + + return transport, nil +} + +func buildClientKey(opts Options) string { + return fmt.Sprintf("%s|%s|%s|%t|%d|%d|%d", + strings.TrimSpace(opts.ProxyURL), + opts.Timeout.String(), + opts.ResponseHeaderTimeout.String(), + opts.InsecureSkipVerify, + opts.MaxIdleConns, + opts.MaxIdleConnsPerHost, + opts.MaxConnsPerHost, + ) +} diff --git a/backend/internal/pkg/proxyutil/dialer.go b/backend/internal/pkg/proxyutil/dialer.go new file mode 100644 index 00000000..91b224a2 --- /dev/null +++ b/backend/internal/pkg/proxyutil/dialer.go @@ -0,0 +1,62 @@ +// Package proxyutil 提供统一的代理配置功能 +// +// 支持的代理协议: +// - HTTP/HTTPS: 通过 Transport.Proxy 设置 +// - SOCKS5/SOCKS5H: 通过 Transport.DialContext 设置(服务端解析 DNS) +package proxyutil + +import ( + "context" + "fmt" + "net" + "net/http" + "net/url" + "strings" + + "golang.org/x/net/proxy" +) + +// ConfigureTransportProxy 根据代理 URL 配置 Transport +// +// 支持的协议: +// - http/https: 设置 transport.Proxy +// - socks5/socks5h: 设置 transport.DialContext(由代理服务端解析 DNS) +// +// 参数: +// - transport: 需要配置的 http.Transport +// - proxyURL: 代理地址,nil 表示直连 +// +// 返回: +// - error: 代理配置错误(协议不支持或 dialer 创建失败) +func ConfigureTransportProxy(transport *http.Transport, proxyURL *url.URL) error { + if proxyURL == nil { + return nil + } + + scheme := strings.ToLower(proxyURL.Scheme) + switch scheme { + case "http", "https": + transport.Proxy = http.ProxyURL(proxyURL) + return nil + + case "socks5", "socks5h": + dialer, err := proxy.FromURL(proxyURL, proxy.Direct) + if err != nil { + return fmt.Errorf("create socks5 dialer: %w", err) + } + // 优先使用支持 context 的 DialContext,以支持请求取消和超时 + if contextDialer, ok := dialer.(proxy.ContextDialer); ok { + transport.DialContext = contextDialer.DialContext + } else { + // 回退路径:如果 dialer 不支持 ContextDialer,则包装为简单的 DialContext + // 注意:此回退不支持请求取消和超时控制 + transport.DialContext = func(_ context.Context, network, addr string) (net.Conn, error) { + return dialer.Dial(network, addr) + } + } + return nil + + default: + return fmt.Errorf("unsupported proxy scheme: %s", scheme) + } +} diff --git a/backend/internal/pkg/proxyutil/dialer_test.go b/backend/internal/pkg/proxyutil/dialer_test.go new file mode 100644 index 00000000..f153cc9f --- /dev/null +++ b/backend/internal/pkg/proxyutil/dialer_test.go @@ -0,0 +1,204 @@ +package proxyutil + +import ( + "net/http" + "net/url" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestConfigureTransportProxy_Nil(t *testing.T) { + transport := &http.Transport{} + err := ConfigureTransportProxy(transport, nil) + + require.NoError(t, err) + assert.Nil(t, transport.Proxy, "nil proxy should not set Proxy") + assert.Nil(t, transport.DialContext, "nil proxy should not set DialContext") +} + +func TestConfigureTransportProxy_HTTP(t *testing.T) { + transport := &http.Transport{} + proxyURL, _ := url.Parse("http://proxy.example.com:8080") + + err := ConfigureTransportProxy(transport, proxyURL) + + require.NoError(t, err) + assert.NotNil(t, transport.Proxy, "HTTP proxy should set Proxy") + assert.Nil(t, transport.DialContext, "HTTP proxy should not set DialContext") +} + +func TestConfigureTransportProxy_HTTPS(t *testing.T) { + transport := &http.Transport{} + proxyURL, _ := url.Parse("https://secure-proxy.example.com:8443") + + err := ConfigureTransportProxy(transport, proxyURL) + + require.NoError(t, err) + assert.NotNil(t, transport.Proxy, "HTTPS proxy should set Proxy") + assert.Nil(t, transport.DialContext, "HTTPS proxy should not set DialContext") +} + +func TestConfigureTransportProxy_SOCKS5(t *testing.T) { + transport := &http.Transport{} + proxyURL, _ := url.Parse("socks5://socks.example.com:1080") + + err := ConfigureTransportProxy(transport, proxyURL) + + require.NoError(t, err) + assert.Nil(t, transport.Proxy, "SOCKS5 proxy should not set Proxy") + assert.NotNil(t, transport.DialContext, "SOCKS5 proxy should set DialContext") +} + +func TestConfigureTransportProxy_SOCKS5H(t *testing.T) { + transport := &http.Transport{} + proxyURL, _ := url.Parse("socks5h://socks.example.com:1080") + + err := ConfigureTransportProxy(transport, proxyURL) + + require.NoError(t, err) + assert.Nil(t, transport.Proxy, "SOCKS5H proxy should not set Proxy") + assert.NotNil(t, transport.DialContext, "SOCKS5H proxy should set DialContext") +} + +func TestConfigureTransportProxy_CaseInsensitive(t *testing.T) { + testCases := []struct { + scheme string + useProxy bool // true = uses Transport.Proxy, false = uses DialContext + }{ + {"HTTP://proxy.example.com:8080", true}, + {"Http://proxy.example.com:8080", true}, + {"HTTPS://proxy.example.com:8443", true}, + {"Https://proxy.example.com:8443", true}, + {"SOCKS5://socks.example.com:1080", false}, + {"Socks5://socks.example.com:1080", false}, + {"SOCKS5H://socks.example.com:1080", false}, + {"Socks5h://socks.example.com:1080", false}, + } + + for _, tc := range testCases { + t.Run(tc.scheme, func(t *testing.T) { + transport := &http.Transport{} + proxyURL, _ := url.Parse(tc.scheme) + + err := ConfigureTransportProxy(transport, proxyURL) + + require.NoError(t, err) + if tc.useProxy { + assert.NotNil(t, transport.Proxy) + assert.Nil(t, transport.DialContext) + } else { + assert.Nil(t, transport.Proxy) + assert.NotNil(t, transport.DialContext) + } + }) + } +} + +func TestConfigureTransportProxy_Unsupported(t *testing.T) { + testCases := []string{ + "ftp://ftp.example.com", + "file:///path/to/file", + "unknown://example.com", + } + + for _, tc := range testCases { + t.Run(tc, func(t *testing.T) { + transport := &http.Transport{} + proxyURL, _ := url.Parse(tc) + + err := ConfigureTransportProxy(transport, proxyURL) + + require.Error(t, err) + assert.Contains(t, err.Error(), "unsupported proxy scheme") + }) + } +} + +func TestConfigureTransportProxy_WithAuth(t *testing.T) { + transport := &http.Transport{} + proxyURL, _ := url.Parse("socks5://user:password@socks.example.com:1080") + + err := ConfigureTransportProxy(transport, proxyURL) + + require.NoError(t, err) + assert.NotNil(t, transport.DialContext, "SOCKS5 with auth should set DialContext") +} + +func TestConfigureTransportProxy_EmptyScheme(t *testing.T) { + transport := &http.Transport{} + // 空 scheme 的 URL + proxyURL := &url.URL{Host: "proxy.example.com:8080"} + + err := ConfigureTransportProxy(transport, proxyURL) + + require.Error(t, err) + assert.Contains(t, err.Error(), "unsupported proxy scheme") +} + +func TestConfigureTransportProxy_PreservesExistingConfig(t *testing.T) { + // 验证代理配置不会覆盖 Transport 的其他配置 + transport := &http.Transport{ + MaxIdleConns: 100, + MaxIdleConnsPerHost: 10, + } + proxyURL, _ := url.Parse("socks5://socks.example.com:1080") + + err := ConfigureTransportProxy(transport, proxyURL) + + require.NoError(t, err) + assert.Equal(t, 100, transport.MaxIdleConns, "MaxIdleConns should be preserved") + assert.Equal(t, 10, transport.MaxIdleConnsPerHost, "MaxIdleConnsPerHost should be preserved") + assert.NotNil(t, transport.DialContext, "DialContext should be set") +} + +func TestConfigureTransportProxy_IPv6(t *testing.T) { + testCases := []struct { + name string + proxyURL string + }{ + {"SOCKS5H with IPv6 loopback", "socks5h://[::1]:1080"}, + {"SOCKS5 with full IPv6", "socks5://[2001:db8::1]:1080"}, + {"HTTP with IPv6", "http://[::1]:8080"}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + transport := &http.Transport{} + proxyURL, err := url.Parse(tc.proxyURL) + require.NoError(t, err, "URL should be parseable") + + err = ConfigureTransportProxy(transport, proxyURL) + require.NoError(t, err) + }) + } +} + +func TestConfigureTransportProxy_SpecialCharsInPassword(t *testing.T) { + testCases := []struct { + name string + proxyURL string + }{ + // 密码包含 @ 符号(URL 编码为 %40) + {"password with @", "socks5://user:p%40ssword@proxy.example.com:1080"}, + // 密码包含 : 符号(URL 编码为 %3A) + {"password with :", "socks5://user:pass%3Aword@proxy.example.com:1080"}, + // 密码包含 / 符号(URL 编码为 %2F) + {"password with /", "socks5://user:pass%2Fword@proxy.example.com:1080"}, + // 复杂密码 + {"complex password", "socks5h://admin:P%40ss%3Aw0rd%2F123@proxy.example.com:1080"}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + transport := &http.Transport{} + proxyURL, err := url.Parse(tc.proxyURL) + require.NoError(t, err, "URL should be parseable") + + err = ConfigureTransportProxy(transport, proxyURL) + require.NoError(t, err) + assert.NotNil(t, transport.DialContext, "SOCKS5 should set DialContext") + }) + } +} diff --git a/backend/internal/repository/claude_oauth_service.go b/backend/internal/repository/claude_oauth_service.go index 55620f78..35e7f535 100644 --- a/backend/internal/repository/claude_oauth_service.go +++ b/backend/internal/repository/claude_oauth_service.go @@ -1,251 +1,257 @@ -package repository - -import ( - "context" - "encoding/json" - "fmt" - "log" - "net/http" - "net/url" - "strings" - "time" - - "github.com/Wei-Shaw/sub2api/internal/pkg/oauth" - "github.com/Wei-Shaw/sub2api/internal/service" - - "github.com/imroc/req/v3" -) - -func NewClaudeOAuthClient() service.ClaudeOAuthClient { - return &claudeOAuthService{ - baseURL: "https://claude.ai", - tokenURL: oauth.TokenURL, - clientFactory: createReqClient, - } -} - -type claudeOAuthService struct { - baseURL string - tokenURL string - clientFactory func(proxyURL string) *req.Client -} - -func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey, proxyURL string) (string, error) { - client := s.clientFactory(proxyURL) - - var orgs []struct { - UUID string `json:"uuid"` - } - - targetURL := s.baseURL + "/api/organizations" - log.Printf("[OAuth] Step 1: Getting organization UUID from %s", targetURL) - - resp, err := client.R(). - SetContext(ctx). - SetCookies(&http.Cookie{ - Name: "sessionKey", - Value: sessionKey, - }). - SetSuccessResult(&orgs). - Get(targetURL) - - if err != nil { - log.Printf("[OAuth] Step 1 FAILED - Request error: %v", err) - return "", fmt.Errorf("request failed: %w", err) - } - - log.Printf("[OAuth] Step 1 Response - Status: %d, Body: %s", resp.StatusCode, resp.String()) - - if !resp.IsSuccessState() { - return "", fmt.Errorf("failed to get organizations: status %d, body: %s", resp.StatusCode, resp.String()) - } - - if len(orgs) == 0 { - return "", fmt.Errorf("no organizations found") - } - - log.Printf("[OAuth] Step 1 SUCCESS - Got org UUID: %s", orgs[0].UUID) - return orgs[0].UUID, nil -} - -func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error) { - client := s.clientFactory(proxyURL) - - authURL := fmt.Sprintf("%s/v1/oauth/%s/authorize", s.baseURL, orgUUID) - - reqBody := map[string]any{ - "response_type": "code", - "client_id": oauth.ClientID, - "organization_uuid": orgUUID, - "redirect_uri": oauth.RedirectURI, - "scope": scope, - "state": state, - "code_challenge": codeChallenge, - "code_challenge_method": "S256", - } - - reqBodyJSON, _ := json.Marshal(reqBody) - log.Printf("[OAuth] Step 2: Getting authorization code from %s", authURL) - log.Printf("[OAuth] Step 2 Request Body: %s", string(reqBodyJSON)) - - var result struct { - RedirectURI string `json:"redirect_uri"` - } - - resp, err := client.R(). - SetContext(ctx). - SetCookies(&http.Cookie{ - Name: "sessionKey", - Value: sessionKey, - }). - SetHeader("Accept", "application/json"). - SetHeader("Accept-Language", "en-US,en;q=0.9"). - SetHeader("Cache-Control", "no-cache"). - SetHeader("Origin", "https://claude.ai"). - SetHeader("Referer", "https://claude.ai/new"). - SetHeader("Content-Type", "application/json"). - SetBody(reqBody). - SetSuccessResult(&result). - Post(authURL) - - if err != nil { - log.Printf("[OAuth] Step 2 FAILED - Request error: %v", err) - return "", fmt.Errorf("request failed: %w", err) - } - - log.Printf("[OAuth] Step 2 Response - Status: %d, Body: %s", resp.StatusCode, resp.String()) - - if !resp.IsSuccessState() { - return "", fmt.Errorf("failed to get authorization code: status %d, body: %s", resp.StatusCode, resp.String()) - } - - if result.RedirectURI == "" { - return "", fmt.Errorf("no redirect_uri in response") - } - - parsedURL, err := url.Parse(result.RedirectURI) - if err != nil { - return "", fmt.Errorf("failed to parse redirect_uri: %w", err) - } - - queryParams := parsedURL.Query() - authCode := queryParams.Get("code") - responseState := queryParams.Get("state") - - if authCode == "" { - return "", fmt.Errorf("no authorization code in redirect_uri") - } - - fullCode := authCode - if responseState != "" { - fullCode = authCode + "#" + responseState - } - - log.Printf("[OAuth] Step 2 SUCCESS - Got authorization code: %s...", prefix(authCode, 20)) - return fullCode, nil -} - -func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error) { - client := s.clientFactory(proxyURL) - - // Parse code which may contain state in format "authCode#state" - authCode := code - codeState := "" - if idx := strings.Index(code, "#"); idx != -1 { - authCode = code[:idx] - codeState = code[idx+1:] - } - - reqBody := map[string]any{ - "code": authCode, - "grant_type": "authorization_code", - "client_id": oauth.ClientID, - "redirect_uri": oauth.RedirectURI, - "code_verifier": codeVerifier, - } - - if codeState != "" { - reqBody["state"] = codeState - } - - // Setup token requires longer expiration (1 year) - if isSetupToken { - reqBody["expires_in"] = 31536000 // 365 * 24 * 60 * 60 seconds - } - - reqBodyJSON, _ := json.Marshal(reqBody) - log.Printf("[OAuth] Step 3: Exchanging code for token at %s", s.tokenURL) - log.Printf("[OAuth] Step 3 Request Body: %s", string(reqBodyJSON)) - - var tokenResp oauth.TokenResponse - - resp, err := client.R(). - SetContext(ctx). - SetHeader("Content-Type", "application/json"). - SetBody(reqBody). - SetSuccessResult(&tokenResp). - Post(s.tokenURL) - - if err != nil { - log.Printf("[OAuth] Step 3 FAILED - Request error: %v", err) - return nil, fmt.Errorf("request failed: %w", err) - } - - log.Printf("[OAuth] Step 3 Response - Status: %d, Body: %s", resp.StatusCode, resp.String()) - - if !resp.IsSuccessState() { - return nil, fmt.Errorf("token exchange failed: status %d, body: %s", resp.StatusCode, resp.String()) - } - - log.Printf("[OAuth] Step 3 SUCCESS - Got access token") - return &tokenResp, nil -} - -func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) { - client := s.clientFactory(proxyURL) - - // 使用 JSON 格式(与 ExchangeCodeForToken 保持一致) - // Anthropic OAuth API 期望 JSON 格式的请求体 - reqBody := map[string]any{ - "grant_type": "refresh_token", - "refresh_token": refreshToken, - "client_id": oauth.ClientID, - } - - var tokenResp oauth.TokenResponse - - resp, err := client.R(). - SetContext(ctx). - SetHeader("Content-Type", "application/json"). - SetBody(reqBody). - SetSuccessResult(&tokenResp). - Post(s.tokenURL) - - if err != nil { - return nil, fmt.Errorf("request failed: %w", err) - } - - if !resp.IsSuccessState() { - return nil, fmt.Errorf("token refresh failed: status %d, body: %s", resp.StatusCode, resp.String()) - } - - return &tokenResp, nil -} - -func createReqClient(proxyURL string) *req.Client { - return getSharedReqClient(reqClientOptions{ - ProxyURL: proxyURL, - Timeout: 60 * time.Second, - Impersonate: true, - }) -} - -func prefix(s string, n int) string { - if n <= 0 { - return "" - } - if len(s) <= n { - return s - } - return s[:n] -} +package repository + +import ( + "context" + "encoding/json" + "fmt" + "log" + "net/http" + "net/url" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/oauth" + "github.com/Wei-Shaw/sub2api/internal/service" + + "github.com/imroc/req/v3" +) + +func NewClaudeOAuthClient() service.ClaudeOAuthClient { + return &claudeOAuthService{ + baseURL: "https://claude.ai", + tokenURL: oauth.TokenURL, + clientFactory: createReqClient, + } +} + +type claudeOAuthService struct { + baseURL string + tokenURL string + clientFactory func(proxyURL string) *req.Client +} + +func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey, proxyURL string) (string, error) { + client := s.clientFactory(proxyURL) + + var orgs []struct { + UUID string `json:"uuid"` + } + + targetURL := s.baseURL + "/api/organizations" + log.Printf("[OAuth] Step 1: Getting organization UUID from %s", targetURL) + + resp, err := client.R(). + SetContext(ctx). + SetCookies(&http.Cookie{ + Name: "sessionKey", + Value: sessionKey, + }). + SetSuccessResult(&orgs). + Get(targetURL) + + if err != nil { + log.Printf("[OAuth] Step 1 FAILED - Request error: %v", err) + return "", fmt.Errorf("request failed: %w", err) + } + + log.Printf("[OAuth] Step 1 Response - Status: %d, Body: %s", resp.StatusCode, resp.String()) + + if !resp.IsSuccessState() { + return "", fmt.Errorf("failed to get organizations: status %d, body: %s", resp.StatusCode, resp.String()) + } + + if len(orgs) == 0 { + return "", fmt.Errorf("no organizations found") + } + + log.Printf("[OAuth] Step 1 SUCCESS - Got org UUID: %s", orgs[0].UUID) + return orgs[0].UUID, nil +} + +func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error) { + client := s.clientFactory(proxyURL) + + authURL := fmt.Sprintf("%s/v1/oauth/%s/authorize", s.baseURL, orgUUID) + + reqBody := map[string]any{ + "response_type": "code", + "client_id": oauth.ClientID, + "organization_uuid": orgUUID, + "redirect_uri": oauth.RedirectURI, + "scope": scope, + "state": state, + "code_challenge": codeChallenge, + "code_challenge_method": "S256", + } + + reqBodyJSON, _ := json.Marshal(reqBody) + log.Printf("[OAuth] Step 2: Getting authorization code from %s", authURL) + log.Printf("[OAuth] Step 2 Request Body: %s", string(reqBodyJSON)) + + var result struct { + RedirectURI string `json:"redirect_uri"` + } + + resp, err := client.R(). + SetContext(ctx). + SetCookies(&http.Cookie{ + Name: "sessionKey", + Value: sessionKey, + }). + SetHeader("Accept", "application/json"). + SetHeader("Accept-Language", "en-US,en;q=0.9"). + SetHeader("Cache-Control", "no-cache"). + SetHeader("Origin", "https://claude.ai"). + SetHeader("Referer", "https://claude.ai/new"). + SetHeader("Content-Type", "application/json"). + SetBody(reqBody). + SetSuccessResult(&result). + Post(authURL) + + if err != nil { + log.Printf("[OAuth] Step 2 FAILED - Request error: %v", err) + return "", fmt.Errorf("request failed: %w", err) + } + + log.Printf("[OAuth] Step 2 Response - Status: %d, Body: %s", resp.StatusCode, resp.String()) + + if !resp.IsSuccessState() { + return "", fmt.Errorf("failed to get authorization code: status %d, body: %s", resp.StatusCode, resp.String()) + } + + if result.RedirectURI == "" { + return "", fmt.Errorf("no redirect_uri in response") + } + + parsedURL, err := url.Parse(result.RedirectURI) + if err != nil { + return "", fmt.Errorf("failed to parse redirect_uri: %w", err) + } + + queryParams := parsedURL.Query() + authCode := queryParams.Get("code") + responseState := queryParams.Get("state") + + if authCode == "" { + return "", fmt.Errorf("no authorization code in redirect_uri") + } + + fullCode := authCode + if responseState != "" { + fullCode = authCode + "#" + responseState + } + + log.Printf("[OAuth] Step 2 SUCCESS - Got authorization code: %s...", prefix(authCode, 20)) + return fullCode, nil +} + +func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error) { + client := s.clientFactory(proxyURL) + + // Parse code which may contain state in format "authCode#state" + authCode := code + codeState := "" + if idx := strings.Index(code, "#"); idx != -1 { + authCode = code[:idx] + codeState = code[idx+1:] + } + + reqBody := map[string]any{ + "code": authCode, + "grant_type": "authorization_code", + "client_id": oauth.ClientID, + "redirect_uri": oauth.RedirectURI, + "code_verifier": codeVerifier, + } + + if codeState != "" { + reqBody["state"] = codeState + } + + // Setup token requires longer expiration (1 year) + if isSetupToken { + reqBody["expires_in"] = 31536000 // 365 * 24 * 60 * 60 seconds + } + + reqBodyJSON, _ := json.Marshal(reqBody) + log.Printf("[OAuth] Step 3: Exchanging code for token at %s", s.tokenURL) + log.Printf("[OAuth] Step 3 Request Body: %s", string(reqBodyJSON)) + + var tokenResp oauth.TokenResponse + + resp, err := client.R(). + SetContext(ctx). + SetHeader("Content-Type", "application/json"). + SetBody(reqBody). + SetSuccessResult(&tokenResp). + Post(s.tokenURL) + + if err != nil { + log.Printf("[OAuth] Step 3 FAILED - Request error: %v", err) + return nil, fmt.Errorf("request failed: %w", err) + } + + log.Printf("[OAuth] Step 3 Response - Status: %d, Body: %s", resp.StatusCode, resp.String()) + + if !resp.IsSuccessState() { + return nil, fmt.Errorf("token exchange failed: status %d, body: %s", resp.StatusCode, resp.String()) + } + + log.Printf("[OAuth] Step 3 SUCCESS - Got access token") + return &tokenResp, nil +} + +func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) { + client := s.clientFactory(proxyURL) + + // 使用 JSON 格式(与 ExchangeCodeForToken 保持一致) + // Anthropic OAuth API 期望 JSON 格式的请求体 + reqBody := map[string]any{ + "grant_type": "refresh_token", + "refresh_token": refreshToken, + "client_id": oauth.ClientID, + } + + var tokenResp oauth.TokenResponse + + resp, err := client.R(). + SetContext(ctx). + SetHeader("Content-Type", "application/json"). + SetBody(reqBody). + SetSuccessResult(&tokenResp). + Post(s.tokenURL) + + if err != nil { + return nil, fmt.Errorf("request failed: %w", err) + } + + if !resp.IsSuccessState() { + return nil, fmt.Errorf("token refresh failed: status %d, body: %s", resp.StatusCode, resp.String()) + } + + return &tokenResp, nil +} + +func createReqClient(proxyURL string) *req.Client { + // 禁用 CookieJar,确保每次授权都是干净的会话 + client := req.C(). + SetTimeout(60 * time.Second). + ImpersonateChrome(). + SetCookieJar(nil) // 禁用 CookieJar + + if strings.TrimSpace(proxyURL) != "" { + client.SetProxyURL(strings.TrimSpace(proxyURL)) + } + + return client +} + +func prefix(s string, n int) string { + if n <= 0 { + return "" + } + if len(s) <= n { + return s + } + return s[:n] +} diff --git a/backend/internal/repository/http_upstream.go b/backend/internal/repository/http_upstream.go index fd6ae1ba..f0669979 100644 --- a/backend/internal/repository/http_upstream.go +++ b/backend/internal/repository/http_upstream.go @@ -1,604 +1,611 @@ -package repository - -import ( - "errors" - "fmt" - "io" - "net" - "net/http" - "net/url" - "strings" - "sync" - "sync/atomic" - "time" - - "github.com/Wei-Shaw/sub2api/internal/config" - "github.com/Wei-Shaw/sub2api/internal/service" -) - -// 默认配置常量 -// 这些值在配置文件未指定时作为回退默认值使用 -const ( - // directProxyKey: 无代理时的缓存键标识 - directProxyKey = "direct" - // defaultMaxIdleConns: 默认最大空闲连接总数 - // HTTP/2 场景下,单连接可多路复用,240 足以支撑高并发 - defaultMaxIdleConns = 240 - // defaultMaxIdleConnsPerHost: 默认每主机最大空闲连接数 - defaultMaxIdleConnsPerHost = 120 - // defaultMaxConnsPerHost: 默认每主机最大连接数(含活跃连接) - // 达到上限后新请求会等待,而非无限创建连接 - defaultMaxConnsPerHost = 240 - // defaultIdleConnTimeout: 默认空闲连接超时时间(5分钟) - // 超时后连接会被关闭,释放系统资源 - defaultIdleConnTimeout = 300 * time.Second - // defaultResponseHeaderTimeout: 默认等待响应头超时时间(5分钟) - // LLM 请求可能排队较久,需要较长超时 - defaultResponseHeaderTimeout = 300 * time.Second - // defaultMaxUpstreamClients: 默认最大客户端缓存数量 - // 超出后会淘汰最久未使用的客户端 - defaultMaxUpstreamClients = 5000 - // defaultClientIdleTTLSeconds: 默认客户端空闲回收阈值(15分钟) - defaultClientIdleTTLSeconds = 900 -) - -var errUpstreamClientLimitReached = errors.New("upstream client cache limit reached") - -// poolSettings 连接池配置参数 -// 封装 Transport 所需的各项连接池参数 -type poolSettings struct { - maxIdleConns int // 最大空闲连接总数 - maxIdleConnsPerHost int // 每主机最大空闲连接数 - maxConnsPerHost int // 每主机最大连接数(含活跃) - idleConnTimeout time.Duration // 空闲连接超时时间 - responseHeaderTimeout time.Duration // 等待响应头超时时间 -} - -// upstreamClientEntry 上游客户端缓存条目 -// 记录客户端实例及其元数据,用于连接池管理和淘汰策略 -type upstreamClientEntry struct { - client *http.Client // HTTP 客户端实例 - proxyKey string // 代理标识(用于检测代理变更) - poolKey string // 连接池配置标识(用于检测配置变更) - lastUsed int64 // 最后使用时间戳(纳秒),用于 LRU 淘汰 - inFlight int64 // 当前进行中的请求数,>0 时不可淘汰 -} - -// httpUpstreamService 通用 HTTP 上游服务 -// 用于向任意 HTTP API(Claude、OpenAI 等)发送请求,支持可选代理 -// -// 架构设计: -// - 根据隔离策略(proxy/account/account_proxy)缓存客户端实例 -// - 每个客户端拥有独立的 Transport 连接池 -// - 支持 LRU + 空闲时间双重淘汰策略 -// -// 性能优化: -// 1. 根据隔离策略缓存客户端实例,避免频繁创建 http.Client -// 2. 复用 Transport 连接池,减少 TCP 握手和 TLS 协商开销 -// 3. 支持账号级隔离与空闲回收,降低连接层关联风险 -// 4. 达到最大连接数后等待可用连接,而非无限创建 -// 5. 仅回收空闲客户端,避免中断活跃请求 -// 6. HTTP/2 多路复用,连接上限不等于并发请求上限 -// 7. 代理变更时清空旧连接池,避免复用错误代理 -// 8. 账号并发数与连接池上限对应(账号隔离策略下) -type httpUpstreamService struct { - cfg *config.Config // 全局配置 - mu sync.RWMutex // 保护 clients map 的读写锁 - clients map[string]*upstreamClientEntry // 客户端缓存池,key 由隔离策略决定 -} - -// NewHTTPUpstream 创建通用 HTTP 上游服务 -// 使用配置中的连接池参数构建 Transport -// -// 参数: -// - cfg: 全局配置,包含连接池参数和隔离策略 -// -// 返回: -// - service.HTTPUpstream 接口实现 -func NewHTTPUpstream(cfg *config.Config) service.HTTPUpstream { - return &httpUpstreamService{ - cfg: cfg, - clients: make(map[string]*upstreamClientEntry), - } -} - -// Do 执行 HTTP 请求 -// 根据隔离策略获取或创建客户端,并跟踪请求生命周期 -// -// 参数: -// - req: HTTP 请求对象 -// - proxyURL: 代理地址,空字符串表示直连 -// - accountID: 账户 ID,用于账户级隔离 -// - accountConcurrency: 账户并发限制,用于动态调整连接池大小 -// -// 返回: -// - *http.Response: HTTP 响应(Body 已包装,关闭时自动更新计数) -// - error: 请求错误 -// -// 注意: -// - 调用方必须关闭 resp.Body,否则会导致 inFlight 计数泄漏 -// - inFlight > 0 的客户端不会被淘汰,确保活跃请求不被中断 -func (s *httpUpstreamService) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) { - // 获取或创建对应的客户端,并标记请求占用 - entry, err := s.acquireClient(proxyURL, accountID, accountConcurrency) - if err != nil { - return nil, err - } - - // 执行请求 - resp, err := entry.client.Do(req) - if err != nil { - // 请求失败,立即减少计数 - atomic.AddInt64(&entry.inFlight, -1) - atomic.StoreInt64(&entry.lastUsed, time.Now().UnixNano()) - return nil, err - } - - // 包装响应体,在关闭时自动减少计数并更新时间戳 - // 这确保了流式响应(如 SSE)在完全读取前不会被淘汰 - resp.Body = wrapTrackedBody(resp.Body, func() { - atomic.AddInt64(&entry.inFlight, -1) - atomic.StoreInt64(&entry.lastUsed, time.Now().UnixNano()) - }) - - return resp, nil -} - -// acquireClient 获取或创建客户端,并标记为进行中请求 -// 用于请求路径,避免在获取后被淘汰 -func (s *httpUpstreamService) acquireClient(proxyURL string, accountID int64, accountConcurrency int) (*upstreamClientEntry, error) { - return s.getClientEntry(proxyURL, accountID, accountConcurrency, true, true) -} - -// getOrCreateClient 获取或创建客户端 -// 根据隔离策略和参数决定缓存键,处理代理变更和配置变更 -// -// 参数: -// - proxyURL: 代理地址 -// - accountID: 账户 ID -// - accountConcurrency: 账户并发限制 -// -// 返回: -// - *upstreamClientEntry: 客户端缓存条目 -// -// 隔离策略说明: -// - proxy: 按代理地址隔离,同一代理共享客户端 -// - account: 按账户隔离,同一账户共享客户端(代理变更时重建) -// - account_proxy: 按账户+代理组合隔离,最细粒度 -func (s *httpUpstreamService) getOrCreateClient(proxyURL string, accountID int64, accountConcurrency int) *upstreamClientEntry { - entry, _ := s.getClientEntry(proxyURL, accountID, accountConcurrency, false, false) - return entry -} - -// getClientEntry 获取或创建客户端条目 -// markInFlight=true 时会标记进行中请求,用于请求路径防止被淘汰 -// enforceLimit=true 时会限制客户端数量,超限且无法淘汰时返回错误 -func (s *httpUpstreamService) getClientEntry(proxyURL string, accountID int64, accountConcurrency int, markInFlight bool, enforceLimit bool) (*upstreamClientEntry, error) { - // 获取隔离模式 - isolation := s.getIsolationMode() - // 标准化代理 URL 并解析 - proxyKey, parsedProxy := normalizeProxyURL(proxyURL) - // 构建缓存键(根据隔离策略不同) - cacheKey := buildCacheKey(isolation, proxyKey, accountID) - // 构建连接池配置键(用于检测配置变更) - poolKey := s.buildPoolKey(isolation, accountConcurrency) - - now := time.Now() - nowUnix := now.UnixNano() - - // 读锁快速路径:命中缓存直接返回,减少锁竞争 - s.mu.RLock() - if entry, ok := s.clients[cacheKey]; ok && s.shouldReuseEntry(entry, isolation, proxyKey, poolKey) { - atomic.StoreInt64(&entry.lastUsed, nowUnix) - if markInFlight { - atomic.AddInt64(&entry.inFlight, 1) - } - s.mu.RUnlock() - return entry, nil - } - s.mu.RUnlock() - - // 写锁慢路径:创建或重建客户端 - s.mu.Lock() - if entry, ok := s.clients[cacheKey]; ok { - if s.shouldReuseEntry(entry, isolation, proxyKey, poolKey) { - atomic.StoreInt64(&entry.lastUsed, nowUnix) - if markInFlight { - atomic.AddInt64(&entry.inFlight, 1) - } - s.mu.Unlock() - return entry, nil - } - s.removeClientLocked(cacheKey, entry) - } - - // 超出缓存上限时尝试淘汰,无法淘汰则拒绝新建 - if enforceLimit && s.maxUpstreamClients() > 0 { - s.evictIdleLocked(now) - if len(s.clients) >= s.maxUpstreamClients() { - if !s.evictOldestIdleLocked() { - s.mu.Unlock() - return nil, errUpstreamClientLimitReached - } - } - } - - // 缓存未命中或需要重建,创建新客户端 - settings := s.resolvePoolSettings(isolation, accountConcurrency) - client := &http.Client{Transport: buildUpstreamTransport(settings, parsedProxy)} - entry := &upstreamClientEntry{ - client: client, - proxyKey: proxyKey, - poolKey: poolKey, - } - atomic.StoreInt64(&entry.lastUsed, nowUnix) - if markInFlight { - atomic.StoreInt64(&entry.inFlight, 1) - } - s.clients[cacheKey] = entry - - // 执行淘汰策略:先淘汰空闲超时的,再淘汰超出数量限制的 - s.evictIdleLocked(now) - s.evictOverLimitLocked() - s.mu.Unlock() - return entry, nil -} - -// shouldReuseEntry 判断缓存条目是否可复用 -// 若代理或连接池配置发生变化,则需要重建客户端 -func (s *httpUpstreamService) shouldReuseEntry(entry *upstreamClientEntry, isolation, proxyKey, poolKey string) bool { - if entry == nil { - return false - } - if isolation == config.ConnectionPoolIsolationAccount && entry.proxyKey != proxyKey { - return false - } - if entry.poolKey != poolKey { - return false - } - return true -} - -// removeClientLocked 移除客户端(需持有锁) -// 从缓存中删除并关闭空闲连接 -// -// 参数: -// - key: 缓存键 -// - entry: 客户端条目 -func (s *httpUpstreamService) removeClientLocked(key string, entry *upstreamClientEntry) { - delete(s.clients, key) - if entry != nil && entry.client != nil { - // 关闭空闲连接,释放系统资源 - // 注意:这不会中断活跃连接 - entry.client.CloseIdleConnections() - } -} - -// evictIdleLocked 淘汰空闲超时的客户端(需持有锁) -// 遍历所有客户端,移除超过 TTL 且无活跃请求的条目 -// -// 参数: -// - now: 当前时间 -func (s *httpUpstreamService) evictIdleLocked(now time.Time) { - ttl := s.clientIdleTTL() - if ttl <= 0 { - return - } - // 计算淘汰截止时间 - cutoff := now.Add(-ttl).UnixNano() - for key, entry := range s.clients { - // 跳过有活跃请求的客户端 - if atomic.LoadInt64(&entry.inFlight) != 0 { - continue - } - // 淘汰超时的空闲客户端 - if atomic.LoadInt64(&entry.lastUsed) <= cutoff { - s.removeClientLocked(key, entry) - } - } -} - -// evictOldestIdleLocked 淘汰最久未使用且无活跃请求的客户端(需持有锁) -func (s *httpUpstreamService) evictOldestIdleLocked() bool { - var ( - oldestKey string - oldestEntry *upstreamClientEntry - oldestTime int64 - ) - // 查找最久未使用且无活跃请求的客户端 - for key, entry := range s.clients { - // 跳过有活跃请求的客户端 - if atomic.LoadInt64(&entry.inFlight) != 0 { - continue - } - lastUsed := atomic.LoadInt64(&entry.lastUsed) - if oldestEntry == nil || lastUsed < oldestTime { - oldestKey = key - oldestEntry = entry - oldestTime = lastUsed - } - } - // 所有客户端都有活跃请求,无法淘汰 - if oldestEntry == nil { - return false - } - s.removeClientLocked(oldestKey, oldestEntry) - return true -} - -// evictOverLimitLocked 淘汰超出数量限制的客户端(需持有锁) -// 使用 LRU 策略,优先淘汰最久未使用且无活跃请求的客户端 -func (s *httpUpstreamService) evictOverLimitLocked() bool { - maxClients := s.maxUpstreamClients() - if maxClients <= 0 { - return false - } - evicted := false - // 循环淘汰直到满足数量限制 - for len(s.clients) > maxClients { - if !s.evictOldestIdleLocked() { - return evicted - } - evicted = true - } - return evicted -} - -// getIsolationMode 获取连接池隔离模式 -// 从配置中读取,无效值回退到 account_proxy 模式 -// -// 返回: -// - string: 隔离模式(proxy/account/account_proxy) -func (s *httpUpstreamService) getIsolationMode() string { - if s.cfg == nil { - return config.ConnectionPoolIsolationAccountProxy - } - mode := strings.ToLower(strings.TrimSpace(s.cfg.Gateway.ConnectionPoolIsolation)) - if mode == "" { - return config.ConnectionPoolIsolationAccountProxy - } - switch mode { - case config.ConnectionPoolIsolationProxy, config.ConnectionPoolIsolationAccount, config.ConnectionPoolIsolationAccountProxy: - return mode - default: - return config.ConnectionPoolIsolationAccountProxy - } -} - -// maxUpstreamClients 获取最大客户端缓存数量 -// 从配置中读取,无效值使用默认值 -func (s *httpUpstreamService) maxUpstreamClients() int { - if s.cfg == nil { - return defaultMaxUpstreamClients - } - if s.cfg.Gateway.MaxUpstreamClients > 0 { - return s.cfg.Gateway.MaxUpstreamClients - } - return defaultMaxUpstreamClients -} - -// clientIdleTTL 获取客户端空闲回收阈值 -// 从配置中读取,无效值使用默认值 -func (s *httpUpstreamService) clientIdleTTL() time.Duration { - if s.cfg == nil { - return time.Duration(defaultClientIdleTTLSeconds) * time.Second - } - if s.cfg.Gateway.ClientIdleTTLSeconds > 0 { - return time.Duration(s.cfg.Gateway.ClientIdleTTLSeconds) * time.Second - } - return time.Duration(defaultClientIdleTTLSeconds) * time.Second -} - -// resolvePoolSettings 解析连接池配置 -// 根据隔离策略和账户并发数动态调整连接池参数 -// -// 参数: -// - isolation: 隔离模式 -// - accountConcurrency: 账户并发限制 -// -// 返回: -// - poolSettings: 连接池配置 -// -// 说明: -// - 账户隔离模式下,连接池大小与账户并发数对应 -// - 这确保了单账户不会占用过多连接资源 -func (s *httpUpstreamService) resolvePoolSettings(isolation string, accountConcurrency int) poolSettings { - settings := defaultPoolSettings(s.cfg) - // 账户隔离模式下,根据账户并发数调整连接池大小 - if (isolation == config.ConnectionPoolIsolationAccount || isolation == config.ConnectionPoolIsolationAccountProxy) && accountConcurrency > 0 { - settings.maxIdleConns = accountConcurrency - settings.maxIdleConnsPerHost = accountConcurrency - settings.maxConnsPerHost = accountConcurrency - } - return settings -} - -// buildPoolKey 构建连接池配置键 -// 用于检测配置变更,配置变更时需要重建客户端 -// -// 参数: -// - isolation: 隔离模式 -// - accountConcurrency: 账户并发限制 -// -// 返回: -// - string: 配置键 -func (s *httpUpstreamService) buildPoolKey(isolation string, accountConcurrency int) string { - if isolation == config.ConnectionPoolIsolationAccount || isolation == config.ConnectionPoolIsolationAccountProxy { - if accountConcurrency > 0 { - return fmt.Sprintf("account:%d", accountConcurrency) - } - } - return "default" -} - -// buildCacheKey 构建客户端缓存键 -// 根据隔离策略决定缓存键的组成 -// -// 参数: -// - isolation: 隔离模式 -// - proxyKey: 代理标识 -// - accountID: 账户 ID -// -// 返回: -// - string: 缓存键 -// -// 缓存键格式: -// - proxy 模式: "proxy:{proxyKey}" -// - account 模式: "account:{accountID}" -// - account_proxy 模式: "account:{accountID}|proxy:{proxyKey}" -func buildCacheKey(isolation, proxyKey string, accountID int64) string { - switch isolation { - case config.ConnectionPoolIsolationAccount: - return fmt.Sprintf("account:%d", accountID) - case config.ConnectionPoolIsolationAccountProxy: - return fmt.Sprintf("account:%d|proxy:%s", accountID, proxyKey) - default: - return fmt.Sprintf("proxy:%s", proxyKey) - } -} - -// normalizeProxyURL 标准化代理 URL -// 处理空值和解析错误,返回标准化的键和解析后的 URL -// -// 参数: -// - raw: 原始代理 URL 字符串 -// -// 返回: -// - string: 标准化的代理键(空或解析失败返回 "direct") -// - *url.URL: 解析后的 URL(空或解析失败返回 nil) -func normalizeProxyURL(raw string) (string, *url.URL) { - proxyURL := strings.TrimSpace(raw) - if proxyURL == "" { - return directProxyKey, nil - } - parsed, err := url.Parse(proxyURL) - if err != nil { - return directProxyKey, nil - } - parsed.Scheme = strings.ToLower(parsed.Scheme) - parsed.Host = strings.ToLower(parsed.Host) - parsed.Path = "" - parsed.RawPath = "" - parsed.RawQuery = "" - parsed.Fragment = "" - parsed.ForceQuery = false - if hostname := parsed.Hostname(); hostname != "" { - port := parsed.Port() - if (parsed.Scheme == "http" && port == "80") || (parsed.Scheme == "https" && port == "443") { - port = "" - } - hostname = strings.ToLower(hostname) - if port != "" { - parsed.Host = net.JoinHostPort(hostname, port) - } else { - parsed.Host = hostname - } - } - return parsed.String(), parsed -} - -// defaultPoolSettings 获取默认连接池配置 -// 从全局配置中读取,无效值使用常量默认值 -// -// 参数: -// - cfg: 全局配置 -// -// 返回: -// - poolSettings: 连接池配置 -func defaultPoolSettings(cfg *config.Config) poolSettings { - maxIdleConns := defaultMaxIdleConns - maxIdleConnsPerHost := defaultMaxIdleConnsPerHost - maxConnsPerHost := defaultMaxConnsPerHost - idleConnTimeout := defaultIdleConnTimeout - responseHeaderTimeout := defaultResponseHeaderTimeout - - if cfg != nil { - if cfg.Gateway.MaxIdleConns > 0 { - maxIdleConns = cfg.Gateway.MaxIdleConns - } - if cfg.Gateway.MaxIdleConnsPerHost > 0 { - maxIdleConnsPerHost = cfg.Gateway.MaxIdleConnsPerHost - } - if cfg.Gateway.MaxConnsPerHost >= 0 { - maxConnsPerHost = cfg.Gateway.MaxConnsPerHost - } - if cfg.Gateway.IdleConnTimeoutSeconds > 0 { - idleConnTimeout = time.Duration(cfg.Gateway.IdleConnTimeoutSeconds) * time.Second - } - if cfg.Gateway.ResponseHeaderTimeout > 0 { - responseHeaderTimeout = time.Duration(cfg.Gateway.ResponseHeaderTimeout) * time.Second - } - } - - return poolSettings{ - maxIdleConns: maxIdleConns, - maxIdleConnsPerHost: maxIdleConnsPerHost, - maxConnsPerHost: maxConnsPerHost, - idleConnTimeout: idleConnTimeout, - responseHeaderTimeout: responseHeaderTimeout, - } -} - -// buildUpstreamTransport 构建上游请求的 Transport -// 使用配置文件中的连接池参数,支持生产环境调优 -// -// 参数: -// - settings: 连接池配置 -// - proxyURL: 代理 URL(nil 表示直连) -// -// 返回: -// - *http.Transport: 配置好的 Transport 实例 -// -// Transport 参数说明: -// - MaxIdleConns: 所有主机的最大空闲连接总数 -// - MaxIdleConnsPerHost: 每主机最大空闲连接数(影响连接复用率) -// - MaxConnsPerHost: 每主机最大连接数(达到后新请求等待) -// - IdleConnTimeout: 空闲连接超时(超时后关闭) -// - ResponseHeaderTimeout: 等待响应头超时(不影响流式传输) -func buildUpstreamTransport(settings poolSettings, proxyURL *url.URL) *http.Transport { - transport := &http.Transport{ - MaxIdleConns: settings.maxIdleConns, - MaxIdleConnsPerHost: settings.maxIdleConnsPerHost, - MaxConnsPerHost: settings.maxConnsPerHost, - IdleConnTimeout: settings.idleConnTimeout, - ResponseHeaderTimeout: settings.responseHeaderTimeout, - } - if proxyURL != nil { - transport.Proxy = http.ProxyURL(proxyURL) - } - return transport -} - -// trackedBody 带跟踪功能的响应体包装器 -// 在 Close 时执行回调,用于更新请求计数 -type trackedBody struct { - io.ReadCloser // 原始响应体 - once sync.Once - onClose func() // 关闭时的回调函数 -} - -// Close 关闭响应体并执行回调 -// 使用 sync.Once 确保回调只执行一次 -func (b *trackedBody) Close() error { - err := b.ReadCloser.Close() - if b.onClose != nil { - b.once.Do(b.onClose) - } - return err -} - -// wrapTrackedBody 包装响应体以跟踪关闭事件 -// 用于在响应体关闭时更新 inFlight 计数 -// -// 参数: -// - body: 原始响应体 -// - onClose: 关闭时的回调函数 -// -// 返回: -// - io.ReadCloser: 包装后的响应体 -func wrapTrackedBody(body io.ReadCloser, onClose func()) io.ReadCloser { - if body == nil { - return body - } - return &trackedBody{ReadCloser: body, onClose: onClose} -} +package repository + +import ( + "errors" + "fmt" + "io" + "net" + "net/http" + "net/url" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil" + "github.com/Wei-Shaw/sub2api/internal/service" +) + +// 默认配置常量 +// 这些值在配置文件未指定时作为回退默认值使用 +const ( + // directProxyKey: 无代理时的缓存键标识 + directProxyKey = "direct" + // defaultMaxIdleConns: 默认最大空闲连接总数 + // HTTP/2 场景下,单连接可多路复用,240 足以支撑高并发 + defaultMaxIdleConns = 240 + // defaultMaxIdleConnsPerHost: 默认每主机最大空闲连接数 + defaultMaxIdleConnsPerHost = 120 + // defaultMaxConnsPerHost: 默认每主机最大连接数(含活跃连接) + // 达到上限后新请求会等待,而非无限创建连接 + defaultMaxConnsPerHost = 240 + // defaultIdleConnTimeout: 默认空闲连接超时时间(5分钟) + // 超时后连接会被关闭,释放系统资源 + defaultIdleConnTimeout = 300 * time.Second + // defaultResponseHeaderTimeout: 默认等待响应头超时时间(5分钟) + // LLM 请求可能排队较久,需要较长超时 + defaultResponseHeaderTimeout = 300 * time.Second + // defaultMaxUpstreamClients: 默认最大客户端缓存数量 + // 超出后会淘汰最久未使用的客户端 + defaultMaxUpstreamClients = 5000 + // defaultClientIdleTTLSeconds: 默认客户端空闲回收阈值(15分钟) + defaultClientIdleTTLSeconds = 900 +) + +var errUpstreamClientLimitReached = errors.New("upstream client cache limit reached") + +// poolSettings 连接池配置参数 +// 封装 Transport 所需的各项连接池参数 +type poolSettings struct { + maxIdleConns int // 最大空闲连接总数 + maxIdleConnsPerHost int // 每主机最大空闲连接数 + maxConnsPerHost int // 每主机最大连接数(含活跃) + idleConnTimeout time.Duration // 空闲连接超时时间 + responseHeaderTimeout time.Duration // 等待响应头超时时间 +} + +// upstreamClientEntry 上游客户端缓存条目 +// 记录客户端实例及其元数据,用于连接池管理和淘汰策略 +type upstreamClientEntry struct { + client *http.Client // HTTP 客户端实例 + proxyKey string // 代理标识(用于检测代理变更) + poolKey string // 连接池配置标识(用于检测配置变更) + lastUsed int64 // 最后使用时间戳(纳秒),用于 LRU 淘汰 + inFlight int64 // 当前进行中的请求数,>0 时不可淘汰 +} + +// httpUpstreamService 通用 HTTP 上游服务 +// 用于向任意 HTTP API(Claude、OpenAI 等)发送请求,支持可选代理 +// +// 架构设计: +// - 根据隔离策略(proxy/account/account_proxy)缓存客户端实例 +// - 每个客户端拥有独立的 Transport 连接池 +// - 支持 LRU + 空闲时间双重淘汰策略 +// +// 性能优化: +// 1. 根据隔离策略缓存客户端实例,避免频繁创建 http.Client +// 2. 复用 Transport 连接池,减少 TCP 握手和 TLS 协商开销 +// 3. 支持账号级隔离与空闲回收,降低连接层关联风险 +// 4. 达到最大连接数后等待可用连接,而非无限创建 +// 5. 仅回收空闲客户端,避免中断活跃请求 +// 6. HTTP/2 多路复用,连接上限不等于并发请求上限 +// 7. 代理变更时清空旧连接池,避免复用错误代理 +// 8. 账号并发数与连接池上限对应(账号隔离策略下) +type httpUpstreamService struct { + cfg *config.Config // 全局配置 + mu sync.RWMutex // 保护 clients map 的读写锁 + clients map[string]*upstreamClientEntry // 客户端缓存池,key 由隔离策略决定 +} + +// NewHTTPUpstream 创建通用 HTTP 上游服务 +// 使用配置中的连接池参数构建 Transport +// +// 参数: +// - cfg: 全局配置,包含连接池参数和隔离策略 +// +// 返回: +// - service.HTTPUpstream 接口实现 +func NewHTTPUpstream(cfg *config.Config) service.HTTPUpstream { + return &httpUpstreamService{ + cfg: cfg, + clients: make(map[string]*upstreamClientEntry), + } +} + +// Do 执行 HTTP 请求 +// 根据隔离策略获取或创建客户端,并跟踪请求生命周期 +// +// 参数: +// - req: HTTP 请求对象 +// - proxyURL: 代理地址,空字符串表示直连 +// - accountID: 账户 ID,用于账户级隔离 +// - accountConcurrency: 账户并发限制,用于动态调整连接池大小 +// +// 返回: +// - *http.Response: HTTP 响应(Body 已包装,关闭时自动更新计数) +// - error: 请求错误 +// +// 注意: +// - 调用方必须关闭 resp.Body,否则会导致 inFlight 计数泄漏 +// - inFlight > 0 的客户端不会被淘汰,确保活跃请求不被中断 +func (s *httpUpstreamService) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) { + // 获取或创建对应的客户端,并标记请求占用 + entry, err := s.acquireClient(proxyURL, accountID, accountConcurrency) + if err != nil { + return nil, err + } + + // 执行请求 + resp, err := entry.client.Do(req) + if err != nil { + // 请求失败,立即减少计数 + atomic.AddInt64(&entry.inFlight, -1) + atomic.StoreInt64(&entry.lastUsed, time.Now().UnixNano()) + return nil, err + } + + // 包装响应体,在关闭时自动减少计数并更新时间戳 + // 这确保了流式响应(如 SSE)在完全读取前不会被淘汰 + resp.Body = wrapTrackedBody(resp.Body, func() { + atomic.AddInt64(&entry.inFlight, -1) + atomic.StoreInt64(&entry.lastUsed, time.Now().UnixNano()) + }) + + return resp, nil +} + +// acquireClient 获取或创建客户端,并标记为进行中请求 +// 用于请求路径,避免在获取后被淘汰 +func (s *httpUpstreamService) acquireClient(proxyURL string, accountID int64, accountConcurrency int) (*upstreamClientEntry, error) { + return s.getClientEntry(proxyURL, accountID, accountConcurrency, true, true) +} + +// getOrCreateClient 获取或创建客户端 +// 根据隔离策略和参数决定缓存键,处理代理变更和配置变更 +// +// 参数: +// - proxyURL: 代理地址 +// - accountID: 账户 ID +// - accountConcurrency: 账户并发限制 +// +// 返回: +// - *upstreamClientEntry: 客户端缓存条目 +// +// 隔离策略说明: +// - proxy: 按代理地址隔离,同一代理共享客户端 +// - account: 按账户隔离,同一账户共享客户端(代理变更时重建) +// - account_proxy: 按账户+代理组合隔离,最细粒度 +func (s *httpUpstreamService) getOrCreateClient(proxyURL string, accountID int64, accountConcurrency int) *upstreamClientEntry { + entry, _ := s.getClientEntry(proxyURL, accountID, accountConcurrency, false, false) + return entry +} + +// getClientEntry 获取或创建客户端条目 +// markInFlight=true 时会标记进行中请求,用于请求路径防止被淘汰 +// enforceLimit=true 时会限制客户端数量,超限且无法淘汰时返回错误 +func (s *httpUpstreamService) getClientEntry(proxyURL string, accountID int64, accountConcurrency int, markInFlight bool, enforceLimit bool) (*upstreamClientEntry, error) { + // 获取隔离模式 + isolation := s.getIsolationMode() + // 标准化代理 URL 并解析 + proxyKey, parsedProxy := normalizeProxyURL(proxyURL) + // 构建缓存键(根据隔离策略不同) + cacheKey := buildCacheKey(isolation, proxyKey, accountID) + // 构建连接池配置键(用于检测配置变更) + poolKey := s.buildPoolKey(isolation, accountConcurrency) + + now := time.Now() + nowUnix := now.UnixNano() + + // 读锁快速路径:命中缓存直接返回,减少锁竞争 + s.mu.RLock() + if entry, ok := s.clients[cacheKey]; ok && s.shouldReuseEntry(entry, isolation, proxyKey, poolKey) { + atomic.StoreInt64(&entry.lastUsed, nowUnix) + if markInFlight { + atomic.AddInt64(&entry.inFlight, 1) + } + s.mu.RUnlock() + return entry, nil + } + s.mu.RUnlock() + + // 写锁慢路径:创建或重建客户端 + s.mu.Lock() + if entry, ok := s.clients[cacheKey]; ok { + if s.shouldReuseEntry(entry, isolation, proxyKey, poolKey) { + atomic.StoreInt64(&entry.lastUsed, nowUnix) + if markInFlight { + atomic.AddInt64(&entry.inFlight, 1) + } + s.mu.Unlock() + return entry, nil + } + s.removeClientLocked(cacheKey, entry) + } + + // 超出缓存上限时尝试淘汰,无法淘汰则拒绝新建 + if enforceLimit && s.maxUpstreamClients() > 0 { + s.evictIdleLocked(now) + if len(s.clients) >= s.maxUpstreamClients() { + if !s.evictOldestIdleLocked() { + s.mu.Unlock() + return nil, errUpstreamClientLimitReached + } + } + } + + // 缓存未命中或需要重建,创建新客户端 + settings := s.resolvePoolSettings(isolation, accountConcurrency) + transport, err := buildUpstreamTransport(settings, parsedProxy) + if err != nil { + s.mu.Unlock() + return nil, fmt.Errorf("build transport: %w", err) + } + client := &http.Client{Transport: transport} + entry := &upstreamClientEntry{ + client: client, + proxyKey: proxyKey, + poolKey: poolKey, + } + atomic.StoreInt64(&entry.lastUsed, nowUnix) + if markInFlight { + atomic.StoreInt64(&entry.inFlight, 1) + } + s.clients[cacheKey] = entry + + // 执行淘汰策略:先淘汰空闲超时的,再淘汰超出数量限制的 + s.evictIdleLocked(now) + s.evictOverLimitLocked() + s.mu.Unlock() + return entry, nil +} + +// shouldReuseEntry 判断缓存条目是否可复用 +// 若代理或连接池配置发生变化,则需要重建客户端 +func (s *httpUpstreamService) shouldReuseEntry(entry *upstreamClientEntry, isolation, proxyKey, poolKey string) bool { + if entry == nil { + return false + } + if isolation == config.ConnectionPoolIsolationAccount && entry.proxyKey != proxyKey { + return false + } + if entry.poolKey != poolKey { + return false + } + return true +} + +// removeClientLocked 移除客户端(需持有锁) +// 从缓存中删除并关闭空闲连接 +// +// 参数: +// - key: 缓存键 +// - entry: 客户端条目 +func (s *httpUpstreamService) removeClientLocked(key string, entry *upstreamClientEntry) { + delete(s.clients, key) + if entry != nil && entry.client != nil { + // 关闭空闲连接,释放系统资源 + // 注意:这不会中断活跃连接 + entry.client.CloseIdleConnections() + } +} + +// evictIdleLocked 淘汰空闲超时的客户端(需持有锁) +// 遍历所有客户端,移除超过 TTL 且无活跃请求的条目 +// +// 参数: +// - now: 当前时间 +func (s *httpUpstreamService) evictIdleLocked(now time.Time) { + ttl := s.clientIdleTTL() + if ttl <= 0 { + return + } + // 计算淘汰截止时间 + cutoff := now.Add(-ttl).UnixNano() + for key, entry := range s.clients { + // 跳过有活跃请求的客户端 + if atomic.LoadInt64(&entry.inFlight) != 0 { + continue + } + // 淘汰超时的空闲客户端 + if atomic.LoadInt64(&entry.lastUsed) <= cutoff { + s.removeClientLocked(key, entry) + } + } +} + +// evictOldestIdleLocked 淘汰最久未使用且无活跃请求的客户端(需持有锁) +func (s *httpUpstreamService) evictOldestIdleLocked() bool { + var ( + oldestKey string + oldestEntry *upstreamClientEntry + oldestTime int64 + ) + // 查找最久未使用且无活跃请求的客户端 + for key, entry := range s.clients { + // 跳过有活跃请求的客户端 + if atomic.LoadInt64(&entry.inFlight) != 0 { + continue + } + lastUsed := atomic.LoadInt64(&entry.lastUsed) + if oldestEntry == nil || lastUsed < oldestTime { + oldestKey = key + oldestEntry = entry + oldestTime = lastUsed + } + } + // 所有客户端都有活跃请求,无法淘汰 + if oldestEntry == nil { + return false + } + s.removeClientLocked(oldestKey, oldestEntry) + return true +} + +// evictOverLimitLocked 淘汰超出数量限制的客户端(需持有锁) +// 使用 LRU 策略,优先淘汰最久未使用且无活跃请求的客户端 +func (s *httpUpstreamService) evictOverLimitLocked() bool { + maxClients := s.maxUpstreamClients() + if maxClients <= 0 { + return false + } + evicted := false + // 循环淘汰直到满足数量限制 + for len(s.clients) > maxClients { + if !s.evictOldestIdleLocked() { + return evicted + } + evicted = true + } + return evicted +} + +// getIsolationMode 获取连接池隔离模式 +// 从配置中读取,无效值回退到 account_proxy 模式 +// +// 返回: +// - string: 隔离模式(proxy/account/account_proxy) +func (s *httpUpstreamService) getIsolationMode() string { + if s.cfg == nil { + return config.ConnectionPoolIsolationAccountProxy + } + mode := strings.ToLower(strings.TrimSpace(s.cfg.Gateway.ConnectionPoolIsolation)) + if mode == "" { + return config.ConnectionPoolIsolationAccountProxy + } + switch mode { + case config.ConnectionPoolIsolationProxy, config.ConnectionPoolIsolationAccount, config.ConnectionPoolIsolationAccountProxy: + return mode + default: + return config.ConnectionPoolIsolationAccountProxy + } +} + +// maxUpstreamClients 获取最大客户端缓存数量 +// 从配置中读取,无效值使用默认值 +func (s *httpUpstreamService) maxUpstreamClients() int { + if s.cfg == nil { + return defaultMaxUpstreamClients + } + if s.cfg.Gateway.MaxUpstreamClients > 0 { + return s.cfg.Gateway.MaxUpstreamClients + } + return defaultMaxUpstreamClients +} + +// clientIdleTTL 获取客户端空闲回收阈值 +// 从配置中读取,无效值使用默认值 +func (s *httpUpstreamService) clientIdleTTL() time.Duration { + if s.cfg == nil { + return time.Duration(defaultClientIdleTTLSeconds) * time.Second + } + if s.cfg.Gateway.ClientIdleTTLSeconds > 0 { + return time.Duration(s.cfg.Gateway.ClientIdleTTLSeconds) * time.Second + } + return time.Duration(defaultClientIdleTTLSeconds) * time.Second +} + +// resolvePoolSettings 解析连接池配置 +// 根据隔离策略和账户并发数动态调整连接池参数 +// +// 参数: +// - isolation: 隔离模式 +// - accountConcurrency: 账户并发限制 +// +// 返回: +// - poolSettings: 连接池配置 +// +// 说明: +// - 账户隔离模式下,连接池大小与账户并发数对应 +// - 这确保了单账户不会占用过多连接资源 +func (s *httpUpstreamService) resolvePoolSettings(isolation string, accountConcurrency int) poolSettings { + settings := defaultPoolSettings(s.cfg) + // 账户隔离模式下,根据账户并发数调整连接池大小 + if (isolation == config.ConnectionPoolIsolationAccount || isolation == config.ConnectionPoolIsolationAccountProxy) && accountConcurrency > 0 { + settings.maxIdleConns = accountConcurrency + settings.maxIdleConnsPerHost = accountConcurrency + settings.maxConnsPerHost = accountConcurrency + } + return settings +} + +// buildPoolKey 构建连接池配置键 +// 用于检测配置变更,配置变更时需要重建客户端 +// +// 参数: +// - isolation: 隔离模式 +// - accountConcurrency: 账户并发限制 +// +// 返回: +// - string: 配置键 +func (s *httpUpstreamService) buildPoolKey(isolation string, accountConcurrency int) string { + if isolation == config.ConnectionPoolIsolationAccount || isolation == config.ConnectionPoolIsolationAccountProxy { + if accountConcurrency > 0 { + return fmt.Sprintf("account:%d", accountConcurrency) + } + } + return "default" +} + +// buildCacheKey 构建客户端缓存键 +// 根据隔离策略决定缓存键的组成 +// +// 参数: +// - isolation: 隔离模式 +// - proxyKey: 代理标识 +// - accountID: 账户 ID +// +// 返回: +// - string: 缓存键 +// +// 缓存键格式: +// - proxy 模式: "proxy:{proxyKey}" +// - account 模式: "account:{accountID}" +// - account_proxy 模式: "account:{accountID}|proxy:{proxyKey}" +func buildCacheKey(isolation, proxyKey string, accountID int64) string { + switch isolation { + case config.ConnectionPoolIsolationAccount: + return fmt.Sprintf("account:%d", accountID) + case config.ConnectionPoolIsolationAccountProxy: + return fmt.Sprintf("account:%d|proxy:%s", accountID, proxyKey) + default: + return fmt.Sprintf("proxy:%s", proxyKey) + } +} + +// normalizeProxyURL 标准化代理 URL +// 处理空值和解析错误,返回标准化的键和解析后的 URL +// +// 参数: +// - raw: 原始代理 URL 字符串 +// +// 返回: +// - string: 标准化的代理键(空或解析失败返回 "direct") +// - *url.URL: 解析后的 URL(空或解析失败返回 nil) +func normalizeProxyURL(raw string) (string, *url.URL) { + proxyURL := strings.TrimSpace(raw) + if proxyURL == "" { + return directProxyKey, nil + } + parsed, err := url.Parse(proxyURL) + if err != nil { + return directProxyKey, nil + } + parsed.Scheme = strings.ToLower(parsed.Scheme) + parsed.Host = strings.ToLower(parsed.Host) + parsed.Path = "" + parsed.RawPath = "" + parsed.RawQuery = "" + parsed.Fragment = "" + parsed.ForceQuery = false + if hostname := parsed.Hostname(); hostname != "" { + port := parsed.Port() + if (parsed.Scheme == "http" && port == "80") || (parsed.Scheme == "https" && port == "443") { + port = "" + } + hostname = strings.ToLower(hostname) + if port != "" { + parsed.Host = net.JoinHostPort(hostname, port) + } else { + parsed.Host = hostname + } + } + return parsed.String(), parsed +} + +// defaultPoolSettings 获取默认连接池配置 +// 从全局配置中读取,无效值使用常量默认值 +// +// 参数: +// - cfg: 全局配置 +// +// 返回: +// - poolSettings: 连接池配置 +func defaultPoolSettings(cfg *config.Config) poolSettings { + maxIdleConns := defaultMaxIdleConns + maxIdleConnsPerHost := defaultMaxIdleConnsPerHost + maxConnsPerHost := defaultMaxConnsPerHost + idleConnTimeout := defaultIdleConnTimeout + responseHeaderTimeout := defaultResponseHeaderTimeout + + if cfg != nil { + if cfg.Gateway.MaxIdleConns > 0 { + maxIdleConns = cfg.Gateway.MaxIdleConns + } + if cfg.Gateway.MaxIdleConnsPerHost > 0 { + maxIdleConnsPerHost = cfg.Gateway.MaxIdleConnsPerHost + } + if cfg.Gateway.MaxConnsPerHost >= 0 { + maxConnsPerHost = cfg.Gateway.MaxConnsPerHost + } + if cfg.Gateway.IdleConnTimeoutSeconds > 0 { + idleConnTimeout = time.Duration(cfg.Gateway.IdleConnTimeoutSeconds) * time.Second + } + if cfg.Gateway.ResponseHeaderTimeout > 0 { + responseHeaderTimeout = time.Duration(cfg.Gateway.ResponseHeaderTimeout) * time.Second + } + } + + return poolSettings{ + maxIdleConns: maxIdleConns, + maxIdleConnsPerHost: maxIdleConnsPerHost, + maxConnsPerHost: maxConnsPerHost, + idleConnTimeout: idleConnTimeout, + responseHeaderTimeout: responseHeaderTimeout, + } +} + +// buildUpstreamTransport 构建上游请求的 Transport +// 使用配置文件中的连接池参数,支持生产环境调优 +// +// 参数: +// - settings: 连接池配置 +// - proxyURL: 代理 URL(nil 表示直连) +// +// 返回: +// - *http.Transport: 配置好的 Transport 实例 +// - error: 代理配置错误 +// +// Transport 参数说明: +// - MaxIdleConns: 所有主机的最大空闲连接总数 +// - MaxIdleConnsPerHost: 每主机最大空闲连接数(影响连接复用率) +// - MaxConnsPerHost: 每主机最大连接数(达到后新请求等待) +// - IdleConnTimeout: 空闲连接超时(超时后关闭) +// - ResponseHeaderTimeout: 等待响应头超时(不影响流式传输) +func buildUpstreamTransport(settings poolSettings, proxyURL *url.URL) (*http.Transport, error) { + transport := &http.Transport{ + MaxIdleConns: settings.maxIdleConns, + MaxIdleConnsPerHost: settings.maxIdleConnsPerHost, + MaxConnsPerHost: settings.maxConnsPerHost, + IdleConnTimeout: settings.idleConnTimeout, + ResponseHeaderTimeout: settings.responseHeaderTimeout, + } + if err := proxyutil.ConfigureTransportProxy(transport, proxyURL); err != nil { + return nil, err + } + return transport, nil +} + +// trackedBody 带跟踪功能的响应体包装器 +// 在 Close 时执行回调,用于更新请求计数 +type trackedBody struct { + io.ReadCloser // 原始响应体 + once sync.Once + onClose func() // 关闭时的回调函数 +} + +// Close 关闭响应体并执行回调 +// 使用 sync.Once 确保回调只执行一次 +func (b *trackedBody) Close() error { + err := b.ReadCloser.Close() + if b.onClose != nil { + b.once.Do(b.onClose) + } + return err +} + +// wrapTrackedBody 包装响应体以跟踪关闭事件 +// 用于在响应体关闭时更新 inFlight 计数 +// +// 参数: +// - body: 原始响应体 +// - onClose: 关闭时的回调函数 +// +// 返回: +// - io.ReadCloser: 包装后的响应体 +func wrapTrackedBody(body io.ReadCloser, onClose func()) io.ReadCloser { + if body == nil { + return body + } + return &trackedBody{ReadCloser: body, onClose: onClose} +} diff --git a/backend/internal/repository/http_upstream_benchmark_test.go b/backend/internal/repository/http_upstream_benchmark_test.go index c434a85c..1e7430a3 100644 --- a/backend/internal/repository/http_upstream_benchmark_test.go +++ b/backend/internal/repository/http_upstream_benchmark_test.go @@ -1,66 +1,70 @@ -package repository - -import ( - "net/http" - "net/url" - "testing" - - "github.com/Wei-Shaw/sub2api/internal/config" -) - -// httpClientSink 用于防止编译器优化掉基准测试中的赋值操作 -// 这是 Go 基准测试的常见模式,确保测试结果准确 -var httpClientSink *http.Client - -// BenchmarkHTTPUpstreamProxyClient 对比重复创建与复用代理客户端的开销 -// -// 测试目的: -// - 验证连接池复用相比每次新建的性能提升 -// - 量化内存分配差异 -// -// 预期结果: -// - "复用" 子测试应显著快于 "新建" -// - "复用" 子测试应零内存分配 -func BenchmarkHTTPUpstreamProxyClient(b *testing.B) { - // 创建测试配置 - cfg := &config.Config{ - Gateway: config.GatewayConfig{ResponseHeaderTimeout: 300}, - } - upstream := NewHTTPUpstream(cfg) - svc, ok := upstream.(*httpUpstreamService) - if !ok { - b.Fatalf("类型断言失败,无法获取 httpUpstreamService") - } - - proxyURL := "http://127.0.0.1:8080" - b.ReportAllocs() // 报告内存分配统计 - - // 子测试:每次新建客户端 - // 模拟未优化前的行为,每次请求都创建新的 http.Client - b.Run("新建", func(b *testing.B) { - parsedProxy, err := url.Parse(proxyURL) - if err != nil { - b.Fatalf("解析代理地址失败: %v", err) - } - settings := defaultPoolSettings(cfg) - for i := 0; i < b.N; i++ { - // 每次迭代都创建新客户端,包含 Transport 分配 - httpClientSink = &http.Client{ - Transport: buildUpstreamTransport(settings, parsedProxy), - } - } - }) - - // 子测试:复用已缓存的客户端 - // 模拟优化后的行为,从缓存获取客户端 - b.Run("复用", func(b *testing.B) { - // 预热:确保客户端已缓存 - entry := svc.getOrCreateClient(proxyURL, 1, 1) - client := entry.client - b.ResetTimer() // 重置计时器,排除预热时间 - for i := 0; i < b.N; i++ { - // 直接使用缓存的客户端,无内存分配 - httpClientSink = client - } - }) -} +package repository + +import ( + "net/http" + "net/url" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" +) + +// httpClientSink 用于防止编译器优化掉基准测试中的赋值操作 +// 这是 Go 基准测试的常见模式,确保测试结果准确 +var httpClientSink *http.Client + +// BenchmarkHTTPUpstreamProxyClient 对比重复创建与复用代理客户端的开销 +// +// 测试目的: +// - 验证连接池复用相比每次新建的性能提升 +// - 量化内存分配差异 +// +// 预期结果: +// - "复用" 子测试应显著快于 "新建" +// - "复用" 子测试应零内存分配 +func BenchmarkHTTPUpstreamProxyClient(b *testing.B) { + // 创建测试配置 + cfg := &config.Config{ + Gateway: config.GatewayConfig{ResponseHeaderTimeout: 300}, + } + upstream := NewHTTPUpstream(cfg) + svc, ok := upstream.(*httpUpstreamService) + if !ok { + b.Fatalf("类型断言失败,无法获取 httpUpstreamService") + } + + proxyURL := "http://127.0.0.1:8080" + b.ReportAllocs() // 报告内存分配统计 + + // 子测试:每次新建客户端 + // 模拟未优化前的行为,每次请求都创建新的 http.Client + b.Run("新建", func(b *testing.B) { + parsedProxy, err := url.Parse(proxyURL) + if err != nil { + b.Fatalf("解析代理地址失败: %v", err) + } + settings := defaultPoolSettings(cfg) + for i := 0; i < b.N; i++ { + // 每次迭代都创建新客户端,包含 Transport 分配 + transport, err := buildUpstreamTransport(settings, parsedProxy) + if err != nil { + b.Fatalf("创建 Transport 失败: %v", err) + } + httpClientSink = &http.Client{ + Transport: transport, + } + } + }) + + // 子测试:复用已缓存的客户端 + // 模拟优化后的行为,从缓存获取客户端 + b.Run("复用", func(b *testing.B) { + // 预热:确保客户端已缓存 + entry := svc.getOrCreateClient(proxyURL, 1, 1) + client := entry.client + b.ResetTimer() // 重置计时器,排除预热时间 + for i := 0; i < b.N; i++ { + // 直接使用缓存的客户端,无内存分配 + httpClientSink = client + } + }) +} diff --git a/backend/internal/repository/proxy_probe_service.go b/backend/internal/repository/proxy_probe_service.go index 181976ed..f5f625f9 100644 --- a/backend/internal/repository/proxy_probe_service.go +++ b/backend/internal/repository/proxy_probe_service.go @@ -1,76 +1,75 @@ -package repository - -import ( - "context" - "encoding/json" - "fmt" - "io" - "net/http" - "time" - - "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" - "github.com/Wei-Shaw/sub2api/internal/service" -) - -func NewProxyExitInfoProber() service.ProxyExitInfoProber { - return &proxyProbeService{ipInfoURL: defaultIPInfoURL} -} - -const defaultIPInfoURL = "https://ipinfo.io/json" - -type proxyProbeService struct { - ipInfoURL string -} - -func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*service.ProxyExitInfo, int64, error) { - client, err := httpclient.GetClient(httpclient.Options{ - ProxyURL: proxyURL, - Timeout: 15 * time.Second, - InsecureSkipVerify: true, - ProxyStrict: true, - }) - if err != nil { - return nil, 0, fmt.Errorf("failed to create proxy client: %w", err) - } - - startTime := time.Now() - req, err := http.NewRequestWithContext(ctx, "GET", s.ipInfoURL, nil) - if err != nil { - return nil, 0, fmt.Errorf("failed to create request: %w", err) - } - - resp, err := client.Do(req) - if err != nil { - return nil, 0, fmt.Errorf("proxy connection failed: %w", err) - } - defer func() { _ = resp.Body.Close() }() - - latencyMs := time.Since(startTime).Milliseconds() - - if resp.StatusCode != http.StatusOK { - return nil, latencyMs, fmt.Errorf("request failed with status: %d", resp.StatusCode) - } - - var ipInfo struct { - IP string `json:"ip"` - City string `json:"city"` - Region string `json:"region"` - Country string `json:"country"` - } - - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, latencyMs, fmt.Errorf("failed to read response: %w", err) - } - - if err := json.Unmarshal(body, &ipInfo); err != nil { - return nil, latencyMs, fmt.Errorf("failed to parse response: %w", err) - } - - return &service.ProxyExitInfo{ - IP: ipInfo.IP, - City: ipInfo.City, - Region: ipInfo.Region, - Country: ipInfo.Country, - }, latencyMs, nil -} +package repository + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" + "github.com/Wei-Shaw/sub2api/internal/service" +) + +func NewProxyExitInfoProber() service.ProxyExitInfoProber { + return &proxyProbeService{ipInfoURL: defaultIPInfoURL} +} + +const defaultIPInfoURL = "https://ipinfo.io/json" + +type proxyProbeService struct { + ipInfoURL string +} + +func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*service.ProxyExitInfo, int64, error) { + client, err := httpclient.GetClient(httpclient.Options{ + ProxyURL: proxyURL, + Timeout: 15 * time.Second, + InsecureSkipVerify: true, + }) + if err != nil { + return nil, 0, fmt.Errorf("failed to create proxy client: %w", err) + } + + startTime := time.Now() + req, err := http.NewRequestWithContext(ctx, "GET", s.ipInfoURL, nil) + if err != nil { + return nil, 0, fmt.Errorf("failed to create request: %w", err) + } + + resp, err := client.Do(req) + if err != nil { + return nil, 0, fmt.Errorf("proxy connection failed: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + latencyMs := time.Since(startTime).Milliseconds() + + if resp.StatusCode != http.StatusOK { + return nil, latencyMs, fmt.Errorf("request failed with status: %d", resp.StatusCode) + } + + var ipInfo struct { + IP string `json:"ip"` + City string `json:"city"` + Region string `json:"region"` + Country string `json:"country"` + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, latencyMs, fmt.Errorf("failed to read response: %w", err) + } + + if err := json.Unmarshal(body, &ipInfo); err != nil { + return nil, latencyMs, fmt.Errorf("failed to parse response: %w", err) + } + + return &service.ProxyExitInfo{ + IP: ipInfo.IP, + City: ipInfo.City, + Region: ipInfo.Region, + Country: ipInfo.Country, + }, latencyMs, nil +} diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go index 02cc5dfa..4df87e9e 100644 --- a/backend/internal/service/account_test_service.go +++ b/backend/internal/service/account_test_service.go @@ -1,836 +1,803 @@ -package service - -import ( - "bufio" - "bytes" - "context" - "crypto/rand" - "encoding/hex" - "encoding/json" - "fmt" - "io" - "log" - "net/http" - "regexp" - "strings" - "time" - - "github.com/Wei-Shaw/sub2api/internal/pkg/claude" - "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" - "github.com/Wei-Shaw/sub2api/internal/pkg/openai" - "github.com/gin-gonic/gin" - "github.com/google/uuid" -) - -// sseDataPrefix matches SSE data lines with optional whitespace after colon. -// Some upstream APIs return non-standard "data:" without space (should be "data: "). -var sseDataPrefix = regexp.MustCompile(`^data:\s*`) - -const ( - testClaudeAPIURL = "https://api.anthropic.com/v1/messages" - testOpenAIAPIURL = "https://api.openai.com/v1/responses" - chatgptCodexAPIURL = "https://chatgpt.com/backend-api/codex/responses" -) - -// TestEvent represents a SSE event for account testing -type TestEvent struct { - Type string `json:"type"` - Text string `json:"text,omitempty"` - Model string `json:"model,omitempty"` - Success bool `json:"success,omitempty"` - Error string `json:"error,omitempty"` -} - -// AccountTestService handles account testing operations -type AccountTestService struct { - accountRepo AccountRepository - oauthService *OAuthService - openaiOAuthService *OpenAIOAuthService - geminiTokenProvider *GeminiTokenProvider - antigravityGatewayService *AntigravityGatewayService - httpUpstream HTTPUpstream -} - -// NewAccountTestService creates a new AccountTestService -func NewAccountTestService( - accountRepo AccountRepository, - oauthService *OAuthService, - openaiOAuthService *OpenAIOAuthService, - geminiTokenProvider *GeminiTokenProvider, - antigravityGatewayService *AntigravityGatewayService, - httpUpstream HTTPUpstream, -) *AccountTestService { - return &AccountTestService{ - accountRepo: accountRepo, - oauthService: oauthService, - openaiOAuthService: openaiOAuthService, - geminiTokenProvider: geminiTokenProvider, - antigravityGatewayService: antigravityGatewayService, - httpUpstream: httpUpstream, - } -} - -// generateSessionString generates a Claude Code style session string -func generateSessionString() (string, error) { - bytes := make([]byte, 32) - if _, err := rand.Read(bytes); err != nil { - return "", err - } - hex64 := hex.EncodeToString(bytes) - sessionUUID := uuid.New().String() - return fmt.Sprintf("user_%s_account__session_%s", hex64, sessionUUID), nil -} - -// createTestPayload creates a Claude Code style test request payload -func createTestPayload(modelID string) (map[string]any, error) { - sessionID, err := generateSessionString() - if err != nil { - return nil, err - } - - return map[string]any{ - "model": modelID, - "messages": []map[string]any{ - { - "role": "user", - "content": []map[string]any{ - { - "type": "text", - "text": "hi", - "cache_control": map[string]string{ - "type": "ephemeral", - }, - }, - }, - }, - }, - "system": []map[string]any{ - { - "type": "text", - "text": "You are Claude Code, Anthropic's official CLI for Claude.", - "cache_control": map[string]string{ - "type": "ephemeral", - }, - }, - }, - "metadata": map[string]string{ - "user_id": sessionID, - }, - "max_tokens": 1024, - "temperature": 1, - "stream": true, - }, nil -} - -// TestAccountConnection tests an account's connection by sending a test request -// All account types use full Claude Code client characteristics, only auth header differs -// modelID is optional - if empty, defaults to claude.DefaultTestModel -func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int64, modelID string) error { - ctx := c.Request.Context() - - // Get account - account, err := s.accountRepo.GetByID(ctx, accountID) - if err != nil { - return s.sendErrorAndEnd(c, "Account not found") - } - - // Route to platform-specific test method - if account.IsOpenAI() { - return s.testOpenAIAccountConnection(c, account, modelID) - } - - if account.IsGemini() { - return s.testGeminiAccountConnection(c, account, modelID) - } - - if account.Platform == PlatformAntigravity { - return s.testAntigravityAccountConnection(c, account, modelID) - } - - return s.testClaudeAccountConnection(c, account, modelID) -} - -// testClaudeAccountConnection tests an Anthropic Claude account's connection -func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account *Account, modelID string) error { - ctx := c.Request.Context() - - // Determine the model to use - testModelID := modelID - if testModelID == "" { - testModelID = claude.DefaultTestModel - } - - // For API Key accounts with model mapping, map the model - if account.Type == "apikey" { - mapping := account.GetModelMapping() - if len(mapping) > 0 { - if mappedModel, exists := mapping[testModelID]; exists { - testModelID = mappedModel - } - } - } - - // Determine authentication method and API URL - var authToken string - var useBearer bool - var apiURL string - - if account.IsOAuth() { - // OAuth or Setup Token - use Bearer token - useBearer = true - apiURL = testClaudeAPIURL - authToken = account.GetCredential("access_token") - if authToken == "" { - return s.sendErrorAndEnd(c, "No access token available") - } - - // Check if token needs refresh - needRefresh := false - if expiresAt := account.GetCredentialAsTime("expires_at"); expiresAt != nil { - if time.Now().Add(5 * time.Minute).After(*expiresAt) { - needRefresh = true - } - } - - if needRefresh && s.oauthService != nil { - tokenInfo, err := s.oauthService.RefreshAccountToken(ctx, account) - if err != nil { - return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to refresh token: %s", err.Error())) - } - authToken = tokenInfo.AccessToken - } - } else if account.Type == "apikey" { - // API Key - use x-api-key header - useBearer = false - authToken = account.GetCredential("api_key") - if authToken == "" { - return s.sendErrorAndEnd(c, "No API key available") - } - - apiURL = account.GetBaseURL() - if apiURL == "" { - apiURL = "https://api.anthropic.com" - } - apiURL = strings.TrimSuffix(apiURL, "/") + "/v1/messages" - } else { - return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type)) - } - - // Set SSE headers - c.Writer.Header().Set("Content-Type", "text/event-stream") - c.Writer.Header().Set("Cache-Control", "no-cache") - c.Writer.Header().Set("Connection", "keep-alive") - c.Writer.Header().Set("X-Accel-Buffering", "no") - c.Writer.Flush() - - // Create Claude Code style payload (same for all account types) - payload, err := createTestPayload(testModelID) - if err != nil { - return s.sendErrorAndEnd(c, "Failed to create test payload") - } - payloadBytes, _ := json.Marshal(payload) - - // Send test_start event - s.sendEvent(c, TestEvent{Type: "test_start", Model: testModelID}) - - req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewReader(payloadBytes)) - if err != nil { - return s.sendErrorAndEnd(c, "Failed to create request") - } - - // Set common headers - req.Header.Set("Content-Type", "application/json") - req.Header.Set("anthropic-version", "2023-06-01") - req.Header.Set("anthropic-beta", claude.DefaultBetaHeader) - - // Apply Claude Code client headers - for key, value := range claude.DefaultHeaders { - req.Header.Set(key, value) - } - - // Set authentication header - if useBearer { - req.Header.Set("Authorization", "Bearer "+authToken) - } else { - req.Header.Set("x-api-key", authToken) - } - - // Get proxy URL - proxyURL := "" - if account.ProxyID != nil && account.Proxy != nil { - proxyURL = account.Proxy.URL() - } - - resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) - if err != nil { - return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error())) - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - return s.sendErrorAndEnd(c, fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body))) - } - - // Process SSE stream - return s.processClaudeStream(c, resp.Body) -} - -// testOpenAIAccountConnection tests an OpenAI account's connection -func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account *Account, modelID string) error { - ctx := c.Request.Context() - - // Default to openai.DefaultTestModel for OpenAI testing - testModelID := modelID - if testModelID == "" { - testModelID = openai.DefaultTestModel - } - - // For API Key accounts with model mapping, map the model - if account.Type == "apikey" { - mapping := account.GetModelMapping() - if len(mapping) > 0 { - if mappedModel, exists := mapping[testModelID]; exists { - testModelID = mappedModel - } - } - } - - // Determine authentication method and API URL - var authToken string - var apiURL string - var isOAuth bool - var chatgptAccountID string - - if account.IsOAuth() { - isOAuth = true - // OAuth - use Bearer token with ChatGPT internal API - authToken = account.GetOpenAIAccessToken() - if authToken == "" { - return s.sendErrorAndEnd(c, "No access token available") - } - - // Check if token is expired and refresh if needed - if account.IsOpenAITokenExpired() && s.openaiOAuthService != nil { - tokenInfo, err := s.openaiOAuthService.RefreshAccountToken(ctx, account) - if err != nil { - return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to refresh token: %s", err.Error())) - } - authToken = tokenInfo.AccessToken - } - - // OAuth uses ChatGPT internal API - apiURL = chatgptCodexAPIURL - chatgptAccountID = account.GetChatGPTAccountID() - } else if account.Type == "apikey" { - // API Key - use Platform API - authToken = account.GetOpenAIApiKey() - if authToken == "" { - return s.sendErrorAndEnd(c, "No API key available") - } - - baseURL := account.GetOpenAIBaseURL() - if baseURL == "" { - baseURL = "https://api.openai.com" - } - apiURL = strings.TrimSuffix(baseURL, "/") + "/responses" - } else { - return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type)) - } - - // Set SSE headers - c.Writer.Header().Set("Content-Type", "text/event-stream") - c.Writer.Header().Set("Cache-Control", "no-cache") - c.Writer.Header().Set("Connection", "keep-alive") - c.Writer.Header().Set("X-Accel-Buffering", "no") - c.Writer.Flush() - - // Create OpenAI Responses API payload - payload := createOpenAITestPayload(testModelID, isOAuth) - payloadBytes, _ := json.Marshal(payload) - - // Send test_start event - s.sendEvent(c, TestEvent{Type: "test_start", Model: testModelID}) - - req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewReader(payloadBytes)) - if err != nil { - return s.sendErrorAndEnd(c, "Failed to create request") - } - - // Set common headers - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+authToken) - - // Set OAuth-specific headers for ChatGPT internal API - if isOAuth { - req.Host = "chatgpt.com" - req.Header.Set("accept", "text/event-stream") - if chatgptAccountID != "" { - req.Header.Set("chatgpt-account-id", chatgptAccountID) - } - } - - // Get proxy URL - proxyURL := "" - if account.ProxyID != nil && account.Proxy != nil { - proxyURL = account.Proxy.URL() - } - - resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) - if err != nil { - return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error())) - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - return s.sendErrorAndEnd(c, fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body))) - } - - // Process SSE stream - return s.processOpenAIStream(c, resp.Body) -} - -// testGeminiAccountConnection tests a Gemini account's connection -func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account *Account, modelID string) error { - ctx := c.Request.Context() - - // Determine the model to use - testModelID := modelID - if testModelID == "" { - testModelID = geminicli.DefaultTestModel - } - - // For API Key accounts with model mapping, map the model - if account.Type == AccountTypeApiKey { - mapping := account.GetModelMapping() - if len(mapping) > 0 { - if mappedModel, exists := mapping[testModelID]; exists { - testModelID = mappedModel - } - } - } - - // Set SSE headers - c.Writer.Header().Set("Content-Type", "text/event-stream") - c.Writer.Header().Set("Cache-Control", "no-cache") - c.Writer.Header().Set("Connection", "keep-alive") - c.Writer.Header().Set("X-Accel-Buffering", "no") - c.Writer.Flush() - - // Create test payload (Gemini format) - payload := createGeminiTestPayload() - - // Build request based on account type - var req *http.Request - var err error - - switch account.Type { - case AccountTypeApiKey: - req, err = s.buildGeminiAPIKeyRequest(ctx, account, testModelID, payload) - case AccountTypeOAuth: - req, err = s.buildGeminiOAuthRequest(ctx, account, testModelID, payload) - default: - return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type)) - } - - if err != nil { - return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to build request: %s", err.Error())) - } - - // Send test_start event - s.sendEvent(c, TestEvent{Type: "test_start", Model: testModelID}) - - // Get proxy and execute request - proxyURL := "" - if account.ProxyID != nil && account.Proxy != nil { - proxyURL = account.Proxy.URL() - } - - resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) - if err != nil { - return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error())) - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - return s.sendErrorAndEnd(c, fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body))) - } - - // Process SSE stream - return s.processGeminiStream(c, resp.Body) -} - -// testAntigravityAccountConnection tests an Antigravity account's connection -// 支持 Claude 和 Gemini 两种协议,使用非流式请求 -func (s *AccountTestService) testAntigravityAccountConnection(c *gin.Context, account *Account, modelID string) error { - ctx := c.Request.Context() - - // 默认模型:Claude 使用 claude-sonnet-4-5,Gemini 使用 gemini-3-pro-preview - testModelID := modelID - if testModelID == "" { - testModelID = "claude-sonnet-4-5" - } - - if s.antigravityGatewayService == nil { - return s.sendErrorAndEnd(c, "Antigravity gateway service not configured") - } - - // Set SSE headers - c.Writer.Header().Set("Content-Type", "text/event-stream") - c.Writer.Header().Set("Cache-Control", "no-cache") - c.Writer.Header().Set("Connection", "keep-alive") - c.Writer.Header().Set("X-Accel-Buffering", "no") - c.Writer.Flush() - - // Send test_start event - s.sendEvent(c, TestEvent{Type: "test_start", Model: testModelID}) - - // 调用 AntigravityGatewayService.TestConnection(复用协议转换逻辑) - result, err := s.antigravityGatewayService.TestConnection(ctx, account, testModelID) - if err != nil { - return s.sendErrorAndEnd(c, err.Error()) - } - - // 发送响应内容 - if result.Text != "" { - s.sendEvent(c, TestEvent{Type: "content", Text: result.Text}) - } - - s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) - return nil -} - -// buildGeminiAPIKeyRequest builds request for Gemini API Key accounts -func (s *AccountTestService) buildGeminiAPIKeyRequest(ctx context.Context, account *Account, modelID string, payload []byte) (*http.Request, error) { - apiKey := account.GetCredential("api_key") - if strings.TrimSpace(apiKey) == "" { - return nil, fmt.Errorf("no API key available") - } - - baseURL := account.GetCredential("base_url") - if baseURL == "" { - baseURL = geminicli.AIStudioBaseURL - } - - // Use streamGenerateContent for real-time feedback - fullURL := fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse", - strings.TrimRight(baseURL, "/"), modelID) - - req, err := http.NewRequestWithContext(ctx, "POST", fullURL, bytes.NewReader(payload)) - if err != nil { - return nil, err - } - - req.Header.Set("Content-Type", "application/json") - req.Header.Set("x-goog-api-key", apiKey) - - return req, nil -} - -// buildGeminiOAuthRequest builds request for Gemini OAuth accounts -func (s *AccountTestService) buildGeminiOAuthRequest(ctx context.Context, account *Account, modelID string, payload []byte) (*http.Request, error) { - if s.geminiTokenProvider == nil { - return nil, fmt.Errorf("gemini token provider not configured") - } - - // Get access token (auto-refreshes if needed) - accessToken, err := s.geminiTokenProvider.GetAccessToken(ctx, account) - if err != nil { - return nil, fmt.Errorf("failed to get access token: %w", err) - } - - projectID := strings.TrimSpace(account.GetCredential("project_id")) - if projectID == "" { - // AI Studio OAuth mode (no project_id): call generativelanguage API directly with Bearer token. - baseURL := account.GetCredential("base_url") - if strings.TrimSpace(baseURL) == "" { - baseURL = geminicli.AIStudioBaseURL - } - fullURL := fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse", strings.TrimRight(baseURL, "/"), modelID) - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(payload)) - if err != nil { - return nil, err - } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+accessToken) - return req, nil - } - - // Code Assist mode (with project_id) - return s.buildCodeAssistRequest(ctx, accessToken, projectID, modelID, payload) -} - -// buildCodeAssistRequest builds request for Google Code Assist API (used by Gemini CLI and Antigravity) -func (s *AccountTestService) buildCodeAssistRequest(ctx context.Context, accessToken, projectID, modelID string, payload []byte) (*http.Request, error) { - var inner map[string]any - if err := json.Unmarshal(payload, &inner); err != nil { - return nil, err - } - - wrapped := map[string]any{ - "model": modelID, - "project": projectID, - "request": inner, - } - wrappedBytes, _ := json.Marshal(wrapped) - - fullURL := fmt.Sprintf("%s/v1internal:streamGenerateContent?alt=sse", geminicli.GeminiCliBaseURL) - - req, err := http.NewRequestWithContext(ctx, "POST", fullURL, bytes.NewReader(wrappedBytes)) - if err != nil { - return nil, err - } - - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+accessToken) - req.Header.Set("User-Agent", geminicli.GeminiCLIUserAgent) - - return req, nil -} - -// createGeminiTestPayload creates a minimal test payload for Gemini API -func createGeminiTestPayload() []byte { - payload := map[string]any{ - "contents": []map[string]any{ - { - "role": "user", - "parts": []map[string]any{ - {"text": "hi"}, - }, - }, - }, - "systemInstruction": map[string]any{ - "parts": []map[string]any{ - {"text": "You are a helpful AI assistant."}, - }, - }, - } - bytes, _ := json.Marshal(payload) - return bytes -} - -// processGeminiStream processes SSE stream from Gemini API -func (s *AccountTestService) processGeminiStream(c *gin.Context, body io.Reader) error { - reader := bufio.NewReader(body) - - for { - line, err := reader.ReadString('\n') - if err != nil { - if err == io.EOF { - s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) - return nil - } - return s.sendErrorAndEnd(c, fmt.Sprintf("Stream read error: %s", err.Error())) - } - - line = strings.TrimSpace(line) - if line == "" || !strings.HasPrefix(line, "data: ") { - continue - } - - jsonStr := strings.TrimPrefix(line, "data: ") - if jsonStr == "[DONE]" { - s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) - return nil - } - - var data map[string]any - if err := json.Unmarshal([]byte(jsonStr), &data); err != nil { - continue - } - - // Support two Gemini response formats: - // - AI Studio: {"candidates": [...]} - // - Gemini CLI: {"response": {"candidates": [...]}} - if resp, ok := data["response"].(map[string]any); ok && resp != nil { - data = resp - } - if candidates, ok := data["candidates"].([]any); ok && len(candidates) > 0 { - if candidate, ok := candidates[0].(map[string]any); ok { - // Check for completion - if finishReason, ok := candidate["finishReason"].(string); ok && finishReason != "" { - s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) - return nil - } - - // Extract content - if content, ok := candidate["content"].(map[string]any); ok { - if parts, ok := content["parts"].([]any); ok { - for _, part := range parts { - if partMap, ok := part.(map[string]any); ok { - if text, ok := partMap["text"].(string); ok && text != "" { - s.sendEvent(c, TestEvent{Type: "content", Text: text}) - } - } - } - } - } - } - } - - // Handle errors - if errData, ok := data["error"].(map[string]any); ok { - errorMsg := "Unknown error" - if msg, ok := errData["message"].(string); ok { - errorMsg = msg - } - return s.sendErrorAndEnd(c, errorMsg) - } - } -} - -// createOpenAITestPayload creates a test payload for OpenAI Responses API -func createOpenAITestPayload(modelID string, isOAuth bool) map[string]any { - payload := map[string]any{ - "model": modelID, - "input": []map[string]any{ - { - "role": "user", - "content": []map[string]any{ - { - "type": "input_text", - "text": "hi", - }, - }, - }, - }, - "stream": true, - } - - // OAuth accounts using ChatGPT internal API require store: false - if isOAuth { - payload["store"] = false - } - - // All accounts require instructions for Responses API - payload["instructions"] = openai.DefaultInstructions - - return payload -} - -// processClaudeStream processes the SSE stream from Claude API -func (s *AccountTestService) processClaudeStream(c *gin.Context, body io.Reader) error { - reader := bufio.NewReader(body) - - for { - line, err := reader.ReadString('\n') - if err != nil { - if err == io.EOF { - s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) - return nil - } - return s.sendErrorAndEnd(c, fmt.Sprintf("Stream read error: %s", err.Error())) - } - - line = strings.TrimSpace(line) - if line == "" || !sseDataPrefix.MatchString(line) { - continue - } - - jsonStr := sseDataPrefix.ReplaceAllString(line, "") - if jsonStr == "[DONE]" { - s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) - return nil - } - - var data map[string]any - if err := json.Unmarshal([]byte(jsonStr), &data); err != nil { - continue - } - - eventType, _ := data["type"].(string) - - switch eventType { - case "content_block_delta": - if delta, ok := data["delta"].(map[string]any); ok { - if text, ok := delta["text"].(string); ok { - s.sendEvent(c, TestEvent{Type: "content", Text: text}) - } - } - case "message_stop": - s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) - return nil - case "error": - errorMsg := "Unknown error" - if errData, ok := data["error"].(map[string]any); ok { - if msg, ok := errData["message"].(string); ok { - errorMsg = msg - } - } - return s.sendErrorAndEnd(c, errorMsg) - } - } -} - -// processOpenAIStream processes the SSE stream from OpenAI Responses API -func (s *AccountTestService) processOpenAIStream(c *gin.Context, body io.Reader) error { - reader := bufio.NewReader(body) - - for { - line, err := reader.ReadString('\n') - if err != nil { - if err == io.EOF { - s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) - return nil - } - return s.sendErrorAndEnd(c, fmt.Sprintf("Stream read error: %s", err.Error())) - } - - line = strings.TrimSpace(line) - if line == "" || !sseDataPrefix.MatchString(line) { - continue - } - - jsonStr := sseDataPrefix.ReplaceAllString(line, "") - if jsonStr == "[DONE]" { - s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) - return nil - } - - var data map[string]any - if err := json.Unmarshal([]byte(jsonStr), &data); err != nil { - continue - } - - eventType, _ := data["type"].(string) - - switch eventType { - case "response.output_text.delta": - // OpenAI Responses API uses "delta" field for text content - if delta, ok := data["delta"].(string); ok && delta != "" { - s.sendEvent(c, TestEvent{Type: "content", Text: delta}) - } - case "response.completed": - s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) - return nil - case "error": - errorMsg := "Unknown error" - if errData, ok := data["error"].(map[string]any); ok { - if msg, ok := errData["message"].(string); ok { - errorMsg = msg - } - } - return s.sendErrorAndEnd(c, errorMsg) - } - } -} - -// sendEvent sends a SSE event to the client -func (s *AccountTestService) sendEvent(c *gin.Context, event TestEvent) { - eventJSON, _ := json.Marshal(event) - if _, err := fmt.Fprintf(c.Writer, "data: %s\n\n", eventJSON); err != nil { - log.Printf("failed to write SSE event: %v", err) - return - } - c.Writer.Flush() -} - -// sendErrorAndEnd sends an error event and ends the stream -func (s *AccountTestService) sendErrorAndEnd(c *gin.Context, errorMsg string) error { - log.Printf("Account test error: %s", errorMsg) - s.sendEvent(c, TestEvent{Type: "error", Error: errorMsg}) - return fmt.Errorf("%s", errorMsg) -} +package service + +import ( + "bufio" + "bytes" + "context" + "crypto/rand" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "regexp" + "strings" + + "github.com/Wei-Shaw/sub2api/internal/pkg/claude" + "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" + "github.com/Wei-Shaw/sub2api/internal/pkg/openai" + "github.com/gin-gonic/gin" + "github.com/google/uuid" +) + +// sseDataPrefix matches SSE data lines with optional whitespace after colon. +// Some upstream APIs return non-standard "data:" without space (should be "data: "). +var sseDataPrefix = regexp.MustCompile(`^data:\s*`) + +const ( + testClaudeAPIURL = "https://api.anthropic.com/v1/messages" + chatgptCodexAPIURL = "https://chatgpt.com/backend-api/codex/responses" +) + +// TestEvent represents a SSE event for account testing +type TestEvent struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` + Model string `json:"model,omitempty"` + Success bool `json:"success,omitempty"` + Error string `json:"error,omitempty"` +} + +// AccountTestService handles account testing operations +type AccountTestService struct { + accountRepo AccountRepository + geminiTokenProvider *GeminiTokenProvider + antigravityGatewayService *AntigravityGatewayService + httpUpstream HTTPUpstream +} + +// NewAccountTestService creates a new AccountTestService +func NewAccountTestService( + accountRepo AccountRepository, + geminiTokenProvider *GeminiTokenProvider, + antigravityGatewayService *AntigravityGatewayService, + httpUpstream HTTPUpstream, +) *AccountTestService { + return &AccountTestService{ + accountRepo: accountRepo, + geminiTokenProvider: geminiTokenProvider, + antigravityGatewayService: antigravityGatewayService, + httpUpstream: httpUpstream, + } +} + +// generateSessionString generates a Claude Code style session string +func generateSessionString() (string, error) { + bytes := make([]byte, 32) + if _, err := rand.Read(bytes); err != nil { + return "", err + } + hex64 := hex.EncodeToString(bytes) + sessionUUID := uuid.New().String() + return fmt.Sprintf("user_%s_account__session_%s", hex64, sessionUUID), nil +} + +// createTestPayload creates a Claude Code style test request payload +func createTestPayload(modelID string) (map[string]any, error) { + sessionID, err := generateSessionString() + if err != nil { + return nil, err + } + + return map[string]any{ + "model": modelID, + "messages": []map[string]any{ + { + "role": "user", + "content": []map[string]any{ + { + "type": "text", + "text": "hi", + "cache_control": map[string]string{ + "type": "ephemeral", + }, + }, + }, + }, + }, + "system": []map[string]any{ + { + "type": "text", + "text": "You are Claude Code, Anthropic's official CLI for Claude.", + "cache_control": map[string]string{ + "type": "ephemeral", + }, + }, + }, + "metadata": map[string]string{ + "user_id": sessionID, + }, + "max_tokens": 1024, + "temperature": 1, + "stream": true, + }, nil +} + +// TestAccountConnection tests an account's connection by sending a test request +// All account types use full Claude Code client characteristics, only auth header differs +// modelID is optional - if empty, defaults to claude.DefaultTestModel +func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int64, modelID string) error { + ctx := c.Request.Context() + + // Get account + account, err := s.accountRepo.GetByID(ctx, accountID) + if err != nil { + return s.sendErrorAndEnd(c, "Account not found") + } + + // Route to platform-specific test method + if account.IsOpenAI() { + return s.testOpenAIAccountConnection(c, account, modelID) + } + + if account.IsGemini() { + return s.testGeminiAccountConnection(c, account, modelID) + } + + if account.Platform == PlatformAntigravity { + return s.testAntigravityAccountConnection(c, account, modelID) + } + + return s.testClaudeAccountConnection(c, account, modelID) +} + +// testClaudeAccountConnection tests an Anthropic Claude account's connection +func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account *Account, modelID string) error { + ctx := c.Request.Context() + + // Determine the model to use + testModelID := modelID + if testModelID == "" { + testModelID = claude.DefaultTestModel + } + + // For API Key accounts with model mapping, map the model + if account.Type == "apikey" { + mapping := account.GetModelMapping() + if len(mapping) > 0 { + if mappedModel, exists := mapping[testModelID]; exists { + testModelID = mappedModel + } + } + } + + // Determine authentication method and API URL + var authToken string + var useBearer bool + var apiURL string + + if account.IsOAuth() { + // OAuth or Setup Token - use Bearer token + useBearer = true + apiURL = testClaudeAPIURL + authToken = account.GetCredential("access_token") + if authToken == "" { + return s.sendErrorAndEnd(c, "No access token available") + } + } else if account.Type == "apikey" { + // API Key - use x-api-key header + useBearer = false + authToken = account.GetCredential("api_key") + if authToken == "" { + return s.sendErrorAndEnd(c, "No API key available") + } + + apiURL = account.GetBaseURL() + if apiURL == "" { + apiURL = "https://api.anthropic.com" + } + apiURL = strings.TrimSuffix(apiURL, "/") + "/v1/messages" + } else { + return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type)) + } + + // Set SSE headers + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("X-Accel-Buffering", "no") + c.Writer.Flush() + + // Create Claude Code style payload (same for all account types) + payload, err := createTestPayload(testModelID) + if err != nil { + return s.sendErrorAndEnd(c, "Failed to create test payload") + } + payloadBytes, _ := json.Marshal(payload) + + // Send test_start event + s.sendEvent(c, TestEvent{Type: "test_start", Model: testModelID}) + + req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewReader(payloadBytes)) + if err != nil { + return s.sendErrorAndEnd(c, "Failed to create request") + } + + // Set common headers + req.Header.Set("Content-Type", "application/json") + req.Header.Set("anthropic-version", "2023-06-01") + req.Header.Set("anthropic-beta", claude.DefaultBetaHeader) + + // Apply Claude Code client headers + for key, value := range claude.DefaultHeaders { + req.Header.Set(key, value) + } + + // Set authentication header + if useBearer { + req.Header.Set("Authorization", "Bearer "+authToken) + } else { + req.Header.Set("x-api-key", authToken) + } + + // Get proxy URL + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) + if err != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error())) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return s.sendErrorAndEnd(c, fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body))) + } + + // Process SSE stream + return s.processClaudeStream(c, resp.Body) +} + +// testOpenAIAccountConnection tests an OpenAI account's connection +func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account *Account, modelID string) error { + ctx := c.Request.Context() + + // Default to openai.DefaultTestModel for OpenAI testing + testModelID := modelID + if testModelID == "" { + testModelID = openai.DefaultTestModel + } + + // For API Key accounts with model mapping, map the model + if account.Type == "apikey" { + mapping := account.GetModelMapping() + if len(mapping) > 0 { + if mappedModel, exists := mapping[testModelID]; exists { + testModelID = mappedModel + } + } + } + + // Determine authentication method and API URL + var authToken string + var apiURL string + var isOAuth bool + var chatgptAccountID string + + if account.IsOAuth() { + isOAuth = true + // OAuth - use Bearer token with ChatGPT internal API + authToken = account.GetOpenAIAccessToken() + if authToken == "" { + return s.sendErrorAndEnd(c, "No access token available") + } + + // OAuth uses ChatGPT internal API + apiURL = chatgptCodexAPIURL + chatgptAccountID = account.GetChatGPTAccountID() + } else if account.Type == "apikey" { + // API Key - use Platform API + authToken = account.GetOpenAIApiKey() + if authToken == "" { + return s.sendErrorAndEnd(c, "No API key available") + } + + baseURL := account.GetOpenAIBaseURL() + if baseURL == "" { + baseURL = "https://api.openai.com" + } + apiURL = strings.TrimSuffix(baseURL, "/") + "/responses" + } else { + return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type)) + } + + // Set SSE headers + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("X-Accel-Buffering", "no") + c.Writer.Flush() + + // Create OpenAI Responses API payload + payload := createOpenAITestPayload(testModelID, isOAuth) + payloadBytes, _ := json.Marshal(payload) + + // Send test_start event + s.sendEvent(c, TestEvent{Type: "test_start", Model: testModelID}) + + req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewReader(payloadBytes)) + if err != nil { + return s.sendErrorAndEnd(c, "Failed to create request") + } + + // Set common headers + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+authToken) + + // Set OAuth-specific headers for ChatGPT internal API + if isOAuth { + req.Host = "chatgpt.com" + req.Header.Set("accept", "text/event-stream") + if chatgptAccountID != "" { + req.Header.Set("chatgpt-account-id", chatgptAccountID) + } + } + + // Get proxy URL + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) + if err != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error())) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return s.sendErrorAndEnd(c, fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body))) + } + + // Process SSE stream + return s.processOpenAIStream(c, resp.Body) +} + +// testGeminiAccountConnection tests a Gemini account's connection +func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account *Account, modelID string) error { + ctx := c.Request.Context() + + // Determine the model to use + testModelID := modelID + if testModelID == "" { + testModelID = geminicli.DefaultTestModel + } + + // For API Key accounts with model mapping, map the model + if account.Type == AccountTypeApiKey { + mapping := account.GetModelMapping() + if len(mapping) > 0 { + if mappedModel, exists := mapping[testModelID]; exists { + testModelID = mappedModel + } + } + } + + // Set SSE headers + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("X-Accel-Buffering", "no") + c.Writer.Flush() + + // Create test payload (Gemini format) + payload := createGeminiTestPayload() + + // Build request based on account type + var req *http.Request + var err error + + switch account.Type { + case AccountTypeApiKey: + req, err = s.buildGeminiAPIKeyRequest(ctx, account, testModelID, payload) + case AccountTypeOAuth: + req, err = s.buildGeminiOAuthRequest(ctx, account, testModelID, payload) + default: + return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type)) + } + + if err != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to build request: %s", err.Error())) + } + + // Send test_start event + s.sendEvent(c, TestEvent{Type: "test_start", Model: testModelID}) + + // Get proxy and execute request + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) + if err != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error())) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return s.sendErrorAndEnd(c, fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body))) + } + + // Process SSE stream + return s.processGeminiStream(c, resp.Body) +} + +// testAntigravityAccountConnection tests an Antigravity account's connection +// 支持 Claude 和 Gemini 两种协议,使用非流式请求 +func (s *AccountTestService) testAntigravityAccountConnection(c *gin.Context, account *Account, modelID string) error { + ctx := c.Request.Context() + + // 默认模型:Claude 使用 claude-sonnet-4-5,Gemini 使用 gemini-3-pro-preview + testModelID := modelID + if testModelID == "" { + testModelID = "claude-sonnet-4-5" + } + + if s.antigravityGatewayService == nil { + return s.sendErrorAndEnd(c, "Antigravity gateway service not configured") + } + + // Set SSE headers + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("X-Accel-Buffering", "no") + c.Writer.Flush() + + // Send test_start event + s.sendEvent(c, TestEvent{Type: "test_start", Model: testModelID}) + + // 调用 AntigravityGatewayService.TestConnection(复用协议转换逻辑) + result, err := s.antigravityGatewayService.TestConnection(ctx, account, testModelID) + if err != nil { + return s.sendErrorAndEnd(c, err.Error()) + } + + // 发送响应内容 + if result.Text != "" { + s.sendEvent(c, TestEvent{Type: "content", Text: result.Text}) + } + + s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) + return nil +} + +// buildGeminiAPIKeyRequest builds request for Gemini API Key accounts +func (s *AccountTestService) buildGeminiAPIKeyRequest(ctx context.Context, account *Account, modelID string, payload []byte) (*http.Request, error) { + apiKey := account.GetCredential("api_key") + if strings.TrimSpace(apiKey) == "" { + return nil, fmt.Errorf("no API key available") + } + + baseURL := account.GetCredential("base_url") + if baseURL == "" { + baseURL = geminicli.AIStudioBaseURL + } + + // Use streamGenerateContent for real-time feedback + fullURL := fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse", + strings.TrimRight(baseURL, "/"), modelID) + + req, err := http.NewRequestWithContext(ctx, "POST", fullURL, bytes.NewReader(payload)) + if err != nil { + return nil, err + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("x-goog-api-key", apiKey) + + return req, nil +} + +// buildGeminiOAuthRequest builds request for Gemini OAuth accounts +func (s *AccountTestService) buildGeminiOAuthRequest(ctx context.Context, account *Account, modelID string, payload []byte) (*http.Request, error) { + if s.geminiTokenProvider == nil { + return nil, fmt.Errorf("gemini token provider not configured") + } + + // Get access token (auto-refreshes if needed) + accessToken, err := s.geminiTokenProvider.GetAccessToken(ctx, account) + if err != nil { + return nil, fmt.Errorf("failed to get access token: %w", err) + } + + projectID := strings.TrimSpace(account.GetCredential("project_id")) + if projectID == "" { + // AI Studio OAuth mode (no project_id): call generativelanguage API directly with Bearer token. + baseURL := account.GetCredential("base_url") + if strings.TrimSpace(baseURL) == "" { + baseURL = geminicli.AIStudioBaseURL + } + fullURL := fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse", strings.TrimRight(baseURL, "/"), modelID) + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(payload)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+accessToken) + return req, nil + } + + // Code Assist mode (with project_id) + return s.buildCodeAssistRequest(ctx, accessToken, projectID, modelID, payload) +} + +// buildCodeAssistRequest builds request for Google Code Assist API (used by Gemini CLI and Antigravity) +func (s *AccountTestService) buildCodeAssistRequest(ctx context.Context, accessToken, projectID, modelID string, payload []byte) (*http.Request, error) { + var inner map[string]any + if err := json.Unmarshal(payload, &inner); err != nil { + return nil, err + } + + wrapped := map[string]any{ + "model": modelID, + "project": projectID, + "request": inner, + } + wrappedBytes, _ := json.Marshal(wrapped) + + fullURL := fmt.Sprintf("%s/v1internal:streamGenerateContent?alt=sse", geminicli.GeminiCliBaseURL) + + req, err := http.NewRequestWithContext(ctx, "POST", fullURL, bytes.NewReader(wrappedBytes)) + if err != nil { + return nil, err + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("User-Agent", geminicli.GeminiCLIUserAgent) + + return req, nil +} + +// createGeminiTestPayload creates a minimal test payload for Gemini API +func createGeminiTestPayload() []byte { + payload := map[string]any{ + "contents": []map[string]any{ + { + "role": "user", + "parts": []map[string]any{ + {"text": "hi"}, + }, + }, + }, + "systemInstruction": map[string]any{ + "parts": []map[string]any{ + {"text": "You are a helpful AI assistant."}, + }, + }, + } + bytes, _ := json.Marshal(payload) + return bytes +} + +// processGeminiStream processes SSE stream from Gemini API +func (s *AccountTestService) processGeminiStream(c *gin.Context, body io.Reader) error { + reader := bufio.NewReader(body) + + for { + line, err := reader.ReadString('\n') + if err != nil { + if err == io.EOF { + s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) + return nil + } + return s.sendErrorAndEnd(c, fmt.Sprintf("Stream read error: %s", err.Error())) + } + + line = strings.TrimSpace(line) + if line == "" || !strings.HasPrefix(line, "data: ") { + continue + } + + jsonStr := strings.TrimPrefix(line, "data: ") + if jsonStr == "[DONE]" { + s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) + return nil + } + + var data map[string]any + if err := json.Unmarshal([]byte(jsonStr), &data); err != nil { + continue + } + + // Support two Gemini response formats: + // - AI Studio: {"candidates": [...]} + // - Gemini CLI: {"response": {"candidates": [...]}} + if resp, ok := data["response"].(map[string]any); ok && resp != nil { + data = resp + } + if candidates, ok := data["candidates"].([]any); ok && len(candidates) > 0 { + if candidate, ok := candidates[0].(map[string]any); ok { + // Check for completion + if finishReason, ok := candidate["finishReason"].(string); ok && finishReason != "" { + s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) + return nil + } + + // Extract content + if content, ok := candidate["content"].(map[string]any); ok { + if parts, ok := content["parts"].([]any); ok { + for _, part := range parts { + if partMap, ok := part.(map[string]any); ok { + if text, ok := partMap["text"].(string); ok && text != "" { + s.sendEvent(c, TestEvent{Type: "content", Text: text}) + } + } + } + } + } + } + } + + // Handle errors + if errData, ok := data["error"].(map[string]any); ok { + errorMsg := "Unknown error" + if msg, ok := errData["message"].(string); ok { + errorMsg = msg + } + return s.sendErrorAndEnd(c, errorMsg) + } + } +} + +// createOpenAITestPayload creates a test payload for OpenAI Responses API +func createOpenAITestPayload(modelID string, isOAuth bool) map[string]any { + payload := map[string]any{ + "model": modelID, + "input": []map[string]any{ + { + "role": "user", + "content": []map[string]any{ + { + "type": "input_text", + "text": "hi", + }, + }, + }, + }, + "stream": true, + } + + // OAuth accounts using ChatGPT internal API require store: false + if isOAuth { + payload["store"] = false + } + + // All accounts require instructions for Responses API + payload["instructions"] = openai.DefaultInstructions + + return payload +} + +// processClaudeStream processes the SSE stream from Claude API +func (s *AccountTestService) processClaudeStream(c *gin.Context, body io.Reader) error { + reader := bufio.NewReader(body) + + for { + line, err := reader.ReadString('\n') + if err != nil { + if err == io.EOF { + s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) + return nil + } + return s.sendErrorAndEnd(c, fmt.Sprintf("Stream read error: %s", err.Error())) + } + + line = strings.TrimSpace(line) + if line == "" || !sseDataPrefix.MatchString(line) { + continue + } + + jsonStr := sseDataPrefix.ReplaceAllString(line, "") + if jsonStr == "[DONE]" { + s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) + return nil + } + + var data map[string]any + if err := json.Unmarshal([]byte(jsonStr), &data); err != nil { + continue + } + + eventType, _ := data["type"].(string) + + switch eventType { + case "content_block_delta": + if delta, ok := data["delta"].(map[string]any); ok { + if text, ok := delta["text"].(string); ok { + s.sendEvent(c, TestEvent{Type: "content", Text: text}) + } + } + case "message_stop": + s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) + return nil + case "error": + errorMsg := "Unknown error" + if errData, ok := data["error"].(map[string]any); ok { + if msg, ok := errData["message"].(string); ok { + errorMsg = msg + } + } + return s.sendErrorAndEnd(c, errorMsg) + } + } +} + +// processOpenAIStream processes the SSE stream from OpenAI Responses API +func (s *AccountTestService) processOpenAIStream(c *gin.Context, body io.Reader) error { + reader := bufio.NewReader(body) + + for { + line, err := reader.ReadString('\n') + if err != nil { + if err == io.EOF { + s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) + return nil + } + return s.sendErrorAndEnd(c, fmt.Sprintf("Stream read error: %s", err.Error())) + } + + line = strings.TrimSpace(line) + if line == "" || !sseDataPrefix.MatchString(line) { + continue + } + + jsonStr := sseDataPrefix.ReplaceAllString(line, "") + if jsonStr == "[DONE]" { + s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) + return nil + } + + var data map[string]any + if err := json.Unmarshal([]byte(jsonStr), &data); err != nil { + continue + } + + eventType, _ := data["type"].(string) + + switch eventType { + case "response.output_text.delta": + // OpenAI Responses API uses "delta" field for text content + if delta, ok := data["delta"].(string); ok && delta != "" { + s.sendEvent(c, TestEvent{Type: "content", Text: delta}) + } + case "response.completed": + s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) + return nil + case "error": + errorMsg := "Unknown error" + if errData, ok := data["error"].(map[string]any); ok { + if msg, ok := errData["message"].(string); ok { + errorMsg = msg + } + } + return s.sendErrorAndEnd(c, errorMsg) + } + } +} + +// sendEvent sends a SSE event to the client +func (s *AccountTestService) sendEvent(c *gin.Context, event TestEvent) { + eventJSON, _ := json.Marshal(event) + if _, err := fmt.Fprintf(c.Writer, "data: %s\n\n", eventJSON); err != nil { + log.Printf("failed to write SSE event: %v", err) + return + } + c.Writer.Flush() +} + +// sendErrorAndEnd sends an error event and ends the stream +func (s *AccountTestService) sendErrorAndEnd(c *gin.Context, errorMsg string) error { + log.Printf("Account test error: %s", errorMsg) + s.sendEvent(c, TestEvent{Type: "error", Error: errorMsg}) + return fmt.Errorf("%s", errorMsg) +} diff --git a/backend/internal/service/account_usage_service.go b/backend/internal/service/account_usage_service.go index 4694d790..439e9508 100644 --- a/backend/internal/service/account_usage_service.go +++ b/backend/internal/service/account_usage_service.go @@ -1,528 +1,528 @@ -package service - -import ( - "context" - "fmt" - "log" - "sync" - "time" - - "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" - "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" -) - -type UsageLogRepository interface { - Create(ctx context.Context, log *UsageLog) error - GetByID(ctx context.Context, id int64) (*UsageLog, error) - Delete(ctx context.Context, id int64) error - - ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error) - ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error) - ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error) - - ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error) - ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error) - ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error) - ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error) - - GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error) - GetAccountTodayStats(ctx context.Context, accountID int64) (*usagestats.AccountStats, error) - - // Admin dashboard stats - GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) - GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) ([]usagestats.TrendDataPoint, error) - GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) ([]usagestats.ModelStat, error) - GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.ApiKeyUsageTrendPoint, error) - GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error) - GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error) - GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchApiKeyUsageStats, error) - - // User dashboard stats - GetUserDashboardStats(ctx context.Context, userID int64) (*usagestats.UserDashboardStats, error) - GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) ([]usagestats.TrendDataPoint, error) - GetUserModelStats(ctx context.Context, userID int64, startTime, endTime time.Time) ([]usagestats.ModelStat, error) - - // Admin usage listing/stats - ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]UsageLog, *pagination.PaginationResult, error) - GetGlobalStats(ctx context.Context, startTime, endTime time.Time) (*usagestats.UsageStats, error) - - // Account stats - GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.AccountUsageStatsResponse, error) - - // Aggregated stats (optimized) - GetUserStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) - GetApiKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) - GetAccountStatsAggregated(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) - GetModelStatsAggregated(ctx context.Context, modelName string, startTime, endTime time.Time) (*usagestats.UsageStats, error) - GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) ([]map[string]any, error) -} - -// apiUsageCache 缓存从 Anthropic API 获取的使用率数据(utilization, resets_at) -type apiUsageCache struct { - response *ClaudeUsageResponse - timestamp time.Time -} - -// windowStatsCache 缓存从本地数据库查询的窗口统计(requests, tokens, cost) -type windowStatsCache struct { - stats *WindowStats - timestamp time.Time -} - -// antigravityUsageCache 缓存 Antigravity 额度数据 -type antigravityUsageCache struct { - usageInfo *UsageInfo - timestamp time.Time -} - -const ( - apiCacheTTL = 10 * time.Minute - windowStatsCacheTTL = 1 * time.Minute -) - -// UsageCache 封装账户使用量相关的缓存 -type UsageCache struct { - apiCache sync.Map // accountID -> *apiUsageCache - windowStatsCache sync.Map // accountID -> *windowStatsCache - antigravityCache sync.Map // accountID -> *antigravityUsageCache -} - -// NewUsageCache 创建 UsageCache 实例 -func NewUsageCache() *UsageCache { - return &UsageCache{} -} - -// WindowStats 窗口期统计 -type WindowStats struct { - Requests int64 `json:"requests"` - Tokens int64 `json:"tokens"` - Cost float64 `json:"cost"` -} - -// UsageProgress 使用量进度 -type UsageProgress struct { - Utilization float64 `json:"utilization"` // 使用率百分比 (0-100+,100表示100%) - ResetsAt *time.Time `json:"resets_at"` // 重置时间 - RemainingSeconds int `json:"remaining_seconds"` // 距重置剩余秒数 - WindowStats *WindowStats `json:"window_stats,omitempty"` // 窗口期统计(从窗口开始到当前的使用量) -} - -// AntigravityModelQuota Antigravity 单个模型的配额信息 -type AntigravityModelQuota struct { - Utilization int `json:"utilization"` // 使用率 0-100 - ResetTime string `json:"reset_time"` // 重置时间 ISO8601 -} - -// UsageInfo 账号使用量信息 -type UsageInfo struct { - UpdatedAt *time.Time `json:"updated_at,omitempty"` // 更新时间 - FiveHour *UsageProgress `json:"five_hour"` // 5小时窗口 - SevenDay *UsageProgress `json:"seven_day,omitempty"` // 7天窗口 - SevenDaySonnet *UsageProgress `json:"seven_day_sonnet,omitempty"` // 7天Sonnet窗口 - GeminiProDaily *UsageProgress `json:"gemini_pro_daily,omitempty"` // Gemini Pro 日配额 - GeminiFlashDaily *UsageProgress `json:"gemini_flash_daily,omitempty"` // Gemini Flash 日配额 - - // Antigravity 多模型配额 - AntigravityQuota map[string]*AntigravityModelQuota `json:"antigravity_quota,omitempty"` -} - -// ClaudeUsageResponse Anthropic API返回的usage结构 -type ClaudeUsageResponse struct { - FiveHour struct { - Utilization float64 `json:"utilization"` - ResetsAt string `json:"resets_at"` - } `json:"five_hour"` - SevenDay struct { - Utilization float64 `json:"utilization"` - ResetsAt string `json:"resets_at"` - } `json:"seven_day"` - SevenDaySonnet struct { - Utilization float64 `json:"utilization"` - ResetsAt string `json:"resets_at"` - } `json:"seven_day_sonnet"` -} - -// ClaudeUsageFetcher fetches usage data from Anthropic OAuth API -type ClaudeUsageFetcher interface { - FetchUsage(ctx context.Context, accessToken, proxyURL string) (*ClaudeUsageResponse, error) -} - -// AccountUsageService 账号使用量查询服务 -type AccountUsageService struct { - accountRepo AccountRepository - usageLogRepo UsageLogRepository - usageFetcher ClaudeUsageFetcher - geminiQuotaService *GeminiQuotaService - antigravityQuotaFetcher *AntigravityQuotaFetcher - cache *UsageCache -} - -// NewAccountUsageService 创建AccountUsageService实例 -func NewAccountUsageService( - accountRepo AccountRepository, - usageLogRepo UsageLogRepository, - usageFetcher ClaudeUsageFetcher, - geminiQuotaService *GeminiQuotaService, - antigravityQuotaFetcher *AntigravityQuotaFetcher, - cache *UsageCache, -) *AccountUsageService { - return &AccountUsageService{ - accountRepo: accountRepo, - usageLogRepo: usageLogRepo, - usageFetcher: usageFetcher, - geminiQuotaService: geminiQuotaService, - antigravityQuotaFetcher: antigravityQuotaFetcher, - cache: cache, - } -} - -// GetUsage 获取账号使用量 -// OAuth账号: 调用Anthropic API获取真实数据(需要profile scope),API响应缓存10分钟,窗口统计缓存1分钟 -// Setup Token账号: 根据session_window推算5h窗口,7d数据不可用(没有profile scope) -// API Key账号: 不支持usage查询 -func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*UsageInfo, error) { - account, err := s.accountRepo.GetByID(ctx, accountID) - if err != nil { - return nil, fmt.Errorf("get account failed: %w", err) - } - - if account.Platform == PlatformGemini { - return s.getGeminiUsage(ctx, account) - } - - // Antigravity 平台:使用 AntigravityQuotaFetcher 获取额度 - if account.Platform == PlatformAntigravity { - return s.getAntigravityUsage(ctx, account) - } - - // 只有oauth类型账号可以通过API获取usage(有profile scope) - if account.CanGetUsage() { - var apiResp *ClaudeUsageResponse - - // 1. 检查 API 缓存(10 分钟) - if cached, ok := s.cache.apiCache.Load(accountID); ok { - if cache, ok := cached.(*apiUsageCache); ok && time.Since(cache.timestamp) < apiCacheTTL { - apiResp = cache.response - } - } - - // 2. 如果没有缓存,从 API 获取 - if apiResp == nil { - apiResp, err = s.fetchOAuthUsageRaw(ctx, account) - if err != nil { - return nil, err - } - // 缓存 API 响应 - s.cache.apiCache.Store(accountID, &apiUsageCache{ - response: apiResp, - timestamp: time.Now(), - }) - } - - // 3. 构建 UsageInfo(每次都重新计算 RemainingSeconds) - now := time.Now() - usage := s.buildUsageInfo(apiResp, &now) - - // 4. 添加窗口统计(有独立缓存,1 分钟) - s.addWindowStats(ctx, account, usage) - - return usage, nil - } - - // Setup Token账号:根据session_window推算(没有profile scope,无法调用usage API) - if account.Type == AccountTypeSetupToken { - usage := s.estimateSetupTokenUsage(account) - // 添加窗口统计 - s.addWindowStats(ctx, account, usage) - return usage, nil - } - - // API Key账号不支持usage查询 - return nil, fmt.Errorf("account type %s does not support usage query", account.Type) -} - -func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Account) (*UsageInfo, error) { - now := time.Now() - usage := &UsageInfo{ - UpdatedAt: &now, - } - - if s.geminiQuotaService == nil || s.usageLogRepo == nil { - return usage, nil - } - - quota, ok := s.geminiQuotaService.QuotaForAccount(ctx, account) - if !ok { - return usage, nil - } - - start := geminiDailyWindowStart(now) - stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID) - if err != nil { - return nil, fmt.Errorf("get gemini usage stats failed: %w", err) - } - - totals := geminiAggregateUsage(stats) - resetAt := geminiDailyResetTime(now) - - usage.GeminiProDaily = buildGeminiUsageProgress(totals.ProRequests, quota.ProRPD, resetAt, totals.ProTokens, totals.ProCost, now) - usage.GeminiFlashDaily = buildGeminiUsageProgress(totals.FlashRequests, quota.FlashRPD, resetAt, totals.FlashTokens, totals.FlashCost, now) - - return usage, nil -} - -// getAntigravityUsage 获取 Antigravity 账户额度 -func (s *AccountUsageService) getAntigravityUsage(ctx context.Context, account *Account) (*UsageInfo, error) { - if s.antigravityQuotaFetcher == nil || !s.antigravityQuotaFetcher.CanFetch(account) { - now := time.Now() - return &UsageInfo{UpdatedAt: &now}, nil - } - - // 1. 检查缓存(10 分钟) - if cached, ok := s.cache.antigravityCache.Load(account.ID); ok { - if cache, ok := cached.(*antigravityUsageCache); ok && time.Since(cache.timestamp) < apiCacheTTL { - // 重新计算 RemainingSeconds - usage := cache.usageInfo - if usage.FiveHour != nil && usage.FiveHour.ResetsAt != nil { - usage.FiveHour.RemainingSeconds = int(time.Until(*usage.FiveHour.ResetsAt).Seconds()) - } - return usage, nil - } - } - - // 2. 获取代理 URL - proxyURL := s.antigravityQuotaFetcher.GetProxyURL(ctx, account) - - // 3. 调用 API 获取额度 - result, err := s.antigravityQuotaFetcher.FetchQuota(ctx, account, proxyURL) - if err != nil { - return nil, fmt.Errorf("fetch antigravity quota failed: %w", err) - } - - // 4. 缓存结果 - s.cache.antigravityCache.Store(account.ID, &antigravityUsageCache{ - usageInfo: result.UsageInfo, - timestamp: time.Now(), - }) - - return result.UsageInfo, nil -} - -// addWindowStats 为 usage 数据添加窗口期统计 -// 使用独立缓存(1 分钟),与 API 缓存分离 -func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Account, usage *UsageInfo) { - // 修复:即使 FiveHour 为 nil,也要尝试获取统计数据 - // 因为 SevenDay/SevenDaySonnet 可能需要 - if usage.FiveHour == nil && usage.SevenDay == nil && usage.SevenDaySonnet == nil { - return - } - - // 检查窗口统计缓存(1 分钟) - var windowStats *WindowStats - if cached, ok := s.cache.windowStatsCache.Load(account.ID); ok { - if cache, ok := cached.(*windowStatsCache); ok && time.Since(cache.timestamp) < windowStatsCacheTTL { - windowStats = cache.stats - } - } - - // 如果没有缓存,从数据库查询 - if windowStats == nil { - var startTime time.Time - if account.SessionWindowStart != nil { - startTime = *account.SessionWindowStart - } else { - startTime = time.Now().Add(-5 * time.Hour) - } - - stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, startTime) - if err != nil { - log.Printf("Failed to get window stats for account %d: %v", account.ID, err) - return - } - - windowStats = &WindowStats{ - Requests: stats.Requests, - Tokens: stats.Tokens, - Cost: stats.Cost, - } - - // 缓存窗口统计(1 分钟) - s.cache.windowStatsCache.Store(account.ID, &windowStatsCache{ - stats: windowStats, - timestamp: time.Now(), - }) - } - - // 为 FiveHour 添加 WindowStats(5h 窗口统计) - if usage.FiveHour != nil { - usage.FiveHour.WindowStats = windowStats - } -} - -// GetTodayStats 获取账号今日统计 -func (s *AccountUsageService) GetTodayStats(ctx context.Context, accountID int64) (*WindowStats, error) { - stats, err := s.usageLogRepo.GetAccountTodayStats(ctx, accountID) - if err != nil { - return nil, fmt.Errorf("get today stats failed: %w", err) - } - - return &WindowStats{ - Requests: stats.Requests, - Tokens: stats.Tokens, - Cost: stats.Cost, - }, nil -} - -func (s *AccountUsageService) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.AccountUsageStatsResponse, error) { - stats, err := s.usageLogRepo.GetAccountUsageStats(ctx, accountID, startTime, endTime) - if err != nil { - return nil, fmt.Errorf("get account usage stats failed: %w", err) - } - return stats, nil -} - -// fetchOAuthUsageRaw 从 Anthropic API 获取原始响应(不构建 UsageInfo) -func (s *AccountUsageService) fetchOAuthUsageRaw(ctx context.Context, account *Account) (*ClaudeUsageResponse, error) { - accessToken := account.GetCredential("access_token") - if accessToken == "" { - return nil, fmt.Errorf("no access token available") - } - - var proxyURL string - if account.ProxyID != nil && account.Proxy != nil { - proxyURL = account.Proxy.URL() - } - - return s.usageFetcher.FetchUsage(ctx, accessToken, proxyURL) -} - -// parseTime 尝试多种格式解析时间 -func parseTime(s string) (time.Time, error) { - formats := []string{ - time.RFC3339, - time.RFC3339Nano, - "2006-01-02T15:04:05Z", - "2006-01-02T15:04:05.000Z", - } - for _, format := range formats { - if t, err := time.Parse(format, s); err == nil { - return t, nil - } - } - return time.Time{}, fmt.Errorf("unable to parse time: %s", s) -} - -// buildUsageInfo 构建UsageInfo -func (s *AccountUsageService) buildUsageInfo(resp *ClaudeUsageResponse, updatedAt *time.Time) *UsageInfo { - info := &UsageInfo{ - UpdatedAt: updatedAt, - } - - // 5小时窗口 - 始终创建对象(即使 ResetsAt 为空) - info.FiveHour = &UsageProgress{ - Utilization: resp.FiveHour.Utilization, - } - if resp.FiveHour.ResetsAt != "" { - if fiveHourReset, err := parseTime(resp.FiveHour.ResetsAt); err == nil { - info.FiveHour.ResetsAt = &fiveHourReset - info.FiveHour.RemainingSeconds = int(time.Until(fiveHourReset).Seconds()) - } else { - log.Printf("Failed to parse FiveHour.ResetsAt: %s, error: %v", resp.FiveHour.ResetsAt, err) - } - } - - // 7天窗口 - if resp.SevenDay.ResetsAt != "" { - if sevenDayReset, err := parseTime(resp.SevenDay.ResetsAt); err == nil { - info.SevenDay = &UsageProgress{ - Utilization: resp.SevenDay.Utilization, - ResetsAt: &sevenDayReset, - RemainingSeconds: int(time.Until(sevenDayReset).Seconds()), - } - } else { - log.Printf("Failed to parse SevenDay.ResetsAt: %s, error: %v", resp.SevenDay.ResetsAt, err) - info.SevenDay = &UsageProgress{ - Utilization: resp.SevenDay.Utilization, - } - } - } - - // 7天Sonnet窗口 - if resp.SevenDaySonnet.ResetsAt != "" { - if sonnetReset, err := parseTime(resp.SevenDaySonnet.ResetsAt); err == nil { - info.SevenDaySonnet = &UsageProgress{ - Utilization: resp.SevenDaySonnet.Utilization, - ResetsAt: &sonnetReset, - RemainingSeconds: int(time.Until(sonnetReset).Seconds()), - } - } else { - log.Printf("Failed to parse SevenDaySonnet.ResetsAt: %s, error: %v", resp.SevenDaySonnet.ResetsAt, err) - info.SevenDaySonnet = &UsageProgress{ - Utilization: resp.SevenDaySonnet.Utilization, - } - } - } - - return info -} - -// estimateSetupTokenUsage 根据session_window推算Setup Token账号的使用量 -func (s *AccountUsageService) estimateSetupTokenUsage(account *Account) *UsageInfo { - info := &UsageInfo{} - - // 如果有session_window信息 - if account.SessionWindowEnd != nil { - remaining := int(time.Until(*account.SessionWindowEnd).Seconds()) - if remaining < 0 { - remaining = 0 - } - - // 根据状态估算使用率 (百分比形式,100 = 100%) - var utilization float64 - switch account.SessionWindowStatus { - case "rejected": - utilization = 100.0 - case "allowed_warning": - utilization = 80.0 - default: - utilization = 0.0 - } - - info.FiveHour = &UsageProgress{ - Utilization: utilization, - ResetsAt: account.SessionWindowEnd, - RemainingSeconds: remaining, - } - } else { - // 没有窗口信息,返回空数据 - info.FiveHour = &UsageProgress{ - Utilization: 0, - RemainingSeconds: 0, - } - } - - // Setup Token无法获取7d数据 - return info -} - -func buildGeminiUsageProgress(used, limit int64, resetAt time.Time, tokens int64, cost float64, now time.Time) *UsageProgress { - if limit <= 0 { - return nil - } - utilization := (float64(used) / float64(limit)) * 100 - remainingSeconds := int(resetAt.Sub(now).Seconds()) - if remainingSeconds < 0 { - remainingSeconds = 0 - } - resetCopy := resetAt - return &UsageProgress{ - Utilization: utilization, - ResetsAt: &resetCopy, - RemainingSeconds: remainingSeconds, - WindowStats: &WindowStats{ - Requests: used, - Tokens: tokens, - Cost: cost, - }, - } -} +package service + +import ( + "context" + "fmt" + "log" + "sync" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" +) + +type UsageLogRepository interface { + Create(ctx context.Context, log *UsageLog) error + GetByID(ctx context.Context, id int64) (*UsageLog, error) + Delete(ctx context.Context, id int64) error + + ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error) + ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error) + ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error) + + ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error) + ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error) + ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error) + ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error) + + GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error) + GetAccountTodayStats(ctx context.Context, accountID int64) (*usagestats.AccountStats, error) + + // Admin dashboard stats + GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) + GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) ([]usagestats.TrendDataPoint, error) + GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) ([]usagestats.ModelStat, error) + GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.ApiKeyUsageTrendPoint, error) + GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error) + GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error) + GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchApiKeyUsageStats, error) + + // User dashboard stats + GetUserDashboardStats(ctx context.Context, userID int64) (*usagestats.UserDashboardStats, error) + GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) ([]usagestats.TrendDataPoint, error) + GetUserModelStats(ctx context.Context, userID int64, startTime, endTime time.Time) ([]usagestats.ModelStat, error) + + // Admin usage listing/stats + ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]UsageLog, *pagination.PaginationResult, error) + GetGlobalStats(ctx context.Context, startTime, endTime time.Time) (*usagestats.UsageStats, error) + + // Account stats + GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.AccountUsageStatsResponse, error) + + // Aggregated stats (optimized) + GetUserStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) + GetApiKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) + GetAccountStatsAggregated(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) + GetModelStatsAggregated(ctx context.Context, modelName string, startTime, endTime time.Time) (*usagestats.UsageStats, error) + GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) ([]map[string]any, error) +} + +// apiUsageCache 缓存从 Anthropic API 获取的使用率数据(utilization, resets_at) +type apiUsageCache struct { + response *ClaudeUsageResponse + timestamp time.Time +} + +// windowStatsCache 缓存从本地数据库查询的窗口统计(requests, tokens, cost) +type windowStatsCache struct { + stats *WindowStats + timestamp time.Time +} + +// antigravityUsageCache 缓存 Antigravity 额度数据 +type antigravityUsageCache struct { + usageInfo *UsageInfo + timestamp time.Time +} + +const ( + apiCacheTTL = 3 * time.Minute + windowStatsCacheTTL = 1 * time.Minute +) + +// UsageCache 封装账户使用量相关的缓存 +type UsageCache struct { + apiCache sync.Map // accountID -> *apiUsageCache + windowStatsCache sync.Map // accountID -> *windowStatsCache + antigravityCache sync.Map // accountID -> *antigravityUsageCache +} + +// NewUsageCache 创建 UsageCache 实例 +func NewUsageCache() *UsageCache { + return &UsageCache{} +} + +// WindowStats 窗口期统计 +type WindowStats struct { + Requests int64 `json:"requests"` + Tokens int64 `json:"tokens"` + Cost float64 `json:"cost"` +} + +// UsageProgress 使用量进度 +type UsageProgress struct { + Utilization float64 `json:"utilization"` // 使用率百分比 (0-100+,100表示100%) + ResetsAt *time.Time `json:"resets_at"` // 重置时间 + RemainingSeconds int `json:"remaining_seconds"` // 距重置剩余秒数 + WindowStats *WindowStats `json:"window_stats,omitempty"` // 窗口期统计(从窗口开始到当前的使用量) +} + +// AntigravityModelQuota Antigravity 单个模型的配额信息 +type AntigravityModelQuota struct { + Utilization int `json:"utilization"` // 使用率 0-100 + ResetTime string `json:"reset_time"` // 重置时间 ISO8601 +} + +// UsageInfo 账号使用量信息 +type UsageInfo struct { + UpdatedAt *time.Time `json:"updated_at,omitempty"` // 更新时间 + FiveHour *UsageProgress `json:"five_hour"` // 5小时窗口 + SevenDay *UsageProgress `json:"seven_day,omitempty"` // 7天窗口 + SevenDaySonnet *UsageProgress `json:"seven_day_sonnet,omitempty"` // 7天Sonnet窗口 + GeminiProDaily *UsageProgress `json:"gemini_pro_daily,omitempty"` // Gemini Pro 日配额 + GeminiFlashDaily *UsageProgress `json:"gemini_flash_daily,omitempty"` // Gemini Flash 日配额 + + // Antigravity 多模型配额 + AntigravityQuota map[string]*AntigravityModelQuota `json:"antigravity_quota,omitempty"` +} + +// ClaudeUsageResponse Anthropic API返回的usage结构 +type ClaudeUsageResponse struct { + FiveHour struct { + Utilization float64 `json:"utilization"` + ResetsAt string `json:"resets_at"` + } `json:"five_hour"` + SevenDay struct { + Utilization float64 `json:"utilization"` + ResetsAt string `json:"resets_at"` + } `json:"seven_day"` + SevenDaySonnet struct { + Utilization float64 `json:"utilization"` + ResetsAt string `json:"resets_at"` + } `json:"seven_day_sonnet"` +} + +// ClaudeUsageFetcher fetches usage data from Anthropic OAuth API +type ClaudeUsageFetcher interface { + FetchUsage(ctx context.Context, accessToken, proxyURL string) (*ClaudeUsageResponse, error) +} + +// AccountUsageService 账号使用量查询服务 +type AccountUsageService struct { + accountRepo AccountRepository + usageLogRepo UsageLogRepository + usageFetcher ClaudeUsageFetcher + geminiQuotaService *GeminiQuotaService + antigravityQuotaFetcher *AntigravityQuotaFetcher + cache *UsageCache +} + +// NewAccountUsageService 创建AccountUsageService实例 +func NewAccountUsageService( + accountRepo AccountRepository, + usageLogRepo UsageLogRepository, + usageFetcher ClaudeUsageFetcher, + geminiQuotaService *GeminiQuotaService, + antigravityQuotaFetcher *AntigravityQuotaFetcher, + cache *UsageCache, +) *AccountUsageService { + return &AccountUsageService{ + accountRepo: accountRepo, + usageLogRepo: usageLogRepo, + usageFetcher: usageFetcher, + geminiQuotaService: geminiQuotaService, + antigravityQuotaFetcher: antigravityQuotaFetcher, + cache: cache, + } +} + +// GetUsage 获取账号使用量 +// OAuth账号: 调用Anthropic API获取真实数据(需要profile scope),API响应缓存10分钟,窗口统计缓存1分钟 +// Setup Token账号: 根据session_window推算5h窗口,7d数据不可用(没有profile scope) +// API Key账号: 不支持usage查询 +func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*UsageInfo, error) { + account, err := s.accountRepo.GetByID(ctx, accountID) + if err != nil { + return nil, fmt.Errorf("get account failed: %w", err) + } + + if account.Platform == PlatformGemini { + return s.getGeminiUsage(ctx, account) + } + + // Antigravity 平台:使用 AntigravityQuotaFetcher 获取额度 + if account.Platform == PlatformAntigravity { + return s.getAntigravityUsage(ctx, account) + } + + // 只有oauth类型账号可以通过API获取usage(有profile scope) + if account.CanGetUsage() { + var apiResp *ClaudeUsageResponse + + // 1. 检查 API 缓存(10 分钟) + if cached, ok := s.cache.apiCache.Load(accountID); ok { + if cache, ok := cached.(*apiUsageCache); ok && time.Since(cache.timestamp) < apiCacheTTL { + apiResp = cache.response + } + } + + // 2. 如果没有缓存,从 API 获取 + if apiResp == nil { + apiResp, err = s.fetchOAuthUsageRaw(ctx, account) + if err != nil { + return nil, err + } + // 缓存 API 响应 + s.cache.apiCache.Store(accountID, &apiUsageCache{ + response: apiResp, + timestamp: time.Now(), + }) + } + + // 3. 构建 UsageInfo(每次都重新计算 RemainingSeconds) + now := time.Now() + usage := s.buildUsageInfo(apiResp, &now) + + // 4. 添加窗口统计(有独立缓存,1 分钟) + s.addWindowStats(ctx, account, usage) + + return usage, nil + } + + // Setup Token账号:根据session_window推算(没有profile scope,无法调用usage API) + if account.Type == AccountTypeSetupToken { + usage := s.estimateSetupTokenUsage(account) + // 添加窗口统计 + s.addWindowStats(ctx, account, usage) + return usage, nil + } + + // API Key账号不支持usage查询 + return nil, fmt.Errorf("account type %s does not support usage query", account.Type) +} + +func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Account) (*UsageInfo, error) { + now := time.Now() + usage := &UsageInfo{ + UpdatedAt: &now, + } + + if s.geminiQuotaService == nil || s.usageLogRepo == nil { + return usage, nil + } + + quota, ok := s.geminiQuotaService.QuotaForAccount(ctx, account) + if !ok { + return usage, nil + } + + start := geminiDailyWindowStart(now) + stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID) + if err != nil { + return nil, fmt.Errorf("get gemini usage stats failed: %w", err) + } + + totals := geminiAggregateUsage(stats) + resetAt := geminiDailyResetTime(now) + + usage.GeminiProDaily = buildGeminiUsageProgress(totals.ProRequests, quota.ProRPD, resetAt, totals.ProTokens, totals.ProCost, now) + usage.GeminiFlashDaily = buildGeminiUsageProgress(totals.FlashRequests, quota.FlashRPD, resetAt, totals.FlashTokens, totals.FlashCost, now) + + return usage, nil +} + +// getAntigravityUsage 获取 Antigravity 账户额度 +func (s *AccountUsageService) getAntigravityUsage(ctx context.Context, account *Account) (*UsageInfo, error) { + if s.antigravityQuotaFetcher == nil || !s.antigravityQuotaFetcher.CanFetch(account) { + now := time.Now() + return &UsageInfo{UpdatedAt: &now}, nil + } + + // 1. 检查缓存(10 分钟) + if cached, ok := s.cache.antigravityCache.Load(account.ID); ok { + if cache, ok := cached.(*antigravityUsageCache); ok && time.Since(cache.timestamp) < apiCacheTTL { + // 重新计算 RemainingSeconds + usage := cache.usageInfo + if usage.FiveHour != nil && usage.FiveHour.ResetsAt != nil { + usage.FiveHour.RemainingSeconds = int(time.Until(*usage.FiveHour.ResetsAt).Seconds()) + } + return usage, nil + } + } + + // 2. 获取代理 URL + proxyURL := s.antigravityQuotaFetcher.GetProxyURL(ctx, account) + + // 3. 调用 API 获取额度 + result, err := s.antigravityQuotaFetcher.FetchQuota(ctx, account, proxyURL) + if err != nil { + return nil, fmt.Errorf("fetch antigravity quota failed: %w", err) + } + + // 4. 缓存结果 + s.cache.antigravityCache.Store(account.ID, &antigravityUsageCache{ + usageInfo: result.UsageInfo, + timestamp: time.Now(), + }) + + return result.UsageInfo, nil +} + +// addWindowStats 为 usage 数据添加窗口期统计 +// 使用独立缓存(1 分钟),与 API 缓存分离 +func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Account, usage *UsageInfo) { + // 修复:即使 FiveHour 为 nil,也要尝试获取统计数据 + // 因为 SevenDay/SevenDaySonnet 可能需要 + if usage.FiveHour == nil && usage.SevenDay == nil && usage.SevenDaySonnet == nil { + return + } + + // 检查窗口统计缓存(1 分钟) + var windowStats *WindowStats + if cached, ok := s.cache.windowStatsCache.Load(account.ID); ok { + if cache, ok := cached.(*windowStatsCache); ok && time.Since(cache.timestamp) < windowStatsCacheTTL { + windowStats = cache.stats + } + } + + // 如果没有缓存,从数据库查询 + if windowStats == nil { + var startTime time.Time + if account.SessionWindowStart != nil { + startTime = *account.SessionWindowStart + } else { + startTime = time.Now().Add(-5 * time.Hour) + } + + stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, startTime) + if err != nil { + log.Printf("Failed to get window stats for account %d: %v", account.ID, err) + return + } + + windowStats = &WindowStats{ + Requests: stats.Requests, + Tokens: stats.Tokens, + Cost: stats.Cost, + } + + // 缓存窗口统计(1 分钟) + s.cache.windowStatsCache.Store(account.ID, &windowStatsCache{ + stats: windowStats, + timestamp: time.Now(), + }) + } + + // 为 FiveHour 添加 WindowStats(5h 窗口统计) + if usage.FiveHour != nil { + usage.FiveHour.WindowStats = windowStats + } +} + +// GetTodayStats 获取账号今日统计 +func (s *AccountUsageService) GetTodayStats(ctx context.Context, accountID int64) (*WindowStats, error) { + stats, err := s.usageLogRepo.GetAccountTodayStats(ctx, accountID) + if err != nil { + return nil, fmt.Errorf("get today stats failed: %w", err) + } + + return &WindowStats{ + Requests: stats.Requests, + Tokens: stats.Tokens, + Cost: stats.Cost, + }, nil +} + +func (s *AccountUsageService) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.AccountUsageStatsResponse, error) { + stats, err := s.usageLogRepo.GetAccountUsageStats(ctx, accountID, startTime, endTime) + if err != nil { + return nil, fmt.Errorf("get account usage stats failed: %w", err) + } + return stats, nil +} + +// fetchOAuthUsageRaw 从 Anthropic API 获取原始响应(不构建 UsageInfo) +func (s *AccountUsageService) fetchOAuthUsageRaw(ctx context.Context, account *Account) (*ClaudeUsageResponse, error) { + accessToken := account.GetCredential("access_token") + if accessToken == "" { + return nil, fmt.Errorf("no access token available") + } + + var proxyURL string + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + return s.usageFetcher.FetchUsage(ctx, accessToken, proxyURL) +} + +// parseTime 尝试多种格式解析时间 +func parseTime(s string) (time.Time, error) { + formats := []string{ + time.RFC3339, + time.RFC3339Nano, + "2006-01-02T15:04:05Z", + "2006-01-02T15:04:05.000Z", + } + for _, format := range formats { + if t, err := time.Parse(format, s); err == nil { + return t, nil + } + } + return time.Time{}, fmt.Errorf("unable to parse time: %s", s) +} + +// buildUsageInfo 构建UsageInfo +func (s *AccountUsageService) buildUsageInfo(resp *ClaudeUsageResponse, updatedAt *time.Time) *UsageInfo { + info := &UsageInfo{ + UpdatedAt: updatedAt, + } + + // 5小时窗口 - 始终创建对象(即使 ResetsAt 为空) + info.FiveHour = &UsageProgress{ + Utilization: resp.FiveHour.Utilization, + } + if resp.FiveHour.ResetsAt != "" { + if fiveHourReset, err := parseTime(resp.FiveHour.ResetsAt); err == nil { + info.FiveHour.ResetsAt = &fiveHourReset + info.FiveHour.RemainingSeconds = int(time.Until(fiveHourReset).Seconds()) + } else { + log.Printf("Failed to parse FiveHour.ResetsAt: %s, error: %v", resp.FiveHour.ResetsAt, err) + } + } + + // 7天窗口 + if resp.SevenDay.ResetsAt != "" { + if sevenDayReset, err := parseTime(resp.SevenDay.ResetsAt); err == nil { + info.SevenDay = &UsageProgress{ + Utilization: resp.SevenDay.Utilization, + ResetsAt: &sevenDayReset, + RemainingSeconds: int(time.Until(sevenDayReset).Seconds()), + } + } else { + log.Printf("Failed to parse SevenDay.ResetsAt: %s, error: %v", resp.SevenDay.ResetsAt, err) + info.SevenDay = &UsageProgress{ + Utilization: resp.SevenDay.Utilization, + } + } + } + + // 7天Sonnet窗口 + if resp.SevenDaySonnet.ResetsAt != "" { + if sonnetReset, err := parseTime(resp.SevenDaySonnet.ResetsAt); err == nil { + info.SevenDaySonnet = &UsageProgress{ + Utilization: resp.SevenDaySonnet.Utilization, + ResetsAt: &sonnetReset, + RemainingSeconds: int(time.Until(sonnetReset).Seconds()), + } + } else { + log.Printf("Failed to parse SevenDaySonnet.ResetsAt: %s, error: %v", resp.SevenDaySonnet.ResetsAt, err) + info.SevenDaySonnet = &UsageProgress{ + Utilization: resp.SevenDaySonnet.Utilization, + } + } + } + + return info +} + +// estimateSetupTokenUsage 根据session_window推算Setup Token账号的使用量 +func (s *AccountUsageService) estimateSetupTokenUsage(account *Account) *UsageInfo { + info := &UsageInfo{} + + // 如果有session_window信息 + if account.SessionWindowEnd != nil { + remaining := int(time.Until(*account.SessionWindowEnd).Seconds()) + if remaining < 0 { + remaining = 0 + } + + // 根据状态估算使用率 (百分比形式,100 = 100%) + var utilization float64 + switch account.SessionWindowStatus { + case "rejected": + utilization = 100.0 + case "allowed_warning": + utilization = 80.0 + default: + utilization = 0.0 + } + + info.FiveHour = &UsageProgress{ + Utilization: utilization, + ResetsAt: account.SessionWindowEnd, + RemainingSeconds: remaining, + } + } else { + // 没有窗口信息,返回空数据 + info.FiveHour = &UsageProgress{ + Utilization: 0, + RemainingSeconds: 0, + } + } + + // Setup Token无法获取7d数据 + return info +} + +func buildGeminiUsageProgress(used, limit int64, resetAt time.Time, tokens int64, cost float64, now time.Time) *UsageProgress { + if limit <= 0 { + return nil + } + utilization := (float64(used) / float64(limit)) * 100 + remainingSeconds := int(resetAt.Sub(now).Seconds()) + if remainingSeconds < 0 { + remainingSeconds = 0 + } + resetCopy := resetAt + return &UsageProgress{ + Utilization: utilization, + ResetsAt: &resetCopy, + RemainingSeconds: remainingSeconds, + WindowStats: &WindowStats{ + Requests: used, + Tokens: tokens, + Cost: cost, + }, + } +} diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index f9d11543..707e728b 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -1,1007 +1,1008 @@ -package service - -import ( - "context" - "errors" - "fmt" - "log" - "time" - - "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" -) - -// AdminService interface defines admin management operations -type AdminService interface { - // User management - ListUsers(ctx context.Context, page, pageSize int, filters UserListFilters) ([]User, int64, error) - GetUser(ctx context.Context, id int64) (*User, error) - CreateUser(ctx context.Context, input *CreateUserInput) (*User, error) - UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error) - DeleteUser(ctx context.Context, id int64) error - UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error) - GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]ApiKey, int64, error) - GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error) - - // Group management - ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]Group, int64, error) - GetAllGroups(ctx context.Context) ([]Group, error) - GetAllGroupsByPlatform(ctx context.Context, platform string) ([]Group, error) - GetGroup(ctx context.Context, id int64) (*Group, error) - CreateGroup(ctx context.Context, input *CreateGroupInput) (*Group, error) - UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) - DeleteGroup(ctx context.Context, id int64) error - GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]ApiKey, int64, error) - - // Account management - ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error) - GetAccount(ctx context.Context, id int64) (*Account, error) - GetAccountsByIDs(ctx context.Context, ids []int64) ([]*Account, error) - CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error) - UpdateAccount(ctx context.Context, id int64, input *UpdateAccountInput) (*Account, error) - DeleteAccount(ctx context.Context, id int64) error - RefreshAccountCredentials(ctx context.Context, id int64) (*Account, error) - ClearAccountError(ctx context.Context, id int64) (*Account, error) - SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error) - BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error) - - // Proxy management - ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]Proxy, int64, error) - GetAllProxies(ctx context.Context) ([]Proxy, error) - GetAllProxiesWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error) - GetProxy(ctx context.Context, id int64) (*Proxy, error) - CreateProxy(ctx context.Context, input *CreateProxyInput) (*Proxy, error) - UpdateProxy(ctx context.Context, id int64, input *UpdateProxyInput) (*Proxy, error) - DeleteProxy(ctx context.Context, id int64) error - GetProxyAccounts(ctx context.Context, proxyID int64, page, pageSize int) ([]Account, int64, error) - CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error) - TestProxy(ctx context.Context, id int64) (*ProxyTestResult, error) - - // Redeem code management - ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]RedeemCode, int64, error) - GetRedeemCode(ctx context.Context, id int64) (*RedeemCode, error) - GenerateRedeemCodes(ctx context.Context, input *GenerateRedeemCodesInput) ([]RedeemCode, error) - DeleteRedeemCode(ctx context.Context, id int64) error - BatchDeleteRedeemCodes(ctx context.Context, ids []int64) (int64, error) - ExpireRedeemCode(ctx context.Context, id int64) (*RedeemCode, error) -} - -// Input types for admin operations -type CreateUserInput struct { - Email string - Password string - Username string - Notes string - Balance float64 - Concurrency int - AllowedGroups []int64 -} - -type UpdateUserInput struct { - Email string - Password string - Username *string - Notes *string - Balance *float64 // 使用指针区分"未提供"和"设置为0" - Concurrency *int // 使用指针区分"未提供"和"设置为0" - Status string - AllowedGroups *[]int64 // 使用指针区分"未提供"和"设置为空数组" -} - -type CreateGroupInput struct { - Name string - Description string - Platform string - RateMultiplier float64 - IsExclusive bool - SubscriptionType string // standard/subscription - DailyLimitUSD *float64 // 日限额 (USD) - WeeklyLimitUSD *float64 // 周限额 (USD) - MonthlyLimitUSD *float64 // 月限额 (USD) -} - -type UpdateGroupInput struct { - Name string - Description string - Platform string - RateMultiplier *float64 // 使用指针以支持设置为0 - IsExclusive *bool - Status string - SubscriptionType string // standard/subscription - DailyLimitUSD *float64 // 日限额 (USD) - WeeklyLimitUSD *float64 // 周限额 (USD) - MonthlyLimitUSD *float64 // 月限额 (USD) -} - -type CreateAccountInput struct { - Name string - Platform string - Type string - Credentials map[string]any - Extra map[string]any - ProxyID *int64 - Concurrency int - Priority int - GroupIDs []int64 -} - -type UpdateAccountInput struct { - Name string - Type string // Account type: oauth, setup-token, apikey - Credentials map[string]any - Extra map[string]any - ProxyID *int64 - Concurrency *int // 使用指针区分"未提供"和"设置为0" - Priority *int // 使用指针区分"未提供"和"设置为0" - Status string - GroupIDs *[]int64 -} - -// BulkUpdateAccountsInput describes the payload for bulk updating accounts. -type BulkUpdateAccountsInput struct { - AccountIDs []int64 - Name string - ProxyID *int64 - Concurrency *int - Priority *int - Status string - GroupIDs *[]int64 - Credentials map[string]any - Extra map[string]any -} - -// BulkUpdateAccountResult captures the result for a single account update. -type BulkUpdateAccountResult struct { - AccountID int64 `json:"account_id"` - Success bool `json:"success"` - Error string `json:"error,omitempty"` -} - -// BulkUpdateAccountsResult is the aggregated response for bulk updates. -type BulkUpdateAccountsResult struct { - Success int `json:"success"` - Failed int `json:"failed"` - Results []BulkUpdateAccountResult `json:"results"` -} - -type CreateProxyInput struct { - Name string - Protocol string - Host string - Port int - Username string - Password string -} - -type UpdateProxyInput struct { - Name string - Protocol string - Host string - Port int - Username string - Password string - Status string -} - -type GenerateRedeemCodesInput struct { - Count int - Type string - Value float64 - GroupID *int64 // 订阅类型专用:关联的分组ID - ValidityDays int // 订阅类型专用:有效天数 -} - -// ProxyTestResult represents the result of testing a proxy -type ProxyTestResult struct { - Success bool `json:"success"` - Message string `json:"message"` - LatencyMs int64 `json:"latency_ms,omitempty"` - IPAddress string `json:"ip_address,omitempty"` - City string `json:"city,omitempty"` - Region string `json:"region,omitempty"` - Country string `json:"country,omitempty"` -} - -// ProxyExitInfo represents proxy exit information from ipinfo.io -type ProxyExitInfo struct { - IP string - City string - Region string - Country string -} - -// ProxyExitInfoProber tests proxy connectivity and retrieves exit information -type ProxyExitInfoProber interface { - ProbeProxy(ctx context.Context, proxyURL string) (*ProxyExitInfo, int64, error) -} - -// adminServiceImpl implements AdminService -type adminServiceImpl struct { - userRepo UserRepository - groupRepo GroupRepository - accountRepo AccountRepository - proxyRepo ProxyRepository - apiKeyRepo ApiKeyRepository - redeemCodeRepo RedeemCodeRepository - billingCacheService *BillingCacheService - proxyProber ProxyExitInfoProber -} - -// NewAdminService creates a new AdminService -func NewAdminService( - userRepo UserRepository, - groupRepo GroupRepository, - accountRepo AccountRepository, - proxyRepo ProxyRepository, - apiKeyRepo ApiKeyRepository, - redeemCodeRepo RedeemCodeRepository, - billingCacheService *BillingCacheService, - proxyProber ProxyExitInfoProber, -) AdminService { - return &adminServiceImpl{ - userRepo: userRepo, - groupRepo: groupRepo, - accountRepo: accountRepo, - proxyRepo: proxyRepo, - apiKeyRepo: apiKeyRepo, - redeemCodeRepo: redeemCodeRepo, - billingCacheService: billingCacheService, - proxyProber: proxyProber, - } -} - -// User management implementations -func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, filters UserListFilters) ([]User, int64, error) { - params := pagination.PaginationParams{Page: page, PageSize: pageSize} - users, result, err := s.userRepo.ListWithFilters(ctx, params, filters) - if err != nil { - return nil, 0, err - } - return users, result.Total, nil -} - -func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error) { - return s.userRepo.GetByID(ctx, id) -} - -func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInput) (*User, error) { - user := &User{ - Email: input.Email, - Username: input.Username, - Notes: input.Notes, - Role: RoleUser, // Always create as regular user, never admin - Balance: input.Balance, - Concurrency: input.Concurrency, - Status: StatusActive, - AllowedGroups: input.AllowedGroups, - } - if err := user.SetPassword(input.Password); err != nil { - return nil, err - } - if err := s.userRepo.Create(ctx, user); err != nil { - return nil, err - } - return user, nil -} - -func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error) { - user, err := s.userRepo.GetByID(ctx, id) - if err != nil { - return nil, err - } - - // Protect admin users: cannot disable admin accounts - if user.Role == "admin" && input.Status == "disabled" { - return nil, errors.New("cannot disable admin user") - } - - oldConcurrency := user.Concurrency - - if input.Email != "" { - user.Email = input.Email - } - if input.Password != "" { - if err := user.SetPassword(input.Password); err != nil { - return nil, err - } - } - - if input.Username != nil { - user.Username = *input.Username - } - if input.Notes != nil { - user.Notes = *input.Notes - } - - if input.Status != "" { - user.Status = input.Status - } - - if input.Concurrency != nil { - user.Concurrency = *input.Concurrency - } - - if input.AllowedGroups != nil { - user.AllowedGroups = *input.AllowedGroups - } - - if err := s.userRepo.Update(ctx, user); err != nil { - return nil, err - } - - concurrencyDiff := user.Concurrency - oldConcurrency - if concurrencyDiff != 0 { - code, err := GenerateRedeemCode() - if err != nil { - log.Printf("failed to generate adjustment redeem code: %v", err) - return user, nil - } - adjustmentRecord := &RedeemCode{ - Code: code, - Type: AdjustmentTypeAdminConcurrency, - Value: float64(concurrencyDiff), - Status: StatusUsed, - UsedBy: &user.ID, - } - now := time.Now() - adjustmentRecord.UsedAt = &now - if err := s.redeemCodeRepo.Create(ctx, adjustmentRecord); err != nil { - log.Printf("failed to create concurrency adjustment redeem code: %v", err) - } - } - - return user, nil -} - -func (s *adminServiceImpl) DeleteUser(ctx context.Context, id int64) error { - // Protect admin users: cannot delete admin accounts - user, err := s.userRepo.GetByID(ctx, id) - if err != nil { - return err - } - if user.Role == "admin" { - return errors.New("cannot delete admin user") - } - if err := s.userRepo.Delete(ctx, id); err != nil { - log.Printf("delete user failed: user_id=%d err=%v", id, err) - return err - } - return nil -} - -func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error) { - user, err := s.userRepo.GetByID(ctx, userID) - if err != nil { - return nil, err - } - - oldBalance := user.Balance - - switch operation { - case "set": - user.Balance = balance - case "add": - user.Balance += balance - case "subtract": - user.Balance -= balance - } - - if user.Balance < 0 { - return nil, fmt.Errorf("balance cannot be negative, current balance: %.2f, requested operation would result in: %.2f", oldBalance, user.Balance) - } - - if err := s.userRepo.Update(ctx, user); err != nil { - return nil, err - } - - if s.billingCacheService != nil { - go func() { - cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - if err := s.billingCacheService.InvalidateUserBalance(cacheCtx, userID); err != nil { - log.Printf("invalidate user balance cache failed: user_id=%d err=%v", userID, err) - } - }() - } - - balanceDiff := user.Balance - oldBalance - if balanceDiff != 0 { - code, err := GenerateRedeemCode() - if err != nil { - log.Printf("failed to generate adjustment redeem code: %v", err) - return user, nil - } - - adjustmentRecord := &RedeemCode{ - Code: code, - Type: AdjustmentTypeAdminBalance, - Value: balanceDiff, - Status: StatusUsed, - UsedBy: &user.ID, - Notes: notes, - } - now := time.Now() - adjustmentRecord.UsedAt = &now - - if err := s.redeemCodeRepo.Create(ctx, adjustmentRecord); err != nil { - log.Printf("failed to create balance adjustment redeem code: %v", err) - } - } - - return user, nil -} - -func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]ApiKey, int64, error) { - params := pagination.PaginationParams{Page: page, PageSize: pageSize} - keys, result, err := s.apiKeyRepo.ListByUserID(ctx, userID, params) - if err != nil { - return nil, 0, err - } - return keys, result.Total, nil -} - -func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error) { - // Return mock data for now - return map[string]any{ - "period": period, - "total_requests": 0, - "total_cost": 0.0, - "total_tokens": 0, - "avg_duration_ms": 0, - }, nil -} - -// Group management implementations -func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]Group, int64, error) { - params := pagination.PaginationParams{Page: page, PageSize: pageSize} - groups, result, err := s.groupRepo.ListWithFilters(ctx, params, platform, status, isExclusive) - if err != nil { - return nil, 0, err - } - return groups, result.Total, nil -} - -func (s *adminServiceImpl) GetAllGroups(ctx context.Context) ([]Group, error) { - return s.groupRepo.ListActive(ctx) -} - -func (s *adminServiceImpl) GetAllGroupsByPlatform(ctx context.Context, platform string) ([]Group, error) { - return s.groupRepo.ListActiveByPlatform(ctx, platform) -} - -func (s *adminServiceImpl) GetGroup(ctx context.Context, id int64) (*Group, error) { - return s.groupRepo.GetByID(ctx, id) -} - -func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupInput) (*Group, error) { - platform := input.Platform - if platform == "" { - platform = PlatformAnthropic - } - - subscriptionType := input.SubscriptionType - if subscriptionType == "" { - subscriptionType = SubscriptionTypeStandard - } - - // 限额字段:0 和 nil 都表示"无限制" - dailyLimit := normalizeLimit(input.DailyLimitUSD) - weeklyLimit := normalizeLimit(input.WeeklyLimitUSD) - monthlyLimit := normalizeLimit(input.MonthlyLimitUSD) - - group := &Group{ - Name: input.Name, - Description: input.Description, - Platform: platform, - RateMultiplier: input.RateMultiplier, - IsExclusive: input.IsExclusive, - Status: StatusActive, - SubscriptionType: subscriptionType, - DailyLimitUSD: dailyLimit, - WeeklyLimitUSD: weeklyLimit, - MonthlyLimitUSD: monthlyLimit, - } - if err := s.groupRepo.Create(ctx, group); err != nil { - return nil, err - } - return group, nil -} - -// normalizeLimit 将 0 或负数转换为 nil(表示无限制) -func normalizeLimit(limit *float64) *float64 { - if limit == nil || *limit <= 0 { - return nil - } - return limit -} - -func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) { - group, err := s.groupRepo.GetByID(ctx, id) - if err != nil { - return nil, err - } - - if input.Name != "" { - group.Name = input.Name - } - if input.Description != "" { - group.Description = input.Description - } - if input.Platform != "" { - group.Platform = input.Platform - } - if input.RateMultiplier != nil { - group.RateMultiplier = *input.RateMultiplier - } - if input.IsExclusive != nil { - group.IsExclusive = *input.IsExclusive - } - if input.Status != "" { - group.Status = input.Status - } - - // 订阅相关字段 - if input.SubscriptionType != "" { - group.SubscriptionType = input.SubscriptionType - } - // 限额字段:0 和 nil 都表示"无限制",正数表示具体限额 - if input.DailyLimitUSD != nil { - group.DailyLimitUSD = normalizeLimit(input.DailyLimitUSD) - } - if input.WeeklyLimitUSD != nil { - group.WeeklyLimitUSD = normalizeLimit(input.WeeklyLimitUSD) - } - if input.MonthlyLimitUSD != nil { - group.MonthlyLimitUSD = normalizeLimit(input.MonthlyLimitUSD) - } - - if err := s.groupRepo.Update(ctx, group); err != nil { - return nil, err - } - return group, nil -} - -func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error { - affectedUserIDs, err := s.groupRepo.DeleteCascade(ctx, id) - if err != nil { - return err - } - - // 事务成功后,异步失效受影响用户的订阅缓存 - if len(affectedUserIDs) > 0 && s.billingCacheService != nil { - groupID := id - go func() { - cacheCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - for _, userID := range affectedUserIDs { - if err := s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID); err != nil { - log.Printf("invalidate subscription cache failed: user_id=%d group_id=%d err=%v", userID, groupID, err) - } - } - }() - } - - return nil -} - -func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]ApiKey, int64, error) { - params := pagination.PaginationParams{Page: page, PageSize: pageSize} - keys, result, err := s.apiKeyRepo.ListByGroupID(ctx, groupID, params) - if err != nil { - return nil, 0, err - } - return keys, result.Total, nil -} - -// Account management implementations -func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error) { - params := pagination.PaginationParams{Page: page, PageSize: pageSize} - accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search) - if err != nil { - return nil, 0, err - } - return accounts, result.Total, nil -} - -func (s *adminServiceImpl) GetAccount(ctx context.Context, id int64) (*Account, error) { - return s.accountRepo.GetByID(ctx, id) -} - -func (s *adminServiceImpl) GetAccountsByIDs(ctx context.Context, ids []int64) ([]*Account, error) { - if len(ids) == 0 { - return []*Account{}, nil - } - - accounts, err := s.accountRepo.GetByIDs(ctx, ids) - if err != nil { - return nil, fmt.Errorf("failed to get accounts by IDs: %w", err) - } - - return accounts, nil -} - -func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error) { - account := &Account{ - Name: input.Name, - Platform: input.Platform, - Type: input.Type, - Credentials: input.Credentials, - Extra: input.Extra, - ProxyID: input.ProxyID, - Concurrency: input.Concurrency, - Priority: input.Priority, - Status: StatusActive, - } - if err := s.accountRepo.Create(ctx, account); err != nil { - return nil, err - } - - // 绑定分组 - groupIDs := input.GroupIDs - // 如果没有指定分组,自动绑定对应平台的默认分组 - if len(groupIDs) == 0 { - defaultGroupName := input.Platform + "-default" - groups, err := s.groupRepo.ListActiveByPlatform(ctx, input.Platform) - if err == nil { - for _, g := range groups { - if g.Name == defaultGroupName { - groupIDs = []int64{g.ID} - log.Printf("[CreateAccount] Auto-binding account %d to default group %s (ID: %d)", account.ID, defaultGroupName, g.ID) - break - } - } - } - } - - if len(groupIDs) > 0 { - if err := s.accountRepo.BindGroups(ctx, account.ID, groupIDs); err != nil { - return nil, err - } - } - - return account, nil -} - -func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *UpdateAccountInput) (*Account, error) { - account, err := s.accountRepo.GetByID(ctx, id) - if err != nil { - return nil, err - } - - if input.Name != "" { - account.Name = input.Name - } - if input.Type != "" { - account.Type = input.Type - } - if len(input.Credentials) > 0 { - account.Credentials = input.Credentials - } - if len(input.Extra) > 0 { - account.Extra = input.Extra - } - if input.ProxyID != nil { - account.ProxyID = input.ProxyID - account.Proxy = nil // 清除关联对象,防止 GORM Save 时根据 Proxy.ID 覆盖 ProxyID - } - // 只在指针非 nil 时更新 Concurrency(支持设置为 0) - if input.Concurrency != nil { - account.Concurrency = *input.Concurrency - } - // 只在指针非 nil 时更新 Priority(支持设置为 0) - if input.Priority != nil { - account.Priority = *input.Priority - } - if input.Status != "" { - account.Status = input.Status - } - - // 先验证分组是否存在(在任何写操作之前) - if input.GroupIDs != nil { - for _, groupID := range *input.GroupIDs { - if _, err := s.groupRepo.GetByID(ctx, groupID); err != nil { - return nil, fmt.Errorf("get group: %w", err) - } - } - } - - if err := s.accountRepo.Update(ctx, account); err != nil { - return nil, err - } - - // 绑定分组 - if input.GroupIDs != nil { - if err := s.accountRepo.BindGroups(ctx, account.ID, *input.GroupIDs); err != nil { - return nil, err - } - } - - // 重新查询以确保返回完整数据(包括正确的 Proxy 关联对象) - return s.accountRepo.GetByID(ctx, id) -} - -// BulkUpdateAccounts updates multiple accounts in one request. -// It merges credentials/extra keys instead of overwriting the whole object. -func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error) { - result := &BulkUpdateAccountsResult{ - Results: make([]BulkUpdateAccountResult, 0, len(input.AccountIDs)), - } - - if len(input.AccountIDs) == 0 { - return result, nil - } - - // Prepare bulk updates for columns and JSONB fields. - repoUpdates := AccountBulkUpdate{ - Credentials: input.Credentials, - Extra: input.Extra, - } - if input.Name != "" { - repoUpdates.Name = &input.Name - } - if input.ProxyID != nil { - repoUpdates.ProxyID = input.ProxyID - } - if input.Concurrency != nil { - repoUpdates.Concurrency = input.Concurrency - } - if input.Priority != nil { - repoUpdates.Priority = input.Priority - } - if input.Status != "" { - repoUpdates.Status = &input.Status - } - - // Run bulk update for column/jsonb fields first. - if _, err := s.accountRepo.BulkUpdate(ctx, input.AccountIDs, repoUpdates); err != nil { - return nil, err - } - - // Handle group bindings per account (requires individual operations). - for _, accountID := range input.AccountIDs { - entry := BulkUpdateAccountResult{AccountID: accountID} - - if input.GroupIDs != nil { - if err := s.accountRepo.BindGroups(ctx, accountID, *input.GroupIDs); err != nil { - entry.Success = false - entry.Error = err.Error() - result.Failed++ - result.Results = append(result.Results, entry) - continue - } - } - - entry.Success = true - result.Success++ - result.Results = append(result.Results, entry) - } - - return result, nil -} - -func (s *adminServiceImpl) DeleteAccount(ctx context.Context, id int64) error { - return s.accountRepo.Delete(ctx, id) -} - -func (s *adminServiceImpl) RefreshAccountCredentials(ctx context.Context, id int64) (*Account, error) { - account, err := s.accountRepo.GetByID(ctx, id) - if err != nil { - return nil, err - } - // TODO: Implement refresh logic - return account, nil -} - -func (s *adminServiceImpl) ClearAccountError(ctx context.Context, id int64) (*Account, error) { - account, err := s.accountRepo.GetByID(ctx, id) - if err != nil { - return nil, err - } - account.Status = StatusActive - account.ErrorMessage = "" - if err := s.accountRepo.Update(ctx, account); err != nil { - return nil, err - } - return account, nil -} - -func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error) { - if err := s.accountRepo.SetSchedulable(ctx, id, schedulable); err != nil { - return nil, err - } - return s.accountRepo.GetByID(ctx, id) -} - -// Proxy management implementations -func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]Proxy, int64, error) { - params := pagination.PaginationParams{Page: page, PageSize: pageSize} - proxies, result, err := s.proxyRepo.ListWithFilters(ctx, params, protocol, status, search) - if err != nil { - return nil, 0, err - } - return proxies, result.Total, nil -} - -func (s *adminServiceImpl) GetAllProxies(ctx context.Context) ([]Proxy, error) { - return s.proxyRepo.ListActive(ctx) -} - -func (s *adminServiceImpl) GetAllProxiesWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error) { - return s.proxyRepo.ListActiveWithAccountCount(ctx) -} - -func (s *adminServiceImpl) GetProxy(ctx context.Context, id int64) (*Proxy, error) { - return s.proxyRepo.GetByID(ctx, id) -} - -func (s *adminServiceImpl) CreateProxy(ctx context.Context, input *CreateProxyInput) (*Proxy, error) { - proxy := &Proxy{ - Name: input.Name, - Protocol: input.Protocol, - Host: input.Host, - Port: input.Port, - Username: input.Username, - Password: input.Password, - Status: StatusActive, - } - if err := s.proxyRepo.Create(ctx, proxy); err != nil { - return nil, err - } - return proxy, nil -} - -func (s *adminServiceImpl) UpdateProxy(ctx context.Context, id int64, input *UpdateProxyInput) (*Proxy, error) { - proxy, err := s.proxyRepo.GetByID(ctx, id) - if err != nil { - return nil, err - } - - if input.Name != "" { - proxy.Name = input.Name - } - if input.Protocol != "" { - proxy.Protocol = input.Protocol - } - if input.Host != "" { - proxy.Host = input.Host - } - if input.Port != 0 { - proxy.Port = input.Port - } - if input.Username != "" { - proxy.Username = input.Username - } - if input.Password != "" { - proxy.Password = input.Password - } - if input.Status != "" { - proxy.Status = input.Status - } - - if err := s.proxyRepo.Update(ctx, proxy); err != nil { - return nil, err - } - return proxy, nil -} - -func (s *adminServiceImpl) DeleteProxy(ctx context.Context, id int64) error { - return s.proxyRepo.Delete(ctx, id) -} - -func (s *adminServiceImpl) GetProxyAccounts(ctx context.Context, proxyID int64, page, pageSize int) ([]Account, int64, error) { - // Return mock data for now - would need a dedicated repository method - return []Account{}, 0, nil -} - -func (s *adminServiceImpl) CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error) { - return s.proxyRepo.ExistsByHostPortAuth(ctx, host, port, username, password) -} - -// Redeem code management implementations -func (s *adminServiceImpl) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]RedeemCode, int64, error) { - params := pagination.PaginationParams{Page: page, PageSize: pageSize} - codes, result, err := s.redeemCodeRepo.ListWithFilters(ctx, params, codeType, status, search) - if err != nil { - return nil, 0, err - } - return codes, result.Total, nil -} - -func (s *adminServiceImpl) GetRedeemCode(ctx context.Context, id int64) (*RedeemCode, error) { - return s.redeemCodeRepo.GetByID(ctx, id) -} - -func (s *adminServiceImpl) GenerateRedeemCodes(ctx context.Context, input *GenerateRedeemCodesInput) ([]RedeemCode, error) { - // 如果是订阅类型,验证必须有 GroupID - if input.Type == RedeemTypeSubscription { - if input.GroupID == nil { - return nil, errors.New("group_id is required for subscription type") - } - // 验证分组存在且为订阅类型 - group, err := s.groupRepo.GetByID(ctx, *input.GroupID) - if err != nil { - return nil, fmt.Errorf("group not found: %w", err) - } - if !group.IsSubscriptionType() { - return nil, errors.New("group must be subscription type") - } - } - - codes := make([]RedeemCode, 0, input.Count) - for i := 0; i < input.Count; i++ { - codeValue, err := GenerateRedeemCode() - if err != nil { - return nil, err - } - code := RedeemCode{ - Code: codeValue, - Type: input.Type, - Value: input.Value, - Status: StatusUnused, - } - // 订阅类型专用字段 - if input.Type == RedeemTypeSubscription { - code.GroupID = input.GroupID - code.ValidityDays = input.ValidityDays - if code.ValidityDays <= 0 { - code.ValidityDays = 30 // 默认30天 - } - } - if err := s.redeemCodeRepo.Create(ctx, &code); err != nil { - return nil, err - } - codes = append(codes, code) - } - return codes, nil -} - -func (s *adminServiceImpl) DeleteRedeemCode(ctx context.Context, id int64) error { - return s.redeemCodeRepo.Delete(ctx, id) -} - -func (s *adminServiceImpl) BatchDeleteRedeemCodes(ctx context.Context, ids []int64) (int64, error) { - var deleted int64 - for _, id := range ids { - if err := s.redeemCodeRepo.Delete(ctx, id); err == nil { - deleted++ - } - } - return deleted, nil -} - -func (s *adminServiceImpl) ExpireRedeemCode(ctx context.Context, id int64) (*RedeemCode, error) { - code, err := s.redeemCodeRepo.GetByID(ctx, id) - if err != nil { - return nil, err - } - code.Status = StatusExpired - if err := s.redeemCodeRepo.Update(ctx, code); err != nil { - return nil, err - } - return code, nil -} - -func (s *adminServiceImpl) TestProxy(ctx context.Context, id int64) (*ProxyTestResult, error) { - proxy, err := s.proxyRepo.GetByID(ctx, id) - if err != nil { - return nil, err - } - - proxyURL := proxy.URL() - exitInfo, latencyMs, err := s.proxyProber.ProbeProxy(ctx, proxyURL) - if err != nil { - return &ProxyTestResult{ - Success: false, - Message: err.Error(), - }, nil - } - - return &ProxyTestResult{ - Success: true, - Message: "Proxy is accessible", - LatencyMs: latencyMs, - IPAddress: exitInfo.IP, - City: exitInfo.City, - Region: exitInfo.Region, - Country: exitInfo.Country, - }, nil -} +package service + +import ( + "context" + "errors" + "fmt" + "log" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" +) + +// AdminService interface defines admin management operations +type AdminService interface { + // User management + ListUsers(ctx context.Context, page, pageSize int, filters UserListFilters) ([]User, int64, error) + GetUser(ctx context.Context, id int64) (*User, error) + CreateUser(ctx context.Context, input *CreateUserInput) (*User, error) + UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error) + DeleteUser(ctx context.Context, id int64) error + UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error) + GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]ApiKey, int64, error) + GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error) + + // Group management + ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]Group, int64, error) + GetAllGroups(ctx context.Context) ([]Group, error) + GetAllGroupsByPlatform(ctx context.Context, platform string) ([]Group, error) + GetGroup(ctx context.Context, id int64) (*Group, error) + CreateGroup(ctx context.Context, input *CreateGroupInput) (*Group, error) + UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) + DeleteGroup(ctx context.Context, id int64) error + GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]ApiKey, int64, error) + + // Account management + ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error) + GetAccount(ctx context.Context, id int64) (*Account, error) + GetAccountsByIDs(ctx context.Context, ids []int64) ([]*Account, error) + CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error) + UpdateAccount(ctx context.Context, id int64, input *UpdateAccountInput) (*Account, error) + DeleteAccount(ctx context.Context, id int64) error + RefreshAccountCredentials(ctx context.Context, id int64) (*Account, error) + ClearAccountError(ctx context.Context, id int64) (*Account, error) + SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error) + BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error) + + // Proxy management + ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]Proxy, int64, error) + GetAllProxies(ctx context.Context) ([]Proxy, error) + GetAllProxiesWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error) + GetProxy(ctx context.Context, id int64) (*Proxy, error) + CreateProxy(ctx context.Context, input *CreateProxyInput) (*Proxy, error) + UpdateProxy(ctx context.Context, id int64, input *UpdateProxyInput) (*Proxy, error) + DeleteProxy(ctx context.Context, id int64) error + GetProxyAccounts(ctx context.Context, proxyID int64, page, pageSize int) ([]Account, int64, error) + CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error) + TestProxy(ctx context.Context, id int64) (*ProxyTestResult, error) + + // Redeem code management + ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]RedeemCode, int64, error) + GetRedeemCode(ctx context.Context, id int64) (*RedeemCode, error) + GenerateRedeemCodes(ctx context.Context, input *GenerateRedeemCodesInput) ([]RedeemCode, error) + DeleteRedeemCode(ctx context.Context, id int64) error + BatchDeleteRedeemCodes(ctx context.Context, ids []int64) (int64, error) + ExpireRedeemCode(ctx context.Context, id int64) (*RedeemCode, error) +} + +// Input types for admin operations +type CreateUserInput struct { + Email string + Password string + Username string + Notes string + Balance float64 + Concurrency int + AllowedGroups []int64 +} + +type UpdateUserInput struct { + Email string + Password string + Username *string + Notes *string + Balance *float64 // 使用指针区分"未提供"和"设置为0" + Concurrency *int // 使用指针区分"未提供"和"设置为0" + Status string + AllowedGroups *[]int64 // 使用指针区分"未提供"和"设置为空数组" +} + +type CreateGroupInput struct { + Name string + Description string + Platform string + RateMultiplier float64 + IsExclusive bool + SubscriptionType string // standard/subscription + DailyLimitUSD *float64 // 日限额 (USD) + WeeklyLimitUSD *float64 // 周限额 (USD) + MonthlyLimitUSD *float64 // 月限额 (USD) +} + +type UpdateGroupInput struct { + Name string + Description string + Platform string + RateMultiplier *float64 // 使用指针以支持设置为0 + IsExclusive *bool + Status string + SubscriptionType string // standard/subscription + DailyLimitUSD *float64 // 日限额 (USD) + WeeklyLimitUSD *float64 // 周限额 (USD) + MonthlyLimitUSD *float64 // 月限额 (USD) +} + +type CreateAccountInput struct { + Name string + Platform string + Type string + Credentials map[string]any + Extra map[string]any + ProxyID *int64 + Concurrency int + Priority int + GroupIDs []int64 +} + +type UpdateAccountInput struct { + Name string + Type string // Account type: oauth, setup-token, apikey + Credentials map[string]any + Extra map[string]any + ProxyID *int64 + Concurrency *int // 使用指针区分"未提供"和"设置为0" + Priority *int // 使用指针区分"未提供"和"设置为0" + Status string + GroupIDs *[]int64 +} + +// BulkUpdateAccountsInput describes the payload for bulk updating accounts. +type BulkUpdateAccountsInput struct { + AccountIDs []int64 + Name string + ProxyID *int64 + Concurrency *int + Priority *int + Status string + GroupIDs *[]int64 + Credentials map[string]any + Extra map[string]any +} + +// BulkUpdateAccountResult captures the result for a single account update. +type BulkUpdateAccountResult struct { + AccountID int64 `json:"account_id"` + Success bool `json:"success"` + Error string `json:"error,omitempty"` +} + +// BulkUpdateAccountsResult is the aggregated response for bulk updates. +type BulkUpdateAccountsResult struct { + Success int `json:"success"` + Failed int `json:"failed"` + Results []BulkUpdateAccountResult `json:"results"` +} + +type CreateProxyInput struct { + Name string + Protocol string + Host string + Port int + Username string + Password string +} + +type UpdateProxyInput struct { + Name string + Protocol string + Host string + Port int + Username string + Password string + Status string +} + +type GenerateRedeemCodesInput struct { + Count int + Type string + Value float64 + GroupID *int64 // 订阅类型专用:关联的分组ID + ValidityDays int // 订阅类型专用:有效天数 +} + +// ProxyTestResult represents the result of testing a proxy +type ProxyTestResult struct { + Success bool `json:"success"` + Message string `json:"message"` + LatencyMs int64 `json:"latency_ms,omitempty"` + IPAddress string `json:"ip_address,omitempty"` + City string `json:"city,omitempty"` + Region string `json:"region,omitempty"` + Country string `json:"country,omitempty"` +} + +// ProxyExitInfo represents proxy exit information from ipinfo.io +type ProxyExitInfo struct { + IP string + City string + Region string + Country string +} + +// ProxyExitInfoProber tests proxy connectivity and retrieves exit information +type ProxyExitInfoProber interface { + ProbeProxy(ctx context.Context, proxyURL string) (*ProxyExitInfo, int64, error) +} + +// adminServiceImpl implements AdminService +type adminServiceImpl struct { + userRepo UserRepository + groupRepo GroupRepository + accountRepo AccountRepository + proxyRepo ProxyRepository + apiKeyRepo ApiKeyRepository + redeemCodeRepo RedeemCodeRepository + billingCacheService *BillingCacheService + proxyProber ProxyExitInfoProber +} + +// NewAdminService creates a new AdminService +func NewAdminService( + userRepo UserRepository, + groupRepo GroupRepository, + accountRepo AccountRepository, + proxyRepo ProxyRepository, + apiKeyRepo ApiKeyRepository, + redeemCodeRepo RedeemCodeRepository, + billingCacheService *BillingCacheService, + proxyProber ProxyExitInfoProber, +) AdminService { + return &adminServiceImpl{ + userRepo: userRepo, + groupRepo: groupRepo, + accountRepo: accountRepo, + proxyRepo: proxyRepo, + apiKeyRepo: apiKeyRepo, + redeemCodeRepo: redeemCodeRepo, + billingCacheService: billingCacheService, + proxyProber: proxyProber, + } +} + +// User management implementations +func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, filters UserListFilters) ([]User, int64, error) { + params := pagination.PaginationParams{Page: page, PageSize: pageSize} + users, result, err := s.userRepo.ListWithFilters(ctx, params, filters) + if err != nil { + return nil, 0, err + } + return users, result.Total, nil +} + +func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error) { + return s.userRepo.GetByID(ctx, id) +} + +func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInput) (*User, error) { + user := &User{ + Email: input.Email, + Username: input.Username, + Notes: input.Notes, + Role: RoleUser, // Always create as regular user, never admin + Balance: input.Balance, + Concurrency: input.Concurrency, + Status: StatusActive, + AllowedGroups: input.AllowedGroups, + } + if err := user.SetPassword(input.Password); err != nil { + return nil, err + } + if err := s.userRepo.Create(ctx, user); err != nil { + return nil, err + } + return user, nil +} + +func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error) { + user, err := s.userRepo.GetByID(ctx, id) + if err != nil { + return nil, err + } + + // Protect admin users: cannot disable admin accounts + if user.Role == "admin" && input.Status == "disabled" { + return nil, errors.New("cannot disable admin user") + } + + oldConcurrency := user.Concurrency + + if input.Email != "" { + user.Email = input.Email + } + if input.Password != "" { + if err := user.SetPassword(input.Password); err != nil { + return nil, err + } + } + + if input.Username != nil { + user.Username = *input.Username + } + if input.Notes != nil { + user.Notes = *input.Notes + } + + if input.Status != "" { + user.Status = input.Status + } + + if input.Concurrency != nil { + user.Concurrency = *input.Concurrency + } + + if input.AllowedGroups != nil { + user.AllowedGroups = *input.AllowedGroups + } + + if err := s.userRepo.Update(ctx, user); err != nil { + return nil, err + } + + concurrencyDiff := user.Concurrency - oldConcurrency + if concurrencyDiff != 0 { + code, err := GenerateRedeemCode() + if err != nil { + log.Printf("failed to generate adjustment redeem code: %v", err) + return user, nil + } + adjustmentRecord := &RedeemCode{ + Code: code, + Type: AdjustmentTypeAdminConcurrency, + Value: float64(concurrencyDiff), + Status: StatusUsed, + UsedBy: &user.ID, + } + now := time.Now() + adjustmentRecord.UsedAt = &now + if err := s.redeemCodeRepo.Create(ctx, adjustmentRecord); err != nil { + log.Printf("failed to create concurrency adjustment redeem code: %v", err) + } + } + + return user, nil +} + +func (s *adminServiceImpl) DeleteUser(ctx context.Context, id int64) error { + // Protect admin users: cannot delete admin accounts + user, err := s.userRepo.GetByID(ctx, id) + if err != nil { + return err + } + if user.Role == "admin" { + return errors.New("cannot delete admin user") + } + if err := s.userRepo.Delete(ctx, id); err != nil { + log.Printf("delete user failed: user_id=%d err=%v", id, err) + return err + } + return nil +} + +func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error) { + user, err := s.userRepo.GetByID(ctx, userID) + if err != nil { + return nil, err + } + + oldBalance := user.Balance + + switch operation { + case "set": + user.Balance = balance + case "add": + user.Balance += balance + case "subtract": + user.Balance -= balance + } + + if user.Balance < 0 { + return nil, fmt.Errorf("balance cannot be negative, current balance: %.2f, requested operation would result in: %.2f", oldBalance, user.Balance) + } + + if err := s.userRepo.Update(ctx, user); err != nil { + return nil, err + } + + if s.billingCacheService != nil { + go func() { + cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := s.billingCacheService.InvalidateUserBalance(cacheCtx, userID); err != nil { + log.Printf("invalidate user balance cache failed: user_id=%d err=%v", userID, err) + } + }() + } + + balanceDiff := user.Balance - oldBalance + if balanceDiff != 0 { + code, err := GenerateRedeemCode() + if err != nil { + log.Printf("failed to generate adjustment redeem code: %v", err) + return user, nil + } + + adjustmentRecord := &RedeemCode{ + Code: code, + Type: AdjustmentTypeAdminBalance, + Value: balanceDiff, + Status: StatusUsed, + UsedBy: &user.ID, + Notes: notes, + } + now := time.Now() + adjustmentRecord.UsedAt = &now + + if err := s.redeemCodeRepo.Create(ctx, adjustmentRecord); err != nil { + log.Printf("failed to create balance adjustment redeem code: %v", err) + } + } + + return user, nil +} + +func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]ApiKey, int64, error) { + params := pagination.PaginationParams{Page: page, PageSize: pageSize} + keys, result, err := s.apiKeyRepo.ListByUserID(ctx, userID, params) + if err != nil { + return nil, 0, err + } + return keys, result.Total, nil +} + +func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error) { + // Return mock data for now + return map[string]any{ + "period": period, + "total_requests": 0, + "total_cost": 0.0, + "total_tokens": 0, + "avg_duration_ms": 0, + }, nil +} + +// Group management implementations +func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]Group, int64, error) { + params := pagination.PaginationParams{Page: page, PageSize: pageSize} + groups, result, err := s.groupRepo.ListWithFilters(ctx, params, platform, status, isExclusive) + if err != nil { + return nil, 0, err + } + return groups, result.Total, nil +} + +func (s *adminServiceImpl) GetAllGroups(ctx context.Context) ([]Group, error) { + return s.groupRepo.ListActive(ctx) +} + +func (s *adminServiceImpl) GetAllGroupsByPlatform(ctx context.Context, platform string) ([]Group, error) { + return s.groupRepo.ListActiveByPlatform(ctx, platform) +} + +func (s *adminServiceImpl) GetGroup(ctx context.Context, id int64) (*Group, error) { + return s.groupRepo.GetByID(ctx, id) +} + +func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupInput) (*Group, error) { + platform := input.Platform + if platform == "" { + platform = PlatformAnthropic + } + + subscriptionType := input.SubscriptionType + if subscriptionType == "" { + subscriptionType = SubscriptionTypeStandard + } + + // 限额字段:0 和 nil 都表示"无限制" + dailyLimit := normalizeLimit(input.DailyLimitUSD) + weeklyLimit := normalizeLimit(input.WeeklyLimitUSD) + monthlyLimit := normalizeLimit(input.MonthlyLimitUSD) + + group := &Group{ + Name: input.Name, + Description: input.Description, + Platform: platform, + RateMultiplier: input.RateMultiplier, + IsExclusive: input.IsExclusive, + Status: StatusActive, + SubscriptionType: subscriptionType, + DailyLimitUSD: dailyLimit, + WeeklyLimitUSD: weeklyLimit, + MonthlyLimitUSD: monthlyLimit, + } + if err := s.groupRepo.Create(ctx, group); err != nil { + return nil, err + } + return group, nil +} + +// normalizeLimit 将 0 或负数转换为 nil(表示无限制) +func normalizeLimit(limit *float64) *float64 { + if limit == nil || *limit <= 0 { + return nil + } + return limit +} + +func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) { + group, err := s.groupRepo.GetByID(ctx, id) + if err != nil { + return nil, err + } + + if input.Name != "" { + group.Name = input.Name + } + if input.Description != "" { + group.Description = input.Description + } + if input.Platform != "" { + group.Platform = input.Platform + } + if input.RateMultiplier != nil { + group.RateMultiplier = *input.RateMultiplier + } + if input.IsExclusive != nil { + group.IsExclusive = *input.IsExclusive + } + if input.Status != "" { + group.Status = input.Status + } + + // 订阅相关字段 + if input.SubscriptionType != "" { + group.SubscriptionType = input.SubscriptionType + } + // 限额字段:0 和 nil 都表示"无限制",正数表示具体限额 + if input.DailyLimitUSD != nil { + group.DailyLimitUSD = normalizeLimit(input.DailyLimitUSD) + } + if input.WeeklyLimitUSD != nil { + group.WeeklyLimitUSD = normalizeLimit(input.WeeklyLimitUSD) + } + if input.MonthlyLimitUSD != nil { + group.MonthlyLimitUSD = normalizeLimit(input.MonthlyLimitUSD) + } + + if err := s.groupRepo.Update(ctx, group); err != nil { + return nil, err + } + return group, nil +} + +func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error { + affectedUserIDs, err := s.groupRepo.DeleteCascade(ctx, id) + if err != nil { + return err + } + + // 事务成功后,异步失效受影响用户的订阅缓存 + if len(affectedUserIDs) > 0 && s.billingCacheService != nil { + groupID := id + go func() { + cacheCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + for _, userID := range affectedUserIDs { + if err := s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID); err != nil { + log.Printf("invalidate subscription cache failed: user_id=%d group_id=%d err=%v", userID, groupID, err) + } + } + }() + } + + return nil +} + +func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]ApiKey, int64, error) { + params := pagination.PaginationParams{Page: page, PageSize: pageSize} + keys, result, err := s.apiKeyRepo.ListByGroupID(ctx, groupID, params) + if err != nil { + return nil, 0, err + } + return keys, result.Total, nil +} + +// Account management implementations +func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error) { + params := pagination.PaginationParams{Page: page, PageSize: pageSize} + accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search) + if err != nil { + return nil, 0, err + } + return accounts, result.Total, nil +} + +func (s *adminServiceImpl) GetAccount(ctx context.Context, id int64) (*Account, error) { + return s.accountRepo.GetByID(ctx, id) +} + +func (s *adminServiceImpl) GetAccountsByIDs(ctx context.Context, ids []int64) ([]*Account, error) { + if len(ids) == 0 { + return []*Account{}, nil + } + + accounts, err := s.accountRepo.GetByIDs(ctx, ids) + if err != nil { + return nil, fmt.Errorf("failed to get accounts by IDs: %w", err) + } + + return accounts, nil +} + +func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error) { + account := &Account{ + Name: input.Name, + Platform: input.Platform, + Type: input.Type, + Credentials: input.Credentials, + Extra: input.Extra, + ProxyID: input.ProxyID, + Concurrency: input.Concurrency, + Priority: input.Priority, + Status: StatusActive, + Schedulable: true, + } + if err := s.accountRepo.Create(ctx, account); err != nil { + return nil, err + } + + // 绑定分组 + groupIDs := input.GroupIDs + // 如果没有指定分组,自动绑定对应平台的默认分组 + if len(groupIDs) == 0 { + defaultGroupName := input.Platform + "-default" + groups, err := s.groupRepo.ListActiveByPlatform(ctx, input.Platform) + if err == nil { + for _, g := range groups { + if g.Name == defaultGroupName { + groupIDs = []int64{g.ID} + log.Printf("[CreateAccount] Auto-binding account %d to default group %s (ID: %d)", account.ID, defaultGroupName, g.ID) + break + } + } + } + } + + if len(groupIDs) > 0 { + if err := s.accountRepo.BindGroups(ctx, account.ID, groupIDs); err != nil { + return nil, err + } + } + + return account, nil +} + +func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *UpdateAccountInput) (*Account, error) { + account, err := s.accountRepo.GetByID(ctx, id) + if err != nil { + return nil, err + } + + if input.Name != "" { + account.Name = input.Name + } + if input.Type != "" { + account.Type = input.Type + } + if len(input.Credentials) > 0 { + account.Credentials = input.Credentials + } + if len(input.Extra) > 0 { + account.Extra = input.Extra + } + if input.ProxyID != nil { + account.ProxyID = input.ProxyID + account.Proxy = nil // 清除关联对象,防止 GORM Save 时根据 Proxy.ID 覆盖 ProxyID + } + // 只在指针非 nil 时更新 Concurrency(支持设置为 0) + if input.Concurrency != nil { + account.Concurrency = *input.Concurrency + } + // 只在指针非 nil 时更新 Priority(支持设置为 0) + if input.Priority != nil { + account.Priority = *input.Priority + } + if input.Status != "" { + account.Status = input.Status + } + + // 先验证分组是否存在(在任何写操作之前) + if input.GroupIDs != nil { + for _, groupID := range *input.GroupIDs { + if _, err := s.groupRepo.GetByID(ctx, groupID); err != nil { + return nil, fmt.Errorf("get group: %w", err) + } + } + } + + if err := s.accountRepo.Update(ctx, account); err != nil { + return nil, err + } + + // 绑定分组 + if input.GroupIDs != nil { + if err := s.accountRepo.BindGroups(ctx, account.ID, *input.GroupIDs); err != nil { + return nil, err + } + } + + // 重新查询以确保返回完整数据(包括正确的 Proxy 关联对象) + return s.accountRepo.GetByID(ctx, id) +} + +// BulkUpdateAccounts updates multiple accounts in one request. +// It merges credentials/extra keys instead of overwriting the whole object. +func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error) { + result := &BulkUpdateAccountsResult{ + Results: make([]BulkUpdateAccountResult, 0, len(input.AccountIDs)), + } + + if len(input.AccountIDs) == 0 { + return result, nil + } + + // Prepare bulk updates for columns and JSONB fields. + repoUpdates := AccountBulkUpdate{ + Credentials: input.Credentials, + Extra: input.Extra, + } + if input.Name != "" { + repoUpdates.Name = &input.Name + } + if input.ProxyID != nil { + repoUpdates.ProxyID = input.ProxyID + } + if input.Concurrency != nil { + repoUpdates.Concurrency = input.Concurrency + } + if input.Priority != nil { + repoUpdates.Priority = input.Priority + } + if input.Status != "" { + repoUpdates.Status = &input.Status + } + + // Run bulk update for column/jsonb fields first. + if _, err := s.accountRepo.BulkUpdate(ctx, input.AccountIDs, repoUpdates); err != nil { + return nil, err + } + + // Handle group bindings per account (requires individual operations). + for _, accountID := range input.AccountIDs { + entry := BulkUpdateAccountResult{AccountID: accountID} + + if input.GroupIDs != nil { + if err := s.accountRepo.BindGroups(ctx, accountID, *input.GroupIDs); err != nil { + entry.Success = false + entry.Error = err.Error() + result.Failed++ + result.Results = append(result.Results, entry) + continue + } + } + + entry.Success = true + result.Success++ + result.Results = append(result.Results, entry) + } + + return result, nil +} + +func (s *adminServiceImpl) DeleteAccount(ctx context.Context, id int64) error { + return s.accountRepo.Delete(ctx, id) +} + +func (s *adminServiceImpl) RefreshAccountCredentials(ctx context.Context, id int64) (*Account, error) { + account, err := s.accountRepo.GetByID(ctx, id) + if err != nil { + return nil, err + } + // TODO: Implement refresh logic + return account, nil +} + +func (s *adminServiceImpl) ClearAccountError(ctx context.Context, id int64) (*Account, error) { + account, err := s.accountRepo.GetByID(ctx, id) + if err != nil { + return nil, err + } + account.Status = StatusActive + account.ErrorMessage = "" + if err := s.accountRepo.Update(ctx, account); err != nil { + return nil, err + } + return account, nil +} + +func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error) { + if err := s.accountRepo.SetSchedulable(ctx, id, schedulable); err != nil { + return nil, err + } + return s.accountRepo.GetByID(ctx, id) +} + +// Proxy management implementations +func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]Proxy, int64, error) { + params := pagination.PaginationParams{Page: page, PageSize: pageSize} + proxies, result, err := s.proxyRepo.ListWithFilters(ctx, params, protocol, status, search) + if err != nil { + return nil, 0, err + } + return proxies, result.Total, nil +} + +func (s *adminServiceImpl) GetAllProxies(ctx context.Context) ([]Proxy, error) { + return s.proxyRepo.ListActive(ctx) +} + +func (s *adminServiceImpl) GetAllProxiesWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error) { + return s.proxyRepo.ListActiveWithAccountCount(ctx) +} + +func (s *adminServiceImpl) GetProxy(ctx context.Context, id int64) (*Proxy, error) { + return s.proxyRepo.GetByID(ctx, id) +} + +func (s *adminServiceImpl) CreateProxy(ctx context.Context, input *CreateProxyInput) (*Proxy, error) { + proxy := &Proxy{ + Name: input.Name, + Protocol: input.Protocol, + Host: input.Host, + Port: input.Port, + Username: input.Username, + Password: input.Password, + Status: StatusActive, + } + if err := s.proxyRepo.Create(ctx, proxy); err != nil { + return nil, err + } + return proxy, nil +} + +func (s *adminServiceImpl) UpdateProxy(ctx context.Context, id int64, input *UpdateProxyInput) (*Proxy, error) { + proxy, err := s.proxyRepo.GetByID(ctx, id) + if err != nil { + return nil, err + } + + if input.Name != "" { + proxy.Name = input.Name + } + if input.Protocol != "" { + proxy.Protocol = input.Protocol + } + if input.Host != "" { + proxy.Host = input.Host + } + if input.Port != 0 { + proxy.Port = input.Port + } + if input.Username != "" { + proxy.Username = input.Username + } + if input.Password != "" { + proxy.Password = input.Password + } + if input.Status != "" { + proxy.Status = input.Status + } + + if err := s.proxyRepo.Update(ctx, proxy); err != nil { + return nil, err + } + return proxy, nil +} + +func (s *adminServiceImpl) DeleteProxy(ctx context.Context, id int64) error { + return s.proxyRepo.Delete(ctx, id) +} + +func (s *adminServiceImpl) GetProxyAccounts(ctx context.Context, proxyID int64, page, pageSize int) ([]Account, int64, error) { + // Return mock data for now - would need a dedicated repository method + return []Account{}, 0, nil +} + +func (s *adminServiceImpl) CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error) { + return s.proxyRepo.ExistsByHostPortAuth(ctx, host, port, username, password) +} + +// Redeem code management implementations +func (s *adminServiceImpl) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]RedeemCode, int64, error) { + params := pagination.PaginationParams{Page: page, PageSize: pageSize} + codes, result, err := s.redeemCodeRepo.ListWithFilters(ctx, params, codeType, status, search) + if err != nil { + return nil, 0, err + } + return codes, result.Total, nil +} + +func (s *adminServiceImpl) GetRedeemCode(ctx context.Context, id int64) (*RedeemCode, error) { + return s.redeemCodeRepo.GetByID(ctx, id) +} + +func (s *adminServiceImpl) GenerateRedeemCodes(ctx context.Context, input *GenerateRedeemCodesInput) ([]RedeemCode, error) { + // 如果是订阅类型,验证必须有 GroupID + if input.Type == RedeemTypeSubscription { + if input.GroupID == nil { + return nil, errors.New("group_id is required for subscription type") + } + // 验证分组存在且为订阅类型 + group, err := s.groupRepo.GetByID(ctx, *input.GroupID) + if err != nil { + return nil, fmt.Errorf("group not found: %w", err) + } + if !group.IsSubscriptionType() { + return nil, errors.New("group must be subscription type") + } + } + + codes := make([]RedeemCode, 0, input.Count) + for i := 0; i < input.Count; i++ { + codeValue, err := GenerateRedeemCode() + if err != nil { + return nil, err + } + code := RedeemCode{ + Code: codeValue, + Type: input.Type, + Value: input.Value, + Status: StatusUnused, + } + // 订阅类型专用字段 + if input.Type == RedeemTypeSubscription { + code.GroupID = input.GroupID + code.ValidityDays = input.ValidityDays + if code.ValidityDays <= 0 { + code.ValidityDays = 30 // 默认30天 + } + } + if err := s.redeemCodeRepo.Create(ctx, &code); err != nil { + return nil, err + } + codes = append(codes, code) + } + return codes, nil +} + +func (s *adminServiceImpl) DeleteRedeemCode(ctx context.Context, id int64) error { + return s.redeemCodeRepo.Delete(ctx, id) +} + +func (s *adminServiceImpl) BatchDeleteRedeemCodes(ctx context.Context, ids []int64) (int64, error) { + var deleted int64 + for _, id := range ids { + if err := s.redeemCodeRepo.Delete(ctx, id); err == nil { + deleted++ + } + } + return deleted, nil +} + +func (s *adminServiceImpl) ExpireRedeemCode(ctx context.Context, id int64) (*RedeemCode, error) { + code, err := s.redeemCodeRepo.GetByID(ctx, id) + if err != nil { + return nil, err + } + code.Status = StatusExpired + if err := s.redeemCodeRepo.Update(ctx, code); err != nil { + return nil, err + } + return code, nil +} + +func (s *adminServiceImpl) TestProxy(ctx context.Context, id int64) (*ProxyTestResult, error) { + proxy, err := s.proxyRepo.GetByID(ctx, id) + if err != nil { + return nil, err + } + + proxyURL := proxy.URL() + exitInfo, latencyMs, err := s.proxyProber.ProbeProxy(ctx, proxyURL) + if err != nil { + return &ProxyTestResult{ + Success: false, + Message: err.Error(), + }, nil + } + + return &ProxyTestResult{ + Success: true, + Message: "Proxy is accessible", + LatencyMs: latencyMs, + IPAddress: exitInfo.IP, + City: exitInfo.City, + Region: exitInfo.Region, + Country: exitInfo.Country, + }, nil +} diff --git a/backend/internal/service/gateway_prompt_test.go b/backend/internal/service/gateway_prompt_test.go new file mode 100644 index 00000000..b056f8fa --- /dev/null +++ b/backend/internal/service/gateway_prompt_test.go @@ -0,0 +1,233 @@ +package service + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestIsClaudeCodeClient(t *testing.T) { + tests := []struct { + name string + userAgent string + metadataUserID string + want bool + }{ + { + name: "Claude Code client", + userAgent: "claude-cli/1.0.62 (darwin; arm64)", + metadataUserID: "session_123e4567-e89b-12d3-a456-426614174000", + want: true, + }, + { + name: "Claude Code without version suffix", + userAgent: "claude-cli/2.0.0", + metadataUserID: "session_abc", + want: true, + }, + { + name: "Missing metadata user_id", + userAgent: "claude-cli/1.0.0", + metadataUserID: "", + want: false, + }, + { + name: "Different user agent", + userAgent: "curl/7.68.0", + metadataUserID: "user123", + want: false, + }, + { + name: "Empty user agent", + userAgent: "", + metadataUserID: "user123", + want: false, + }, + { + name: "Similar but not Claude CLI", + userAgent: "claude-api/1.0.0", + metadataUserID: "user123", + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := isClaudeCodeClient(tt.userAgent, tt.metadataUserID) + require.Equal(t, tt.want, got) + }) + } +} + +func TestSystemIncludesClaudeCodePrompt(t *testing.T) { + tests := []struct { + name string + system any + want bool + }{ + { + name: "nil system", + system: nil, + want: false, + }, + { + name: "empty string", + system: "", + want: false, + }, + { + name: "string with Claude Code prompt", + system: claudeCodeSystemPrompt, + want: true, + }, + { + name: "string with different content", + system: "You are a helpful assistant.", + want: false, + }, + { + name: "empty array", + system: []any{}, + want: false, + }, + { + name: "array with Claude Code prompt", + system: []any{ + map[string]any{ + "type": "text", + "text": claudeCodeSystemPrompt, + }, + }, + want: true, + }, + { + name: "array with Claude Code prompt in second position", + system: []any{ + map[string]any{"type": "text", "text": "First prompt"}, + map[string]any{"type": "text", "text": claudeCodeSystemPrompt}, + }, + want: true, + }, + { + name: "array without Claude Code prompt", + system: []any{ + map[string]any{"type": "text", "text": "Custom prompt"}, + }, + want: false, + }, + { + name: "array with partial match (should not match)", + system: []any{ + map[string]any{"type": "text", "text": "You are Claude"}, + }, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := systemIncludesClaudeCodePrompt(tt.system) + require.Equal(t, tt.want, got) + }) + } +} + +func TestInjectClaudeCodePrompt(t *testing.T) { + tests := []struct { + name string + body string + system any + wantSystemLen int + wantFirstText string + wantSecondText string + }{ + { + name: "nil system", + body: `{"model":"claude-3"}`, + system: nil, + wantSystemLen: 1, + wantFirstText: claudeCodeSystemPrompt, + }, + { + name: "empty string system", + body: `{"model":"claude-3"}`, + system: "", + wantSystemLen: 1, + wantFirstText: claudeCodeSystemPrompt, + }, + { + name: "string system", + body: `{"model":"claude-3"}`, + system: "Custom prompt", + wantSystemLen: 2, + wantFirstText: claudeCodeSystemPrompt, + wantSecondText: "Custom prompt", + }, + { + name: "string system equals Claude Code prompt", + body: `{"model":"claude-3"}`, + system: claudeCodeSystemPrompt, + wantSystemLen: 1, + wantFirstText: claudeCodeSystemPrompt, + }, + { + name: "array system", + body: `{"model":"claude-3"}`, + system: []any{map[string]any{"type": "text", "text": "Custom"}}, + // Claude Code + Custom = 2 + wantSystemLen: 2, + wantFirstText: claudeCodeSystemPrompt, + wantSecondText: "Custom", + }, + { + name: "array system with existing Claude Code prompt (should dedupe)", + body: `{"model":"claude-3"}`, + system: []any{ + map[string]any{"type": "text", "text": claudeCodeSystemPrompt}, + map[string]any{"type": "text", "text": "Other"}, + }, + // Claude Code at start + Other = 2 (deduped) + wantSystemLen: 2, + wantFirstText: claudeCodeSystemPrompt, + wantSecondText: "Other", + }, + { + name: "empty array", + body: `{"model":"claude-3"}`, + system: []any{}, + wantSystemLen: 1, + wantFirstText: claudeCodeSystemPrompt, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := injectClaudeCodePrompt([]byte(tt.body), tt.system) + + var parsed map[string]any + err := json.Unmarshal(result, &parsed) + require.NoError(t, err) + + system, ok := parsed["system"].([]any) + require.True(t, ok, "system should be an array") + require.Len(t, system, tt.wantSystemLen) + + first, ok := system[0].(map[string]any) + require.True(t, ok) + require.Equal(t, tt.wantFirstText, first["text"]) + require.Equal(t, "text", first["type"]) + + // Check cache_control + cc, ok := first["cache_control"].(map[string]any) + require.True(t, ok) + require.Equal(t, "ephemeral", cc["type"]) + + if tt.wantSecondText != "" && len(system) > 1 { + second, ok := system[1].(map[string]any) + require.True(t, ok) + require.Equal(t, tt.wantSecondText, second["text"]) + } + }) + } +} diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 5ec6459c..0af2c328 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -1,1981 +1,2050 @@ -package service - -import ( - "bufio" - "bytes" - "context" - "crypto/sha256" - "encoding/hex" - "encoding/json" - "errors" - "fmt" - "io" - "log" - "net/http" - "regexp" - "sort" - "strings" - "time" - - "github.com/Wei-Shaw/sub2api/internal/config" - "github.com/Wei-Shaw/sub2api/internal/pkg/claude" - "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" - "github.com/tidwall/gjson" - "github.com/tidwall/sjson" - - "github.com/gin-gonic/gin" -) - -const ( - claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true" - claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true" - stickySessionTTL = time.Hour // 粘性会话TTL -) - -// sseDataRe matches SSE data lines with optional whitespace after colon. -// Some upstream APIs return non-standard "data:" without space (should be "data: "). -var ( - sseDataRe = regexp.MustCompile(`^data:\s*`) - sessionIDRegex = regexp.MustCompile(`session_([a-f0-9-]{36})`) -) - -// allowedHeaders 白名单headers(参考CRS项目) -var allowedHeaders = map[string]bool{ - "accept": true, - "x-stainless-retry-count": true, - "x-stainless-timeout": true, - "x-stainless-lang": true, - "x-stainless-package-version": true, - "x-stainless-os": true, - "x-stainless-arch": true, - "x-stainless-runtime": true, - "x-stainless-runtime-version": true, - "x-stainless-helper-method": true, - "anthropic-dangerous-direct-browser-access": true, - "anthropic-version": true, - "x-app": true, - "anthropic-beta": true, - "accept-language": true, - "sec-fetch-mode": true, - "user-agent": true, - "content-type": true, -} - -// GatewayCache defines cache operations for gateway service -type GatewayCache interface { - GetSessionAccountID(ctx context.Context, sessionHash string) (int64, error) - SetSessionAccountID(ctx context.Context, sessionHash string, accountID int64, ttl time.Duration) error - RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error -} - -type AccountWaitPlan struct { - AccountID int64 - MaxConcurrency int - Timeout time.Duration - MaxWaiting int -} - -type AccountSelectionResult struct { - Account *Account - Acquired bool - ReleaseFunc func() - WaitPlan *AccountWaitPlan // nil means no wait allowed -} - -// ClaudeUsage 表示Claude API返回的usage信息 -type ClaudeUsage struct { - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` - CacheCreationInputTokens int `json:"cache_creation_input_tokens"` - CacheReadInputTokens int `json:"cache_read_input_tokens"` -} - -// ForwardResult 转发结果 -type ForwardResult struct { - RequestID string - Usage ClaudeUsage - Model string - Stream bool - Duration time.Duration - FirstTokenMs *int // 首字时间(流式请求) -} - -// UpstreamFailoverError indicates an upstream error that should trigger account failover. -type UpstreamFailoverError struct { - StatusCode int -} - -func (e *UpstreamFailoverError) Error() string { - return fmt.Sprintf("upstream error: %d (failover)", e.StatusCode) -} - -// GatewayService handles API gateway operations -type GatewayService struct { - accountRepo AccountRepository - groupRepo GroupRepository - usageLogRepo UsageLogRepository - userRepo UserRepository - userSubRepo UserSubscriptionRepository - cache GatewayCache - cfg *config.Config - billingService *BillingService - rateLimitService *RateLimitService - billingCacheService *BillingCacheService - identityService *IdentityService - httpUpstream HTTPUpstream - deferredService *DeferredService - concurrencyService *ConcurrencyService -} - -// NewGatewayService creates a new GatewayService -func NewGatewayService( - accountRepo AccountRepository, - groupRepo GroupRepository, - usageLogRepo UsageLogRepository, - userRepo UserRepository, - userSubRepo UserSubscriptionRepository, - cache GatewayCache, - cfg *config.Config, - concurrencyService *ConcurrencyService, - billingService *BillingService, - rateLimitService *RateLimitService, - billingCacheService *BillingCacheService, - identityService *IdentityService, - httpUpstream HTTPUpstream, - deferredService *DeferredService, -) *GatewayService { - return &GatewayService{ - accountRepo: accountRepo, - groupRepo: groupRepo, - usageLogRepo: usageLogRepo, - userRepo: userRepo, - userSubRepo: userSubRepo, - cache: cache, - cfg: cfg, - concurrencyService: concurrencyService, - billingService: billingService, - rateLimitService: rateLimitService, - billingCacheService: billingCacheService, - identityService: identityService, - httpUpstream: httpUpstream, - deferredService: deferredService, - } -} - -// GenerateSessionHash 从预解析请求计算粘性会话 hash -func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string { - if parsed == nil { - return "" - } - - // 1. 最高优先级:从 metadata.user_id 提取 session_xxx - if parsed.MetadataUserID != "" { - if match := sessionIDRegex.FindStringSubmatch(parsed.MetadataUserID); len(match) > 1 { - return match[1] - } - } - - // 2. 提取带 cache_control: {type: "ephemeral"} 的内容 - cacheableContent := s.extractCacheableContent(parsed) - if cacheableContent != "" { - return s.hashContent(cacheableContent) - } - - // 3. Fallback: 使用 system 内容 - if parsed.System != nil { - systemText := s.extractTextFromSystem(parsed.System) - if systemText != "" { - return s.hashContent(systemText) - } - } - - // 4. 最后 fallback: 使用第一条消息 - if len(parsed.Messages) > 0 { - if firstMsg, ok := parsed.Messages[0].(map[string]any); ok { - msgText := s.extractTextFromContent(firstMsg["content"]) - if msgText != "" { - return s.hashContent(msgText) - } - } - } - - return "" -} - -// BindStickySession sets session -> account binding with standard TTL. -func (s *GatewayService) BindStickySession(ctx context.Context, sessionHash string, accountID int64) error { - if sessionHash == "" || accountID <= 0 || s.cache == nil { - return nil - } - return s.cache.SetSessionAccountID(ctx, sessionHash, accountID, stickySessionTTL) -} - -func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string { - if parsed == nil { - return "" - } - - var builder strings.Builder - - // 检查 system 中的 cacheable 内容 - if system, ok := parsed.System.([]any); ok { - for _, part := range system { - if partMap, ok := part.(map[string]any); ok { - if cc, ok := partMap["cache_control"].(map[string]any); ok { - if cc["type"] == "ephemeral" { - if text, ok := partMap["text"].(string); ok { - _, _ = builder.WriteString(text) - } - } - } - } - } - } - systemText := builder.String() - - // 检查 messages 中的 cacheable 内容 - for _, msg := range parsed.Messages { - if msgMap, ok := msg.(map[string]any); ok { - if msgContent, ok := msgMap["content"].([]any); ok { - for _, part := range msgContent { - if partMap, ok := part.(map[string]any); ok { - if cc, ok := partMap["cache_control"].(map[string]any); ok { - if cc["type"] == "ephemeral" { - return s.extractTextFromContent(msgMap["content"]) - } - } - } - } - } - } - } - - return systemText -} - -func (s *GatewayService) extractTextFromSystem(system any) string { - switch v := system.(type) { - case string: - return v - case []any: - var texts []string - for _, part := range v { - if partMap, ok := part.(map[string]any); ok { - if text, ok := partMap["text"].(string); ok { - texts = append(texts, text) - } - } - } - return strings.Join(texts, "") - } - return "" -} - -func (s *GatewayService) extractTextFromContent(content any) string { - switch v := content.(type) { - case string: - return v - case []any: - var texts []string - for _, part := range v { - if partMap, ok := part.(map[string]any); ok { - if partMap["type"] == "text" { - if text, ok := partMap["text"].(string); ok { - texts = append(texts, text) - } - } - } - } - return strings.Join(texts, "") - } - return "" -} - -func (s *GatewayService) hashContent(content string) string { - hash := sha256.Sum256([]byte(content)) - return hex.EncodeToString(hash[:16]) // 32字符 -} - -// replaceModelInBody 替换请求体中的model字段 -func (s *GatewayService) replaceModelInBody(body []byte, newModel string) []byte { - var req map[string]any - if err := json.Unmarshal(body, &req); err != nil { - return body - } - req["model"] = newModel - newBody, err := json.Marshal(req) - if err != nil { - return body - } - return newBody -} - -// SelectAccount 选择账号(粘性会话+优先级) -func (s *GatewayService) SelectAccount(ctx context.Context, groupID *int64, sessionHash string) (*Account, error) { - return s.SelectAccountForModel(ctx, groupID, sessionHash, "") -} - -// SelectAccountForModel 选择支持指定模型的账号(粘性会话+优先级+模型映射) -func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*Account, error) { - return s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, nil) -} - -// SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts. -func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) { - // 优先检查 context 中的强制平台(/antigravity 路由) - var platform string - forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string) - if hasForcePlatform && forcePlatform != "" { - platform = forcePlatform - } else if groupID != nil { - // 根据分组 platform 决定查询哪种账号 - group, err := s.groupRepo.GetByID(ctx, *groupID) - if err != nil { - return nil, fmt.Errorf("get group failed: %w", err) - } - platform = group.Platform - } else { - // 无分组时只使用原生 anthropic 平台 - platform = PlatformAnthropic - } - - // anthropic/gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户) - // 注意:强制平台模式不走混合调度 - if (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform { - return s.selectAccountWithMixedScheduling(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform) - } - - // 强制平台模式:优先按分组查找,找不到再查全部该平台账户 - if hasForcePlatform && groupID != nil { - account, err := s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform) - if err == nil { - return account, nil - } - // 分组中找不到,回退查询全部该平台账户 - groupID = nil - } - - // antigravity 分组、强制平台模式或无分组使用单平台选择 - return s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform) -} - -// SelectAccountWithLoadAwareness selects account with load-awareness and wait plan. -func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*AccountSelectionResult, error) { - cfg := s.schedulingConfig() - var stickyAccountID int64 - if sessionHash != "" && s.cache != nil { - if accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash); err == nil { - stickyAccountID = accountID - } - } - if s.concurrencyService == nil || !cfg.LoadBatchEnabled { - account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs) - if err != nil { - return nil, err - } - result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency) - if err == nil && result.Acquired { - return &AccountSelectionResult{ - Account: account, - Acquired: true, - ReleaseFunc: result.ReleaseFunc, - }, nil - } - if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil { - waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID) - if waitingCount < cfg.StickySessionMaxWaiting { - return &AccountSelectionResult{ - Account: account, - WaitPlan: &AccountWaitPlan{ - AccountID: account.ID, - MaxConcurrency: account.Concurrency, - Timeout: cfg.StickySessionWaitTimeout, - MaxWaiting: cfg.StickySessionMaxWaiting, - }, - }, nil - } - } - return &AccountSelectionResult{ - Account: account, - WaitPlan: &AccountWaitPlan{ - AccountID: account.ID, - MaxConcurrency: account.Concurrency, - Timeout: cfg.FallbackWaitTimeout, - MaxWaiting: cfg.FallbackMaxWaiting, - }, - }, nil - } - - platform, hasForcePlatform, err := s.resolvePlatform(ctx, groupID) - if err != nil { - return nil, err - } - preferOAuth := platform == PlatformGemini - - accounts, useMixed, err := s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform) - if err != nil { - return nil, err - } - if len(accounts) == 0 { - return nil, errors.New("no available accounts") - } - - isExcluded := func(accountID int64) bool { - if excludedIDs == nil { - return false - } - _, excluded := excludedIDs[accountID] - return excluded - } - - // ============ Layer 1: 粘性会话优先 ============ - if sessionHash != "" && s.cache != nil { - accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash) - if err == nil && accountID > 0 && !isExcluded(accountID) { - account, err := s.accountRepo.GetByID(ctx, accountID) - if err == nil && s.isAccountAllowedForPlatform(account, platform, useMixed) && - account.IsSchedulable() && - (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) { - result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) - if err == nil && result.Acquired { - _ = s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL) - return &AccountSelectionResult{ - Account: account, - Acquired: true, - ReleaseFunc: result.ReleaseFunc, - }, nil - } - - waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID) - if waitingCount < cfg.StickySessionMaxWaiting { - return &AccountSelectionResult{ - Account: account, - WaitPlan: &AccountWaitPlan{ - AccountID: accountID, - MaxConcurrency: account.Concurrency, - Timeout: cfg.StickySessionWaitTimeout, - MaxWaiting: cfg.StickySessionMaxWaiting, - }, - }, nil - } - } - } - } - - // ============ Layer 2: 负载感知选择 ============ - candidates := make([]*Account, 0, len(accounts)) - for i := range accounts { - acc := &accounts[i] - if isExcluded(acc.ID) { - continue - } - if !s.isAccountAllowedForPlatform(acc, platform, useMixed) { - continue - } - if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) { - continue - } - candidates = append(candidates, acc) - } - - if len(candidates) == 0 { - return nil, errors.New("no available accounts") - } - - accountLoads := make([]AccountWithConcurrency, 0, len(candidates)) - for _, acc := range candidates { - accountLoads = append(accountLoads, AccountWithConcurrency{ - ID: acc.ID, - MaxConcurrency: acc.Concurrency, - }) - } - - loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads) - if err != nil { - if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, sessionHash, preferOAuth); ok { - return result, nil - } - } else { - type accountWithLoad struct { - account *Account - loadInfo *AccountLoadInfo - } - var available []accountWithLoad - for _, acc := range candidates { - loadInfo := loadMap[acc.ID] - if loadInfo == nil { - loadInfo = &AccountLoadInfo{AccountID: acc.ID} - } - if loadInfo.LoadRate < 100 { - available = append(available, accountWithLoad{ - account: acc, - loadInfo: loadInfo, - }) - } - } - - if len(available) > 0 { - sort.SliceStable(available, func(i, j int) bool { - a, b := available[i], available[j] - if a.account.Priority != b.account.Priority { - return a.account.Priority < b.account.Priority - } - if a.loadInfo.LoadRate != b.loadInfo.LoadRate { - return a.loadInfo.LoadRate < b.loadInfo.LoadRate - } - switch { - case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil: - return true - case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil: - return false - case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil: - if preferOAuth && a.account.Type != b.account.Type { - return a.account.Type == AccountTypeOAuth - } - return false - default: - return a.account.LastUsedAt.Before(*b.account.LastUsedAt) - } - }) - - for _, item := range available { - result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency) - if err == nil && result.Acquired { - if sessionHash != "" { - _ = s.cache.SetSessionAccountID(ctx, sessionHash, item.account.ID, stickySessionTTL) - } - return &AccountSelectionResult{ - Account: item.account, - Acquired: true, - ReleaseFunc: result.ReleaseFunc, - }, nil - } - } - } - } - - // ============ Layer 3: 兜底排队 ============ - sortAccountsByPriorityAndLastUsed(candidates, preferOAuth) - for _, acc := range candidates { - return &AccountSelectionResult{ - Account: acc, - WaitPlan: &AccountWaitPlan{ - AccountID: acc.ID, - MaxConcurrency: acc.Concurrency, - Timeout: cfg.FallbackWaitTimeout, - MaxWaiting: cfg.FallbackMaxWaiting, - }, - }, nil - } - return nil, errors.New("no available accounts") -} - -func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool) { - ordered := append([]*Account(nil), candidates...) - sortAccountsByPriorityAndLastUsed(ordered, preferOAuth) - - for _, acc := range ordered { - result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency) - if err == nil && result.Acquired { - if sessionHash != "" { - _ = s.cache.SetSessionAccountID(ctx, sessionHash, acc.ID, stickySessionTTL) - } - return &AccountSelectionResult{ - Account: acc, - Acquired: true, - ReleaseFunc: result.ReleaseFunc, - }, true - } - } - - return nil, false -} - -func (s *GatewayService) schedulingConfig() config.GatewaySchedulingConfig { - if s.cfg != nil { - return s.cfg.Gateway.Scheduling - } - return config.GatewaySchedulingConfig{ - StickySessionMaxWaiting: 3, - StickySessionWaitTimeout: 45 * time.Second, - FallbackWaitTimeout: 30 * time.Second, - FallbackMaxWaiting: 100, - LoadBatchEnabled: true, - SlotCleanupInterval: 30 * time.Second, - } -} - -func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64) (string, bool, error) { - forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string) - if hasForcePlatform && forcePlatform != "" { - return forcePlatform, true, nil - } - if groupID != nil { - group, err := s.groupRepo.GetByID(ctx, *groupID) - if err != nil { - return "", false, fmt.Errorf("get group failed: %w", err) - } - return group.Platform, false, nil - } - return PlatformAnthropic, false, nil -} - -func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, bool, error) { - useMixed := (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform - if useMixed { - platforms := []string{platform, PlatformAntigravity} - var accounts []Account - var err error - if groupID != nil { - accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, platforms) - } else { - accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms) - } - if err != nil { - return nil, useMixed, err - } - filtered := make([]Account, 0, len(accounts)) - for _, acc := range accounts { - if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() { - continue - } - filtered = append(filtered, acc) - } - return filtered, useMixed, nil - } - - var accounts []Account - var err error - if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { - accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform) - } else if groupID != nil { - accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, platform) - if err == nil && len(accounts) == 0 && hasForcePlatform { - accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform) - } - } else { - accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform) - } - if err != nil { - return nil, useMixed, err - } - return accounts, useMixed, nil -} - -func (s *GatewayService) isAccountAllowedForPlatform(account *Account, platform string, useMixed bool) bool { - if account == nil { - return false - } - if useMixed { - if account.Platform == platform { - return true - } - return account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled() - } - return account.Platform == platform -} - -func (s *GatewayService) tryAcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (*AcquireResult, error) { - if s.concurrencyService == nil { - return &AcquireResult{Acquired: true, ReleaseFunc: func() {}}, nil - } - return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency) -} - -func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) { - sort.SliceStable(accounts, func(i, j int) bool { - a, b := accounts[i], accounts[j] - if a.Priority != b.Priority { - return a.Priority < b.Priority - } - switch { - case a.LastUsedAt == nil && b.LastUsedAt != nil: - return true - case a.LastUsedAt != nil && b.LastUsedAt == nil: - return false - case a.LastUsedAt == nil && b.LastUsedAt == nil: - if preferOAuth && a.Type != b.Type { - return a.Type == AccountTypeOAuth - } - return false - default: - return a.LastUsedAt.Before(*b.LastUsedAt) - } - }) -} - -// selectAccountForModelWithPlatform 选择单平台账户(完全隔离) -func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) { - preferOAuth := platform == PlatformGemini - // 1. 查询粘性会话 - if sessionHash != "" { - accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash) - if err == nil && accountID > 0 { - if _, excluded := excludedIDs[accountID]; !excluded { - account, err := s.accountRepo.GetByID(ctx, accountID) - // 检查账号平台是否匹配(确保粘性会话不会跨平台) - if err == nil && account.Platform == platform && account.IsSchedulable() && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) { - if err := s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL); err != nil { - log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err) - } - return account, nil - } - } - } - } - - // 2. 获取可调度账号列表(单平台) - var accounts []Account - var err error - if s.cfg.RunMode == config.RunModeSimple { - // 简易模式:忽略 groupID,查询所有可用账号 - accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform) - } else if groupID != nil { - accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, platform) - } else { - accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform) - } - if err != nil { - return nil, fmt.Errorf("query accounts failed: %w", err) - } - - // 3. 按优先级+最久未用选择(考虑模型支持) - var selected *Account - for i := range accounts { - acc := &accounts[i] - if _, excluded := excludedIDs[acc.ID]; excluded { - continue - } - if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) { - continue - } - if selected == nil { - selected = acc - continue - } - if acc.Priority < selected.Priority { - selected = acc - } else if acc.Priority == selected.Priority { - switch { - case acc.LastUsedAt == nil && selected.LastUsedAt != nil: - selected = acc - case acc.LastUsedAt != nil && selected.LastUsedAt == nil: - // keep selected (never used is preferred) - case acc.LastUsedAt == nil && selected.LastUsedAt == nil: - if preferOAuth && acc.Type != selected.Type && acc.Type == AccountTypeOAuth { - selected = acc - } - default: - if acc.LastUsedAt.Before(*selected.LastUsedAt) { - selected = acc - } - } - } - } - - if selected == nil { - if requestedModel != "" { - return nil, fmt.Errorf("no available accounts supporting model: %s", requestedModel) - } - return nil, errors.New("no available accounts") - } - - // 4. 建立粘性绑定 - if sessionHash != "" { - if err := s.cache.SetSessionAccountID(ctx, sessionHash, selected.ID, stickySessionTTL); err != nil { - log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err) - } - } - - return selected, nil -} - -// selectAccountWithMixedScheduling 选择账户(支持混合调度) -// 查询原生平台账户 + 启用 mixed_scheduling 的 antigravity 账户 -func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, nativePlatform string) (*Account, error) { - platforms := []string{nativePlatform, PlatformAntigravity} - preferOAuth := nativePlatform == PlatformGemini - - // 1. 查询粘性会话 - if sessionHash != "" { - accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash) - if err == nil && accountID > 0 { - if _, excluded := excludedIDs[accountID]; !excluded { - account, err := s.accountRepo.GetByID(ctx, accountID) - // 检查账号是否有效:原生平台直接匹配,antigravity 需要启用混合调度 - if err == nil && account.IsSchedulable() && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) { - if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) { - if err := s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL); err != nil { - log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err) - } - return account, nil - } - } - } - } - } - - // 2. 获取可调度账号列表 - var accounts []Account - var err error - if groupID != nil { - accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, platforms) - } else { - accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms) - } - if err != nil { - return nil, fmt.Errorf("query accounts failed: %w", err) - } - - // 3. 按优先级+最久未用选择(考虑模型支持和混合调度) - var selected *Account - for i := range accounts { - acc := &accounts[i] - if _, excluded := excludedIDs[acc.ID]; excluded { - continue - } - // 过滤:原生平台直接通过,antigravity 需要启用混合调度 - if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() { - continue - } - if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) { - continue - } - if selected == nil { - selected = acc - continue - } - if acc.Priority < selected.Priority { - selected = acc - } else if acc.Priority == selected.Priority { - switch { - case acc.LastUsedAt == nil && selected.LastUsedAt != nil: - selected = acc - case acc.LastUsedAt != nil && selected.LastUsedAt == nil: - // keep selected (never used is preferred) - case acc.LastUsedAt == nil && selected.LastUsedAt == nil: - if preferOAuth && acc.Platform == PlatformGemini && selected.Platform == PlatformGemini && acc.Type != selected.Type && acc.Type == AccountTypeOAuth { - selected = acc - } - default: - if acc.LastUsedAt.Before(*selected.LastUsedAt) { - selected = acc - } - } - } - } - - if selected == nil { - if requestedModel != "" { - return nil, fmt.Errorf("no available accounts supporting model: %s", requestedModel) - } - return nil, errors.New("no available accounts") - } - - // 4. 建立粘性绑定 - if sessionHash != "" { - if err := s.cache.SetSessionAccountID(ctx, sessionHash, selected.ID, stickySessionTTL); err != nil { - log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err) - } - } - - return selected, nil -} - -// isModelSupportedByAccount 根据账户平台检查模型支持 -func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedModel string) bool { - if account.Platform == PlatformAntigravity { - // Antigravity 平台使用专门的模型支持检查 - return IsAntigravityModelSupported(requestedModel) - } - // 其他平台使用账户的模型支持检查 - return account.IsModelSupported(requestedModel) -} - -// IsAntigravityModelSupported 检查 Antigravity 平台是否支持指定模型 -// 所有 claude- 和 gemini- 前缀的模型都能通过映射或透传支持 -func IsAntigravityModelSupported(requestedModel string) bool { - return strings.HasPrefix(requestedModel, "claude-") || - strings.HasPrefix(requestedModel, "gemini-") -} - -// GetAccessToken 获取账号凭证 -func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) { - switch account.Type { - case AccountTypeOAuth, AccountTypeSetupToken: - // Both oauth and setup-token use OAuth token flow - return s.getOAuthToken(ctx, account) - case AccountTypeApiKey: - apiKey := account.GetCredential("api_key") - if apiKey == "" { - return "", "", errors.New("api_key not found in credentials") - } - return apiKey, "apikey", nil - default: - return "", "", fmt.Errorf("unsupported account type: %s", account.Type) - } -} - -func (s *GatewayService) getOAuthToken(ctx context.Context, account *Account) (string, string, error) { - accessToken := account.GetCredential("access_token") - if accessToken == "" { - return "", "", errors.New("access_token not found in credentials") - } - // Token刷新由后台 TokenRefreshService 处理,此处只返回当前token - return accessToken, "oauth", nil -} - -// 重试相关常量 -const ( - maxRetries = 10 // 最大重试次数 - retryDelay = 3 * time.Second // 重试等待时间 -) - -func (s *GatewayService) shouldRetryUpstreamError(account *Account, statusCode int) bool { - // OAuth/Setup Token 账号:仅 403 重试 - if account.IsOAuth() { - return statusCode == 403 - } - - // API Key 账号:未配置的错误码重试 - return !account.ShouldHandleErrorCode(statusCode) -} - -// shouldFailoverUpstreamError determines whether an upstream error should trigger account failover. -func (s *GatewayService) shouldFailoverUpstreamError(statusCode int) bool { - switch statusCode { - case 401, 403, 429, 529: - return true - default: - return statusCode >= 500 - } -} - -// Forward 转发请求到Claude API -func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) (*ForwardResult, error) { - startTime := time.Now() - if parsed == nil { - return nil, fmt.Errorf("parse request: empty request") - } - - body := parsed.Body - reqModel := parsed.Model - reqStream := parsed.Stream - - if !parsed.HasSystem { - body, _ = sjson.SetBytes(body, "system", []any{ - map[string]any{ - "type": "text", - "text": "You are Claude Code, Anthropic's official CLI for Claude.", - "cache_control": map[string]string{ - "type": "ephemeral", - }, - }, - }) - } - - // 应用模型映射(仅对apikey类型账号) - originalModel := reqModel - if account.Type == AccountTypeApiKey { - mappedModel := account.GetMappedModel(reqModel) - if mappedModel != reqModel { - // 替换请求体中的模型名 - body = s.replaceModelInBody(body, mappedModel) - reqModel = mappedModel - log.Printf("Model mapping applied: %s -> %s (account: %s)", originalModel, mappedModel, account.Name) - } - } - - // 获取凭证 - token, tokenType, err := s.GetAccessToken(ctx, account) - if err != nil { - return nil, err - } - - // 获取代理URL - proxyURL := "" - if account.ProxyID != nil && account.Proxy != nil { - proxyURL = account.Proxy.URL() - } - - // 重试循环 - var resp *http.Response - for attempt := 1; attempt <= maxRetries; attempt++ { - // 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取) - upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType, reqModel) - if err != nil { - return nil, err - } - - // 发送请求 - resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) - if err != nil { - return nil, fmt.Errorf("upstream request failed: %w", err) - } - - // 检查是否需要重试 - if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) { - if attempt < maxRetries { - log.Printf("Account %d: upstream error %d, retry %d/%d after %v", - account.ID, resp.StatusCode, attempt, maxRetries, retryDelay) - _ = resp.Body.Close() - time.Sleep(retryDelay) - continue - } - // 最后一次尝试也失败,跳出循环处理重试耗尽 - break - } - - // 不需要重试(成功或不可重试的错误),跳出循环 - break - } - defer func() { _ = resp.Body.Close() }() - - // 处理重试耗尽的情况 - if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) { - if s.shouldFailoverUpstreamError(resp.StatusCode) { - s.handleRetryExhaustedSideEffects(ctx, resp, account) - return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} - } - return s.handleRetryExhaustedError(ctx, resp, c, account) - } - - // 处理可切换账号的错误 - if resp.StatusCode >= 400 && s.shouldFailoverUpstreamError(resp.StatusCode) { - s.handleFailoverSideEffects(ctx, resp, account) - return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} - } - - // 处理错误响应(不可重试的错误) - if resp.StatusCode >= 400 { - // 可选:对部分 400 触发 failover(默认关闭以保持语义) - if resp.StatusCode == 400 && s.cfg != nil && s.cfg.Gateway.FailoverOn400 { - respBody, readErr := io.ReadAll(resp.Body) - if readErr != nil { - // ReadAll failed, fall back to normal error handling without consuming the stream - return s.handleErrorResponse(ctx, resp, c, account) - } - _ = resp.Body.Close() - resp.Body = io.NopCloser(bytes.NewReader(respBody)) - - if s.shouldFailoverOn400(respBody) { - if s.cfg.Gateway.LogUpstreamErrorBody { - log.Printf( - "Account %d: 400 error, attempting failover: %s", - account.ID, - truncateForLog(respBody, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes), - ) - } else { - log.Printf("Account %d: 400 error, attempting failover", account.ID) - } - s.handleFailoverSideEffects(ctx, resp, account) - return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} - } - } - return s.handleErrorResponse(ctx, resp, c, account) - } - - // 处理正常响应 - var usage *ClaudeUsage - var firstTokenMs *int - if reqStream { - streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, reqModel) - if err != nil { - if err.Error() == "have error in stream" { - return nil, &UpstreamFailoverError{ - StatusCode: 403, - } - } - return nil, err - } - usage = streamResult.usage - firstTokenMs = streamResult.firstTokenMs - } else { - usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, reqModel) - if err != nil { - return nil, err - } - } - - return &ForwardResult{ - RequestID: resp.Header.Get("x-request-id"), - Usage: *usage, - Model: originalModel, // 使用原始模型用于计费和日志 - Stream: reqStream, - Duration: time.Since(startTime), - FirstTokenMs: firstTokenMs, - }, nil -} - -func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) { - // 确定目标URL - targetURL := claudeAPIURL - if account.Type == AccountTypeApiKey { - baseURL := account.GetBaseURL() - targetURL = baseURL + "/v1/messages" - } - - // OAuth账号:应用统一指纹 - var fingerprint *Fingerprint - if account.IsOAuth() && s.identityService != nil { - // 1. 获取或创建指纹(包含随机生成的ClientID) - fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header) - if err != nil { - log.Printf("Warning: failed to get fingerprint for account %d: %v", account.ID, err) - // 失败时降级为透传原始headers - } else { - fingerprint = fp - - // 2. 重写metadata.user_id(需要指纹中的ClientID和账号的account_uuid) - accountUUID := account.GetExtraString("account_uuid") - if accountUUID != "" && fp.ClientID != "" { - if newBody, err := s.identityService.RewriteUserID(body, account.ID, accountUUID, fp.ClientID); err == nil && len(newBody) > 0 { - body = newBody - } - } - } - } - - req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body)) - if err != nil { - return nil, err - } - - // 设置认证头 - if tokenType == "oauth" { - req.Header.Set("authorization", "Bearer "+token) - } else { - req.Header.Set("x-api-key", token) - } - - // 白名单透传headers - for key, values := range c.Request.Header { - lowerKey := strings.ToLower(key) - if allowedHeaders[lowerKey] { - for _, v := range values { - req.Header.Add(key, v) - } - } - } - - // OAuth账号:应用缓存的指纹到请求头(覆盖白名单透传的头) - if fingerprint != nil { - s.identityService.ApplyFingerprint(req, fingerprint) - } - - // 确保必要的headers存在 - if req.Header.Get("content-type") == "" { - req.Header.Set("content-type", "application/json") - } - if req.Header.Get("anthropic-version") == "" { - req.Header.Set("anthropic-version", "2023-06-01") - } - - // 处理anthropic-beta header(OAuth账号需要特殊处理) - if tokenType == "oauth" { - req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta"))) - } else if s.cfg != nil && s.cfg.Gateway.InjectBetaForApiKey && req.Header.Get("anthropic-beta") == "" { - // API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭) - if requestNeedsBetaFeatures(body) { - if beta := defaultApiKeyBetaHeader(body); beta != "" { - req.Header.Set("anthropic-beta", beta) - } - } - } - - return req, nil -} - -// getBetaHeader 处理anthropic-beta header -// 对于OAuth账号,需要确保包含oauth-2025-04-20 -func (s *GatewayService) getBetaHeader(modelID string, clientBetaHeader string) string { - // 如果客户端传了anthropic-beta - if clientBetaHeader != "" { - // 已包含oauth beta则直接返回 - if strings.Contains(clientBetaHeader, claude.BetaOAuth) { - return clientBetaHeader - } - - // 需要添加oauth beta - parts := strings.Split(clientBetaHeader, ",") - for i, p := range parts { - parts[i] = strings.TrimSpace(p) - } - - // 在claude-code-20250219后面插入oauth beta - claudeCodeIdx := -1 - for i, p := range parts { - if p == claude.BetaClaudeCode { - claudeCodeIdx = i - break - } - } - - if claudeCodeIdx >= 0 { - // 在claude-code后面插入 - newParts := make([]string, 0, len(parts)+1) - newParts = append(newParts, parts[:claudeCodeIdx+1]...) - newParts = append(newParts, claude.BetaOAuth) - newParts = append(newParts, parts[claudeCodeIdx+1:]...) - return strings.Join(newParts, ",") - } - - // 没有claude-code,放在第一位 - return claude.BetaOAuth + "," + clientBetaHeader - } - - // 客户端没传,根据模型生成 - // haiku 模型不需要 claude-code beta - if strings.Contains(strings.ToLower(modelID), "haiku") { - return claude.HaikuBetaHeader - } - - return claude.DefaultBetaHeader -} - -func requestNeedsBetaFeatures(body []byte) bool { - tools := gjson.GetBytes(body, "tools") - if tools.Exists() && tools.IsArray() && len(tools.Array()) > 0 { - return true - } - if strings.EqualFold(gjson.GetBytes(body, "thinking.type").String(), "enabled") { - return true - } - return false -} - -func defaultApiKeyBetaHeader(body []byte) string { - modelID := gjson.GetBytes(body, "model").String() - if strings.Contains(strings.ToLower(modelID), "haiku") { - return claude.ApiKeyHaikuBetaHeader - } - return claude.ApiKeyBetaHeader -} - -func truncateForLog(b []byte, maxBytes int) string { - if maxBytes <= 0 { - maxBytes = 2048 - } - if len(b) > maxBytes { - b = b[:maxBytes] - } - s := string(b) - // 保持一行,避免污染日志格式 - s = strings.ReplaceAll(s, "\n", "\\n") - s = strings.ReplaceAll(s, "\r", "\\r") - return s -} - -func (s *GatewayService) shouldFailoverOn400(respBody []byte) bool { - // 只对“可能是兼容性差异导致”的 400 允许切换,避免无意义重试。 - // 默认保守:无法识别则不切换。 - msg := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(respBody))) - if msg == "" { - return false - } - - // 缺少/错误的 beta header:换账号/链路可能成功(尤其是混合调度时)。 - // 更精确匹配 beta 相关的兼容性问题,避免误触发切换。 - if strings.Contains(msg, "anthropic-beta") || - strings.Contains(msg, "beta feature") || - strings.Contains(msg, "requires beta") { - return true - } - - // thinking/tool streaming 等兼容性约束(常见于中间转换链路) - if strings.Contains(msg, "thinking") || strings.Contains(msg, "thought_signature") || strings.Contains(msg, "signature") { - return true - } - if strings.Contains(msg, "tool_use") || strings.Contains(msg, "tool_result") || strings.Contains(msg, "tools") { - return true - } - - return false -} - -func extractUpstreamErrorMessage(body []byte) string { - // Claude 风格:{"type":"error","error":{"type":"...","message":"..."}} - if m := gjson.GetBytes(body, "error.message").String(); strings.TrimSpace(m) != "" { - inner := strings.TrimSpace(m) - // 有些上游会把完整 JSON 作为字符串塞进 message - if strings.HasPrefix(inner, "{") { - if innerMsg := gjson.Get(inner, "error.message").String(); strings.TrimSpace(innerMsg) != "" { - return innerMsg - } - } - return m - } - - // 兜底:尝试顶层 message - return gjson.GetBytes(body, "message").String() -} - -func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) { - body, _ := io.ReadAll(resp.Body) - - // 处理上游错误,标记账号状态 - s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) - - // 根据状态码返回适当的自定义错误响应(不透传上游详细信息) - var errType, errMsg string - var statusCode int - - switch resp.StatusCode { - case 400: - // 仅记录上游错误摘要(避免输出请求内容);需要时可通过配置打开 - if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { - log.Printf( - "Upstream 400 error (account=%d platform=%s type=%s): %s", - account.ID, - account.Platform, - account.Type, - truncateForLog(body, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes), - ) - } - c.Data(http.StatusBadRequest, "application/json", body) - return nil, fmt.Errorf("upstream error: %d", resp.StatusCode) - case 401: - statusCode = http.StatusBadGateway - errType = "upstream_error" - errMsg = "Upstream authentication failed, please contact administrator" - case 403: - statusCode = http.StatusBadGateway - errType = "upstream_error" - errMsg = "Upstream access forbidden, please contact administrator" - case 429: - statusCode = http.StatusTooManyRequests - errType = "rate_limit_error" - errMsg = "Upstream rate limit exceeded, please retry later" - case 529: - statusCode = http.StatusServiceUnavailable - errType = "overloaded_error" - errMsg = "Upstream service overloaded, please retry later" - case 500, 502, 503, 504: - statusCode = http.StatusBadGateway - errType = "upstream_error" - errMsg = "Upstream service temporarily unavailable" - default: - statusCode = http.StatusBadGateway - errType = "upstream_error" - errMsg = "Upstream request failed" - } - - // 返回自定义错误响应 - c.JSON(statusCode, gin.H{ - "type": "error", - "error": gin.H{ - "type": errType, - "message": errMsg, - }, - }) - - return nil, fmt.Errorf("upstream error: %d", resp.StatusCode) -} - -func (s *GatewayService) handleRetryExhaustedSideEffects(ctx context.Context, resp *http.Response, account *Account) { - body, _ := io.ReadAll(resp.Body) - statusCode := resp.StatusCode - - // OAuth/Setup Token 账号的 403:标记账号异常 - if account.IsOAuth() && statusCode == 403 { - s.rateLimitService.HandleUpstreamError(ctx, account, statusCode, resp.Header, body) - log.Printf("Account %d: marked as error after %d retries for status %d", account.ID, maxRetries, statusCode) - } else { - // API Key 未配置错误码:不标记账号状态 - log.Printf("Account %d: upstream error %d after %d retries (not marking account)", account.ID, statusCode, maxRetries) - } -} - -func (s *GatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account) { - body, _ := io.ReadAll(resp.Body) - s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) -} - -// handleRetryExhaustedError 处理重试耗尽后的错误 -// OAuth 403:标记账号异常 -// API Key 未配置错误码:仅返回错误,不标记账号 -func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) { - s.handleRetryExhaustedSideEffects(ctx, resp, account) - - // 返回统一的重试耗尽错误响应 - c.JSON(http.StatusBadGateway, gin.H{ - "type": "error", - "error": gin.H{ - "type": "upstream_error", - "message": "Upstream request failed after retries", - }, - }) - - return nil, fmt.Errorf("upstream error: %d (retries exhausted)", resp.StatusCode) -} - -// streamingResult 流式响应结果 -type streamingResult struct { - usage *ClaudeUsage - firstTokenMs *int -} - -func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string) (*streamingResult, error) { - // 更新5h窗口状态 - s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header) - - // 设置SSE响应头 - c.Header("Content-Type", "text/event-stream") - c.Header("Cache-Control", "no-cache") - c.Header("Connection", "keep-alive") - c.Header("X-Accel-Buffering", "no") - - // 透传其他响应头 - if v := resp.Header.Get("x-request-id"); v != "" { - c.Header("x-request-id", v) - } - - w := c.Writer - flusher, ok := w.(http.Flusher) - if !ok { - return nil, errors.New("streaming not supported") - } - - usage := &ClaudeUsage{} - var firstTokenMs *int - scanner := bufio.NewScanner(resp.Body) - // 设置更大的buffer以处理长行 - scanner.Buffer(make([]byte, 64*1024), 1024*1024) - - needModelReplace := originalModel != mappedModel - - for scanner.Scan() { - line := scanner.Text() - if line == "event: error" { - return nil, errors.New("have error in stream") - } - - // Extract data from SSE line (supports both "data: " and "data:" formats) - if sseDataRe.MatchString(line) { - data := sseDataRe.ReplaceAllString(line, "") - - // 如果有模型映射,替换响应中的model字段 - if needModelReplace { - line = s.replaceModelInSSELine(line, mappedModel, originalModel) - } - - // 转发行 - if _, err := fmt.Fprintf(w, "%s\n", line); err != nil { - return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err - } - flusher.Flush() - - // 记录首字时间:第一个有效的 content_block_delta 或 message_start - if firstTokenMs == nil && data != "" && data != "[DONE]" { - ms := int(time.Since(startTime).Milliseconds()) - firstTokenMs = &ms - } - s.parseSSEUsage(data, usage) - } else { - // 非 data 行直接转发 - if _, err := fmt.Fprintf(w, "%s\n", line); err != nil { - return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err - } - flusher.Flush() - } - } - - if err := scanner.Err(); err != nil { - return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err) - } - - return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil -} - -// replaceModelInSSELine 替换SSE数据行中的model字段 -func (s *GatewayService) replaceModelInSSELine(line, fromModel, toModel string) string { - if !sseDataRe.MatchString(line) { - return line - } - data := sseDataRe.ReplaceAllString(line, "") - if data == "" || data == "[DONE]" { - return line - } - - var event map[string]any - if err := json.Unmarshal([]byte(data), &event); err != nil { - return line - } - - // 只替换 message_start 事件中的 message.model - if event["type"] != "message_start" { - return line - } - - msg, ok := event["message"].(map[string]any) - if !ok { - return line - } - - model, ok := msg["model"].(string) - if !ok || model != fromModel { - return line - } - - msg["model"] = toModel - newData, err := json.Marshal(event) - if err != nil { - return line - } - - return "data: " + string(newData) -} - -func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) { - // 解析message_start获取input tokens(标准Claude API格式) - var msgStart struct { - Type string `json:"type"` - Message struct { - Usage ClaudeUsage `json:"usage"` - } `json:"message"` - } - if json.Unmarshal([]byte(data), &msgStart) == nil && msgStart.Type == "message_start" { - usage.InputTokens = msgStart.Message.Usage.InputTokens - usage.CacheCreationInputTokens = msgStart.Message.Usage.CacheCreationInputTokens - usage.CacheReadInputTokens = msgStart.Message.Usage.CacheReadInputTokens - } - - // 解析message_delta获取tokens(兼容GLM等把所有usage放在delta中的API) - var msgDelta struct { - Type string `json:"type"` - Usage struct { - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` - CacheCreationInputTokens int `json:"cache_creation_input_tokens"` - CacheReadInputTokens int `json:"cache_read_input_tokens"` - } `json:"usage"` - } - if json.Unmarshal([]byte(data), &msgDelta) == nil && msgDelta.Type == "message_delta" { - // output_tokens 总是从 message_delta 获取 - usage.OutputTokens = msgDelta.Usage.OutputTokens - - // 如果 message_start 中没有值,则从 message_delta 获取(兼容GLM等API) - if usage.InputTokens == 0 { - usage.InputTokens = msgDelta.Usage.InputTokens - } - if usage.CacheCreationInputTokens == 0 { - usage.CacheCreationInputTokens = msgDelta.Usage.CacheCreationInputTokens - } - if usage.CacheReadInputTokens == 0 { - usage.CacheReadInputTokens = msgDelta.Usage.CacheReadInputTokens - } - } -} - -func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*ClaudeUsage, error) { - // 更新5h窗口状态 - s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header) - - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, err - } - - // 解析usage - var response struct { - Usage ClaudeUsage `json:"usage"` - } - if err := json.Unmarshal(body, &response); err != nil { - return nil, fmt.Errorf("parse response: %w", err) - } - - // 如果有模型映射,替换响应中的model字段 - if originalModel != mappedModel { - body = s.replaceModelInResponseBody(body, mappedModel, originalModel) - } - - // 透传响应头 - for key, values := range resp.Header { - for _, value := range values { - c.Header(key, value) - } - } - - // 写入响应 - c.Data(resp.StatusCode, "application/json", body) - - return &response.Usage, nil -} - -// replaceModelInResponseBody 替换响应体中的model字段 -func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toModel string) []byte { - var resp map[string]any - if err := json.Unmarshal(body, &resp); err != nil { - return body - } - - model, ok := resp["model"].(string) - if !ok || model != fromModel { - return body - } - - resp["model"] = toModel - newBody, err := json.Marshal(resp) - if err != nil { - return body - } - - return newBody -} - -// RecordUsageInput 记录使用量的输入参数 -type RecordUsageInput struct { - Result *ForwardResult - ApiKey *ApiKey - User *User - Account *Account - Subscription *UserSubscription // 可选:订阅信息 -} - -// RecordUsage 记录使用量并扣费(或更新订阅用量) -func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInput) error { - result := input.Result - apiKey := input.ApiKey - user := input.User - account := input.Account - subscription := input.Subscription - - // 计算费用 - tokens := UsageTokens{ - InputTokens: result.Usage.InputTokens, - OutputTokens: result.Usage.OutputTokens, - CacheCreationTokens: result.Usage.CacheCreationInputTokens, - CacheReadTokens: result.Usage.CacheReadInputTokens, - } - - // 获取费率倍数 - multiplier := s.cfg.Default.RateMultiplier - if apiKey.GroupID != nil && apiKey.Group != nil { - multiplier = apiKey.Group.RateMultiplier - } - - cost, err := s.billingService.CalculateCost(result.Model, tokens, multiplier) - if err != nil { - log.Printf("Calculate cost failed: %v", err) - // 使用默认费用继续 - cost = &CostBreakdown{ActualCost: 0} - } - - // 判断计费方式:订阅模式 vs 余额模式 - isSubscriptionBilling := subscription != nil && apiKey.Group != nil && apiKey.Group.IsSubscriptionType() - billingType := BillingTypeBalance - if isSubscriptionBilling { - billingType = BillingTypeSubscription - } - - // 创建使用日志 - durationMs := int(result.Duration.Milliseconds()) - usageLog := &UsageLog{ - UserID: user.ID, - ApiKeyID: apiKey.ID, - AccountID: account.ID, - RequestID: result.RequestID, - Model: result.Model, - InputTokens: result.Usage.InputTokens, - OutputTokens: result.Usage.OutputTokens, - CacheCreationTokens: result.Usage.CacheCreationInputTokens, - CacheReadTokens: result.Usage.CacheReadInputTokens, - InputCost: cost.InputCost, - OutputCost: cost.OutputCost, - CacheCreationCost: cost.CacheCreationCost, - CacheReadCost: cost.CacheReadCost, - TotalCost: cost.TotalCost, - ActualCost: cost.ActualCost, - RateMultiplier: multiplier, - BillingType: billingType, - Stream: result.Stream, - DurationMs: &durationMs, - FirstTokenMs: result.FirstTokenMs, - CreatedAt: time.Now(), - } - - // 添加分组和订阅关联 - if apiKey.GroupID != nil { - usageLog.GroupID = apiKey.GroupID - } - if subscription != nil { - usageLog.SubscriptionID = &subscription.ID - } - - if err := s.usageLogRepo.Create(ctx, usageLog); err != nil { - log.Printf("Create usage log failed: %v", err) - } - - if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { - log.Printf("[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens()) - s.deferredService.ScheduleLastUsedUpdate(account.ID) - return nil - } - - // 根据计费类型执行扣费 - if isSubscriptionBilling { - // 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率) - if cost.TotalCost > 0 { - if err := s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost); err != nil { - log.Printf("Increment subscription usage failed: %v", err) - } - // 异步更新订阅缓存 - s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost) - } - } else { - // 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用) - if cost.ActualCost > 0 { - if err := s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost); err != nil { - log.Printf("Deduct balance failed: %v", err) - } - // 异步更新余额缓存 - s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost) - } - } - - // Schedule batch update for account last_used_at - s.deferredService.ScheduleLastUsedUpdate(account.ID) - - return nil -} - -// ForwardCountTokens 转发 count_tokens 请求到上游 API -// 特点:不记录使用量、仅支持非流式响应 -func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) error { - if parsed == nil { - s.countTokensError(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty") - return fmt.Errorf("parse request: empty request") - } - - body := parsed.Body - reqModel := parsed.Model - - // Antigravity 账户不支持 count_tokens 转发,直接返回空值 - if account.Platform == PlatformAntigravity { - c.JSON(http.StatusOK, gin.H{"input_tokens": 0}) - return nil - } - - // 应用模型映射(仅对 apikey 类型账号) - if account.Type == AccountTypeApiKey { - if reqModel != "" { - mappedModel := account.GetMappedModel(reqModel) - if mappedModel != reqModel { - body = s.replaceModelInBody(body, mappedModel) - reqModel = mappedModel - log.Printf("CountTokens model mapping applied: %s -> %s (account: %s)", parsed.Model, mappedModel, account.Name) - } - } - } - - // 获取凭证 - token, tokenType, err := s.GetAccessToken(ctx, account) - if err != nil { - s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to get access token") - return err - } - - // 构建上游请求 - upstreamReq, err := s.buildCountTokensRequest(ctx, c, account, body, token, tokenType, reqModel) - if err != nil { - s.countTokensError(c, http.StatusInternalServerError, "api_error", "Failed to build request") - return err - } - - // 获取代理URL - proxyURL := "" - if account.ProxyID != nil && account.Proxy != nil { - proxyURL = account.Proxy.URL() - } - - // 发送请求 - resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) - if err != nil { - s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Request failed") - return fmt.Errorf("upstream request failed: %w", err) - } - defer func() { - _ = resp.Body.Close() - }() - - // 读取响应体 - respBody, err := io.ReadAll(resp.Body) - if err != nil { - s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response") - return err - } - - // 处理错误响应 - if resp.StatusCode >= 400 { - // 标记账号状态(429/529等) - s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) - - // 记录上游错误摘要便于排障(不回显请求内容) - if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { - log.Printf( - "count_tokens upstream error %d (account=%d platform=%s type=%s): %s", - resp.StatusCode, - account.ID, - account.Platform, - account.Type, - truncateForLog(respBody, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes), - ) - } - - // 返回简化的错误响应 - errMsg := "Upstream request failed" - switch resp.StatusCode { - case 429: - errMsg = "Rate limit exceeded" - case 529: - errMsg = "Service overloaded" - } - s.countTokensError(c, resp.StatusCode, "upstream_error", errMsg) - return fmt.Errorf("upstream error: %d", resp.StatusCode) - } - - // 透传成功响应 - c.Data(resp.StatusCode, "application/json", respBody) - return nil -} - -// buildCountTokensRequest 构建 count_tokens 上游请求 -func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) { - // 确定目标 URL - targetURL := claudeAPICountTokensURL - if account.Type == AccountTypeApiKey { - baseURL := account.GetBaseURL() - targetURL = baseURL + "/v1/messages/count_tokens" - } - - // OAuth 账号:应用统一指纹和重写 userID - if account.IsOAuth() && s.identityService != nil { - fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header) - if err == nil { - accountUUID := account.GetExtraString("account_uuid") - if accountUUID != "" && fp.ClientID != "" { - if newBody, err := s.identityService.RewriteUserID(body, account.ID, accountUUID, fp.ClientID); err == nil && len(newBody) > 0 { - body = newBody - } - } - } - } - - req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body)) - if err != nil { - return nil, err - } - - // 设置认证头 - if tokenType == "oauth" { - req.Header.Set("authorization", "Bearer "+token) - } else { - req.Header.Set("x-api-key", token) - } - - // 白名单透传 headers - for key, values := range c.Request.Header { - lowerKey := strings.ToLower(key) - if allowedHeaders[lowerKey] { - for _, v := range values { - req.Header.Add(key, v) - } - } - } - - // OAuth 账号:应用指纹到请求头 - if account.IsOAuth() && s.identityService != nil { - fp, _ := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header) - if fp != nil { - s.identityService.ApplyFingerprint(req, fp) - } - } - - // 确保必要的 headers 存在 - if req.Header.Get("content-type") == "" { - req.Header.Set("content-type", "application/json") - } - if req.Header.Get("anthropic-version") == "" { - req.Header.Set("anthropic-version", "2023-06-01") - } - - // OAuth 账号:处理 anthropic-beta header - if tokenType == "oauth" { - req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta"))) - } else if s.cfg != nil && s.cfg.Gateway.InjectBetaForApiKey && req.Header.Get("anthropic-beta") == "" { - // API-key:与 messages 同步的按需 beta 注入(默认关闭) - if requestNeedsBetaFeatures(body) { - if beta := defaultApiKeyBetaHeader(body); beta != "" { - req.Header.Set("anthropic-beta", beta) - } - } - } - - return req, nil -} - -// countTokensError 返回 count_tokens 错误响应 -func (s *GatewayService) countTokensError(c *gin.Context, status int, errType, message string) { - c.JSON(status, gin.H{ - "type": "error", - "error": gin.H{ - "type": errType, - "message": message, - }, - }) -} - -// GetAvailableModels returns the list of models available for a group -// It aggregates model_mapping keys from all schedulable accounts in the group -func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64, platform string) []string { - var accounts []Account - var err error - - if groupID != nil { - accounts, err = s.accountRepo.ListSchedulableByGroupID(ctx, *groupID) - } else { - accounts, err = s.accountRepo.ListSchedulable(ctx) - } - - if err != nil || len(accounts) == 0 { - return nil - } - - // Filter by platform if specified - if platform != "" { - filtered := make([]Account, 0) - for _, acc := range accounts { - if acc.Platform == platform { - filtered = append(filtered, acc) - } - } - accounts = filtered - } - - // Collect unique models from all accounts - modelSet := make(map[string]struct{}) - hasAnyMapping := false - - for _, acc := range accounts { - mapping := acc.GetModelMapping() - if len(mapping) > 0 { - hasAnyMapping = true - for model := range mapping { - modelSet[model] = struct{}{} - } - } - } - - // If no account has model_mapping, return nil (use default) - if !hasAnyMapping { - return nil - } - - // Convert to slice - models := make([]string, 0, len(modelSet)) - for model := range modelSet { - models = append(models, model) - } - - return models -} +package service + +import ( + "bufio" + "bytes" + "context" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "io" + "log" + "net/http" + "regexp" + "sort" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/claude" + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" + + "github.com/gin-gonic/gin" +) + +const ( + claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true" + claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true" + stickySessionTTL = time.Hour // 粘性会话TTL + claudeCodeSystemPrompt = "You are Claude Code, Anthropic's official CLI for Claude." +) + +// sseDataRe matches SSE data lines with optional whitespace after colon. +// Some upstream APIs return non-standard "data:" without space (should be "data: "). +var ( + sseDataRe = regexp.MustCompile(`^data:\s*`) + sessionIDRegex = regexp.MustCompile(`session_([a-f0-9-]{36})`) + claudeCliUserAgentRe = regexp.MustCompile(`^claude-cli/\d+\.\d+\.\d+`) +) + +// allowedHeaders 白名单headers(参考CRS项目) +var allowedHeaders = map[string]bool{ + "accept": true, + "x-stainless-retry-count": true, + "x-stainless-timeout": true, + "x-stainless-lang": true, + "x-stainless-package-version": true, + "x-stainless-os": true, + "x-stainless-arch": true, + "x-stainless-runtime": true, + "x-stainless-runtime-version": true, + "x-stainless-helper-method": true, + "anthropic-dangerous-direct-browser-access": true, + "anthropic-version": true, + "x-app": true, + "anthropic-beta": true, + "accept-language": true, + "sec-fetch-mode": true, + "user-agent": true, + "content-type": true, +} + +// GatewayCache defines cache operations for gateway service +type GatewayCache interface { + GetSessionAccountID(ctx context.Context, sessionHash string) (int64, error) + SetSessionAccountID(ctx context.Context, sessionHash string, accountID int64, ttl time.Duration) error + RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error +} + +type AccountWaitPlan struct { + AccountID int64 + MaxConcurrency int + Timeout time.Duration + MaxWaiting int +} + +type AccountSelectionResult struct { + Account *Account + Acquired bool + ReleaseFunc func() + WaitPlan *AccountWaitPlan // nil means no wait allowed +} + +// ClaudeUsage 表示Claude API返回的usage信息 +type ClaudeUsage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + CacheCreationInputTokens int `json:"cache_creation_input_tokens"` + CacheReadInputTokens int `json:"cache_read_input_tokens"` +} + +// ForwardResult 转发结果 +type ForwardResult struct { + RequestID string + Usage ClaudeUsage + Model string + Stream bool + Duration time.Duration + FirstTokenMs *int // 首字时间(流式请求) +} + +// UpstreamFailoverError indicates an upstream error that should trigger account failover. +type UpstreamFailoverError struct { + StatusCode int +} + +func (e *UpstreamFailoverError) Error() string { + return fmt.Sprintf("upstream error: %d (failover)", e.StatusCode) +} + +// GatewayService handles API gateway operations +type GatewayService struct { + accountRepo AccountRepository + groupRepo GroupRepository + usageLogRepo UsageLogRepository + userRepo UserRepository + userSubRepo UserSubscriptionRepository + cache GatewayCache + cfg *config.Config + billingService *BillingService + rateLimitService *RateLimitService + billingCacheService *BillingCacheService + identityService *IdentityService + httpUpstream HTTPUpstream + deferredService *DeferredService + concurrencyService *ConcurrencyService +} + +// NewGatewayService creates a new GatewayService +func NewGatewayService( + accountRepo AccountRepository, + groupRepo GroupRepository, + usageLogRepo UsageLogRepository, + userRepo UserRepository, + userSubRepo UserSubscriptionRepository, + cache GatewayCache, + cfg *config.Config, + concurrencyService *ConcurrencyService, + billingService *BillingService, + rateLimitService *RateLimitService, + billingCacheService *BillingCacheService, + identityService *IdentityService, + httpUpstream HTTPUpstream, + deferredService *DeferredService, +) *GatewayService { + return &GatewayService{ + accountRepo: accountRepo, + groupRepo: groupRepo, + usageLogRepo: usageLogRepo, + userRepo: userRepo, + userSubRepo: userSubRepo, + cache: cache, + cfg: cfg, + concurrencyService: concurrencyService, + billingService: billingService, + rateLimitService: rateLimitService, + billingCacheService: billingCacheService, + identityService: identityService, + httpUpstream: httpUpstream, + deferredService: deferredService, + } +} + +// GenerateSessionHash 从预解析请求计算粘性会话 hash +func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string { + if parsed == nil { + return "" + } + + // 1. 最高优先级:从 metadata.user_id 提取 session_xxx + if parsed.MetadataUserID != "" { + if match := sessionIDRegex.FindStringSubmatch(parsed.MetadataUserID); len(match) > 1 { + return match[1] + } + } + + // 2. 提取带 cache_control: {type: "ephemeral"} 的内容 + cacheableContent := s.extractCacheableContent(parsed) + if cacheableContent != "" { + return s.hashContent(cacheableContent) + } + + // 3. Fallback: 使用 system 内容 + if parsed.System != nil { + systemText := s.extractTextFromSystem(parsed.System) + if systemText != "" { + return s.hashContent(systemText) + } + } + + // 4. 最后 fallback: 使用第一条消息 + if len(parsed.Messages) > 0 { + if firstMsg, ok := parsed.Messages[0].(map[string]any); ok { + msgText := s.extractTextFromContent(firstMsg["content"]) + if msgText != "" { + return s.hashContent(msgText) + } + } + } + + return "" +} + +// BindStickySession sets session -> account binding with standard TTL. +func (s *GatewayService) BindStickySession(ctx context.Context, sessionHash string, accountID int64) error { + if sessionHash == "" || accountID <= 0 || s.cache == nil { + return nil + } + return s.cache.SetSessionAccountID(ctx, sessionHash, accountID, stickySessionTTL) +} + +func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string { + if parsed == nil { + return "" + } + + var builder strings.Builder + + // 检查 system 中的 cacheable 内容 + if system, ok := parsed.System.([]any); ok { + for _, part := range system { + if partMap, ok := part.(map[string]any); ok { + if cc, ok := partMap["cache_control"].(map[string]any); ok { + if cc["type"] == "ephemeral" { + if text, ok := partMap["text"].(string); ok { + _, _ = builder.WriteString(text) + } + } + } + } + } + } + systemText := builder.String() + + // 检查 messages 中的 cacheable 内容 + for _, msg := range parsed.Messages { + if msgMap, ok := msg.(map[string]any); ok { + if msgContent, ok := msgMap["content"].([]any); ok { + for _, part := range msgContent { + if partMap, ok := part.(map[string]any); ok { + if cc, ok := partMap["cache_control"].(map[string]any); ok { + if cc["type"] == "ephemeral" { + return s.extractTextFromContent(msgMap["content"]) + } + } + } + } + } + } + } + + return systemText +} + +func (s *GatewayService) extractTextFromSystem(system any) string { + switch v := system.(type) { + case string: + return v + case []any: + var texts []string + for _, part := range v { + if partMap, ok := part.(map[string]any); ok { + if text, ok := partMap["text"].(string); ok { + texts = append(texts, text) + } + } + } + return strings.Join(texts, "") + } + return "" +} + +func (s *GatewayService) extractTextFromContent(content any) string { + switch v := content.(type) { + case string: + return v + case []any: + var texts []string + for _, part := range v { + if partMap, ok := part.(map[string]any); ok { + if partMap["type"] == "text" { + if text, ok := partMap["text"].(string); ok { + texts = append(texts, text) + } + } + } + } + return strings.Join(texts, "") + } + return "" +} + +func (s *GatewayService) hashContent(content string) string { + hash := sha256.Sum256([]byte(content)) + return hex.EncodeToString(hash[:16]) // 32字符 +} + +// replaceModelInBody 替换请求体中的model字段 +func (s *GatewayService) replaceModelInBody(body []byte, newModel string) []byte { + var req map[string]any + if err := json.Unmarshal(body, &req); err != nil { + return body + } + req["model"] = newModel + newBody, err := json.Marshal(req) + if err != nil { + return body + } + return newBody +} + +// SelectAccount 选择账号(粘性会话+优先级) +func (s *GatewayService) SelectAccount(ctx context.Context, groupID *int64, sessionHash string) (*Account, error) { + return s.SelectAccountForModel(ctx, groupID, sessionHash, "") +} + +// SelectAccountForModel 选择支持指定模型的账号(粘性会话+优先级+模型映射) +func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*Account, error) { + return s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, nil) +} + +// SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts. +func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) { + // 优先检查 context 中的强制平台(/antigravity 路由) + var platform string + forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string) + if hasForcePlatform && forcePlatform != "" { + platform = forcePlatform + } else if groupID != nil { + // 根据分组 platform 决定查询哪种账号 + group, err := s.groupRepo.GetByID(ctx, *groupID) + if err != nil { + return nil, fmt.Errorf("get group failed: %w", err) + } + platform = group.Platform + } else { + // 无分组时只使用原生 anthropic 平台 + platform = PlatformAnthropic + } + + // anthropic/gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户) + // 注意:强制平台模式不走混合调度 + if (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform { + return s.selectAccountWithMixedScheduling(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform) + } + + // 强制平台模式:优先按分组查找,找不到再查全部该平台账户 + if hasForcePlatform && groupID != nil { + account, err := s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform) + if err == nil { + return account, nil + } + // 分组中找不到,回退查询全部该平台账户 + groupID = nil + } + + // antigravity 分组、强制平台模式或无分组使用单平台选择 + return s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform) +} + +// SelectAccountWithLoadAwareness selects account with load-awareness and wait plan. +func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*AccountSelectionResult, error) { + cfg := s.schedulingConfig() + var stickyAccountID int64 + if sessionHash != "" && s.cache != nil { + if accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash); err == nil { + stickyAccountID = accountID + } + } + if s.concurrencyService == nil || !cfg.LoadBatchEnabled { + account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs) + if err != nil { + return nil, err + } + result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency) + if err == nil && result.Acquired { + return &AccountSelectionResult{ + Account: account, + Acquired: true, + ReleaseFunc: result.ReleaseFunc, + }, nil + } + if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil { + waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID) + if waitingCount < cfg.StickySessionMaxWaiting { + return &AccountSelectionResult{ + Account: account, + WaitPlan: &AccountWaitPlan{ + AccountID: account.ID, + MaxConcurrency: account.Concurrency, + Timeout: cfg.StickySessionWaitTimeout, + MaxWaiting: cfg.StickySessionMaxWaiting, + }, + }, nil + } + } + return &AccountSelectionResult{ + Account: account, + WaitPlan: &AccountWaitPlan{ + AccountID: account.ID, + MaxConcurrency: account.Concurrency, + Timeout: cfg.FallbackWaitTimeout, + MaxWaiting: cfg.FallbackMaxWaiting, + }, + }, nil + } + + platform, hasForcePlatform, err := s.resolvePlatform(ctx, groupID) + if err != nil { + return nil, err + } + preferOAuth := platform == PlatformGemini + + accounts, useMixed, err := s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform) + if err != nil { + return nil, err + } + if len(accounts) == 0 { + return nil, errors.New("no available accounts") + } + + isExcluded := func(accountID int64) bool { + if excludedIDs == nil { + return false + } + _, excluded := excludedIDs[accountID] + return excluded + } + + // ============ Layer 1: 粘性会话优先 ============ + if sessionHash != "" && s.cache != nil { + accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash) + if err == nil && accountID > 0 && !isExcluded(accountID) { + account, err := s.accountRepo.GetByID(ctx, accountID) + if err == nil && s.isAccountAllowedForPlatform(account, platform, useMixed) && + account.IsSchedulable() && + (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) { + result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) + if err == nil && result.Acquired { + _ = s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL) + return &AccountSelectionResult{ + Account: account, + Acquired: true, + ReleaseFunc: result.ReleaseFunc, + }, nil + } + + waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID) + if waitingCount < cfg.StickySessionMaxWaiting { + return &AccountSelectionResult{ + Account: account, + WaitPlan: &AccountWaitPlan{ + AccountID: accountID, + MaxConcurrency: account.Concurrency, + Timeout: cfg.StickySessionWaitTimeout, + MaxWaiting: cfg.StickySessionMaxWaiting, + }, + }, nil + } + } + } + } + + // ============ Layer 2: 负载感知选择 ============ + candidates := make([]*Account, 0, len(accounts)) + for i := range accounts { + acc := &accounts[i] + if isExcluded(acc.ID) { + continue + } + if !s.isAccountAllowedForPlatform(acc, platform, useMixed) { + continue + } + if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) { + continue + } + candidates = append(candidates, acc) + } + + if len(candidates) == 0 { + return nil, errors.New("no available accounts") + } + + accountLoads := make([]AccountWithConcurrency, 0, len(candidates)) + for _, acc := range candidates { + accountLoads = append(accountLoads, AccountWithConcurrency{ + ID: acc.ID, + MaxConcurrency: acc.Concurrency, + }) + } + + loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads) + if err != nil { + if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, sessionHash, preferOAuth); ok { + return result, nil + } + } else { + type accountWithLoad struct { + account *Account + loadInfo *AccountLoadInfo + } + var available []accountWithLoad + for _, acc := range candidates { + loadInfo := loadMap[acc.ID] + if loadInfo == nil { + loadInfo = &AccountLoadInfo{AccountID: acc.ID} + } + if loadInfo.LoadRate < 100 { + available = append(available, accountWithLoad{ + account: acc, + loadInfo: loadInfo, + }) + } + } + + if len(available) > 0 { + sort.SliceStable(available, func(i, j int) bool { + a, b := available[i], available[j] + if a.account.Priority != b.account.Priority { + return a.account.Priority < b.account.Priority + } + if a.loadInfo.LoadRate != b.loadInfo.LoadRate { + return a.loadInfo.LoadRate < b.loadInfo.LoadRate + } + switch { + case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil: + return true + case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil: + return false + case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil: + if preferOAuth && a.account.Type != b.account.Type { + return a.account.Type == AccountTypeOAuth + } + return false + default: + return a.account.LastUsedAt.Before(*b.account.LastUsedAt) + } + }) + + for _, item := range available { + result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency) + if err == nil && result.Acquired { + if sessionHash != "" { + _ = s.cache.SetSessionAccountID(ctx, sessionHash, item.account.ID, stickySessionTTL) + } + return &AccountSelectionResult{ + Account: item.account, + Acquired: true, + ReleaseFunc: result.ReleaseFunc, + }, nil + } + } + } + } + + // ============ Layer 3: 兜底排队 ============ + sortAccountsByPriorityAndLastUsed(candidates, preferOAuth) + for _, acc := range candidates { + return &AccountSelectionResult{ + Account: acc, + WaitPlan: &AccountWaitPlan{ + AccountID: acc.ID, + MaxConcurrency: acc.Concurrency, + Timeout: cfg.FallbackWaitTimeout, + MaxWaiting: cfg.FallbackMaxWaiting, + }, + }, nil + } + return nil, errors.New("no available accounts") +} + +func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool) { + ordered := append([]*Account(nil), candidates...) + sortAccountsByPriorityAndLastUsed(ordered, preferOAuth) + + for _, acc := range ordered { + result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency) + if err == nil && result.Acquired { + if sessionHash != "" { + _ = s.cache.SetSessionAccountID(ctx, sessionHash, acc.ID, stickySessionTTL) + } + return &AccountSelectionResult{ + Account: acc, + Acquired: true, + ReleaseFunc: result.ReleaseFunc, + }, true + } + } + + return nil, false +} + +func (s *GatewayService) schedulingConfig() config.GatewaySchedulingConfig { + if s.cfg != nil { + return s.cfg.Gateway.Scheduling + } + return config.GatewaySchedulingConfig{ + StickySessionMaxWaiting: 3, + StickySessionWaitTimeout: 45 * time.Second, + FallbackWaitTimeout: 30 * time.Second, + FallbackMaxWaiting: 100, + LoadBatchEnabled: true, + SlotCleanupInterval: 30 * time.Second, + } +} + +func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64) (string, bool, error) { + forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string) + if hasForcePlatform && forcePlatform != "" { + return forcePlatform, true, nil + } + if groupID != nil { + group, err := s.groupRepo.GetByID(ctx, *groupID) + if err != nil { + return "", false, fmt.Errorf("get group failed: %w", err) + } + return group.Platform, false, nil + } + return PlatformAnthropic, false, nil +} + +func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, bool, error) { + useMixed := (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform + if useMixed { + platforms := []string{platform, PlatformAntigravity} + var accounts []Account + var err error + if groupID != nil { + accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, platforms) + } else { + accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms) + } + if err != nil { + return nil, useMixed, err + } + filtered := make([]Account, 0, len(accounts)) + for _, acc := range accounts { + if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() { + continue + } + filtered = append(filtered, acc) + } + return filtered, useMixed, nil + } + + var accounts []Account + var err error + if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { + accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform) + } else if groupID != nil { + accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, platform) + if err == nil && len(accounts) == 0 && hasForcePlatform { + accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform) + } + } else { + accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform) + } + if err != nil { + return nil, useMixed, err + } + return accounts, useMixed, nil +} + +func (s *GatewayService) isAccountAllowedForPlatform(account *Account, platform string, useMixed bool) bool { + if account == nil { + return false + } + if useMixed { + if account.Platform == platform { + return true + } + return account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled() + } + return account.Platform == platform +} + +func (s *GatewayService) tryAcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (*AcquireResult, error) { + if s.concurrencyService == nil { + return &AcquireResult{Acquired: true, ReleaseFunc: func() {}}, nil + } + return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency) +} + +func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) { + sort.SliceStable(accounts, func(i, j int) bool { + a, b := accounts[i], accounts[j] + if a.Priority != b.Priority { + return a.Priority < b.Priority + } + switch { + case a.LastUsedAt == nil && b.LastUsedAt != nil: + return true + case a.LastUsedAt != nil && b.LastUsedAt == nil: + return false + case a.LastUsedAt == nil && b.LastUsedAt == nil: + if preferOAuth && a.Type != b.Type { + return a.Type == AccountTypeOAuth + } + return false + default: + return a.LastUsedAt.Before(*b.LastUsedAt) + } + }) +} + +// selectAccountForModelWithPlatform 选择单平台账户(完全隔离) +func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) { + preferOAuth := platform == PlatformGemini + // 1. 查询粘性会话 + if sessionHash != "" { + accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash) + if err == nil && accountID > 0 { + if _, excluded := excludedIDs[accountID]; !excluded { + account, err := s.accountRepo.GetByID(ctx, accountID) + // 检查账号平台是否匹配(确保粘性会话不会跨平台) + if err == nil && account.Platform == platform && account.IsSchedulable() && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) { + if err := s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL); err != nil { + log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err) + } + return account, nil + } + } + } + } + + // 2. 获取可调度账号列表(单平台) + var accounts []Account + var err error + if s.cfg.RunMode == config.RunModeSimple { + // 简易模式:忽略 groupID,查询所有可用账号 + accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform) + } else if groupID != nil { + accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, platform) + } else { + accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform) + } + if err != nil { + return nil, fmt.Errorf("query accounts failed: %w", err) + } + + // 3. 按优先级+最久未用选择(考虑模型支持) + var selected *Account + for i := range accounts { + acc := &accounts[i] + if _, excluded := excludedIDs[acc.ID]; excluded { + continue + } + if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) { + continue + } + if selected == nil { + selected = acc + continue + } + if acc.Priority < selected.Priority { + selected = acc + } else if acc.Priority == selected.Priority { + switch { + case acc.LastUsedAt == nil && selected.LastUsedAt != nil: + selected = acc + case acc.LastUsedAt != nil && selected.LastUsedAt == nil: + // keep selected (never used is preferred) + case acc.LastUsedAt == nil && selected.LastUsedAt == nil: + if preferOAuth && acc.Type != selected.Type && acc.Type == AccountTypeOAuth { + selected = acc + } + default: + if acc.LastUsedAt.Before(*selected.LastUsedAt) { + selected = acc + } + } + } + } + + if selected == nil { + if requestedModel != "" { + return nil, fmt.Errorf("no available accounts supporting model: %s", requestedModel) + } + return nil, errors.New("no available accounts") + } + + // 4. 建立粘性绑定 + if sessionHash != "" { + if err := s.cache.SetSessionAccountID(ctx, sessionHash, selected.ID, stickySessionTTL); err != nil { + log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err) + } + } + + return selected, nil +} + +// selectAccountWithMixedScheduling 选择账户(支持混合调度) +// 查询原生平台账户 + 启用 mixed_scheduling 的 antigravity 账户 +func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, nativePlatform string) (*Account, error) { + platforms := []string{nativePlatform, PlatformAntigravity} + preferOAuth := nativePlatform == PlatformGemini + + // 1. 查询粘性会话 + if sessionHash != "" { + accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash) + if err == nil && accountID > 0 { + if _, excluded := excludedIDs[accountID]; !excluded { + account, err := s.accountRepo.GetByID(ctx, accountID) + // 检查账号是否有效:原生平台直接匹配,antigravity 需要启用混合调度 + if err == nil && account.IsSchedulable() && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) { + if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) { + if err := s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL); err != nil { + log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err) + } + return account, nil + } + } + } + } + } + + // 2. 获取可调度账号列表 + var accounts []Account + var err error + if groupID != nil { + accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, platforms) + } else { + accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms) + } + if err != nil { + return nil, fmt.Errorf("query accounts failed: %w", err) + } + + // 3. 按优先级+最久未用选择(考虑模型支持和混合调度) + var selected *Account + for i := range accounts { + acc := &accounts[i] + if _, excluded := excludedIDs[acc.ID]; excluded { + continue + } + // 过滤:原生平台直接通过,antigravity 需要启用混合调度 + if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() { + continue + } + if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) { + continue + } + if selected == nil { + selected = acc + continue + } + if acc.Priority < selected.Priority { + selected = acc + } else if acc.Priority == selected.Priority { + switch { + case acc.LastUsedAt == nil && selected.LastUsedAt != nil: + selected = acc + case acc.LastUsedAt != nil && selected.LastUsedAt == nil: + // keep selected (never used is preferred) + case acc.LastUsedAt == nil && selected.LastUsedAt == nil: + if preferOAuth && acc.Platform == PlatformGemini && selected.Platform == PlatformGemini && acc.Type != selected.Type && acc.Type == AccountTypeOAuth { + selected = acc + } + default: + if acc.LastUsedAt.Before(*selected.LastUsedAt) { + selected = acc + } + } + } + } + + if selected == nil { + if requestedModel != "" { + return nil, fmt.Errorf("no available accounts supporting model: %s", requestedModel) + } + return nil, errors.New("no available accounts") + } + + // 4. 建立粘性绑定 + if sessionHash != "" { + if err := s.cache.SetSessionAccountID(ctx, sessionHash, selected.ID, stickySessionTTL); err != nil { + log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err) + } + } + + return selected, nil +} + +// isModelSupportedByAccount 根据账户平台检查模型支持 +func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedModel string) bool { + if account.Platform == PlatformAntigravity { + // Antigravity 平台使用专门的模型支持检查 + return IsAntigravityModelSupported(requestedModel) + } + // 其他平台使用账户的模型支持检查 + return account.IsModelSupported(requestedModel) +} + +// IsAntigravityModelSupported 检查 Antigravity 平台是否支持指定模型 +// 所有 claude- 和 gemini- 前缀的模型都能通过映射或透传支持 +func IsAntigravityModelSupported(requestedModel string) bool { + return strings.HasPrefix(requestedModel, "claude-") || + strings.HasPrefix(requestedModel, "gemini-") +} + +// GetAccessToken 获取账号凭证 +func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) { + switch account.Type { + case AccountTypeOAuth, AccountTypeSetupToken: + // Both oauth and setup-token use OAuth token flow + return s.getOAuthToken(ctx, account) + case AccountTypeApiKey: + apiKey := account.GetCredential("api_key") + if apiKey == "" { + return "", "", errors.New("api_key not found in credentials") + } + return apiKey, "apikey", nil + default: + return "", "", fmt.Errorf("unsupported account type: %s", account.Type) + } +} + +func (s *GatewayService) getOAuthToken(ctx context.Context, account *Account) (string, string, error) { + accessToken := account.GetCredential("access_token") + if accessToken == "" { + return "", "", errors.New("access_token not found in credentials") + } + // Token刷新由后台 TokenRefreshService 处理,此处只返回当前token + return accessToken, "oauth", nil +} + +// 重试相关常量 +const ( + maxRetries = 10 // 最大重试次数 + retryDelay = 3 * time.Second // 重试等待时间 +) + +func (s *GatewayService) shouldRetryUpstreamError(account *Account, statusCode int) bool { + // OAuth/Setup Token 账号:仅 403 重试 + if account.IsOAuth() { + return statusCode == 403 + } + + // API Key 账号:未配置的错误码重试 + return !account.ShouldHandleErrorCode(statusCode) +} + +// shouldFailoverUpstreamError determines whether an upstream error should trigger account failover. +func (s *GatewayService) shouldFailoverUpstreamError(statusCode int) bool { + switch statusCode { + case 401, 403, 429, 529: + return true + default: + return statusCode >= 500 + } +} + +// isClaudeCodeClient 判断请求是否来自 Claude Code 客户端 +// 简化判断:User-Agent 匹配 + metadata.user_id 存在 +func isClaudeCodeClient(userAgent string, metadataUserID string) bool { + if metadataUserID == "" { + return false + } + return claudeCliUserAgentRe.MatchString(userAgent) +} + +// systemIncludesClaudeCodePrompt 检查 system 中是否已包含 Claude Code 提示词 +// 支持 string 和 []any 两种格式 +func systemIncludesClaudeCodePrompt(system any) bool { + switch v := system.(type) { + case string: + return v == claudeCodeSystemPrompt + case []any: + for _, item := range v { + if m, ok := item.(map[string]any); ok { + if text, ok := m["text"].(string); ok && text == claudeCodeSystemPrompt { + return true + } + } + } + } + return false +} + +// injectClaudeCodePrompt 在 system 开头注入 Claude Code 提示词 +// 处理 null、字符串、数组三种格式 +func injectClaudeCodePrompt(body []byte, system any) []byte { + claudeCodeBlock := map[string]any{ + "type": "text", + "text": claudeCodeSystemPrompt, + "cache_control": map[string]string{"type": "ephemeral"}, + } + + var newSystem []any + + switch v := system.(type) { + case nil: + newSystem = []any{claudeCodeBlock} + case string: + if v == "" || v == claudeCodeSystemPrompt { + newSystem = []any{claudeCodeBlock} + } else { + newSystem = []any{claudeCodeBlock, map[string]any{"type": "text", "text": v}} + } + case []any: + newSystem = make([]any, 0, len(v)+1) + newSystem = append(newSystem, claudeCodeBlock) + for _, item := range v { + if m, ok := item.(map[string]any); ok { + if text, ok := m["text"].(string); ok && text == claudeCodeSystemPrompt { + continue + } + } + newSystem = append(newSystem, item) + } + default: + newSystem = []any{claudeCodeBlock} + } + + result, err := sjson.SetBytes(body, "system", newSystem) + if err != nil { + log.Printf("Warning: failed to inject Claude Code prompt: %v", err) + return body + } + return result +} + +// Forward 转发请求到Claude API +func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) (*ForwardResult, error) { + startTime := time.Now() + if parsed == nil { + return nil, fmt.Errorf("parse request: empty request") + } + + body := parsed.Body + reqModel := parsed.Model + reqStream := parsed.Stream + + // 智能注入 Claude Code 系统提示词(仅 OAuth/SetupToken 账号需要) + // 条件:1) OAuth/SetupToken 账号 2) 不是 Claude Code 客户端 3) 不是 Haiku 模型 4) system 中还没有 Claude Code 提示词 + if account.IsOAuth() && + !isClaudeCodeClient(c.GetHeader("User-Agent"), parsed.MetadataUserID) && + !strings.Contains(strings.ToLower(reqModel), "haiku") && + !systemIncludesClaudeCodePrompt(parsed.System) { + body = injectClaudeCodePrompt(body, parsed.System) + } + + // 应用模型映射(仅对apikey类型账号) + originalModel := reqModel + if account.Type == AccountTypeApiKey { + mappedModel := account.GetMappedModel(reqModel) + if mappedModel != reqModel { + // 替换请求体中的模型名 + body = s.replaceModelInBody(body, mappedModel) + reqModel = mappedModel + log.Printf("Model mapping applied: %s -> %s (account: %s)", originalModel, mappedModel, account.Name) + } + } + + // 获取凭证 + token, tokenType, err := s.GetAccessToken(ctx, account) + if err != nil { + return nil, err + } + + // 获取代理URL + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + // 重试循环 + var resp *http.Response + for attempt := 1; attempt <= maxRetries; attempt++ { + // 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取) + upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType, reqModel) + if err != nil { + return nil, err + } + + // 发送请求 + resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) + if err != nil { + return nil, fmt.Errorf("upstream request failed: %w", err) + } + + // 检查是否需要重试 + if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) { + if attempt < maxRetries { + log.Printf("Account %d: upstream error %d, retry %d/%d after %v", + account.ID, resp.StatusCode, attempt, maxRetries, retryDelay) + _ = resp.Body.Close() + time.Sleep(retryDelay) + continue + } + // 最后一次尝试也失败,跳出循环处理重试耗尽 + break + } + + // 不需要重试(成功或不可重试的错误),跳出循环 + break + } + defer func() { _ = resp.Body.Close() }() + + // 处理重试耗尽的情况 + if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) { + if s.shouldFailoverUpstreamError(resp.StatusCode) { + s.handleRetryExhaustedSideEffects(ctx, resp, account) + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} + } + return s.handleRetryExhaustedError(ctx, resp, c, account) + } + + // 处理可切换账号的错误 + if resp.StatusCode >= 400 && s.shouldFailoverUpstreamError(resp.StatusCode) { + s.handleFailoverSideEffects(ctx, resp, account) + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} + } + + // 处理错误响应(不可重试的错误) + if resp.StatusCode >= 400 { + // 可选:对部分 400 触发 failover(默认关闭以保持语义) + if resp.StatusCode == 400 && s.cfg != nil && s.cfg.Gateway.FailoverOn400 { + respBody, readErr := io.ReadAll(resp.Body) + if readErr != nil { + // ReadAll failed, fall back to normal error handling without consuming the stream + return s.handleErrorResponse(ctx, resp, c, account) + } + _ = resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + + if s.shouldFailoverOn400(respBody) { + if s.cfg.Gateway.LogUpstreamErrorBody { + log.Printf( + "Account %d: 400 error, attempting failover: %s", + account.ID, + truncateForLog(respBody, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes), + ) + } else { + log.Printf("Account %d: 400 error, attempting failover", account.ID) + } + s.handleFailoverSideEffects(ctx, resp, account) + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} + } + } + return s.handleErrorResponse(ctx, resp, c, account) + } + + // 处理正常响应 + var usage *ClaudeUsage + var firstTokenMs *int + if reqStream { + streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, reqModel) + if err != nil { + if err.Error() == "have error in stream" { + return nil, &UpstreamFailoverError{ + StatusCode: 403, + } + } + return nil, err + } + usage = streamResult.usage + firstTokenMs = streamResult.firstTokenMs + } else { + usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, reqModel) + if err != nil { + return nil, err + } + } + + return &ForwardResult{ + RequestID: resp.Header.Get("x-request-id"), + Usage: *usage, + Model: originalModel, // 使用原始模型用于计费和日志 + Stream: reqStream, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + }, nil +} + +func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) { + // 确定目标URL + targetURL := claudeAPIURL + if account.Type == AccountTypeApiKey { + baseURL := account.GetBaseURL() + targetURL = baseURL + "/v1/messages" + } + + // OAuth账号:应用统一指纹 + var fingerprint *Fingerprint + if account.IsOAuth() && s.identityService != nil { + // 1. 获取或创建指纹(包含随机生成的ClientID) + fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header) + if err != nil { + log.Printf("Warning: failed to get fingerprint for account %d: %v", account.ID, err) + // 失败时降级为透传原始headers + } else { + fingerprint = fp + + // 2. 重写metadata.user_id(需要指纹中的ClientID和账号的account_uuid) + accountUUID := account.GetExtraString("account_uuid") + if accountUUID != "" && fp.ClientID != "" { + if newBody, err := s.identityService.RewriteUserID(body, account.ID, accountUUID, fp.ClientID); err == nil && len(newBody) > 0 { + body = newBody + } + } + } + } + + req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body)) + if err != nil { + return nil, err + } + + // 设置认证头 + if tokenType == "oauth" { + req.Header.Set("authorization", "Bearer "+token) + } else { + req.Header.Set("x-api-key", token) + } + + // 白名单透传headers + for key, values := range c.Request.Header { + lowerKey := strings.ToLower(key) + if allowedHeaders[lowerKey] { + for _, v := range values { + req.Header.Add(key, v) + } + } + } + + // OAuth账号:应用缓存的指纹到请求头(覆盖白名单透传的头) + if fingerprint != nil { + s.identityService.ApplyFingerprint(req, fingerprint) + } + + // 确保必要的headers存在 + if req.Header.Get("content-type") == "" { + req.Header.Set("content-type", "application/json") + } + if req.Header.Get("anthropic-version") == "" { + req.Header.Set("anthropic-version", "2023-06-01") + } + + // 处理anthropic-beta header(OAuth账号需要特殊处理) + if tokenType == "oauth" { + req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta"))) + } else if s.cfg != nil && s.cfg.Gateway.InjectBetaForApiKey && req.Header.Get("anthropic-beta") == "" { + // API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭) + if requestNeedsBetaFeatures(body) { + if beta := defaultApiKeyBetaHeader(body); beta != "" { + req.Header.Set("anthropic-beta", beta) + } + } + } + + return req, nil +} + +// getBetaHeader 处理anthropic-beta header +// 对于OAuth账号,需要确保包含oauth-2025-04-20 +func (s *GatewayService) getBetaHeader(modelID string, clientBetaHeader string) string { + // 如果客户端传了anthropic-beta + if clientBetaHeader != "" { + // 已包含oauth beta则直接返回 + if strings.Contains(clientBetaHeader, claude.BetaOAuth) { + return clientBetaHeader + } + + // 需要添加oauth beta + parts := strings.Split(clientBetaHeader, ",") + for i, p := range parts { + parts[i] = strings.TrimSpace(p) + } + + // 在claude-code-20250219后面插入oauth beta + claudeCodeIdx := -1 + for i, p := range parts { + if p == claude.BetaClaudeCode { + claudeCodeIdx = i + break + } + } + + if claudeCodeIdx >= 0 { + // 在claude-code后面插入 + newParts := make([]string, 0, len(parts)+1) + newParts = append(newParts, parts[:claudeCodeIdx+1]...) + newParts = append(newParts, claude.BetaOAuth) + newParts = append(newParts, parts[claudeCodeIdx+1:]...) + return strings.Join(newParts, ",") + } + + // 没有claude-code,放在第一位 + return claude.BetaOAuth + "," + clientBetaHeader + } + + // 客户端没传,根据模型生成 + // haiku 模型不需要 claude-code beta + if strings.Contains(strings.ToLower(modelID), "haiku") { + return claude.HaikuBetaHeader + } + + return claude.DefaultBetaHeader +} + +func requestNeedsBetaFeatures(body []byte) bool { + tools := gjson.GetBytes(body, "tools") + if tools.Exists() && tools.IsArray() && len(tools.Array()) > 0 { + return true + } + if strings.EqualFold(gjson.GetBytes(body, "thinking.type").String(), "enabled") { + return true + } + return false +} + +func defaultApiKeyBetaHeader(body []byte) string { + modelID := gjson.GetBytes(body, "model").String() + if strings.Contains(strings.ToLower(modelID), "haiku") { + return claude.ApiKeyHaikuBetaHeader + } + return claude.ApiKeyBetaHeader +} + +func truncateForLog(b []byte, maxBytes int) string { + if maxBytes <= 0 { + maxBytes = 2048 + } + if len(b) > maxBytes { + b = b[:maxBytes] + } + s := string(b) + // 保持一行,避免污染日志格式 + s = strings.ReplaceAll(s, "\n", "\\n") + s = strings.ReplaceAll(s, "\r", "\\r") + return s +} + +func (s *GatewayService) shouldFailoverOn400(respBody []byte) bool { + // 只对“可能是兼容性差异导致”的 400 允许切换,避免无意义重试。 + // 默认保守:无法识别则不切换。 + msg := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(respBody))) + if msg == "" { + return false + } + + // 缺少/错误的 beta header:换账号/链路可能成功(尤其是混合调度时)。 + // 更精确匹配 beta 相关的兼容性问题,避免误触发切换。 + if strings.Contains(msg, "anthropic-beta") || + strings.Contains(msg, "beta feature") || + strings.Contains(msg, "requires beta") { + return true + } + + // thinking/tool streaming 等兼容性约束(常见于中间转换链路) + if strings.Contains(msg, "thinking") || strings.Contains(msg, "thought_signature") || strings.Contains(msg, "signature") { + return true + } + if strings.Contains(msg, "tool_use") || strings.Contains(msg, "tool_result") || strings.Contains(msg, "tools") { + return true + } + + return false +} + +func extractUpstreamErrorMessage(body []byte) string { + // Claude 风格:{"type":"error","error":{"type":"...","message":"..."}} + if m := gjson.GetBytes(body, "error.message").String(); strings.TrimSpace(m) != "" { + inner := strings.TrimSpace(m) + // 有些上游会把完整 JSON 作为字符串塞进 message + if strings.HasPrefix(inner, "{") { + if innerMsg := gjson.Get(inner, "error.message").String(); strings.TrimSpace(innerMsg) != "" { + return innerMsg + } + } + return m + } + + // 兜底:尝试顶层 message + return gjson.GetBytes(body, "message").String() +} + +func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) { + body, _ := io.ReadAll(resp.Body) + + // 处理上游错误,标记账号状态 + s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) + + // 根据状态码返回适当的自定义错误响应(不透传上游详细信息) + var errType, errMsg string + var statusCode int + + switch resp.StatusCode { + case 400: + // 仅记录上游错误摘要(避免输出请求内容);需要时可通过配置打开 + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + log.Printf( + "Upstream 400 error (account=%d platform=%s type=%s): %s", + account.ID, + account.Platform, + account.Type, + truncateForLog(body, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes), + ) + } + c.Data(http.StatusBadRequest, "application/json", body) + return nil, fmt.Errorf("upstream error: %d", resp.StatusCode) + case 401: + statusCode = http.StatusBadGateway + errType = "upstream_error" + errMsg = "Upstream authentication failed, please contact administrator" + case 403: + statusCode = http.StatusBadGateway + errType = "upstream_error" + errMsg = "Upstream access forbidden, please contact administrator" + case 429: + statusCode = http.StatusTooManyRequests + errType = "rate_limit_error" + errMsg = "Upstream rate limit exceeded, please retry later" + case 529: + statusCode = http.StatusServiceUnavailable + errType = "overloaded_error" + errMsg = "Upstream service overloaded, please retry later" + case 500, 502, 503, 504: + statusCode = http.StatusBadGateway + errType = "upstream_error" + errMsg = "Upstream service temporarily unavailable" + default: + statusCode = http.StatusBadGateway + errType = "upstream_error" + errMsg = "Upstream request failed" + } + + // 返回自定义错误响应 + c.JSON(statusCode, gin.H{ + "type": "error", + "error": gin.H{ + "type": errType, + "message": errMsg, + }, + }) + + return nil, fmt.Errorf("upstream error: %d", resp.StatusCode) +} + +func (s *GatewayService) handleRetryExhaustedSideEffects(ctx context.Context, resp *http.Response, account *Account) { + body, _ := io.ReadAll(resp.Body) + statusCode := resp.StatusCode + + // OAuth/Setup Token 账号的 403:标记账号异常 + if account.IsOAuth() && statusCode == 403 { + s.rateLimitService.HandleUpstreamError(ctx, account, statusCode, resp.Header, body) + log.Printf("Account %d: marked as error after %d retries for status %d", account.ID, maxRetries, statusCode) + } else { + // API Key 未配置错误码:不标记账号状态 + log.Printf("Account %d: upstream error %d after %d retries (not marking account)", account.ID, statusCode, maxRetries) + } +} + +func (s *GatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account) { + body, _ := io.ReadAll(resp.Body) + s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) +} + +// handleRetryExhaustedError 处理重试耗尽后的错误 +// OAuth 403:标记账号异常 +// API Key 未配置错误码:仅返回错误,不标记账号 +func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) { + s.handleRetryExhaustedSideEffects(ctx, resp, account) + + // 返回统一的重试耗尽错误响应 + c.JSON(http.StatusBadGateway, gin.H{ + "type": "error", + "error": gin.H{ + "type": "upstream_error", + "message": "Upstream request failed after retries", + }, + }) + + return nil, fmt.Errorf("upstream error: %d (retries exhausted)", resp.StatusCode) +} + +// streamingResult 流式响应结果 +type streamingResult struct { + usage *ClaudeUsage + firstTokenMs *int +} + +func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string) (*streamingResult, error) { + // 更新5h窗口状态 + s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header) + + // 设置SSE响应头 + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("X-Accel-Buffering", "no") + + // 透传其他响应头 + if v := resp.Header.Get("x-request-id"); v != "" { + c.Header("x-request-id", v) + } + + w := c.Writer + flusher, ok := w.(http.Flusher) + if !ok { + return nil, errors.New("streaming not supported") + } + + usage := &ClaudeUsage{} + var firstTokenMs *int + scanner := bufio.NewScanner(resp.Body) + // 设置更大的buffer以处理长行 + scanner.Buffer(make([]byte, 64*1024), 1024*1024) + + needModelReplace := originalModel != mappedModel + + for scanner.Scan() { + line := scanner.Text() + if line == "event: error" { + return nil, errors.New("have error in stream") + } + + // Extract data from SSE line (supports both "data: " and "data:" formats) + if sseDataRe.MatchString(line) { + data := sseDataRe.ReplaceAllString(line, "") + + // 如果有模型映射,替换响应中的model字段 + if needModelReplace { + line = s.replaceModelInSSELine(line, mappedModel, originalModel) + } + + // 转发行 + if _, err := fmt.Fprintf(w, "%s\n", line); err != nil { + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err + } + flusher.Flush() + + // 记录首字时间:第一个有效的 content_block_delta 或 message_start + if firstTokenMs == nil && data != "" && data != "[DONE]" { + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + s.parseSSEUsage(data, usage) + } else { + // 非 data 行直接转发 + if _, err := fmt.Fprintf(w, "%s\n", line); err != nil { + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err + } + flusher.Flush() + } + } + + if err := scanner.Err(); err != nil { + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err) + } + + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil +} + +// replaceModelInSSELine 替换SSE数据行中的model字段 +func (s *GatewayService) replaceModelInSSELine(line, fromModel, toModel string) string { + if !sseDataRe.MatchString(line) { + return line + } + data := sseDataRe.ReplaceAllString(line, "") + if data == "" || data == "[DONE]" { + return line + } + + var event map[string]any + if err := json.Unmarshal([]byte(data), &event); err != nil { + return line + } + + // 只替换 message_start 事件中的 message.model + if event["type"] != "message_start" { + return line + } + + msg, ok := event["message"].(map[string]any) + if !ok { + return line + } + + model, ok := msg["model"].(string) + if !ok || model != fromModel { + return line + } + + msg["model"] = toModel + newData, err := json.Marshal(event) + if err != nil { + return line + } + + return "data: " + string(newData) +} + +func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) { + // 解析message_start获取input tokens(标准Claude API格式) + var msgStart struct { + Type string `json:"type"` + Message struct { + Usage ClaudeUsage `json:"usage"` + } `json:"message"` + } + if json.Unmarshal([]byte(data), &msgStart) == nil && msgStart.Type == "message_start" { + usage.InputTokens = msgStart.Message.Usage.InputTokens + usage.CacheCreationInputTokens = msgStart.Message.Usage.CacheCreationInputTokens + usage.CacheReadInputTokens = msgStart.Message.Usage.CacheReadInputTokens + } + + // 解析message_delta获取tokens(兼容GLM等把所有usage放在delta中的API) + var msgDelta struct { + Type string `json:"type"` + Usage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + CacheCreationInputTokens int `json:"cache_creation_input_tokens"` + CacheReadInputTokens int `json:"cache_read_input_tokens"` + } `json:"usage"` + } + if json.Unmarshal([]byte(data), &msgDelta) == nil && msgDelta.Type == "message_delta" { + // output_tokens 总是从 message_delta 获取 + usage.OutputTokens = msgDelta.Usage.OutputTokens + + // 如果 message_start 中没有值,则从 message_delta 获取(兼容GLM等API) + if usage.InputTokens == 0 { + usage.InputTokens = msgDelta.Usage.InputTokens + } + if usage.CacheCreationInputTokens == 0 { + usage.CacheCreationInputTokens = msgDelta.Usage.CacheCreationInputTokens + } + if usage.CacheReadInputTokens == 0 { + usage.CacheReadInputTokens = msgDelta.Usage.CacheReadInputTokens + } + } +} + +func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*ClaudeUsage, error) { + // 更新5h窗口状态 + s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header) + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + // 解析usage + var response struct { + Usage ClaudeUsage `json:"usage"` + } + if err := json.Unmarshal(body, &response); err != nil { + return nil, fmt.Errorf("parse response: %w", err) + } + + // 如果有模型映射,替换响应中的model字段 + if originalModel != mappedModel { + body = s.replaceModelInResponseBody(body, mappedModel, originalModel) + } + + // 透传响应头 + for key, values := range resp.Header { + for _, value := range values { + c.Header(key, value) + } + } + + // 写入响应 + c.Data(resp.StatusCode, "application/json", body) + + return &response.Usage, nil +} + +// replaceModelInResponseBody 替换响应体中的model字段 +func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toModel string) []byte { + var resp map[string]any + if err := json.Unmarshal(body, &resp); err != nil { + return body + } + + model, ok := resp["model"].(string) + if !ok || model != fromModel { + return body + } + + resp["model"] = toModel + newBody, err := json.Marshal(resp) + if err != nil { + return body + } + + return newBody +} + +// RecordUsageInput 记录使用量的输入参数 +type RecordUsageInput struct { + Result *ForwardResult + ApiKey *ApiKey + User *User + Account *Account + Subscription *UserSubscription // 可选:订阅信息 +} + +// RecordUsage 记录使用量并扣费(或更新订阅用量) +func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInput) error { + result := input.Result + apiKey := input.ApiKey + user := input.User + account := input.Account + subscription := input.Subscription + + // 计算费用 + tokens := UsageTokens{ + InputTokens: result.Usage.InputTokens, + OutputTokens: result.Usage.OutputTokens, + CacheCreationTokens: result.Usage.CacheCreationInputTokens, + CacheReadTokens: result.Usage.CacheReadInputTokens, + } + + // 获取费率倍数 + multiplier := s.cfg.Default.RateMultiplier + if apiKey.GroupID != nil && apiKey.Group != nil { + multiplier = apiKey.Group.RateMultiplier + } + + cost, err := s.billingService.CalculateCost(result.Model, tokens, multiplier) + if err != nil { + log.Printf("Calculate cost failed: %v", err) + // 使用默认费用继续 + cost = &CostBreakdown{ActualCost: 0} + } + + // 判断计费方式:订阅模式 vs 余额模式 + isSubscriptionBilling := subscription != nil && apiKey.Group != nil && apiKey.Group.IsSubscriptionType() + billingType := BillingTypeBalance + if isSubscriptionBilling { + billingType = BillingTypeSubscription + } + + // 创建使用日志 + durationMs := int(result.Duration.Milliseconds()) + usageLog := &UsageLog{ + UserID: user.ID, + ApiKeyID: apiKey.ID, + AccountID: account.ID, + RequestID: result.RequestID, + Model: result.Model, + InputTokens: result.Usage.InputTokens, + OutputTokens: result.Usage.OutputTokens, + CacheCreationTokens: result.Usage.CacheCreationInputTokens, + CacheReadTokens: result.Usage.CacheReadInputTokens, + InputCost: cost.InputCost, + OutputCost: cost.OutputCost, + CacheCreationCost: cost.CacheCreationCost, + CacheReadCost: cost.CacheReadCost, + TotalCost: cost.TotalCost, + ActualCost: cost.ActualCost, + RateMultiplier: multiplier, + BillingType: billingType, + Stream: result.Stream, + DurationMs: &durationMs, + FirstTokenMs: result.FirstTokenMs, + CreatedAt: time.Now(), + } + + // 添加分组和订阅关联 + if apiKey.GroupID != nil { + usageLog.GroupID = apiKey.GroupID + } + if subscription != nil { + usageLog.SubscriptionID = &subscription.ID + } + + if err := s.usageLogRepo.Create(ctx, usageLog); err != nil { + log.Printf("Create usage log failed: %v", err) + } + + if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { + log.Printf("[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens()) + s.deferredService.ScheduleLastUsedUpdate(account.ID) + return nil + } + + // 根据计费类型执行扣费 + if isSubscriptionBilling { + // 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率) + if cost.TotalCost > 0 { + if err := s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost); err != nil { + log.Printf("Increment subscription usage failed: %v", err) + } + // 异步更新订阅缓存 + s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost) + } + } else { + // 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用) + if cost.ActualCost > 0 { + if err := s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost); err != nil { + log.Printf("Deduct balance failed: %v", err) + } + // 异步更新余额缓存 + s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost) + } + } + + // Schedule batch update for account last_used_at + s.deferredService.ScheduleLastUsedUpdate(account.ID) + + return nil +} + +// ForwardCountTokens 转发 count_tokens 请求到上游 API +// 特点:不记录使用量、仅支持非流式响应 +func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) error { + if parsed == nil { + s.countTokensError(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty") + return fmt.Errorf("parse request: empty request") + } + + body := parsed.Body + reqModel := parsed.Model + + // Antigravity 账户不支持 count_tokens 转发,直接返回空值 + if account.Platform == PlatformAntigravity { + c.JSON(http.StatusOK, gin.H{"input_tokens": 0}) + return nil + } + + // 应用模型映射(仅对 apikey 类型账号) + if account.Type == AccountTypeApiKey { + if reqModel != "" { + mappedModel := account.GetMappedModel(reqModel) + if mappedModel != reqModel { + body = s.replaceModelInBody(body, mappedModel) + reqModel = mappedModel + log.Printf("CountTokens model mapping applied: %s -> %s (account: %s)", parsed.Model, mappedModel, account.Name) + } + } + } + + // 获取凭证 + token, tokenType, err := s.GetAccessToken(ctx, account) + if err != nil { + s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to get access token") + return err + } + + // 构建上游请求 + upstreamReq, err := s.buildCountTokensRequest(ctx, c, account, body, token, tokenType, reqModel) + if err != nil { + s.countTokensError(c, http.StatusInternalServerError, "api_error", "Failed to build request") + return err + } + + // 获取代理URL + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + // 发送请求 + resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) + if err != nil { + s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Request failed") + return fmt.Errorf("upstream request failed: %w", err) + } + defer func() { + _ = resp.Body.Close() + }() + + // 读取响应体 + respBody, err := io.ReadAll(resp.Body) + if err != nil { + s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response") + return err + } + + // 处理错误响应 + if resp.StatusCode >= 400 { + // 标记账号状态(429/529等) + s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + + // 记录上游错误摘要便于排障(不回显请求内容) + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + log.Printf( + "count_tokens upstream error %d (account=%d platform=%s type=%s): %s", + resp.StatusCode, + account.ID, + account.Platform, + account.Type, + truncateForLog(respBody, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes), + ) + } + + // 返回简化的错误响应 + errMsg := "Upstream request failed" + switch resp.StatusCode { + case 429: + errMsg = "Rate limit exceeded" + case 529: + errMsg = "Service overloaded" + } + s.countTokensError(c, resp.StatusCode, "upstream_error", errMsg) + return fmt.Errorf("upstream error: %d", resp.StatusCode) + } + + // 透传成功响应 + c.Data(resp.StatusCode, "application/json", respBody) + return nil +} + +// buildCountTokensRequest 构建 count_tokens 上游请求 +func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) { + // 确定目标 URL + targetURL := claudeAPICountTokensURL + if account.Type == AccountTypeApiKey { + baseURL := account.GetBaseURL() + targetURL = baseURL + "/v1/messages/count_tokens" + } + + // OAuth 账号:应用统一指纹和重写 userID + if account.IsOAuth() && s.identityService != nil { + fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header) + if err == nil { + accountUUID := account.GetExtraString("account_uuid") + if accountUUID != "" && fp.ClientID != "" { + if newBody, err := s.identityService.RewriteUserID(body, account.ID, accountUUID, fp.ClientID); err == nil && len(newBody) > 0 { + body = newBody + } + } + } + } + + req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body)) + if err != nil { + return nil, err + } + + // 设置认证头 + if tokenType == "oauth" { + req.Header.Set("authorization", "Bearer "+token) + } else { + req.Header.Set("x-api-key", token) + } + + // 白名单透传 headers + for key, values := range c.Request.Header { + lowerKey := strings.ToLower(key) + if allowedHeaders[lowerKey] { + for _, v := range values { + req.Header.Add(key, v) + } + } + } + + // OAuth 账号:应用指纹到请求头 + if account.IsOAuth() && s.identityService != nil { + fp, _ := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header) + if fp != nil { + s.identityService.ApplyFingerprint(req, fp) + } + } + + // 确保必要的 headers 存在 + if req.Header.Get("content-type") == "" { + req.Header.Set("content-type", "application/json") + } + if req.Header.Get("anthropic-version") == "" { + req.Header.Set("anthropic-version", "2023-06-01") + } + + // OAuth 账号:处理 anthropic-beta header + if tokenType == "oauth" { + req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta"))) + } else if s.cfg != nil && s.cfg.Gateway.InjectBetaForApiKey && req.Header.Get("anthropic-beta") == "" { + // API-key:与 messages 同步的按需 beta 注入(默认关闭) + if requestNeedsBetaFeatures(body) { + if beta := defaultApiKeyBetaHeader(body); beta != "" { + req.Header.Set("anthropic-beta", beta) + } + } + } + + return req, nil +} + +// countTokensError 返回 count_tokens 错误响应 +func (s *GatewayService) countTokensError(c *gin.Context, status int, errType, message string) { + c.JSON(status, gin.H{ + "type": "error", + "error": gin.H{ + "type": errType, + "message": message, + }, + }) +} + +// GetAvailableModels returns the list of models available for a group +// It aggregates model_mapping keys from all schedulable accounts in the group +func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64, platform string) []string { + var accounts []Account + var err error + + if groupID != nil { + accounts, err = s.accountRepo.ListSchedulableByGroupID(ctx, *groupID) + } else { + accounts, err = s.accountRepo.ListSchedulable(ctx) + } + + if err != nil || len(accounts) == 0 { + return nil + } + + // Filter by platform if specified + if platform != "" { + filtered := make([]Account, 0) + for _, acc := range accounts { + if acc.Platform == platform { + filtered = append(filtered, acc) + } + } + accounts = filtered + } + + // Collect unique models from all accounts + modelSet := make(map[string]struct{}) + hasAnyMapping := false + + for _, acc := range accounts { + mapping := acc.GetModelMapping() + if len(mapping) > 0 { + hasAnyMapping = true + for model := range mapping { + modelSet[model] = struct{}{} + } + } + } + + // If no account has model_mapping, return nil (use default) + if !hasAnyMapping { + return nil + } + + // Convert to slice + models := make([]string, 0, len(modelSet)) + for model := range modelSet { + models = append(models, model) + } + + return models +} diff --git a/frontend/src/components/common/DateRangePicker.vue b/frontend/src/components/common/DateRangePicker.vue index 6c43c75d..4fce029f 100644 --- a/frontend/src/components/common/DateRangePicker.vue +++ b/frontend/src/components/common/DateRangePicker.vue @@ -1,444 +1,452 @@ - - - - - + + + + + diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index c52ddc4d..2f6aa2c8 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -1,851 +1,851 @@ -/** - * Core Type Definitions for TianShuAPI Frontend - */ - -// ==================== User & Auth Types ==================== - -export interface User { - id: number - username: string - notes: string - email: string - role: 'admin' | 'user' // User role for authorization - balance: number // User balance for API usage - concurrency: number // Allowed concurrent requests - status: 'active' | 'disabled' // Account status - allowed_groups: number[] | null // Allowed group IDs (null = all non-exclusive groups) - subscriptions?: UserSubscription[] // User's active subscriptions - created_at: string - updated_at: string -} - -export interface LoginRequest { - email: string - password: string - turnstile_token?: string -} - -export interface RegisterRequest { - email: string - password: string - verify_code?: string - turnstile_token?: string -} - -export interface SendVerifyCodeRequest { - email: string - turnstile_token?: string -} - -export interface SendVerifyCodeResponse { - message: string - countdown: number -} - -export interface PublicSettings { - registration_enabled: boolean - email_verify_enabled: boolean - turnstile_enabled: boolean - turnstile_site_key: string - site_name: string - site_logo: string - site_subtitle: string - api_base_url: string - contact_info: string - doc_url: string - version: string -} - -export interface AuthResponse { - access_token: string - token_type: string - user: User & { run_mode?: 'standard' | 'simple' } -} - -export interface CurrentUserResponse extends User { - run_mode?: 'standard' | 'simple' -} - -// ==================== Subscription Types ==================== - -export interface Subscription { - id: number - user_id: number - name: string - url: string - type: 'clash' | 'v2ray' | 'surge' | 'quantumult' | 'shadowrocket' - update_interval: number // in hours - last_updated: string | null - node_count: number - is_active: boolean - created_at: string - updated_at: string -} - -export interface CreateSubscriptionRequest { - name: string - url: string - type: Subscription['type'] - update_interval?: number -} - -export interface UpdateSubscriptionRequest { - name?: string - url?: string - type?: Subscription['type'] - update_interval?: number - is_active?: boolean -} - -// ==================== Proxy Node Types ==================== - -export interface ProxyNode { - id: number - subscription_id: number - name: string - type: 'ss' | 'ssr' | 'vmess' | 'vless' | 'trojan' | 'hysteria' | 'hysteria2' - server: string - port: number - config: Record // JSON configuration specific to proxy type - latency: number | null // in milliseconds - last_checked: string | null - is_available: boolean - created_at: string - updated_at: string -} - -// ==================== Conversion Types ==================== - -export interface ConversionRequest { - subscription_ids: number[] - target_type: 'clash' | 'v2ray' | 'surge' | 'quantumult' | 'shadowrocket' - filter?: { - name_pattern?: string - types?: ProxyNode['type'][] - min_latency?: number - max_latency?: number - available_only?: boolean - } - sort?: { - by: 'name' | 'latency' | 'type' - order: 'asc' | 'desc' - } -} - -export interface ConversionResult { - url: string // URL to download the converted subscription - expires_at: string - node_count: number -} - -// ==================== Statistics Types ==================== - -export interface SubscriptionStats { - subscription_id: number - total_nodes: number - available_nodes: number - avg_latency: number | null - by_type: Record - last_update: string -} - -export interface UserStats { - total_subscriptions: number - total_nodes: number - active_subscriptions: number - total_conversions: number - last_conversion: string | null -} - -// ==================== API Response Types ==================== - -export interface ApiResponse { - code: number - message: string - data: T -} - -export interface ApiError { - detail: string - code?: string - field?: string -} - -export interface PaginatedResponse { - items: T[] - total: number - page: number - page_size: number - pages: number -} - -// ==================== UI State Types ==================== - -export type ToastType = 'success' | 'error' | 'info' | 'warning' - -export interface Toast { - id: string - type: ToastType - message: string - title?: string - duration?: number // in milliseconds, undefined means no auto-dismiss - startTime?: number // timestamp when toast was created, for progress bar -} - -export interface AppState { - sidebarCollapsed: boolean - loading: boolean - toasts: Toast[] -} - -// ==================== Validation Types ==================== - -export interface ValidationError { - field: string - message: string -} - -// ==================== Table/List Types ==================== - -export interface SortConfig { - key: string - order: 'asc' | 'desc' -} - -export interface FilterConfig { - [key: string]: string | number | boolean | null | undefined -} - -export interface PaginationConfig { - page: number - page_size: number -} - -// ==================== API Key & Group Types ==================== - -export type GroupPlatform = 'anthropic' | 'openai' | 'gemini' | 'antigravity' - -export type SubscriptionType = 'standard' | 'subscription' - -export interface Group { - id: number - name: string - description: string | null - platform: GroupPlatform - rate_multiplier: number - is_exclusive: boolean - status: 'active' | 'inactive' - subscription_type: SubscriptionType - daily_limit_usd: number | null - weekly_limit_usd: number | null - monthly_limit_usd: number | null - account_count?: number - created_at: string - updated_at: string -} - -export interface ApiKey { - id: number - user_id: number - key: string - name: string - group_id: number | null - status: 'active' | 'inactive' - created_at: string - updated_at: string - group?: Group -} - -export interface CreateApiKeyRequest { - name: string - group_id?: number | null - custom_key?: string // Optional custom API Key -} - -export interface UpdateApiKeyRequest { - name?: string - group_id?: number | null - status?: 'active' | 'inactive' -} - -export interface CreateGroupRequest { - name: string - description?: string | null - platform?: GroupPlatform - rate_multiplier?: number - is_exclusive?: boolean -} - -export interface UpdateGroupRequest { - name?: string - description?: string | null - platform?: GroupPlatform - rate_multiplier?: number - is_exclusive?: boolean - status?: 'active' | 'inactive' -} - -// ==================== Account & Proxy Types ==================== - -export type AccountPlatform = 'anthropic' | 'openai' | 'gemini' | 'antigravity' -export type AccountType = 'oauth' | 'setup-token' | 'apikey' -export type OAuthAddMethod = 'oauth' | 'setup-token' -export type ProxyProtocol = 'http' | 'https' | 'socks5' - -// Claude Model type (returned by /v1/models and account models API) -export interface ClaudeModel { - id: string - type: string - display_name: string - created_at: string -} - -export interface Proxy { - id: number - name: string - protocol: ProxyProtocol - host: string - port: number - username: string | null - password?: string | null - status: 'active' | 'inactive' - account_count?: number // Number of accounts using this proxy - created_at: string - updated_at: string -} - -// Gemini credentials structure for OAuth and API Key authentication -export interface GeminiCredentials { - // API Key authentication - api_key?: string - - // OAuth authentication - access_token?: string - refresh_token?: string - oauth_type?: 'code_assist' | 'ai_studio' | string - tier_id?: 'LEGACY' | 'PRO' | 'ULTRA' | string - project_id?: string - token_type?: string - scope?: string - expires_at?: string -} - -export interface Account { - id: number - name: string - platform: AccountPlatform - type: AccountType - credentials?: Record - extra?: CodexUsageSnapshot & Record // Extra fields including Codex usage - proxy_id: number | null - concurrency: number - current_concurrency?: number // Real-time concurrency count from Redis - priority: number - status: 'active' | 'inactive' | 'error' - error_message: string | null - last_used_at: string | null - created_at: string - updated_at: string - proxy?: Proxy - group_ids?: number[] // Groups this account belongs to - groups?: Group[] // Preloaded group objects - - // Rate limit & scheduling fields - schedulable: boolean - rate_limited_at: string | null - rate_limit_reset_at: string | null - overload_until: string | null - - // Session window fields (5-hour window) - session_window_start: string | null - session_window_end: string | null - session_window_status: 'allowed' | 'allowed_warning' | 'rejected' | null -} - -// Account Usage types -export interface WindowStats { - requests: number - tokens: number - cost: number -} - -export interface UsageProgress { - utilization: number // Percentage (0-100+, 100 = 100%) - resets_at: string | null - remaining_seconds: number - window_stats?: WindowStats | null // 窗口期统计(从窗口开始到当前的使用量) -} - -// Antigravity 单个模型的配额信息 -export interface AntigravityModelQuota { - utilization: number // 使用率 0-100 - reset_time: string // 重置时间 ISO8601 -} - -export interface AccountUsageInfo { - updated_at: string | null - five_hour: UsageProgress | null - seven_day: UsageProgress | null - seven_day_sonnet: UsageProgress | null - gemini_pro_daily?: UsageProgress | null - gemini_flash_daily?: UsageProgress | null - antigravity_quota?: Record | null -} - -// OpenAI Codex usage snapshot (from response headers) -export interface CodexUsageSnapshot { - // Legacy fields (kept for backwards compatibility) - // NOTE: The naming is ambiguous - actual window type is determined by window_minutes value - codex_primary_used_percent?: number // Usage percentage (check window_minutes for actual window type) - codex_primary_reset_after_seconds?: number // Seconds until reset - codex_primary_window_minutes?: number // Window in minutes - codex_secondary_used_percent?: number // Usage percentage (check window_minutes for actual window type) - codex_secondary_reset_after_seconds?: number // Seconds until reset - codex_secondary_window_minutes?: number // Window in minutes - codex_primary_over_secondary_percent?: number // Overflow ratio - - // Canonical fields (normalized by backend, use these preferentially) - codex_5h_used_percent?: number // 5-hour window usage percentage - codex_5h_reset_after_seconds?: number // Seconds until 5h window reset - codex_5h_window_minutes?: number // 5h window in minutes (should be ~300) - codex_7d_used_percent?: number // 7-day window usage percentage - codex_7d_reset_after_seconds?: number // Seconds until 7d window reset - codex_7d_window_minutes?: number // 7d window in minutes (should be ~10080) - - codex_usage_updated_at?: string // Last update timestamp -} - -export interface CreateAccountRequest { - name: string - platform: AccountPlatform - type: AccountType - credentials: Record - extra?: Record - proxy_id?: number | null - concurrency?: number - priority?: number - group_ids?: number[] -} - -export interface UpdateAccountRequest { - name?: string - type?: AccountType - credentials?: Record - extra?: Record - proxy_id?: number | null - concurrency?: number - priority?: number - status?: 'active' | 'inactive' - group_ids?: number[] -} - -export interface CreateProxyRequest { - name: string - protocol: ProxyProtocol - host: string - port: number - username?: string | null - password?: string | null -} - -export interface UpdateProxyRequest { - name?: string - protocol?: ProxyProtocol - host?: string - port?: number - username?: string | null - password?: string | null - status?: 'active' | 'inactive' -} - -// ==================== Usage & Redeem Types ==================== - -export type RedeemCodeType = 'balance' | 'concurrency' | 'subscription' - -// 消费类型: 0=钱包余额, 1=订阅套餐 -export type BillingType = 0 | 1 - -export interface UsageLog { - id: number - user_id: number - api_key_id: number - account_id: number | null - request_id: string - model: string - - group_id: number | null - subscription_id: number | null - - input_tokens: number - output_tokens: number - cache_creation_tokens: number - cache_read_tokens: number - cache_creation_5m_tokens: number - cache_creation_1h_tokens: number - - input_cost: number - output_cost: number - cache_creation_cost: number - cache_read_cost: number - total_cost: number - actual_cost: number - rate_multiplier: number - - billing_type: BillingType - stream: boolean - duration_ms: number - first_token_ms: number | null - created_at: string - - user?: User - api_key?: ApiKey - account?: Account - group?: Group - subscription?: UserSubscription -} - -export interface RedeemCode { - id: number - code: string - type: RedeemCodeType - value: number - status: 'active' | 'used' | 'expired' | 'unused' - used_by: number | null - used_at: string | null - created_at: string - updated_at?: string - group_id?: number | null // 订阅类型专用 - validity_days?: number // 订阅类型专用 - user?: User - group?: Group // 关联的分组 -} - -export interface GenerateRedeemCodesRequest { - count: number - type: RedeemCodeType - value: number - group_id?: number | null // 订阅类型专用 - validity_days?: number // 订阅类型专用 -} - -export interface RedeemCodeRequest { - code: string -} - -// ==================== Dashboard & Statistics ==================== - -export interface DashboardStats { - // 用户统计 - total_users: number - today_new_users: number // 今日新增用户数 - active_users: number // 今日有请求的用户数 - - // API Key 统计 - total_api_keys: number - active_api_keys: number // 状态为 active 的 API Key 数 - - // 账户统计 - total_accounts: number - normal_accounts: number // 正常账户数 - error_accounts: number // 异常账户数 - ratelimit_accounts: number // 限流账户数 - overload_accounts: number // 过载账户数 - - // 累计 Token 使用统计 - total_requests: number - total_input_tokens: number - total_output_tokens: number - total_cache_creation_tokens: number - total_cache_read_tokens: number - total_tokens: number - total_cost: number // 累计标准计费 - total_actual_cost: number // 累计实际扣除 - - // 今日 Token 使用统计 - today_requests: number - today_input_tokens: number - today_output_tokens: number - today_cache_creation_tokens: number - today_cache_read_tokens: number - today_tokens: number - today_cost: number // 今日标准计费 - today_actual_cost: number // 今日实际扣除 - - // 系统运行统计 - average_duration_ms: number // 平均响应时间 - uptime: number // 系统运行时间(秒) - - // 性能指标 - rpm: number // 近5分钟平均每分钟请求数 - tpm: number // 近5分钟平均每分钟Token数 -} - -export interface UsageStatsResponse { - period?: string - total_requests: number - total_input_tokens: number - total_output_tokens: number - total_cache_tokens: number - total_tokens: number - total_cost: number // 标准计费 - total_actual_cost: number // 实际扣除 - average_duration_ms: number - models?: Record -} - -// ==================== Trend & Chart Types ==================== - -export interface TrendDataPoint { - date: string - requests: number - input_tokens: number - output_tokens: number - cache_tokens: number - total_tokens: number - cost: number // 标准计费 - actual_cost: number // 实际扣除 -} - -export interface ModelStat { - model: string - requests: number - input_tokens: number - output_tokens: number - total_tokens: number - cost: number // 标准计费 - actual_cost: number // 实际扣除 -} - -export interface UserUsageTrendPoint { - date: string - user_id: number - email: string - requests: number - tokens: number - cost: number // 标准计费 - actual_cost: number // 实际扣除 -} - -export interface ApiKeyUsageTrendPoint { - date: string - api_key_id: number - key_name: string - requests: number - tokens: number -} - -// ==================== Admin User Management ==================== - -export interface UpdateUserRequest { - email?: string - password?: string - username?: string - notes?: string - role?: 'admin' | 'user' - balance?: number - concurrency?: number - status?: 'active' | 'disabled' - allowed_groups?: number[] | null -} - -export interface ChangePasswordRequest { - old_password: string - new_password: string -} - -// ==================== User Subscription Types ==================== - -export interface UserSubscription { - id: number - user_id: number - group_id: number - status: 'active' | 'expired' | 'revoked' - daily_usage_usd: number - weekly_usage_usd: number - monthly_usage_usd: number - daily_window_start: string | null - weekly_window_start: string | null - monthly_window_start: string | null - created_at: string - updated_at: string - expires_at: string | null - user?: User - group?: Group -} - -export interface SubscriptionProgress { - subscription_id: number - daily: { - used: number - limit: number | null - percentage: number - reset_in_seconds: number | null - } | null - weekly: { - used: number - limit: number | null - percentage: number - reset_in_seconds: number | null - } | null - monthly: { - used: number - limit: number | null - percentage: number - reset_in_seconds: number | null - } | null - expires_at: string | null - days_remaining: number | null -} - -export interface AssignSubscriptionRequest { - user_id: number - group_id: number - validity_days?: number -} - -export interface BulkAssignSubscriptionRequest { - user_ids: number[] - group_id: number - validity_days?: number -} - -export interface ExtendSubscriptionRequest { - days: number -} - -// ==================== Query Parameters ==================== - -export interface UsageQueryParams { - page?: number - page_size?: number - api_key_id?: number - user_id?: number - account_id?: number - group_id?: number - model?: string - stream?: boolean - billing_type?: number - start_date?: string - end_date?: string -} - -// ==================== Account Usage Statistics ==================== - -export interface AccountUsageHistory { - date: string - label: string - requests: number - tokens: number - cost: number - actual_cost: number -} - -export interface AccountUsageSummary { - days: number - actual_days_used: number - total_cost: number - total_standard_cost: number - total_requests: number - total_tokens: number - avg_daily_cost: number - avg_daily_requests: number - avg_daily_tokens: number - avg_duration_ms: number - today: { - date: string - cost: number - requests: number - tokens: number - } | null - highest_cost_day: { - date: string - label: string - cost: number - requests: number - } | null - highest_request_day: { - date: string - label: string - requests: number - cost: number - } | null -} - -export interface AccountUsageStatsResponse { - history: AccountUsageHistory[] - summary: AccountUsageSummary - models: ModelStat[] -} - -// ==================== User Attribute Types ==================== - -export type UserAttributeType = 'text' | 'textarea' | 'number' | 'email' | 'url' | 'date' | 'select' | 'multi_select' - -export interface UserAttributeOption { - value: string - label: string -} - -export interface UserAttributeValidation { - min_length?: number - max_length?: number - min?: number - max?: number - pattern?: string - message?: string -} - -export interface UserAttributeDefinition { - id: number - key: string - name: string - description: string - type: UserAttributeType - options: UserAttributeOption[] - required: boolean - validation: UserAttributeValidation - placeholder: string - display_order: number - enabled: boolean - created_at: string - updated_at: string -} - -export interface UserAttributeValue { - id: number - user_id: number - attribute_id: number - value: string - created_at: string - updated_at: string -} - -export interface CreateUserAttributeRequest { - key: string - name: string - description?: string - type: UserAttributeType - options?: UserAttributeOption[] - required?: boolean - validation?: UserAttributeValidation - placeholder?: string - display_order?: number - enabled?: boolean -} - -export interface UpdateUserAttributeRequest { - key?: string - name?: string - description?: string - type?: UserAttributeType - options?: UserAttributeOption[] - required?: boolean - validation?: UserAttributeValidation - placeholder?: string - display_order?: number - enabled?: boolean -} - -export interface UserAttributeValuesMap { - [attributeId: number]: string -} +/** + * Core Type Definitions for Sub2API Frontend + */ + +// ==================== User & Auth Types ==================== + +export interface User { + id: number + username: string + notes: string + email: string + role: 'admin' | 'user' // User role for authorization + balance: number // User balance for API usage + concurrency: number // Allowed concurrent requests + status: 'active' | 'disabled' // Account status + allowed_groups: number[] | null // Allowed group IDs (null = all non-exclusive groups) + subscriptions?: UserSubscription[] // User's active subscriptions + created_at: string + updated_at: string +} + +export interface LoginRequest { + email: string + password: string + turnstile_token?: string +} + +export interface RegisterRequest { + email: string + password: string + verify_code?: string + turnstile_token?: string +} + +export interface SendVerifyCodeRequest { + email: string + turnstile_token?: string +} + +export interface SendVerifyCodeResponse { + message: string + countdown: number +} + +export interface PublicSettings { + registration_enabled: boolean + email_verify_enabled: boolean + turnstile_enabled: boolean + turnstile_site_key: string + site_name: string + site_logo: string + site_subtitle: string + api_base_url: string + contact_info: string + doc_url: string + version: string +} + +export interface AuthResponse { + access_token: string + token_type: string + user: User & { run_mode?: 'standard' | 'simple' } +} + +export interface CurrentUserResponse extends User { + run_mode?: 'standard' | 'simple' +} + +// ==================== Subscription Types ==================== + +export interface Subscription { + id: number + user_id: number + name: string + url: string + type: 'clash' | 'v2ray' | 'surge' | 'quantumult' | 'shadowrocket' + update_interval: number // in hours + last_updated: string | null + node_count: number + is_active: boolean + created_at: string + updated_at: string +} + +export interface CreateSubscriptionRequest { + name: string + url: string + type: Subscription['type'] + update_interval?: number +} + +export interface UpdateSubscriptionRequest { + name?: string + url?: string + type?: Subscription['type'] + update_interval?: number + is_active?: boolean +} + +// ==================== Proxy Node Types ==================== + +export interface ProxyNode { + id: number + subscription_id: number + name: string + type: 'ss' | 'ssr' | 'vmess' | 'vless' | 'trojan' | 'hysteria' | 'hysteria2' + server: string + port: number + config: Record // JSON configuration specific to proxy type + latency: number | null // in milliseconds + last_checked: string | null + is_available: boolean + created_at: string + updated_at: string +} + +// ==================== Conversion Types ==================== + +export interface ConversionRequest { + subscription_ids: number[] + target_type: 'clash' | 'v2ray' | 'surge' | 'quantumult' | 'shadowrocket' + filter?: { + name_pattern?: string + types?: ProxyNode['type'][] + min_latency?: number + max_latency?: number + available_only?: boolean + } + sort?: { + by: 'name' | 'latency' | 'type' + order: 'asc' | 'desc' + } +} + +export interface ConversionResult { + url: string // URL to download the converted subscription + expires_at: string + node_count: number +} + +// ==================== Statistics Types ==================== + +export interface SubscriptionStats { + subscription_id: number + total_nodes: number + available_nodes: number + avg_latency: number | null + by_type: Record + last_update: string +} + +export interface UserStats { + total_subscriptions: number + total_nodes: number + active_subscriptions: number + total_conversions: number + last_conversion: string | null +} + +// ==================== API Response Types ==================== + +export interface ApiResponse { + code: number + message: string + data: T +} + +export interface ApiError { + detail: string + code?: string + field?: string +} + +export interface PaginatedResponse { + items: T[] + total: number + page: number + page_size: number + pages: number +} + +// ==================== UI State Types ==================== + +export type ToastType = 'success' | 'error' | 'info' | 'warning' + +export interface Toast { + id: string + type: ToastType + message: string + title?: string + duration?: number // in milliseconds, undefined means no auto-dismiss + startTime?: number // timestamp when toast was created, for progress bar +} + +export interface AppState { + sidebarCollapsed: boolean + loading: boolean + toasts: Toast[] +} + +// ==================== Validation Types ==================== + +export interface ValidationError { + field: string + message: string +} + +// ==================== Table/List Types ==================== + +export interface SortConfig { + key: string + order: 'asc' | 'desc' +} + +export interface FilterConfig { + [key: string]: string | number | boolean | null | undefined +} + +export interface PaginationConfig { + page: number + page_size: number +} + +// ==================== API Key & Group Types ==================== + +export type GroupPlatform = 'anthropic' | 'openai' | 'gemini' | 'antigravity' + +export type SubscriptionType = 'standard' | 'subscription' + +export interface Group { + id: number + name: string + description: string | null + platform: GroupPlatform + rate_multiplier: number + is_exclusive: boolean + status: 'active' | 'inactive' + subscription_type: SubscriptionType + daily_limit_usd: number | null + weekly_limit_usd: number | null + monthly_limit_usd: number | null + account_count?: number + created_at: string + updated_at: string +} + +export interface ApiKey { + id: number + user_id: number + key: string + name: string + group_id: number | null + status: 'active' | 'inactive' + created_at: string + updated_at: string + group?: Group +} + +export interface CreateApiKeyRequest { + name: string + group_id?: number | null + custom_key?: string // Optional custom API Key +} + +export interface UpdateApiKeyRequest { + name?: string + group_id?: number | null + status?: 'active' | 'inactive' +} + +export interface CreateGroupRequest { + name: string + description?: string | null + platform?: GroupPlatform + rate_multiplier?: number + is_exclusive?: boolean +} + +export interface UpdateGroupRequest { + name?: string + description?: string | null + platform?: GroupPlatform + rate_multiplier?: number + is_exclusive?: boolean + status?: 'active' | 'inactive' +} + +// ==================== Account & Proxy Types ==================== + +export type AccountPlatform = 'anthropic' | 'openai' | 'gemini' | 'antigravity' +export type AccountType = 'oauth' | 'setup-token' | 'apikey' +export type OAuthAddMethod = 'oauth' | 'setup-token' +export type ProxyProtocol = 'http' | 'https' | 'socks5' | 'socks5h' + +// Claude Model type (returned by /v1/models and account models API) +export interface ClaudeModel { + id: string + type: string + display_name: string + created_at: string +} + +export interface Proxy { + id: number + name: string + protocol: ProxyProtocol + host: string + port: number + username: string | null + password?: string | null + status: 'active' | 'inactive' + account_count?: number // Number of accounts using this proxy + created_at: string + updated_at: string +} + +// Gemini credentials structure for OAuth and API Key authentication +export interface GeminiCredentials { + // API Key authentication + api_key?: string + + // OAuth authentication + access_token?: string + refresh_token?: string + oauth_type?: 'code_assist' | 'ai_studio' | string + tier_id?: 'LEGACY' | 'PRO' | 'ULTRA' | string + project_id?: string + token_type?: string + scope?: string + expires_at?: string +} + +export interface Account { + id: number + name: string + platform: AccountPlatform + type: AccountType + credentials?: Record + extra?: CodexUsageSnapshot & Record // Extra fields including Codex usage + proxy_id: number | null + concurrency: number + current_concurrency?: number // Real-time concurrency count from Redis + priority: number + status: 'active' | 'inactive' | 'error' + error_message: string | null + last_used_at: string | null + created_at: string + updated_at: string + proxy?: Proxy + group_ids?: number[] // Groups this account belongs to + groups?: Group[] // Preloaded group objects + + // Rate limit & scheduling fields + schedulable: boolean + rate_limited_at: string | null + rate_limit_reset_at: string | null + overload_until: string | null + + // Session window fields (5-hour window) + session_window_start: string | null + session_window_end: string | null + session_window_status: 'allowed' | 'allowed_warning' | 'rejected' | null +} + +// Account Usage types +export interface WindowStats { + requests: number + tokens: number + cost: number +} + +export interface UsageProgress { + utilization: number // Percentage (0-100+, 100 = 100%) + resets_at: string | null + remaining_seconds: number + window_stats?: WindowStats | null // 窗口期统计(从窗口开始到当前的使用量) +} + +// Antigravity 单个模型的配额信息 +export interface AntigravityModelQuota { + utilization: number // 使用率 0-100 + reset_time: string // 重置时间 ISO8601 +} + +export interface AccountUsageInfo { + updated_at: string | null + five_hour: UsageProgress | null + seven_day: UsageProgress | null + seven_day_sonnet: UsageProgress | null + gemini_pro_daily?: UsageProgress | null + gemini_flash_daily?: UsageProgress | null + antigravity_quota?: Record | null +} + +// OpenAI Codex usage snapshot (from response headers) +export interface CodexUsageSnapshot { + // Legacy fields (kept for backwards compatibility) + // NOTE: The naming is ambiguous - actual window type is determined by window_minutes value + codex_primary_used_percent?: number // Usage percentage (check window_minutes for actual window type) + codex_primary_reset_after_seconds?: number // Seconds until reset + codex_primary_window_minutes?: number // Window in minutes + codex_secondary_used_percent?: number // Usage percentage (check window_minutes for actual window type) + codex_secondary_reset_after_seconds?: number // Seconds until reset + codex_secondary_window_minutes?: number // Window in minutes + codex_primary_over_secondary_percent?: number // Overflow ratio + + // Canonical fields (normalized by backend, use these preferentially) + codex_5h_used_percent?: number // 5-hour window usage percentage + codex_5h_reset_after_seconds?: number // Seconds until 5h window reset + codex_5h_window_minutes?: number // 5h window in minutes (should be ~300) + codex_7d_used_percent?: number // 7-day window usage percentage + codex_7d_reset_after_seconds?: number // Seconds until 7d window reset + codex_7d_window_minutes?: number // 7d window in minutes (should be ~10080) + + codex_usage_updated_at?: string // Last update timestamp +} + +export interface CreateAccountRequest { + name: string + platform: AccountPlatform + type: AccountType + credentials: Record + extra?: Record + proxy_id?: number | null + concurrency?: number + priority?: number + group_ids?: number[] +} + +export interface UpdateAccountRequest { + name?: string + type?: AccountType + credentials?: Record + extra?: Record + proxy_id?: number | null + concurrency?: number + priority?: number + status?: 'active' | 'inactive' + group_ids?: number[] +} + +export interface CreateProxyRequest { + name: string + protocol: ProxyProtocol + host: string + port: number + username?: string | null + password?: string | null +} + +export interface UpdateProxyRequest { + name?: string + protocol?: ProxyProtocol + host?: string + port?: number + username?: string | null + password?: string | null + status?: 'active' | 'inactive' +} + +// ==================== Usage & Redeem Types ==================== + +export type RedeemCodeType = 'balance' | 'concurrency' | 'subscription' + +// 消费类型: 0=钱包余额, 1=订阅套餐 +export type BillingType = 0 | 1 + +export interface UsageLog { + id: number + user_id: number + api_key_id: number + account_id: number | null + request_id: string + model: string + + group_id: number | null + subscription_id: number | null + + input_tokens: number + output_tokens: number + cache_creation_tokens: number + cache_read_tokens: number + cache_creation_5m_tokens: number + cache_creation_1h_tokens: number + + input_cost: number + output_cost: number + cache_creation_cost: number + cache_read_cost: number + total_cost: number + actual_cost: number + rate_multiplier: number + + billing_type: BillingType + stream: boolean + duration_ms: number + first_token_ms: number | null + created_at: string + + user?: User + api_key?: ApiKey + account?: Account + group?: Group + subscription?: UserSubscription +} + +export interface RedeemCode { + id: number + code: string + type: RedeemCodeType + value: number + status: 'active' | 'used' | 'expired' | 'unused' + used_by: number | null + used_at: string | null + created_at: string + updated_at?: string + group_id?: number | null // 订阅类型专用 + validity_days?: number // 订阅类型专用 + user?: User + group?: Group // 关联的分组 +} + +export interface GenerateRedeemCodesRequest { + count: number + type: RedeemCodeType + value: number + group_id?: number | null // 订阅类型专用 + validity_days?: number // 订阅类型专用 +} + +export interface RedeemCodeRequest { + code: string +} + +// ==================== Dashboard & Statistics ==================== + +export interface DashboardStats { + // 用户统计 + total_users: number + today_new_users: number // 今日新增用户数 + active_users: number // 今日有请求的用户数 + + // API Key 统计 + total_api_keys: number + active_api_keys: number // 状态为 active 的 API Key 数 + + // 账户统计 + total_accounts: number + normal_accounts: number // 正常账户数 + error_accounts: number // 异常账户数 + ratelimit_accounts: number // 限流账户数 + overload_accounts: number // 过载账户数 + + // 累计 Token 使用统计 + total_requests: number + total_input_tokens: number + total_output_tokens: number + total_cache_creation_tokens: number + total_cache_read_tokens: number + total_tokens: number + total_cost: number // 累计标准计费 + total_actual_cost: number // 累计实际扣除 + + // 今日 Token 使用统计 + today_requests: number + today_input_tokens: number + today_output_tokens: number + today_cache_creation_tokens: number + today_cache_read_tokens: number + today_tokens: number + today_cost: number // 今日标准计费 + today_actual_cost: number // 今日实际扣除 + + // 系统运行统计 + average_duration_ms: number // 平均响应时间 + uptime: number // 系统运行时间(秒) + + // 性能指标 + rpm: number // 近5分钟平均每分钟请求数 + tpm: number // 近5分钟平均每分钟Token数 +} + +export interface UsageStatsResponse { + period?: string + total_requests: number + total_input_tokens: number + total_output_tokens: number + total_cache_tokens: number + total_tokens: number + total_cost: number // 标准计费 + total_actual_cost: number // 实际扣除 + average_duration_ms: number + models?: Record +} + +// ==================== Trend & Chart Types ==================== + +export interface TrendDataPoint { + date: string + requests: number + input_tokens: number + output_tokens: number + cache_tokens: number + total_tokens: number + cost: number // 标准计费 + actual_cost: number // 实际扣除 +} + +export interface ModelStat { + model: string + requests: number + input_tokens: number + output_tokens: number + total_tokens: number + cost: number // 标准计费 + actual_cost: number // 实际扣除 +} + +export interface UserUsageTrendPoint { + date: string + user_id: number + email: string + requests: number + tokens: number + cost: number // 标准计费 + actual_cost: number // 实际扣除 +} + +export interface ApiKeyUsageTrendPoint { + date: string + api_key_id: number + key_name: string + requests: number + tokens: number +} + +// ==================== Admin User Management ==================== + +export interface UpdateUserRequest { + email?: string + password?: string + username?: string + notes?: string + role?: 'admin' | 'user' + balance?: number + concurrency?: number + status?: 'active' | 'disabled' + allowed_groups?: number[] | null +} + +export interface ChangePasswordRequest { + old_password: string + new_password: string +} + +// ==================== User Subscription Types ==================== + +export interface UserSubscription { + id: number + user_id: number + group_id: number + status: 'active' | 'expired' | 'revoked' + daily_usage_usd: number + weekly_usage_usd: number + monthly_usage_usd: number + daily_window_start: string | null + weekly_window_start: string | null + monthly_window_start: string | null + created_at: string + updated_at: string + expires_at: string | null + user?: User + group?: Group +} + +export interface SubscriptionProgress { + subscription_id: number + daily: { + used: number + limit: number | null + percentage: number + reset_in_seconds: number | null + } | null + weekly: { + used: number + limit: number | null + percentage: number + reset_in_seconds: number | null + } | null + monthly: { + used: number + limit: number | null + percentage: number + reset_in_seconds: number | null + } | null + expires_at: string | null + days_remaining: number | null +} + +export interface AssignSubscriptionRequest { + user_id: number + group_id: number + validity_days?: number +} + +export interface BulkAssignSubscriptionRequest { + user_ids: number[] + group_id: number + validity_days?: number +} + +export interface ExtendSubscriptionRequest { + days: number +} + +// ==================== Query Parameters ==================== + +export interface UsageQueryParams { + page?: number + page_size?: number + api_key_id?: number + user_id?: number + account_id?: number + group_id?: number + model?: string + stream?: boolean + billing_type?: number + start_date?: string + end_date?: string +} + +// ==================== Account Usage Statistics ==================== + +export interface AccountUsageHistory { + date: string + label: string + requests: number + tokens: number + cost: number + actual_cost: number +} + +export interface AccountUsageSummary { + days: number + actual_days_used: number + total_cost: number + total_standard_cost: number + total_requests: number + total_tokens: number + avg_daily_cost: number + avg_daily_requests: number + avg_daily_tokens: number + avg_duration_ms: number + today: { + date: string + cost: number + requests: number + tokens: number + } | null + highest_cost_day: { + date: string + label: string + cost: number + requests: number + } | null + highest_request_day: { + date: string + label: string + requests: number + cost: number + } | null +} + +export interface AccountUsageStatsResponse { + history: AccountUsageHistory[] + summary: AccountUsageSummary + models: ModelStat[] +} + +// ==================== User Attribute Types ==================== + +export type UserAttributeType = 'text' | 'textarea' | 'number' | 'email' | 'url' | 'date' | 'select' | 'multi_select' + +export interface UserAttributeOption { + value: string + label: string +} + +export interface UserAttributeValidation { + min_length?: number + max_length?: number + min?: number + max?: number + pattern?: string + message?: string +} + +export interface UserAttributeDefinition { + id: number + key: string + name: string + description: string + type: UserAttributeType + options: UserAttributeOption[] + required: boolean + validation: UserAttributeValidation + placeholder: string + display_order: number + enabled: boolean + created_at: string + updated_at: string +} + +export interface UserAttributeValue { + id: number + user_id: number + attribute_id: number + value: string + created_at: string + updated_at: string +} + +export interface CreateUserAttributeRequest { + key: string + name: string + description?: string + type: UserAttributeType + options?: UserAttributeOption[] + required?: boolean + validation?: UserAttributeValidation + placeholder?: string + display_order?: number + enabled?: boolean +} + +export interface UpdateUserAttributeRequest { + key?: string + name?: string + description?: string + type?: UserAttributeType + options?: UserAttributeOption[] + required?: boolean + validation?: UserAttributeValidation + placeholder?: string + display_order?: number + enabled?: boolean +} + +export interface UserAttributeValuesMap { + [attributeId: number]: string +} diff --git a/frontend/src/views/admin/ProxiesView.vue b/frontend/src/views/admin/ProxiesView.vue index dc7770ee..613b503c 100644 --- a/frontend/src/views/admin/ProxiesView.vue +++ b/frontend/src/views/admin/ProxiesView.vue @@ -1,995 +1,997 @@ - - - + + + diff --git a/frontend/src/views/admin/UsageView.vue b/frontend/src/views/admin/UsageView.vue index 70f03b40..ac5d1e05 100644 --- a/frontend/src/views/admin/UsageView.vue +++ b/frontend/src/views/admin/UsageView.vue @@ -1,1430 +1,1436 @@ - - - + + +