feat: 新增支持codex转发
This commit is contained in:
@@ -14,7 +14,9 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/claude"
|
||||
"sub2api/internal/pkg/openai"
|
||||
"sub2api/internal/service/ports"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -22,7 +24,9 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
testClaudeAPIURL = "https://api.anthropic.com/v1/messages"
|
||||
testClaudeAPIURL = "https://api.anthropic.com/v1/messages"
|
||||
testOpenAIAPIURL = "https://api.openai.com/v1/responses"
|
||||
chatgptCodexAPIURL = "https://chatgpt.com/backend-api/codex/responses"
|
||||
)
|
||||
|
||||
// TestEvent represents a SSE event for account testing
|
||||
@@ -36,17 +40,19 @@ type TestEvent struct {
|
||||
|
||||
// AccountTestService handles account testing operations
|
||||
type AccountTestService struct {
|
||||
accountRepo ports.AccountRepository
|
||||
oauthService *OAuthService
|
||||
claudeUpstream ClaudeUpstream
|
||||
accountRepo ports.AccountRepository
|
||||
oauthService *OAuthService
|
||||
openaiOAuthService *OpenAIOAuthService
|
||||
httpUpstream ports.HTTPUpstream
|
||||
}
|
||||
|
||||
// NewAccountTestService creates a new AccountTestService
|
||||
func NewAccountTestService(accountRepo ports.AccountRepository, oauthService *OAuthService, claudeUpstream ClaudeUpstream) *AccountTestService {
|
||||
func NewAccountTestService(accountRepo ports.AccountRepository, oauthService *OAuthService, openaiOAuthService *OpenAIOAuthService, httpUpstream ports.HTTPUpstream) *AccountTestService {
|
||||
return &AccountTestService{
|
||||
accountRepo: accountRepo,
|
||||
oauthService: oauthService,
|
||||
claudeUpstream: claudeUpstream,
|
||||
accountRepo: accountRepo,
|
||||
oauthService: oauthService,
|
||||
openaiOAuthService: openaiOAuthService,
|
||||
httpUpstream: httpUpstream,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -114,6 +120,18 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
|
||||
return s.sendErrorAndEnd(c, "Account not found")
|
||||
}
|
||||
|
||||
// Route to platform-specific test method
|
||||
if account.IsOpenAI() {
|
||||
return s.testOpenAIAccountConnection(c, account, modelID)
|
||||
}
|
||||
|
||||
return s.testClaudeAccountConnection(c, account, modelID)
|
||||
}
|
||||
|
||||
// testClaudeAccountConnection tests an Anthropic Claude account's connection
|
||||
func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account *model.Account, modelID string) error {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Determine the model to use
|
||||
testModelID := modelID
|
||||
if testModelID == "" {
|
||||
@@ -222,7 +240,7 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
resp, err := s.claudeUpstream.Do(req, proxyURL)
|
||||
resp, err := s.httpUpstream.Do(req, proxyURL)
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
|
||||
}
|
||||
@@ -234,11 +252,153 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
|
||||
}
|
||||
|
||||
// Process SSE stream
|
||||
return s.processStream(c, resp.Body)
|
||||
return s.processClaudeStream(c, resp.Body)
|
||||
}
|
||||
|
||||
// processStream processes the SSE stream from Claude API
|
||||
func (s *AccountTestService) processStream(c *gin.Context, body io.Reader) error {
|
||||
// testOpenAIAccountConnection tests an OpenAI account's connection
|
||||
func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account *model.Account, modelID string) error {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Default to openai.DefaultTestModel for OpenAI testing
|
||||
testModelID := modelID
|
||||
if testModelID == "" {
|
||||
testModelID = openai.DefaultTestModel
|
||||
}
|
||||
|
||||
// For API Key accounts with model mapping, map the model
|
||||
if account.Type == "apikey" {
|
||||
mapping := account.GetModelMapping()
|
||||
if len(mapping) > 0 {
|
||||
if mappedModel, exists := mapping[testModelID]; exists {
|
||||
testModelID = mappedModel
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Determine authentication method and API URL
|
||||
var authToken string
|
||||
var apiURL string
|
||||
var isOAuth bool
|
||||
var chatgptAccountID string
|
||||
|
||||
if account.IsOAuth() {
|
||||
isOAuth = true
|
||||
// OAuth - use Bearer token with ChatGPT internal API
|
||||
authToken = account.GetOpenAIAccessToken()
|
||||
if authToken == "" {
|
||||
return s.sendErrorAndEnd(c, "No access token available")
|
||||
}
|
||||
|
||||
// Check if token is expired and refresh if needed
|
||||
if account.IsOpenAITokenExpired() && s.openaiOAuthService != nil {
|
||||
tokenInfo, err := s.openaiOAuthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to refresh token: %s", err.Error()))
|
||||
}
|
||||
authToken = tokenInfo.AccessToken
|
||||
}
|
||||
|
||||
// OAuth uses ChatGPT internal API
|
||||
apiURL = chatgptCodexAPIURL
|
||||
chatgptAccountID = account.GetChatGPTAccountID()
|
||||
} else if account.Type == "apikey" {
|
||||
// API Key - use Platform API
|
||||
authToken = account.GetOpenAIApiKey()
|
||||
if authToken == "" {
|
||||
return s.sendErrorAndEnd(c, "No API key available")
|
||||
}
|
||||
|
||||
baseURL := account.GetOpenAIBaseURL()
|
||||
if baseURL == "" {
|
||||
baseURL = "https://api.openai.com"
|
||||
}
|
||||
apiURL = strings.TrimSuffix(baseURL, "/") + "/v1/responses"
|
||||
} else {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
|
||||
}
|
||||
|
||||
// Set SSE headers
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||
c.Writer.Flush()
|
||||
|
||||
// Create OpenAI Responses API payload
|
||||
payload := createOpenAITestPayload(testModelID, isOAuth)
|
||||
payloadBytes, _ := json.Marshal(payload)
|
||||
|
||||
// Send test_start event
|
||||
s.sendEvent(c, TestEvent{Type: "test_start", Model: testModelID})
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewReader(payloadBytes))
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, "Failed to create request")
|
||||
}
|
||||
|
||||
// Set common headers
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+authToken)
|
||||
|
||||
// Set OAuth-specific headers for ChatGPT internal API
|
||||
if isOAuth {
|
||||
req.Host = "chatgpt.com"
|
||||
req.Header.Set("accept", "text/event-stream")
|
||||
if chatgptAccountID != "" {
|
||||
req.Header.Set("chatgpt-account-id", chatgptAccountID)
|
||||
}
|
||||
}
|
||||
|
||||
// Get proxy URL
|
||||
proxyURL := ""
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
resp, err := s.httpUpstream.Do(req, proxyURL)
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body)))
|
||||
}
|
||||
|
||||
// Process SSE stream
|
||||
return s.processOpenAIStream(c, resp.Body)
|
||||
}
|
||||
|
||||
// createOpenAITestPayload creates a test payload for OpenAI Responses API
|
||||
func createOpenAITestPayload(modelID string, isOAuth bool) map[string]any {
|
||||
payload := map[string]any{
|
||||
"model": modelID,
|
||||
"input": []map[string]any{
|
||||
{
|
||||
"role": "user",
|
||||
"content": []map[string]any{
|
||||
{
|
||||
"type": "input_text",
|
||||
"text": "hi",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"stream": true,
|
||||
}
|
||||
|
||||
// OAuth accounts using ChatGPT internal API require store: false and instructions
|
||||
if isOAuth {
|
||||
payload["store"] = false
|
||||
payload["instructions"] = openai.DefaultInstructions
|
||||
}
|
||||
|
||||
return payload
|
||||
}
|
||||
|
||||
// processClaudeStream processes the SSE stream from Claude API
|
||||
func (s *AccountTestService) processClaudeStream(c *gin.Context, body io.Reader) error {
|
||||
reader := bufio.NewReader(body)
|
||||
|
||||
for {
|
||||
@@ -291,6 +451,59 @@ func (s *AccountTestService) processStream(c *gin.Context, body io.Reader) error
|
||||
}
|
||||
}
|
||||
|
||||
// processOpenAIStream processes the SSE stream from OpenAI Responses API
|
||||
func (s *AccountTestService) processOpenAIStream(c *gin.Context, body io.Reader) error {
|
||||
reader := bufio.NewReader(body)
|
||||
|
||||
for {
|
||||
line, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
|
||||
return nil
|
||||
}
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Stream read error: %s", err.Error()))
|
||||
}
|
||||
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" || !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
|
||||
jsonStr := strings.TrimPrefix(line, "data: ")
|
||||
if jsonStr == "[DONE]" {
|
||||
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
|
||||
return nil
|
||||
}
|
||||
|
||||
var data map[string]any
|
||||
if err := json.Unmarshal([]byte(jsonStr), &data); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
eventType, _ := data["type"].(string)
|
||||
|
||||
switch eventType {
|
||||
case "response.output_text.delta":
|
||||
// OpenAI Responses API uses "delta" field for text content
|
||||
if delta, ok := data["delta"].(string); ok && delta != "" {
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: delta})
|
||||
}
|
||||
case "response.completed":
|
||||
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
|
||||
return nil
|
||||
case "error":
|
||||
errorMsg := "Unknown error"
|
||||
if errData, ok := data["error"].(map[string]any); ok {
|
||||
if msg, ok := errData["message"].(string); ok {
|
||||
errorMsg = msg
|
||||
}
|
||||
}
|
||||
return s.sendErrorAndEnd(c, errorMsg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// sendEvent sends a SSE event to the client
|
||||
func (s *AccountTestService) sendEvent(c *gin.Context, event TestEvent) {
|
||||
eventJSON, _ := json.Marshal(event)
|
||||
|
||||
@@ -24,11 +24,6 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// ClaudeUpstream handles HTTP requests to Claude API
|
||||
type ClaudeUpstream interface {
|
||||
Do(req *http.Request, proxyURL string) (*http.Response, error)
|
||||
}
|
||||
|
||||
const (
|
||||
claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true"
|
||||
claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true"
|
||||
@@ -87,7 +82,7 @@ type GatewayService struct {
|
||||
rateLimitService *RateLimitService
|
||||
billingCacheService *BillingCacheService
|
||||
identityService *IdentityService
|
||||
claudeUpstream ClaudeUpstream
|
||||
httpUpstream ports.HTTPUpstream
|
||||
}
|
||||
|
||||
// NewGatewayService creates a new GatewayService
|
||||
@@ -102,7 +97,7 @@ func NewGatewayService(
|
||||
rateLimitService *RateLimitService,
|
||||
billingCacheService *BillingCacheService,
|
||||
identityService *IdentityService,
|
||||
claudeUpstream ClaudeUpstream,
|
||||
httpUpstream ports.HTTPUpstream,
|
||||
) *GatewayService {
|
||||
return &GatewayService{
|
||||
accountRepo: accountRepo,
|
||||
@@ -115,7 +110,7 @@ func NewGatewayService(
|
||||
rateLimitService: rateLimitService,
|
||||
billingCacheService: billingCacheService,
|
||||
identityService: identityService,
|
||||
claudeUpstream: claudeUpstream,
|
||||
httpUpstream: httpUpstream,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -285,13 +280,13 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 获取可调度账号列表(排除限流和过载的账号)
|
||||
// 2. 获取可调度账号列表(排除限流和过载的账号,仅限 Anthropic 平台)
|
||||
var accounts []model.Account
|
||||
var err error
|
||||
if groupID != nil {
|
||||
accounts, err = s.accountRepo.ListSchedulableByGroupID(ctx, *groupID)
|
||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, model.PlatformAnthropic)
|
||||
} else {
|
||||
accounts, err = s.accountRepo.ListSchedulable(ctx)
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, model.PlatformAnthropic)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||
@@ -407,7 +402,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *m
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
resp, err := s.claudeUpstream.Do(upstreamReq, proxyURL)
|
||||
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("upstream request failed: %w", err)
|
||||
}
|
||||
@@ -481,7 +476,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
||||
|
||||
// 设置认证头
|
||||
if tokenType == "oauth" {
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
req.Header.Set("authorization", "Bearer "+token)
|
||||
} else {
|
||||
req.Header.Set("x-api-key", token)
|
||||
}
|
||||
@@ -502,8 +497,8 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
||||
}
|
||||
|
||||
// 确保必要的headers存在
|
||||
if req.Header.Get("Content-Type") == "" {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if req.Header.Get("content-type") == "" {
|
||||
req.Header.Set("content-type", "application/json")
|
||||
}
|
||||
if req.Header.Get("anthropic-version") == "" {
|
||||
req.Header.Set("anthropic-version", "2023-06-01")
|
||||
@@ -982,7 +977,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
resp, err := s.claudeUpstream.Do(upstreamReq, proxyURL)
|
||||
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL)
|
||||
if err != nil {
|
||||
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Request failed")
|
||||
return fmt.Errorf("upstream request failed: %w", err)
|
||||
@@ -1049,7 +1044,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
|
||||
|
||||
// 设置认证头
|
||||
if tokenType == "oauth" {
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
req.Header.Set("authorization", "Bearer "+token)
|
||||
} else {
|
||||
req.Header.Set("x-api-key", token)
|
||||
}
|
||||
@@ -1073,8 +1068,8 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
|
||||
}
|
||||
|
||||
// 确保必要的 headers 存在
|
||||
if req.Header.Get("Content-Type") == "" {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if req.Header.Get("content-type") == "" {
|
||||
req.Header.Set("content-type", "application/json")
|
||||
}
|
||||
if req.Header.Get("anthropic-version") == "" {
|
||||
req.Header.Set("anthropic-version", "2023-06-01")
|
||||
|
||||
@@ -114,12 +114,12 @@ func (s *IdentityService) ApplyFingerprint(req *http.Request, fp *ports.Fingerpr
|
||||
return
|
||||
}
|
||||
|
||||
// 设置User-Agent
|
||||
// 设置user-agent
|
||||
if fp.UserAgent != "" {
|
||||
req.Header.Set("User-Agent", fp.UserAgent)
|
||||
req.Header.Set("user-agent", fp.UserAgent)
|
||||
}
|
||||
|
||||
// 设置x-stainless-*头(使用正确的大小写)
|
||||
// 设置x-stainless-*头
|
||||
if fp.StainlessLang != "" {
|
||||
req.Header.Set("X-Stainless-Lang", fp.StainlessLang)
|
||||
}
|
||||
|
||||
@@ -284,3 +284,8 @@ func (s *OAuthService) RefreshAccountToken(ctx context.Context, account *model.A
|
||||
|
||||
return s.RefreshToken(ctx, refreshToken, proxyURL)
|
||||
}
|
||||
|
||||
// Stop stops the session store cleanup goroutine
|
||||
func (s *OAuthService) Stop() {
|
||||
s.sessionStore.Stop()
|
||||
}
|
||||
|
||||
700
backend/internal/service/openai_gateway_service.go
Normal file
700
backend/internal/service/openai_gateway_service.go
Normal file
@@ -0,0 +1,700 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/config"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/service/ports"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const (
|
||||
// ChatGPT internal API for OAuth accounts
|
||||
chatgptCodexURL = "https://chatgpt.com/backend-api/codex/responses"
|
||||
// OpenAI Platform API for API Key accounts (fallback)
|
||||
openaiPlatformAPIURL = "https://api.openai.com/v1/responses"
|
||||
openaiStickySessionTTL = time.Hour // 粘性会话TTL
|
||||
)
|
||||
|
||||
// OpenAI allowed headers whitelist (for non-OAuth accounts)
|
||||
var openaiAllowedHeaders = map[string]bool{
|
||||
"accept-language": true,
|
||||
"content-type": true,
|
||||
"user-agent": true,
|
||||
"originator": true,
|
||||
"session_id": true,
|
||||
}
|
||||
|
||||
// OpenAIUsage represents OpenAI API response usage
|
||||
type OpenAIUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
CacheCreationInputTokens int `json:"cache_creation_input_tokens,omitempty"`
|
||||
CacheReadInputTokens int `json:"cache_read_input_tokens,omitempty"`
|
||||
}
|
||||
|
||||
// OpenAIForwardResult represents the result of forwarding
|
||||
type OpenAIForwardResult struct {
|
||||
RequestID string
|
||||
Usage OpenAIUsage
|
||||
Model string
|
||||
Stream bool
|
||||
Duration time.Duration
|
||||
FirstTokenMs *int
|
||||
}
|
||||
|
||||
// OpenAIGatewayService handles OpenAI API gateway operations
|
||||
type OpenAIGatewayService struct {
|
||||
accountRepo ports.AccountRepository
|
||||
usageLogRepo ports.UsageLogRepository
|
||||
userRepo ports.UserRepository
|
||||
userSubRepo ports.UserSubscriptionRepository
|
||||
cache ports.GatewayCache
|
||||
cfg *config.Config
|
||||
billingService *BillingService
|
||||
rateLimitService *RateLimitService
|
||||
billingCacheService *BillingCacheService
|
||||
httpUpstream ports.HTTPUpstream
|
||||
}
|
||||
|
||||
// NewOpenAIGatewayService creates a new OpenAIGatewayService
|
||||
func NewOpenAIGatewayService(
|
||||
accountRepo ports.AccountRepository,
|
||||
usageLogRepo ports.UsageLogRepository,
|
||||
userRepo ports.UserRepository,
|
||||
userSubRepo ports.UserSubscriptionRepository,
|
||||
cache ports.GatewayCache,
|
||||
cfg *config.Config,
|
||||
billingService *BillingService,
|
||||
rateLimitService *RateLimitService,
|
||||
billingCacheService *BillingCacheService,
|
||||
httpUpstream ports.HTTPUpstream,
|
||||
) *OpenAIGatewayService {
|
||||
return &OpenAIGatewayService{
|
||||
accountRepo: accountRepo,
|
||||
usageLogRepo: usageLogRepo,
|
||||
userRepo: userRepo,
|
||||
userSubRepo: userSubRepo,
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
billingService: billingService,
|
||||
rateLimitService: rateLimitService,
|
||||
billingCacheService: billingCacheService,
|
||||
httpUpstream: httpUpstream,
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateSessionHash generates session hash from header (OpenAI uses session_id header)
|
||||
func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context) string {
|
||||
sessionID := c.GetHeader("session_id")
|
||||
if sessionID == "" {
|
||||
return ""
|
||||
}
|
||||
hash := sha256.Sum256([]byte(sessionID))
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
// SelectAccount selects an OpenAI account with sticky session support
|
||||
func (s *OpenAIGatewayService) SelectAccount(ctx context.Context, groupID *int64, sessionHash string) (*model.Account, error) {
|
||||
return s.SelectAccountForModel(ctx, groupID, sessionHash, "")
|
||||
}
|
||||
|
||||
// SelectAccountForModel selects an account supporting the requested model
|
||||
func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*model.Account, error) {
|
||||
// 1. Check sticky session
|
||||
if sessionHash != "" {
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, "openai:"+sessionHash)
|
||||
if err == nil && accountID > 0 {
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
if err == nil && account.IsSchedulable() && account.IsOpenAI() && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
|
||||
// Refresh sticky session TTL
|
||||
_ = s.cache.RefreshSessionTTL(ctx, "openai:"+sessionHash, openaiStickySessionTTL)
|
||||
return account, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Get schedulable OpenAI accounts
|
||||
var accounts []model.Account
|
||||
var err error
|
||||
if groupID != nil {
|
||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, model.PlatformOpenAI)
|
||||
} else {
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, model.PlatformOpenAI)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||
}
|
||||
|
||||
// 3. Select by priority + LRU
|
||||
var selected *model.Account
|
||||
for i := range accounts {
|
||||
acc := &accounts[i]
|
||||
// Check model support
|
||||
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
|
||||
continue
|
||||
}
|
||||
if selected == nil {
|
||||
selected = acc
|
||||
continue
|
||||
}
|
||||
// Lower priority value means higher priority
|
||||
if acc.Priority < selected.Priority {
|
||||
selected = acc
|
||||
} else if acc.Priority == selected.Priority {
|
||||
// Same priority, select least recently used
|
||||
if acc.LastUsedAt == nil || (selected.LastUsedAt != nil && acc.LastUsedAt.Before(*selected.LastUsedAt)) {
|
||||
selected = acc
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if selected == nil {
|
||||
if requestedModel != "" {
|
||||
return nil, fmt.Errorf("no available OpenAI accounts supporting model: %s", requestedModel)
|
||||
}
|
||||
return nil, errors.New("no available OpenAI accounts")
|
||||
}
|
||||
|
||||
// 4. Set sticky session
|
||||
if sessionHash != "" {
|
||||
_ = s.cache.SetSessionAccountID(ctx, "openai:"+sessionHash, selected.ID, openaiStickySessionTTL)
|
||||
}
|
||||
|
||||
return selected, nil
|
||||
}
|
||||
|
||||
// GetAccessToken gets the access token for an OpenAI account
|
||||
func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *model.Account) (string, string, error) {
|
||||
if account.Type == model.AccountTypeOAuth {
|
||||
accessToken := account.GetOpenAIAccessToken()
|
||||
if accessToken == "" {
|
||||
return "", "", errors.New("access_token not found in credentials")
|
||||
}
|
||||
return accessToken, "oauth", nil
|
||||
} else if account.Type == model.AccountTypeApiKey {
|
||||
apiKey := account.GetOpenAIApiKey()
|
||||
if apiKey == "" {
|
||||
return "", "", errors.New("api_key not found in credentials")
|
||||
}
|
||||
return apiKey, "apikey", nil
|
||||
}
|
||||
return "", "", fmt.Errorf("unsupported account type: %s", account.Type)
|
||||
}
|
||||
|
||||
// Forward forwards request to OpenAI API
|
||||
func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, account *model.Account, body []byte) (*OpenAIForwardResult, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
// Parse request body once (avoid multiple parse/serialize cycles)
|
||||
var reqBody map[string]any
|
||||
if err := json.Unmarshal(body, &reqBody); err != nil {
|
||||
return nil, fmt.Errorf("parse request: %w", err)
|
||||
}
|
||||
|
||||
// Extract model and stream from parsed body
|
||||
reqModel, _ := reqBody["model"].(string)
|
||||
reqStream, _ := reqBody["stream"].(bool)
|
||||
|
||||
// Track if body needs re-serialization
|
||||
bodyModified := false
|
||||
originalModel := reqModel
|
||||
|
||||
// Apply model mapping
|
||||
mappedModel := account.GetMappedModel(reqModel)
|
||||
if mappedModel != reqModel {
|
||||
reqBody["model"] = mappedModel
|
||||
bodyModified = true
|
||||
}
|
||||
|
||||
// For OAuth accounts using ChatGPT internal API, add store: false
|
||||
if account.Type == model.AccountTypeOAuth {
|
||||
reqBody["store"] = false
|
||||
bodyModified = true
|
||||
}
|
||||
|
||||
// Re-serialize body only if modified
|
||||
if bodyModified {
|
||||
var err error
|
||||
body, err = json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("serialize request body: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Get access token
|
||||
token, _, err := s.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Build upstream request
|
||||
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, reqStream)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get proxy URL
|
||||
proxyURL := ""
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
// Send request
|
||||
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("upstream request failed: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
// Handle error response
|
||||
if resp.StatusCode >= 400 {
|
||||
return s.handleErrorResponse(ctx, resp, c, account)
|
||||
}
|
||||
|
||||
// Handle normal response
|
||||
var usage *OpenAIUsage
|
||||
var firstTokenMs *int
|
||||
if reqStream {
|
||||
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, mappedModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
usage = streamResult.usage
|
||||
firstTokenMs = streamResult.firstTokenMs
|
||||
} else {
|
||||
usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, mappedModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &OpenAIForwardResult{
|
||||
RequestID: resp.Header.Get("x-request-id"),
|
||||
Usage: *usage,
|
||||
Model: originalModel,
|
||||
Stream: reqStream,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *model.Account, body []byte, token string, isStream bool) (*http.Request, error) {
|
||||
// Determine target URL based on account type
|
||||
var targetURL string
|
||||
if account.Type == model.AccountTypeOAuth {
|
||||
// OAuth accounts use ChatGPT internal API
|
||||
targetURL = chatgptCodexURL
|
||||
} else if account.Type == model.AccountTypeApiKey {
|
||||
// API Key accounts use Platform API or custom base URL
|
||||
baseURL := account.GetOpenAIBaseURL()
|
||||
if baseURL != "" {
|
||||
targetURL = baseURL + "/v1/responses"
|
||||
} else {
|
||||
targetURL = openaiPlatformAPIURL
|
||||
}
|
||||
} else {
|
||||
targetURL = openaiPlatformAPIURL
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Set authentication header
|
||||
req.Header.Set("authorization", "Bearer "+token)
|
||||
|
||||
// Set headers specific to OAuth accounts (ChatGPT internal API)
|
||||
if account.Type == model.AccountTypeOAuth {
|
||||
// Required: set Host for ChatGPT API (must use req.Host, not Header.Set)
|
||||
req.Host = "chatgpt.com"
|
||||
// Required: set chatgpt-account-id header
|
||||
chatgptAccountID := account.GetChatGPTAccountID()
|
||||
if chatgptAccountID != "" {
|
||||
req.Header.Set("chatgpt-account-id", chatgptAccountID)
|
||||
}
|
||||
// Set accept header based on stream mode
|
||||
if isStream {
|
||||
req.Header.Set("accept", "text/event-stream")
|
||||
} else {
|
||||
req.Header.Set("accept", "application/json")
|
||||
}
|
||||
}
|
||||
|
||||
// Whitelist passthrough headers
|
||||
for key, values := range c.Request.Header {
|
||||
lowerKey := strings.ToLower(key)
|
||||
if openaiAllowedHeaders[lowerKey] {
|
||||
for _, v := range values {
|
||||
req.Header.Add(key, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Apply custom User-Agent if configured
|
||||
customUA := account.GetOpenAIUserAgent()
|
||||
if customUA != "" {
|
||||
req.Header.Set("user-agent", customUA)
|
||||
}
|
||||
|
||||
// Ensure required headers exist
|
||||
if req.Header.Get("content-type") == "" {
|
||||
req.Header.Set("content-type", "application/json")
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account) (*OpenAIForwardResult, error) {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
// Check custom error codes
|
||||
if !account.ShouldHandleErrorCode(resp.StatusCode) {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": gin.H{
|
||||
"type": "upstream_error",
|
||||
"message": "Upstream gateway error",
|
||||
},
|
||||
})
|
||||
return nil, fmt.Errorf("upstream error: %d (not in custom error codes)", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Handle upstream error (mark account status)
|
||||
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
||||
|
||||
// Return appropriate error response
|
||||
var errType, errMsg string
|
||||
var statusCode int
|
||||
|
||||
switch resp.StatusCode {
|
||||
case 401:
|
||||
statusCode = http.StatusBadGateway
|
||||
errType = "upstream_error"
|
||||
errMsg = "Upstream authentication failed, please contact administrator"
|
||||
case 403:
|
||||
statusCode = http.StatusBadGateway
|
||||
errType = "upstream_error"
|
||||
errMsg = "Upstream access forbidden, please contact administrator"
|
||||
case 429:
|
||||
statusCode = http.StatusTooManyRequests
|
||||
errType = "rate_limit_error"
|
||||
errMsg = "Upstream rate limit exceeded, please retry later"
|
||||
default:
|
||||
statusCode = http.StatusBadGateway
|
||||
errType = "upstream_error"
|
||||
errMsg = "Upstream request failed"
|
||||
}
|
||||
|
||||
c.JSON(statusCode, gin.H{
|
||||
"error": gin.H{
|
||||
"type": errType,
|
||||
"message": errMsg,
|
||||
},
|
||||
})
|
||||
|
||||
return nil, fmt.Errorf("upstream error: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// openaiStreamingResult streaming response result
|
||||
type openaiStreamingResult struct {
|
||||
usage *OpenAIUsage
|
||||
firstTokenMs *int
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account, startTime time.Time, originalModel, mappedModel string) (*openaiStreamingResult, error) {
|
||||
// Set SSE response headers
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("X-Accel-Buffering", "no")
|
||||
|
||||
// Pass through other headers
|
||||
if v := resp.Header.Get("x-request-id"); v != "" {
|
||||
c.Header("x-request-id", v)
|
||||
}
|
||||
|
||||
w := c.Writer
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
return nil, errors.New("streaming not supported")
|
||||
}
|
||||
|
||||
usage := &OpenAIUsage{}
|
||||
var firstTokenMs *int
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Buffer(make([]byte, 64*1024), 1024*1024)
|
||||
|
||||
needModelReplace := originalModel != mappedModel
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
|
||||
// Replace model in response if needed
|
||||
if needModelReplace && strings.HasPrefix(line, "data: ") {
|
||||
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
|
||||
}
|
||||
|
||||
// Forward line
|
||||
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
||||
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
||||
}
|
||||
flusher.Flush()
|
||||
|
||||
// Parse usage data
|
||||
if strings.HasPrefix(line, "data: ") {
|
||||
data := line[6:]
|
||||
// Record first token time
|
||||
if firstTokenMs == nil && data != "" && data != "[DONE]" {
|
||||
ms := int(time.Since(startTime).Milliseconds())
|
||||
firstTokenMs = &ms
|
||||
}
|
||||
s.parseSSEUsage(data, usage)
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err)
|
||||
}
|
||||
|
||||
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) replaceModelInSSELine(line, fromModel, toModel string) string {
|
||||
data := line[6:]
|
||||
if data == "" || data == "[DONE]" {
|
||||
return line
|
||||
}
|
||||
|
||||
var event map[string]any
|
||||
if err := json.Unmarshal([]byte(data), &event); err != nil {
|
||||
return line
|
||||
}
|
||||
|
||||
// Replace model in response
|
||||
if m, ok := event["model"].(string); ok && m == fromModel {
|
||||
event["model"] = toModel
|
||||
newData, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
return line
|
||||
}
|
||||
return "data: " + string(newData)
|
||||
}
|
||||
|
||||
// Check nested response
|
||||
if response, ok := event["response"].(map[string]any); ok {
|
||||
if m, ok := response["model"].(string); ok && m == fromModel {
|
||||
response["model"] = toModel
|
||||
newData, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
return line
|
||||
}
|
||||
return "data: " + string(newData)
|
||||
}
|
||||
}
|
||||
|
||||
return line
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) parseSSEUsage(data string, usage *OpenAIUsage) {
|
||||
// Parse response.completed event for usage (OpenAI Responses format)
|
||||
var event struct {
|
||||
Type string `json:"type"`
|
||||
Response struct {
|
||||
Usage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
InputTokenDetails struct {
|
||||
CachedTokens int `json:"cached_tokens"`
|
||||
} `json:"input_tokens_details"`
|
||||
} `json:"usage"`
|
||||
} `json:"response"`
|
||||
}
|
||||
|
||||
if json.Unmarshal([]byte(data), &event) == nil && event.Type == "response.completed" {
|
||||
usage.InputTokens = event.Response.Usage.InputTokens
|
||||
usage.OutputTokens = event.Response.Usage.OutputTokens
|
||||
usage.CacheReadInputTokens = event.Response.Usage.InputTokenDetails.CachedTokens
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account, originalModel, mappedModel string) (*OpenAIUsage, error) {
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse usage
|
||||
var response struct {
|
||||
Usage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
InputTokenDetails struct {
|
||||
CachedTokens int `json:"cached_tokens"`
|
||||
} `json:"input_tokens_details"`
|
||||
} `json:"usage"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &response); err != nil {
|
||||
return nil, fmt.Errorf("parse response: %w", err)
|
||||
}
|
||||
|
||||
usage := &OpenAIUsage{
|
||||
InputTokens: response.Usage.InputTokens,
|
||||
OutputTokens: response.Usage.OutputTokens,
|
||||
CacheReadInputTokens: response.Usage.InputTokenDetails.CachedTokens,
|
||||
}
|
||||
|
||||
// Replace model in response if needed
|
||||
if originalModel != mappedModel {
|
||||
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
|
||||
}
|
||||
|
||||
// Pass through headers
|
||||
for key, values := range resp.Header {
|
||||
for _, value := range values {
|
||||
c.Header(key, value)
|
||||
}
|
||||
}
|
||||
|
||||
c.Data(resp.StatusCode, "application/json", body)
|
||||
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel, toModel string) []byte {
|
||||
var resp map[string]any
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
return body
|
||||
}
|
||||
|
||||
model, ok := resp["model"].(string)
|
||||
if !ok || model != fromModel {
|
||||
return body
|
||||
}
|
||||
|
||||
resp["model"] = toModel
|
||||
newBody, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
|
||||
return newBody
|
||||
}
|
||||
|
||||
// OpenAIRecordUsageInput input for recording usage
|
||||
type OpenAIRecordUsageInput struct {
|
||||
Result *OpenAIForwardResult
|
||||
ApiKey *model.ApiKey
|
||||
User *model.User
|
||||
Account *model.Account
|
||||
Subscription *model.UserSubscription
|
||||
}
|
||||
|
||||
// RecordUsage records usage and deducts balance
|
||||
func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRecordUsageInput) error {
|
||||
result := input.Result
|
||||
apiKey := input.ApiKey
|
||||
user := input.User
|
||||
account := input.Account
|
||||
subscription := input.Subscription
|
||||
|
||||
// Calculate cost
|
||||
tokens := UsageTokens{
|
||||
InputTokens: result.Usage.InputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
||||
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
||||
}
|
||||
|
||||
// Get rate multiplier
|
||||
multiplier := s.cfg.Default.RateMultiplier
|
||||
if apiKey.GroupID != nil && apiKey.Group != nil {
|
||||
multiplier = apiKey.Group.RateMultiplier
|
||||
}
|
||||
|
||||
cost, err := s.billingService.CalculateCost(result.Model, tokens, multiplier)
|
||||
if err != nil {
|
||||
cost = &CostBreakdown{ActualCost: 0}
|
||||
}
|
||||
|
||||
// Determine billing type
|
||||
isSubscriptionBilling := subscription != nil && apiKey.Group != nil && apiKey.Group.IsSubscriptionType()
|
||||
billingType := model.BillingTypeBalance
|
||||
if isSubscriptionBilling {
|
||||
billingType = model.BillingTypeSubscription
|
||||
}
|
||||
|
||||
// Create usage log
|
||||
durationMs := int(result.Duration.Milliseconds())
|
||||
usageLog := &model.UsageLog{
|
||||
UserID: user.ID,
|
||||
ApiKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: result.RequestID,
|
||||
Model: result.Model,
|
||||
InputTokens: result.Usage.InputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
||||
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
||||
InputCost: cost.InputCost,
|
||||
OutputCost: cost.OutputCost,
|
||||
CacheCreationCost: cost.CacheCreationCost,
|
||||
CacheReadCost: cost.CacheReadCost,
|
||||
TotalCost: cost.TotalCost,
|
||||
ActualCost: cost.ActualCost,
|
||||
RateMultiplier: multiplier,
|
||||
BillingType: billingType,
|
||||
Stream: result.Stream,
|
||||
DurationMs: &durationMs,
|
||||
FirstTokenMs: result.FirstTokenMs,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
if apiKey.GroupID != nil {
|
||||
usageLog.GroupID = apiKey.GroupID
|
||||
}
|
||||
if subscription != nil {
|
||||
usageLog.SubscriptionID = &subscription.ID
|
||||
}
|
||||
|
||||
_ = s.usageLogRepo.Create(ctx, usageLog)
|
||||
|
||||
// Deduct based on billing type
|
||||
if isSubscriptionBilling {
|
||||
if cost.TotalCost > 0 {
|
||||
_ = s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost)
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_ = s.billingCacheService.UpdateSubscriptionUsage(cacheCtx, user.ID, *apiKey.GroupID, cost.TotalCost)
|
||||
}()
|
||||
}
|
||||
} else {
|
||||
if cost.ActualCost > 0 {
|
||||
_ = s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost)
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_ = s.billingCacheService.DeductBalanceCache(cacheCtx, user.ID, cost.ActualCost)
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
// Update account last used
|
||||
_ = s.accountRepo.UpdateLastUsed(ctx, account.ID)
|
||||
|
||||
return nil
|
||||
}
|
||||
257
backend/internal/service/openai_oauth_service.go
Normal file
257
backend/internal/service/openai_oauth_service.go
Normal file
@@ -0,0 +1,257 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/openai"
|
||||
"sub2api/internal/service/ports"
|
||||
)
|
||||
|
||||
// OpenAIOAuthService handles OpenAI OAuth authentication flows
|
||||
type OpenAIOAuthService struct {
|
||||
sessionStore *openai.SessionStore
|
||||
proxyRepo ports.ProxyRepository
|
||||
oauthClient ports.OpenAIOAuthClient
|
||||
}
|
||||
|
||||
// NewOpenAIOAuthService creates a new OpenAI OAuth service
|
||||
func NewOpenAIOAuthService(proxyRepo ports.ProxyRepository, oauthClient ports.OpenAIOAuthClient) *OpenAIOAuthService {
|
||||
return &OpenAIOAuthService{
|
||||
sessionStore: openai.NewSessionStore(),
|
||||
proxyRepo: proxyRepo,
|
||||
oauthClient: oauthClient,
|
||||
}
|
||||
}
|
||||
|
||||
// OpenAIAuthURLResult contains the authorization URL and session info
|
||||
type OpenAIAuthURLResult struct {
|
||||
AuthURL string `json:"auth_url"`
|
||||
SessionID string `json:"session_id"`
|
||||
}
|
||||
|
||||
// GenerateAuthURL generates an OpenAI OAuth authorization URL
|
||||
func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64, redirectURI string) (*OpenAIAuthURLResult, error) {
|
||||
// Generate PKCE values
|
||||
state, err := openai.GenerateState()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate state: %w", err)
|
||||
}
|
||||
|
||||
codeVerifier, err := openai.GenerateCodeVerifier()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate code verifier: %w", err)
|
||||
}
|
||||
|
||||
codeChallenge := openai.GenerateCodeChallenge(codeVerifier)
|
||||
|
||||
// Generate session ID
|
||||
sessionID, err := openai.GenerateSessionID()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate session ID: %w", err)
|
||||
}
|
||||
|
||||
// Get proxy URL if specified
|
||||
var proxyURL string
|
||||
if proxyID != nil {
|
||||
proxy, err := s.proxyRepo.GetByID(ctx, *proxyID)
|
||||
if err == nil && proxy != nil {
|
||||
proxyURL = proxy.URL()
|
||||
}
|
||||
}
|
||||
|
||||
// Use default redirect URI if not specified
|
||||
if redirectURI == "" {
|
||||
redirectURI = openai.DefaultRedirectURI
|
||||
}
|
||||
|
||||
// Store session
|
||||
session := &openai.OAuthSession{
|
||||
State: state,
|
||||
CodeVerifier: codeVerifier,
|
||||
RedirectURI: redirectURI,
|
||||
ProxyURL: proxyURL,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
s.sessionStore.Set(sessionID, session)
|
||||
|
||||
// Build authorization URL
|
||||
authURL := openai.BuildAuthorizationURL(state, codeChallenge, redirectURI)
|
||||
|
||||
return &OpenAIAuthURLResult{
|
||||
AuthURL: authURL,
|
||||
SessionID: sessionID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// OpenAIExchangeCodeInput represents the input for code exchange
|
||||
type OpenAIExchangeCodeInput struct {
|
||||
SessionID string
|
||||
Code string
|
||||
RedirectURI string
|
||||
ProxyID *int64
|
||||
}
|
||||
|
||||
// OpenAITokenInfo represents the token information for OpenAI
|
||||
type OpenAITokenInfo struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
IDToken string `json:"id_token,omitempty"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
ExpiresAt int64 `json:"expires_at"`
|
||||
Email string `json:"email,omitempty"`
|
||||
ChatGPTAccountID string `json:"chatgpt_account_id,omitempty"`
|
||||
ChatGPTUserID string `json:"chatgpt_user_id,omitempty"`
|
||||
OrganizationID string `json:"organization_id,omitempty"`
|
||||
}
|
||||
|
||||
// ExchangeCode exchanges authorization code for tokens
|
||||
func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExchangeCodeInput) (*OpenAITokenInfo, error) {
|
||||
// Get session
|
||||
session, ok := s.sessionStore.Get(input.SessionID)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("session not found or expired")
|
||||
}
|
||||
|
||||
// Get proxy URL
|
||||
proxyURL := session.ProxyURL
|
||||
if input.ProxyID != nil {
|
||||
proxy, err := s.proxyRepo.GetByID(ctx, *input.ProxyID)
|
||||
if err == nil && proxy != nil {
|
||||
proxyURL = proxy.URL()
|
||||
}
|
||||
}
|
||||
|
||||
// Use redirect URI from session or input
|
||||
redirectURI := session.RedirectURI
|
||||
if input.RedirectURI != "" {
|
||||
redirectURI = input.RedirectURI
|
||||
}
|
||||
|
||||
// Exchange code for token
|
||||
tokenResp, err := s.oauthClient.ExchangeCode(ctx, input.Code, session.CodeVerifier, redirectURI, proxyURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to exchange code: %w", err)
|
||||
}
|
||||
|
||||
// Parse ID token to get user info
|
||||
var userInfo *openai.UserInfo
|
||||
if tokenResp.IDToken != "" {
|
||||
claims, err := openai.ParseIDToken(tokenResp.IDToken)
|
||||
if err == nil {
|
||||
userInfo = claims.GetUserInfo()
|
||||
}
|
||||
}
|
||||
|
||||
// Delete session after successful exchange
|
||||
s.sessionStore.Delete(input.SessionID)
|
||||
|
||||
tokenInfo := &OpenAITokenInfo{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
IDToken: tokenResp.IDToken,
|
||||
ExpiresIn: int64(tokenResp.ExpiresIn),
|
||||
ExpiresAt: time.Now().Unix() + int64(tokenResp.ExpiresIn),
|
||||
}
|
||||
|
||||
if userInfo != nil {
|
||||
tokenInfo.Email = userInfo.Email
|
||||
tokenInfo.ChatGPTAccountID = userInfo.ChatGPTAccountID
|
||||
tokenInfo.ChatGPTUserID = userInfo.ChatGPTUserID
|
||||
tokenInfo.OrganizationID = userInfo.OrganizationID
|
||||
}
|
||||
|
||||
return tokenInfo, nil
|
||||
}
|
||||
|
||||
// RefreshToken refreshes an OpenAI OAuth token
|
||||
func (s *OpenAIOAuthService) RefreshToken(ctx context.Context, refreshToken string, proxyURL string) (*OpenAITokenInfo, error) {
|
||||
tokenResp, err := s.oauthClient.RefreshToken(ctx, refreshToken, proxyURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse ID token to get user info
|
||||
var userInfo *openai.UserInfo
|
||||
if tokenResp.IDToken != "" {
|
||||
claims, err := openai.ParseIDToken(tokenResp.IDToken)
|
||||
if err == nil {
|
||||
userInfo = claims.GetUserInfo()
|
||||
}
|
||||
}
|
||||
|
||||
tokenInfo := &OpenAITokenInfo{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
IDToken: tokenResp.IDToken,
|
||||
ExpiresIn: int64(tokenResp.ExpiresIn),
|
||||
ExpiresAt: time.Now().Unix() + int64(tokenResp.ExpiresIn),
|
||||
}
|
||||
|
||||
if userInfo != nil {
|
||||
tokenInfo.Email = userInfo.Email
|
||||
tokenInfo.ChatGPTAccountID = userInfo.ChatGPTAccountID
|
||||
tokenInfo.ChatGPTUserID = userInfo.ChatGPTUserID
|
||||
tokenInfo.OrganizationID = userInfo.OrganizationID
|
||||
}
|
||||
|
||||
return tokenInfo, nil
|
||||
}
|
||||
|
||||
// RefreshAccountToken refreshes token for an OpenAI account
|
||||
func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *model.Account) (*OpenAITokenInfo, error) {
|
||||
if !account.IsOpenAI() {
|
||||
return nil, fmt.Errorf("account is not an OpenAI account")
|
||||
}
|
||||
|
||||
refreshToken := account.GetOpenAIRefreshToken()
|
||||
if refreshToken == "" {
|
||||
return nil, fmt.Errorf("no refresh token available")
|
||||
}
|
||||
|
||||
var proxyURL string
|
||||
if account.ProxyID != nil {
|
||||
proxy, err := s.proxyRepo.GetByID(ctx, *account.ProxyID)
|
||||
if err == nil && proxy != nil {
|
||||
proxyURL = proxy.URL()
|
||||
}
|
||||
}
|
||||
|
||||
return s.RefreshToken(ctx, refreshToken, proxyURL)
|
||||
}
|
||||
|
||||
// BuildAccountCredentials builds credentials map from token info
|
||||
func (s *OpenAIOAuthService) BuildAccountCredentials(tokenInfo *OpenAITokenInfo) map[string]any {
|
||||
expiresAt := time.Unix(tokenInfo.ExpiresAt, 0).Format(time.RFC3339)
|
||||
|
||||
creds := map[string]any{
|
||||
"access_token": tokenInfo.AccessToken,
|
||||
"refresh_token": tokenInfo.RefreshToken,
|
||||
"expires_at": expiresAt,
|
||||
}
|
||||
|
||||
if tokenInfo.IDToken != "" {
|
||||
creds["id_token"] = tokenInfo.IDToken
|
||||
}
|
||||
if tokenInfo.Email != "" {
|
||||
creds["email"] = tokenInfo.Email
|
||||
}
|
||||
if tokenInfo.ChatGPTAccountID != "" {
|
||||
creds["chatgpt_account_id"] = tokenInfo.ChatGPTAccountID
|
||||
}
|
||||
if tokenInfo.ChatGPTUserID != "" {
|
||||
creds["chatgpt_user_id"] = tokenInfo.ChatGPTUserID
|
||||
}
|
||||
if tokenInfo.OrganizationID != "" {
|
||||
creds["organization_id"] = tokenInfo.OrganizationID
|
||||
}
|
||||
|
||||
return creds
|
||||
}
|
||||
|
||||
// Stop stops the session store cleanup goroutine
|
||||
func (s *OpenAIOAuthService) Stop() {
|
||||
s.sessionStore.Stop()
|
||||
}
|
||||
@@ -27,6 +27,8 @@ type AccountRepository interface {
|
||||
|
||||
ListSchedulable(ctx context.Context) ([]model.Account, error)
|
||||
ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]model.Account, error)
|
||||
ListSchedulableByPlatform(ctx context.Context, platform string) ([]model.Account, error)
|
||||
ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]model.Account, error)
|
||||
|
||||
SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error
|
||||
SetOverloaded(ctx context.Context, id int64, until time.Time) error
|
||||
|
||||
9
backend/internal/service/ports/http_upstream.go
Normal file
9
backend/internal/service/ports/http_upstream.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package ports
|
||||
|
||||
import "net/http"
|
||||
|
||||
// HTTPUpstream interface for making HTTP requests to upstream APIs (Claude, OpenAI, etc.)
|
||||
// This is a generic interface that can be used for any HTTP-based upstream service.
|
||||
type HTTPUpstream interface {
|
||||
Do(req *http.Request, proxyURL string) (*http.Response, error)
|
||||
}
|
||||
13
backend/internal/service/ports/openai_oauth.go
Normal file
13
backend/internal/service/ports/openai_oauth.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package ports
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"sub2api/internal/pkg/openai"
|
||||
)
|
||||
|
||||
// OpenAIOAuthClient interface for OpenAI OAuth operations
|
||||
type OpenAIOAuthClient interface {
|
||||
ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error)
|
||||
RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error)
|
||||
}
|
||||
@@ -2,30 +2,32 @@ package service
|
||||
|
||||
// Services 服务集合容器
|
||||
type Services struct {
|
||||
Auth *AuthService
|
||||
User *UserService
|
||||
ApiKey *ApiKeyService
|
||||
Group *GroupService
|
||||
Account *AccountService
|
||||
Proxy *ProxyService
|
||||
Redeem *RedeemService
|
||||
Usage *UsageService
|
||||
Pricing *PricingService
|
||||
Billing *BillingService
|
||||
BillingCache *BillingCacheService
|
||||
Admin AdminService
|
||||
Gateway *GatewayService
|
||||
OAuth *OAuthService
|
||||
RateLimit *RateLimitService
|
||||
AccountUsage *AccountUsageService
|
||||
AccountTest *AccountTestService
|
||||
Setting *SettingService
|
||||
Email *EmailService
|
||||
EmailQueue *EmailQueueService
|
||||
Turnstile *TurnstileService
|
||||
Subscription *SubscriptionService
|
||||
Concurrency *ConcurrencyService
|
||||
Identity *IdentityService
|
||||
Update *UpdateService
|
||||
TokenRefresh *TokenRefreshService
|
||||
Auth *AuthService
|
||||
User *UserService
|
||||
ApiKey *ApiKeyService
|
||||
Group *GroupService
|
||||
Account *AccountService
|
||||
Proxy *ProxyService
|
||||
Redeem *RedeemService
|
||||
Usage *UsageService
|
||||
Pricing *PricingService
|
||||
Billing *BillingService
|
||||
BillingCache *BillingCacheService
|
||||
Admin AdminService
|
||||
Gateway *GatewayService
|
||||
OpenAIGateway *OpenAIGatewayService
|
||||
OAuth *OAuthService
|
||||
OpenAIOAuth *OpenAIOAuthService
|
||||
RateLimit *RateLimitService
|
||||
AccountUsage *AccountUsageService
|
||||
AccountTest *AccountTestService
|
||||
Setting *SettingService
|
||||
Email *EmailService
|
||||
EmailQueue *EmailQueueService
|
||||
Turnstile *TurnstileService
|
||||
Subscription *SubscriptionService
|
||||
Concurrency *ConcurrencyService
|
||||
Identity *IdentityService
|
||||
Update *UpdateService
|
||||
TokenRefresh *TokenRefreshService
|
||||
}
|
||||
|
||||
@@ -27,6 +27,7 @@ type TokenRefreshService struct {
|
||||
func NewTokenRefreshService(
|
||||
accountRepo ports.AccountRepository,
|
||||
oauthService *OAuthService,
|
||||
openaiOAuthService *OpenAIOAuthService,
|
||||
cfg *config.Config,
|
||||
) *TokenRefreshService {
|
||||
s := &TokenRefreshService{
|
||||
@@ -38,9 +39,7 @@ func NewTokenRefreshService(
|
||||
// 注册平台特定的刷新器
|
||||
s.refreshers = []TokenRefresher{
|
||||
NewClaudeTokenRefresher(oauthService),
|
||||
// 未来可以添加其他平台的刷新器:
|
||||
// NewOpenAITokenRefresher(...),
|
||||
// NewGeminiTokenRefresher(...),
|
||||
NewOpenAITokenRefresher(openaiOAuthService),
|
||||
}
|
||||
|
||||
return s
|
||||
|
||||
@@ -88,3 +88,54 @@ func (r *ClaudeTokenRefresher) Refresh(ctx context.Context, account *model.Accou
|
||||
|
||||
return newCredentials, nil
|
||||
}
|
||||
|
||||
// OpenAITokenRefresher 处理 OpenAI OAuth token刷新
|
||||
type OpenAITokenRefresher struct {
|
||||
openaiOAuthService *OpenAIOAuthService
|
||||
}
|
||||
|
||||
// NewOpenAITokenRefresher 创建 OpenAI token刷新器
|
||||
func NewOpenAITokenRefresher(openaiOAuthService *OpenAIOAuthService) *OpenAITokenRefresher {
|
||||
return &OpenAITokenRefresher{
|
||||
openaiOAuthService: openaiOAuthService,
|
||||
}
|
||||
}
|
||||
|
||||
// CanRefresh 检查是否能处理此账号
|
||||
// 只处理 openai 平台的 oauth 类型账号
|
||||
func (r *OpenAITokenRefresher) CanRefresh(account *model.Account) bool {
|
||||
return account.Platform == model.PlatformOpenAI &&
|
||||
account.Type == model.AccountTypeOAuth
|
||||
}
|
||||
|
||||
// NeedsRefresh 检查token是否需要刷新
|
||||
// 基于 expires_at 字段判断是否在刷新窗口内
|
||||
func (r *OpenAITokenRefresher) NeedsRefresh(account *model.Account, refreshWindow time.Duration) bool {
|
||||
expiresAt := account.GetOpenAITokenExpiresAt()
|
||||
if expiresAt == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return time.Until(*expiresAt) < refreshWindow
|
||||
}
|
||||
|
||||
// Refresh 执行token刷新
|
||||
// 保留原有credentials中的所有字段,只更新token相关字段
|
||||
func (r *OpenAITokenRefresher) Refresh(ctx context.Context, account *model.Account) (map[string]any, error) {
|
||||
tokenInfo, err := r.openaiOAuthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 使用服务提供的方法构建新凭证,并保留原有字段
|
||||
newCredentials := r.openaiOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
|
||||
// 保留原有credentials中非token相关字段
|
||||
for k, v := range account.Credentials {
|
||||
if _, exists := newCredentials[k]; !exists {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
return newCredentials, nil
|
||||
}
|
||||
|
||||
@@ -37,9 +37,10 @@ func ProvideEmailQueueService(emailService *EmailService) *EmailQueueService {
|
||||
func ProvideTokenRefreshService(
|
||||
accountRepo ports.AccountRepository,
|
||||
oauthService *OAuthService,
|
||||
openaiOAuthService *OpenAIOAuthService,
|
||||
cfg *config.Config,
|
||||
) *TokenRefreshService {
|
||||
svc := NewTokenRefreshService(accountRepo, oauthService, cfg)
|
||||
svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, cfg)
|
||||
svc.Start()
|
||||
return svc
|
||||
}
|
||||
@@ -60,7 +61,9 @@ var ProviderSet = wire.NewSet(
|
||||
NewBillingCacheService,
|
||||
NewAdminService,
|
||||
NewGatewayService,
|
||||
NewOpenAIGatewayService,
|
||||
NewOAuthService,
|
||||
NewOpenAIOAuthService,
|
||||
NewRateLimitService,
|
||||
NewAccountUsageService,
|
||||
NewAccountTestService,
|
||||
|
||||
Reference in New Issue
Block a user