merge upstream/main
This commit is contained in:
273
backend/internal/handler/admin/error_passthrough_handler.go
Normal file
273
backend/internal/handler/admin/error_passthrough_handler.go
Normal file
@@ -0,0 +1,273 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// ErrorPassthroughHandler 处理错误透传规则的 HTTP 请求
|
||||
type ErrorPassthroughHandler struct {
|
||||
service *service.ErrorPassthroughService
|
||||
}
|
||||
|
||||
// NewErrorPassthroughHandler 创建错误透传规则处理器
|
||||
func NewErrorPassthroughHandler(service *service.ErrorPassthroughService) *ErrorPassthroughHandler {
|
||||
return &ErrorPassthroughHandler{service: service}
|
||||
}
|
||||
|
||||
// CreateErrorPassthroughRuleRequest 创建规则请求
|
||||
type CreateErrorPassthroughRuleRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
Enabled *bool `json:"enabled"`
|
||||
Priority int `json:"priority"`
|
||||
ErrorCodes []int `json:"error_codes"`
|
||||
Keywords []string `json:"keywords"`
|
||||
MatchMode string `json:"match_mode"`
|
||||
Platforms []string `json:"platforms"`
|
||||
PassthroughCode *bool `json:"passthrough_code"`
|
||||
ResponseCode *int `json:"response_code"`
|
||||
PassthroughBody *bool `json:"passthrough_body"`
|
||||
CustomMessage *string `json:"custom_message"`
|
||||
Description *string `json:"description"`
|
||||
}
|
||||
|
||||
// UpdateErrorPassthroughRuleRequest 更新规则请求(部分更新,所有字段可选)
|
||||
type UpdateErrorPassthroughRuleRequest struct {
|
||||
Name *string `json:"name"`
|
||||
Enabled *bool `json:"enabled"`
|
||||
Priority *int `json:"priority"`
|
||||
ErrorCodes []int `json:"error_codes"`
|
||||
Keywords []string `json:"keywords"`
|
||||
MatchMode *string `json:"match_mode"`
|
||||
Platforms []string `json:"platforms"`
|
||||
PassthroughCode *bool `json:"passthrough_code"`
|
||||
ResponseCode *int `json:"response_code"`
|
||||
PassthroughBody *bool `json:"passthrough_body"`
|
||||
CustomMessage *string `json:"custom_message"`
|
||||
Description *string `json:"description"`
|
||||
}
|
||||
|
||||
// List 获取所有规则
|
||||
// GET /api/v1/admin/error-passthrough-rules
|
||||
func (h *ErrorPassthroughHandler) List(c *gin.Context) {
|
||||
rules, err := h.service.List(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, rules)
|
||||
}
|
||||
|
||||
// GetByID 根据 ID 获取规则
|
||||
// GET /api/v1/admin/error-passthrough-rules/:id
|
||||
func (h *ErrorPassthroughHandler) GetByID(c *gin.Context) {
|
||||
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid rule ID")
|
||||
return
|
||||
}
|
||||
|
||||
rule, err := h.service.GetByID(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
if rule == nil {
|
||||
response.NotFound(c, "Rule not found")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, rule)
|
||||
}
|
||||
|
||||
// Create 创建规则
|
||||
// POST /api/v1/admin/error-passthrough-rules
|
||||
func (h *ErrorPassthroughHandler) Create(c *gin.Context) {
|
||||
var req CreateErrorPassthroughRuleRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
rule := &model.ErrorPassthroughRule{
|
||||
Name: req.Name,
|
||||
Priority: req.Priority,
|
||||
ErrorCodes: req.ErrorCodes,
|
||||
Keywords: req.Keywords,
|
||||
Platforms: req.Platforms,
|
||||
}
|
||||
|
||||
// 设置默认值
|
||||
if req.Enabled != nil {
|
||||
rule.Enabled = *req.Enabled
|
||||
} else {
|
||||
rule.Enabled = true
|
||||
}
|
||||
if req.MatchMode != "" {
|
||||
rule.MatchMode = req.MatchMode
|
||||
} else {
|
||||
rule.MatchMode = model.MatchModeAny
|
||||
}
|
||||
if req.PassthroughCode != nil {
|
||||
rule.PassthroughCode = *req.PassthroughCode
|
||||
} else {
|
||||
rule.PassthroughCode = true
|
||||
}
|
||||
if req.PassthroughBody != nil {
|
||||
rule.PassthroughBody = *req.PassthroughBody
|
||||
} else {
|
||||
rule.PassthroughBody = true
|
||||
}
|
||||
rule.ResponseCode = req.ResponseCode
|
||||
rule.CustomMessage = req.CustomMessage
|
||||
rule.Description = req.Description
|
||||
|
||||
// 确保切片不为 nil
|
||||
if rule.ErrorCodes == nil {
|
||||
rule.ErrorCodes = []int{}
|
||||
}
|
||||
if rule.Keywords == nil {
|
||||
rule.Keywords = []string{}
|
||||
}
|
||||
if rule.Platforms == nil {
|
||||
rule.Platforms = []string{}
|
||||
}
|
||||
|
||||
created, err := h.service.Create(c.Request.Context(), rule)
|
||||
if err != nil {
|
||||
if _, ok := err.(*model.ValidationError); ok {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, created)
|
||||
}
|
||||
|
||||
// Update 更新规则(支持部分更新)
|
||||
// PUT /api/v1/admin/error-passthrough-rules/:id
|
||||
func (h *ErrorPassthroughHandler) Update(c *gin.Context) {
|
||||
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid rule ID")
|
||||
return
|
||||
}
|
||||
|
||||
var req UpdateErrorPassthroughRuleRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 先获取现有规则
|
||||
existing, err := h.service.GetByID(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
if existing == nil {
|
||||
response.NotFound(c, "Rule not found")
|
||||
return
|
||||
}
|
||||
|
||||
// 部分更新:只更新请求中提供的字段
|
||||
rule := &model.ErrorPassthroughRule{
|
||||
ID: id,
|
||||
Name: existing.Name,
|
||||
Enabled: existing.Enabled,
|
||||
Priority: existing.Priority,
|
||||
ErrorCodes: existing.ErrorCodes,
|
||||
Keywords: existing.Keywords,
|
||||
MatchMode: existing.MatchMode,
|
||||
Platforms: existing.Platforms,
|
||||
PassthroughCode: existing.PassthroughCode,
|
||||
ResponseCode: existing.ResponseCode,
|
||||
PassthroughBody: existing.PassthroughBody,
|
||||
CustomMessage: existing.CustomMessage,
|
||||
Description: existing.Description,
|
||||
}
|
||||
|
||||
// 应用请求中提供的更新
|
||||
if req.Name != nil {
|
||||
rule.Name = *req.Name
|
||||
}
|
||||
if req.Enabled != nil {
|
||||
rule.Enabled = *req.Enabled
|
||||
}
|
||||
if req.Priority != nil {
|
||||
rule.Priority = *req.Priority
|
||||
}
|
||||
if req.ErrorCodes != nil {
|
||||
rule.ErrorCodes = req.ErrorCodes
|
||||
}
|
||||
if req.Keywords != nil {
|
||||
rule.Keywords = req.Keywords
|
||||
}
|
||||
if req.MatchMode != nil {
|
||||
rule.MatchMode = *req.MatchMode
|
||||
}
|
||||
if req.Platforms != nil {
|
||||
rule.Platforms = req.Platforms
|
||||
}
|
||||
if req.PassthroughCode != nil {
|
||||
rule.PassthroughCode = *req.PassthroughCode
|
||||
}
|
||||
if req.ResponseCode != nil {
|
||||
rule.ResponseCode = req.ResponseCode
|
||||
}
|
||||
if req.PassthroughBody != nil {
|
||||
rule.PassthroughBody = *req.PassthroughBody
|
||||
}
|
||||
if req.CustomMessage != nil {
|
||||
rule.CustomMessage = req.CustomMessage
|
||||
}
|
||||
if req.Description != nil {
|
||||
rule.Description = req.Description
|
||||
}
|
||||
|
||||
// 确保切片不为 nil
|
||||
if rule.ErrorCodes == nil {
|
||||
rule.ErrorCodes = []int{}
|
||||
}
|
||||
if rule.Keywords == nil {
|
||||
rule.Keywords = []string{}
|
||||
}
|
||||
if rule.Platforms == nil {
|
||||
rule.Platforms = []string{}
|
||||
}
|
||||
|
||||
updated, err := h.service.Update(c.Request.Context(), rule)
|
||||
if err != nil {
|
||||
if _, ok := err.(*model.ValidationError); ok {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, updated)
|
||||
}
|
||||
|
||||
// Delete 删除规则
|
||||
// DELETE /api/v1/admin/error-passthrough-rules/:id
|
||||
func (h *ErrorPassthroughHandler) Delete(c *gin.Context) {
|
||||
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid rule ID")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.service.Delete(c.Request.Context(), id); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "Rule deleted successfully"})
|
||||
}
|
||||
@@ -45,6 +45,9 @@ type UpdateUserRequest struct {
|
||||
Concurrency *int `json:"concurrency"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=active disabled"`
|
||||
AllowedGroups *[]int64 `json:"allowed_groups"`
|
||||
// GroupRates 用户专属分组倍率配置
|
||||
// map[groupID]*rate,nil 表示删除该分组的专属倍率
|
||||
GroupRates map[int64]*float64 `json:"group_rates"`
|
||||
}
|
||||
|
||||
// UpdateBalanceRequest represents balance update request
|
||||
@@ -183,6 +186,7 @@ func (h *UserHandler) Update(c *gin.Context) {
|
||||
Concurrency: req.Concurrency,
|
||||
Status: req.Status,
|
||||
AllowedGroups: req.AllowedGroups,
|
||||
GroupRates: req.GroupRates,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
|
||||
@@ -243,3 +243,21 @@ func (h *APIKeyHandler) GetAvailableGroups(c *gin.Context) {
|
||||
}
|
||||
response.Success(c, out)
|
||||
}
|
||||
|
||||
// GetUserGroupRates 获取当前用户的专属分组倍率配置
|
||||
// GET /api/v1/groups/rates
|
||||
func (h *APIKeyHandler) GetUserGroupRates(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
rates, err := h.apiKeyService.GetUserGroupRates(c.Request.Context(), subject.UserID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, rates)
|
||||
}
|
||||
|
||||
@@ -58,8 +58,9 @@ func UserFromServiceAdmin(u *service.User) *AdminUser {
|
||||
return nil
|
||||
}
|
||||
return &AdminUser{
|
||||
User: *base,
|
||||
Notes: u.Notes,
|
||||
User: *base,
|
||||
Notes: u.Notes,
|
||||
GroupRates: u.GroupRates,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -29,6 +29,9 @@ type AdminUser struct {
|
||||
User
|
||||
|
||||
Notes string `json:"notes"`
|
||||
// GroupRates 用户专属分组倍率配置
|
||||
// map[groupID]rateMultiplier
|
||||
GroupRates map[int64]float64 `json:"group_rates,omitempty"`
|
||||
}
|
||||
|
||||
type APIKey struct {
|
||||
|
||||
@@ -33,6 +33,7 @@ type GatewayHandler struct {
|
||||
billingCacheService *service.BillingCacheService
|
||||
usageService *service.UsageService
|
||||
apiKeyService *service.APIKeyService
|
||||
errorPassthroughService *service.ErrorPassthroughService
|
||||
concurrencyHelper *ConcurrencyHelper
|
||||
maxAccountSwitches int
|
||||
maxAccountSwitchesGemini int
|
||||
@@ -48,6 +49,7 @@ func NewGatewayHandler(
|
||||
billingCacheService *service.BillingCacheService,
|
||||
usageService *service.UsageService,
|
||||
apiKeyService *service.APIKeyService,
|
||||
errorPassthroughService *service.ErrorPassthroughService,
|
||||
cfg *config.Config,
|
||||
) *GatewayHandler {
|
||||
pingInterval := time.Duration(0)
|
||||
@@ -70,6 +72,7 @@ func NewGatewayHandler(
|
||||
billingCacheService: billingCacheService,
|
||||
usageService: usageService,
|
||||
apiKeyService: apiKeyService,
|
||||
errorPassthroughService: errorPassthroughService,
|
||||
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval),
|
||||
maxAccountSwitches: maxAccountSwitches,
|
||||
maxAccountSwitchesGemini: maxAccountSwitchesGemini,
|
||||
@@ -201,7 +204,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
maxAccountSwitches := h.maxAccountSwitchesGemini
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
lastFailoverStatus := 0
|
||||
var lastFailoverErr *service.UpstreamFailoverError
|
||||
|
||||
for {
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, "") // Gemini 不使用会话限制
|
||||
@@ -210,7 +213,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||
return
|
||||
}
|
||||
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
||||
if lastFailoverErr != nil {
|
||||
h.handleFailoverExhausted(c, lastFailoverErr, service.PlatformGemini, streamStarted)
|
||||
} else {
|
||||
h.handleFailoverExhaustedSimple(c, 502, streamStarted)
|
||||
}
|
||||
return
|
||||
}
|
||||
account := selection.Account
|
||||
@@ -301,9 +308,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
lastFailoverStatus = failoverErr.StatusCode
|
||||
lastFailoverErr = failoverErr
|
||||
if switchCount >= maxAccountSwitches {
|
||||
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
||||
h.handleFailoverExhausted(c, failoverErr, service.PlatformGemini, streamStarted)
|
||||
return
|
||||
}
|
||||
switchCount++
|
||||
@@ -352,7 +359,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
maxAccountSwitches := h.maxAccountSwitches
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
lastFailoverStatus := 0
|
||||
var lastFailoverErr *service.UpstreamFailoverError
|
||||
retryWithFallback := false
|
||||
|
||||
for {
|
||||
@@ -363,7 +370,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||
return
|
||||
}
|
||||
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
||||
if lastFailoverErr != nil {
|
||||
h.handleFailoverExhausted(c, lastFailoverErr, platform, streamStarted)
|
||||
} else {
|
||||
h.handleFailoverExhaustedSimple(c, 502, streamStarted)
|
||||
}
|
||||
return
|
||||
}
|
||||
account := selection.Account
|
||||
@@ -487,9 +498,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
lastFailoverStatus = failoverErr.StatusCode
|
||||
lastFailoverErr = failoverErr
|
||||
if switchCount >= maxAccountSwitches {
|
||||
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
||||
h.handleFailoverExhausted(c, failoverErr, account.Platform, streamStarted)
|
||||
return
|
||||
}
|
||||
switchCount++
|
||||
@@ -755,7 +766,37 @@ func (h *GatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotT
|
||||
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
|
||||
}
|
||||
|
||||
func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, streamStarted bool) {
|
||||
func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, platform string, streamStarted bool) {
|
||||
statusCode := failoverErr.StatusCode
|
||||
responseBody := failoverErr.ResponseBody
|
||||
|
||||
// 先检查透传规则
|
||||
if h.errorPassthroughService != nil && len(responseBody) > 0 {
|
||||
if rule := h.errorPassthroughService.MatchRule(platform, statusCode, responseBody); rule != nil {
|
||||
// 确定响应状态码
|
||||
respCode := statusCode
|
||||
if !rule.PassthroughCode && rule.ResponseCode != nil {
|
||||
respCode = *rule.ResponseCode
|
||||
}
|
||||
|
||||
// 确定响应消息
|
||||
msg := service.ExtractUpstreamErrorMessage(responseBody)
|
||||
if !rule.PassthroughBody && rule.CustomMessage != nil {
|
||||
msg = *rule.CustomMessage
|
||||
}
|
||||
|
||||
h.handleStreamingAwareError(c, respCode, "upstream_error", msg, streamStarted)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 使用默认的错误映射
|
||||
status, errType, errMsg := h.mapUpstreamError(statusCode)
|
||||
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
||||
}
|
||||
|
||||
// handleFailoverExhaustedSimple 简化版本,用于没有响应体的情况
|
||||
func (h *GatewayHandler) handleFailoverExhaustedSimple(c *gin.Context, statusCode int, streamStarted bool) {
|
||||
status, errType, errMsg := h.mapUpstreamError(statusCode)
|
||||
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
||||
}
|
||||
|
||||
@@ -253,7 +253,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
maxAccountSwitches := h.maxAccountSwitchesGemini
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
lastFailoverStatus := 0
|
||||
var lastFailoverErr *service.UpstreamFailoverError
|
||||
|
||||
for {
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, failedAccountIDs, "") // Gemini 不使用会话限制
|
||||
@@ -262,7 +262,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
|
||||
return
|
||||
}
|
||||
handleGeminiFailoverExhausted(c, lastFailoverStatus)
|
||||
h.handleGeminiFailoverExhausted(c, lastFailoverErr)
|
||||
return
|
||||
}
|
||||
account := selection.Account
|
||||
@@ -353,11 +353,11 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
if errors.As(err, &failoverErr) {
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
if switchCount >= maxAccountSwitches {
|
||||
lastFailoverStatus = failoverErr.StatusCode
|
||||
handleGeminiFailoverExhausted(c, lastFailoverStatus)
|
||||
lastFailoverErr = failoverErr
|
||||
h.handleGeminiFailoverExhausted(c, lastFailoverErr)
|
||||
return
|
||||
}
|
||||
lastFailoverStatus = failoverErr.StatusCode
|
||||
lastFailoverErr = failoverErr
|
||||
switchCount++
|
||||
log.Printf("Gemini account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
|
||||
continue
|
||||
@@ -414,7 +414,36 @@ func parseGeminiModelAction(rest string) (model string, action string, err error
|
||||
return "", "", &pathParseError{"invalid model action path"}
|
||||
}
|
||||
|
||||
func handleGeminiFailoverExhausted(c *gin.Context, statusCode int) {
|
||||
func (h *GatewayHandler) handleGeminiFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError) {
|
||||
if failoverErr == nil {
|
||||
googleError(c, http.StatusBadGateway, "Upstream request failed")
|
||||
return
|
||||
}
|
||||
|
||||
statusCode := failoverErr.StatusCode
|
||||
responseBody := failoverErr.ResponseBody
|
||||
|
||||
// 先检查透传规则
|
||||
if h.errorPassthroughService != nil && len(responseBody) > 0 {
|
||||
if rule := h.errorPassthroughService.MatchRule(service.PlatformGemini, statusCode, responseBody); rule != nil {
|
||||
// 确定响应状态码
|
||||
respCode := statusCode
|
||||
if !rule.PassthroughCode && rule.ResponseCode != nil {
|
||||
respCode = *rule.ResponseCode
|
||||
}
|
||||
|
||||
// 确定响应消息
|
||||
msg := service.ExtractUpstreamErrorMessage(responseBody)
|
||||
if !rule.PassthroughBody && rule.CustomMessage != nil {
|
||||
msg = *rule.CustomMessage
|
||||
}
|
||||
|
||||
googleError(c, respCode, msg)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 使用默认的错误映射
|
||||
status, message := mapGeminiUpstreamError(statusCode)
|
||||
googleError(c, status, message)
|
||||
}
|
||||
|
||||
@@ -24,6 +24,7 @@ type AdminHandlers struct {
|
||||
Subscription *admin.SubscriptionHandler
|
||||
Usage *admin.UsageHandler
|
||||
UserAttribute *admin.UserAttributeHandler
|
||||
ErrorPassthrough *admin.ErrorPassthroughHandler
|
||||
}
|
||||
|
||||
// Handlers contains all HTTP handlers
|
||||
|
||||
@@ -22,11 +22,12 @@ import (
|
||||
|
||||
// OpenAIGatewayHandler handles OpenAI API gateway requests
|
||||
type OpenAIGatewayHandler struct {
|
||||
gatewayService *service.OpenAIGatewayService
|
||||
billingCacheService *service.BillingCacheService
|
||||
apiKeyService *service.APIKeyService
|
||||
concurrencyHelper *ConcurrencyHelper
|
||||
maxAccountSwitches int
|
||||
gatewayService *service.OpenAIGatewayService
|
||||
billingCacheService *service.BillingCacheService
|
||||
apiKeyService *service.APIKeyService
|
||||
errorPassthroughService *service.ErrorPassthroughService
|
||||
concurrencyHelper *ConcurrencyHelper
|
||||
maxAccountSwitches int
|
||||
}
|
||||
|
||||
// NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler
|
||||
@@ -35,6 +36,7 @@ func NewOpenAIGatewayHandler(
|
||||
concurrencyService *service.ConcurrencyService,
|
||||
billingCacheService *service.BillingCacheService,
|
||||
apiKeyService *service.APIKeyService,
|
||||
errorPassthroughService *service.ErrorPassthroughService,
|
||||
cfg *config.Config,
|
||||
) *OpenAIGatewayHandler {
|
||||
pingInterval := time.Duration(0)
|
||||
@@ -46,11 +48,12 @@ func NewOpenAIGatewayHandler(
|
||||
}
|
||||
}
|
||||
return &OpenAIGatewayHandler{
|
||||
gatewayService: gatewayService,
|
||||
billingCacheService: billingCacheService,
|
||||
apiKeyService: apiKeyService,
|
||||
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
|
||||
maxAccountSwitches: maxAccountSwitches,
|
||||
gatewayService: gatewayService,
|
||||
billingCacheService: billingCacheService,
|
||||
apiKeyService: apiKeyService,
|
||||
errorPassthroughService: errorPassthroughService,
|
||||
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
|
||||
maxAccountSwitches: maxAccountSwitches,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -201,7 +204,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
maxAccountSwitches := h.maxAccountSwitches
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
lastFailoverStatus := 0
|
||||
var lastFailoverErr *service.UpstreamFailoverError
|
||||
|
||||
for {
|
||||
// Select account supporting the requested model
|
||||
@@ -213,7 +216,11 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||
return
|
||||
}
|
||||
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
||||
if lastFailoverErr != nil {
|
||||
h.handleFailoverExhausted(c, lastFailoverErr, streamStarted)
|
||||
} else {
|
||||
h.handleFailoverExhaustedSimple(c, 502, streamStarted)
|
||||
}
|
||||
return
|
||||
}
|
||||
account := selection.Account
|
||||
@@ -278,12 +285,11 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
lastFailoverErr = failoverErr
|
||||
if switchCount >= maxAccountSwitches {
|
||||
lastFailoverStatus = failoverErr.StatusCode
|
||||
h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted)
|
||||
h.handleFailoverExhausted(c, failoverErr, streamStarted)
|
||||
return
|
||||
}
|
||||
lastFailoverStatus = failoverErr.StatusCode
|
||||
switchCount++
|
||||
log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches)
|
||||
continue
|
||||
@@ -324,7 +330,37 @@ func (h *OpenAIGatewayHandler) handleConcurrencyError(c *gin.Context, err error,
|
||||
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
|
||||
}
|
||||
|
||||
func (h *OpenAIGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, streamStarted bool) {
|
||||
func (h *OpenAIGatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, streamStarted bool) {
|
||||
statusCode := failoverErr.StatusCode
|
||||
responseBody := failoverErr.ResponseBody
|
||||
|
||||
// 先检查透传规则
|
||||
if h.errorPassthroughService != nil && len(responseBody) > 0 {
|
||||
if rule := h.errorPassthroughService.MatchRule("openai", statusCode, responseBody); rule != nil {
|
||||
// 确定响应状态码
|
||||
respCode := statusCode
|
||||
if !rule.PassthroughCode && rule.ResponseCode != nil {
|
||||
respCode = *rule.ResponseCode
|
||||
}
|
||||
|
||||
// 确定响应消息
|
||||
msg := service.ExtractUpstreamErrorMessage(responseBody)
|
||||
if !rule.PassthroughBody && rule.CustomMessage != nil {
|
||||
msg = *rule.CustomMessage
|
||||
}
|
||||
|
||||
h.handleStreamingAwareError(c, respCode, "upstream_error", msg, streamStarted)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 使用默认的错误映射
|
||||
status, errType, errMsg := h.mapUpstreamError(statusCode)
|
||||
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
||||
}
|
||||
|
||||
// handleFailoverExhaustedSimple 简化版本,用于没有响应体的情况
|
||||
func (h *OpenAIGatewayHandler) handleFailoverExhaustedSimple(c *gin.Context, statusCode int, streamStarted bool) {
|
||||
status, errType, errMsg := h.mapUpstreamError(statusCode)
|
||||
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
||||
}
|
||||
|
||||
@@ -27,6 +27,7 @@ func ProvideAdminHandlers(
|
||||
subscriptionHandler *admin.SubscriptionHandler,
|
||||
usageHandler *admin.UsageHandler,
|
||||
userAttributeHandler *admin.UserAttributeHandler,
|
||||
errorPassthroughHandler *admin.ErrorPassthroughHandler,
|
||||
) *AdminHandlers {
|
||||
return &AdminHandlers{
|
||||
Dashboard: dashboardHandler,
|
||||
@@ -47,6 +48,7 @@ func ProvideAdminHandlers(
|
||||
Subscription: subscriptionHandler,
|
||||
Usage: usageHandler,
|
||||
UserAttribute: userAttributeHandler,
|
||||
ErrorPassthrough: errorPassthroughHandler,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -125,6 +127,7 @@ var ProviderSet = wire.NewSet(
|
||||
admin.NewSubscriptionHandler,
|
||||
admin.NewUsageHandler,
|
||||
admin.NewUserAttributeHandler,
|
||||
admin.NewErrorPassthroughHandler,
|
||||
|
||||
// AdminHandlers and Handlers constructors
|
||||
ProvideAdminHandlers,
|
||||
|
||||
74
backend/internal/model/error_passthrough_rule.go
Normal file
74
backend/internal/model/error_passthrough_rule.go
Normal file
@@ -0,0 +1,74 @@
|
||||
// Package model 定义服务层使用的数据模型。
|
||||
package model
|
||||
|
||||
import "time"
|
||||
|
||||
// ErrorPassthroughRule 全局错误透传规则
|
||||
// 用于控制上游错误如何返回给客户端
|
||||
type ErrorPassthroughRule struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"` // 规则名称
|
||||
Enabled bool `json:"enabled"` // 是否启用
|
||||
Priority int `json:"priority"` // 优先级(数字越小优先级越高)
|
||||
ErrorCodes []int `json:"error_codes"` // 匹配的错误码列表(OR关系)
|
||||
Keywords []string `json:"keywords"` // 匹配的关键词列表(OR关系)
|
||||
MatchMode string `json:"match_mode"` // "any"(任一条件) 或 "all"(所有条件)
|
||||
Platforms []string `json:"platforms"` // 适用平台列表
|
||||
PassthroughCode bool `json:"passthrough_code"` // 是否透传原始状态码
|
||||
ResponseCode *int `json:"response_code"` // 自定义状态码(passthrough_code=false 时使用)
|
||||
PassthroughBody bool `json:"passthrough_body"` // 是否透传原始错误信息
|
||||
CustomMessage *string `json:"custom_message"` // 自定义错误信息(passthrough_body=false 时使用)
|
||||
Description *string `json:"description"` // 规则描述
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// MatchModeAny 表示任一条件匹配即可
|
||||
const MatchModeAny = "any"
|
||||
|
||||
// MatchModeAll 表示所有条件都必须匹配
|
||||
const MatchModeAll = "all"
|
||||
|
||||
// 支持的平台常量
|
||||
const (
|
||||
PlatformAnthropic = "anthropic"
|
||||
PlatformOpenAI = "openai"
|
||||
PlatformGemini = "gemini"
|
||||
PlatformAntigravity = "antigravity"
|
||||
)
|
||||
|
||||
// AllPlatforms 返回所有支持的平台列表
|
||||
func AllPlatforms() []string {
|
||||
return []string{PlatformAnthropic, PlatformOpenAI, PlatformGemini, PlatformAntigravity}
|
||||
}
|
||||
|
||||
// Validate 验证规则配置的有效性
|
||||
func (r *ErrorPassthroughRule) Validate() error {
|
||||
if r.Name == "" {
|
||||
return &ValidationError{Field: "name", Message: "name is required"}
|
||||
}
|
||||
if r.MatchMode != MatchModeAny && r.MatchMode != MatchModeAll {
|
||||
return &ValidationError{Field: "match_mode", Message: "match_mode must be 'any' or 'all'"}
|
||||
}
|
||||
// 至少需要配置一个匹配条件(错误码或关键词)
|
||||
if len(r.ErrorCodes) == 0 && len(r.Keywords) == 0 {
|
||||
return &ValidationError{Field: "conditions", Message: "at least one error_code or keyword is required"}
|
||||
}
|
||||
if !r.PassthroughCode && (r.ResponseCode == nil || *r.ResponseCode <= 0) {
|
||||
return &ValidationError{Field: "response_code", Message: "response_code is required when passthrough_code is false"}
|
||||
}
|
||||
if !r.PassthroughBody && (r.CustomMessage == nil || *r.CustomMessage == "") {
|
||||
return &ValidationError{Field: "custom_message", Message: "custom_message is required when passthrough_body is false"}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidationError 表示验证错误
|
||||
type ValidationError struct {
|
||||
Field string
|
||||
Message string
|
||||
}
|
||||
|
||||
func (e *ValidationError) Error() string {
|
||||
return e.Field + ": " + e.Message
|
||||
}
|
||||
@@ -71,6 +71,12 @@ var DefaultModels = []Model{
|
||||
DisplayName: "Claude Opus 4.5",
|
||||
CreatedAt: "2025-11-01T00:00:00Z",
|
||||
},
|
||||
{
|
||||
ID: "claude-opus-4-6",
|
||||
Type: "model",
|
||||
DisplayName: "Claude Opus 4.6",
|
||||
CreatedAt: "2026-02-06T00:00:00Z",
|
||||
},
|
||||
{
|
||||
ID: "claude-sonnet-4-5-20250929",
|
||||
Type: "model",
|
||||
|
||||
109
backend/internal/pkg/googleapi/error.go
Normal file
109
backend/internal/pkg/googleapi/error.go
Normal file
@@ -0,0 +1,109 @@
|
||||
// Package googleapi provides helpers for Google-style API responses.
|
||||
package googleapi
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ErrorResponse represents a Google API error response
|
||||
type ErrorResponse struct {
|
||||
Error ErrorDetail `json:"error"`
|
||||
}
|
||||
|
||||
// ErrorDetail contains the error details from Google API
|
||||
type ErrorDetail struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Status string `json:"status"`
|
||||
Details []json.RawMessage `json:"details,omitempty"`
|
||||
}
|
||||
|
||||
// ErrorDetailInfo contains additional error information
|
||||
type ErrorDetailInfo struct {
|
||||
Type string `json:"@type"`
|
||||
Reason string `json:"reason,omitempty"`
|
||||
Domain string `json:"domain,omitempty"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// ErrorHelp contains help links
|
||||
type ErrorHelp struct {
|
||||
Type string `json:"@type"`
|
||||
Links []HelpLink `json:"links,omitempty"`
|
||||
}
|
||||
|
||||
// HelpLink represents a help link
|
||||
type HelpLink struct {
|
||||
Description string `json:"description"`
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
// ParseError parses a Google API error response and extracts key information
|
||||
func ParseError(body string) (*ErrorResponse, error) {
|
||||
var errResp ErrorResponse
|
||||
if err := json.Unmarshal([]byte(body), &errResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse error response: %w", err)
|
||||
}
|
||||
return &errResp, nil
|
||||
}
|
||||
|
||||
// ExtractActivationURL extracts the API activation URL from error details
|
||||
func ExtractActivationURL(body string) string {
|
||||
var errResp ErrorResponse
|
||||
if err := json.Unmarshal([]byte(body), &errResp); err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Check error details for activation URL
|
||||
for _, detailRaw := range errResp.Error.Details {
|
||||
// Parse as ErrorDetailInfo
|
||||
var info ErrorDetailInfo
|
||||
if err := json.Unmarshal(detailRaw, &info); err == nil {
|
||||
if info.Metadata != nil {
|
||||
if activationURL, ok := info.Metadata["activationUrl"]; ok && activationURL != "" {
|
||||
return activationURL
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Parse as ErrorHelp
|
||||
var help ErrorHelp
|
||||
if err := json.Unmarshal(detailRaw, &help); err == nil {
|
||||
for _, link := range help.Links {
|
||||
if strings.Contains(link.Description, "activation") ||
|
||||
strings.Contains(link.Description, "API activation") ||
|
||||
strings.Contains(link.URL, "/apis/api/") {
|
||||
return link.URL
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// IsServiceDisabledError checks if the error is a SERVICE_DISABLED error
|
||||
func IsServiceDisabledError(body string) bool {
|
||||
var errResp ErrorResponse
|
||||
if err := json.Unmarshal([]byte(body), &errResp); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if it's a 403 PERMISSION_DENIED with SERVICE_DISABLED reason
|
||||
if errResp.Error.Code != 403 || errResp.Error.Status != "PERMISSION_DENIED" {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, detailRaw := range errResp.Error.Details {
|
||||
var info ErrorDetailInfo
|
||||
if err := json.Unmarshal(detailRaw, &info); err == nil {
|
||||
if info.Reason == "SERVICE_DISABLED" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
143
backend/internal/pkg/googleapi/error_test.go
Normal file
143
backend/internal/pkg/googleapi/error_test.go
Normal file
@@ -0,0 +1,143 @@
|
||||
package googleapi
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestExtractActivationURL(t *testing.T) {
|
||||
// Test case from the user's error message
|
||||
errorBody := `{
|
||||
"error": {
|
||||
"code": 403,
|
||||
"message": "Gemini for Google Cloud API has not been used in project project-6eca5881-ab73-4736-843 before or it is disabled. Enable it by visiting https://console.developers.google.com/apis/api/cloudaicompanion.googleapis.com/overview?project=project-6eca5881-ab73-4736-843 then retry. If you enabled this API recently, wait a few minutes for the action to propagate to our systems and retry.",
|
||||
"status": "PERMISSION_DENIED",
|
||||
"details": [
|
||||
{
|
||||
"@type": "type.googleapis.com/google.rpc.ErrorInfo",
|
||||
"reason": "SERVICE_DISABLED",
|
||||
"domain": "googleapis.com",
|
||||
"metadata": {
|
||||
"service": "cloudaicompanion.googleapis.com",
|
||||
"activationUrl": "https://console.developers.google.com/apis/api/cloudaicompanion.googleapis.com/overview?project=project-6eca5881-ab73-4736-843",
|
||||
"consumer": "projects/project-6eca5881-ab73-4736-843",
|
||||
"serviceTitle": "Gemini for Google Cloud API",
|
||||
"containerInfo": "project-6eca5881-ab73-4736-843"
|
||||
}
|
||||
},
|
||||
{
|
||||
"@type": "type.googleapis.com/google.rpc.LocalizedMessage",
|
||||
"locale": "en-US",
|
||||
"message": "Gemini for Google Cloud API has not been used in project project-6eca5881-ab73-4736-843 before or it is disabled. Enable it by visiting https://console.developers.google.com/apis/api/cloudaicompanion.googleapis.com/overview?project=project-6eca5881-ab73-4736-843 then retry. If you enabled this API recently, wait a few minutes for the action to propagate to our systems and retry."
|
||||
},
|
||||
{
|
||||
"@type": "type.googleapis.com/google.rpc.Help",
|
||||
"links": [
|
||||
{
|
||||
"description": "Google developers console API activation",
|
||||
"url": "https://console.developers.google.com/apis/api/cloudaicompanion.googleapis.com/overview?project=project-6eca5881-ab73-4736-843"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
}`
|
||||
|
||||
activationURL := ExtractActivationURL(errorBody)
|
||||
expectedURL := "https://console.developers.google.com/apis/api/cloudaicompanion.googleapis.com/overview?project=project-6eca5881-ab73-4736-843"
|
||||
|
||||
if activationURL != expectedURL {
|
||||
t.Errorf("Expected activation URL %s, got %s", expectedURL, activationURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsServiceDisabledError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "SERVICE_DISABLED error",
|
||||
body: `{
|
||||
"error": {
|
||||
"code": 403,
|
||||
"status": "PERMISSION_DENIED",
|
||||
"details": [
|
||||
{
|
||||
"@type": "type.googleapis.com/google.rpc.ErrorInfo",
|
||||
"reason": "SERVICE_DISABLED"
|
||||
}
|
||||
]
|
||||
}
|
||||
}`,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Other 403 error",
|
||||
body: `{
|
||||
"error": {
|
||||
"code": 403,
|
||||
"status": "PERMISSION_DENIED",
|
||||
"details": [
|
||||
{
|
||||
"@type": "type.googleapis.com/google.rpc.ErrorInfo",
|
||||
"reason": "OTHER_REASON"
|
||||
}
|
||||
]
|
||||
}
|
||||
}`,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "404 error",
|
||||
body: `{
|
||||
"error": {
|
||||
"code": 404,
|
||||
"status": "NOT_FOUND"
|
||||
}
|
||||
}`,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid JSON",
|
||||
body: `invalid json`,
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := IsServiceDisabledError(tt.body)
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected %v, got %v", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseError(t *testing.T) {
|
||||
errorBody := `{
|
||||
"error": {
|
||||
"code": 403,
|
||||
"message": "API not enabled",
|
||||
"status": "PERMISSION_DENIED"
|
||||
}
|
||||
}`
|
||||
|
||||
errResp, err := ParseError(errorBody)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse error: %v", err)
|
||||
}
|
||||
|
||||
if errResp.Error.Code != 403 {
|
||||
t.Errorf("Expected code 403, got %d", errResp.Error.Code)
|
||||
}
|
||||
|
||||
if errResp.Error.Status != "PERMISSION_DENIED" {
|
||||
t.Errorf("Expected status PERMISSION_DENIED, got %s", errResp.Error.Status)
|
||||
}
|
||||
|
||||
if errResp.Error.Message != "API not enabled" {
|
||||
t.Errorf("Expected message 'API not enabled', got %s", errResp.Error.Message)
|
||||
}
|
||||
}
|
||||
@@ -15,6 +15,8 @@ type Model struct {
|
||||
|
||||
// DefaultModels OpenAI models list
|
||||
var DefaultModels = []Model{
|
||||
{ID: "gpt-5.3", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3"},
|
||||
{ID: "gpt-5.3-codex", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3 Codex"},
|
||||
{ID: "gpt-5.2", Object: "model", Created: 1733875200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2"},
|
||||
{ID: "gpt-5.2-codex", Object: "model", Created: 1733011200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2 Codex"},
|
||||
{ID: "gpt-5.1-codex-max", Object: "model", Created: 1730419200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1 Codex Max"},
|
||||
|
||||
@@ -1089,8 +1089,9 @@ func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates m
|
||||
result, err := client.ExecContext(
|
||||
ctx,
|
||||
"UPDATE accounts SET extra = COALESCE(extra, '{}'::jsonb) || $1::jsonb, updated_at = NOW() WHERE id = $2 AND deleted_at IS NULL",
|
||||
payload, id,
|
||||
string(payload), id,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
128
backend/internal/repository/error_passthrough_cache.go
Normal file
128
backend/internal/repository/error_passthrough_cache.go
Normal file
@@ -0,0 +1,128 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const (
|
||||
errorPassthroughCacheKey = "error_passthrough_rules"
|
||||
errorPassthroughPubSubKey = "error_passthrough_rules_updated"
|
||||
errorPassthroughCacheTTL = 24 * time.Hour
|
||||
)
|
||||
|
||||
type errorPassthroughCache struct {
|
||||
rdb *redis.Client
|
||||
localCache []*model.ErrorPassthroughRule
|
||||
localMu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewErrorPassthroughCache 创建错误透传规则缓存
|
||||
func NewErrorPassthroughCache(rdb *redis.Client) service.ErrorPassthroughCache {
|
||||
return &errorPassthroughCache{
|
||||
rdb: rdb,
|
||||
}
|
||||
}
|
||||
|
||||
// Get 从缓存获取规则列表
|
||||
func (c *errorPassthroughCache) Get(ctx context.Context) ([]*model.ErrorPassthroughRule, bool) {
|
||||
// 先检查本地缓存
|
||||
c.localMu.RLock()
|
||||
if c.localCache != nil {
|
||||
rules := c.localCache
|
||||
c.localMu.RUnlock()
|
||||
return rules, true
|
||||
}
|
||||
c.localMu.RUnlock()
|
||||
|
||||
// 从 Redis 获取
|
||||
data, err := c.rdb.Get(ctx, errorPassthroughCacheKey).Bytes()
|
||||
if err != nil {
|
||||
if err != redis.Nil {
|
||||
log.Printf("[ErrorPassthroughCache] Failed to get from Redis: %v", err)
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
var rules []*model.ErrorPassthroughRule
|
||||
if err := json.Unmarshal(data, &rules); err != nil {
|
||||
log.Printf("[ErrorPassthroughCache] Failed to unmarshal rules: %v", err)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// 更新本地缓存
|
||||
c.localMu.Lock()
|
||||
c.localCache = rules
|
||||
c.localMu.Unlock()
|
||||
|
||||
return rules, true
|
||||
}
|
||||
|
||||
// Set 设置缓存
|
||||
func (c *errorPassthroughCache) Set(ctx context.Context, rules []*model.ErrorPassthroughRule) error {
|
||||
data, err := json.Marshal(rules)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := c.rdb.Set(ctx, errorPassthroughCacheKey, data, errorPassthroughCacheTTL).Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 更新本地缓存
|
||||
c.localMu.Lock()
|
||||
c.localCache = rules
|
||||
c.localMu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Invalidate 使缓存失效
|
||||
func (c *errorPassthroughCache) Invalidate(ctx context.Context) error {
|
||||
// 清除本地缓存
|
||||
c.localMu.Lock()
|
||||
c.localCache = nil
|
||||
c.localMu.Unlock()
|
||||
|
||||
// 清除 Redis 缓存
|
||||
return c.rdb.Del(ctx, errorPassthroughCacheKey).Err()
|
||||
}
|
||||
|
||||
// NotifyUpdate 通知其他实例刷新缓存
|
||||
func (c *errorPassthroughCache) NotifyUpdate(ctx context.Context) error {
|
||||
return c.rdb.Publish(ctx, errorPassthroughPubSubKey, "refresh").Err()
|
||||
}
|
||||
|
||||
// SubscribeUpdates 订阅缓存更新通知
|
||||
func (c *errorPassthroughCache) SubscribeUpdates(ctx context.Context, handler func()) {
|
||||
go func() {
|
||||
sub := c.rdb.Subscribe(ctx, errorPassthroughPubSubKey)
|
||||
defer func() { _ = sub.Close() }()
|
||||
|
||||
ch := sub.Channel()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case msg := <-ch:
|
||||
if msg == nil {
|
||||
return
|
||||
}
|
||||
// 清除本地缓存,下次访问时会从 Redis 或数据库重新加载
|
||||
c.localMu.Lock()
|
||||
c.localCache = nil
|
||||
c.localMu.Unlock()
|
||||
|
||||
// 调用处理函数
|
||||
handler()
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
178
backend/internal/repository/error_passthrough_repo.go
Normal file
178
backend/internal/repository/error_passthrough_repo.go
Normal file
@@ -0,0 +1,178 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
type errorPassthroughRepository struct {
|
||||
client *ent.Client
|
||||
}
|
||||
|
||||
// NewErrorPassthroughRepository 创建错误透传规则仓库
|
||||
func NewErrorPassthroughRepository(client *ent.Client) service.ErrorPassthroughRepository {
|
||||
return &errorPassthroughRepository{client: client}
|
||||
}
|
||||
|
||||
// List 获取所有规则
|
||||
func (r *errorPassthroughRepository) List(ctx context.Context) ([]*model.ErrorPassthroughRule, error) {
|
||||
rules, err := r.client.ErrorPassthroughRule.Query().
|
||||
Order(ent.Asc(errorpassthroughrule.FieldPriority)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := make([]*model.ErrorPassthroughRule, len(rules))
|
||||
for i, rule := range rules {
|
||||
result[i] = r.toModel(rule)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetByID 根据 ID 获取规则
|
||||
func (r *errorPassthroughRepository) GetByID(ctx context.Context, id int64) (*model.ErrorPassthroughRule, error) {
|
||||
rule, err := r.client.ErrorPassthroughRule.Get(ctx, id)
|
||||
if err != nil {
|
||||
if ent.IsNotFound(err) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return r.toModel(rule), nil
|
||||
}
|
||||
|
||||
// Create 创建规则
|
||||
func (r *errorPassthroughRepository) Create(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) {
|
||||
builder := r.client.ErrorPassthroughRule.Create().
|
||||
SetName(rule.Name).
|
||||
SetEnabled(rule.Enabled).
|
||||
SetPriority(rule.Priority).
|
||||
SetMatchMode(rule.MatchMode).
|
||||
SetPassthroughCode(rule.PassthroughCode).
|
||||
SetPassthroughBody(rule.PassthroughBody)
|
||||
|
||||
if len(rule.ErrorCodes) > 0 {
|
||||
builder.SetErrorCodes(rule.ErrorCodes)
|
||||
}
|
||||
if len(rule.Keywords) > 0 {
|
||||
builder.SetKeywords(rule.Keywords)
|
||||
}
|
||||
if len(rule.Platforms) > 0 {
|
||||
builder.SetPlatforms(rule.Platforms)
|
||||
}
|
||||
if rule.ResponseCode != nil {
|
||||
builder.SetResponseCode(*rule.ResponseCode)
|
||||
}
|
||||
if rule.CustomMessage != nil {
|
||||
builder.SetCustomMessage(*rule.CustomMessage)
|
||||
}
|
||||
if rule.Description != nil {
|
||||
builder.SetDescription(*rule.Description)
|
||||
}
|
||||
|
||||
created, err := builder.Save(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return r.toModel(created), nil
|
||||
}
|
||||
|
||||
// Update 更新规则
|
||||
func (r *errorPassthroughRepository) Update(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) {
|
||||
builder := r.client.ErrorPassthroughRule.UpdateOneID(rule.ID).
|
||||
SetName(rule.Name).
|
||||
SetEnabled(rule.Enabled).
|
||||
SetPriority(rule.Priority).
|
||||
SetMatchMode(rule.MatchMode).
|
||||
SetPassthroughCode(rule.PassthroughCode).
|
||||
SetPassthroughBody(rule.PassthroughBody)
|
||||
|
||||
// 处理可选字段
|
||||
if len(rule.ErrorCodes) > 0 {
|
||||
builder.SetErrorCodes(rule.ErrorCodes)
|
||||
} else {
|
||||
builder.ClearErrorCodes()
|
||||
}
|
||||
if len(rule.Keywords) > 0 {
|
||||
builder.SetKeywords(rule.Keywords)
|
||||
} else {
|
||||
builder.ClearKeywords()
|
||||
}
|
||||
if len(rule.Platforms) > 0 {
|
||||
builder.SetPlatforms(rule.Platforms)
|
||||
} else {
|
||||
builder.ClearPlatforms()
|
||||
}
|
||||
if rule.ResponseCode != nil {
|
||||
builder.SetResponseCode(*rule.ResponseCode)
|
||||
} else {
|
||||
builder.ClearResponseCode()
|
||||
}
|
||||
if rule.CustomMessage != nil {
|
||||
builder.SetCustomMessage(*rule.CustomMessage)
|
||||
} else {
|
||||
builder.ClearCustomMessage()
|
||||
}
|
||||
if rule.Description != nil {
|
||||
builder.SetDescription(*rule.Description)
|
||||
} else {
|
||||
builder.ClearDescription()
|
||||
}
|
||||
|
||||
updated, err := builder.Save(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return r.toModel(updated), nil
|
||||
}
|
||||
|
||||
// Delete 删除规则
|
||||
func (r *errorPassthroughRepository) Delete(ctx context.Context, id int64) error {
|
||||
return r.client.ErrorPassthroughRule.DeleteOneID(id).Exec(ctx)
|
||||
}
|
||||
|
||||
// toModel 将 Ent 实体转换为服务模型
|
||||
func (r *errorPassthroughRepository) toModel(e *ent.ErrorPassthroughRule) *model.ErrorPassthroughRule {
|
||||
rule := &model.ErrorPassthroughRule{
|
||||
ID: int64(e.ID),
|
||||
Name: e.Name,
|
||||
Enabled: e.Enabled,
|
||||
Priority: e.Priority,
|
||||
ErrorCodes: e.ErrorCodes,
|
||||
Keywords: e.Keywords,
|
||||
MatchMode: e.MatchMode,
|
||||
Platforms: e.Platforms,
|
||||
PassthroughCode: e.PassthroughCode,
|
||||
PassthroughBody: e.PassthroughBody,
|
||||
CreatedAt: e.CreatedAt,
|
||||
UpdatedAt: e.UpdatedAt,
|
||||
}
|
||||
|
||||
if e.ResponseCode != nil {
|
||||
rule.ResponseCode = e.ResponseCode
|
||||
}
|
||||
if e.CustomMessage != nil {
|
||||
rule.CustomMessage = e.CustomMessage
|
||||
}
|
||||
if e.Description != nil {
|
||||
rule.Description = e.Description
|
||||
}
|
||||
|
||||
// 确保切片不为 nil
|
||||
if rule.ErrorCodes == nil {
|
||||
rule.ErrorCodes = []int{}
|
||||
}
|
||||
if rule.Keywords == nil {
|
||||
rule.Keywords = []string{}
|
||||
}
|
||||
if rule.Platforms == nil {
|
||||
rule.Platforms = []string{}
|
||||
}
|
||||
|
||||
return rule
|
||||
}
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/imroc/req/v3"
|
||||
@@ -38,9 +39,20 @@ func (c *geminiCliCodeAssistClient) LoadCodeAssist(ctx context.Context, accessTo
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
if !resp.IsSuccessState() {
|
||||
body := geminicli.SanitizeBodyForLogs(resp.String())
|
||||
fmt.Printf("[CodeAssist] LoadCodeAssist failed: status %d, body: %s\n", resp.StatusCode, body)
|
||||
return nil, fmt.Errorf("loadCodeAssist failed: status %d, body: %s", resp.StatusCode, body)
|
||||
body := resp.String()
|
||||
sanitizedBody := geminicli.SanitizeBodyForLogs(body)
|
||||
fmt.Printf("[CodeAssist] LoadCodeAssist failed: status %d, body: %s\n", resp.StatusCode, sanitizedBody)
|
||||
|
||||
// Check if this is a SERVICE_DISABLED error and extract activation URL
|
||||
if googleapi.IsServiceDisabledError(body) {
|
||||
activationURL := googleapi.ExtractActivationURL(body)
|
||||
if activationURL != "" {
|
||||
return nil, fmt.Errorf("gemini API not enabled for this project, please enable it by visiting: %s\n\nAfter enabling the API, wait a few minutes for the changes to propagate, then try again", activationURL)
|
||||
}
|
||||
return nil, fmt.Errorf("gemini API not enabled for this project, please enable it in the Google Cloud Console at: https://console.cloud.google.com/apis/library/cloudaicompanion.googleapis.com")
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("loadCodeAssist failed: status %d, body: %s", resp.StatusCode, sanitizedBody)
|
||||
}
|
||||
fmt.Printf("[CodeAssist] LoadCodeAssist success: status %d, response: %+v\n", resp.StatusCode, out)
|
||||
return &out, nil
|
||||
@@ -67,9 +79,20 @@ func (c *geminiCliCodeAssistClient) OnboardUser(ctx context.Context, accessToken
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
if !resp.IsSuccessState() {
|
||||
body := geminicli.SanitizeBodyForLogs(resp.String())
|
||||
fmt.Printf("[CodeAssist] OnboardUser failed: status %d, body: %s\n", resp.StatusCode, body)
|
||||
return nil, fmt.Errorf("onboardUser failed: status %d, body: %s", resp.StatusCode, body)
|
||||
body := resp.String()
|
||||
sanitizedBody := geminicli.SanitizeBodyForLogs(body)
|
||||
fmt.Printf("[CodeAssist] OnboardUser failed: status %d, body: %s\n", resp.StatusCode, sanitizedBody)
|
||||
|
||||
// Check if this is a SERVICE_DISABLED error and extract activation URL
|
||||
if googleapi.IsServiceDisabledError(body) {
|
||||
activationURL := googleapi.ExtractActivationURL(body)
|
||||
if activationURL != "" {
|
||||
return nil, fmt.Errorf("gemini API not enabled for this project, please enable it by visiting: %s\n\nAfter enabling the API, wait a few minutes for the changes to propagate, then try again", activationURL)
|
||||
}
|
||||
return nil, fmt.Errorf("gemini API not enabled for this project, please enable it in the Google Cloud Console at: https://console.cloud.google.com/apis/library/cloudaicompanion.googleapis.com")
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("onboardUser failed: status %d, body: %s", resp.StatusCode, sanitizedBody)
|
||||
}
|
||||
fmt.Printf("[CodeAssist] OnboardUser success: status %d, response: %+v\n", resp.StatusCode, out)
|
||||
return &out, nil
|
||||
|
||||
@@ -3,6 +3,7 @@ package repository
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
@@ -153,6 +154,21 @@ func NewSessionLimitCache(rdb *redis.Client, defaultIdleTimeoutMinutes int) serv
|
||||
if defaultIdleTimeoutMinutes <= 0 {
|
||||
defaultIdleTimeoutMinutes = 5 // 默认 5 分钟
|
||||
}
|
||||
|
||||
// 预加载 Lua 脚本到 Redis,避免 Pipeline 中出现 NOSCRIPT 错误
|
||||
ctx := context.Background()
|
||||
scripts := []*redis.Script{
|
||||
registerSessionScript,
|
||||
refreshSessionScript,
|
||||
getActiveSessionCountScript,
|
||||
isSessionActiveScript,
|
||||
}
|
||||
for _, script := range scripts {
|
||||
if err := script.Load(ctx, rdb).Err(); err != nil {
|
||||
log.Printf("[SessionLimitCache] Failed to preload Lua script: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return &sessionLimitCache{
|
||||
rdb: rdb,
|
||||
defaultIdleTimeout: time.Duration(defaultIdleTimeoutMinutes) * time.Minute,
|
||||
|
||||
113
backend/internal/repository/user_group_rate_repo.go
Normal file
113
backend/internal/repository/user_group_rate_repo.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
type userGroupRateRepository struct {
|
||||
sql sqlExecutor
|
||||
}
|
||||
|
||||
// NewUserGroupRateRepository 创建用户专属分组倍率仓储
|
||||
func NewUserGroupRateRepository(sqlDB *sql.DB) service.UserGroupRateRepository {
|
||||
return &userGroupRateRepository{sql: sqlDB}
|
||||
}
|
||||
|
||||
// GetByUserID 获取用户的所有专属分组倍率
|
||||
func (r *userGroupRateRepository) GetByUserID(ctx context.Context, userID int64) (map[int64]float64, error) {
|
||||
query := `SELECT group_id, rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1`
|
||||
rows, err := r.sql.QueryContext(ctx, query, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
result := make(map[int64]float64)
|
||||
for rows.Next() {
|
||||
var groupID int64
|
||||
var rate float64
|
||||
if err := rows.Scan(&groupID, &rate); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result[groupID] = rate
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetByUserAndGroup 获取用户在特定分组的专属倍率
|
||||
func (r *userGroupRateRepository) GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) {
|
||||
query := `SELECT rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2`
|
||||
var rate float64
|
||||
err := scanSingleRow(ctx, r.sql, query, []any{userID, groupID}, &rate)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &rate, nil
|
||||
}
|
||||
|
||||
// SyncUserGroupRates 同步用户的分组专属倍率
|
||||
func (r *userGroupRateRepository) SyncUserGroupRates(ctx context.Context, userID int64, rates map[int64]*float64) error {
|
||||
if len(rates) == 0 {
|
||||
// 如果传入空 map,删除该用户的所有专属倍率
|
||||
_, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE user_id = $1`, userID)
|
||||
return err
|
||||
}
|
||||
|
||||
// 分离需要删除和需要 upsert 的记录
|
||||
var toDelete []int64
|
||||
toUpsert := make(map[int64]float64)
|
||||
for groupID, rate := range rates {
|
||||
if rate == nil {
|
||||
toDelete = append(toDelete, groupID)
|
||||
} else {
|
||||
toUpsert[groupID] = *rate
|
||||
}
|
||||
}
|
||||
|
||||
// 删除指定的记录
|
||||
for _, groupID := range toDelete {
|
||||
_, err := r.sql.ExecContext(ctx,
|
||||
`DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2`,
|
||||
userID, groupID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Upsert 记录
|
||||
now := time.Now()
|
||||
for groupID, rate := range toUpsert {
|
||||
_, err := r.sql.ExecContext(ctx, `
|
||||
INSERT INTO user_group_rate_multipliers (user_id, group_id, rate_multiplier, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $4)
|
||||
ON CONFLICT (user_id, group_id) DO UPDATE SET rate_multiplier = $3, updated_at = $4
|
||||
`, userID, groupID, rate, now)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteByGroupID 删除指定分组的所有用户专属倍率
|
||||
func (r *userGroupRateRepository) DeleteByGroupID(ctx context.Context, groupID int64) error {
|
||||
_, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE group_id = $1`, groupID)
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteByUserID 删除指定用户的所有专属倍率
|
||||
func (r *userGroupRateRepository) DeleteByUserID(ctx context.Context, userID int64) error {
|
||||
_, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE user_id = $1`, userID)
|
||||
return err
|
||||
}
|
||||
@@ -66,6 +66,8 @@ var ProviderSet = wire.NewSet(
|
||||
NewUserSubscriptionRepository,
|
||||
NewUserAttributeDefinitionRepository,
|
||||
NewUserAttributeValueRepository,
|
||||
NewUserGroupRateRepository,
|
||||
NewErrorPassthroughRepository,
|
||||
|
||||
// Cache implementations
|
||||
NewGatewayCache,
|
||||
@@ -86,6 +88,7 @@ var ProviderSet = wire.NewSet(
|
||||
NewProxyLatencyCache,
|
||||
NewTotpCache,
|
||||
NewRefreshTokenCache,
|
||||
NewErrorPassthroughCache,
|
||||
|
||||
// Encryptors
|
||||
NewAESEncryptor,
|
||||
|
||||
@@ -593,7 +593,7 @@ func newContractDeps(t *testing.T) *contractDeps {
|
||||
}
|
||||
|
||||
userService := service.NewUserService(userRepo, nil)
|
||||
apiKeyService := service.NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, apiKeyCache, cfg)
|
||||
apiKeyService := service.NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, nil, apiKeyCache, cfg)
|
||||
|
||||
usageRepo := newStubUsageLogRepo()
|
||||
usageService := service.NewUsageService(usageRepo, userRepo, nil, nil)
|
||||
@@ -607,7 +607,7 @@ func newContractDeps(t *testing.T) *contractDeps {
|
||||
settingRepo := newStubSettingRepo()
|
||||
settingService := service.NewSettingService(settingRepo, cfg)
|
||||
|
||||
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil)
|
||||
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil)
|
||||
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil)
|
||||
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
||||
|
||||
@@ -93,6 +93,7 @@ func newTestAPIKeyService(repo service.APIKeyRepository) *service.APIKeyService
|
||||
nil, // userRepo (unused in GetByKey)
|
||||
nil, // groupRepo
|
||||
nil, // userSubRepo
|
||||
nil, // userGroupRateRepo
|
||||
nil, // cache
|
||||
&config.Config{},
|
||||
)
|
||||
@@ -187,6 +188,7 @@ func TestApiKeyAuthWithSubscriptionGoogleSetsGroupContext(t *testing.T) {
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
&config.Config{RunMode: config.RunModeSimple},
|
||||
)
|
||||
|
||||
|
||||
@@ -59,7 +59,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
|
||||
|
||||
t.Run("simple_mode_bypasses_quota_check", func(t *testing.T) {
|
||||
cfg := &config.Config{RunMode: config.RunModeSimple}
|
||||
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
|
||||
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
|
||||
subscriptionService := service.NewSubscriptionService(nil, &stubUserSubscriptionRepo{}, nil)
|
||||
router := newAuthTestRouter(apiKeyService, subscriptionService, cfg)
|
||||
|
||||
@@ -73,7 +73,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
|
||||
|
||||
t.Run("standard_mode_enforces_quota_check", func(t *testing.T) {
|
||||
cfg := &config.Config{RunMode: config.RunModeStandard}
|
||||
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
|
||||
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
|
||||
|
||||
now := time.Now()
|
||||
sub := &service.UserSubscription{
|
||||
@@ -150,7 +150,7 @@ func TestAPIKeyAuthSetsGroupContext(t *testing.T) {
|
||||
}
|
||||
|
||||
cfg := &config.Config{RunMode: config.RunModeSimple}
|
||||
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
|
||||
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
|
||||
router := gin.New()
|
||||
router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, nil, cfg)))
|
||||
router.GET("/t", func(c *gin.Context) {
|
||||
@@ -208,7 +208,7 @@ func TestAPIKeyAuthOverwritesInvalidContextGroup(t *testing.T) {
|
||||
}
|
||||
|
||||
cfg := &config.Config{RunMode: config.RunModeSimple}
|
||||
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
|
||||
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
|
||||
router := gin.New()
|
||||
router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, nil, cfg)))
|
||||
|
||||
|
||||
@@ -67,6 +67,9 @@ func RegisterAdminRoutes(
|
||||
|
||||
// 用户属性管理
|
||||
registerUserAttributeRoutes(admin, h)
|
||||
|
||||
// 错误透传规则管理
|
||||
registerErrorPassthroughRoutes(admin, h)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -391,3 +394,14 @@ func registerUserAttributeRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
attrs.DELETE("/:id", h.Admin.UserAttribute.DeleteDefinition)
|
||||
}
|
||||
}
|
||||
|
||||
func registerErrorPassthroughRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
rules := admin.Group("/error-passthrough-rules")
|
||||
{
|
||||
rules.GET("", h.Admin.ErrorPassthrough.List)
|
||||
rules.GET("/:id", h.Admin.ErrorPassthrough.GetByID)
|
||||
rules.POST("", h.Admin.ErrorPassthrough.Create)
|
||||
rules.PUT("/:id", h.Admin.ErrorPassthrough.Update)
|
||||
rules.DELETE("/:id", h.Admin.ErrorPassthrough.Delete)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -49,6 +49,7 @@ func RegisterUserRoutes(
|
||||
groups := authenticated.Group("/groups")
|
||||
{
|
||||
groups.GET("/available", h.APIKey.GetAvailableGroups)
|
||||
groups.GET("/rates", h.APIKey.GetUserGroupRates)
|
||||
}
|
||||
|
||||
// 使用记录
|
||||
|
||||
@@ -94,6 +94,9 @@ type UpdateUserInput struct {
|
||||
Concurrency *int // 使用指针区分"未提供"和"设置为0"
|
||||
Status string
|
||||
AllowedGroups *[]int64 // 使用指针区分"未提供"和"设置为空数组"
|
||||
// GroupRates 用户专属分组倍率配置
|
||||
// map[groupID]*rate,nil 表示删除该分组的专属倍率
|
||||
GroupRates map[int64]*float64
|
||||
}
|
||||
|
||||
type CreateGroupInput struct {
|
||||
@@ -296,6 +299,7 @@ type adminServiceImpl struct {
|
||||
proxyRepo ProxyRepository
|
||||
apiKeyRepo APIKeyRepository
|
||||
redeemCodeRepo RedeemCodeRepository
|
||||
userGroupRateRepo UserGroupRateRepository
|
||||
billingCacheService *BillingCacheService
|
||||
proxyProber ProxyExitInfoProber
|
||||
proxyLatencyCache ProxyLatencyCache
|
||||
@@ -310,6 +314,7 @@ func NewAdminService(
|
||||
proxyRepo ProxyRepository,
|
||||
apiKeyRepo APIKeyRepository,
|
||||
redeemCodeRepo RedeemCodeRepository,
|
||||
userGroupRateRepo UserGroupRateRepository,
|
||||
billingCacheService *BillingCacheService,
|
||||
proxyProber ProxyExitInfoProber,
|
||||
proxyLatencyCache ProxyLatencyCache,
|
||||
@@ -322,6 +327,7 @@ func NewAdminService(
|
||||
proxyRepo: proxyRepo,
|
||||
apiKeyRepo: apiKeyRepo,
|
||||
redeemCodeRepo: redeemCodeRepo,
|
||||
userGroupRateRepo: userGroupRateRepo,
|
||||
billingCacheService: billingCacheService,
|
||||
proxyProber: proxyProber,
|
||||
proxyLatencyCache: proxyLatencyCache,
|
||||
@@ -336,11 +342,35 @@ func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, fi
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
// 批量加载用户专属分组倍率
|
||||
if s.userGroupRateRepo != nil && len(users) > 0 {
|
||||
for i := range users {
|
||||
rates, err := s.userGroupRateRepo.GetByUserID(ctx, users[i].ID)
|
||||
if err != nil {
|
||||
log.Printf("failed to load user group rates: user_id=%d err=%v", users[i].ID, err)
|
||||
continue
|
||||
}
|
||||
users[i].GroupRates = rates
|
||||
}
|
||||
}
|
||||
return users, result.Total, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error) {
|
||||
return s.userRepo.GetByID(ctx, id)
|
||||
user, err := s.userRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 加载用户专属分组倍率
|
||||
if s.userGroupRateRepo != nil {
|
||||
rates, err := s.userGroupRateRepo.GetByUserID(ctx, id)
|
||||
if err != nil {
|
||||
log.Printf("failed to load user group rates: user_id=%d err=%v", id, err)
|
||||
} else {
|
||||
user.GroupRates = rates
|
||||
}
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInput) (*User, error) {
|
||||
@@ -409,6 +439,14 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
|
||||
if err := s.userRepo.Update(ctx, user); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 同步用户专属分组倍率
|
||||
if input.GroupRates != nil && s.userGroupRateRepo != nil {
|
||||
if err := s.userGroupRateRepo.SyncUserGroupRates(ctx, user.ID, input.GroupRates); err != nil {
|
||||
log.Printf("failed to sync user group rates: user_id=%d err=%v", user.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
if s.authCacheInvalidator != nil {
|
||||
if user.Concurrency != oldConcurrency || user.Status != oldStatus || user.Role != oldRole {
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, user.ID)
|
||||
@@ -944,6 +982,7 @@ func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// 注意:user_group_rate_multipliers 表通过外键 ON DELETE CASCADE 自动清理
|
||||
|
||||
// 事务成功后,异步失效受影响用户的订阅缓存
|
||||
if len(affectedUserIDs) > 0 && s.billingCacheService != nil {
|
||||
|
||||
@@ -1106,7 +1106,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
|
||||
}
|
||||
|
||||
return nil, s.writeMappedClaudeError(c, account, resp.StatusCode, resp.Header.Get("x-request-id"), respBody)
|
||||
@@ -1779,6 +1779,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
||||
// 处理错误响应
|
||||
if resp.StatusCode >= 400 {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
// 尽早关闭原始响应体,释放连接;后续逻辑仍可能需要读取 body,因此用内存副本重新包装。
|
||||
_ = resp.Body.Close()
|
||||
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||||
@@ -1849,10 +1850,8 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: unwrappedForOps}
|
||||
}
|
||||
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
if contentType == "" {
|
||||
contentType = "application/json"
|
||||
}
|
||||
|
||||
@@ -115,15 +115,16 @@ type UpdateAPIKeyRequest struct {
|
||||
|
||||
// APIKeyService API Key服务
|
||||
type APIKeyService struct {
|
||||
apiKeyRepo APIKeyRepository
|
||||
userRepo UserRepository
|
||||
groupRepo GroupRepository
|
||||
userSubRepo UserSubscriptionRepository
|
||||
cache APIKeyCache
|
||||
cfg *config.Config
|
||||
authCacheL1 *ristretto.Cache
|
||||
authCfg apiKeyAuthCacheConfig
|
||||
authGroup singleflight.Group
|
||||
apiKeyRepo APIKeyRepository
|
||||
userRepo UserRepository
|
||||
groupRepo GroupRepository
|
||||
userSubRepo UserSubscriptionRepository
|
||||
userGroupRateRepo UserGroupRateRepository
|
||||
cache APIKeyCache
|
||||
cfg *config.Config
|
||||
authCacheL1 *ristretto.Cache
|
||||
authCfg apiKeyAuthCacheConfig
|
||||
authGroup singleflight.Group
|
||||
}
|
||||
|
||||
// NewAPIKeyService 创建API Key服务实例
|
||||
@@ -132,16 +133,18 @@ func NewAPIKeyService(
|
||||
userRepo UserRepository,
|
||||
groupRepo GroupRepository,
|
||||
userSubRepo UserSubscriptionRepository,
|
||||
userGroupRateRepo UserGroupRateRepository,
|
||||
cache APIKeyCache,
|
||||
cfg *config.Config,
|
||||
) *APIKeyService {
|
||||
svc := &APIKeyService{
|
||||
apiKeyRepo: apiKeyRepo,
|
||||
userRepo: userRepo,
|
||||
groupRepo: groupRepo,
|
||||
userSubRepo: userSubRepo,
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
apiKeyRepo: apiKeyRepo,
|
||||
userRepo: userRepo,
|
||||
groupRepo: groupRepo,
|
||||
userSubRepo: userSubRepo,
|
||||
userGroupRateRepo: userGroupRateRepo,
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
}
|
||||
svc.initAuthCache(cfg)
|
||||
return svc
|
||||
@@ -627,6 +630,19 @@ func (s *APIKeyService) SearchAPIKeys(ctx context.Context, userID int64, keyword
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
// GetUserGroupRates 获取用户的专属分组倍率配置
|
||||
// 返回 map[groupID]rateMultiplier
|
||||
func (s *APIKeyService) GetUserGroupRates(ctx context.Context, userID int64) (map[int64]float64, error) {
|
||||
if s.userGroupRateRepo == nil {
|
||||
return nil, nil
|
||||
}
|
||||
rates, err := s.userGroupRateRepo.GetByUserID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get user group rates: %w", err)
|
||||
}
|
||||
return rates, nil
|
||||
}
|
||||
|
||||
// CheckAPIKeyQuotaAndExpiry checks if the API key is valid for use (not expired, quota not exhausted)
|
||||
// Returns nil if valid, error if invalid
|
||||
func (s *APIKeyService) CheckAPIKeyQuotaAndExpiry(apiKey *APIKey) error {
|
||||
|
||||
@@ -167,7 +167,7 @@ func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) {
|
||||
NegativeTTLSeconds: 30,
|
||||
},
|
||||
}
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
|
||||
|
||||
groupID := int64(9)
|
||||
cacheEntry := &APIKeyAuthCacheEntry{
|
||||
@@ -223,7 +223,7 @@ func TestAPIKeyService_GetByKey_NegativeCache(t *testing.T) {
|
||||
NegativeTTLSeconds: 30,
|
||||
},
|
||||
}
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
|
||||
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
|
||||
return &APIKeyAuthCacheEntry{NotFound: true}, nil
|
||||
}
|
||||
@@ -256,7 +256,7 @@ func TestAPIKeyService_GetByKey_CacheMissStoresL2(t *testing.T) {
|
||||
NegativeTTLSeconds: 30,
|
||||
},
|
||||
}
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
|
||||
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
|
||||
return nil, redis.Nil
|
||||
}
|
||||
@@ -293,7 +293,7 @@ func TestAPIKeyService_GetByKey_UsesL1Cache(t *testing.T) {
|
||||
L1TTLSeconds: 60,
|
||||
},
|
||||
}
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
|
||||
require.NotNil(t, svc.authCacheL1)
|
||||
|
||||
_, err := svc.GetByKey(context.Background(), "k-l1")
|
||||
@@ -320,7 +320,7 @@ func TestAPIKeyService_InvalidateAuthCacheByUserID(t *testing.T) {
|
||||
NegativeTTLSeconds: 30,
|
||||
},
|
||||
}
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
|
||||
|
||||
svc.InvalidateAuthCacheByUserID(context.Background(), 7)
|
||||
require.Len(t, cache.deleteAuthKeys, 2)
|
||||
@@ -338,7 +338,7 @@ func TestAPIKeyService_InvalidateAuthCacheByGroupID(t *testing.T) {
|
||||
L2TTLSeconds: 60,
|
||||
},
|
||||
}
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
|
||||
|
||||
svc.InvalidateAuthCacheByGroupID(context.Background(), 9)
|
||||
require.Len(t, cache.deleteAuthKeys, 2)
|
||||
@@ -356,7 +356,7 @@ func TestAPIKeyService_InvalidateAuthCacheByKey(t *testing.T) {
|
||||
L2TTLSeconds: 60,
|
||||
},
|
||||
}
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
|
||||
|
||||
svc.InvalidateAuthCacheByKey(context.Background(), "k1")
|
||||
require.Len(t, cache.deleteAuthKeys, 1)
|
||||
@@ -375,7 +375,7 @@ func TestAPIKeyService_GetByKey_CachesNegativeOnRepoMiss(t *testing.T) {
|
||||
NegativeTTLSeconds: 30,
|
||||
},
|
||||
}
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
|
||||
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
|
||||
return nil, redis.Nil
|
||||
}
|
||||
@@ -411,7 +411,7 @@ func TestAPIKeyService_GetByKey_SingleflightCollapses(t *testing.T) {
|
||||
Singleflight: true,
|
||||
},
|
||||
}
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
|
||||
|
||||
start := make(chan struct{})
|
||||
wg := sync.WaitGroup{}
|
||||
|
||||
300
backend/internal/service/error_passthrough_service.go
Normal file
300
backend/internal/service/error_passthrough_service.go
Normal file
@@ -0,0 +1,300 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
)
|
||||
|
||||
// ErrorPassthroughRepository 定义错误透传规则的数据访问接口
|
||||
type ErrorPassthroughRepository interface {
|
||||
// List 获取所有规则
|
||||
List(ctx context.Context) ([]*model.ErrorPassthroughRule, error)
|
||||
// GetByID 根据 ID 获取规则
|
||||
GetByID(ctx context.Context, id int64) (*model.ErrorPassthroughRule, error)
|
||||
// Create 创建规则
|
||||
Create(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error)
|
||||
// Update 更新规则
|
||||
Update(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error)
|
||||
// Delete 删除规则
|
||||
Delete(ctx context.Context, id int64) error
|
||||
}
|
||||
|
||||
// ErrorPassthroughCache 定义错误透传规则的缓存接口
|
||||
type ErrorPassthroughCache interface {
|
||||
// Get 从缓存获取规则列表
|
||||
Get(ctx context.Context) ([]*model.ErrorPassthroughRule, bool)
|
||||
// Set 设置缓存
|
||||
Set(ctx context.Context, rules []*model.ErrorPassthroughRule) error
|
||||
// Invalidate 使缓存失效
|
||||
Invalidate(ctx context.Context) error
|
||||
// NotifyUpdate 通知其他实例刷新缓存
|
||||
NotifyUpdate(ctx context.Context) error
|
||||
// SubscribeUpdates 订阅缓存更新通知
|
||||
SubscribeUpdates(ctx context.Context, handler func())
|
||||
}
|
||||
|
||||
// ErrorPassthroughService 错误透传规则服务
|
||||
type ErrorPassthroughService struct {
|
||||
repo ErrorPassthroughRepository
|
||||
cache ErrorPassthroughCache
|
||||
|
||||
// 本地内存缓存,用于快速匹配
|
||||
localCache []*model.ErrorPassthroughRule
|
||||
localCacheMu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewErrorPassthroughService 创建错误透传规则服务
|
||||
func NewErrorPassthroughService(
|
||||
repo ErrorPassthroughRepository,
|
||||
cache ErrorPassthroughCache,
|
||||
) *ErrorPassthroughService {
|
||||
svc := &ErrorPassthroughService{
|
||||
repo: repo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
// 启动时加载规则到本地缓存
|
||||
ctx := context.Background()
|
||||
if err := svc.refreshLocalCache(ctx); err != nil {
|
||||
log.Printf("[ErrorPassthroughService] Failed to load rules on startup: %v", err)
|
||||
}
|
||||
|
||||
// 订阅缓存更新通知
|
||||
if cache != nil {
|
||||
cache.SubscribeUpdates(ctx, func() {
|
||||
if err := svc.refreshLocalCache(context.Background()); err != nil {
|
||||
log.Printf("[ErrorPassthroughService] Failed to refresh cache on notification: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
return svc
|
||||
}
|
||||
|
||||
// List 获取所有规则
|
||||
func (s *ErrorPassthroughService) List(ctx context.Context) ([]*model.ErrorPassthroughRule, error) {
|
||||
return s.repo.List(ctx)
|
||||
}
|
||||
|
||||
// GetByID 根据 ID 获取规则
|
||||
func (s *ErrorPassthroughService) GetByID(ctx context.Context, id int64) (*model.ErrorPassthroughRule, error) {
|
||||
return s.repo.GetByID(ctx, id)
|
||||
}
|
||||
|
||||
// Create 创建规则
|
||||
func (s *ErrorPassthroughService) Create(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) {
|
||||
if err := rule.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
created, err := s.repo.Create(ctx, rule)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 刷新缓存
|
||||
s.invalidateAndNotify(ctx)
|
||||
|
||||
return created, nil
|
||||
}
|
||||
|
||||
// Update 更新规则
|
||||
func (s *ErrorPassthroughService) Update(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) {
|
||||
if err := rule.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
updated, err := s.repo.Update(ctx, rule)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 刷新缓存
|
||||
s.invalidateAndNotify(ctx)
|
||||
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
// Delete 删除规则
|
||||
func (s *ErrorPassthroughService) Delete(ctx context.Context, id int64) error {
|
||||
if err := s.repo.Delete(ctx, id); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 刷新缓存
|
||||
s.invalidateAndNotify(ctx)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MatchRule 匹配透传规则
|
||||
// 返回第一个匹配的规则,如果没有匹配则返回 nil
|
||||
func (s *ErrorPassthroughService) MatchRule(platform string, statusCode int, body []byte) *model.ErrorPassthroughRule {
|
||||
rules := s.getCachedRules()
|
||||
if len(rules) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
bodyStr := strings.ToLower(string(body))
|
||||
|
||||
for _, rule := range rules {
|
||||
if !rule.Enabled {
|
||||
continue
|
||||
}
|
||||
if !s.platformMatches(rule, platform) {
|
||||
continue
|
||||
}
|
||||
if s.ruleMatches(rule, statusCode, bodyStr) {
|
||||
return rule
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getCachedRules 获取缓存的规则列表(按优先级排序)
|
||||
func (s *ErrorPassthroughService) getCachedRules() []*model.ErrorPassthroughRule {
|
||||
s.localCacheMu.RLock()
|
||||
rules := s.localCache
|
||||
s.localCacheMu.RUnlock()
|
||||
|
||||
if rules != nil {
|
||||
return rules
|
||||
}
|
||||
|
||||
// 如果本地缓存为空,尝试刷新
|
||||
ctx := context.Background()
|
||||
if err := s.refreshLocalCache(ctx); err != nil {
|
||||
log.Printf("[ErrorPassthroughService] Failed to refresh cache: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
s.localCacheMu.RLock()
|
||||
defer s.localCacheMu.RUnlock()
|
||||
return s.localCache
|
||||
}
|
||||
|
||||
// refreshLocalCache 刷新本地缓存
|
||||
func (s *ErrorPassthroughService) refreshLocalCache(ctx context.Context) error {
|
||||
// 先尝试从 Redis 缓存获取
|
||||
if s.cache != nil {
|
||||
if rules, ok := s.cache.Get(ctx); ok {
|
||||
s.setLocalCache(rules)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// 从数据库加载(repo.List 已按 priority 排序)
|
||||
rules, err := s.repo.List(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 更新 Redis 缓存
|
||||
if s.cache != nil {
|
||||
if err := s.cache.Set(ctx, rules); err != nil {
|
||||
log.Printf("[ErrorPassthroughService] Failed to set cache: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 更新本地缓存(setLocalCache 内部会确保排序)
|
||||
s.setLocalCache(rules)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// setLocalCache 设置本地缓存
|
||||
func (s *ErrorPassthroughService) setLocalCache(rules []*model.ErrorPassthroughRule) {
|
||||
// 按优先级排序
|
||||
sorted := make([]*model.ErrorPassthroughRule, len(rules))
|
||||
copy(sorted, rules)
|
||||
sort.Slice(sorted, func(i, j int) bool {
|
||||
return sorted[i].Priority < sorted[j].Priority
|
||||
})
|
||||
|
||||
s.localCacheMu.Lock()
|
||||
s.localCache = sorted
|
||||
s.localCacheMu.Unlock()
|
||||
}
|
||||
|
||||
// invalidateAndNotify 使缓存失效并通知其他实例
|
||||
func (s *ErrorPassthroughService) invalidateAndNotify(ctx context.Context) {
|
||||
// 刷新本地缓存
|
||||
if err := s.refreshLocalCache(ctx); err != nil {
|
||||
log.Printf("[ErrorPassthroughService] Failed to refresh local cache: %v", err)
|
||||
}
|
||||
|
||||
// 通知其他实例
|
||||
if s.cache != nil {
|
||||
if err := s.cache.NotifyUpdate(ctx); err != nil {
|
||||
log.Printf("[ErrorPassthroughService] Failed to notify cache update: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// platformMatches 检查平台是否匹配
|
||||
func (s *ErrorPassthroughService) platformMatches(rule *model.ErrorPassthroughRule, platform string) bool {
|
||||
// 如果没有配置平台限制,则匹配所有平台
|
||||
if len(rule.Platforms) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
platform = strings.ToLower(platform)
|
||||
for _, p := range rule.Platforms {
|
||||
if strings.ToLower(p) == platform {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// ruleMatches 检查规则是否匹配
|
||||
func (s *ErrorPassthroughService) ruleMatches(rule *model.ErrorPassthroughRule, statusCode int, bodyLower string) bool {
|
||||
hasErrorCodes := len(rule.ErrorCodes) > 0
|
||||
hasKeywords := len(rule.Keywords) > 0
|
||||
|
||||
// 如果没有配置任何条件,不匹配
|
||||
if !hasErrorCodes && !hasKeywords {
|
||||
return false
|
||||
}
|
||||
|
||||
codeMatch := !hasErrorCodes || s.containsInt(rule.ErrorCodes, statusCode)
|
||||
keywordMatch := !hasKeywords || s.containsAnyKeyword(bodyLower, rule.Keywords)
|
||||
|
||||
if rule.MatchMode == model.MatchModeAll {
|
||||
// "all" 模式:所有配置的条件都必须满足
|
||||
return codeMatch && keywordMatch
|
||||
}
|
||||
|
||||
// "any" 模式:任一条件满足即可
|
||||
if hasErrorCodes && hasKeywords {
|
||||
return codeMatch || keywordMatch
|
||||
}
|
||||
return codeMatch && keywordMatch
|
||||
}
|
||||
|
||||
// containsInt 检查切片是否包含指定整数
|
||||
func (s *ErrorPassthroughService) containsInt(slice []int, val int) bool {
|
||||
for _, v := range slice {
|
||||
if v == val {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// containsAnyKeyword 检查字符串是否包含任一关键词(不区分大小写)
|
||||
func (s *ErrorPassthroughService) containsAnyKeyword(bodyLower string, keywords []string) bool {
|
||||
for _, kw := range keywords {
|
||||
if strings.Contains(bodyLower, strings.ToLower(kw)) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
755
backend/internal/service/error_passthrough_service_test.go
Normal file
755
backend/internal/service/error_passthrough_service_test.go
Normal file
@@ -0,0 +1,755 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// mockErrorPassthroughRepo 用于测试的 mock repository
|
||||
type mockErrorPassthroughRepo struct {
|
||||
rules []*model.ErrorPassthroughRule
|
||||
}
|
||||
|
||||
func (m *mockErrorPassthroughRepo) List(ctx context.Context) ([]*model.ErrorPassthroughRule, error) {
|
||||
return m.rules, nil
|
||||
}
|
||||
|
||||
func (m *mockErrorPassthroughRepo) GetByID(ctx context.Context, id int64) (*model.ErrorPassthroughRule, error) {
|
||||
for _, r := range m.rules {
|
||||
if r.ID == id {
|
||||
return r, nil
|
||||
}
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockErrorPassthroughRepo) Create(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) {
|
||||
rule.ID = int64(len(m.rules) + 1)
|
||||
m.rules = append(m.rules, rule)
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
func (m *mockErrorPassthroughRepo) Update(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) {
|
||||
for i, r := range m.rules {
|
||||
if r.ID == rule.ID {
|
||||
m.rules[i] = rule
|
||||
return rule, nil
|
||||
}
|
||||
}
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
func (m *mockErrorPassthroughRepo) Delete(ctx context.Context, id int64) error {
|
||||
for i, r := range m.rules {
|
||||
if r.ID == id {
|
||||
m.rules = append(m.rules[:i], m.rules[i+1:]...)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// newTestService 创建测试用的服务实例
|
||||
func newTestService(rules []*model.ErrorPassthroughRule) *ErrorPassthroughService {
|
||||
repo := &mockErrorPassthroughRepo{rules: rules}
|
||||
svc := &ErrorPassthroughService{
|
||||
repo: repo,
|
||||
cache: nil, // 不使用缓存
|
||||
}
|
||||
// 直接设置本地缓存,避免调用 refreshLocalCache
|
||||
svc.setLocalCache(rules)
|
||||
return svc
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// 测试 ruleMatches 核心匹配逻辑
|
||||
// =============================================================================
|
||||
|
||||
func TestRuleMatches_NoConditions(t *testing.T) {
|
||||
// 没有配置任何条件时,不应该匹配
|
||||
svc := newTestService(nil)
|
||||
rule := &model.ErrorPassthroughRule{
|
||||
Enabled: true,
|
||||
ErrorCodes: []int{},
|
||||
Keywords: []string{},
|
||||
MatchMode: model.MatchModeAny,
|
||||
}
|
||||
|
||||
assert.False(t, svc.ruleMatches(rule, 422, "some error message"),
|
||||
"没有配置条件时不应该匹配")
|
||||
}
|
||||
|
||||
func TestRuleMatches_OnlyErrorCodes_AnyMode(t *testing.T) {
|
||||
svc := newTestService(nil)
|
||||
rule := &model.ErrorPassthroughRule{
|
||||
Enabled: true,
|
||||
ErrorCodes: []int{422, 400},
|
||||
Keywords: []string{},
|
||||
MatchMode: model.MatchModeAny,
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
statusCode int
|
||||
body string
|
||||
expected bool
|
||||
}{
|
||||
{"状态码匹配 422", 422, "any message", true},
|
||||
{"状态码匹配 400", 400, "any message", true},
|
||||
{"状态码不匹配 500", 500, "any message", false},
|
||||
{"状态码不匹配 429", 429, "any message", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := svc.ruleMatches(rule, tt.statusCode, tt.body)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuleMatches_OnlyKeywords_AnyMode(t *testing.T) {
|
||||
svc := newTestService(nil)
|
||||
rule := &model.ErrorPassthroughRule{
|
||||
Enabled: true,
|
||||
ErrorCodes: []int{},
|
||||
Keywords: []string{"context limit", "model not supported"},
|
||||
MatchMode: model.MatchModeAny,
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
statusCode int
|
||||
body string
|
||||
expected bool
|
||||
}{
|
||||
{"关键词匹配 context limit", 500, "error: context limit reached", true},
|
||||
{"关键词匹配 model not supported", 400, "the model not supported here", true},
|
||||
{"关键词不匹配", 422, "some other error", false},
|
||||
// 注意:ruleMatches 接收的 body 参数应该是已经转换为小写的
|
||||
// 实际使用时,MatchRule 会先将 body 转换为小写再传给 ruleMatches
|
||||
{"关键词大小写 - 输入已小写", 500, "context limit exceeded", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 模拟 MatchRule 的行为:先转换为小写
|
||||
bodyLower := strings.ToLower(tt.body)
|
||||
result := svc.ruleMatches(rule, tt.statusCode, bodyLower)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuleMatches_BothConditions_AnyMode(t *testing.T) {
|
||||
// any 模式:错误码 OR 关键词
|
||||
svc := newTestService(nil)
|
||||
rule := &model.ErrorPassthroughRule{
|
||||
Enabled: true,
|
||||
ErrorCodes: []int{422, 400},
|
||||
Keywords: []string{"context limit"},
|
||||
MatchMode: model.MatchModeAny,
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
statusCode int
|
||||
body string
|
||||
expected bool
|
||||
reason string
|
||||
}{
|
||||
{
|
||||
name: "状态码和关键词都匹配",
|
||||
statusCode: 422,
|
||||
body: "context limit reached",
|
||||
expected: true,
|
||||
reason: "both match",
|
||||
},
|
||||
{
|
||||
name: "只有状态码匹配",
|
||||
statusCode: 422,
|
||||
body: "some other error",
|
||||
expected: true,
|
||||
reason: "code matches, keyword doesn't - OR mode should match",
|
||||
},
|
||||
{
|
||||
name: "只有关键词匹配",
|
||||
statusCode: 500,
|
||||
body: "context limit exceeded",
|
||||
expected: true,
|
||||
reason: "keyword matches, code doesn't - OR mode should match",
|
||||
},
|
||||
{
|
||||
name: "都不匹配",
|
||||
statusCode: 500,
|
||||
body: "some other error",
|
||||
expected: false,
|
||||
reason: "neither matches",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := svc.ruleMatches(rule, tt.statusCode, tt.body)
|
||||
assert.Equal(t, tt.expected, result, tt.reason)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuleMatches_BothConditions_AllMode(t *testing.T) {
|
||||
// all 模式:错误码 AND 关键词
|
||||
svc := newTestService(nil)
|
||||
rule := &model.ErrorPassthroughRule{
|
||||
Enabled: true,
|
||||
ErrorCodes: []int{422, 400},
|
||||
Keywords: []string{"context limit"},
|
||||
MatchMode: model.MatchModeAll,
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
statusCode int
|
||||
body string
|
||||
expected bool
|
||||
reason string
|
||||
}{
|
||||
{
|
||||
name: "状态码和关键词都匹配",
|
||||
statusCode: 422,
|
||||
body: "context limit reached",
|
||||
expected: true,
|
||||
reason: "both match - AND mode should match",
|
||||
},
|
||||
{
|
||||
name: "只有状态码匹配",
|
||||
statusCode: 422,
|
||||
body: "some other error",
|
||||
expected: false,
|
||||
reason: "code matches but keyword doesn't - AND mode should NOT match",
|
||||
},
|
||||
{
|
||||
name: "只有关键词匹配",
|
||||
statusCode: 500,
|
||||
body: "context limit exceeded",
|
||||
expected: false,
|
||||
reason: "keyword matches but code doesn't - AND mode should NOT match",
|
||||
},
|
||||
{
|
||||
name: "都不匹配",
|
||||
statusCode: 500,
|
||||
body: "some other error",
|
||||
expected: false,
|
||||
reason: "neither matches",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := svc.ruleMatches(rule, tt.statusCode, tt.body)
|
||||
assert.Equal(t, tt.expected, result, tt.reason)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// 测试 platformMatches 平台匹配逻辑
|
||||
// =============================================================================
|
||||
|
||||
func TestPlatformMatches(t *testing.T) {
|
||||
svc := newTestService(nil)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
rulePlatforms []string
|
||||
requestPlatform string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "空平台列表匹配所有",
|
||||
rulePlatforms: []string{},
|
||||
requestPlatform: "anthropic",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "nil平台列表匹配所有",
|
||||
rulePlatforms: nil,
|
||||
requestPlatform: "openai",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "精确匹配 anthropic",
|
||||
rulePlatforms: []string{"anthropic", "openai"},
|
||||
requestPlatform: "anthropic",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "精确匹配 openai",
|
||||
rulePlatforms: []string{"anthropic", "openai"},
|
||||
requestPlatform: "openai",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "不匹配 gemini",
|
||||
rulePlatforms: []string{"anthropic", "openai"},
|
||||
requestPlatform: "gemini",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "大小写不敏感",
|
||||
rulePlatforms: []string{"Anthropic", "OpenAI"},
|
||||
requestPlatform: "anthropic",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "匹配 antigravity",
|
||||
rulePlatforms: []string{"antigravity"},
|
||||
requestPlatform: "antigravity",
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rule := &model.ErrorPassthroughRule{
|
||||
Platforms: tt.rulePlatforms,
|
||||
}
|
||||
result := svc.platformMatches(rule, tt.requestPlatform)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// 测试 MatchRule 完整匹配流程
|
||||
// =============================================================================
|
||||
|
||||
func TestMatchRule_Priority(t *testing.T) {
|
||||
// 测试规则按优先级排序,优先级小的先匹配
|
||||
rules := []*model.ErrorPassthroughRule{
|
||||
{
|
||||
ID: 1,
|
||||
Name: "Low Priority",
|
||||
Enabled: true,
|
||||
Priority: 10,
|
||||
ErrorCodes: []int{422},
|
||||
MatchMode: model.MatchModeAny,
|
||||
},
|
||||
{
|
||||
ID: 2,
|
||||
Name: "High Priority",
|
||||
Enabled: true,
|
||||
Priority: 1,
|
||||
ErrorCodes: []int{422},
|
||||
MatchMode: model.MatchModeAny,
|
||||
},
|
||||
}
|
||||
|
||||
svc := newTestService(rules)
|
||||
matched := svc.MatchRule("anthropic", 422, []byte("error"))
|
||||
|
||||
require.NotNil(t, matched)
|
||||
assert.Equal(t, int64(2), matched.ID, "应该匹配优先级更高(数值更小)的规则")
|
||||
assert.Equal(t, "High Priority", matched.Name)
|
||||
}
|
||||
|
||||
func TestMatchRule_DisabledRule(t *testing.T) {
|
||||
rules := []*model.ErrorPassthroughRule{
|
||||
{
|
||||
ID: 1,
|
||||
Name: "Disabled Rule",
|
||||
Enabled: false,
|
||||
Priority: 1,
|
||||
ErrorCodes: []int{422},
|
||||
MatchMode: model.MatchModeAny,
|
||||
},
|
||||
{
|
||||
ID: 2,
|
||||
Name: "Enabled Rule",
|
||||
Enabled: true,
|
||||
Priority: 10,
|
||||
ErrorCodes: []int{422},
|
||||
MatchMode: model.MatchModeAny,
|
||||
},
|
||||
}
|
||||
|
||||
svc := newTestService(rules)
|
||||
matched := svc.MatchRule("anthropic", 422, []byte("error"))
|
||||
|
||||
require.NotNil(t, matched)
|
||||
assert.Equal(t, int64(2), matched.ID, "应该跳过禁用的规则")
|
||||
}
|
||||
|
||||
func TestMatchRule_PlatformFilter(t *testing.T) {
|
||||
rules := []*model.ErrorPassthroughRule{
|
||||
{
|
||||
ID: 1,
|
||||
Name: "Anthropic Only",
|
||||
Enabled: true,
|
||||
Priority: 1,
|
||||
ErrorCodes: []int{422},
|
||||
Platforms: []string{"anthropic"},
|
||||
MatchMode: model.MatchModeAny,
|
||||
},
|
||||
{
|
||||
ID: 2,
|
||||
Name: "OpenAI Only",
|
||||
Enabled: true,
|
||||
Priority: 2,
|
||||
ErrorCodes: []int{422},
|
||||
Platforms: []string{"openai"},
|
||||
MatchMode: model.MatchModeAny,
|
||||
},
|
||||
{
|
||||
ID: 3,
|
||||
Name: "All Platforms",
|
||||
Enabled: true,
|
||||
Priority: 3,
|
||||
ErrorCodes: []int{422},
|
||||
Platforms: []string{},
|
||||
MatchMode: model.MatchModeAny,
|
||||
},
|
||||
}
|
||||
|
||||
svc := newTestService(rules)
|
||||
|
||||
t.Run("Anthropic 请求匹配 Anthropic 规则", func(t *testing.T) {
|
||||
matched := svc.MatchRule("anthropic", 422, []byte("error"))
|
||||
require.NotNil(t, matched)
|
||||
assert.Equal(t, int64(1), matched.ID)
|
||||
})
|
||||
|
||||
t.Run("OpenAI 请求匹配 OpenAI 规则", func(t *testing.T) {
|
||||
matched := svc.MatchRule("openai", 422, []byte("error"))
|
||||
require.NotNil(t, matched)
|
||||
assert.Equal(t, int64(2), matched.ID)
|
||||
})
|
||||
|
||||
t.Run("Gemini 请求匹配全平台规则", func(t *testing.T) {
|
||||
matched := svc.MatchRule("gemini", 422, []byte("error"))
|
||||
require.NotNil(t, matched)
|
||||
assert.Equal(t, int64(3), matched.ID)
|
||||
})
|
||||
|
||||
t.Run("Antigravity 请求匹配全平台规则", func(t *testing.T) {
|
||||
matched := svc.MatchRule("antigravity", 422, []byte("error"))
|
||||
require.NotNil(t, matched)
|
||||
assert.Equal(t, int64(3), matched.ID)
|
||||
})
|
||||
}
|
||||
|
||||
func TestMatchRule_NoMatch(t *testing.T) {
|
||||
rules := []*model.ErrorPassthroughRule{
|
||||
{
|
||||
ID: 1,
|
||||
Name: "Rule for 422",
|
||||
Enabled: true,
|
||||
Priority: 1,
|
||||
ErrorCodes: []int{422},
|
||||
MatchMode: model.MatchModeAny,
|
||||
},
|
||||
}
|
||||
|
||||
svc := newTestService(rules)
|
||||
matched := svc.MatchRule("anthropic", 500, []byte("error"))
|
||||
|
||||
assert.Nil(t, matched, "不匹配任何规则时应返回 nil")
|
||||
}
|
||||
|
||||
func TestMatchRule_EmptyRules(t *testing.T) {
|
||||
svc := newTestService([]*model.ErrorPassthroughRule{})
|
||||
matched := svc.MatchRule("anthropic", 422, []byte("error"))
|
||||
|
||||
assert.Nil(t, matched, "没有规则时应返回 nil")
|
||||
}
|
||||
|
||||
func TestMatchRule_CaseInsensitiveKeyword(t *testing.T) {
|
||||
rules := []*model.ErrorPassthroughRule{
|
||||
{
|
||||
ID: 1,
|
||||
Name: "Context Limit",
|
||||
Enabled: true,
|
||||
Priority: 1,
|
||||
Keywords: []string{"Context Limit"},
|
||||
MatchMode: model.MatchModeAny,
|
||||
},
|
||||
}
|
||||
|
||||
svc := newTestService(rules)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
expected bool
|
||||
}{
|
||||
{"完全匹配", "Context Limit reached", true},
|
||||
{"小写匹配", "context limit reached", true},
|
||||
{"大写匹配", "CONTEXT LIMIT REACHED", true},
|
||||
{"混合大小写", "ConTeXt LiMiT error", true},
|
||||
{"不匹配", "some other error", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
matched := svc.MatchRule("anthropic", 500, []byte(tt.body))
|
||||
if tt.expected {
|
||||
assert.NotNil(t, matched)
|
||||
} else {
|
||||
assert.Nil(t, matched)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// 测试真实场景
|
||||
// =============================================================================
|
||||
|
||||
func TestMatchRule_RealWorldScenario_ContextLimitPassthrough(t *testing.T) {
|
||||
// 场景:上游返回 422 + "context limit has been reached",需要透传给客户端
|
||||
rules := []*model.ErrorPassthroughRule{
|
||||
{
|
||||
ID: 1,
|
||||
Name: "Context Limit Passthrough",
|
||||
Enabled: true,
|
||||
Priority: 1,
|
||||
ErrorCodes: []int{422},
|
||||
Keywords: []string{"context limit"},
|
||||
MatchMode: model.MatchModeAll, // 必须同时满足
|
||||
Platforms: []string{"anthropic", "antigravity"},
|
||||
PassthroughCode: true,
|
||||
PassthroughBody: true,
|
||||
},
|
||||
}
|
||||
|
||||
svc := newTestService(rules)
|
||||
|
||||
// 测试 Anthropic 平台
|
||||
t.Run("Anthropic 422 with context limit", func(t *testing.T) {
|
||||
body := []byte(`{"type":"error","error":{"type":"invalid_request","message":"The context limit has been reached"}}`)
|
||||
matched := svc.MatchRule("anthropic", 422, body)
|
||||
require.NotNil(t, matched)
|
||||
assert.True(t, matched.PassthroughCode)
|
||||
assert.True(t, matched.PassthroughBody)
|
||||
})
|
||||
|
||||
// 测试 Antigravity 平台
|
||||
t.Run("Antigravity 422 with context limit", func(t *testing.T) {
|
||||
body := []byte(`{"error":"context limit exceeded"}`)
|
||||
matched := svc.MatchRule("antigravity", 422, body)
|
||||
require.NotNil(t, matched)
|
||||
})
|
||||
|
||||
// 测试 OpenAI 平台(不在规则的平台列表中)
|
||||
t.Run("OpenAI should not match", func(t *testing.T) {
|
||||
body := []byte(`{"error":"context limit exceeded"}`)
|
||||
matched := svc.MatchRule("openai", 422, body)
|
||||
assert.Nil(t, matched, "OpenAI 不在规则的平台列表中")
|
||||
})
|
||||
|
||||
// 测试状态码不匹配
|
||||
t.Run("Wrong status code", func(t *testing.T) {
|
||||
body := []byte(`{"error":"context limit exceeded"}`)
|
||||
matched := svc.MatchRule("anthropic", 400, body)
|
||||
assert.Nil(t, matched, "状态码不匹配")
|
||||
})
|
||||
|
||||
// 测试关键词不匹配
|
||||
t.Run("Wrong keyword", func(t *testing.T) {
|
||||
body := []byte(`{"error":"rate limit exceeded"}`)
|
||||
matched := svc.MatchRule("anthropic", 422, body)
|
||||
assert.Nil(t, matched, "关键词不匹配")
|
||||
})
|
||||
}
|
||||
|
||||
func TestMatchRule_RealWorldScenario_CustomErrorMessage(t *testing.T) {
|
||||
// 场景:某些错误需要返回自定义消息,隐藏上游详细信息
|
||||
customMsg := "Service temporarily unavailable, please try again later"
|
||||
responseCode := 503
|
||||
rules := []*model.ErrorPassthroughRule{
|
||||
{
|
||||
ID: 1,
|
||||
Name: "Hide Internal Errors",
|
||||
Enabled: true,
|
||||
Priority: 1,
|
||||
ErrorCodes: []int{500, 502, 503},
|
||||
MatchMode: model.MatchModeAny,
|
||||
PassthroughCode: false,
|
||||
ResponseCode: &responseCode,
|
||||
PassthroughBody: false,
|
||||
CustomMessage: &customMsg,
|
||||
},
|
||||
}
|
||||
|
||||
svc := newTestService(rules)
|
||||
|
||||
matched := svc.MatchRule("anthropic", 500, []byte("internal server error"))
|
||||
require.NotNil(t, matched)
|
||||
assert.False(t, matched.PassthroughCode)
|
||||
assert.Equal(t, 503, *matched.ResponseCode)
|
||||
assert.False(t, matched.PassthroughBody)
|
||||
assert.Equal(t, customMsg, *matched.CustomMessage)
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// 测试 model.Validate
|
||||
// =============================================================================
|
||||
|
||||
func TestErrorPassthroughRule_Validate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
rule *model.ErrorPassthroughRule
|
||||
expectError bool
|
||||
errorField string
|
||||
}{
|
||||
{
|
||||
name: "有效规则 - 透传模式(含错误码)",
|
||||
rule: &model.ErrorPassthroughRule{
|
||||
Name: "Valid Rule",
|
||||
MatchMode: model.MatchModeAny,
|
||||
ErrorCodes: []int{422},
|
||||
PassthroughCode: true,
|
||||
PassthroughBody: true,
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "有效规则 - 透传模式(含关键词)",
|
||||
rule: &model.ErrorPassthroughRule{
|
||||
Name: "Valid Rule",
|
||||
MatchMode: model.MatchModeAny,
|
||||
Keywords: []string{"context limit"},
|
||||
PassthroughCode: true,
|
||||
PassthroughBody: true,
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "有效规则 - 自定义响应",
|
||||
rule: &model.ErrorPassthroughRule{
|
||||
Name: "Valid Rule",
|
||||
MatchMode: model.MatchModeAll,
|
||||
ErrorCodes: []int{500},
|
||||
Keywords: []string{"internal error"},
|
||||
PassthroughCode: false,
|
||||
ResponseCode: testIntPtr(503),
|
||||
PassthroughBody: false,
|
||||
CustomMessage: testStrPtr("Custom error"),
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "缺少名称",
|
||||
rule: &model.ErrorPassthroughRule{
|
||||
Name: "",
|
||||
MatchMode: model.MatchModeAny,
|
||||
ErrorCodes: []int{422},
|
||||
PassthroughCode: true,
|
||||
PassthroughBody: true,
|
||||
},
|
||||
expectError: true,
|
||||
errorField: "name",
|
||||
},
|
||||
{
|
||||
name: "无效的匹配模式",
|
||||
rule: &model.ErrorPassthroughRule{
|
||||
Name: "Invalid Mode",
|
||||
MatchMode: "invalid",
|
||||
ErrorCodes: []int{422},
|
||||
PassthroughCode: true,
|
||||
PassthroughBody: true,
|
||||
},
|
||||
expectError: true,
|
||||
errorField: "match_mode",
|
||||
},
|
||||
{
|
||||
name: "缺少匹配条件(错误码和关键词都为空)",
|
||||
rule: &model.ErrorPassthroughRule{
|
||||
Name: "No Conditions",
|
||||
MatchMode: model.MatchModeAny,
|
||||
ErrorCodes: []int{},
|
||||
Keywords: []string{},
|
||||
PassthroughCode: true,
|
||||
PassthroughBody: true,
|
||||
},
|
||||
expectError: true,
|
||||
errorField: "conditions",
|
||||
},
|
||||
{
|
||||
name: "缺少匹配条件(nil切片)",
|
||||
rule: &model.ErrorPassthroughRule{
|
||||
Name: "Nil Conditions",
|
||||
MatchMode: model.MatchModeAny,
|
||||
ErrorCodes: nil,
|
||||
Keywords: nil,
|
||||
PassthroughCode: true,
|
||||
PassthroughBody: true,
|
||||
},
|
||||
expectError: true,
|
||||
errorField: "conditions",
|
||||
},
|
||||
{
|
||||
name: "自定义状态码但未提供值",
|
||||
rule: &model.ErrorPassthroughRule{
|
||||
Name: "Missing Code",
|
||||
MatchMode: model.MatchModeAny,
|
||||
ErrorCodes: []int{422},
|
||||
PassthroughCode: false,
|
||||
ResponseCode: nil,
|
||||
PassthroughBody: true,
|
||||
},
|
||||
expectError: true,
|
||||
errorField: "response_code",
|
||||
},
|
||||
{
|
||||
name: "自定义消息但未提供值",
|
||||
rule: &model.ErrorPassthroughRule{
|
||||
Name: "Missing Message",
|
||||
MatchMode: model.MatchModeAny,
|
||||
ErrorCodes: []int{422},
|
||||
PassthroughCode: true,
|
||||
PassthroughBody: false,
|
||||
CustomMessage: nil,
|
||||
},
|
||||
expectError: true,
|
||||
errorField: "custom_message",
|
||||
},
|
||||
{
|
||||
name: "自定义消息为空字符串",
|
||||
rule: &model.ErrorPassthroughRule{
|
||||
Name: "Empty Message",
|
||||
MatchMode: model.MatchModeAny,
|
||||
ErrorCodes: []int{422},
|
||||
PassthroughCode: true,
|
||||
PassthroughBody: false,
|
||||
CustomMessage: testStrPtr(""),
|
||||
},
|
||||
expectError: true,
|
||||
errorField: "custom_message",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.rule.Validate()
|
||||
if tt.expectError {
|
||||
require.Error(t, err)
|
||||
validationErr, ok := err.(*model.ValidationError)
|
||||
require.True(t, ok, "应该返回 ValidationError")
|
||||
assert.Equal(t, tt.errorField, validationErr.Field)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
func testIntPtr(i int) *int { return &i }
|
||||
func testStrPtr(s string) *string { return &s }
|
||||
288
backend/internal/service/gateway_cached_tokens_test.go
Normal file
288
backend/internal/service/gateway_cached_tokens_test.go
Normal file
@@ -0,0 +1,288 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// ---------- reconcileCachedTokens 单元测试 ----------
|
||||
|
||||
func TestReconcileCachedTokens_NilUsage(t *testing.T) {
|
||||
assert.False(t, reconcileCachedTokens(nil))
|
||||
}
|
||||
|
||||
func TestReconcileCachedTokens_AlreadyHasCacheRead(t *testing.T) {
|
||||
// 已有标准字段,不应覆盖
|
||||
usage := map[string]any{
|
||||
"cache_read_input_tokens": float64(100),
|
||||
"cached_tokens": float64(50),
|
||||
}
|
||||
assert.False(t, reconcileCachedTokens(usage))
|
||||
assert.Equal(t, float64(100), usage["cache_read_input_tokens"])
|
||||
}
|
||||
|
||||
func TestReconcileCachedTokens_KimiStyle(t *testing.T) {
|
||||
// Kimi 风格:cache_read_input_tokens=0,cached_tokens>0
|
||||
usage := map[string]any{
|
||||
"input_tokens": float64(23),
|
||||
"cache_creation_input_tokens": float64(0),
|
||||
"cache_read_input_tokens": float64(0),
|
||||
"cached_tokens": float64(23),
|
||||
}
|
||||
assert.True(t, reconcileCachedTokens(usage))
|
||||
assert.Equal(t, float64(23), usage["cache_read_input_tokens"])
|
||||
}
|
||||
|
||||
func TestReconcileCachedTokens_NoCachedTokens(t *testing.T) {
|
||||
// 无 cached_tokens 字段(原生 Claude)
|
||||
usage := map[string]any{
|
||||
"input_tokens": float64(100),
|
||||
"cache_read_input_tokens": float64(0),
|
||||
"cache_creation_input_tokens": float64(0),
|
||||
}
|
||||
assert.False(t, reconcileCachedTokens(usage))
|
||||
assert.Equal(t, float64(0), usage["cache_read_input_tokens"])
|
||||
}
|
||||
|
||||
func TestReconcileCachedTokens_CachedTokensZero(t *testing.T) {
|
||||
// cached_tokens 为 0,不应覆盖
|
||||
usage := map[string]any{
|
||||
"cache_read_input_tokens": float64(0),
|
||||
"cached_tokens": float64(0),
|
||||
}
|
||||
assert.False(t, reconcileCachedTokens(usage))
|
||||
assert.Equal(t, float64(0), usage["cache_read_input_tokens"])
|
||||
}
|
||||
|
||||
func TestReconcileCachedTokens_MissingCacheReadField(t *testing.T) {
|
||||
// cache_read_input_tokens 字段完全不存在,cached_tokens > 0
|
||||
usage := map[string]any{
|
||||
"cached_tokens": float64(42),
|
||||
}
|
||||
assert.True(t, reconcileCachedTokens(usage))
|
||||
assert.Equal(t, float64(42), usage["cache_read_input_tokens"])
|
||||
}
|
||||
|
||||
// ---------- 流式 message_start 事件 reconcile 测试 ----------
|
||||
|
||||
func TestStreamingReconcile_MessageStart(t *testing.T) {
|
||||
// 模拟 Kimi 返回的 message_start SSE 事件
|
||||
eventJSON := `{
|
||||
"type": "message_start",
|
||||
"message": {
|
||||
"id": "msg_123",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": "kimi",
|
||||
"usage": {
|
||||
"input_tokens": 23,
|
||||
"cache_creation_input_tokens": 0,
|
||||
"cache_read_input_tokens": 0,
|
||||
"cached_tokens": 23
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
var event map[string]any
|
||||
require.NoError(t, json.Unmarshal([]byte(eventJSON), &event))
|
||||
|
||||
eventType, _ := event["type"].(string)
|
||||
require.Equal(t, "message_start", eventType)
|
||||
|
||||
// 模拟 processSSEEvent 中的 reconcile 逻辑
|
||||
if msg, ok := event["message"].(map[string]any); ok {
|
||||
if u, ok := msg["usage"].(map[string]any); ok {
|
||||
reconcileCachedTokens(u)
|
||||
}
|
||||
}
|
||||
|
||||
// 验证 cache_read_input_tokens 已被填充
|
||||
msg, ok := event["message"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
usage, ok := msg["usage"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, float64(23), usage["cache_read_input_tokens"])
|
||||
|
||||
// 验证重新序列化后 JSON 也包含正确值
|
||||
data, err := json.Marshal(event)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(23), gjson.GetBytes(data, "message.usage.cache_read_input_tokens").Int())
|
||||
}
|
||||
|
||||
func TestStreamingReconcile_MessageStart_NativeClaude(t *testing.T) {
|
||||
// 原生 Claude 不返回 cached_tokens,reconcile 不应改变任何值
|
||||
eventJSON := `{
|
||||
"type": "message_start",
|
||||
"message": {
|
||||
"usage": {
|
||||
"input_tokens": 100,
|
||||
"cache_creation_input_tokens": 50,
|
||||
"cache_read_input_tokens": 30
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
var event map[string]any
|
||||
require.NoError(t, json.Unmarshal([]byte(eventJSON), &event))
|
||||
|
||||
if msg, ok := event["message"].(map[string]any); ok {
|
||||
if u, ok := msg["usage"].(map[string]any); ok {
|
||||
reconcileCachedTokens(u)
|
||||
}
|
||||
}
|
||||
|
||||
msg, ok := event["message"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
usage, ok := msg["usage"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, float64(30), usage["cache_read_input_tokens"])
|
||||
}
|
||||
|
||||
// ---------- 流式 message_delta 事件 reconcile 测试 ----------
|
||||
|
||||
func TestStreamingReconcile_MessageDelta(t *testing.T) {
|
||||
// 模拟 Kimi 返回的 message_delta SSE 事件
|
||||
eventJSON := `{
|
||||
"type": "message_delta",
|
||||
"usage": {
|
||||
"output_tokens": 7,
|
||||
"cache_read_input_tokens": 0,
|
||||
"cached_tokens": 15
|
||||
}
|
||||
}`
|
||||
|
||||
var event map[string]any
|
||||
require.NoError(t, json.Unmarshal([]byte(eventJSON), &event))
|
||||
|
||||
eventType, _ := event["type"].(string)
|
||||
require.Equal(t, "message_delta", eventType)
|
||||
|
||||
// 模拟 processSSEEvent 中的 reconcile 逻辑
|
||||
usage, ok := event["usage"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
reconcileCachedTokens(usage)
|
||||
assert.Equal(t, float64(15), usage["cache_read_input_tokens"])
|
||||
}
|
||||
|
||||
func TestStreamingReconcile_MessageDelta_NativeClaude(t *testing.T) {
|
||||
// 原生 Claude 的 message_delta 通常没有 cached_tokens
|
||||
eventJSON := `{
|
||||
"type": "message_delta",
|
||||
"usage": {
|
||||
"output_tokens": 50
|
||||
}
|
||||
}`
|
||||
|
||||
var event map[string]any
|
||||
require.NoError(t, json.Unmarshal([]byte(eventJSON), &event))
|
||||
|
||||
usage, ok := event["usage"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
reconcileCachedTokens(usage)
|
||||
_, hasCacheRead := usage["cache_read_input_tokens"]
|
||||
assert.False(t, hasCacheRead, "不应为原生 Claude 响应注入 cache_read_input_tokens")
|
||||
}
|
||||
|
||||
// ---------- 非流式响应 reconcile 测试 ----------
|
||||
|
||||
func TestNonStreamingReconcile_KimiResponse(t *testing.T) {
|
||||
// 模拟 Kimi 非流式响应
|
||||
body := []byte(`{
|
||||
"id": "msg_123",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "hello"}],
|
||||
"model": "kimi",
|
||||
"usage": {
|
||||
"input_tokens": 23,
|
||||
"output_tokens": 7,
|
||||
"cache_creation_input_tokens": 0,
|
||||
"cache_read_input_tokens": 0,
|
||||
"cached_tokens": 23,
|
||||
"prompt_tokens": 23,
|
||||
"completion_tokens": 7
|
||||
}
|
||||
}`)
|
||||
|
||||
// 模拟 handleNonStreamingResponse 中的逻辑
|
||||
var response struct {
|
||||
Usage ClaudeUsage `json:"usage"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(body, &response))
|
||||
|
||||
// reconcile
|
||||
if response.Usage.CacheReadInputTokens == 0 {
|
||||
cachedTokens := gjson.GetBytes(body, "usage.cached_tokens").Int()
|
||||
if cachedTokens > 0 {
|
||||
response.Usage.CacheReadInputTokens = int(cachedTokens)
|
||||
if newBody, err := sjson.SetBytes(body, "usage.cache_read_input_tokens", cachedTokens); err == nil {
|
||||
body = newBody
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 验证内部 usage(计费用)
|
||||
assert.Equal(t, 23, response.Usage.CacheReadInputTokens)
|
||||
assert.Equal(t, 23, response.Usage.InputTokens)
|
||||
assert.Equal(t, 7, response.Usage.OutputTokens)
|
||||
|
||||
// 验证返回给客户端的 JSON body
|
||||
assert.Equal(t, int64(23), gjson.GetBytes(body, "usage.cache_read_input_tokens").Int())
|
||||
}
|
||||
|
||||
func TestNonStreamingReconcile_NativeClaude(t *testing.T) {
|
||||
// 原生 Claude 响应:cache_read_input_tokens 已有值
|
||||
body := []byte(`{
|
||||
"usage": {
|
||||
"input_tokens": 100,
|
||||
"output_tokens": 50,
|
||||
"cache_creation_input_tokens": 20,
|
||||
"cache_read_input_tokens": 30
|
||||
}
|
||||
}`)
|
||||
|
||||
var response struct {
|
||||
Usage ClaudeUsage `json:"usage"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(body, &response))
|
||||
|
||||
// CacheReadInputTokens == 30,条件不成立,整个 reconcile 分支不会执行
|
||||
assert.NotZero(t, response.Usage.CacheReadInputTokens)
|
||||
assert.Equal(t, 30, response.Usage.CacheReadInputTokens)
|
||||
}
|
||||
|
||||
func TestNonStreamingReconcile_NoCachedTokens(t *testing.T) {
|
||||
// 没有 cached_tokens 字段
|
||||
body := []byte(`{
|
||||
"usage": {
|
||||
"input_tokens": 100,
|
||||
"output_tokens": 50,
|
||||
"cache_creation_input_tokens": 0,
|
||||
"cache_read_input_tokens": 0
|
||||
}
|
||||
}`)
|
||||
|
||||
var response struct {
|
||||
Usage ClaudeUsage `json:"usage"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(body, &response))
|
||||
|
||||
if response.Usage.CacheReadInputTokens == 0 {
|
||||
cachedTokens := gjson.GetBytes(body, "usage.cached_tokens").Int()
|
||||
if cachedTokens > 0 {
|
||||
response.Usage.CacheReadInputTokens = int(cachedTokens)
|
||||
if newBody, err := sjson.SetBytes(body, "usage.cache_read_input_tokens", cachedTokens); err == nil {
|
||||
body = newBody
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// cache_read_input_tokens 应保持为 0
|
||||
assert.Equal(t, 0, response.Usage.CacheReadInputTokens)
|
||||
assert.Equal(t, int64(0), gjson.GetBytes(body, "usage.cache_read_input_tokens").Int())
|
||||
}
|
||||
@@ -370,7 +370,8 @@ type ForwardResult struct {
|
||||
|
||||
// UpstreamFailoverError indicates an upstream error that should trigger account failover.
|
||||
type UpstreamFailoverError struct {
|
||||
StatusCode int
|
||||
StatusCode int
|
||||
ResponseBody []byte // 上游响应体,用于错误透传规则匹配
|
||||
}
|
||||
|
||||
func (e *UpstreamFailoverError) Error() string {
|
||||
@@ -384,6 +385,7 @@ type GatewayService struct {
|
||||
usageLogRepo UsageLogRepository
|
||||
userRepo UserRepository
|
||||
userSubRepo UserSubscriptionRepository
|
||||
userGroupRateRepo UserGroupRateRepository
|
||||
cache GatewayCache
|
||||
cfg *config.Config
|
||||
schedulerSnapshot *SchedulerSnapshotService
|
||||
@@ -405,6 +407,7 @@ func NewGatewayService(
|
||||
usageLogRepo UsageLogRepository,
|
||||
userRepo UserRepository,
|
||||
userSubRepo UserSubscriptionRepository,
|
||||
userGroupRateRepo UserGroupRateRepository,
|
||||
cache GatewayCache,
|
||||
cfg *config.Config,
|
||||
schedulerSnapshot *SchedulerSnapshotService,
|
||||
@@ -424,6 +427,7 @@ func NewGatewayService(
|
||||
usageLogRepo: usageLogRepo,
|
||||
userRepo: userRepo,
|
||||
userSubRepo: userSubRepo,
|
||||
userGroupRateRepo: userGroupRateRepo,
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
schedulerSnapshot: schedulerSnapshot,
|
||||
@@ -3281,7 +3285,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
return ""
|
||||
}(),
|
||||
})
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
|
||||
}
|
||||
return s.handleRetryExhaustedError(ctx, resp, c, account)
|
||||
}
|
||||
@@ -3311,10 +3315,8 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
return ""
|
||||
}(),
|
||||
})
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
|
||||
}
|
||||
|
||||
// 处理错误响应(不可重试的错误)
|
||||
if resp.StatusCode >= 400 {
|
||||
// 可选:对部分 400 触发 failover(默认关闭以保持语义)
|
||||
if resp.StatusCode == 400 && s.cfg != nil && s.cfg.Gateway.FailoverOn400 {
|
||||
@@ -3358,7 +3360,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
log.Printf("Account %d: 400 error, attempting failover", account.ID)
|
||||
}
|
||||
s.handleFailoverSideEffects(ctx, resp, account)
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
|
||||
}
|
||||
}
|
||||
return s.handleErrorResponse(ctx, resp, c, account)
|
||||
@@ -3755,6 +3757,12 @@ func (s *GatewayService) shouldFailoverOn400(respBody []byte) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// ExtractUpstreamErrorMessage 从上游响应体中提取错误消息
|
||||
// 支持 Claude 风格的错误格式:{"type":"error","error":{"type":"...","message":"..."}}
|
||||
func ExtractUpstreamErrorMessage(body []byte) string {
|
||||
return extractUpstreamErrorMessage(body)
|
||||
}
|
||||
|
||||
func extractUpstreamErrorMessage(body []byte) string {
|
||||
// Claude 风格:{"type":"error","error":{"type":"...","message":"..."}}
|
||||
if m := gjson.GetBytes(body, "error.message").String(); strings.TrimSpace(m) != "" {
|
||||
@@ -3822,7 +3830,7 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res
|
||||
shouldDisable = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
||||
}
|
||||
if shouldDisable {
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: body}
|
||||
}
|
||||
|
||||
// 记录上游错误响应体摘要便于排障(可选:由配置控制;不回显到客户端)
|
||||
@@ -4168,6 +4176,20 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
||||
eventName = eventType
|
||||
}
|
||||
|
||||
// 兼容 Kimi cached_tokens → cache_read_input_tokens
|
||||
if eventType == "message_start" {
|
||||
if msg, ok := event["message"].(map[string]any); ok {
|
||||
if u, ok := msg["usage"].(map[string]any); ok {
|
||||
reconcileCachedTokens(u)
|
||||
}
|
||||
}
|
||||
}
|
||||
if eventType == "message_delta" {
|
||||
if u, ok := event["usage"].(map[string]any); ok {
|
||||
reconcileCachedTokens(u)
|
||||
}
|
||||
}
|
||||
|
||||
if needModelReplace {
|
||||
if msg, ok := event["message"].(map[string]any); ok {
|
||||
if model, ok := msg["model"].(string); ok && model == mappedModel {
|
||||
@@ -4518,6 +4540,17 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
|
||||
return nil, fmt.Errorf("parse response: %w", err)
|
||||
}
|
||||
|
||||
// 兼容 Kimi cached_tokens → cache_read_input_tokens
|
||||
if response.Usage.CacheReadInputTokens == 0 {
|
||||
cachedTokens := gjson.GetBytes(body, "usage.cached_tokens").Int()
|
||||
if cachedTokens > 0 {
|
||||
response.Usage.CacheReadInputTokens = int(cachedTokens)
|
||||
if newBody, err := sjson.SetBytes(body, "usage.cache_read_input_tokens", cachedTokens); err == nil {
|
||||
body = newBody
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 如果有模型映射,替换响应中的model字段
|
||||
if originalModel != mappedModel {
|
||||
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
|
||||
@@ -4609,10 +4642,17 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
account := input.Account
|
||||
subscription := input.Subscription
|
||||
|
||||
// 获取费率倍数
|
||||
// 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认)
|
||||
multiplier := s.cfg.Default.RateMultiplier
|
||||
if apiKey.GroupID != nil && apiKey.Group != nil {
|
||||
multiplier = apiKey.Group.RateMultiplier
|
||||
|
||||
// 检查用户专属倍率
|
||||
if s.userGroupRateRepo != nil {
|
||||
if userRate, err := s.userGroupRateRepo.GetByUserAndGroup(ctx, user.ID, *apiKey.GroupID); err == nil && userRate != nil {
|
||||
multiplier = *userRate
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var cost *CostBreakdown
|
||||
@@ -4773,10 +4813,17 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
|
||||
account := input.Account
|
||||
subscription := input.Subscription
|
||||
|
||||
// 获取费率倍数
|
||||
// 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认)
|
||||
multiplier := s.cfg.Default.RateMultiplier
|
||||
if apiKey.GroupID != nil && apiKey.Group != nil {
|
||||
multiplier = apiKey.Group.RateMultiplier
|
||||
|
||||
// 检查用户专属倍率
|
||||
if s.userGroupRateRepo != nil {
|
||||
if userRate, err := s.userGroupRateRepo.GetByUserAndGroup(ctx, user.ID, *apiKey.GroupID); err == nil && userRate != nil {
|
||||
multiplier = *userRate
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var cost *CostBreakdown
|
||||
@@ -5289,3 +5336,21 @@ func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64,
|
||||
|
||||
return models
|
||||
}
|
||||
|
||||
// reconcileCachedTokens 兼容 Kimi 等上游:
|
||||
// 将 OpenAI 风格的 cached_tokens 映射到 Claude 标准的 cache_read_input_tokens
|
||||
func reconcileCachedTokens(usage map[string]any) bool {
|
||||
if usage == nil {
|
||||
return false
|
||||
}
|
||||
cacheRead, _ := usage["cache_read_input_tokens"].(float64)
|
||||
if cacheRead > 0 {
|
||||
return false // 已有标准字段,无需处理
|
||||
}
|
||||
cached, _ := usage["cached_tokens"].(float64)
|
||||
if cached <= 0 {
|
||||
return false
|
||||
}
|
||||
usage["cache_read_input_tokens"] = cached
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -864,7 +864,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
|
||||
}
|
||||
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
|
||||
upstreamReqID := resp.Header.Get(requestIDHeader)
|
||||
@@ -891,7 +891,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
|
||||
}
|
||||
upstreamReqID := resp.Header.Get(requestIDHeader)
|
||||
if upstreamReqID == "" {
|
||||
@@ -1301,7 +1301,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
|
||||
}
|
||||
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
|
||||
evBody := unwrapIfNeeded(isOAuth, respBody)
|
||||
@@ -1325,7 +1325,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: evBody}
|
||||
}
|
||||
|
||||
respBody = unwrapIfNeeded(isOAuth, respBody)
|
||||
|
||||
@@ -944,6 +944,32 @@ func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, pr
|
||||
return strings.TrimSpace(loadResp.CloudAICompanionProject), tierID, nil
|
||||
}
|
||||
|
||||
// 关键逻辑:对齐 Gemini CLI 对“已注册用户”的处理方式。
|
||||
// 当 LoadCodeAssist 返回了 currentTier / paidTier(表示账号已注册)但没有返回 cloudaicompanionProject 时:
|
||||
// - 不要再调用 onboardUser(通常不会再分配 project_id,且可能触发 INVALID_ARGUMENT)
|
||||
// - 先尝试从 Cloud Resource Manager 获取可用项目;仍失败则提示用户手动填写 project_id
|
||||
if loadResp != nil {
|
||||
registeredTierID := strings.TrimSpace(loadResp.GetTier())
|
||||
if registeredTierID != "" {
|
||||
// 已注册但未返回 cloudaicompanionProject,这在 Google One 用户中较常见:需要用户自行提供 project_id。
|
||||
log.Printf("[GeminiOAuth] User has tier (%s) but no cloudaicompanionProject, trying Cloud Resource Manager...", registeredTierID)
|
||||
|
||||
// Try to get project from Cloud Resource Manager
|
||||
fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL)
|
||||
if fbErr == nil && strings.TrimSpace(fallback) != "" {
|
||||
log.Printf("[GeminiOAuth] Found project from Cloud Resource Manager: %s", fallback)
|
||||
return strings.TrimSpace(fallback), tierID, nil
|
||||
}
|
||||
|
||||
// No project found - user must provide project_id manually
|
||||
log.Printf("[GeminiOAuth] No project found from Cloud Resource Manager, user must provide project_id manually")
|
||||
return "", tierID, fmt.Errorf("user is registered (tier: %s) but no project_id available. Please provide Project ID manually in the authorization form, or create a project at https://console.cloud.google.com", registeredTierID)
|
||||
}
|
||||
}
|
||||
|
||||
// 未检测到 currentTier/paidTier,视为新用户,继续调用 onboardUser
|
||||
log.Printf("[GeminiOAuth] No currentTier/paidTier found, proceeding with onboardUser (tierID: %s)", tierID)
|
||||
|
||||
req := &geminicli.OnboardUserRequest{
|
||||
TierID: tierID,
|
||||
Metadata: geminicli.LoadCodeAssistMetadata{
|
||||
|
||||
@@ -21,6 +21,17 @@ const (
|
||||
var codexCLIInstructions string
|
||||
|
||||
var codexModelMap = map[string]string{
|
||||
"gpt-5.3": "gpt-5.3",
|
||||
"gpt-5.3-none": "gpt-5.3",
|
||||
"gpt-5.3-low": "gpt-5.3",
|
||||
"gpt-5.3-medium": "gpt-5.3",
|
||||
"gpt-5.3-high": "gpt-5.3",
|
||||
"gpt-5.3-xhigh": "gpt-5.3",
|
||||
"gpt-5.3-codex": "gpt-5.3-codex",
|
||||
"gpt-5.3-codex-low": "gpt-5.3-codex",
|
||||
"gpt-5.3-codex-medium": "gpt-5.3-codex",
|
||||
"gpt-5.3-codex-high": "gpt-5.3-codex",
|
||||
"gpt-5.3-codex-xhigh": "gpt-5.3-codex",
|
||||
"gpt-5.1-codex": "gpt-5.1-codex",
|
||||
"gpt-5.1-codex-low": "gpt-5.1-codex",
|
||||
"gpt-5.1-codex-medium": "gpt-5.1-codex",
|
||||
@@ -156,6 +167,12 @@ func normalizeCodexModel(model string) string {
|
||||
if strings.Contains(normalized, "gpt-5.2") || strings.Contains(normalized, "gpt 5.2") {
|
||||
return "gpt-5.2"
|
||||
}
|
||||
if strings.Contains(normalized, "gpt-5.3-codex") || strings.Contains(normalized, "gpt 5.3 codex") {
|
||||
return "gpt-5.3-codex"
|
||||
}
|
||||
if strings.Contains(normalized, "gpt-5.3") || strings.Contains(normalized, "gpt 5.3") {
|
||||
return "gpt-5.3"
|
||||
}
|
||||
if strings.Contains(normalized, "gpt-5.1-codex-max") || strings.Contains(normalized, "gpt 5.1 codex max") {
|
||||
return "gpt-5.1-codex-max"
|
||||
}
|
||||
|
||||
@@ -176,6 +176,19 @@ func TestApplyCodexOAuthTransform_EmptyInput(t *testing.T) {
|
||||
require.Len(t, input, 0)
|
||||
}
|
||||
|
||||
func TestNormalizeCodexModel_Gpt53(t *testing.T) {
|
||||
cases := map[string]string{
|
||||
"gpt-5.3": "gpt-5.3",
|
||||
"gpt-5.3-codex": "gpt-5.3-codex",
|
||||
"gpt-5.3-codex-xhigh": "gpt-5.3-codex",
|
||||
"gpt 5.3 codex": "gpt-5.3-codex",
|
||||
}
|
||||
|
||||
for input, expected := range cases {
|
||||
require.Equal(t, expected, normalizeCodexModel(input))
|
||||
}
|
||||
}
|
||||
|
||||
func setupCodexCache(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
|
||||
@@ -940,7 +940,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
})
|
||||
|
||||
s.handleFailoverSideEffects(ctx, resp, account)
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
|
||||
}
|
||||
return s.handleErrorResponse(ctx, resp, c, account)
|
||||
}
|
||||
@@ -1131,7 +1131,7 @@ func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *ht
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
if shouldDisable {
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: body}
|
||||
}
|
||||
|
||||
// Return appropriate error response
|
||||
|
||||
@@ -579,6 +579,7 @@ func (s *PricingService) extractBaseName(model string) string {
|
||||
func (s *PricingService) matchByModelFamily(model string) *LiteLLMModelPricing {
|
||||
// Claude模型系列匹配规则
|
||||
familyPatterns := map[string][]string{
|
||||
"opus-4.6": {"claude-opus-4.6", "claude-opus-4-6"},
|
||||
"opus-4.5": {"claude-opus-4.5", "claude-opus-4-5"},
|
||||
"opus-4": {"claude-opus-4", "claude-3-opus"},
|
||||
"sonnet-4.5": {"claude-sonnet-4.5", "claude-sonnet-4-5"},
|
||||
@@ -651,7 +652,8 @@ func (s *PricingService) matchByModelFamily(model string) *LiteLLMModelPricing {
|
||||
// 回退顺序:
|
||||
// 1. gpt-5.2-codex -> gpt-5.2(去掉后缀如 -codex, -mini, -max 等)
|
||||
// 2. gpt-5.2-20251222 -> gpt-5.2(去掉日期版本号)
|
||||
// 3. 最终回退到 DefaultTestModel (gpt-5.1-codex)
|
||||
// 3. gpt-5.3-codex -> gpt-5.2-codex
|
||||
// 4. 最终回退到 DefaultTestModel (gpt-5.1-codex)
|
||||
func (s *PricingService) matchOpenAIModel(model string) *LiteLLMModelPricing {
|
||||
// 尝试的回退变体
|
||||
variants := s.generateOpenAIModelVariants(model, openAIModelDatePattern)
|
||||
@@ -663,6 +665,13 @@ func (s *PricingService) matchOpenAIModel(model string) *LiteLLMModelPricing {
|
||||
}
|
||||
}
|
||||
|
||||
if strings.HasPrefix(model, "gpt-5.3-codex") {
|
||||
if pricing, ok := s.pricingData["gpt-5.2-codex"]; ok {
|
||||
log.Printf("[Pricing] OpenAI fallback matched %s -> %s", model, "gpt-5.2-codex")
|
||||
return pricing
|
||||
}
|
||||
}
|
||||
|
||||
// 最终回退到 DefaultTestModel
|
||||
defaultModel := strings.ToLower(openai.DefaultTestModel)
|
||||
if pricing, ok := s.pricingData[defaultModel]; ok {
|
||||
|
||||
@@ -21,6 +21,10 @@ type User struct {
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
|
||||
// GroupRates 用户专属分组倍率配置
|
||||
// map[groupID]rateMultiplier
|
||||
GroupRates map[int64]float64
|
||||
|
||||
// TOTP 双因素认证字段
|
||||
TotpSecretEncrypted *string // AES-256-GCM 加密的 TOTP 密钥
|
||||
TotpEnabled bool // 是否启用 TOTP
|
||||
@@ -40,18 +44,20 @@ func (u *User) IsActive() bool {
|
||||
|
||||
// CanBindGroup checks whether a user can bind to a given group.
|
||||
// For standard groups:
|
||||
// - If AllowedGroups is non-empty, only allow binding to IDs in that list.
|
||||
// - If AllowedGroups is empty (nil or length 0), allow binding to any non-exclusive group.
|
||||
// - Public groups (non-exclusive): all users can bind
|
||||
// - Exclusive groups: only users with the group in AllowedGroups can bind
|
||||
func (u *User) CanBindGroup(groupID int64, isExclusive bool) bool {
|
||||
if len(u.AllowedGroups) > 0 {
|
||||
for _, id := range u.AllowedGroups {
|
||||
if id == groupID {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
// 公开分组(非专属):所有用户都可以绑定
|
||||
if !isExclusive {
|
||||
return true
|
||||
}
|
||||
return !isExclusive
|
||||
// 专属分组:需要在 AllowedGroups 中
|
||||
for _, id := range u.AllowedGroups {
|
||||
if id == groupID {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (u *User) SetPassword(password string) error {
|
||||
|
||||
25
backend/internal/service/user_group_rate.go
Normal file
25
backend/internal/service/user_group_rate.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package service
|
||||
|
||||
import "context"
|
||||
|
||||
// UserGroupRateRepository 用户专属分组倍率仓储接口
|
||||
// 允许管理员为特定用户设置分组的专属计费倍率,覆盖分组默认倍率
|
||||
type UserGroupRateRepository interface {
|
||||
// GetByUserID 获取用户的所有专属分组倍率
|
||||
// 返回 map[groupID]rateMultiplier
|
||||
GetByUserID(ctx context.Context, userID int64) (map[int64]float64, error)
|
||||
|
||||
// GetByUserAndGroup 获取用户在特定分组的专属倍率
|
||||
// 如果未设置专属倍率,返回 nil
|
||||
GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error)
|
||||
|
||||
// SyncUserGroupRates 同步用户的分组专属倍率
|
||||
// rates: map[groupID]*rateMultiplier,nil 表示删除该分组的专属倍率
|
||||
SyncUserGroupRates(ctx context.Context, userID int64, rates map[int64]*float64) error
|
||||
|
||||
// DeleteByGroupID 删除指定分组的所有用户专属倍率(分组删除时调用)
|
||||
DeleteByGroupID(ctx context.Context, groupID int64) error
|
||||
|
||||
// DeleteByUserID 删除指定用户的所有专属倍率(用户删除时调用)
|
||||
DeleteByUserID(ctx context.Context, userID int64) error
|
||||
}
|
||||
@@ -274,4 +274,5 @@ var ProviderSet = wire.NewSet(
|
||||
NewUserAttributeService,
|
||||
NewUsageCache,
|
||||
NewTotpService,
|
||||
NewErrorPassthroughService,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user