merge upstream/main

This commit is contained in:
LLLLLLiulei
2026-02-06 11:33:45 +08:00
89 changed files with 10119 additions and 343 deletions

View File

@@ -0,0 +1,273 @@
package admin
import (
"strconv"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// ErrorPassthroughHandler 处理错误透传规则的 HTTP 请求
type ErrorPassthroughHandler struct {
service *service.ErrorPassthroughService
}
// NewErrorPassthroughHandler 创建错误透传规则处理器
func NewErrorPassthroughHandler(service *service.ErrorPassthroughService) *ErrorPassthroughHandler {
return &ErrorPassthroughHandler{service: service}
}
// CreateErrorPassthroughRuleRequest 创建规则请求
type CreateErrorPassthroughRuleRequest struct {
Name string `json:"name" binding:"required"`
Enabled *bool `json:"enabled"`
Priority int `json:"priority"`
ErrorCodes []int `json:"error_codes"`
Keywords []string `json:"keywords"`
MatchMode string `json:"match_mode"`
Platforms []string `json:"platforms"`
PassthroughCode *bool `json:"passthrough_code"`
ResponseCode *int `json:"response_code"`
PassthroughBody *bool `json:"passthrough_body"`
CustomMessage *string `json:"custom_message"`
Description *string `json:"description"`
}
// UpdateErrorPassthroughRuleRequest 更新规则请求(部分更新,所有字段可选)
type UpdateErrorPassthroughRuleRequest struct {
Name *string `json:"name"`
Enabled *bool `json:"enabled"`
Priority *int `json:"priority"`
ErrorCodes []int `json:"error_codes"`
Keywords []string `json:"keywords"`
MatchMode *string `json:"match_mode"`
Platforms []string `json:"platforms"`
PassthroughCode *bool `json:"passthrough_code"`
ResponseCode *int `json:"response_code"`
PassthroughBody *bool `json:"passthrough_body"`
CustomMessage *string `json:"custom_message"`
Description *string `json:"description"`
}
// List 获取所有规则
// GET /api/v1/admin/error-passthrough-rules
func (h *ErrorPassthroughHandler) List(c *gin.Context) {
rules, err := h.service.List(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, rules)
}
// GetByID 根据 ID 获取规则
// GET /api/v1/admin/error-passthrough-rules/:id
func (h *ErrorPassthroughHandler) GetByID(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid rule ID")
return
}
rule, err := h.service.GetByID(c.Request.Context(), id)
if err != nil {
response.ErrorFrom(c, err)
return
}
if rule == nil {
response.NotFound(c, "Rule not found")
return
}
response.Success(c, rule)
}
// Create 创建规则
// POST /api/v1/admin/error-passthrough-rules
func (h *ErrorPassthroughHandler) Create(c *gin.Context) {
var req CreateErrorPassthroughRuleRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
rule := &model.ErrorPassthroughRule{
Name: req.Name,
Priority: req.Priority,
ErrorCodes: req.ErrorCodes,
Keywords: req.Keywords,
Platforms: req.Platforms,
}
// 设置默认值
if req.Enabled != nil {
rule.Enabled = *req.Enabled
} else {
rule.Enabled = true
}
if req.MatchMode != "" {
rule.MatchMode = req.MatchMode
} else {
rule.MatchMode = model.MatchModeAny
}
if req.PassthroughCode != nil {
rule.PassthroughCode = *req.PassthroughCode
} else {
rule.PassthroughCode = true
}
if req.PassthroughBody != nil {
rule.PassthroughBody = *req.PassthroughBody
} else {
rule.PassthroughBody = true
}
rule.ResponseCode = req.ResponseCode
rule.CustomMessage = req.CustomMessage
rule.Description = req.Description
// 确保切片不为 nil
if rule.ErrorCodes == nil {
rule.ErrorCodes = []int{}
}
if rule.Keywords == nil {
rule.Keywords = []string{}
}
if rule.Platforms == nil {
rule.Platforms = []string{}
}
created, err := h.service.Create(c.Request.Context(), rule)
if err != nil {
if _, ok := err.(*model.ValidationError); ok {
response.BadRequest(c, err.Error())
return
}
response.ErrorFrom(c, err)
return
}
response.Success(c, created)
}
// Update 更新规则(支持部分更新)
// PUT /api/v1/admin/error-passthrough-rules/:id
func (h *ErrorPassthroughHandler) Update(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid rule ID")
return
}
var req UpdateErrorPassthroughRuleRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
// 先获取现有规则
existing, err := h.service.GetByID(c.Request.Context(), id)
if err != nil {
response.ErrorFrom(c, err)
return
}
if existing == nil {
response.NotFound(c, "Rule not found")
return
}
// 部分更新:只更新请求中提供的字段
rule := &model.ErrorPassthroughRule{
ID: id,
Name: existing.Name,
Enabled: existing.Enabled,
Priority: existing.Priority,
ErrorCodes: existing.ErrorCodes,
Keywords: existing.Keywords,
MatchMode: existing.MatchMode,
Platforms: existing.Platforms,
PassthroughCode: existing.PassthroughCode,
ResponseCode: existing.ResponseCode,
PassthroughBody: existing.PassthroughBody,
CustomMessage: existing.CustomMessage,
Description: existing.Description,
}
// 应用请求中提供的更新
if req.Name != nil {
rule.Name = *req.Name
}
if req.Enabled != nil {
rule.Enabled = *req.Enabled
}
if req.Priority != nil {
rule.Priority = *req.Priority
}
if req.ErrorCodes != nil {
rule.ErrorCodes = req.ErrorCodes
}
if req.Keywords != nil {
rule.Keywords = req.Keywords
}
if req.MatchMode != nil {
rule.MatchMode = *req.MatchMode
}
if req.Platforms != nil {
rule.Platforms = req.Platforms
}
if req.PassthroughCode != nil {
rule.PassthroughCode = *req.PassthroughCode
}
if req.ResponseCode != nil {
rule.ResponseCode = req.ResponseCode
}
if req.PassthroughBody != nil {
rule.PassthroughBody = *req.PassthroughBody
}
if req.CustomMessage != nil {
rule.CustomMessage = req.CustomMessage
}
if req.Description != nil {
rule.Description = req.Description
}
// 确保切片不为 nil
if rule.ErrorCodes == nil {
rule.ErrorCodes = []int{}
}
if rule.Keywords == nil {
rule.Keywords = []string{}
}
if rule.Platforms == nil {
rule.Platforms = []string{}
}
updated, err := h.service.Update(c.Request.Context(), rule)
if err != nil {
if _, ok := err.(*model.ValidationError); ok {
response.BadRequest(c, err.Error())
return
}
response.ErrorFrom(c, err)
return
}
response.Success(c, updated)
}
// Delete 删除规则
// DELETE /api/v1/admin/error-passthrough-rules/:id
func (h *ErrorPassthroughHandler) Delete(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid rule ID")
return
}
if err := h.service.Delete(c.Request.Context(), id); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"message": "Rule deleted successfully"})
}

View File

@@ -45,6 +45,9 @@ type UpdateUserRequest struct {
Concurrency *int `json:"concurrency"`
Status string `json:"status" binding:"omitempty,oneof=active disabled"`
AllowedGroups *[]int64 `json:"allowed_groups"`
// GroupRates 用户专属分组倍率配置
// map[groupID]*ratenil 表示删除该分组的专属倍率
GroupRates map[int64]*float64 `json:"group_rates"`
}
// UpdateBalanceRequest represents balance update request
@@ -183,6 +186,7 @@ func (h *UserHandler) Update(c *gin.Context) {
Concurrency: req.Concurrency,
Status: req.Status,
AllowedGroups: req.AllowedGroups,
GroupRates: req.GroupRates,
})
if err != nil {
response.ErrorFrom(c, err)

View File

@@ -243,3 +243,21 @@ func (h *APIKeyHandler) GetAvailableGroups(c *gin.Context) {
}
response.Success(c, out)
}
// GetUserGroupRates 获取当前用户的专属分组倍率配置
// GET /api/v1/groups/rates
func (h *APIKeyHandler) GetUserGroupRates(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
rates, err := h.apiKeyService.GetUserGroupRates(c.Request.Context(), subject.UserID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, rates)
}

View File

@@ -58,8 +58,9 @@ func UserFromServiceAdmin(u *service.User) *AdminUser {
return nil
}
return &AdminUser{
User: *base,
Notes: u.Notes,
User: *base,
Notes: u.Notes,
GroupRates: u.GroupRates,
}
}

View File

@@ -29,6 +29,9 @@ type AdminUser struct {
User
Notes string `json:"notes"`
// GroupRates 用户专属分组倍率配置
// map[groupID]rateMultiplier
GroupRates map[int64]float64 `json:"group_rates,omitempty"`
}
type APIKey struct {

View File

@@ -33,6 +33,7 @@ type GatewayHandler struct {
billingCacheService *service.BillingCacheService
usageService *service.UsageService
apiKeyService *service.APIKeyService
errorPassthroughService *service.ErrorPassthroughService
concurrencyHelper *ConcurrencyHelper
maxAccountSwitches int
maxAccountSwitchesGemini int
@@ -48,6 +49,7 @@ func NewGatewayHandler(
billingCacheService *service.BillingCacheService,
usageService *service.UsageService,
apiKeyService *service.APIKeyService,
errorPassthroughService *service.ErrorPassthroughService,
cfg *config.Config,
) *GatewayHandler {
pingInterval := time.Duration(0)
@@ -70,6 +72,7 @@ func NewGatewayHandler(
billingCacheService: billingCacheService,
usageService: usageService,
apiKeyService: apiKeyService,
errorPassthroughService: errorPassthroughService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval),
maxAccountSwitches: maxAccountSwitches,
maxAccountSwitchesGemini: maxAccountSwitchesGemini,
@@ -201,7 +204,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
maxAccountSwitches := h.maxAccountSwitchesGemini
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
lastFailoverStatus := 0
var lastFailoverErr *service.UpstreamFailoverError
for {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, "") // Gemini 不使用会话限制
@@ -210,7 +213,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
return
}
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
if lastFailoverErr != nil {
h.handleFailoverExhausted(c, lastFailoverErr, service.PlatformGemini, streamStarted)
} else {
h.handleFailoverExhaustedSimple(c, 502, streamStarted)
}
return
}
account := selection.Account
@@ -301,9 +308,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
failedAccountIDs[account.ID] = struct{}{}
lastFailoverStatus = failoverErr.StatusCode
lastFailoverErr = failoverErr
if switchCount >= maxAccountSwitches {
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
h.handleFailoverExhausted(c, failoverErr, service.PlatformGemini, streamStarted)
return
}
switchCount++
@@ -352,7 +359,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
maxAccountSwitches := h.maxAccountSwitches
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
lastFailoverStatus := 0
var lastFailoverErr *service.UpstreamFailoverError
retryWithFallback := false
for {
@@ -363,7 +370,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
return
}
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
if lastFailoverErr != nil {
h.handleFailoverExhausted(c, lastFailoverErr, platform, streamStarted)
} else {
h.handleFailoverExhaustedSimple(c, 502, streamStarted)
}
return
}
account := selection.Account
@@ -487,9 +498,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
failedAccountIDs[account.ID] = struct{}{}
lastFailoverStatus = failoverErr.StatusCode
lastFailoverErr = failoverErr
if switchCount >= maxAccountSwitches {
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
h.handleFailoverExhausted(c, failoverErr, account.Platform, streamStarted)
return
}
switchCount++
@@ -755,7 +766,37 @@ func (h *GatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotT
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
}
func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, streamStarted bool) {
func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, platform string, streamStarted bool) {
statusCode := failoverErr.StatusCode
responseBody := failoverErr.ResponseBody
// 先检查透传规则
if h.errorPassthroughService != nil && len(responseBody) > 0 {
if rule := h.errorPassthroughService.MatchRule(platform, statusCode, responseBody); rule != nil {
// 确定响应状态码
respCode := statusCode
if !rule.PassthroughCode && rule.ResponseCode != nil {
respCode = *rule.ResponseCode
}
// 确定响应消息
msg := service.ExtractUpstreamErrorMessage(responseBody)
if !rule.PassthroughBody && rule.CustomMessage != nil {
msg = *rule.CustomMessage
}
h.handleStreamingAwareError(c, respCode, "upstream_error", msg, streamStarted)
return
}
}
// 使用默认的错误映射
status, errType, errMsg := h.mapUpstreamError(statusCode)
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
}
// handleFailoverExhaustedSimple 简化版本,用于没有响应体的情况
func (h *GatewayHandler) handleFailoverExhaustedSimple(c *gin.Context, statusCode int, streamStarted bool) {
status, errType, errMsg := h.mapUpstreamError(statusCode)
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
}

View File

@@ -253,7 +253,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
maxAccountSwitches := h.maxAccountSwitchesGemini
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
lastFailoverStatus := 0
var lastFailoverErr *service.UpstreamFailoverError
for {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, failedAccountIDs, "") // Gemini 不使用会话限制
@@ -262,7 +262,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
return
}
handleGeminiFailoverExhausted(c, lastFailoverStatus)
h.handleGeminiFailoverExhausted(c, lastFailoverErr)
return
}
account := selection.Account
@@ -353,11 +353,11 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
if errors.As(err, &failoverErr) {
failedAccountIDs[account.ID] = struct{}{}
if switchCount >= maxAccountSwitches {
lastFailoverStatus = failoverErr.StatusCode
handleGeminiFailoverExhausted(c, lastFailoverStatus)
lastFailoverErr = failoverErr
h.handleGeminiFailoverExhausted(c, lastFailoverErr)
return
}
lastFailoverStatus = failoverErr.StatusCode
lastFailoverErr = failoverErr
switchCount++
log.Printf("Gemini account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
continue
@@ -414,7 +414,36 @@ func parseGeminiModelAction(rest string) (model string, action string, err error
return "", "", &pathParseError{"invalid model action path"}
}
func handleGeminiFailoverExhausted(c *gin.Context, statusCode int) {
func (h *GatewayHandler) handleGeminiFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError) {
if failoverErr == nil {
googleError(c, http.StatusBadGateway, "Upstream request failed")
return
}
statusCode := failoverErr.StatusCode
responseBody := failoverErr.ResponseBody
// 先检查透传规则
if h.errorPassthroughService != nil && len(responseBody) > 0 {
if rule := h.errorPassthroughService.MatchRule(service.PlatformGemini, statusCode, responseBody); rule != nil {
// 确定响应状态码
respCode := statusCode
if !rule.PassthroughCode && rule.ResponseCode != nil {
respCode = *rule.ResponseCode
}
// 确定响应消息
msg := service.ExtractUpstreamErrorMessage(responseBody)
if !rule.PassthroughBody && rule.CustomMessage != nil {
msg = *rule.CustomMessage
}
googleError(c, respCode, msg)
return
}
}
// 使用默认的错误映射
status, message := mapGeminiUpstreamError(statusCode)
googleError(c, status, message)
}

View File

@@ -24,6 +24,7 @@ type AdminHandlers struct {
Subscription *admin.SubscriptionHandler
Usage *admin.UsageHandler
UserAttribute *admin.UserAttributeHandler
ErrorPassthrough *admin.ErrorPassthroughHandler
}
// Handlers contains all HTTP handlers

View File

@@ -22,11 +22,12 @@ import (
// OpenAIGatewayHandler handles OpenAI API gateway requests
type OpenAIGatewayHandler struct {
gatewayService *service.OpenAIGatewayService
billingCacheService *service.BillingCacheService
apiKeyService *service.APIKeyService
concurrencyHelper *ConcurrencyHelper
maxAccountSwitches int
gatewayService *service.OpenAIGatewayService
billingCacheService *service.BillingCacheService
apiKeyService *service.APIKeyService
errorPassthroughService *service.ErrorPassthroughService
concurrencyHelper *ConcurrencyHelper
maxAccountSwitches int
}
// NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler
@@ -35,6 +36,7 @@ func NewOpenAIGatewayHandler(
concurrencyService *service.ConcurrencyService,
billingCacheService *service.BillingCacheService,
apiKeyService *service.APIKeyService,
errorPassthroughService *service.ErrorPassthroughService,
cfg *config.Config,
) *OpenAIGatewayHandler {
pingInterval := time.Duration(0)
@@ -46,11 +48,12 @@ func NewOpenAIGatewayHandler(
}
}
return &OpenAIGatewayHandler{
gatewayService: gatewayService,
billingCacheService: billingCacheService,
apiKeyService: apiKeyService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
maxAccountSwitches: maxAccountSwitches,
gatewayService: gatewayService,
billingCacheService: billingCacheService,
apiKeyService: apiKeyService,
errorPassthroughService: errorPassthroughService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
maxAccountSwitches: maxAccountSwitches,
}
}
@@ -201,7 +204,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
maxAccountSwitches := h.maxAccountSwitches
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
lastFailoverStatus := 0
var lastFailoverErr *service.UpstreamFailoverError
for {
// Select account supporting the requested model
@@ -213,7 +216,11 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
return
}
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
if lastFailoverErr != nil {
h.handleFailoverExhausted(c, lastFailoverErr, streamStarted)
} else {
h.handleFailoverExhaustedSimple(c, 502, streamStarted)
}
return
}
account := selection.Account
@@ -278,12 +285,11 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
failedAccountIDs[account.ID] = struct{}{}
lastFailoverErr = failoverErr
if switchCount >= maxAccountSwitches {
lastFailoverStatus = failoverErr.StatusCode
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
h.handleFailoverExhausted(c, failoverErr, streamStarted)
return
}
lastFailoverStatus = failoverErr.StatusCode
switchCount++
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
continue
@@ -324,7 +330,37 @@ func (h *OpenAIGatewayHandler) handleConcurrencyError(c *gin.Context, err error,
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
}
func (h *OpenAIGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, streamStarted bool) {
func (h *OpenAIGatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, streamStarted bool) {
statusCode := failoverErr.StatusCode
responseBody := failoverErr.ResponseBody
// 先检查透传规则
if h.errorPassthroughService != nil && len(responseBody) > 0 {
if rule := h.errorPassthroughService.MatchRule("openai", statusCode, responseBody); rule != nil {
// 确定响应状态码
respCode := statusCode
if !rule.PassthroughCode && rule.ResponseCode != nil {
respCode = *rule.ResponseCode
}
// 确定响应消息
msg := service.ExtractUpstreamErrorMessage(responseBody)
if !rule.PassthroughBody && rule.CustomMessage != nil {
msg = *rule.CustomMessage
}
h.handleStreamingAwareError(c, respCode, "upstream_error", msg, streamStarted)
return
}
}
// 使用默认的错误映射
status, errType, errMsg := h.mapUpstreamError(statusCode)
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
}
// handleFailoverExhaustedSimple 简化版本,用于没有响应体的情况
func (h *OpenAIGatewayHandler) handleFailoverExhaustedSimple(c *gin.Context, statusCode int, streamStarted bool) {
status, errType, errMsg := h.mapUpstreamError(statusCode)
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
}

View File

@@ -27,6 +27,7 @@ func ProvideAdminHandlers(
subscriptionHandler *admin.SubscriptionHandler,
usageHandler *admin.UsageHandler,
userAttributeHandler *admin.UserAttributeHandler,
errorPassthroughHandler *admin.ErrorPassthroughHandler,
) *AdminHandlers {
return &AdminHandlers{
Dashboard: dashboardHandler,
@@ -47,6 +48,7 @@ func ProvideAdminHandlers(
Subscription: subscriptionHandler,
Usage: usageHandler,
UserAttribute: userAttributeHandler,
ErrorPassthrough: errorPassthroughHandler,
}
}
@@ -125,6 +127,7 @@ var ProviderSet = wire.NewSet(
admin.NewSubscriptionHandler,
admin.NewUsageHandler,
admin.NewUserAttributeHandler,
admin.NewErrorPassthroughHandler,
// AdminHandlers and Handlers constructors
ProvideAdminHandlers,

View File

@@ -0,0 +1,74 @@
// Package model 定义服务层使用的数据模型。
package model
import "time"
// ErrorPassthroughRule 全局错误透传规则
// 用于控制上游错误如何返回给客户端
type ErrorPassthroughRule struct {
ID int64 `json:"id"`
Name string `json:"name"` // 规则名称
Enabled bool `json:"enabled"` // 是否启用
Priority int `json:"priority"` // 优先级(数字越小优先级越高)
ErrorCodes []int `json:"error_codes"` // 匹配的错误码列表OR关系
Keywords []string `json:"keywords"` // 匹配的关键词列表OR关系
MatchMode string `json:"match_mode"` // "any"(任一条件) 或 "all"(所有条件)
Platforms []string `json:"platforms"` // 适用平台列表
PassthroughCode bool `json:"passthrough_code"` // 是否透传原始状态码
ResponseCode *int `json:"response_code"` // 自定义状态码passthrough_code=false 时使用)
PassthroughBody bool `json:"passthrough_body"` // 是否透传原始错误信息
CustomMessage *string `json:"custom_message"` // 自定义错误信息passthrough_body=false 时使用)
Description *string `json:"description"` // 规则描述
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// MatchModeAny 表示任一条件匹配即可
const MatchModeAny = "any"
// MatchModeAll 表示所有条件都必须匹配
const MatchModeAll = "all"
// 支持的平台常量
const (
PlatformAnthropic = "anthropic"
PlatformOpenAI = "openai"
PlatformGemini = "gemini"
PlatformAntigravity = "antigravity"
)
// AllPlatforms 返回所有支持的平台列表
func AllPlatforms() []string {
return []string{PlatformAnthropic, PlatformOpenAI, PlatformGemini, PlatformAntigravity}
}
// Validate 验证规则配置的有效性
func (r *ErrorPassthroughRule) Validate() error {
if r.Name == "" {
return &ValidationError{Field: "name", Message: "name is required"}
}
if r.MatchMode != MatchModeAny && r.MatchMode != MatchModeAll {
return &ValidationError{Field: "match_mode", Message: "match_mode must be 'any' or 'all'"}
}
// 至少需要配置一个匹配条件(错误码或关键词)
if len(r.ErrorCodes) == 0 && len(r.Keywords) == 0 {
return &ValidationError{Field: "conditions", Message: "at least one error_code or keyword is required"}
}
if !r.PassthroughCode && (r.ResponseCode == nil || *r.ResponseCode <= 0) {
return &ValidationError{Field: "response_code", Message: "response_code is required when passthrough_code is false"}
}
if !r.PassthroughBody && (r.CustomMessage == nil || *r.CustomMessage == "") {
return &ValidationError{Field: "custom_message", Message: "custom_message is required when passthrough_body is false"}
}
return nil
}
// ValidationError 表示验证错误
type ValidationError struct {
Field string
Message string
}
func (e *ValidationError) Error() string {
return e.Field + ": " + e.Message
}

View File

@@ -71,6 +71,12 @@ var DefaultModels = []Model{
DisplayName: "Claude Opus 4.5",
CreatedAt: "2025-11-01T00:00:00Z",
},
{
ID: "claude-opus-4-6",
Type: "model",
DisplayName: "Claude Opus 4.6",
CreatedAt: "2026-02-06T00:00:00Z",
},
{
ID: "claude-sonnet-4-5-20250929",
Type: "model",

View File

@@ -0,0 +1,109 @@
// Package googleapi provides helpers for Google-style API responses.
package googleapi
import (
"encoding/json"
"fmt"
"strings"
)
// ErrorResponse represents a Google API error response
type ErrorResponse struct {
Error ErrorDetail `json:"error"`
}
// ErrorDetail contains the error details from Google API
type ErrorDetail struct {
Code int `json:"code"`
Message string `json:"message"`
Status string `json:"status"`
Details []json.RawMessage `json:"details,omitempty"`
}
// ErrorDetailInfo contains additional error information
type ErrorDetailInfo struct {
Type string `json:"@type"`
Reason string `json:"reason,omitempty"`
Domain string `json:"domain,omitempty"`
Metadata map[string]string `json:"metadata,omitempty"`
}
// ErrorHelp contains help links
type ErrorHelp struct {
Type string `json:"@type"`
Links []HelpLink `json:"links,omitempty"`
}
// HelpLink represents a help link
type HelpLink struct {
Description string `json:"description"`
URL string `json:"url"`
}
// ParseError parses a Google API error response and extracts key information
func ParseError(body string) (*ErrorResponse, error) {
var errResp ErrorResponse
if err := json.Unmarshal([]byte(body), &errResp); err != nil {
return nil, fmt.Errorf("failed to parse error response: %w", err)
}
return &errResp, nil
}
// ExtractActivationURL extracts the API activation URL from error details
func ExtractActivationURL(body string) string {
var errResp ErrorResponse
if err := json.Unmarshal([]byte(body), &errResp); err != nil {
return ""
}
// Check error details for activation URL
for _, detailRaw := range errResp.Error.Details {
// Parse as ErrorDetailInfo
var info ErrorDetailInfo
if err := json.Unmarshal(detailRaw, &info); err == nil {
if info.Metadata != nil {
if activationURL, ok := info.Metadata["activationUrl"]; ok && activationURL != "" {
return activationURL
}
}
}
// Parse as ErrorHelp
var help ErrorHelp
if err := json.Unmarshal(detailRaw, &help); err == nil {
for _, link := range help.Links {
if strings.Contains(link.Description, "activation") ||
strings.Contains(link.Description, "API activation") ||
strings.Contains(link.URL, "/apis/api/") {
return link.URL
}
}
}
}
return ""
}
// IsServiceDisabledError checks if the error is a SERVICE_DISABLED error
func IsServiceDisabledError(body string) bool {
var errResp ErrorResponse
if err := json.Unmarshal([]byte(body), &errResp); err != nil {
return false
}
// Check if it's a 403 PERMISSION_DENIED with SERVICE_DISABLED reason
if errResp.Error.Code != 403 || errResp.Error.Status != "PERMISSION_DENIED" {
return false
}
for _, detailRaw := range errResp.Error.Details {
var info ErrorDetailInfo
if err := json.Unmarshal(detailRaw, &info); err == nil {
if info.Reason == "SERVICE_DISABLED" {
return true
}
}
}
return false
}

View File

@@ -0,0 +1,143 @@
package googleapi
import (
"testing"
)
func TestExtractActivationURL(t *testing.T) {
// Test case from the user's error message
errorBody := `{
"error": {
"code": 403,
"message": "Gemini for Google Cloud API has not been used in project project-6eca5881-ab73-4736-843 before or it is disabled. Enable it by visiting https://console.developers.google.com/apis/api/cloudaicompanion.googleapis.com/overview?project=project-6eca5881-ab73-4736-843 then retry. If you enabled this API recently, wait a few minutes for the action to propagate to our systems and retry.",
"status": "PERMISSION_DENIED",
"details": [
{
"@type": "type.googleapis.com/google.rpc.ErrorInfo",
"reason": "SERVICE_DISABLED",
"domain": "googleapis.com",
"metadata": {
"service": "cloudaicompanion.googleapis.com",
"activationUrl": "https://console.developers.google.com/apis/api/cloudaicompanion.googleapis.com/overview?project=project-6eca5881-ab73-4736-843",
"consumer": "projects/project-6eca5881-ab73-4736-843",
"serviceTitle": "Gemini for Google Cloud API",
"containerInfo": "project-6eca5881-ab73-4736-843"
}
},
{
"@type": "type.googleapis.com/google.rpc.LocalizedMessage",
"locale": "en-US",
"message": "Gemini for Google Cloud API has not been used in project project-6eca5881-ab73-4736-843 before or it is disabled. Enable it by visiting https://console.developers.google.com/apis/api/cloudaicompanion.googleapis.com/overview?project=project-6eca5881-ab73-4736-843 then retry. If you enabled this API recently, wait a few minutes for the action to propagate to our systems and retry."
},
{
"@type": "type.googleapis.com/google.rpc.Help",
"links": [
{
"description": "Google developers console API activation",
"url": "https://console.developers.google.com/apis/api/cloudaicompanion.googleapis.com/overview?project=project-6eca5881-ab73-4736-843"
}
]
}
]
}
}`
activationURL := ExtractActivationURL(errorBody)
expectedURL := "https://console.developers.google.com/apis/api/cloudaicompanion.googleapis.com/overview?project=project-6eca5881-ab73-4736-843"
if activationURL != expectedURL {
t.Errorf("Expected activation URL %s, got %s", expectedURL, activationURL)
}
}
func TestIsServiceDisabledError(t *testing.T) {
tests := []struct {
name string
body string
expected bool
}{
{
name: "SERVICE_DISABLED error",
body: `{
"error": {
"code": 403,
"status": "PERMISSION_DENIED",
"details": [
{
"@type": "type.googleapis.com/google.rpc.ErrorInfo",
"reason": "SERVICE_DISABLED"
}
]
}
}`,
expected: true,
},
{
name: "Other 403 error",
body: `{
"error": {
"code": 403,
"status": "PERMISSION_DENIED",
"details": [
{
"@type": "type.googleapis.com/google.rpc.ErrorInfo",
"reason": "OTHER_REASON"
}
]
}
}`,
expected: false,
},
{
name: "404 error",
body: `{
"error": {
"code": 404,
"status": "NOT_FOUND"
}
}`,
expected: false,
},
{
name: "Invalid JSON",
body: `invalid json`,
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := IsServiceDisabledError(tt.body)
if result != tt.expected {
t.Errorf("Expected %v, got %v", tt.expected, result)
}
})
}
}
func TestParseError(t *testing.T) {
errorBody := `{
"error": {
"code": 403,
"message": "API not enabled",
"status": "PERMISSION_DENIED"
}
}`
errResp, err := ParseError(errorBody)
if err != nil {
t.Fatalf("Failed to parse error: %v", err)
}
if errResp.Error.Code != 403 {
t.Errorf("Expected code 403, got %d", errResp.Error.Code)
}
if errResp.Error.Status != "PERMISSION_DENIED" {
t.Errorf("Expected status PERMISSION_DENIED, got %s", errResp.Error.Status)
}
if errResp.Error.Message != "API not enabled" {
t.Errorf("Expected message 'API not enabled', got %s", errResp.Error.Message)
}
}

View File

@@ -15,6 +15,8 @@ type Model struct {
// DefaultModels OpenAI models list
var DefaultModels = []Model{
{ID: "gpt-5.3", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3"},
{ID: "gpt-5.3-codex", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3 Codex"},
{ID: "gpt-5.2", Object: "model", Created: 1733875200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2"},
{ID: "gpt-5.2-codex", Object: "model", Created: 1733011200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2 Codex"},
{ID: "gpt-5.1-codex-max", Object: "model", Created: 1730419200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1 Codex Max"},

View File

@@ -1089,8 +1089,9 @@ func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates m
result, err := client.ExecContext(
ctx,
"UPDATE accounts SET extra = COALESCE(extra, '{}'::jsonb) || $1::jsonb, updated_at = NOW() WHERE id = $2 AND deleted_at IS NULL",
payload, id,
string(payload), id,
)
if err != nil {
return err
}

View File

@@ -0,0 +1,128 @@
package repository
import (
"context"
"encoding/json"
"log"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
const (
errorPassthroughCacheKey = "error_passthrough_rules"
errorPassthroughPubSubKey = "error_passthrough_rules_updated"
errorPassthroughCacheTTL = 24 * time.Hour
)
type errorPassthroughCache struct {
rdb *redis.Client
localCache []*model.ErrorPassthroughRule
localMu sync.RWMutex
}
// NewErrorPassthroughCache 创建错误透传规则缓存
func NewErrorPassthroughCache(rdb *redis.Client) service.ErrorPassthroughCache {
return &errorPassthroughCache{
rdb: rdb,
}
}
// Get 从缓存获取规则列表
func (c *errorPassthroughCache) Get(ctx context.Context) ([]*model.ErrorPassthroughRule, bool) {
// 先检查本地缓存
c.localMu.RLock()
if c.localCache != nil {
rules := c.localCache
c.localMu.RUnlock()
return rules, true
}
c.localMu.RUnlock()
// 从 Redis 获取
data, err := c.rdb.Get(ctx, errorPassthroughCacheKey).Bytes()
if err != nil {
if err != redis.Nil {
log.Printf("[ErrorPassthroughCache] Failed to get from Redis: %v", err)
}
return nil, false
}
var rules []*model.ErrorPassthroughRule
if err := json.Unmarshal(data, &rules); err != nil {
log.Printf("[ErrorPassthroughCache] Failed to unmarshal rules: %v", err)
return nil, false
}
// 更新本地缓存
c.localMu.Lock()
c.localCache = rules
c.localMu.Unlock()
return rules, true
}
// Set 设置缓存
func (c *errorPassthroughCache) Set(ctx context.Context, rules []*model.ErrorPassthroughRule) error {
data, err := json.Marshal(rules)
if err != nil {
return err
}
if err := c.rdb.Set(ctx, errorPassthroughCacheKey, data, errorPassthroughCacheTTL).Err(); err != nil {
return err
}
// 更新本地缓存
c.localMu.Lock()
c.localCache = rules
c.localMu.Unlock()
return nil
}
// Invalidate 使缓存失效
func (c *errorPassthroughCache) Invalidate(ctx context.Context) error {
// 清除本地缓存
c.localMu.Lock()
c.localCache = nil
c.localMu.Unlock()
// 清除 Redis 缓存
return c.rdb.Del(ctx, errorPassthroughCacheKey).Err()
}
// NotifyUpdate 通知其他实例刷新缓存
func (c *errorPassthroughCache) NotifyUpdate(ctx context.Context) error {
return c.rdb.Publish(ctx, errorPassthroughPubSubKey, "refresh").Err()
}
// SubscribeUpdates 订阅缓存更新通知
func (c *errorPassthroughCache) SubscribeUpdates(ctx context.Context, handler func()) {
go func() {
sub := c.rdb.Subscribe(ctx, errorPassthroughPubSubKey)
defer func() { _ = sub.Close() }()
ch := sub.Channel()
for {
select {
case <-ctx.Done():
return
case msg := <-ch:
if msg == nil {
return
}
// 清除本地缓存,下次访问时会从 Redis 或数据库重新加载
c.localMu.Lock()
c.localCache = nil
c.localMu.Unlock()
// 调用处理函数
handler()
}
}
}()
}

View File

@@ -0,0 +1,178 @@
package repository
import (
"context"
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/service"
)
type errorPassthroughRepository struct {
client *ent.Client
}
// NewErrorPassthroughRepository 创建错误透传规则仓库
func NewErrorPassthroughRepository(client *ent.Client) service.ErrorPassthroughRepository {
return &errorPassthroughRepository{client: client}
}
// List 获取所有规则
func (r *errorPassthroughRepository) List(ctx context.Context) ([]*model.ErrorPassthroughRule, error) {
rules, err := r.client.ErrorPassthroughRule.Query().
Order(ent.Asc(errorpassthroughrule.FieldPriority)).
All(ctx)
if err != nil {
return nil, err
}
result := make([]*model.ErrorPassthroughRule, len(rules))
for i, rule := range rules {
result[i] = r.toModel(rule)
}
return result, nil
}
// GetByID 根据 ID 获取规则
func (r *errorPassthroughRepository) GetByID(ctx context.Context, id int64) (*model.ErrorPassthroughRule, error) {
rule, err := r.client.ErrorPassthroughRule.Get(ctx, id)
if err != nil {
if ent.IsNotFound(err) {
return nil, nil
}
return nil, err
}
return r.toModel(rule), nil
}
// Create 创建规则
func (r *errorPassthroughRepository) Create(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) {
builder := r.client.ErrorPassthroughRule.Create().
SetName(rule.Name).
SetEnabled(rule.Enabled).
SetPriority(rule.Priority).
SetMatchMode(rule.MatchMode).
SetPassthroughCode(rule.PassthroughCode).
SetPassthroughBody(rule.PassthroughBody)
if len(rule.ErrorCodes) > 0 {
builder.SetErrorCodes(rule.ErrorCodes)
}
if len(rule.Keywords) > 0 {
builder.SetKeywords(rule.Keywords)
}
if len(rule.Platforms) > 0 {
builder.SetPlatforms(rule.Platforms)
}
if rule.ResponseCode != nil {
builder.SetResponseCode(*rule.ResponseCode)
}
if rule.CustomMessage != nil {
builder.SetCustomMessage(*rule.CustomMessage)
}
if rule.Description != nil {
builder.SetDescription(*rule.Description)
}
created, err := builder.Save(ctx)
if err != nil {
return nil, err
}
return r.toModel(created), nil
}
// Update 更新规则
func (r *errorPassthroughRepository) Update(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) {
builder := r.client.ErrorPassthroughRule.UpdateOneID(rule.ID).
SetName(rule.Name).
SetEnabled(rule.Enabled).
SetPriority(rule.Priority).
SetMatchMode(rule.MatchMode).
SetPassthroughCode(rule.PassthroughCode).
SetPassthroughBody(rule.PassthroughBody)
// 处理可选字段
if len(rule.ErrorCodes) > 0 {
builder.SetErrorCodes(rule.ErrorCodes)
} else {
builder.ClearErrorCodes()
}
if len(rule.Keywords) > 0 {
builder.SetKeywords(rule.Keywords)
} else {
builder.ClearKeywords()
}
if len(rule.Platforms) > 0 {
builder.SetPlatforms(rule.Platforms)
} else {
builder.ClearPlatforms()
}
if rule.ResponseCode != nil {
builder.SetResponseCode(*rule.ResponseCode)
} else {
builder.ClearResponseCode()
}
if rule.CustomMessage != nil {
builder.SetCustomMessage(*rule.CustomMessage)
} else {
builder.ClearCustomMessage()
}
if rule.Description != nil {
builder.SetDescription(*rule.Description)
} else {
builder.ClearDescription()
}
updated, err := builder.Save(ctx)
if err != nil {
return nil, err
}
return r.toModel(updated), nil
}
// Delete 删除规则
func (r *errorPassthroughRepository) Delete(ctx context.Context, id int64) error {
return r.client.ErrorPassthroughRule.DeleteOneID(id).Exec(ctx)
}
// toModel 将 Ent 实体转换为服务模型
func (r *errorPassthroughRepository) toModel(e *ent.ErrorPassthroughRule) *model.ErrorPassthroughRule {
rule := &model.ErrorPassthroughRule{
ID: int64(e.ID),
Name: e.Name,
Enabled: e.Enabled,
Priority: e.Priority,
ErrorCodes: e.ErrorCodes,
Keywords: e.Keywords,
MatchMode: e.MatchMode,
Platforms: e.Platforms,
PassthroughCode: e.PassthroughCode,
PassthroughBody: e.PassthroughBody,
CreatedAt: e.CreatedAt,
UpdatedAt: e.UpdatedAt,
}
if e.ResponseCode != nil {
rule.ResponseCode = e.ResponseCode
}
if e.CustomMessage != nil {
rule.CustomMessage = e.CustomMessage
}
if e.Description != nil {
rule.Description = e.Description
}
// 确保切片不为 nil
if rule.ErrorCodes == nil {
rule.ErrorCodes = []int{}
}
if rule.Keywords == nil {
rule.Keywords = []string{}
}
if rule.Platforms == nil {
rule.Platforms = []string{}
}
return rule
}

View File

@@ -6,6 +6,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/imroc/req/v3"
@@ -38,9 +39,20 @@ func (c *geminiCliCodeAssistClient) LoadCodeAssist(ctx context.Context, accessTo
return nil, fmt.Errorf("request failed: %w", err)
}
if !resp.IsSuccessState() {
body := geminicli.SanitizeBodyForLogs(resp.String())
fmt.Printf("[CodeAssist] LoadCodeAssist failed: status %d, body: %s\n", resp.StatusCode, body)
return nil, fmt.Errorf("loadCodeAssist failed: status %d, body: %s", resp.StatusCode, body)
body := resp.String()
sanitizedBody := geminicli.SanitizeBodyForLogs(body)
fmt.Printf("[CodeAssist] LoadCodeAssist failed: status %d, body: %s\n", resp.StatusCode, sanitizedBody)
// Check if this is a SERVICE_DISABLED error and extract activation URL
if googleapi.IsServiceDisabledError(body) {
activationURL := googleapi.ExtractActivationURL(body)
if activationURL != "" {
return nil, fmt.Errorf("gemini API not enabled for this project, please enable it by visiting: %s\n\nAfter enabling the API, wait a few minutes for the changes to propagate, then try again", activationURL)
}
return nil, fmt.Errorf("gemini API not enabled for this project, please enable it in the Google Cloud Console at: https://console.cloud.google.com/apis/library/cloudaicompanion.googleapis.com")
}
return nil, fmt.Errorf("loadCodeAssist failed: status %d, body: %s", resp.StatusCode, sanitizedBody)
}
fmt.Printf("[CodeAssist] LoadCodeAssist success: status %d, response: %+v\n", resp.StatusCode, out)
return &out, nil
@@ -67,9 +79,20 @@ func (c *geminiCliCodeAssistClient) OnboardUser(ctx context.Context, accessToken
return nil, fmt.Errorf("request failed: %w", err)
}
if !resp.IsSuccessState() {
body := geminicli.SanitizeBodyForLogs(resp.String())
fmt.Printf("[CodeAssist] OnboardUser failed: status %d, body: %s\n", resp.StatusCode, body)
return nil, fmt.Errorf("onboardUser failed: status %d, body: %s", resp.StatusCode, body)
body := resp.String()
sanitizedBody := geminicli.SanitizeBodyForLogs(body)
fmt.Printf("[CodeAssist] OnboardUser failed: status %d, body: %s\n", resp.StatusCode, sanitizedBody)
// Check if this is a SERVICE_DISABLED error and extract activation URL
if googleapi.IsServiceDisabledError(body) {
activationURL := googleapi.ExtractActivationURL(body)
if activationURL != "" {
return nil, fmt.Errorf("gemini API not enabled for this project, please enable it by visiting: %s\n\nAfter enabling the API, wait a few minutes for the changes to propagate, then try again", activationURL)
}
return nil, fmt.Errorf("gemini API not enabled for this project, please enable it in the Google Cloud Console at: https://console.cloud.google.com/apis/library/cloudaicompanion.googleapis.com")
}
return nil, fmt.Errorf("onboardUser failed: status %d, body: %s", resp.StatusCode, sanitizedBody)
}
fmt.Printf("[CodeAssist] OnboardUser success: status %d, response: %+v\n", resp.StatusCode, out)
return &out, nil

View File

@@ -3,6 +3,7 @@ package repository
import (
"context"
"fmt"
"log"
"strconv"
"time"
@@ -153,6 +154,21 @@ func NewSessionLimitCache(rdb *redis.Client, defaultIdleTimeoutMinutes int) serv
if defaultIdleTimeoutMinutes <= 0 {
defaultIdleTimeoutMinutes = 5 // 默认 5 分钟
}
// 预加载 Lua 脚本到 Redis避免 Pipeline 中出现 NOSCRIPT 错误
ctx := context.Background()
scripts := []*redis.Script{
registerSessionScript,
refreshSessionScript,
getActiveSessionCountScript,
isSessionActiveScript,
}
for _, script := range scripts {
if err := script.Load(ctx, rdb).Err(); err != nil {
log.Printf("[SessionLimitCache] Failed to preload Lua script: %v", err)
}
}
return &sessionLimitCache{
rdb: rdb,
defaultIdleTimeout: time.Duration(defaultIdleTimeoutMinutes) * time.Minute,

View File

@@ -0,0 +1,113 @@
package repository
import (
"context"
"database/sql"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
)
type userGroupRateRepository struct {
sql sqlExecutor
}
// NewUserGroupRateRepository 创建用户专属分组倍率仓储
func NewUserGroupRateRepository(sqlDB *sql.DB) service.UserGroupRateRepository {
return &userGroupRateRepository{sql: sqlDB}
}
// GetByUserID 获取用户的所有专属分组倍率
func (r *userGroupRateRepository) GetByUserID(ctx context.Context, userID int64) (map[int64]float64, error) {
query := `SELECT group_id, rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1`
rows, err := r.sql.QueryContext(ctx, query, userID)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
result := make(map[int64]float64)
for rows.Next() {
var groupID int64
var rate float64
if err := rows.Scan(&groupID, &rate); err != nil {
return nil, err
}
result[groupID] = rate
}
if err := rows.Err(); err != nil {
return nil, err
}
return result, nil
}
// GetByUserAndGroup 获取用户在特定分组的专属倍率
func (r *userGroupRateRepository) GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) {
query := `SELECT rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2`
var rate float64
err := scanSingleRow(ctx, r.sql, query, []any{userID, groupID}, &rate)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
return &rate, nil
}
// SyncUserGroupRates 同步用户的分组专属倍率
func (r *userGroupRateRepository) SyncUserGroupRates(ctx context.Context, userID int64, rates map[int64]*float64) error {
if len(rates) == 0 {
// 如果传入空 map删除该用户的所有专属倍率
_, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE user_id = $1`, userID)
return err
}
// 分离需要删除和需要 upsert 的记录
var toDelete []int64
toUpsert := make(map[int64]float64)
for groupID, rate := range rates {
if rate == nil {
toDelete = append(toDelete, groupID)
} else {
toUpsert[groupID] = *rate
}
}
// 删除指定的记录
for _, groupID := range toDelete {
_, err := r.sql.ExecContext(ctx,
`DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2`,
userID, groupID)
if err != nil {
return err
}
}
// Upsert 记录
now := time.Now()
for groupID, rate := range toUpsert {
_, err := r.sql.ExecContext(ctx, `
INSERT INTO user_group_rate_multipliers (user_id, group_id, rate_multiplier, created_at, updated_at)
VALUES ($1, $2, $3, $4, $4)
ON CONFLICT (user_id, group_id) DO UPDATE SET rate_multiplier = $3, updated_at = $4
`, userID, groupID, rate, now)
if err != nil {
return err
}
}
return nil
}
// DeleteByGroupID 删除指定分组的所有用户专属倍率
func (r *userGroupRateRepository) DeleteByGroupID(ctx context.Context, groupID int64) error {
_, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE group_id = $1`, groupID)
return err
}
// DeleteByUserID 删除指定用户的所有专属倍率
func (r *userGroupRateRepository) DeleteByUserID(ctx context.Context, userID int64) error {
_, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE user_id = $1`, userID)
return err
}

View File

@@ -66,6 +66,8 @@ var ProviderSet = wire.NewSet(
NewUserSubscriptionRepository,
NewUserAttributeDefinitionRepository,
NewUserAttributeValueRepository,
NewUserGroupRateRepository,
NewErrorPassthroughRepository,
// Cache implementations
NewGatewayCache,
@@ -86,6 +88,7 @@ var ProviderSet = wire.NewSet(
NewProxyLatencyCache,
NewTotpCache,
NewRefreshTokenCache,
NewErrorPassthroughCache,
// Encryptors
NewAESEncryptor,

View File

@@ -593,7 +593,7 @@ func newContractDeps(t *testing.T) *contractDeps {
}
userService := service.NewUserService(userRepo, nil)
apiKeyService := service.NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, apiKeyCache, cfg)
apiKeyService := service.NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, nil, apiKeyCache, cfg)
usageRepo := newStubUsageLogRepo()
usageService := service.NewUsageService(usageRepo, userRepo, nil, nil)
@@ -607,7 +607,7 @@ func newContractDeps(t *testing.T) *contractDeps {
settingRepo := newStubSettingRepo()
settingService := service.NewSettingService(settingRepo, cfg)
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil)
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil)
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)

View File

@@ -93,6 +93,7 @@ func newTestAPIKeyService(repo service.APIKeyRepository) *service.APIKeyService
nil, // userRepo (unused in GetByKey)
nil, // groupRepo
nil, // userSubRepo
nil, // userGroupRateRepo
nil, // cache
&config.Config{},
)
@@ -187,6 +188,7 @@ func TestApiKeyAuthWithSubscriptionGoogleSetsGroupContext(t *testing.T) {
nil,
nil,
nil,
nil,
&config.Config{RunMode: config.RunModeSimple},
)

View File

@@ -59,7 +59,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
t.Run("simple_mode_bypasses_quota_check", func(t *testing.T) {
cfg := &config.Config{RunMode: config.RunModeSimple}
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
subscriptionService := service.NewSubscriptionService(nil, &stubUserSubscriptionRepo{}, nil)
router := newAuthTestRouter(apiKeyService, subscriptionService, cfg)
@@ -73,7 +73,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
t.Run("standard_mode_enforces_quota_check", func(t *testing.T) {
cfg := &config.Config{RunMode: config.RunModeStandard}
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
now := time.Now()
sub := &service.UserSubscription{
@@ -150,7 +150,7 @@ func TestAPIKeyAuthSetsGroupContext(t *testing.T) {
}
cfg := &config.Config{RunMode: config.RunModeSimple}
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
router := gin.New()
router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, nil, cfg)))
router.GET("/t", func(c *gin.Context) {
@@ -208,7 +208,7 @@ func TestAPIKeyAuthOverwritesInvalidContextGroup(t *testing.T) {
}
cfg := &config.Config{RunMode: config.RunModeSimple}
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
router := gin.New()
router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, nil, cfg)))

View File

@@ -67,6 +67,9 @@ func RegisterAdminRoutes(
// 用户属性管理
registerUserAttributeRoutes(admin, h)
// 错误透传规则管理
registerErrorPassthroughRoutes(admin, h)
}
}
@@ -391,3 +394,14 @@ func registerUserAttributeRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
attrs.DELETE("/:id", h.Admin.UserAttribute.DeleteDefinition)
}
}
func registerErrorPassthroughRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
rules := admin.Group("/error-passthrough-rules")
{
rules.GET("", h.Admin.ErrorPassthrough.List)
rules.GET("/:id", h.Admin.ErrorPassthrough.GetByID)
rules.POST("", h.Admin.ErrorPassthrough.Create)
rules.PUT("/:id", h.Admin.ErrorPassthrough.Update)
rules.DELETE("/:id", h.Admin.ErrorPassthrough.Delete)
}
}

View File

@@ -49,6 +49,7 @@ func RegisterUserRoutes(
groups := authenticated.Group("/groups")
{
groups.GET("/available", h.APIKey.GetAvailableGroups)
groups.GET("/rates", h.APIKey.GetUserGroupRates)
}
// 使用记录

View File

@@ -94,6 +94,9 @@ type UpdateUserInput struct {
Concurrency *int // 使用指针区分"未提供"和"设置为0"
Status string
AllowedGroups *[]int64 // 使用指针区分"未提供"和"设置为空数组"
// GroupRates 用户专属分组倍率配置
// map[groupID]*ratenil 表示删除该分组的专属倍率
GroupRates map[int64]*float64
}
type CreateGroupInput struct {
@@ -296,6 +299,7 @@ type adminServiceImpl struct {
proxyRepo ProxyRepository
apiKeyRepo APIKeyRepository
redeemCodeRepo RedeemCodeRepository
userGroupRateRepo UserGroupRateRepository
billingCacheService *BillingCacheService
proxyProber ProxyExitInfoProber
proxyLatencyCache ProxyLatencyCache
@@ -310,6 +314,7 @@ func NewAdminService(
proxyRepo ProxyRepository,
apiKeyRepo APIKeyRepository,
redeemCodeRepo RedeemCodeRepository,
userGroupRateRepo UserGroupRateRepository,
billingCacheService *BillingCacheService,
proxyProber ProxyExitInfoProber,
proxyLatencyCache ProxyLatencyCache,
@@ -322,6 +327,7 @@ func NewAdminService(
proxyRepo: proxyRepo,
apiKeyRepo: apiKeyRepo,
redeemCodeRepo: redeemCodeRepo,
userGroupRateRepo: userGroupRateRepo,
billingCacheService: billingCacheService,
proxyProber: proxyProber,
proxyLatencyCache: proxyLatencyCache,
@@ -336,11 +342,35 @@ func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, fi
if err != nil {
return nil, 0, err
}
// 批量加载用户专属分组倍率
if s.userGroupRateRepo != nil && len(users) > 0 {
for i := range users {
rates, err := s.userGroupRateRepo.GetByUserID(ctx, users[i].ID)
if err != nil {
log.Printf("failed to load user group rates: user_id=%d err=%v", users[i].ID, err)
continue
}
users[i].GroupRates = rates
}
}
return users, result.Total, nil
}
func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error) {
return s.userRepo.GetByID(ctx, id)
user, err := s.userRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
// 加载用户专属分组倍率
if s.userGroupRateRepo != nil {
rates, err := s.userGroupRateRepo.GetByUserID(ctx, id)
if err != nil {
log.Printf("failed to load user group rates: user_id=%d err=%v", id, err)
} else {
user.GroupRates = rates
}
}
return user, nil
}
func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInput) (*User, error) {
@@ -409,6 +439,14 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
if err := s.userRepo.Update(ctx, user); err != nil {
return nil, err
}
// 同步用户专属分组倍率
if input.GroupRates != nil && s.userGroupRateRepo != nil {
if err := s.userGroupRateRepo.SyncUserGroupRates(ctx, user.ID, input.GroupRates); err != nil {
log.Printf("failed to sync user group rates: user_id=%d err=%v", user.ID, err)
}
}
if s.authCacheInvalidator != nil {
if user.Concurrency != oldConcurrency || user.Status != oldStatus || user.Role != oldRole {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, user.ID)
@@ -944,6 +982,7 @@ func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
if err != nil {
return err
}
// 注意user_group_rate_multipliers 表通过外键 ON DELETE CASCADE 自动清理
// 事务成功后,异步失效受影响用户的订阅缓存
if len(affectedUserIDs) > 0 && s.billingCacheService != nil {

View File

@@ -1106,7 +1106,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
Message: upstreamMsg,
Detail: upstreamDetail,
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
}
return nil, s.writeMappedClaudeError(c, account, resp.StatusCode, resp.Header.Get("x-request-id"), respBody)
@@ -1779,6 +1779,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
// 处理错误响应
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
contentType := resp.Header.Get("Content-Type")
// 尽早关闭原始响应体,释放连接;后续逻辑仍可能需要读取 body因此用内存副本重新包装。
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody))
@@ -1849,10 +1850,8 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
Message: upstreamMsg,
Detail: upstreamDetail,
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: unwrappedForOps}
}
contentType := resp.Header.Get("Content-Type")
if contentType == "" {
contentType = "application/json"
}

View File

@@ -115,15 +115,16 @@ type UpdateAPIKeyRequest struct {
// APIKeyService API Key服务
type APIKeyService struct {
apiKeyRepo APIKeyRepository
userRepo UserRepository
groupRepo GroupRepository
userSubRepo UserSubscriptionRepository
cache APIKeyCache
cfg *config.Config
authCacheL1 *ristretto.Cache
authCfg apiKeyAuthCacheConfig
authGroup singleflight.Group
apiKeyRepo APIKeyRepository
userRepo UserRepository
groupRepo GroupRepository
userSubRepo UserSubscriptionRepository
userGroupRateRepo UserGroupRateRepository
cache APIKeyCache
cfg *config.Config
authCacheL1 *ristretto.Cache
authCfg apiKeyAuthCacheConfig
authGroup singleflight.Group
}
// NewAPIKeyService 创建API Key服务实例
@@ -132,16 +133,18 @@ func NewAPIKeyService(
userRepo UserRepository,
groupRepo GroupRepository,
userSubRepo UserSubscriptionRepository,
userGroupRateRepo UserGroupRateRepository,
cache APIKeyCache,
cfg *config.Config,
) *APIKeyService {
svc := &APIKeyService{
apiKeyRepo: apiKeyRepo,
userRepo: userRepo,
groupRepo: groupRepo,
userSubRepo: userSubRepo,
cache: cache,
cfg: cfg,
apiKeyRepo: apiKeyRepo,
userRepo: userRepo,
groupRepo: groupRepo,
userSubRepo: userSubRepo,
userGroupRateRepo: userGroupRateRepo,
cache: cache,
cfg: cfg,
}
svc.initAuthCache(cfg)
return svc
@@ -627,6 +630,19 @@ func (s *APIKeyService) SearchAPIKeys(ctx context.Context, userID int64, keyword
return keys, nil
}
// GetUserGroupRates 获取用户的专属分组倍率配置
// 返回 map[groupID]rateMultiplier
func (s *APIKeyService) GetUserGroupRates(ctx context.Context, userID int64) (map[int64]float64, error) {
if s.userGroupRateRepo == nil {
return nil, nil
}
rates, err := s.userGroupRateRepo.GetByUserID(ctx, userID)
if err != nil {
return nil, fmt.Errorf("get user group rates: %w", err)
}
return rates, nil
}
// CheckAPIKeyQuotaAndExpiry checks if the API key is valid for use (not expired, quota not exhausted)
// Returns nil if valid, error if invalid
func (s *APIKeyService) CheckAPIKeyQuotaAndExpiry(apiKey *APIKey) error {

View File

@@ -167,7 +167,7 @@ func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) {
NegativeTTLSeconds: 30,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
groupID := int64(9)
cacheEntry := &APIKeyAuthCacheEntry{
@@ -223,7 +223,7 @@ func TestAPIKeyService_GetByKey_NegativeCache(t *testing.T) {
NegativeTTLSeconds: 30,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
return &APIKeyAuthCacheEntry{NotFound: true}, nil
}
@@ -256,7 +256,7 @@ func TestAPIKeyService_GetByKey_CacheMissStoresL2(t *testing.T) {
NegativeTTLSeconds: 30,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
return nil, redis.Nil
}
@@ -293,7 +293,7 @@ func TestAPIKeyService_GetByKey_UsesL1Cache(t *testing.T) {
L1TTLSeconds: 60,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
require.NotNil(t, svc.authCacheL1)
_, err := svc.GetByKey(context.Background(), "k-l1")
@@ -320,7 +320,7 @@ func TestAPIKeyService_InvalidateAuthCacheByUserID(t *testing.T) {
NegativeTTLSeconds: 30,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
svc.InvalidateAuthCacheByUserID(context.Background(), 7)
require.Len(t, cache.deleteAuthKeys, 2)
@@ -338,7 +338,7 @@ func TestAPIKeyService_InvalidateAuthCacheByGroupID(t *testing.T) {
L2TTLSeconds: 60,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
svc.InvalidateAuthCacheByGroupID(context.Background(), 9)
require.Len(t, cache.deleteAuthKeys, 2)
@@ -356,7 +356,7 @@ func TestAPIKeyService_InvalidateAuthCacheByKey(t *testing.T) {
L2TTLSeconds: 60,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
svc.InvalidateAuthCacheByKey(context.Background(), "k1")
require.Len(t, cache.deleteAuthKeys, 1)
@@ -375,7 +375,7 @@ func TestAPIKeyService_GetByKey_CachesNegativeOnRepoMiss(t *testing.T) {
NegativeTTLSeconds: 30,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
return nil, redis.Nil
}
@@ -411,7 +411,7 @@ func TestAPIKeyService_GetByKey_SingleflightCollapses(t *testing.T) {
Singleflight: true,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
start := make(chan struct{})
wg := sync.WaitGroup{}

View File

@@ -0,0 +1,300 @@
package service
import (
"context"
"log"
"sort"
"strings"
"sync"
"github.com/Wei-Shaw/sub2api/internal/model"
)
// ErrorPassthroughRepository 定义错误透传规则的数据访问接口
type ErrorPassthroughRepository interface {
// List 获取所有规则
List(ctx context.Context) ([]*model.ErrorPassthroughRule, error)
// GetByID 根据 ID 获取规则
GetByID(ctx context.Context, id int64) (*model.ErrorPassthroughRule, error)
// Create 创建规则
Create(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error)
// Update 更新规则
Update(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error)
// Delete 删除规则
Delete(ctx context.Context, id int64) error
}
// ErrorPassthroughCache 定义错误透传规则的缓存接口
type ErrorPassthroughCache interface {
// Get 从缓存获取规则列表
Get(ctx context.Context) ([]*model.ErrorPassthroughRule, bool)
// Set 设置缓存
Set(ctx context.Context, rules []*model.ErrorPassthroughRule) error
// Invalidate 使缓存失效
Invalidate(ctx context.Context) error
// NotifyUpdate 通知其他实例刷新缓存
NotifyUpdate(ctx context.Context) error
// SubscribeUpdates 订阅缓存更新通知
SubscribeUpdates(ctx context.Context, handler func())
}
// ErrorPassthroughService 错误透传规则服务
type ErrorPassthroughService struct {
repo ErrorPassthroughRepository
cache ErrorPassthroughCache
// 本地内存缓存,用于快速匹配
localCache []*model.ErrorPassthroughRule
localCacheMu sync.RWMutex
}
// NewErrorPassthroughService 创建错误透传规则服务
func NewErrorPassthroughService(
repo ErrorPassthroughRepository,
cache ErrorPassthroughCache,
) *ErrorPassthroughService {
svc := &ErrorPassthroughService{
repo: repo,
cache: cache,
}
// 启动时加载规则到本地缓存
ctx := context.Background()
if err := svc.refreshLocalCache(ctx); err != nil {
log.Printf("[ErrorPassthroughService] Failed to load rules on startup: %v", err)
}
// 订阅缓存更新通知
if cache != nil {
cache.SubscribeUpdates(ctx, func() {
if err := svc.refreshLocalCache(context.Background()); err != nil {
log.Printf("[ErrorPassthroughService] Failed to refresh cache on notification: %v", err)
}
})
}
return svc
}
// List 获取所有规则
func (s *ErrorPassthroughService) List(ctx context.Context) ([]*model.ErrorPassthroughRule, error) {
return s.repo.List(ctx)
}
// GetByID 根据 ID 获取规则
func (s *ErrorPassthroughService) GetByID(ctx context.Context, id int64) (*model.ErrorPassthroughRule, error) {
return s.repo.GetByID(ctx, id)
}
// Create 创建规则
func (s *ErrorPassthroughService) Create(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) {
if err := rule.Validate(); err != nil {
return nil, err
}
created, err := s.repo.Create(ctx, rule)
if err != nil {
return nil, err
}
// 刷新缓存
s.invalidateAndNotify(ctx)
return created, nil
}
// Update 更新规则
func (s *ErrorPassthroughService) Update(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) {
if err := rule.Validate(); err != nil {
return nil, err
}
updated, err := s.repo.Update(ctx, rule)
if err != nil {
return nil, err
}
// 刷新缓存
s.invalidateAndNotify(ctx)
return updated, nil
}
// Delete 删除规则
func (s *ErrorPassthroughService) Delete(ctx context.Context, id int64) error {
if err := s.repo.Delete(ctx, id); err != nil {
return err
}
// 刷新缓存
s.invalidateAndNotify(ctx)
return nil
}
// MatchRule 匹配透传规则
// 返回第一个匹配的规则,如果没有匹配则返回 nil
func (s *ErrorPassthroughService) MatchRule(platform string, statusCode int, body []byte) *model.ErrorPassthroughRule {
rules := s.getCachedRules()
if len(rules) == 0 {
return nil
}
bodyStr := strings.ToLower(string(body))
for _, rule := range rules {
if !rule.Enabled {
continue
}
if !s.platformMatches(rule, platform) {
continue
}
if s.ruleMatches(rule, statusCode, bodyStr) {
return rule
}
}
return nil
}
// getCachedRules 获取缓存的规则列表(按优先级排序)
func (s *ErrorPassthroughService) getCachedRules() []*model.ErrorPassthroughRule {
s.localCacheMu.RLock()
rules := s.localCache
s.localCacheMu.RUnlock()
if rules != nil {
return rules
}
// 如果本地缓存为空,尝试刷新
ctx := context.Background()
if err := s.refreshLocalCache(ctx); err != nil {
log.Printf("[ErrorPassthroughService] Failed to refresh cache: %v", err)
return nil
}
s.localCacheMu.RLock()
defer s.localCacheMu.RUnlock()
return s.localCache
}
// refreshLocalCache 刷新本地缓存
func (s *ErrorPassthroughService) refreshLocalCache(ctx context.Context) error {
// 先尝试从 Redis 缓存获取
if s.cache != nil {
if rules, ok := s.cache.Get(ctx); ok {
s.setLocalCache(rules)
return nil
}
}
// 从数据库加载repo.List 已按 priority 排序)
rules, err := s.repo.List(ctx)
if err != nil {
return err
}
// 更新 Redis 缓存
if s.cache != nil {
if err := s.cache.Set(ctx, rules); err != nil {
log.Printf("[ErrorPassthroughService] Failed to set cache: %v", err)
}
}
// 更新本地缓存setLocalCache 内部会确保排序)
s.setLocalCache(rules)
return nil
}
// setLocalCache 设置本地缓存
func (s *ErrorPassthroughService) setLocalCache(rules []*model.ErrorPassthroughRule) {
// 按优先级排序
sorted := make([]*model.ErrorPassthroughRule, len(rules))
copy(sorted, rules)
sort.Slice(sorted, func(i, j int) bool {
return sorted[i].Priority < sorted[j].Priority
})
s.localCacheMu.Lock()
s.localCache = sorted
s.localCacheMu.Unlock()
}
// invalidateAndNotify 使缓存失效并通知其他实例
func (s *ErrorPassthroughService) invalidateAndNotify(ctx context.Context) {
// 刷新本地缓存
if err := s.refreshLocalCache(ctx); err != nil {
log.Printf("[ErrorPassthroughService] Failed to refresh local cache: %v", err)
}
// 通知其他实例
if s.cache != nil {
if err := s.cache.NotifyUpdate(ctx); err != nil {
log.Printf("[ErrorPassthroughService] Failed to notify cache update: %v", err)
}
}
}
// platformMatches 检查平台是否匹配
func (s *ErrorPassthroughService) platformMatches(rule *model.ErrorPassthroughRule, platform string) bool {
// 如果没有配置平台限制,则匹配所有平台
if len(rule.Platforms) == 0 {
return true
}
platform = strings.ToLower(platform)
for _, p := range rule.Platforms {
if strings.ToLower(p) == platform {
return true
}
}
return false
}
// ruleMatches 检查规则是否匹配
func (s *ErrorPassthroughService) ruleMatches(rule *model.ErrorPassthroughRule, statusCode int, bodyLower string) bool {
hasErrorCodes := len(rule.ErrorCodes) > 0
hasKeywords := len(rule.Keywords) > 0
// 如果没有配置任何条件,不匹配
if !hasErrorCodes && !hasKeywords {
return false
}
codeMatch := !hasErrorCodes || s.containsInt(rule.ErrorCodes, statusCode)
keywordMatch := !hasKeywords || s.containsAnyKeyword(bodyLower, rule.Keywords)
if rule.MatchMode == model.MatchModeAll {
// "all" 模式:所有配置的条件都必须满足
return codeMatch && keywordMatch
}
// "any" 模式:任一条件满足即可
if hasErrorCodes && hasKeywords {
return codeMatch || keywordMatch
}
return codeMatch && keywordMatch
}
// containsInt 检查切片是否包含指定整数
func (s *ErrorPassthroughService) containsInt(slice []int, val int) bool {
for _, v := range slice {
if v == val {
return true
}
}
return false
}
// containsAnyKeyword 检查字符串是否包含任一关键词(不区分大小写)
func (s *ErrorPassthroughService) containsAnyKeyword(bodyLower string, keywords []string) bool {
for _, kw := range keywords {
if strings.Contains(bodyLower, strings.ToLower(kw)) {
return true
}
}
return false
}

View File

@@ -0,0 +1,755 @@
//go:build unit
package service
import (
"context"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// mockErrorPassthroughRepo 用于测试的 mock repository
type mockErrorPassthroughRepo struct {
rules []*model.ErrorPassthroughRule
}
func (m *mockErrorPassthroughRepo) List(ctx context.Context) ([]*model.ErrorPassthroughRule, error) {
return m.rules, nil
}
func (m *mockErrorPassthroughRepo) GetByID(ctx context.Context, id int64) (*model.ErrorPassthroughRule, error) {
for _, r := range m.rules {
if r.ID == id {
return r, nil
}
}
return nil, nil
}
func (m *mockErrorPassthroughRepo) Create(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) {
rule.ID = int64(len(m.rules) + 1)
m.rules = append(m.rules, rule)
return rule, nil
}
func (m *mockErrorPassthroughRepo) Update(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) {
for i, r := range m.rules {
if r.ID == rule.ID {
m.rules[i] = rule
return rule, nil
}
}
return rule, nil
}
func (m *mockErrorPassthroughRepo) Delete(ctx context.Context, id int64) error {
for i, r := range m.rules {
if r.ID == id {
m.rules = append(m.rules[:i], m.rules[i+1:]...)
return nil
}
}
return nil
}
// newTestService 创建测试用的服务实例
func newTestService(rules []*model.ErrorPassthroughRule) *ErrorPassthroughService {
repo := &mockErrorPassthroughRepo{rules: rules}
svc := &ErrorPassthroughService{
repo: repo,
cache: nil, // 不使用缓存
}
// 直接设置本地缓存,避免调用 refreshLocalCache
svc.setLocalCache(rules)
return svc
}
// =============================================================================
// 测试 ruleMatches 核心匹配逻辑
// =============================================================================
func TestRuleMatches_NoConditions(t *testing.T) {
// 没有配置任何条件时,不应该匹配
svc := newTestService(nil)
rule := &model.ErrorPassthroughRule{
Enabled: true,
ErrorCodes: []int{},
Keywords: []string{},
MatchMode: model.MatchModeAny,
}
assert.False(t, svc.ruleMatches(rule, 422, "some error message"),
"没有配置条件时不应该匹配")
}
func TestRuleMatches_OnlyErrorCodes_AnyMode(t *testing.T) {
svc := newTestService(nil)
rule := &model.ErrorPassthroughRule{
Enabled: true,
ErrorCodes: []int{422, 400},
Keywords: []string{},
MatchMode: model.MatchModeAny,
}
tests := []struct {
name string
statusCode int
body string
expected bool
}{
{"状态码匹配 422", 422, "any message", true},
{"状态码匹配 400", 400, "any message", true},
{"状态码不匹配 500", 500, "any message", false},
{"状态码不匹配 429", 429, "any message", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := svc.ruleMatches(rule, tt.statusCode, tt.body)
assert.Equal(t, tt.expected, result)
})
}
}
func TestRuleMatches_OnlyKeywords_AnyMode(t *testing.T) {
svc := newTestService(nil)
rule := &model.ErrorPassthroughRule{
Enabled: true,
ErrorCodes: []int{},
Keywords: []string{"context limit", "model not supported"},
MatchMode: model.MatchModeAny,
}
tests := []struct {
name string
statusCode int
body string
expected bool
}{
{"关键词匹配 context limit", 500, "error: context limit reached", true},
{"关键词匹配 model not supported", 400, "the model not supported here", true},
{"关键词不匹配", 422, "some other error", false},
// 注意ruleMatches 接收的 body 参数应该是已经转换为小写的
// 实际使用时MatchRule 会先将 body 转换为小写再传给 ruleMatches
{"关键词大小写 - 输入已小写", 500, "context limit exceeded", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 模拟 MatchRule 的行为:先转换为小写
bodyLower := strings.ToLower(tt.body)
result := svc.ruleMatches(rule, tt.statusCode, bodyLower)
assert.Equal(t, tt.expected, result)
})
}
}
func TestRuleMatches_BothConditions_AnyMode(t *testing.T) {
// any 模式:错误码 OR 关键词
svc := newTestService(nil)
rule := &model.ErrorPassthroughRule{
Enabled: true,
ErrorCodes: []int{422, 400},
Keywords: []string{"context limit"},
MatchMode: model.MatchModeAny,
}
tests := []struct {
name string
statusCode int
body string
expected bool
reason string
}{
{
name: "状态码和关键词都匹配",
statusCode: 422,
body: "context limit reached",
expected: true,
reason: "both match",
},
{
name: "只有状态码匹配",
statusCode: 422,
body: "some other error",
expected: true,
reason: "code matches, keyword doesn't - OR mode should match",
},
{
name: "只有关键词匹配",
statusCode: 500,
body: "context limit exceeded",
expected: true,
reason: "keyword matches, code doesn't - OR mode should match",
},
{
name: "都不匹配",
statusCode: 500,
body: "some other error",
expected: false,
reason: "neither matches",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := svc.ruleMatches(rule, tt.statusCode, tt.body)
assert.Equal(t, tt.expected, result, tt.reason)
})
}
}
func TestRuleMatches_BothConditions_AllMode(t *testing.T) {
// all 模式:错误码 AND 关键词
svc := newTestService(nil)
rule := &model.ErrorPassthroughRule{
Enabled: true,
ErrorCodes: []int{422, 400},
Keywords: []string{"context limit"},
MatchMode: model.MatchModeAll,
}
tests := []struct {
name string
statusCode int
body string
expected bool
reason string
}{
{
name: "状态码和关键词都匹配",
statusCode: 422,
body: "context limit reached",
expected: true,
reason: "both match - AND mode should match",
},
{
name: "只有状态码匹配",
statusCode: 422,
body: "some other error",
expected: false,
reason: "code matches but keyword doesn't - AND mode should NOT match",
},
{
name: "只有关键词匹配",
statusCode: 500,
body: "context limit exceeded",
expected: false,
reason: "keyword matches but code doesn't - AND mode should NOT match",
},
{
name: "都不匹配",
statusCode: 500,
body: "some other error",
expected: false,
reason: "neither matches",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := svc.ruleMatches(rule, tt.statusCode, tt.body)
assert.Equal(t, tt.expected, result, tt.reason)
})
}
}
// =============================================================================
// 测试 platformMatches 平台匹配逻辑
// =============================================================================
func TestPlatformMatches(t *testing.T) {
svc := newTestService(nil)
tests := []struct {
name string
rulePlatforms []string
requestPlatform string
expected bool
}{
{
name: "空平台列表匹配所有",
rulePlatforms: []string{},
requestPlatform: "anthropic",
expected: true,
},
{
name: "nil平台列表匹配所有",
rulePlatforms: nil,
requestPlatform: "openai",
expected: true,
},
{
name: "精确匹配 anthropic",
rulePlatforms: []string{"anthropic", "openai"},
requestPlatform: "anthropic",
expected: true,
},
{
name: "精确匹配 openai",
rulePlatforms: []string{"anthropic", "openai"},
requestPlatform: "openai",
expected: true,
},
{
name: "不匹配 gemini",
rulePlatforms: []string{"anthropic", "openai"},
requestPlatform: "gemini",
expected: false,
},
{
name: "大小写不敏感",
rulePlatforms: []string{"Anthropic", "OpenAI"},
requestPlatform: "anthropic",
expected: true,
},
{
name: "匹配 antigravity",
rulePlatforms: []string{"antigravity"},
requestPlatform: "antigravity",
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rule := &model.ErrorPassthroughRule{
Platforms: tt.rulePlatforms,
}
result := svc.platformMatches(rule, tt.requestPlatform)
assert.Equal(t, tt.expected, result)
})
}
}
// =============================================================================
// 测试 MatchRule 完整匹配流程
// =============================================================================
func TestMatchRule_Priority(t *testing.T) {
// 测试规则按优先级排序,优先级小的先匹配
rules := []*model.ErrorPassthroughRule{
{
ID: 1,
Name: "Low Priority",
Enabled: true,
Priority: 10,
ErrorCodes: []int{422},
MatchMode: model.MatchModeAny,
},
{
ID: 2,
Name: "High Priority",
Enabled: true,
Priority: 1,
ErrorCodes: []int{422},
MatchMode: model.MatchModeAny,
},
}
svc := newTestService(rules)
matched := svc.MatchRule("anthropic", 422, []byte("error"))
require.NotNil(t, matched)
assert.Equal(t, int64(2), matched.ID, "应该匹配优先级更高(数值更小)的规则")
assert.Equal(t, "High Priority", matched.Name)
}
func TestMatchRule_DisabledRule(t *testing.T) {
rules := []*model.ErrorPassthroughRule{
{
ID: 1,
Name: "Disabled Rule",
Enabled: false,
Priority: 1,
ErrorCodes: []int{422},
MatchMode: model.MatchModeAny,
},
{
ID: 2,
Name: "Enabled Rule",
Enabled: true,
Priority: 10,
ErrorCodes: []int{422},
MatchMode: model.MatchModeAny,
},
}
svc := newTestService(rules)
matched := svc.MatchRule("anthropic", 422, []byte("error"))
require.NotNil(t, matched)
assert.Equal(t, int64(2), matched.ID, "应该跳过禁用的规则")
}
func TestMatchRule_PlatformFilter(t *testing.T) {
rules := []*model.ErrorPassthroughRule{
{
ID: 1,
Name: "Anthropic Only",
Enabled: true,
Priority: 1,
ErrorCodes: []int{422},
Platforms: []string{"anthropic"},
MatchMode: model.MatchModeAny,
},
{
ID: 2,
Name: "OpenAI Only",
Enabled: true,
Priority: 2,
ErrorCodes: []int{422},
Platforms: []string{"openai"},
MatchMode: model.MatchModeAny,
},
{
ID: 3,
Name: "All Platforms",
Enabled: true,
Priority: 3,
ErrorCodes: []int{422},
Platforms: []string{},
MatchMode: model.MatchModeAny,
},
}
svc := newTestService(rules)
t.Run("Anthropic 请求匹配 Anthropic 规则", func(t *testing.T) {
matched := svc.MatchRule("anthropic", 422, []byte("error"))
require.NotNil(t, matched)
assert.Equal(t, int64(1), matched.ID)
})
t.Run("OpenAI 请求匹配 OpenAI 规则", func(t *testing.T) {
matched := svc.MatchRule("openai", 422, []byte("error"))
require.NotNil(t, matched)
assert.Equal(t, int64(2), matched.ID)
})
t.Run("Gemini 请求匹配全平台规则", func(t *testing.T) {
matched := svc.MatchRule("gemini", 422, []byte("error"))
require.NotNil(t, matched)
assert.Equal(t, int64(3), matched.ID)
})
t.Run("Antigravity 请求匹配全平台规则", func(t *testing.T) {
matched := svc.MatchRule("antigravity", 422, []byte("error"))
require.NotNil(t, matched)
assert.Equal(t, int64(3), matched.ID)
})
}
func TestMatchRule_NoMatch(t *testing.T) {
rules := []*model.ErrorPassthroughRule{
{
ID: 1,
Name: "Rule for 422",
Enabled: true,
Priority: 1,
ErrorCodes: []int{422},
MatchMode: model.MatchModeAny,
},
}
svc := newTestService(rules)
matched := svc.MatchRule("anthropic", 500, []byte("error"))
assert.Nil(t, matched, "不匹配任何规则时应返回 nil")
}
func TestMatchRule_EmptyRules(t *testing.T) {
svc := newTestService([]*model.ErrorPassthroughRule{})
matched := svc.MatchRule("anthropic", 422, []byte("error"))
assert.Nil(t, matched, "没有规则时应返回 nil")
}
func TestMatchRule_CaseInsensitiveKeyword(t *testing.T) {
rules := []*model.ErrorPassthroughRule{
{
ID: 1,
Name: "Context Limit",
Enabled: true,
Priority: 1,
Keywords: []string{"Context Limit"},
MatchMode: model.MatchModeAny,
},
}
svc := newTestService(rules)
tests := []struct {
name string
body string
expected bool
}{
{"完全匹配", "Context Limit reached", true},
{"小写匹配", "context limit reached", true},
{"大写匹配", "CONTEXT LIMIT REACHED", true},
{"混合大小写", "ConTeXt LiMiT error", true},
{"不匹配", "some other error", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
matched := svc.MatchRule("anthropic", 500, []byte(tt.body))
if tt.expected {
assert.NotNil(t, matched)
} else {
assert.Nil(t, matched)
}
})
}
}
// =============================================================================
// 测试真实场景
// =============================================================================
func TestMatchRule_RealWorldScenario_ContextLimitPassthrough(t *testing.T) {
// 场景:上游返回 422 + "context limit has been reached",需要透传给客户端
rules := []*model.ErrorPassthroughRule{
{
ID: 1,
Name: "Context Limit Passthrough",
Enabled: true,
Priority: 1,
ErrorCodes: []int{422},
Keywords: []string{"context limit"},
MatchMode: model.MatchModeAll, // 必须同时满足
Platforms: []string{"anthropic", "antigravity"},
PassthroughCode: true,
PassthroughBody: true,
},
}
svc := newTestService(rules)
// 测试 Anthropic 平台
t.Run("Anthropic 422 with context limit", func(t *testing.T) {
body := []byte(`{"type":"error","error":{"type":"invalid_request","message":"The context limit has been reached"}}`)
matched := svc.MatchRule("anthropic", 422, body)
require.NotNil(t, matched)
assert.True(t, matched.PassthroughCode)
assert.True(t, matched.PassthroughBody)
})
// 测试 Antigravity 平台
t.Run("Antigravity 422 with context limit", func(t *testing.T) {
body := []byte(`{"error":"context limit exceeded"}`)
matched := svc.MatchRule("antigravity", 422, body)
require.NotNil(t, matched)
})
// 测试 OpenAI 平台(不在规则的平台列表中)
t.Run("OpenAI should not match", func(t *testing.T) {
body := []byte(`{"error":"context limit exceeded"}`)
matched := svc.MatchRule("openai", 422, body)
assert.Nil(t, matched, "OpenAI 不在规则的平台列表中")
})
// 测试状态码不匹配
t.Run("Wrong status code", func(t *testing.T) {
body := []byte(`{"error":"context limit exceeded"}`)
matched := svc.MatchRule("anthropic", 400, body)
assert.Nil(t, matched, "状态码不匹配")
})
// 测试关键词不匹配
t.Run("Wrong keyword", func(t *testing.T) {
body := []byte(`{"error":"rate limit exceeded"}`)
matched := svc.MatchRule("anthropic", 422, body)
assert.Nil(t, matched, "关键词不匹配")
})
}
func TestMatchRule_RealWorldScenario_CustomErrorMessage(t *testing.T) {
// 场景:某些错误需要返回自定义消息,隐藏上游详细信息
customMsg := "Service temporarily unavailable, please try again later"
responseCode := 503
rules := []*model.ErrorPassthroughRule{
{
ID: 1,
Name: "Hide Internal Errors",
Enabled: true,
Priority: 1,
ErrorCodes: []int{500, 502, 503},
MatchMode: model.MatchModeAny,
PassthroughCode: false,
ResponseCode: &responseCode,
PassthroughBody: false,
CustomMessage: &customMsg,
},
}
svc := newTestService(rules)
matched := svc.MatchRule("anthropic", 500, []byte("internal server error"))
require.NotNil(t, matched)
assert.False(t, matched.PassthroughCode)
assert.Equal(t, 503, *matched.ResponseCode)
assert.False(t, matched.PassthroughBody)
assert.Equal(t, customMsg, *matched.CustomMessage)
}
// =============================================================================
// 测试 model.Validate
// =============================================================================
func TestErrorPassthroughRule_Validate(t *testing.T) {
tests := []struct {
name string
rule *model.ErrorPassthroughRule
expectError bool
errorField string
}{
{
name: "有效规则 - 透传模式(含错误码)",
rule: &model.ErrorPassthroughRule{
Name: "Valid Rule",
MatchMode: model.MatchModeAny,
ErrorCodes: []int{422},
PassthroughCode: true,
PassthroughBody: true,
},
expectError: false,
},
{
name: "有效规则 - 透传模式(含关键词)",
rule: &model.ErrorPassthroughRule{
Name: "Valid Rule",
MatchMode: model.MatchModeAny,
Keywords: []string{"context limit"},
PassthroughCode: true,
PassthroughBody: true,
},
expectError: false,
},
{
name: "有效规则 - 自定义响应",
rule: &model.ErrorPassthroughRule{
Name: "Valid Rule",
MatchMode: model.MatchModeAll,
ErrorCodes: []int{500},
Keywords: []string{"internal error"},
PassthroughCode: false,
ResponseCode: testIntPtr(503),
PassthroughBody: false,
CustomMessage: testStrPtr("Custom error"),
},
expectError: false,
},
{
name: "缺少名称",
rule: &model.ErrorPassthroughRule{
Name: "",
MatchMode: model.MatchModeAny,
ErrorCodes: []int{422},
PassthroughCode: true,
PassthroughBody: true,
},
expectError: true,
errorField: "name",
},
{
name: "无效的匹配模式",
rule: &model.ErrorPassthroughRule{
Name: "Invalid Mode",
MatchMode: "invalid",
ErrorCodes: []int{422},
PassthroughCode: true,
PassthroughBody: true,
},
expectError: true,
errorField: "match_mode",
},
{
name: "缺少匹配条件(错误码和关键词都为空)",
rule: &model.ErrorPassthroughRule{
Name: "No Conditions",
MatchMode: model.MatchModeAny,
ErrorCodes: []int{},
Keywords: []string{},
PassthroughCode: true,
PassthroughBody: true,
},
expectError: true,
errorField: "conditions",
},
{
name: "缺少匹配条件nil切片",
rule: &model.ErrorPassthroughRule{
Name: "Nil Conditions",
MatchMode: model.MatchModeAny,
ErrorCodes: nil,
Keywords: nil,
PassthroughCode: true,
PassthroughBody: true,
},
expectError: true,
errorField: "conditions",
},
{
name: "自定义状态码但未提供值",
rule: &model.ErrorPassthroughRule{
Name: "Missing Code",
MatchMode: model.MatchModeAny,
ErrorCodes: []int{422},
PassthroughCode: false,
ResponseCode: nil,
PassthroughBody: true,
},
expectError: true,
errorField: "response_code",
},
{
name: "自定义消息但未提供值",
rule: &model.ErrorPassthroughRule{
Name: "Missing Message",
MatchMode: model.MatchModeAny,
ErrorCodes: []int{422},
PassthroughCode: true,
PassthroughBody: false,
CustomMessage: nil,
},
expectError: true,
errorField: "custom_message",
},
{
name: "自定义消息为空字符串",
rule: &model.ErrorPassthroughRule{
Name: "Empty Message",
MatchMode: model.MatchModeAny,
ErrorCodes: []int{422},
PassthroughCode: true,
PassthroughBody: false,
CustomMessage: testStrPtr(""),
},
expectError: true,
errorField: "custom_message",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.rule.Validate()
if tt.expectError {
require.Error(t, err)
validationErr, ok := err.(*model.ValidationError)
require.True(t, ok, "应该返回 ValidationError")
assert.Equal(t, tt.errorField, validationErr.Field)
} else {
assert.NoError(t, err)
}
})
}
}
// Helper functions
func testIntPtr(i int) *int { return &i }
func testStrPtr(s string) *string { return &s }

View File

@@ -0,0 +1,288 @@
package service
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// ---------- reconcileCachedTokens 单元测试 ----------
func TestReconcileCachedTokens_NilUsage(t *testing.T) {
assert.False(t, reconcileCachedTokens(nil))
}
func TestReconcileCachedTokens_AlreadyHasCacheRead(t *testing.T) {
// 已有标准字段,不应覆盖
usage := map[string]any{
"cache_read_input_tokens": float64(100),
"cached_tokens": float64(50),
}
assert.False(t, reconcileCachedTokens(usage))
assert.Equal(t, float64(100), usage["cache_read_input_tokens"])
}
func TestReconcileCachedTokens_KimiStyle(t *testing.T) {
// Kimi 风格cache_read_input_tokens=0cached_tokens>0
usage := map[string]any{
"input_tokens": float64(23),
"cache_creation_input_tokens": float64(0),
"cache_read_input_tokens": float64(0),
"cached_tokens": float64(23),
}
assert.True(t, reconcileCachedTokens(usage))
assert.Equal(t, float64(23), usage["cache_read_input_tokens"])
}
func TestReconcileCachedTokens_NoCachedTokens(t *testing.T) {
// 无 cached_tokens 字段(原生 Claude
usage := map[string]any{
"input_tokens": float64(100),
"cache_read_input_tokens": float64(0),
"cache_creation_input_tokens": float64(0),
}
assert.False(t, reconcileCachedTokens(usage))
assert.Equal(t, float64(0), usage["cache_read_input_tokens"])
}
func TestReconcileCachedTokens_CachedTokensZero(t *testing.T) {
// cached_tokens 为 0不应覆盖
usage := map[string]any{
"cache_read_input_tokens": float64(0),
"cached_tokens": float64(0),
}
assert.False(t, reconcileCachedTokens(usage))
assert.Equal(t, float64(0), usage["cache_read_input_tokens"])
}
func TestReconcileCachedTokens_MissingCacheReadField(t *testing.T) {
// cache_read_input_tokens 字段完全不存在cached_tokens > 0
usage := map[string]any{
"cached_tokens": float64(42),
}
assert.True(t, reconcileCachedTokens(usage))
assert.Equal(t, float64(42), usage["cache_read_input_tokens"])
}
// ---------- 流式 message_start 事件 reconcile 测试 ----------
func TestStreamingReconcile_MessageStart(t *testing.T) {
// 模拟 Kimi 返回的 message_start SSE 事件
eventJSON := `{
"type": "message_start",
"message": {
"id": "msg_123",
"type": "message",
"role": "assistant",
"model": "kimi",
"usage": {
"input_tokens": 23,
"cache_creation_input_tokens": 0,
"cache_read_input_tokens": 0,
"cached_tokens": 23
}
}
}`
var event map[string]any
require.NoError(t, json.Unmarshal([]byte(eventJSON), &event))
eventType, _ := event["type"].(string)
require.Equal(t, "message_start", eventType)
// 模拟 processSSEEvent 中的 reconcile 逻辑
if msg, ok := event["message"].(map[string]any); ok {
if u, ok := msg["usage"].(map[string]any); ok {
reconcileCachedTokens(u)
}
}
// 验证 cache_read_input_tokens 已被填充
msg, ok := event["message"].(map[string]any)
require.True(t, ok)
usage, ok := msg["usage"].(map[string]any)
require.True(t, ok)
assert.Equal(t, float64(23), usage["cache_read_input_tokens"])
// 验证重新序列化后 JSON 也包含正确值
data, err := json.Marshal(event)
require.NoError(t, err)
assert.Equal(t, int64(23), gjson.GetBytes(data, "message.usage.cache_read_input_tokens").Int())
}
func TestStreamingReconcile_MessageStart_NativeClaude(t *testing.T) {
// 原生 Claude 不返回 cached_tokensreconcile 不应改变任何值
eventJSON := `{
"type": "message_start",
"message": {
"usage": {
"input_tokens": 100,
"cache_creation_input_tokens": 50,
"cache_read_input_tokens": 30
}
}
}`
var event map[string]any
require.NoError(t, json.Unmarshal([]byte(eventJSON), &event))
if msg, ok := event["message"].(map[string]any); ok {
if u, ok := msg["usage"].(map[string]any); ok {
reconcileCachedTokens(u)
}
}
msg, ok := event["message"].(map[string]any)
require.True(t, ok)
usage, ok := msg["usage"].(map[string]any)
require.True(t, ok)
assert.Equal(t, float64(30), usage["cache_read_input_tokens"])
}
// ---------- 流式 message_delta 事件 reconcile 测试 ----------
func TestStreamingReconcile_MessageDelta(t *testing.T) {
// 模拟 Kimi 返回的 message_delta SSE 事件
eventJSON := `{
"type": "message_delta",
"usage": {
"output_tokens": 7,
"cache_read_input_tokens": 0,
"cached_tokens": 15
}
}`
var event map[string]any
require.NoError(t, json.Unmarshal([]byte(eventJSON), &event))
eventType, _ := event["type"].(string)
require.Equal(t, "message_delta", eventType)
// 模拟 processSSEEvent 中的 reconcile 逻辑
usage, ok := event["usage"].(map[string]any)
require.True(t, ok)
reconcileCachedTokens(usage)
assert.Equal(t, float64(15), usage["cache_read_input_tokens"])
}
func TestStreamingReconcile_MessageDelta_NativeClaude(t *testing.T) {
// 原生 Claude 的 message_delta 通常没有 cached_tokens
eventJSON := `{
"type": "message_delta",
"usage": {
"output_tokens": 50
}
}`
var event map[string]any
require.NoError(t, json.Unmarshal([]byte(eventJSON), &event))
usage, ok := event["usage"].(map[string]any)
require.True(t, ok)
reconcileCachedTokens(usage)
_, hasCacheRead := usage["cache_read_input_tokens"]
assert.False(t, hasCacheRead, "不应为原生 Claude 响应注入 cache_read_input_tokens")
}
// ---------- 非流式响应 reconcile 测试 ----------
func TestNonStreamingReconcile_KimiResponse(t *testing.T) {
// 模拟 Kimi 非流式响应
body := []byte(`{
"id": "msg_123",
"type": "message",
"role": "assistant",
"content": [{"type": "text", "text": "hello"}],
"model": "kimi",
"usage": {
"input_tokens": 23,
"output_tokens": 7,
"cache_creation_input_tokens": 0,
"cache_read_input_tokens": 0,
"cached_tokens": 23,
"prompt_tokens": 23,
"completion_tokens": 7
}
}`)
// 模拟 handleNonStreamingResponse 中的逻辑
var response struct {
Usage ClaudeUsage `json:"usage"`
}
require.NoError(t, json.Unmarshal(body, &response))
// reconcile
if response.Usage.CacheReadInputTokens == 0 {
cachedTokens := gjson.GetBytes(body, "usage.cached_tokens").Int()
if cachedTokens > 0 {
response.Usage.CacheReadInputTokens = int(cachedTokens)
if newBody, err := sjson.SetBytes(body, "usage.cache_read_input_tokens", cachedTokens); err == nil {
body = newBody
}
}
}
// 验证内部 usage计费用
assert.Equal(t, 23, response.Usage.CacheReadInputTokens)
assert.Equal(t, 23, response.Usage.InputTokens)
assert.Equal(t, 7, response.Usage.OutputTokens)
// 验证返回给客户端的 JSON body
assert.Equal(t, int64(23), gjson.GetBytes(body, "usage.cache_read_input_tokens").Int())
}
func TestNonStreamingReconcile_NativeClaude(t *testing.T) {
// 原生 Claude 响应cache_read_input_tokens 已有值
body := []byte(`{
"usage": {
"input_tokens": 100,
"output_tokens": 50,
"cache_creation_input_tokens": 20,
"cache_read_input_tokens": 30
}
}`)
var response struct {
Usage ClaudeUsage `json:"usage"`
}
require.NoError(t, json.Unmarshal(body, &response))
// CacheReadInputTokens == 30条件不成立整个 reconcile 分支不会执行
assert.NotZero(t, response.Usage.CacheReadInputTokens)
assert.Equal(t, 30, response.Usage.CacheReadInputTokens)
}
func TestNonStreamingReconcile_NoCachedTokens(t *testing.T) {
// 没有 cached_tokens 字段
body := []byte(`{
"usage": {
"input_tokens": 100,
"output_tokens": 50,
"cache_creation_input_tokens": 0,
"cache_read_input_tokens": 0
}
}`)
var response struct {
Usage ClaudeUsage `json:"usage"`
}
require.NoError(t, json.Unmarshal(body, &response))
if response.Usage.CacheReadInputTokens == 0 {
cachedTokens := gjson.GetBytes(body, "usage.cached_tokens").Int()
if cachedTokens > 0 {
response.Usage.CacheReadInputTokens = int(cachedTokens)
if newBody, err := sjson.SetBytes(body, "usage.cache_read_input_tokens", cachedTokens); err == nil {
body = newBody
}
}
}
// cache_read_input_tokens 应保持为 0
assert.Equal(t, 0, response.Usage.CacheReadInputTokens)
assert.Equal(t, int64(0), gjson.GetBytes(body, "usage.cache_read_input_tokens").Int())
}

View File

@@ -370,7 +370,8 @@ type ForwardResult struct {
// UpstreamFailoverError indicates an upstream error that should trigger account failover.
type UpstreamFailoverError struct {
StatusCode int
StatusCode int
ResponseBody []byte // 上游响应体,用于错误透传规则匹配
}
func (e *UpstreamFailoverError) Error() string {
@@ -384,6 +385,7 @@ type GatewayService struct {
usageLogRepo UsageLogRepository
userRepo UserRepository
userSubRepo UserSubscriptionRepository
userGroupRateRepo UserGroupRateRepository
cache GatewayCache
cfg *config.Config
schedulerSnapshot *SchedulerSnapshotService
@@ -405,6 +407,7 @@ func NewGatewayService(
usageLogRepo UsageLogRepository,
userRepo UserRepository,
userSubRepo UserSubscriptionRepository,
userGroupRateRepo UserGroupRateRepository,
cache GatewayCache,
cfg *config.Config,
schedulerSnapshot *SchedulerSnapshotService,
@@ -424,6 +427,7 @@ func NewGatewayService(
usageLogRepo: usageLogRepo,
userRepo: userRepo,
userSubRepo: userSubRepo,
userGroupRateRepo: userGroupRateRepo,
cache: cache,
cfg: cfg,
schedulerSnapshot: schedulerSnapshot,
@@ -3281,7 +3285,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
return ""
}(),
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
}
return s.handleRetryExhaustedError(ctx, resp, c, account)
}
@@ -3311,10 +3315,8 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
return ""
}(),
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
}
// 处理错误响应(不可重试的错误)
if resp.StatusCode >= 400 {
// 可选:对部分 400 触发 failover默认关闭以保持语义
if resp.StatusCode == 400 && s.cfg != nil && s.cfg.Gateway.FailoverOn400 {
@@ -3358,7 +3360,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
log.Printf("Account %d: 400 error, attempting failover", account.ID)
}
s.handleFailoverSideEffects(ctx, resp, account)
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
}
}
return s.handleErrorResponse(ctx, resp, c, account)
@@ -3755,6 +3757,12 @@ func (s *GatewayService) shouldFailoverOn400(respBody []byte) bool {
return false
}
// ExtractUpstreamErrorMessage 从上游响应体中提取错误消息
// 支持 Claude 风格的错误格式:{"type":"error","error":{"type":"...","message":"..."}}
func ExtractUpstreamErrorMessage(body []byte) string {
return extractUpstreamErrorMessage(body)
}
func extractUpstreamErrorMessage(body []byte) string {
// Claude 风格:{"type":"error","error":{"type":"...","message":"..."}}
if m := gjson.GetBytes(body, "error.message").String(); strings.TrimSpace(m) != "" {
@@ -3822,7 +3830,7 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res
shouldDisable = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
}
if shouldDisable {
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: body}
}
// 记录上游错误响应体摘要便于排障(可选:由配置控制;不回显到客户端)
@@ -4168,6 +4176,20 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
eventName = eventType
}
// 兼容 Kimi cached_tokens → cache_read_input_tokens
if eventType == "message_start" {
if msg, ok := event["message"].(map[string]any); ok {
if u, ok := msg["usage"].(map[string]any); ok {
reconcileCachedTokens(u)
}
}
}
if eventType == "message_delta" {
if u, ok := event["usage"].(map[string]any); ok {
reconcileCachedTokens(u)
}
}
if needModelReplace {
if msg, ok := event["message"].(map[string]any); ok {
if model, ok := msg["model"].(string); ok && model == mappedModel {
@@ -4518,6 +4540,17 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
return nil, fmt.Errorf("parse response: %w", err)
}
// 兼容 Kimi cached_tokens → cache_read_input_tokens
if response.Usage.CacheReadInputTokens == 0 {
cachedTokens := gjson.GetBytes(body, "usage.cached_tokens").Int()
if cachedTokens > 0 {
response.Usage.CacheReadInputTokens = int(cachedTokens)
if newBody, err := sjson.SetBytes(body, "usage.cache_read_input_tokens", cachedTokens); err == nil {
body = newBody
}
}
}
// 如果有模型映射替换响应中的model字段
if originalModel != mappedModel {
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
@@ -4609,10 +4642,17 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
account := input.Account
subscription := input.Subscription
// 获取费率倍数
// 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认)
multiplier := s.cfg.Default.RateMultiplier
if apiKey.GroupID != nil && apiKey.Group != nil {
multiplier = apiKey.Group.RateMultiplier
// 检查用户专属倍率
if s.userGroupRateRepo != nil {
if userRate, err := s.userGroupRateRepo.GetByUserAndGroup(ctx, user.ID, *apiKey.GroupID); err == nil && userRate != nil {
multiplier = *userRate
}
}
}
var cost *CostBreakdown
@@ -4773,10 +4813,17 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
account := input.Account
subscription := input.Subscription
// 获取费率倍数
// 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认)
multiplier := s.cfg.Default.RateMultiplier
if apiKey.GroupID != nil && apiKey.Group != nil {
multiplier = apiKey.Group.RateMultiplier
// 检查用户专属倍率
if s.userGroupRateRepo != nil {
if userRate, err := s.userGroupRateRepo.GetByUserAndGroup(ctx, user.ID, *apiKey.GroupID); err == nil && userRate != nil {
multiplier = *userRate
}
}
}
var cost *CostBreakdown
@@ -5289,3 +5336,21 @@ func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64,
return models
}
// reconcileCachedTokens 兼容 Kimi 等上游:
// 将 OpenAI 风格的 cached_tokens 映射到 Claude 标准的 cache_read_input_tokens
func reconcileCachedTokens(usage map[string]any) bool {
if usage == nil {
return false
}
cacheRead, _ := usage["cache_read_input_tokens"].(float64)
if cacheRead > 0 {
return false // 已有标准字段,无需处理
}
cached, _ := usage["cached_tokens"].(float64)
if cached <= 0 {
return false
}
usage["cache_read_input_tokens"] = cached
return true
}

View File

@@ -864,7 +864,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
Message: upstreamMsg,
Detail: upstreamDetail,
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
}
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
upstreamReqID := resp.Header.Get(requestIDHeader)
@@ -891,7 +891,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
Message: upstreamMsg,
Detail: upstreamDetail,
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
}
upstreamReqID := resp.Header.Get(requestIDHeader)
if upstreamReqID == "" {
@@ -1301,7 +1301,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
Message: upstreamMsg,
Detail: upstreamDetail,
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
}
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
evBody := unwrapIfNeeded(isOAuth, respBody)
@@ -1325,7 +1325,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
Message: upstreamMsg,
Detail: upstreamDetail,
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: evBody}
}
respBody = unwrapIfNeeded(isOAuth, respBody)

View File

@@ -944,6 +944,32 @@ func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, pr
return strings.TrimSpace(loadResp.CloudAICompanionProject), tierID, nil
}
// 关键逻辑:对齐 Gemini CLI 对“已注册用户”的处理方式。
// 当 LoadCodeAssist 返回了 currentTier / paidTier表示账号已注册但没有返回 cloudaicompanionProject 时:
// - 不要再调用 onboardUser通常不会再分配 project_id且可能触发 INVALID_ARGUMENT
// - 先尝试从 Cloud Resource Manager 获取可用项目;仍失败则提示用户手动填写 project_id
if loadResp != nil {
registeredTierID := strings.TrimSpace(loadResp.GetTier())
if registeredTierID != "" {
// 已注册但未返回 cloudaicompanionProject这在 Google One 用户中较常见:需要用户自行提供 project_id。
log.Printf("[GeminiOAuth] User has tier (%s) but no cloudaicompanionProject, trying Cloud Resource Manager...", registeredTierID)
// Try to get project from Cloud Resource Manager
fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL)
if fbErr == nil && strings.TrimSpace(fallback) != "" {
log.Printf("[GeminiOAuth] Found project from Cloud Resource Manager: %s", fallback)
return strings.TrimSpace(fallback), tierID, nil
}
// No project found - user must provide project_id manually
log.Printf("[GeminiOAuth] No project found from Cloud Resource Manager, user must provide project_id manually")
return "", tierID, fmt.Errorf("user is registered (tier: %s) but no project_id available. Please provide Project ID manually in the authorization form, or create a project at https://console.cloud.google.com", registeredTierID)
}
}
// 未检测到 currentTier/paidTier视为新用户继续调用 onboardUser
log.Printf("[GeminiOAuth] No currentTier/paidTier found, proceeding with onboardUser (tierID: %s)", tierID)
req := &geminicli.OnboardUserRequest{
TierID: tierID,
Metadata: geminicli.LoadCodeAssistMetadata{

View File

@@ -21,6 +21,17 @@ const (
var codexCLIInstructions string
var codexModelMap = map[string]string{
"gpt-5.3": "gpt-5.3",
"gpt-5.3-none": "gpt-5.3",
"gpt-5.3-low": "gpt-5.3",
"gpt-5.3-medium": "gpt-5.3",
"gpt-5.3-high": "gpt-5.3",
"gpt-5.3-xhigh": "gpt-5.3",
"gpt-5.3-codex": "gpt-5.3-codex",
"gpt-5.3-codex-low": "gpt-5.3-codex",
"gpt-5.3-codex-medium": "gpt-5.3-codex",
"gpt-5.3-codex-high": "gpt-5.3-codex",
"gpt-5.3-codex-xhigh": "gpt-5.3-codex",
"gpt-5.1-codex": "gpt-5.1-codex",
"gpt-5.1-codex-low": "gpt-5.1-codex",
"gpt-5.1-codex-medium": "gpt-5.1-codex",
@@ -156,6 +167,12 @@ func normalizeCodexModel(model string) string {
if strings.Contains(normalized, "gpt-5.2") || strings.Contains(normalized, "gpt 5.2") {
return "gpt-5.2"
}
if strings.Contains(normalized, "gpt-5.3-codex") || strings.Contains(normalized, "gpt 5.3 codex") {
return "gpt-5.3-codex"
}
if strings.Contains(normalized, "gpt-5.3") || strings.Contains(normalized, "gpt 5.3") {
return "gpt-5.3"
}
if strings.Contains(normalized, "gpt-5.1-codex-max") || strings.Contains(normalized, "gpt 5.1 codex max") {
return "gpt-5.1-codex-max"
}

View File

@@ -176,6 +176,19 @@ func TestApplyCodexOAuthTransform_EmptyInput(t *testing.T) {
require.Len(t, input, 0)
}
func TestNormalizeCodexModel_Gpt53(t *testing.T) {
cases := map[string]string{
"gpt-5.3": "gpt-5.3",
"gpt-5.3-codex": "gpt-5.3-codex",
"gpt-5.3-codex-xhigh": "gpt-5.3-codex",
"gpt 5.3 codex": "gpt-5.3-codex",
}
for input, expected := range cases {
require.Equal(t, expected, normalizeCodexModel(input))
}
}
func setupCodexCache(t *testing.T) {
t.Helper()

View File

@@ -940,7 +940,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
})
s.handleFailoverSideEffects(ctx, resp, account)
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
}
return s.handleErrorResponse(ctx, resp, c, account)
}
@@ -1131,7 +1131,7 @@ func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *ht
Detail: upstreamDetail,
})
if shouldDisable {
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: body}
}
// Return appropriate error response

View File

@@ -579,6 +579,7 @@ func (s *PricingService) extractBaseName(model string) string {
func (s *PricingService) matchByModelFamily(model string) *LiteLLMModelPricing {
// Claude模型系列匹配规则
familyPatterns := map[string][]string{
"opus-4.6": {"claude-opus-4.6", "claude-opus-4-6"},
"opus-4.5": {"claude-opus-4.5", "claude-opus-4-5"},
"opus-4": {"claude-opus-4", "claude-3-opus"},
"sonnet-4.5": {"claude-sonnet-4.5", "claude-sonnet-4-5"},
@@ -651,7 +652,8 @@ func (s *PricingService) matchByModelFamily(model string) *LiteLLMModelPricing {
// 回退顺序:
// 1. gpt-5.2-codex -> gpt-5.2(去掉后缀如 -codex, -mini, -max 等)
// 2. gpt-5.2-20251222 -> gpt-5.2(去掉日期版本号)
// 3. 最终回退到 DefaultTestModel (gpt-5.1-codex)
// 3. gpt-5.3-codex -> gpt-5.2-codex
// 4. 最终回退到 DefaultTestModel (gpt-5.1-codex)
func (s *PricingService) matchOpenAIModel(model string) *LiteLLMModelPricing {
// 尝试的回退变体
variants := s.generateOpenAIModelVariants(model, openAIModelDatePattern)
@@ -663,6 +665,13 @@ func (s *PricingService) matchOpenAIModel(model string) *LiteLLMModelPricing {
}
}
if strings.HasPrefix(model, "gpt-5.3-codex") {
if pricing, ok := s.pricingData["gpt-5.2-codex"]; ok {
log.Printf("[Pricing] OpenAI fallback matched %s -> %s", model, "gpt-5.2-codex")
return pricing
}
}
// 最终回退到 DefaultTestModel
defaultModel := strings.ToLower(openai.DefaultTestModel)
if pricing, ok := s.pricingData[defaultModel]; ok {

View File

@@ -21,6 +21,10 @@ type User struct {
CreatedAt time.Time
UpdatedAt time.Time
// GroupRates 用户专属分组倍率配置
// map[groupID]rateMultiplier
GroupRates map[int64]float64
// TOTP 双因素认证字段
TotpSecretEncrypted *string // AES-256-GCM 加密的 TOTP 密钥
TotpEnabled bool // 是否启用 TOTP
@@ -40,18 +44,20 @@ func (u *User) IsActive() bool {
// CanBindGroup checks whether a user can bind to a given group.
// For standard groups:
// - If AllowedGroups is non-empty, only allow binding to IDs in that list.
// - If AllowedGroups is empty (nil or length 0), allow binding to any non-exclusive group.
// - Public groups (non-exclusive): all users can bind
// - Exclusive groups: only users with the group in AllowedGroups can bind
func (u *User) CanBindGroup(groupID int64, isExclusive bool) bool {
if len(u.AllowedGroups) > 0 {
for _, id := range u.AllowedGroups {
if id == groupID {
return true
}
}
return false
// 公开分组(非专属):所有用户都可以绑定
if !isExclusive {
return true
}
return !isExclusive
// 专属分组:需要在 AllowedGroups 中
for _, id := range u.AllowedGroups {
if id == groupID {
return true
}
}
return false
}
func (u *User) SetPassword(password string) error {

View File

@@ -0,0 +1,25 @@
package service
import "context"
// UserGroupRateRepository 用户专属分组倍率仓储接口
// 允许管理员为特定用户设置分组的专属计费倍率,覆盖分组默认倍率
type UserGroupRateRepository interface {
// GetByUserID 获取用户的所有专属分组倍率
// 返回 map[groupID]rateMultiplier
GetByUserID(ctx context.Context, userID int64) (map[int64]float64, error)
// GetByUserAndGroup 获取用户在特定分组的专属倍率
// 如果未设置专属倍率,返回 nil
GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error)
// SyncUserGroupRates 同步用户的分组专属倍率
// rates: map[groupID]*rateMultipliernil 表示删除该分组的专属倍率
SyncUserGroupRates(ctx context.Context, userID int64, rates map[int64]*float64) error
// DeleteByGroupID 删除指定分组的所有用户专属倍率(分组删除时调用)
DeleteByGroupID(ctx context.Context, groupID int64) error
// DeleteByUserID 删除指定用户的所有专属倍率(用户删除时调用)
DeleteByUserID(ctx context.Context, userID int64) error
}

View File

@@ -274,4 +274,5 @@ var ProviderSet = wire.NewSet(
NewUserAttributeService,
NewUsageCache,
NewTotpService,
NewErrorPassthroughService,
)