Merge branch 'test' into release
This commit is contained in:
@@ -341,7 +341,7 @@ func (h *AccountHandler) listAccountsFiltered(ctx context.Context, platform, acc
|
||||
pageSize := dataPageCap
|
||||
var out []service.Account
|
||||
for {
|
||||
items, total, err := h.adminService.ListAccounts(ctx, page, pageSize, platform, accountType, status, search)
|
||||
items, total, err := h.adminService.ListAccounts(ctx, page, pageSize, platform, accountType, status, search, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -200,7 +200,12 @@ func (h *AccountHandler) List(c *gin.Context) {
|
||||
search = search[:100]
|
||||
}
|
||||
|
||||
accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search)
|
||||
var groupID int64
|
||||
if groupIDStr := c.Query("group"); groupIDStr != "" {
|
||||
groupID, _ = strconv.ParseInt(groupIDStr, 10, 64)
|
||||
}
|
||||
|
||||
accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search, groupID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
@@ -1433,6 +1438,12 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Handle Sora accounts
|
||||
if account.Platform == service.PlatformSora {
|
||||
response.Success(c, service.DefaultSoraModels(nil))
|
||||
return
|
||||
}
|
||||
|
||||
// Handle Claude/Anthropic accounts
|
||||
// For OAuth and Setup-Token accounts: return default models
|
||||
if account.IsOAuth() {
|
||||
@@ -1542,7 +1553,7 @@ func (h *AccountHandler) BatchRefreshTier(c *gin.Context) {
|
||||
accounts := make([]*service.Account, 0)
|
||||
|
||||
if len(req.AccountIDs) == 0 {
|
||||
allAccounts, _, err := h.adminService.ListAccounts(ctx, 1, 10000, "gemini", "oauth", "", "")
|
||||
allAccounts, _, err := h.adminService.ListAccounts(ctx, 1, 10000, "gemini", "oauth", "", "", 0)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
|
||||
@@ -47,6 +47,7 @@ func setupAdminRouter() (*gin.Engine, *stubAdminService) {
|
||||
router.DELETE("/api/v1/admin/proxies/:id", proxyHandler.Delete)
|
||||
router.POST("/api/v1/admin/proxies/batch-delete", proxyHandler.BatchDelete)
|
||||
router.POST("/api/v1/admin/proxies/:id/test", proxyHandler.Test)
|
||||
router.POST("/api/v1/admin/proxies/:id/quality-check", proxyHandler.CheckQuality)
|
||||
router.GET("/api/v1/admin/proxies/:id/stats", proxyHandler.GetStats)
|
||||
router.GET("/api/v1/admin/proxies/:id/accounts", proxyHandler.GetProxyAccounts)
|
||||
|
||||
@@ -208,6 +209,11 @@ func TestProxyHandlerEndpoints(t *testing.T) {
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/proxies/4/quality-check", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/4/stats", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
@@ -166,7 +166,7 @@ func (s *stubAdminService) GetGroupAPIKeys(ctx context.Context, groupID int64, p
|
||||
return s.apiKeys, int64(len(s.apiKeys)), nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]service.Account, int64, error) {
|
||||
func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]service.Account, int64, error) {
|
||||
return s.accounts, int64(len(s.accounts)), nil
|
||||
}
|
||||
|
||||
@@ -327,6 +327,27 @@ func (s *stubAdminService) TestProxy(ctx context.Context, id int64) (*service.Pr
|
||||
return &service.ProxyTestResult{Success: true, Message: "ok"}, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) CheckProxyQuality(ctx context.Context, id int64) (*service.ProxyQualityCheckResult, error) {
|
||||
return &service.ProxyQualityCheckResult{
|
||||
ProxyID: id,
|
||||
Score: 95,
|
||||
Grade: "A",
|
||||
Summary: "通过 5 项,告警 0 项,失败 0 项,挑战 0 项",
|
||||
PassedCount: 5,
|
||||
WarnCount: 0,
|
||||
FailedCount: 0,
|
||||
ChallengeCount: 0,
|
||||
CheckedAt: time.Now().Unix(),
|
||||
Items: []service.ProxyQualityCheckItem{
|
||||
{Target: "base_connectivity", Status: "pass", Message: "ok"},
|
||||
{Target: "openai", Status: "pass", HTTPStatus: 401},
|
||||
{Target: "anthropic", Status: "pass", HTTPStatus: 401},
|
||||
{Target: "gemini", Status: "pass", HTTPStatus: 200},
|
||||
{Target: "sora", Status: "pass", HTTPStatus: 401},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]service.RedeemCode, int64, error) {
|
||||
return s.redeems, int64(len(s.redeems)), nil
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package admin
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
@@ -16,6 +17,13 @@ type OpenAIOAuthHandler struct {
|
||||
adminService service.AdminService
|
||||
}
|
||||
|
||||
func oauthPlatformFromPath(c *gin.Context) string {
|
||||
if strings.Contains(c.FullPath(), "/admin/sora/") {
|
||||
return service.PlatformSora
|
||||
}
|
||||
return service.PlatformOpenAI
|
||||
}
|
||||
|
||||
// NewOpenAIOAuthHandler creates a new OpenAI OAuth handler
|
||||
func NewOpenAIOAuthHandler(openaiOAuthService *service.OpenAIOAuthService, adminService service.AdminService) *OpenAIOAuthHandler {
|
||||
return &OpenAIOAuthHandler{
|
||||
@@ -52,6 +60,7 @@ func (h *OpenAIOAuthHandler) GenerateAuthURL(c *gin.Context) {
|
||||
type OpenAIExchangeCodeRequest struct {
|
||||
SessionID string `json:"session_id" binding:"required"`
|
||||
Code string `json:"code" binding:"required"`
|
||||
State string `json:"state" binding:"required"`
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
}
|
||||
@@ -68,6 +77,7 @@ func (h *OpenAIOAuthHandler) ExchangeCode(c *gin.Context) {
|
||||
tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{
|
||||
SessionID: req.SessionID,
|
||||
Code: req.Code,
|
||||
State: req.State,
|
||||
RedirectURI: req.RedirectURI,
|
||||
ProxyID: req.ProxyID,
|
||||
})
|
||||
@@ -81,18 +91,29 @@ func (h *OpenAIOAuthHandler) ExchangeCode(c *gin.Context) {
|
||||
|
||||
// OpenAIRefreshTokenRequest represents the request for refreshing OpenAI token
|
||||
type OpenAIRefreshTokenRequest struct {
|
||||
RefreshToken string `json:"refresh_token" binding:"required"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
RT string `json:"rt"`
|
||||
ClientID string `json:"client_id"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
}
|
||||
|
||||
// RefreshToken refreshes an OpenAI OAuth token
|
||||
// POST /api/v1/admin/openai/refresh-token
|
||||
// POST /api/v1/admin/sora/rt2at
|
||||
func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
|
||||
var req OpenAIRefreshTokenRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
refreshToken := strings.TrimSpace(req.RefreshToken)
|
||||
if refreshToken == "" {
|
||||
refreshToken = strings.TrimSpace(req.RT)
|
||||
}
|
||||
if refreshToken == "" {
|
||||
response.BadRequest(c, "refresh_token is required")
|
||||
return
|
||||
}
|
||||
|
||||
var proxyURL string
|
||||
if req.ProxyID != nil {
|
||||
@@ -102,7 +123,7 @@ func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
tokenInfo, err := h.openaiOAuthService.RefreshToken(c.Request.Context(), req.RefreshToken, proxyURL)
|
||||
tokenInfo, err := h.openaiOAuthService.RefreshTokenWithClientID(c.Request.Context(), refreshToken, proxyURL, strings.TrimSpace(req.ClientID))
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
@@ -111,8 +132,39 @@ func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
|
||||
response.Success(c, tokenInfo)
|
||||
}
|
||||
|
||||
// RefreshAccountToken refreshes token for a specific OpenAI account
|
||||
// ExchangeSoraSessionToken exchanges Sora session token to access token
|
||||
// POST /api/v1/admin/sora/st2at
|
||||
func (h *OpenAIOAuthHandler) ExchangeSoraSessionToken(c *gin.Context) {
|
||||
var req struct {
|
||||
SessionToken string `json:"session_token"`
|
||||
ST string `json:"st"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
sessionToken := strings.TrimSpace(req.SessionToken)
|
||||
if sessionToken == "" {
|
||||
sessionToken = strings.TrimSpace(req.ST)
|
||||
}
|
||||
if sessionToken == "" {
|
||||
response.BadRequest(c, "session_token is required")
|
||||
return
|
||||
}
|
||||
|
||||
tokenInfo, err := h.openaiOAuthService.ExchangeSoraSessionToken(c.Request.Context(), sessionToken, req.ProxyID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, tokenInfo)
|
||||
}
|
||||
|
||||
// RefreshAccountToken refreshes token for a specific OpenAI/Sora account
|
||||
// POST /api/v1/admin/openai/accounts/:id/refresh
|
||||
// POST /api/v1/admin/sora/accounts/:id/refresh
|
||||
func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
|
||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
@@ -127,9 +179,9 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Ensure account is OpenAI platform
|
||||
if !account.IsOpenAI() {
|
||||
response.BadRequest(c, "Account is not an OpenAI account")
|
||||
platform := oauthPlatformFromPath(c)
|
||||
if account.Platform != platform {
|
||||
response.BadRequest(c, "Account platform does not match OAuth endpoint")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -167,12 +219,14 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
|
||||
response.Success(c, dto.AccountFromService(updatedAccount))
|
||||
}
|
||||
|
||||
// CreateAccountFromOAuth creates a new OpenAI OAuth account from token info
|
||||
// CreateAccountFromOAuth creates a new OpenAI/Sora OAuth account from token info
|
||||
// POST /api/v1/admin/openai/create-from-oauth
|
||||
// POST /api/v1/admin/sora/create-from-oauth
|
||||
func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
|
||||
var req struct {
|
||||
SessionID string `json:"session_id" binding:"required"`
|
||||
Code string `json:"code" binding:"required"`
|
||||
State string `json:"state" binding:"required"`
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Name string `json:"name"`
|
||||
@@ -189,6 +243,7 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
|
||||
tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{
|
||||
SessionID: req.SessionID,
|
||||
Code: req.Code,
|
||||
State: req.State,
|
||||
RedirectURI: req.RedirectURI,
|
||||
ProxyID: req.ProxyID,
|
||||
})
|
||||
@@ -200,19 +255,25 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
|
||||
// Build credentials from token info
|
||||
credentials := h.openaiOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
|
||||
platform := oauthPlatformFromPath(c)
|
||||
|
||||
// Use email as default name if not provided
|
||||
name := req.Name
|
||||
if name == "" && tokenInfo.Email != "" {
|
||||
name = tokenInfo.Email
|
||||
}
|
||||
if name == "" {
|
||||
name = "OpenAI OAuth Account"
|
||||
if platform == service.PlatformSora {
|
||||
name = "Sora OAuth Account"
|
||||
} else {
|
||||
name = "OpenAI OAuth Account"
|
||||
}
|
||||
}
|
||||
|
||||
// Create account
|
||||
account, err := h.adminService.CreateAccount(c.Request.Context(), &service.CreateAccountInput{
|
||||
Name: name,
|
||||
Platform: "openai",
|
||||
Platform: platform,
|
||||
Type: "oauth",
|
||||
Credentials: credentials,
|
||||
ProxyID: req.ProxyID,
|
||||
|
||||
@@ -236,6 +236,24 @@ func (h *ProxyHandler) Test(c *gin.Context) {
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
// CheckQuality handles checking proxy quality across common AI targets.
|
||||
// POST /api/v1/admin/proxies/:id/quality-check
|
||||
func (h *ProxyHandler) CheckQuality(c *gin.Context) {
|
||||
proxyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid proxy ID")
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.adminService.CheckProxyQuality(c.Request.Context(), proxyID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
// GetStats handles getting proxy statistics
|
||||
// GET /api/v1/admin/proxies/:id/stats
|
||||
func (h *ProxyHandler) GetStats(c *gin.Context) {
|
||||
|
||||
@@ -214,6 +214,13 @@ func AccountFromServiceShallow(a *service.Account) *Account {
|
||||
enabled := true
|
||||
out.EnableSessionIDMasking = &enabled
|
||||
}
|
||||
// 缓存 TTL 强制替换
|
||||
if a.IsCacheTTLOverrideEnabled() {
|
||||
enabled := true
|
||||
out.CacheTTLOverrideEnabled = &enabled
|
||||
target := a.GetCacheTTLOverrideTarget()
|
||||
out.CacheTTLOverrideTarget = &target
|
||||
}
|
||||
}
|
||||
|
||||
return out
|
||||
@@ -296,6 +303,11 @@ func ProxyWithAccountCountFromService(p *service.ProxyWithAccountCount) *ProxyWi
|
||||
CountryCode: p.CountryCode,
|
||||
Region: p.Region,
|
||||
City: p.City,
|
||||
QualityStatus: p.QualityStatus,
|
||||
QualityScore: p.QualityScore,
|
||||
QualityGrade: p.QualityGrade,
|
||||
QualitySummary: p.QualitySummary,
|
||||
QualityChecked: p.QualityChecked,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -402,6 +414,7 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
|
||||
ImageSize: l.ImageSize,
|
||||
MediaType: l.MediaType,
|
||||
UserAgent: l.UserAgent,
|
||||
CacheTTLOverridden: l.CacheTTLOverridden,
|
||||
CreatedAt: l.CreatedAt,
|
||||
User: UserFromServiceShallow(l.User),
|
||||
APIKey: APIKeyFromService(l.APIKey),
|
||||
|
||||
@@ -156,6 +156,11 @@ type Account struct {
|
||||
// 从 extra 字段提取,方便前端显示和编辑
|
||||
EnableSessionIDMasking *bool `json:"session_id_masking_enabled,omitempty"`
|
||||
|
||||
// 缓存 TTL 强制替换(仅 Anthropic OAuth/SetupToken 账号有效)
|
||||
// 启用后将所有 cache creation tokens 归入指定的 TTL 类型计费
|
||||
CacheTTLOverrideEnabled *bool `json:"cache_ttl_override_enabled,omitempty"`
|
||||
CacheTTLOverrideTarget *string `json:"cache_ttl_override_target,omitempty"`
|
||||
|
||||
Proxy *Proxy `json:"proxy,omitempty"`
|
||||
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
|
||||
|
||||
@@ -197,6 +202,11 @@ type ProxyWithAccountCount struct {
|
||||
CountryCode string `json:"country_code,omitempty"`
|
||||
Region string `json:"region,omitempty"`
|
||||
City string `json:"city,omitempty"`
|
||||
QualityStatus string `json:"quality_status,omitempty"`
|
||||
QualityScore *int `json:"quality_score,omitempty"`
|
||||
QualityGrade string `json:"quality_grade,omitempty"`
|
||||
QualitySummary string `json:"quality_summary,omitempty"`
|
||||
QualityChecked *int64 `json:"quality_checked,omitempty"`
|
||||
}
|
||||
|
||||
type ProxyAccountSummary struct {
|
||||
@@ -280,6 +290,9 @@ type UsageLog struct {
|
||||
// User-Agent
|
||||
UserAgent *string `json:"user_agent"`
|
||||
|
||||
// Cache TTL Override 标记
|
||||
CacheTTLOverridden bool `json:"cache_ttl_overridden"`
|
||||
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
|
||||
User *User `json:"user,omitempty"`
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -20,6 +21,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/soraerror"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tidwall/gjson"
|
||||
@@ -35,6 +37,7 @@ type SoraGatewayHandler struct {
|
||||
concurrencyHelper *ConcurrencyHelper
|
||||
maxAccountSwitches int
|
||||
streamMode string
|
||||
soraTLSEnabled bool
|
||||
soraMediaSigningKey string
|
||||
soraMediaRoot string
|
||||
}
|
||||
@@ -50,6 +53,7 @@ func NewSoraGatewayHandler(
|
||||
pingInterval := time.Duration(0)
|
||||
maxAccountSwitches := 3
|
||||
streamMode := "force"
|
||||
soraTLSEnabled := true
|
||||
signKey := ""
|
||||
mediaRoot := "/app/data/sora"
|
||||
if cfg != nil {
|
||||
@@ -60,6 +64,7 @@ func NewSoraGatewayHandler(
|
||||
if mode := strings.TrimSpace(cfg.Gateway.SoraStreamMode); mode != "" {
|
||||
streamMode = mode
|
||||
}
|
||||
soraTLSEnabled = !cfg.Sora.Client.DisableTLSFingerprint
|
||||
signKey = strings.TrimSpace(cfg.Gateway.SoraMediaSigningKey)
|
||||
if root := strings.TrimSpace(cfg.Sora.Storage.LocalPath); root != "" {
|
||||
mediaRoot = root
|
||||
@@ -72,6 +77,7 @@ func NewSoraGatewayHandler(
|
||||
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
|
||||
maxAccountSwitches: maxAccountSwitches,
|
||||
streamMode: strings.ToLower(streamMode),
|
||||
soraTLSEnabled: soraTLSEnabled,
|
||||
soraMediaSigningKey: signKey,
|
||||
soraMediaRoot: mediaRoot,
|
||||
}
|
||||
@@ -212,6 +218,8 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
lastFailoverStatus := 0
|
||||
var lastFailoverBody []byte
|
||||
var lastFailoverHeaders http.Header
|
||||
|
||||
for {
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs, "")
|
||||
@@ -224,11 +232,31 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||
return
|
||||
}
|
||||
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
||||
rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(lastFailoverHeaders, lastFailoverBody)
|
||||
fields := []zap.Field{
|
||||
zap.Int("last_upstream_status", lastFailoverStatus),
|
||||
}
|
||||
if rayID != "" {
|
||||
fields = append(fields, zap.String("last_upstream_cf_ray", rayID))
|
||||
}
|
||||
if mitigated != "" {
|
||||
fields = append(fields, zap.String("last_upstream_cf_mitigated", mitigated))
|
||||
}
|
||||
if contentType != "" {
|
||||
fields = append(fields, zap.String("last_upstream_content_type", contentType))
|
||||
}
|
||||
reqLog.Warn("sora.failover_exhausted_no_available_accounts", fields...)
|
||||
h.handleFailoverExhausted(c, lastFailoverStatus, lastFailoverHeaders, lastFailoverBody, streamStarted)
|
||||
return
|
||||
}
|
||||
account := selection.Account
|
||||
setOpsSelectedAccount(c, account.ID, account.Platform)
|
||||
proxyBound := account.ProxyID != nil
|
||||
proxyID := int64(0)
|
||||
if account.ProxyID != nil {
|
||||
proxyID = *account.ProxyID
|
||||
}
|
||||
tlsFingerprintEnabled := h.soraTLSEnabled
|
||||
|
||||
accountReleaseFunc := selection.ReleaseFunc
|
||||
if !selection.Acquired {
|
||||
@@ -239,10 +267,19 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
accountWaitCounted := false
|
||||
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
|
||||
if err != nil {
|
||||
reqLog.Warn("sora.account_wait_counter_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
reqLog.Warn("sora.account_wait_counter_increment_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int64("proxy_id", proxyID),
|
||||
zap.Bool("proxy_bound", proxyBound),
|
||||
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
|
||||
zap.Error(err),
|
||||
)
|
||||
} else if !canWait {
|
||||
reqLog.Info("sora.account_wait_queue_full",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int64("proxy_id", proxyID),
|
||||
zap.Bool("proxy_bound", proxyBound),
|
||||
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
|
||||
zap.Int("max_waiting", selection.WaitPlan.MaxWaiting),
|
||||
)
|
||||
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
|
||||
@@ -266,7 +303,13 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
&streamStarted,
|
||||
)
|
||||
if err != nil {
|
||||
reqLog.Warn("sora.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
reqLog.Warn("sora.account_slot_acquire_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int64("proxy_id", proxyID),
|
||||
zap.Bool("proxy_bound", proxyBound),
|
||||
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
|
||||
zap.Error(err),
|
||||
)
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
return
|
||||
}
|
||||
@@ -287,20 +330,67 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
if switchCount >= maxAccountSwitches {
|
||||
lastFailoverStatus = failoverErr.StatusCode
|
||||
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
||||
lastFailoverHeaders = cloneHTTPHeaders(failoverErr.ResponseHeaders)
|
||||
lastFailoverBody = failoverErr.ResponseBody
|
||||
rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(lastFailoverHeaders, lastFailoverBody)
|
||||
fields := []zap.Field{
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int64("proxy_id", proxyID),
|
||||
zap.Bool("proxy_bound", proxyBound),
|
||||
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("switch_count", switchCount),
|
||||
zap.Int("max_switches", maxAccountSwitches),
|
||||
}
|
||||
if rayID != "" {
|
||||
fields = append(fields, zap.String("upstream_cf_ray", rayID))
|
||||
}
|
||||
if mitigated != "" {
|
||||
fields = append(fields, zap.String("upstream_cf_mitigated", mitigated))
|
||||
}
|
||||
if contentType != "" {
|
||||
fields = append(fields, zap.String("upstream_content_type", contentType))
|
||||
}
|
||||
reqLog.Warn("sora.upstream_failover_exhausted", fields...)
|
||||
h.handleFailoverExhausted(c, lastFailoverStatus, lastFailoverHeaders, lastFailoverBody, streamStarted)
|
||||
return
|
||||
}
|
||||
lastFailoverStatus = failoverErr.StatusCode
|
||||
lastFailoverHeaders = cloneHTTPHeaders(failoverErr.ResponseHeaders)
|
||||
lastFailoverBody = failoverErr.ResponseBody
|
||||
switchCount++
|
||||
reqLog.Warn("sora.upstream_failover_switching",
|
||||
upstreamErrCode, upstreamErrMsg := extractUpstreamErrorCodeAndMessage(lastFailoverBody)
|
||||
rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(lastFailoverHeaders, lastFailoverBody)
|
||||
fields := []zap.Field{
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int64("proxy_id", proxyID),
|
||||
zap.Bool("proxy_bound", proxyBound),
|
||||
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.String("upstream_error_code", upstreamErrCode),
|
||||
zap.String("upstream_error_message", upstreamErrMsg),
|
||||
zap.Int("switch_count", switchCount),
|
||||
zap.Int("max_switches", maxAccountSwitches),
|
||||
)
|
||||
}
|
||||
if rayID != "" {
|
||||
fields = append(fields, zap.String("upstream_cf_ray", rayID))
|
||||
}
|
||||
if mitigated != "" {
|
||||
fields = append(fields, zap.String("upstream_cf_mitigated", mitigated))
|
||||
}
|
||||
if contentType != "" {
|
||||
fields = append(fields, zap.String("upstream_content_type", contentType))
|
||||
}
|
||||
reqLog.Warn("sora.upstream_failover_switching", fields...)
|
||||
continue
|
||||
}
|
||||
reqLog.Error("sora.forward_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
reqLog.Error("sora.forward_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int64("proxy_id", proxyID),
|
||||
zap.Bool("proxy_bound", proxyBound),
|
||||
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -331,6 +421,9 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
}(result, account, userAgent, clientIP)
|
||||
reqLog.Debug("sora.request_completed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int64("proxy_id", proxyID),
|
||||
zap.Bool("proxy_bound", proxyBound),
|
||||
zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled),
|
||||
zap.Int("switch_count", switchCount),
|
||||
)
|
||||
return
|
||||
@@ -360,17 +453,41 @@ func (h *SoraGatewayHandler) handleConcurrencyError(c *gin.Context, err error, s
|
||||
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
|
||||
}
|
||||
|
||||
func (h *SoraGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, streamStarted bool) {
|
||||
status, errType, errMsg := h.mapUpstreamError(statusCode)
|
||||
func (h *SoraGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, responseHeaders http.Header, responseBody []byte, streamStarted bool) {
|
||||
status, errType, errMsg := h.mapUpstreamError(statusCode, responseHeaders, responseBody)
|
||||
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
||||
}
|
||||
|
||||
func (h *SoraGatewayHandler) mapUpstreamError(statusCode int) (int, string, string) {
|
||||
func (h *SoraGatewayHandler) mapUpstreamError(statusCode int, responseHeaders http.Header, responseBody []byte) (int, string, string) {
|
||||
if isSoraCloudflareChallengeResponse(statusCode, responseHeaders, responseBody) {
|
||||
baseMsg := fmt.Sprintf("Sora request blocked by Cloudflare challenge (HTTP %d). Please switch to a clean proxy/network and retry.", statusCode)
|
||||
return http.StatusBadGateway, "upstream_error", formatSoraCloudflareChallengeMessage(baseMsg, responseHeaders, responseBody)
|
||||
}
|
||||
|
||||
upstreamCode, upstreamMessage := extractUpstreamErrorCodeAndMessage(responseBody)
|
||||
if strings.EqualFold(upstreamCode, "cf_shield_429") {
|
||||
baseMsg := "Sora request blocked by Cloudflare shield (429). Please switch to a clean proxy/network and retry."
|
||||
return http.StatusTooManyRequests, "rate_limit_error", formatSoraCloudflareChallengeMessage(baseMsg, responseHeaders, responseBody)
|
||||
}
|
||||
if shouldPassthroughSoraUpstreamMessage(statusCode, upstreamMessage) {
|
||||
switch statusCode {
|
||||
case 401, 403, 404, 500, 502, 503, 504:
|
||||
return http.StatusBadGateway, "upstream_error", upstreamMessage
|
||||
case 429:
|
||||
return http.StatusTooManyRequests, "rate_limit_error", upstreamMessage
|
||||
}
|
||||
}
|
||||
|
||||
switch statusCode {
|
||||
case 401:
|
||||
return http.StatusBadGateway, "upstream_error", "Upstream authentication failed, please contact administrator"
|
||||
case 403:
|
||||
return http.StatusBadGateway, "upstream_error", "Upstream access forbidden, please contact administrator"
|
||||
case 404:
|
||||
if strings.EqualFold(upstreamCode, "unsupported_country_code") {
|
||||
return http.StatusBadGateway, "upstream_error", "Upstream region capability unavailable for this account, please contact administrator"
|
||||
}
|
||||
return http.StatusBadGateway, "upstream_error", "Upstream capability unavailable for this account, please contact administrator"
|
||||
case 429:
|
||||
return http.StatusTooManyRequests, "rate_limit_error", "Upstream rate limit exceeded, please retry later"
|
||||
case 529:
|
||||
@@ -382,11 +499,67 @@ func (h *SoraGatewayHandler) mapUpstreamError(statusCode int) (int, string, stri
|
||||
}
|
||||
}
|
||||
|
||||
func cloneHTTPHeaders(headers http.Header) http.Header {
|
||||
if headers == nil {
|
||||
return nil
|
||||
}
|
||||
return headers.Clone()
|
||||
}
|
||||
|
||||
func extractSoraFailoverHeaderInsights(headers http.Header, body []byte) (rayID, mitigated, contentType string) {
|
||||
if headers != nil {
|
||||
mitigated = strings.TrimSpace(headers.Get("cf-mitigated"))
|
||||
contentType = strings.TrimSpace(headers.Get("content-type"))
|
||||
if contentType == "" {
|
||||
contentType = strings.TrimSpace(headers.Get("Content-Type"))
|
||||
}
|
||||
}
|
||||
rayID = soraerror.ExtractCloudflareRayID(headers, body)
|
||||
return rayID, mitigated, contentType
|
||||
}
|
||||
|
||||
func isSoraCloudflareChallengeResponse(statusCode int, headers http.Header, body []byte) bool {
|
||||
return soraerror.IsCloudflareChallengeResponse(statusCode, headers, body)
|
||||
}
|
||||
|
||||
func shouldPassthroughSoraUpstreamMessage(statusCode int, message string) bool {
|
||||
message = strings.TrimSpace(message)
|
||||
if message == "" {
|
||||
return false
|
||||
}
|
||||
if statusCode == http.StatusForbidden || statusCode == http.StatusTooManyRequests {
|
||||
lower := strings.ToLower(message)
|
||||
if strings.Contains(lower, "<html") || strings.Contains(lower, "<!doctype html") || strings.Contains(lower, "window._cf_chl_opt") {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func formatSoraCloudflareChallengeMessage(base string, headers http.Header, body []byte) string {
|
||||
return soraerror.FormatCloudflareChallengeMessage(base, headers, body)
|
||||
}
|
||||
|
||||
func extractUpstreamErrorCodeAndMessage(body []byte) (string, string) {
|
||||
return soraerror.ExtractUpstreamErrorCodeAndMessage(body)
|
||||
}
|
||||
|
||||
func (h *SoraGatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
|
||||
if streamStarted {
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if ok {
|
||||
errorEvent := fmt.Sprintf(`event: error`+"\n"+`data: {"error": {"type": "%s", "message": "%s"}}`+"\n\n", errType, message)
|
||||
errorData := map[string]any{
|
||||
"error": map[string]string{
|
||||
"type": errType,
|
||||
"message": message,
|
||||
},
|
||||
}
|
||||
jsonBytes, err := json.Marshal(errorData)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
errorEvent := fmt.Sprintf("event: error\ndata: %s\n\n", string(jsonBytes))
|
||||
if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil {
|
||||
_ = c.Error(err)
|
||||
}
|
||||
|
||||
@@ -43,6 +43,48 @@ func (s *stubSoraClient) CreateImageTask(ctx context.Context, account *service.A
|
||||
func (s *stubSoraClient) CreateVideoTask(ctx context.Context, account *service.Account, req service.SoraVideoRequest) (string, error) {
|
||||
return "task-video", nil
|
||||
}
|
||||
func (s *stubSoraClient) CreateStoryboardTask(ctx context.Context, account *service.Account, req service.SoraStoryboardRequest) (string, error) {
|
||||
return "task-video", nil
|
||||
}
|
||||
func (s *stubSoraClient) UploadCharacterVideo(ctx context.Context, account *service.Account, data []byte) (string, error) {
|
||||
return "cameo-1", nil
|
||||
}
|
||||
func (s *stubSoraClient) GetCameoStatus(ctx context.Context, account *service.Account, cameoID string) (*service.SoraCameoStatus, error) {
|
||||
return &service.SoraCameoStatus{
|
||||
Status: "finalized",
|
||||
StatusMessage: "Completed",
|
||||
DisplayNameHint: "Character",
|
||||
UsernameHint: "user.character",
|
||||
ProfileAssetURL: "https://example.com/avatar.webp",
|
||||
}, nil
|
||||
}
|
||||
func (s *stubSoraClient) DownloadCharacterImage(ctx context.Context, account *service.Account, imageURL string) ([]byte, error) {
|
||||
return []byte("avatar"), nil
|
||||
}
|
||||
func (s *stubSoraClient) UploadCharacterImage(ctx context.Context, account *service.Account, data []byte) (string, error) {
|
||||
return "asset-pointer", nil
|
||||
}
|
||||
func (s *stubSoraClient) FinalizeCharacter(ctx context.Context, account *service.Account, req service.SoraCharacterFinalizeRequest) (string, error) {
|
||||
return "character-1", nil
|
||||
}
|
||||
func (s *stubSoraClient) SetCharacterPublic(ctx context.Context, account *service.Account, cameoID string) error {
|
||||
return nil
|
||||
}
|
||||
func (s *stubSoraClient) DeleteCharacter(ctx context.Context, account *service.Account, characterID string) error {
|
||||
return nil
|
||||
}
|
||||
func (s *stubSoraClient) PostVideoForWatermarkFree(ctx context.Context, account *service.Account, generationID string) (string, error) {
|
||||
return "s_post", nil
|
||||
}
|
||||
func (s *stubSoraClient) DeletePost(ctx context.Context, account *service.Account, postID string) error {
|
||||
return nil
|
||||
}
|
||||
func (s *stubSoraClient) GetWatermarkFreeURLCustom(ctx context.Context, account *service.Account, parseURL, parseToken, postID string) (string, error) {
|
||||
return "https://example.com/no-watermark.mp4", nil
|
||||
}
|
||||
func (s *stubSoraClient) EnhancePrompt(ctx context.Context, account *service.Account, prompt, expansionLevel string, durationS int) (string, error) {
|
||||
return "enhanced prompt", nil
|
||||
}
|
||||
func (s *stubSoraClient) GetImageTask(ctx context.Context, account *service.Account, taskID string) (*service.SoraImageTaskStatus, error) {
|
||||
return &service.SoraImageTaskStatus{ID: taskID, Status: "completed", URLs: s.imageURLs}, nil
|
||||
}
|
||||
@@ -88,7 +130,7 @@ func (r *stubAccountRepo) Delete(ctx context.Context, id int64) error
|
||||
func (r *stubAccountRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (r *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]service.Account, *pagination.PaginationResult, error) {
|
||||
func (r *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]service.Account, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (r *stubAccountRepo) ListByGroup(ctx context.Context, groupID int64) ([]service.Account, error) {
|
||||
@@ -495,3 +537,152 @@ func TestGenerateOpenAISessionHash_WithBody(t *testing.T) {
|
||||
require.NotEmpty(t, hash3)
|
||||
require.NotEqual(t, hash, hash3) // 不同来源应产生不同 hash
|
||||
}
|
||||
|
||||
func TestSoraHandleStreamingAwareError_JSONEscaping(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
errType string
|
||||
message string
|
||||
}{
|
||||
{
|
||||
name: "包含双引号",
|
||||
errType: "upstream_error",
|
||||
message: `upstream returned "invalid" payload`,
|
||||
},
|
||||
{
|
||||
name: "包含换行和制表符",
|
||||
errType: "rate_limit_error",
|
||||
message: "line1\nline2\ttab",
|
||||
},
|
||||
{
|
||||
name: "包含反斜杠",
|
||||
errType: "upstream_error",
|
||||
message: `path C:\Users\test\file.txt not found`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
h := &SoraGatewayHandler{}
|
||||
h.handleStreamingAwareError(c, http.StatusBadGateway, tt.errType, tt.message, true)
|
||||
|
||||
body := w.Body.String()
|
||||
require.True(t, strings.HasPrefix(body, "event: error\n"), "应以 SSE error 事件开头")
|
||||
require.True(t, strings.HasSuffix(body, "\n\n"), "应以 SSE 结束分隔符结尾")
|
||||
|
||||
lines := strings.Split(strings.TrimSuffix(body, "\n\n"), "\n")
|
||||
require.Len(t, lines, 2, "SSE 错误事件应包含 event 行和 data 行")
|
||||
require.Equal(t, "event: error", lines[0])
|
||||
require.True(t, strings.HasPrefix(lines[1], "data: "), "第二行应为 data 前缀")
|
||||
|
||||
jsonStr := strings.TrimPrefix(lines[1], "data: ")
|
||||
var parsed map[string]any
|
||||
require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed), "data 行必须是合法 JSON")
|
||||
|
||||
errorObj, ok := parsed["error"].(map[string]any)
|
||||
require.True(t, ok, "JSON 中应包含 error 对象")
|
||||
require.Equal(t, tt.errType, errorObj["type"])
|
||||
require.Equal(t, tt.message, errorObj["message"])
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSoraHandleFailoverExhausted_StreamPassesUpstreamMessage(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
h := &SoraGatewayHandler{}
|
||||
resp := []byte(`{"error":{"message":"invalid \"prompt\"\nline2","code":"bad_request"}}`)
|
||||
h.handleFailoverExhausted(c, http.StatusBadGateway, nil, resp, true)
|
||||
|
||||
body := w.Body.String()
|
||||
require.True(t, strings.HasPrefix(body, "event: error\n"))
|
||||
require.True(t, strings.HasSuffix(body, "\n\n"))
|
||||
|
||||
lines := strings.Split(strings.TrimSuffix(body, "\n\n"), "\n")
|
||||
require.Len(t, lines, 2)
|
||||
jsonStr := strings.TrimPrefix(lines[1], "data: ")
|
||||
|
||||
var parsed map[string]any
|
||||
require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed))
|
||||
|
||||
errorObj, ok := parsed["error"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "upstream_error", errorObj["type"])
|
||||
require.Equal(t, "invalid \"prompt\"\nline2", errorObj["message"])
|
||||
}
|
||||
|
||||
func TestSoraHandleFailoverExhausted_CloudflareChallengeIncludesRay(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
headers := http.Header{}
|
||||
headers.Set("cf-ray", "9d01b0e9ecc35829-SEA")
|
||||
body := []byte(`<!DOCTYPE html><html><head><title>Just a moment...</title></head><body><script>window._cf_chl_opt={};</script></body></html>`)
|
||||
|
||||
h := &SoraGatewayHandler{}
|
||||
h.handleFailoverExhausted(c, http.StatusForbidden, headers, body, true)
|
||||
|
||||
lines := strings.Split(strings.TrimSuffix(w.Body.String(), "\n\n"), "\n")
|
||||
require.Len(t, lines, 2)
|
||||
jsonStr := strings.TrimPrefix(lines[1], "data: ")
|
||||
|
||||
var parsed map[string]any
|
||||
require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed))
|
||||
|
||||
errorObj, ok := parsed["error"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "upstream_error", errorObj["type"])
|
||||
msg, _ := errorObj["message"].(string)
|
||||
require.Contains(t, msg, "Cloudflare challenge")
|
||||
require.Contains(t, msg, "cf-ray: 9d01b0e9ecc35829-SEA")
|
||||
}
|
||||
|
||||
func TestSoraHandleFailoverExhausted_CfShield429MappedToRateLimitError(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
headers := http.Header{}
|
||||
headers.Set("cf-ray", "9d03b68c086027a1-SEA")
|
||||
body := []byte(`{"error":{"code":"cf_shield_429","message":"shield blocked"}}`)
|
||||
|
||||
h := &SoraGatewayHandler{}
|
||||
h.handleFailoverExhausted(c, http.StatusTooManyRequests, headers, body, true)
|
||||
|
||||
lines := strings.Split(strings.TrimSuffix(w.Body.String(), "\n\n"), "\n")
|
||||
require.Len(t, lines, 2)
|
||||
jsonStr := strings.TrimPrefix(lines[1], "data: ")
|
||||
|
||||
var parsed map[string]any
|
||||
require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed))
|
||||
|
||||
errorObj, ok := parsed["error"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "rate_limit_error", errorObj["type"])
|
||||
msg, _ := errorObj["message"].(string)
|
||||
require.Contains(t, msg, "Cloudflare shield")
|
||||
require.Contains(t, msg, "cf-ray: 9d03b68c086027a1-SEA")
|
||||
}
|
||||
|
||||
func TestExtractSoraFailoverHeaderInsights(t *testing.T) {
|
||||
headers := http.Header{}
|
||||
headers.Set("cf-mitigated", "challenge")
|
||||
headers.Set("content-type", "text/html")
|
||||
body := []byte(`<script>window._cf_chl_opt={cRay: '9cff2d62d83bb98d'};</script>`)
|
||||
|
||||
rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(headers, body)
|
||||
require.Equal(t, "9cff2d62d83bb98d", rayID)
|
||||
require.Equal(t, "challenge", mitigated)
|
||||
require.Equal(t, "text/html", contentType)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user