fix: 预扣额度使用 relay info 传递

This commit is contained in:
Xyfacai
2025-09-11 16:04:32 +08:00
parent db6a788e0d
commit b25ac0bfb6
2 changed files with 14 additions and 14 deletions

View File

@@ -139,15 +139,15 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
// common.SetContextKey(c, constant.ContextKeyTokenCountMeta, meta) // common.SetContextKey(c, constant.ContextKeyTokenCountMeta, meta)
preConsumedQuota, newAPIError := service.PreConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) newAPIError = service.PreConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if newAPIError != nil { if newAPIError != nil {
return return
} }
defer func() { defer func() {
// Only return quota if downstream failed and quota was actually pre-consumed // Only return quota if downstream failed and quota was actually pre-consumed
if newAPIError != nil && preConsumedQuota != 0 { if newAPIError != nil && relayInfo.FinalPreConsumedQuota != 0 {
service.ReturnPreConsumedQuota(c, relayInfo, preConsumedQuota) service.ReturnPreConsumedQuota(c, relayInfo)
} }
}() }()

View File

@@ -13,13 +13,13 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
func ReturnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, preConsumedQuota int) { func ReturnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo) {
if preConsumedQuota != 0 { if relayInfo.FinalPreConsumedQuota != 0 {
logger.LogInfo(c, fmt.Sprintf("用户 %d 请求失败, 返还预扣费额度 %s", relayInfo.UserId, logger.FormatQuota(preConsumedQuota))) logger.LogInfo(c, fmt.Sprintf("用户 %d 请求失败, 返还预扣费额度 %s", relayInfo.UserId, logger.FormatQuota(relayInfo.FinalPreConsumedQuota)))
gopool.Go(func() { gopool.Go(func() {
relayInfoCopy := *relayInfo relayInfoCopy := *relayInfo
err := PostConsumeQuota(&relayInfoCopy, -preConsumedQuota, 0, false) err := PostConsumeQuota(&relayInfoCopy, -relayInfo.FinalPreConsumedQuota, 0, false)
if err != nil { if err != nil {
common.SysLog("error return pre-consumed quota: " + err.Error()) common.SysLog("error return pre-consumed quota: " + err.Error())
} }
@@ -29,16 +29,16 @@ func ReturnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, pr
// PreConsumeQuota checks if the user has enough quota to pre-consume. // PreConsumeQuota checks if the user has enough quota to pre-consume.
// It returns the pre-consumed quota if successful, or an error if not. // It returns the pre-consumed quota if successful, or an error if not.
func PreConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) (int, *types.NewAPIError) { func PreConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) *types.NewAPIError {
userQuota, err := model.GetUserQuota(relayInfo.UserId, false) userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
if err != nil { if err != nil {
return 0, types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry()) return types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry())
} }
if userQuota <= 0 { if userQuota <= 0 {
return 0, types.NewErrorWithStatusCode(fmt.Errorf("用户额度不足, 剩余额度: %s", logger.FormatQuota(userQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog()) return types.NewErrorWithStatusCode(fmt.Errorf("用户额度不足, 剩余额度: %s", logger.FormatQuota(userQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
} }
if userQuota-preConsumedQuota < 0 { if userQuota-preConsumedQuota < 0 {
return 0, types.NewErrorWithStatusCode(fmt.Errorf("预扣费额度失败, 用户剩余额度: %s, 需要预扣费额度: %s", logger.FormatQuota(userQuota), logger.FormatQuota(preConsumedQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog()) return types.NewErrorWithStatusCode(fmt.Errorf("预扣费额度失败, 用户剩余额度: %s, 需要预扣费额度: %s", logger.FormatQuota(userQuota), logger.FormatQuota(preConsumedQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
} }
trustQuota := common.GetTrustQuota() trustQuota := common.GetTrustQuota()
@@ -65,14 +65,14 @@ func PreConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
if preConsumedQuota > 0 { if preConsumedQuota > 0 {
err := PreConsumeTokenQuota(relayInfo, preConsumedQuota) err := PreConsumeTokenQuota(relayInfo, preConsumedQuota)
if err != nil { if err != nil {
return 0, types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog()) return types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
} }
err = model.DecreaseUserQuota(relayInfo.UserId, preConsumedQuota) err = model.DecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
if err != nil { if err != nil {
return 0, types.NewError(err, types.ErrorCodeUpdateDataError, types.ErrOptionWithSkipRetry()) return types.NewError(err, types.ErrorCodeUpdateDataError, types.ErrOptionWithSkipRetry())
} }
logger.LogInfo(c, fmt.Sprintf("用户 %d 预扣费 %s, 预扣费后剩余额度: %s", relayInfo.UserId, logger.FormatQuota(preConsumedQuota), logger.FormatQuota(userQuota-preConsumedQuota))) logger.LogInfo(c, fmt.Sprintf("用户 %d 预扣费 %s, 预扣费后剩余额度: %s", relayInfo.UserId, logger.FormatQuota(preConsumedQuota), logger.FormatQuota(userQuota-preConsumedQuota)))
} }
relayInfo.FinalPreConsumedQuota = preConsumedQuota relayInfo.FinalPreConsumedQuota = preConsumedQuota
return preConsumedQuota, nil return nil
} }