From a9f739a7e2dc062ea4361ded6efb739f876ba3ec Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Sun, 1 Dec 2024 08:24:41 +0800 Subject: [PATCH] refactor: improve validation logic and error handling in relay-text.go - Simplified validation checks for MaxTokens and Messages fields. - Enhanced error messages for better clarity. - Updated goroutine to avoid passing context unnecessarily. --- relay/relay-text.go | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/relay/relay-text.go b/relay/relay-text.go index b050ac51..3c6ed39e 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -2,11 +2,9 @@ package relay import ( "bytes" - "context" "encoding/json" "errors" "fmt" - "github.com/bytedance/sonic" "io" "math" "net/http" @@ -20,6 +18,8 @@ import ( "strings" "time" + "github.com/bytedance/sonic" + "github.com/gin-gonic/gin" ) @@ -36,7 +36,7 @@ func getAndValidateTextRequest(c *gin.Context, relayInfo *relaycommon.RelayInfo) textRequest.Model = c.Param("model") } - if textRequest.MaxTokens < 0 || textRequest.MaxTokens > math.MaxInt32/2 { + if textRequest.MaxTokens > math.MaxInt32/2 { return nil, errors.New("max_tokens is invalid") } if textRequest.Model == "" { @@ -48,12 +48,12 @@ func getAndValidateTextRequest(c *gin.Context, relayInfo *relaycommon.RelayInfo) return nil, errors.New("field prompt is required") } case relayconstant.RelayModeChatCompletions: - if textRequest.Messages == nil || len(textRequest.Messages) == 0 { + if len(textRequest.Messages) == 0 { return nil, errors.New("field messages is required") } case relayconstant.RelayModeEmbeddings: case relayconstant.RelayModeModerations: - if textRequest.Input == "" || textRequest.Input == nil { + if textRequest.Input == nil || textRequest.Input == "" { return nil, errors.New("field input is required") } case relayconstant.RelayModeEdits: @@ -264,7 +264,7 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo return 0, 0, service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) } if userQuota-preConsumedQuota < 0 { - return 0, 0, service.OpenAIErrorWrapperLocal(errors.New(fmt.Sprintf("chat pre-consumed quota failed, user quota: %d, need quota: %d", userQuota, preConsumedQuota)), "insufficient_user_quota", http.StatusBadRequest) + return 0, 0, service.OpenAIErrorWrapperLocal(fmt.Errorf("chat pre-consumed quota failed, user quota: %d, need quota: %d", userQuota, preConsumedQuota), "insufficient_user_quota", http.StatusBadRequest) } err = model.CacheDecreaseUserQuota(relayInfo.UserId, preConsumedQuota) if err != nil { @@ -298,13 +298,14 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo func returnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, userQuota int, preConsumedQuota int) { if preConsumedQuota != 0 { - go func(ctx context.Context) { - // return pre-consumed quota - err := model.PostConsumeTokenQuota(relayInfo, userQuota, -preConsumedQuota, 0, false) + go func() { + relayInfoCopy := *relayInfo + + err := model.PostConsumeTokenQuota(&relayInfoCopy, userQuota, -preConsumedQuota, 0, false) if err != nil { common.SysError("error return pre-consumed quota: " + err.Error()) } - }(c) + }() } }