feat(service): 扩展 CRS 同步和定价服务支持 Gemini
- CRS 同步服务新增 Gemini 账号同步逻辑(+273行) - 定价服务扩展 Gemini 模型定价计算(+99行) - 更新 Token 刷新服务集成 Gemini - 更新相关单元测试
This commit is contained in:
@@ -120,10 +120,9 @@ func (s *PricingServiceSuite) TestFetchHashText_WhitespaceOnly() {
|
||||
|
||||
func (s *PricingServiceSuite) TestFetchPricingJSON_ContextCancel() {
|
||||
started := make(chan struct{})
|
||||
block := make(chan struct{})
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
close(started)
|
||||
<-block
|
||||
<-r.Context().Done()
|
||||
}))
|
||||
|
||||
ctx, cancel := context.WithCancel(s.ctx)
|
||||
@@ -136,7 +135,6 @@ func (s *PricingServiceSuite) TestFetchPricingJSON_ContextCancel() {
|
||||
|
||||
<-started
|
||||
cancel()
|
||||
close(block)
|
||||
|
||||
err := <-done
|
||||
require.Error(s.T(), err)
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -20,6 +21,7 @@ type CRSSyncService struct {
|
||||
proxyRepo ProxyRepository
|
||||
oauthService *OAuthService
|
||||
openaiOAuthService *OpenAIOAuthService
|
||||
geminiOAuthService *GeminiOAuthService
|
||||
}
|
||||
|
||||
func NewCRSSyncService(
|
||||
@@ -27,12 +29,14 @@ func NewCRSSyncService(
|
||||
proxyRepo ProxyRepository,
|
||||
oauthService *OAuthService,
|
||||
openaiOAuthService *OpenAIOAuthService,
|
||||
geminiOAuthService *GeminiOAuthService,
|
||||
) *CRSSyncService {
|
||||
return &CRSSyncService{
|
||||
accountRepo: accountRepo,
|
||||
proxyRepo: proxyRepo,
|
||||
oauthService: oauthService,
|
||||
openaiOAuthService: openaiOAuthService,
|
||||
geminiOAuthService: geminiOAuthService,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -77,6 +81,8 @@ type crsExportResponse struct {
|
||||
ClaudeConsoleAccounts []crsConsoleAccount `json:"claudeConsoleAccounts"`
|
||||
OpenAIOAuthAccounts []crsOpenAIOAuthAccount `json:"openaiOAuthAccounts"`
|
||||
OpenAIResponsesAccounts []crsOpenAIResponsesAccount `json:"openaiResponsesAccounts"`
|
||||
GeminiOAuthAccounts []crsGeminiOAuthAccount `json:"geminiOAuthAccounts"`
|
||||
GeminiAPIKeyAccounts []crsGeminiAPIKeyAccount `json:"geminiApiKeyAccounts"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
@@ -149,6 +155,37 @@ type crsOpenAIOAuthAccount struct {
|
||||
Extra map[string]any `json:"extra"`
|
||||
}
|
||||
|
||||
type crsGeminiOAuthAccount struct {
|
||||
Kind string `json:"kind"`
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Platform string `json:"platform"`
|
||||
AuthType string `json:"authType"` // oauth
|
||||
IsActive bool `json:"isActive"`
|
||||
Schedulable bool `json:"schedulable"`
|
||||
Priority int `json:"priority"`
|
||||
Status string `json:"status"`
|
||||
Proxy *crsProxy `json:"proxy"`
|
||||
Credentials map[string]any `json:"credentials"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
}
|
||||
|
||||
type crsGeminiAPIKeyAccount struct {
|
||||
Kind string `json:"kind"`
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Platform string `json:"platform"`
|
||||
IsActive bool `json:"isActive"`
|
||||
Schedulable bool `json:"schedulable"`
|
||||
Priority int `json:"priority"`
|
||||
Status string `json:"status"`
|
||||
Proxy *crsProxy `json:"proxy"`
|
||||
Credentials map[string]any `json:"credentials"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
}
|
||||
|
||||
func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput) (*SyncFromCRSResult, error) {
|
||||
baseURL, err := normalizeBaseURL(input.BaseURL)
|
||||
if err != nil {
|
||||
@@ -176,7 +213,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
||||
Items: make(
|
||||
[]SyncFromCRSItemResult,
|
||||
0,
|
||||
len(exported.Data.ClaudeAccounts)+len(exported.Data.ClaudeConsoleAccounts)+len(exported.Data.OpenAIOAuthAccounts)+len(exported.Data.OpenAIResponsesAccounts),
|
||||
len(exported.Data.ClaudeAccounts)+len(exported.Data.ClaudeConsoleAccounts)+len(exported.Data.OpenAIOAuthAccounts)+len(exported.Data.OpenAIResponsesAccounts)+len(exported.Data.GeminiOAuthAccounts)+len(exported.Data.GeminiAPIKeyAccounts),
|
||||
),
|
||||
}
|
||||
|
||||
@@ -680,6 +717,225 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
||||
result.Items = append(result.Items, item)
|
||||
}
|
||||
|
||||
// Gemini OAuth -> sub2api gemini oauth
|
||||
for _, src := range exported.Data.GeminiOAuthAccounts {
|
||||
item := SyncFromCRSItemResult{
|
||||
CRSAccountID: src.ID,
|
||||
Kind: src.Kind,
|
||||
Name: src.Name,
|
||||
}
|
||||
|
||||
refreshToken, _ := src.Credentials["refresh_token"].(string)
|
||||
if strings.TrimSpace(refreshToken) == "" {
|
||||
item.Action = "failed"
|
||||
item.Error = "missing refresh_token"
|
||||
result.Failed++
|
||||
result.Items = append(result.Items, item)
|
||||
continue
|
||||
}
|
||||
|
||||
proxyID, err := s.mapOrCreateProxy(ctx, input.SyncProxies, &proxies, src.Proxy, fmt.Sprintf("crs-%s", src.Name))
|
||||
if err != nil {
|
||||
item.Action = "failed"
|
||||
item.Error = "proxy sync failed: " + err.Error()
|
||||
result.Failed++
|
||||
result.Items = append(result.Items, item)
|
||||
continue
|
||||
}
|
||||
|
||||
credentials := sanitizeCredentialsMap(src.Credentials)
|
||||
if v, ok := credentials["token_type"].(string); !ok || strings.TrimSpace(v) == "" {
|
||||
credentials["token_type"] = "Bearer"
|
||||
}
|
||||
// Convert expires_at from RFC3339 to Unix seconds string (recommended to keep consistent with GetCredential())
|
||||
if expiresAtStr, ok := credentials["expires_at"].(string); ok && strings.TrimSpace(expiresAtStr) != "" {
|
||||
if t, err := time.Parse(time.RFC3339, expiresAtStr); err == nil {
|
||||
credentials["expires_at"] = strconv.FormatInt(t.Unix(), 10)
|
||||
}
|
||||
}
|
||||
|
||||
extra := make(map[string]any)
|
||||
if src.Extra != nil {
|
||||
for k, v := range src.Extra {
|
||||
extra[k] = v
|
||||
}
|
||||
}
|
||||
extra["crs_account_id"] = src.ID
|
||||
extra["crs_kind"] = src.Kind
|
||||
extra["crs_synced_at"] = now
|
||||
|
||||
existing, err := s.accountRepo.GetByCRSAccountID(ctx, src.ID)
|
||||
if err != nil {
|
||||
item.Action = "failed"
|
||||
item.Error = "db lookup failed: " + err.Error()
|
||||
result.Failed++
|
||||
result.Items = append(result.Items, item)
|
||||
continue
|
||||
}
|
||||
|
||||
if existing == nil {
|
||||
account := &model.Account{
|
||||
Name: defaultName(src.Name, src.ID),
|
||||
Platform: model.PlatformGemini,
|
||||
Type: model.AccountTypeOAuth,
|
||||
Credentials: model.JSONB(credentials),
|
||||
Extra: model.JSONB(extra),
|
||||
ProxyID: proxyID,
|
||||
Concurrency: 3,
|
||||
Priority: clampPriority(src.Priority),
|
||||
Status: mapCRSStatus(src.IsActive, src.Status),
|
||||
Schedulable: src.Schedulable,
|
||||
}
|
||||
if err := s.accountRepo.Create(ctx, account); err != nil {
|
||||
item.Action = "failed"
|
||||
item.Error = "create failed: " + err.Error()
|
||||
result.Failed++
|
||||
result.Items = append(result.Items, item)
|
||||
continue
|
||||
}
|
||||
if refreshedCreds := s.refreshOAuthToken(ctx, account); refreshedCreds != nil {
|
||||
account.Credentials = refreshedCreds
|
||||
_ = s.accountRepo.Update(ctx, account)
|
||||
}
|
||||
item.Action = "created"
|
||||
result.Created++
|
||||
result.Items = append(result.Items, item)
|
||||
continue
|
||||
}
|
||||
|
||||
existing.Extra = mergeJSONB(existing.Extra, extra)
|
||||
existing.Name = defaultName(src.Name, src.ID)
|
||||
existing.Platform = model.PlatformGemini
|
||||
existing.Type = model.AccountTypeOAuth
|
||||
existing.Credentials = mergeJSONB(existing.Credentials, credentials)
|
||||
if proxyID != nil {
|
||||
existing.ProxyID = proxyID
|
||||
}
|
||||
existing.Concurrency = 3
|
||||
existing.Priority = clampPriority(src.Priority)
|
||||
existing.Status = mapCRSStatus(src.IsActive, src.Status)
|
||||
existing.Schedulable = src.Schedulable
|
||||
|
||||
if err := s.accountRepo.Update(ctx, existing); err != nil {
|
||||
item.Action = "failed"
|
||||
item.Error = "update failed: " + err.Error()
|
||||
result.Failed++
|
||||
result.Items = append(result.Items, item)
|
||||
continue
|
||||
}
|
||||
|
||||
if refreshedCreds := s.refreshOAuthToken(ctx, existing); refreshedCreds != nil {
|
||||
existing.Credentials = refreshedCreds
|
||||
_ = s.accountRepo.Update(ctx, existing)
|
||||
}
|
||||
|
||||
item.Action = "updated"
|
||||
result.Updated++
|
||||
result.Items = append(result.Items, item)
|
||||
}
|
||||
|
||||
// Gemini API Key -> sub2api gemini apikey
|
||||
for _, src := range exported.Data.GeminiAPIKeyAccounts {
|
||||
item := SyncFromCRSItemResult{
|
||||
CRSAccountID: src.ID,
|
||||
Kind: src.Kind,
|
||||
Name: src.Name,
|
||||
}
|
||||
|
||||
apiKey, _ := src.Credentials["api_key"].(string)
|
||||
if strings.TrimSpace(apiKey) == "" {
|
||||
item.Action = "failed"
|
||||
item.Error = "missing api_key"
|
||||
result.Failed++
|
||||
result.Items = append(result.Items, item)
|
||||
continue
|
||||
}
|
||||
|
||||
proxyID, err := s.mapOrCreateProxy(ctx, input.SyncProxies, &proxies, src.Proxy, fmt.Sprintf("crs-%s", src.Name))
|
||||
if err != nil {
|
||||
item.Action = "failed"
|
||||
item.Error = "proxy sync failed: " + err.Error()
|
||||
result.Failed++
|
||||
result.Items = append(result.Items, item)
|
||||
continue
|
||||
}
|
||||
|
||||
credentials := sanitizeCredentialsMap(src.Credentials)
|
||||
if baseURL, ok := credentials["base_url"].(string); !ok || strings.TrimSpace(baseURL) == "" {
|
||||
credentials["base_url"] = "https://generativelanguage.googleapis.com"
|
||||
}
|
||||
|
||||
extra := make(map[string]any)
|
||||
if src.Extra != nil {
|
||||
for k, v := range src.Extra {
|
||||
extra[k] = v
|
||||
}
|
||||
}
|
||||
extra["crs_account_id"] = src.ID
|
||||
extra["crs_kind"] = src.Kind
|
||||
extra["crs_synced_at"] = now
|
||||
|
||||
existing, err := s.accountRepo.GetByCRSAccountID(ctx, src.ID)
|
||||
if err != nil {
|
||||
item.Action = "failed"
|
||||
item.Error = "db lookup failed: " + err.Error()
|
||||
result.Failed++
|
||||
result.Items = append(result.Items, item)
|
||||
continue
|
||||
}
|
||||
|
||||
if existing == nil {
|
||||
account := &model.Account{
|
||||
Name: defaultName(src.Name, src.ID),
|
||||
Platform: model.PlatformGemini,
|
||||
Type: model.AccountTypeApiKey,
|
||||
Credentials: model.JSONB(credentials),
|
||||
Extra: model.JSONB(extra),
|
||||
ProxyID: proxyID,
|
||||
Concurrency: 3,
|
||||
Priority: clampPriority(src.Priority),
|
||||
Status: mapCRSStatus(src.IsActive, src.Status),
|
||||
Schedulable: src.Schedulable,
|
||||
}
|
||||
if err := s.accountRepo.Create(ctx, account); err != nil {
|
||||
item.Action = "failed"
|
||||
item.Error = "create failed: " + err.Error()
|
||||
result.Failed++
|
||||
result.Items = append(result.Items, item)
|
||||
continue
|
||||
}
|
||||
item.Action = "created"
|
||||
result.Created++
|
||||
result.Items = append(result.Items, item)
|
||||
continue
|
||||
}
|
||||
|
||||
existing.Extra = mergeJSONB(existing.Extra, extra)
|
||||
existing.Name = defaultName(src.Name, src.ID)
|
||||
existing.Platform = model.PlatformGemini
|
||||
existing.Type = model.AccountTypeApiKey
|
||||
existing.Credentials = mergeJSONB(existing.Credentials, credentials)
|
||||
if proxyID != nil {
|
||||
existing.ProxyID = proxyID
|
||||
}
|
||||
existing.Concurrency = 3
|
||||
existing.Priority = clampPriority(src.Priority)
|
||||
existing.Status = mapCRSStatus(src.IsActive, src.Status)
|
||||
existing.Schedulable = src.Schedulable
|
||||
|
||||
if err := s.accountRepo.Update(ctx, existing); err != nil {
|
||||
item.Action = "failed"
|
||||
item.Error = "update failed: " + err.Error()
|
||||
result.Failed++
|
||||
result.Items = append(result.Items, item)
|
||||
continue
|
||||
}
|
||||
|
||||
item.Action = "updated"
|
||||
result.Updated++
|
||||
result.Items = append(result.Items, item)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@@ -947,6 +1203,21 @@ func (s *CRSSyncService) refreshOAuthToken(ctx context.Context, account *model.A
|
||||
}
|
||||
}
|
||||
}
|
||||
case model.PlatformGemini:
|
||||
if s.geminiOAuthService == nil {
|
||||
return nil
|
||||
}
|
||||
tokenInfo, refreshErr := s.geminiOAuthService.RefreshAccountToken(ctx, account)
|
||||
if refreshErr != nil {
|
||||
err = refreshErr
|
||||
} else {
|
||||
newCredentials = s.geminiOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
for k, v := range account.Credentials {
|
||||
if _, exists := newCredentials[k]; !exists {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -393,27 +393,32 @@ func (s *PricingService) GetModelPricing(modelName string) *LiteLLMModelPricing
|
||||
return nil
|
||||
}
|
||||
|
||||
// 标准化模型名称
|
||||
modelLower := strings.ToLower(modelName)
|
||||
// 标准化模型名称(同时兼容 "models/xxx"、VertexAI 资源名等前缀)
|
||||
modelLower := strings.ToLower(strings.TrimSpace(modelName))
|
||||
lookupCandidates := s.buildModelLookupCandidates(modelLower)
|
||||
|
||||
// 1. 精确匹配
|
||||
if pricing, ok := s.pricingData[modelLower]; ok {
|
||||
return pricing
|
||||
}
|
||||
if pricing, ok := s.pricingData[modelName]; ok {
|
||||
return pricing
|
||||
for _, candidate := range lookupCandidates {
|
||||
if candidate == "" {
|
||||
continue
|
||||
}
|
||||
if pricing, ok := s.pricingData[candidate]; ok {
|
||||
return pricing
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 处理常见的模型名称变体
|
||||
// claude-opus-4-5-20251101 -> claude-opus-4.5-20251101
|
||||
normalized := strings.ReplaceAll(modelLower, "-4-5-", "-4.5-")
|
||||
if pricing, ok := s.pricingData[normalized]; ok {
|
||||
return pricing
|
||||
for _, candidate := range lookupCandidates {
|
||||
normalized := strings.ReplaceAll(candidate, "-4-5-", "-4.5-")
|
||||
if pricing, ok := s.pricingData[normalized]; ok {
|
||||
return pricing
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 尝试模糊匹配(去掉版本号后缀)
|
||||
// claude-opus-4-5-20251101 -> claude-opus-4.5
|
||||
baseName := s.extractBaseName(modelLower)
|
||||
baseName := s.extractBaseName(lookupCandidates[0])
|
||||
for key, pricing := range s.pricingData {
|
||||
keyBase := s.extractBaseName(strings.ToLower(key))
|
||||
if keyBase == baseName {
|
||||
@@ -422,18 +427,84 @@ func (s *PricingService) GetModelPricing(modelName string) *LiteLLMModelPricing
|
||||
}
|
||||
|
||||
// 4. 基于模型系列匹配(Claude)
|
||||
if pricing := s.matchByModelFamily(modelLower); pricing != nil {
|
||||
if pricing := s.matchByModelFamily(lookupCandidates[0]); pricing != nil {
|
||||
return pricing
|
||||
}
|
||||
|
||||
// 5. OpenAI 模型回退策略
|
||||
if strings.HasPrefix(modelLower, "gpt-") {
|
||||
return s.matchOpenAIModel(modelLower)
|
||||
if strings.HasPrefix(lookupCandidates[0], "gpt-") {
|
||||
return s.matchOpenAIModel(lookupCandidates[0])
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *PricingService) buildModelLookupCandidates(modelLower string) []string {
|
||||
// Prefer canonical model name first (this also improves billing compatibility with "models/xxx").
|
||||
candidates := []string{
|
||||
normalizeModelNameForPricing(modelLower),
|
||||
modelLower,
|
||||
}
|
||||
for _, cand := range []string{
|
||||
strings.TrimPrefix(modelLower, "models/"),
|
||||
lastSegment(modelLower),
|
||||
lastSegment(strings.TrimPrefix(modelLower, "models/")),
|
||||
} {
|
||||
candidates = append(candidates, cand)
|
||||
}
|
||||
|
||||
seen := make(map[string]struct{}, len(candidates))
|
||||
out := make([]string, 0, len(candidates))
|
||||
for _, c := range candidates {
|
||||
c = strings.TrimSpace(c)
|
||||
if c == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[c]; ok {
|
||||
continue
|
||||
}
|
||||
seen[c] = struct{}{}
|
||||
out = append(out, c)
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return []string{modelLower}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func normalizeModelNameForPricing(model string) string {
|
||||
// Common Gemini/VertexAI forms:
|
||||
// - models/gemini-2.0-flash-exp
|
||||
// - publishers/google/models/gemini-1.5-pro
|
||||
// - projects/.../locations/.../publishers/google/models/gemini-1.5-pro
|
||||
model = strings.TrimSpace(model)
|
||||
model = strings.TrimLeft(model, "/")
|
||||
|
||||
if strings.HasPrefix(model, "models/") {
|
||||
model = strings.TrimPrefix(model, "models/")
|
||||
}
|
||||
if strings.HasPrefix(model, "publishers/google/models/") {
|
||||
model = strings.TrimPrefix(model, "publishers/google/models/")
|
||||
}
|
||||
|
||||
if idx := strings.LastIndex(model, "/publishers/google/models/"); idx != -1 {
|
||||
model = model[idx+len("/publishers/google/models/"):]
|
||||
}
|
||||
if idx := strings.LastIndex(model, "/models/"); idx != -1 {
|
||||
model = model[idx+len("/models/"):]
|
||||
}
|
||||
|
||||
model = strings.TrimLeft(model, "/")
|
||||
return model
|
||||
}
|
||||
|
||||
func lastSegment(model string) string {
|
||||
if idx := strings.LastIndex(model, "/"); idx != -1 {
|
||||
return model[idx+1:]
|
||||
}
|
||||
return model
|
||||
}
|
||||
|
||||
// extractBaseName 提取基础模型名称(去掉日期版本号)
|
||||
func (s *PricingService) extractBaseName(model string) string {
|
||||
// 移除日期后缀 (如 -20251101, -20241022)
|
||||
|
||||
@@ -27,6 +27,7 @@ func NewTokenRefreshService(
|
||||
accountRepo AccountRepository,
|
||||
oauthService *OAuthService,
|
||||
openaiOAuthService *OpenAIOAuthService,
|
||||
geminiOAuthService *GeminiOAuthService,
|
||||
cfg *config.Config,
|
||||
) *TokenRefreshService {
|
||||
s := &TokenRefreshService{
|
||||
@@ -39,6 +40,7 @@ func NewTokenRefreshService(
|
||||
s.refreshers = []TokenRefresher{
|
||||
NewClaudeTokenRefresher(oauthService),
|
||||
NewOpenAITokenRefresher(openaiOAuthService),
|
||||
NewGeminiTokenRefresher(geminiOAuthService),
|
||||
}
|
||||
|
||||
return s
|
||||
|
||||
Reference in New Issue
Block a user