merge: 合并 upstream/main 解决 PR #37 冲突

- 删除 backend/internal/model/account.go 符合重构方向
- 合并最新的项目结构重构
- 包含 SSE 格式解析修复
- 更新依赖和配置文件
This commit is contained in:
IanShaw027
2025-12-26 21:56:08 +08:00
118 changed files with 6077 additions and 3478 deletions

View File

@@ -11,12 +11,12 @@ import (
"fmt"
"io"
"net/http"
"regexp"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/gin-gonic/gin"
)
@@ -28,6 +28,10 @@ const (
openaiStickySessionTTL = time.Hour // 粘性会话TTL
)
// openaiSSEDataRe matches SSE data lines with optional whitespace after colon.
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
var openaiSSEDataRe = regexp.MustCompile(`^data:\s*`)
// OpenAI allowed headers whitelist (for non-OAuth accounts)
var openaiAllowedHeaders = map[string]bool{
"accept-language": true,
@@ -119,12 +123,12 @@ func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context) string {
}
// SelectAccount selects an OpenAI account with sticky session support
func (s *OpenAIGatewayService) SelectAccount(ctx context.Context, groupID *int64, sessionHash string) (*model.Account, error) {
func (s *OpenAIGatewayService) SelectAccount(ctx context.Context, groupID *int64, sessionHash string) (*Account, error) {
return s.SelectAccountForModel(ctx, groupID, sessionHash, "")
}
// SelectAccountForModel selects an account supporting the requested model
func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*model.Account, error) {
func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*Account, error) {
// 1. Check sticky session
if sessionHash != "" {
accountID, err := s.cache.GetSessionAccountID(ctx, "openai:"+sessionHash)
@@ -139,19 +143,19 @@ func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupI
}
// 2. Get schedulable OpenAI accounts
var accounts []model.Account
var accounts []Account
var err error
if groupID != nil {
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, model.PlatformOpenAI)
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformOpenAI)
} else {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, model.PlatformOpenAI)
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformOpenAI)
}
if err != nil {
return nil, fmt.Errorf("query accounts failed: %w", err)
}
// 3. Select by priority + LRU
var selected *model.Account
var selected *Account
for i := range accounts {
acc := &accounts[i]
// Check model support
@@ -198,15 +202,15 @@ func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupI
}
// GetAccessToken gets the access token for an OpenAI account
func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *model.Account) (string, string, error) {
func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) {
switch account.Type {
case model.AccountTypeOAuth:
case AccountTypeOAuth:
accessToken := account.GetOpenAIAccessToken()
if accessToken == "" {
return "", "", errors.New("access_token not found in credentials")
}
return accessToken, "oauth", nil
case model.AccountTypeApiKey:
case AccountTypeApiKey:
apiKey := account.GetOpenAIApiKey()
if apiKey == "" {
return "", "", errors.New("api_key not found in credentials")
@@ -218,7 +222,7 @@ func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *mode
}
// Forward forwards request to OpenAI API
func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, account *model.Account, body []byte) (*OpenAIForwardResult, error) {
func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*OpenAIForwardResult, error) {
startTime := time.Now()
// Parse request body once (avoid multiple parse/serialize cycles)
@@ -243,7 +247,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
}
// For OAuth accounts using ChatGPT internal API, add store: false
if account.Type == model.AccountTypeOAuth {
if account.Type == AccountTypeOAuth {
reqBody["store"] = false
bodyModified = true
}
@@ -305,7 +309,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
}
// Extract and save Codex usage snapshot from response headers (for OAuth accounts)
if account.Type == model.AccountTypeOAuth {
if account.Type == AccountTypeOAuth {
if snapshot := extractCodexUsageHeaders(resp.Header); snapshot != nil {
s.updateCodexUsageSnapshot(ctx, account.ID, snapshot)
}
@@ -321,14 +325,14 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
}, nil
}
func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *model.Account, body []byte, token string, isStream bool) (*http.Request, error) {
func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token string, isStream bool) (*http.Request, error) {
// Determine target URL based on account type
var targetURL string
switch account.Type {
case model.AccountTypeOAuth:
case AccountTypeOAuth:
// OAuth accounts use ChatGPT internal API
targetURL = chatgptCodexURL
case model.AccountTypeApiKey:
case AccountTypeApiKey:
// API Key accounts use Platform API or custom base URL
baseURL := account.GetOpenAIBaseURL()
if baseURL != "" {
@@ -349,7 +353,7 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
req.Header.Set("authorization", "Bearer "+token)
// Set headers specific to OAuth accounts (ChatGPT internal API)
if account.Type == model.AccountTypeOAuth {
if account.Type == AccountTypeOAuth {
// Required: set Host for ChatGPT API (must use req.Host, not Header.Set)
req.Host = "chatgpt.com"
// Required: set chatgpt-account-id header
@@ -389,7 +393,7 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
return req, nil
}
func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account) (*OpenAIForwardResult, error) {
func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*OpenAIForwardResult, error) {
body, _ := io.ReadAll(resp.Body)
// Check custom error codes
@@ -445,7 +449,7 @@ type openaiStreamingResult struct {
firstTokenMs *int
}
func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account, startTime time.Time, originalModel, mappedModel string) (*openaiStreamingResult, error) {
func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string) (*openaiStreamingResult, error) {
// Set SSE response headers
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
@@ -473,26 +477,33 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
for scanner.Scan() {
line := scanner.Text()
// Replace model in response if needed
if needModelReplace && strings.HasPrefix(line, "data: ") {
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
}
// Extract data from SSE line (supports both "data: " and "data:" formats)
if openaiSSEDataRe.MatchString(line) {
data := openaiSSEDataRe.ReplaceAllString(line, "")
// Forward line
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
}
flusher.Flush()
// Replace model in response if needed
if needModelReplace {
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
}
// Forward line
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
}
flusher.Flush()
// Parse usage data
if strings.HasPrefix(line, "data: ") {
data := line[6:]
// Record first token time
if firstTokenMs == nil && data != "" && data != "[DONE]" {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
s.parseSSEUsage(data, usage)
} else {
// Forward non-data lines as-is
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
}
flusher.Flush()
}
}
@@ -504,7 +515,10 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
}
func (s *OpenAIGatewayService) replaceModelInSSELine(line, fromModel, toModel string) string {
data := line[6:]
if !openaiSSEDataRe.MatchString(line) {
return line
}
data := openaiSSEDataRe.ReplaceAllString(line, "")
if data == "" || data == "[DONE]" {
return line
}
@@ -561,7 +575,7 @@ func (s *OpenAIGatewayService) parseSSEUsage(data string, usage *OpenAIUsage) {
}
}
func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account, originalModel, mappedModel string) (*OpenAIUsage, error) {
func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*OpenAIUsage, error) {
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
@@ -627,10 +641,10 @@ func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel
// OpenAIRecordUsageInput input for recording usage
type OpenAIRecordUsageInput struct {
Result *OpenAIForwardResult
ApiKey *model.ApiKey
User *model.User
Account *model.Account
Subscription *model.UserSubscription
ApiKey *ApiKey
User *User
Account *Account
Subscription *UserSubscription
}
// RecordUsage records usage and deducts balance
@@ -669,14 +683,14 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
// Determine billing type
isSubscriptionBilling := subscription != nil && apiKey.Group != nil && apiKey.Group.IsSubscriptionType()
billingType := model.BillingTypeBalance
billingType := BillingTypeBalance
if isSubscriptionBilling {
billingType = model.BillingTypeSubscription
billingType = BillingTypeSubscription
}
// Create usage log
durationMs := int(result.Duration.Milliseconds())
usageLog := &model.UsageLog{
usageLog := &UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
AccountID: account.ID,