diff --git a/backend/internal/repository/pricing_service_test.go b/backend/internal/repository/pricing_service_test.go index 8cfc8222..c51317a4 100644 --- a/backend/internal/repository/pricing_service_test.go +++ b/backend/internal/repository/pricing_service_test.go @@ -120,10 +120,9 @@ func (s *PricingServiceSuite) TestFetchHashText_WhitespaceOnly() { func (s *PricingServiceSuite) TestFetchPricingJSON_ContextCancel() { started := make(chan struct{}) - block := make(chan struct{}) s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { close(started) - <-block + <-r.Context().Done() })) ctx, cancel := context.WithCancel(s.ctx) @@ -136,7 +135,6 @@ func (s *PricingServiceSuite) TestFetchPricingJSON_ContextCancel() { <-started cancel() - close(block) err := <-done require.Error(s.T(), err) diff --git a/backend/internal/service/crs_sync_service.go b/backend/internal/service/crs_sync_service.go index e1f9d252..6a0241fb 100644 --- a/backend/internal/service/crs_sync_service.go +++ b/backend/internal/service/crs_sync_service.go @@ -9,6 +9,7 @@ import ( "io" "net/http" "net/url" + "strconv" "strings" "time" @@ -20,6 +21,7 @@ type CRSSyncService struct { proxyRepo ProxyRepository oauthService *OAuthService openaiOAuthService *OpenAIOAuthService + geminiOAuthService *GeminiOAuthService } func NewCRSSyncService( @@ -27,12 +29,14 @@ func NewCRSSyncService( proxyRepo ProxyRepository, oauthService *OAuthService, openaiOAuthService *OpenAIOAuthService, + geminiOAuthService *GeminiOAuthService, ) *CRSSyncService { return &CRSSyncService{ accountRepo: accountRepo, proxyRepo: proxyRepo, oauthService: oauthService, openaiOAuthService: openaiOAuthService, + geminiOAuthService: geminiOAuthService, } } @@ -77,6 +81,8 @@ type crsExportResponse struct { ClaudeConsoleAccounts []crsConsoleAccount `json:"claudeConsoleAccounts"` OpenAIOAuthAccounts []crsOpenAIOAuthAccount `json:"openaiOAuthAccounts"` OpenAIResponsesAccounts []crsOpenAIResponsesAccount `json:"openaiResponsesAccounts"` + GeminiOAuthAccounts []crsGeminiOAuthAccount `json:"geminiOAuthAccounts"` + GeminiAPIKeyAccounts []crsGeminiAPIKeyAccount `json:"geminiApiKeyAccounts"` } `json:"data"` } @@ -149,6 +155,37 @@ type crsOpenAIOAuthAccount struct { Extra map[string]any `json:"extra"` } +type crsGeminiOAuthAccount struct { + Kind string `json:"kind"` + ID string `json:"id"` + Name string `json:"name"` + Description string `json:"description"` + Platform string `json:"platform"` + AuthType string `json:"authType"` // oauth + IsActive bool `json:"isActive"` + Schedulable bool `json:"schedulable"` + Priority int `json:"priority"` + Status string `json:"status"` + Proxy *crsProxy `json:"proxy"` + Credentials map[string]any `json:"credentials"` + Extra map[string]any `json:"extra"` +} + +type crsGeminiAPIKeyAccount struct { + Kind string `json:"kind"` + ID string `json:"id"` + Name string `json:"name"` + Description string `json:"description"` + Platform string `json:"platform"` + IsActive bool `json:"isActive"` + Schedulable bool `json:"schedulable"` + Priority int `json:"priority"` + Status string `json:"status"` + Proxy *crsProxy `json:"proxy"` + Credentials map[string]any `json:"credentials"` + Extra map[string]any `json:"extra"` +} + func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput) (*SyncFromCRSResult, error) { baseURL, err := normalizeBaseURL(input.BaseURL) if err != nil { @@ -176,7 +213,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput Items: make( []SyncFromCRSItemResult, 0, - len(exported.Data.ClaudeAccounts)+len(exported.Data.ClaudeConsoleAccounts)+len(exported.Data.OpenAIOAuthAccounts)+len(exported.Data.OpenAIResponsesAccounts), + len(exported.Data.ClaudeAccounts)+len(exported.Data.ClaudeConsoleAccounts)+len(exported.Data.OpenAIOAuthAccounts)+len(exported.Data.OpenAIResponsesAccounts)+len(exported.Data.GeminiOAuthAccounts)+len(exported.Data.GeminiAPIKeyAccounts), ), } @@ -680,6 +717,225 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput result.Items = append(result.Items, item) } + // Gemini OAuth -> sub2api gemini oauth + for _, src := range exported.Data.GeminiOAuthAccounts { + item := SyncFromCRSItemResult{ + CRSAccountID: src.ID, + Kind: src.Kind, + Name: src.Name, + } + + refreshToken, _ := src.Credentials["refresh_token"].(string) + if strings.TrimSpace(refreshToken) == "" { + item.Action = "failed" + item.Error = "missing refresh_token" + result.Failed++ + result.Items = append(result.Items, item) + continue + } + + proxyID, err := s.mapOrCreateProxy(ctx, input.SyncProxies, &proxies, src.Proxy, fmt.Sprintf("crs-%s", src.Name)) + if err != nil { + item.Action = "failed" + item.Error = "proxy sync failed: " + err.Error() + result.Failed++ + result.Items = append(result.Items, item) + continue + } + + credentials := sanitizeCredentialsMap(src.Credentials) + if v, ok := credentials["token_type"].(string); !ok || strings.TrimSpace(v) == "" { + credentials["token_type"] = "Bearer" + } + // Convert expires_at from RFC3339 to Unix seconds string (recommended to keep consistent with GetCredential()) + if expiresAtStr, ok := credentials["expires_at"].(string); ok && strings.TrimSpace(expiresAtStr) != "" { + if t, err := time.Parse(time.RFC3339, expiresAtStr); err == nil { + credentials["expires_at"] = strconv.FormatInt(t.Unix(), 10) + } + } + + extra := make(map[string]any) + if src.Extra != nil { + for k, v := range src.Extra { + extra[k] = v + } + } + extra["crs_account_id"] = src.ID + extra["crs_kind"] = src.Kind + extra["crs_synced_at"] = now + + existing, err := s.accountRepo.GetByCRSAccountID(ctx, src.ID) + if err != nil { + item.Action = "failed" + item.Error = "db lookup failed: " + err.Error() + result.Failed++ + result.Items = append(result.Items, item) + continue + } + + if existing == nil { + account := &model.Account{ + Name: defaultName(src.Name, src.ID), + Platform: model.PlatformGemini, + Type: model.AccountTypeOAuth, + Credentials: model.JSONB(credentials), + Extra: model.JSONB(extra), + ProxyID: proxyID, + Concurrency: 3, + Priority: clampPriority(src.Priority), + Status: mapCRSStatus(src.IsActive, src.Status), + Schedulable: src.Schedulable, + } + if err := s.accountRepo.Create(ctx, account); err != nil { + item.Action = "failed" + item.Error = "create failed: " + err.Error() + result.Failed++ + result.Items = append(result.Items, item) + continue + } + if refreshedCreds := s.refreshOAuthToken(ctx, account); refreshedCreds != nil { + account.Credentials = refreshedCreds + _ = s.accountRepo.Update(ctx, account) + } + item.Action = "created" + result.Created++ + result.Items = append(result.Items, item) + continue + } + + existing.Extra = mergeJSONB(existing.Extra, extra) + existing.Name = defaultName(src.Name, src.ID) + existing.Platform = model.PlatformGemini + existing.Type = model.AccountTypeOAuth + existing.Credentials = mergeJSONB(existing.Credentials, credentials) + if proxyID != nil { + existing.ProxyID = proxyID + } + existing.Concurrency = 3 + existing.Priority = clampPriority(src.Priority) + existing.Status = mapCRSStatus(src.IsActive, src.Status) + existing.Schedulable = src.Schedulable + + if err := s.accountRepo.Update(ctx, existing); err != nil { + item.Action = "failed" + item.Error = "update failed: " + err.Error() + result.Failed++ + result.Items = append(result.Items, item) + continue + } + + if refreshedCreds := s.refreshOAuthToken(ctx, existing); refreshedCreds != nil { + existing.Credentials = refreshedCreds + _ = s.accountRepo.Update(ctx, existing) + } + + item.Action = "updated" + result.Updated++ + result.Items = append(result.Items, item) + } + + // Gemini API Key -> sub2api gemini apikey + for _, src := range exported.Data.GeminiAPIKeyAccounts { + item := SyncFromCRSItemResult{ + CRSAccountID: src.ID, + Kind: src.Kind, + Name: src.Name, + } + + apiKey, _ := src.Credentials["api_key"].(string) + if strings.TrimSpace(apiKey) == "" { + item.Action = "failed" + item.Error = "missing api_key" + result.Failed++ + result.Items = append(result.Items, item) + continue + } + + proxyID, err := s.mapOrCreateProxy(ctx, input.SyncProxies, &proxies, src.Proxy, fmt.Sprintf("crs-%s", src.Name)) + if err != nil { + item.Action = "failed" + item.Error = "proxy sync failed: " + err.Error() + result.Failed++ + result.Items = append(result.Items, item) + continue + } + + credentials := sanitizeCredentialsMap(src.Credentials) + if baseURL, ok := credentials["base_url"].(string); !ok || strings.TrimSpace(baseURL) == "" { + credentials["base_url"] = "https://generativelanguage.googleapis.com" + } + + extra := make(map[string]any) + if src.Extra != nil { + for k, v := range src.Extra { + extra[k] = v + } + } + extra["crs_account_id"] = src.ID + extra["crs_kind"] = src.Kind + extra["crs_synced_at"] = now + + existing, err := s.accountRepo.GetByCRSAccountID(ctx, src.ID) + if err != nil { + item.Action = "failed" + item.Error = "db lookup failed: " + err.Error() + result.Failed++ + result.Items = append(result.Items, item) + continue + } + + if existing == nil { + account := &model.Account{ + Name: defaultName(src.Name, src.ID), + Platform: model.PlatformGemini, + Type: model.AccountTypeApiKey, + Credentials: model.JSONB(credentials), + Extra: model.JSONB(extra), + ProxyID: proxyID, + Concurrency: 3, + Priority: clampPriority(src.Priority), + Status: mapCRSStatus(src.IsActive, src.Status), + Schedulable: src.Schedulable, + } + if err := s.accountRepo.Create(ctx, account); err != nil { + item.Action = "failed" + item.Error = "create failed: " + err.Error() + result.Failed++ + result.Items = append(result.Items, item) + continue + } + item.Action = "created" + result.Created++ + result.Items = append(result.Items, item) + continue + } + + existing.Extra = mergeJSONB(existing.Extra, extra) + existing.Name = defaultName(src.Name, src.ID) + existing.Platform = model.PlatformGemini + existing.Type = model.AccountTypeApiKey + existing.Credentials = mergeJSONB(existing.Credentials, credentials) + if proxyID != nil { + existing.ProxyID = proxyID + } + existing.Concurrency = 3 + existing.Priority = clampPriority(src.Priority) + existing.Status = mapCRSStatus(src.IsActive, src.Status) + existing.Schedulable = src.Schedulable + + if err := s.accountRepo.Update(ctx, existing); err != nil { + item.Action = "failed" + item.Error = "update failed: " + err.Error() + result.Failed++ + result.Items = append(result.Items, item) + continue + } + + item.Action = "updated" + result.Updated++ + result.Items = append(result.Items, item) + } + return result, nil } @@ -947,6 +1203,21 @@ func (s *CRSSyncService) refreshOAuthToken(ctx context.Context, account *model.A } } } + case model.PlatformGemini: + if s.geminiOAuthService == nil { + return nil + } + tokenInfo, refreshErr := s.geminiOAuthService.RefreshAccountToken(ctx, account) + if refreshErr != nil { + err = refreshErr + } else { + newCredentials = s.geminiOAuthService.BuildAccountCredentials(tokenInfo) + for k, v := range account.Credentials { + if _, exists := newCredentials[k]; !exists { + newCredentials[k] = v + } + } + } default: return nil } diff --git a/backend/internal/service/pricing_service.go b/backend/internal/service/pricing_service.go index 56aab2bc..6d95f12f 100644 --- a/backend/internal/service/pricing_service.go +++ b/backend/internal/service/pricing_service.go @@ -393,27 +393,32 @@ func (s *PricingService) GetModelPricing(modelName string) *LiteLLMModelPricing return nil } - // 标准化模型名称 - modelLower := strings.ToLower(modelName) + // 标准化模型名称(同时兼容 "models/xxx"、VertexAI 资源名等前缀) + modelLower := strings.ToLower(strings.TrimSpace(modelName)) + lookupCandidates := s.buildModelLookupCandidates(modelLower) // 1. 精确匹配 - if pricing, ok := s.pricingData[modelLower]; ok { - return pricing - } - if pricing, ok := s.pricingData[modelName]; ok { - return pricing + for _, candidate := range lookupCandidates { + if candidate == "" { + continue + } + if pricing, ok := s.pricingData[candidate]; ok { + return pricing + } } // 2. 处理常见的模型名称变体 // claude-opus-4-5-20251101 -> claude-opus-4.5-20251101 - normalized := strings.ReplaceAll(modelLower, "-4-5-", "-4.5-") - if pricing, ok := s.pricingData[normalized]; ok { - return pricing + for _, candidate := range lookupCandidates { + normalized := strings.ReplaceAll(candidate, "-4-5-", "-4.5-") + if pricing, ok := s.pricingData[normalized]; ok { + return pricing + } } // 3. 尝试模糊匹配(去掉版本号后缀) // claude-opus-4-5-20251101 -> claude-opus-4.5 - baseName := s.extractBaseName(modelLower) + baseName := s.extractBaseName(lookupCandidates[0]) for key, pricing := range s.pricingData { keyBase := s.extractBaseName(strings.ToLower(key)) if keyBase == baseName { @@ -422,18 +427,84 @@ func (s *PricingService) GetModelPricing(modelName string) *LiteLLMModelPricing } // 4. 基于模型系列匹配(Claude) - if pricing := s.matchByModelFamily(modelLower); pricing != nil { + if pricing := s.matchByModelFamily(lookupCandidates[0]); pricing != nil { return pricing } // 5. OpenAI 模型回退策略 - if strings.HasPrefix(modelLower, "gpt-") { - return s.matchOpenAIModel(modelLower) + if strings.HasPrefix(lookupCandidates[0], "gpt-") { + return s.matchOpenAIModel(lookupCandidates[0]) } return nil } +func (s *PricingService) buildModelLookupCandidates(modelLower string) []string { + // Prefer canonical model name first (this also improves billing compatibility with "models/xxx"). + candidates := []string{ + normalizeModelNameForPricing(modelLower), + modelLower, + } + for _, cand := range []string{ + strings.TrimPrefix(modelLower, "models/"), + lastSegment(modelLower), + lastSegment(strings.TrimPrefix(modelLower, "models/")), + } { + candidates = append(candidates, cand) + } + + seen := make(map[string]struct{}, len(candidates)) + out := make([]string, 0, len(candidates)) + for _, c := range candidates { + c = strings.TrimSpace(c) + if c == "" { + continue + } + if _, ok := seen[c]; ok { + continue + } + seen[c] = struct{}{} + out = append(out, c) + } + if len(out) == 0 { + return []string{modelLower} + } + return out +} + +func normalizeModelNameForPricing(model string) string { + // Common Gemini/VertexAI forms: + // - models/gemini-2.0-flash-exp + // - publishers/google/models/gemini-1.5-pro + // - projects/.../locations/.../publishers/google/models/gemini-1.5-pro + model = strings.TrimSpace(model) + model = strings.TrimLeft(model, "/") + + if strings.HasPrefix(model, "models/") { + model = strings.TrimPrefix(model, "models/") + } + if strings.HasPrefix(model, "publishers/google/models/") { + model = strings.TrimPrefix(model, "publishers/google/models/") + } + + if idx := strings.LastIndex(model, "/publishers/google/models/"); idx != -1 { + model = model[idx+len("/publishers/google/models/"):] + } + if idx := strings.LastIndex(model, "/models/"); idx != -1 { + model = model[idx+len("/models/"):] + } + + model = strings.TrimLeft(model, "/") + return model +} + +func lastSegment(model string) string { + if idx := strings.LastIndex(model, "/"); idx != -1 { + return model[idx+1:] + } + return model +} + // extractBaseName 提取基础模型名称(去掉日期版本号) func (s *PricingService) extractBaseName(model string) string { // 移除日期后缀 (如 -20251101, -20241022) diff --git a/backend/internal/service/token_refresh_service.go b/backend/internal/service/token_refresh_service.go index 24ef7b8e..187a517e 100644 --- a/backend/internal/service/token_refresh_service.go +++ b/backend/internal/service/token_refresh_service.go @@ -27,6 +27,7 @@ func NewTokenRefreshService( accountRepo AccountRepository, oauthService *OAuthService, openaiOAuthService *OpenAIOAuthService, + geminiOAuthService *GeminiOAuthService, cfg *config.Config, ) *TokenRefreshService { s := &TokenRefreshService{ @@ -39,6 +40,7 @@ func NewTokenRefreshService( s.refreshers = []TokenRefresher{ NewClaudeTokenRefresher(oauthService), NewOpenAITokenRefresher(openaiOAuthService), + NewGeminiTokenRefresher(geminiOAuthService), } return s