diff --git a/controller/relay.go b/controller/relay.go index 07c3aeaa..23d72515 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -139,15 +139,15 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) { // common.SetContextKey(c, constant.ContextKeyTokenCountMeta, meta) - preConsumedQuota, newAPIError := service.PreConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) + newAPIError = service.PreConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) if newAPIError != nil { return } defer func() { // Only return quota if downstream failed and quota was actually pre-consumed - if newAPIError != nil && preConsumedQuota != 0 { - service.ReturnPreConsumedQuota(c, relayInfo, preConsumedQuota) + if newAPIError != nil && relayInfo.FinalPreConsumedQuota != 0 { + service.ReturnPreConsumedQuota(c, relayInfo) } }() diff --git a/service/pre_consume_quota.go b/service/pre_consume_quota.go index 86b04e52..3cfabc1a 100644 --- a/service/pre_consume_quota.go +++ b/service/pre_consume_quota.go @@ -13,13 +13,13 @@ import ( "github.com/gin-gonic/gin" ) -func ReturnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, preConsumedQuota int) { - if preConsumedQuota != 0 { - logger.LogInfo(c, fmt.Sprintf("用户 %d 请求失败, 返还预扣费额度 %s", relayInfo.UserId, logger.FormatQuota(preConsumedQuota))) +func ReturnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo) { + if relayInfo.FinalPreConsumedQuota != 0 { + logger.LogInfo(c, fmt.Sprintf("用户 %d 请求失败, 返还预扣费额度 %s", relayInfo.UserId, logger.FormatQuota(relayInfo.FinalPreConsumedQuota))) gopool.Go(func() { relayInfoCopy := *relayInfo - err := PostConsumeQuota(&relayInfoCopy, -preConsumedQuota, 0, false) + err := PostConsumeQuota(&relayInfoCopy, -relayInfo.FinalPreConsumedQuota, 0, false) if err != nil { common.SysLog("error return pre-consumed quota: " + err.Error()) } @@ -29,16 +29,16 @@ func ReturnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, pr // PreConsumeQuota checks if the user has enough quota to pre-consume. // It returns the pre-consumed quota if successful, or an error if not. -func PreConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) (int, *types.NewAPIError) { +func PreConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) *types.NewAPIError { userQuota, err := model.GetUserQuota(relayInfo.UserId, false) if err != nil { - return 0, types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry()) + return types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry()) } if userQuota <= 0 { - return 0, types.NewErrorWithStatusCode(fmt.Errorf("用户额度不足, 剩余额度: %s", logger.FormatQuota(userQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog()) + return types.NewErrorWithStatusCode(fmt.Errorf("用户额度不足, 剩余额度: %s", logger.FormatQuota(userQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog()) } if userQuota-preConsumedQuota < 0 { - return 0, types.NewErrorWithStatusCode(fmt.Errorf("预扣费额度失败, 用户剩余额度: %s, 需要预扣费额度: %s", logger.FormatQuota(userQuota), logger.FormatQuota(preConsumedQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog()) + return types.NewErrorWithStatusCode(fmt.Errorf("预扣费额度失败, 用户剩余额度: %s, 需要预扣费额度: %s", logger.FormatQuota(userQuota), logger.FormatQuota(preConsumedQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog()) } trustQuota := common.GetTrustQuota() @@ -65,14 +65,14 @@ func PreConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo if preConsumedQuota > 0 { err := PreConsumeTokenQuota(relayInfo, preConsumedQuota) if err != nil { - return 0, types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog()) + return types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog()) } err = model.DecreaseUserQuota(relayInfo.UserId, preConsumedQuota) if err != nil { - return 0, types.NewError(err, types.ErrorCodeUpdateDataError, types.ErrOptionWithSkipRetry()) + return types.NewError(err, types.ErrorCodeUpdateDataError, types.ErrOptionWithSkipRetry()) } logger.LogInfo(c, fmt.Sprintf("用户 %d 预扣费 %s, 预扣费后剩余额度: %s", relayInfo.UserId, logger.FormatQuota(preConsumedQuota), logger.FormatQuota(userQuota-preConsumedQuota))) } relayInfo.FinalPreConsumedQuota = preConsumedQuota - return preConsumedQuota, nil + return nil }