Merge remote-tracking branch 'origin/alpha' into alpha
This commit is contained in:
@@ -445,7 +445,7 @@ func testAllChannels(notify bool) error {
|
|||||||
|
|
||||||
// disable channel
|
// disable channel
|
||||||
if isChannelEnabled && shouldBanChannel && channel.GetAutoBan() {
|
if isChannelEnabled && shouldBanChannel && channel.GetAutoBan() {
|
||||||
go processChannelError(result.context, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(result.context, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
|
processChannelError(result.context, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(result.context, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
|
||||||
}
|
}
|
||||||
|
|
||||||
// enable channel
|
// enable channel
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package controller
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/bytedance/gopkg/util/gopool"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -61,8 +62,8 @@ func geminiRelayHandler(c *gin.Context, info *relaycommon.RelayInfo) *types.NewA
|
|||||||
func Relay(c *gin.Context, relayFormat types.RelayFormat) {
|
func Relay(c *gin.Context, relayFormat types.RelayFormat) {
|
||||||
|
|
||||||
requestId := c.GetString(common.RequestIdKey)
|
requestId := c.GetString(common.RequestIdKey)
|
||||||
group := c.GetString("group")
|
group := common.GetContextKeyString(c, constant.ContextKeyUsingGroup)
|
||||||
originalModel := c.GetString("original_model")
|
originalModel := common.GetContextKeyString(c, constant.ContextKeyOriginalModel)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
newAPIError *types.NewAPIError
|
newAPIError *types.NewAPIError
|
||||||
@@ -172,35 +173,9 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
|
|||||||
|
|
||||||
if newAPIError == nil {
|
if newAPIError == nil {
|
||||||
return
|
return
|
||||||
} else {
|
|
||||||
if constant.ErrorLogEnabled && types.IsRecordErrorLog(newAPIError) {
|
|
||||||
// 保存错误日志到mysql中
|
|
||||||
userId := c.GetInt("id")
|
|
||||||
tokenName := c.GetString("token_name")
|
|
||||||
modelName := c.GetString("original_model")
|
|
||||||
tokenId := c.GetInt("token_id")
|
|
||||||
userGroup := c.GetString("group")
|
|
||||||
channelId := c.GetInt("channel_id")
|
|
||||||
other := make(map[string]interface{})
|
|
||||||
other["error_type"] = newAPIError.GetErrorType()
|
|
||||||
other["error_code"] = newAPIError.GetErrorCode()
|
|
||||||
other["status_code"] = newAPIError.StatusCode
|
|
||||||
other["channel_id"] = channelId
|
|
||||||
other["channel_name"] = c.GetString("channel_name")
|
|
||||||
other["channel_type"] = c.GetInt("channel_type")
|
|
||||||
adminInfo := make(map[string]interface{})
|
|
||||||
adminInfo["use_channel"] = c.GetStringSlice("use_channel")
|
|
||||||
isMultiKey := common.GetContextKeyBool(c, constant.ContextKeyChannelIsMultiKey)
|
|
||||||
if isMultiKey {
|
|
||||||
adminInfo["is_multi_key"] = true
|
|
||||||
adminInfo["multi_key_index"] = common.GetContextKeyInt(c, constant.ContextKeyChannelMultiKeyIndex)
|
|
||||||
}
|
|
||||||
other["admin_info"] = adminInfo
|
|
||||||
model.RecordErrorLog(c, userId, channelId, modelName, tokenName, newAPIError.MaskSensitiveError(), tokenId, 0, false, userGroup, other)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
go processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
|
processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
|
||||||
|
|
||||||
if !shouldRetry(c, newAPIError, common.RetryTimes-i) {
|
if !shouldRetry(c, newAPIError, common.RetryTimes-i) {
|
||||||
break
|
break
|
||||||
@@ -298,12 +273,42 @@ func shouldRetry(c *gin.Context, openaiErr *types.NewAPIError, retryTimes int) b
|
|||||||
}
|
}
|
||||||
|
|
||||||
func processChannelError(c *gin.Context, channelError types.ChannelError, err *types.NewAPIError) {
|
func processChannelError(c *gin.Context, channelError types.ChannelError, err *types.NewAPIError) {
|
||||||
// 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况
|
|
||||||
// do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
|
|
||||||
logger.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelError.ChannelId, err.StatusCode, err.Error()))
|
logger.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelError.ChannelId, err.StatusCode, err.Error()))
|
||||||
if service.ShouldDisableChannel(channelError.ChannelId, err) && channelError.AutoBan {
|
|
||||||
service.DisableChannel(channelError, err.Error())
|
gopool.Go(func() {
|
||||||
|
// 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况
|
||||||
|
// do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
|
||||||
|
if service.ShouldDisableChannel(channelError.ChannelId, err) && channelError.AutoBan {
|
||||||
|
service.DisableChannel(channelError, err.Error())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
if constant.ErrorLogEnabled && types.IsRecordErrorLog(err) {
|
||||||
|
// 保存错误日志到mysql中
|
||||||
|
userId := c.GetInt("id")
|
||||||
|
tokenName := c.GetString("token_name")
|
||||||
|
modelName := c.GetString("original_model")
|
||||||
|
tokenId := c.GetInt("token_id")
|
||||||
|
userGroup := c.GetString("group")
|
||||||
|
channelId := c.GetInt("channel_id")
|
||||||
|
other := make(map[string]interface{})
|
||||||
|
other["error_type"] = err.GetErrorType()
|
||||||
|
other["error_code"] = err.GetErrorCode()
|
||||||
|
other["status_code"] = err.StatusCode
|
||||||
|
other["channel_id"] = channelId
|
||||||
|
other["channel_name"] = c.GetString("channel_name")
|
||||||
|
other["channel_type"] = c.GetInt("channel_type")
|
||||||
|
adminInfo := make(map[string]interface{})
|
||||||
|
adminInfo["use_channel"] = c.GetStringSlice("use_channel")
|
||||||
|
isMultiKey := common.GetContextKeyBool(c, constant.ContextKeyChannelIsMultiKey)
|
||||||
|
if isMultiKey {
|
||||||
|
adminInfo["is_multi_key"] = true
|
||||||
|
adminInfo["multi_key_index"] = common.GetContextKeyInt(c, constant.ContextKeyChannelMultiKeyIndex)
|
||||||
|
}
|
||||||
|
other["admin_info"] = adminInfo
|
||||||
|
model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.MaskSensitiveError(), tokenId, 0, false, userGroup, other)
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func RelayMidjourney(c *gin.Context) {
|
func RelayMidjourney(c *gin.Context) {
|
||||||
|
|||||||
@@ -9,9 +9,11 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type ConditionOperation struct {
|
type ConditionOperation struct {
|
||||||
Path string `json:"path"` // JSON路径
|
Path string `json:"path"` // JSON路径
|
||||||
Mode string `json:"mode"` // full, prefix, suffix, contains
|
Mode string `json:"mode"` // full, prefix, suffix, contains, gt, gte, lt, lte
|
||||||
Value string `json:"value"` // 匹配的值
|
Value interface{} `json:"value"` // 匹配的值
|
||||||
|
Invert bool `json:"invert"` // 反选功能,true表示取反结果
|
||||||
|
PassMissingKey bool `json:"pass_missing_key"` // 未获取到json key时的行为
|
||||||
}
|
}
|
||||||
|
|
||||||
type ParamOperation struct {
|
type ParamOperation struct {
|
||||||
@@ -34,11 +36,7 @@ func ApplyParamOverride(jsonData []byte, paramOverride map[string]interface{}) (
|
|||||||
if operations, ok := tryParseOperations(paramOverride); ok {
|
if operations, ok := tryParseOperations(paramOverride); ok {
|
||||||
// 使用新方法
|
// 使用新方法
|
||||||
result, err := applyOperations(string(jsonData), operations)
|
result, err := applyOperations(string(jsonData), operations)
|
||||||
if err != nil {
|
return []byte(result), err
|
||||||
// 新方法失败,回退到旧方法
|
|
||||||
return applyOperationsLegacy(jsonData, paramOverride)
|
|
||||||
}
|
|
||||||
return []byte(result), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 直接使用旧方法
|
// 直接使用旧方法
|
||||||
@@ -95,9 +93,15 @@ func tryParseOperations(paramOverride map[string]interface{}) ([]ParamOperation,
|
|||||||
if mode, ok := condMap["mode"].(string); ok {
|
if mode, ok := condMap["mode"].(string); ok {
|
||||||
condition.Mode = mode
|
condition.Mode = mode
|
||||||
}
|
}
|
||||||
if value, ok := condMap["value"].(string); ok {
|
if value, ok := condMap["value"]; ok {
|
||||||
condition.Value = value
|
condition.Value = value
|
||||||
}
|
}
|
||||||
|
if invert, ok := condMap["invert"].(bool); ok {
|
||||||
|
condition.Invert = invert
|
||||||
|
}
|
||||||
|
if passMissingKey, ok := condMap["pass_missing_key"].(bool); ok {
|
||||||
|
condition.PassMissingKey = passMissingKey
|
||||||
|
}
|
||||||
operation.Conditions = append(operation.Conditions, condition)
|
operation.Conditions = append(operation.Conditions, condition)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -116,52 +120,131 @@ func tryParseOperations(paramOverride map[string]interface{}) ([]ParamOperation,
|
|||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkConditions(jsonStr string, conditions []ConditionOperation, logic string) bool {
|
func checkConditions(jsonStr string, conditions []ConditionOperation, logic string) (bool, error) {
|
||||||
if len(conditions) == 0 {
|
if len(conditions) == 0 {
|
||||||
return true // 没有条件,直接通过
|
return true, nil // 没有条件,直接通过
|
||||||
}
|
}
|
||||||
results := make([]bool, len(conditions))
|
results := make([]bool, len(conditions))
|
||||||
|
|
||||||
for i, condition := range conditions {
|
for i, condition := range conditions {
|
||||||
results[i] = checkSingleCondition(jsonStr, condition)
|
result, err := checkSingleCondition(jsonStr, condition)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
results[i] = result
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.ToUpper(logic) == "AND" {
|
if strings.ToUpper(logic) == "AND" {
|
||||||
for _, result := range results {
|
for _, result := range results {
|
||||||
if !result {
|
if !result {
|
||||||
return false
|
return false, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return true
|
return true, nil
|
||||||
} else {
|
} else {
|
||||||
for _, result := range results {
|
for _, result := range results {
|
||||||
if result {
|
if result {
|
||||||
return true
|
return true, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return false
|
return false, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkSingleCondition(jsonStr string, condition ConditionOperation) bool {
|
func checkSingleCondition(jsonStr string, condition ConditionOperation) (bool, error) {
|
||||||
value := gjson.Get(jsonStr, condition.Path)
|
value := gjson.Get(jsonStr, condition.Path)
|
||||||
if !value.Exists() {
|
if !value.Exists() {
|
||||||
return false
|
if condition.PassMissingKey {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
valueStr := value.String()
|
// 利用gjson的类型解析
|
||||||
targetStr := condition.Value
|
targetBytes, err := json.Marshal(condition.Value)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("failed to marshal condition value: %v", err)
|
||||||
|
}
|
||||||
|
targetValue := gjson.ParseBytes(targetBytes)
|
||||||
|
|
||||||
switch strings.ToLower(condition.Mode) {
|
result, err := compareGjsonValues(value, targetValue, strings.ToLower(condition.Mode))
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("comparison failed for path %s: %v", condition.Path, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if condition.Invert {
|
||||||
|
result = !result
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// compareGjsonValues 直接比较两个gjson.Result,支持所有比较模式
|
||||||
|
func compareGjsonValues(jsonValue, targetValue gjson.Result, mode string) (bool, error) {
|
||||||
|
switch mode {
|
||||||
case "full":
|
case "full":
|
||||||
return valueStr == targetStr
|
return compareEqual(jsonValue, targetValue)
|
||||||
case "prefix":
|
case "prefix":
|
||||||
return strings.HasPrefix(valueStr, targetStr)
|
return strings.HasPrefix(jsonValue.String(), targetValue.String()), nil
|
||||||
case "suffix":
|
case "suffix":
|
||||||
return strings.HasSuffix(valueStr, targetStr)
|
return strings.HasSuffix(jsonValue.String(), targetValue.String()), nil
|
||||||
case "contains":
|
case "contains":
|
||||||
return strings.Contains(valueStr, targetStr)
|
return strings.Contains(jsonValue.String(), targetValue.String()), nil
|
||||||
|
case "gt":
|
||||||
|
return compareNumeric(jsonValue, targetValue, "gt")
|
||||||
|
case "gte":
|
||||||
|
return compareNumeric(jsonValue, targetValue, "gte")
|
||||||
|
case "lt":
|
||||||
|
return compareNumeric(jsonValue, targetValue, "lt")
|
||||||
|
case "lte":
|
||||||
|
return compareNumeric(jsonValue, targetValue, "lte")
|
||||||
default:
|
default:
|
||||||
return valueStr == targetStr // 默认精准匹配
|
return false, fmt.Errorf("unsupported comparison mode: %s", mode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func compareEqual(jsonValue, targetValue gjson.Result) (bool, error) {
|
||||||
|
// 对布尔值特殊处理
|
||||||
|
if (jsonValue.Type == gjson.True || jsonValue.Type == gjson.False) &&
|
||||||
|
(targetValue.Type == gjson.True || targetValue.Type == gjson.False) {
|
||||||
|
return jsonValue.Bool() == targetValue.Bool(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果类型不同,报错
|
||||||
|
if jsonValue.Type != targetValue.Type {
|
||||||
|
return false, fmt.Errorf("compare for different types, got %v and %v", jsonValue.Type, targetValue.Type)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch jsonValue.Type {
|
||||||
|
case gjson.True, gjson.False:
|
||||||
|
return jsonValue.Bool() == targetValue.Bool(), nil
|
||||||
|
case gjson.Number:
|
||||||
|
return jsonValue.Num == targetValue.Num, nil
|
||||||
|
case gjson.String:
|
||||||
|
return jsonValue.String() == targetValue.String(), nil
|
||||||
|
default:
|
||||||
|
return jsonValue.String() == targetValue.String(), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func compareNumeric(jsonValue, targetValue gjson.Result, operator string) (bool, error) {
|
||||||
|
// 只有数字类型才支持数值比较
|
||||||
|
if jsonValue.Type != gjson.Number || targetValue.Type != gjson.Number {
|
||||||
|
return false, fmt.Errorf("numeric comparison requires both values to be numbers, got %v and %v", jsonValue.Type, targetValue.Type)
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonNum := jsonValue.Num
|
||||||
|
targetNum := targetValue.Num
|
||||||
|
|
||||||
|
switch operator {
|
||||||
|
case "gt":
|
||||||
|
return jsonNum > targetNum, nil
|
||||||
|
case "gte":
|
||||||
|
return jsonNum >= targetNum, nil
|
||||||
|
case "lt":
|
||||||
|
return jsonNum < targetNum, nil
|
||||||
|
case "lte":
|
||||||
|
return jsonNum <= targetNum, nil
|
||||||
|
default:
|
||||||
|
return false, fmt.Errorf("unsupported numeric operator: %s", operator)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -184,11 +267,14 @@ func applyOperations(jsonStr string, operations []ParamOperation) (string, error
|
|||||||
result := jsonStr
|
result := jsonStr
|
||||||
for _, op := range operations {
|
for _, op := range operations {
|
||||||
// 检查条件是否满足
|
// 检查条件是否满足
|
||||||
if !checkConditions(result, op.Conditions, op.Logic) {
|
ok, err := checkConditions(result, op.Conditions, op.Logic)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
continue // 条件不满足,跳过当前操作
|
continue // 条件不满足,跳过当前操作
|
||||||
}
|
}
|
||||||
|
|
||||||
var err error
|
|
||||||
switch op.Mode {
|
switch op.Mode {
|
||||||
case "delete":
|
case "delete":
|
||||||
result, err = sjson.Delete(result, op.Path)
|
result, err = sjson.Delete(result, op.Path)
|
||||||
|
|||||||
@@ -84,6 +84,8 @@ func Path2RelayMode(path string) int {
|
|||||||
relayMode = RelayModeRealtime
|
relayMode = RelayModeRealtime
|
||||||
} else if strings.HasPrefix(path, "/v1beta/models") || strings.HasPrefix(path, "/v1/models") {
|
} else if strings.HasPrefix(path, "/v1beta/models") || strings.HasPrefix(path, "/v1/models") {
|
||||||
relayMode = RelayModeGemini
|
relayMode = RelayModeGemini
|
||||||
|
} else if strings.HasPrefix(path, "/mj") {
|
||||||
|
relayMode = Path2RelayModeMidjourney(path)
|
||||||
}
|
}
|
||||||
return relayMode
|
return relayMode
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user