refactor: improve validation logic and error handling in relay-text.go

- Simplified validation checks for MaxTokens and Messages fields.
- Enhanced error messages for better clarity.
- Updated goroutine to avoid passing context unnecessarily.
This commit is contained in:
CalciumIon
2024-12-01 08:24:41 +08:00
parent 6d4edc1f5b
commit a9f739a7e2

View File

@@ -2,11 +2,9 @@ package relay
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/bytedance/sonic"
"io" "io"
"math" "math"
"net/http" "net/http"
@@ -20,6 +18,8 @@ import (
"strings" "strings"
"time" "time"
"github.com/bytedance/sonic"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@@ -36,7 +36,7 @@ func getAndValidateTextRequest(c *gin.Context, relayInfo *relaycommon.RelayInfo)
textRequest.Model = c.Param("model") textRequest.Model = c.Param("model")
} }
if textRequest.MaxTokens < 0 || textRequest.MaxTokens > math.MaxInt32/2 { if textRequest.MaxTokens > math.MaxInt32/2 {
return nil, errors.New("max_tokens is invalid") return nil, errors.New("max_tokens is invalid")
} }
if textRequest.Model == "" { if textRequest.Model == "" {
@@ -48,12 +48,12 @@ func getAndValidateTextRequest(c *gin.Context, relayInfo *relaycommon.RelayInfo)
return nil, errors.New("field prompt is required") return nil, errors.New("field prompt is required")
} }
case relayconstant.RelayModeChatCompletions: case relayconstant.RelayModeChatCompletions:
if textRequest.Messages == nil || len(textRequest.Messages) == 0 { if len(textRequest.Messages) == 0 {
return nil, errors.New("field messages is required") return nil, errors.New("field messages is required")
} }
case relayconstant.RelayModeEmbeddings: case relayconstant.RelayModeEmbeddings:
case relayconstant.RelayModeModerations: case relayconstant.RelayModeModerations:
if textRequest.Input == "" || textRequest.Input == nil { if textRequest.Input == nil || textRequest.Input == "" {
return nil, errors.New("field input is required") return nil, errors.New("field input is required")
} }
case relayconstant.RelayModeEdits: case relayconstant.RelayModeEdits:
@@ -264,7 +264,7 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
return 0, 0, service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) return 0, 0, service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
} }
if userQuota-preConsumedQuota < 0 { if userQuota-preConsumedQuota < 0 {
return 0, 0, service.OpenAIErrorWrapperLocal(errors.New(fmt.Sprintf("chat pre-consumed quota failed, user quota: %d, need quota: %d", userQuota, preConsumedQuota)), "insufficient_user_quota", http.StatusBadRequest) return 0, 0, service.OpenAIErrorWrapperLocal(fmt.Errorf("chat pre-consumed quota failed, user quota: %d, need quota: %d", userQuota, preConsumedQuota), "insufficient_user_quota", http.StatusBadRequest)
} }
err = model.CacheDecreaseUserQuota(relayInfo.UserId, preConsumedQuota) err = model.CacheDecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
if err != nil { if err != nil {
@@ -298,13 +298,14 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
func returnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, userQuota int, preConsumedQuota int) { func returnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, userQuota int, preConsumedQuota int) {
if preConsumedQuota != 0 { if preConsumedQuota != 0 {
go func(ctx context.Context) { go func() {
// return pre-consumed quota relayInfoCopy := *relayInfo
err := model.PostConsumeTokenQuota(relayInfo, userQuota, -preConsumedQuota, 0, false)
err := model.PostConsumeTokenQuota(&relayInfoCopy, userQuota, -preConsumedQuota, 0, false)
if err != nil { if err != nil {
common.SysError("error return pre-consumed quota: " + err.Error()) common.SysError("error return pre-consumed quota: " + err.Error())
} }
}(c) }()
} }
} }