merge: 合并 upstream/main 并解决冲突

解决了以下文件的冲突:
- backend/internal/handler/admin/setting_handler.go
  - 采用 upstream 的字段对齐风格和 *Configured 字段名
  - 添加 EnableIdentityPatch 和 IdentityPatchPrompt 字段

- backend/internal/handler/gateway_handler.go
  - 采用 upstream 的 billingErrorDetails 错误处理方式

- frontend/src/api/admin/settings.ts
  - 采用 upstream 的 *_configured 字段名
  - 添加 enable_identity_patch 和 identity_patch_prompt 字段

- frontend/src/views/admin/SettingsView.vue
  - 合并 turnstile_secret_key_configured 字段
  - 保留 enable_identity_patch 和 identity_patch_prompt 字段
This commit is contained in:
IanShaw027
2026-01-04 23:17:15 +08:00
65 changed files with 2712 additions and 796 deletions

View File

@@ -1,8 +1,12 @@
package admin
import (
"log"
"time"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
@@ -34,33 +38,33 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
}
response.Success(c, dto.SystemSettings{
RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled,
SMTPHost: settings.SMTPHost,
SMTPPort: settings.SMTPPort,
SMTPUsername: settings.SMTPUsername,
SMTPPassword: settings.SMTPPassword,
SMTPFrom: settings.SMTPFrom,
SMTPFromName: settings.SMTPFromName,
SMTPUseTLS: settings.SMTPUseTLS,
TurnstileEnabled: settings.TurnstileEnabled,
TurnstileSiteKey: settings.TurnstileSiteKey,
TurnstileSecretKey: settings.TurnstileSecretKey,
SiteName: settings.SiteName,
SiteLogo: settings.SiteLogo,
SiteSubtitle: settings.SiteSubtitle,
APIBaseURL: settings.APIBaseURL,
ContactInfo: settings.ContactInfo,
DocURL: settings.DocURL,
DefaultConcurrency: settings.DefaultConcurrency,
DefaultBalance: settings.DefaultBalance,
EnableModelFallback: settings.EnableModelFallback,
FallbackModelAnthropic: settings.FallbackModelAnthropic,
FallbackModelOpenAI: settings.FallbackModelOpenAI,
FallbackModelGemini: settings.FallbackModelGemini,
FallbackModelAntigravity: settings.FallbackModelAntigravity,
EnableIdentityPatch: settings.EnableIdentityPatch,
IdentityPatchPrompt: settings.IdentityPatchPrompt,
RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled,
SMTPHost: settings.SMTPHost,
SMTPPort: settings.SMTPPort,
SMTPUsername: settings.SMTPUsername,
SMTPPasswordConfigured: settings.SMTPPasswordConfigured,
SMTPFrom: settings.SMTPFrom,
SMTPFromName: settings.SMTPFromName,
SMTPUseTLS: settings.SMTPUseTLS,
TurnstileEnabled: settings.TurnstileEnabled,
TurnstileSiteKey: settings.TurnstileSiteKey,
TurnstileSecretKeyConfigured: settings.TurnstileSecretKeyConfigured,
SiteName: settings.SiteName,
SiteLogo: settings.SiteLogo,
SiteSubtitle: settings.SiteSubtitle,
APIBaseURL: settings.APIBaseURL,
ContactInfo: settings.ContactInfo,
DocURL: settings.DocURL,
DefaultConcurrency: settings.DefaultConcurrency,
DefaultBalance: settings.DefaultBalance,
EnableModelFallback: settings.EnableModelFallback,
FallbackModelAnthropic: settings.FallbackModelAnthropic,
FallbackModelOpenAI: settings.FallbackModelOpenAI,
FallbackModelGemini: settings.FallbackModelGemini,
FallbackModelAntigravity: settings.FallbackModelAntigravity,
EnableIdentityPatch: settings.EnableIdentityPatch,
IdentityPatchPrompt: settings.IdentityPatchPrompt,
})
}
@@ -117,6 +121,12 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
return
}
previousSettings, err := h.settingService.GetAllSettings(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
// 验证参数
if req.DefaultConcurrency < 1 {
req.DefaultConcurrency = 1
@@ -193,6 +203,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
return
}
h.auditSettingsUpdate(c, previousSettings, settings, req)
// 重新获取设置返回
updatedSettings, err := h.settingService.GetAllSettings(c.Request.Context())
if err != nil {
@@ -201,36 +213,136 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
response.Success(c, dto.SystemSettings{
RegistrationEnabled: updatedSettings.RegistrationEnabled,
EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
SMTPHost: updatedSettings.SMTPHost,
SMTPPort: updatedSettings.SMTPPort,
SMTPUsername: updatedSettings.SMTPUsername,
SMTPPassword: updatedSettings.SMTPPassword,
SMTPFrom: updatedSettings.SMTPFrom,
SMTPFromName: updatedSettings.SMTPFromName,
SMTPUseTLS: updatedSettings.SMTPUseTLS,
TurnstileEnabled: updatedSettings.TurnstileEnabled,
TurnstileSiteKey: updatedSettings.TurnstileSiteKey,
TurnstileSecretKey: updatedSettings.TurnstileSecretKey,
SiteName: updatedSettings.SiteName,
SiteLogo: updatedSettings.SiteLogo,
SiteSubtitle: updatedSettings.SiteSubtitle,
APIBaseURL: updatedSettings.APIBaseURL,
ContactInfo: updatedSettings.ContactInfo,
DocURL: updatedSettings.DocURL,
DefaultConcurrency: updatedSettings.DefaultConcurrency,
DefaultBalance: updatedSettings.DefaultBalance,
EnableModelFallback: updatedSettings.EnableModelFallback,
FallbackModelAnthropic: updatedSettings.FallbackModelAnthropic,
FallbackModelOpenAI: updatedSettings.FallbackModelOpenAI,
FallbackModelGemini: updatedSettings.FallbackModelGemini,
FallbackModelAntigravity: updatedSettings.FallbackModelAntigravity,
EnableIdentityPatch: updatedSettings.EnableIdentityPatch,
IdentityPatchPrompt: updatedSettings.IdentityPatchPrompt,
RegistrationEnabled: updatedSettings.RegistrationEnabled,
EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
SMTPHost: updatedSettings.SMTPHost,
SMTPPort: updatedSettings.SMTPPort,
SMTPUsername: updatedSettings.SMTPUsername,
SMTPPasswordConfigured: updatedSettings.SMTPPasswordConfigured,
SMTPFrom: updatedSettings.SMTPFrom,
SMTPFromName: updatedSettings.SMTPFromName,
SMTPUseTLS: updatedSettings.SMTPUseTLS,
TurnstileEnabled: updatedSettings.TurnstileEnabled,
TurnstileSiteKey: updatedSettings.TurnstileSiteKey,
TurnstileSecretKeyConfigured: updatedSettings.TurnstileSecretKeyConfigured,
SiteName: updatedSettings.SiteName,
SiteLogo: updatedSettings.SiteLogo,
SiteSubtitle: updatedSettings.SiteSubtitle,
APIBaseURL: updatedSettings.APIBaseURL,
ContactInfo: updatedSettings.ContactInfo,
DocURL: updatedSettings.DocURL,
DefaultConcurrency: updatedSettings.DefaultConcurrency,
DefaultBalance: updatedSettings.DefaultBalance,
EnableModelFallback: updatedSettings.EnableModelFallback,
FallbackModelAnthropic: updatedSettings.FallbackModelAnthropic,
FallbackModelOpenAI: updatedSettings.FallbackModelOpenAI,
FallbackModelGemini: updatedSettings.FallbackModelGemini,
FallbackModelAntigravity: updatedSettings.FallbackModelAntigravity,
EnableIdentityPatch: updatedSettings.EnableIdentityPatch,
IdentityPatchPrompt: updatedSettings.IdentityPatchPrompt,
})
}
func (h *SettingHandler) auditSettingsUpdate(c *gin.Context, before *service.SystemSettings, after *service.SystemSettings, req UpdateSettingsRequest) {
if before == nil || after == nil {
return
}
changed := diffSettings(before, after, req)
if len(changed) == 0 {
return
}
subject, _ := middleware.GetAuthSubjectFromContext(c)
role, _ := middleware.GetUserRoleFromContext(c)
log.Printf("AUDIT: settings updated at=%s user_id=%d role=%s changed=%v",
time.Now().UTC().Format(time.RFC3339),
subject.UserID,
role,
changed,
)
}
func diffSettings(before *service.SystemSettings, after *service.SystemSettings, req UpdateSettingsRequest) []string {
changed := make([]string, 0, 20)
if before.RegistrationEnabled != after.RegistrationEnabled {
changed = append(changed, "registration_enabled")
}
if before.EmailVerifyEnabled != after.EmailVerifyEnabled {
changed = append(changed, "email_verify_enabled")
}
if before.SMTPHost != after.SMTPHost {
changed = append(changed, "smtp_host")
}
if before.SMTPPort != after.SMTPPort {
changed = append(changed, "smtp_port")
}
if before.SMTPUsername != after.SMTPUsername {
changed = append(changed, "smtp_username")
}
if req.SMTPPassword != "" {
changed = append(changed, "smtp_password")
}
if before.SMTPFrom != after.SMTPFrom {
changed = append(changed, "smtp_from_email")
}
if before.SMTPFromName != after.SMTPFromName {
changed = append(changed, "smtp_from_name")
}
if before.SMTPUseTLS != after.SMTPUseTLS {
changed = append(changed, "smtp_use_tls")
}
if before.TurnstileEnabled != after.TurnstileEnabled {
changed = append(changed, "turnstile_enabled")
}
if before.TurnstileSiteKey != after.TurnstileSiteKey {
changed = append(changed, "turnstile_site_key")
}
if req.TurnstileSecretKey != "" {
changed = append(changed, "turnstile_secret_key")
}
if before.SiteName != after.SiteName {
changed = append(changed, "site_name")
}
if before.SiteLogo != after.SiteLogo {
changed = append(changed, "site_logo")
}
if before.SiteSubtitle != after.SiteSubtitle {
changed = append(changed, "site_subtitle")
}
if before.APIBaseURL != after.APIBaseURL {
changed = append(changed, "api_base_url")
}
if before.ContactInfo != after.ContactInfo {
changed = append(changed, "contact_info")
}
if before.DocURL != after.DocURL {
changed = append(changed, "doc_url")
}
if before.DefaultConcurrency != after.DefaultConcurrency {
changed = append(changed, "default_concurrency")
}
if before.DefaultBalance != after.DefaultBalance {
changed = append(changed, "default_balance")
}
if before.EnableModelFallback != after.EnableModelFallback {
changed = append(changed, "enable_model_fallback")
}
if before.FallbackModelAnthropic != after.FallbackModelAnthropic {
changed = append(changed, "fallback_model_anthropic")
}
if before.FallbackModelOpenAI != after.FallbackModelOpenAI {
changed = append(changed, "fallback_model_openai")
}
if before.FallbackModelGemini != after.FallbackModelGemini {
changed = append(changed, "fallback_model_gemini")
}
if before.FallbackModelAntigravity != after.FallbackModelAntigravity {
changed = append(changed, "fallback_model_antigravity")
}
return changed
}
// TestSMTPRequest 测试SMTP连接请求
type TestSMTPRequest struct {
SMTPHost string `json:"smtp_host" binding:"required"`

View File

@@ -5,17 +5,17 @@ type SystemSettings struct {
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
SMTPHost string `json:"smtp_host"`
SMTPPort int `json:"smtp_port"`
SMTPUsername string `json:"smtp_username"`
SMTPPassword string `json:"smtp_password,omitempty"`
SMTPFrom string `json:"smtp_from_email"`
SMTPFromName string `json:"smtp_from_name"`
SMTPUseTLS bool `json:"smtp_use_tls"`
SMTPHost string `json:"smtp_host"`
SMTPPort int `json:"smtp_port"`
SMTPUsername string `json:"smtp_username"`
SMTPPasswordConfigured bool `json:"smtp_password_configured"`
SMTPFrom string `json:"smtp_from_email"`
SMTPFromName string `json:"smtp_from_name"`
SMTPUseTLS bool `json:"smtp_use_tls"`
TurnstileEnabled bool `json:"turnstile_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key"`
TurnstileSecretKey string `json:"turnstile_secret_key,omitempty"`
TurnstileEnabled bool `json:"turnstile_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key"`
TurnstileSecretKeyConfigured bool `json:"turnstile_secret_key_configured"`
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`

View File

@@ -1,7 +1,6 @@
package handler
import (
"bytes"
"context"
"encoding/json"
"errors"
@@ -12,8 +11,10 @@ import (
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
pkgerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -21,10 +22,6 @@ import (
"github.com/gin-gonic/gin"
)
const maxGatewayRequestBodyBytes int64 = 10 * 1024 * 1024 // 10MB
var errEmptyRequestBody = errors.New("request body is empty")
// GatewayHandler handles API gateway requests
type GatewayHandler struct {
gatewayService *service.GatewayService
@@ -35,23 +32,6 @@ type GatewayHandler struct {
concurrencyHelper *ConcurrencyHelper
}
func (h *GatewayHandler) recordUsageSync(apiKey *service.APIKey, subscription *service.UserSubscription, result *service.ForwardResult, usedAccount *service.Account) {
// 计费属于关键数据:同步写入,避免 goroutine 异步导致进程崩溃时丢失使用量/扣费数据。
// 使用独立 Background context避免客户端取消请求导致计费中断。
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: usedAccount,
Subscription: subscription,
}); err != nil {
log.Printf("Record usage failed: request_id=%s user=%d api_key=%d account=%d err=%v", result.RequestID, apiKey.UserID, apiKey.ID, usedAccount.ID, err)
}
}
// NewGatewayHandler creates a new GatewayHandler
func NewGatewayHandler(
gatewayService *service.GatewayService,
@@ -60,89 +40,22 @@ func NewGatewayHandler(
userService *service.UserService,
concurrencyService *service.ConcurrencyService,
billingCacheService *service.BillingCacheService,
cfg *config.Config,
) *GatewayHandler {
pingInterval := time.Duration(0)
if cfg != nil {
pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second
}
return &GatewayHandler{
gatewayService: gatewayService,
geminiCompatService: geminiCompatService,
antigravityGatewayService: antigravityGatewayService,
userService: userService,
billingCacheService: billingCacheService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude),
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval),
}
}
func parseGatewayRequestStream(r io.Reader, limit int64) (*service.ParsedRequest, error) {
if r == nil {
return nil, errEmptyRequestBody
}
var raw bytes.Buffer
limited := io.LimitReader(r, limit+1)
tee := io.TeeReader(limited, &raw)
decoder := json.NewDecoder(tee)
var req map[string]any
if err := decoder.Decode(&req); err != nil {
if errors.Is(err, io.EOF) {
return nil, errEmptyRequestBody
}
if int64(raw.Len()) > limit {
return nil, &http.MaxBytesError{Limit: limit}
}
return nil, err
}
// Ensure the body contains exactly one JSON value (allowing trailing whitespace).
var extra any
if err := decoder.Decode(&extra); err != io.EOF {
if int64(raw.Len()) > limit {
return nil, &http.MaxBytesError{Limit: limit}
}
if err == nil {
return nil, fmt.Errorf("request body must contain a single JSON object")
}
return nil, err
}
if int64(raw.Len()) > limit {
return nil, &http.MaxBytesError{Limit: limit}
}
parsed := &service.ParsedRequest{
Body: raw.Bytes(),
}
if rawModel, exists := req["model"]; exists {
model, ok := rawModel.(string)
if !ok {
return nil, fmt.Errorf("invalid model field type")
}
parsed.Model = model
}
if rawStream, exists := req["stream"]; exists {
stream, ok := rawStream.(bool)
if !ok {
return nil, fmt.Errorf("invalid stream field type")
}
parsed.Stream = stream
}
if metadata, ok := req["metadata"].(map[string]any); ok {
if userID, ok := metadata["user_id"].(string); ok {
parsed.MetadataUserID = userID
}
}
// system 字段只要存在就视为显式提供(即使为 null
// 以避免客户端传 null 时被默认 system 误注入。
if system, ok := req["system"]; ok {
parsed.HasSystem = true
parsed.System = system
}
if messages, ok := req["messages"].([]any); ok {
parsed.Messages = messages
}
return parsed, nil
}
// Messages handles Claude API compatible messages endpoint
// POST /v1/messages
func (h *GatewayHandler) Messages(c *gin.Context) {
@@ -159,29 +72,27 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return
}
parsedReq, err := parseGatewayRequestStream(c.Request.Body, maxGatewayRequestBodyBytes)
// 读取请求体
body, err := io.ReadAll(c.Request.Body)
if err != nil {
if maxErr, ok := extractMaxBytesError(err); ok {
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
return
}
if errors.Is(err, errEmptyRequestBody) {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
return
}
var syntaxErr *json.SyntaxError
var typeErr *json.UnmarshalTypeError
if errors.As(err, &syntaxErr) || errors.As(err, &typeErr) || errors.Is(err, io.ErrUnexpectedEOF) {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
}
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
return
}
if len(parsedReq.Body) == 0 {
if len(body) == 0 {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
return
}
parsedReq, err := service.ParseGatewayRequest(body)
if err != nil {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
}
reqModel := parsedReq.Model
reqStream := parsedReq.Stream
@@ -217,6 +128,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h.handleConcurrencyError(c, err, "user", streamStarted)
return
}
// 在请求结束或 Context 取消时确保释放槽位,避免客户端断开造成泄漏
userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc)
if userReleaseFunc != nil {
defer userReleaseFunc()
}
@@ -224,7 +137,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 2. 【新增】Wait后二次检查余额/订阅
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
log.Printf("Billing eligibility check failed after wait: %v", err)
h.handleStreamingAwareError(c, http.StatusForbidden, "permission_error", "Insufficient balance or active subscription required", streamStarted)
status, code, message := billingErrorDetails(err)
h.handleStreamingAwareError(c, status, code, message, streamStarted)
return
}
@@ -252,9 +166,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
for {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs)
if err != nil {
log.Printf("Select account failed: %v", err)
if len(failedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts for requested model", streamStarted)
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
return
}
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
@@ -263,7 +176,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
account := selection.Account
// 检查预热请求拦截(在账号选择后、转发前检查)
if account.IsInterceptWarmupEnabled() && isWarmupRequest(parsedReq.Body) {
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
if selection.Acquired && selection.ReleaseFunc != nil {
selection.ReleaseFunc()
}
@@ -317,13 +230,16 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
log.Printf("Bind sticky session failed: %v", err)
}
}
// 账号槽位/等待计数需要在超时或断开时安全回收
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
accountWaitRelease = wrapReleaseOnDone(c.Request.Context(), accountWaitRelease)
// 转发请求 - 根据账号平台分流
var result *service.ForwardResult
if account.Platform == service.PlatformAntigravity {
result, err = h.antigravityGatewayService.ForwardGemini(c.Request.Context(), c, account, reqModel, "generateContent", reqStream, parsedReq.Body)
result, err = h.antigravityGatewayService.ForwardGemini(c.Request.Context(), c, account, reqModel, "generateContent", reqStream, body)
} else {
result, err = h.geminiCompatService.Forward(c.Request.Context(), c, account, parsedReq.Body)
result, err = h.geminiCompatService.Forward(c.Request.Context(), c, account, body)
}
if accountReleaseFunc != nil {
accountReleaseFunc()
@@ -350,8 +266,20 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return
}
// 步记录使用量,避免进程崩溃导致计费数据丢失subscription已在函数开头获取
h.recordUsageSync(apiKey, subscription, result, account)
// 步记录使用量subscription已在函数开头获取
go func(result *service.ForwardResult, usedAccount *service.Account) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: usedAccount,
Subscription: subscription,
}); err != nil {
log.Printf("Record usage failed: %v", err)
}
}(result, account)
return
}
}
@@ -365,9 +293,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 选择支持该模型的账号
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs)
if err != nil {
log.Printf("Select account failed: %v", err)
if len(failedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts for requested model", streamStarted)
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
return
}
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
@@ -376,7 +303,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
account := selection.Account
// 检查预热请求拦截(在账号选择后、转发前检查)
if account.IsInterceptWarmupEnabled() && isWarmupRequest(parsedReq.Body) {
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
if selection.Acquired && selection.ReleaseFunc != nil {
selection.ReleaseFunc()
}
@@ -430,11 +357,14 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
log.Printf("Bind sticky session failed: %v", err)
}
}
// 账号槽位/等待计数需要在超时或断开时安全回收
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
accountWaitRelease = wrapReleaseOnDone(c.Request.Context(), accountWaitRelease)
// 转发请求 - 根据账号平台分流
var result *service.ForwardResult
if account.Platform == service.PlatformAntigravity {
result, err = h.antigravityGatewayService.Forward(c.Request.Context(), c, account, parsedReq.Body)
result, err = h.antigravityGatewayService.Forward(c.Request.Context(), c, account, body)
} else {
result, err = h.gatewayService.Forward(c.Request.Context(), c, account, parsedReq)
}
@@ -463,8 +393,20 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return
}
// 步记录使用量,避免进程崩溃导致计费数据丢失subscription已在函数开头获取
h.recordUsageSync(apiKey, subscription, result, account)
// 步记录使用量subscription已在函数开头获取
go func(result *service.ForwardResult, usedAccount *service.Account) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: usedAccount,
Subscription: subscription,
}); err != nil {
log.Printf("Record usage failed: %v", err)
}
}(result, account)
return
}
}
@@ -640,71 +582,32 @@ func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int,
func (h *GatewayHandler) mapUpstreamError(statusCode int) (int, string, string) {
switch statusCode {
case 401:
return http.StatusBadGateway, "api_error", "Upstream authentication failed, please contact administrator"
return http.StatusBadGateway, "upstream_error", "Upstream authentication failed, please contact administrator"
case 403:
return http.StatusBadGateway, "api_error", "Upstream access forbidden, please contact administrator"
return http.StatusBadGateway, "upstream_error", "Upstream access forbidden, please contact administrator"
case 429:
return http.StatusTooManyRequests, "rate_limit_error", "Upstream rate limit exceeded, please retry later"
case 529:
return http.StatusServiceUnavailable, "overloaded_error", "Upstream service overloaded, please retry later"
case 500, 502, 503, 504:
return http.StatusBadGateway, "api_error", "Upstream service temporarily unavailable"
return http.StatusBadGateway, "upstream_error", "Upstream service temporarily unavailable"
default:
return http.StatusBadGateway, "api_error", "Upstream request failed"
return http.StatusBadGateway, "upstream_error", "Upstream request failed"
}
}
func normalizeAnthropicErrorType(errType string) string {
switch errType {
case "invalid_request_error",
"authentication_error",
"permission_error",
"not_found_error",
"rate_limit_error",
"api_error",
"overloaded_error":
return errType
case "billing_error":
// Not an Anthropic-standard error type; map to the closest equivalent.
return "permission_error"
case "subscription_error":
// Not an Anthropic-standard error type; map to the closest equivalent.
return "permission_error"
case "upstream_error":
// Not an Anthropic-standard error type; keep clients compatible.
return "api_error"
default:
return "api_error"
}
}
const maxPublicErrorMessageLen = 512
func sanitizePublicErrorMessage(message string) string {
cleaned := strings.TrimSpace(message)
cleaned = strings.ReplaceAll(cleaned, "\r", " ")
cleaned = strings.ReplaceAll(cleaned, "\n", " ")
if len(cleaned) > maxPublicErrorMessageLen {
cleaned = cleaned[:maxPublicErrorMessageLen] + "..."
}
return cleaned
}
// handleStreamingAwareError handles errors that may occur after streaming has started
func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
normalizedType := normalizeAnthropicErrorType(errType)
publicMessage := sanitizePublicErrorMessage(message)
if streamStarted {
// Stream already started, send error as SSE event then close
flusher, ok := c.Writer.(http.Flusher)
if ok {
// Anthropic streaming spec: send `event: error` with JSON `data`.
// Send error event in SSE format with proper JSON marshaling
errorData := map[string]any{
"type": "error",
"error": map[string]string{
"type": normalizedType,
"message": publicMessage,
"type": errType,
"message": message,
},
}
jsonBytes, err := json.Marshal(errorData)
@@ -712,11 +615,8 @@ func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, e
_ = c.Error(err)
return
}
if _, err := fmt.Fprintf(c.Writer, "event: error\n"); err != nil {
_ = c.Error(err)
return
}
if _, err := fmt.Fprintf(c.Writer, "data: %s\n\n", string(jsonBytes)); err != nil {
errorEvent := fmt.Sprintf("data: %s\n\n", string(jsonBytes))
if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil {
_ = c.Error(err)
}
flusher.Flush()
@@ -725,19 +625,16 @@ func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, e
}
// Normal case: return JSON response with proper status code
h.errorResponse(c, status, normalizedType, publicMessage)
h.errorResponse(c, status, errType, message)
}
// errorResponse 返回Claude API格式的错误响应
func (h *GatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) {
normalizedType := normalizeAnthropicErrorType(errType)
publicMessage := sanitizePublicErrorMessage(message)
c.JSON(status, gin.H{
"type": "error",
"error": gin.H{
"type": normalizedType,
"message": publicMessage,
"type": errType,
"message": message,
},
})
}
@@ -759,30 +656,28 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
return
}
parsedReq, err := parseGatewayRequestStream(c.Request.Body, maxGatewayRequestBodyBytes)
// 读取请求体
body, err := io.ReadAll(c.Request.Body)
if err != nil {
if maxErr, ok := extractMaxBytesError(err); ok {
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
return
}
if errors.Is(err, errEmptyRequestBody) {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
return
}
var syntaxErr *json.SyntaxError
var typeErr *json.UnmarshalTypeError
if errors.As(err, &syntaxErr) || errors.As(err, &typeErr) || errors.Is(err, io.ErrUnexpectedEOF) {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
}
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
return
}
if len(parsedReq.Body) == 0 {
if len(body) == 0 {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
return
}
parsedReq, err := service.ParseGatewayRequest(body)
if err != nil {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
}
// 验证 model 必填
if parsedReq.Model == "" {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
@@ -795,8 +690,8 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
// 校验 billing eligibility订阅/余额)
// 【注意】不计算并发,但需要校验订阅/余额
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
log.Printf("Billing eligibility check failed: %v", err)
h.errorResponse(c, http.StatusForbidden, "permission_error", "Insufficient balance or active subscription required")
status, code, message := billingErrorDetails(err)
h.errorResponse(c, status, code, message)
return
}
@@ -806,8 +701,7 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
// 选择支持该模型的账号
account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, parsedReq.Model)
if err != nil {
log.Printf("Select account failed: %v", err)
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts for requested model")
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error())
return
}
@@ -923,3 +817,18 @@ func sendMockWarmupResponse(c *gin.Context, model string) {
},
})
}
func billingErrorDetails(err error) (status int, code, message string) {
if errors.Is(err, service.ErrBillingServiceUnavailable) {
msg := pkgerrors.Message(err)
if msg == "" {
msg = "Billing service temporarily unavailable. Please retry later."
}
return http.StatusServiceUnavailable, "billing_service_error", msg
}
msg := pkgerrors.Message(err)
if msg == "" {
msg = err.Error()
}
return http.StatusForbidden, "billing_error", msg
}

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"math/rand"
"net/http"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -26,8 +27,8 @@ import (
const (
// maxConcurrencyWait 等待并发槽位的最大时间
maxConcurrencyWait = 30 * time.Second
// pingInterval 流式响应等待时发送 ping 的间隔
pingInterval = 15 * time.Second
// defaultPingInterval 流式响应等待时发送 ping 的默认间隔
defaultPingInterval = 10 * time.Second
// initialBackoff 初始退避时间
initialBackoff = 100 * time.Millisecond
// backoffMultiplier 退避时间乘数(指数退避)
@@ -44,6 +45,8 @@ const (
SSEPingFormatClaude SSEPingFormat = "data: {\"type\": \"ping\"}\n\n"
// SSEPingFormatNone indicates no ping should be sent (e.g., OpenAI has no ping spec)
SSEPingFormatNone SSEPingFormat = ""
// SSEPingFormatComment is an SSE comment ping for OpenAI/Codex CLI clients
SSEPingFormatComment SSEPingFormat = ":\n\n"
)
// ConcurrencyError represents a concurrency limit error with context
@@ -63,16 +66,38 @@ func (e *ConcurrencyError) Error() string {
type ConcurrencyHelper struct {
concurrencyService *service.ConcurrencyService
pingFormat SSEPingFormat
pingInterval time.Duration
}
// NewConcurrencyHelper creates a new ConcurrencyHelper
func NewConcurrencyHelper(concurrencyService *service.ConcurrencyService, pingFormat SSEPingFormat) *ConcurrencyHelper {
func NewConcurrencyHelper(concurrencyService *service.ConcurrencyService, pingFormat SSEPingFormat, pingInterval time.Duration) *ConcurrencyHelper {
if pingInterval <= 0 {
pingInterval = defaultPingInterval
}
return &ConcurrencyHelper{
concurrencyService: concurrencyService,
pingFormat: pingFormat,
pingInterval: pingInterval,
}
}
// wrapReleaseOnDone ensures release runs at most once and still triggers on context cancellation.
// 用于避免客户端断开或上游超时导致的并发槽位泄漏。
func wrapReleaseOnDone(ctx context.Context, releaseFunc func()) func() {
if releaseFunc == nil {
return nil
}
var once sync.Once
wrapped := func() {
once.Do(releaseFunc)
}
go func() {
<-ctx.Done()
wrapped()
}()
return wrapped
}
// IncrementWaitCount increments the wait count for a user
func (h *ConcurrencyHelper) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
return h.concurrencyService.IncrementWaitCount(ctx, userID, maxWait)
@@ -174,7 +199,7 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType
// Only create ping ticker if ping is needed
var pingCh <-chan time.Time
if needPing {
pingTicker := time.NewTicker(pingInterval)
pingTicker := time.NewTicker(h.pingInterval)
defer pingTicker.Stop()
pingCh = pingTicker.C
}

View File

@@ -165,7 +165,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
subscription, _ := middleware.GetSubscriptionFromContext(c)
// For Gemini native API, do not send Claude-style ping frames.
geminiConcurrency := NewConcurrencyHelper(h.concurrencyHelper.concurrencyService, SSEPingFormatNone)
geminiConcurrency := NewConcurrencyHelper(h.concurrencyHelper.concurrencyService, SSEPingFormatNone, 0)
// 0) wait queue check
maxWait := service.CalculateMaxWait(authSubject.Concurrency)
@@ -185,13 +185,16 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
googleError(c, http.StatusTooManyRequests, err.Error())
return
}
// 确保请求取消时也会释放槽位,避免长连接被动中断造成泄漏
userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc)
if userReleaseFunc != nil {
defer userReleaseFunc()
}
// 2) billing eligibility check (after wait)
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
googleError(c, http.StatusForbidden, err.Error())
status, _, message := billingErrorDetails(err)
googleError(c, status, message)
return
}
@@ -260,6 +263,9 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
log.Printf("Bind sticky session failed: %v", err)
}
}
// 账号槽位/等待计数需要在超时或断开时安全回收
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
accountWaitRelease = wrapReleaseOnDone(c.Request.Context(), accountWaitRelease)
// 5) forward (根据平台分流)
var result *service.ForwardResult
@@ -373,7 +379,7 @@ func writeUpstreamResponse(c *gin.Context, res *service.UpstreamHTTPResult) {
}
for k, vv := range res.Headers {
// Avoid overriding content-length and hop-by-hop headers.
if strings.EqualFold(k, "Content-Length") || strings.EqualFold(k, "Transfer-Encoding") || strings.EqualFold(k, "Connection") {
if strings.EqualFold(k, "Content-Length") || strings.EqualFold(k, "Transfer-Encoding") || strings.EqualFold(k, "Connection") || strings.EqualFold(k, "Www-Authenticate") {
continue
}
for _, v := range vv {

View File

@@ -10,6 +10,7 @@ import (
"net/http"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -29,11 +30,16 @@ func NewOpenAIGatewayHandler(
gatewayService *service.OpenAIGatewayService,
concurrencyService *service.ConcurrencyService,
billingCacheService *service.BillingCacheService,
cfg *config.Config,
) *OpenAIGatewayHandler {
pingInterval := time.Duration(0)
if cfg != nil {
pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second
}
return &OpenAIGatewayHandler{
gatewayService: gatewayService,
billingCacheService: billingCacheService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatNone),
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
}
}
@@ -124,6 +130,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
h.handleConcurrencyError(c, err, "user", streamStarted)
return
}
// 确保请求取消时也会释放槽位,避免长连接被动中断造成泄漏
userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc)
if userReleaseFunc != nil {
defer userReleaseFunc()
}
@@ -131,7 +139,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
// 2. Re-check billing eligibility after wait
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
log.Printf("Billing eligibility check failed after wait: %v", err)
h.handleStreamingAwareError(c, http.StatusForbidden, "billing_error", err.Error(), streamStarted)
status, code, message := billingErrorDetails(err)
h.handleStreamingAwareError(c, status, code, message, streamStarted)
return
}
@@ -201,6 +210,9 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
log.Printf("Bind sticky session failed: %v", err)
}
}
// 账号槽位/等待计数需要在超时或断开时安全回收
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
accountWaitRelease = wrapReleaseOnDone(c.Request.Context(), accountWaitRelease)
// Forward request
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body)