feat(service): 扩展 CRS 同步和定价服务支持 Gemini

- CRS 同步服务新增 Gemini 账号同步逻辑(+273行)
- 定价服务扩展 Gemini 模型定价计算(+99行)
- 更新 Token 刷新服务集成 Gemini
- 更新相关单元测试
This commit is contained in:
ianshaw
2025-12-25 06:44:40 -08:00
parent dc109827b7
commit 55258bf099
4 changed files with 360 additions and 18 deletions

View File

@@ -120,10 +120,9 @@ func (s *PricingServiceSuite) TestFetchHashText_WhitespaceOnly() {
func (s *PricingServiceSuite) TestFetchPricingJSON_ContextCancel() {
started := make(chan struct{})
block := make(chan struct{})
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
close(started)
<-block
<-r.Context().Done()
}))
ctx, cancel := context.WithCancel(s.ctx)
@@ -136,7 +135,6 @@ func (s *PricingServiceSuite) TestFetchPricingJSON_ContextCancel() {
<-started
cancel()
close(block)
err := <-done
require.Error(s.T(), err)

View File

@@ -9,6 +9,7 @@ import (
"io"
"net/http"
"net/url"
"strconv"
"strings"
"time"
@@ -20,6 +21,7 @@ type CRSSyncService struct {
proxyRepo ProxyRepository
oauthService *OAuthService
openaiOAuthService *OpenAIOAuthService
geminiOAuthService *GeminiOAuthService
}
func NewCRSSyncService(
@@ -27,12 +29,14 @@ func NewCRSSyncService(
proxyRepo ProxyRepository,
oauthService *OAuthService,
openaiOAuthService *OpenAIOAuthService,
geminiOAuthService *GeminiOAuthService,
) *CRSSyncService {
return &CRSSyncService{
accountRepo: accountRepo,
proxyRepo: proxyRepo,
oauthService: oauthService,
openaiOAuthService: openaiOAuthService,
geminiOAuthService: geminiOAuthService,
}
}
@@ -77,6 +81,8 @@ type crsExportResponse struct {
ClaudeConsoleAccounts []crsConsoleAccount `json:"claudeConsoleAccounts"`
OpenAIOAuthAccounts []crsOpenAIOAuthAccount `json:"openaiOAuthAccounts"`
OpenAIResponsesAccounts []crsOpenAIResponsesAccount `json:"openaiResponsesAccounts"`
GeminiOAuthAccounts []crsGeminiOAuthAccount `json:"geminiOAuthAccounts"`
GeminiAPIKeyAccounts []crsGeminiAPIKeyAccount `json:"geminiApiKeyAccounts"`
} `json:"data"`
}
@@ -149,6 +155,37 @@ type crsOpenAIOAuthAccount struct {
Extra map[string]any `json:"extra"`
}
type crsGeminiOAuthAccount struct {
Kind string `json:"kind"`
ID string `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
Platform string `json:"platform"`
AuthType string `json:"authType"` // oauth
IsActive bool `json:"isActive"`
Schedulable bool `json:"schedulable"`
Priority int `json:"priority"`
Status string `json:"status"`
Proxy *crsProxy `json:"proxy"`
Credentials map[string]any `json:"credentials"`
Extra map[string]any `json:"extra"`
}
type crsGeminiAPIKeyAccount struct {
Kind string `json:"kind"`
ID string `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
Platform string `json:"platform"`
IsActive bool `json:"isActive"`
Schedulable bool `json:"schedulable"`
Priority int `json:"priority"`
Status string `json:"status"`
Proxy *crsProxy `json:"proxy"`
Credentials map[string]any `json:"credentials"`
Extra map[string]any `json:"extra"`
}
func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput) (*SyncFromCRSResult, error) {
baseURL, err := normalizeBaseURL(input.BaseURL)
if err != nil {
@@ -176,7 +213,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
Items: make(
[]SyncFromCRSItemResult,
0,
len(exported.Data.ClaudeAccounts)+len(exported.Data.ClaudeConsoleAccounts)+len(exported.Data.OpenAIOAuthAccounts)+len(exported.Data.OpenAIResponsesAccounts),
len(exported.Data.ClaudeAccounts)+len(exported.Data.ClaudeConsoleAccounts)+len(exported.Data.OpenAIOAuthAccounts)+len(exported.Data.OpenAIResponsesAccounts)+len(exported.Data.GeminiOAuthAccounts)+len(exported.Data.GeminiAPIKeyAccounts),
),
}
@@ -680,6 +717,225 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
result.Items = append(result.Items, item)
}
// Gemini OAuth -> sub2api gemini oauth
for _, src := range exported.Data.GeminiOAuthAccounts {
item := SyncFromCRSItemResult{
CRSAccountID: src.ID,
Kind: src.Kind,
Name: src.Name,
}
refreshToken, _ := src.Credentials["refresh_token"].(string)
if strings.TrimSpace(refreshToken) == "" {
item.Action = "failed"
item.Error = "missing refresh_token"
result.Failed++
result.Items = append(result.Items, item)
continue
}
proxyID, err := s.mapOrCreateProxy(ctx, input.SyncProxies, &proxies, src.Proxy, fmt.Sprintf("crs-%s", src.Name))
if err != nil {
item.Action = "failed"
item.Error = "proxy sync failed: " + err.Error()
result.Failed++
result.Items = append(result.Items, item)
continue
}
credentials := sanitizeCredentialsMap(src.Credentials)
if v, ok := credentials["token_type"].(string); !ok || strings.TrimSpace(v) == "" {
credentials["token_type"] = "Bearer"
}
// Convert expires_at from RFC3339 to Unix seconds string (recommended to keep consistent with GetCredential())
if expiresAtStr, ok := credentials["expires_at"].(string); ok && strings.TrimSpace(expiresAtStr) != "" {
if t, err := time.Parse(time.RFC3339, expiresAtStr); err == nil {
credentials["expires_at"] = strconv.FormatInt(t.Unix(), 10)
}
}
extra := make(map[string]any)
if src.Extra != nil {
for k, v := range src.Extra {
extra[k] = v
}
}
extra["crs_account_id"] = src.ID
extra["crs_kind"] = src.Kind
extra["crs_synced_at"] = now
existing, err := s.accountRepo.GetByCRSAccountID(ctx, src.ID)
if err != nil {
item.Action = "failed"
item.Error = "db lookup failed: " + err.Error()
result.Failed++
result.Items = append(result.Items, item)
continue
}
if existing == nil {
account := &model.Account{
Name: defaultName(src.Name, src.ID),
Platform: model.PlatformGemini,
Type: model.AccountTypeOAuth,
Credentials: model.JSONB(credentials),
Extra: model.JSONB(extra),
ProxyID: proxyID,
Concurrency: 3,
Priority: clampPriority(src.Priority),
Status: mapCRSStatus(src.IsActive, src.Status),
Schedulable: src.Schedulable,
}
if err := s.accountRepo.Create(ctx, account); err != nil {
item.Action = "failed"
item.Error = "create failed: " + err.Error()
result.Failed++
result.Items = append(result.Items, item)
continue
}
if refreshedCreds := s.refreshOAuthToken(ctx, account); refreshedCreds != nil {
account.Credentials = refreshedCreds
_ = s.accountRepo.Update(ctx, account)
}
item.Action = "created"
result.Created++
result.Items = append(result.Items, item)
continue
}
existing.Extra = mergeJSONB(existing.Extra, extra)
existing.Name = defaultName(src.Name, src.ID)
existing.Platform = model.PlatformGemini
existing.Type = model.AccountTypeOAuth
existing.Credentials = mergeJSONB(existing.Credentials, credentials)
if proxyID != nil {
existing.ProxyID = proxyID
}
existing.Concurrency = 3
existing.Priority = clampPriority(src.Priority)
existing.Status = mapCRSStatus(src.IsActive, src.Status)
existing.Schedulable = src.Schedulable
if err := s.accountRepo.Update(ctx, existing); err != nil {
item.Action = "failed"
item.Error = "update failed: " + err.Error()
result.Failed++
result.Items = append(result.Items, item)
continue
}
if refreshedCreds := s.refreshOAuthToken(ctx, existing); refreshedCreds != nil {
existing.Credentials = refreshedCreds
_ = s.accountRepo.Update(ctx, existing)
}
item.Action = "updated"
result.Updated++
result.Items = append(result.Items, item)
}
// Gemini API Key -> sub2api gemini apikey
for _, src := range exported.Data.GeminiAPIKeyAccounts {
item := SyncFromCRSItemResult{
CRSAccountID: src.ID,
Kind: src.Kind,
Name: src.Name,
}
apiKey, _ := src.Credentials["api_key"].(string)
if strings.TrimSpace(apiKey) == "" {
item.Action = "failed"
item.Error = "missing api_key"
result.Failed++
result.Items = append(result.Items, item)
continue
}
proxyID, err := s.mapOrCreateProxy(ctx, input.SyncProxies, &proxies, src.Proxy, fmt.Sprintf("crs-%s", src.Name))
if err != nil {
item.Action = "failed"
item.Error = "proxy sync failed: " + err.Error()
result.Failed++
result.Items = append(result.Items, item)
continue
}
credentials := sanitizeCredentialsMap(src.Credentials)
if baseURL, ok := credentials["base_url"].(string); !ok || strings.TrimSpace(baseURL) == "" {
credentials["base_url"] = "https://generativelanguage.googleapis.com"
}
extra := make(map[string]any)
if src.Extra != nil {
for k, v := range src.Extra {
extra[k] = v
}
}
extra["crs_account_id"] = src.ID
extra["crs_kind"] = src.Kind
extra["crs_synced_at"] = now
existing, err := s.accountRepo.GetByCRSAccountID(ctx, src.ID)
if err != nil {
item.Action = "failed"
item.Error = "db lookup failed: " + err.Error()
result.Failed++
result.Items = append(result.Items, item)
continue
}
if existing == nil {
account := &model.Account{
Name: defaultName(src.Name, src.ID),
Platform: model.PlatformGemini,
Type: model.AccountTypeApiKey,
Credentials: model.JSONB(credentials),
Extra: model.JSONB(extra),
ProxyID: proxyID,
Concurrency: 3,
Priority: clampPriority(src.Priority),
Status: mapCRSStatus(src.IsActive, src.Status),
Schedulable: src.Schedulable,
}
if err := s.accountRepo.Create(ctx, account); err != nil {
item.Action = "failed"
item.Error = "create failed: " + err.Error()
result.Failed++
result.Items = append(result.Items, item)
continue
}
item.Action = "created"
result.Created++
result.Items = append(result.Items, item)
continue
}
existing.Extra = mergeJSONB(existing.Extra, extra)
existing.Name = defaultName(src.Name, src.ID)
existing.Platform = model.PlatformGemini
existing.Type = model.AccountTypeApiKey
existing.Credentials = mergeJSONB(existing.Credentials, credentials)
if proxyID != nil {
existing.ProxyID = proxyID
}
existing.Concurrency = 3
existing.Priority = clampPriority(src.Priority)
existing.Status = mapCRSStatus(src.IsActive, src.Status)
existing.Schedulable = src.Schedulable
if err := s.accountRepo.Update(ctx, existing); err != nil {
item.Action = "failed"
item.Error = "update failed: " + err.Error()
result.Failed++
result.Items = append(result.Items, item)
continue
}
item.Action = "updated"
result.Updated++
result.Items = append(result.Items, item)
}
return result, nil
}
@@ -947,6 +1203,21 @@ func (s *CRSSyncService) refreshOAuthToken(ctx context.Context, account *model.A
}
}
}
case model.PlatformGemini:
if s.geminiOAuthService == nil {
return nil
}
tokenInfo, refreshErr := s.geminiOAuthService.RefreshAccountToken(ctx, account)
if refreshErr != nil {
err = refreshErr
} else {
newCredentials = s.geminiOAuthService.BuildAccountCredentials(tokenInfo)
for k, v := range account.Credentials {
if _, exists := newCredentials[k]; !exists {
newCredentials[k] = v
}
}
}
default:
return nil
}

View File

@@ -393,27 +393,32 @@ func (s *PricingService) GetModelPricing(modelName string) *LiteLLMModelPricing
return nil
}
// 标准化模型名称
modelLower := strings.ToLower(modelName)
// 标准化模型名称(同时兼容 "models/xxx"、VertexAI 资源名等前缀)
modelLower := strings.ToLower(strings.TrimSpace(modelName))
lookupCandidates := s.buildModelLookupCandidates(modelLower)
// 1. 精确匹配
if pricing, ok := s.pricingData[modelLower]; ok {
return pricing
}
if pricing, ok := s.pricingData[modelName]; ok {
return pricing
for _, candidate := range lookupCandidates {
if candidate == "" {
continue
}
if pricing, ok := s.pricingData[candidate]; ok {
return pricing
}
}
// 2. 处理常见的模型名称变体
// claude-opus-4-5-20251101 -> claude-opus-4.5-20251101
normalized := strings.ReplaceAll(modelLower, "-4-5-", "-4.5-")
if pricing, ok := s.pricingData[normalized]; ok {
return pricing
for _, candidate := range lookupCandidates {
normalized := strings.ReplaceAll(candidate, "-4-5-", "-4.5-")
if pricing, ok := s.pricingData[normalized]; ok {
return pricing
}
}
// 3. 尝试模糊匹配(去掉版本号后缀)
// claude-opus-4-5-20251101 -> claude-opus-4.5
baseName := s.extractBaseName(modelLower)
baseName := s.extractBaseName(lookupCandidates[0])
for key, pricing := range s.pricingData {
keyBase := s.extractBaseName(strings.ToLower(key))
if keyBase == baseName {
@@ -422,18 +427,84 @@ func (s *PricingService) GetModelPricing(modelName string) *LiteLLMModelPricing
}
// 4. 基于模型系列匹配Claude
if pricing := s.matchByModelFamily(modelLower); pricing != nil {
if pricing := s.matchByModelFamily(lookupCandidates[0]); pricing != nil {
return pricing
}
// 5. OpenAI 模型回退策略
if strings.HasPrefix(modelLower, "gpt-") {
return s.matchOpenAIModel(modelLower)
if strings.HasPrefix(lookupCandidates[0], "gpt-") {
return s.matchOpenAIModel(lookupCandidates[0])
}
return nil
}
func (s *PricingService) buildModelLookupCandidates(modelLower string) []string {
// Prefer canonical model name first (this also improves billing compatibility with "models/xxx").
candidates := []string{
normalizeModelNameForPricing(modelLower),
modelLower,
}
for _, cand := range []string{
strings.TrimPrefix(modelLower, "models/"),
lastSegment(modelLower),
lastSegment(strings.TrimPrefix(modelLower, "models/")),
} {
candidates = append(candidates, cand)
}
seen := make(map[string]struct{}, len(candidates))
out := make([]string, 0, len(candidates))
for _, c := range candidates {
c = strings.TrimSpace(c)
if c == "" {
continue
}
if _, ok := seen[c]; ok {
continue
}
seen[c] = struct{}{}
out = append(out, c)
}
if len(out) == 0 {
return []string{modelLower}
}
return out
}
func normalizeModelNameForPricing(model string) string {
// Common Gemini/VertexAI forms:
// - models/gemini-2.0-flash-exp
// - publishers/google/models/gemini-1.5-pro
// - projects/.../locations/.../publishers/google/models/gemini-1.5-pro
model = strings.TrimSpace(model)
model = strings.TrimLeft(model, "/")
if strings.HasPrefix(model, "models/") {
model = strings.TrimPrefix(model, "models/")
}
if strings.HasPrefix(model, "publishers/google/models/") {
model = strings.TrimPrefix(model, "publishers/google/models/")
}
if idx := strings.LastIndex(model, "/publishers/google/models/"); idx != -1 {
model = model[idx+len("/publishers/google/models/"):]
}
if idx := strings.LastIndex(model, "/models/"); idx != -1 {
model = model[idx+len("/models/"):]
}
model = strings.TrimLeft(model, "/")
return model
}
func lastSegment(model string) string {
if idx := strings.LastIndex(model, "/"); idx != -1 {
return model[idx+1:]
}
return model
}
// extractBaseName 提取基础模型名称(去掉日期版本号)
func (s *PricingService) extractBaseName(model string) string {
// 移除日期后缀 (如 -20251101, -20241022)

View File

@@ -27,6 +27,7 @@ func NewTokenRefreshService(
accountRepo AccountRepository,
oauthService *OAuthService,
openaiOAuthService *OpenAIOAuthService,
geminiOAuthService *GeminiOAuthService,
cfg *config.Config,
) *TokenRefreshService {
s := &TokenRefreshService{
@@ -39,6 +40,7 @@ func NewTokenRefreshService(
s.refreshers = []TokenRefresher{
NewClaudeTokenRefresher(oauthService),
NewOpenAITokenRefresher(openaiOAuthService),
NewGeminiTokenRefresher(geminiOAuthService),
}
return s