feat(backend): 实现 Gemini AI Studio OAuth 和消息兼容服务
- gemini_oauth_service.go: 新增 AI Studio OAuth 类型支持 - gemini_token_provider.go: Token 提供器增强 - gemini_messages_compat_service.go: 支持 AI Studio 端点 - account_test_service.go: Gemini 账户可用性检测 - gateway_service.go: 网关服务适配 - openai_gateway_service.go: OpenAI 兼容层调整
This commit is contained in:
@@ -3,6 +3,7 @@ package service
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
@@ -16,6 +17,7 @@ import (
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
@@ -38,19 +40,30 @@ type TestEvent struct {
|
||||
|
||||
// AccountTestService handles account testing operations
|
||||
type AccountTestService struct {
|
||||
accountRepo AccountRepository
|
||||
oauthService *OAuthService
|
||||
openaiOAuthService *OpenAIOAuthService
|
||||
httpUpstream HTTPUpstream
|
||||
accountRepo AccountRepository
|
||||
oauthService *OAuthService
|
||||
openaiOAuthService *OpenAIOAuthService
|
||||
geminiOAuthService *GeminiOAuthService
|
||||
geminiTokenProvider *GeminiTokenProvider
|
||||
httpUpstream HTTPUpstream
|
||||
}
|
||||
|
||||
// NewAccountTestService creates a new AccountTestService
|
||||
func NewAccountTestService(accountRepo AccountRepository, oauthService *OAuthService, openaiOAuthService *OpenAIOAuthService, httpUpstream HTTPUpstream) *AccountTestService {
|
||||
func NewAccountTestService(
|
||||
accountRepo AccountRepository,
|
||||
oauthService *OAuthService,
|
||||
openaiOAuthService *OpenAIOAuthService,
|
||||
geminiOAuthService *GeminiOAuthService,
|
||||
geminiTokenProvider *GeminiTokenProvider,
|
||||
httpUpstream HTTPUpstream,
|
||||
) *AccountTestService {
|
||||
return &AccountTestService{
|
||||
accountRepo: accountRepo,
|
||||
oauthService: oauthService,
|
||||
openaiOAuthService: openaiOAuthService,
|
||||
httpUpstream: httpUpstream,
|
||||
accountRepo: accountRepo,
|
||||
oauthService: oauthService,
|
||||
openaiOAuthService: openaiOAuthService,
|
||||
geminiOAuthService: geminiOAuthService,
|
||||
geminiTokenProvider: geminiTokenProvider,
|
||||
httpUpstream: httpUpstream,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -123,6 +136,10 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
|
||||
return s.testOpenAIAccountConnection(c, account, modelID)
|
||||
}
|
||||
|
||||
if account.IsGemini() {
|
||||
return s.testGeminiAccountConnection(c, account, modelID)
|
||||
}
|
||||
|
||||
return s.testClaudeAccountConnection(c, account, modelID)
|
||||
}
|
||||
|
||||
@@ -368,6 +385,247 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
|
||||
return s.processOpenAIStream(c, resp.Body)
|
||||
}
|
||||
|
||||
// testGeminiAccountConnection tests a Gemini account's connection
|
||||
func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account *model.Account, modelID string) error {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Determine the model to use
|
||||
testModelID := modelID
|
||||
if testModelID == "" {
|
||||
testModelID = geminicli.DefaultTestModel
|
||||
}
|
||||
|
||||
// For API Key accounts with model mapping, map the model
|
||||
if account.Type == model.AccountTypeApiKey {
|
||||
mapping := account.GetModelMapping()
|
||||
if len(mapping) > 0 {
|
||||
if mappedModel, exists := mapping[testModelID]; exists {
|
||||
testModelID = mappedModel
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Set SSE headers
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||
c.Writer.Flush()
|
||||
|
||||
// Create test payload (Gemini format)
|
||||
payload := createGeminiTestPayload()
|
||||
|
||||
// Build request based on account type
|
||||
var req *http.Request
|
||||
var err error
|
||||
|
||||
switch account.Type {
|
||||
case model.AccountTypeApiKey:
|
||||
req, err = s.buildGeminiAPIKeyRequest(ctx, account, testModelID, payload)
|
||||
case model.AccountTypeOAuth:
|
||||
req, err = s.buildGeminiOAuthRequest(ctx, account, testModelID, payload)
|
||||
default:
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to build request: %s", err.Error()))
|
||||
}
|
||||
|
||||
// Send test_start event
|
||||
s.sendEvent(c, TestEvent{Type: "test_start", Model: testModelID})
|
||||
|
||||
// Get proxy and execute request
|
||||
proxyURL := ""
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
resp, err := s.httpUpstream.Do(req, proxyURL)
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body)))
|
||||
}
|
||||
|
||||
// Process SSE stream
|
||||
return s.processGeminiStream(c, resp.Body)
|
||||
}
|
||||
|
||||
// buildGeminiAPIKeyRequest builds request for Gemini API Key accounts
|
||||
func (s *AccountTestService) buildGeminiAPIKeyRequest(ctx context.Context, account *model.Account, modelID string, payload []byte) (*http.Request, error) {
|
||||
apiKey := account.GetCredential("api_key")
|
||||
if strings.TrimSpace(apiKey) == "" {
|
||||
return nil, fmt.Errorf("No API key available")
|
||||
}
|
||||
|
||||
baseURL := account.GetCredential("base_url")
|
||||
if baseURL == "" {
|
||||
baseURL = geminicli.AIStudioBaseURL
|
||||
}
|
||||
|
||||
// Use streamGenerateContent for real-time feedback
|
||||
fullURL := fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse",
|
||||
strings.TrimRight(baseURL, "/"), modelID)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", fullURL, bytes.NewReader(payload))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("x-goog-api-key", apiKey)
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// buildGeminiOAuthRequest builds request for Gemini OAuth accounts
|
||||
func (s *AccountTestService) buildGeminiOAuthRequest(ctx context.Context, account *model.Account, modelID string, payload []byte) (*http.Request, error) {
|
||||
if s.geminiTokenProvider == nil {
|
||||
return nil, fmt.Errorf("Gemini token provider not configured")
|
||||
}
|
||||
|
||||
// Get access token (auto-refreshes if needed)
|
||||
accessToken, err := s.geminiTokenProvider.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Failed to get access token: %w", err)
|
||||
}
|
||||
|
||||
projectID := strings.TrimSpace(account.GetCredential("project_id"))
|
||||
if projectID == "" {
|
||||
// AI Studio OAuth mode (no project_id): call generativelanguage API directly with Bearer token.
|
||||
baseURL := account.GetCredential("base_url")
|
||||
if strings.TrimSpace(baseURL) == "" {
|
||||
baseURL = geminicli.AIStudioBaseURL
|
||||
}
|
||||
fullURL := fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse", strings.TrimRight(baseURL, "/"), modelID)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(payload))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// Wrap payload in Code Assist format
|
||||
var inner map[string]any
|
||||
if err := json.Unmarshal(payload, &inner); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
wrapped := map[string]any{
|
||||
"model": modelID,
|
||||
"project": projectID,
|
||||
"request": inner,
|
||||
}
|
||||
wrappedBytes, _ := json.Marshal(wrapped)
|
||||
|
||||
fullURL := fmt.Sprintf("%s/v1internal:streamGenerateContent?alt=sse", geminicli.GeminiCliBaseURL)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", fullURL, bytes.NewReader(wrappedBytes))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("User-Agent", geminicli.GeminiCLIUserAgent)
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// createGeminiTestPayload creates a minimal test payload for Gemini API
|
||||
func createGeminiTestPayload() []byte {
|
||||
payload := map[string]any{
|
||||
"contents": []map[string]any{
|
||||
{
|
||||
"role": "user",
|
||||
"parts": []map[string]any{
|
||||
{"text": "hi"},
|
||||
},
|
||||
},
|
||||
},
|
||||
"systemInstruction": map[string]any{
|
||||
"parts": []map[string]any{
|
||||
{"text": "You are a helpful AI assistant."},
|
||||
},
|
||||
},
|
||||
}
|
||||
bytes, _ := json.Marshal(payload)
|
||||
return bytes
|
||||
}
|
||||
|
||||
// processGeminiStream processes SSE stream from Gemini API
|
||||
func (s *AccountTestService) processGeminiStream(c *gin.Context, body io.Reader) error {
|
||||
reader := bufio.NewReader(body)
|
||||
|
||||
for {
|
||||
line, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
|
||||
return nil
|
||||
}
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Stream read error: %s", err.Error()))
|
||||
}
|
||||
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" || !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
|
||||
jsonStr := strings.TrimPrefix(line, "data: ")
|
||||
if jsonStr == "[DONE]" {
|
||||
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
|
||||
return nil
|
||||
}
|
||||
|
||||
var data map[string]any
|
||||
if err := json.Unmarshal([]byte(jsonStr), &data); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Extract text from candidates[0].content.parts[].text
|
||||
if candidates, ok := data["candidates"].([]any); ok && len(candidates) > 0 {
|
||||
if candidate, ok := candidates[0].(map[string]any); ok {
|
||||
// Check for completion
|
||||
if finishReason, ok := candidate["finishReason"].(string); ok && finishReason != "" {
|
||||
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
|
||||
return nil
|
||||
}
|
||||
|
||||
// Extract content
|
||||
if content, ok := candidate["content"].(map[string]any); ok {
|
||||
if parts, ok := content["parts"].([]any); ok {
|
||||
for _, part := range parts {
|
||||
if partMap, ok := part.(map[string]any); ok {
|
||||
if text, ok := partMap["text"].(string); ok && text != "" {
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: text})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle errors
|
||||
if errData, ok := data["error"].(map[string]any); ok {
|
||||
errorMsg := "Unknown error"
|
||||
if msg, ok := errData["message"].(string); ok {
|
||||
errorMsg = msg
|
||||
}
|
||||
return s.sendErrorAndEnd(c, errorMsg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// createOpenAITestPayload creates a test payload for OpenAI Responses API
|
||||
func createOpenAITestPayload(modelID string, isOAuth bool) map[string]any {
|
||||
payload := map[string]any{
|
||||
|
||||
Reference in New Issue
Block a user