Files
sub2api/backend/internal/service/openai_gateway_service.go
2025-12-22 22:58:31 +08:00

701 lines
20 KiB
Go

package service
import (
"bufio"
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"time"
"sub2api/internal/config"
"sub2api/internal/model"
"sub2api/internal/service/ports"
"github.com/gin-gonic/gin"
)
const (
// ChatGPT internal API for OAuth accounts
chatgptCodexURL = "https://chatgpt.com/backend-api/codex/responses"
// OpenAI Platform API for API Key accounts (fallback)
openaiPlatformAPIURL = "https://api.openai.com/v1/responses"
openaiStickySessionTTL = time.Hour // 粘性会话TTL
)
// OpenAI allowed headers whitelist (for non-OAuth accounts)
var openaiAllowedHeaders = map[string]bool{
"accept-language": true,
"content-type": true,
"user-agent": true,
"originator": true,
"session_id": true,
}
// OpenAIUsage represents OpenAI API response usage
type OpenAIUsage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
CacheCreationInputTokens int `json:"cache_creation_input_tokens,omitempty"`
CacheReadInputTokens int `json:"cache_read_input_tokens,omitempty"`
}
// OpenAIForwardResult represents the result of forwarding
type OpenAIForwardResult struct {
RequestID string
Usage OpenAIUsage
Model string
Stream bool
Duration time.Duration
FirstTokenMs *int
}
// OpenAIGatewayService handles OpenAI API gateway operations
type OpenAIGatewayService struct {
accountRepo ports.AccountRepository
usageLogRepo ports.UsageLogRepository
userRepo ports.UserRepository
userSubRepo ports.UserSubscriptionRepository
cache ports.GatewayCache
cfg *config.Config
billingService *BillingService
rateLimitService *RateLimitService
billingCacheService *BillingCacheService
httpUpstream ports.HTTPUpstream
}
// NewOpenAIGatewayService creates a new OpenAIGatewayService
func NewOpenAIGatewayService(
accountRepo ports.AccountRepository,
usageLogRepo ports.UsageLogRepository,
userRepo ports.UserRepository,
userSubRepo ports.UserSubscriptionRepository,
cache ports.GatewayCache,
cfg *config.Config,
billingService *BillingService,
rateLimitService *RateLimitService,
billingCacheService *BillingCacheService,
httpUpstream ports.HTTPUpstream,
) *OpenAIGatewayService {
return &OpenAIGatewayService{
accountRepo: accountRepo,
usageLogRepo: usageLogRepo,
userRepo: userRepo,
userSubRepo: userSubRepo,
cache: cache,
cfg: cfg,
billingService: billingService,
rateLimitService: rateLimitService,
billingCacheService: billingCacheService,
httpUpstream: httpUpstream,
}
}
// GenerateSessionHash generates session hash from header (OpenAI uses session_id header)
func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context) string {
sessionID := c.GetHeader("session_id")
if sessionID == "" {
return ""
}
hash := sha256.Sum256([]byte(sessionID))
return hex.EncodeToString(hash[:])
}
// SelectAccount selects an OpenAI account with sticky session support
func (s *OpenAIGatewayService) SelectAccount(ctx context.Context, groupID *int64, sessionHash string) (*model.Account, error) {
return s.SelectAccountForModel(ctx, groupID, sessionHash, "")
}
// SelectAccountForModel selects an account supporting the requested model
func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*model.Account, error) {
// 1. Check sticky session
if sessionHash != "" {
accountID, err := s.cache.GetSessionAccountID(ctx, "openai:"+sessionHash)
if err == nil && accountID > 0 {
account, err := s.accountRepo.GetByID(ctx, accountID)
if err == nil && account.IsSchedulable() && account.IsOpenAI() && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
// Refresh sticky session TTL
_ = s.cache.RefreshSessionTTL(ctx, "openai:"+sessionHash, openaiStickySessionTTL)
return account, nil
}
}
}
// 2. Get schedulable OpenAI accounts
var accounts []model.Account
var err error
if groupID != nil {
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, model.PlatformOpenAI)
} else {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, model.PlatformOpenAI)
}
if err != nil {
return nil, fmt.Errorf("query accounts failed: %w", err)
}
// 3. Select by priority + LRU
var selected *model.Account
for i := range accounts {
acc := &accounts[i]
// Check model support
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
continue
}
if selected == nil {
selected = acc
continue
}
// Lower priority value means higher priority
if acc.Priority < selected.Priority {
selected = acc
} else if acc.Priority == selected.Priority {
// Same priority, select least recently used
if acc.LastUsedAt == nil || (selected.LastUsedAt != nil && acc.LastUsedAt.Before(*selected.LastUsedAt)) {
selected = acc
}
}
}
if selected == nil {
if requestedModel != "" {
return nil, fmt.Errorf("no available OpenAI accounts supporting model: %s", requestedModel)
}
return nil, errors.New("no available OpenAI accounts")
}
// 4. Set sticky session
if sessionHash != "" {
_ = s.cache.SetSessionAccountID(ctx, "openai:"+sessionHash, selected.ID, openaiStickySessionTTL)
}
return selected, nil
}
// GetAccessToken gets the access token for an OpenAI account
func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *model.Account) (string, string, error) {
if account.Type == model.AccountTypeOAuth {
accessToken := account.GetOpenAIAccessToken()
if accessToken == "" {
return "", "", errors.New("access_token not found in credentials")
}
return accessToken, "oauth", nil
} else if account.Type == model.AccountTypeApiKey {
apiKey := account.GetOpenAIApiKey()
if apiKey == "" {
return "", "", errors.New("api_key not found in credentials")
}
return apiKey, "apikey", nil
}
return "", "", fmt.Errorf("unsupported account type: %s", account.Type)
}
// Forward forwards request to OpenAI API
func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, account *model.Account, body []byte) (*OpenAIForwardResult, error) {
startTime := time.Now()
// Parse request body once (avoid multiple parse/serialize cycles)
var reqBody map[string]any
if err := json.Unmarshal(body, &reqBody); err != nil {
return nil, fmt.Errorf("parse request: %w", err)
}
// Extract model and stream from parsed body
reqModel, _ := reqBody["model"].(string)
reqStream, _ := reqBody["stream"].(bool)
// Track if body needs re-serialization
bodyModified := false
originalModel := reqModel
// Apply model mapping
mappedModel := account.GetMappedModel(reqModel)
if mappedModel != reqModel {
reqBody["model"] = mappedModel
bodyModified = true
}
// For OAuth accounts using ChatGPT internal API, add store: false
if account.Type == model.AccountTypeOAuth {
reqBody["store"] = false
bodyModified = true
}
// Re-serialize body only if modified
if bodyModified {
var err error
body, err = json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("serialize request body: %w", err)
}
}
// Get access token
token, _, err := s.GetAccessToken(ctx, account)
if err != nil {
return nil, err
}
// Build upstream request
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, reqStream)
if err != nil {
return nil, err
}
// Get proxy URL
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
// Send request
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL)
if err != nil {
return nil, fmt.Errorf("upstream request failed: %w", err)
}
defer func() { _ = resp.Body.Close() }()
// Handle error response
if resp.StatusCode >= 400 {
return s.handleErrorResponse(ctx, resp, c, account)
}
// Handle normal response
var usage *OpenAIUsage
var firstTokenMs *int
if reqStream {
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, mappedModel)
if err != nil {
return nil, err
}
usage = streamResult.usage
firstTokenMs = streamResult.firstTokenMs
} else {
usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, mappedModel)
if err != nil {
return nil, err
}
}
return &OpenAIForwardResult{
RequestID: resp.Header.Get("x-request-id"),
Usage: *usage,
Model: originalModel,
Stream: reqStream,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
}, nil
}
func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *model.Account, body []byte, token string, isStream bool) (*http.Request, error) {
// Determine target URL based on account type
var targetURL string
if account.Type == model.AccountTypeOAuth {
// OAuth accounts use ChatGPT internal API
targetURL = chatgptCodexURL
} else if account.Type == model.AccountTypeApiKey {
// API Key accounts use Platform API or custom base URL
baseURL := account.GetOpenAIBaseURL()
if baseURL != "" {
targetURL = baseURL + "/v1/responses"
} else {
targetURL = openaiPlatformAPIURL
}
} else {
targetURL = openaiPlatformAPIURL
}
req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body))
if err != nil {
return nil, err
}
// Set authentication header
req.Header.Set("authorization", "Bearer "+token)
// Set headers specific to OAuth accounts (ChatGPT internal API)
if account.Type == model.AccountTypeOAuth {
// Required: set Host for ChatGPT API (must use req.Host, not Header.Set)
req.Host = "chatgpt.com"
// Required: set chatgpt-account-id header
chatgptAccountID := account.GetChatGPTAccountID()
if chatgptAccountID != "" {
req.Header.Set("chatgpt-account-id", chatgptAccountID)
}
// Set accept header based on stream mode
if isStream {
req.Header.Set("accept", "text/event-stream")
} else {
req.Header.Set("accept", "application/json")
}
}
// Whitelist passthrough headers
for key, values := range c.Request.Header {
lowerKey := strings.ToLower(key)
if openaiAllowedHeaders[lowerKey] {
for _, v := range values {
req.Header.Add(key, v)
}
}
}
// Apply custom User-Agent if configured
customUA := account.GetOpenAIUserAgent()
if customUA != "" {
req.Header.Set("user-agent", customUA)
}
// Ensure required headers exist
if req.Header.Get("content-type") == "" {
req.Header.Set("content-type", "application/json")
}
return req, nil
}
func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account) (*OpenAIForwardResult, error) {
body, _ := io.ReadAll(resp.Body)
// Check custom error codes
if !account.ShouldHandleErrorCode(resp.StatusCode) {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"type": "upstream_error",
"message": "Upstream gateway error",
},
})
return nil, fmt.Errorf("upstream error: %d (not in custom error codes)", resp.StatusCode)
}
// Handle upstream error (mark account status)
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
// Return appropriate error response
var errType, errMsg string
var statusCode int
switch resp.StatusCode {
case 401:
statusCode = http.StatusBadGateway
errType = "upstream_error"
errMsg = "Upstream authentication failed, please contact administrator"
case 403:
statusCode = http.StatusBadGateway
errType = "upstream_error"
errMsg = "Upstream access forbidden, please contact administrator"
case 429:
statusCode = http.StatusTooManyRequests
errType = "rate_limit_error"
errMsg = "Upstream rate limit exceeded, please retry later"
default:
statusCode = http.StatusBadGateway
errType = "upstream_error"
errMsg = "Upstream request failed"
}
c.JSON(statusCode, gin.H{
"error": gin.H{
"type": errType,
"message": errMsg,
},
})
return nil, fmt.Errorf("upstream error: %d", resp.StatusCode)
}
// openaiStreamingResult streaming response result
type openaiStreamingResult struct {
usage *OpenAIUsage
firstTokenMs *int
}
func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account, startTime time.Time, originalModel, mappedModel string) (*openaiStreamingResult, error) {
// Set SSE response headers
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
// Pass through other headers
if v := resp.Header.Get("x-request-id"); v != "" {
c.Header("x-request-id", v)
}
w := c.Writer
flusher, ok := w.(http.Flusher)
if !ok {
return nil, errors.New("streaming not supported")
}
usage := &OpenAIUsage{}
var firstTokenMs *int
scanner := bufio.NewScanner(resp.Body)
scanner.Buffer(make([]byte, 64*1024), 1024*1024)
needModelReplace := originalModel != mappedModel
for scanner.Scan() {
line := scanner.Text()
// Replace model in response if needed
if needModelReplace && strings.HasPrefix(line, "data: ") {
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
}
// Forward line
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
}
flusher.Flush()
// Parse usage data
if strings.HasPrefix(line, "data: ") {
data := line[6:]
// Record first token time
if firstTokenMs == nil && data != "" && data != "[DONE]" {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
s.parseSSEUsage(data, usage)
}
}
if err := scanner.Err(); err != nil {
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err)
}
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
}
func (s *OpenAIGatewayService) replaceModelInSSELine(line, fromModel, toModel string) string {
data := line[6:]
if data == "" || data == "[DONE]" {
return line
}
var event map[string]any
if err := json.Unmarshal([]byte(data), &event); err != nil {
return line
}
// Replace model in response
if m, ok := event["model"].(string); ok && m == fromModel {
event["model"] = toModel
newData, err := json.Marshal(event)
if err != nil {
return line
}
return "data: " + string(newData)
}
// Check nested response
if response, ok := event["response"].(map[string]any); ok {
if m, ok := response["model"].(string); ok && m == fromModel {
response["model"] = toModel
newData, err := json.Marshal(event)
if err != nil {
return line
}
return "data: " + string(newData)
}
}
return line
}
func (s *OpenAIGatewayService) parseSSEUsage(data string, usage *OpenAIUsage) {
// Parse response.completed event for usage (OpenAI Responses format)
var event struct {
Type string `json:"type"`
Response struct {
Usage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
InputTokenDetails struct {
CachedTokens int `json:"cached_tokens"`
} `json:"input_tokens_details"`
} `json:"usage"`
} `json:"response"`
}
if json.Unmarshal([]byte(data), &event) == nil && event.Type == "response.completed" {
usage.InputTokens = event.Response.Usage.InputTokens
usage.OutputTokens = event.Response.Usage.OutputTokens
usage.CacheReadInputTokens = event.Response.Usage.InputTokenDetails.CachedTokens
}
}
func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account, originalModel, mappedModel string) (*OpenAIUsage, error) {
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
// Parse usage
var response struct {
Usage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
InputTokenDetails struct {
CachedTokens int `json:"cached_tokens"`
} `json:"input_tokens_details"`
} `json:"usage"`
}
if err := json.Unmarshal(body, &response); err != nil {
return nil, fmt.Errorf("parse response: %w", err)
}
usage := &OpenAIUsage{
InputTokens: response.Usage.InputTokens,
OutputTokens: response.Usage.OutputTokens,
CacheReadInputTokens: response.Usage.InputTokenDetails.CachedTokens,
}
// Replace model in response if needed
if originalModel != mappedModel {
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
}
// Pass through headers
for key, values := range resp.Header {
for _, value := range values {
c.Header(key, value)
}
}
c.Data(resp.StatusCode, "application/json", body)
return usage, nil
}
func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel, toModel string) []byte {
var resp map[string]any
if err := json.Unmarshal(body, &resp); err != nil {
return body
}
model, ok := resp["model"].(string)
if !ok || model != fromModel {
return body
}
resp["model"] = toModel
newBody, err := json.Marshal(resp)
if err != nil {
return body
}
return newBody
}
// OpenAIRecordUsageInput input for recording usage
type OpenAIRecordUsageInput struct {
Result *OpenAIForwardResult
ApiKey *model.ApiKey
User *model.User
Account *model.Account
Subscription *model.UserSubscription
}
// RecordUsage records usage and deducts balance
func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRecordUsageInput) error {
result := input.Result
apiKey := input.ApiKey
user := input.User
account := input.Account
subscription := input.Subscription
// Calculate cost
tokens := UsageTokens{
InputTokens: result.Usage.InputTokens,
OutputTokens: result.Usage.OutputTokens,
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
CacheReadTokens: result.Usage.CacheReadInputTokens,
}
// Get rate multiplier
multiplier := s.cfg.Default.RateMultiplier
if apiKey.GroupID != nil && apiKey.Group != nil {
multiplier = apiKey.Group.RateMultiplier
}
cost, err := s.billingService.CalculateCost(result.Model, tokens, multiplier)
if err != nil {
cost = &CostBreakdown{ActualCost: 0}
}
// Determine billing type
isSubscriptionBilling := subscription != nil && apiKey.Group != nil && apiKey.Group.IsSubscriptionType()
billingType := model.BillingTypeBalance
if isSubscriptionBilling {
billingType = model.BillingTypeSubscription
}
// Create usage log
durationMs := int(result.Duration.Milliseconds())
usageLog := &model.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: result.RequestID,
Model: result.Model,
InputTokens: result.Usage.InputTokens,
OutputTokens: result.Usage.OutputTokens,
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
CacheReadTokens: result.Usage.CacheReadInputTokens,
InputCost: cost.InputCost,
OutputCost: cost.OutputCost,
CacheCreationCost: cost.CacheCreationCost,
CacheReadCost: cost.CacheReadCost,
TotalCost: cost.TotalCost,
ActualCost: cost.ActualCost,
RateMultiplier: multiplier,
BillingType: billingType,
Stream: result.Stream,
DurationMs: &durationMs,
FirstTokenMs: result.FirstTokenMs,
CreatedAt: time.Now(),
}
if apiKey.GroupID != nil {
usageLog.GroupID = apiKey.GroupID
}
if subscription != nil {
usageLog.SubscriptionID = &subscription.ID
}
_ = s.usageLogRepo.Create(ctx, usageLog)
// Deduct based on billing type
if isSubscriptionBilling {
if cost.TotalCost > 0 {
_ = s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost)
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = s.billingCacheService.UpdateSubscriptionUsage(cacheCtx, user.ID, *apiKey.GroupID, cost.TotalCost)
}()
}
} else {
if cost.ActualCost > 0 {
_ = s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost)
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = s.billingCacheService.DeductBalanceCache(cacheCtx, user.ID, cost.ActualCost)
}()
}
}
// Update account last used
_ = s.accountRepo.UpdateLastUsed(ctx, account.ID)
return nil
}