refactor: improve validation logic and error handling in relay-text.go
- Simplified validation checks for MaxTokens and Messages fields. - Enhanced error messages for better clarity. - Updated goroutine to avoid passing context unnecessarily.
This commit is contained in:
@@ -2,11 +2,9 @@ package relay
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/bytedance/sonic"
|
|
||||||
"io"
|
"io"
|
||||||
"math"
|
"math"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -20,6 +18,8 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/bytedance/sonic"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -36,7 +36,7 @@ func getAndValidateTextRequest(c *gin.Context, relayInfo *relaycommon.RelayInfo)
|
|||||||
textRequest.Model = c.Param("model")
|
textRequest.Model = c.Param("model")
|
||||||
}
|
}
|
||||||
|
|
||||||
if textRequest.MaxTokens < 0 || textRequest.MaxTokens > math.MaxInt32/2 {
|
if textRequest.MaxTokens > math.MaxInt32/2 {
|
||||||
return nil, errors.New("max_tokens is invalid")
|
return nil, errors.New("max_tokens is invalid")
|
||||||
}
|
}
|
||||||
if textRequest.Model == "" {
|
if textRequest.Model == "" {
|
||||||
@@ -48,12 +48,12 @@ func getAndValidateTextRequest(c *gin.Context, relayInfo *relaycommon.RelayInfo)
|
|||||||
return nil, errors.New("field prompt is required")
|
return nil, errors.New("field prompt is required")
|
||||||
}
|
}
|
||||||
case relayconstant.RelayModeChatCompletions:
|
case relayconstant.RelayModeChatCompletions:
|
||||||
if textRequest.Messages == nil || len(textRequest.Messages) == 0 {
|
if len(textRequest.Messages) == 0 {
|
||||||
return nil, errors.New("field messages is required")
|
return nil, errors.New("field messages is required")
|
||||||
}
|
}
|
||||||
case relayconstant.RelayModeEmbeddings:
|
case relayconstant.RelayModeEmbeddings:
|
||||||
case relayconstant.RelayModeModerations:
|
case relayconstant.RelayModeModerations:
|
||||||
if textRequest.Input == "" || textRequest.Input == nil {
|
if textRequest.Input == nil || textRequest.Input == "" {
|
||||||
return nil, errors.New("field input is required")
|
return nil, errors.New("field input is required")
|
||||||
}
|
}
|
||||||
case relayconstant.RelayModeEdits:
|
case relayconstant.RelayModeEdits:
|
||||||
@@ -264,7 +264,7 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
|
|||||||
return 0, 0, service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
return 0, 0, service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
||||||
}
|
}
|
||||||
if userQuota-preConsumedQuota < 0 {
|
if userQuota-preConsumedQuota < 0 {
|
||||||
return 0, 0, service.OpenAIErrorWrapperLocal(errors.New(fmt.Sprintf("chat pre-consumed quota failed, user quota: %d, need quota: %d", userQuota, preConsumedQuota)), "insufficient_user_quota", http.StatusBadRequest)
|
return 0, 0, service.OpenAIErrorWrapperLocal(fmt.Errorf("chat pre-consumed quota failed, user quota: %d, need quota: %d", userQuota, preConsumedQuota), "insufficient_user_quota", http.StatusBadRequest)
|
||||||
}
|
}
|
||||||
err = model.CacheDecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
|
err = model.CacheDecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -298,13 +298,14 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
|
|||||||
|
|
||||||
func returnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, userQuota int, preConsumedQuota int) {
|
func returnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, userQuota int, preConsumedQuota int) {
|
||||||
if preConsumedQuota != 0 {
|
if preConsumedQuota != 0 {
|
||||||
go func(ctx context.Context) {
|
go func() {
|
||||||
// return pre-consumed quota
|
relayInfoCopy := *relayInfo
|
||||||
err := model.PostConsumeTokenQuota(relayInfo, userQuota, -preConsumedQuota, 0, false)
|
|
||||||
|
err := model.PostConsumeTokenQuota(&relayInfoCopy, userQuota, -preConsumedQuota, 0, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("error return pre-consumed quota: " + err.Error())
|
common.SysError("error return pre-consumed quota: " + err.Error())
|
||||||
}
|
}
|
||||||
}(c)
|
}()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user