diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go new file mode 100644 index 00000000..49fe7135 --- /dev/null +++ b/backend/internal/service/gemini_messages_compat_service.go @@ -0,0 +1,1298 @@ +package service + +import ( + "bufio" + "bytes" + "context" + "crypto/rand" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "io" + "log" + "math" + mathrand "math/rand" + "net/http" + "regexp" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/model" + "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" + "github.com/Wei-Shaw/sub2api/internal/service/ports" + + "github.com/gin-gonic/gin" +) + +const geminiStickySessionTTL = time.Hour + +const ( + geminiMaxRetries = 5 + geminiRetryBaseDelay = 1 * time.Second + geminiRetryMaxDelay = 16 * time.Second +) + +type GeminiMessagesCompatService struct { + accountRepo ports.AccountRepository + cache ports.GatewayCache + tokenProvider *GeminiTokenProvider + rateLimitService *RateLimitService + httpUpstream ports.HTTPUpstream +} + +func NewGeminiMessagesCompatService( + accountRepo ports.AccountRepository, + cache ports.GatewayCache, + tokenProvider *GeminiTokenProvider, + rateLimitService *RateLimitService, + httpUpstream ports.HTTPUpstream, +) *GeminiMessagesCompatService { + return &GeminiMessagesCompatService{ + accountRepo: accountRepo, + cache: cache, + tokenProvider: tokenProvider, + rateLimitService: rateLimitService, + httpUpstream: httpUpstream, + } +} + +func (s *GeminiMessagesCompatService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*model.Account, error) { + cacheKey := "gemini:" + sessionHash + if sessionHash != "" { + accountID, err := s.cache.GetSessionAccountID(ctx, cacheKey) + if err == nil && accountID > 0 { + account, err := s.accountRepo.GetByID(ctx, accountID) + if err == nil && account.IsSchedulable() && account.Platform == model.PlatformGemini && (requestedModel == "" || account.IsModelSupported(requestedModel)) { + _ = s.cache.RefreshSessionTTL(ctx, cacheKey, geminiStickySessionTTL) + return account, nil + } + } + } + + var accounts []model.Account + var err error + if groupID != nil { + accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, model.PlatformGemini) + } else { + accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, model.PlatformGemini) + } + if err != nil { + return nil, fmt.Errorf("query accounts failed: %w", err) + } + + var selected *model.Account + for i := range accounts { + acc := &accounts[i] + if requestedModel != "" && !acc.IsModelSupported(requestedModel) { + continue + } + if selected == nil { + selected = acc + continue + } + if acc.Priority < selected.Priority { + selected = acc + } else if acc.Priority == selected.Priority { + if acc.LastUsedAt == nil || (selected.LastUsedAt != nil && acc.LastUsedAt.Before(*selected.LastUsedAt)) { + selected = acc + } + } + } + + if selected == nil { + if requestedModel != "" { + return nil, fmt.Errorf("no available Gemini accounts supporting model: %s", requestedModel) + } + return nil, errors.New("no available Gemini accounts") + } + + if sessionHash != "" { + _ = s.cache.SetSessionAccountID(ctx, cacheKey, selected.ID, geminiStickySessionTTL) + } + + return selected, nil +} + +func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Context, account *model.Account, body []byte) (*ForwardResult, error) { + startTime := time.Now() + + var req struct { + Model string `json:"model"` + Stream bool `json:"stream"` + } + if err := json.Unmarshal(body, &req); err != nil { + return nil, fmt.Errorf("parse request: %w", err) + } + if strings.TrimSpace(req.Model) == "" { + return nil, fmt.Errorf("missing model") + } + + originalModel := req.Model + mappedModel := req.Model + if account.Type == model.AccountTypeApiKey { + mappedModel = account.GetMappedModel(req.Model) + } + + geminiReq, err := convertClaudeMessagesToGeminiGenerateContent(body) + if err != nil { + return nil, s.writeClaudeError(c, http.StatusBadRequest, "invalid_request_error", err.Error()) + } + + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + var requestIDHeader string + var buildReq func(ctx context.Context) (*http.Request, string, error) + + switch account.Type { + case model.AccountTypeApiKey: + buildReq = func(ctx context.Context) (*http.Request, string, error) { + apiKey := account.GetCredential("api_key") + if strings.TrimSpace(apiKey) == "" { + return nil, "", errors.New("Gemini api_key not configured") + } + + baseURL := strings.TrimRight(account.GetCredential("base_url"), "/") + if baseURL == "" { + baseURL = geminicli.AIStudioBaseURL + } + + action := "generateContent" + if req.Stream { + action = "streamGenerateContent" + } + fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(baseURL, "/"), mappedModel, action) + if req.Stream { + fullURL += "?alt=sse" + } + + upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(geminiReq)) + if err != nil { + return nil, "", err + } + upstreamReq.Header.Set("Content-Type", "application/json") + upstreamReq.Header.Set("x-goog-api-key", apiKey) + return upstreamReq, "x-request-id", nil + } + requestIDHeader = "x-request-id" + + case model.AccountTypeOAuth: + buildReq = func(ctx context.Context) (*http.Request, string, error) { + if s.tokenProvider == nil { + return nil, "", errors.New("Gemini token provider not configured") + } + accessToken, err := s.tokenProvider.GetAccessToken(ctx, account) + if err != nil { + return nil, "", err + } + + projectID := strings.TrimSpace(account.GetCredential("project_id")) + if projectID == "" { + return nil, "", errors.New("missing project_id in account credentials") + } + + action := "generateContent" + if req.Stream { + action = "streamGenerateContent" + } + fullURL := fmt.Sprintf("%s/v1internal:%s", geminicli.GeminiCliBaseURL, action) + if req.Stream { + fullURL += "?alt=sse" + } + + wrapped := map[string]any{ + "model": mappedModel, + "project": projectID, + } + var inner any + if err := json.Unmarshal(geminiReq, &inner); err != nil { + return nil, "", fmt.Errorf("failed to parse gemini request: %w", err) + } + wrapped["request"] = inner + wrappedBytes, _ := json.Marshal(wrapped) + + upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(wrappedBytes)) + if err != nil { + return nil, "", err + } + upstreamReq.Header.Set("Content-Type", "application/json") + upstreamReq.Header.Set("Authorization", "Bearer "+accessToken) + upstreamReq.Header.Set("User-Agent", geminicli.GeminiCLIUserAgent) + return upstreamReq, "x-request-id", nil + } + requestIDHeader = "x-request-id" + + default: + return nil, fmt.Errorf("unsupported account type: %s", account.Type) + } + + if buildReq == nil { + return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Gemini upstream not configured") + } + + var resp *http.Response + for attempt := 1; attempt <= geminiMaxRetries; attempt++ { + upstreamReq, idHeader, err := buildReq(ctx) + if err != nil { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return nil, err + } + // Local build error: don't retry. + if strings.Contains(err.Error(), "missing project_id") { + return nil, s.writeClaudeError(c, http.StatusBadRequest, "invalid_request_error", err.Error()) + } + return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", err.Error()) + } + requestIDHeader = idHeader + + resp, err = s.httpUpstream.Do(upstreamReq, proxyURL) + if err != nil { + if attempt < geminiMaxRetries { + log.Printf("Gemini account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, geminiMaxRetries, err) + sleepGeminiBackoff(attempt) + continue + } + return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries") + } + + if resp.StatusCode >= 400 && s.shouldRetryGeminiUpstreamError(account, resp.StatusCode) { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + if resp.StatusCode == 429 { + // Mark as rate-limited early so concurrent requests avoid this account. + s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + } + if attempt < geminiMaxRetries { + log.Printf("Gemini account %d: upstream status %d, retry %d/%d", account.ID, resp.StatusCode, attempt, geminiMaxRetries) + sleepGeminiBackoff(attempt) + continue + } + return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries") + } + + break + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode >= 400 { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + return nil, s.writeGeminiMappedError(c, resp.StatusCode, respBody) + } + + requestID := resp.Header.Get(requestIDHeader) + if requestID == "" { + requestID = resp.Header.Get("x-goog-request-id") + } + if requestID != "" { + c.Header("x-request-id", requestID) + } + + var usage *ClaudeUsage + var firstTokenMs *int + if req.Stream { + streamRes, err := s.handleStreamingResponse(c, resp, startTime, originalModel) + if err != nil { + return nil, err + } + usage = streamRes.usage + firstTokenMs = streamRes.firstTokenMs + } else { + usage, err = s.handleNonStreamingResponse(c, resp, originalModel) + if err != nil { + return nil, err + } + } + + return &ForwardResult{ + RequestID: requestID, + Usage: *usage, + Model: originalModel, + Stream: req.Stream, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + }, nil +} + +func (s *GeminiMessagesCompatService) shouldRetryGeminiUpstreamError(account *model.Account, statusCode int) bool { + switch statusCode { + case 429, 500, 502, 503, 504, 529: + return true + case 403: + // GeminiCli OAuth occasionally returns 403 transiently (activation/quota propagation); allow retry. + return account != nil && account.Type == model.AccountTypeOAuth + default: + return false + } +} + +func sleepGeminiBackoff(attempt int) { + delay := geminiRetryBaseDelay * time.Duration(1< geminiRetryMaxDelay { + delay = geminiRetryMaxDelay + } + + // +/- 20% jitter + r := mathrand.New(mathrand.NewSource(time.Now().UnixNano())) + jitter := time.Duration(float64(delay) * 0.2 * (r.Float64()*2 - 1)) + sleepFor := delay + jitter + if sleepFor < 0 { + sleepFor = 0 + } + time.Sleep(sleepFor) +} + +func (s *GeminiMessagesCompatService) writeGeminiMappedError(c *gin.Context, upstreamStatus int, body []byte) error { + var statusCode int + var errType, errMsg string + + if mapped := mapGeminiErrorBodyToClaudeError(body); mapped != nil { + errType = mapped.Type + if mapped.Message != "" { + errMsg = mapped.Message + } + if mapped.StatusCode > 0 { + statusCode = mapped.StatusCode + } + } + + switch upstreamStatus { + case 400: + if statusCode == 0 { + statusCode = http.StatusBadRequest + } + if errType == "" { + errType = "invalid_request_error" + } + if errMsg == "" { + errMsg = "Invalid request" + } + case 401: + if statusCode == 0 { + statusCode = http.StatusBadGateway + } + if errType == "" { + errType = "authentication_error" + } + if errMsg == "" { + errMsg = "Upstream authentication failed, please contact administrator" + } + case 403: + if statusCode == 0 { + statusCode = http.StatusBadGateway + } + if errType == "" { + errType = "permission_error" + } + if errMsg == "" { + errMsg = "Upstream access forbidden, please contact administrator" + } + case 404: + if statusCode == 0 { + statusCode = http.StatusNotFound + } + if errType == "" { + errType = "not_found_error" + } + if errMsg == "" { + errMsg = "Resource not found" + } + case 429: + if statusCode == 0 { + statusCode = http.StatusTooManyRequests + } + if errType == "" { + errType = "rate_limit_error" + } + if errMsg == "" { + errMsg = "Upstream rate limit exceeded, please retry later" + } + case 529: + if statusCode == 0 { + statusCode = http.StatusServiceUnavailable + } + if errType == "" { + errType = "overloaded_error" + } + if errMsg == "" { + errMsg = "Upstream service overloaded, please retry later" + } + case 500, 502, 503, 504: + if statusCode == 0 { + statusCode = http.StatusBadGateway + } + if errType == "" { + switch upstreamStatus { + case 504: + errType = "timeout_error" + case 503: + errType = "overloaded_error" + default: + errType = "api_error" + } + } + if errMsg == "" { + errMsg = "Upstream service temporarily unavailable" + } + default: + if statusCode == 0 { + statusCode = http.StatusBadGateway + } + if errType == "" { + errType = "upstream_error" + } + if errMsg == "" { + errMsg = "Upstream request failed" + } + } + + c.JSON(statusCode, gin.H{ + "type": "error", + "error": gin.H{"type": errType, "message": errMsg}, + }) + return fmt.Errorf("upstream error: %d", upstreamStatus) +} + +type claudeErrorMapping struct { + Type string + Message string + StatusCode int +} + +func mapGeminiErrorBodyToClaudeError(body []byte) *claudeErrorMapping { + if len(body) == 0 { + return nil + } + + var parsed struct { + Error struct { + Code int `json:"code"` + Message string `json:"message"` + Status string `json:"status"` + } `json:"error"` + } + if err := json.Unmarshal(body, &parsed); err != nil { + return nil + } + if strings.TrimSpace(parsed.Error.Status) == "" && parsed.Error.Code == 0 && strings.TrimSpace(parsed.Error.Message) == "" { + return nil + } + + mapped := &claudeErrorMapping{ + Type: mapGeminiStatusToClaudeErrorType(parsed.Error.Status), + Message: "", + } + if mapped.Type == "" { + mapped.Type = "upstream_error" + } + + switch strings.ToUpper(strings.TrimSpace(parsed.Error.Status)) { + case "INVALID_ARGUMENT": + mapped.StatusCode = http.StatusBadRequest + case "NOT_FOUND": + mapped.StatusCode = http.StatusNotFound + case "RESOURCE_EXHAUSTED": + mapped.StatusCode = http.StatusTooManyRequests + default: + // Keep StatusCode unset and let HTTP status mapping decide. + } + + // Keep messages generic by default; upstream error message can be long or include sensitive fragments. + return mapped +} + +func mapGeminiStatusToClaudeErrorType(status string) string { + switch strings.ToUpper(strings.TrimSpace(status)) { + case "INVALID_ARGUMENT": + return "invalid_request_error" + case "PERMISSION_DENIED": + return "permission_error" + case "NOT_FOUND": + return "not_found_error" + case "RESOURCE_EXHAUSTED": + return "rate_limit_error" + case "UNAUTHENTICATED": + return "authentication_error" + case "UNAVAILABLE": + return "overloaded_error" + case "INTERNAL": + return "api_error" + case "DEADLINE_EXCEEDED": + return "timeout_error" + default: + return "" + } +} + +type geminiStreamResult struct { + usage *ClaudeUsage + firstTokenMs *int +} + +func (s *GeminiMessagesCompatService) handleNonStreamingResponse(c *gin.Context, resp *http.Response, originalModel string) (*ClaudeUsage, error) { + body, err := io.ReadAll(io.LimitReader(resp.Body, 8<<20)) + if err != nil { + return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to read upstream response") + } + + geminiResp, err := unwrapGeminiResponse(body) + if err != nil { + return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to parse upstream response") + } + + claudeResp, usage := convertGeminiToClaudeMessage(geminiResp, originalModel) + c.JSON(http.StatusOK, claudeResp) + + return usage, nil +} + +func (s *GeminiMessagesCompatService) handleStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time, originalModel string) (*geminiStreamResult, error) { + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("X-Accel-Buffering", "no") + c.Status(http.StatusOK) + + flusher, ok := c.Writer.(http.Flusher) + if !ok { + return nil, errors.New("streaming not supported") + } + + messageID := "msg_" + randomHex(12) + messageStart := map[string]any{ + "type": "message_start", + "message": map[string]any{ + "id": messageID, + "type": "message", + "role": "assistant", + "model": originalModel, + "content": []any{}, + "stop_reason": nil, + "stop_sequence": nil, + "usage": map[string]any{ + "input_tokens": 0, + "output_tokens": 0, + }, + }, + } + writeSSE(c.Writer, "message_start", messageStart) + flusher.Flush() + + var firstTokenMs *int + var usage ClaudeUsage + finishReason := "" + sawToolUse := false + + nextBlockIndex := 0 + openBlockIndex := -1 + openBlockType := "" + seenText := "" + + reader := bufio.NewReader(resp.Body) + for { + line, err := reader.ReadString('\n') + if err != nil { + if errors.Is(err, io.EOF) { + break + } + return nil, fmt.Errorf("stream read error: %w", err) + } + + if !strings.HasPrefix(line, "data:") { + continue + } + payload := strings.TrimSpace(strings.TrimPrefix(line, "data:")) + if payload == "" || payload == "[DONE]" { + continue + } + + geminiResp, err := unwrapGeminiResponse([]byte(payload)) + if err != nil { + continue + } + + if fr := extractGeminiFinishReason(geminiResp); fr != "" { + finishReason = fr + } + + parts := extractGeminiParts(geminiResp) + for _, part := range parts { + if text, ok := part["text"].(string); ok && text != "" { + delta, newSeen := computeGeminiTextDelta(seenText, text) + seenText = newSeen + if delta == "" { + continue + } + + if openBlockType != "text" { + if openBlockIndex >= 0 { + writeSSE(c.Writer, "content_block_stop", map[string]any{ + "type": "content_block_stop", + "index": openBlockIndex, + }) + } + openBlockType = "text" + openBlockIndex = nextBlockIndex + nextBlockIndex++ + writeSSE(c.Writer, "content_block_start", map[string]any{ + "type": "content_block_start", + "index": openBlockIndex, + "content_block": map[string]any{ + "type": "text", + "text": "", + }, + }) + } + + if firstTokenMs == nil { + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + writeSSE(c.Writer, "content_block_delta", map[string]any{ + "type": "content_block_delta", + "index": openBlockIndex, + "delta": map[string]any{ + "type": "text_delta", + "text": delta, + }, + }) + flusher.Flush() + continue + } + + if fc, ok := part["functionCall"].(map[string]any); ok && fc != nil { + name, _ := fc["name"].(string) + args := fc["args"] + if strings.TrimSpace(name) == "" { + name = "tool" + } + + // Close any open block before tool_use. + if openBlockIndex >= 0 { + writeSSE(c.Writer, "content_block_stop", map[string]any{ + "type": "content_block_stop", + "index": openBlockIndex, + }) + openBlockIndex = -1 + openBlockType = "" + } + + toolID := "toolu_" + randomHex(8) + toolIndex := nextBlockIndex + nextBlockIndex++ + sawToolUse = true + + writeSSE(c.Writer, "content_block_start", map[string]any{ + "type": "content_block_start", + "index": toolIndex, + "content_block": map[string]any{ + "type": "tool_use", + "id": toolID, + "name": name, + "input": map[string]any{}, + }, + }) + + argsJSON := "{}" + if args != nil { + if b, err := json.Marshal(args); err == nil { + argsJSON = string(b) + } + } + writeSSE(c.Writer, "content_block_delta", map[string]any{ + "type": "content_block_delta", + "index": toolIndex, + "delta": map[string]any{ + "type": "input_json_delta", + "partial_json": argsJSON, + }, + }) + writeSSE(c.Writer, "content_block_stop", map[string]any{ + "type": "content_block_stop", + "index": toolIndex, + }) + flusher.Flush() + } + } + + if u := extractGeminiUsage(geminiResp); u != nil { + usage = *u + } + } + + if openBlockIndex >= 0 { + writeSSE(c.Writer, "content_block_stop", map[string]any{ + "type": "content_block_stop", + "index": openBlockIndex, + }) + } + + stopReason := mapGeminiFinishReasonToClaudeStopReason(finishReason) + if sawToolUse { + stopReason = "tool_use" + } + + usageObj := map[string]any{ + "output_tokens": usage.OutputTokens, + } + if usage.InputTokens > 0 { + usageObj["input_tokens"] = usage.InputTokens + } + writeSSE(c.Writer, "message_delta", map[string]any{ + "type": "message_delta", + "delta": map[string]any{ + "stop_reason": stopReason, + "stop_sequence": nil, + }, + "usage": usageObj, + }) + writeSSE(c.Writer, "message_stop", map[string]any{ + "type": "message_stop", + }) + flusher.Flush() + + return &geminiStreamResult{usage: &usage, firstTokenMs: firstTokenMs}, nil +} + +func writeSSE(w io.Writer, event string, data any) { + if event != "" { + _, _ = fmt.Fprintf(w, "event: %s\n", event) + } + b, _ := json.Marshal(data) + _, _ = fmt.Fprintf(w, "data: %s\n\n", string(b)) +} + +func randomHex(nBytes int) string { + b := make([]byte, nBytes) + _, _ = rand.Read(b) + return hex.EncodeToString(b) +} + +func (s *GeminiMessagesCompatService) writeClaudeError(c *gin.Context, status int, errType, message string) error { + c.JSON(status, gin.H{ + "type": "error", + "error": gin.H{"type": errType, "message": message}, + }) + return fmt.Errorf("%s", message) +} + +func unwrapGeminiResponse(raw []byte) (map[string]any, error) { + var outer map[string]any + if err := json.Unmarshal(raw, &outer); err != nil { + return nil, err + } + if resp, ok := outer["response"].(map[string]any); ok && resp != nil { + return resp, nil + } + return outer, nil +} + +func convertGeminiToClaudeMessage(geminiResp map[string]any, originalModel string) (map[string]any, *ClaudeUsage) { + usage := extractGeminiUsage(geminiResp) + if usage == nil { + usage = &ClaudeUsage{} + } + + contentBlocks := make([]any, 0) + sawToolUse := false + if candidates, ok := geminiResp["candidates"].([]any); ok && len(candidates) > 0 { + if cand, ok := candidates[0].(map[string]any); ok { + if content, ok := cand["content"].(map[string]any); ok { + if parts, ok := content["parts"].([]any); ok { + for _, part := range parts { + pm, ok := part.(map[string]any) + if !ok { + continue + } + if text, ok := pm["text"].(string); ok && text != "" { + contentBlocks = append(contentBlocks, map[string]any{ + "type": "text", + "text": text, + }) + } + if fc, ok := pm["functionCall"].(map[string]any); ok { + name, _ := fc["name"].(string) + if strings.TrimSpace(name) == "" { + name = "tool" + } + args := fc["args"] + sawToolUse = true + contentBlocks = append(contentBlocks, map[string]any{ + "type": "tool_use", + "id": "toolu_" + randomHex(8), + "name": name, + "input": args, + }) + } + } + } + } + } + } + + stopReason := mapGeminiFinishReasonToClaudeStopReason(extractGeminiFinishReason(geminiResp)) + if sawToolUse { + stopReason = "tool_use" + } + + resp := map[string]any{ + "id": "msg_" + randomHex(12), + "type": "message", + "role": "assistant", + "model": originalModel, + "content": contentBlocks, + "stop_reason": stopReason, + "stop_sequence": nil, + "usage": map[string]any{ + "input_tokens": usage.InputTokens, + "output_tokens": usage.OutputTokens, + }, + } + + return resp, usage +} + +func extractGeminiUsage(geminiResp map[string]any) *ClaudeUsage { + usageMeta, ok := geminiResp["usageMetadata"].(map[string]any) + if !ok || usageMeta == nil { + return nil + } + prompt, _ := asInt(usageMeta["promptTokenCount"]) + cand, _ := asInt(usageMeta["candidatesTokenCount"]) + return &ClaudeUsage{ + InputTokens: prompt, + OutputTokens: cand, + } +} + +func asInt(v any) (int, bool) { + switch t := v.(type) { + case float64: + return int(t), true + case int: + return t, true + case int64: + return int(t), true + case json.Number: + i, err := t.Int64() + if err != nil { + return 0, false + } + return int(i), true + default: + return 0, false + } +} + +func (s *GeminiMessagesCompatService) handleGeminiUpstreamError(ctx context.Context, account *model.Account, statusCode int, headers http.Header, body []byte) { + if s.rateLimitService != nil && (statusCode == 401 || statusCode == 403 || statusCode == 529) { + s.rateLimitService.HandleUpstreamError(ctx, account, statusCode, headers, body) + return + } + if statusCode != 429 { + return + } + resetAt := parseGeminiRateLimitResetTime(body) + if resetAt == nil { + ra := time.Now().Add(5 * time.Minute) + _ = s.accountRepo.SetRateLimited(ctx, account.ID, ra) + return + } + _ = s.accountRepo.SetRateLimited(ctx, account.ID, time.Unix(*resetAt, 0)) +} + +func parseGeminiRateLimitResetTime(body []byte) *int64 { + // Try to parse metadata.quotaResetDelay like "12.345s" + var parsed map[string]any + if err := json.Unmarshal(body, &parsed); err == nil { + if errObj, ok := parsed["error"].(map[string]any); ok { + if msg, ok := errObj["message"].(string); ok { + if looksLikeGeminiDailyQuota(msg) { + if ts := nextGeminiDailyResetUnix(); ts != nil { + return ts + } + } + } + if details, ok := errObj["details"].([]any); ok { + for _, d := range details { + dm, ok := d.(map[string]any) + if !ok { + continue + } + if meta, ok := dm["metadata"].(map[string]any); ok { + if v, ok := meta["quotaResetDelay"].(string); ok { + if dur, err := time.ParseDuration(v); err == nil { + ts := time.Now().Unix() + int64(dur.Seconds()) + return &ts + } + } + } + } + } + } + } + + // Match "Please retry in Xs" + retryInRegex := regexp.MustCompile(`Please retry in ([0-9.]+)s`) + matches := retryInRegex.FindStringSubmatch(string(body)) + if len(matches) == 2 { + if dur, err := time.ParseDuration(matches[1] + "s"); err == nil { + ts := time.Now().Unix() + int64(math.Ceil(dur.Seconds())) + return &ts + } + } + + return nil +} + +func looksLikeGeminiDailyQuota(message string) bool { + m := strings.ToLower(message) + if strings.Contains(m, "per day") || strings.Contains(m, "requests per day") || strings.Contains(m, "quota") && strings.Contains(m, "per day") { + return true + } + return false +} + +func nextGeminiDailyResetUnix() *int64 { + loc, err := time.LoadLocation("America/Los_Angeles") + if err != nil { + // Fallback: PST without DST. + loc = time.FixedZone("PST", -8*3600) + } + now := time.Now().In(loc) + reset := time.Date(now.Year(), now.Month(), now.Day(), 0, 5, 0, 0, loc) + if !reset.After(now) { + reset = reset.Add(24 * time.Hour) + } + ts := reset.Unix() + return &ts +} + +func extractGeminiFinishReason(geminiResp map[string]any) string { + if candidates, ok := geminiResp["candidates"].([]any); ok && len(candidates) > 0 { + if cand, ok := candidates[0].(map[string]any); ok { + if fr, ok := cand["finishReason"].(string); ok { + return fr + } + } + } + return "" +} + +func extractGeminiParts(geminiResp map[string]any) []map[string]any { + if candidates, ok := geminiResp["candidates"].([]any); ok && len(candidates) > 0 { + if cand, ok := candidates[0].(map[string]any); ok { + if content, ok := cand["content"].(map[string]any); ok { + if partsAny, ok := content["parts"].([]any); ok && len(partsAny) > 0 { + out := make([]map[string]any, 0, len(partsAny)) + for _, p := range partsAny { + pm, ok := p.(map[string]any) + if !ok { + continue + } + out = append(out, pm) + } + return out + } + } + } + } + return nil +} + +func computeGeminiTextDelta(seen, incoming string) (delta, newSeen string) { + incoming = strings.TrimSuffix(incoming, "\u0000") + if incoming == "" { + return "", seen + } + + // Cumulative mode: incoming contains full text so far. + if strings.HasPrefix(incoming, seen) { + return strings.TrimPrefix(incoming, seen), incoming + } + // Duplicate/rewind: ignore. + if strings.HasPrefix(seen, incoming) { + return "", seen + } + // Delta mode: treat incoming as incremental chunk. + return incoming, seen + incoming +} + +func mapGeminiFinishReasonToClaudeStopReason(finishReason string) string { + switch strings.ToUpper(strings.TrimSpace(finishReason)) { + case "MAX_TOKENS": + return "max_tokens" + case "STOP": + return "end_turn" + default: + return "end_turn" + } +} + +func convertClaudeMessagesToGeminiGenerateContent(body []byte) ([]byte, error) { + var req map[string]any + if err := json.Unmarshal(body, &req); err != nil { + return nil, err + } + + toolUseIDToName := make(map[string]string) + + systemText := extractClaudeSystemText(req["system"]) + contents, err := convertClaudeMessagesToGeminiContents(req["messages"], toolUseIDToName) + if err != nil { + return nil, err + } + + out := make(map[string]any) + if systemText != "" { + out["systemInstruction"] = map[string]any{ + "parts": []any{map[string]any{"text": systemText}}, + } + } + out["contents"] = contents + + if tools := convertClaudeToolsToGeminiTools(req["tools"]); tools != nil { + out["tools"] = tools + } + + generationConfig := convertClaudeGenerationConfig(req) + if generationConfig != nil { + out["generationConfig"] = generationConfig + } + + stripGeminiFunctionIDs(out) + return json.Marshal(out) +} + +func stripGeminiFunctionIDs(req map[string]any) { + // Defensive cleanup: some upstreams reject unexpected `id` fields in functionCall/functionResponse. + contents, ok := req["contents"].([]any) + if !ok { + return + } + for _, c := range contents { + cm, ok := c.(map[string]any) + if !ok { + continue + } + contentParts, ok := cm["parts"].([]any) + if !ok { + continue + } + for _, p := range contentParts { + pm, ok := p.(map[string]any) + if !ok { + continue + } + if fc, ok := pm["functionCall"].(map[string]any); ok && fc != nil { + delete(fc, "id") + } + if fr, ok := pm["functionResponse"].(map[string]any); ok && fr != nil { + delete(fr, "id") + } + } + } +} + +func extractClaudeSystemText(system any) string { + switch v := system.(type) { + case string: + return strings.TrimSpace(v) + case []any: + var parts []string + for _, p := range v { + pm, ok := p.(map[string]any) + if !ok { + continue + } + if t, _ := pm["type"].(string); t != "text" { + continue + } + if text, ok := pm["text"].(string); ok && strings.TrimSpace(text) != "" { + parts = append(parts, text) + } + } + return strings.TrimSpace(strings.Join(parts, "\n")) + default: + return "" + } +} + +func convertClaudeMessagesToGeminiContents(messages any, toolUseIDToName map[string]string) ([]any, error) { + arr, ok := messages.([]any) + if !ok { + return nil, errors.New("messages must be an array") + } + + out := make([]any, 0, len(arr)) + for _, m := range arr { + mm, ok := m.(map[string]any) + if !ok { + continue + } + role, _ := mm["role"].(string) + role = strings.ToLower(strings.TrimSpace(role)) + gRole := "user" + if role == "assistant" { + gRole = "model" + } + + parts := make([]any, 0) + switch content := mm["content"].(type) { + case string: + if strings.TrimSpace(content) != "" { + parts = append(parts, map[string]any{"text": content}) + } + case []any: + for _, block := range content { + bm, ok := block.(map[string]any) + if !ok { + continue + } + bt, _ := bm["type"].(string) + switch bt { + case "text": + if text, ok := bm["text"].(string); ok && strings.TrimSpace(text) != "" { + parts = append(parts, map[string]any{"text": text}) + } + case "tool_use": + id, _ := bm["id"].(string) + name, _ := bm["name"].(string) + if strings.TrimSpace(id) != "" && strings.TrimSpace(name) != "" { + toolUseIDToName[id] = name + } + parts = append(parts, map[string]any{ + "functionCall": map[string]any{ + "name": name, + "args": bm["input"], + }, + }) + case "tool_result": + toolUseID, _ := bm["tool_use_id"].(string) + name := toolUseIDToName[toolUseID] + if name == "" { + name = "tool" + } + parts = append(parts, map[string]any{ + "functionResponse": map[string]any{ + "name": name, + "response": map[string]any{ + "content": extractClaudeContentText(bm["content"]), + }, + }, + }) + case "image": + if src, ok := bm["source"].(map[string]any); ok { + if srcType, _ := src["type"].(string); srcType == "base64" { + mediaType, _ := src["media_type"].(string) + data, _ := src["data"].(string) + if mediaType != "" && data != "" { + parts = append(parts, map[string]any{ + "inlineData": map[string]any{ + "mimeType": mediaType, + "data": data, + }, + }) + } + } + } + default: + // best-effort: preserve unknown blocks as text + if b, err := json.Marshal(bm); err == nil { + parts = append(parts, map[string]any{"text": string(b)}) + } + } + } + default: + // ignore + } + + out = append(out, map[string]any{ + "role": gRole, + "parts": parts, + }) + } + return out, nil +} + +func extractClaudeContentText(v any) string { + switch t := v.(type) { + case string: + return t + case []any: + var sb strings.Builder + for _, part := range t { + pm, ok := part.(map[string]any) + if !ok { + continue + } + if pm["type"] == "text" { + if text, ok := pm["text"].(string); ok { + sb.WriteString(text) + } + } + } + return sb.String() + default: + b, _ := json.Marshal(t) + return string(b) + } +} + +func convertClaudeToolsToGeminiTools(tools any) []any { + arr, ok := tools.([]any) + if !ok || len(arr) == 0 { + return nil + } + + funcDecls := make([]any, 0, len(arr)) + for _, t := range arr { + tm, ok := t.(map[string]any) + if !ok { + continue + } + name, _ := tm["name"].(string) + desc, _ := tm["description"].(string) + params := tm["input_schema"] + if name == "" { + continue + } + funcDecls = append(funcDecls, map[string]any{ + "name": name, + "description": desc, + "parameters": params, + }) + } + + if len(funcDecls) == 0 { + return nil + } + return []any{ + map[string]any{ + "functionDeclarations": funcDecls, + }, + } +} + +func convertClaudeGenerationConfig(req map[string]any) map[string]any { + out := make(map[string]any) + if mt, ok := asInt(req["max_tokens"]); ok && mt > 0 { + out["maxOutputTokens"] = mt + } + if temp, ok := req["temperature"].(float64); ok { + out["temperature"] = temp + } + if topP, ok := req["top_p"].(float64); ok { + out["topP"] = topP + } + if stopSeq, ok := req["stop_sequences"].([]any); ok && len(stopSeq) > 0 { + out["stopSequences"] = stopSeq + } + if len(out) == 0 { + return nil + } + return out +} diff --git a/backend/internal/service/gemini_oauth_service.go b/backend/internal/service/gemini_oauth_service.go new file mode 100644 index 00000000..067a2455 --- /dev/null +++ b/backend/internal/service/gemini_oauth_service.go @@ -0,0 +1,305 @@ +package service + +import ( + "context" + "errors" + "fmt" + "strconv" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/model" + "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" + "github.com/Wei-Shaw/sub2api/internal/service/ports" +) + +type GeminiOAuthService struct { + sessionStore *geminicli.SessionStore + proxyRepo ports.ProxyRepository + oauthClient ports.GeminiOAuthClient + codeAssist ports.GeminiCliCodeAssistClient + cfg *config.Config +} + +func NewGeminiOAuthService( + proxyRepo ports.ProxyRepository, + oauthClient ports.GeminiOAuthClient, + codeAssist ports.GeminiCliCodeAssistClient, + cfg *config.Config, +) *GeminiOAuthService { + return &GeminiOAuthService{ + sessionStore: geminicli.NewSessionStore(), + proxyRepo: proxyRepo, + oauthClient: oauthClient, + codeAssist: codeAssist, + cfg: cfg, + } +} + +type GeminiAuthURLResult struct { + AuthURL string `json:"auth_url"` + SessionID string `json:"session_id"` + State string `json:"state"` +} + +func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64, redirectURI string) (*GeminiAuthURLResult, error) { + state, err := geminicli.GenerateState() + if err != nil { + return nil, fmt.Errorf("failed to generate state: %w", err) + } + codeVerifier, err := geminicli.GenerateCodeVerifier() + if err != nil { + return nil, fmt.Errorf("failed to generate code verifier: %w", err) + } + codeChallenge := geminicli.GenerateCodeChallenge(codeVerifier) + sessionID, err := geminicli.GenerateSessionID() + if err != nil { + return nil, fmt.Errorf("failed to generate session ID: %w", err) + } + + var proxyURL string + if proxyID != nil { + proxy, err := s.proxyRepo.GetByID(ctx, *proxyID) + if err == nil && proxy != nil { + proxyURL = proxy.URL() + } + } + + session := &geminicli.OAuthSession{ + State: state, + CodeVerifier: codeVerifier, + ProxyURL: proxyURL, + RedirectURI: redirectURI, + CreatedAt: time.Now(), + } + s.sessionStore.Set(sessionID, session) + + oauthCfg := geminicli.OAuthConfig{ + ClientID: s.cfg.Gemini.OAuth.ClientID, + ClientSecret: s.cfg.Gemini.OAuth.ClientSecret, + Scopes: s.cfg.Gemini.OAuth.Scopes, + } + + authURL, err := geminicli.BuildAuthorizationURL(oauthCfg, state, codeChallenge, redirectURI) + if err != nil { + return nil, err + } + + return &GeminiAuthURLResult{ + AuthURL: authURL, + SessionID: sessionID, + State: state, + }, nil +} + +type GeminiExchangeCodeInput struct { + SessionID string + State string + Code string + RedirectURI string + ProxyID *int64 +} + +type GeminiTokenInfo struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + ExpiresIn int64 `json:"expires_in"` + ExpiresAt int64 `json:"expires_at"` + TokenType string `json:"token_type"` + Scope string `json:"scope,omitempty"` + ProjectID string `json:"project_id,omitempty"` +} + +func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExchangeCodeInput) (*GeminiTokenInfo, error) { + session, ok := s.sessionStore.Get(input.SessionID) + if !ok { + return nil, fmt.Errorf("session not found or expired") + } + if strings.TrimSpace(input.State) == "" || input.State != session.State { + return nil, fmt.Errorf("invalid state") + } + + proxyURL := session.ProxyURL + if input.ProxyID != nil { + proxy, err := s.proxyRepo.GetByID(ctx, *input.ProxyID) + if err == nil && proxy != nil { + proxyURL = proxy.URL() + } + } + + redirectURI := session.RedirectURI + if strings.TrimSpace(input.RedirectURI) != "" { + redirectURI = input.RedirectURI + } + + tokenResp, err := s.oauthClient.ExchangeCode(ctx, input.Code, session.CodeVerifier, redirectURI, proxyURL) + if err != nil { + return nil, fmt.Errorf("failed to exchange code: %w", err) + } + s.sessionStore.Delete(input.SessionID) + + expiresAt := time.Now().Unix() + tokenResp.ExpiresIn + projectID, _ := s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL) + + return &GeminiTokenInfo{ + AccessToken: tokenResp.AccessToken, + RefreshToken: tokenResp.RefreshToken, + TokenType: tokenResp.TokenType, + ExpiresIn: tokenResp.ExpiresIn, + ExpiresAt: expiresAt, + Scope: tokenResp.Scope, + ProjectID: projectID, + }, nil +} + +func (s *GeminiOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*GeminiTokenInfo, error) { + var lastErr error + + for attempt := 0; attempt <= 3; attempt++ { + if attempt > 0 { + backoff := time.Duration(1< 30*time.Second { + backoff = 30 * time.Second + } + time.Sleep(backoff) + } + + tokenResp, err := s.oauthClient.RefreshToken(ctx, refreshToken, proxyURL) + if err == nil { + expiresAt := time.Now().Unix() + tokenResp.ExpiresIn + return &GeminiTokenInfo{ + AccessToken: tokenResp.AccessToken, + RefreshToken: tokenResp.RefreshToken, + TokenType: tokenResp.TokenType, + ExpiresIn: tokenResp.ExpiresIn, + ExpiresAt: expiresAt, + Scope: tokenResp.Scope, + }, nil + } + + if isNonRetryableGeminiOAuthError(err) { + return nil, err + } + lastErr = err + } + + return nil, fmt.Errorf("token refresh failed after retries: %w", lastErr) +} + +func isNonRetryableGeminiOAuthError(err error) bool { + msg := err.Error() + nonRetryable := []string{ + "invalid_grant", + "invalid_client", + "unauthorized_client", + "access_denied", + } + for _, needle := range nonRetryable { + if strings.Contains(msg, needle) { + return true + } + } + return false +} + +func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *model.Account) (*GeminiTokenInfo, error) { + if account.Platform != model.PlatformGemini || account.Type != model.AccountTypeOAuth { + return nil, fmt.Errorf("account is not a Gemini OAuth account") + } + + refreshToken := account.GetCredential("refresh_token") + if strings.TrimSpace(refreshToken) == "" { + return nil, fmt.Errorf("no refresh token available") + } + + var proxyURL string + if account.ProxyID != nil { + proxy, err := s.proxyRepo.GetByID(ctx, *account.ProxyID) + if err == nil && proxy != nil { + proxyURL = proxy.URL() + } + } + + return s.RefreshToken(ctx, refreshToken, proxyURL) +} + +func (s *GeminiOAuthService) BuildAccountCredentials(tokenInfo *GeminiTokenInfo) map[string]any { + creds := map[string]any{ + "access_token": tokenInfo.AccessToken, + "expires_at": strconv.FormatInt(tokenInfo.ExpiresAt, 10), + } + if tokenInfo.RefreshToken != "" { + creds["refresh_token"] = tokenInfo.RefreshToken + } + if tokenInfo.TokenType != "" { + creds["token_type"] = tokenInfo.TokenType + } + if tokenInfo.Scope != "" { + creds["scope"] = tokenInfo.Scope + } + if tokenInfo.ProjectID != "" { + creds["project_id"] = tokenInfo.ProjectID + } + return creds +} + +func (s *GeminiOAuthService) Stop() { + s.sessionStore.Stop() +} + +func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, proxyURL string) (string, error) { + if s.codeAssist == nil { + return "", errors.New("code assist client not configured") + } + + loadResp, err := s.codeAssist.LoadCodeAssist(ctx, accessToken, proxyURL, nil) + if err == nil && strings.TrimSpace(loadResp.CurrentTier) != "" && strings.TrimSpace(loadResp.CloudAICompanionProject) != "" { + return strings.TrimSpace(loadResp.CloudAICompanionProject), nil + } + + // pick default tier from allowedTiers, fallback to LEGACY. + tierID := "LEGACY" + if loadResp != nil { + for _, tier := range loadResp.AllowedTiers { + if tier.IsDefault && strings.TrimSpace(tier.ID) != "" { + tierID = tier.ID + break + } + } + } + + req := &geminicli.OnboardUserRequest{ + TierID: tierID, + Metadata: geminicli.LoadCodeAssistMetadata{ + IDEType: "ANTIGRAVITY", + Platform: "PLATFORM_UNSPECIFIED", + PluginType: "GEMINI", + }, + } + + maxAttempts := 5 + for attempt := 1; attempt <= maxAttempts; attempt++ { + resp, err := s.codeAssist.OnboardUser(ctx, accessToken, proxyURL, req) + if err != nil { + return "", err + } + if resp.Done { + if resp.Response == nil || resp.Response.CloudAICompanionProject == nil { + return "", errors.New("onboardUser completed but no project_id returned") + } + switch v := resp.Response.CloudAICompanionProject.(type) { + case string: + return strings.TrimSpace(v), nil + case map[string]any: + if id, ok := v["id"].(string); ok { + return strings.TrimSpace(id), nil + } + } + return "", errors.New("onboardUser returned unsupported project_id format") + } + time.Sleep(2 * time.Second) + } + + return "", fmt.Errorf("onboardUser timeout after %d attempts", maxAttempts) +} diff --git a/backend/internal/service/gemini_token_provider.go b/backend/internal/service/gemini_token_provider.go new file mode 100644 index 00000000..51a2f54a --- /dev/null +++ b/backend/internal/service/gemini_token_provider.go @@ -0,0 +1,139 @@ +package service + +import ( + "context" + "errors" + "strconv" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/model" + "github.com/Wei-Shaw/sub2api/internal/service/ports" +) + +const ( + geminiTokenRefreshSkew = 3 * time.Minute + geminiTokenCacheSkew = 5 * time.Minute +) + +type GeminiTokenProvider struct { + accountRepo ports.AccountRepository + tokenCache ports.GeminiTokenCache + geminiOAuthService *GeminiOAuthService +} + +func NewGeminiTokenProvider( + accountRepo ports.AccountRepository, + tokenCache ports.GeminiTokenCache, + geminiOAuthService *GeminiOAuthService, +) *GeminiTokenProvider { + return &GeminiTokenProvider{ + accountRepo: accountRepo, + tokenCache: tokenCache, + geminiOAuthService: geminiOAuthService, + } +} + +func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *model.Account) (string, error) { + if account == nil { + return "", errors.New("account is nil") + } + if account.Platform != model.PlatformGemini || account.Type != model.AccountTypeOAuth { + return "", errors.New("not a gemini oauth account") + } + + cacheKey := geminiTokenCacheKey(account) + + // 1) Try cache first. + if p.tokenCache != nil { + if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" { + return token, nil + } + } + + // 2) Refresh if needed (pre-expiry skew). + expiresAt := parseExpiresAt(account) + needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= geminiTokenRefreshSkew + if needsRefresh && p.tokenCache != nil { + locked, err := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second) + if err == nil && locked { + defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }() + + // Re-check after lock (another worker may have refreshed). + if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" { + return token, nil + } + + fresh, err := p.accountRepo.GetByID(ctx, account.ID) + if err == nil && fresh != nil { + account = fresh + } + expiresAt = parseExpiresAt(account) + if expiresAt == nil || time.Until(*expiresAt) <= geminiTokenRefreshSkew { + if p.geminiOAuthService == nil { + return "", errors.New("gemini oauth service not configured") + } + tokenInfo, err := p.geminiOAuthService.RefreshAccountToken(ctx, account) + if err != nil { + return "", err + } + newCredentials := p.geminiOAuthService.BuildAccountCredentials(tokenInfo) + for k, v := range account.Credentials { + if _, exists := newCredentials[k]; !exists { + newCredentials[k] = v + } + } + account.Credentials = model.JSONB(newCredentials) + _ = p.accountRepo.Update(ctx, account) + expiresAt = parseExpiresAt(account) + } + } + } + + accessToken := account.GetCredential("access_token") + if strings.TrimSpace(accessToken) == "" { + return "", errors.New("access_token not found in credentials") + } + + // 3) Populate cache with TTL. + if p.tokenCache != nil { + ttl := 30 * time.Minute + if expiresAt != nil { + until := time.Until(*expiresAt) + switch { + case until > geminiTokenCacheSkew: + ttl = until - geminiTokenCacheSkew + case until > 0: + ttl = until + default: + ttl = time.Minute + } + } + _ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl) + } + + return accessToken, nil +} + +func geminiTokenCacheKey(account *model.Account) string { + projectID := strings.TrimSpace(account.GetCredential("project_id")) + if projectID != "" { + return projectID + } + return "account:" + strconv.FormatInt(account.ID, 10) +} + +func parseExpiresAt(account *model.Account) *time.Time { + raw := strings.TrimSpace(account.GetCredential("expires_at")) + if raw == "" { + return nil + } + if unixSec, err := strconv.ParseInt(raw, 10, 64); err == nil && unixSec > 0 { + t := time.Unix(unixSec, 0) + return &t + } + if t, err := time.Parse(time.RFC3339, raw); err == nil { + return &t + } + return nil +} diff --git a/backend/internal/service/gemini_token_refresher.go b/backend/internal/service/gemini_token_refresher.go new file mode 100644 index 00000000..25ad699d --- /dev/null +++ b/backend/internal/service/gemini_token_refresher.go @@ -0,0 +1,53 @@ +package service + +import ( + "context" + "strconv" + "time" + + "github.com/Wei-Shaw/sub2api/internal/model" +) + +type GeminiTokenRefresher struct { + geminiOAuthService *GeminiOAuthService +} + +func NewGeminiTokenRefresher(geminiOAuthService *GeminiOAuthService) *GeminiTokenRefresher { + return &GeminiTokenRefresher{geminiOAuthService: geminiOAuthService} +} + +func (r *GeminiTokenRefresher) CanRefresh(account *model.Account) bool { + return account.Platform == model.PlatformGemini && account.Type == model.AccountTypeOAuth +} + +func (r *GeminiTokenRefresher) NeedsRefresh(account *model.Account, refreshWindow time.Duration) bool { + if !r.CanRefresh(account) { + return false + } + expiresAtStr := account.GetCredential("expires_at") + if expiresAtStr == "" { + return false + } + expiresAt, err := strconv.ParseInt(expiresAtStr, 10, 64) + if err != nil { + return false + } + expiryTime := time.Unix(expiresAt, 0) + return time.Until(expiryTime) < refreshWindow +} + +func (r *GeminiTokenRefresher) Refresh(ctx context.Context, account *model.Account) (map[string]any, error) { + tokenInfo, err := r.geminiOAuthService.RefreshAccountToken(ctx, account) + if err != nil { + return nil, err + } + + newCredentials := r.geminiOAuthService.BuildAccountCredentials(tokenInfo) + for k, v := range account.Credentials { + if _, exists := newCredentials[k]; !exists { + newCredentials[k] = v + } + } + + return newCredentials, nil +}