diff --git a/.env.example b/.env.example index bece06db..ea246427 100644 --- a/.env.example +++ b/.env.example @@ -7,6 +7,8 @@ # 调试相关配置 # 启用pprof # ENABLE_PPROF=true +# 启用调试模式 +# DEBUG=true # 数据库相关配置 # 数据库连接字符串 @@ -41,6 +43,14 @@ # 更新任务启用 # UPDATE_TASK=true +# 对话超时设置 +# 所有请求超时时间,单位秒,默认为0,表示不限制 +# RELAY_TIMEOUT=0 +# 流模式无响应超时时间,单位秒,如果出现空补全可以尝试改为更大值 +# STREAMING_TIMEOUT=120 + +# Gemini 识别图片 最大图片数量 +# GEMINI_VISION_MAX_IMAGE_NUM=16 # 会话密钥 # SESSION_SECRET=random_string @@ -58,8 +68,6 @@ # GET_MEDIA_TOKEN_NOT_STREAM=true # 设置 Dify 渠道是否输出工作流和节点信息到客户端 # DIFY_DEBUG=true -# 设置流式一次回复的超时时间 -# STREAMING_TIMEOUT=90 # 节点类型 diff --git a/.github/PULL_REQUEST_TEMPLATE/pull_request_template.md b/.github/PULL_REQUEST_TEMPLATE/pull_request_template.md new file mode 100644 index 00000000..4f6e41ac --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE/pull_request_template.md @@ -0,0 +1,19 @@ +### PR 类型 + +- [ ] Bug 修复 +- [ ] 新功能 +- [ ] 文档更新 +- [ ] 其他 + +### PR 是否包含破坏性更新? + +- [ ] 是 +- [ ] 否 + +### PR 描述 + +**请在下方详细描述您的 PR,包括目的、实现细节等。** + +### **重要提示** + +**所有 PR 都必须提交到 `alpha` 分支。请确保您的 PR 目标分支是 `alpha`。** diff --git a/.github/workflows/macos-release.yml b/.github/workflows/macos-release.yml index 3210065b..1bc786ac 100644 --- a/.github/workflows/macos-release.yml +++ b/.github/workflows/macos-release.yml @@ -26,6 +26,7 @@ jobs: - name: Build Frontend env: CI: "" + NODE_OPTIONS: "--max-old-space-size=4096" run: | cd web bun install diff --git a/.github/workflows/pr-target-branch-check.yml b/.github/workflows/pr-target-branch-check.yml new file mode 100644 index 00000000..e7bd4c81 --- /dev/null +++ b/.github/workflows/pr-target-branch-check.yml @@ -0,0 +1,21 @@ +name: Check PR Branching Strategy +on: + pull_request: + types: [opened, synchronize, reopened, edited] + +jobs: + check-branching-strategy: + runs-on: ubuntu-latest + steps: + - name: Enforce branching strategy + run: | + if [[ "${{ github.base_ref }}" == "main" ]]; then + if [[ "${{ github.head_ref }}" != "alpha" ]]; then + echo "Error: Pull requests to 'main' are only allowed from the 'alpha' branch." + exit 1 + fi + elif [[ "${{ github.base_ref }}" != "alpha" ]]; then + echo "Error: Pull requests must be targeted to the 'alpha' or 'main' branch." + exit 1 + fi + echo "Branching strategy check passed." \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index 214ceaa3..3b42089b 100644 --- a/Dockerfile +++ b/Dockerfile @@ -24,8 +24,7 @@ RUN go build -ldflags "-s -w -X 'one-api/common.Version=$(cat VERSION)'" -o one- FROM alpine -RUN apk update \ - && apk upgrade \ +RUN apk upgrade --no-cache \ && apk add --no-cache ca-certificates tzdata ffmpeg \ && update-ca-certificates diff --git a/README.en.md b/README.en.md index 10a3cdb0..b4ae921a 100644 --- a/README.en.md +++ b/README.en.md @@ -100,7 +100,7 @@ This version supports multiple models, please refer to [API Documentation-Relay For detailed configuration instructions, please refer to [Installation Guide-Environment Variables Configuration](https://docs.newapi.pro/installation/environment-variables): - `GENERATE_DEFAULT_TOKEN`: Whether to generate initial tokens for newly registered users, default is `false` -- `STREAMING_TIMEOUT`: Streaming response timeout, default is 60 seconds +- `STREAMING_TIMEOUT`: Streaming response timeout, default is 120 seconds - `DIFY_DEBUG`: Whether to output workflow and node information for Dify channels, default is `true` - `FORCE_STREAM_OPTION`: Whether to override client stream_options parameter, default is `true` - `GET_MEDIA_TOKEN`: Whether to count image tokens, default is `true` diff --git a/README.md b/README.md index 6ba3574c..05423548 100644 --- a/README.md +++ b/README.md @@ -27,9 +27,6 @@ GoReportCard - - CodeRabbit Pull Request Reviews -

@@ -103,7 +100,7 @@ New API提供了丰富的功能,详细特性请参考[特性说明](https://do 详细配置说明请参考[安装指南-环境变量配置](https://docs.newapi.pro/installation/environment-variables): - `GENERATE_DEFAULT_TOKEN`:是否为新注册用户生成初始令牌,默认为 `false` -- `STREAMING_TIMEOUT`:流式回复超时时间,默认60秒 +- `STREAMING_TIMEOUT`:流式回复超时时间,默认120秒 - `DIFY_DEBUG`:Dify渠道是否输出工作流和节点信息,默认 `true` - `FORCE_STREAM_OPTION`:是否覆盖客户端stream_options参数,默认 `true` - `GET_MEDIA_TOKEN`:是否统计图片token,默认 `true` diff --git a/common/api_type.go b/common/api_type.go new file mode 100644 index 00000000..d9071236 --- /dev/null +++ b/common/api_type.go @@ -0,0 +1,71 @@ +package common + +import "one-api/constant" + +func ChannelType2APIType(channelType int) (int, bool) { + apiType := -1 + switch channelType { + case constant.ChannelTypeOpenAI: + apiType = constant.APITypeOpenAI + case constant.ChannelTypeAnthropic: + apiType = constant.APITypeAnthropic + case constant.ChannelTypeBaidu: + apiType = constant.APITypeBaidu + case constant.ChannelTypePaLM: + apiType = constant.APITypePaLM + case constant.ChannelTypeZhipu: + apiType = constant.APITypeZhipu + case constant.ChannelTypeAli: + apiType = constant.APITypeAli + case constant.ChannelTypeXunfei: + apiType = constant.APITypeXunfei + case constant.ChannelTypeAIProxyLibrary: + apiType = constant.APITypeAIProxyLibrary + case constant.ChannelTypeTencent: + apiType = constant.APITypeTencent + case constant.ChannelTypeGemini: + apiType = constant.APITypeGemini + case constant.ChannelTypeZhipu_v4: + apiType = constant.APITypeZhipuV4 + case constant.ChannelTypeOllama: + apiType = constant.APITypeOllama + case constant.ChannelTypePerplexity: + apiType = constant.APITypePerplexity + case constant.ChannelTypeAws: + apiType = constant.APITypeAws + case constant.ChannelTypeCohere: + apiType = constant.APITypeCohere + case constant.ChannelTypeDify: + apiType = constant.APITypeDify + case constant.ChannelTypeJina: + apiType = constant.APITypeJina + case constant.ChannelCloudflare: + apiType = constant.APITypeCloudflare + case constant.ChannelTypeSiliconFlow: + apiType = constant.APITypeSiliconFlow + case constant.ChannelTypeVertexAi: + apiType = constant.APITypeVertexAi + case constant.ChannelTypeMistral: + apiType = constant.APITypeMistral + case constant.ChannelTypeDeepSeek: + apiType = constant.APITypeDeepSeek + case constant.ChannelTypeMokaAI: + apiType = constant.APITypeMokaAI + case constant.ChannelTypeVolcEngine: + apiType = constant.APITypeVolcEngine + case constant.ChannelTypeBaiduV2: + apiType = constant.APITypeBaiduV2 + case constant.ChannelTypeOpenRouter: + apiType = constant.APITypeOpenRouter + case constant.ChannelTypeXinference: + apiType = constant.APITypeXinference + case constant.ChannelTypeXai: + apiType = constant.APITypeXai + case constant.ChannelTypeCoze: + apiType = constant.APITypeCoze + } + if apiType == -1 { + return constant.APITypeOpenAI, false + } + return apiType, true +} diff --git a/common/constants.go b/common/constants.go index bee00506..e4f5f047 100644 --- a/common/constants.go +++ b/common/constants.go @@ -193,107 +193,3 @@ const ( ChannelStatusManuallyDisabled = 2 // also don't use 0 ChannelStatusAutoDisabled = 3 ) - -const ( - ChannelTypeUnknown = 0 - ChannelTypeOpenAI = 1 - ChannelTypeMidjourney = 2 - ChannelTypeAzure = 3 - ChannelTypeOllama = 4 - ChannelTypeMidjourneyPlus = 5 - ChannelTypeOpenAIMax = 6 - ChannelTypeOhMyGPT = 7 - ChannelTypeCustom = 8 - ChannelTypeAILS = 9 - ChannelTypeAIProxy = 10 - ChannelTypePaLM = 11 - ChannelTypeAPI2GPT = 12 - ChannelTypeAIGC2D = 13 - ChannelTypeAnthropic = 14 - ChannelTypeBaidu = 15 - ChannelTypeZhipu = 16 - ChannelTypeAli = 17 - ChannelTypeXunfei = 18 - ChannelType360 = 19 - ChannelTypeOpenRouter = 20 - ChannelTypeAIProxyLibrary = 21 - ChannelTypeFastGPT = 22 - ChannelTypeTencent = 23 - ChannelTypeGemini = 24 - ChannelTypeMoonshot = 25 - ChannelTypeZhipu_v4 = 26 - ChannelTypePerplexity = 27 - ChannelTypeLingYiWanWu = 31 - ChannelTypeAws = 33 - ChannelTypeCohere = 34 - ChannelTypeMiniMax = 35 - ChannelTypeSunoAPI = 36 - ChannelTypeDify = 37 - ChannelTypeJina = 38 - ChannelCloudflare = 39 - ChannelTypeSiliconFlow = 40 - ChannelTypeVertexAi = 41 - ChannelTypeMistral = 42 - ChannelTypeDeepSeek = 43 - ChannelTypeMokaAI = 44 - ChannelTypeVolcEngine = 45 - ChannelTypeBaiduV2 = 46 - ChannelTypeXinference = 47 - ChannelTypeXai = 48 - ChannelTypeCoze = 49 - ChannelTypeDummy // this one is only for count, do not add any channel after this - -) - -var ChannelBaseURLs = []string{ - "", // 0 - "https://api.openai.com", // 1 - "https://oa.api2d.net", // 2 - "", // 3 - "http://localhost:11434", // 4 - "https://api.openai-sb.com", // 5 - "https://api.openaimax.com", // 6 - "https://api.ohmygpt.com", // 7 - "", // 8 - "https://api.caipacity.com", // 9 - "https://api.aiproxy.io", // 10 - "", // 11 - "https://api.api2gpt.com", // 12 - "https://api.aigc2d.com", // 13 - "https://api.anthropic.com", // 14 - "https://aip.baidubce.com", // 15 - "https://open.bigmodel.cn", // 16 - "https://dashscope.aliyuncs.com", // 17 - "", // 18 - "https://api.360.cn", // 19 - "https://openrouter.ai/api", // 20 - "https://api.aiproxy.io", // 21 - "https://fastgpt.run/api/openapi", // 22 - "https://hunyuan.tencentcloudapi.com", //23 - "https://generativelanguage.googleapis.com", //24 - "https://api.moonshot.cn", //25 - "https://open.bigmodel.cn", //26 - "https://api.perplexity.ai", //27 - "", //28 - "", //29 - "", //30 - "https://api.lingyiwanwu.com", //31 - "", //32 - "", //33 - "https://api.cohere.ai", //34 - "https://api.minimax.chat", //35 - "", //36 - "https://api.dify.ai", //37 - "https://api.jina.ai", //38 - "https://api.cloudflare.com", //39 - "https://api.siliconflow.cn", //40 - "", //41 - "https://api.mistral.ai", //42 - "https://api.deepseek.com", //43 - "https://api.moka.ai", //44 - "https://ark.cn-beijing.volces.com", //45 - "https://qianfan.baidubce.com", //46 - "", //47 - "https://api.x.ai", //48 - "https://api.coze.cn", //49 -} diff --git a/common/endpoint_type.go b/common/endpoint_type.go new file mode 100644 index 00000000..a0ca73ea --- /dev/null +++ b/common/endpoint_type.go @@ -0,0 +1,41 @@ +package common + +import "one-api/constant" + +// GetEndpointTypesByChannelType 获取渠道最优先端点类型(所有的渠道都支持 OpenAI 端点) +func GetEndpointTypesByChannelType(channelType int, modelName string) []constant.EndpointType { + var endpointTypes []constant.EndpointType + switch channelType { + case constant.ChannelTypeJina: + endpointTypes = []constant.EndpointType{constant.EndpointTypeJinaRerank} + //case constant.ChannelTypeMidjourney, constant.ChannelTypeMidjourneyPlus: + // endpointTypes = []constant.EndpointType{constant.EndpointTypeMidjourney} + //case constant.ChannelTypeSunoAPI: + // endpointTypes = []constant.EndpointType{constant.EndpointTypeSuno} + //case constant.ChannelTypeKling: + // endpointTypes = []constant.EndpointType{constant.EndpointTypeKling} + //case constant.ChannelTypeJimeng: + // endpointTypes = []constant.EndpointType{constant.EndpointTypeJimeng} + case constant.ChannelTypeAws: + fallthrough + case constant.ChannelTypeAnthropic: + endpointTypes = []constant.EndpointType{constant.EndpointTypeAnthropic, constant.EndpointTypeOpenAI} + case constant.ChannelTypeVertexAi: + fallthrough + case constant.ChannelTypeGemini: + endpointTypes = []constant.EndpointType{constant.EndpointTypeGemini, constant.EndpointTypeOpenAI} + case constant.ChannelTypeOpenRouter: // OpenRouter 只支持 OpenAI 端点 + endpointTypes = []constant.EndpointType{constant.EndpointTypeOpenAI} + default: + if IsOpenAIResponseOnlyModel(modelName) { + endpointTypes = []constant.EndpointType{constant.EndpointTypeOpenAIResponse} + } else { + endpointTypes = []constant.EndpointType{constant.EndpointTypeOpenAI} + } + } + if IsImageGenerationModel(modelName) { + // add to first + endpointTypes = append([]constant.EndpointType{constant.EndpointTypeImageGeneration}, endpointTypes...) + } + return endpointTypes +} diff --git a/common/gin.go b/common/gin.go index 4a909dfc..62c4c692 100644 --- a/common/gin.go +++ b/common/gin.go @@ -2,10 +2,11 @@ package common import ( "bytes" - "encoding/json" "github.com/gin-gonic/gin" "io" + "one-api/constant" "strings" + "time" ) const KeyRequestBody = "key_request_body" @@ -31,7 +32,7 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error { } contentType := c.Request.Header.Get("Content-Type") if strings.HasPrefix(contentType, "application/json") { - err = json.Unmarshal(requestBody, &v) + err = UnmarshalJson(requestBody, &v) } else { // skip for now // TODO: someday non json request have variant model, we will need to implementation this @@ -43,3 +44,35 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error { c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) return nil } + +func SetContextKey(c *gin.Context, key constant.ContextKey, value any) { + c.Set(string(key), value) +} + +func GetContextKey(c *gin.Context, key constant.ContextKey) (any, bool) { + return c.Get(string(key)) +} + +func GetContextKeyString(c *gin.Context, key constant.ContextKey) string { + return c.GetString(string(key)) +} + +func GetContextKeyInt(c *gin.Context, key constant.ContextKey) int { + return c.GetInt(string(key)) +} + +func GetContextKeyBool(c *gin.Context, key constant.ContextKey) bool { + return c.GetBool(string(key)) +} + +func GetContextKeyStringSlice(c *gin.Context, key constant.ContextKey) []string { + return c.GetStringSlice(string(key)) +} + +func GetContextKeyStringMap(c *gin.Context, key constant.ContextKey) map[string]any { + return c.GetStringMap(string(key)) +} + +func GetContextKeyTime(c *gin.Context, key constant.ContextKey) time.Time { + return c.GetTime(string(key)) +} diff --git a/common/http.go b/common/http.go new file mode 100644 index 00000000..d2e824ef --- /dev/null +++ b/common/http.go @@ -0,0 +1,57 @@ +package common + +import ( + "bytes" + "fmt" + "io" + "net/http" + + "github.com/gin-gonic/gin" +) + +func CloseResponseBodyGracefully(httpResponse *http.Response) { + if httpResponse == nil || httpResponse.Body == nil { + return + } + err := httpResponse.Body.Close() + if err != nil { + SysError("failed to close response body: " + err.Error()) + } +} + +func IOCopyBytesGracefully(c *gin.Context, src *http.Response, data []byte) { + if c.Writer == nil { + return + } + + body := io.NopCloser(bytes.NewBuffer(data)) + + // We shouldn't set the header before we parse the response body, because the parse part may fail. + // And then we will have to send an error response, but in this case, the header has already been set. + // So the httpClient will be confused by the response. + // For example, Postman will report error, and we cannot check the response at all. + if src != nil { + for k, v := range src.Header { + // avoid setting Content-Length + if k == "Content-Length" { + continue + } + c.Writer.Header().Set(k, v[0]) + } + } + + // set Content-Length header manually BEFORE calling WriteHeader + c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(data))) + + // Write header with status code (this sends the headers) + if src != nil { + c.Writer.WriteHeader(src.StatusCode) + } else { + c.Writer.WriteHeader(http.StatusOK) + } + + _, err := io.Copy(c.Writer, body) + if err != nil { + LogError(c, fmt.Sprintf("failed to copy response body: %s", err.Error())) + } +} diff --git a/common/init.go b/common/init.go index c0caf0a1..d70a09dd 100644 --- a/common/init.go +++ b/common/init.go @@ -4,6 +4,7 @@ import ( "flag" "fmt" "log" + "one-api/constant" "os" "path/filepath" "strconv" @@ -24,7 +25,7 @@ func printHelp() { fmt.Println("Usage: one-api [--port ] [--log-dir ] [--version] [--help]") } -func LoadEnv() { +func InitEnv() { flag.Parse() if *PrintVersion { @@ -95,4 +96,25 @@ func LoadEnv() { GlobalWebRateLimitEnable = GetEnvOrDefaultBool("GLOBAL_WEB_RATE_LIMIT_ENABLE", true) GlobalWebRateLimitNum = GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT", 60) GlobalWebRateLimitDuration = int64(GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT_DURATION", 180)) + + initConstantEnv() +} + +func initConstantEnv() { + constant.StreamingTimeout = GetEnvOrDefault("STREAMING_TIMEOUT", 120) + constant.DifyDebug = GetEnvOrDefaultBool("DIFY_DEBUG", true) + constant.MaxFileDownloadMB = GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 20) + // ForceStreamOption 覆盖请求参数,强制返回usage信息 + constant.ForceStreamOption = GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true) + constant.GetMediaToken = GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true) + constant.GetMediaTokenNotStream = GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", true) + constant.UpdateTask = GetEnvOrDefaultBool("UPDATE_TASK", true) + constant.AzureDefaultAPIVersion = GetEnvOrDefaultString("AZURE_DEFAULT_API_VERSION", "2025-04-01-preview") + constant.GeminiVisionMaxImageNum = GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16) + constant.NotifyLimitCount = GetEnvOrDefault("NOTIFY_LIMIT_COUNT", 2) + constant.NotificationLimitDurationMinute = GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10) + // GenerateDefaultToken 是否生成初始令牌,默认关闭。 + constant.GenerateDefaultToken = GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false) + // 是否启用错误日志 + constant.ErrorLogEnabled = GetEnvOrDefaultBool("ERROR_LOG_ENABLED", false) } diff --git a/common/json.go b/common/json.go index cec8f16b..512ad0c3 100644 --- a/common/json.go +++ b/common/json.go @@ -5,12 +5,16 @@ import ( "encoding/json" ) -func DecodeJson(data []byte, v any) error { - return json.NewDecoder(bytes.NewReader(data)).Decode(v) +func UnmarshalJson(data []byte, v any) error { + return json.Unmarshal(data, v) } -func DecodeJsonStr(data string, v any) error { - return DecodeJson(StringToByteSlice(data), v) +func UnmarshalJsonStr(data string, v any) error { + return json.Unmarshal(StringToByteSlice(data), v) +} + +func DecodeJson(reader *bytes.Reader, v any) error { + return json.NewDecoder(reader).Decode(v) } func EncodeJson(v any) ([]byte, error) { diff --git a/common/model.go b/common/model.go new file mode 100644 index 00000000..14ca1911 --- /dev/null +++ b/common/model.go @@ -0,0 +1,42 @@ +package common + +import "strings" + +var ( + // OpenAIResponseOnlyModels is a list of models that are only available for OpenAI responses. + OpenAIResponseOnlyModels = []string{ + "o3-pro", + "o3-deep-research", + "o4-mini-deep-research", + } + ImageGenerationModels = []string{ + "dall-e-3", + "dall-e-2", + "gpt-image-1", + "prefix:imagen-", + "flux-", + "flux.1-", + } +) + +func IsOpenAIResponseOnlyModel(modelName string) bool { + for _, m := range OpenAIResponseOnlyModels { + if strings.Contains(modelName, m) { + return true + } + } + return false +} + +func IsImageGenerationModel(modelName string) bool { + modelName = strings.ToLower(modelName) + for _, m := range ImageGenerationModels { + if strings.Contains(modelName, m) { + return true + } + if strings.HasPrefix(m, "prefix:") && strings.HasPrefix(modelName, strings.TrimPrefix(m, "prefix:")) { + return true + } + } + return false +} diff --git a/common/page_info.go b/common/page_info.go new file mode 100644 index 00000000..20a90fa2 --- /dev/null +++ b/common/page_info.go @@ -0,0 +1,62 @@ +package common + +import ( + "github.com/gin-gonic/gin" + "strconv" +) + +type PageInfo struct { + Page int `json:"page"` // page num 页码 + PageSize int `json:"page_size"` // page size 页大小 + StartTimestamp int64 `json:"start_timestamp"` // 秒级 + EndTimestamp int64 `json:"end_timestamp"` // 秒级 + + Total int `json:"total"` // 总条数,后设置 + Items any `json:"items"` // 数据,后设置 +} + +func (p *PageInfo) GetStartIdx() int { + return (p.Page - 1) * p.PageSize +} + +func (p *PageInfo) GetEndIdx() int { + return p.Page * p.PageSize +} + +func (p *PageInfo) GetPageSize() int { + return p.PageSize +} + +func (p *PageInfo) GetPage() int { + return p.Page +} + +func (p *PageInfo) SetTotal(total int) { + p.Total = total +} + +func (p *PageInfo) SetItems(items any) { + p.Items = items +} + +func GetPageQuery(c *gin.Context) (*PageInfo, error) { + pageInfo := &PageInfo{} + err := c.BindQuery(pageInfo) + if err != nil { + return nil, err + } + if pageInfo.Page < 1 { + // 兼容 + page, _ := strconv.Atoi(c.Query("p")) + if page != 0 { + pageInfo.Page = page + } else { + pageInfo.Page = 1 + } + } + + if pageInfo.PageSize == 0 { + pageInfo.PageSize = ItemsPerPage + } + return pageInfo, nil +} diff --git a/common/redis.go b/common/redis.go index ba35331a..c7287837 100644 --- a/common/redis.go +++ b/common/redis.go @@ -16,6 +16,10 @@ import ( var RDB *redis.Client var RedisEnabled = true +func RedisKeyCacheSeconds() int { + return SyncFrequency +} + // InitRedisClient This function is called after init() func InitRedisClient() (err error) { if os.Getenv("REDIS_CONN_STRING") == "" { @@ -141,7 +145,11 @@ func RedisHSetObj(key string, obj interface{}, expiration time.Duration) error { txn := RDB.TxPipeline() txn.HSet(ctx, key, data) - txn.Expire(ctx, key, expiration) + + // 只有在 expiration 大于 0 时才设置过期时间 + if expiration > 0 { + txn.Expire(ctx, key, expiration) + } _, err := txn.Exec(ctx) if err != nil { diff --git a/common/utils.go b/common/utils.go index 587de537..17aecd95 100644 --- a/common/utils.go +++ b/common/utils.go @@ -13,6 +13,7 @@ import ( "math/big" "math/rand" "net" + "net/url" "os" "os/exec" "runtime" @@ -249,13 +250,55 @@ func SaveTmpFile(filename string, data io.Reader) (string, error) { } // GetAudioDuration returns the duration of an audio file in seconds. -func GetAudioDuration(ctx context.Context, filename string) (float64, error) { +func GetAudioDuration(ctx context.Context, filename string, ext string) (float64, error) { // ffprobe -v error -show_entries format=duration -of default=noprint_wrappers=1:nokey=1 {{input}} c := exec.CommandContext(ctx, "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", filename) output, err := c.Output() if err != nil { return 0, errors.Wrap(err, "failed to get audio duration") } + durationStr := string(bytes.TrimSpace(output)) + if durationStr == "N/A" { + // Create a temporary output file name + tmpFp, err := os.CreateTemp("", "audio-*"+ext) + if err != nil { + return 0, errors.Wrap(err, "failed to create temporary file") + } + tmpName := tmpFp.Name() + // Close immediately so ffmpeg can open the file on Windows. + _ = tmpFp.Close() + defer os.Remove(tmpName) - return strconv.ParseFloat(string(bytes.TrimSpace(output)), 64) + // ffmpeg -y -i filename -vcodec copy -acodec copy + ffmpegCmd := exec.CommandContext(ctx, "ffmpeg", "-y", "-i", filename, "-vcodec", "copy", "-acodec", "copy", tmpName) + if err := ffmpegCmd.Run(); err != nil { + return 0, errors.Wrap(err, "failed to run ffmpeg") + } + + // Recalculate the duration of the new file + c = exec.CommandContext(ctx, "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", tmpName) + output, err := c.Output() + if err != nil { + return 0, errors.Wrap(err, "failed to get audio duration after ffmpeg") + } + durationStr = string(bytes.TrimSpace(output)) + } + return strconv.ParseFloat(durationStr, 64) +} + +// BuildURL concatenates base and endpoint, returns the complete url string +func BuildURL(base string, endpoint string) string { + u, err := url.Parse(base) + if err != nil { + return base + endpoint + } + end := endpoint + if end == "" { + end = "/" + } + ref, err := url.Parse(end) + if err != nil { + return base + endpoint + } + return u.ResolveReference(ref).String() } diff --git a/constant/README.md b/constant/README.md new file mode 100644 index 00000000..12a9ffad --- /dev/null +++ b/constant/README.md @@ -0,0 +1,26 @@ +# constant 包 (`/constant`) + +该目录仅用于放置全局可复用的**常量定义**,不包含任何业务逻辑或依赖关系。 + +## 当前文件 + +| 文件 | 说明 | +|----------------------|---------------------------------------------------------------------| +| `azure.go` | 定义与 Azure 相关的全局常量,如 `AzureNoRemoveDotTime`(控制删除 `.` 的截止时间)。 | +| `cache_key.go` | 缓存键格式字符串及 Token 相关字段常量,统一缓存命名规则。 | +| `channel_setting.go` | Channel 级别的设置键,如 `proxy`、`force_format` 等。 | +| `context_key.go` | 定义 `ContextKey` 类型以及在整个项目中使用的上下文键常量(请求时间、Token/Channel/User 相关信息等)。 | +| `env.go` | 环境配置相关的全局变量,在启动阶段根据配置文件或环境变量注入。 | +| `finish_reason.go` | OpenAI/GPT 请求返回的 `finish_reason` 字符串常量集合。 | +| `midjourney.go` | Midjourney 相关错误码及动作(Action)常量与模型到动作的映射表。 | +| `setup.go` | 标识项目是否已完成初始化安装 (`Setup` 布尔值)。 | +| `task.go` | 各种任务(Task)平台、动作常量及模型与动作映射表,如 Suno、Midjourney 等。 | +| `user_setting.go` | 用户设置相关键常量以及通知类型(Email/Webhook)等。 | + +## 使用约定 + +1. `constant` 包**只能被其他包引用**(import),**禁止在此包中引用项目内的其他自定义包**。如确有需要,仅允许引用 **Go 标准库**。 +2. 不允许在此目录内编写任何与业务流程、数据库操作、第三方服务调用等相关的逻辑代码。 +3. 新增类型时,请保持命名语义清晰,并在本 README 的 **当前文件** 表格中补充说明,确保团队成员能够快速了解其用途。 + +> ⚠️ 违反以上约定将导致包之间产生不必要的耦合,影响代码可维护性与可测试性。请在提交代码前自行检查。 \ No newline at end of file diff --git a/constant/api_type.go b/constant/api_type.go new file mode 100644 index 00000000..ae867870 --- /dev/null +++ b/constant/api_type.go @@ -0,0 +1,34 @@ +package constant + +const ( + APITypeOpenAI = iota + APITypeAnthropic + APITypePaLM + APITypeBaidu + APITypeZhipu + APITypeAli + APITypeXunfei + APITypeAIProxyLibrary + APITypeTencent + APITypeGemini + APITypeZhipuV4 + APITypeOllama + APITypePerplexity + APITypeAws + APITypeCohere + APITypeDify + APITypeJina + APITypeCloudflare + APITypeSiliconFlow + APITypeVertexAi + APITypeMistral + APITypeDeepSeek + APITypeMokaAI + APITypeVolcEngine + APITypeBaiduV2 + APITypeOpenRouter + APITypeXinference + APITypeXai + APITypeCoze + APITypeDummy // this one is only for count, do not add any channel after this +) diff --git a/constant/cache_key.go b/constant/cache_key.go index 27cb3b75..0601396a 100644 --- a/constant/cache_key.go +++ b/constant/cache_key.go @@ -1,14 +1,5 @@ package constant -import "one-api/common" - -var ( - TokenCacheSeconds = common.SyncFrequency - UserId2GroupCacheSeconds = common.SyncFrequency - UserId2QuotaCacheSeconds = common.SyncFrequency - UserId2StatusCacheSeconds = common.SyncFrequency -) - // Cache keys const ( UserGroupKeyFmt = "user_group:%d" diff --git a/constant/channel.go b/constant/channel.go new file mode 100644 index 00000000..224121e7 --- /dev/null +++ b/constant/channel.go @@ -0,0 +1,109 @@ +package constant + +const ( + ChannelTypeUnknown = 0 + ChannelTypeOpenAI = 1 + ChannelTypeMidjourney = 2 + ChannelTypeAzure = 3 + ChannelTypeOllama = 4 + ChannelTypeMidjourneyPlus = 5 + ChannelTypeOpenAIMax = 6 + ChannelTypeOhMyGPT = 7 + ChannelTypeCustom = 8 + ChannelTypeAILS = 9 + ChannelTypeAIProxy = 10 + ChannelTypePaLM = 11 + ChannelTypeAPI2GPT = 12 + ChannelTypeAIGC2D = 13 + ChannelTypeAnthropic = 14 + ChannelTypeBaidu = 15 + ChannelTypeZhipu = 16 + ChannelTypeAli = 17 + ChannelTypeXunfei = 18 + ChannelType360 = 19 + ChannelTypeOpenRouter = 20 + ChannelTypeAIProxyLibrary = 21 + ChannelTypeFastGPT = 22 + ChannelTypeTencent = 23 + ChannelTypeGemini = 24 + ChannelTypeMoonshot = 25 + ChannelTypeZhipu_v4 = 26 + ChannelTypePerplexity = 27 + ChannelTypeLingYiWanWu = 31 + ChannelTypeAws = 33 + ChannelTypeCohere = 34 + ChannelTypeMiniMax = 35 + ChannelTypeSunoAPI = 36 + ChannelTypeDify = 37 + ChannelTypeJina = 38 + ChannelCloudflare = 39 + ChannelTypeSiliconFlow = 40 + ChannelTypeVertexAi = 41 + ChannelTypeMistral = 42 + ChannelTypeDeepSeek = 43 + ChannelTypeMokaAI = 44 + ChannelTypeVolcEngine = 45 + ChannelTypeBaiduV2 = 46 + ChannelTypeXinference = 47 + ChannelTypeXai = 48 + ChannelTypeCoze = 49 + ChannelTypeKling = 50 + ChannelTypeJimeng = 51 + ChannelTypeDummy // this one is only for count, do not add any channel after this + +) + +var ChannelBaseURLs = []string{ + "", // 0 + "https://api.openai.com", // 1 + "https://oa.api2d.net", // 2 + "", // 3 + "http://localhost:11434", // 4 + "https://api.openai-sb.com", // 5 + "https://api.openaimax.com", // 6 + "https://api.ohmygpt.com", // 7 + "", // 8 + "https://api.caipacity.com", // 9 + "https://api.aiproxy.io", // 10 + "", // 11 + "https://api.api2gpt.com", // 12 + "https://api.aigc2d.com", // 13 + "https://api.anthropic.com", // 14 + "https://aip.baidubce.com", // 15 + "https://open.bigmodel.cn", // 16 + "https://dashscope.aliyuncs.com", // 17 + "", // 18 + "https://api.360.cn", // 19 + "https://openrouter.ai/api", // 20 + "https://api.aiproxy.io", // 21 + "https://fastgpt.run/api/openapi", // 22 + "https://hunyuan.tencentcloudapi.com", //23 + "https://generativelanguage.googleapis.com", //24 + "https://api.moonshot.cn", //25 + "https://open.bigmodel.cn", //26 + "https://api.perplexity.ai", //27 + "", //28 + "", //29 + "", //30 + "https://api.lingyiwanwu.com", //31 + "", //32 + "", //33 + "https://api.cohere.ai", //34 + "https://api.minimax.chat", //35 + "", //36 + "https://api.dify.ai", //37 + "https://api.jina.ai", //38 + "https://api.cloudflare.com", //39 + "https://api.siliconflow.cn", //40 + "", //41 + "https://api.mistral.ai", //42 + "https://api.deepseek.com", //43 + "https://api.moka.ai", //44 + "https://ark.cn-beijing.volces.com", //45 + "https://qianfan.baidubce.com", //46 + "", //47 + "https://api.x.ai", //48 + "https://api.coze.cn", //49 + "https://api.klingai.com", //50 + "https://visual.volcengineapi.com", //51 +} diff --git a/constant/context_key.go b/constant/context_key.go index 4b4d5cae..71e02f01 100644 --- a/constant/context_key.go +++ b/constant/context_key.go @@ -1,10 +1,35 @@ package constant +type ContextKey string + const ( - ContextKeyRequestStartTime = "request_start_time" - ContextKeyUserSetting = "user_setting" - ContextKeyUserQuota = "user_quota" - ContextKeyUserStatus = "user_status" - ContextKeyUserEmail = "user_email" - ContextKeyUserGroup = "user_group" + ContextKeyOriginalModel ContextKey = "original_model" + ContextKeyRequestStartTime ContextKey = "request_start_time" + + /* token related keys */ + ContextKeyTokenUnlimited ContextKey = "token_unlimited_quota" + ContextKeyTokenKey ContextKey = "token_key" + ContextKeyTokenId ContextKey = "token_id" + ContextKeyTokenGroup ContextKey = "token_group" + ContextKeyTokenAllowIps ContextKey = "allow_ips" + ContextKeyTokenSpecificChannelId ContextKey = "specific_channel_id" + ContextKeyTokenModelLimitEnabled ContextKey = "token_model_limit_enabled" + ContextKeyTokenModelLimit ContextKey = "token_model_limit" + + /* channel related keys */ + ContextKeyBaseUrl ContextKey = "base_url" + ContextKeyChannelType ContextKey = "channel_type" + ContextKeyChannelId ContextKey = "channel_id" + ContextKeyChannelSetting ContextKey = "channel_setting" + ContextKeyParamOverride ContextKey = "param_override" + + /* user related keys */ + ContextKeyUserId ContextKey = "id" + ContextKeyUserSetting ContextKey = "user_setting" + ContextKeyUserQuota ContextKey = "user_quota" + ContextKeyUserStatus ContextKey = "user_status" + ContextKeyUserEmail ContextKey = "user_email" + ContextKeyUserGroup ContextKey = "user_group" + ContextKeyUsingGroup ContextKey = "group" + ContextKeyUserName ContextKey = "username" ) diff --git a/constant/endpoint_type.go b/constant/endpoint_type.go new file mode 100644 index 00000000..ef096b75 --- /dev/null +++ b/constant/endpoint_type.go @@ -0,0 +1,16 @@ +package constant + +type EndpointType string + +const ( + EndpointTypeOpenAI EndpointType = "openai" + EndpointTypeOpenAIResponse EndpointType = "openai-response" + EndpointTypeAnthropic EndpointType = "anthropic" + EndpointTypeGemini EndpointType = "gemini" + EndpointTypeJinaRerank EndpointType = "jina-rerank" + EndpointTypeImageGeneration EndpointType = "image-generation" + //EndpointTypeMidjourney EndpointType = "midjourney-proxy" + //EndpointTypeSuno EndpointType = "suno-proxy" + //EndpointTypeKling EndpointType = "kling" + //EndpointTypeJimeng EndpointType = "jimeng" +) diff --git a/constant/env.go b/constant/env.go index 612f3e8b..8bc2f131 100644 --- a/constant/env.go +++ b/constant/env.go @@ -1,9 +1,5 @@ package constant -import ( - "one-api/common" -) - var StreamingTimeout int var DifyDebug bool var MaxFileDownloadMB int @@ -17,39 +13,3 @@ var NotifyLimitCount int var NotificationLimitDurationMinute int var GenerateDefaultToken bool var ErrorLogEnabled bool - -//var GeminiModelMap = map[string]string{ -// "gemini-1.0-pro": "v1", -//} - -func InitEnv() { - StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 60) - DifyDebug = common.GetEnvOrDefaultBool("DIFY_DEBUG", true) - MaxFileDownloadMB = common.GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 20) - // ForceStreamOption 覆盖请求参数,强制返回usage信息 - ForceStreamOption = common.GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true) - GetMediaToken = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true) - GetMediaTokenNotStream = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", true) - UpdateTask = common.GetEnvOrDefaultBool("UPDATE_TASK", true) - AzureDefaultAPIVersion = common.GetEnvOrDefaultString("AZURE_DEFAULT_API_VERSION", "2025-04-01-preview") - GeminiVisionMaxImageNum = common.GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16) - NotifyLimitCount = common.GetEnvOrDefault("NOTIFY_LIMIT_COUNT", 2) - NotificationLimitDurationMinute = common.GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10) - // GenerateDefaultToken 是否生成初始令牌,默认关闭。 - GenerateDefaultToken = common.GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false) - // 是否启用错误日志 - ErrorLogEnabled = common.GetEnvOrDefaultBool("ERROR_LOG_ENABLED", false) - - //modelVersionMapStr := strings.TrimSpace(os.Getenv("GEMINI_MODEL_MAP")) - //if modelVersionMapStr == "" { - // return - //} - //for _, pair := range strings.Split(modelVersionMapStr, ",") { - // parts := strings.Split(pair, ":") - // if len(parts) == 2 { - // GeminiModelMap[parts[0]] = parts[1] - // } else { - // common.SysError(fmt.Sprintf("invalid model version map: %s", pair)) - // } - //} -} diff --git a/constant/midjourney.go b/constant/midjourney.go index 1bf4d549..5934be2f 100644 --- a/constant/midjourney.go +++ b/constant/midjourney.go @@ -22,6 +22,8 @@ const ( MjActionPan = "PAN" MjActionSwapFace = "SWAP_FACE" MjActionUpload = "UPLOAD" + MjActionVideo = "VIDEO" + MjActionEdits = "EDITS" ) var MidjourneyModel2Action = map[string]string{ @@ -41,4 +43,6 @@ var MidjourneyModel2Action = map[string]string{ "mj_pan": MjActionPan, "swap_face": MjActionSwapFace, "mj_upload": MjActionUpload, + "mj_video": MjActionVideo, + "mj_edits": MjActionEdits, } diff --git a/constant/task.go b/constant/task.go index 1a68b812..e7af39a6 100644 --- a/constant/task.go +++ b/constant/task.go @@ -5,11 +5,16 @@ type TaskPlatform string const ( TaskPlatformSuno TaskPlatform = "suno" TaskPlatformMidjourney = "mj" + TaskPlatformKling TaskPlatform = "kling" + TaskPlatformJimeng TaskPlatform = "jimeng" ) const ( SunoActionMusic = "MUSIC" SunoActionLyrics = "LYRICS" + + TaskActionGenerate = "generate" + TaskActionTextGenerate = "textGenerate" ) var SunoModel2Action = map[string]string{ diff --git a/controller/channel-billing.go b/controller/channel-billing.go index 2bda0fd2..3c92c78b 100644 --- a/controller/channel-billing.go +++ b/controller/channel-billing.go @@ -4,11 +4,14 @@ import ( "encoding/json" "errors" "fmt" + "github.com/shopspring/decimal" "io" "net/http" "one-api/common" + "one-api/constant" "one-api/model" "one-api/service" + "one-api/setting" "strconv" "time" @@ -304,34 +307,70 @@ func updateChannelOpenRouterBalance(channel *model.Channel) (float64, error) { return balance, nil } +func updateChannelMoonshotBalance(channel *model.Channel) (float64, error) { + url := "https://api.moonshot.cn/v1/users/me/balance" + body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) + if err != nil { + return 0, err + } + + type MoonshotBalanceData struct { + AvailableBalance float64 `json:"available_balance"` + VoucherBalance float64 `json:"voucher_balance"` + CashBalance float64 `json:"cash_balance"` + } + + type MoonshotBalanceResponse struct { + Code int `json:"code"` + Data MoonshotBalanceData `json:"data"` + Scode string `json:"scode"` + Status bool `json:"status"` + } + + response := MoonshotBalanceResponse{} + err = json.Unmarshal(body, &response) + if err != nil { + return 0, err + } + if !response.Status || response.Code != 0 { + return 0, fmt.Errorf("failed to update moonshot balance, status: %v, code: %d, scode: %s", response.Status, response.Code, response.Scode) + } + availableBalanceCny := response.Data.AvailableBalance + availableBalanceUsd := decimal.NewFromFloat(availableBalanceCny).Div(decimal.NewFromFloat(setting.Price)).InexactFloat64() + channel.UpdateBalance(availableBalanceUsd) + return availableBalanceUsd, nil +} + func updateChannelBalance(channel *model.Channel) (float64, error) { - baseURL := common.ChannelBaseURLs[channel.Type] + baseURL := constant.ChannelBaseURLs[channel.Type] if channel.GetBaseURL() == "" { channel.BaseURL = &baseURL } switch channel.Type { - case common.ChannelTypeOpenAI: + case constant.ChannelTypeOpenAI: if channel.GetBaseURL() != "" { baseURL = channel.GetBaseURL() } - case common.ChannelTypeAzure: + case constant.ChannelTypeAzure: return 0, errors.New("尚未实现") - case common.ChannelTypeCustom: + case constant.ChannelTypeCustom: baseURL = channel.GetBaseURL() //case common.ChannelTypeOpenAISB: // return updateChannelOpenAISBBalance(channel) - case common.ChannelTypeAIProxy: + case constant.ChannelTypeAIProxy: return updateChannelAIProxyBalance(channel) - case common.ChannelTypeAPI2GPT: + case constant.ChannelTypeAPI2GPT: return updateChannelAPI2GPTBalance(channel) - case common.ChannelTypeAIGC2D: + case constant.ChannelTypeAIGC2D: return updateChannelAIGC2DBalance(channel) - case common.ChannelTypeSiliconFlow: + case constant.ChannelTypeSiliconFlow: return updateChannelSiliconFlowBalance(channel) - case common.ChannelTypeDeepSeek: + case constant.ChannelTypeDeepSeek: return updateChannelDeepSeekBalance(channel) - case common.ChannelTypeOpenRouter: + case constant.ChannelTypeOpenRouter: return updateChannelOpenRouterBalance(channel) + case constant.ChannelTypeMoonshot: + return updateChannelMoonshotBalance(channel) default: return 0, errors.New("尚未实现") } diff --git a/controller/channel-test.go b/controller/channel-test.go index 52f8a7ef..0b474c25 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -11,12 +11,12 @@ import ( "net/http/httptest" "net/url" "one-api/common" + "one-api/constant" "one-api/dto" "one-api/middleware" "one-api/model" "one-api/relay" relaycommon "one-api/relay/common" - "one-api/relay/constant" "one-api/relay/helper" "one-api/service" "strconv" @@ -31,15 +31,21 @@ import ( func testChannel(channel *model.Channel, testModel string) (err error, openAIErrorWithStatusCode *dto.OpenAIErrorWithStatusCode) { tik := time.Now() - if channel.Type == common.ChannelTypeMidjourney { + if channel.Type == constant.ChannelTypeMidjourney { return errors.New("midjourney channel test is not supported"), nil } - if channel.Type == common.ChannelTypeMidjourneyPlus { - return errors.New("midjourney plus channel test is not supported!!!"), nil + if channel.Type == constant.ChannelTypeMidjourneyPlus { + return errors.New("midjourney plus channel test is not supported"), nil } - if channel.Type == common.ChannelTypeSunoAPI { + if channel.Type == constant.ChannelTypeSunoAPI { return errors.New("suno channel test is not supported"), nil } + if channel.Type == constant.ChannelTypeKling { + return errors.New("kling channel test is not supported"), nil + } + if channel.Type == constant.ChannelTypeJimeng { + return errors.New("jimeng channel test is not supported"), nil + } w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) @@ -50,7 +56,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr strings.HasPrefix(testModel, "m3e") || // m3e 系列模型 strings.Contains(testModel, "bge-") || // bge 系列模型 strings.Contains(testModel, "embed") || - channel.Type == common.ChannelTypeMokaAI { // 其他 embedding 模型 + channel.Type == constant.ChannelTypeMokaAI { // 其他 embedding 模型 requestPath = "/v1/embeddings" // 修改请求路径 } @@ -90,13 +96,13 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr info := relaycommon.GenRelayInfo(c) - err = helper.ModelMappedHelper(c, info) + err = helper.ModelMappedHelper(c, info, nil) if err != nil { return err, nil } testModel = info.UpstreamModelName - apiType, _ := constant.ChannelType2APIType(channel.Type) + apiType, _ := common.ChannelType2APIType(channel.Type) adaptor := relay.GetAdaptor(apiType) if adaptor == nil { return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil @@ -165,10 +171,10 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr tok := time.Now() milliseconds := tok.Sub(tik).Milliseconds() consumedTime := float64(milliseconds) / 1000.0 - other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatio, priceData.CompletionRatio, - usage.PromptTokensDetails.CachedTokens, priceData.CacheRatio, priceData.ModelPrice, priceData.UserGroupRatio) + other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatioInfo.GroupRatio, priceData.CompletionRatio, + usage.PromptTokensDetails.CachedTokens, priceData.CacheRatio, priceData.ModelPrice, priceData.GroupRatioInfo.GroupSpecialRatio) model.RecordConsumeLog(c, 1, channel.Id, usage.PromptTokens, usage.CompletionTokens, info.OriginModelName, "模型测试", - quota, "模型测试", 0, quota, int(consumedTime), false, info.Group, other) + quota, "模型测试", 0, quota, int(consumedTime), false, info.UsingGroup, other) common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody))) return nil, nil } @@ -196,7 +202,7 @@ func buildTestRequest(model string) *dto.GeneralOpenAIRequest { testRequest.MaxTokens = 50 } } else if strings.Contains(model, "gemini") { - testRequest.MaxTokens = 300 + testRequest.MaxTokens = 3000 } else { testRequest.MaxTokens = 10 } @@ -312,7 +318,7 @@ func testAllChannels(notify bool) error { channel.UpdateResponseTime(milliseconds) time.Sleep(common.RequestInterval) } - + if notify { service.NotifyRootUser(dto.NotifyTypeChannelTest, "通道测试完成", "所有通道测试已完成") } diff --git a/controller/channel.go b/controller/channel.go index dc4e0cbf..59177b7a 100644 --- a/controller/channel.go +++ b/controller/channel.go @@ -5,6 +5,7 @@ import ( "fmt" "net/http" "one-api/common" + "one-api/constant" "one-api/model" "strconv" "strings" @@ -40,6 +41,17 @@ type OpenAIModelsResponse struct { Success bool `json:"success"` } +func parseStatusFilter(statusParam string) int { + switch strings.ToLower(statusParam) { + case "enabled", "1": + return common.ChannelStatusEnabled + case "disabled", "0": + return 0 + default: + return -1 + } +} + func GetAllChannels(c *gin.Context) { p, _ := strconv.Atoi(c.Query("p")) pageSize, _ := strconv.Atoi(c.Query("page_size")) @@ -52,44 +64,100 @@ func GetAllChannels(c *gin.Context) { channelData := make([]*model.Channel, 0) idSort, _ := strconv.ParseBool(c.Query("id_sort")) enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode")) + statusParam := c.Query("status") + // statusFilter: -1 all, 1 enabled, 0 disabled (include auto & manual) + statusFilter := parseStatusFilter(statusParam) + // type filter + typeStr := c.Query("type") + typeFilter := -1 + if typeStr != "" { + if t, err := strconv.Atoi(typeStr); err == nil { + typeFilter = t + } + } var total int64 if enableTagMode { - // tag 分页:先分页 tag,再取各 tag 下 channels tags, err := model.GetPaginatedTags((p-1)*pageSize, pageSize) if err != nil { c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()}) return } for _, tag := range tags { - if tag != nil && *tag != "" { - tagChannel, err := model.GetChannelsByTag(*tag, idSort) - if err == nil { - channelData = append(channelData, tagChannel...) - } + if tag == nil || *tag == "" { + continue } + tagChannels, err := model.GetChannelsByTag(*tag, idSort) + if err != nil { + continue + } + filtered := make([]*model.Channel, 0) + for _, ch := range tagChannels { + if statusFilter == common.ChannelStatusEnabled && ch.Status != common.ChannelStatusEnabled { + continue + } + if statusFilter == 0 && ch.Status == common.ChannelStatusEnabled { + continue + } + if typeFilter >= 0 && ch.Type != typeFilter { + continue + } + filtered = append(filtered, ch) + } + channelData = append(channelData, filtered...) } - // 计算 tag 总数用于分页 total, _ = model.CountAllTags() } else { - channels, err := model.GetAllChannels((p-1)*pageSize, pageSize, false, idSort) + baseQuery := model.DB.Model(&model.Channel{}) + if typeFilter >= 0 { + baseQuery = baseQuery.Where("type = ?", typeFilter) + } + if statusFilter == common.ChannelStatusEnabled { + baseQuery = baseQuery.Where("status = ?", common.ChannelStatusEnabled) + } else if statusFilter == 0 { + baseQuery = baseQuery.Where("status != ?", common.ChannelStatusEnabled) + } + + baseQuery.Count(&total) + + order := "priority desc" + if idSort { + order = "id desc" + } + + err := baseQuery.Order(order).Limit(pageSize).Offset((p - 1) * pageSize).Omit("key").Find(&channelData).Error if err != nil { c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()}) return } - channelData = channels - total, _ = model.CountAllChannels() + } + + countQuery := model.DB.Model(&model.Channel{}) + if statusFilter == common.ChannelStatusEnabled { + countQuery = countQuery.Where("status = ?", common.ChannelStatusEnabled) + } else if statusFilter == 0 { + countQuery = countQuery.Where("status != ?", common.ChannelStatusEnabled) + } + var results []struct { + Type int64 + Count int64 + } + _ = countQuery.Select("type, count(*) as count").Group("type").Find(&results).Error + typeCounts := make(map[int64]int64) + for _, r := range results { + typeCounts[r.Type] = r.Count } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": gin.H{ - "items": channelData, - "total": total, - "page": p, - "page_size": pageSize, + "items": channelData, + "total": total, + "page": p, + "page_size": pageSize, + "type_counts": typeCounts, }, }) return @@ -114,22 +182,15 @@ func FetchUpstreamModels(c *gin.Context) { return } - //if channel.Type != common.ChannelTypeOpenAI { - // c.JSON(http.StatusOK, gin.H{ - // "success": false, - // "message": "仅支持 OpenAI 类型渠道", - // }) - // return - //} - baseURL := common.ChannelBaseURLs[channel.Type] + baseURL := constant.ChannelBaseURLs[channel.Type] if channel.GetBaseURL() != "" { baseURL = channel.GetBaseURL() } url := fmt.Sprintf("%s/v1/models", baseURL) switch channel.Type { - case common.ChannelTypeGemini: + case constant.ChannelTypeGemini: url = fmt.Sprintf("%s/v1beta/openai/models", baseURL) - case common.ChannelTypeAli: + case constant.ChannelTypeAli: url = fmt.Sprintf("%s/compatible-mode/v1/models", baseURL) } body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) @@ -153,7 +214,7 @@ func FetchUpstreamModels(c *gin.Context) { var ids []string for _, model := range result.Data { id := model.ID - if channel.Type == common.ChannelTypeGemini { + if channel.Type == constant.ChannelTypeGemini { id = strings.TrimPrefix(id, "models/") } ids = append(ids, id) @@ -186,6 +247,8 @@ func SearchChannels(c *gin.Context) { keyword := c.Query("keyword") group := c.Query("group") modelKeyword := c.Query("model") + statusParam := c.Query("status") + statusFilter := parseStatusFilter(statusParam) idSort, _ := strconv.ParseBool(c.Query("id_sort")) enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode")) channelData := make([]*model.Channel, 0) @@ -217,10 +280,74 @@ func SearchChannels(c *gin.Context) { } channelData = channels } + + if statusFilter == common.ChannelStatusEnabled || statusFilter == 0 { + filtered := make([]*model.Channel, 0, len(channelData)) + for _, ch := range channelData { + if statusFilter == common.ChannelStatusEnabled && ch.Status != common.ChannelStatusEnabled { + continue + } + if statusFilter == 0 && ch.Status == common.ChannelStatusEnabled { + continue + } + filtered = append(filtered, ch) + } + channelData = filtered + } + + // calculate type counts for search results + typeCounts := make(map[int64]int64) + for _, channel := range channelData { + typeCounts[int64(channel.Type)]++ + } + + typeParam := c.Query("type") + typeFilter := -1 + if typeParam != "" { + if tp, err := strconv.Atoi(typeParam); err == nil { + typeFilter = tp + } + } + + if typeFilter >= 0 { + filtered := make([]*model.Channel, 0, len(channelData)) + for _, ch := range channelData { + if ch.Type == typeFilter { + filtered = append(filtered, ch) + } + } + channelData = filtered + } + + page, _ := strconv.Atoi(c.DefaultQuery("p", "1")) + pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20")) + if page < 1 { + page = 1 + } + if pageSize <= 0 { + pageSize = 20 + } + + total := len(channelData) + startIdx := (page - 1) * pageSize + if startIdx > total { + startIdx = total + } + endIdx := startIdx + pageSize + if endIdx > total { + endIdx = total + } + + pagedData := channelData[startIdx:endIdx] + c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", - "data": channelData, + "data": gin.H{ + "items": pagedData, + "total": total, + "type_counts": typeCounts, + }, }) return } @@ -283,7 +410,7 @@ func AddChannel(c *gin.Context) { return } } - if addChannelRequest.Channel.Type == common.ChannelTypeVertexAi { + if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi { if addChannelRequest.Channel.Other == "" { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -566,7 +693,7 @@ func UpdateChannel(c *gin.Context) { }) return } - if channel.Type == common.ChannelTypeVertexAi { + if channel.Type == constant.ChannelTypeVertexAi { if channel.Other == "" { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -595,6 +722,7 @@ func UpdateChannel(c *gin.Context) { }) return } + channel.Key = "" c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", @@ -620,7 +748,7 @@ func FetchModels(c *gin.Context) { baseURL := req.BaseURL if baseURL == "" { - baseURL = common.ChannelBaseURLs[req.Type] + baseURL = constant.ChannelBaseURLs[req.Type] } client := &http.Client{} diff --git a/controller/group.go b/controller/group.go index 2c725a4d..2565b6ea 100644 --- a/controller/group.go +++ b/controller/group.go @@ -1,15 +1,17 @@ package controller import ( - "github.com/gin-gonic/gin" "net/http" "one-api/model" "one-api/setting" + "one-api/setting/ratio_setting" + + "github.com/gin-gonic/gin" ) func GetGroups(c *gin.Context) { groupNames := make([]string, 0) - for groupName, _ := range setting.GetGroupRatioCopy() { + for groupName := range ratio_setting.GetGroupRatioCopy() { groupNames = append(groupNames, groupName) } c.JSON(http.StatusOK, gin.H{ @@ -24,7 +26,7 @@ func GetUserGroups(c *gin.Context) { userGroup := "" userId := c.GetInt("id") userGroup, _ = model.GetUserGroup(userId, false) - for groupName, ratio := range setting.GetGroupRatioCopy() { + for groupName, ratio := range ratio_setting.GetGroupRatioCopy() { // UserUsableGroups contains the groups that the user can use userUsableGroups := setting.GetUserUsableGroups(userGroup) if desc, ok := userUsableGroups[groupName]; ok { @@ -34,6 +36,12 @@ func GetUserGroups(c *gin.Context) { } } } + if setting.GroupInUserUsableGroups("auto") { + usableGroups["auto"] = map[string]interface{}{ + "ratio": "自动", + "desc": setting.GetUsableGroupDescription("auto"), + } + } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", diff --git a/controller/misc.go b/controller/misc.go index 33a41302..4ffe86f4 100644 --- a/controller/misc.go +++ b/controller/misc.go @@ -9,9 +9,9 @@ import ( "one-api/middleware" "one-api/model" "one-api/setting" + "one-api/setting/console_setting" "one-api/setting/operation_setting" "one-api/setting/system_setting" - "one-api/setting/console_setting" "strings" "github.com/gin-gonic/gin" @@ -41,46 +41,48 @@ func GetStatus(c *gin.Context) { cs := console_setting.GetConsoleSetting() data := gin.H{ - "version": common.Version, - "start_time": common.StartTime, - "email_verification": common.EmailVerificationEnabled, - "github_oauth": common.GitHubOAuthEnabled, - "github_client_id": common.GitHubClientId, - "linuxdo_oauth": common.LinuxDOOAuthEnabled, - "linuxdo_client_id": common.LinuxDOClientId, - "telegram_oauth": common.TelegramOAuthEnabled, - "telegram_bot_name": common.TelegramBotName, - "system_name": common.SystemName, - "logo": common.Logo, - "footer_html": common.Footer, - "wechat_qrcode": common.WeChatAccountQRCodeImageURL, - "wechat_login": common.WeChatAuthEnabled, - "server_address": setting.ServerAddress, - "price": setting.Price, - "min_topup": setting.MinTopUp, - "turnstile_check": common.TurnstileCheckEnabled, - "turnstile_site_key": common.TurnstileSiteKey, - "top_up_link": common.TopUpLink, - "docs_link": operation_setting.GetGeneralSetting().DocsLink, - "quota_per_unit": common.QuotaPerUnit, - "display_in_currency": common.DisplayInCurrencyEnabled, - "enable_batch_update": common.BatchUpdateEnabled, - "enable_drawing": common.DrawingEnabled, - "enable_task": common.TaskEnabled, - "enable_data_export": common.DataExportEnabled, - "data_export_default_time": common.DataExportDefaultTime, - "default_collapse_sidebar": common.DefaultCollapseSidebar, - "enable_online_topup": setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "", - "mj_notify_enabled": setting.MjNotifyEnabled, - "chats": setting.Chats, - "demo_site_enabled": operation_setting.DemoSiteEnabled, - "self_use_mode_enabled": operation_setting.SelfUseModeEnabled, + "version": common.Version, + "start_time": common.StartTime, + "email_verification": common.EmailVerificationEnabled, + "github_oauth": common.GitHubOAuthEnabled, + "github_client_id": common.GitHubClientId, + "linuxdo_oauth": common.LinuxDOOAuthEnabled, + "linuxdo_client_id": common.LinuxDOClientId, + "telegram_oauth": common.TelegramOAuthEnabled, + "telegram_bot_name": common.TelegramBotName, + "system_name": common.SystemName, + "logo": common.Logo, + "footer_html": common.Footer, + "wechat_qrcode": common.WeChatAccountQRCodeImageURL, + "wechat_login": common.WeChatAuthEnabled, + "server_address": setting.ServerAddress, + "price": setting.Price, + "min_topup": setting.MinTopUp, + "turnstile_check": common.TurnstileCheckEnabled, + "turnstile_site_key": common.TurnstileSiteKey, + "top_up_link": common.TopUpLink, + "docs_link": operation_setting.GetGeneralSetting().DocsLink, + "quota_per_unit": common.QuotaPerUnit, + "display_in_currency": common.DisplayInCurrencyEnabled, + "enable_batch_update": common.BatchUpdateEnabled, + "enable_drawing": common.DrawingEnabled, + "enable_task": common.TaskEnabled, + "enable_data_export": common.DataExportEnabled, + "data_export_default_time": common.DataExportDefaultTime, + "default_collapse_sidebar": common.DefaultCollapseSidebar, + "enable_online_topup": setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "", + "mj_notify_enabled": setting.MjNotifyEnabled, + "chats": setting.Chats, + "demo_site_enabled": operation_setting.DemoSiteEnabled, + "self_use_mode_enabled": operation_setting.SelfUseModeEnabled, + "default_use_auto_group": setting.DefaultUseAutoGroup, + "pay_methods": setting.PayMethods, // 面板启用开关 - "api_info_enabled": cs.ApiInfoEnabled, - "uptime_kuma_enabled": cs.UptimeKumaEnabled, - "announcements_enabled": cs.AnnouncementsEnabled, - "faq_enabled": cs.FAQEnabled, + "api_info_enabled": cs.ApiInfoEnabled, + "uptime_kuma_enabled": cs.UptimeKumaEnabled, + "announcements_enabled": cs.AnnouncementsEnabled, + "faq_enabled": cs.FAQEnabled, "oidc_enabled": system_setting.GetOIDCSettings().Enabled, "oidc_client_id": system_setting.GetOIDCSettings().ClientId, diff --git a/controller/model.go b/controller/model.go index df7e59a6..31a66b29 100644 --- a/controller/model.go +++ b/controller/model.go @@ -3,6 +3,7 @@ package controller import ( "fmt" "github.com/gin-gonic/gin" + "github.com/samber/lo" "net/http" "one-api/common" "one-api/constant" @@ -14,7 +15,7 @@ import ( "one-api/relay/channel/minimax" "one-api/relay/channel/moonshot" relaycommon "one-api/relay/common" - relayconstant "one-api/relay/constant" + "one-api/setting" ) // https://platform.openai.com/docs/api-reference/models/list @@ -23,30 +24,10 @@ var openAIModels []dto.OpenAIModels var openAIModelsMap map[string]dto.OpenAIModels var channelId2Models map[int][]string -func getPermission() []dto.OpenAIModelPermission { - var permission []dto.OpenAIModelPermission - permission = append(permission, dto.OpenAIModelPermission{ - Id: "modelperm-LwHkVFn8AcMItP432fKKDIKJ", - Object: "model_permission", - Created: 1626777600, - AllowCreateEngine: true, - AllowSampling: true, - AllowLogprobs: true, - AllowSearchIndices: false, - AllowView: true, - AllowFineTuning: false, - Organization: "*", - Group: nil, - IsBlocking: false, - }) - return permission -} - func init() { // https://platform.openai.com/docs/models/model-endpoint-compatibility - permission := getPermission() - for i := 0; i < relayconstant.APITypeDummy; i++ { - if i == relayconstant.APITypeAIProxyLibrary { + for i := 0; i < constant.APITypeDummy; i++ { + if i == constant.APITypeAIProxyLibrary { continue } adaptor := relay.GetAdaptor(i) @@ -54,69 +35,51 @@ func init() { modelNames := adaptor.GetModelList() for _, modelName := range modelNames { openAIModels = append(openAIModels, dto.OpenAIModels{ - Id: modelName, - Object: "model", - Created: 1626777600, - OwnedBy: channelName, - Permission: permission, - Root: modelName, - Parent: nil, + Id: modelName, + Object: "model", + Created: 1626777600, + OwnedBy: channelName, }) } } for _, modelName := range ai360.ModelList { openAIModels = append(openAIModels, dto.OpenAIModels{ - Id: modelName, - Object: "model", - Created: 1626777600, - OwnedBy: ai360.ChannelName, - Permission: permission, - Root: modelName, - Parent: nil, + Id: modelName, + Object: "model", + Created: 1626777600, + OwnedBy: ai360.ChannelName, }) } for _, modelName := range moonshot.ModelList { openAIModels = append(openAIModels, dto.OpenAIModels{ - Id: modelName, - Object: "model", - Created: 1626777600, - OwnedBy: moonshot.ChannelName, - Permission: permission, - Root: modelName, - Parent: nil, + Id: modelName, + Object: "model", + Created: 1626777600, + OwnedBy: moonshot.ChannelName, }) } for _, modelName := range lingyiwanwu.ModelList { openAIModels = append(openAIModels, dto.OpenAIModels{ - Id: modelName, - Object: "model", - Created: 1626777600, - OwnedBy: lingyiwanwu.ChannelName, - Permission: permission, - Root: modelName, - Parent: nil, + Id: modelName, + Object: "model", + Created: 1626777600, + OwnedBy: lingyiwanwu.ChannelName, }) } for _, modelName := range minimax.ModelList { openAIModels = append(openAIModels, dto.OpenAIModels{ - Id: modelName, - Object: "model", - Created: 1626777600, - OwnedBy: minimax.ChannelName, - Permission: permission, - Root: modelName, - Parent: nil, + Id: modelName, + Object: "model", + Created: 1626777600, + OwnedBy: minimax.ChannelName, }) } for modelName, _ := range constant.MidjourneyModel2Action { openAIModels = append(openAIModels, dto.OpenAIModels{ - Id: modelName, - Object: "model", - Created: 1626777600, - OwnedBy: "midjourney", - Permission: permission, - Root: modelName, - Parent: nil, + Id: modelName, + Object: "model", + Created: 1626777600, + OwnedBy: "midjourney", }) } openAIModelsMap = make(map[string]dto.OpenAIModels) @@ -124,9 +87,9 @@ func init() { openAIModelsMap[aiModel.Id] = aiModel } channelId2Models = make(map[int][]string) - for i := 1; i <= common.ChannelTypeDummy; i++ { - apiType, success := relayconstant.ChannelType2APIType(i) - if !success || apiType == relayconstant.APITypeAIProxyLibrary { + for i := 1; i <= constant.ChannelTypeDummy; i++ { + apiType, success := common.ChannelType2APIType(i) + if !success || apiType == constant.APITypeAIProxyLibrary { continue } meta := &relaycommon.RelayInfo{ChannelType: i} @@ -134,15 +97,17 @@ func init() { adaptor.Init(meta) channelId2Models[i] = adaptor.GetModelList() } + openAIModels = lo.UniqBy(openAIModels, func(m dto.OpenAIModels) string { + return m.Id + }) } func ListModels(c *gin.Context) { userOpenAiModels := make([]dto.OpenAIModels, 0) - permission := getPermission() - modelLimitEnable := c.GetBool("token_model_limit_enabled") + modelLimitEnable := common.GetContextKeyBool(c, constant.ContextKeyTokenModelLimitEnabled) if modelLimitEnable { - s, ok := c.Get("token_model_limit") + s, ok := common.GetContextKey(c, constant.ContextKeyTokenModelLimit) var tokenModelLimit map[string]bool if ok { tokenModelLimit = s.(map[string]bool) @@ -150,23 +115,22 @@ func ListModels(c *gin.Context) { tokenModelLimit = map[string]bool{} } for allowModel, _ := range tokenModelLimit { - if _, ok := openAIModelsMap[allowModel]; ok { - userOpenAiModels = append(userOpenAiModels, openAIModelsMap[allowModel]) + if oaiModel, ok := openAIModelsMap[allowModel]; ok { + oaiModel.SupportedEndpointTypes = model.GetModelSupportEndpointTypes(allowModel) + userOpenAiModels = append(userOpenAiModels, oaiModel) } else { userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{ - Id: allowModel, - Object: "model", - Created: 1626777600, - OwnedBy: "custom", - Permission: permission, - Root: allowModel, - Parent: nil, + Id: allowModel, + Object: "model", + Created: 1626777600, + OwnedBy: "custom", + SupportedEndpointTypes: model.GetModelSupportEndpointTypes(allowModel), }) } } } else { userId := c.GetInt("id") - userGroup, err := model.GetUserGroup(userId, true) + userGroup, err := model.GetUserGroup(userId, false) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -175,23 +139,34 @@ func ListModels(c *gin.Context) { return } group := userGroup - tokenGroup := c.GetString("token_group") + tokenGroup := common.GetContextKeyString(c, constant.ContextKeyTokenGroup) if tokenGroup != "" { group = tokenGroup } - models := model.GetGroupModels(group) - for _, s := range models { - if _, ok := openAIModelsMap[s]; ok { - userOpenAiModels = append(userOpenAiModels, openAIModelsMap[s]) + var models []string + if tokenGroup == "auto" { + for _, autoGroup := range setting.AutoGroups { + groupModels := model.GetGroupEnabledModels(autoGroup) + for _, g := range groupModels { + if !common.StringsContains(models, g) { + models = append(models, g) + } + } + } + } else { + models = model.GetGroupEnabledModels(group) + } + for _, modelName := range models { + if oaiModel, ok := openAIModelsMap[modelName]; ok { + oaiModel.SupportedEndpointTypes = model.GetModelSupportEndpointTypes(modelName) + userOpenAiModels = append(userOpenAiModels, oaiModel) } else { userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{ - Id: s, - Object: "model", - Created: 1626777600, - OwnedBy: "custom", - Permission: permission, - Root: s, - Parent: nil, + Id: modelName, + Object: "model", + Created: 1626777600, + OwnedBy: "custom", + SupportedEndpointTypes: model.GetModelSupportEndpointTypes(modelName), }) } } diff --git a/controller/option.go b/controller/option.go index 79ba2ffe..97bb6a5a 100644 --- a/controller/option.go +++ b/controller/option.go @@ -7,6 +7,7 @@ import ( "one-api/model" "one-api/setting" "one-api/setting/console_setting" + "one-api/setting/ratio_setting" "one-api/setting/system_setting" "strings" @@ -103,7 +104,7 @@ func UpdateOption(c *gin.Context) { return } case "GroupRatio": - err = setting.CheckGroupRatio(option.Value) + err = ratio_setting.CheckGroupRatio(option.Value) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, diff --git a/controller/playground.go b/controller/playground.go index a2b54790..33471455 100644 --- a/controller/playground.go +++ b/controller/playground.go @@ -3,7 +3,6 @@ package controller import ( "errors" "fmt" - "github.com/gin-gonic/gin" "net/http" "one-api/common" "one-api/constant" @@ -13,6 +12,8 @@ import ( "one-api/service" "one-api/setting" "time" + + "github.com/gin-gonic/gin" ) func Playground(c *gin.Context) { @@ -57,13 +58,22 @@ func Playground(c *gin.Context) { c.Set("group", group) } c.Set("token_name", "playground-"+group) - channel, err := model.CacheGetRandomSatisfiedChannel(group, playgroundRequest.Model, 0) + channel, finalGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, playgroundRequest.Model, 0) if err != nil { - message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", group, playgroundRequest.Model) + message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", finalGroup, playgroundRequest.Model) openaiErr = service.OpenAIErrorWrapperLocal(errors.New(message), "get_playground_channel_failed", http.StatusInternalServerError) return } middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model) - c.Set(constant.ContextKeyRequestStartTime, time.Now()) + common.SetContextKey(c, constant.ContextKeyRequestStartTime, time.Now()) + + // Write user context to ensure acceptUnsetRatio is available + userId := c.GetInt("id") + userCache, err := model.GetUserCache(userId) + if err != nil { + openaiErr = service.OpenAIErrorWrapperLocal(err, "get_user_cache_failed", http.StatusInternalServerError) + return + } + userCache.WriteContext(c) Relay(c) } diff --git a/controller/pricing.go b/controller/pricing.go index e6a3e57f..f27336b7 100644 --- a/controller/pricing.go +++ b/controller/pricing.go @@ -3,7 +3,7 @@ package controller import ( "one-api/model" "one-api/setting" - "one-api/setting/operation_setting" + "one-api/setting/ratio_setting" "github.com/gin-gonic/gin" ) @@ -13,7 +13,7 @@ func GetPricing(c *gin.Context) { userId, exists := c.Get("id") usableGroup := map[string]string{} groupRatio := map[string]float64{} - for s, f := range setting.GetGroupRatioCopy() { + for s, f := range ratio_setting.GetGroupRatioCopy() { groupRatio[s] = f } var group string @@ -22,7 +22,7 @@ func GetPricing(c *gin.Context) { if err == nil { group = user.Group for g := range groupRatio { - ratio, ok := setting.GetGroupGroupRatio(group, g) + ratio, ok := ratio_setting.GetGroupGroupRatio(group, g) if ok { groupRatio[g] = ratio } @@ -32,7 +32,7 @@ func GetPricing(c *gin.Context) { usableGroup = setting.GetUserUsableGroups(group) // check groupRatio contains usableGroup - for group := range setting.GetGroupRatioCopy() { + for group := range ratio_setting.GetGroupRatioCopy() { if _, ok := usableGroup[group]; !ok { delete(groupRatio, group) } @@ -47,7 +47,7 @@ func GetPricing(c *gin.Context) { } func ResetModelRatio(c *gin.Context) { - defaultStr := operation_setting.DefaultModelRatio2JSONString() + defaultStr := ratio_setting.DefaultModelRatio2JSONString() err := model.UpdateOption("ModelRatio", defaultStr) if err != nil { c.JSON(200, gin.H{ @@ -56,7 +56,7 @@ func ResetModelRatio(c *gin.Context) { }) return } - err = operation_setting.UpdateModelRatioByJSONString(defaultStr) + err = ratio_setting.UpdateModelRatioByJSONString(defaultStr) if err != nil { c.JSON(200, gin.H{ "success": false, diff --git a/controller/ratio_config.go b/controller/ratio_config.go new file mode 100644 index 00000000..6ddc3d9e --- /dev/null +++ b/controller/ratio_config.go @@ -0,0 +1,24 @@ +package controller + +import ( + "net/http" + "one-api/setting/ratio_setting" + + "github.com/gin-gonic/gin" +) + +func GetRatioConfig(c *gin.Context) { + if !ratio_setting.IsExposeRatioEnabled() { + c.JSON(http.StatusForbidden, gin.H{ + "success": false, + "message": "倍率配置接口未启用", + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": ratio_setting.GetExposedData(), + }) +} \ No newline at end of file diff --git a/controller/ratio_sync.go b/controller/ratio_sync.go new file mode 100644 index 00000000..0453870d --- /dev/null +++ b/controller/ratio_sync.go @@ -0,0 +1,474 @@ +package controller + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "strings" + "sync" + "time" + + "one-api/common" + "one-api/dto" + "one-api/model" + "one-api/setting/ratio_setting" + + "github.com/gin-gonic/gin" +) + +const ( + defaultTimeoutSeconds = 10 + defaultEndpoint = "/api/ratio_config" + maxConcurrentFetches = 8 +) + +var ratioTypes = []string{"model_ratio", "completion_ratio", "cache_ratio", "model_price"} + +type upstreamResult struct { + Name string `json:"name"` + Data map[string]any `json:"data,omitempty"` + Err string `json:"err,omitempty"` +} + +func FetchUpstreamRatios(c *gin.Context) { + var req dto.UpstreamRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": err.Error()}) + return + } + + if req.Timeout <= 0 { + req.Timeout = defaultTimeoutSeconds + } + + var upstreams []dto.UpstreamDTO + + if len(req.Upstreams) > 0 { + for _, u := range req.Upstreams { + if strings.HasPrefix(u.BaseURL, "http") { + if u.Endpoint == "" { + u.Endpoint = defaultEndpoint + } + u.BaseURL = strings.TrimRight(u.BaseURL, "/") + upstreams = append(upstreams, u) + } + } + } else if len(req.ChannelIDs) > 0 { + intIds := make([]int, 0, len(req.ChannelIDs)) + for _, id64 := range req.ChannelIDs { + intIds = append(intIds, int(id64)) + } + dbChannels, err := model.GetChannelsByIds(intIds) + if err != nil { + common.LogError(c.Request.Context(), "failed to query channels: "+err.Error()) + c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": "查询渠道失败"}) + return + } + for _, ch := range dbChannels { + if base := ch.GetBaseURL(); strings.HasPrefix(base, "http") { + upstreams = append(upstreams, dto.UpstreamDTO{ + ID: ch.Id, + Name: ch.Name, + BaseURL: strings.TrimRight(base, "/"), + Endpoint: "", + }) + } + } + } + + if len(upstreams) == 0 { + c.JSON(http.StatusOK, gin.H{"success": false, "message": "无有效上游渠道"}) + return + } + + var wg sync.WaitGroup + ch := make(chan upstreamResult, len(upstreams)) + + sem := make(chan struct{}, maxConcurrentFetches) + + client := &http.Client{Transport: &http.Transport{MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second}} + + for _, chn := range upstreams { + wg.Add(1) + go func(chItem dto.UpstreamDTO) { + defer wg.Done() + + sem <- struct{}{} + defer func() { <-sem }() + + endpoint := chItem.Endpoint + if endpoint == "" { + endpoint = defaultEndpoint + } else if !strings.HasPrefix(endpoint, "/") { + endpoint = "/" + endpoint + } + fullURL := chItem.BaseURL + endpoint + + uniqueName := chItem.Name + if chItem.ID != 0 { + uniqueName = fmt.Sprintf("%s(%d)", chItem.Name, chItem.ID) + } + + ctx, cancel := context.WithTimeout(c.Request.Context(), time.Duration(req.Timeout)*time.Second) + defer cancel() + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil) + if err != nil { + common.LogWarn(c.Request.Context(), "build request failed: "+err.Error()) + ch <- upstreamResult{Name: uniqueName, Err: err.Error()} + return + } + + resp, err := client.Do(httpReq) + if err != nil { + common.LogWarn(c.Request.Context(), "http error on "+chItem.Name+": "+err.Error()) + ch <- upstreamResult{Name: uniqueName, Err: err.Error()} + return + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + common.LogWarn(c.Request.Context(), "non-200 from "+chItem.Name+": "+resp.Status) + ch <- upstreamResult{Name: uniqueName, Err: resp.Status} + return + } + // 兼容两种上游接口格式: + // type1: /api/ratio_config -> data 为 map[string]any,包含 model_ratio/completion_ratio/cache_ratio/model_price + // type2: /api/pricing -> data 为 []Pricing 列表,需要转换为与 type1 相同的 map 格式 + var body struct { + Success bool `json:"success"` + Data json.RawMessage `json:"data"` + Message string `json:"message"` + } + + if err := json.NewDecoder(resp.Body).Decode(&body); err != nil { + common.LogWarn(c.Request.Context(), "json decode failed from "+chItem.Name+": "+err.Error()) + ch <- upstreamResult{Name: uniqueName, Err: err.Error()} + return + } + + if !body.Success { + ch <- upstreamResult{Name: uniqueName, Err: body.Message} + return + } + + // 尝试按 type1 解析 + var type1Data map[string]any + if err := json.Unmarshal(body.Data, &type1Data); err == nil { + // 如果包含至少一个 ratioTypes 字段,则认为是 type1 + isType1 := false + for _, rt := range ratioTypes { + if _, ok := type1Data[rt]; ok { + isType1 = true + break + } + } + if isType1 { + ch <- upstreamResult{Name: uniqueName, Data: type1Data} + return + } + } + + // 如果不是 type1,则尝试按 type2 (/api/pricing) 解析 + var pricingItems []struct { + ModelName string `json:"model_name"` + QuotaType int `json:"quota_type"` + ModelRatio float64 `json:"model_ratio"` + ModelPrice float64 `json:"model_price"` + CompletionRatio float64 `json:"completion_ratio"` + } + if err := json.Unmarshal(body.Data, &pricingItems); err != nil { + common.LogWarn(c.Request.Context(), "unrecognized data format from "+chItem.Name+": "+err.Error()) + ch <- upstreamResult{Name: uniqueName, Err: "无法解析上游返回数据"} + return + } + + modelRatioMap := make(map[string]float64) + completionRatioMap := make(map[string]float64) + modelPriceMap := make(map[string]float64) + + for _, item := range pricingItems { + if item.QuotaType == 1 { + modelPriceMap[item.ModelName] = item.ModelPrice + } else { + modelRatioMap[item.ModelName] = item.ModelRatio + // completionRatio 可能为 0,此时也直接赋值,保持与上游一致 + completionRatioMap[item.ModelName] = item.CompletionRatio + } + } + + converted := make(map[string]any) + + if len(modelRatioMap) > 0 { + ratioAny := make(map[string]any, len(modelRatioMap)) + for k, v := range modelRatioMap { + ratioAny[k] = v + } + converted["model_ratio"] = ratioAny + } + + if len(completionRatioMap) > 0 { + compAny := make(map[string]any, len(completionRatioMap)) + for k, v := range completionRatioMap { + compAny[k] = v + } + converted["completion_ratio"] = compAny + } + + if len(modelPriceMap) > 0 { + priceAny := make(map[string]any, len(modelPriceMap)) + for k, v := range modelPriceMap { + priceAny[k] = v + } + converted["model_price"] = priceAny + } + + ch <- upstreamResult{Name: uniqueName, Data: converted} + }(chn) + } + + wg.Wait() + close(ch) + + localData := ratio_setting.GetExposedData() + + var testResults []dto.TestResult + var successfulChannels []struct { + name string + data map[string]any + } + + for r := range ch { + if r.Err != "" { + testResults = append(testResults, dto.TestResult{ + Name: r.Name, + Status: "error", + Error: r.Err, + }) + } else { + testResults = append(testResults, dto.TestResult{ + Name: r.Name, + Status: "success", + }) + successfulChannels = append(successfulChannels, struct { + name string + data map[string]any + }{name: r.Name, data: r.Data}) + } + } + + differences := buildDifferences(localData, successfulChannels) + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "data": gin.H{ + "differences": differences, + "test_results": testResults, + }, + }) +} + +func buildDifferences(localData map[string]any, successfulChannels []struct { + name string + data map[string]any +}) map[string]map[string]dto.DifferenceItem { + differences := make(map[string]map[string]dto.DifferenceItem) + + allModels := make(map[string]struct{}) + + for _, ratioType := range ratioTypes { + if localRatioAny, ok := localData[ratioType]; ok { + if localRatio, ok := localRatioAny.(map[string]float64); ok { + for modelName := range localRatio { + allModels[modelName] = struct{}{} + } + } + } + } + + for _, channel := range successfulChannels { + for _, ratioType := range ratioTypes { + if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok { + for modelName := range upstreamRatio { + allModels[modelName] = struct{}{} + } + } + } + } + + confidenceMap := make(map[string]map[string]bool) + + // 预处理阶段:检查pricing接口的可信度 + for _, channel := range successfulChannels { + confidenceMap[channel.name] = make(map[string]bool) + + modelRatios, hasModelRatio := channel.data["model_ratio"].(map[string]any) + completionRatios, hasCompletionRatio := channel.data["completion_ratio"].(map[string]any) + + if hasModelRatio && hasCompletionRatio { + // 遍历所有模型,检查是否满足不可信条件 + for modelName := range allModels { + // 默认为可信 + confidenceMap[channel.name][modelName] = true + + // 检查是否满足不可信条件:model_ratio为37.5且completion_ratio为1 + if modelRatioVal, ok := modelRatios[modelName]; ok { + if completionRatioVal, ok := completionRatios[modelName]; ok { + // 转换为float64进行比较 + if modelRatioFloat, ok := modelRatioVal.(float64); ok { + if completionRatioFloat, ok := completionRatioVal.(float64); ok { + if modelRatioFloat == 37.5 && completionRatioFloat == 1.0 { + confidenceMap[channel.name][modelName] = false + } + } + } + } + } + } + } else { + // 如果不是从pricing接口获取的数据,则全部标记为可信 + for modelName := range allModels { + confidenceMap[channel.name][modelName] = true + } + } + } + + for modelName := range allModels { + for _, ratioType := range ratioTypes { + var localValue interface{} = nil + if localRatioAny, ok := localData[ratioType]; ok { + if localRatio, ok := localRatioAny.(map[string]float64); ok { + if val, exists := localRatio[modelName]; exists { + localValue = val + } + } + } + + upstreamValues := make(map[string]interface{}) + confidenceValues := make(map[string]bool) + hasUpstreamValue := false + hasDifference := false + + for _, channel := range successfulChannels { + var upstreamValue interface{} = nil + + if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok { + if val, exists := upstreamRatio[modelName]; exists { + upstreamValue = val + hasUpstreamValue = true + + if localValue != nil && localValue != val { + hasDifference = true + } else if localValue == val { + upstreamValue = "same" + } + } + } + if upstreamValue == nil && localValue == nil { + upstreamValue = "same" + } + + if localValue == nil && upstreamValue != nil && upstreamValue != "same" { + hasDifference = true + } + + upstreamValues[channel.name] = upstreamValue + + confidenceValues[channel.name] = confidenceMap[channel.name][modelName] + } + + shouldInclude := false + + if localValue != nil { + if hasDifference { + shouldInclude = true + } + } else { + if hasUpstreamValue { + shouldInclude = true + } + } + + if shouldInclude { + if differences[modelName] == nil { + differences[modelName] = make(map[string]dto.DifferenceItem) + } + differences[modelName][ratioType] = dto.DifferenceItem{ + Current: localValue, + Upstreams: upstreamValues, + Confidence: confidenceValues, + } + } + } + } + + channelHasDiff := make(map[string]bool) + for _, ratioMap := range differences { + for _, item := range ratioMap { + for chName, val := range item.Upstreams { + if val != nil && val != "same" { + channelHasDiff[chName] = true + } + } + } + } + + for modelName, ratioMap := range differences { + for ratioType, item := range ratioMap { + for chName := range item.Upstreams { + if !channelHasDiff[chName] { + delete(item.Upstreams, chName) + delete(item.Confidence, chName) + } + } + + allSame := true + for _, v := range item.Upstreams { + if v != "same" { + allSame = false + break + } + } + if len(item.Upstreams) == 0 || allSame { + delete(ratioMap, ratioType) + } else { + differences[modelName][ratioType] = item + } + } + + if len(ratioMap) == 0 { + delete(differences, modelName) + } + } + + return differences +} + +func GetSyncableChannels(c *gin.Context) { + channels, err := model.GetAllChannels(0, 0, true, false) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + + var syncableChannels []dto.SyncableChannel + for _, channel := range channels { + if channel.GetBaseURL() != "" { + syncableChannels = append(syncableChannels, dto.SyncableChannel{ + ID: channel.Id, + Name: channel.Name, + BaseURL: channel.GetBaseURL(), + Status: channel.Status, + }) + } + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": syncableChannels, + }) +} \ No newline at end of file diff --git a/controller/relay.go b/controller/relay.go index 1a875dbc..e375120b 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -8,12 +8,12 @@ import ( "log" "net/http" "one-api/common" + "one-api/constant" constant2 "one-api/constant" "one-api/dto" "one-api/middleware" "one-api/model" "one-api/relay" - "one-api/relay/constant" relayconstant "one-api/relay/constant" "one-api/relay/helper" "one-api/service" @@ -69,7 +69,7 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode } func Relay(c *gin.Context) { - relayMode := constant.Path2RelayMode(c.Request.URL.Path) + relayMode := relayconstant.Path2RelayMode(c.Request.URL.Path) requestId := c.GetString(common.RequestIdKey) group := c.GetString("group") originalModel := c.GetString("original_model") @@ -132,7 +132,7 @@ func WssRelay(c *gin.Context) { return } - relayMode := constant.Path2RelayMode(c.Request.URL.Path) + relayMode := relayconstant.Path2RelayMode(c.Request.URL.Path) requestId := c.GetString(common.RequestIdKey) group := c.GetString("group") //wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01 @@ -259,7 +259,7 @@ func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*m AutoBan: &autoBanInt, }, nil } - channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, retryCount) + channel, _, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount) if err != nil { return nil, errors.New(fmt.Sprintf("获取重试渠道失败: %s", err.Error())) } @@ -295,7 +295,7 @@ func shouldRetry(c *gin.Context, openaiErr *dto.OpenAIErrorWithStatusCode, retry } if openaiErr.StatusCode == http.StatusBadRequest { channelType := c.GetInt("channel_type") - if channelType == common.ChannelTypeAnthropic { + if channelType == constant.ChannelTypeAnthropic { return true } return false @@ -388,7 +388,7 @@ func RelayTask(c *gin.Context) { retryTimes = 0 } for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ { - channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, i) + channel, _, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, i) if err != nil { common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error())) break @@ -420,7 +420,7 @@ func RelayTask(c *gin.Context) { func taskRelayHandler(c *gin.Context, relayMode int) *dto.TaskError { var err *dto.TaskError switch relayMode { - case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID: + case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID, relayconstant.RelayModeKlingFetchByID: err = relay.RelayTaskFetch(c, relayMode) default: err = relay.RelayTaskSubmit(c, relayMode) diff --git a/controller/task.go b/controller/task.go index 34e14f3f..5cfa728a 100644 --- a/controller/task.go +++ b/controller/task.go @@ -74,6 +74,8 @@ func UpdateTaskByPlatform(platform constant.TaskPlatform, taskChannelM map[int][ //_ = UpdateMidjourneyTaskAll(context.Background(), tasks) case constant.TaskPlatformSuno: _ = UpdateSunoTaskAll(context.Background(), taskChannelM, taskM) + case constant.TaskPlatformKling, constant.TaskPlatformJimeng: + _ = UpdateVideoTaskAll(context.Background(), platform, taskChannelM, taskM) default: common.SysLog("未知平台") } diff --git a/controller/task_video.go b/controller/task_video.go new file mode 100644 index 00000000..b62978a7 --- /dev/null +++ b/controller/task_video.go @@ -0,0 +1,138 @@ +package controller + +import ( + "context" + "fmt" + "io" + "one-api/common" + "one-api/constant" + "one-api/model" + "one-api/relay" + "one-api/relay/channel" + "time" +) + +func UpdateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) error { + for channelId, taskIds := range taskChannelM { + if err := updateVideoTaskAll(ctx, platform, channelId, taskIds, taskM); err != nil { + common.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error())) + } + } + return nil +} + +func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, channelId int, taskIds []string, taskM map[string]*model.Task) error { + common.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds))) + if len(taskIds) == 0 { + return nil + } + cacheGetChannel, err := model.CacheGetChannel(channelId) + if err != nil { + errUpdate := model.TaskBulkUpdate(taskIds, map[string]any{ + "fail_reason": fmt.Sprintf("Failed to get channel info, channel ID: %d", channelId), + "status": "FAILURE", + "progress": "100%", + }) + if errUpdate != nil { + common.SysError(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate)) + } + return fmt.Errorf("CacheGetChannel failed: %w", err) + } + adaptor := relay.GetTaskAdaptor(platform) + if adaptor == nil { + return fmt.Errorf("video adaptor not found") + } + for _, taskId := range taskIds { + if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil { + common.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error())) + } + } + return nil +} + +func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, channel *model.Channel, taskId string, taskM map[string]*model.Task) error { + baseURL := constant.ChannelBaseURLs[channel.Type] + if channel.GetBaseURL() != "" { + baseURL = channel.GetBaseURL() + } + + task := taskM[taskId] + if task == nil { + common.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId)) + return fmt.Errorf("task %s not found", taskId) + } + resp, err := adaptor.FetchTask(baseURL, channel.Key, map[string]any{ + "task_id": taskId, + "action": task.Action, + }) + if err != nil { + return fmt.Errorf("fetchTask failed for task %s: %w", taskId, err) + } + //if resp.StatusCode != http.StatusOK { + //return fmt.Errorf("get Video Task status code: %d", resp.StatusCode) + //} + defer resp.Body.Close() + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("readAll failed for task %s: %w", taskId, err) + } + + taskResult, err := adaptor.ParseTaskResult(responseBody) + if err != nil { + return fmt.Errorf("parseTaskResult failed for task %s: %w", taskId, err) + } + //if taskResult.Code != 0 { + // return fmt.Errorf("video task fetch failed for task %s", taskId) + //} + + now := time.Now().Unix() + if taskResult.Status == "" { + return fmt.Errorf("task %s status is empty", taskId) + } + task.Status = model.TaskStatus(taskResult.Status) + switch taskResult.Status { + case model.TaskStatusSubmitted: + task.Progress = "10%" + case model.TaskStatusQueued: + task.Progress = "20%" + case model.TaskStatusInProgress: + task.Progress = "30%" + if task.StartTime == 0 { + task.StartTime = now + } + case model.TaskStatusSuccess: + task.Progress = "100%" + if task.FinishTime == 0 { + task.FinishTime = now + } + task.FailReason = taskResult.Url + case model.TaskStatusFailure: + task.Status = model.TaskStatusFailure + task.Progress = "100%" + if task.FinishTime == 0 { + task.FinishTime = now + } + task.FailReason = taskResult.Reason + common.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason)) + quota := task.Quota + if quota != 0 { + if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil { + common.LogError(ctx, "Failed to increase user quota: "+err.Error()) + } + logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, common.LogQuota(quota)) + model.RecordLog(task.UserId, model.LogTypeSystem, logContent) + } + default: + return fmt.Errorf("unknown task status %s for task %s", taskResult.Status, taskId) + } + if taskResult.Progress != "" { + task.Progress = taskResult.Progress + } + + task.Data = responseBody + if err := task.Update(); err != nil { + common.SysError("UpdateVideoTask task error: " + err.Error()) + } + + return nil +} diff --git a/controller/token.go b/controller/token.go index c57552c0..173fc22e 100644 --- a/controller/token.go +++ b/controller/token.go @@ -258,3 +258,32 @@ func UpdateToken(c *gin.Context) { }) return } + +type TokenBatch struct { + Ids []int `json:"ids"` +} + +func DeleteTokenBatch(c *gin.Context) { + tokenBatch := TokenBatch{} + if err := c.ShouldBindJSON(&tokenBatch); err != nil || len(tokenBatch.Ids) == 0 { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "参数错误", + }) + return + } + userId := c.GetInt("id") + count, err := model.BatchDeleteTokens(tokenBatch.Ids, userId) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": count, + }) +} diff --git a/controller/topup.go b/controller/topup.go index 951b2cf2..827dda39 100644 --- a/controller/topup.go +++ b/controller/topup.go @@ -97,14 +97,12 @@ func RequestEpay(c *gin.Context) { c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"}) return } - payType := "wxpay" - if req.PaymentMethod == "zfb" { - payType = "alipay" - } - if req.PaymentMethod == "wx" { - req.PaymentMethod = "wxpay" - payType = "wxpay" + + if !setting.ContainsPayMethod(req.PaymentMethod) { + c.JSON(200, gin.H{"message": "error", "data": "支付方式不存在"}) + return } + callBackAddress := service.GetCallbackAddress() returnUrl, _ := url.Parse(setting.ServerAddress + "/console/log") notifyUrl, _ := url.Parse(callBackAddress + "/api/user/epay/notify") @@ -116,7 +114,7 @@ func RequestEpay(c *gin.Context) { return } uri, params, err := client.Purchase(&epay.PurchaseArgs{ - Type: payType, + Type: req.PaymentMethod, ServiceTradeNo: tradeNo, Name: fmt.Sprintf("TUC%d", req.Amount), Money: strconv.FormatFloat(payMoney, 'f', 2, 64), diff --git a/controller/user.go b/controller/user.go index d7eb42d7..ca161f42 100644 --- a/controller/user.go +++ b/controller/user.go @@ -226,6 +226,9 @@ func Register(c *gin.Context) { UnlimitedQuota: true, ModelLimitsEnabled: false, } + if setting.DefaultUseAutoGroup { + token.Group = "auto" + } if err := token.Insert(); err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -243,15 +246,15 @@ func Register(c *gin.Context) { } func GetAllUsers(c *gin.Context) { - p, _ := strconv.Atoi(c.Query("p")) - pageSize, _ := strconv.Atoi(c.Query("page_size")) - if p < 1 { - p = 1 + pageInfo, err := common.GetPageQuery(c) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "parse page query failed", + }) + return } - if pageSize < 0 { - pageSize = common.ItemsPerPage - } - users, total, err := model.GetAllUsers((p-1)*pageSize, pageSize) + users, total, err := model.GetAllUsers(pageInfo) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -259,15 +262,13 @@ func GetAllUsers(c *gin.Context) { }) return } + + pageInfo.SetTotal(int(total)) + pageInfo.SetItems(users) c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", - "data": gin.H{ - "items": users, - "total": total, - "page": p, - "page_size": pageSize, - }, + "data": pageInfo, }) return } @@ -459,6 +460,9 @@ func GetSelf(c *gin.Context) { }) return } + // Hide admin remarks: set to empty to trigger omitempty tag, ensuring the remark field is not included in JSON returned to regular users + user.Remark = "" + c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", @@ -483,7 +487,7 @@ func GetUserModels(c *gin.Context) { groups := setting.GetUserUsableGroups(user.Group) var models []string for group := range groups { - for _, g := range model.GetGroupModels(group) { + for _, g := range model.GetGroupEnabledModels(group) { if !common.StringsContains(models, g) { models = append(models, g) } diff --git a/docker-compose.yml b/docker-compose.yml index fef6a803..57ad0b30 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,4 +1,4 @@ -version: '3.4' # 兼容旧版docker-compose +version: '3.4' services: new-api: @@ -16,6 +16,7 @@ services: - REDIS_CONN_STRING=redis://redis - TZ=Asia/Shanghai - ERROR_LOG_ENABLED=true # 是否启用错误日志记录 + # - STREAMING_TIMEOUT=120 # 流模式无响应超时时间,单位秒,默认120秒,如果出现空补全可以尝试改为更大值 # - SESSION_SECRET=random_string # 多机部署时设置,必须修改这个随机字符串!!!!!!! # - NODE_TYPE=slave # Uncomment for slave node in multi-node deployment # - SYNC_FREQUENCY=60 # Uncomment if regular database syncing is needed diff --git a/dto/claude.go b/dto/claude.go index 4d24bc70..98e09c78 100644 --- a/dto/claude.go +++ b/dto/claude.go @@ -178,7 +178,14 @@ type ClaudeRequest struct { type Thinking struct { Type string `json:"type"` - BudgetTokens int `json:"budget_tokens"` + BudgetTokens *int `json:"budget_tokens,omitempty"` +} + +func (c *Thinking) GetBudgetTokens() int { + if c.BudgetTokens == nil { + return 0 + } + return *c.BudgetTokens } func (c *ClaudeRequest) IsStringSystem() bool { diff --git a/dto/dalle.go b/dto/dalle.go index a1309b6c..ce2f6361 100644 --- a/dto/dalle.go +++ b/dto/dalle.go @@ -15,6 +15,7 @@ type ImageRequest struct { Background string `json:"background,omitempty"` Moderation string `json:"moderation,omitempty"` OutputFormat string `json:"output_format,omitempty"` + Watermark *bool `json:"watermark,omitempty"` } type ImageResponse struct { diff --git a/dto/midjourney.go b/dto/midjourney.go index 40251ee9..6fbcb357 100644 --- a/dto/midjourney.go +++ b/dto/midjourney.go @@ -57,6 +57,8 @@ type MidjourneyDto struct { StartTime int64 `json:"startTime"` FinishTime int64 `json:"finishTime"` ImageUrl string `json:"imageUrl"` + VideoUrl string `json:"videoUrl"` + VideoUrls []ImgUrls `json:"videoUrls"` Status string `json:"status"` Progress string `json:"progress"` FailReason string `json:"failReason"` @@ -65,6 +67,10 @@ type MidjourneyDto struct { Properties *Properties `json:"properties"` } +type ImgUrls struct { + Url string `json:"url"` +} + type MidjourneyStatus struct { Status int `json:"status"` } diff --git a/dto/openai_request.go b/dto/openai_request.go index 299171ba..a6567542 100644 --- a/dto/openai_request.go +++ b/dto/openai_request.go @@ -53,9 +53,11 @@ type GeneralOpenAIRequest struct { Modalities json.RawMessage `json:"modalities,omitempty"` Audio json.RawMessage `json:"audio,omitempty"` EnableThinking any `json:"enable_thinking,omitempty"` // ali + THINKING json.RawMessage `json:"thinking,omitempty"` // doubao ExtraBody json.RawMessage `json:"extra_body,omitempty"` WebSearchOptions *WebSearchOptions `json:"web_search_options,omitempty"` // OpenRouter Params + Usage json.RawMessage `json:"usage,omitempty"` Reasoning json.RawMessage `json:"reasoning,omitempty"` // Ali Qwen Params VlHighResolutionImages json.RawMessage `json:"vl_high_resolution_images,omitempty"` @@ -64,7 +66,7 @@ type GeneralOpenAIRequest struct { func (r *GeneralOpenAIRequest) ToMap() map[string]any { result := make(map[string]any) data, _ := common.EncodeJson(r) - _ = common.DecodeJson(data, &result) + _ = common.UnmarshalJson(data, &result) return result } @@ -644,4 +646,6 @@ type ResponsesToolsCall struct { Name string `json:"name,omitempty"` Description string `json:"description,omitempty"` Parameters json.RawMessage `json:"parameters,omitempty"` + Function json.RawMessage `json:"function,omitempty"` + Container json.RawMessage `json:"container,omitempty"` } diff --git a/dto/openai_response.go b/dto/openai_response.go index 790d4df8..d95acd9e 100644 --- a/dto/openai_response.go +++ b/dto/openai_response.go @@ -26,7 +26,7 @@ type OpenAITextResponse struct { Id string `json:"id"` Model string `json:"model"` Object string `json:"object"` - Created int64 `json:"created"` + Created any `json:"created"` Choices []OpenAITextResponseChoice `json:"choices"` Error *OpenAIError `json:"error,omitempty"` Usage `json:"usage"` @@ -178,6 +178,8 @@ type Usage struct { InputTokens int `json:"input_tokens"` OutputTokens int `json:"output_tokens"` InputTokensDetails *InputTokenDetails `json:"input_tokens_details"` + // OpenRouter Params + Cost float64 `json:"cost,omitempty"` } type InputTokenDetails struct { diff --git a/dto/pricing.go b/dto/pricing.go index ee77c098..0f317d9d 100644 --- a/dto/pricing.go +++ b/dto/pricing.go @@ -1,26 +1,11 @@ package dto -type OpenAIModelPermission struct { - Id string `json:"id"` - Object string `json:"object"` - Created int `json:"created"` - AllowCreateEngine bool `json:"allow_create_engine"` - AllowSampling bool `json:"allow_sampling"` - AllowLogprobs bool `json:"allow_logprobs"` - AllowSearchIndices bool `json:"allow_search_indices"` - AllowView bool `json:"allow_view"` - AllowFineTuning bool `json:"allow_fine_tuning"` - Organization string `json:"organization"` - Group *string `json:"group"` - IsBlocking bool `json:"is_blocking"` -} +import "one-api/constant" type OpenAIModels struct { - Id string `json:"id"` - Object string `json:"object"` - Created int `json:"created"` - OwnedBy string `json:"owned_by"` - Permission []OpenAIModelPermission `json:"permission"` - Root string `json:"root"` - Parent *string `json:"parent"` + Id string `json:"id"` + Object string `json:"object"` + Created int `json:"created"` + OwnedBy string `json:"owned_by"` + SupportedEndpointTypes []constant.EndpointType `json:"supported_endpoint_types"` } diff --git a/dto/ratio_sync.go b/dto/ratio_sync.go new file mode 100644 index 00000000..6315f31a --- /dev/null +++ b/dto/ratio_sync.go @@ -0,0 +1,38 @@ +package dto + +type UpstreamDTO struct { + ID int `json:"id,omitempty"` + Name string `json:"name" binding:"required"` + BaseURL string `json:"base_url" binding:"required"` + Endpoint string `json:"endpoint"` +} + +type UpstreamRequest struct { + ChannelIDs []int64 `json:"channel_ids"` + Upstreams []UpstreamDTO `json:"upstreams"` + Timeout int `json:"timeout"` +} + +// TestResult 上游测试连通性结果 +type TestResult struct { + Name string `json:"name"` + Status string `json:"status"` + Error string `json:"error,omitempty"` +} + +// DifferenceItem 差异项 +// Current 为本地值,可能为 nil +// Upstreams 为各渠道的上游值,具体数值 / "same" / nil + +type DifferenceItem struct { + Current interface{} `json:"current"` + Upstreams map[string]interface{} `json:"upstreams"` + Confidence map[string]bool `json:"confidence"` +} + +type SyncableChannel struct { + ID int `json:"id"` + Name string `json:"name"` + BaseURL string `json:"base_url"` + Status int `json:"status"` +} \ No newline at end of file diff --git a/dto/rerank.go b/dto/rerank.go index 21f6437c..5ea68cba 100644 --- a/dto/rerank.go +++ b/dto/rerank.go @@ -4,7 +4,7 @@ type RerankRequest struct { Documents []any `json:"documents"` Query string `json:"query"` Model string `json:"model"` - TopN int `json:"top_n"` + TopN int `json:"top_n,omitempty"` ReturnDocuments *bool `json:"return_documents,omitempty"` MaxChunkPerDoc int `json:"max_chunk_per_doc,omitempty"` OverLapTokens int `json:"overlap_tokens,omitempty"` diff --git a/dto/video.go b/dto/video.go new file mode 100644 index 00000000..5b48146a --- /dev/null +++ b/dto/video.go @@ -0,0 +1,47 @@ +package dto + +type VideoRequest struct { + Model string `json:"model,omitempty" example:"kling-v1"` // Model/style ID + Prompt string `json:"prompt,omitempty" example:"宇航员站起身走了"` // Text prompt + Image string `json:"image,omitempty" example:"https://h2.inkwai.com/bs2/upload-ylab-stunt/se/ai_portal_queue_mmu_image_upscale_aiweb/3214b798-e1b4-4b00-b7af-72b5b0417420_raw_image_0.jpg"` // Image input (URL/Base64) + Duration float64 `json:"duration" example:"5.0"` // Video duration (seconds) + Width int `json:"width" example:"512"` // Video width + Height int `json:"height" example:"512"` // Video height + Fps int `json:"fps,omitempty" example:"30"` // Video frame rate + Seed int `json:"seed,omitempty" example:"20231234"` // Random seed + N int `json:"n,omitempty" example:"1"` // Number of videos to generate + ResponseFormat string `json:"response_format,omitempty" example:"url"` // Response format + User string `json:"user,omitempty" example:"user-1234"` // User identifier + Metadata map[string]any `json:"metadata,omitempty"` // Vendor-specific/custom params (e.g. negative_prompt, style, quality_level, etc.) +} + +// VideoResponse 视频生成提交任务后的响应 +type VideoResponse struct { + TaskId string `json:"task_id"` + Status string `json:"status"` +} + +// VideoTaskResponse 查询视频生成任务状态的响应 +type VideoTaskResponse struct { + TaskId string `json:"task_id" example:"abcd1234efgh"` // 任务ID + Status string `json:"status" example:"succeeded"` // 任务状态 + Url string `json:"url,omitempty"` // 视频资源URL(成功时) + Format string `json:"format,omitempty" example:"mp4"` // 视频格式 + Metadata *VideoTaskMetadata `json:"metadata,omitempty"` // 结果元数据 + Error *VideoTaskError `json:"error,omitempty"` // 错误信息(失败时) +} + +// VideoTaskMetadata 视频任务元数据 +type VideoTaskMetadata struct { + Duration float64 `json:"duration" example:"5.0"` // 实际生成的视频时长 + Fps int `json:"fps" example:"30"` // 实际帧率 + Width int `json:"width" example:"512"` // 实际宽度 + Height int `json:"height" example:"512"` // 实际高度 + Seed int `json:"seed" example:"20231234"` // 使用的随机种子 +} + +// VideoTaskError 视频任务错误信息 +type VideoTaskError struct { + Code int `json:"code"` + Message string `json:"message"` +} diff --git a/i18n/zh-cn.json b/i18n/zh-cn.json new file mode 100644 index 00000000..7b57b51a --- /dev/null +++ b/i18n/zh-cn.json @@ -0,0 +1,1041 @@ +{ + "未登录或登录已过期,请重新登录": "未登录或登录已过期,请重新登录", + "登 录": "登 录", + "使用 微信 继续": "使用 微信 继续", + "使用 GitHub 继续": "使用 GitHub 继续", + "使用 LinuxDO 继续": "使用 LinuxDO 继续", + "使用 邮箱或用户名 登录": "使用 邮箱或用户名 登录", + "没有账户?": "没有账户?", + "用户名或邮箱": "用户名或邮箱", + "请输入您的用户名或邮箱地址": "请输入您的用户名或邮箱地址", + "请输入您的密码": "请输入您的密码", + "继续": "继续", + "忘记密码?": "忘记密码?", + "其他登录选项": "其他登录选项", + "微信扫码登录": "微信扫码登录", + "登录": "登录", + "微信扫码关注公众号,输入「验证码」获取验证码(三分钟内有效)": "微信扫码关注公众号,输入「验证码」获取验证码(三分钟内有效)", + "验证码": "验证码", + "处理中...": "处理中...", + "绑定成功!": "绑定成功!", + "登录成功!": "登录成功!", + "操作失败,重定向至登录界面中...": "操作失败,重定向至登录界面中...", + "出现错误,第 ${count} 次重试中...": "出现错误,第 ${count} 次重试中...", + "无效的重置链接,请重新发起密码重置请求": "无效的重置链接,请重新发起密码重置请求", + "密码已重置并已复制到剪贴板:": "密码已重置并已复制到剪贴板:", + "密码重置确认": "密码重置确认", + "等待获取邮箱信息...": "等待获取邮箱信息...", + "新密码": "新密码", + "密码已复制到剪贴板:": "密码已复制到剪贴板:", + "密码重置完成": "密码重置完成", + "确认重置密码": "确认重置密码", + "返回登录": "返回登录", + "请输入邮箱地址": "请输入邮箱地址", + "请稍后几秒重试,Turnstile 正在检查用户环境!": "请稍后几秒重试,Turnstile 正在检查用户环境!", + "重置邮件发送成功,请检查邮箱!": "重置邮件发送成功,请检查邮箱!", + "密码重置": "密码重置", + "请输入您的邮箱地址": "请输入您的邮箱地址", + "重试": "重试", + "想起来了?": "想起来了?", + "注 册": "注 册", + "使用 用户名 注册": "使用 用户名 注册", + "已有账户?": "已有账户?", + "用户名": "用户名", + "请输入用户名": "请输入用户名", + "输入密码,最短 8 位,最长 20 位": "输入密码,最短 8 位,最长 20 位", + "确认密码": "确认密码", + "输入邮箱地址": "输入邮箱地址", + "获取验证码": "获取验证码", + "输入验证码": "输入验证码", + "或": "或", + "其他注册选项": "其他注册选项", + "加载中...": "加载中...", + "复制代码": "复制代码", + "代码已复制到剪贴板": "代码已复制到剪贴板", + "复制失败,请手动复制": "复制失败,请手动复制", + "显示更多": "显示更多", + "关于我们": "关于我们", + "关于项目": "关于项目", + "联系我们": "联系我们", + "功能特性": "功能特性", + "快速开始": "快速开始", + "安装指南": "安装指南", + "API 文档": "API 文档", + "基于New API的项目": "基于New API的项目", + "版权所有": "版权所有", + "设计与开发由": "设计与开发由", + "首页": "首页", + "控制台": "控制台", + "文档": "文档", + "关于": "关于", + "注销成功!": "注销成功!", + "个人设置": "个人设置", + "API令牌": "API令牌", + "退出": "退出", + "关闭侧边栏": "关闭侧边栏", + "打开侧边栏": "打开侧边栏", + "关闭菜单": "关闭菜单", + "打开菜单": "打开菜单", + "演示站点": "演示站点", + "自用模式": "自用模式", + "系统公告": "系统公告", + "切换主题": "切换主题", + "切换语言": "切换语言", + "暂无公告": "暂无公告", + "暂无系统公告": "暂无系统公告", + "今日关闭": "今日关闭", + "关闭公告": "关闭公告", + "数据看板": "数据看板", + "绘图日志": "绘图日志", + "任务日志": "任务日志", + "渠道": "渠道", + "兑换码": "兑换码", + "用户管理": "用户管理", + "操练场": "操练场", + "聊天": "聊天", + "管理员": "管理员", + "个人中心": "个人中心", + "展开侧边栏": "展开侧边栏", + "AI 对话": "AI 对话", + "选择模型开始对话": "选择模型开始对话", + "显示调试": "显示调试", + "请输入您的问题...": "请输入您的问题...", + "已复制到剪贴板": "已复制到剪贴板", + "复制失败": "复制失败", + "正在构造请求体预览...": "正在构造请求体预览...", + "暂无请求数据": "暂无请求数据", + "暂无响应数据": "暂无响应数据", + "内容较大,已启用性能优化模式": "内容较大,已启用性能优化模式", + "内容较大,部分功能可能受限": "内容较大,部分功能可能受限", + "已复制": "已复制", + "正在处理大内容...": "正在处理大内容...", + "显示完整内容": "显示完整内容", + "收起": "收起", + "配置已导出到下载文件夹": "配置已导出到下载文件夹", + "导出配置失败: ": "导出配置失败: ", + "确认导入配置": "确认导入配置", + "导入的配置将覆盖当前设置,是否继续?": "导入的配置将覆盖当前设置,是否继续?", + "取消": "取消", + "配置导入成功": "配置导入成功", + "导入配置失败: ": "导入配置失败: ", + "重置配置": "重置配置", + "将清除所有保存的配置并恢复默认设置,此操作不可撤销。是否继续?": "将清除所有保存的配置并恢复默认设置,此操作不可撤销。是否继续?", + "重置选项": "重置选项", + "是否同时重置对话消息?选择\"是\"将清空所有对话记录并恢复默认示例;选择\"否\"将保留当前对话记录。": "是否同时重置对话消息?选择\"是\"将清空所有对话记录并恢复默认示例;选择\"否\"将保留当前对话记录。", + "同时重置消息": "同时重置消息", + "仅重置配置": "仅重置配置", + "配置和消息已全部重置": "配置和消息已全部重置", + "配置已重置,对话消息已保留": "配置已重置,对话消息已保留", + "已有保存的配置": "已有保存的配置", + "暂无保存的配置": "暂无保存的配置", + "导出配置": "导出配置", + "导入配置": "导入配置", + "导出": "导出", + "导入": "导入", + "调试信息": "调试信息", + "预览请求体": "预览请求体", + "实际请求体": "实际请求体", + "预览更新": "预览更新", + "最后请求": "最后请求", + "操作暂时被禁用": "操作暂时被禁用", + "复制": "复制", + "编辑": "编辑", + "切换为System角色": "切换为System角色", + "切换为Assistant角色": "切换为Assistant角色", + "删除": "删除", + "请求发生错误": "请求发生错误", + "系统消息": "系统消息", + "请输入消息内容...": "请输入消息内容...", + "保存": "保存", + "模型配置": "模型配置", + "分组": "分组", + "请选择分组": "请选择分组", + "请选择模型": "请选择模型", + "思考中...": "思考中...", + "思考过程": "思考过程", + "选择同步渠道": "选择同步渠道", + "搜索渠道名称或地址": "搜索渠道名称或地址", + "暂无渠道": "暂无渠道", + "暂无选择": "暂无选择", + "无搜索结果": "无搜索结果", + "公告已更新": "公告已更新", + "公告更新失败": "公告更新失败", + "系统名称已更新": "系统名称已更新", + "系统名称更新失败": "系统名称更新失败", + "系统信息": "系统信息", + "当前版本": "当前版本", + "检查更新": "检查更新", + "启动时间": "启动时间", + "通用设置": "通用设置", + "设置公告": "设置公告", + "个性化设置": "个性化设置", + "系统名称": "系统名称", + "在此输入系统名称": "在此输入系统名称", + "设置系统名称": "设置系统名称", + "Logo 图片地址": "Logo 图片地址", + "在此输入 Logo 图片地址": "在此输入 Logo 图片地址", + "首页内容": "首页内容", + "设置首页内容": "设置首页内容", + "设置关于": "设置关于", + "页脚": "页脚", + "设置页脚": "设置页脚", + "详情": "详情", + "刷新失败": "刷新失败", + "令牌已重置并已复制到剪贴板": "令牌已重置并已复制到剪贴板", + "加载模型列表失败": "加载模型列表失败", + "系统令牌已复制到剪切板": "系统令牌已复制到剪切板", + "请输入你的账户名以确认删除!": "请输入你的账户名以确认删除!", + "账户已删除!": "账户已删除!", + "微信账户绑定成功!": "微信账户绑定成功!", + "请输入原密码!": "请输入原密码!", + "请输入新密码!": "请输入新密码!", + "新密码需要和原密码不一致!": "新密码需要和原密码不一致!", + "两次输入的密码不一致!": "两次输入的密码不一致!", + "密码修改成功!": "密码修改成功!", + "验证码发送成功,请检查邮箱!": "验证码发送成功,请检查邮箱!", + "请输入邮箱验证码!": "请输入邮箱验证码!", + "邮箱账户绑定成功!": "邮箱账户绑定成功!", + "无法复制到剪贴板,请手动复制": "无法复制到剪贴板,请手动复制", + "设置保存成功": "设置保存成功", + "设置保存失败": "设置保存失败", + "超级管理员": "超级管理员", + "普通用户": "普通用户", + "当前余额": "当前余额", + "历史消耗": "历史消耗", + "请求次数": "请求次数", + "默认": "默认", + "可用模型": "可用模型", + "模型列表": "模型列表", + "点击模型名称可复制": "点击模型名称可复制", + "没有可用模型": "没有可用模型", + "该分类下没有可用模型": "该分类下没有可用模型", + "更多": "更多", + "个模型": "个模型", + "账户绑定": "账户绑定", + "未绑定": "未绑定", + "修改绑定": "修改绑定", + "微信": "微信", + "已绑定": "已绑定", + "未启用": "未启用", + "绑定": "绑定", + "安全设置": "安全设置", + "系统访问令牌": "系统访问令牌", + "用于API调用的身份验证令牌,请妥善保管": "用于API调用的身份验证令牌,请妥善保管", + "生成令牌": "生成令牌", + "密码管理": "密码管理", + "定期更改密码可以提高账户安全性": "定期更改密码可以提高账户安全性", + "修改密码": "修改密码", + "此操作不可逆,所有数据将被永久删除": "此操作不可逆,所有数据将被永久删除", + "删除账户": "删除账户", + "其他设置": "其他设置", + "通知设置": "通知设置", + "邮件通知": "邮件通知", + "通过邮件接收通知": "通过邮件接收通知", + "Webhook通知": "Webhook通知", + "通过HTTP请求接收通知": "通过HTTP请求接收通知", + "请输入Webhook地址,例如: https://example.com/webhook": "请输入Webhook地址,例如: https://example.com/webhook", + "只支持https,系统将以 POST 方式发送通知,请确保地址可以接收 POST 请求": "只支持https,系统将以 POST 方式发送通知,请确保地址可以接收 POST 请求", + "接口凭证(可选)": "接口凭证(可选)", + "请输入密钥": "请输入密钥", + "密钥将以 Bearer 方式添加到请求头中,用于验证webhook请求的合法性": "密钥将以 Bearer 方式添加到请求头中,用于验证webhook请求的合法性", + "通知邮箱": "通知邮箱", + "留空则使用账号绑定的邮箱": "留空则使用账号绑定的邮箱", + "设置用于接收额度预警的邮箱地址,不填则使用账号绑定的邮箱": "设置用于接收额度预警的邮箱地址,不填则使用账号绑定的邮箱", + "额度预警阈值": "额度预警阈值", + "请输入预警额度": "请输入预警额度", + "当剩余额度低于此数值时,系统将通过选择的方式发送通知": "当剩余额度低于此数值时,系统将通过选择的方式发送通知", + "接受未设置价格模型": "接受未设置价格模型", + "当模型没有设置价格时仍接受调用,仅当您信任该网站时使用,可能会产生高额费用": "当模型没有设置价格时仍接受调用,仅当您信任该网站时使用,可能会产生高额费用", + "IP记录": "IP记录", + "记录请求与错误日志 IP": "记录请求与错误日志 IP", + "开启后,仅“消费”和“错误”日志将记录您的客户端 IP 地址": "开启后,仅“消费”和“错误”日志将记录您的客户端 IP 地址", + "绑定邮箱地址": "绑定邮箱地址", + "重新发送": "重新发送", + "绑定微信账户": "绑定微信账户", + "删除账户确认": "删除账户确认", + "您正在删除自己的帐户,将清空所有数据且不可恢复": "您正在删除自己的帐户,将清空所有数据且不可恢复", + "请输入您的用户名以确认删除": "请输入您的用户名以确认删除", + "输入你的账户名{{username}}以确认删除": "输入你的账户名{{username}}以确认删除", + "原密码": "原密码", + "请输入原密码": "请输入原密码", + "请输入新密码": "请输入新密码", + "确认新密码": "确认新密码", + "请再次输入新密码": "请再次输入新密码", + "模型倍率设置": "模型倍率设置", + "可视化倍率设置": "可视化倍率设置", + "未设置倍率模型": "未设置倍率模型", + "上游倍率同步": "上游倍率同步", + "未知类型": "未知类型", + "标签聚合": "标签聚合", + "已启用": "已启用", + "自动禁用": "自动禁用", + "未知状态": "未知状态", + "未测试": "未测试", + "名称": "名称", + "类型": "类型", + "状态": "状态", + ",时间:": ",时间:", + "响应时间": "响应时间", + "已用/剩余": "已用/剩余", + "剩余额度$": "剩余额度$", + ",点击更新": ",点击更新", + "已用额度": "已用额度", + "修改子渠道优先级": "修改子渠道优先级", + "确定要修改所有子渠道优先级为 ": "确定要修改所有子渠道优先级为 ", + "权重": "权重", + "修改子渠道权重": "修改子渠道权重", + "确定要修改所有子渠道权重为 ": "确定要修改所有子渠道权重为 ", + "确定是否要删除此渠道?": "确定是否要删除此渠道?", + "此修改将不可逆": "此修改将不可逆", + "确定是否要复制此渠道?": "确定是否要复制此渠道?", + "复制渠道的所有信息": "复制渠道的所有信息", + "测试单个渠道操作项目组": "测试单个渠道操作项目组", + "禁用": "禁用", + "启用": "启用", + "启用全部": "启用全部", + "禁用全部": "禁用全部", + "重置": "重置", + "全选": "全选", + "_复制": "_复制", + "渠道未找到,请刷新页面后重试。": "渠道未找到,请刷新页面后重试。", + "渠道复制成功": "渠道复制成功", + "渠道复制失败: ": "渠道复制失败: ", + "操作成功完成!": "操作成功完成!", + "通道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。": "通道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。", + "已停止测试": "已停止测试", + "全部": "全部", + "请先选择要设置标签的渠道!": "请先选择要设置标签的渠道!", + "标签不能为空!": "标签不能为空!", + "已为 ${count} 个渠道设置标签!": "已为 ${count} 个渠道设置标签!", + "已成功开始测试所有已启用通道,请刷新页面查看结果。": "已成功开始测试所有已启用通道,请刷新页面查看结果。", + "已删除所有禁用渠道,共计 ${data} 个": "已删除所有禁用渠道,共计 ${data} 个", + "已更新完毕所有已启用通道余额!": "已更新完毕所有已启用通道余额!", + "通道 ${name} 余额更新成功!": "通道 ${name} 余额更新成功!", + "已删除 ${data} 个通道!": "已删除 ${data} 个通道!", + "已修复 ${data} 个通道!": "已修复 ${data} 个通道!", + "确定是否要删除所选通道?": "确定是否要删除所选通道?", + "删除所选通道": "删除所选通道", + "批量设置标签": "批量设置标签", + "确定要测试所有通道吗?": "确定要测试所有通道吗?", + "测试所有通道": "测试所有通道", + "确定要更新所有已启用通道余额吗?": "确定要更新所有已启用通道余额吗?", + "更新所有已启用通道余额": "更新所有已启用通道余额", + "确定是否要删除禁用通道?": "确定是否要删除禁用通道?", + "删除禁用通道": "删除禁用通道", + "确定是否要修复数据库一致性?": "确定是否要修复数据库一致性?", + "进行该操作时,可能导致渠道访问错误,请仅在数据库出现问题时使用": "进行该操作时,可能导致渠道访问错误,请仅在数据库出现问题时使用", + "批量操作": "批量操作", + "使用ID排序": "使用ID排序", + "开启批量操作": "开启批量操作", + "标签聚合模式": "标签聚合模式", + "刷新": "刷新", + "列设置": "列设置", + "搜索渠道的 ID,名称,密钥和API地址 ...": "搜索渠道的 ID,名称,密钥和API地址 ...", + "模型关键字": "模型关键字", + "选择分组": "选择分组", + "查询": "查询", + "第 {{start}} - {{end}} 条,共 {{total}} 条": "第 {{start}} - {{end}} 条,共 {{total}} 条", + "搜索无结果": "搜索无结果", + "请输入要设置的标签名称": "请输入要设置的标签名称", + "请输入标签名称": "请输入标签名称", + "已选择 ${count} 个渠道": "已选择 ${count} 个渠道", + "共": "共", + "停止测试": "停止测试", + "测试中...": "测试中...", + "批量测试${count}个模型": "批量测试${count}个模型", + "搜索模型...": "搜索模型...", + "模型名称": "模型名称", + "测试中": "测试中", + "未开始": "未开始", + "失败": "失败", + "请求时长: ${time}s": "请求时长: ${time}s", + "充值": "充值", + "消费": "消费", + "系统": "系统", + "错误": "错误", + "流": "流", + "非流": "非流", + "请求并计费模型": "请求并计费模型", + "实际模型": "实际模型", + "用户": "用户", + "用时/首字": "用时/首字", + "提示": "提示", + "花费": "花费", + "只有当用户设置开启IP记录时,才会进行请求和错误类型日志的IP记录": "只有当用户设置开启IP记录时,才会进行请求和错误类型日志的IP记录", + "确定": "确定", + "用户信息": "用户信息", + "渠道信息": "渠道信息", + "语音输入": "语音输入", + "文字输入": "文字输入", + "文字输出": "文字输出", + "缓存创建 Tokens": "缓存创建 Tokens", + "日志详情": "日志详情", + "消耗额度": "消耗额度", + "开始时间": "开始时间", + "结束时间": "结束时间", + "用户名称": "用户名称", + "日志类型": "日志类型", + "绘图": "绘图", + "放大": "放大", + "变换": "变换", + "强变换": "强变换", + "平移": "平移", + "图生文": "图生文", + "图混合": "图混合", + "重绘": "重绘", + "局部重绘-提交": "局部重绘-提交", + "自定义变焦-提交": "自定义变焦-提交", + "窗口处理": "窗口处理", + "未知": "未知", + "已提交": "已提交", + "等待中": "等待中", + "重复提交": "重复提交", + "成功": "成功", + "未启动": "未启动", + "执行中": "执行中", + "窗口等待": "窗口等待", + "秒": "秒", + "提交时间": "提交时间", + "花费时间": "花费时间", + "任务ID": "任务ID", + "提交结果": "提交结果", + "任务状态": "任务状态", + "结果图片": "结果图片", + "查看图片": "查看图片", + "无": "无", + "失败原因": "失败原因", + "已复制:": "已复制:", + "当前未开启Midjourney回调,部分项目可能无法获得绘图结果,可在运营设置中开启。": "当前未开启Midjourney回调,部分项目可能无法获得绘图结果,可在运营设置中开启。", + "Midjourney 任务记录": "Midjourney 任务记录", + "任务 ID": "任务 ID", + "按次计费": "按次计费", + "按量计费": "按量计费", + "您的分组可以使用该模型": "您的分组可以使用该模型", + "可用性": "可用性", + "计费类型": "计费类型", + "当前查看的分组为:{{group}},倍率为:{{ratio}}": "当前查看的分组为:{{group}},倍率为:{{ratio}}", + "倍率": "倍率", + "倍率是为了方便换算不同价格的模型": "倍率是为了方便换算不同价格的模型", + "模型倍率": "模型倍率", + "补全倍率": "补全倍率", + "分组倍率": "分组倍率", + "模型价格": "模型价格", + "补全": "补全", + "模糊搜索模型名称": "模糊搜索模型名称", + "复制选中模型": "复制选中模型", + "模型定价": "模型定价", + "当前分组": "当前分组", + "未登录,使用默认分组倍率": "未登录,使用默认分组倍率", + "按量计费费用 = 分组倍率 × 模型倍率 × (提示token数 + 补全token数 × 补全倍率)/ 500000 (单位:美元)": "按量计费费用 = 分组倍率 × 模型倍率 × (提示token数 + 补全token数 × 补全倍率)/ 500000 (单位:美元)", + "已过期": "已过期", + "未使用": "未使用", + "已禁用": "已禁用", + "创建时间": "创建时间", + "过期时间": "过期时间", + "永不过期": "永不过期", + "确定是否要删除此兑换码?": "确定是否要删除此兑换码?", + "查看": "查看", + "已复制到剪贴板!": "已复制到剪贴板!", + "兑换码可以批量生成和分发,适合用于推广活动或批量充值。": "兑换码可以批量生成和分发,适合用于推广活动或批量充值。", + "添加兑换码": "添加兑换码", + "请至少选择一个兑换码!": "请至少选择一个兑换码!", + "复制所选兑换码到剪贴板": "复制所选兑换码到剪贴板", + "确定清除所有失效兑换码?": "确定清除所有失效兑换码?", + "将删除已使用、已禁用及过期的兑换码,此操作不可撤销。": "将删除已使用、已禁用及过期的兑换码,此操作不可撤销。", + "已删除 {{count}} 条失效兑换码": "已删除 {{count}} 条失效兑换码", + "关键字(id或者名称)": "关键字(id或者名称)", + "生成音乐": "生成音乐", + "生成歌词": "生成歌词", + "生成视频": "生成视频", + "排队中": "排队中", + "正在提交": "正在提交", + "平台": "平台", + "点击预览视频": "点击预览视频", + "任务记录": "任务记录", + "渠道 ID": "渠道 ID", + "已启用:限制模型": "已启用:限制模型", + "已耗尽": "已耗尽", + "剩余额度": "剩余额度", + "聊天链接配置错误,请联系管理员": "聊天链接配置错误,请联系管理员", + "令牌详情": "令牌详情", + "确定是否要删除此令牌?": "确定是否要删除此令牌?", + "项目操作按钮组": "项目操作按钮组", + "请联系管理员配置聊天链接": "请联系管理员配置聊天链接", + "令牌用于API访问认证,可以设置额度限制和模型权限。": "令牌用于API访问认证,可以设置额度限制和模型权限。", + "添加令牌": "添加令牌", + "请至少选择一个令牌!": "请至少选择一个令牌!", + "复制所选令牌到剪贴板": "复制所选令牌到剪贴板", + "搜索关键字": "搜索关键字", + "未知身份": "未知身份", + "已封禁": "已封禁", + "统计信息": "统计信息", + "剩余": "剩余", + "调用": "调用", + "邀请信息": "邀请信息", + "收益": "收益", + "无邀请人": "无邀请人", + "已注销": "已注销", + "确定要提升此用户吗?": "确定要提升此用户吗?", + "此操作将提升用户的权限级别": "此操作将提升用户的权限级别", + "确定要降级此用户吗?": "确定要降级此用户吗?", + "此操作将降低用户的权限级别": "此操作将降低用户的权限级别", + "确定是否要注销此用户?": "确定是否要注销此用户?", + "相当于删除用户,此修改将不可逆": "相当于删除用户,此修改将不可逆", + "用户管理页面,可以查看和管理所有注册用户的信息、权限和状态。": "用户管理页面,可以查看和管理所有注册用户的信息、权限和状态。", + "添加用户": "添加用户", + "支持搜索用户的 ID、用户名、显示名称和邮箱地址": "支持搜索用户的 ID、用户名、显示名称和邮箱地址", + "全部模型": "全部模型", + "智谱": "智谱", + "通义千问": "通义千问", + "文心一言": "文心一言", + "腾讯混元": "腾讯混元", + "360智脑": "360智脑", + "豆包": "豆包", + "用户分组": "用户分组", + "专属倍率": "专属倍率", + "输入价格:${{price}} / 1M tokens{{audioPrice}}": "输入价格:${{price}} / 1M tokens{{audioPrice}}", + "Web搜索价格:${{price}} / 1K 次": "Web搜索价格:${{price}} / 1K 次", + "文件搜索价格:${{price}} / 1K 次": "文件搜索价格:${{price}} / 1K 次", + "仅供参考,以实际扣费为准": "仅供参考,以实际扣费为准", + "价格:${{price}} * {{ratioType}}:{{ratio}}": "价格:${{price}} * {{ratioType}}:{{ratio}}", + "模型: {{ratio}} * {{ratioType}}:{{groupRatio}}": "模型: {{ratio}} * {{ratioType}}:{{groupRatio}}", + "提示价格:${{price}} / 1M tokens": "提示价格:${{price}} / 1M tokens", + "模型价格 ${{price}},{{ratioType}} {{ratio}}": "模型价格 ${{price}},{{ratioType}} {{ratio}}", + "模型: {{ratio}} * {{ratioType}}: {{groupRatio}}": "模型: {{ratio}} * {{ratioType}}: {{groupRatio}}", + "不是合法的 JSON 字符串": "不是合法的 JSON 字符串", + "请求发生错误: ": "请求发生错误: ", + "解析响应数据时发生错误": "解析响应数据时发生错误", + "连接已断开": "连接已断开", + "建立连接时发生错误": "建立连接时发生错误", + "加载模型失败": "加载模型失败", + "加载分组失败": "加载分组失败", + "消息已复制到剪贴板": "消息已复制到剪贴板", + "确认删除": "确认删除", + "确定要删除这条消息吗?": "确定要删除这条消息吗?", + "已删除消息及其回复": "已删除消息及其回复", + "消息已删除": "消息已删除", + "消息已编辑": "消息已编辑", + "检测到该消息后有AI回复,是否删除后续回复并重新生成?": "检测到该消息后有AI回复,是否删除后续回复并重新生成?", + "重新生成": "重新生成", + "消息已更新": "消息已更新", + "加载关于内容失败...": "加载关于内容失败...", + "可在设置页面设置关于内容,支持 HTML & Markdown": "可在设置页面设置关于内容,支持 HTML & Markdown", + "New API项目仓库地址:": "New API项目仓库地址:", + "| 基于": "| 基于", + "本项目根据": "本项目根据", + "MIT许可证": "MIT许可证", + "授权,需在遵守": "授权,需在遵守", + "Apache-2.0协议": "Apache-2.0协议", + "管理员暂时未设置任何关于内容": "管理员暂时未设置任何关于内容", + "仅支持 OpenAI 接口格式": "仅支持 OpenAI 接口格式", + "请填写密钥": "请填写密钥", + "获取模型列表成功": "获取模型列表成功", + "获取模型列表失败": "获取模型列表失败", + "请填写渠道名称和渠道密钥!": "请填写渠道名称和渠道密钥!", + "请至少选择一个模型!": "请至少选择一个模型!", + "模型映射必须是合法的 JSON 格式!": "模型映射必须是合法的 JSON 格式!", + "提交失败,请勿重复提交!": "提交失败,请勿重复提交!", + "渠道创建成功!": "渠道创建成功!", + "已新增 {{count}} 个模型:{{list}}": "已新增 {{count}} 个模型:{{list}}", + "未发现新增模型": "未发现新增模型", + "新建": "新建", + "更新渠道信息": "更新渠道信息", + "创建新的渠道": "创建新的渠道", + "基本信息": "基本信息", + "渠道的基本配置信息": "渠道的基本配置信息", + "请选择渠道类型": "请选择渠道类型", + "请为渠道命名": "请为渠道命名", + "请输入密钥,一行一个": "请输入密钥,一行一个", + "批量创建": "批量创建", + "API 配置": "API 配置", + "API 地址和相关配置": "API 地址和相关配置", + "2025年5月10日后添加的渠道,不需要再在部署的时候移除模型名称中的\".\"": "2025年5月10日后添加的渠道,不需要再在部署的时候移除模型名称中的\".\"", + "请输入 AZURE_OPENAI_ENDPOINT,例如:https://docs-test-001.openai.azure.com": "请输入 AZURE_OPENAI_ENDPOINT,例如:https://docs-test-001.openai.azure.com", + "请输入默认 API 版本,例如:2025-04-01-preview": "请输入默认 API 版本,例如:2025-04-01-preview", + "如果你对接的是上游One API或者New API等转发项目,请使用OpenAI类型,不要使用此类型,除非你知道你在做什么。": "如果你对接的是上游One API或者New API等转发项目,请使用OpenAI类型,不要使用此类型,除非你知道你在做什么。", + "完整的 Base URL,支持变量{model}": "完整的 Base URL,支持变量{model}", + "请输入完整的URL,例如:https://api.openai.com/v1/chat/completions": "请输入完整的URL,例如:https://api.openai.com/v1/chat/completions", + "Dify渠道只适配chatflow和agent,并且agent不支持图片!": "Dify渠道只适配chatflow和agent,并且agent不支持图片!", + "此项可选,用于通过自定义API地址来进行 API 调用,末尾不要带/v1和/": "此项可选,用于通过自定义API地址来进行 API 调用,末尾不要带/v1和/", + "对于官方渠道,new-api已经内置地址,除非是第三方代理站点或者Azure的特殊接入地址,否则不需要填写": "对于官方渠道,new-api已经内置地址,除非是第三方代理站点或者Azure的特殊接入地址,否则不需要填写", + "私有部署地址": "私有部署地址", + "请输入私有部署地址,格式为:https://fastgpt.run/api/openapi": "请输入私有部署地址,格式为:https://fastgpt.run/api/openapi", + "注意非Chat API,请务必填写正确的API地址,否则可能导致无法使用": "注意非Chat API,请务必填写正确的API地址,否则可能导致无法使用", + "请输入到 /suno 前的路径,通常就是域名,例如:https://api.example.com": "请输入到 /suno 前的路径,通常就是域名,例如:https://api.example.com", + "模型选择和映射设置": "模型选择和映射设置", + "模型": "模型", + "请选择该渠道所支持的模型": "请选择该渠道所支持的模型", + "填入相关模型": "填入相关模型", + "填入所有模型": "填入所有模型", + "获取模型列表": "获取模型列表", + "清除所有模型": "清除所有模型", + "输入自定义模型名称": "输入自定义模型名称", + "模型重定向": "模型重定向", + "此项可选,用于修改请求体中的模型名称,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型名称,例如:": "此项可选,用于修改请求体中的模型名称,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型名称,例如:", + "填入模板": "填入模板", + "默认测试模型": "默认测试模型", + "不填则为模型列表第一个": "不填则为模型列表第一个", + "渠道的高级配置选项": "渠道的高级配置选项", + "请选择可以使用该渠道的分组": "请选择可以使用该渠道的分组", + "请在系统设置页面编辑分组倍率以添加新的分组:": "请在系统设置页面编辑分组倍率以添加新的分组:", + "部署地区": "部署地区", + "知识库 ID": "知识库 ID", + "渠道标签": "渠道标签", + "渠道优先级": "渠道优先级", + "渠道权重": "渠道权重", + "渠道额外设置": "渠道额外设置", + "此项可选,用于配置渠道特定设置,为一个 JSON 字符串,例如:": "此项可选,用于配置渠道特定设置,为一个 JSON 字符串,例如:", + "参数覆盖": "参数覆盖", + "此项可选,用于覆盖请求参数。不支持覆盖 stream 参数。为一个 JSON 字符串,例如:": "此项可选,用于覆盖请求参数。不支持覆盖 stream 参数。为一个 JSON 字符串,例如:", + "请输入组织org-xxx": "请输入组织org-xxx", + "组织,可选,不填则为默认组织": "组织,可选,不填则为默认组织", + "是否自动禁用(仅当自动禁用开启时有效),关闭后不会自动禁用该渠道": "是否自动禁用(仅当自动禁用开启时有效),关闭后不会自动禁用该渠道", + "状态码复写(仅影响本地判断,不修改返回到上游的状态码)": "状态码复写(仅影响本地判断,不修改返回到上游的状态码)", + "此项可选,用于复写返回的状态码,比如将claude渠道的400错误复写为500(用于重试),请勿滥用该功能,例如:": "此项可选,用于复写返回的状态码,比如将claude渠道的400错误复写为500(用于重试),请勿滥用该功能,例如:", + "编辑标签": "编辑标签", + "标签信息": "标签信息", + "标签的基本配置": "标签的基本配置", + "所有编辑均为覆盖操作,留空则不更改": "所有编辑均为覆盖操作,留空则不更改", + "标签名称": "标签名称", + "请输入新标签,留空则解散标签": "请输入新标签,留空则解散标签", + "当前模型列表为该标签下所有渠道模型列表最长的一个,并非所有渠道的并集,请注意可能导致某些渠道模型丢失。": "当前模型列表为该标签下所有渠道模型列表最长的一个,并非所有渠道的并集,请注意可能导致某些渠道模型丢失。", + "请选择该渠道所支持的模型,留空则不更改": "请选择该渠道所支持的模型,留空则不更改", + "此项可选,用于修改请求体中的模型名称,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型名称,留空则不更改": "此项可选,用于修改请求体中的模型名称,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型名称,留空则不更改", + "清空重定向": "清空重定向", + "分组设置": "分组设置", + "用户分组配置": "用户分组配置", + "请选择可以使用该渠道的分组,留空则不更改": "请选择可以使用该渠道的分组,留空则不更改", + "正在跳转...": "正在跳转...", + "小时": "小时", + "周": "周", + "模型调用次数占比": "模型调用次数占比", + "模型消耗分布": "模型消耗分布", + "总计": "总计", + "早上好": "早上好", + "中午好": "中午好", + "下午好": "下午好", + "账户数据": "账户数据", + "使用统计": "使用统计", + "统计次数": "统计次数", + "资源消耗": "资源消耗", + "统计额度": "统计额度", + "性能指标": "性能指标", + "平均RPM": "平均RPM", + "复制成功": "复制成功", + "进行中": "进行中", + "异常": "异常", + "正常": "正常", + "可用率": "可用率", + "有异常": "有异常", + "高延迟": "高延迟", + "维护中": "维护中", + "暂无监控数据": "暂无监控数据", + "搜索条件": "搜索条件", + "时间粒度": "时间粒度", + "模型数据分析": "模型数据分析", + "消耗分布": "消耗分布", + "调用次数分布": "调用次数分布", + "API信息": "API信息", + "暂无API信息": "暂无API信息", + "请联系管理员在系统设置中配置API信息": "请联系管理员在系统设置中配置API信息", + "显示最新20条": "显示最新20条", + "请联系管理员在系统设置中配置公告信息": "请联系管理员在系统设置中配置公告信息", + "暂无常见问答": "暂无常见问答", + "请联系管理员在系统设置中配置常见问答": "请联系管理员在系统设置中配置常见问答", + "服务可用性": "服务可用性", + "请联系管理员在系统设置中配置Uptime": "请联系管理员在系统设置中配置Uptime", + "加载首页内容失败...": "加载首页内容失败...", + "统一的大模型接口网关": "统一的大模型接口网关", + "更好的价格,更好的稳定性,无需订阅": "更好的价格,更好的稳定性,无需订阅", + "开始使用": "开始使用", + "支持众多的大模型供应商": "支持众多的大模型供应商", + "页面未找到,请检查您的浏览器地址是否正确": "页面未找到,请检查您的浏览器地址是否正确", + "登录过期,请重新登录!": "登录过期,请重新登录!", + "兑换码更新成功!": "兑换码更新成功!", + "兑换码创建成功!": "兑换码创建成功!", + "兑换码创建成功": "兑换码创建成功", + "兑换码创建成功,是否下载兑换码?": "兑换码创建成功,是否下载兑换码?", + "兑换码将以文本文件的形式下载,文件名为兑换码的名称。": "兑换码将以文本文件的形式下载,文件名为兑换码的名称。", + "更新兑换码信息": "更新兑换码信息", + "创建新的兑换码": "创建新的兑换码", + "设置兑换码的基本信息": "设置兑换码的基本信息", + "请输入名称": "请输入名称", + "选择过期时间(可选,留空为永久)": "选择过期时间(可选,留空为永久)", + "额度设置": "额度设置", + "设置兑换码的额度和数量": "设置兑换码的额度和数量", + "请输入额度": "请输入额度", + "生成数量": "生成数量", + "请输入生成数量": "请输入生成数量", + "你似乎并没有修改什么": "你似乎并没有修改什么", + "部分保存失败,请重试": "部分保存失败,请重试", + "保存成功": "保存成功", + "保存失败,请重试": "保存失败,请重试", + "请检查输入": "请检查输入", + "聊天配置": "聊天配置", + "为一个 JSON 文本": "为一个 JSON 文本", + "保存聊天设置": "保存聊天设置", + "设置已保存": "设置已保存", + "API地址": "API地址", + "说明": "说明", + "颜色": "颜色", + "API信息管理,可以配置多个API地址用于状态展示和负载均衡(最多50个)": "API信息管理,可以配置多个API地址用于状态展示和负载均衡(最多50个)", + "批量删除": "批量删除", + "保存设置": "保存设置", + "添加API": "添加API", + "请输入API地址": "请输入API地址", + "如:香港线路": "如:香港线路", + "请输入线路描述": "请输入线路描述", + "如:大带宽批量分析图片推荐": "如:大带宽批量分析图片推荐", + "请输入说明": "请输入说明", + "标识颜色": "标识颜色", + "确定要删除此API信息吗?": "确定要删除此API信息吗?", + "警告": "警告", + "发布时间": "发布时间", + "操作": "操作", + "系统公告管理,可以发布系统通知和重要消息(最多100个,前端显示最新20条)": "系统公告管理,可以发布系统通知和重要消息(最多100个,前端显示最新20条)", + "添加公告": "添加公告", + "编辑公告": "编辑公告", + "公告内容": "公告内容", + "请输入公告内容": "请输入公告内容", + "请选择发布日期": "请选择发布日期", + "公告类型": "公告类型", + "说明信息": "说明信息", + "可选,公告的补充说明": "可选,公告的补充说明", + "确定要删除此公告吗?": "确定要删除此公告吗?", + "数据看板设置": "数据看板设置", + "启用数据看板(实验性)": "启用数据看板(实验性)", + "数据看板更新间隔": "数据看板更新间隔", + "设置过短会影响数据库性能": "设置过短会影响数据库性能", + "数据看板默认时间粒度": "数据看板默认时间粒度", + "仅修改展示粒度,统计精确到小时": "仅修改展示粒度,统计精确到小时", + "保存数据看板设置": "保存数据看板设置", + "问题标题": "问题标题", + "回答内容": "回答内容", + "常见问答管理,为用户提供常见问题的答案(最多50个,前端显示最新20条)": "常见问答管理,为用户提供常见问题的答案(最多50个,前端显示最新20条)", + "添加问答": "添加问答", + "编辑问答": "编辑问答", + "请输入问题标题": "请输入问题标题", + "请输入回答内容": "请输入回答内容", + "确定要删除此问答吗?": "确定要删除此问答吗?", + "分类名称": "分类名称", + "Uptime Kuma地址": "Uptime Kuma地址", + "Uptime Kuma监控分类管理,可以配置多个监控分类用于服务状态展示(最多20个)": "Uptime Kuma监控分类管理,可以配置多个监控分类用于服务状态展示(最多20个)", + "编辑分类": "编辑分类", + "添加分类": "添加分类", + "请输入分类名称,如:OpenAI、Claude等": "请输入分类名称,如:OpenAI、Claude等", + "请输入分类名称": "请输入分类名称", + "请输入Uptime Kuma服务地址,如:https://status.example.com": "请输入Uptime Kuma服务地址,如:https://status.example.com", + "请输入Uptime Kuma地址": "请输入Uptime Kuma地址", + "请输入状态页面的Slug,如:my-status": "请输入状态页面的Slug,如:my-status", + "请输入状态页面Slug": "请输入状态页面Slug", + "确定要删除此分类吗?": "确定要删除此分类吗?", + "绘图设置": "绘图设置", + "启用绘图功能": "启用绘图功能", + "允许回调(会泄露服务器 IP 地址)": "允许回调(会泄露服务器 IP 地址)", + "允许 AccountFilter 参数": "允许 AccountFilter 参数", + "开启之后会清除用户提示词中的": "开启之后会清除用户提示词中的", + "以及": "以及", + "检测必须等待绘图成功才能进行放大等操作": "检测必须等待绘图成功才能进行放大等操作", + "保存绘图设置": "保存绘图设置", + "Claude设置": "Claude设置", + "Claude请求头覆盖": "Claude请求头覆盖", + "为一个 JSON 文本,例如:": "为一个 JSON 文本,例如:", + "缺省 MaxTokens": "缺省 MaxTokens", + "启用Claude思考适配(-thinking后缀)": "启用Claude思考适配(-thinking后缀)", + "思考适配 BudgetTokens 百分比": "思考适配 BudgetTokens 百分比", + "0.1-1之间的小数": "0.1-1之间的小数", + "Gemini设置": "Gemini设置", + "Gemini安全设置": "Gemini安全设置", + "default为默认设置,可单独设置每个模型的版本": "default为默认设置,可单独设置每个模型的版本", + "例如:": "例如:", + "Gemini思考适配设置": "Gemini思考适配设置", + "启用Gemini思考后缀适配": "启用Gemini思考后缀适配", + "适配 -thinking、-thinking-预算数字 和 -nothinking 后缀": "适配 -thinking、-thinking-预算数字 和 -nothinking 后缀", + "0.002-1之间的小数": "0.002-1之间的小数", + "全局设置": "全局设置", + "启用请求透传": "启用请求透传", + "连接保活设置": "连接保活设置", + "启用Ping间隔": "启用Ping间隔", + "Ping间隔(秒)": "Ping间隔(秒)", + "新用户初始额度": "新用户初始额度", + "请求预扣费额度": "请求预扣费额度", + "请求结束后多退少补": "请求结束后多退少补", + "邀请新用户奖励额度": "邀请新用户奖励额度", + "新用户使用邀请码奖励额度": "新用户使用邀请码奖励额度", + "例如:1000": "例如:1000", + "保存额度设置": "保存额度设置", + "例如发卡网站的购买链接": "例如发卡网站的购买链接", + "文档地址": "文档地址", + "单位美元额度": "单位美元额度", + "一单位货币能兑换的额度": "一单位货币能兑换的额度", + "失败重试次数": "失败重试次数", + "以货币形式显示额度": "以货币形式显示额度", + "额度查询接口返回令牌额度而非用户额度": "额度查询接口返回令牌额度而非用户额度", + "默认折叠侧边栏": "默认折叠侧边栏", + "开启后不限制:必须设置模型倍率": "开启后不限制:必须设置模型倍率", + "保存通用设置": "保存通用设置", + "请选择日志记录时间": "请选择日志记录时间", + "条日志已清理!": "条日志已清理!", + "日志清理失败:": "日志清理失败:", + "启用额度消费日志记录": "启用额度消费日志记录", + "日志记录时间": "日志记录时间", + "清除历史日志": "清除历史日志", + "保存日志设置": "保存日志设置", + "监控设置": "监控设置", + "测试所有渠道的最长响应时间": "测试所有渠道的最长响应时间", + "额度提醒阈值": "额度提醒阈值", + "低于此额度时将发送邮件提醒用户": "低于此额度时将发送邮件提醒用户", + "失败时自动禁用通道": "失败时自动禁用通道", + "成功时自动启用通道": "成功时自动启用通道", + "自动禁用关键词": "自动禁用关键词", + "一行一个,不区分大小写": "一行一个,不区分大小写", + "屏蔽词过滤设置": "屏蔽词过滤设置", + "启用屏蔽词过滤功能": "启用屏蔽词过滤功能", + "启用 Prompt 检查": "启用 Prompt 检查", + "一行一个屏蔽词,不需要符号分割": "一行一个屏蔽词,不需要符号分割", + "保存屏蔽词过滤设置": "保存屏蔽词过滤设置", + "更新成功": "更新成功", + "更新失败": "更新失败", + "服务器地址": "服务器地址", + "更新服务器地址": "更新服务器地址", + "请先填写服务器地址": "请先填写服务器地址", + "充值分组倍率不是合法的 JSON 字符串": "充值分组倍率不是合法的 JSON 字符串", + "充值方式设置不是合法的 JSON 字符串": "充值方式设置不是合法的 JSON 字符串", + "支付设置": "支付设置", + "(当前仅支持易支付接口,默认使用上方服务器地址作为回调地址!)": "(当前仅支持易支付接口,默认使用上方服务器地址作为回调地址!)", + "例如:https://yourdomain.com": "例如:https://yourdomain.com", + "易支付商户ID": "易支付商户ID", + "易支付商户密钥": "易支付商户密钥", + "敏感信息不会发送到前端显示": "敏感信息不会发送到前端显示", + "回调地址": "回调地址", + "充值价格(x元/美金)": "充值价格(x元/美金)", + "例如:7,就是7元/美金": "例如:7,就是7元/美金", + "最低充值美元数量": "最低充值美元数量", + "例如:2,就是最低充值2$": "例如:2,就是最低充值2$", + "为一个 JSON 文本,键为组名称,值为倍率": "为一个 JSON 文本,键为组名称,值为倍率", + "充值方式设置": "充值方式设置", + "更新支付设置": "更新支付设置", + "模型请求速率限制": "模型请求速率限制", + "启用用户模型请求速率限制(可能会影响高并发性能)": "启用用户模型请求速率限制(可能会影响高并发性能)", + "分钟": "分钟", + "频率限制的周期(分钟)": "频率限制的周期(分钟)", + "用户每周期最多请求次数": "用户每周期最多请求次数", + "包括失败请求的次数,0代表不限制": "包括失败请求的次数,0代表不限制", + "用户每周期最多请求完成次数": "用户每周期最多请求完成次数", + "只包括请求成功的次数": "只包括请求成功的次数", + "分组速率限制": "分组速率限制", + "使用 JSON 对象格式,格式为:{\"组名\": [最多请求次数, 最多请求完成次数]}": "使用 JSON 对象格式,格式为:{\"组名\": [最多请求次数, 最多请求完成次数]}", + "示例:{\"default\": [200, 100], \"vip\": [0, 1000]}。": "示例:{\"default\": [200, 100], \"vip\": [0, 1000]}。", + "[最多请求次数]必须大于等于0,[最多请求完成次数]必须大于等于1。": "[最多请求次数]必须大于等于0,[最多请求完成次数]必须大于等于1。", + "分组速率配置优先级高于全局速率限制。": "分组速率配置优先级高于全局速率限制。", + "限制周期统一使用上方配置的“限制周期”值。": "限制周期统一使用上方配置的“限制周期”值。", + "保存模型速率限制": "保存模型速率限制", + "保存失败": "保存失败", + "为一个 JSON 文本,键为分组名称,值为倍率": "为一个 JSON 文本,键为分组名称,值为倍率", + "用户可选分组": "用户可选分组", + "为一个 JSON 文本,键为分组名称,值为分组描述": "为一个 JSON 文本,键为分组名称,值为分组描述", + "自动分组auto,从第一个开始选择": "自动分组auto,从第一个开始选择", + "必须是有效的 JSON 字符串数组,例如:[\"g1\",\"g2\"]": "必须是有效的 JSON 字符串数组,例如:[\"g1\",\"g2\"]", + "模型固定价格": "模型固定价格", + "一次调用消耗多少刀,优先级大于模型倍率": "一次调用消耗多少刀,优先级大于模型倍率", + "为一个 JSON 文本,键为模型名称,值为倍率": "为一个 JSON 文本,键为模型名称,值为倍率", + "模型补全倍率(仅对自定义模型有效)": "模型补全倍率(仅对自定义模型有效)", + "仅对自定义模型有效": "仅对自定义模型有效", + "保存模型倍率设置": "保存模型倍率设置", + "确定重置模型倍率吗?": "确定重置模型倍率吗?", + "重置模型倍率": "重置模型倍率", + "获取启用模型失败:": "获取启用模型失败:", + "获取启用模型失败": "获取启用模型失败", + "JSON解析错误:": "JSON解析错误:", + "保存失败:": "保存失败:", + "输入模型倍率": "输入模型倍率", + "输入补全倍率": "输入补全倍率", + "请输入数字": "请输入数字", + "模型名称已存在": "模型名称已存在", + "请先选择需要批量设置的模型": "请先选择需要批量设置的模型", + "请输入模型倍率和补全倍率": "请输入模型倍率和补全倍率", + "请输入有效的数字": "请输入有效的数字", + "请输入填充值": "请输入填充值", + "批量设置成功": "批量设置成功", + "已为 {{count}} 个模型设置{{type}}": "已为 {{count}} 个模型设置{{type}}", + "模型倍率和补全倍率": "模型倍率和补全倍率", + "添加模型": "添加模型", + "批量设置": "批量设置", + "应用更改": "应用更改", + "搜索模型名称": "搜索模型名称", + "此页面仅显示未设置价格或倍率的模型,设置后将自动从列表中移除": "此页面仅显示未设置价格或倍率的模型,设置后将自动从列表中移除", + "定价模式": "定价模式", + "固定价格": "固定价格", + "固定价格(每次)": "固定价格(每次)", + "输入每次价格": "输入每次价格", + "输入补全价格": "输入补全价格", + "批量设置模型参数": "批量设置模型参数", + "设置类型": "设置类型", + "模型倍率和补全倍率同时设置": "模型倍率和补全倍率同时设置", + "模型倍率值": "模型倍率值", + "请输入模型倍率": "请输入模型倍率", + "补全倍率值": "补全倍率值", + "请输入补全倍率": "请输入补全倍率", + "请输入数值": "请输入数值", + "将为选中的 ": "将为选中的 ", + " 个模型设置相同的值": " 个模型设置相同的值", + "当前设置类型: ": "当前设置类型: ", + "默认补全倍率": "默认补全倍率", + "添加成功": "添加成功", + "价格设置方式": "价格设置方式", + "按倍率设置": "按倍率设置", + "按价格设置": "按价格设置", + "输入价格": "输入价格", + "输出价格": "输出价格", + "获取渠道失败:": "获取渠道失败:", + "请至少选择一个渠道": "请至少选择一个渠道", + "后端请求失败": "后端请求失败", + "部分渠道测试失败:": "部分渠道测试失败:", + "未找到差异化倍率,无需同步": "未找到差异化倍率,无需同步", + "请求后端接口失败:": "请求后端接口失败:", + "同步成功": "同步成功", + "部分保存失败": "部分保存失败", + "未找到匹配的模型": "未找到匹配的模型", + "暂无差异化倍率显示": "暂无差异化倍率显示", + "请先选择同步渠道": "请先选择同步渠道", + "倍率类型": "倍率类型", + "缓存倍率": "缓存倍率", + "当前值": "当前值", + "未设置": "未设置", + "与本地相同": "与本地相同", + "运营设置": "运营设置", + "聊天设置": "聊天设置", + "速率限制设置": "速率限制设置", + "模型相关设置": "模型相关设置", + "系统设置": "系统设置", + "仪表盘设置": "仪表盘设置", + "获取初始化状态失败": "获取初始化状态失败", + "表单引用错误,请刷新页面重试": "表单引用错误,请刷新页面重试", + "请输入管理员用户名": "请输入管理员用户名", + "密码长度至少为8个字符": "密码长度至少为8个字符", + "两次输入的密码不一致": "两次输入的密码不一致", + "系统初始化成功,正在跳转...": "系统初始化成功,正在跳转...", + "初始化失败,请重试": "初始化失败,请重试", + "系统初始化失败,请重试": "系统初始化失败,请重试", + "系统初始化": "系统初始化", + "欢迎使用,请完成以下设置以开始使用系统": "欢迎使用,请完成以下设置以开始使用系统", + "数据库信息": "数据库信息", + "管理员账号": "管理员账号", + "设置系统管理员的登录信息": "设置系统管理员的登录信息", + "管理员账号已经初始化过,请继续设置其他参数": "管理员账号已经初始化过,请继续设置其他参数", + "密码": "密码", + "请输入管理员密码": "请输入管理员密码", + "请确认管理员密码": "请确认管理员密码", + "选择适合您使用场景的模式": "选择适合您使用场景的模式", + "对外运营模式": "对外运营模式", + "适用于为多个用户提供服务的场景": "适用于为多个用户提供服务的场景", + "默认模式": "默认模式", + "适用于个人使用的场景,不需要设置模型价格": "适用于个人使用的场景,不需要设置模型价格", + "无需计费": "无需计费", + "演示站点模式": "演示站点模式", + "适用于展示系统功能的场景,提供基础功能演示": "适用于展示系统功能的场景,提供基础功能演示", + "初始化系统": "初始化系统", + "使用模式说明": "使用模式说明", + "我已了解": "我已了解", + "默认模式,适用于为多个用户提供服务的场景。": "默认模式,适用于为多个用户提供服务的场景。", + "此模式下,系统将计算每次调用的用量,您需要对每个模型都设置价格,如果没有设置价格,用户将无法使用该模型。": "此模式下,系统将计算每次调用的用量,您需要对每个模型都设置价格,如果没有设置价格,用户将无法使用该模型。", + "多用户支持": "多用户支持", + "适用于个人使用的场景。": "适用于个人使用的场景。", + "不需要设置模型价格,系统将弱化用量计算,您可专注于使用模型。": "不需要设置模型价格,系统将弱化用量计算,您可专注于使用模型。", + "个人使用": "个人使用", + "适用于展示系统功能的场景。": "适用于展示系统功能的场景。", + "提供基础功能演示,方便用户了解系统特性。": "提供基础功能演示,方便用户了解系统特性。", + "体验试用": "体验试用", + "自动选择": "自动选择", + "过期时间格式错误!": "过期时间格式错误!", + "令牌更新成功!": "令牌更新成功!", + "令牌创建成功,请在列表页面点击复制获取令牌!": "令牌创建成功,请在列表页面点击复制获取令牌!", + "更新令牌信息": "更新令牌信息", + "创建新的令牌": "创建新的令牌", + "设置令牌的基本信息": "设置令牌的基本信息", + "请选择过期时间": "请选择过期时间", + "一天": "一天", + "一个月": "一个月", + "设置令牌可用额度和数量": "设置令牌可用额度和数量", + "新建数量": "新建数量", + "请选择或输入创建令牌的数量": "请选择或输入创建令牌的数量", + "20个": "20个", + "100个": "100个", + "取消无限额度": "取消无限额度", + "设为无限额度": "设为无限额度", + "设置令牌的访问限制": "设置令牌的访问限制", + "IP白名单": "IP白名单", + "允许的IP,一行一个,不填写则不限制": "允许的IP,一行一个,不填写则不限制", + "请勿过度信任此功能,IP可能被伪造": "请勿过度信任此功能,IP可能被伪造", + "勾选启用模型限制后可选择": "勾选启用模型限制后可选择", + "非必要,不建议启用模型限制": "非必要,不建议启用模型限制", + "分组信息": "分组信息", + "设置令牌的分组": "设置令牌的分组", + "令牌分组,默认为用户的分组": "令牌分组,默认为用户的分组", + "管理员未设置用户可选分组": "管理员未设置用户可选分组", + "请输入兑换码!": "请输入兑换码!", + "兑换成功!": "兑换成功!", + "成功兑换额度:": "成功兑换额度:", + "请求失败": "请求失败", + "超级管理员未设置充值链接!": "超级管理员未设置充值链接!", + "管理员未开启在线充值!": "管理员未开启在线充值!", + "充值数量不能小于": "充值数量不能小于", + "支付请求失败": "支付请求失败", + "划转金额最低为": "划转金额最低为", + "邀请链接已复制到剪切板": "邀请链接已复制到剪切板", + "支付方式配置错误, 请联系管理员": "支付方式配置错误, 请联系管理员", + "划转邀请额度": "划转邀请额度", + "可用邀请额度": "可用邀请额度", + "划转额度": "划转额度", + "充值确认": "充值确认", + "充值数量": "充值数量", + "实付金额": "实付金额", + "支付方式": "支付方式", + "在线充值": "在线充值", + "快速方便的充值方式": "快速方便的充值方式", + "选择充值额度": "选择充值额度", + "实付": "实付", + "或输入自定义金额": "或输入自定义金额", + "充值数量,最低 ": "充值数量,最低 ", + "选择支付方式": "选择支付方式", + "处理中": "处理中", + "兑换码充值": "兑换码充值", + "使用兑换码快速充值": "使用兑换码快速充值", + "请输入兑换码": "请输入兑换码", + "兑换中...": "兑换中...", + "兑换": "兑换", + "邀请奖励": "邀请奖励", + "邀请好友获得额外奖励": "邀请好友获得额外奖励", + "待使用收益": "待使用收益", + "总收益": "总收益", + "邀请人数": "邀请人数", + "邀请链接": "邀请链接", + "邀请好友注册,好友充值后您可获得相应奖励": "邀请好友注册,好友充值后您可获得相应奖励", + "通过划转功能将奖励额度转入到您的账户余额中": "通过划转功能将奖励额度转入到您的账户余额中", + "邀请的好友越多,获得的奖励越多": "邀请的好友越多,获得的奖励越多", + "用户名和密码不能为空!": "用户名和密码不能为空!", + "用户账户创建成功!": "用户账户创建成功!", + "提交": "提交", + "创建新用户账户": "创建新用户账户", + "请输入显示名称": "请输入显示名称", + "请输入密码": "请输入密码", + "请输入备注(仅管理员可见)": "请输入备注(仅管理员可见)", + "编辑用户": "编辑用户", + "用户的基本账户信息": "用户的基本账户信息", + "请输入新的用户名": "请输入新的用户名", + "请输入新的密码,最短 8 位": "请输入新的密码,最短 8 位", + "显示名称": "显示名称", + "请输入新的显示名称": "请输入新的显示名称", + "权限设置": "权限设置", + "用户分组和额度管理": "用户分组和额度管理", + "请输入新的剩余额度": "请输入新的剩余额度", + "添加额度": "添加额度", + "第三方账户绑定状态(只读)": "第三方账户绑定状态(只读)", + "已绑定的 GitHub 账户": "已绑定的 GitHub 账户", + "已绑定的 OIDC 账户": "已绑定的 OIDC 账户", + "已绑定的微信账户": "已绑定的微信账户", + "已绑定的邮箱账户": "已绑定的邮箱账户", + "已绑定的 Telegram 账户": "已绑定的 Telegram 账户", + "新额度": "新额度", + "需要添加的额度(支持负数)": "需要添加的额度(支持负数)" +} \ No newline at end of file diff --git a/main.go b/main.go index c286650f..727d5db6 100644 --- a/main.go +++ b/main.go @@ -12,7 +12,7 @@ import ( "one-api/model" "one-api/router" "one-api/service" - "one-api/setting/operation_setting" + "one-api/setting/ratio_setting" "os" "strconv" @@ -32,12 +32,12 @@ var buildFS embed.FS var indexPage []byte func main() { - err := godotenv.Load(".env") - if err != nil { - common.SysLog("Support for .env file is disabled: " + err.Error()) - } - common.LoadEnv() + err := InitResources() + if err != nil { + common.FatalLog("failed to initialize resources: " + err.Error()) + return + } common.SetupLogger() common.SysLog("New API " + common.Version + " started") @@ -47,19 +47,7 @@ func main() { if common.DebugEnabled { common.SysLog("running in debug mode") } - // Initialize SQL Database - err = model.InitDB() - if err != nil { - common.FatalLog("failed to initialize database: " + err.Error()) - } - model.CheckSetup() - - // Initialize SQL Database - err = model.InitLogDB() - if err != nil { - common.FatalLog("failed to initialize database: " + err.Error()) - } defer func() { err := model.CloseDB() if err != nil { @@ -67,21 +55,6 @@ func main() { } }() - // Initialize Redis - err = common.InitRedisClient() - if err != nil { - common.FatalLog("failed to initialize Redis: " + err.Error()) - } - - // Initialize model settings - operation_setting.InitRatioSettings() - // Initialize constants - constant.InitEnv() - // Initialize options - model.InitOptionMap() - - service.InitTokenEncoders() - if common.RedisEnabled { // for compatibility with old versions common.MemoryCacheEnabled = true @@ -105,10 +78,12 @@ func main() { model.InitChannelCache() }() - go model.SyncOptions(common.SyncFrequency) go model.SyncChannelCache(common.SyncFrequency) } + // 热更新配置 + go model.SyncOptions(common.SyncFrequency) + // 数据看板 go model.UpdateQuotaData() @@ -184,3 +159,51 @@ func main() { common.FatalLog("failed to start HTTP server: " + err.Error()) } } + +func InitResources() error { + // Initialize resources here if needed + // This is a placeholder function for future resource initialization + err := godotenv.Load(".env") + if err != nil { + common.SysLog("未找到 .env 文件,使用默认环境变量,如果需要,请创建 .env 文件并设置相关变量") + common.SysLog("No .env file found, using default environment variables. If needed, please create a .env file and set the relevant variables.") + } + + // 加载环境变量 + common.InitEnv() + + // Initialize model settings + ratio_setting.InitRatioSettings() + + service.InitHttpClient() + + service.InitTokenEncoders() + + // Initialize SQL Database + err = model.InitDB() + if err != nil { + common.FatalLog("failed to initialize database: " + err.Error()) + return err + } + + model.CheckSetup() + + // Initialize options, should after model.InitDB() + model.InitOptionMap() + + // 初始化模型 + model.GetPricing() + + // Initialize SQL Database + err = model.InitLogDB() + if err != nil { + return err + } + + // Initialize Redis + err = common.InitRedisClient() + if err != nil { + return err + } + return nil +} diff --git a/middleware/auth.go b/middleware/auth.go index f387029f..ecf4844b 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -184,7 +184,7 @@ func TokenAuth() func(c *gin.Context) { } } // gemini api 从query中获取key - if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") { + if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") || strings.HasPrefix(c.Request.URL.Path, "/v1/models/") { skKey := c.Query("key") if skKey != "" { c.Request.Header.Set("Authorization", "Bearer "+skKey) diff --git a/middleware/distributor.go b/middleware/distributor.go index 1bfe1821..17916e7a 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -11,6 +11,7 @@ import ( relayconstant "one-api/relay/constant" "one-api/service" "one-api/setting" + "one-api/setting/ratio_setting" "strconv" "strings" "time" @@ -24,7 +25,7 @@ type ModelRequest struct { func Distribute() func(c *gin.Context) { return func(c *gin.Context) { - allowIpsMap := c.GetStringMap("allow_ips") + allowIpsMap := common.GetContextKeyStringMap(c, constant.ContextKeyTokenAllowIps) if len(allowIpsMap) != 0 { clientIp := c.ClientIP() if _, ok := allowIpsMap[clientIp]; !ok { @@ -33,14 +34,14 @@ func Distribute() func(c *gin.Context) { } } var channel *model.Channel - channelId, ok := c.Get("specific_channel_id") + channelId, ok := common.GetContextKey(c, constant.ContextKeyTokenSpecificChannelId) modelRequest, shouldSelectChannel, err := getModelRequest(c) if err != nil { abortWithOpenAiMessage(c, http.StatusBadRequest, "Invalid request, "+err.Error()) return } - userGroup := c.GetString(constant.ContextKeyUserGroup) - tokenGroup := c.GetString("token_group") + userGroup := common.GetContextKeyString(c, constant.ContextKeyUserGroup) + tokenGroup := common.GetContextKeyString(c, constant.ContextKeyTokenGroup) if tokenGroup != "" { // check common.UserUsableGroups[userGroup] if _, ok := setting.GetUserUsableGroups(userGroup)[tokenGroup]; !ok { @@ -48,13 +49,15 @@ func Distribute() func(c *gin.Context) { return } // check group in common.GroupRatio - if !setting.ContainsGroupRatio(tokenGroup) { - abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup)) - return + if !ratio_setting.ContainsGroupRatio(tokenGroup) { + if tokenGroup != "auto" { + abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup)) + return + } } userGroup = tokenGroup } - c.Set("group", userGroup) + common.SetContextKey(c, constant.ContextKeyUsingGroup, userGroup) if ok { id, err := strconv.Atoi(channelId.(string)) if err != nil { @@ -73,9 +76,9 @@ func Distribute() func(c *gin.Context) { } else { // Select a channel for the user // check token model mapping - modelLimitEnable := c.GetBool("token_model_limit_enabled") + modelLimitEnable := common.GetContextKeyBool(c, constant.ContextKeyTokenModelLimitEnabled) if modelLimitEnable { - s, ok := c.Get("token_model_limit") + s, ok := common.GetContextKey(c, constant.ContextKeyTokenModelLimit) var tokenModelLimit map[string]bool if ok { tokenModelLimit = s.(map[string]bool) @@ -95,9 +98,14 @@ func Distribute() func(c *gin.Context) { } if shouldSelectChannel { - channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, 0) + var selectGroup string + channel, selectGroup, err = model.CacheGetRandomSatisfiedChannel(c, userGroup, modelRequest.Model, 0) if err != nil { - message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model) + showGroup := userGroup + if userGroup == "auto" { + showGroup = fmt.Sprintf("auto(%s)", selectGroup) + } + message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", showGroup, modelRequest.Model) // 如果错误,但是渠道不为空,说明是数据库一致性问题 if channel != nil { common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id)) @@ -113,7 +121,7 @@ func Distribute() func(c *gin.Context) { } } } - c.Set(constant.ContextKeyRequestStartTime, time.Now()) + common.SetContextKey(c, constant.ContextKeyRequestStartTime, time.Now()) SetupContextForSelectedChannel(c, channel, modelRequest.Model) c.Next() } @@ -162,7 +170,26 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) { } c.Set("platform", string(constant.TaskPlatformSuno)) c.Set("relay_mode", relayMode) - } else if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") { + } else if strings.Contains(c.Request.URL.Path, "/v1/video/generations") { + err = common.UnmarshalBodyReusable(c, &modelRequest) + var platform string + var relayMode int + if strings.HasPrefix(modelRequest.Model, "jimeng") { + platform = string(constant.TaskPlatformJimeng) + relayMode = relayconstant.Path2RelayJimeng(c.Request.Method, c.Request.URL.Path) + if relayMode == relayconstant.RelayModeJimengFetchByID { + shouldSelectChannel = false + } + } else { + platform = string(constant.TaskPlatformKling) + relayMode = relayconstant.Path2RelayKling(c.Request.Method, c.Request.URL.Path) + if relayMode == relayconstant.RelayModeKlingFetchByID { + shouldSelectChannel = false + } + } + c.Set("platform", platform) + c.Set("relay_mode", relayMode) + } else if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") || strings.HasPrefix(c.Request.URL.Path, "/v1/models/") { // Gemini API 路径处理: /v1beta/models/gemini-2.0-flash:generateContent relayMode := relayconstant.RelayModeGemini modelName := extractModelNameFromGeminiPath(c.Request.URL.Path) @@ -234,21 +261,21 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode c.Set("base_url", channel.GetBaseURL()) // TODO: api_version统一 switch channel.Type { - case common.ChannelTypeAzure: + case constant.ChannelTypeAzure: c.Set("api_version", channel.Other) - case common.ChannelTypeVertexAi: + case constant.ChannelTypeVertexAi: c.Set("region", channel.Other) - case common.ChannelTypeXunfei: + case constant.ChannelTypeXunfei: c.Set("api_version", channel.Other) - case common.ChannelTypeGemini: + case constant.ChannelTypeGemini: c.Set("api_version", channel.Other) - case common.ChannelTypeAli: + case constant.ChannelTypeAli: c.Set("plugin", channel.Other) - case common.ChannelCloudflare: + case constant.ChannelCloudflare: c.Set("api_version", channel.Other) - case common.ChannelTypeMokaAI: + case constant.ChannelTypeMokaAI: c.Set("api_version", channel.Other) - case common.ChannelTypeCoze: + case constant.ChannelTypeCoze: c.Set("bot_id", channel.Other) } } diff --git a/middleware/kling_adapter.go b/middleware/kling_adapter.go new file mode 100644 index 00000000..8e2a3551 --- /dev/null +++ b/middleware/kling_adapter.go @@ -0,0 +1,47 @@ +package middleware + +import ( + "bytes" + "encoding/json" + "io" + "one-api/common" + "one-api/constant" + + "github.com/gin-gonic/gin" +) + +func KlingRequestConvert() func(c *gin.Context) { + return func(c *gin.Context) { + var originalReq map[string]interface{} + if err := common.UnmarshalBodyReusable(c, &originalReq); err != nil { + c.Next() + return + } + + model, _ := originalReq["model"].(string) + prompt, _ := originalReq["prompt"].(string) + + unifiedReq := map[string]interface{}{ + "model": model, + "prompt": prompt, + "metadata": originalReq, + } + + jsonData, err := json.Marshal(unifiedReq) + if err != nil { + c.Next() + return + } + + // Rewrite request body and path + c.Request.Body = io.NopCloser(bytes.NewBuffer(jsonData)) + c.Request.URL.Path = "/v1/video/generations" + if image := originalReq["image"]; image == "" { + c.Set("action", constant.TaskActionTextGenerate) + } + + // We have to reset the request body for the next handlers + c.Set(common.KeyRequestBody, jsonData) + c.Next() + } +} diff --git a/middleware/model-rate-limit.go b/middleware/model-rate-limit.go index 34caa59b..14d9a737 100644 --- a/middleware/model-rate-limit.go +++ b/middleware/model-rate-limit.go @@ -177,9 +177,9 @@ func ModelRequestRateLimit() func(c *gin.Context) { successMaxCount := setting.ModelRequestRateLimitSuccessCount // 获取分组 - group := c.GetString("token_group") + group := common.GetContextKeyString(c, constant.ContextKeyTokenGroup) if group == "" { - group = c.GetString(constant.ContextKeyUserGroup) + group = common.GetContextKeyString(c, constant.ContextKeyUserGroup) } //获取分组的限流配置 diff --git a/model/ability.go b/model/ability.go index 96a9ef6a..fb5301fe 100644 --- a/model/ability.go +++ b/model/ability.go @@ -21,7 +21,22 @@ type Ability struct { Tag *string `json:"tag" gorm:"index"` } -func GetGroupModels(group string) []string { +type AbilityWithChannel struct { + Ability + ChannelType int `json:"channel_type"` +} + +func GetAllEnableAbilityWithChannels() ([]AbilityWithChannel, error) { + var abilities []AbilityWithChannel + err := DB.Table("abilities"). + Select("abilities.*, channels.type as channel_type"). + Joins("left join channels on abilities.channel_id = channels.id"). + Where("abilities.enabled = ?", true). + Scan(&abilities).Error + return abilities, err +} + +func GetGroupEnabledModels(group string) []string { var models []string // Find distinct models DB.Table("abilities").Where(commonGroupCol+" = ? and enabled = ?", group, true).Distinct("model").Pluck("model", &models) @@ -46,7 +61,7 @@ func getPriority(group string, model string, retry int) (int, error) { var priorities []int err := DB.Model(&Ability{}). Select("DISTINCT(priority)"). - Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, commonTrueVal). + Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, true). Order("priority DESC"). // 按优先级降序排序 Pluck("priority", &priorities).Error // Pluck用于将查询的结果直接扫描到一个切片中 @@ -72,14 +87,14 @@ func getPriority(group string, model string, retry int) (int, error) { } func getChannelQuery(group string, model string, retry int) *gorm.DB { - maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, commonTrueVal) - channelQuery := DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = (?)", group, model, commonTrueVal, maxPrioritySubQuery) + maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, true) + channelQuery := DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = (?)", group, model, true, maxPrioritySubQuery) if retry != 0 { priority, err := getPriority(group, model, retry) if err != nil { common.SysError(fmt.Sprintf("Get priority failed: %s", err.Error())) } else { - channelQuery = DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = ?", group, model, commonTrueVal, priority) + channelQuery = DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = ?", group, model, true, priority) } } diff --git a/model/cache.go b/model/cache.go index e2f83e22..3e5eb4c4 100644 --- a/model/cache.go +++ b/model/cache.go @@ -5,10 +5,13 @@ import ( "fmt" "math/rand" "one-api/common" + "one-api/setting" "sort" "strings" "sync" "time" + + "github.com/gin-gonic/gin" ) var group2model2channels map[string]map[string][]*Channel @@ -75,7 +78,43 @@ func SyncChannelCache(frequency int) { } } -func CacheGetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) { +func CacheGetRandomSatisfiedChannel(c *gin.Context, group string, model string, retry int) (*Channel, string, error) { + var channel *Channel + var err error + selectGroup := group + if group == "auto" { + if len(setting.AutoGroups) == 0 { + return nil, selectGroup, errors.New("auto groups is not enabled") + } + for _, autoGroup := range setting.AutoGroups { + if common.DebugEnabled { + println("autoGroup:", autoGroup) + } + channel, _ = getRandomSatisfiedChannel(autoGroup, model, retry) + if channel == nil { + continue + } else { + c.Set("auto_group", autoGroup) + selectGroup = autoGroup + if common.DebugEnabled { + println("selectGroup:", selectGroup) + } + break + } + } + } else { + channel, err = getRandomSatisfiedChannel(group, model, retry) + if err != nil { + return nil, group, err + } + } + if channel == nil { + return nil, group, errors.New("channel not found") + } + return channel, selectGroup, nil +} + +func getRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) { if strings.HasPrefix(model, "gpt-4-gizmo") { model = "gpt-4-gizmo-*" } diff --git a/model/channel.go b/model/channel.go index 5f460ec8..ed9a478a 100644 --- a/model/channel.go +++ b/model/channel.go @@ -617,3 +617,39 @@ func CountAllTags() (int64, error) { err := DB.Model(&Channel{}).Where("tag is not null AND tag != ''").Distinct("tag").Count(&total).Error return total, err } + +// Get channels of specified type with pagination +func GetChannelsByType(startIdx int, num int, idSort bool, channelType int) ([]*Channel, error) { + var channels []*Channel + order := "priority desc" + if idSort { + order = "id desc" + } + err := DB.Where("type = ?", channelType).Order(order).Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error + return channels, err +} + +// Count channels of specific type +func CountChannelsByType(channelType int) (int64, error) { + var count int64 + err := DB.Model(&Channel{}).Where("type = ?", channelType).Count(&count).Error + return count, err +} + +// Return map[type]count for all channels +func CountChannelsGroupByType() (map[int64]int64, error) { + type result struct { + Type int64 `gorm:"column:type"` + Count int64 `gorm:"column:count"` + } + var results []result + err := DB.Model(&Channel{}).Select("type, count(*) as count").Group("type").Find(&results).Error + if err != nil { + return nil, err + } + counts := make(map[int64]int64) + for _, r := range results { + counts[r.Type] = r.Count + } + return counts, nil +} diff --git a/model/main.go b/model/main.go index a5efda62..e2f9aecb 100644 --- a/model/main.go +++ b/model/main.go @@ -46,6 +46,15 @@ func initCol() { logGroupCol = commonGroupCol logKeyCol = commonKeyCol } + } else { + // LOG_SQL_DSN 为空时,日志数据库与主数据库相同 + if common.UsingPostgreSQL { + logGroupCol = `"group"` + logKeyCol = `"key"` + } else { + logGroupCol = commonGroupCol + logKeyCol = commonKeyCol + } } // log sql type and database type //common.SysLog("Using Log SQL Type: " + common.LogSqlType) diff --git a/model/midjourney.go b/model/midjourney.go index e8140447..c6ef5de5 100644 --- a/model/midjourney.go +++ b/model/midjourney.go @@ -14,6 +14,8 @@ type Midjourney struct { StartTime int64 `json:"start_time" gorm:"index"` FinishTime int64 `json:"finish_time" gorm:"index"` ImageUrl string `json:"image_url"` + VideoUrl string `json:"video_url"` + VideoUrls string `json:"video_urls"` Status string `json:"status" gorm:"type:varchar(20);index"` Progress string `json:"progress" gorm:"type:varchar(30);index"` FailReason string `json:"fail_reason"` diff --git a/model/option.go b/model/option.go index d1689cb7..ea72e5ee 100644 --- a/model/option.go +++ b/model/option.go @@ -5,6 +5,7 @@ import ( "one-api/setting" "one-api/setting/config" "one-api/setting/operation_setting" + "one-api/setting/ratio_setting" "strconv" "strings" "time" @@ -76,6 +77,9 @@ func InitOptionMap() { common.OptionMap["MinTopUp"] = strconv.Itoa(setting.MinTopUp) common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString() common.OptionMap["Chats"] = setting.Chats2JsonString() + common.OptionMap["AutoGroups"] = setting.AutoGroups2JsonString() + common.OptionMap["DefaultUseAutoGroup"] = strconv.FormatBool(setting.DefaultUseAutoGroup) + common.OptionMap["PayMethods"] = setting.PayMethods2JsonString() common.OptionMap["GitHubClientId"] = "" common.OptionMap["GitHubClientSecret"] = "" common.OptionMap["TelegramBotToken"] = "" @@ -94,13 +98,13 @@ func InitOptionMap() { common.OptionMap["ModelRequestRateLimitDurationMinutes"] = strconv.Itoa(setting.ModelRequestRateLimitDurationMinutes) common.OptionMap["ModelRequestRateLimitSuccessCount"] = strconv.Itoa(setting.ModelRequestRateLimitSuccessCount) common.OptionMap["ModelRequestRateLimitGroup"] = setting.ModelRequestRateLimitGroup2JSONString() - common.OptionMap["ModelRatio"] = operation_setting.ModelRatio2JSONString() - common.OptionMap["ModelPrice"] = operation_setting.ModelPrice2JSONString() - common.OptionMap["CacheRatio"] = operation_setting.CacheRatio2JSONString() - common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString() - common.OptionMap["GroupGroupRatio"] = setting.GroupGroupRatio2JSONString() + common.OptionMap["ModelRatio"] = ratio_setting.ModelRatio2JSONString() + common.OptionMap["ModelPrice"] = ratio_setting.ModelPrice2JSONString() + common.OptionMap["CacheRatio"] = ratio_setting.CacheRatio2JSONString() + common.OptionMap["GroupRatio"] = ratio_setting.GroupRatio2JSONString() + common.OptionMap["GroupGroupRatio"] = ratio_setting.GroupGroupRatio2JSONString() common.OptionMap["UserUsableGroups"] = setting.UserUsableGroups2JSONString() - common.OptionMap["CompletionRatio"] = operation_setting.CompletionRatio2JSONString() + common.OptionMap["CompletionRatio"] = ratio_setting.CompletionRatio2JSONString() common.OptionMap["TopUpLink"] = common.TopUpLink //common.OptionMap["ChatLink"] = common.ChatLink //common.OptionMap["ChatLink2"] = common.ChatLink2 @@ -123,6 +127,7 @@ func InitOptionMap() { common.OptionMap["SensitiveWords"] = setting.SensitiveWordsToString() common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(setting.StreamCacheQueueLength) common.OptionMap["AutomaticDisableKeywords"] = operation_setting.AutomaticDisableKeywordsToString() + common.OptionMap["ExposeRatioEnabled"] = strconv.FormatBool(ratio_setting.IsExposeRatioEnabled()) // 自动添加所有注册的模型配置 modelConfigs := config.GlobalConfig.ExportAllConfigs() @@ -192,7 +197,7 @@ func updateOptionMap(key string, value string) (err error) { common.ImageDownloadPermission = intValue } } - if strings.HasSuffix(key, "Enabled") || key == "DefaultCollapseSidebar" { + if strings.HasSuffix(key, "Enabled") || key == "DefaultCollapseSidebar" || key == "DefaultUseAutoGroup" { boolValue := value == "true" switch key { case "PasswordRegisterEnabled": @@ -261,6 +266,10 @@ func updateOptionMap(key string, value string) (err error) { common.SMTPSSLEnabled = boolValue case "WorkerAllowHttpImageRequestEnabled": setting.WorkerAllowHttpImageRequestEnabled = boolValue + case "DefaultUseAutoGroup": + setting.DefaultUseAutoGroup = boolValue + case "ExposeRatioEnabled": + ratio_setting.SetExposeRatioEnabled(boolValue) } } switch key { @@ -287,6 +296,8 @@ func updateOptionMap(key string, value string) (err error) { setting.PayAddress = value case "Chats": err = setting.UpdateChatsByJsonString(value) + case "AutoGroups": + err = setting.UpdateAutoGroupsByJsonString(value) case "CustomCallbackAddress": setting.CustomCallbackAddress = value case "EpayId": @@ -352,19 +363,19 @@ func updateOptionMap(key string, value string) (err error) { case "DataExportDefaultTime": common.DataExportDefaultTime = value case "ModelRatio": - err = operation_setting.UpdateModelRatioByJSONString(value) + err = ratio_setting.UpdateModelRatioByJSONString(value) case "GroupRatio": - err = setting.UpdateGroupRatioByJSONString(value) + err = ratio_setting.UpdateGroupRatioByJSONString(value) case "GroupGroupRatio": - err = setting.UpdateGroupGroupRatioByJSONString(value) + err = ratio_setting.UpdateGroupGroupRatioByJSONString(value) case "UserUsableGroups": err = setting.UpdateUserUsableGroupsByJSONString(value) case "CompletionRatio": - err = operation_setting.UpdateCompletionRatioByJSONString(value) + err = ratio_setting.UpdateCompletionRatioByJSONString(value) case "ModelPrice": - err = operation_setting.UpdateModelPriceByJSONString(value) + err = ratio_setting.UpdateModelPriceByJSONString(value) case "CacheRatio": - err = operation_setting.UpdateCacheRatioByJSONString(value) + err = ratio_setting.UpdateCacheRatioByJSONString(value) case "TopUpLink": common.TopUpLink = value //case "ChatLink": @@ -381,6 +392,8 @@ func updateOptionMap(key string, value string) (err error) { operation_setting.AutomaticDisableKeywordsFromString(value) case "StreamCacheQueueLength": setting.StreamCacheQueueLength, _ = strconv.Atoi(value) + case "PayMethods": + err = setting.UpdatePayMethodsByJsonString(value) } return err } diff --git a/model/pricing.go b/model/pricing.go index ba1815e2..0c0216f1 100644 --- a/model/pricing.go +++ b/model/pricing.go @@ -1,20 +1,24 @@ package model import ( + "fmt" "one-api/common" - "one-api/setting/operation_setting" + "one-api/constant" + "one-api/setting/ratio_setting" + "one-api/types" "sync" "time" ) type Pricing struct { - ModelName string `json:"model_name"` - QuotaType int `json:"quota_type"` - ModelRatio float64 `json:"model_ratio"` - ModelPrice float64 `json:"model_price"` - OwnerBy string `json:"owner_by"` - CompletionRatio float64 `json:"completion_ratio"` - EnableGroup []string `json:"enable_groups,omitempty"` + ModelName string `json:"model_name"` + QuotaType int `json:"quota_type"` + ModelRatio float64 `json:"model_ratio"` + ModelPrice float64 `json:"model_price"` + OwnerBy string `json:"owner_by"` + CompletionRatio float64 `json:"completion_ratio"` + EnableGroup []string `json:"enable_groups"` + SupportedEndpointTypes []constant.EndpointType `json:"supported_endpoint_types"` } var ( @@ -23,56 +27,98 @@ var ( updatePricingLock sync.Mutex ) -func GetPricing() []Pricing { - updatePricingLock.Lock() - defer updatePricingLock.Unlock() +var ( + modelSupportEndpointTypes = make(map[string][]constant.EndpointType) + modelSupportEndpointsLock = sync.RWMutex{} +) +func GetPricing() []Pricing { if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 { - updatePricing() + updatePricingLock.Lock() + defer updatePricingLock.Unlock() + // Double check after acquiring the lock + if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 { + modelSupportEndpointsLock.Lock() + defer modelSupportEndpointsLock.Unlock() + updatePricing() + } } - //if group != "" { - // userPricingMap := make([]Pricing, 0) - // models := GetGroupModels(group) - // for _, pricing := range pricingMap { - // if !common.StringsContains(models, pricing.ModelName) { - // pricing.Available = false - // } - // userPricingMap = append(userPricingMap, pricing) - // } - // return userPricingMap - //} return pricingMap } +func GetModelSupportEndpointTypes(model string) []constant.EndpointType { + if model == "" { + return make([]constant.EndpointType, 0) + } + modelSupportEndpointsLock.RLock() + defer modelSupportEndpointsLock.RUnlock() + if endpoints, ok := modelSupportEndpointTypes[model]; ok { + return endpoints + } + return make([]constant.EndpointType, 0) +} + func updatePricing() { //modelRatios := common.GetModelRatios() - enableAbilities := GetAllEnableAbilities() - modelGroupsMap := make(map[string][]string) + enableAbilities, err := GetAllEnableAbilityWithChannels() + if err != nil { + common.SysError(fmt.Sprintf("GetAllEnableAbilityWithChannels error: %v", err)) + return + } + modelGroupsMap := make(map[string]*types.Set[string]) + for _, ability := range enableAbilities { - groups := modelGroupsMap[ability.Model] - if groups == nil { - groups = make([]string, 0) + groups, ok := modelGroupsMap[ability.Model] + if !ok { + groups = types.NewSet[string]() + modelGroupsMap[ability.Model] = groups } - if !common.StringsContains(groups, ability.Group) { - groups = append(groups, ability.Group) + groups.Add(ability.Group) + } + + //这里使用切片而不是Set,因为一个模型可能支持多个端点类型,并且第一个端点是优先使用端点 + modelSupportEndpointsStr := make(map[string][]string) + + for _, ability := range enableAbilities { + endpoints, ok := modelSupportEndpointsStr[ability.Model] + if !ok { + endpoints = make([]string, 0) + modelSupportEndpointsStr[ability.Model] = endpoints } - modelGroupsMap[ability.Model] = groups + channelTypes := common.GetEndpointTypesByChannelType(ability.ChannelType, ability.Model) + for _, channelType := range channelTypes { + if !common.StringsContains(endpoints, string(channelType)) { + endpoints = append(endpoints, string(channelType)) + } + } + modelSupportEndpointsStr[ability.Model] = endpoints + } + + modelSupportEndpointTypes = make(map[string][]constant.EndpointType) + for model, endpoints := range modelSupportEndpointsStr { + supportedEndpoints := make([]constant.EndpointType, 0) + for _, endpointStr := range endpoints { + endpointType := constant.EndpointType(endpointStr) + supportedEndpoints = append(supportedEndpoints, endpointType) + } + modelSupportEndpointTypes[model] = supportedEndpoints } pricingMap = make([]Pricing, 0) for model, groups := range modelGroupsMap { pricing := Pricing{ - ModelName: model, - EnableGroup: groups, + ModelName: model, + EnableGroup: groups.Items(), + SupportedEndpointTypes: modelSupportEndpointTypes[model], } - modelPrice, findPrice := operation_setting.GetModelPrice(model, false) + modelPrice, findPrice := ratio_setting.GetModelPrice(model, false) if findPrice { pricing.ModelPrice = modelPrice pricing.QuotaType = 1 } else { - modelRatio, _ := operation_setting.GetModelRatio(model) + modelRatio, _ := ratio_setting.GetModelRatio(model) pricing.ModelRatio = modelRatio - pricing.CompletionRatio = operation_setting.GetCompletionRatio(model) + pricing.CompletionRatio = ratio_setting.GetCompletionRatio(model) pricing.QuotaType = 0 } pricingMap = append(pricingMap, pricing) diff --git a/model/token.go b/model/token.go index 2ed2c09a..7e68f185 100644 --- a/model/token.go +++ b/model/token.go @@ -327,3 +327,37 @@ func CountUserTokens(userId int) (int64, error) { err := DB.Model(&Token{}).Where("user_id = ?", userId).Count(&total).Error return total, err } + +// BatchDeleteTokens 删除指定用户的一组令牌,返回成功删除数量 +func BatchDeleteTokens(ids []int, userId int) (int, error) { + if len(ids) == 0 { + return 0, errors.New("ids 不能为空!") + } + + tx := DB.Begin() + + var tokens []Token + if err := tx.Where("user_id = ? AND id IN (?)", userId, ids).Find(&tokens).Error; err != nil { + tx.Rollback() + return 0, err + } + + if err := tx.Where("user_id = ? AND id IN (?)", userId, ids).Delete(&Token{}).Error; err != nil { + tx.Rollback() + return 0, err + } + + if err := tx.Commit().Error; err != nil { + return 0, err + } + + if common.RedisEnabled { + gopool.Go(func() { + for _, t := range tokens { + _ = cacheDeleteToken(t.Key) + } + }) + } + + return len(tokens), nil +} diff --git a/model/token_cache.go b/model/token_cache.go index b2e0c951..5399dbc8 100644 --- a/model/token_cache.go +++ b/model/token_cache.go @@ -10,7 +10,7 @@ import ( func cacheSetToken(token Token) error { key := common.GenerateHMAC(token.Key) token.Clean() - err := common.RedisHSetObj(fmt.Sprintf("token:%s", key), &token, time.Duration(constant.TokenCacheSeconds)*time.Second) + err := common.RedisHSetObj(fmt.Sprintf("token:%s", key), &token, time.Duration(common.RedisKeyCacheSeconds())*time.Second) if err != nil { return err } diff --git a/model/user.go b/model/user.go index 1b3a04b6..bd685e54 100644 --- a/model/user.go +++ b/model/user.go @@ -41,6 +41,7 @@ type User struct { DeletedAt gorm.DeletedAt `gorm:"index"` LinuxDOId string `json:"linux_do_id" gorm:"column:linux_do_id;index"` Setting string `json:"setting" gorm:"type:text;column:setting"` + Remark string `json:"remark,omitempty" gorm:"type:varchar(255)" validate:"max=255"` } func (user *User) ToBaseUser() *UserBase { @@ -113,7 +114,7 @@ func GetMaxUserId() int { return user.Id } -func GetAllUsers(startIdx int, num int) (users []*User, total int64, err error) { +func GetAllUsers(pageInfo *common.PageInfo) (users []*User, total int64, err error) { // Start transaction tx := DB.Begin() if tx.Error != nil { @@ -133,7 +134,7 @@ func GetAllUsers(startIdx int, num int) (users []*User, total int64, err error) } // Get paginated users within same transaction - err = tx.Unscoped().Order("id desc").Limit(num).Offset(startIdx).Omit("password").Find(&users).Error + err = tx.Unscoped().Order("id desc").Limit(pageInfo.GetPageSize()).Offset(pageInfo.GetStartIdx()).Omit("password").Find(&users).Error if err != nil { tx.Rollback() return nil, 0, err @@ -366,6 +367,7 @@ func (user *User) Edit(updatePassword bool) error { "display_name": newUser.DisplayName, "group": newUser.Group, "quota": newUser.Quota, + "remark": newUser.Remark, } if updatePassword { updates["password"] = newUser.Password diff --git a/model/user_cache.go b/model/user_cache.go index d74877bd..b4bc2f1e 100644 --- a/model/user_cache.go +++ b/model/user_cache.go @@ -24,12 +24,12 @@ type UserBase struct { } func (user *UserBase) WriteContext(c *gin.Context) { - c.Set(constant.ContextKeyUserGroup, user.Group) - c.Set(constant.ContextKeyUserQuota, user.Quota) - c.Set(constant.ContextKeyUserStatus, user.Status) - c.Set(constant.ContextKeyUserEmail, user.Email) - c.Set("username", user.Username) - c.Set(constant.ContextKeyUserSetting, user.GetSetting()) + common.SetContextKey(c, constant.ContextKeyUserGroup, user.Group) + common.SetContextKey(c, constant.ContextKeyUserQuota, user.Quota) + common.SetContextKey(c, constant.ContextKeyUserStatus, user.Status) + common.SetContextKey(c, constant.ContextKeyUserEmail, user.Email) + common.SetContextKey(c, constant.ContextKeyUserName, user.Username) + common.SetContextKey(c, constant.ContextKeyUserSetting, user.GetSetting()) } func (user *UserBase) GetSetting() map[string]interface{} { @@ -70,7 +70,7 @@ func updateUserCache(user User) error { return common.RedisHSetObj( getUserCacheKey(user.Id), user.ToBaseUser(), - time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second, + time.Duration(common.RedisKeyCacheSeconds())*time.Second, ) } diff --git a/model/utils.go b/model/utils.go index e6b09aa5..1f8a0963 100644 --- a/model/utils.go +++ b/model/utils.go @@ -2,11 +2,12 @@ package model import ( "errors" - "github.com/bytedance/gopkg/util/gopool" - "gorm.io/gorm" "one-api/common" "sync" "time" + + "github.com/bytedance/gopkg/util/gopool" + "gorm.io/gorm" ) const ( @@ -48,6 +49,22 @@ func addNewRecord(type_ int, id int, value int) { } func batchUpdate() { + // check if there's any data to update + hasData := false + for i := 0; i < BatchUpdateTypeCount; i++ { + batchUpdateLocks[i].Lock() + if len(batchUpdateStores[i]) > 0 { + hasData = true + batchUpdateLocks[i].Unlock() + break + } + batchUpdateLocks[i].Unlock() + } + + if !hasData { + return + } + common.SysLog("batch update started") for i := 0; i < BatchUpdateTypeCount; i++ { batchUpdateLocks[i].Lock() diff --git a/relay/relay-audio.go b/relay/audio_handler.go similarity index 91% rename from relay/relay-audio.go rename to relay/audio_handler.go index deb45c58..c1ce1a02 100644 --- a/relay/relay-audio.go +++ b/relay/audio_handler.go @@ -55,7 +55,7 @@ func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto. } func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { - relayInfo := relaycommon.GenRelayInfo(c) + relayInfo := relaycommon.GenRelayInfoOpenAIAudio(c) audioRequest, err := getAndValidAudioRequest(c, relayInfo) if err != nil { @@ -66,10 +66,7 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { promptTokens := 0 preConsumedTokens := common.PreConsumedQuota if relayInfo.RelayMode == relayconstant.RelayModeAudioSpeech { - promptTokens, err = service.CountTTSToken(audioRequest.Input, audioRequest.Model) - if err != nil { - return service.OpenAIErrorWrapper(err, "count_audio_token_failed", http.StatusInternalServerError) - } + promptTokens = service.CountTTSToken(audioRequest.Input, audioRequest.Model) preConsumedTokens = promptTokens relayInfo.PromptTokens = promptTokens } @@ -89,13 +86,11 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { } }() - err = helper.ModelMappedHelper(c, relayInfo) + err = helper.ModelMappedHelper(c, relayInfo, audioRequest) if err != nil { return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError) } - audioRequest.Model = relayInfo.UpstreamModelName - adaptor := GetAdaptor(relayInfo.ApiType) if adaptor == nil { return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest) diff --git a/relay/channel/adapter.go b/relay/channel/adapter.go index 50255d0a..2ff34e01 100644 --- a/relay/channel/adapter.go +++ b/relay/channel/adapter.go @@ -44,4 +44,6 @@ type TaskAdaptor interface { // FetchTask FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) + + ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) } diff --git a/relay/channel/ali/adaptor.go b/relay/channel/ali/adaptor.go index f30d4dc4..63525cc4 100644 --- a/relay/channel/ali/adaptor.go +++ b/relay/channel/ali/adaptor.go @@ -30,7 +30,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { var fullRequestURL string switch info.RelayMode { case constant.RelayModeEmbeddings: - fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", info.BaseUrl) + fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/embeddings", info.BaseUrl) case constant.RelayModeRerank: fullRequestURL = fmt.Sprintf("%s/api/v1/services/rerank/text-rerank/text-rerank", info.BaseUrl) case constant.RelayModeImagesGenerations: @@ -82,7 +82,7 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt } func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { - return embeddingRequestOpenAI2Ali(request), nil + return request, nil } func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { diff --git a/relay/channel/ali/image.go b/relay/channel/ali/image.go index 44203583..c84c7885 100644 --- a/relay/channel/ali/image.go +++ b/relay/channel/ali/image.go @@ -132,10 +132,7 @@ func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela if err != nil { return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } - err = resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } + common.CloseResponseBodyGracefully(resp) err = json.Unmarshal(responseBody, &aliTaskResponse) if err != nil { return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil diff --git a/relay/channel/ali/rerank.go b/relay/channel/ali/rerank.go index c9ae066a..ebfe26de 100644 --- a/relay/channel/ali/rerank.go +++ b/relay/channel/ali/rerank.go @@ -4,6 +4,7 @@ import ( "encoding/json" "io" "net/http" + "one-api/common" "one-api/dto" relaycommon "one-api/relay/common" "one-api/service" @@ -35,10 +36,7 @@ func RerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI if err != nil { return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } - err = resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } + common.CloseResponseBodyGracefully(resp) var aliResponse AliRerankResponse err = json.Unmarshal(responseBody, &aliResponse) diff --git a/relay/channel/ali/text.go b/relay/channel/ali/text.go index 2f1387c5..149c9b4b 100644 --- a/relay/channel/ali/text.go +++ b/relay/channel/ali/text.go @@ -39,34 +39,18 @@ func embeddingRequestOpenAI2Ali(request dto.EmbeddingRequest) *AliEmbeddingReque } func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { - var aliResponse AliEmbeddingResponse - err := json.NewDecoder(resp.Body).Decode(&aliResponse) + var fullTextResponse dto.OpenAIEmbeddingResponse + err := json.NewDecoder(resp.Body).Decode(&fullTextResponse) if err != nil { return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } - err = resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } - - if aliResponse.Code != "" { - return &dto.OpenAIErrorWithStatusCode{ - Error: dto.OpenAIError{ - Message: aliResponse.Message, - Type: aliResponse.Code, - Param: aliResponse.RequestId, - Code: aliResponse.Code, - }, - StatusCode: resp.StatusCode, - }, nil - } + common.CloseResponseBodyGracefully(resp) model := c.GetString("model") if model == "" { model = "text-embedding-v4" } - fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse, model) jsonResponse, err := json.Marshal(fullTextResponse) if err != nil { return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil @@ -186,10 +170,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWith return false } }) - err := resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } + common.CloseResponseBodyGracefully(resp) return nil, &usage } @@ -199,10 +180,7 @@ func aliHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatus if err != nil { return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } - err = resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } + common.CloseResponseBodyGracefully(resp) err = json.Unmarshal(responseBody, &aliResponse) if err != nil { return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil diff --git a/relay/channel/baidu/relay-baidu.go b/relay/channel/baidu/relay-baidu.go index 55b6c137..11492fe3 100644 --- a/relay/channel/baidu/relay-baidu.go +++ b/relay/channel/baidu/relay-baidu.go @@ -166,10 +166,7 @@ func baiduStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWi return false } }) - err := resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } + common.CloseResponseBodyGracefully(resp) return nil, &usage } @@ -179,10 +176,7 @@ func baiduHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStat if err != nil { return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } - err = resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } + common.CloseResponseBodyGracefully(resp) err = json.Unmarshal(responseBody, &baiduResponse) if err != nil { return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil @@ -215,10 +209,7 @@ func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErro if err != nil { return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } - err = resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } + common.CloseResponseBodyGracefully(resp) err = json.Unmarshal(responseBody, &baiduResponse) if err != nil { return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil @@ -280,7 +271,7 @@ func getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessToken, error) { } req.Header.Add("Content-Type", "application/json") req.Header.Add("Accept", "application/json") - res, err := service.GetImpatientHttpClient().Do(req) + res, err := service.GetHttpClient().Do(req) if err != nil { return nil, err } diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index cb2c75b1..a8607d86 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -7,6 +7,7 @@ import ( "net/http" "one-api/common" "one-api/dto" + "one-api/relay/channel/openrouter" relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" @@ -113,7 +114,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla // BudgetTokens 为 max_tokens 的 80% claudeRequest.Thinking = &dto.Thinking{ Type: "enabled", - BudgetTokens: int(float64(claudeRequest.MaxTokens) * model_setting.GetClaudeSettings().ThinkingAdapterBudgetTokensPercentage), + BudgetTokens: common.GetPointer[int](int(float64(claudeRequest.MaxTokens) * model_setting.GetClaudeSettings().ThinkingAdapterBudgetTokensPercentage)), } // TODO: 临时处理 // https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#important-considerations-when-using-extended-thinking @@ -122,6 +123,21 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla claudeRequest.Model = strings.TrimSuffix(textRequest.Model, "-thinking") } + if textRequest.Reasoning != nil { + var reasoning openrouter.RequestReasoning + if err := common.UnmarshalJson(textRequest.Reasoning, &reasoning); err != nil { + return nil, err + } + + budgetTokens := reasoning.MaxTokens + if budgetTokens > 0 { + claudeRequest.Thinking = &dto.Thinking{ + Type: "enabled", + BudgetTokens: &budgetTokens, + } + } + } + if textRequest.Stop != nil { // stop maybe string/array string, convert to array string switch textRequest.Stop.(type) { @@ -454,6 +470,7 @@ type ClaudeResponseInfo struct { Model string ResponseText strings.Builder Usage *dto.Usage + Done bool } func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeResponse, oaiResponse *dto.ChatCompletionsStreamResponse, claudeInfo *ClaudeResponseInfo) bool { @@ -461,20 +478,32 @@ func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeRespons claudeInfo.ResponseText.WriteString(claudeResponse.Completion) } else { if claudeResponse.Type == "message_start" { - // message_start, 获取usage claudeInfo.ResponseId = claudeResponse.Message.Id claudeInfo.Model = claudeResponse.Message.Model + + // message_start, 获取usage claudeInfo.Usage.PromptTokens = claudeResponse.Message.Usage.InputTokens + claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Message.Usage.CacheReadInputTokens + claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Message.Usage.CacheCreationInputTokens + claudeInfo.Usage.CompletionTokens = claudeResponse.Message.Usage.OutputTokens } else if claudeResponse.Type == "content_block_delta" { if claudeResponse.Delta.Text != nil { claudeInfo.ResponseText.WriteString(*claudeResponse.Delta.Text) } + if claudeResponse.Delta.Thinking != "" { + claudeInfo.ResponseText.WriteString(claudeResponse.Delta.Thinking) + } } else if claudeResponse.Type == "message_delta" { - claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens + // 最终的usage获取 if claudeResponse.Usage.InputTokens > 0 { + // 不叠加,只取最新的 claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens } - claudeInfo.Usage.TotalTokens = claudeInfo.Usage.PromptTokens + claudeResponse.Usage.OutputTokens + claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens + claudeInfo.Usage.TotalTokens = claudeInfo.Usage.PromptTokens + claudeInfo.Usage.CompletionTokens + + // 判断是否完整 + claudeInfo.Done = true } else if claudeResponse.Type == "content_block_start" { } else { return false @@ -490,7 +519,7 @@ func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeRespons func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data string, requestMode int) *dto.OpenAIErrorWithStatusCode { var claudeResponse dto.ClaudeResponse - err := common.DecodeJsonStr(data, &claudeResponse) + err := common.UnmarshalJsonStr(data, &claudeResponse) if err != nil { common.SysError("error unmarshalling stream response: " + err.Error()) return service.OpenAIErrorWrapper(err, "stream_response_error", http.StatusInternalServerError) @@ -506,25 +535,15 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud } } if info.RelayFormat == relaycommon.RelayFormatClaude { + FormatClaudeResponseInfo(requestMode, &claudeResponse, nil, claudeInfo) + if requestMode == RequestModeCompletion { - claudeInfo.ResponseText.WriteString(claudeResponse.Completion) } else { if claudeResponse.Type == "message_start" { // message_start, 获取usage info.UpstreamModelName = claudeResponse.Message.Model - claudeInfo.Usage.PromptTokens = claudeResponse.Message.Usage.InputTokens - claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Message.Usage.CacheReadInputTokens - claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Message.Usage.CacheCreationInputTokens - claudeInfo.Usage.CompletionTokens = claudeResponse.Message.Usage.OutputTokens } else if claudeResponse.Type == "content_block_delta" { - claudeInfo.ResponseText.WriteString(claudeResponse.Delta.GetText()) } else if claudeResponse.Type == "message_delta" { - if claudeResponse.Usage.InputTokens > 0 { - // 不叠加,只取最新的 - claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens - } - claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens - claudeInfo.Usage.TotalTokens = claudeInfo.Usage.PromptTokens + claudeInfo.Usage.CompletionTokens } } helper.ClaudeChunkData(c, claudeResponse, data) @@ -544,29 +563,25 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud } func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, requestMode int) { + + if requestMode == RequestModeCompletion { + claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens) + } else { + if claudeInfo.Usage.PromptTokens == 0 { + //上游出错 + } + if claudeInfo.Usage.CompletionTokens == 0 || !claudeInfo.Done { + if common.DebugEnabled { + common.SysError("claude response usage is not complete, maybe upstream error") + } + claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens) + } + } + if info.RelayFormat == relaycommon.RelayFormatClaude { - if requestMode == RequestModeCompletion { - claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens) - } else { - // 说明流模式建立失败,可能为官方出错 - if claudeInfo.Usage.PromptTokens == 0 { - //usage.PromptTokens = info.PromptTokens - } - if claudeInfo.Usage.CompletionTokens == 0 { - claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens) - } - } + // } else if info.RelayFormat == relaycommon.RelayFormatOpenAI { - if requestMode == RequestModeCompletion { - claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens) - } else { - if claudeInfo.Usage.PromptTokens == 0 { - //上游出错 - } - if claudeInfo.Usage.CompletionTokens == 0 { - claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens) - } - } + if info.ShouldIncludeUsage { response := helper.GenerateFinalUsageResponse(claudeInfo.ResponseId, claudeInfo.Created, info.UpstreamModelName, *claudeInfo.Usage) err := helper.ObjectData(c, response) @@ -604,7 +619,7 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data []byte, requestMode int) *dto.OpenAIErrorWithStatusCode { var claudeResponse dto.ClaudeResponse - err := common.DecodeJson(data, &claudeResponse) + err := common.UnmarshalJson(data, &claudeResponse) if err != nil { return service.OpenAIErrorWrapper(err, "unmarshal_claude_response_failed", http.StatusInternalServerError) } @@ -619,10 +634,7 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud } } if requestMode == RequestModeCompletion { - completionTokens, err := service.CountTextToken(claudeResponse.Completion, info.OriginModelName) - if err != nil { - return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError) - } + completionTokens := service.CountTextToken(claudeResponse.Completion, info.OriginModelName) claudeInfo.Usage.PromptTokens = info.PromptTokens claudeInfo.Usage.CompletionTokens = completionTokens claudeInfo.Usage.TotalTokens = info.PromptTokens + completionTokens @@ -645,13 +657,14 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud case relaycommon.RelayFormatClaude: responseData = data } - c.Writer.Header().Set("Content-Type", "application/json") - c.Writer.WriteHeader(http.StatusOK) - _, err = c.Writer.Write(responseData) + + common.IOCopyBytesGracefully(c, nil, responseData) return nil } func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + defer common.CloseResponseBodyGracefully(resp) + claudeInfo := &ClaudeResponseInfo{ ResponseId: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), Created: common.GetTimestamp(), @@ -663,7 +676,6 @@ func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *r if err != nil { return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } - resp.Body.Close() if common.DebugEnabled { println("responseBody: ", string(responseBody)) } diff --git a/relay/channel/cloudflare/relay_cloudflare.go b/relay/channel/cloudflare/relay_cloudflare.go index a487429c..1c3a26f7 100644 --- a/relay/channel/cloudflare/relay_cloudflare.go +++ b/relay/channel/cloudflare/relay_cloudflare.go @@ -71,7 +71,7 @@ func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela if err := scanner.Err(); err != nil { common.LogError(c, "error_scanning_stream_response: "+err.Error()) } - usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) + usage := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) if info.ShouldIncludeUsage { response := helper.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage) err := helper.ObjectData(c, response) @@ -81,10 +81,7 @@ func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela } helper.Done(c) - err := resp.Body.Close() - if err != nil { - common.LogError(c, "close_response_body_failed: "+err.Error()) - } + common.CloseResponseBodyGracefully(resp) return nil, usage } @@ -94,10 +91,7 @@ func cfHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) if err != nil { return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } - err = resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil - } + common.CloseResponseBodyGracefully(resp) var response dto.TextResponse err = json.Unmarshal(responseBody, &response) if err != nil { @@ -108,7 +102,7 @@ func cfHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) for _, choice := range response.Choices { responseText += choice.Message.StringContent() } - usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) + usage := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) response.Usage = *usage response.Id = helper.GetResponseID(c) jsonResponse, err := json.Marshal(response) @@ -127,10 +121,7 @@ func cfSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayIn if err != nil { return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } - err = resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } + common.CloseResponseBodyGracefully(resp) err = json.Unmarshal(responseBody, &cfResp) if err != nil { return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil @@ -150,7 +141,7 @@ func cfSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayIn usage := &dto.Usage{} usage.PromptTokens = info.PromptTokens - usage.CompletionTokens, _ = service.CountTextToken(cfResp.Result.Text, info.UpstreamModelName) + usage.CompletionTokens = service.CountTextToken(cfResp.Result.Text, info.UpstreamModelName) usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens return nil, usage diff --git a/relay/channel/cohere/relay-cohere.go b/relay/channel/cohere/relay-cohere.go index 10c4328b..4637740d 100644 --- a/relay/channel/cohere/relay-cohere.go +++ b/relay/channel/cohere/relay-cohere.go @@ -3,7 +3,6 @@ package cohere import ( "bufio" "encoding/json" - "fmt" "github.com/gin-gonic/gin" "io" "net/http" @@ -78,7 +77,7 @@ func stopReasonCohere2OpenAI(reason string) string { } func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { - responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) + responseId := helper.GetResponseID(c) createdTime := common.GetTimestamp() usage := &dto.Usage{} responseText := "" @@ -163,7 +162,7 @@ func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. } }) if usage.PromptTokens == 0 { - usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) + usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) } return nil, usage } @@ -174,10 +173,7 @@ func cohereHandler(c *gin.Context, resp *http.Response, modelName string, prompt if err != nil { return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } - err = resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } + common.CloseResponseBodyGracefully(resp) var cohereResp CohereResponseResult err = json.Unmarshal(responseBody, &cohereResp) if err != nil { @@ -218,10 +214,7 @@ func cohereRerankHandler(c *gin.Context, resp *http.Response, info *relaycommon. if err != nil { return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } - err = resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } + common.CloseResponseBodyGracefully(resp) var cohereResp CohereRerankResponseResult err = json.Unmarshal(responseBody, &cohereResp) if err != nil { diff --git a/relay/channel/coze/relay-coze.go b/relay/channel/coze/relay-coze.go index 6db40213..6c08261b 100644 --- a/relay/channel/coze/relay-coze.go +++ b/relay/channel/coze/relay-coze.go @@ -48,10 +48,7 @@ func cozeChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela if err != nil { return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } - err = resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil - } + common.CloseResponseBodyGracefully(resp) // convert coze response to openai response var response dto.TextResponse var cozeResponse CozeChatDetailResponse @@ -106,7 +103,7 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo var currentEvent string var currentData string - var usage dto.Usage + var usage = &dto.Usage{} for scanner.Scan() { line := scanner.Text() @@ -114,7 +111,7 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo if line == "" { if currentEvent != "" && currentData != "" { // handle last event - handleCozeEvent(c, currentEvent, currentData, &responseText, &usage, id, info) + handleCozeEvent(c, currentEvent, currentData, &responseText, usage, id, info) currentEvent = "" currentData = "" } @@ -134,7 +131,7 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo // Last event if currentEvent != "" && currentData != "" { - handleCozeEvent(c, currentEvent, currentData, &responseText, &usage, id, info) + handleCozeEvent(c, currentEvent, currentData, &responseText, usage, id, info) } if err := scanner.Err(); err != nil { @@ -143,12 +140,10 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo helper.Done(c) if usage.TotalTokens == 0 { - usage.PromptTokens = info.PromptTokens - usage.CompletionTokens, _ = service.CountTextToken("gpt-3.5-turbo", responseText) - usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens + usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, c.GetInt("coze_input_count")) } - return nil, &usage + return nil, usage } func handleCozeEvent(c *gin.Context, event string, data string, responseText *string, usage *dto.Usage, id string, info *relaycommon.RelayInfo) { diff --git a/relay/channel/dify/relay-dify.go b/relay/channel/dify/relay-dify.go index 93e3e8d6..3a2845b3 100644 --- a/relay/channel/dify/relay-dify.go +++ b/relay/channel/dify/relay-dify.go @@ -95,7 +95,7 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey)) // Send request - client := service.GetImpatientHttpClient() + client := service.GetHttpClient() resp, err := client.Do(req) if err != nil { common.SysError("failed to send request: " + err.Error()) @@ -243,15 +243,8 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re return true }) helper.Done(c) - err := resp.Body.Close() - if err != nil { - // return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - common.SysError("close_response_body_failed: " + err.Error()) - } if usage.TotalTokens == 0 { - usage.PromptTokens = info.PromptTokens - usage.CompletionTokens, _ = service.CountTextToken("gpt-3.5-turbo", responseText) - usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens + usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) } usage.CompletionTokens += nodeToken return nil, usage @@ -264,10 +257,7 @@ func difyHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInf if err != nil { return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } - err = resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } + common.CloseResponseBodyGracefully(resp) err = json.Unmarshal(responseBody, &difyResponse) if err != nil { return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go index e6f66d5f..968d9c9b 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/channel/gemini/adaptor.go @@ -72,10 +72,13 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { if model_setting.GetGeminiSettings().ThinkingAdapterEnabled { - // suffix -thinking and -nothinking - if strings.HasSuffix(info.OriginModelName, "-thinking") { + // 新增逻辑:处理 -thinking- 格式 + if strings.Contains(info.UpstreamModelName, "-thinking-") { + parts := strings.Split(info.UpstreamModelName, "-thinking-") + info.UpstreamModelName = parts[0] + } else if strings.HasSuffix(info.UpstreamModelName, "-thinking") { // 旧的适配 info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking") - } else if strings.HasSuffix(info.OriginModelName, "-nothinking") { + } else if strings.HasSuffix(info.UpstreamModelName, "-nothinking") { info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-nothinking") } } diff --git a/relay/channel/gemini/dto.go b/relay/channel/gemini/dto.go index fa9108df..b22e092a 100644 --- a/relay/channel/gemini/dto.go +++ b/relay/channel/gemini/dto.go @@ -140,6 +140,7 @@ type GeminiChatGenerationConfig struct { Seed int64 `json:"seed,omitempty"` ResponseModalities []string `json:"responseModalities,omitempty"` ThinkingConfig *GeminiThinkingConfig `json:"thinkingConfig,omitempty"` + SpeechConfig json.RawMessage `json:"speechConfig,omitempty"` // RawMessage to allow flexible speech config } type GeminiChatCandidate struct { diff --git a/relay/channel/gemini/relay-gemini-native.go b/relay/channel/gemini/relay-gemini-native.go index d9d0054d..52846c66 100644 --- a/relay/channel/gemini/relay-gemini-native.go +++ b/relay/channel/gemini/relay-gemini-native.go @@ -1,7 +1,6 @@ package gemini import ( - "encoding/json" "io" "net/http" "one-api/common" @@ -9,20 +8,19 @@ import ( relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" + "strings" "github.com/gin-gonic/gin" ) func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *dto.OpenAIErrorWithStatusCode) { + defer common.CloseResponseBodyGracefully(resp) + // 读取响应体 responseBody, err := io.ReadAll(resp.Body) if err != nil { return nil, service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) } - err = resp.Body.Close() - if err != nil { - return nil, service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) - } if common.DebugEnabled { println(string(responseBody)) @@ -30,28 +28,15 @@ func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *rela // 解析为 Gemini 原生响应格式 var geminiResponse GeminiChatResponse - err = common.DecodeJson(responseBody, &geminiResponse) + err = common.UnmarshalJson(responseBody, &geminiResponse) if err != nil { return nil, service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) } - // 检查是否有候选响应 - if len(geminiResponse.Candidates) == 0 { - return nil, &dto.OpenAIErrorWithStatusCode{ - Error: dto.OpenAIError{ - Message: "No candidates returned", - Type: "server_error", - Param: "", - Code: 500, - }, - StatusCode: resp.StatusCode, - } - } - // 计算使用量(基于 UsageMetadata) usage := dto.Usage{ PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount, - CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount, + CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount, TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount, } @@ -66,18 +51,12 @@ func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *rela } // 直接返回 Gemini 原生格式的 JSON 响应 - jsonResponse, err := json.Marshal(geminiResponse) + jsonResponse, err := common.EncodeJson(geminiResponse) if err != nil { return nil, service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError) } - // 设置响应头并写入响应 - c.Writer.Header().Set("Content-Type", "application/json") - c.Writer.WriteHeader(resp.StatusCode) - _, err = c.Writer.Write(jsonResponse) - if err != nil { - return nil, service.OpenAIErrorWrapper(err, "write_response_failed", http.StatusInternalServerError) - } + common.IOCopyBytesGracefully(c, resp, jsonResponse) return &usage, nil } @@ -88,9 +67,11 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info helper.SetEventStreamHeaders(c) + responseText := strings.Builder{} + helper.StreamScannerHandler(c, resp, info, func(data string) bool { var geminiResponse GeminiChatResponse - err := common.DecodeJsonStr(data, &geminiResponse) + err := common.UnmarshalJsonStr(data, &geminiResponse) if err != nil { common.LogError(c, "error unmarshalling stream response: "+err.Error()) return false @@ -102,13 +83,16 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info if part.InlineData != nil && part.InlineData.MimeType != "" { imageCount++ } + if part.Text != "" { + responseText.WriteString(part.Text) + } } } // 更新使用量统计 if geminiResponse.UsageMetadata.TotalTokenCount != 0 { usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount - usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount + usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount usage.TotalTokens = geminiResponse.UsageMetadata.TotalTokenCount usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails { @@ -121,7 +105,7 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info } // 直接发送 GeminiChatResponse 响应 - err = helper.ObjectData(c, geminiResponse) + err = helper.StringData(c, data) if err != nil { common.LogError(c, err.Error()) } @@ -135,8 +119,16 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info } } - // 计算最终使用量 - usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens + // 如果usage.CompletionTokens为0,则使用本地统计的completion tokens + if usage.CompletionTokens == 0 { + str := responseText.String() + if len(str) > 0 { + usage = service.ResponseText2Usage(responseText.String(), info.UpstreamModelName, info.PromptTokens) + } else { + // 空补全,不需要使用量 + usage = &dto.Usage{} + } + } // 移除流式响应结尾的[Done],因为Gemini API没有发送Done的行为 //helper.Done(c) diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index e2288faf..1544e8cf 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -12,6 +12,7 @@ import ( "one-api/relay/helper" "one-api/service" "one-api/setting/model_setting" + "strconv" "strings" "unicode/utf8" @@ -36,6 +37,102 @@ var geminiSupportedMimeTypes = map[string]bool{ "video/flv": true, } +// Gemini 允许的思考预算范围 +const ( + pro25MinBudget = 128 + pro25MaxBudget = 32768 + flash25MaxBudget = 24576 + flash25LiteMinBudget = 512 + flash25LiteMaxBudget = 24576 +) + +// clampThinkingBudget 根据模型名称将预算限制在允许的范围内 +func clampThinkingBudget(modelName string, budget int) int { + isNew25Pro := strings.HasPrefix(modelName, "gemini-2.5-pro") && + !strings.HasPrefix(modelName, "gemini-2.5-pro-preview-05-06") && + !strings.HasPrefix(modelName, "gemini-2.5-pro-preview-03-25") + is25FlashLite := strings.HasPrefix(modelName, "gemini-2.5-flash-lite") + + if is25FlashLite { + if budget < flash25LiteMinBudget { + return flash25LiteMinBudget + } + if budget > flash25LiteMaxBudget { + return flash25LiteMaxBudget + } + } else if isNew25Pro { + if budget < pro25MinBudget { + return pro25MinBudget + } + if budget > pro25MaxBudget { + return pro25MaxBudget + } + } else { // 其他模型 + if budget < 0 { + return 0 + } + if budget > flash25MaxBudget { + return flash25MaxBudget + } + } + return budget +} + +func ThinkingAdaptor(geminiRequest *GeminiChatRequest, info *relaycommon.RelayInfo) { + if model_setting.GetGeminiSettings().ThinkingAdapterEnabled { + modelName := info.UpstreamModelName + isNew25Pro := strings.HasPrefix(modelName, "gemini-2.5-pro") && + !strings.HasPrefix(modelName, "gemini-2.5-pro-preview-05-06") && + !strings.HasPrefix(modelName, "gemini-2.5-pro-preview-03-25") + + if strings.Contains(modelName, "-thinking-") { + parts := strings.SplitN(modelName, "-thinking-", 2) + if len(parts) == 2 && parts[1] != "" { + if budgetTokens, err := strconv.Atoi(parts[1]); err == nil { + clampedBudget := clampThinkingBudget(modelName, budgetTokens) + geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{ + ThinkingBudget: common.GetPointer(clampedBudget), + IncludeThoughts: true, + } + } + } + } else if strings.HasSuffix(modelName, "-thinking") { + unsupportedModels := []string{ + "gemini-2.5-pro-preview-05-06", + "gemini-2.5-pro-preview-03-25", + } + isUnsupported := false + for _, unsupportedModel := range unsupportedModels { + if strings.HasPrefix(modelName, unsupportedModel) { + isUnsupported = true + break + } + } + + if isUnsupported { + geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{ + IncludeThoughts: true, + } + } else { + geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{ + IncludeThoughts: true, + } + if geminiRequest.GenerationConfig.MaxOutputTokens > 0 { + budgetTokens := model_setting.GetGeminiSettings().ThinkingAdapterBudgetTokensPercentage * float64(geminiRequest.GenerationConfig.MaxOutputTokens) + clampedBudget := clampThinkingBudget(modelName, int(budgetTokens)) + geminiRequest.GenerationConfig.ThinkingConfig.ThinkingBudget = common.GetPointer(clampedBudget) + } + } + } else if strings.HasSuffix(modelName, "-nothinking") { + if !isNew25Pro { + geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{ + ThinkingBudget: common.GetPointer(0), + } + } + } + } +} + // Setting safety to the lowest possible values since Gemini is already powerless enough func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (*GeminiChatRequest, error) { @@ -56,67 +153,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon } } - if model_setting.GetGeminiSettings().ThinkingAdapterEnabled { - if strings.HasSuffix(info.OriginModelName, "-thinking") { - // 硬编码不支持 ThinkingBudget 的旧模型 - unsupportedModels := []string{ - "gemini-2.5-pro-preview-05-06", - "gemini-2.5-pro-preview-03-25", - } - - isUnsupported := false - for _, unsupportedModel := range unsupportedModels { - if strings.HasPrefix(info.OriginModelName, unsupportedModel) { - isUnsupported = true - break - } - } - - if isUnsupported { - geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{ - IncludeThoughts: true, - } - } else { - budgetTokens := model_setting.GetGeminiSettings().ThinkingAdapterBudgetTokensPercentage * float64(geminiRequest.GenerationConfig.MaxOutputTokens) - - // 检查是否为新的2.5pro模型(支持ThinkingBudget但有特殊范围) - isNew25Pro := strings.HasPrefix(info.OriginModelName, "gemini-2.5-pro") && - !strings.HasPrefix(info.OriginModelName, "gemini-2.5-pro-preview-05-06") && - !strings.HasPrefix(info.OriginModelName, "gemini-2.5-pro-preview-03-25") - - if isNew25Pro { - // 新的2.5pro模型:ThinkingBudget范围为128-32768 - if budgetTokens == 0 || budgetTokens < 128 { - budgetTokens = 128 - } else if budgetTokens > 32768 { - budgetTokens = 32768 - } - } else { - // 其他模型:ThinkingBudget范围为0-24576 - if budgetTokens == 0 || budgetTokens > 24576 { - budgetTokens = 24576 - } - } - - geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{ - ThinkingBudget: common.GetPointer(int(budgetTokens)), - IncludeThoughts: true, - } - } - } else if strings.HasSuffix(info.OriginModelName, "-nothinking") { - // 检查是否为新的2.5pro模型(不支持-nothinking,因为最低值只能为128) - isNew25Pro := strings.HasPrefix(info.OriginModelName, "gemini-2.5-pro") && - !strings.HasPrefix(info.OriginModelName, "gemini-2.5-pro-preview-05-06") && - !strings.HasPrefix(info.OriginModelName, "gemini-2.5-pro-preview-03-25") - - if !isNew25Pro { - // 只有非新2.5pro模型才支持-nothinking - geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{ - ThinkingBudget: common.GetPointer(0), - } - } - } - } + ThinkingAdaptor(&geminiRequest, info) safetySettings := make([]GeminiChatSafetySettings, 0, len(SafetySettingList)) for _, category := range SafetySettingList { @@ -283,7 +320,8 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon // 校验 MimeType 是否在 Gemini 支持的白名单中 if _, ok := geminiSupportedMimeTypes[strings.ToLower(fileData.MimeType)]; !ok { - return nil, fmt.Errorf("MIME type '%s' from URL '%s' is not supported by Gemini. Supported types are: %v", fileData.MimeType, part.GetImageMedia().Url, getSupportedMimeTypesList()) + url := part.GetImageMedia().Url + return nil, fmt.Errorf("mime type is not supported by Gemini: '%s', url: '%s', supported types are: %v", fileData.MimeType, url, getSupportedMimeTypesList()) } parts = append(parts, GeminiPart{ @@ -341,7 +379,9 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon if content.Role == "assistant" { content.Role = "model" } - geminiRequest.Contents = append(geminiRequest.Contents, content) + if len(content.Parts) > 0 { + geminiRequest.Contents = append(geminiRequest.Contents, content) + } } if len(system_content) > 0 { @@ -611,9 +651,9 @@ func getResponseToolCall(item *GeminiPart) *dto.ToolCallResponse { } } -func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResponse { +func responseGeminiChat2OpenAI(c *gin.Context, response *GeminiChatResponse) *dto.OpenAITextResponse { fullTextResponse := dto.OpenAITextResponse{ - Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), + Id: helper.GetResponseID(c), Object: "chat.completion", Created: common.GetTimestamp(), Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)), @@ -754,14 +794,14 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { // responseText := "" - id := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) + id := helper.GetResponseID(c) createAt := common.GetTimestamp() var usage = &dto.Usage{} var imageCount int helper.StreamScannerHandler(c, resp, info, func(data string) bool { var geminiResponse GeminiChatResponse - err := common.DecodeJsonStr(data, &geminiResponse) + err := common.UnmarshalJsonStr(data, &geminiResponse) if err != nil { common.LogError(c, "error unmarshalling stream response: "+err.Error()) return false @@ -826,15 +866,12 @@ func GeminiChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re if err != nil { return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } - err = resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } + common.CloseResponseBodyGracefully(resp) if common.DebugEnabled { println(string(responseBody)) } var geminiResponse GeminiChatResponse - err = common.DecodeJson(responseBody, &geminiResponse) + err = common.UnmarshalJson(responseBody, &geminiResponse) if err != nil { return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } @@ -849,7 +886,7 @@ func GeminiChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re StatusCode: resp.StatusCode, }, nil } - fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse) + fullTextResponse := responseGeminiChat2OpenAI(c, &geminiResponse) fullTextResponse.Model = info.UpstreamModelName usage := dto.Usage{ PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount, @@ -880,11 +917,12 @@ func GeminiChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re } func GeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { + defer common.CloseResponseBodyGracefully(resp) + responseBody, readErr := io.ReadAll(resp.Body) if readErr != nil { return nil, service.OpenAIErrorWrapper(readErr, "read_response_body_failed", http.StatusInternalServerError) } - _ = resp.Body.Close() var geminiResponse GeminiEmbeddingResponse if jsonErr := json.Unmarshal(responseBody, &geminiResponse); jsonErr != nil { @@ -916,14 +954,11 @@ func GeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *relaycomm } openAIResponse.Usage = *usage.(*dto.Usage) - jsonResponse, jsonErr := json.Marshal(openAIResponse) + jsonResponse, jsonErr := common.EncodeJson(openAIResponse) if jsonErr != nil { return nil, service.OpenAIErrorWrapper(jsonErr, "marshal_response_failed", http.StatusInternalServerError) } - c.Writer.Header().Set("Content-Type", "application/json") - c.Writer.WriteHeader(resp.StatusCode) - _, _ = c.Writer.Write(jsonResponse) - + common.IOCopyBytesGracefully(c, resp, jsonResponse) return usage, nil } diff --git a/relay/channel/jina/constant.go b/relay/channel/jina/constant.go index 45fc44c9..be290fb6 100644 --- a/relay/channel/jina/constant.go +++ b/relay/channel/jina/constant.go @@ -3,6 +3,7 @@ package jina var ModelList = []string{ "jina-clip-v1", "jina-reranker-v2-base-multilingual", + "jina-reranker-m0", } var ChannelName = "jina" diff --git a/relay/channel/mokaai/relay-mokaai.go b/relay/channel/mokaai/relay-mokaai.go index d7580d7a..645475dd 100644 --- a/relay/channel/mokaai/relay-mokaai.go +++ b/relay/channel/mokaai/relay-mokaai.go @@ -5,6 +5,7 @@ import ( "github.com/gin-gonic/gin" "io" "net/http" + "one-api/common" "one-api/dto" "one-api/service" ) @@ -26,7 +27,7 @@ func embeddingRequestOpenAI2Moka(request dto.GeneralOpenAIRequest) *dto.Embeddin } return &dto.EmbeddingRequest{ Input: input, - Model: request.Model, + Model: request.Model, } } @@ -53,10 +54,7 @@ func mokaEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIError if err != nil { return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } - err = resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } + common.CloseResponseBodyGracefully(resp) err = json.Unmarshal(responseBody, &baiduResponse) if err != nil { return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil @@ -80,4 +78,3 @@ func mokaEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIError _, err = c.Writer.Write(jsonResponse) return nil, &fullTextResponse.Usage } - diff --git a/relay/channel/ollama/relay-ollama.go b/relay/channel/ollama/relay-ollama.go index 89a04646..bf7501e5 100644 --- a/relay/channel/ollama/relay-ollama.go +++ b/relay/channel/ollama/relay-ollama.go @@ -1,12 +1,12 @@ package ollama import ( - "bytes" "encoding/json" "fmt" "github.com/gin-gonic/gin" "io" "net/http" + "one-api/common" "one-api/dto" "one-api/service" "strings" @@ -88,10 +88,7 @@ func ollamaEmbeddingHandler(c *gin.Context, resp *http.Response, promptTokens in if err != nil { return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } - err = resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } + common.CloseResponseBodyGracefully(resp) err = json.Unmarshal(responseBody, &ollamaEmbeddingResponse) if err != nil { return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil @@ -120,31 +117,7 @@ func ollamaEmbeddingHandler(c *gin.Context, resp *http.Response, promptTokens in if err != nil { return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil } - resp.Body = io.NopCloser(bytes.NewBuffer(doResponseBody)) - // We shouldn't set the header before we parse the response body, because the parse part may fail. - // And then we will have to send an error response, but in this case, the header has already been set. - // So the httpClient will be confused by the response. - // For example, Postman will report error, and we cannot check the response at all. - // Copy headers - for k, v := range resp.Header { - // 删除任何现有的相同头部,以防止重复添加头部 - c.Writer.Header().Del(k) - for _, vv := range v { - c.Writer.Header().Add(k, vv) - } - } - // reset content length - c.Writer.Header().Del("Content-Length") - c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(doResponseBody))) - c.Writer.WriteHeader(resp.StatusCode) - _, err = io.Copy(c.Writer, resp.Body) - if err != nil { - return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil - } - err = resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } + common.IOCopyBytesGracefully(c, resp, doResponseBody) return nil, usage } diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index f0cf073f..711284f1 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -9,8 +9,7 @@ import ( "mime/multipart" "net/http" "net/textproto" - "one-api/common" - constant2 "one-api/constant" + "one-api/constant" "one-api/dto" "one-api/relay/channel" "one-api/relay/channel/ai360" @@ -21,7 +20,7 @@ import ( "one-api/relay/channel/xinference" relaycommon "one-api/relay/common" "one-api/relay/common_handler" - "one-api/relay/constant" + relayconstant "one-api/relay/constant" "one-api/service" "path/filepath" "strings" @@ -54,7 +53,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { a.ChannelType = info.ChannelType // initialize ThinkingContentInfo when thinking_to_content is enabled - if think2Content, ok := info.ChannelSetting[constant2.ChannelSettingThinkingToContent].(bool); ok && think2Content { + if think2Content, ok := info.ChannelSetting[constant.ChannelSettingThinkingToContent].(bool); ok && think2Content { info.ThinkingContentInfo = relaycommon.ThinkingContentInfo{ IsFirstThinkingContent: true, SendLastThinkingContent: false, @@ -67,7 +66,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { if info.RelayFormat == relaycommon.RelayFormatClaude { return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil } - if info.RelayMode == constant.RelayModeRealtime { + if info.RelayMode == relayconstant.RelayModeRealtime { if strings.HasPrefix(info.BaseUrl, "https://") { baseUrl := strings.TrimPrefix(info.BaseUrl, "https://") baseUrl = "wss://" + baseUrl @@ -79,29 +78,36 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { } } switch info.ChannelType { - case common.ChannelTypeAzure: + case constant.ChannelTypeAzure: apiVersion := info.ApiVersion if apiVersion == "" { - apiVersion = constant2.AzureDefaultAPIVersion + apiVersion = constant.AzureDefaultAPIVersion } // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api requestURL := strings.Split(info.RequestURLPath, "?")[0] requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion) task := strings.TrimPrefix(requestURL, "/v1/") + + // 特殊处理 responses API + if info.RelayMode == relayconstant.RelayModeResponses { + requestURL = fmt.Sprintf("/openai/v1/responses?api-version=preview") + return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil + } + model_ := info.UpstreamModelName // 2025年5月10日后创建的渠道不移除. - if info.ChannelCreateTime < constant2.AzureNoRemoveDotTime { + if info.ChannelCreateTime < constant.AzureNoRemoveDotTime { model_ = strings.Replace(model_, ".", "", -1) } // https://github.com/songquanpeng/one-api/issues/67 requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task) - if info.RelayMode == constant.RelayModeRealtime { + if info.RelayMode == relayconstant.RelayModeRealtime { requestURL = fmt.Sprintf("/openai/realtime?deployment=%s&api-version=%s", model_, apiVersion) } return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil - case common.ChannelTypeMiniMax: + case constant.ChannelTypeMiniMax: return minimax.GetRequestURL(info) - case common.ChannelTypeCustom: + case constant.ChannelTypeCustom: url := info.BaseUrl url = strings.Replace(url, "{model}", info.UpstreamModelName, -1) return url, nil @@ -112,14 +118,14 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info *relaycommon.RelayInfo) error { channel.SetupApiRequestHeader(info, c, header) - if info.ChannelType == common.ChannelTypeAzure { + if info.ChannelType == constant.ChannelTypeAzure { header.Set("api-key", info.ApiKey) return nil } - if info.ChannelType == common.ChannelTypeOpenAI && "" != info.Organization { + if info.ChannelType == constant.ChannelTypeOpenAI && "" != info.Organization { header.Set("OpenAI-Organization", info.Organization) } - if info.RelayMode == constant.RelayModeRealtime { + if info.RelayMode == relayconstant.RelayModeRealtime { swp := c.Request.Header.Get("Sec-WebSocket-Protocol") if swp != "" { items := []string{ @@ -138,7 +144,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info * } else { header.Set("Authorization", "Bearer "+info.ApiKey) } - if info.ChannelType == common.ChannelTypeOpenRouter { + if info.ChannelType == constant.ChannelTypeOpenRouter { header.Set("HTTP-Referer", "https://github.com/Calcium-Ion/new-api") header.Set("X-Title", "New API") } @@ -149,9 +155,14 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn if request == nil { return nil, errors.New("request is nil") } - if info.ChannelType != common.ChannelTypeOpenAI && info.ChannelType != common.ChannelTypeAzure { + if info.ChannelType != constant.ChannelTypeOpenAI && info.ChannelType != constant.ChannelTypeAzure { request.StreamOptions = nil } + if info.ChannelType == constant.ChannelTypeOpenRouter { + if len(request.Usage) == 0 { + request.Usage = json.RawMessage(`{"include":true}`) + } + } if strings.HasPrefix(request.Model, "o") { if request.MaxCompletionTokens == 0 && request.MaxTokens != 0 { request.MaxCompletionTokens = request.MaxTokens @@ -193,7 +204,7 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { a.ResponseFormat = request.ResponseFormat - if info.RelayMode == constant.RelayModeAudioSpeech { + if info.RelayMode == relayconstant.RelayModeAudioSpeech { jsonData, err := json.Marshal(request) if err != nil { return nil, fmt.Errorf("error marshalling object: %w", err) @@ -242,7 +253,7 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { switch info.RelayMode { - case constant.RelayModeImagesEdits: + case relayconstant.RelayModeImagesEdits: var requestBody bytes.Buffer writer := multipart.NewWriter(&requestBody) @@ -399,11 +410,11 @@ func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommo } func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { - if info.RelayMode == constant.RelayModeAudioTranscription || - info.RelayMode == constant.RelayModeAudioTranslation || - info.RelayMode == constant.RelayModeImagesEdits { + if info.RelayMode == relayconstant.RelayModeAudioTranscription || + info.RelayMode == relayconstant.RelayModeAudioTranslation || + info.RelayMode == relayconstant.RelayModeImagesEdits { return channel.DoFormRequest(a, c, info, requestBody) - } else if info.RelayMode == constant.RelayModeRealtime { + } else if info.RelayMode == relayconstant.RelayModeRealtime { return channel.DoWssRequest(a, c, info, requestBody) } else { return channel.DoApiRequest(a, c, info, requestBody) @@ -412,19 +423,19 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { switch info.RelayMode { - case constant.RelayModeRealtime: + case relayconstant.RelayModeRealtime: err, usage = OpenaiRealtimeHandler(c, info) - case constant.RelayModeAudioSpeech: + case relayconstant.RelayModeAudioSpeech: err, usage = OpenaiTTSHandler(c, resp, info) - case constant.RelayModeAudioTranslation: + case relayconstant.RelayModeAudioTranslation: fallthrough - case constant.RelayModeAudioTranscription: + case relayconstant.RelayModeAudioTranscription: err, usage = OpenaiSTTHandler(c, resp, info, a.ResponseFormat) - case constant.RelayModeImagesGenerations, constant.RelayModeImagesEdits: + case relayconstant.RelayModeImagesGenerations, relayconstant.RelayModeImagesEdits: err, usage = OpenaiHandlerWithUsage(c, resp, info) - case constant.RelayModeRerank: + case relayconstant.RelayModeRerank: err, usage = common_handler.RerankHandler(c, info, resp) - case constant.RelayModeResponses: + case relayconstant.RelayModeResponses: if info.IsStream { err, usage = OaiResponsesStreamHandler(c, resp, info) } else { @@ -442,17 +453,17 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom func (a *Adaptor) GetModelList() []string { switch a.ChannelType { - case common.ChannelType360: + case constant.ChannelType360: return ai360.ModelList - case common.ChannelTypeMoonshot: + case constant.ChannelTypeMoonshot: return moonshot.ModelList - case common.ChannelTypeLingYiWanWu: + case constant.ChannelTypeLingYiWanWu: return lingyiwanwu.ModelList - case common.ChannelTypeMiniMax: + case constant.ChannelTypeMiniMax: return minimax.ModelList - case common.ChannelTypeXinference: + case constant.ChannelTypeXinference: return xinference.ModelList - case common.ChannelTypeOpenRouter: + case constant.ChannelTypeOpenRouter: return openrouter.ModelList default: return ModelList @@ -461,17 +472,17 @@ func (a *Adaptor) GetModelList() []string { func (a *Adaptor) GetChannelName() string { switch a.ChannelType { - case common.ChannelType360: + case constant.ChannelType360: return ai360.ChannelName - case common.ChannelTypeMoonshot: + case constant.ChannelTypeMoonshot: return moonshot.ChannelName - case common.ChannelTypeLingYiWanWu: + case constant.ChannelTypeLingYiWanWu: return lingyiwanwu.ChannelName - case common.ChannelTypeMiniMax: + case constant.ChannelTypeMiniMax: return minimax.ChannelName - case common.ChannelTypeXinference: + case constant.ChannelTypeXinference: return xinference.ChannelName - case common.ChannelTypeOpenRouter: + case constant.ChannelTypeOpenRouter: return openrouter.ChannelName default: return ChannelName diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index 2e3d8df1..7c283bd0 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -2,7 +2,6 @@ package openai import ( "bytes" - "encoding/json" "fmt" "io" "math" @@ -15,6 +14,7 @@ import ( "one-api/relay/helper" "one-api/service" "os" + "path/filepath" "strings" "github.com/bytedance/gopkg/util/gopool" @@ -33,7 +33,7 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo } var lastStreamResponse dto.ChatCompletionsStreamResponse - if err := common.DecodeJsonStr(data, &lastStreamResponse); err != nil { + if err := common.UnmarshalJsonStr(data, &lastStreamResponse); err != nil { return err } @@ -110,12 +110,13 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel return service.OpenAIErrorWrapper(fmt.Errorf("invalid response"), "invalid_response", http.StatusInternalServerError), nil } - containStreamUsage := false + defer common.CloseResponseBodyGracefully(resp) + + model := info.UpstreamModelName var responseId string var createAt int64 = 0 var systemFingerprint string - model := info.UpstreamModelName - + var containStreamUsage bool var responseTextBuilder strings.Builder var toolCount int var usage = &dto.Usage{} @@ -147,31 +148,15 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel return true }) + // 处理最后的响应 shouldSendLastResp := true - var lastStreamResponse dto.ChatCompletionsStreamResponse - err := common.DecodeJsonStr(lastStreamData, &lastStreamResponse) - if err == nil { - responseId = lastStreamResponse.Id - createAt = lastStreamResponse.Created - systemFingerprint = lastStreamResponse.GetSystemFingerprint() - model = lastStreamResponse.Model - if service.ValidUsage(lastStreamResponse.Usage) { - containStreamUsage = true - usage = lastStreamResponse.Usage - if !info.ShouldIncludeUsage { - shouldSendLastResp = false - } - } - for _, choice := range lastStreamResponse.Choices { - if choice.FinishReason != nil { - shouldSendLastResp = true - } - } + if err := handleLastResponse(lastStreamData, &responseId, &createAt, &systemFingerprint, &model, &usage, + &containStreamUsage, info, &shouldSendLastResp); err != nil { + common.SysError("error handling last response: " + err.Error()) } - if shouldSendLastResp { - sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent) - //err = handleStreamFormat(c, info, lastStreamData, forceFormat, thinkToContent) + if shouldSendLastResp && info.RelayFormat == relaycommon.RelayFormatOpenAI { + _ = sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent) } // 处理token计算 @@ -180,10 +165,10 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel } if !containStreamUsage { - usage, _ = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens) + usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens) usage.CompletionTokens += toolCount * 7 } else { - if info.ChannelType == common.ChannelTypeDeepSeek { + if info.ChannelType == constant.ChannelTypeDeepSeek { if usage.PromptCacheHitTokens != 0 { usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens } @@ -196,16 +181,14 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel } func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + defer common.CloseResponseBodyGracefully(resp) + var simpleResponse dto.OpenAITextResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } - err = resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } - err = common.DecodeJson(responseBody, &simpleResponse) + err = common.UnmarshalJson(responseBody, &simpleResponse) if err != nil { return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } @@ -215,7 +198,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI StatusCode: resp.StatusCode, }, nil } - + forceFormat := false if forceFmt, ok := info.ChannelSetting[constant.ForceFormat].(bool); ok { forceFormat = forceFmt @@ -224,7 +207,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) { completionTokens := 0 for _, choice := range simpleResponse.Choices { - ctkm, _ := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName) + ctkm := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName) completionTokens += ctkm } simpleResponse.Usage = dto.Usage{ @@ -237,7 +220,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI switch info.RelayFormat { case relaycommon.RelayFormatOpenAI: if forceFormat { - responseBody, err = json.Marshal(simpleResponse) + responseBody, err = common.EncodeJson(simpleResponse) if err != nil { return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil } @@ -246,40 +229,26 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI } case relaycommon.RelayFormatClaude: claudeResp := service.ResponseOpenAI2Claude(&simpleResponse, info) - claudeRespStr, err := json.Marshal(claudeResp) + claudeRespStr, err := common.EncodeJson(claudeResp) if err != nil { return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil } responseBody = claudeRespStr } - // Reset response body - resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) - // We shouldn't set the header before we parse the response body, because the parse part may fail. - // And then we will have to send an error response, but in this case, the header has already been set. - // So the httpClient will be confused by the response. - // For example, Postman will report error, and we cannot check the response at all. - for k, v := range resp.Header { - c.Writer.Header().Set(k, v[0]) - } - c.Writer.WriteHeader(resp.StatusCode) - _, err = io.Copy(c.Writer, resp.Body) - if err != nil { - //return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil - common.SysError("error copying response body: " + err.Error()) - } - resp.Body.Close() + common.IOCopyBytesGracefully(c, resp, responseBody) + return nil, &simpleResponse.Usage } func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { // the status code has been judged before, if there is a body reading failure, // it should be regarded as a non-recoverable error, so it should not return err for external retry. - // Analogous to nginx's load balancing, it will only retry if it can't be requested or - // if the upstream returns a specific status code, once the upstream has already written the header, - // the subsequent failure of the response body should be regarded as a non-recoverable error, + // Analogous to nginx's load balancing, it will only retry if it can't be requested or + // if the upstream returns a specific status code, once the upstream has already written the header, + // the subsequent failure of the response body should be regarded as a non-recoverable error, // and can be terminated directly. - defer resp.Body.Close() + defer common.CloseResponseBodyGracefully(resp) usage := &dto.Usage{} usage.PromptTokens = info.PromptTokens usage.TotalTokens = info.PromptTokens @@ -296,6 +265,8 @@ func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel } func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + defer common.CloseResponseBodyGracefully(resp) + // count tokens by audio file duration audioTokens, err := countAudioTokens(c) if err != nil { @@ -305,25 +276,8 @@ func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel if err != nil { return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } - err = resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } - // Reset response body - resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) - // We shouldn't set the header before we parse the response body, because the parse part may fail. - // And then we will have to send an error response, but in this case, the header has already been set. - // So the httpClient will be confused by the response. - // For example, Postman will report error, and we cannot check the response at all. - for k, v := range resp.Header { - c.Writer.Header().Set(k, v[0]) - } - c.Writer.WriteHeader(resp.StatusCode) - _, err = io.Copy(c.Writer, resp.Body) - if err != nil { - return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil - } - resp.Body.Close() + // 写入新的 response body + common.IOCopyBytesGracefully(c, resp, responseBody) usage := &dto.Usage{} usage.PromptTokens = audioTokens @@ -345,13 +299,14 @@ func countAudioTokens(c *gin.Context) (int, error) { if err = c.ShouldBind(&reqBody); err != nil { return 0, errors.WithStack(err) } - + ext := filepath.Ext(reqBody.File.Filename) // 获取文件扩展名 reqFp, err := reqBody.File.Open() if err != nil { return 0, errors.WithStack(err) } + defer reqFp.Close() - tmpFp, err := os.CreateTemp("", "audio-*") + tmpFp, err := os.CreateTemp("", "audio-*"+ext) if err != nil { return 0, errors.WithStack(err) } @@ -365,7 +320,7 @@ func countAudioTokens(c *gin.Context) (int, error) { return 0, errors.WithStack(err) } - duration, err := common.GetAudioDuration(c.Request.Context(), tmpFp.Name()) + duration, err := common.GetAudioDuration(c.Request.Context(), tmpFp.Name(), ext) if err != nil { return 0, errors.WithStack(err) } @@ -413,7 +368,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op } realtimeEvent := &dto.RealtimeEvent{} - err = json.Unmarshal(message, realtimeEvent) + err = common.UnmarshalJson(message, realtimeEvent) if err != nil { errChan <- fmt.Errorf("error unmarshalling message: %v", err) return @@ -473,7 +428,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op } info.SetFirstResponseTime() realtimeEvent := &dto.RealtimeEvent{} - err = json.Unmarshal(message, realtimeEvent) + err = common.UnmarshalJson(message, realtimeEvent) if err != nil { errChan <- fmt.Errorf("error unmarshalling message: %v", err) return @@ -520,9 +475,9 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op localUsage = &dto.RealtimeUsage{} // print now usage } - //common.LogInfo(c, fmt.Sprintf("realtime streaming sumUsage: %v", sumUsage)) - //common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage)) - //common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage)) + common.LogInfo(c, fmt.Sprintf("realtime streaming sumUsage: %v", sumUsage)) + common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage)) + common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage)) } else if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdated || realtimeEvent.Type == dto.RealtimeEventTypeSessionCreated { realtimeSession := realtimeEvent.Session @@ -599,40 +554,25 @@ func preConsumeUsage(ctx *gin.Context, info *relaycommon.RelayInfo, usage *dto.R } func OpenaiHandlerWithUsage(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + defer common.CloseResponseBodyGracefully(resp) + responseBody, err := io.ReadAll(resp.Body) if err != nil { return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } - err = resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } - // Reset response body - resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) - // We shouldn't set the header before we parse the response body, because the parse part may fail. - // And then we will have to send an error response, but in this case, the header has already been set. - // So the httpClient will be confused by the response. - // For example, Postman will report error, and we cannot check the response at all. - for k, v := range resp.Header { - c.Writer.Header().Set(k, v[0]) - } - // reset content length - c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(responseBody))) - c.Writer.WriteHeader(resp.StatusCode) - _, err = io.Copy(c.Writer, resp.Body) - if err != nil { - return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil - } - err = resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } var usageResp dto.SimpleResponse - err = json.Unmarshal(responseBody, &usageResp) + err = common.UnmarshalJson(responseBody, &usageResp) if err != nil { return service.OpenAIErrorWrapper(err, "parse_response_body_failed", http.StatusInternalServerError), nil } + + // 写入新的 response body + common.IOCopyBytesGracefully(c, resp, responseBody) + + // Once we've written to the client, we should not return errors anymore + // because the upstream has already consumed resources and returned content + // We should still perform billing even if parsing fails // format if usageResp.InputTokens > 0 { usageResp.PromptTokens += usageResp.InputTokens diff --git a/relay/channel/openai/relay_responses.go b/relay/channel/openai/relay_responses.go index 1d1e060e..7f426c33 100644 --- a/relay/channel/openai/relay_responses.go +++ b/relay/channel/openai/relay_responses.go @@ -1,7 +1,6 @@ package openai import ( - "bytes" "fmt" "io" "net/http" @@ -16,17 +15,15 @@ import ( ) func OaiResponsesHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + defer common.CloseResponseBodyGracefully(resp) + // read response body var responsesResponse dto.OpenAIResponsesResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } - err = resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } - err = common.DecodeJson(responseBody, &responsesResponse) + err = common.UnmarshalJson(responseBody, &responsesResponse) if err != nil { return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } @@ -41,22 +38,9 @@ func OaiResponsesHandler(c *gin.Context, resp *http.Response, info *relaycommon. }, nil } - // reset response body - resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) - // We shouldn't set the header before we parse the response body, because the parse part may fail. - // And then we will have to send an error response, but in this case, the header has already been set. - // So the httpClient will be confused by the response. - // For example, Postman will report error, and we cannot check the response at all. - for k, v := range resp.Header { - c.Writer.Header().Set(k, v[0]) - } - c.Writer.WriteHeader(resp.StatusCode) - // copy response body - _, err = io.Copy(c.Writer, resp.Body) - if err != nil { - common.SysError("error copying response body: " + err.Error()) - } - resp.Body.Close() + // 写入新的 response body + common.IOCopyBytesGracefully(c, resp, responseBody) + // compute usage usage := dto.Usage{} usage.PromptTokens = responsesResponse.Usage.InputTokens @@ -82,7 +66,7 @@ func OaiResponsesStreamHandler(c *gin.Context, resp *http.Response, info *relayc // 检查当前数据是否包含 completed 状态和 usage 信息 var streamResponse dto.ResponsesStreamResponse - if err := common.DecodeJsonStr(data, &streamResponse); err == nil { + if err := common.UnmarshalJsonStr(data, &streamResponse); err == nil { sendResponsesStreamData(c, streamResponse, data) switch streamResponse.Type { case "response.completed": @@ -110,7 +94,7 @@ func OaiResponsesStreamHandler(c *gin.Context, resp *http.Response, info *relayc tempStr := responseTextBuilder.String() if len(tempStr) > 0 { // 非正常结束,使用输出文本的 token 数量 - completionTokens, _ := service.CountTextToken(tempStr, info.UpstreamModelName) + completionTokens := service.CountTextToken(tempStr, info.UpstreamModelName) usage.CompletionTokens = completionTokens } } diff --git a/relay/channel/palm/adaptor.go b/relay/channel/palm/adaptor.go index 3a06e7ee..aee4a307 100644 --- a/relay/channel/palm/adaptor.go +++ b/relay/channel/palm/adaptor.go @@ -74,7 +74,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom if info.IsStream { var responseText string err, responseText = palmStreamHandler(c, resp) - usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) + usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) } else { err, usage = palmHandler(c, resp, info.PromptTokens, info.UpstreamModelName) } diff --git a/relay/channel/palm/relay-palm.go b/relay/channel/palm/relay-palm.go index 5c398b5e..44c60713 100644 --- a/relay/channel/palm/relay-palm.go +++ b/relay/channel/palm/relay-palm.go @@ -2,7 +2,6 @@ package palm import ( "encoding/json" - "fmt" "github.com/gin-gonic/gin" "io" "net/http" @@ -73,7 +72,7 @@ func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *dto.ChatCompleti func palmStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) { responseText := "" - responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) + responseId := helper.GetResponseID(c) createdTime := common.GetTimestamp() dataChan := make(chan string) stopChan := make(chan bool) @@ -84,12 +83,7 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWit stopChan <- true return } - err = resp.Body.Close() - if err != nil { - common.SysError("error closing stream response: " + err.Error()) - stopChan <- true - return - } + common.CloseResponseBodyGracefully(resp) var palmResponse PaLMChatResponse err = json.Unmarshal(responseBody, &palmResponse) if err != nil { @@ -123,10 +117,7 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWit return false } }) - err := resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" - } + common.CloseResponseBodyGracefully(resp) return nil, responseText } @@ -135,10 +126,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st if err != nil { return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } - err = resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } + common.CloseResponseBodyGracefully(resp) var palmResponse PaLMChatResponse err = json.Unmarshal(responseBody, &palmResponse) if err != nil { @@ -156,7 +144,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st }, nil } fullTextResponse := responsePaLM2OpenAI(&palmResponse) - completionTokens, _ := service.CountTextToken(palmResponse.Candidates[0].Content, model) + completionTokens := service.CountTextToken(palmResponse.Candidates[0].Content, model) usage := dto.Usage{ PromptTokens: promptTokens, CompletionTokens: completionTokens, diff --git a/relay/channel/siliconflow/relay-siliconflow.go b/relay/channel/siliconflow/relay-siliconflow.go index a01e745c..a52ebfda 100644 --- a/relay/channel/siliconflow/relay-siliconflow.go +++ b/relay/channel/siliconflow/relay-siliconflow.go @@ -5,6 +5,7 @@ import ( "github.com/gin-gonic/gin" "io" "net/http" + "one-api/common" "one-api/dto" "one-api/service" ) @@ -14,10 +15,7 @@ func siliconflowRerankHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIE if err != nil { return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } - err = resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } + common.CloseResponseBodyGracefully(resp) var siliconflowResp SFRerankResponse err = json.Unmarshal(responseBody, &siliconflowResp) if err != nil { diff --git a/relay/channel/task/jimeng/adaptor.go b/relay/channel/task/jimeng/adaptor.go new file mode 100644 index 00000000..8d057513 --- /dev/null +++ b/relay/channel/task/jimeng/adaptor.go @@ -0,0 +1,380 @@ +package jimeng + +import ( + "bytes" + "crypto/hmac" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "one-api/model" + "sort" + "strings" + "time" + + "github.com/gin-gonic/gin" + "github.com/pkg/errors" + + "one-api/common" + "one-api/constant" + "one-api/dto" + "one-api/relay/channel" + relaycommon "one-api/relay/common" + "one-api/service" +) + +// ============================ +// Request / Response structures +// ============================ + +type requestPayload struct { + ReqKey string `json:"req_key"` + BinaryDataBase64 []string `json:"binary_data_base64,omitempty"` + ImageUrls []string `json:"image_urls,omitempty"` + Prompt string `json:"prompt,omitempty"` + Seed int64 `json:"seed"` + AspectRatio string `json:"aspect_ratio"` +} + +type responsePayload struct { + Code int `json:"code"` + Message string `json:"message"` + RequestId string `json:"request_id"` + Data struct { + TaskID string `json:"task_id"` + } `json:"data"` +} + +type responseTask struct { + Code int `json:"code"` + Data struct { + BinaryDataBase64 []interface{} `json:"binary_data_base64"` + ImageUrls interface{} `json:"image_urls"` + RespData string `json:"resp_data"` + Status string `json:"status"` + VideoUrl string `json:"video_url"` + } `json:"data"` + Message string `json:"message"` + RequestId string `json:"request_id"` + Status int `json:"status"` + TimeElapsed string `json:"time_elapsed"` +} + +// ============================ +// Adaptor implementation +// ============================ + +type TaskAdaptor struct { + ChannelType int + accessKey string + secretKey string + baseURL string +} + +func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) { + a.ChannelType = info.ChannelType + a.baseURL = info.BaseUrl + + // apiKey format: "access_key|secret_key" + keyParts := strings.Split(info.ApiKey, "|") + if len(keyParts) == 2 { + a.accessKey = strings.TrimSpace(keyParts[0]) + a.secretKey = strings.TrimSpace(keyParts[1]) + } +} + +// ValidateRequestAndSetAction parses body, validates fields and sets default action. +func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) (taskErr *dto.TaskError) { + // Accept only POST /v1/video/generations as "generate" action. + action := constant.TaskActionGenerate + info.Action = action + + req := relaycommon.TaskSubmitReq{} + if err := common.UnmarshalBodyReusable(c, &req); err != nil { + taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest) + return + } + if strings.TrimSpace(req.Prompt) == "" { + taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest) + return + } + + // Store into context for later usage + c.Set("task_request", req) + return nil +} + +// BuildRequestURL constructs the upstream URL. +func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) { + return fmt.Sprintf("%s/?Action=CVSync2AsyncSubmitTask&Version=2022-08-31", a.baseURL), nil +} + +// BuildRequestHeader sets required headers. +func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error { + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + return a.signRequest(req, a.accessKey, a.secretKey) +} + +// BuildRequestBody converts request into Jimeng specific format. +func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error) { + v, exists := c.Get("task_request") + if !exists { + return nil, fmt.Errorf("request not found in context") + } + req := v.(relaycommon.TaskSubmitReq) + + body, err := a.convertToRequestPayload(&req) + if err != nil { + return nil, errors.Wrap(err, "convert request payload failed") + } + data, err := json.Marshal(body) + if err != nil { + return nil, err + } + return bytes.NewReader(data), nil +} + +// DoRequest delegates to common helper. +func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) { + return channel.DoTaskApiRequest(a, c, info, requestBody) +} + +// DoResponse handles upstream response, returns taskID etc. +func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) + return + } + _ = resp.Body.Close() + + // Parse Jimeng response + var jResp responsePayload + if err := json.Unmarshal(responseBody, &jResp); err != nil { + taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError) + return + } + + if jResp.Code != 10000 { + taskErr = service.TaskErrorWrapper(fmt.Errorf(jResp.Message), fmt.Sprintf("%d", jResp.Code), http.StatusInternalServerError) + return + } + + c.JSON(http.StatusOK, gin.H{"task_id": jResp.Data.TaskID}) + return jResp.Data.TaskID, responseBody, nil +} + +// FetchTask fetch task status +func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) { + taskID, ok := body["task_id"].(string) + if !ok { + return nil, fmt.Errorf("invalid task_id") + } + + uri := fmt.Sprintf("%s/?Action=CVSync2AsyncGetResult&Version=2022-08-31", baseUrl) + payload := map[string]string{ + "req_key": "jimeng_vgfm_t2v_l20", // This is fixed value from doc: https://www.volcengine.com/docs/85621/1544774 + "task_id": taskID, + } + payloadBytes, err := json.Marshal(payload) + if err != nil { + return nil, errors.Wrap(err, "marshal fetch task payload failed") + } + + req, err := http.NewRequest(http.MethodPost, uri, bytes.NewBuffer(payloadBytes)) + if err != nil { + return nil, err + } + + req.Header.Set("Accept", "application/json") + req.Header.Set("Content-Type", "application/json") + + keyParts := strings.Split(key, "|") + if len(keyParts) != 2 { + return nil, fmt.Errorf("invalid api key format for jimeng: expected 'ak|sk'") + } + accessKey := strings.TrimSpace(keyParts[0]) + secretKey := strings.TrimSpace(keyParts[1]) + + if err := a.signRequest(req, accessKey, secretKey); err != nil { + return nil, errors.Wrap(err, "sign request failed") + } + + return service.GetHttpClient().Do(req) +} + +func (a *TaskAdaptor) GetModelList() []string { + return []string{"jimeng_vgfm_t2v_l20"} +} + +func (a *TaskAdaptor) GetChannelName() string { + return "jimeng" +} + +func (a *TaskAdaptor) signRequest(req *http.Request, accessKey, secretKey string) error { + var bodyBytes []byte + var err error + + if req.Body != nil { + bodyBytes, err = io.ReadAll(req.Body) + if err != nil { + return errors.Wrap(err, "read request body failed") + } + _ = req.Body.Close() + req.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) // Rewind + } else { + bodyBytes = []byte{} + } + + payloadHash := sha256.Sum256(bodyBytes) + hexPayloadHash := hex.EncodeToString(payloadHash[:]) + + t := time.Now().UTC() + xDate := t.Format("20060102T150405Z") + shortDate := t.Format("20060102") + + req.Header.Set("Host", req.URL.Host) + req.Header.Set("X-Date", xDate) + req.Header.Set("X-Content-Sha256", hexPayloadHash) + + // Sort and encode query parameters to create canonical query string + queryParams := req.URL.Query() + sortedKeys := make([]string, 0, len(queryParams)) + for k := range queryParams { + sortedKeys = append(sortedKeys, k) + } + sort.Strings(sortedKeys) + var queryParts []string + for _, k := range sortedKeys { + values := queryParams[k] + sort.Strings(values) + for _, v := range values { + queryParts = append(queryParts, fmt.Sprintf("%s=%s", url.QueryEscape(k), url.QueryEscape(v))) + } + } + canonicalQueryString := strings.Join(queryParts, "&") + + headersToSign := map[string]string{ + "host": req.URL.Host, + "x-date": xDate, + "x-content-sha256": hexPayloadHash, + } + if req.Header.Get("Content-Type") != "" { + headersToSign["content-type"] = req.Header.Get("Content-Type") + } + + var signedHeaderKeys []string + for k := range headersToSign { + signedHeaderKeys = append(signedHeaderKeys, k) + } + sort.Strings(signedHeaderKeys) + + var canonicalHeaders strings.Builder + for _, k := range signedHeaderKeys { + canonicalHeaders.WriteString(k) + canonicalHeaders.WriteString(":") + canonicalHeaders.WriteString(strings.TrimSpace(headersToSign[k])) + canonicalHeaders.WriteString("\n") + } + signedHeaders := strings.Join(signedHeaderKeys, ";") + + canonicalRequest := fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s", + req.Method, + req.URL.Path, + canonicalQueryString, + canonicalHeaders.String(), + signedHeaders, + hexPayloadHash, + ) + + hashedCanonicalRequest := sha256.Sum256([]byte(canonicalRequest)) + hexHashedCanonicalRequest := hex.EncodeToString(hashedCanonicalRequest[:]) + + region := "cn-north-1" + serviceName := "cv" + credentialScope := fmt.Sprintf("%s/%s/%s/request", shortDate, region, serviceName) + stringToSign := fmt.Sprintf("HMAC-SHA256\n%s\n%s\n%s", + xDate, + credentialScope, + hexHashedCanonicalRequest, + ) + + kDate := hmacSHA256([]byte(secretKey), []byte(shortDate)) + kRegion := hmacSHA256(kDate, []byte(region)) + kService := hmacSHA256(kRegion, []byte(serviceName)) + kSigning := hmacSHA256(kService, []byte("request")) + signature := hex.EncodeToString(hmacSHA256(kSigning, []byte(stringToSign))) + + authorization := fmt.Sprintf("HMAC-SHA256 Credential=%s/%s, SignedHeaders=%s, Signature=%s", + accessKey, + credentialScope, + signedHeaders, + signature, + ) + req.Header.Set("Authorization", authorization) + return nil +} + +func hmacSHA256(key []byte, data []byte) []byte { + h := hmac.New(sha256.New, key) + h.Write(data) + return h.Sum(nil) +} + +func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) { + r := requestPayload{ + ReqKey: "jimeng_vgfm_i2v_l20", + Prompt: req.Prompt, + AspectRatio: "16:9", // Default aspect ratio + Seed: -1, // Default to random + } + + // Handle one-of image_urls or binary_data_base64 + if req.Image != "" { + if strings.HasPrefix(req.Image, "http") { + r.ImageUrls = []string{req.Image} + } else { + r.BinaryDataBase64 = []string{req.Image} + } + } + metadata := req.Metadata + medaBytes, err := json.Marshal(metadata) + if err != nil { + return nil, errors.Wrap(err, "metadata marshal metadata failed") + } + err = json.Unmarshal(medaBytes, &r) + if err != nil { + return nil, errors.Wrap(err, "unmarshal metadata failed") + } + return &r, nil +} + +func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) { + resTask := responseTask{} + if err := json.Unmarshal(respBody, &resTask); err != nil { + return nil, errors.Wrap(err, "unmarshal task result failed") + } + taskResult := relaycommon.TaskInfo{} + if resTask.Code == 10000 { + taskResult.Code = 0 + } else { + taskResult.Code = resTask.Code // todo uni code + taskResult.Reason = resTask.Message + taskResult.Status = model.TaskStatusFailure + taskResult.Progress = "100%" + } + switch resTask.Data.Status { + case "in_queue": + taskResult.Status = model.TaskStatusQueued + taskResult.Progress = "10%" + case "done": + taskResult.Status = model.TaskStatusSuccess + taskResult.Progress = "100%" + } + taskResult.Url = resTask.Data.VideoUrl + return &taskResult, nil +} diff --git a/relay/channel/task/kling/adaptor.go b/relay/channel/task/kling/adaptor.go new file mode 100644 index 00000000..afa39201 --- /dev/null +++ b/relay/channel/task/kling/adaptor.go @@ -0,0 +1,346 @@ +package kling + +import ( + "bytes" + "encoding/json" + "fmt" + "github.com/samber/lo" + "io" + "net/http" + "one-api/model" + "strings" + "time" + + "github.com/gin-gonic/gin" + "github.com/golang-jwt/jwt" + "github.com/pkg/errors" + + "one-api/common" + "one-api/constant" + "one-api/dto" + "one-api/relay/channel" + relaycommon "one-api/relay/common" + "one-api/service" +) + +// ============================ +// Request / Response structures +// ============================ + +type SubmitReq struct { + Prompt string `json:"prompt"` + Model string `json:"model,omitempty"` + Mode string `json:"mode,omitempty"` + Image string `json:"image,omitempty"` + Size string `json:"size,omitempty"` + Duration int `json:"duration,omitempty"` + Metadata map[string]interface{} `json:"metadata,omitempty"` +} + +type requestPayload struct { + Prompt string `json:"prompt,omitempty"` + Image string `json:"image,omitempty"` + Mode string `json:"mode,omitempty"` + Duration string `json:"duration,omitempty"` + AspectRatio string `json:"aspect_ratio,omitempty"` + ModelName string `json:"model_name,omitempty"` + CfgScale float64 `json:"cfg_scale,omitempty"` +} + +type responsePayload struct { + Code int `json:"code"` + Message string `json:"message"` + RequestId string `json:"request_id"` + Data struct { + TaskId string `json:"task_id"` + TaskStatus string `json:"task_status"` + TaskStatusMsg string `json:"task_status_msg"` + TaskResult struct { + Videos []struct { + Id string `json:"id"` + Url string `json:"url"` + Duration string `json:"duration"` + } `json:"videos"` + } `json:"task_result"` + CreatedAt int64 `json:"created_at"` + UpdatedAt int64 `json:"updated_at"` + } `json:"data"` +} + +// ============================ +// Adaptor implementation +// ============================ + +type TaskAdaptor struct { + ChannelType int + accessKey string + secretKey string + baseURL string +} + +func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) { + a.ChannelType = info.ChannelType + a.baseURL = info.BaseUrl + + // apiKey format: "access_key|secret_key" + keyParts := strings.Split(info.ApiKey, "|") + if len(keyParts) == 2 { + a.accessKey = strings.TrimSpace(keyParts[0]) + a.secretKey = strings.TrimSpace(keyParts[1]) + } +} + +// ValidateRequestAndSetAction parses body, validates fields and sets default action. +func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) (taskErr *dto.TaskError) { + // Accept only POST /v1/video/generations as "generate" action. + action := constant.TaskActionGenerate + info.Action = action + + var req SubmitReq + if err := common.UnmarshalBodyReusable(c, &req); err != nil { + taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest) + return + } + if strings.TrimSpace(req.Prompt) == "" { + taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest) + return + } + + // Store into context for later usage + c.Set("task_request", req) + return nil +} + +// BuildRequestURL constructs the upstream URL. +func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) { + path := lo.Ternary(info.Action == constant.TaskActionGenerate, "/v1/videos/image2video", "/v1/videos/text2video") + return fmt.Sprintf("%s%s", a.baseURL, path), nil +} + +// BuildRequestHeader sets required headers. +func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error { + token, err := a.createJWTToken() + if err != nil { + return fmt.Errorf("failed to create JWT token: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + req.Header.Set("Authorization", "Bearer "+token) + req.Header.Set("User-Agent", "kling-sdk/1.0") + return nil +} + +// BuildRequestBody converts request into Kling specific format. +func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error) { + v, exists := c.Get("task_request") + if !exists { + return nil, fmt.Errorf("request not found in context") + } + req := v.(SubmitReq) + + body, err := a.convertToRequestPayload(&req) + if err != nil { + return nil, err + } + data, err := json.Marshal(body) + if err != nil { + return nil, err + } + return bytes.NewReader(data), nil +} + +// DoRequest delegates to common helper. +func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) { + if action := c.GetString("action"); action != "" { + info.Action = action + } + return channel.DoTaskApiRequest(a, c, info, requestBody) +} + +// DoResponse handles upstream response, returns taskID etc. +func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) + return + } + + // Attempt Kling response parse first. + var kResp responsePayload + if err := json.Unmarshal(responseBody, &kResp); err == nil && kResp.Code == 0 { + c.JSON(http.StatusOK, gin.H{"task_id": kResp.Data.TaskId}) + return kResp.Data.TaskId, responseBody, nil + } + + // Fallback generic task response. + var generic dto.TaskResponse[string] + if err := json.Unmarshal(responseBody, &generic); err != nil { + taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError) + return + } + + if !generic.IsSuccess() { + taskErr = service.TaskErrorWrapper(fmt.Errorf(generic.Message), generic.Code, http.StatusInternalServerError) + return + } + + c.JSON(http.StatusOK, gin.H{"task_id": generic.Data}) + return generic.Data, responseBody, nil +} + +// FetchTask fetch task status +func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) { + taskID, ok := body["task_id"].(string) + if !ok { + return nil, fmt.Errorf("invalid task_id") + } + action, ok := body["action"].(string) + if !ok { + return nil, fmt.Errorf("invalid action") + } + path := lo.Ternary(action == constant.TaskActionGenerate, "/v1/videos/image2video", "/v1/videos/text2video") + url := fmt.Sprintf("%s%s/%s", baseUrl, path, taskID) + + req, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + return nil, err + } + + token, err := a.createJWTTokenWithKey(key) + if err != nil { + token = key + } + + req.Header.Set("Accept", "application/json") + req.Header.Set("Authorization", "Bearer "+token) + req.Header.Set("User-Agent", "kling-sdk/1.0") + + return service.GetHttpClient().Do(req) +} + +func (a *TaskAdaptor) GetModelList() []string { + return []string{"kling-v1", "kling-v1-6", "kling-v2-master"} +} + +func (a *TaskAdaptor) GetChannelName() string { + return "kling" +} + +// ============================ +// helpers +// ============================ + +func (a *TaskAdaptor) convertToRequestPayload(req *SubmitReq) (*requestPayload, error) { + r := requestPayload{ + Prompt: req.Prompt, + Image: req.Image, + Mode: defaultString(req.Mode, "std"), + Duration: fmt.Sprintf("%d", defaultInt(req.Duration, 5)), + AspectRatio: a.getAspectRatio(req.Size), + ModelName: req.Model, + CfgScale: 0.5, + } + if r.ModelName == "" { + r.ModelName = "kling-v1" + } + metadata := req.Metadata + medaBytes, err := json.Marshal(metadata) + if err != nil { + return nil, errors.Wrap(err, "metadata marshal metadata failed") + } + err = json.Unmarshal(medaBytes, &r) + if err != nil { + return nil, errors.Wrap(err, "unmarshal metadata failed") + } + return &r, nil +} + +func (a *TaskAdaptor) getAspectRatio(size string) string { + switch size { + case "1024x1024", "512x512": + return "1:1" + case "1280x720", "1920x1080": + return "16:9" + case "720x1280", "1080x1920": + return "9:16" + default: + return "1:1" + } +} + +func defaultString(s, def string) string { + if strings.TrimSpace(s) == "" { + return def + } + return s +} + +func defaultInt(v int, def int) int { + if v == 0 { + return def + } + return v +} + +// ============================ +// JWT helpers +// ============================ + +func (a *TaskAdaptor) createJWTToken() (string, error) { + return a.createJWTTokenWithKeys(a.accessKey, a.secretKey) +} + +func (a *TaskAdaptor) createJWTTokenWithKey(apiKey string) (string, error) { + parts := strings.Split(apiKey, "|") + if len(parts) != 2 { + return "", fmt.Errorf("invalid API key format, expected 'access_key,secret_key'") + } + return a.createJWTTokenWithKeys(strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1])) +} + +func (a *TaskAdaptor) createJWTTokenWithKeys(accessKey, secretKey string) (string, error) { + if accessKey == "" || secretKey == "" { + return "", fmt.Errorf("access key and secret key are required") + } + now := time.Now().Unix() + claims := jwt.MapClaims{ + "iss": accessKey, + "exp": now + 1800, // 30 minutes + "nbf": now - 5, + } + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + token.Header["typ"] = "JWT" + return token.SignedString([]byte(secretKey)) +} + +func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) { + resPayload := responsePayload{} + err := json.Unmarshal(respBody, &resPayload) + if err != nil { + return nil, errors.Wrap(err, "failed to unmarshal response body") + } + taskInfo := &relaycommon.TaskInfo{} + taskInfo.Code = resPayload.Code + taskInfo.TaskID = resPayload.Data.TaskId + taskInfo.Reason = resPayload.Message + //任务状态,枚举值:submitted(已提交)、processing(处理中)、succeed(成功)、failed(失败) + status := resPayload.Data.TaskStatus + switch status { + case "submitted": + taskInfo.Status = model.TaskStatusSubmitted + case "processing": + taskInfo.Status = model.TaskStatusInProgress + case "succeed": + taskInfo.Status = model.TaskStatusSuccess + case "failed": + taskInfo.Status = model.TaskStatusFailure + default: + return nil, fmt.Errorf("unknown task status: %s", status) + } + if videos := resPayload.Data.TaskResult.Videos; len(videos) > 0 { + video := videos[0] + taskInfo.Url = video.Url + } + return taskInfo, nil +} diff --git a/relay/channel/task/suno/adaptor.go b/relay/channel/task/suno/adaptor.go index 03d60516..9c04c7ad 100644 --- a/relay/channel/task/suno/adaptor.go +++ b/relay/channel/task/suno/adaptor.go @@ -22,6 +22,10 @@ type TaskAdaptor struct { ChannelType int } +func (a *TaskAdaptor) ParseTaskResult([]byte) (*relaycommon.TaskInfo, error) { + return nil, fmt.Errorf("not implement") // todo implement this method if needed +} + func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) { a.ChannelType = info.ChannelType } diff --git a/relay/channel/tencent/adaptor.go b/relay/channel/tencent/adaptor.go index 44718a25..7ea3aae7 100644 --- a/relay/channel/tencent/adaptor.go +++ b/relay/channel/tencent/adaptor.go @@ -98,7 +98,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom if info.IsStream { var responseText string err, responseText = tencentStreamHandler(c, resp) - usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) + usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) } else { err, usage = tencentHandler(c, resp) } diff --git a/relay/channel/tencent/relay-tencent.go b/relay/channel/tencent/relay-tencent.go index 1446e06e..a7106a88 100644 --- a/relay/channel/tencent/relay-tencent.go +++ b/relay/channel/tencent/relay-tencent.go @@ -124,10 +124,7 @@ func tencentStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIError helper.Done(c) - err := resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" - } + common.CloseResponseBodyGracefully(resp) return nil, responseText } @@ -138,10 +135,7 @@ func tencentHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithSt if err != nil { return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } - err = resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } + common.CloseResponseBodyGracefully(resp) err = json.Unmarshal(responseBody, &tencentSb) if err != nil { return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil diff --git a/relay/channel/vertex/adaptor.go b/relay/channel/vertex/adaptor.go index 31f84abf..e568f651 100644 --- a/relay/channel/vertex/adaptor.go +++ b/relay/channel/vertex/adaptor.go @@ -83,10 +83,13 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { suffix := "" if a.RequestMode == RequestModeGemini { if model_setting.GetGeminiSettings().ThinkingAdapterEnabled { - // suffix -thinking and -nothinking - if strings.HasSuffix(info.OriginModelName, "-thinking") { + // 新增逻辑:处理 -thinking- 格式 + if strings.Contains(info.UpstreamModelName, "-thinking-") { + parts := strings.Split(info.UpstreamModelName, "-thinking-") + info.UpstreamModelName = parts[0] + } else if strings.HasSuffix(info.UpstreamModelName, "-thinking") { // 旧的适配 info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking") - } else if strings.HasSuffix(info.OriginModelName, "-nothinking") { + } else if strings.HasSuffix(info.UpstreamModelName, "-nothinking") { info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-nothinking") } } @@ -123,14 +126,23 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { if v, ok := claudeModelMap[info.UpstreamModelName]; ok { model = v } - return fmt.Sprintf( - "https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s", - region, - adc.ProjectID, - region, - model, - suffix, - ), nil + if region == "global" { + return fmt.Sprintf( + "https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/anthropic/models/%s:%s", + adc.ProjectID, + model, + suffix, + ), nil + } else { + return fmt.Sprintf( + "https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s", + region, + adc.ProjectID, + region, + model, + suffix, + ), nil + } } else if a.RequestMode == RequestModeLlama { return fmt.Sprintf( "https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions", diff --git a/relay/channel/vertex/service_account.go b/relay/channel/vertex/service_account.go index cc640803..1d41c945 100644 --- a/relay/channel/vertex/service_account.go +++ b/relay/channel/vertex/service_account.go @@ -11,6 +11,7 @@ import ( "net/http" "net/url" relaycommon "one-api/relay/common" + "one-api/service" "strings" "fmt" @@ -45,7 +46,7 @@ func getAccessToken(a *Adaptor, info *relaycommon.RelayInfo) (string, error) { if err != nil { return "", fmt.Errorf("failed to create signed JWT: %w", err) } - newToken, err := exchangeJwtForAccessToken(signedJWT) + newToken, err := exchangeJwtForAccessToken(signedJWT, info) if err != nil { return "", fmt.Errorf("failed to exchange JWT for access token: %w", err) } @@ -96,14 +97,25 @@ func createSignedJWT(email, privateKeyPEM string) (string, error) { return signedToken, nil } -func exchangeJwtForAccessToken(signedJWT string) (string, error) { +func exchangeJwtForAccessToken(signedJWT string, info *relaycommon.RelayInfo) (string, error) { authURL := "https://www.googleapis.com/oauth2/v4/token" data := url.Values{} data.Set("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer") data.Set("assertion", signedJWT) - resp, err := http.PostForm(authURL, data) + var client *http.Client + var err error + if proxyURL, ok := info.ChannelSetting["proxy"]; ok { + client, err = service.NewProxyHttpClient(proxyURL.(string)) + if err != nil { + return "", fmt.Errorf("new proxy http client failed: %w", err) + } + } else { + client = service.GetHttpClient() + } + + resp, err := client.PostForm(authURL, data) if err != nil { return "", err } diff --git a/relay/channel/volcengine/adaptor.go b/relay/channel/volcengine/adaptor.go index a4a48ee9..78233934 100644 --- a/relay/channel/volcengine/adaptor.go +++ b/relay/channel/volcengine/adaptor.go @@ -1,15 +1,19 @@ package volcengine import ( + "bytes" "errors" "fmt" "io" + "mime/multipart" "net/http" + "net/textproto" "one-api/dto" "one-api/relay/channel" "one-api/relay/channel/openai" relaycommon "one-api/relay/common" "one-api/relay/constant" + "path/filepath" "strings" "github.com/gin-gonic/gin" @@ -30,8 +34,146 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf } func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { - //TODO implement me - return nil, errors.New("not implemented") + switch info.RelayMode { + case constant.RelayModeImagesEdits: + + var requestBody bytes.Buffer + writer := multipart.NewWriter(&requestBody) + + writer.WriteField("model", request.Model) + // 获取所有表单字段 + formData := c.Request.PostForm + // 遍历表单字段并打印输出 + for key, values := range formData { + if key == "model" { + continue + } + for _, value := range values { + writer.WriteField(key, value) + } + } + + // Parse the multipart form to handle both single image and multiple images + if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory + return nil, errors.New("failed to parse multipart form") + } + + if c.Request.MultipartForm != nil && c.Request.MultipartForm.File != nil { + // Check if "image" field exists in any form, including array notation + var imageFiles []*multipart.FileHeader + var exists bool + + // First check for standard "image" field + if imageFiles, exists = c.Request.MultipartForm.File["image"]; !exists || len(imageFiles) == 0 { + // If not found, check for "image[]" field + if imageFiles, exists = c.Request.MultipartForm.File["image[]"]; !exists || len(imageFiles) == 0 { + // If still not found, iterate through all fields to find any that start with "image[" + foundArrayImages := false + for fieldName, files := range c.Request.MultipartForm.File { + if strings.HasPrefix(fieldName, "image[") && len(files) > 0 { + foundArrayImages = true + for _, file := range files { + imageFiles = append(imageFiles, file) + } + } + } + + // If no image fields found at all + if !foundArrayImages && (len(imageFiles) == 0) { + return nil, errors.New("image is required") + } + } + } + + // Process all image files + for i, fileHeader := range imageFiles { + file, err := fileHeader.Open() + if err != nil { + return nil, fmt.Errorf("failed to open image file %d: %w", i, err) + } + defer file.Close() + + // If multiple images, use image[] as the field name + fieldName := "image" + if len(imageFiles) > 1 { + fieldName = "image[]" + } + + // Determine MIME type based on file extension + mimeType := detectImageMimeType(fileHeader.Filename) + + // Create a form file with the appropriate content type + h := make(textproto.MIMEHeader) + h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"; filename="%s"`, fieldName, fileHeader.Filename)) + h.Set("Content-Type", mimeType) + + part, err := writer.CreatePart(h) + if err != nil { + return nil, fmt.Errorf("create form part failed for image %d: %w", i, err) + } + + if _, err := io.Copy(part, file); err != nil { + return nil, fmt.Errorf("copy file failed for image %d: %w", i, err) + } + } + + // Handle mask file if present + if maskFiles, exists := c.Request.MultipartForm.File["mask"]; exists && len(maskFiles) > 0 { + maskFile, err := maskFiles[0].Open() + if err != nil { + return nil, errors.New("failed to open mask file") + } + defer maskFile.Close() + + // Determine MIME type for mask file + mimeType := detectImageMimeType(maskFiles[0].Filename) + + // Create a form file with the appropriate content type + h := make(textproto.MIMEHeader) + h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="mask"; filename="%s"`, maskFiles[0].Filename)) + h.Set("Content-Type", mimeType) + + maskPart, err := writer.CreatePart(h) + if err != nil { + return nil, errors.New("create form file failed for mask") + } + + if _, err := io.Copy(maskPart, maskFile); err != nil { + return nil, errors.New("copy mask file failed") + } + } + } else { + return nil, errors.New("no multipart form data found") + } + + // 关闭 multipart 编写器以设置分界线 + writer.Close() + c.Request.Header.Set("Content-Type", writer.FormDataContentType()) + return bytes.NewReader(requestBody.Bytes()), nil + + default: + return request, nil + } +} + +// detectImageMimeType determines the MIME type based on the file extension +func detectImageMimeType(filename string) string { + ext := strings.ToLower(filepath.Ext(filename)) + switch ext { + case ".jpg", ".jpeg": + return "image/jpeg" + case ".png": + return "image/png" + case ".webp": + return "image/webp" + default: + // Try to detect from extension if possible + if strings.HasPrefix(ext, ".jp") { + return "image/jpeg" + } + // Default to png as a fallback + return "image/png" + } } func (a *Adaptor) Init(info *relaycommon.RelayInfo) { @@ -46,6 +188,8 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { return fmt.Sprintf("%s/api/v3/chat/completions", info.BaseUrl), nil case constant.RelayModeEmbeddings: return fmt.Sprintf("%s/api/v3/embeddings", info.BaseUrl), nil + case constant.RelayModeImagesGenerations: + return fmt.Sprintf("%s/api/v3/images/generations", info.BaseUrl), nil default: } return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode) @@ -91,6 +235,8 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom } case constant.RelayModeEmbeddings: err, usage = openai.OpenaiHandler(c, resp, info) + case constant.RelayModeImagesGenerations, constant.RelayModeImagesEdits: + err, usage = openai.OpenaiHandlerWithUsage(c, resp, info) } return } diff --git a/relay/channel/xai/text.go b/relay/channel/xai/text.go index e019c2dc..4a030e48 100644 --- a/relay/channel/xai/text.go +++ b/relay/channel/xai/text.go @@ -1,9 +1,7 @@ package xai import ( - "bytes" "encoding/json" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/common" @@ -13,6 +11,8 @@ import ( "one-api/relay/helper" "one-api/service" "strings" + + "github.com/gin-gonic/gin" ) func streamResponseXAI2OpenAI(xAIResp *dto.ChatCompletionsStreamResponse, usage *dto.Usage) *dto.ChatCompletionsStreamResponse { @@ -68,23 +68,21 @@ func xAIStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel }) if !containStreamUsage { - usage, _ = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens) + usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens) usage.CompletionTokens += toolCount * 7 } helper.Done(c) - err := resp.Body.Close() - if err != nil { - //return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - common.SysError("close_response_body_failed: " + err.Error()) - } + common.CloseResponseBodyGracefully(resp) return nil, usage } func xAIHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + defer common.CloseResponseBodyGracefully(resp) + responseBody, err := io.ReadAll(resp.Body) - var response *dto.TextResponse - err = common.DecodeJson(responseBody, &response) + var response *dto.SimpleResponse + err = common.UnmarshalJson(responseBody, &response) if err != nil { common.SysError("error unmarshalling stream response: " + err.Error()) return nil, nil @@ -99,21 +97,7 @@ func xAIHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo return nil, nil } - // set new body - resp.Body = io.NopCloser(bytes.NewBuffer(encodeJson)) - - for k, v := range resp.Header { - c.Writer.Header().Set(k, v[0]) - } - c.Writer.WriteHeader(resp.StatusCode) - _, err = io.Copy(c.Writer, resp.Body) - if err != nil { - return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil - } - err = resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } + common.IOCopyBytesGracefully(c, resp, encodeJson) return nil, &response.Usage } diff --git a/relay/channel/xinference/dto.go b/relay/channel/xinference/dto.go index 2f12ad10..35f339fe 100644 --- a/relay/channel/xinference/dto.go +++ b/relay/channel/xinference/dto.go @@ -1,7 +1,7 @@ package xinference type XinRerankResponseDocument struct { - Document string `json:"document,omitempty"` + Document any `json:"document,omitempty"` Index int `json:"index"` RelevanceScore float64 `json:"relevance_score"` } diff --git a/relay/channel/zhipu/relay-zhipu.go b/relay/channel/zhipu/relay-zhipu.go index 744538e3..91cd384b 100644 --- a/relay/channel/zhipu/relay-zhipu.go +++ b/relay/channel/zhipu/relay-zhipu.go @@ -210,10 +210,7 @@ func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWi return false } }) - err := resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } + common.CloseResponseBodyGracefully(resp) return nil, usage } @@ -223,10 +220,7 @@ func zhipuHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStat if err != nil { return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } - err = resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } + common.CloseResponseBodyGracefully(resp) err = json.Unmarshal(responseBody, &zhipuResponse) if err != nil { return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil diff --git a/relay/claude_handler.go b/relay/claude_handler.go index fb68a88a..42139ddf 100644 --- a/relay/claude_handler.go +++ b/relay/claude_handler.go @@ -46,13 +46,11 @@ func ClaudeHelper(c *gin.Context) (claudeError *dto.ClaudeErrorWithStatusCode) { relayInfo.IsStream = true } - err = helper.ModelMappedHelper(c, relayInfo) + err = helper.ModelMappedHelper(c, relayInfo, textRequest) if err != nil { return service.ClaudeErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError) } - textRequest.Model = relayInfo.UpstreamModelName - promptTokens, err := getClaudePromptTokens(textRequest, relayInfo) // count messages token error 计算promptTokens错误 if err != nil { @@ -98,7 +96,7 @@ func ClaudeHelper(c *gin.Context) (claudeError *dto.ClaudeErrorWithStatusCode) { // BudgetTokens 为 max_tokens 的 80% textRequest.Thinking = &dto.Thinking{ Type: "enabled", - BudgetTokens: int(float64(textRequest.MaxTokens) * model_setting.GetClaudeSettings().ThinkingAdapterBudgetTokensPercentage), + BudgetTokens: common.GetPointer[int](int(float64(textRequest.MaxTokens) * model_setting.GetClaudeSettings().ThinkingAdapterBudgetTokensPercentage)), } // TODO: 临时处理 // https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#important-considerations-when-using-extended-thinking @@ -126,7 +124,7 @@ func ClaudeHelper(c *gin.Context) (claudeError *dto.ClaudeErrorWithStatusCode) { var httpResp *http.Response resp, err := adaptor.DoRequest(c, relayInfo, requestBody) if err != nil { - return service.ClaudeErrorWrapperLocal(err, "do_request_failed", http.StatusInternalServerError) + return service.ClaudeErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) } if resp != nil { diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index a842a58d..37161c16 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -34,9 +34,14 @@ type ClaudeConvertInfo struct { } const ( - RelayFormatOpenAI = "openai" - RelayFormatClaude = "claude" - RelayFormatGemini = "gemini" + RelayFormatOpenAI = "openai" + RelayFormatClaude = "claude" + RelayFormatGemini = "gemini" + RelayFormatOpenAIResponses = "openai_responses" + RelayFormatOpenAIAudio = "openai_audio" + RelayFormatOpenAIImage = "openai_image" + RelayFormatRerank = "rerank" + RelayFormatEmbedding = "embedding" ) type RerankerInfo struct { @@ -60,8 +65,8 @@ type RelayInfo struct { TokenId int TokenKey string UserId int - Group string - UserGroup string + UsingGroup string // 使用的分组 + UserGroup string // 用户所在分组 TokenUnlimited bool StartTime time.Time FirstResponseTime time.Time @@ -108,17 +113,17 @@ type RelayInfo struct { // 定义支持流式选项的通道类型 var streamSupportedChannels = map[int]bool{ - common.ChannelTypeOpenAI: true, - common.ChannelTypeAnthropic: true, - common.ChannelTypeAws: true, - common.ChannelTypeGemini: true, - common.ChannelCloudflare: true, - common.ChannelTypeAzure: true, - common.ChannelTypeVolcEngine: true, - common.ChannelTypeOllama: true, - common.ChannelTypeXai: true, - common.ChannelTypeDeepSeek: true, - common.ChannelTypeBaiduV2: true, + constant.ChannelTypeOpenAI: true, + constant.ChannelTypeAnthropic: true, + constant.ChannelTypeAws: true, + constant.ChannelTypeGemini: true, + constant.ChannelCloudflare: true, + constant.ChannelTypeAzure: true, + constant.ChannelTypeVolcEngine: true, + constant.ChannelTypeOllama: true, + constant.ChannelTypeXai: true, + constant.ChannelTypeDeepSeek: true, + constant.ChannelTypeBaiduV2: true, } func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo { @@ -143,6 +148,7 @@ func GenRelayInfoClaude(c *gin.Context) *RelayInfo { func GenRelayInfoRerank(c *gin.Context, req *dto.RerankRequest) *RelayInfo { info := GenRelayInfo(c) info.RelayMode = relayconstant.RelayModeRerank + info.RelayFormat = RelayFormatRerank info.RerankerInfo = &RerankerInfo{ Documents: req.Documents, ReturnDocuments: req.GetReturnDocuments(), @@ -150,9 +156,25 @@ func GenRelayInfoRerank(c *gin.Context, req *dto.RerankRequest) *RelayInfo { return info } +func GenRelayInfoOpenAIAudio(c *gin.Context) *RelayInfo { + info := GenRelayInfo(c) + info.RelayFormat = RelayFormatOpenAIAudio + return info +} + +func GenRelayInfoEmbedding(c *gin.Context) *RelayInfo { + info := GenRelayInfo(c) + info.RelayFormat = RelayFormatEmbedding + return info +} + func GenRelayInfoResponses(c *gin.Context, req *dto.OpenAIResponsesRequest) *RelayInfo { info := GenRelayInfo(c) info.RelayMode = relayconstant.RelayModeResponses + info.RelayFormat = RelayFormatOpenAIResponses + + info.SupportStreamOptions = false + info.ResponsesUsageInfo = &ResponsesUsageInfo{ BuiltInTools: make(map[string]*BuildInToolInfo), } @@ -175,42 +197,54 @@ func GenRelayInfoResponses(c *gin.Context, req *dto.OpenAIResponsesRequest) *Rel return info } -func GenRelayInfo(c *gin.Context) *RelayInfo { - channelType := c.GetInt("channel_type") - channelId := c.GetInt("channel_id") - channelSetting := c.GetStringMap("channel_setting") - paramOverride := c.GetStringMap("param_override") +func GenRelayInfoGemini(c *gin.Context) *RelayInfo { + info := GenRelayInfo(c) + info.RelayFormat = RelayFormatGemini + info.ShouldIncludeUsage = false + return info +} - tokenId := c.GetInt("token_id") - tokenKey := c.GetString("token_key") - userId := c.GetInt("id") - group := c.GetString("group") - tokenUnlimited := c.GetBool("token_unlimited_quota") - startTime := c.GetTime(constant.ContextKeyRequestStartTime) +func GenRelayInfoImage(c *gin.Context) *RelayInfo { + info := GenRelayInfo(c) + info.RelayFormat = RelayFormatOpenAIImage + return info +} + +func GenRelayInfo(c *gin.Context) *RelayInfo { + channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType) + channelId := common.GetContextKeyInt(c, constant.ContextKeyChannelId) + channelSetting := common.GetContextKeyStringMap(c, constant.ContextKeyChannelSetting) + paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyParamOverride) + + tokenId := common.GetContextKeyInt(c, constant.ContextKeyTokenId) + tokenKey := common.GetContextKeyString(c, constant.ContextKeyTokenKey) + userId := common.GetContextKeyInt(c, constant.ContextKeyUserId) + tokenUnlimited := common.GetContextKeyBool(c, constant.ContextKeyTokenUnlimited) + startTime := common.GetContextKeyTime(c, constant.ContextKeyRequestStartTime) // firstResponseTime = time.Now() - 1 second - apiType, _ := relayconstant.ChannelType2APIType(channelType) + apiType, _ := common.ChannelType2APIType(channelType) info := &RelayInfo{ - UserQuota: c.GetInt(constant.ContextKeyUserQuota), - UserSetting: c.GetStringMap(constant.ContextKeyUserSetting), - UserEmail: c.GetString(constant.ContextKeyUserEmail), + UserQuota: common.GetContextKeyInt(c, constant.ContextKeyUserQuota), + UserSetting: common.GetContextKeyStringMap(c, constant.ContextKeyUserSetting), + UserEmail: common.GetContextKeyString(c, constant.ContextKeyUserEmail), isFirstResponse: true, RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path), - BaseUrl: c.GetString("base_url"), + BaseUrl: common.GetContextKeyString(c, constant.ContextKeyBaseUrl), RequestURLPath: c.Request.URL.String(), ChannelType: channelType, ChannelId: channelId, TokenId: tokenId, TokenKey: tokenKey, UserId: userId, - Group: group, - UserGroup: c.GetString(constant.ContextKeyUserGroup), + UsingGroup: common.GetContextKeyString(c, constant.ContextKeyUsingGroup), + UserGroup: common.GetContextKeyString(c, constant.ContextKeyUserGroup), TokenUnlimited: tokenUnlimited, StartTime: startTime, FirstResponseTime: startTime.Add(-time.Second), - OriginModelName: c.GetString("original_model"), - UpstreamModelName: c.GetString("original_model"), + OriginModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel), + UpstreamModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel), //RecodeModelName: c.GetString("original_model"), IsModelMapped: false, ApiType: apiType, @@ -232,21 +266,17 @@ func GenRelayInfo(c *gin.Context) *RelayInfo { info.RequestURLPath = "/v1" + info.RequestURLPath } if info.BaseUrl == "" { - info.BaseUrl = common.ChannelBaseURLs[channelType] + info.BaseUrl = constant.ChannelBaseURLs[channelType] } - if info.ChannelType == common.ChannelTypeAzure { + if info.ChannelType == constant.ChannelTypeAzure { info.ApiVersion = GetAPIVersion(c) } - if info.ChannelType == common.ChannelTypeVertexAi { + if info.ChannelType == constant.ChannelTypeVertexAi { info.ApiVersion = c.GetString("region") } if streamSupportedChannels[info.ChannelType] { info.SupportStreamOptions = true } - // responses 模式不支持 StreamOptions - if relayconstant.RelayModeResponses == info.RelayMode { - info.SupportStreamOptions = false - } return info } @@ -283,3 +313,22 @@ func GenTaskRelayInfo(c *gin.Context) *TaskRelayInfo { } return info } + +type TaskSubmitReq struct { + Prompt string `json:"prompt"` + Model string `json:"model,omitempty"` + Mode string `json:"mode,omitempty"` + Image string `json:"image,omitempty"` + Size string `json:"size,omitempty"` + Duration int `json:"duration,omitempty"` + Metadata map[string]interface{} `json:"metadata,omitempty"` +} + +type TaskInfo struct { + Code int `json:"code"` + TaskID string `json:"task_id"` + Status string `json:"status"` + Reason string `json:"reason,omitempty"` + Url string `json:"url,omitempty"` + Progress string `json:"progress,omitempty"` +} diff --git a/relay/common/relay_utils.go b/relay/common/relay_utils.go index 7a4f44bb..29086585 100644 --- a/relay/common/relay_utils.go +++ b/relay/common/relay_utils.go @@ -6,7 +6,7 @@ import ( _ "image/gif" _ "image/jpeg" _ "image/png" - "one-api/common" + "one-api/constant" "strings" ) @@ -15,9 +15,9 @@ func GetFullRequestURL(baseURL string, requestURL string, channelType int) strin if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") { switch channelType { - case common.ChannelTypeOpenAI: + case constant.ChannelTypeOpenAI: fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1")) - case common.ChannelTypeAzure: + case constant.ChannelTypeAzure: fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments")) } } diff --git a/relay/common_handler/rerank.go b/relay/common_handler/rerank.go index 496278b5..0df219e3 100644 --- a/relay/common_handler/rerank.go +++ b/relay/common_handler/rerank.go @@ -5,6 +5,7 @@ import ( "io" "net/http" "one-api/common" + "one-api/constant" "one-api/dto" "one-api/relay/channel/xinference" relaycommon "one-api/relay/common" @@ -16,17 +17,14 @@ func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo if err != nil { return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } - err = resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } + common.CloseResponseBodyGracefully(resp) if common.DebugEnabled { println("reranker response body: ", string(responseBody)) } var jinaResp dto.RerankResponse - if info.ChannelType == common.ChannelTypeXinference { + if info.ChannelType == constant.ChannelTypeXinference { var xinRerankResponse xinference.XinRerankResponse - err = common.DecodeJson(responseBody, &xinRerankResponse) + err = common.UnmarshalJson(responseBody, &xinRerankResponse) if err != nil { return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } @@ -38,10 +36,16 @@ func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo } if info.ReturnDocuments { var document any - if result.Document == "" { - document = info.Documents[result.Index] - } else { - document = result.Document + if result.Document != nil { + if doc, ok := result.Document.(string); ok { + if doc == "" { + document = info.Documents[result.Index] + } else { + document = doc + } + } else { + document = result.Document + } } respResult.Document = document } @@ -55,7 +59,7 @@ func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo }, } } else { - err = common.DecodeJson(responseBody, &jinaResp) + err = common.UnmarshalJson(responseBody, &jinaResp) if err != nil { return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } diff --git a/relay/constant/api_type.go b/relay/constant/api_type.go deleted file mode 100644 index 3f1ecd78..00000000 --- a/relay/constant/api_type.go +++ /dev/null @@ -1,106 +0,0 @@ -package constant - -import ( - "one-api/common" -) - -const ( - APITypeOpenAI = iota - APITypeAnthropic - APITypePaLM - APITypeBaidu - APITypeZhipu - APITypeAli - APITypeXunfei - APITypeAIProxyLibrary - APITypeTencent - APITypeGemini - APITypeZhipuV4 - APITypeOllama - APITypePerplexity - APITypeAws - APITypeCohere - APITypeDify - APITypeJina - APITypeCloudflare - APITypeSiliconFlow - APITypeVertexAi - APITypeMistral - APITypeDeepSeek - APITypeMokaAI - APITypeVolcEngine - APITypeBaiduV2 - APITypeOpenRouter - APITypeXinference - APITypeXai - APITypeCoze - APITypeDummy // this one is only for count, do not add any channel after this -) - -func ChannelType2APIType(channelType int) (int, bool) { - apiType := -1 - switch channelType { - case common.ChannelTypeOpenAI: - apiType = APITypeOpenAI - case common.ChannelTypeAnthropic: - apiType = APITypeAnthropic - case common.ChannelTypeBaidu: - apiType = APITypeBaidu - case common.ChannelTypePaLM: - apiType = APITypePaLM - case common.ChannelTypeZhipu: - apiType = APITypeZhipu - case common.ChannelTypeAli: - apiType = APITypeAli - case common.ChannelTypeXunfei: - apiType = APITypeXunfei - case common.ChannelTypeAIProxyLibrary: - apiType = APITypeAIProxyLibrary - case common.ChannelTypeTencent: - apiType = APITypeTencent - case common.ChannelTypeGemini: - apiType = APITypeGemini - case common.ChannelTypeZhipu_v4: - apiType = APITypeZhipuV4 - case common.ChannelTypeOllama: - apiType = APITypeOllama - case common.ChannelTypePerplexity: - apiType = APITypePerplexity - case common.ChannelTypeAws: - apiType = APITypeAws - case common.ChannelTypeCohere: - apiType = APITypeCohere - case common.ChannelTypeDify: - apiType = APITypeDify - case common.ChannelTypeJina: - apiType = APITypeJina - case common.ChannelCloudflare: - apiType = APITypeCloudflare - case common.ChannelTypeSiliconFlow: - apiType = APITypeSiliconFlow - case common.ChannelTypeVertexAi: - apiType = APITypeVertexAi - case common.ChannelTypeMistral: - apiType = APITypeMistral - case common.ChannelTypeDeepSeek: - apiType = APITypeDeepSeek - case common.ChannelTypeMokaAI: - apiType = APITypeMokaAI - case common.ChannelTypeVolcEngine: - apiType = APITypeVolcEngine - case common.ChannelTypeBaiduV2: - apiType = APITypeBaiduV2 - case common.ChannelTypeOpenRouter: - apiType = APITypeOpenRouter - case common.ChannelTypeXinference: - apiType = APITypeXinference - case common.ChannelTypeXai: - apiType = APITypeXai - case common.ChannelTypeCoze: - apiType = APITypeCoze - } - if apiType == -1 { - return APITypeOpenAI, false - } - return apiType, true -} diff --git a/relay/constant/relay_mode.go b/relay/constant/relay_mode.go index f22a20bd..b5195752 100644 --- a/relay/constant/relay_mode.go +++ b/relay/constant/relay_mode.go @@ -29,6 +29,8 @@ const ( RelayModeMidjourneyShorten RelayModeSwapFace RelayModeMidjourneyUpload + RelayModeMidjourneyVideo + RelayModeMidjourneyEdits RelayModeAudioSpeech // tts RelayModeAudioTranscription // whisper @@ -38,6 +40,12 @@ const ( RelayModeSunoFetchByID RelayModeSunoSubmit + RelayModeKlingFetchByID + RelayModeKlingSubmit + + RelayModeJimengFetchByID + RelayModeJimengSubmit + RelayModeRerank RelayModeResponses @@ -77,7 +85,7 @@ func Path2RelayMode(path string) int { relayMode = RelayModeRerank } else if strings.HasPrefix(path, "/v1/realtime") { relayMode = RelayModeRealtime - } else if strings.HasPrefix(path, "/v1beta/models") { + } else if strings.HasPrefix(path, "/v1beta/models") || strings.HasPrefix(path, "/v1/models") { relayMode = RelayModeGemini } return relayMode @@ -102,6 +110,10 @@ func Path2RelayModeMidjourney(path string) int { relayMode = RelayModeMidjourneyUpload } else if strings.HasSuffix(path, "/mj/submit/imagine") { relayMode = RelayModeMidjourneyImagine + } else if strings.HasSuffix(path, "/mj/submit/video") { + relayMode = RelayModeMidjourneyVideo + } else if strings.HasSuffix(path, "/mj/submit/edits") { + relayMode = RelayModeMidjourneyEdits } else if strings.HasSuffix(path, "/mj/submit/blend") { relayMode = RelayModeMidjourneyBlend } else if strings.HasSuffix(path, "/mj/submit/describe") { @@ -133,3 +145,23 @@ func Path2RelaySuno(method, path string) int { } return relayMode } + +func Path2RelayKling(method, path string) int { + relayMode := RelayModeUnknown + if method == http.MethodPost && strings.HasSuffix(path, "/video/generations") { + relayMode = RelayModeKlingSubmit + } else if method == http.MethodGet && strings.Contains(path, "/video/generations/") { + relayMode = RelayModeKlingFetchByID + } + return relayMode +} + +func Path2RelayJimeng(method, path string) int { + relayMode := RelayModeUnknown + if method == http.MethodPost && strings.HasSuffix(path, "/video/generations") { + relayMode = RelayModeJimengSubmit + } else if method == http.MethodGet && strings.Contains(path, "/video/generations/") { + relayMode = RelayModeJimengFetchByID + } + return relayMode +} diff --git a/relay/relay_embedding.go b/relay/embedding_handler.go similarity index 94% rename from relay/relay_embedding.go rename to relay/embedding_handler.go index b4909849..849c70da 100644 --- a/relay/relay_embedding.go +++ b/relay/embedding_handler.go @@ -15,7 +15,7 @@ import ( ) func getEmbeddingPromptToken(embeddingRequest dto.EmbeddingRequest) int { - token, _ := service.CountTokenInput(embeddingRequest.Input, embeddingRequest.Model) + token := service.CountTokenInput(embeddingRequest.Input, embeddingRequest.Model) return token } @@ -33,7 +33,7 @@ func validateEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, embed } func EmbeddingHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { - relayInfo := relaycommon.GenRelayInfo(c) + relayInfo := relaycommon.GenRelayInfoEmbedding(c) var embeddingRequest *dto.EmbeddingRequest err := common.UnmarshalBodyReusable(c, &embeddingRequest) @@ -47,13 +47,11 @@ func EmbeddingHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) return service.OpenAIErrorWrapperLocal(err, "invalid_embedding_request", http.StatusBadRequest) } - err = helper.ModelMappedHelper(c, relayInfo) + err = helper.ModelMappedHelper(c, relayInfo, embeddingRequest) if err != nil { return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError) } - embeddingRequest.Model = relayInfo.UpstreamModelName - promptToken := getEmbeddingPromptToken(*embeddingRequest) relayInfo.PromptTokens = promptToken diff --git a/relay/relay-gemini.go b/relay/gemini_handler.go similarity index 63% rename from relay/relay-gemini.go rename to relay/gemini_handler.go index 21cf5e12..9185ce62 100644 --- a/relay/relay-gemini.go +++ b/relay/gemini_handler.go @@ -13,6 +13,7 @@ import ( "one-api/relay/helper" "one-api/service" "one-api/setting" + "one-api/setting/model_setting" "strings" "github.com/gin-gonic/gin" @@ -59,7 +60,7 @@ func checkGeminiInputSensitive(textRequest *gemini.GeminiChatRequest) ([]string, return sensitiveWords, err } -func getGeminiInputTokens(req *gemini.GeminiChatRequest, info *relaycommon.RelayInfo) (int, error) { +func getGeminiInputTokens(req *gemini.GeminiChatRequest, info *relaycommon.RelayInfo) int { // 计算输入 token 数量 var inputTexts []string for _, content := range req.Contents { @@ -71,9 +72,36 @@ func getGeminiInputTokens(req *gemini.GeminiChatRequest, info *relaycommon.Relay } inputText := strings.Join(inputTexts, "\n") - inputTokens, err := service.CountTokenInput(inputText, info.UpstreamModelName) + inputTokens := service.CountTokenInput(inputText, info.UpstreamModelName) info.PromptTokens = inputTokens - return inputTokens, err + return inputTokens +} + +func isNoThinkingRequest(req *gemini.GeminiChatRequest) bool { + if req.GenerationConfig.ThinkingConfig != nil && req.GenerationConfig.ThinkingConfig.ThinkingBudget != nil { + return *req.GenerationConfig.ThinkingConfig.ThinkingBudget <= 0 + } + return false +} + +func trimModelThinking(modelName string) string { + // 去除模型名称中的 -nothinking 后缀 + if strings.HasSuffix(modelName, "-nothinking") { + return strings.TrimSuffix(modelName, "-nothinking") + } + // 去除模型名称中的 -thinking 后缀 + if strings.HasSuffix(modelName, "-thinking") { + return strings.TrimSuffix(modelName, "-thinking") + } + + // 去除模型名称中的 -thinking-number + if strings.Contains(modelName, "-thinking-") { + parts := strings.Split(modelName, "-thinking-") + if len(parts) > 1 { + return parts[0] + "-thinking" + } + } + return modelName } func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { @@ -83,7 +111,7 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { return service.OpenAIErrorWrapperLocal(err, "invalid_gemini_request", http.StatusBadRequest) } - relayInfo := relaycommon.GenRelayInfo(c) + relayInfo := relaycommon.GenRelayInfoGemini(c) // 检查 Gemini 流式模式 checkGeminiStreamMode(c, relayInfo) @@ -97,7 +125,7 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { } // model mapped 模型映射 - err = helper.ModelMappedHelper(c, relayInfo) + err = helper.ModelMappedHelper(c, relayInfo, req) if err != nil { return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusBadRequest) } @@ -106,13 +134,28 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { promptTokens := value.(int) relayInfo.SetPromptTokens(promptTokens) } else { - promptTokens, err := getGeminiInputTokens(req, relayInfo) - if err != nil { - return service.OpenAIErrorWrapperLocal(err, "count_input_tokens_error", http.StatusBadRequest) - } + promptTokens := getGeminiInputTokens(req, relayInfo) c.Set("prompt_tokens", promptTokens) } + if model_setting.GetGeminiSettings().ThinkingAdapterEnabled { + if isNoThinkingRequest(req) { + // check is thinking + if !strings.Contains(relayInfo.OriginModelName, "-nothinking") { + // try to get no thinking model price + noThinkingModelName := relayInfo.OriginModelName + "-nothinking" + containPrice := helper.ContainPriceOrRatio(noThinkingModelName) + if containPrice { + relayInfo.OriginModelName = noThinkingModelName + relayInfo.UpstreamModelName = noThinkingModelName + } + } + } + if req.GenerationConfig.ThinkingConfig == nil { + gemini.ThinkingAdaptor(req, relayInfo) + } + } + priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, int(req.GenerationConfig.MaxOutputTokens)) if err != nil { return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError) @@ -155,14 +198,33 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { return service.OpenAIErrorWrapperLocal(err, "marshal_text_request_failed", http.StatusInternalServerError) } + if common.DebugEnabled { + println("Gemini request body: %s", string(requestBody)) + } + resp, err := adaptor.DoRequest(c, relayInfo, bytes.NewReader(requestBody)) if err != nil { common.LogError(c, "Do gemini request failed: "+err.Error()) - return service.OpenAIErrorWrapperLocal(err, "do_request_failed", http.StatusInternalServerError) + return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) + } + + statusCodeMappingStr := c.GetString("status_code_mapping") + + var httpResp *http.Response + if resp != nil { + httpResp = resp.(*http.Response) + relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream") + if httpResp.StatusCode != http.StatusOK { + openaiErr = service.RelayErrorHandler(httpResp, false) + // reset status code 重置状态码 + service.ResetStatusCode(openaiErr, statusCodeMappingStr) + return openaiErr + } } usage, openaiErr := adaptor.DoResponse(c, resp.(*http.Response), relayInfo) if openaiErr != nil { + service.ResetStatusCode(openaiErr, statusCodeMappingStr) return openaiErr } diff --git a/relay/helper/model_mapped.go b/relay/helper/model_mapped.go index 9bf67c03..c1735149 100644 --- a/relay/helper/model_mapped.go +++ b/relay/helper/model_mapped.go @@ -4,12 +4,14 @@ import ( "encoding/json" "errors" "fmt" + common2 "one-api/common" + "one-api/dto" "one-api/relay/common" "github.com/gin-gonic/gin" ) -func ModelMappedHelper(c *gin.Context, info *common.RelayInfo) error { +func ModelMappedHelper(c *gin.Context, info *common.RelayInfo, request any) error { // map model name modelMapping := c.GetString("model_mapping") if modelMapping != "" && modelMapping != "{}" { @@ -50,5 +52,41 @@ func ModelMappedHelper(c *gin.Context, info *common.RelayInfo) error { info.UpstreamModelName = currentModel } } + if request != nil { + switch info.RelayFormat { + case common.RelayFormatGemini: + // Gemini 模型映射 + case common.RelayFormatClaude: + if claudeRequest, ok := request.(*dto.ClaudeRequest); ok { + claudeRequest.Model = info.UpstreamModelName + } + case common.RelayFormatOpenAIResponses: + if openAIResponsesRequest, ok := request.(*dto.OpenAIResponsesRequest); ok { + openAIResponsesRequest.Model = info.UpstreamModelName + } + case common.RelayFormatOpenAIAudio: + if openAIAudioRequest, ok := request.(*dto.AudioRequest); ok { + openAIAudioRequest.Model = info.UpstreamModelName + } + case common.RelayFormatOpenAIImage: + if imageRequest, ok := request.(*dto.ImageRequest); ok { + imageRequest.Model = info.UpstreamModelName + } + case common.RelayFormatRerank: + if rerankRequest, ok := request.(*dto.RerankRequest); ok { + rerankRequest.Model = info.UpstreamModelName + } + case common.RelayFormatEmbedding: + if embeddingRequest, ok := request.(*dto.EmbeddingRequest); ok { + embeddingRequest.Model = info.UpstreamModelName + } + default: + if openAIRequest, ok := request.(*dto.GeneralOpenAIRequest); ok { + openAIRequest.Model = info.UpstreamModelName + } else { + common2.LogWarn(c, fmt.Sprintf("model mapped but request type %T not supported", request)) + } + } + } return nil } diff --git a/relay/helper/price.go b/relay/helper/price.go index 1b52bf37..ab614cbd 100644 --- a/relay/helper/price.go +++ b/relay/helper/price.go @@ -5,12 +5,17 @@ import ( "one-api/common" constant2 "one-api/constant" relaycommon "one-api/relay/common" - "one-api/setting" - "one-api/setting/operation_setting" + "one-api/setting/ratio_setting" "github.com/gin-gonic/gin" ) +type GroupRatioInfo struct { + GroupRatio float64 + GroupSpecialRatio float64 + HasSpecialRatio bool +} + type PriceData struct { ModelPrice float64 ModelRatio float64 @@ -18,23 +23,51 @@ type PriceData struct { CacheRatio float64 CacheCreationRatio float64 ImageRatio float64 - GroupRatio float64 - UserGroupRatio float64 UsePrice bool ShouldPreConsumedQuota int + GroupRatioInfo GroupRatioInfo } func (p PriceData) ToSetting() string { - return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d, ImageRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota, p.ImageRatio) + return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d, ImageRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatioInfo.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota, p.ImageRatio) +} + +// HandleGroupRatio checks for "auto_group" in the context and updates the group ratio and relayInfo.UsingGroup if present +func HandleGroupRatio(ctx *gin.Context, relayInfo *relaycommon.RelayInfo) GroupRatioInfo { + groupRatioInfo := GroupRatioInfo{ + GroupRatio: 1.0, // default ratio + GroupSpecialRatio: -1, + } + + // check auto group + autoGroup, exists := ctx.Get("auto_group") + if exists { + if common.DebugEnabled { + println(fmt.Sprintf("final group: %s", autoGroup)) + } + relayInfo.UsingGroup = autoGroup.(string) + } + + // check user group special ratio + userGroupRatio, ok := ratio_setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.UsingGroup) + if ok { + // user group special ratio + groupRatioInfo.GroupSpecialRatio = userGroupRatio + groupRatioInfo.GroupRatio = userGroupRatio + groupRatioInfo.HasSpecialRatio = true + } else { + // normal group ratio + groupRatioInfo.GroupRatio = ratio_setting.GetGroupRatio(relayInfo.UsingGroup) + } + + return groupRatioInfo } func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, maxTokens int) (PriceData, error) { - modelPrice, usePrice := operation_setting.GetModelPrice(info.OriginModelName, false) - groupRatio := setting.GetGroupRatio(info.Group) - userGroupRatio, ok := setting.GetGroupGroupRatio(info.UserGroup, info.Group) - if ok { - groupRatio = userGroupRatio - } + modelPrice, usePrice := ratio_setting.GetModelPrice(info.OriginModelName, false) + + groupRatioInfo := HandleGroupRatio(c, info) + var preConsumedQuota int var modelRatio float64 var completionRatio float64 @@ -47,7 +80,7 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens preConsumedTokens = promptTokens + maxTokens } var success bool - modelRatio, success = operation_setting.GetModelRatio(info.OriginModelName) + modelRatio, success = ratio_setting.GetModelRatio(info.OriginModelName) if !success { acceptUnsetRatio := false if accept, ok := info.UserSetting[constant2.UserAcceptUnsetRatioModel]; ok { @@ -60,22 +93,21 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens return PriceData{}, fmt.Errorf("模型 %s 倍率或价格未配置,请联系管理员设置或开始自用模式;Model %s ratio or price not set, please set or start self-use mode", info.OriginModelName, info.OriginModelName) } } - completionRatio = operation_setting.GetCompletionRatio(info.OriginModelName) - cacheRatio, _ = operation_setting.GetCacheRatio(info.OriginModelName) - cacheCreationRatio, _ = operation_setting.GetCreateCacheRatio(info.OriginModelName) - imageRatio, _ = operation_setting.GetImageRatio(info.OriginModelName) - ratio := modelRatio * groupRatio + completionRatio = ratio_setting.GetCompletionRatio(info.OriginModelName) + cacheRatio, _ = ratio_setting.GetCacheRatio(info.OriginModelName) + cacheCreationRatio, _ = ratio_setting.GetCreateCacheRatio(info.OriginModelName) + imageRatio, _ = ratio_setting.GetImageRatio(info.OriginModelName) + ratio := modelRatio * groupRatioInfo.GroupRatio preConsumedQuota = int(float64(preConsumedTokens) * ratio) } else { - preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio) + preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatioInfo.GroupRatio) } priceData := PriceData{ ModelPrice: modelPrice, ModelRatio: modelRatio, CompletionRatio: completionRatio, - GroupRatio: groupRatio, - UserGroupRatio: userGroupRatio, + GroupRatioInfo: groupRatioInfo, UsePrice: usePrice, CacheRatio: cacheRatio, ImageRatio: imageRatio, @@ -90,12 +122,41 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens return priceData, nil } +type PerCallPriceData struct { + ModelPrice float64 + Quota int + GroupRatioInfo GroupRatioInfo +} + +// ModelPriceHelperPerCall 按次计费的 PriceHelper (MJ、Task) +func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) PerCallPriceData { + groupRatioInfo := HandleGroupRatio(c, info) + + modelPrice, success := ratio_setting.GetModelPrice(info.OriginModelName, true) + // 如果没有配置价格,则使用默认价格 + if !success { + defaultPrice, ok := ratio_setting.GetDefaultModelRatioMap()[info.OriginModelName] + if !ok { + modelPrice = 0.1 + } else { + modelPrice = defaultPrice + } + } + quota := int(modelPrice * common.QuotaPerUnit * groupRatioInfo.GroupRatio) + priceData := PerCallPriceData{ + ModelPrice: modelPrice, + Quota: quota, + GroupRatioInfo: groupRatioInfo, + } + return priceData +} + func ContainPriceOrRatio(modelName string) bool { - _, ok := operation_setting.GetModelPrice(modelName, false) + _, ok := ratio_setting.GetModelPrice(modelName, false) if ok { return true } - _, ok = operation_setting.GetModelRatio(modelName) + _, ok = ratio_setting.GetModelRatio(modelName) if ok { return true } diff --git a/relay/helper/stream_scanner.go b/relay/helper/stream_scanner.go index a69877e2..b526b1c0 100644 --- a/relay/helper/stream_scanner.go +++ b/relay/helper/stream_scanner.go @@ -20,8 +20,8 @@ import ( ) const ( - InitialScannerBufferSize = 64 << 10 // 64KB (64*1024) - MaxScannerBufferSize = 10 << 20 // 10MB (10*1024*1024) + InitialScannerBufferSize = 64 << 10 // 64KB (64*1024) + MaxScannerBufferSize = 10 << 20 // 10MB (10*1024*1024) DefaultPingInterval = 10 * time.Second ) @@ -49,7 +49,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon scanner = bufio.NewScanner(resp.Body) ticker = time.NewTicker(streamingTimeout) pingTicker *time.Ticker - writeMutex sync.Mutex // Mutex to protect concurrent writes + writeMutex sync.Mutex // Mutex to protect concurrent writes wg sync.WaitGroup // 用于等待所有 goroutine 退出 ) @@ -64,32 +64,39 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon pingTicker = time.NewTicker(pingInterval) } + if common.DebugEnabled { + // print timeout and ping interval for debugging + println("relay timeout seconds:", common.RelayTimeout) + println("streaming timeout seconds:", int64(streamingTimeout.Seconds())) + println("ping interval seconds:", int64(pingInterval.Seconds())) + } + // 改进资源清理,确保所有 goroutine 正确退出 defer func() { // 通知所有 goroutine 停止 common.SafeSendBool(stopChan, true) - + ticker.Stop() if pingTicker != nil { pingTicker.Stop() } - + // 等待所有 goroutine 退出,最多等待5秒 done := make(chan struct{}) go func() { wg.Wait() close(done) }() - + select { case <-done: case <-time.After(5 * time.Second): common.LogError(c, "timeout waiting for goroutines to exit") } - + close(stopChan) }() - + scanner.Buffer(make([]byte, InitialScannerBufferSize), MaxScannerBufferSize) scanner.Split(bufio.ScanLines) SetEventStreamHeaders(c) @@ -113,12 +120,12 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon println("ping goroutine exited") } }() - + // 添加超时保护,防止 goroutine 无限运行 maxPingDuration := 30 * time.Minute // 最大 ping 持续时间 pingTimeout := time.NewTimer(maxPingDuration) defer pingTimeout.Stop() - + for { select { case <-pingTicker.C: @@ -129,7 +136,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon defer writeMutex.Unlock() done <- PingData(c) }() - + select { case err := <-done: if err != nil { @@ -175,7 +182,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon println("scanner goroutine exited") } }() - + for scanner.Scan() { // 检查是否需要停止 select { @@ -187,7 +194,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon return default: } - + ticker.Reset(streamingTimeout) data := scanner.Text() if common.DebugEnabled { @@ -205,7 +212,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon data = strings.TrimSuffix(data, "\r") if !strings.HasPrefix(data, "[DONE]") { info.SetFirstResponseTime() - + // 使用超时机制防止写操作阻塞 done := make(chan bool, 1) go func() { @@ -213,7 +220,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon defer writeMutex.Unlock() done <- dataHandler(data) }() - + select { case success := <-done: if !success { diff --git a/relay/relay-image.go b/relay/image_handler.go similarity index 95% rename from relay/relay-image.go rename to relay/image_handler.go index dc63cce8..5decb497 100644 --- a/relay/relay-image.go +++ b/relay/image_handler.go @@ -8,6 +8,7 @@ import ( "io" "net/http" "one-api/common" + "one-api/constant" "one-api/dto" "one-api/model" relaycommon "one-api/relay/common" @@ -44,6 +45,11 @@ func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto. if imageRequest.N == 0 { imageRequest.N = 1 } + + if info.ApiType == constant.APITypeVolcEngine { + watermark := formData.Has("watermark") + imageRequest.Watermark = &watermark + } default: err := common.UnmarshalBodyReusable(c, imageRequest) if err != nil { @@ -102,7 +108,7 @@ func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto. } func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { - relayInfo := relaycommon.GenRelayInfo(c) + relayInfo := relaycommon.GenRelayInfoImage(c) imageRequest, err := getAndValidImageRequest(c, relayInfo) if err != nil { @@ -110,13 +116,11 @@ func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { return service.OpenAIErrorWrapper(err, "invalid_image_request", http.StatusBadRequest) } - err = helper.ModelMappedHelper(c, relayInfo) + err = helper.ModelMappedHelper(c, relayInfo, imageRequest) if err != nil { return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError) } - imageRequest.Model = relayInfo.UpstreamModelName - priceData, err := helper.ModelPriceHelper(c, relayInfo, len(imageRequest.Prompt), 0) if err != nil { return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError) @@ -162,7 +166,7 @@ func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { // reset model price priceData.ModelPrice *= sizeRatio * qualityRatio * float64(imageRequest.N) - quota = int(priceData.ModelPrice * priceData.GroupRatio * common.QuotaPerUnit) + quota = int(priceData.ModelPrice * priceData.GroupRatioInfo.GroupRatio * common.QuotaPerUnit) userQuota, err = model.GetUserQuota(relayInfo.UserId, false) if err != nil { return service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError) diff --git a/relay/relay-mj.go b/relay/relay-mj.go index 9d0a2077..cc09e4a6 100644 --- a/relay/relay-mj.go +++ b/relay/relay-mj.go @@ -13,9 +13,9 @@ import ( "one-api/model" relaycommon "one-api/relay/common" relayconstant "one-api/relay/constant" + "one-api/relay/helper" "one-api/service" "one-api/setting" - "one-api/setting/operation_setting" "strconv" "strings" "time" @@ -106,6 +106,9 @@ func RelayMidjourneyNotify(c *gin.Context) *dto.MidjourneyResponse { midjourneyTask.StartTime = midjRequest.StartTime midjourneyTask.FinishTime = midjRequest.FinishTime midjourneyTask.ImageUrl = midjRequest.ImageUrl + midjourneyTask.VideoUrl = midjRequest.VideoUrl + videoUrlsStr, _ := json.Marshal(midjRequest.VideoUrls) + midjourneyTask.VideoUrls = string(videoUrlsStr) midjourneyTask.Status = midjRequest.Status midjourneyTask.FailReason = midjRequest.FailReason err = midjourneyTask.Update() @@ -136,6 +139,9 @@ func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjo } else { midjourneyTask.ImageUrl = originTask.ImageUrl } + if originTask.VideoUrl != "" { + midjourneyTask.VideoUrl = originTask.VideoUrl + } midjourneyTask.Status = originTask.Status midjourneyTask.FailReason = originTask.FailReason midjourneyTask.Action = originTask.Action @@ -148,6 +154,13 @@ func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjo midjourneyTask.Buttons = buttons } } + if originTask.VideoUrls != "" { + var videoUrls []dto.ImgUrls + err := json.Unmarshal([]byte(originTask.VideoUrls), &videoUrls) + if err == nil { + midjourneyTask.VideoUrls = videoUrls + } + } if originTask.Properties != "" { var properties dto.Properties err := json.Unmarshal([]byte(originTask.Properties), &properties) @@ -174,18 +187,9 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse { return service.MidjourneyErrorWrapper(constant.MjRequestError, "sour_base64_and_target_base64_is_required") } modelName := service.CoverActionToModelName(constant.MjActionSwapFace) - modelPrice, success := operation_setting.GetModelPrice(modelName, true) - // 如果没有配置价格,则使用默认价格 - if !success { - defaultPrice, ok := operation_setting.GetDefaultModelRatioMap()[modelName] - if !ok { - modelPrice = 0.1 - } else { - modelPrice = defaultPrice - } - } - groupRatio := setting.GetGroupRatio(group) - ratio := modelPrice * groupRatio + + priceData := helper.ModelPriceHelperPerCall(c, relayInfo) + userQuota, err := model.GetUserQuota(userId, false) if err != nil { return &dto.MidjourneyResponse{ @@ -193,9 +197,8 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse { Description: err.Error(), } } - quota := int(ratio * common.QuotaPerUnit) - if userQuota-quota < 0 { + if userQuota-priceData.Quota < 0 { return &dto.MidjourneyResponse{ Code: 4, Description: "quota_not_enough", @@ -210,26 +213,18 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse { } defer func() { if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 { - err := service.PostConsumeQuota(relayInfo, quota, 0, true) + err := service.PostConsumeQuota(relayInfo, priceData.Quota, 0, true) if err != nil { common.SysError("error consuming token remain quota: " + err.Error()) } - //err = model.CacheUpdateUserQuota(userId) - if err != nil { - common.SysError("error update user quota cache: " + err.Error()) - } - if quota != 0 { - tokenName := c.GetString("token_name") - logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, constant.MjActionSwapFace) - other := make(map[string]interface{}) - other["model_price"] = modelPrice - other["group_ratio"] = groupRatio - model.RecordConsumeLog(c, userId, channelId, 0, 0, modelName, tokenName, - quota, logContent, tokenId, userQuota, 0, false, group, other) - model.UpdateUserUsedQuotaAndRequestCount(userId, quota) - channelId := c.GetInt("channel_id") - model.UpdateChannelUsedQuota(channelId, quota) - } + + tokenName := c.GetString("token_name") + logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", priceData.ModelPrice, priceData.GroupRatioInfo.GroupRatio, constant.MjActionSwapFace) + other := service.GenerateMjOtherInfo(priceData) + model.RecordConsumeLog(c, userId, channelId, 0, 0, modelName, tokenName, + priceData.Quota, logContent, tokenId, userQuota, 0, false, group, other) + model.UpdateUserUsedQuotaAndRequestCount(userId, priceData.Quota) + model.UpdateChannelUsedQuota(channelId, priceData.Quota) } }() midjResponse := &mjResp.Response @@ -250,7 +245,7 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse { Progress: "0%", FailReason: "", ChannelId: c.GetInt("channel_id"), - Quota: quota, + Quota: priceData.Quota, } err = midjourneyTask.Insert() if err != nil { @@ -297,10 +292,7 @@ func RelayMidjourneyTaskImageSeed(c *gin.Context) *dto.MidjourneyResponse { if err != nil { return service.MidjourneyErrorWrapper(constant.MjRequestError, "unmarshal_response_body_failed") } - _, err = io.Copy(c.Writer, bytes.NewBuffer(respBody)) - if err != nil { - return service.MidjourneyErrorWrapper(constant.MjRequestError, "copy_response_body_failed") - } + common.IOCopyBytesGracefully(c, nil, respBody) return nil } @@ -391,6 +383,9 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons } relayMode = relayconstant.RelayModeMidjourneyChange } + if relayMode == relayconstant.RelayModeMidjourneyVideo { + midjRequest.Action = constant.MjActionVideo + } if relayMode == relayconstant.RelayModeMidjourneyImagine { //绘画任务,此类任务可重复 if midjRequest.Prompt == "" { @@ -399,6 +394,8 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons midjRequest.Action = constant.MjActionImagine } else if relayMode == relayconstant.RelayModeMidjourneyDescribe { //按图生文任务,此类任务可重复 midjRequest.Action = constant.MjActionDescribe + } else if relayMode == relayconstant.RelayModeMidjourneyEdits { //编辑任务,此类任务可重复 + midjRequest.Action = constant.MjActionEdits } else if relayMode == relayconstant.RelayModeMidjourneyShorten { //缩短任务,此类任务可重复,plus only midjRequest.Action = constant.MjActionShorten } else if relayMode == relayconstant.RelayModeMidjourneyBlend { //绘画任务,此类任务可重复 @@ -433,6 +430,14 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons //} mjId = midjRequest.TaskId midjRequest.Action = constant.MjActionModal + } else if relayMode == relayconstant.RelayModeMidjourneyVideo { + midjRequest.Action = constant.MjActionVideo + if midjRequest.TaskId == "" { + return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_id_is_required") + } else if midjRequest.Action == "" { + return service.MidjourneyErrorWrapper(constant.MjRequestError, "action_is_required") + } + mjId = midjRequest.TaskId } originTask := model.GetByMJId(userId, mjId) @@ -480,18 +485,9 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) modelName := service.CoverActionToModelName(midjRequest.Action) - modelPrice, success := operation_setting.GetModelPrice(modelName, true) - // 如果没有配置价格,则使用默认价格 - if !success { - defaultPrice, ok := operation_setting.GetDefaultModelRatioMap()[modelName] - if !ok { - modelPrice = 0.1 - } else { - modelPrice = defaultPrice - } - } - groupRatio := setting.GetGroupRatio(group) - ratio := modelPrice * groupRatio + + priceData := helper.ModelPriceHelperPerCall(c, relayInfo) + userQuota, err := model.GetUserQuota(userId, false) if err != nil { return &dto.MidjourneyResponse{ @@ -499,9 +495,8 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons Description: err.Error(), } } - quota := int(ratio * common.QuotaPerUnit) - if consumeQuota && userQuota-quota < 0 { + if consumeQuota && userQuota-priceData.Quota < 0 { return &dto.MidjourneyResponse{ Code: 4, Description: "quota_not_enough", @@ -516,22 +511,17 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons defer func() { if consumeQuota && midjResponseWithStatus.StatusCode == 200 { - err := service.PostConsumeQuota(relayInfo, quota, 0, true) + err := service.PostConsumeQuota(relayInfo, priceData.Quota, 0, true) if err != nil { common.SysError("error consuming token remain quota: " + err.Error()) } - if quota != 0 { - tokenName := c.GetString("token_name") - logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s,ID %s", modelPrice, groupRatio, midjRequest.Action, midjResponse.Result) - other := make(map[string]interface{}) - other["model_price"] = modelPrice - other["group_ratio"] = groupRatio - model.RecordConsumeLog(c, userId, channelId, 0, 0, modelName, tokenName, - quota, logContent, tokenId, userQuota, 0, false, group, other) - model.UpdateUserUsedQuotaAndRequestCount(userId, quota) - channelId := c.GetInt("channel_id") - model.UpdateChannelUsedQuota(channelId, quota) - } + tokenName := c.GetString("token_name") + logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s,ID %s", priceData.ModelPrice, priceData.GroupRatioInfo.GroupRatio, midjRequest.Action, midjResponse.Result) + other := service.GenerateMjOtherInfo(priceData) + model.RecordConsumeLog(c, userId, channelId, 0, 0, modelName, tokenName, + priceData.Quota, logContent, tokenId, userQuota, 0, false, group, other) + model.UpdateUserUsedQuotaAndRequestCount(userId, priceData.Quota) + model.UpdateChannelUsedQuota(channelId, priceData.Quota) } }() @@ -559,7 +549,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons Progress: "0%", FailReason: "", ChannelId: c.GetInt("channel_id"), - Quota: quota, + Quota: priceData.Quota, } if midjResponse.Code == 3 { //无实例账号自动禁用渠道(No available account instance) diff --git a/relay/relay-text.go b/relay/relay-text.go index 3aa382e8..e0c8f047 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -90,15 +90,16 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { // get & validate textRequest 获取并验证文本请求 textRequest, err := getAndValidateTextRequest(c, relayInfo) - if textRequest.WebSearchOptions != nil { - c.Set("chat_completion_web_search_context_size", textRequest.WebSearchOptions.SearchContextSize) - } if err != nil { common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error())) return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest) } + if textRequest.WebSearchOptions != nil { + c.Set("chat_completion_web_search_context_size", textRequest.WebSearchOptions.SearchContextSize) + } + if setting.ShouldCheckPromptSensitive() { words, err := checkRequestSensitive(textRequest, relayInfo) if err != nil { @@ -107,13 +108,11 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { } } - err = helper.ModelMappedHelper(c, relayInfo) + err = helper.ModelMappedHelper(c, relayInfo, textRequest) if err != nil { return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError) } - textRequest.Model = relayInfo.UpstreamModelName - // 获取 promptTokens,如果上下文中已经存在,则直接使用 var promptTokens int if value, exists := c.Get("prompt_tokens"); exists { @@ -252,11 +251,11 @@ func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.Re case relayconstant.RelayModeChatCompletions: promptTokens, err = service.CountTokenChatRequest(info, *textRequest) case relayconstant.RelayModeCompletions: - promptTokens, err = service.CountTokenInput(textRequest.Prompt, textRequest.Model) + promptTokens = service.CountTokenInput(textRequest.Prompt, textRequest.Model) case relayconstant.RelayModeModerations: - promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model) + promptTokens = service.CountTokenInput(textRequest.Input, textRequest.Model) case relayconstant.RelayModeEmbeddings: - promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model) + promptTokens = service.CountTokenInput(textRequest.Input, textRequest.Model) default: err = errors.New("unknown relay mode") promptTokens = 0 @@ -361,9 +360,8 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, cacheRatio := priceData.CacheRatio imageRatio := priceData.ImageRatio modelRatio := priceData.ModelRatio - groupRatio := priceData.GroupRatio + groupRatio := priceData.GroupRatioInfo.GroupRatio modelPrice := priceData.ModelPrice - userGroupRatio := priceData.UserGroupRatio // Convert values to decimal for precise calculation dPromptTokens := decimal.NewFromInt(int64(promptTokens)) @@ -511,7 +509,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, if extraContent != "" { logContent += ", " + extraContent } - other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice, userGroupRatio) + other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio) if imageTokens != 0 { other["image"] = true other["image_ratio"] = imageRatio @@ -543,5 +541,5 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, other["audio_input_price"] = audioInputPrice } model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, logModel, - tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other) + tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.UsingGroup, other) } diff --git a/relay/relay_adaptor.go b/relay/relay_adaptor.go index 7bf0da9f..00e59eac 100644 --- a/relay/relay_adaptor.go +++ b/relay/relay_adaptor.go @@ -1,6 +1,7 @@ package relay import ( + "one-api/constant" commonconstant "one-api/constant" "one-api/relay/channel" "one-api/relay/channel/ali" @@ -22,6 +23,8 @@ import ( "one-api/relay/channel/palm" "one-api/relay/channel/perplexity" "one-api/relay/channel/siliconflow" + "one-api/relay/channel/task/jimeng" + "one-api/relay/channel/task/kling" "one-api/relay/channel/task/suno" "one-api/relay/channel/tencent" "one-api/relay/channel/vertex" @@ -30,7 +33,6 @@ import ( "one-api/relay/channel/xunfei" "one-api/relay/channel/zhipu" "one-api/relay/channel/zhipu_4v" - "one-api/relay/constant" ) func GetAdaptor(apiType int) channel.Adaptor { @@ -101,6 +103,10 @@ func GetTaskAdaptor(platform commonconstant.TaskPlatform) channel.TaskAdaptor { // return &aiproxy.Adaptor{} case commonconstant.TaskPlatformSuno: return &suno.TaskAdaptor{} + case commonconstant.TaskPlatformKling: + return &kling.TaskAdaptor{} + case commonconstant.TaskPlatformJimeng: + return &jimeng.TaskAdaptor{} } return nil } diff --git a/relay/relay_task.go b/relay/relay_task.go index 26874ba6..702cff4c 100644 --- a/relay/relay_task.go +++ b/relay/relay_task.go @@ -5,7 +5,6 @@ import ( "encoding/json" "errors" "fmt" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/common" @@ -15,8 +14,9 @@ import ( relaycommon "one-api/relay/common" relayconstant "one-api/relay/constant" "one-api/service" - "one-api/setting" - "one-api/setting/operation_setting" + "one-api/setting/ratio_setting" + + "github.com/gin-gonic/gin" ) /* @@ -38,9 +38,12 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) { } modelName := service.CoverTaskActionToModelName(platform, relayInfo.Action) - modelPrice, success := operation_setting.GetModelPrice(modelName, true) + if platform == constant.TaskPlatformKling { + modelName = relayInfo.OriginModelName + } + modelPrice, success := ratio_setting.GetModelPrice(modelName, true) if !success { - defaultPrice, ok := operation_setting.GetDefaultModelRatioMap()[modelName] + defaultPrice, ok := ratio_setting.GetDefaultModelRatioMap()[modelName] if !ok { modelPrice = 0.1 } else { @@ -49,8 +52,14 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) { } // 预扣 - groupRatio := setting.GetGroupRatio(relayInfo.Group) - ratio := modelPrice * groupRatio + groupRatio := ratio_setting.GetGroupRatio(relayInfo.UsingGroup) + var ratio float64 + userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.UsingGroup) + if hasUserGroupRatio { + ratio = modelPrice * userGroupRatio + } else { + ratio = modelPrice * groupRatio + } userQuota, err := model.GetUserQuota(relayInfo.UserId, false) if err != nil { taskErr = service.TaskErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) @@ -119,12 +128,19 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) { } if quota != 0 { tokenName := c.GetString("token_name") - logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, relayInfo.Action) + gRatio := groupRatio + if hasUserGroupRatio { + gRatio = userGroupRatio + } + logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, gRatio, relayInfo.Action) other := make(map[string]interface{}) other["model_price"] = modelPrice other["group_ratio"] = groupRatio + if hasUserGroupRatio { + other["user_group_ratio"] = userGroupRatio + } model.RecordConsumeLog(c, relayInfo.UserId, relayInfo.ChannelId, 0, 0, - modelName, tokenName, quota, logContent, relayInfo.TokenId, userQuota, 0, false, relayInfo.Group, other) + modelName, tokenName, quota, logContent, relayInfo.TokenId, userQuota, 0, false, relayInfo.UsingGroup, other) model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota) model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota) } @@ -137,10 +153,11 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) { } relayInfo.ConsumeQuota = true // insert task - task := model.InitTask(constant.TaskPlatformSuno, relayInfo) + task := model.InitTask(platform, relayInfo) task.TaskID = taskID task.Quota = quota task.Data = taskData + task.Action = relayInfo.Action err = task.Insert() if err != nil { taskErr = service.TaskErrorWrapper(err, "insert_task_failed", http.StatusInternalServerError) @@ -150,8 +167,9 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) { } var fetchRespBuilders = map[int]func(c *gin.Context) (respBody []byte, taskResp *dto.TaskError){ - relayconstant.RelayModeSunoFetchByID: sunoFetchByIDRespBodyBuilder, - relayconstant.RelayModeSunoFetch: sunoFetchRespBodyBuilder, + relayconstant.RelayModeSunoFetchByID: sunoFetchByIDRespBodyBuilder, + relayconstant.RelayModeSunoFetch: sunoFetchRespBodyBuilder, + relayconstant.RelayModeKlingFetchByID: videoFetchByIDRespBodyBuilder, } func RelayTaskFetch(c *gin.Context, relayMode int) (taskResp *dto.TaskError) { @@ -226,6 +244,27 @@ func sunoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dt return } +func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) { + taskId := c.Param("task_id") + userId := c.GetInt("id") + + originTask, exist, err := model.GetByTaskId(userId, taskId) + if err != nil { + taskResp = service.TaskErrorWrapper(err, "get_task_failed", http.StatusInternalServerError) + return + } + if !exist { + taskResp = service.TaskErrorWrapperLocal(errors.New("task_not_exist"), "task_not_exist", http.StatusBadRequest) + return + } + + respBody, err = json.Marshal(dto.TaskResponse[any]{ + Code: "success", + Data: TaskModel2Dto(originTask), + }) + return +} + func TaskModel2Dto(task *model.Task) *dto.TaskDto { return &dto.TaskDto{ TaskID: task.TaskID, diff --git a/relay/relay_rerank.go b/relay/rerank_handler.go similarity index 91% rename from relay/relay_rerank.go rename to relay/rerank_handler.go index 6ca98de7..5cf384a8 100644 --- a/relay/relay_rerank.go +++ b/relay/rerank_handler.go @@ -14,12 +14,10 @@ import ( ) func getRerankPromptToken(rerankRequest dto.RerankRequest) int { - token, _ := service.CountTokenInput(rerankRequest.Query, rerankRequest.Model) + token := service.CountTokenInput(rerankRequest.Query, rerankRequest.Model) for _, document := range rerankRequest.Documents { - tkm, err := service.CountTokenInput(document, rerankRequest.Model) - if err == nil { - token += tkm - } + tkm := service.CountTokenInput(document, rerankRequest.Model) + token += tkm } return token } @@ -42,13 +40,11 @@ func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWith return service.OpenAIErrorWrapperLocal(fmt.Errorf("documents is empty"), "invalid_documents", http.StatusBadRequest) } - err = helper.ModelMappedHelper(c, relayInfo) + err = helper.ModelMappedHelper(c, relayInfo, rerankRequest) if err != nil { return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError) } - rerankRequest.Model = relayInfo.UpstreamModelName - promptToken := getRerankPromptToken(*rerankRequest) relayInfo.PromptTokens = promptToken @@ -82,12 +78,15 @@ func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWith return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError) } requestBody := bytes.NewBuffer(jsonData) - statusCodeMappingStr := c.GetString("status_code_mapping") + if common.DebugEnabled { + println(fmt.Sprintf("Rerank request body: %s", requestBody.String())) + } resp, err := adaptor.DoRequest(c, relayInfo, requestBody) if err != nil { return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) } + statusCodeMappingStr := c.GetString("status_code_mapping") var httpResp *http.Response if resp != nil { httpResp = resp.(*http.Response) diff --git a/relay/relay-responses.go b/relay/responses_handler.go similarity index 93% rename from relay/relay-responses.go rename to relay/responses_handler.go index fd3ddb5a..e744e354 100644 --- a/relay/relay-responses.go +++ b/relay/responses_handler.go @@ -40,10 +40,10 @@ func checkInputSensitive(textRequest *dto.OpenAIResponsesRequest, info *relaycom return sensitiveWords, err } -func getInputTokens(req *dto.OpenAIResponsesRequest, info *relaycommon.RelayInfo) (int, error) { - inputTokens, err := service.CountTokenInput(req.Input, req.Model) +func getInputTokens(req *dto.OpenAIResponsesRequest, info *relaycommon.RelayInfo) int { + inputTokens := service.CountTokenInput(req.Input, req.Model) info.PromptTokens = inputTokens - return inputTokens, err + return inputTokens } func ResponsesHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { @@ -63,19 +63,16 @@ func ResponsesHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) } } - err = helper.ModelMappedHelper(c, relayInfo) + err = helper.ModelMappedHelper(c, relayInfo, req) if err != nil { return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusBadRequest) } - req.Model = relayInfo.UpstreamModelName + if value, exists := c.Get("prompt_tokens"); exists { promptTokens := value.(int) relayInfo.SetPromptTokens(promptTokens) } else { - promptTokens, err := getInputTokens(req, relayInfo) - if err != nil { - return service.OpenAIErrorWrapperLocal(err, "count_input_tokens_error", http.StatusBadRequest) - } + promptTokens := getInputTokens(req, relayInfo) c.Set("prompt_tokens", promptTokens) } diff --git a/relay/websocket.go b/relay/websocket.go index c815eb71..571f3a82 100644 --- a/relay/websocket.go +++ b/relay/websocket.go @@ -6,12 +6,10 @@ import ( "github.com/gin-gonic/gin" "github.com/gorilla/websocket" "net/http" - "one-api/common" "one-api/dto" relaycommon "one-api/relay/common" + "one-api/relay/helper" "one-api/service" - "one-api/setting" - "one-api/setting/operation_setting" ) func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWithStatusCode) { @@ -39,43 +37,14 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWi //isModelMapped = true } } - //relayInfo.UpstreamModelName = textRequest.Model - modelPrice, getModelPriceSuccess := operation_setting.GetModelPrice(relayInfo.UpstreamModelName, false) - groupRatio := setting.GetGroupRatio(relayInfo.Group) - var preConsumedQuota int - var ratio float64 - var modelRatio float64 - //err := service.SensitiveWordsCheck(textRequest) - - //if constant.ShouldCheckPromptSensitive() { - // err = checkRequestSensitive(textRequest, relayInfo) - // if err != nil { - // return service.OpenAIErrorWrapperLocal(err, "sensitive_words_detected", http.StatusBadRequest) - // } - //} - - //promptTokens, err := getWssPromptTokens(realtimeEvent, relayInfo) - //// count messages token error 计算promptTokens错误 - //if err != nil { - // return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError) - //} - // - if !getModelPriceSuccess { - preConsumedTokens := common.PreConsumedQuota - //if realtimeEvent.Session.MaxResponseOutputTokens != 0 { - // preConsumedTokens = promptTokens + int(realtimeEvent.Session.MaxResponseOutputTokens) - //} - modelRatio, _ = operation_setting.GetModelRatio(relayInfo.UpstreamModelName) - ratio = modelRatio * groupRatio - preConsumedQuota = int(float64(preConsumedTokens) * ratio) - } else { - preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio) - relayInfo.UsePrice = true + priceData, err := helper.ModelPriceHelper(c, relayInfo, 0, 0) + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError) } // pre-consume quota 预消耗配额 - preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo) + preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) if openaiErr != nil { return openaiErr } @@ -113,6 +82,6 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWi return openaiErr } service.PostWssConsumeQuota(c, relayInfo, relayInfo.UpstreamModelName, usage.(*dto.RealtimeUsage), preConsumedQuota, - userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "") + userQuota, priceData, "") return nil } diff --git a/router/api-router.go b/router/api-router.go index 45930246..db4c3898 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -36,6 +36,7 @@ func SetApiRouter(router *gin.Engine) { apiRouter.GET("/oauth/email/bind", middleware.CriticalRateLimit(), controller.EmailBind) apiRouter.GET("/oauth/telegram/login", middleware.CriticalRateLimit(), controller.TelegramLogin) apiRouter.GET("/oauth/telegram/bind", middleware.CriticalRateLimit(), controller.TelegramBind) + apiRouter.GET("/ratio_config", middleware.CriticalRateLimit(), controller.GetRatioConfig) userRoute := apiRouter.Group("/user") { @@ -83,6 +84,12 @@ func SetApiRouter(router *gin.Engine) { optionRoute.POST("/rest_model_ratio", controller.ResetModelRatio) optionRoute.POST("/migrate_console_setting", controller.MigrateConsoleSetting) // 用于迁移检测的旧键,下个版本会删除 } + ratioSyncRoute := apiRouter.Group("/ratio_sync") + ratioSyncRoute.Use(middleware.RootAuth()) + { + ratioSyncRoute.GET("/channels", controller.GetSyncableChannels) + ratioSyncRoute.POST("/fetch", controller.FetchUpstreamRatios) + } channelRoute := apiRouter.Group("/channel") channelRoute.Use(middleware.AdminAuth()) { @@ -118,6 +125,7 @@ func SetApiRouter(router *gin.Engine) { tokenRoute.POST("/", controller.AddToken) tokenRoute.PUT("/", controller.UpdateToken) tokenRoute.DELETE("/:id", controller.DeleteToken) + tokenRoute.POST("/batch", controller.DeleteTokenBatch) } redemptionRoute := apiRouter.Group("/redemption") redemptionRoute.Use(middleware.AdminAuth()) diff --git a/router/main.go b/router/main.go index b8ac4055..0d2bfdce 100644 --- a/router/main.go +++ b/router/main.go @@ -14,6 +14,7 @@ func SetRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) { SetApiRouter(router) SetDashboardRouter(router) SetRelayRouter(router) + SetVideoRouter(router) frontendBaseUrl := os.Getenv("FRONTEND_BASE_URL") if common.IsMasterNode && frontendBaseUrl != "" { frontendBaseUrl = "" diff --git a/router/relay-router.go b/router/relay-router.go index aa7f27a8..b48c9dc7 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -63,6 +63,7 @@ func SetRelayRouter(router *gin.Engine) { httpRouter.DELETE("/models/:model", controller.RelayNotImplemented) httpRouter.POST("/moderations", controller.Relay) httpRouter.POST("/rerank", controller.Relay) + httpRouter.POST("/models/*path", controller.Relay) } relayMjRouter := router.Group("/mj") @@ -102,6 +103,8 @@ func registerMjRouterGroup(relayMjRouter *gin.RouterGroup) { relayMjRouter.POST("/submit/simple-change", controller.RelayMidjourney) relayMjRouter.POST("/submit/describe", controller.RelayMidjourney) relayMjRouter.POST("/submit/blend", controller.RelayMidjourney) + relayMjRouter.POST("/submit/edits", controller.RelayMidjourney) + relayMjRouter.POST("/submit/video", controller.RelayMidjourney) relayMjRouter.POST("/notify", controller.RelayMidjourney) relayMjRouter.GET("/task/:id/fetch", controller.RelayMidjourney) relayMjRouter.GET("/task/:id/image-seed", controller.RelayMidjourney) diff --git a/router/video-router.go b/router/video-router.go new file mode 100644 index 00000000..9e605d54 --- /dev/null +++ b/router/video-router.go @@ -0,0 +1,24 @@ +package router + +import ( + "one-api/controller" + "one-api/middleware" + + "github.com/gin-gonic/gin" +) + +func SetVideoRouter(router *gin.Engine) { + videoV1Router := router.Group("/v1") + videoV1Router.Use(middleware.TokenAuth(), middleware.Distribute()) + { + videoV1Router.POST("/video/generations", controller.RelayTask) + videoV1Router.GET("/video/generations/:task_id", controller.RelayTask) + } + + klingV1Router := router.Group("/kling/v1") + klingV1Router.Use(middleware.KlingRequestConvert(), middleware.TokenAuth(), middleware.Distribute()) + { + klingV1Router.POST("/videos/text2video", controller.RelayTask) + klingV1Router.POST("/videos/image2video", controller.RelayTask) + } +} diff --git a/service/channel.go b/service/channel.go index e3a76af4..d50de78d 100644 --- a/service/channel.go +++ b/service/channel.go @@ -4,6 +4,7 @@ import ( "fmt" "net/http" "one-api/common" + "one-api/constant" "one-api/dto" "one-api/model" "one-api/setting/operation_setting" @@ -48,7 +49,7 @@ func ShouldDisableChannel(channelType int, err *dto.OpenAIErrorWithStatusCode) b } if err.StatusCode == http.StatusForbidden { switch channelType { - case common.ChannelTypeGemini: + case constant.ChannelTypeGemini: return true } } @@ -59,6 +60,8 @@ func ShouldDisableChannel(channelType int, err *dto.OpenAIErrorWithStatusCode) b return true case "billing_not_active": return true + case "pre_consume_token_quota_failed": + return true } switch err.Error.Type { case "insufficient_quota": diff --git a/service/convert.go b/service/convert.go index cb964a46..c97f8475 100644 --- a/service/convert.go +++ b/service/convert.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "one-api/common" + "one-api/constant" "one-api/dto" "one-api/relay/channel/openrouter" relaycommon "one-api/relay/common" @@ -19,12 +20,12 @@ func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest, info *relaycommon.Re Stream: claudeRequest.Stream, } - isOpenRouter := info.ChannelType == common.ChannelTypeOpenRouter + isOpenRouter := info.ChannelType == constant.ChannelTypeOpenRouter - if claudeRequest.Thinking != nil { + if claudeRequest.Thinking != nil && claudeRequest.Thinking.Type == "enabled" { if isOpenRouter { reasoning := openrouter.RequestReasoning{ - MaxTokens: claudeRequest.Thinking.BudgetTokens, + MaxTokens: claudeRequest.Thinking.GetBudgetTokens(), } reasoningJSON, err := json.Marshal(reasoning) if err != nil { @@ -276,12 +277,15 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon } if info.Done { claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index)) - if info.ClaudeConvertInfo.Usage != nil { + oaiUsage := info.ClaudeConvertInfo.Usage + if oaiUsage != nil { claudeResponses = append(claudeResponses, &dto.ClaudeResponse{ Type: "message_delta", Usage: &dto.ClaudeUsage{ - InputTokens: info.ClaudeConvertInfo.Usage.PromptTokens, - OutputTokens: info.ClaudeConvertInfo.Usage.CompletionTokens, + InputTokens: oaiUsage.PromptTokens, + OutputTokens: oaiUsage.CompletionTokens, + CacheCreationInputTokens: oaiUsage.PromptTokensDetails.CachedCreationTokens, + CacheReadInputTokens: oaiUsage.PromptTokensDetails.CachedTokens, }, Delta: &dto.ClaudeMediaMessage{ StopReason: common.GetPointer[string](stopReasonOpenAI2Claude(info.FinishReason)), diff --git a/service/error.go b/service/error.go index 1bf5992b..21835f2a 100644 --- a/service/error.go +++ b/service/error.go @@ -29,9 +29,11 @@ func MidjourneyErrorWithStatusCodeWrapper(code int, desc string, statusCode int) func OpenAIErrorWrapper(err error, code string, statusCode int) *dto.OpenAIErrorWithStatusCode { text := err.Error() lowerText := strings.ToLower(text) - if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") { - common.SysLog(fmt.Sprintf("error: %s", text)) - text = "请求上游地址失败" + if !strings.HasPrefix(lowerText, "get file base64 from url") && !strings.HasPrefix(lowerText, "mime type is not supported") { + if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") { + common.SysLog(fmt.Sprintf("error: %s", text)) + text = "请求上游地址失败" + } } openAIError := dto.OpenAIError{ Message: text, @@ -53,9 +55,11 @@ func OpenAIErrorWrapperLocal(err error, code string, statusCode int) *dto.OpenAI func ClaudeErrorWrapper(err error, code string, statusCode int) *dto.ClaudeErrorWithStatusCode { text := err.Error() lowerText := strings.ToLower(text) - if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") { - common.SysLog(fmt.Sprintf("error: %s", text)) - text = "请求上游地址失败" + if !strings.HasPrefix(lowerText, "get file base64 from url") { + if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") { + common.SysLog(fmt.Sprintf("error: %s", text)) + text = "请求上游地址失败" + } } claudeError := dto.ClaudeError{ Message: text, @@ -86,10 +90,7 @@ func RelayErrorHandler(resp *http.Response, showBodyWhenFail bool) (errWithStatu if err != nil { return } - err = resp.Body.Close() - if err != nil { - return - } + common.CloseResponseBodyGracefully(resp) var errResponse dto.GeneralErrorResponse err = json.Unmarshal(responseBody, &errResponse) if err != nil { diff --git a/service/file_decoder.go b/service/file_decoder.go index bbb188f8..c1d4fb0c 100644 --- a/service/file_decoder.go +++ b/service/file_decoder.go @@ -4,8 +4,10 @@ import ( "encoding/base64" "fmt" "io" + "one-api/common" "one-api/constant" "one-api/dto" + "strings" ) func GetFileBase64FromUrl(url string) (*dto.LocalFileData, error) { @@ -30,9 +32,104 @@ func GetFileBase64FromUrl(url string) (*dto.LocalFileData, error) { // Convert to base64 base64Data := base64.StdEncoding.EncodeToString(fileBytes) + mimeType := resp.Header.Get("Content-Type") + if len(strings.Split(mimeType, ";")) > 1 { + // If Content-Type has parameters, take the first part + mimeType = strings.Split(mimeType, ";")[0] + } + if mimeType == "application/octet-stream" { + if common.DebugEnabled { + println("MIME type is application/octet-stream, trying to guess from URL or filename") + } + // try to guess the MIME type from the url last segment + urlParts := strings.Split(url, "/") + if len(urlParts) > 0 { + lastSegment := urlParts[len(urlParts)-1] + if strings.Contains(lastSegment, ".") { + // Extract the file extension + filename := strings.Split(lastSegment, ".") + if len(filename) > 1 { + ext := strings.ToLower(filename[len(filename)-1]) + // Guess MIME type based on file extension + mimeType = GetMimeTypeByExtension(ext) + } + } + } else { + // try to guess the MIME type from the file extension + fileName := resp.Header.Get("Content-Disposition") + if fileName != "" { + // Extract the filename from the Content-Disposition header + parts := strings.Split(fileName, ";") + for _, part := range parts { + if strings.HasPrefix(strings.TrimSpace(part), "filename=") { + fileName = strings.TrimSpace(strings.TrimPrefix(part, "filename=")) + // Remove quotes if present + if len(fileName) > 2 && fileName[0] == '"' && fileName[len(fileName)-1] == '"' { + fileName = fileName[1 : len(fileName)-1] + } + // Guess MIME type based on file extension + if ext := strings.ToLower(strings.TrimPrefix(fileName, ".")); ext != "" { + mimeType = GetMimeTypeByExtension(ext) + } + break + } + } + } + } + } + return &dto.LocalFileData{ Base64Data: base64Data, - MimeType: resp.Header.Get("Content-Type"), + MimeType: mimeType, Size: int64(len(fileBytes)), }, nil } + +func GetMimeTypeByExtension(ext string) string { + // Convert to lowercase for case-insensitive comparison + ext = strings.ToLower(ext) + switch ext { + // Text files + case "txt", "md", "markdown", "csv", "json", "xml", "html", "htm": + return "text/plain" + + // Image files + case "jpg", "jpeg": + return "image/jpeg" + case "png": + return "image/png" + case "gif": + return "image/gif" + + // Audio files + case "mp3": + return "audio/mp3" + case "wav": + return "audio/wav" + case "mpeg": + return "audio/mpeg" + + // Video files + case "mp4": + return "video/mp4" + case "wmv": + return "video/wmv" + case "flv": + return "video/flv" + case "mov": + return "video/mov" + case "mpg": + return "video/mpg" + case "avi": + return "video/avi" + case "mpegps": + return "video/mpegps" + + // Document files + case "pdf": + return "application/pdf" + + default: + return "application/octet-stream" // Default for unknown types + } +} diff --git a/service/http_client.go b/service/http_client.go index 64a361cf..b191ddd7 100644 --- a/service/http_client.go +++ b/service/http_client.go @@ -13,9 +13,8 @@ import ( ) var httpClient *http.Client -var impatientHTTPClient *http.Client -func init() { +func InitHttpClient() { if common.RelayTimeout == 0 { httpClient = &http.Client{} } else { @@ -23,20 +22,12 @@ func init() { Timeout: time.Duration(common.RelayTimeout) * time.Second, } } - - impatientHTTPClient = &http.Client{ - Timeout: 5 * time.Second, - } } func GetHttpClient() *http.Client { return httpClient } -func GetImpatientHttpClient() *http.Client { - return impatientHTTPClient -} - // NewProxyHttpClient 创建支持代理的 HTTP 客户端 func NewProxyHttpClient(proxyURL string) (*http.Client, error) { if proxyURL == "" { diff --git a/service/log_info_generate.go b/service/log_info_generate.go index 1edc9073..affae5fb 100644 --- a/service/log_info_generate.go +++ b/service/log_info_generate.go @@ -3,6 +3,7 @@ package service import ( "one-api/dto" relaycommon "one-api/relay/common" + "one-api/relay/helper" "github.com/gin-gonic/gin" ) @@ -63,3 +64,13 @@ func GenerateClaudeOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, info["cache_creation_ratio"] = cacheCreationRatio return info } + +func GenerateMjOtherInfo(priceData helper.PerCallPriceData) map[string]interface{} { + other := make(map[string]interface{}) + other["model_price"] = priceData.ModelPrice + other["group_ratio"] = priceData.GroupRatioInfo.GroupRatio + if priceData.GroupRatioInfo.HasSpecialRatio { + other["user_group_ratio"] = priceData.GroupRatioInfo.GroupSpecialRatio + } + return other +} diff --git a/service/midjourney.go b/service/midjourney.go index 635c29ae..83404bd9 100644 --- a/service/midjourney.go +++ b/service/midjourney.go @@ -3,7 +3,6 @@ package service import ( "context" "encoding/json" - "github.com/gin-gonic/gin" "io" "log" "net/http" @@ -15,6 +14,8 @@ import ( "strconv" "strings" "time" + + "github.com/gin-gonic/gin" ) func CoverActionToModelName(mjAction string) string { @@ -38,6 +39,10 @@ func GetMjRequestModel(relayMode int, midjRequest *dto.MidjourneyRequest) (strin switch relayMode { case relayconstant.RelayModeMidjourneyImagine: action = constant.MjActionImagine + case relayconstant.RelayModeMidjourneyVideo: + action = constant.MjActionVideo + case relayconstant.RelayModeMidjourneyEdits: + action = constant.MjActionEdits case relayconstant.RelayModeMidjourneyDescribe: action = constant.MjActionDescribe case relayconstant.RelayModeMidjourneyBlend: @@ -228,10 +233,7 @@ func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestU if err != nil { return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "read_response_body_failed", statusCode), nullBytes, err } - err = resp.Body.Close() - if err != nil { - return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_response_body_failed", statusCode), responseBody, err - } + common.CloseResponseBodyGracefully(resp) respStr := string(responseBody) log.Printf("respStr: %s", respStr) if respStr == "" { diff --git a/service/quota.go b/service/quota.go index da3dd9b9..bc3ef296 100644 --- a/service/quota.go +++ b/service/quota.go @@ -3,14 +3,16 @@ package service import ( "errors" "fmt" + "log" + "math" "one-api/common" - constant2 "one-api/constant" + "one-api/constant" "one-api/dto" "one-api/model" relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/setting" - "one-api/setting/operation_setting" + "one-api/setting/ratio_setting" "strings" "time" @@ -45,9 +47,9 @@ func calculateAudioQuota(info QuotaInfo) int { return int(quota.IntPart()) } - completionRatio := decimal.NewFromFloat(operation_setting.GetCompletionRatio(info.ModelName)) - audioRatio := decimal.NewFromFloat(operation_setting.GetAudioRatio(info.ModelName)) - audioCompletionRatio := decimal.NewFromFloat(operation_setting.GetAudioCompletionRatio(info.ModelName)) + completionRatio := decimal.NewFromFloat(ratio_setting.GetCompletionRatio(info.ModelName)) + audioRatio := decimal.NewFromFloat(ratio_setting.GetAudioRatio(info.ModelName)) + audioCompletionRatio := decimal.NewFromFloat(ratio_setting.GetAudioCompletionRatio(info.ModelName)) groupRatio := decimal.NewFromFloat(info.GroupRatio) modelRatio := decimal.NewFromFloat(info.ModelRatio) @@ -93,12 +95,21 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag textOutTokens := usage.OutputTokenDetails.TextTokens audioInputTokens := usage.InputTokenDetails.AudioTokens audioOutTokens := usage.OutputTokenDetails.AudioTokens - groupRatio := setting.GetGroupRatio(relayInfo.Group) - userGroupRatio, ok := setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.Group) - if ok { - groupRatio = userGroupRatio + groupRatio := ratio_setting.GetGroupRatio(relayInfo.UsingGroup) + modelRatio, _ := ratio_setting.GetModelRatio(modelName) + + autoGroup, exists := ctx.Get("auto_group") + if exists { + groupRatio = ratio_setting.GetGroupRatio(autoGroup.(string)) + log.Printf("final group ratio: %f", groupRatio) + relayInfo.UsingGroup = autoGroup.(string) + } + + actualGroupRatio := groupRatio + userGroupRatio, ok := ratio_setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.UsingGroup) + if ok { + actualGroupRatio = userGroupRatio } - modelRatio, _ := operation_setting.GetModelRatio(modelName) quotaInfo := QuotaInfo{ InputDetails: TokenDetails{ @@ -112,7 +123,7 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag ModelName: modelName, UsePrice: relayInfo.UsePrice, ModelRatio: modelRatio, - GroupRatio: groupRatio, + GroupRatio: actualGroupRatio, } quota := calculateAudioQuota(quotaInfo) @@ -134,8 +145,7 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag } func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string, - usage *dto.RealtimeUsage, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64, - modelPrice float64, usePrice bool, extraContent string) { + usage *dto.RealtimeUsage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) { useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix() textInputTokens := usage.InputTokenDetails.TextTokens @@ -145,15 +155,15 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod audioOutTokens := usage.OutputTokenDetails.AudioTokens tokenName := ctx.GetString("token_name") - completionRatio := decimal.NewFromFloat(operation_setting.GetCompletionRatio(modelName)) - audioRatio := decimal.NewFromFloat(operation_setting.GetAudioRatio(relayInfo.OriginModelName)) - audioCompletionRatio := decimal.NewFromFloat(operation_setting.GetAudioCompletionRatio(modelName)) + completionRatio := decimal.NewFromFloat(ratio_setting.GetCompletionRatio(modelName)) + audioRatio := decimal.NewFromFloat(ratio_setting.GetAudioRatio(relayInfo.OriginModelName)) + audioCompletionRatio := decimal.NewFromFloat(ratio_setting.GetAudioCompletionRatio(modelName)) + + modelRatio := priceData.ModelRatio + groupRatio := priceData.GroupRatioInfo.GroupRatio + modelPrice := priceData.ModelPrice + usePrice := priceData.UsePrice - actualGroupRatio := groupRatio - userGroupRatio, ok := setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.Group) - if ok { - actualGroupRatio = userGroupRatio - } quotaInfo := QuotaInfo{ InputDetails: TokenDetails{ TextTokens: textInputTokens, @@ -166,7 +176,7 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod ModelName: modelName, UsePrice: usePrice, ModelRatio: modelRatio, - GroupRatio: actualGroupRatio, + GroupRatio: groupRatio, } quota := calculateAudioQuota(quotaInfo) @@ -198,9 +208,9 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod logContent += ", " + extraContent } other := GenerateWssOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio, - completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, userGroupRatio) + completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio) model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.InputTokens, usage.OutputTokens, logModel, - tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other) + tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.UsingGroup, other) } func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, @@ -214,15 +224,25 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, tokenName := ctx.GetString("token_name") completionRatio := priceData.CompletionRatio modelRatio := priceData.ModelRatio - groupRatio := priceData.GroupRatio + groupRatio := priceData.GroupRatioInfo.GroupRatio modelPrice := priceData.ModelPrice - userGroupRatio := priceData.UserGroupRatio cacheRatio := priceData.CacheRatio cacheTokens := usage.PromptTokensDetails.CachedTokens cacheCreationRatio := priceData.CacheCreationRatio cacheCreationTokens := usage.PromptTokensDetails.CachedCreationTokens + if relayInfo.ChannelType == constant.ChannelTypeOpenRouter { + promptTokens -= cacheTokens + if cacheCreationTokens == 0 && priceData.CacheCreationRatio != 1 && usage.Cost != 0 { + maybeCacheCreationTokens := CalcOpenRouterCacheCreateTokens(*usage, priceData) + if promptTokens >= maybeCacheCreationTokens { + cacheCreationTokens = maybeCacheCreationTokens + } + } + promptTokens -= cacheCreationTokens + } + calculateQuota := 0.0 if !priceData.UsePrice { calculateQuota = float64(promptTokens) @@ -265,9 +285,30 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, } other := GenerateClaudeOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, - cacheTokens, cacheRatio, cacheCreationTokens, cacheCreationRatio, modelPrice, userGroupRatio) + cacheTokens, cacheRatio, cacheCreationTokens, cacheCreationRatio, modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio) model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, modelName, - tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other) + tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.UsingGroup, other) +} + +func CalcOpenRouterCacheCreateTokens(usage dto.Usage, priceData helper.PriceData) int { + if priceData.CacheCreationRatio == 1 { + return 0 + } + quotaPrice := priceData.ModelRatio / common.QuotaPerUnit + promptCacheCreatePrice := quotaPrice * priceData.CacheCreationRatio + promptCacheReadPrice := quotaPrice * priceData.CacheRatio + completionPrice := quotaPrice * priceData.CompletionRatio + + cost := usage.Cost + totalPromptTokens := float64(usage.PromptTokens) + completionTokens := float64(usage.CompletionTokens) + promptCacheReadTokens := float64(usage.PromptTokensDetails.CachedTokens) + + return int(math.Round((cost - + totalPromptTokens*quotaPrice + + promptCacheReadTokens*(quotaPrice-promptCacheReadPrice) - + completionTokens*completionPrice) / + (promptCacheCreatePrice - quotaPrice))) } func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, @@ -281,21 +322,15 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, audioOutTokens := usage.CompletionTokenDetails.AudioTokens tokenName := ctx.GetString("token_name") - completionRatio := decimal.NewFromFloat(operation_setting.GetCompletionRatio(relayInfo.OriginModelName)) - audioRatio := decimal.NewFromFloat(operation_setting.GetAudioRatio(relayInfo.OriginModelName)) - audioCompletionRatio := decimal.NewFromFloat(operation_setting.GetAudioCompletionRatio(relayInfo.OriginModelName)) + completionRatio := decimal.NewFromFloat(ratio_setting.GetCompletionRatio(relayInfo.OriginModelName)) + audioRatio := decimal.NewFromFloat(ratio_setting.GetAudioRatio(relayInfo.OriginModelName)) + audioCompletionRatio := decimal.NewFromFloat(ratio_setting.GetAudioCompletionRatio(relayInfo.OriginModelName)) modelRatio := priceData.ModelRatio - groupRatio := priceData.GroupRatio + groupRatio := priceData.GroupRatioInfo.GroupRatio modelPrice := priceData.ModelPrice usePrice := priceData.UsePrice - actualGroupRatio := groupRatio - userGroupRatio, ok := setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.Group) - if ok { - actualGroupRatio = userGroupRatio - } - quotaInfo := QuotaInfo{ InputDetails: TokenDetails{ TextTokens: textInputTokens, @@ -308,7 +343,7 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, ModelName: relayInfo.OriginModelName, UsePrice: usePrice, ModelRatio: modelRatio, - GroupRatio: actualGroupRatio, + GroupRatio: groupRatio, } quota := calculateAudioQuota(quotaInfo) @@ -348,9 +383,9 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, logContent += ", " + extraContent } other := GenerateAudioOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio, - completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, userGroupRatio) + completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio) model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.PromptTokens, usage.CompletionTokens, logModel, - tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other) + tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.UsingGroup, other) } func PreConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, quota int) error { @@ -412,7 +447,7 @@ func checkAndSendQuotaNotify(relayInfo *relaycommon.RelayInfo, quota int, preCon gopool.Go(func() { userSetting := relayInfo.UserSetting threshold := common.QuotaRemindThreshold - if userCustomThreshold, ok := userSetting[constant2.UserSettingQuotaWarningThreshold]; ok { + if userCustomThreshold, ok := userSetting[constant.UserSettingQuotaWarningThreshold]; ok { threshold = int(userCustomThreshold.(float64)) } diff --git a/service/token_counter.go b/service/token_counter.go index 82de0a05..eed5b5ca 100644 --- a/service/token_counter.go +++ b/service/token_counter.go @@ -101,7 +101,7 @@ func getImageToken(info *relaycommon.RelayInfo, imageUrl *dto.MessageImageUrl, m if !constant.GetMediaToken { return 3 * baseTokens, nil } - if info.ChannelType == common.ChannelTypeGemini || info.ChannelType == common.ChannelTypeVertexAi || info.ChannelType == common.ChannelTypeAnthropic { + if info.ChannelType == constant.ChannelTypeGemini || info.ChannelType == constant.ChannelTypeVertexAi || info.ChannelType == constant.ChannelTypeAnthropic { return 3 * baseTokens, nil } var config image.Config @@ -171,10 +171,7 @@ func CountTokenChatRequest(info *relaycommon.RelayInfo, request dto.GeneralOpenA countStr += fmt.Sprintf("%v", tool.Function.Parameters) } } - toolTokens, err := CountTokenInput(countStr, request.Model) - if err != nil { - return 0, err - } + toolTokens := CountTokenInput(countStr, request.Model) tkm += 8 tkm += toolTokens } @@ -194,10 +191,7 @@ func CountTokenClaudeRequest(request dto.ClaudeRequest, model string) (int, erro // Count tokens in system message if request.System != "" { - systemTokens, err := CountTokenInput(request.System, model) - if err != nil { - return 0, err - } + systemTokens := CountTokenInput(request.System, model) tkm += systemTokens } @@ -296,10 +290,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent, switch request.Type { case dto.RealtimeEventTypeSessionUpdate: if request.Session != nil { - msgTokens, err := CountTextToken(request.Session.Instructions, model) - if err != nil { - return 0, 0, err - } + msgTokens := CountTextToken(request.Session.Instructions, model) textToken += msgTokens } case dto.RealtimeEventResponseAudioDelta: @@ -311,10 +302,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent, audioToken += atk case dto.RealtimeEventResponseAudioTranscriptionDelta, dto.RealtimeEventResponseFunctionCallArgumentsDelta: // count text token - tkm, err := CountTextToken(request.Delta, model) - if err != nil { - return 0, 0, fmt.Errorf("error counting text token: %v", err) - } + tkm := CountTextToken(request.Delta, model) textToken += tkm case dto.RealtimeEventInputAudioBufferAppend: // count audio token @@ -329,10 +317,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent, case "message": for _, content := range request.Item.Content { if content.Type == "input_text" { - tokens, err := CountTextToken(content.Text, model) - if err != nil { - return 0, 0, err - } + tokens := CountTextToken(content.Text, model) textToken += tokens } } @@ -343,10 +328,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent, if !info.IsFirstRequest { if info.RealtimeTools != nil && len(info.RealtimeTools) > 0 { for _, tool := range info.RealtimeTools { - toolTokens, err := CountTokenInput(tool, model) - if err != nil { - return 0, 0, err - } + toolTokens := CountTokenInput(tool, model) textToken += 8 textToken += toolTokens } @@ -409,7 +391,7 @@ func CountTokenMessages(info *relaycommon.RelayInfo, messages []dto.Message, mod return tokenNum, nil } -func CountTokenInput(input any, model string) (int, error) { +func CountTokenInput(input any, model string) int { switch v := input.(type) { case string: return CountTextToken(v, model) @@ -432,13 +414,13 @@ func CountTokenInput(input any, model string) (int, error) { func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice, model string) int { tokens := 0 for _, message := range messages { - tkm, _ := CountTokenInput(message.Delta.GetContentString(), model) + tkm := CountTokenInput(message.Delta.GetContentString(), model) tokens += tkm if message.Delta.ToolCalls != nil { for _, tool := range message.Delta.ToolCalls { - tkm, _ := CountTokenInput(tool.Function.Name, model) + tkm := CountTokenInput(tool.Function.Name, model) tokens += tkm - tkm, _ = CountTokenInput(tool.Function.Arguments, model) + tkm = CountTokenInput(tool.Function.Arguments, model) tokens += tkm } } @@ -446,9 +428,9 @@ func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice, return tokens } -func CountTTSToken(text string, model string) (int, error) { +func CountTTSToken(text string, model string) int { if strings.HasPrefix(model, "tts") { - return utf8.RuneCountInString(text), nil + return utf8.RuneCountInString(text) } else { return CountTextToken(text, model) } @@ -483,8 +465,10 @@ func CountAudioTokenOutput(audioBase64 string, audioFormat string) (int, error) //} // CountTextToken 统计文本的token数量,仅当文本包含敏感词,返回错误,同时返回token数量 -func CountTextToken(text string, model string) (int, error) { - var err error +func CountTextToken(text string, model string) int { + if text == "" { + return 0 + } tokenEncoder := getTokenEncoder(model) - return getTokenNum(tokenEncoder, text), err + return getTokenNum(tokenEncoder, text) } diff --git a/service/usage_helpr.go b/service/usage_helpr.go index c52e1e15..ca9c0830 100644 --- a/service/usage_helpr.go +++ b/service/usage_helpr.go @@ -16,13 +16,13 @@ import ( // return 0, errors.New("unknown relay mode") //} -func ResponseText2Usage(responseText string, modeName string, promptTokens int) (*dto.Usage, error) { +func ResponseText2Usage(responseText string, modeName string, promptTokens int) *dto.Usage { usage := &dto.Usage{} usage.PromptTokens = promptTokens - ctkm, err := CountTextToken(responseText, modeName) + ctkm := CountTextToken(responseText, modeName) usage.CompletionTokens = ctkm usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens - return usage, err + return usage } func ValidUsage(usage *dto.Usage) bool { diff --git a/service/webhook.go b/service/webhook.go index ad2967eb..8faccda3 100644 --- a/service/webhook.go +++ b/service/webhook.go @@ -101,7 +101,7 @@ func SendWebhookNotify(webhookURL string, secret string, data dto.Notify) error } // 发送请求 - client := GetImpatientHttpClient() + client := GetHttpClient() resp, err = client.Do(req) if err != nil { return fmt.Errorf("failed to send webhook request: %v", err) diff --git a/setting/auto_group.go b/setting/auto_group.go new file mode 100644 index 00000000..5a87ae56 --- /dev/null +++ b/setting/auto_group.go @@ -0,0 +1,31 @@ +package setting + +import "encoding/json" + +var AutoGroups = []string{ + "default", +} + +var DefaultUseAutoGroup = false + +func ContainsAutoGroup(group string) bool { + for _, autoGroup := range AutoGroups { + if autoGroup == group { + return true + } + } + return false +} + +func UpdateAutoGroupsByJsonString(jsonString string) error { + AutoGroups = make([]string, 0) + return json.Unmarshal([]byte(jsonString), &AutoGroups) +} + +func AutoGroups2JsonString() string { + jsonBytes, err := json.Marshal(AutoGroups) + if err != nil { + return "[]" + } + return string(jsonBytes) +} diff --git a/setting/console_setting/validation.go b/setting/console_setting/validation.go index 51a84849..fda6453d 100644 --- a/setting/console_setting/validation.go +++ b/setting/console_setting/validation.go @@ -7,6 +7,7 @@ import ( "regexp" "strings" "time" + "sort" ) var ( @@ -210,8 +211,23 @@ func validateFAQ(faqStr string) error { return nil } +func getPublishTime(item map[string]interface{}) time.Time { + if v, ok := item["publishDate"]; ok { + if s, ok2 := v.(string); ok2 { + if t, err := time.Parse(time.RFC3339, s); err == nil { + return t + } + } + } + return time.Time{} +} + func GetAnnouncements() []map[string]interface{} { - return getJSONList(GetConsoleSetting().Announcements) + list := getJSONList(GetConsoleSetting().Announcements) + sort.SliceStable(list, func(i, j int) bool { + return getPublishTime(list[i]).After(getPublishTime(list[j])) + }) + return list } func GetFAQ() []map[string]interface{} { diff --git a/setting/operation_setting/tools.go b/setting/operation_setting/tools.go index 3e1af99e..a401b923 100644 --- a/setting/operation_setting/tools.go +++ b/setting/operation_setting/tools.go @@ -17,6 +17,8 @@ const ( const ( // Gemini Audio Input Price Gemini25FlashPreviewInputAudioPrice = 1.00 + Gemini25FlashProductionInputAudioPrice = 1.00 // for `gemini-2.5-flash` + Gemini25FlashLitePreviewInputAudioPrice = 0.50 Gemini25FlashNativeAudioInputAudioPrice = 3.00 Gemini20FlashInputAudioPrice = 0.70 ) @@ -64,10 +66,14 @@ func GetFileSearchPricePerThousand() float64 { } func GetGeminiInputAudioPricePerMillionTokens(modelName string) float64 { - if strings.HasPrefix(modelName, "gemini-2.5-flash-preview") { - return Gemini25FlashPreviewInputAudioPrice - } else if strings.HasPrefix(modelName, "gemini-2.5-flash-preview-native-audio") { + if strings.HasPrefix(modelName, "gemini-2.5-flash-preview-native-audio") { return Gemini25FlashNativeAudioInputAudioPrice + } else if strings.HasPrefix(modelName, "gemini-2.5-flash-preview-lite") { + return Gemini25FlashLitePreviewInputAudioPrice + } else if strings.HasPrefix(modelName, "gemini-2.5-flash-preview") { + return Gemini25FlashPreviewInputAudioPrice + } else if strings.HasPrefix(modelName, "gemini-2.5-flash") { + return Gemini25FlashProductionInputAudioPrice } else if strings.HasPrefix(modelName, "gemini-2.0-flash") { return Gemini20FlashInputAudioPrice } diff --git a/setting/payment.go b/setting/payment.go index f50723c3..3fc0f14a 100644 --- a/setting/payment.go +++ b/setting/payment.go @@ -1,8 +1,45 @@ package setting +import "encoding/json" + var PayAddress = "" var CustomCallbackAddress = "" var EpayId = "" var EpayKey = "" var Price = 7.3 var MinTopUp = 1 + +var PayMethods = []map[string]string{ + { + "name": "支付宝", + "color": "rgba(var(--semi-blue-5), 1)", + "type": "alipay", + }, + { + "name": "微信", + "color": "rgba(var(--semi-green-5), 1)", + "type": "wxpay", + }, +} + +func UpdatePayMethodsByJsonString(jsonString string) error { + PayMethods = make([]map[string]string, 0) + return json.Unmarshal([]byte(jsonString), &PayMethods) +} + +func PayMethods2JsonString() string { + jsonBytes, err := json.Marshal(PayMethods) + if err != nil { + return "[]" + } + return string(jsonBytes) +} + +func ContainsPayMethod(method string) bool { + for _, payMethod := range PayMethods { + if payMethod["type"] == method { + return true + } + } + return false +} diff --git a/setting/operation_setting/cache_ratio.go b/setting/ratio_setting/cache_ratio.go similarity index 90% rename from setting/operation_setting/cache_ratio.go rename to setting/ratio_setting/cache_ratio.go index ec0c766d..51d473a8 100644 --- a/setting/operation_setting/cache_ratio.go +++ b/setting/ratio_setting/cache_ratio.go @@ -1,4 +1,4 @@ -package operation_setting +package ratio_setting import ( "encoding/json" @@ -85,7 +85,11 @@ func UpdateCacheRatioByJSONString(jsonStr string) error { cacheRatioMapMutex.Lock() defer cacheRatioMapMutex.Unlock() cacheRatioMap = make(map[string]float64) - return json.Unmarshal([]byte(jsonStr), &cacheRatioMap) + err := json.Unmarshal([]byte(jsonStr), &cacheRatioMap) + if err == nil { + InvalidateExposedDataCache() + } + return err } // GetCacheRatio returns the cache ratio for a model @@ -106,3 +110,13 @@ func GetCreateCacheRatio(name string) (float64, bool) { } return ratio, true } + +func GetCacheRatioCopy() map[string]float64 { + cacheRatioMapMutex.RLock() + defer cacheRatioMapMutex.RUnlock() + copyMap := make(map[string]float64, len(cacheRatioMap)) + for k, v := range cacheRatioMap { + copyMap[k] = v + } + return copyMap +} diff --git a/setting/ratio_setting/expose_ratio.go b/setting/ratio_setting/expose_ratio.go new file mode 100644 index 00000000..8fca0bcb --- /dev/null +++ b/setting/ratio_setting/expose_ratio.go @@ -0,0 +1,17 @@ +package ratio_setting + +import "sync/atomic" + +var exposeRatioEnabled atomic.Bool + +func init() { + exposeRatioEnabled.Store(false) +} + +func SetExposeRatioEnabled(enabled bool) { + exposeRatioEnabled.Store(enabled) +} + +func IsExposeRatioEnabled() bool { + return exposeRatioEnabled.Load() +} \ No newline at end of file diff --git a/setting/ratio_setting/exposed_cache.go b/setting/ratio_setting/exposed_cache.go new file mode 100644 index 00000000..9e5b6c30 --- /dev/null +++ b/setting/ratio_setting/exposed_cache.go @@ -0,0 +1,55 @@ +package ratio_setting + +import ( + "sync" + "sync/atomic" + "time" + + "github.com/gin-gonic/gin" +) + +const exposedDataTTL = 30 * time.Second + +type exposedCache struct { + data gin.H + expiresAt time.Time +} + +var ( + exposedData atomic.Value + rebuildMu sync.Mutex +) + +func InvalidateExposedDataCache() { + exposedData.Store((*exposedCache)(nil)) +} + +func cloneGinH(src gin.H) gin.H { + dst := make(gin.H, len(src)) + for k, v := range src { + dst[k] = v + } + return dst +} + +func GetExposedData() gin.H { + if c, ok := exposedData.Load().(*exposedCache); ok && c != nil && time.Now().Before(c.expiresAt) { + return cloneGinH(c.data) + } + rebuildMu.Lock() + defer rebuildMu.Unlock() + if c, ok := exposedData.Load().(*exposedCache); ok && c != nil && time.Now().Before(c.expiresAt) { + return cloneGinH(c.data) + } + newData := gin.H{ + "model_ratio": GetModelRatioCopy(), + "completion_ratio": GetCompletionRatioCopy(), + "cache_ratio": GetCacheRatioCopy(), + "model_price": GetModelPriceCopy(), + } + exposedData.Store(&exposedCache{ + data: newData, + expiresAt: time.Now().Add(exposedDataTTL), + }) + return cloneGinH(newData) +} \ No newline at end of file diff --git a/setting/group_ratio.go b/setting/ratio_setting/group_ratio.go similarity index 93% rename from setting/group_ratio.go rename to setting/ratio_setting/group_ratio.go index 28dbd167..86f4a8d1 100644 --- a/setting/group_ratio.go +++ b/setting/ratio_setting/group_ratio.go @@ -1,4 +1,4 @@ -package setting +package ratio_setting import ( "encoding/json" @@ -73,15 +73,15 @@ func GetGroupRatio(name string) float64 { return ratio } -func GetGroupGroupRatio(group, name string) (float64, bool) { +func GetGroupGroupRatio(userGroup, usingGroup string) (float64, bool) { groupGroupRatioMutex.RLock() defer groupGroupRatioMutex.RUnlock() - gp, ok := GroupGroupRatio[group] + gp, ok := GroupGroupRatio[userGroup] if !ok { return -1, false } - ratio, ok := gp[name] + ratio, ok := gp[usingGroup] if !ok { return -1, false } diff --git a/setting/operation_setting/model-ratio.go b/setting/ratio_setting/model_ratio.go similarity index 87% rename from setting/operation_setting/model-ratio.go rename to setting/ratio_setting/model_ratio.go index 700a7c4e..033b07a0 100644 --- a/setting/operation_setting/model-ratio.go +++ b/setting/ratio_setting/model_ratio.go @@ -1,8 +1,9 @@ -package operation_setting +package ratio_setting import ( "encoding/json" "one-api/common" + "one-api/setting/operation_setting" "strings" "sync" ) @@ -139,9 +140,17 @@ var defaultModelRatio = map[string]float64{ "gemini-2.0-flash": 0.05, "gemini-2.5-pro-exp-03-25": 0.625, "gemini-2.5-pro-preview-03-25": 0.625, + "gemini-2.5-pro": 0.625, "gemini-2.5-flash-preview-04-17": 0.075, "gemini-2.5-flash-preview-04-17-thinking": 0.075, "gemini-2.5-flash-preview-04-17-nothinking": 0.075, + "gemini-2.5-flash-preview-05-20": 0.075, + "gemini-2.5-flash-preview-05-20-thinking": 0.075, + "gemini-2.5-flash-preview-05-20-nothinking": 0.075, + "gemini-2.5-flash-thinking-*": 0.075, // 用于为后续所有2.5 flash thinking budget 模型设置默认倍率 + "gemini-2.5-pro-thinking-*": 0.625, // 用于为后续所有2.5 pro thinking budget 模型设置默认倍率 + "gemini-2.5-flash-lite-preview-06-17": 0.05, + "gemini-2.5-flash": 0.15, "text-embedding-004": 0.001, "chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens @@ -222,7 +231,9 @@ var defaultModelPrice = map[string]float64{ "dall-e-3": 0.04, "imagen-3.0-generate-002": 0.03, "gpt-4-gizmo-*": 0.1, + "mj_video": 0.8, "mj_imagine": 0.1, + "mj_edits": 0.1, "mj_variation": 0.1, "mj_reroll": 0.1, "mj_blend": 0.1, @@ -311,7 +322,11 @@ func UpdateModelPriceByJSONString(jsonStr string) error { modelPriceMapMutex.Lock() defer modelPriceMapMutex.Unlock() modelPriceMap = make(map[string]float64) - return json.Unmarshal([]byte(jsonStr), &modelPriceMap) + err := json.Unmarshal([]byte(jsonStr), &modelPriceMap) + if err == nil { + InvalidateExposedDataCache() + } + return err } // GetModelPrice 返回模型的价格,如果模型不存在则返回-1,false @@ -339,19 +354,33 @@ func UpdateModelRatioByJSONString(jsonStr string) error { modelRatioMapMutex.Lock() defer modelRatioMapMutex.Unlock() modelRatioMap = make(map[string]float64) - return json.Unmarshal([]byte(jsonStr), &modelRatioMap) + err := json.Unmarshal([]byte(jsonStr), &modelRatioMap) + if err == nil { + InvalidateExposedDataCache() + } + return err +} + +// 处理带有思考预算的模型名称,方便统一定价 +func handleThinkingBudgetModel(name, prefix, wildcard string) string { + if strings.HasPrefix(name, prefix) && strings.Contains(name, "-thinking-") { + return wildcard + } + return name } func GetModelRatio(name string) (float64, bool) { modelRatioMapMutex.RLock() defer modelRatioMapMutex.RUnlock() + name = handleThinkingBudgetModel(name, "gemini-2.5-flash", "gemini-2.5-flash-thinking-*") + name = handleThinkingBudgetModel(name, "gemini-2.5-pro", "gemini-2.5-pro-thinking-*") if strings.HasPrefix(name, "gpt-4-gizmo") { name = "gpt-4-gizmo-*" } ratio, ok := modelRatioMap[name] if !ok { - return 37.5, SelfUseModeEnabled + return 37.5, operation_setting.SelfUseModeEnabled } return ratio, true } @@ -389,13 +418,22 @@ func UpdateCompletionRatioByJSONString(jsonStr string) error { CompletionRatioMutex.Lock() defer CompletionRatioMutex.Unlock() CompletionRatio = make(map[string]float64) - return json.Unmarshal([]byte(jsonStr), &CompletionRatio) + err := json.Unmarshal([]byte(jsonStr), &CompletionRatio) + if err == nil { + InvalidateExposedDataCache() + } + return err } func GetCompletionRatio(name string) float64 { CompletionRatioMutex.RLock() defer CompletionRatioMutex.RUnlock() - + if strings.HasPrefix(name, "gpt-4-gizmo") { + name = "gpt-4-gizmo-*" + } + if strings.HasPrefix(name, "gpt-4o-gizmo") { + name = "gpt-4o-gizmo-*" + } if strings.Contains(name, "/") { if ratio, ok := CompletionRatio[name]; ok { return ratio @@ -413,12 +451,6 @@ func GetCompletionRatio(name string) float64 { func getHardcodedCompletionModelRatio(name string) (float64, bool) { lowercaseName := strings.ToLower(name) - if strings.HasPrefix(name, "gpt-4-gizmo") { - name = "gpt-4-gizmo-*" - } - if strings.HasPrefix(name, "gpt-4o-gizmo") { - name = "gpt-4o-gizmo-*" - } if strings.HasPrefix(name, "gpt-4") && !strings.HasSuffix(name, "-all") && !strings.HasSuffix(name, "-gizmo-*") { if strings.HasPrefix(name, "gpt-4o") { if name == "gpt-4o-2024-05-13" { @@ -470,14 +502,22 @@ func getHardcodedCompletionModelRatio(name string) (float64, bool) { return 4, true } else if strings.HasPrefix(name, "gemini-2.0") { return 4, true - } else if strings.HasPrefix(name, "gemini-2.5-pro-preview") { - return 8, true - } else if strings.HasPrefix(name, "gemini-2.5-flash-preview") { - if strings.HasSuffix(name, "-nothinking") { - return 4, false - } else { - return 3.5 / 0.6, false + } else if strings.HasPrefix(name, "gemini-2.5-pro") { // 移除preview来增加兼容性,这里假设正式版的倍率和preview一致 + return 8, false + } else if strings.HasPrefix(name, "gemini-2.5-flash") { // 处理不同的flash模型倍率 + if strings.HasPrefix(name, "gemini-2.5-flash-preview") { + if strings.HasSuffix(name, "-nothinking") { + return 4, false + } + return 3.5 / 0.15, false } + if strings.HasPrefix(name, "gemini-2.5-flash-lite") { + if strings.HasPrefix(name, "gemini-2.5-flash-lite-preview") { + return 4, false + } + return 4, false + } + return 2.5 / 0.3, true } return 4, false } @@ -593,3 +633,33 @@ func GetImageRatio(name string) (float64, bool) { } return ratio, true } + +func GetModelRatioCopy() map[string]float64 { + modelRatioMapMutex.RLock() + defer modelRatioMapMutex.RUnlock() + copyMap := make(map[string]float64, len(modelRatioMap)) + for k, v := range modelRatioMap { + copyMap[k] = v + } + return copyMap +} + +func GetModelPriceCopy() map[string]float64 { + modelPriceMapMutex.RLock() + defer modelPriceMapMutex.RUnlock() + copyMap := make(map[string]float64, len(modelPriceMap)) + for k, v := range modelPriceMap { + copyMap[k] = v + } + return copyMap +} + +func GetCompletionRatioCopy() map[string]float64 { + CompletionRatioMutex.RLock() + defer CompletionRatioMutex.RUnlock() + copyMap := make(map[string]float64, len(CompletionRatio)) + for k, v := range CompletionRatio { + copyMap[k] = v + } + return copyMap +} diff --git a/setting/user_usable_group.go b/setting/user_usable_group.go index 7082b683..fdf2f723 100644 --- a/setting/user_usable_group.go +++ b/setting/user_usable_group.go @@ -50,3 +50,10 @@ func GroupInUserUsableGroups(groupName string) bool { _, ok := userUsableGroups[groupName] return ok } + +func GetUsableGroupDescription(groupName string) string { + if desc, ok := userUsableGroups[groupName]; ok { + return desc + } + return groupName +} diff --git a/types/set.go b/types/set.go new file mode 100644 index 00000000..db6b0272 --- /dev/null +++ b/types/set.go @@ -0,0 +1,42 @@ +package types + +type Set[T comparable] struct { + items map[T]struct{} +} + +// NewSet 创建并返回一个新的 Set +func NewSet[T comparable]() *Set[T] { + return &Set[T]{ + items: make(map[T]struct{}), + } +} + +func (s *Set[T]) Add(item T) { + s.items[item] = struct{}{} +} + +// Remove 从 Set 中移除一个元素 +func (s *Set[T]) Remove(item T) { + delete(s.items, item) +} + +// Contains 检查 Set 是否包含某个元素 +func (s *Set[T]) Contains(item T) bool { + _, exists := s.items[item] + return exists +} + +// Len 返回 Set 中元素的数量 +func (s *Set[T]) Len() int { + return len(s.items) +} + +// Items 返回 Set 中所有元素组成的切片 +// 注意:由于 map 的无序性,返回的切片元素顺序是随机的 +func (s *Set[T]) Items() []T { + items := make([]T, 0, s.Len()) + for item := range s.items { + items = append(items, item) + } + return items +} diff --git a/web/src/components/auth/LoginForm.js b/web/src/components/auth/LoginForm.js index c8847a33..ae7fc0fc 100644 --- a/web/src/components/auth/LoginForm.js +++ b/web/src/components/auth/LoginForm.js @@ -34,20 +34,20 @@ import LinuxDoIcon from '../common/logo/LinuxDoIcon.js'; import { useTranslation } from 'react-i18next'; const LoginForm = () => { + let navigate = useNavigate(); + const { t } = useTranslation(); const [inputs, setInputs] = useState({ username: '', password: '', wechat_verification_code: '', }); + const { username, password } = inputs; const [searchParams, setSearchParams] = useSearchParams(); const [submitted, setSubmitted] = useState(false); - const { username, password } = inputs; const [userState, userDispatch] = useContext(UserContext); const [turnstileEnabled, setTurnstileEnabled] = useState(false); const [turnstileSiteKey, setTurnstileSiteKey] = useState(''); const [turnstileToken, setTurnstileToken] = useState(''); - let navigate = useNavigate(); - const [status, setStatus] = useState({}); const [showWeChatLoginModal, setShowWeChatLoginModal] = useState(false); const [showEmailLogin, setShowEmailLogin] = useState(false); const [wechatLoading, setWechatLoading] = useState(false); @@ -59,7 +59,6 @@ const LoginForm = () => { const [resetPasswordLoading, setResetPasswordLoading] = useState(false); const [otherLoginOptionsLoading, setOtherLoginOptionsLoading] = useState(false); const [wechatCodeSubmitLoading, setWechatCodeSubmitLoading] = useState(false); - const { t } = useTranslation(); const logo = getLogo(); const systemName = getSystemName(); @@ -69,19 +68,22 @@ const LoginForm = () => { localStorage.setItem('aff', affCode); } + const [status] = useState(() => { + const savedStatus = localStorage.getItem('status'); + return savedStatus ? JSON.parse(savedStatus) : {}; + }); + + useEffect(() => { + if (status.turnstile_check) { + setTurnstileEnabled(true); + setTurnstileSiteKey(status.turnstile_site_key); + } + }, [status]); + useEffect(() => { if (searchParams.get('expired')) { showError(t('未登录或登录已过期,请重新登录')); } - let status = localStorage.getItem('status'); - if (status) { - status = JSON.parse(status); - setStatus(status); - if (status.turnstile_check) { - setTurnstileEnabled(true); - setTurnstileSiteKey(status.turnstile_site_key); - } - } }, []); const onWeChatLoginClicked = () => { @@ -356,9 +358,19 @@ const LoginForm = () => { -
- {t('没有账户?')} {t('注册')} -
+ {!status.self_use_mode_enabled && ( +
+ + {t('没有账户?')}{' '} + + {t('注册')} + + +
+ )} @@ -387,7 +399,6 @@ const LoginForm = () => { placeholder={t('请输入您的用户名或邮箱地址')} name="username" size="large" - className="!rounded-md" onChange={(value) => handleChange('username', value)} prefix={} /> @@ -399,7 +410,6 @@ const LoginForm = () => { name="password" mode="password" size="large" - className="!rounded-md" onChange={(value) => handleChange('password', value)} prefix={} /> @@ -451,9 +461,19 @@ const LoginForm = () => { )} -
- {t('没有账户?')} {t('注册')} -
+ {!status.self_use_mode_enabled && ( +
+ + {t('没有账户?')}{' '} + + {t('注册')} + + +
+ )} @@ -499,8 +519,11 @@ const LoginForm = () => { }; return ( -
-
+
+ {/* 背景模糊晕染球 */} +
+
+
{showEmailLogin || !(status.github_oauth || status.oidc_enabled || status.wechat_login || status.linuxdo_oauth || status.telegram_oauth) ? renderEmailLoginForm() : renderOAuthOptions()} diff --git a/web/src/components/auth/PasswordResetConfirm.js b/web/src/components/auth/PasswordResetConfirm.js index e2d9a9ad..5fbd1fc5 100644 --- a/web/src/components/auth/PasswordResetConfirm.js +++ b/web/src/components/auth/PasswordResetConfirm.js @@ -78,8 +78,11 @@ const PasswordResetConfirm = () => { } return ( -
-
+
+ {/* 背景模糊晕染球 */} +
+
+
@@ -110,7 +113,6 @@ const PasswordResetConfirm = () => { label={t('邮箱')} name="email" size="large" - className="!rounded-md" disabled={true} prefix={} placeholder={email ? '' : t('等待获取邮箱信息...')} @@ -122,7 +124,6 @@ const PasswordResetConfirm = () => { label={t('新密码')} name="newPassword" size="large" - className="!rounded-md" disabled={true} prefix={} suffix={ diff --git a/web/src/components/auth/PasswordResetForm.js b/web/src/components/auth/PasswordResetForm.js index 29c3d477..033989e0 100644 --- a/web/src/components/auth/PasswordResetForm.js +++ b/web/src/components/auth/PasswordResetForm.js @@ -78,8 +78,11 @@ const PasswordResetForm = () => { } return ( -
-
+
+ {/* 背景模糊晕染球 */} +
+
+
@@ -99,7 +102,6 @@ const PasswordResetForm = () => { placeholder={t('请输入您的邮箱地址')} name="email" size="large" - className="!rounded-md" value={email} onChange={handleChange} prefix={} diff --git a/web/src/components/auth/RegisterForm.js b/web/src/components/auth/RegisterForm.js index 0d9c8982..9d213a60 100644 --- a/web/src/components/auth/RegisterForm.js +++ b/web/src/components/auth/RegisterForm.js @@ -35,6 +35,7 @@ import { UserContext } from '../../context/User/index.js'; import { useTranslation } from 'react-i18next'; const RegisterForm = () => { + let navigate = useNavigate(); const { t } = useTranslation(); const [inputs, setInputs] = useState({ username: '', @@ -45,15 +46,12 @@ const RegisterForm = () => { wechat_verification_code: '', }); const { username, password, password2 } = inputs; - const [showEmailVerification, setShowEmailVerification] = useState(false); const [userState, userDispatch] = useContext(UserContext); const [turnstileEnabled, setTurnstileEnabled] = useState(false); const [turnstileSiteKey, setTurnstileSiteKey] = useState(''); const [turnstileToken, setTurnstileToken] = useState(''); - const [loading, setLoading] = useState(false); const [showWeChatLoginModal, setShowWeChatLoginModal] = useState(false); const [showEmailRegister, setShowEmailRegister] = useState(false); - const [status, setStatus] = useState({}); const [wechatLoading, setWechatLoading] = useState(false); const [githubLoading, setGithubLoading] = useState(false); const [oidcLoading, setOidcLoading] = useState(false); @@ -63,7 +61,6 @@ const RegisterForm = () => { const [verificationCodeLoading, setVerificationCodeLoading] = useState(false); const [otherRegisterOptionsLoading, setOtherRegisterOptionsLoading] = useState(false); const [wechatCodeSubmitLoading, setWechatCodeSubmitLoading] = useState(false); - let navigate = useNavigate(); const logo = getLogo(); const systemName = getSystemName(); @@ -73,18 +70,22 @@ const RegisterForm = () => { localStorage.setItem('aff', affCode); } + const [status] = useState(() => { + const savedStatus = localStorage.getItem('status'); + return savedStatus ? JSON.parse(savedStatus) : {}; + }); + + const [showEmailVerification, setShowEmailVerification] = useState(() => { + return status.email_verification ?? false; + }); + useEffect(() => { - let status = localStorage.getItem('status'); - if (status) { - status = JSON.parse(status); - setStatus(status); - setShowEmailVerification(status.email_verification); - if (status.turnstile_check) { - setTurnstileEnabled(true); - setTurnstileSiteKey(status.turnstile_site_key); - } + setShowEmailVerification(status.email_verification); + if (status.turnstile_check) { + setTurnstileEnabled(true); + setTurnstileSiteKey(status.turnstile_site_key); } - }, []); + }, [status]); const onWeChatLoginClicked = () => { setWechatLoading(true); @@ -393,7 +394,6 @@ const RegisterForm = () => { placeholder={t('请输入用户名')} name="username" size="large" - className="!rounded-md" onChange={(value) => handleChange('username', value)} prefix={} /> @@ -405,7 +405,6 @@ const RegisterForm = () => { name="password" mode="password" size="large" - className="!rounded-md" onChange={(value) => handleChange('password', value)} prefix={} /> @@ -417,7 +416,6 @@ const RegisterForm = () => { name="password2" mode="password" size="large" - className="!rounded-md" onChange={(value) => handleChange('password2', value)} prefix={} /> @@ -431,7 +429,6 @@ const RegisterForm = () => { name="email" type="email" size="large" - className="!rounded-md" onChange={(value) => handleChange('email', value)} prefix={} suffix={ @@ -439,7 +436,6 @@ const RegisterForm = () => { onClick={sendVerificationCode} loading={verificationCodeLoading} size="small" - className="!rounded-md mr-2" > {t('获取验证码')} @@ -451,7 +447,6 @@ const RegisterForm = () => { placeholder={t('输入验证码')} name="verification_code" size="large" - className="!rounded-md" onChange={(value) => handleChange('verification_code', value)} prefix={} /> @@ -541,8 +536,11 @@ const RegisterForm = () => { }; return ( -
-
+
+ {/* 背景模糊晕染球 */} +
+
+
{showEmailRegister || !(status.github_oauth || status.oidc_enabled || status.wechat_login || status.linuxdo_oauth || status.telegram_oauth) ? renderEmailRegisterForm() : renderOAuthOptions()} diff --git a/web/src/components/layout/HeaderBar.js b/web/src/components/layout/HeaderBar.js index 6317c576..b7425645 100644 --- a/web/src/components/layout/HeaderBar.js +++ b/web/src/components/layout/HeaderBar.js @@ -28,6 +28,7 @@ import { Tag, Typography, Skeleton, + Badge, } from '@douyinfe/semi-ui'; import { StatusContext } from '../../context/Status/index.js'; import { useStyle, styleActions } from '../../context/Style/index.js'; @@ -43,6 +44,7 @@ const HeaderBar = () => { const [mobileMenuOpen, setMobileMenuOpen] = useState(false); const location = useLocation(); const [noticeVisible, setNoticeVisible] = useState(false); + const [unreadCount, setUnreadCount] = useState(0); const systemName = getSystemName(); const logo = getLogo(); @@ -53,9 +55,44 @@ const HeaderBar = () => { const docsLink = statusState?.status?.docs_link || ''; const isDemoSiteMode = statusState?.status?.demo_site_enabled || false; + const isConsoleRoute = location.pathname.startsWith('/console'); + const theme = useTheme(); const setTheme = useSetTheme(); + const announcements = statusState?.status?.announcements || []; + + const getAnnouncementKey = (a) => `${a?.publishDate || ''}-${(a?.content || '').slice(0, 30)}`; + + const calculateUnreadCount = () => { + if (!announcements.length) return 0; + let readKeys = []; + try { + readKeys = JSON.parse(localStorage.getItem('notice_read_keys')) || []; + } catch (_) { + readKeys = []; + } + const readSet = new Set(readKeys); + return announcements.filter((a) => !readSet.has(getAnnouncementKey(a))).length; + }; + + const getUnreadKeys = () => { + if (!announcements.length) return []; + let readKeys = []; + try { + readKeys = JSON.parse(localStorage.getItem('notice_read_keys')) || []; + } catch (_) { + readKeys = []; + } + const readSet = new Set(readKeys); + return announcements.filter((a) => !readSet.has(getAnnouncementKey(a))).map(getAnnouncementKey); + }; + + useEffect(() => { + setUnreadCount(calculateUnreadCount()); + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [announcements]); + const mainNavLinks = [ { text: t('首页'), @@ -106,6 +143,25 @@ const HeaderBar = () => { }, 3000); }; + const handleNoticeOpen = () => { + setNoticeVisible(true); + }; + + const handleNoticeClose = () => { + setNoticeVisible(false); + if (announcements.length) { + let readKeys = []; + try { + readKeys = JSON.parse(localStorage.getItem('notice_read_keys')) || []; + } catch (_) { + readKeys = []; + } + const mergedKeys = Array.from(new Set([...readKeys, ...announcements.map(getAnnouncementKey)])); + localStorage.setItem('notice_read_keys', JSON.stringify(mergedKeys)); + } + setUnreadCount(0); + }; + useEffect(() => { if (theme === 'dark') { document.body.setAttribute('theme-mode', 'dark'); @@ -353,15 +409,14 @@ const HeaderBar = () => { } }; - // 检查当前路由是否以/console开头 - const isConsoleRoute = location.pathname.startsWith('/console'); - return (
setNoticeVisible(false)} + onClose={handleNoticeClose} isMobile={styleState.isMobile} + defaultTab={unreadCount > 0 ? 'system' : 'inApp'} + unreadKeys={getUnreadKeys()} />
@@ -462,14 +517,27 @@ const HeaderBar = () => { )} - - + +
)} size={isMobile ? 'full-width' : 'large'} > - {renderContent()} + {renderBody()} ); }; diff --git a/web/src/components/layout/PageLayout.js b/web/src/components/layout/PageLayout.js index e25901ef..17d16fc0 100644 --- a/web/src/components/layout/PageLayout.js +++ b/web/src/components/layout/PageLayout.js @@ -11,7 +11,7 @@ import { API, getLogo, getSystemName, showError, setStatusData } from '../../hel import { UserContext } from '../../context/User/index.js'; import { StatusContext } from '../../context/Status/index.js'; import { useLocation } from 'react-router-dom'; -const { Sider, Content, Header, Footer } = Layout; +const { Sider, Content, Header } = Layout; const PageLayout = () => { const [userState, userDispatch] = useContext(UserContext); @@ -94,8 +94,6 @@ const PageLayout = () => {
{ + const [searchText, setSearchText] = useState(''); + const [currentPage, setCurrentPage] = useState(1); + const [pageSize, setPageSize] = useState(10); + + const [filteredData, setFilteredData] = useState([]); + + useImperativeHandle(ref, () => ({ + resetPagination: () => { + setCurrentPage(1); + setSearchText(''); + }, + })); + + useEffect(() => { + if (!allChannels) return; + + const searchLower = searchText.trim().toLowerCase(); + const matched = searchLower + ? allChannels.filter((item) => { + const name = (item.label || '').toLowerCase(); + const baseUrl = (item._originalData?.base_url || '').toLowerCase(); + return name.includes(searchLower) || baseUrl.includes(searchLower); + }) + : allChannels; + + setFilteredData(matched); + }, [allChannels, searchText]); + + const total = filteredData.length; + + const paginatedData = filteredData.slice( + (currentPage - 1) * pageSize, + currentPage * pageSize, + ); + + const updateEndpoint = (channelId, endpoint) => { + if (typeof updateChannelEndpoint === 'function') { + updateChannelEndpoint(channelId, endpoint); + } + }; + + const renderEndpointCell = (text, record) => { + const channelId = record.key || record.value; + const currentEndpoint = channelEndpoints[channelId] || ''; + + const getEndpointType = (ep) => { + if (ep === '/api/ratio_config') return 'ratio_config'; + if (ep === '/api/pricing') return 'pricing'; + return 'custom'; + }; + + const currentType = getEndpointType(currentEndpoint); + + const handleTypeChange = (val) => { + if (val === 'ratio_config') { + updateEndpoint(channelId, '/api/ratio_config'); + } else if (val === 'pricing') { + updateEndpoint(channelId, '/api/pricing'); + } else { + if (currentType !== 'custom') { + updateEndpoint(channelId, ''); + } + } + }; + + return ( +
+ updateEndpoint(channelId, val)} + placeholder="/your/endpoint" + style={{ width: 160, fontSize: 12 }} + /> + )} +
+ ); + }; + + const renderStatusCell = (status) => { + switch (status) { + case 1: + return ( + }> + {t('已启用')} + + ); + case 2: + return ( + }> + {t('已禁用')} + + ); + case 3: + return ( + }> + {t('自动禁用')} + + ); + default: + return ( + }> + {t('未知状态')} + + ); + } + }; + + const renderNameCell = (text) => ( + + ); + + const renderBaseUrlCell = (text) => ( + + ); + + const columns = [ + { + title: t('名称'), + dataIndex: 'label', + render: renderNameCell, + }, + { + title: t('源地址'), + dataIndex: '_originalData.base_url', + render: (_, record) => renderBaseUrlCell(record._originalData?.base_url || ''), + }, + { + title: t('状态'), + dataIndex: '_originalData.status', + render: (_, record) => renderStatusCell(record._originalData?.status || 0), + }, + { + title: t('同步接口'), + dataIndex: 'endpoint', + fixed: 'right', + render: renderEndpointCell, + }, + ]; + + const rowSelection = { + selectedRowKeys: selectedChannelIds, + onChange: (keys) => setSelectedChannelIds(keys), + }; + + return ( + {t('选择同步渠道')}} + size={isMobile() ? 'full-width' : 'large'} + keepDOM + lazyRender={false} + > + + } + placeholder={t('搜索渠道名称或地址')} + value={searchText} + onChange={setSearchText} + showClear + /> + + t('第 {{start}} - {{end}} 条,共 {{total}} 条', { + start: page.currentStart, + end: page.currentEnd, + total: total, + }), + onChange: (page, size) => { + setCurrentPage(page); + setPageSize(size); + }, + onShowSizeChange: (curr, size) => { + setCurrentPage(1); + setPageSize(size); + }, + }} + size="small" + /> + + + ); +}); + +export default ChannelSelectorModal; \ No newline at end of file diff --git a/web/src/components/settings/ChatsSetting.js b/web/src/components/settings/ChatsSetting.js new file mode 100644 index 00000000..6330808d --- /dev/null +++ b/web/src/components/settings/ChatsSetting.js @@ -0,0 +1,63 @@ +import React, { useEffect, useState } from 'react'; +import { Card, Spin } from '@douyinfe/semi-ui'; +import SettingsChats from '../../pages/Setting/Chat/SettingsChats.js'; +import { API, showError } from '../../helpers'; + +const ChatsSetting = () => { + let [inputs, setInputs] = useState({ + /* 聊天设置 */ + Chats: '[]', + }); + + let [loading, setLoading] = useState(false); + + const getOptions = async () => { + const res = await API.get('/api/option/'); + const { success, message, data } = res.data; + if (success) { + let newInputs = {}; + data.forEach((item) => { + if ( + item.key.endsWith('Enabled') || + ['DefaultCollapseSidebar'].includes(item.key) + ) { + newInputs[item.key] = item.value === 'true' ? true : false; + } else { + newInputs[item.key] = item.value; + } + }); + + setInputs(newInputs); + } else { + showError(message); + } + }; + + async function onRefresh() { + try { + setLoading(true); + await getOptions(); + } catch (error) { + showError('刷新失败'); + } finally { + setLoading(false); + } + } + + useEffect(() => { + onRefresh(); + }, []); + + return ( + <> + + {/* 聊天设置 */} + + + + + + ); +}; + +export default ChatsSetting; \ No newline at end of file diff --git a/web/src/components/settings/DashboardSetting.js b/web/src/components/settings/DashboardSetting.js index bf4a26a3..4fa1ad10 100644 --- a/web/src/components/settings/DashboardSetting.js +++ b/web/src/components/settings/DashboardSetting.js @@ -5,6 +5,7 @@ import SettingsAPIInfo from '../../pages/Setting/Dashboard/SettingsAPIInfo.js'; import SettingsAnnouncements from '../../pages/Setting/Dashboard/SettingsAnnouncements.js'; import SettingsFAQ from '../../pages/Setting/Dashboard/SettingsFAQ.js'; import SettingsUptimeKuma from '../../pages/Setting/Dashboard/SettingsUptimeKuma.js'; +import SettingsDataDashboard from '../../pages/Setting/Dashboard/SettingsDataDashboard.js'; const DashboardSetting = () => { let [inputs, setInputs] = useState({ @@ -23,6 +24,11 @@ const DashboardSetting = () => { FAQ: '', UptimeKumaUrl: '', UptimeKumaSlug: '', + + /* 数据看板 */ + DataExportEnabled: false, + DataExportDefaultTime: 'hour', + DataExportInterval: 5, }); let [loading, setLoading] = useState(false); @@ -37,6 +43,10 @@ const DashboardSetting = () => { if (item.key in inputs) { newInputs[item.key] = item.value; } + if (item.key.endsWith('Enabled') && + (item.key === 'DataExportEnabled')) { + newInputs[item.key] = item.value === 'true' ? true : false; + } }); setInputs(newInputs); } else { @@ -106,9 +116,9 @@ const DashboardSetting = () => {

- {/* API信息管理 */} + {/* 数据看板设置 */} - + {/* 系统公告管理 */} @@ -116,6 +126,11 @@ const DashboardSetting = () => { + {/* API信息管理 */} + + + + {/* 常见问答管理 */} diff --git a/web/src/components/settings/DrawingSetting.js b/web/src/components/settings/DrawingSetting.js new file mode 100644 index 00000000..d2cdce1e --- /dev/null +++ b/web/src/components/settings/DrawingSetting.js @@ -0,0 +1,65 @@ +import React, { useEffect, useState } from 'react'; +import { Card, Spin } from '@douyinfe/semi-ui'; +import SettingsDrawing from '../../pages/Setting/Drawing/SettingsDrawing.js'; +import { API, showError } from '../../helpers'; + +const DrawingSetting = () => { + let [inputs, setInputs] = useState({ + /* 绘图设置 */ + DrawingEnabled: false, + MjNotifyEnabled: false, + MjAccountFilterEnabled: false, + MjForwardUrlEnabled: false, + MjModeClearEnabled: false, + MjActionCheckSuccessEnabled: false, + }); + + let [loading, setLoading] = useState(false); + + const getOptions = async () => { + const res = await API.get('/api/option/'); + const { success, message, data } = res.data; + if (success) { + let newInputs = {}; + data.forEach((item) => { + if (item.key.endsWith('Enabled')) { + newInputs[item.key] = item.value === 'true' ? true : false; + } else { + newInputs[item.key] = item.value; + } + }); + + setInputs(newInputs); + } else { + showError(message); + } + }; + + async function onRefresh() { + try { + setLoading(true); + await getOptions(); + } catch (error) { + showError('刷新失败'); + } finally { + setLoading(false); + } + } + + useEffect(() => { + onRefresh(); + }, []); + + return ( + <> + + {/* 绘图设置 */} + + + + + + ); +}; + +export default DrawingSetting; \ No newline at end of file diff --git a/web/src/components/settings/OperationSetting.js b/web/src/components/settings/OperationSetting.js index 55e328a3..75a5c81a 100644 --- a/web/src/components/settings/OperationSetting.js +++ b/web/src/components/settings/OperationSetting.js @@ -1,66 +1,44 @@ import React, { useEffect, useState } from 'react'; -import { Card, Spin, Tabs } from '@douyinfe/semi-ui'; +import { Card, Spin } from '@douyinfe/semi-ui'; import SettingsGeneral from '../../pages/Setting/Operation/SettingsGeneral.js'; -import SettingsDrawing from '../../pages/Setting/Operation/SettingsDrawing.js'; import SettingsSensitiveWords from '../../pages/Setting/Operation/SettingsSensitiveWords.js'; import SettingsLog from '../../pages/Setting/Operation/SettingsLog.js'; -import SettingsDataDashboard from '../../pages/Setting/Operation/SettingsDataDashboard.js'; import SettingsMonitoring from '../../pages/Setting/Operation/SettingsMonitoring.js'; import SettingsCreditLimit from '../../pages/Setting/Operation/SettingsCreditLimit.js'; -import ModelSettingsVisualEditor from '../../pages/Setting/Operation/ModelSettingsVisualEditor.js'; -import GroupRatioSettings from '../../pages/Setting/Operation/GroupRatioSettings.js'; -import ModelRatioSettings from '../../pages/Setting/Operation/ModelRatioSettings.js'; - -import { API, showError, showSuccess } from '../../helpers'; -import SettingsChats from '../../pages/Setting/Operation/SettingsChats.js'; -import { useTranslation } from 'react-i18next'; -import ModelRatioNotSetEditor from '../../pages/Setting/Operation/ModelRationNotSetEditor.js'; +import { API, showError } from '../../helpers'; const OperationSetting = () => { - const { t } = useTranslation(); let [inputs, setInputs] = useState({ + /* 额度相关 */ QuotaForNewUser: 0, + PreConsumedQuota: 0, QuotaForInviter: 0, QuotaForInvitee: 0, - QuotaRemindThreshold: 0, - PreConsumedQuota: 0, - StreamCacheQueueLength: 0, - ModelRatio: '', - CacheRatio: '', - CompletionRatio: '', - ModelPrice: '', - GroupRatio: '', - GroupGroupRatio: '', - UserUsableGroups: '', + + /* 通用设置 */ TopUpLink: '', 'general_setting.docs_link': '', - // ChatLink2: '', // 添加的新状态变量 QuotaPerUnit: 0, - AutomaticDisableChannelEnabled: false, - AutomaticEnableChannelEnabled: false, - ChannelDisableThreshold: 0, - LogConsumeEnabled: false, + RetryTimes: 0, DisplayInCurrencyEnabled: false, DisplayTokenStatEnabled: false, - CheckSensitiveEnabled: false, - CheckSensitiveOnPromptEnabled: false, - CheckSensitiveOnCompletionEnabled: '', - StopOnSensitiveEnabled: '', - SensitiveWords: '', - MjNotifyEnabled: false, - MjAccountFilterEnabled: false, - MjModeClearEnabled: false, - MjForwardUrlEnabled: false, - MjActionCheckSuccessEnabled: false, - DrawingEnabled: false, - DataExportEnabled: false, - DataExportDefaultTime: 'hour', - DataExportInterval: 5, - DefaultCollapseSidebar: false, // 默认折叠侧边栏 - RetryTimes: 0, - Chats: '[]', + DefaultCollapseSidebar: false, DemoSiteEnabled: false, SelfUseModeEnabled: false, + + /* 敏感词设置 */ + CheckSensitiveEnabled: false, + CheckSensitiveOnPromptEnabled: false, + SensitiveWords: '', + + /* 日志设置 */ + LogConsumeEnabled: false, + + /* 监控设置 */ + ChannelDisableThreshold: 0, + QuotaRemindThreshold: 0, + AutomaticDisableChannelEnabled: false, + AutomaticEnableChannelEnabled: false, AutomaticDisableKeywords: '', }); @@ -72,17 +50,6 @@ const OperationSetting = () => { if (success) { let newInputs = {}; data.forEach((item) => { - if ( - item.key === 'ModelRatio' || - item.key === 'GroupRatio' || - item.key === 'GroupGroupRatio' || - item.key === 'UserUsableGroups' || - item.key === 'CompletionRatio' || - item.key === 'ModelPrice' || - item.key === 'CacheRatio' - ) { - item.value = JSON.stringify(JSON.parse(item.value), null, 2); - } if ( item.key.endsWith('Enabled') || ['DefaultCollapseSidebar'].includes(item.key) @@ -121,10 +88,6 @@ const OperationSetting = () => { - {/* 绘图设置 */} - - - {/* 屏蔽词过滤设置 */} @@ -133,10 +96,6 @@ const OperationSetting = () => { - {/* 数据看板 */} - - - {/* 监控设置 */} @@ -145,28 +104,6 @@ const OperationSetting = () => { - {/* 聊天设置 */} - - - - {/* 分组倍率设置 */} - - - - {/* 合并模型倍率设置和可视化倍率设置 */} - - - - - - - - - - - - - ); diff --git a/web/src/components/settings/PaymentSetting.js b/web/src/components/settings/PaymentSetting.js new file mode 100644 index 00000000..91a40a2b --- /dev/null +++ b/web/src/components/settings/PaymentSetting.js @@ -0,0 +1,88 @@ +import React, { useEffect, useState } from 'react'; +import { Card, Spin } from '@douyinfe/semi-ui'; +import SettingsGeneralPayment from '../../pages/Setting/Payment/SettingsGeneralPayment.js'; +import SettingsPaymentGateway from '../../pages/Setting/Payment/SettingsPaymentGateway.js'; +import { API, showError } from '../../helpers'; +import { useTranslation } from 'react-i18next'; + +const PaymentSetting = () => { + const { t } = useTranslation(); + let [inputs, setInputs] = useState({ + ServerAddress: '', + PayAddress: '', + EpayId: '', + EpayKey: '', + Price: 7.3, + MinTopUp: 1, + TopupGroupRatio: '', + CustomCallbackAddress: '', + PayMethods: '', + }); + + let [loading, setLoading] = useState(false); + + const getOptions = async () => { + const res = await API.get('/api/option/'); + const { success, message, data } = res.data; + if (success) { + let newInputs = {}; + data.forEach((item) => { + switch (item.key) { + case 'TopupGroupRatio': + try { + newInputs[item.key] = JSON.stringify(JSON.parse(item.value), null, 2); + } catch (error) { + console.error('解析TopupGroupRatio出错:', error); + newInputs[item.key] = item.value; + } + break; + case 'Price': + case 'MinTopUp': + newInputs[item.key] = parseFloat(item.value); + break; + default: + if (item.key.endsWith('Enabled')) { + newInputs[item.key] = item.value === 'true' ? true : false; + } else { + newInputs[item.key] = item.value; + } + break; + } + }); + + setInputs(newInputs); + } else { + showError(t(message)); + } + }; + + async function onRefresh() { + try { + setLoading(true); + await getOptions(); + } catch (error) { + showError(t('刷新失败')); + } finally { + setLoading(false); + } + } + + useEffect(() => { + onRefresh(); + }, []); + + return ( + <> + + + + + + + + + + ); +}; + +export default PaymentSetting; \ No newline at end of file diff --git a/web/src/components/settings/PersonalSetting.js b/web/src/components/settings/PersonalSetting.js index 36eb4e4d..7e2b85fd 100644 --- a/web/src/components/settings/PersonalSetting.js +++ b/web/src/components/settings/PersonalSetting.js @@ -379,257 +379,268 @@ const PersonalSetting = () => { }; return ( -
- - +
+
+
+ {/* 主卡片容器 */} + + {/* 顶部用户信息区域 */} + + {/* 装饰性背景元素 */} +
+
+
+
+
-
-
- {/* 主卡片容器 */} - - {/* 顶部用户信息区域 */} - - {/* 装饰性背景元素 */} -
-
-
-
-
- -
-
-
- - {getAvatarText()} - -
-
- {getUsername()} -
-
- {isRoot() ? ( - - {t('超级管理员')} - - ) : isAdmin() ? ( - - {t('管理员')} - - ) : ( - - {t('普通用户')} - - )} - - ID: {userState?.user?.id} - -
-
-
-
- -
-
- -
-
- {t('当前余额')} -
-
- {renderQuota(userState?.user?.quota)} -
-
- -
-
-
-
- {t('历史消耗')} -
-
- {renderQuota(userState?.user?.used_quota)} -
-
-
-
- {t('请求次数')} -
-
- {userState.user?.request_count || 0} -
-
-
-
- {t('用户分组')} -
-
- {userState?.user?.group || t('默认')} -
-
-
-
- -
-
-
- - {/* 主内容区域 - 使用Tabs组织不同功能模块 */} -
- - {/* 可用模型Tab */} - - - {t('可用模型')} -
- } - itemKey='models' +
+
+
+ -
- {/* 可用模型部分 */} -
-
-
- -
-
- {t('模型列表')} -
{t('点击模型名称可复制')}
+ {getAvatarText()} + +
+
+ {getUsername()} +
+
+ {isRoot() ? ( + + {t('超级管理员')} + + ) : isAdmin() ? ( + + {t('管理员')} + + ) : ( + + {t('普通用户')} + + )} + + ID: {userState?.user?.id} + +
+
+
+
+ +
+
+ +
+
+ {t('当前余额')} +
+
+ {renderQuota(userState?.user?.quota)} +
+
+ +
+
+
+
+ {t('历史消耗')} +
+
+ {renderQuota(userState?.user?.used_quota)} +
+
+
+
+ {t('请求次数')} +
+
+ {userState.user?.request_count || 0} +
+
+
+
+ {t('用户分组')} +
+
+ {userState?.user?.group || t('默认')} +
+
+
+
+ +
+
+ + + {/* 主内容区域 - 使用Tabs组织不同功能模块 */} +
+ + {/* 可用模型Tab */} + + + {t('可用模型')} +
+ } + itemKey='models' + > +
+ {/* 可用模型部分 */} +
+
+
+ +
+
+ {t('模型列表')} +
{t('点击模型名称可复制')}
+
+
+ + {modelsLoading ? ( + // 骨架屏加载状态 - 模拟实际加载后的布局 +
+ {/* 模拟分类标签 */} +
+
+ {Array.from({ length: 8 }).map((_, index) => ( + + ))}
- {modelsLoading ? ( - // 骨架屏加载状态 - 模拟实际加载后的布局 -
- {/* 模拟分类标签 */} -
-
- {Array.from({ length: 8 }).map((_, index) => ( - - ))} -
-
- - {/* 模拟模型标签列表 */} -
- {Array.from({ length: 20 }).map((_, index) => ( - - ))} -
-
- ) : models.length === 0 ? ( -
- } - darkModeImage={} - description={t('没有可用模型')} - style={{ padding: '24px 0' }} + {/* 模拟模型标签列表 */} +
+ {Array.from({ length: 20 }).map((_, index) => ( + -
- ) : ( - <> - {/* 模型分类标签页 */} -
- setActiveModelCategory(key)} - className="mt-2" - > - {Object.entries(getModelCategories(t)).map(([key, category]) => { - // 计算该分类下的模型数量 - const modelCount = key === 'all' - ? models.length - : models.filter(model => category.filter({ model_name: model })).length; + ))} +
+
+ ) : models.length === 0 ? ( +
+ } + darkModeImage={} + description={t('没有可用模型')} + style={{ padding: '24px 0' }} + /> +
+ ) : ( + <> + {/* 模型分类标签页 */} +
+ setActiveModelCategory(key)} + className="mt-2" + > + {Object.entries(getModelCategories(t)).map(([key, category]) => { + // 计算该分类下的模型数量 + const modelCount = key === 'all' + ? models.length + : models.filter(model => category.filter({ model_name: model })).length; - if (modelCount === 0 && key !== 'all') return null; + if (modelCount === 0 && key !== 'all') return null; - return ( - - {category.icon && {category.icon}} - {category.label} - - {modelCount} - - - } - itemKey={key} - key={key} - /> - ); - })} - -
+ return ( + + {category.icon && {category.icon}} + {category.label} + + {modelCount} + + + } + itemKey={key} + key={key} + /> + ); + })} + +
-
- {(() => { - // 根据当前选中的分类过滤模型 - const categories = getModelCategories(t); - const filteredModels = activeModelCategory === 'all' - ? models - : models.filter(model => categories[activeModelCategory].filter({ model_name: model })); +
+ {(() => { + // 根据当前选中的分类过滤模型 + const categories = getModelCategories(t); + const filteredModels = activeModelCategory === 'all' + ? models + : models.filter(model => categories[activeModelCategory].filter({ model_name: model })); - // 如果过滤后没有模型,显示空状态 - if (filteredModels.length === 0) { - return ( - } - darkModeImage={} - description={t('该分类下没有可用模型')} - style={{ padding: '16px 0' }} - /> - ); - } + // 如果过滤后没有模型,显示空状态 + if (filteredModels.length === 0) { + return ( + } + darkModeImage={} + description={t('该分类下没有可用模型')} + style={{ padding: '16px 0' }} + /> + ); + } - if (filteredModels.length <= MODELS_DISPLAY_COUNT) { - return ( + if (filteredModels.length <= MODELS_DISPLAY_COUNT) { + return ( + + {filteredModels.map((model) => ( + renderModelTag(model, { + size: 'large', + shape: 'circle', + onClick: () => copyText(model), + }) + ))} + + ); + } else { + return ( + <> + {filteredModels.map((model) => ( renderModelTag(model, { @@ -638,527 +649,513 @@ const PersonalSetting = () => { onClick: () => copyText(model), }) ))} + setIsModelsExpanded(false)} + icon={} + > + {t('收起')} + + + + {!isModelsExpanded && ( + + {filteredModels + .slice(0, MODELS_DISPLAY_COUNT) + .map((model) => ( + renderModelTag(model, { + size: 'large', + shape: 'circle', + onClick: () => copyText(model), + }) + ))} + setIsModelsExpanded(true)} + icon={} + > + {t('更多')} {filteredModels.length - MODELS_DISPLAY_COUNT} {t('个模型')} + - ); - } else { - return ( - <> - - - {filteredModels.map((model) => ( - renderModelTag(model, { - size: 'large', - shape: 'circle', - onClick: () => copyText(model), - }) - ))} - setIsModelsExpanded(false)} - icon={} - > - {t('收起')} - - - - {!isModelsExpanded && ( - - {filteredModels - .slice(0, MODELS_DISPLAY_COUNT) - .map((model) => ( - renderModelTag(model, { - size: 'large', - shape: 'circle', - onClick: () => copyText(model), - }) - ))} - setIsModelsExpanded(true)} - icon={} - > - {t('更多')} {filteredModels.length - MODELS_DISPLAY_COUNT} {t('个模型')} - - - )} - - ); - } - })()} -
- - )} -
-
- - - {/* 账户绑定Tab */} - - - {t('账户绑定')} -
- } - itemKey='account' - > -
-
- {/* 邮箱绑定 */} - -
-
-
- -
-
-
{t('邮箱')}
-
- {userState.user && userState.user.email !== '' - ? userState.user.email - : t('未绑定')} -
-
-
- -
-
- - {/* 微信绑定 */} - -
-
-
- -
-
-
{t('微信')}
-
- {userState.user && userState.user.wechat_id !== '' - ? t('已绑定') - : t('未绑定')} -
-
-
- -
-
- - {/* GitHub绑定 */} - -
-
-
- -
-
-
{t('GitHub')}
-
- {userState.user && userState.user.github_id !== '' - ? userState.user.github_id - : t('未绑定')} -
-
-
- -
-
- - {/* OIDC绑定 */} - -
-
-
- -
-
-
{t('OIDC')}
-
- {userState.user && userState.user.oidc_id !== '' - ? userState.user.oidc_id - : t('未绑定')} -
-
-
- -
-
- - {/* Telegram绑定 */} - -
-
-
- -
-
-
{t('Telegram')}
-
- {userState.user && userState.user.telegram_id !== '' - ? userState.user.telegram_id - : t('未绑定')} -
-
-
-
- {status.telegram_oauth ? ( - userState.user.telegram_id !== '' ? ( - - ) : ( -
- -
- ) - ) : ( - - )} -
-
-
- - {/* LinuxDO绑定 */} - -
-
-
- -
-
-
{t('LinuxDO')}
-
- {userState.user && userState.user.linux_do_id !== '' - ? userState.user.linux_do_id - : t('未绑定')} -
-
-
- -
-
-
-
- - - {/* 安全设置Tab */} - - - {t('安全设置')} -
- } - itemKey='security' - > -
-
- - {/* 系统访问令牌 */} - -
-
-
- -
-
- - {t('系统访问令牌')} - - - {t('用于API调用的身份验证令牌,请妥善保管')} - - {systemToken && ( -
- } - /> -
)} -
-
- -
-
+ + ); + } + })()} +
+ + )} +
+
+ - {/* 密码管理 */} - -
-
-
- -
-
- - {t('密码管理')} - - - {t('定期更改密码可以提高账户安全性')} - -
-
- + {/* 账户绑定Tab */} + + + {t('账户绑定')} +
+ } + itemKey='account' + > +
+
+ {/* 邮箱绑定 */} + +
+
+
+ +
+
+
{t('邮箱')}
+
+ {userState.user && userState.user.email !== '' + ? userState.user.email + : t('未绑定')}
- - - {/* 危险区域 */} - -
-
-
- -
-
- - {t('删除账户')} - - - {t('此操作不可逆,所有数据将被永久删除')} - -
-
- -
-
- -
-
- - - {/* 通知设置Tab */} - - - {t('其他设置')} -
- } - itemKey='notification' - > -
- - +
+ +
+ + + {/* 微信绑定 */} + +
+
+
+ +
+
+
{t('微信')}
+
+ {userState.user && userState.user.wechat_id !== '' + ? t('已绑定') + : t('未绑定')}
+
+
+ +
+
- {/* Webhook设置 */} - {notificationSettings.warningType === 'webhook' && ( -
-
- {t('Webhook地址')} - - handleNotificationSettingChange('webhookUrl', val) - } - placeholder={t('请输入Webhook地址,例如: https://example.com/webhook')} - size="large" - className="!rounded-lg" - prefix={} - /> -
- {t('只支持https,系统将以 POST 方式发送通知,请确保地址可以接收 POST 请求')} -
-
+ {/* GitHub绑定 */} + +
+
+
+ +
+
+
{t('GitHub')}
+
+ {userState.user && userState.user.github_id !== '' + ? userState.user.github_id + : t('未绑定')} +
+
+
+ +
+
-
- {t('接口凭证(可选)')} + {/* OIDC绑定 */} + +
+
+
+ +
+
+
{t('OIDC')}
+
+ {userState.user && userState.user.oidc_id !== '' + ? userState.user.oidc_id + : t('未绑定')} +
+
+
+ +
+
+ + {/* Telegram绑定 */} + +
+
+
+ +
+
+
{t('Telegram')}
+
+ {userState.user && userState.user.telegram_id !== '' + ? userState.user.telegram_id + : t('未绑定')} +
+
+
+
+ {status.telegram_oauth ? ( + userState.user.telegram_id !== '' ? ( + + ) : ( +
+ +
+ ) + ) : ( + + )} +
+
+
+ + {/* LinuxDO绑定 */} + +
+
+
+ +
+
+
{t('LinuxDO')}
+
+ {userState.user && userState.user.linux_do_id !== '' + ? userState.user.linux_do_id + : t('未绑定')} +
+
+
+ +
+
+
+
+ + + {/* 安全设置Tab */} + + + {t('安全设置')} +
+ } + itemKey='security' + > +
+
+ + {/* 系统访问令牌 */} + +
+
+
+ +
+
+ + {t('系统访问令牌')} + + + {t('用于API调用的身份验证令牌,请妥善保管')} + + {systemToken && ( +
- handleNotificationSettingChange('webhookSecret', val) - } - placeholder={t('请输入密钥')} + readonly + value={systemToken} + onClick={handleSystemTokenClick} size="large" className="!rounded-lg" prefix={} /> -
- {t('密钥将以 Bearer 方式添加到请求头中,用于验证webhook请求的合法性')} -
+ )} +
+
+ +
+
-
-
setShowWebhookDocs(!showWebhookDocs)}> -
- - - {t('Webhook请求结构')} - -
- {showWebhookDocs ? : } -
- -
-                                        {`{
+                        {/* 密码管理 */}
+                        
+                          
+
+
+ +
+
+ + {t('密码管理')} + + + {t('定期更改密码可以提高账户安全性')} + +
+
+ +
+
+ + {/* 危险区域 */} + +
+
+
+ +
+
+ + {t('删除账户')} + + + {t('此操作不可逆,所有数据将被永久删除')} + +
+
+ +
+
+ +
+
+ + + {/* 通知设置Tab */} + + + {t('其他设置')} +
+ } + itemKey='notification' + > +
+ + +
+ {/* 通知方式选择 */} +
+ {t('通知方式')} + + handleNotificationSettingChange('warningType', value) + } + type="pureCard" + > + +
+ +
+
{t('邮件通知')}
+
{t('通过邮件接收通知')}
+
+
+
+ +
+ +
+
{t('Webhook通知')}
+
{t('通过HTTP请求接收通知')}
+
+
+
+
+
+ + {/* Webhook设置 */} + {notificationSettings.warningType === 'webhook' && ( +
+
+ {t('Webhook地址')} + + handleNotificationSettingChange('webhookUrl', val) + } + placeholder={t('请输入Webhook地址,例如: https://example.com/webhook')} + size="large" + className="!rounded-lg" + prefix={} + /> +
+ {t('只支持https,系统将以 POST 方式发送通知,请确保地址可以接收 POST 请求')} +
+
+ +
+ {t('接口凭证(可选)')} + + handleNotificationSettingChange('webhookSecret', val) + } + placeholder={t('请输入密钥')} + size="large" + className="!rounded-lg" + prefix={} + /> +
+ {t('密钥将以 Bearer 方式添加到请求头中,用于验证webhook请求的合法性')} +
+
+ +
+
setShowWebhookDocs(!showWebhookDocs)}> +
+ + + {t('Webhook请求结构')} + +
+ {showWebhookDocs ? : } +
+ +
+                                    {`{
   "type": "quota_exceed",      // 通知类型
   "title": "标题",             // 通知标题
   "content": "通知内容",       // 通知内容,支持 {{value}} 变量占位符
@@ -1174,158 +1171,156 @@ const PersonalSetting = () => {
   "values": ["$0.99"],
   "timestamp": 1739950503
 }`}
-                                      
-
-
-
- )} - - {/* 邮件设置 */} - {notificationSettings.warningType === 'email' && ( -
- {t('通知邮箱')} - - handleNotificationSettingChange('notificationEmail', val) - } - placeholder={t('留空则使用账号绑定的邮箱')} - size="large" - className="!rounded-lg" - prefix={} - /> -
- {t('设置用于接收额度预警的邮箱地址,不填则使用账号绑定的邮箱')} -
-
- )} - - {/* 预警阈值 */} -
- - {t('额度预警阈值')} {renderQuotaWithPrompt(notificationSettings.warningThreshold)} - - - handleNotificationSettingChange('warningThreshold', val) - } - size="large" - className="!rounded-lg w-full max-w-xs" - placeholder={t('请输入预警额度')} - data={[ - { value: 100000, label: '0.2$' }, - { value: 500000, label: '1$' }, - { value: 1000000, label: '5$' }, - { value: 5000000, label: '10$' }, - ]} - prefix={} - /> -
- {t('当剩余额度低于此数值时,系统将通过选择的方式发送通知')} -
+ +
-
+ )} - -
-
- {/* 接受未设置价格模型 */} -
-
-
- -
-
-
-
- - {t('接受未设置价格模型')} - -
- {t('当模型没有设置价格时仍接受调用,仅当您信任该网站时使用,可能会产生高额费用')} -
-
- - handleNotificationSettingChange( - 'acceptUnsetModelRatioModel', - e.target.checked, - ) - } - className="ml-4" - /> -
-
-
-
+ {/* 邮件设置 */} + {notificationSettings.warningType === 'email' && ( +
+ {t('通知邮箱')} + + handleNotificationSettingChange('notificationEmail', val) + } + placeholder={t('留空则使用账号绑定的邮箱')} + size="large" + className="!rounded-lg" + prefix={} + /> +
+ {t('设置用于接收额度预警的邮箱地址,不填则使用账号绑定的邮箱')}
- + )} - -
-
-
-
- -
-
-
-
- - {t('记录请求与错误日志 IP')} - -
- {t('开启后,仅“消费”和“错误”日志将记录您的客户端 IP 地址')} -
-
- - handleNotificationSettingChange( - 'recordIpLog', - e.target.checked, - ) - } - className="ml-4" - /> -
-
-
-
+ {/* 预警阈值 */} +
+ + {t('额度预警阈值')} {renderQuotaWithPrompt(notificationSettings.warningThreshold)} + + + handleNotificationSettingChange('warningThreshold', val) + } + size="large" + className="!rounded-lg w-full max-w-xs" + placeholder={t('请输入预警额度')} + data={[ + { value: 100000, label: '0.2$' }, + { value: 500000, label: '1$' }, + { value: 1000000, label: '5$' }, + { value: 5000000, label: '10$' }, + ]} + prefix={} + /> +
+ {t('当剩余额度低于此数值时,系统将通过选择的方式发送通知')}
- - - -
- +
-
-
- -
- + + + +
+
+ {/* 接受未设置价格模型 */} +
+
+
+ +
+
+
+
+ + {t('接受未设置价格模型')} + +
+ {t('当模型没有设置价格时仍接受调用,仅当您信任该网站时使用,可能会产生高额费用')} +
+
+ + handleNotificationSettingChange( + 'acceptUnsetModelRatioModel', + e.target.checked, + ) + } + className="ml-4" + /> +
+
+
+
+
+
+
+ + +
+
+
+
+ +
+
+
+
+ + {t('记录请求与错误日志 IP')} + +
+ {t('开启后,仅“消费”和“错误”日志将记录您的客户端 IP 地址')} +
+
+ + handleNotificationSettingChange( + 'recordIpLog', + e.target.checked, + ) + } + className="ml-4" + /> +
+
+
+
+
+
+ + +
+ +
+
+
+
-
- - + +
+
{/* 邮箱绑定模态框 */} { if (success) { let newInputs = {}; data.forEach((item) => { - if (item.key === 'ModelRequestRateLimitGroup') { - item.value = JSON.stringify(JSON.parse(item.value), null, 2); - } + if (item.key === 'ModelRequestRateLimitGroup') { + item.value = JSON.stringify(JSON.parse(item.value), null, 2); + } - if (item.key.endsWith('Enabled')) { - newInputs[item.key] = item.value === 'true' ? true : false; - } else { - newInputs[item.key] = item.value; + if (item.key.endsWith('Enabled')) { + newInputs[item.key] = item.value === 'true' ? true : false; + } else { + newInputs[item.key] = item.value; } }); diff --git a/web/src/components/settings/RatioSetting.js b/web/src/components/settings/RatioSetting.js new file mode 100644 index 00000000..b0284e1d --- /dev/null +++ b/web/src/components/settings/RatioSetting.js @@ -0,0 +1,122 @@ +import React, { useEffect, useState } from 'react'; +import { Card, Spin, Tabs } from '@douyinfe/semi-ui'; +import { useTranslation } from 'react-i18next'; + +import GroupRatioSettings from '../../pages/Setting/Ratio/GroupRatioSettings.js'; +import ModelRatioSettings from '../../pages/Setting/Ratio/ModelRatioSettings.js'; +import ModelSettingsVisualEditor from '../../pages/Setting/Ratio/ModelSettingsVisualEditor.js'; +import ModelRatioNotSetEditor from '../../pages/Setting/Ratio/ModelRationNotSetEditor.js'; +import UpstreamRatioSync from '../../pages/Setting/Ratio/UpstreamRatioSync.js'; + +import { API, showError } from '../../helpers'; + +const RatioSetting = () => { + const { t } = useTranslation(); + + let [inputs, setInputs] = useState({ + ModelPrice: '', + ModelRatio: '', + CacheRatio: '', + CompletionRatio: '', + GroupRatio: '', + GroupGroupRatio: '', + AutoGroups: '', + DefaultUseAutoGroup: false, + ExposeRatioEnabled: false, + UserUsableGroups: '', + }); + + const [loading, setLoading] = useState(false); + + const getOptions = async () => { + const res = await API.get('/api/option/'); + const { success, message, data } = res.data; + if (success) { + let newInputs = {}; + data.forEach((item) => { + if ( + item.key === 'ModelRatio' || + item.key === 'GroupRatio' || + item.key === 'GroupGroupRatio' || + item.key === 'AutoGroups' || + item.key === 'UserUsableGroups' || + item.key === 'CompletionRatio' || + item.key === 'ModelPrice' || + item.key === 'CacheRatio' + ) { + try { + item.value = JSON.stringify(JSON.parse(item.value), null, 2); + } catch (e) { + // 如果后端返回的不是合法 JSON,直接展示 + } + } + if (['DefaultUseAutoGroup', 'ExposeRatioEnabled'].includes(item.key)) { + newInputs[item.key] = item.value === 'true' ? true : false; + } else { + newInputs[item.key] = item.value; + } + }); + setInputs(newInputs); + } else { + showError(message); + } + }; + + const onRefresh = async () => { + try { + setLoading(true); + await getOptions(); + } catch (error) { + showError('刷新失败'); + } finally { + setLoading(false); + } + }; + + useEffect(() => { + onRefresh(); + // eslint-disable-next-line react-hooks/exhaustive-deps + }, []); + + return ( + + {/* 模型倍率设置以及可视化编辑器 */} + + + + + + + + + + + + + + + + + + + + + ); +}; + +export default RatioSetting; \ No newline at end of file diff --git a/web/src/components/settings/SystemSetting.js b/web/src/components/settings/SystemSetting.js index 8219159b..b5829f31 100644 --- a/web/src/components/settings/SystemSetting.js +++ b/web/src/components/settings/SystemSetting.js @@ -17,7 +17,6 @@ import { removeTrailingSlash, showError, showSuccess, - verifyJSON } from '../../helpers'; import axios from 'axios'; @@ -42,17 +41,9 @@ const SystemSetting = () => { SMTPAccount: '', SMTPFrom: '', SMTPToken: '', - ServerAddress: '', WorkerUrl: '', WorkerValidKey: '', WorkerAllowHttpImageRequestEnabled: '', - EpayId: '', - EpayKey: '', - Price: 7.3, - MinTopUp: 1, - TopupGroupRatio: '', - PayAddress: '', - CustomCallbackAddress: '', Footer: '', WeChatAuthEnabled: '', WeChatServerAddress: '', @@ -199,11 +190,6 @@ const SystemSetting = () => { setInputs(values); }; - const submitServerAddress = async () => { - let ServerAddress = removeTrailingSlash(inputs.ServerAddress); - await updateOptions([{ key: 'ServerAddress', value: ServerAddress }]); - }; - const submitWorker = async () => { let WorkerUrl = removeTrailingSlash(inputs.WorkerUrl); const options = [ @@ -219,47 +205,6 @@ const SystemSetting = () => { await updateOptions(options); }; - const submitPayAddress = async () => { - if (inputs.ServerAddress === '') { - showError('请先填写服务器地址'); - return; - } - if (originInputs['TopupGroupRatio'] !== inputs.TopupGroupRatio) { - if (!verifyJSON(inputs.TopupGroupRatio)) { - showError('充值分组倍率不是合法的 JSON 字符串'); - return; - } - } - - const options = [ - { key: 'PayAddress', value: removeTrailingSlash(inputs.PayAddress) }, - ]; - - if (inputs.EpayId !== '') { - options.push({ key: 'EpayId', value: inputs.EpayId }); - } - if (inputs.EpayKey !== undefined && inputs.EpayKey !== '') { - options.push({ key: 'EpayKey', value: inputs.EpayKey }); - } - if (inputs.Price !== '') { - options.push({ key: 'Price', value: inputs.Price.toString() }); - } - if (inputs.MinTopUp !== '') { - options.push({ key: 'MinTopUp', value: inputs.MinTopUp.toString() }); - } - if (inputs.CustomCallbackAddress !== '') { - options.push({ - key: 'CustomCallbackAddress', - value: inputs.CustomCallbackAddress, - }); - } - if (originInputs['TopupGroupRatio'] !== inputs.TopupGroupRatio) { - options.push({ key: 'TopupGroupRatio', value: inputs.TopupGroupRatio }); - } - - await updateOptions(options); - }; - const submitSMTP = async () => { const options = []; @@ -541,17 +486,6 @@ const SystemSetting = () => { marginTop: '10px', }} > - - - - - - @@ -594,74 +528,6 @@ const SystemSetting = () => { - - - - (当前仅支持易支付接口,默认使用上方服务器地址作为回调地址!) - - -
- - - - - - - - - - - - - - - - - - - - - - - - - { const { t } = useTranslation(); let type2label = undefined; - const renderType = (type, multiKeyMode=false) => { + const renderType = (type) => { if (!type2label) { type2label = new Map(); for (let i = 0; i < CHANNEL_OPTIONS.length; i++) { @@ -82,24 +61,6 @@ const ChannelsTable = () => { } type2label[0] = { value: 0, label: t('未知类型'), color: 'grey' }; } - - if (multiKeyMode) { - return ( - - - {getChannelIcon(type)} - - } - > - {type2label[type]?.label} - - ); - } return ( { return ( } size='large' shape='circle' type='light' @@ -126,85 +86,29 @@ const ChannelsTable = () => { ); }; - const renderMultiKeyStatus = (status, channelInfo) => { - if (!channelInfo || !channelInfo.multi_key_mode) { - return renderStatus(status, channelInfo); - } - - const { multi_key_status_list, multi_key_size } = channelInfo; - const totalCount = multi_key_size || 0; - - // If multi_key_status_list is null, it means all keys are enabled - if (!multi_key_status_list) { - return ( - }> - {t('已启用')}:{totalCount}/{totalCount} - - ); - } - - // Count enabled keys from the status map - const statusValues = Object.values(multi_key_status_list); - const enabledCount = statusValues.filter(s => s === 1).length; - - // Determine status text, color and icon based on enabled ratio - let statusText, statusColor, statusIcon; - const enabledRatio = totalCount > 0 ? enabledCount / totalCount : 0; - - if (enabledCount === totalCount) { - statusText = t('已启用'); - statusColor = 'green'; - statusIcon = ; - } else if (enabledCount === 0) { - statusText = t('已禁用'); - statusColor = 'red'; - statusIcon = ; - } else { - statusText = t('部分启用'); - // Color based on percentage: green (>80%), yellow (20-80%), red (<20%) - if (enabledRatio > 0.8) { - statusColor = 'green'; - } else if (enabledRatio >= 0.2) { - statusColor = 'yellow'; - } else { - statusColor = 'red'; - } - statusIcon = ; - } - - return ( - - {statusText}:{enabledCount}/{totalCount} - - ); - }; - - const renderStatus = (status, channelInfo=undefined) => { - if (channelInfo?.multi_key_mode) { - return renderMultiKeyStatus(status, channelInfo); - } + const renderStatus = (status) => { switch (status) { case 1: return ( - }> + {t('已启用')} ); case 2: return ( - }> + {t('已禁用')} ); case 3: return ( - }> + {t('自动禁用')} ); default: return ( - }> + {t('未知状态')} ); @@ -216,31 +120,31 @@ const ChannelsTable = () => { time = time.toFixed(2) + t(' 秒'); if (responseTime === 0) { return ( - }> + {t('未测试')} ); } else if (responseTime <= 1000) { return ( - }> + {time} ); } else if (responseTime <= 3000) { return ( - }> + {time} ); } else if (responseTime <= 5000) { return ( - }> + {time} ); } else { return ( - }> + {time} ); @@ -265,6 +169,11 @@ const ChannelsTable = () => { const [visibleColumns, setVisibleColumns] = useState({}); const [showColumnSelector, setShowColumnSelector] = useState(false); + // 状态筛选 all / enabled / disabled + const [statusFilter, setStatusFilter] = useState( + localStorage.getItem('channel-status-filter') || 'all' + ); + // Load saved column preferences from localStorage useEffect(() => { const savedColumns = localStorage.getItem('channels-table-columns'); @@ -372,7 +281,7 @@ const ChannelsTable = () => { dataIndex: 'type', render: (text, record, index) => { if (record.children === undefined) { - return <>{renderType(text, record.channel_info?.multi_key_mode)}; + return <>{renderType(text)}; } else { return <>{renderTagType()}; } @@ -395,12 +304,12 @@ const ChannelsTable = () => { - {renderStatus(text, record.channel_info)} + {renderStatus(text)} ); } else { - return renderStatus(text, record.channel_info); + return renderStatus(text); } }, }, @@ -422,7 +331,7 @@ const ChannelsTable = () => {
- }> + {renderQuota(record.used_quota)} @@ -432,7 +341,6 @@ const ChannelsTable = () => { type='ghost' size='large' shape='circle' - prefixIcon={} onClick={() => updateChannelBalance(record)} > {renderQuotaWithAmount(record.balance)} @@ -444,7 +352,7 @@ const ChannelsTable = () => { } else { return ( - }> + {renderQuota(record.used_quota)} @@ -568,7 +476,6 @@ const ChannelsTable = () => { { node: 'item', name: t('删除'), - icon: , type: 'danger', onClick: () => { Modal.confirm({ @@ -585,7 +492,6 @@ const ChannelsTable = () => { { node: 'item', name: t('复制'), - icon: , type: 'primary', onClick: () => { Modal.confirm({ @@ -600,7 +506,7 @@ const ChannelsTable = () => { return ( - { + setShowEditTag(true); + setEditingTag(record.key); + }} > - ); } @@ -754,19 +635,22 @@ const ChannelsTable = () => { const [modelSearchKeyword, setModelSearchKeyword] = useState(''); const [modelTestResults, setModelTestResults] = useState({}); const [testingModels, setTestingModels] = useState(new Set()); + const [selectedModelKeys, setSelectedModelKeys] = useState([]); const [isBatchTesting, setIsBatchTesting] = useState(false); const [testQueue, setTestQueue] = useState([]); const [isProcessingQueue, setIsProcessingQueue] = useState(false); - - // Form API 引用 + const [modelTablePage, setModelTablePage] = useState(1); + const [activeTypeKey, setActiveTypeKey] = useState('all'); + const [typeCounts, setTypeCounts] = useState({}); + const requestCounter = useRef(0); const [formApi, setFormApi] = useState(null); - - // Form 初始值 + const [compactMode, setCompactMode] = useTableCompactMode('channels'); const formInitValues = { searchKeyword: '', searchGroup: '', searchModel: '', }; + const allSelectingRef = useRef(false); // Filter columns based on visibility settings const getVisibleColumns = () => { @@ -785,21 +669,18 @@ const ChannelsTable = () => { @@ -943,17 +824,41 @@ const ChannelsTable = () => { setChannels(channelDates); }; - const loadChannels = async (page, pageSize, idSort, enableTagMode) => { + const loadChannels = async ( + page, + pageSize, + idSort, + enableTagMode, + typeKey = activeTypeKey, + statusF, + ) => { + if (statusF === undefined) statusF = statusFilter; + + const { searchKeyword, searchGroup, searchModel } = getFormValues(); + if (searchKeyword !== '' || searchGroup !== '' || searchModel !== '') { + setLoading(true); + await searchChannels(enableTagMode, typeKey, statusF, page, pageSize, idSort); + setLoading(false); + return; + } + + const reqId = ++requestCounter.current; // 记录当前请求序号 setLoading(true); + const typeParam = (typeKey !== 'all') ? `&type=${typeKey}` : ''; + const statusParam = statusF !== 'all' ? `&status=${statusF}` : ''; const res = await API.get( - `/api/channel/?p=${page}&page_size=${pageSize}&id_sort=${idSort}&tag_mode=${enableTagMode}`, + `/api/channel/?p=${page}&page_size=${pageSize}&id_sort=${idSort}&tag_mode=${enableTagMode}${typeParam}${statusParam}`, ); - if (res === undefined) { + if (res === undefined || reqId !== requestCounter.current) { return; } const { success, message, data } = res.data; if (success) { - const { items, total } = data; + const { items, total, type_counts } = data; + if (type_counts) { + const sumAll = Object.values(type_counts).reduce((acc, v) => acc + v, 0); + setTypeCounts({ ...type_counts, all: sumAll }); + } setChannelFormat(items, enableTagMode); setChannelCount(total); } else { @@ -993,7 +898,7 @@ const ChannelsTable = () => { if (searchKeyword === '' && searchGroup === '' && searchModel === '') { await loadChannels(activePage, pageSize, idSort, enableTagMode); } else { - await searchChannels(enableTagMode); + await searchChannels(enableTagMode, activeTypeKey, statusFilter, activePage, pageSize, idSort); } }; @@ -1099,7 +1004,7 @@ const ChannelsTable = () => { } }; - // 获取表单值的辅助函数,确保所有值都是字符串 + // 获取表单值的辅助函数 const getFormValues = () => { const formValues = formApi ? formApi.getValues() : {}; return { @@ -1109,23 +1014,35 @@ const ChannelsTable = () => { }; }; - const searchChannels = async (enableTagMode) => { + const searchChannels = async ( + enableTagMode, + typeKey = activeTypeKey, + statusF = statusFilter, + page = 1, + pageSz = pageSize, + sortFlag = idSort, + ) => { const { searchKeyword, searchGroup, searchModel } = getFormValues(); - setSearching(true); try { if (searchKeyword === '' && searchGroup === '' && searchModel === '') { - await loadChannels(activePage - 1, pageSize, idSort, enableTagMode); + await loadChannels(page, pageSz, sortFlag, enableTagMode, typeKey, statusF); return; } + const typeParam = (typeKey !== 'all') ? `&type=${typeKey}` : ''; + const statusParam = statusF !== 'all' ? `&status=${statusF}` : ''; const res = await API.get( - `/api/channel/search?keyword=${searchKeyword}&group=${searchGroup}&model=${searchModel}&id_sort=${idSort}&tag_mode=${enableTagMode}`, + `/api/channel/search?keyword=${searchKeyword}&group=${searchGroup}&model=${searchModel}&id_sort=${sortFlag}&tag_mode=${enableTagMode}&p=${page}&page_size=${pageSz}${typeParam}${statusParam}`, ); const { success, message, data } = res.data; if (success) { - setChannelFormat(data, enableTagMode); - setActivePage(1); + const { items = [], total = 0, type_counts = {} } = data; + const sumAll = Object.values(type_counts).reduce((acc, v) => acc + v, 0); + setTypeCounts({ ...type_counts, all: sumAll }); + setChannelFormat(items, enableTagMode); + setChannelCount(total); + setActivePage(page); } else { showError(message); } @@ -1165,7 +1082,22 @@ const ChannelsTable = () => { const processTestQueue = async () => { if (!isProcessingQueue || testQueue.length === 0) return; - const { channel, model } = testQueue[0]; + const { channel, model, indexInFiltered } = testQueue[0]; + + // 自动翻页到正在测试的模型所在页 + if (currentTestChannel && currentTestChannel.id === channel.id) { + let pageNo; + if (indexInFiltered !== undefined) { + pageNo = Math.floor(indexInFiltered / MODEL_TABLE_PAGE_SIZE) + 1; + } else { + const filteredModelsList = currentTestChannel.models + .split(',') + .filter((m) => m.toLowerCase().includes(modelSearchKeyword.toLowerCase())); + const modelIdx = filteredModelsList.indexOf(model); + pageNo = modelIdx !== -1 ? Math.floor(modelIdx / MODEL_TABLE_PAGE_SIZE) + 1 : 1; + } + setModelTablePage(pageNo); + } try { setTestingModels(prev => new Set([...prev, model])); @@ -1228,16 +1160,22 @@ const ChannelsTable = () => { setIsBatchTesting(true); - const models = currentTestChannel.models + // 重置分页到第一页 + setModelTablePage(1); + + const filteredModels = currentTestChannel.models .split(',') .filter((model) => - model.toLowerCase().includes(modelSearchKeyword.toLowerCase()) + model.toLowerCase().includes(modelSearchKeyword.toLowerCase()), ); - setTestQueue(models.map(model => ({ - channel: currentTestChannel, - model - }))); + setTestQueue( + filteredModels.map((model, idx) => ({ + channel: currentTestChannel, + model, + indexInFiltered: idx, // 记录在过滤列表中的顺序 + })), + ); setIsProcessingQueue(true); }; @@ -1251,25 +1189,113 @@ const ChannelsTable = () => { } else { setShowModelTestModal(false); setModelSearchKeyword(''); + setSelectedModelKeys([]); + setModelTablePage(1); } }; + const channelTypeCounts = useMemo(() => { + if (Object.keys(typeCounts).length > 0) return typeCounts; + // fallback 本地计算 + const counts = { all: channels.length }; + channels.forEach((channel) => { + const collect = (ch) => { + const type = ch.type; + counts[type] = (counts[type] || 0) + 1; + }; + if (channel.children !== undefined) { + channel.children.forEach(collect); + } else { + collect(channel); + } + }); + return counts; + }, [typeCounts, channels]); + + const availableTypeKeys = useMemo(() => { + const keys = ['all']; + Object.entries(channelTypeCounts).forEach(([k, v]) => { + if (k !== 'all' && v > 0) keys.push(String(k)); + }); + return keys; + }, [channelTypeCounts]); + + const renderTypeTabs = () => { + if (enableTagMode) return null; + + return ( + { + setActiveTypeKey(key); + setActivePage(1); + loadChannels(1, pageSize, idSort, enableTagMode, key); + }} + className="mb-4" + > + + {t('全部')} + + {channelTypeCounts['all'] || 0} + + + } + /> + + {CHANNEL_OPTIONS.filter((opt) => availableTypeKeys.includes(String(opt.value))).map((option) => { + const key = String(option.value); + const count = channelTypeCounts[option.value] || 0; + return ( + + {getChannelIcon(option.value)} + {option.label} + + {count} + + + } + /> + ); + })} + + ); + }; + let pageData = channels; const handlePageChange = (page) => { + const { searchKeyword, searchGroup, searchModel } = getFormValues(); setActivePage(page); - loadChannels(page, pageSize, idSort, enableTagMode).then(() => { }); + if (searchKeyword === '' && searchGroup === '' && searchModel === '') { + loadChannels(page, pageSize, idSort, enableTagMode).then(() => { }); + } else { + searchChannels(enableTagMode, activeTypeKey, statusFilter, page, pageSize, idSort); + } }; const handlePageSizeChange = async (size) => { localStorage.setItem('page-size', size + ''); setPageSize(size); setActivePage(1); - loadChannels(1, size, idSort, enableTagMode) - .then() - .catch((reason) => { - showError(reason); - }); + const { searchKeyword, searchGroup, searchModel } = getFormValues(); + if (searchKeyword === '' && searchGroup === '' && searchModel === '') { + loadChannels(1, size, idSort, enableTagMode) + .then() + .catch((reason) => { + showError(reason); + }); + } else { + searchChannels(enableTagMode, activeTypeKey, statusFilter, 1, size, idSort); + } }; const fetchGroups = async () => { @@ -1446,13 +1472,15 @@ const ChannelsTable = () => { const renderHeader = () => (
+ {renderTypeTabs()}
+ +
@@ -1565,11 +1609,17 @@ const ChannelsTable = () => { {t('使用ID排序')} { localStorage.setItem('id-sort', v + ''); setIdSort(v); - loadChannels(activePage, pageSize, v, enableTagMode); + const { searchKeyword, searchGroup, searchModel } = getFormValues(); + if (searchKeyword === '' && searchGroup === '' && searchModel === '') { + loadChannels(activePage, pageSize, v, enableTagMode); + } else { + searchChannels(enableTagMode, activeTypeKey, statusFilter, activePage, pageSize, v); + } }} />
@@ -1579,6 +1629,7 @@ const ChannelsTable = () => { {t('开启批量操作')} { localStorage.setItem('enable-batch-delete', v + ''); @@ -1592,6 +1643,7 @@ const ChannelsTable = () => { {t('标签聚合模式')} { localStorage.setItem('enable-tag-mode', v + ''); @@ -1601,6 +1653,27 @@ const ChannelsTable = () => { }} />
+ + {/* 状态筛选器 */} +
+ + {t('状态筛选')} + + +
@@ -1609,10 +1682,10 @@ const ChannelsTable = () => {
@@ -1658,33 +1731,34 @@ const ChannelsTable = () => { >
} - placeholder={t('搜索渠道的 ID,名称,密钥和API地址 ...')} - className="!rounded-full" + placeholder={t('渠道ID,名称,密钥,API地址')} showClear pure />
} placeholder={t('模型关键字')} - className="!rounded-full" showClear pure />
-
+
{ @@ -1696,14 +1770,16 @@ const ChannelsTable = () => { />
@@ -1747,9 +1823,9 @@ const ChannelsTable = () => { bordered={false} >
rest) : getVisibleColumns()} dataSource={pageData} - scroll={{ x: 'max-content' }} + scroll={compactMode ? undefined : { x: 'max-content' }} pagination={{ currentPage: activePage, pageSize: pageSize, @@ -1787,7 +1863,7 @@ const ChannelsTable = () => { } className="rounded-xl overflow-hidden" size="middle" - loading={loading} + loading={loading || searching} /> @@ -1810,7 +1886,6 @@ const ChannelsTable = () => { value={batchSetTagValue} onChange={(v) => setBatchSetTagValue(v)} size='large' - className="!rounded-full" />
@@ -1823,13 +1898,70 @@ const ChannelsTable = () => { - - {currentTestChannel.name} {t('渠道的模型测试')} - - - {t('共')} {currentTestChannel.models.split(',').length} {t('个模型')} - +
+
+ + {currentTestChannel.name} {t('渠道的模型测试')} + + + {t('共')} {currentTestChannel.models.split(',').length} {t('个模型')} + +
+ + {/* 搜索与操作按钮 */} +
+ { + setModelSearchKeyword(v); + setModelTablePage(1); + }} + className="!w-full" + prefix={} + showClear + /> + + + + +
) } @@ -1841,7 +1973,6 @@ const ChannelsTable = () => {
{ if (isTesting) { return ( - + {t('测试中')} ); @@ -1923,7 +2041,7 @@ const ChannelsTable = () => { if (!testResult) { return ( - + {t('未开始')} ); @@ -1934,7 +2052,7 @@ const ChannelsTable = () => { {testResult.success ? t('成功') : t('失败')} @@ -1956,11 +2074,9 @@ const ChannelsTable = () => { @@ -1968,16 +2084,47 @@ const ChannelsTable = () => { } } ]} - dataSource={currentTestChannel.models - .split(',') - .filter((model) => - model.toLowerCase().includes(modelSearchKeyword.toLowerCase()) - ) - .map((model) => ({ + dataSource={(() => { + const filtered = currentTestChannel.models + .split(',') + .filter((model) => + model.toLowerCase().includes(modelSearchKeyword.toLowerCase()), + ); + const start = (modelTablePage - 1) * MODEL_TABLE_PAGE_SIZE; + const end = start + MODEL_TABLE_PAGE_SIZE; + return filtered.slice(start, end).map((model) => ({ model, - key: model - }))} - pagination={false} + key: model, + })); + })()} + rowSelection={{ + selectedRowKeys: selectedModelKeys, + onChange: (keys) => { + if (allSelectingRef.current) { + allSelectingRef.current = false; + return; + } + setSelectedModelKeys(keys); + }, + onSelectAll: (checked) => { + const filtered = currentTestChannel.models + .split(',') + .filter((m) => m.toLowerCase().includes(modelSearchKeyword.toLowerCase())); + allSelectingRef.current = true; + setSelectedModelKeys(checked ? filtered : []); + }, + }} + pagination={{ + currentPage: modelTablePage, + pageSize: MODEL_TABLE_PAGE_SIZE, + total: currentTestChannel.models + .split(',') + .filter((model) => + model.toLowerCase().includes(modelSearchKeyword.toLowerCase()), + ).length, + showSizeChanger: false, + onPageChange: (page) => setModelTablePage(page), + }} /> )} diff --git a/web/src/components/table/LogsTable.js b/web/src/components/table/LogsTable.js index 90e4a809..cc1d2082 100644 --- a/web/src/components/table/LogsTable.js +++ b/web/src/components/table/LogsTable.js @@ -47,8 +47,9 @@ import { } from '@douyinfe/semi-illustrations'; import { ITEMS_PER_PAGE } from '../../constants'; import Paragraph from '@douyinfe/semi-ui/lib/es/typography/paragraph'; -import { IconSetting, IconSearch, IconHelpCircle } from '@douyinfe/semi-icons'; +import { IconSearch, IconHelpCircle } from '@douyinfe/semi-icons'; import { Route } from 'lucide-react'; +import { useTableCompactMode } from '../../hooks/useTableCompactMode'; const { Text } = Typography; @@ -192,7 +193,7 @@ const LogsTable = () => { if (!modelMapped) { return renderModelTag(record.model_name, { onClick: (event) => { - copyText(event, record.model_name).then((r) => {}); + copyText(event, record.model_name).then((r) => { }); }, }); } else { @@ -209,7 +210,7 @@ const LogsTable = () => { {renderModelTag(record.model_name, { onClick: (event) => { - copyText(event, record.model_name).then((r) => {}); + copyText(event, record.model_name).then((r) => { }); }, })} @@ -220,7 +221,7 @@ const LogsTable = () => { {renderModelTag(other.upstream_model_name, { onClick: (event) => { copyText(event, other.upstream_model_name).then( - (r) => {}, + (r) => { }, ); }, })} @@ -231,7 +232,7 @@ const LogsTable = () => { > {renderModelTag(record.model_name, { onClick: (event) => { - copyText(event, record.model_name).then((r) => {}); + copyText(event, record.model_name).then((r) => { }); }, suffixIcon: ( { } let content = other?.claude ? renderClaudeModelPriceSimple( - other.model_ratio, - other.model_price, - other.group_ratio, - other?.user_group_ratio, - other.cache_tokens || 0, - other.cache_ratio || 1.0, - other.cache_creation_tokens || 0, - other.cache_creation_ratio || 1.0, - ) + other.model_ratio, + other.model_price, + other.group_ratio, + other?.user_group_ratio, + other.cache_tokens || 0, + other.cache_ratio || 1.0, + other.cache_creation_tokens || 0, + other.cache_creation_ratio || 1.0, + ) : renderModelPriceSimple( - other.model_ratio, - other.model_price, - other.group_ratio, - other?.user_group_ratio, - other.cache_tokens || 0, - other.cache_ratio || 1.0, - ); + other.model_ratio, + other.model_price, + other.group_ratio, + other?.user_group_ratio, + other.cache_tokens || 0, + other.cache_ratio || 1.0, + ); return ( { @@ -985,27 +983,27 @@ const LogsTable = () => { key: t('日志详情'), value: other?.claude ? renderClaudeLogContent( - other?.model_ratio, - other.completion_ratio, - other.model_price, - other.group_ratio, - other?.user_group_ratio, - other.cache_ratio || 1.0, - other.cache_creation_ratio || 1.0, - ) + other?.model_ratio, + other.completion_ratio, + other.model_price, + other.group_ratio, + other?.user_group_ratio, + other.cache_ratio || 1.0, + other.cache_creation_ratio || 1.0, + ) : renderLogContent( - other?.model_ratio, - other.completion_ratio, - other.model_price, - other.group_ratio, - other?.user_group_ratio, - false, - 1.0, - other.web_search || false, - other.web_search_call_count || 0, - other.file_search || false, - other.file_search_call_count || 0, - ), + other?.model_ratio, + other.completion_ratio, + other.model_price, + other.group_ratio, + other?.user_group_ratio, + false, + 1.0, + other.web_search || false, + other.web_search_call_count || 0, + other.file_search || false, + other.file_search_call_count || 0, + ), }); } if (logs[i].type === 2) { @@ -1145,7 +1143,7 @@ const LogsTable = () => { const handlePageChange = (page) => { setActivePage(page); - loadLogs(page, pageSize).then((r) => {}); // 不传入logType,让其从表单获取最新值 + loadLogs(page, pageSize).then((r) => { }); // 不传入logType,让其从表单获取最新值 }; const handlePageSizeChange = async (size) => { @@ -1203,6 +1201,8 @@ const LogsTable = () => { ); }; + const [compactMode, setCompactMode] = useTableCompactMode('logs'); + return ( <> {renderColumnSelector()} @@ -1211,45 +1211,56 @@ const LogsTable = () => { title={
- - + + + {t('消耗额度')}: {renderQuota(stat.quota)} + + + RPM: {stat.rpm} + + + TPM: {stat.tpm} + + + + +
@@ -1284,7 +1295,6 @@ const LogsTable = () => { field='token_name' prefix={} placeholder={t('令牌名称')} - className='!rounded-full' showClear pure /> @@ -1293,7 +1303,6 @@ const LogsTable = () => { field='model_name' prefix={} placeholder={t('模型名称')} - className='!rounded-full' showClear pure /> @@ -1302,7 +1311,6 @@ const LogsTable = () => { field='group' prefix={} placeholder={t('分组')} - className='!rounded-full' showClear pure /> @@ -1313,7 +1321,6 @@ const LogsTable = () => { field='channel' prefix={} placeholder={t('渠道 ID')} - className='!rounded-full' showClear pure /> @@ -1321,7 +1328,6 @@ const LogsTable = () => { field='username' prefix={} placeholder={t('用户名称')} - className='!rounded-full' showClear pure /> @@ -1336,7 +1342,7 @@ const LogsTable = () => { { @@ -1372,7 +1378,6 @@ const LogsTable = () => { type='primary' htmlType='submit' loading={loading} - className='!rounded-full' > {t('查询')} @@ -1382,22 +1387,18 @@ const LogsTable = () => { if (formApi) { formApi.reset(); setLogType(0); - // 重置后立即查询,使用setTimeout确保表单重置完成 setTimeout(() => { refresh(); }, 100); } }} - className='!rounded-full' > {t('重置')} @@ -1411,7 +1412,7 @@ const LogsTable = () => { bordered={false} >
rest) : getVisibleColumns()} {...(hasExpandableRows() && { expandedRowRender: expandRowRender, expandRowByClick: true, @@ -1421,7 +1422,7 @@ const LogsTable = () => { dataSource={logs} rowKey='key' loading={loading} - scroll={{ x: 'max-content' }} + scroll={compactMode ? undefined : { x: 'max-content' }} className='rounded-xl overflow-hidden' size='middle' empty={ diff --git a/web/src/components/table/MjLogsTable.js b/web/src/components/table/MjLogsTable.js index 869db485..66e52dd6 100644 --- a/web/src/components/table/MjLogsTable.js +++ b/web/src/components/table/MjLogsTable.js @@ -24,7 +24,7 @@ import { XCircle, Loader, AlertCircle, - Hash + Hash, } from 'lucide-react'; import { API, @@ -59,8 +59,8 @@ import { ITEMS_PER_PAGE } from '../../constants'; import { IconEyeOpened, IconSearch, - IconSetting } from '@douyinfe/semi-icons'; +import { useTableCompactMode } from '../../hooks/useTableCompactMode'; const { Text } = Typography; @@ -107,6 +107,7 @@ const LogsTable = () => { const [visibleColumns, setVisibleColumns] = useState({}); const [showColumnSelector, setShowColumnSelector] = useState(false); const isAdminUser = isAdmin(); + const [compactMode, setCompactMode] = useTableCompactMode('mjLogs'); // 加载保存的列偏好设置 useEffect(() => { @@ -194,6 +195,18 @@ const LogsTable = () => { {t('放大')} ); + case 'VIDEO': + return ( + }> + {t('视频')} + + ); + case 'EDITS': + return ( + }> + {t('编辑')} + + ); case 'VARIATION': return ( }> @@ -514,7 +527,6 @@ const LogsTable = () => { setModalImageUrl(text); setIsModalOpenurl(true); }} - className="!rounded-full" > {t('查看图片')} @@ -732,21 +744,18 @@ const LogsTable = () => { @@ -802,7 +811,7 @@ const LogsTable = () => { className="!rounded-2xl mb-4" title={
-
+
{loading ? ( @@ -821,6 +830,14 @@ const LogsTable = () => { )}
+
@@ -855,7 +872,6 @@ const LogsTable = () => { field='mj_id' prefix={} placeholder={t('任务 ID')} - className="!rounded-full" showClear pure /> @@ -866,7 +882,6 @@ const LogsTable = () => { field='channel_id' prefix={} placeholder={t('渠道 ID')} - className="!rounded-full" showClear pure /> @@ -881,7 +896,6 @@ const LogsTable = () => { type='primary' htmlType='submit' loading={loading} - className="!rounded-full" > {t('查询')} @@ -896,16 +910,13 @@ const LogsTable = () => { }, 100); } }} - className="!rounded-full" > {t('重置')} @@ -919,11 +930,11 @@ const LogsTable = () => { bordered={false} >
rest) : getVisibleColumns()} dataSource={logs} rowKey='key' loading={loading} - scroll={{ x: 'max-content' }} + scroll={compactMode ? undefined : { x: 'max-content' }} className="rounded-xl overflow-hidden" size="middle" empty={ diff --git a/web/src/components/table/ModelPricing.js b/web/src/components/table/ModelPricing.js index b81274c7..bf9df911 100644 --- a/web/src/components/table/ModelPricing.js +++ b/web/src/components/table/ModelPricing.js @@ -1,5 +1,5 @@ import React, { useContext, useEffect, useRef, useMemo, useState } from 'react'; -import { API, copy, showError, showInfo, showSuccess, getModelCategories, renderModelTag } from '../../helpers'; +import { API, copy, showError, showInfo, showSuccess, getModelCategories, renderModelTag, stringToColor } from '../../helpers'; import { useTranslation } from 'react-i18next'; import { @@ -16,7 +16,6 @@ import { Card, Tabs, TabPane, - Dropdown, Empty } from '@douyinfe/semi-ui'; import { @@ -107,6 +106,26 @@ const ModelPricing = () => { ) : null; } + function renderSupportedEndpoints(endpoints) { + if (!endpoints || endpoints.length === 0) { + return null; + } + return ( + + {endpoints.map((endpoint, idx) => ( + + {endpoint} + + ))} + + ); + } + const columns = [ { title: t('可用性'), @@ -121,6 +140,13 @@ const ModelPricing = () => { }, defaultSortOrder: 'descend', }, + { + title: t('可用端点类型'), + dataIndex: 'supported_endpoint_types', + render: (text, record, index) => { + return renderSupportedEndpoints(text); + }, + }, { title: t('模型名称'), dataIndex: 'model_name', @@ -162,6 +188,7 @@ const ModelPricing = () => { { setSelectedGroup(group); showInfo( @@ -171,7 +198,7 @@ const ModelPricing = () => { }), ); }} - className="cursor-pointer hover:opacity-80 transition-opacity !rounded-full" + className="cursor-pointer hover:opacity-80 transition-opacity" > {group} @@ -257,7 +284,7 @@ const ModelPricing = () => { const [models, setModels] = useState([]); const [loading, setLoading] = useState(true); - const [userState, userDispatch] = useContext(UserContext); + const [userState] = useContext(UserContext); const [groupRatio, setGroupRatio] = useState({}); const [usableGroup, setUsableGroup] = useState({}); @@ -334,57 +361,6 @@ const ModelPricing = () => { return counts; }, [models, modelCategories]); - const renderArrow = (items, pos, handleArrowClick) => { - const style = { - width: 32, - height: 32, - margin: '0 12px', - display: 'flex', - justifyContent: 'center', - alignItems: 'center', - borderRadius: '100%', - background: 'rgba(var(--semi-grey-1), 1)', - color: 'var(--semi-color-text)', - cursor: 'pointer', - }; - return ( - - {items.map(item => { - const key = item.itemKey; - const modelCount = categoryCounts[key] || 0; - - return ( - setActiveKey(item.itemKey)} - icon={modelCategories[item.itemKey]?.icon} - > -
- {modelCategories[item.itemKey]?.label || item.itemKey} - - {modelCount} - -
-
- ); - })} - - } - > -
- {pos === 'start' ? '←' : '→'} -
-
- ); - }; - - // 检查分类是否有对应的模型 const availableCategories = useMemo(() => { if (!models.length) return ['all']; @@ -394,11 +370,9 @@ const ModelPricing = () => { }).map(([key]) => key); }, [models]); - // 渲染标签页 const renderTabs = () => { return ( { ); }; - // 优化过滤逻辑 const filteredModels = useMemo(() => { let result = models; - // 先按分类过滤 if (activeKey !== 'all') { result = result.filter(model => modelCategories[activeKey].filter(model)); } - // 再按搜索词过滤 if (filteredValue.length > 0) { const searchTerm = filteredValue[0].toLowerCase(); result = result.filter(model => @@ -454,7 +425,6 @@ const ModelPricing = () => { return result; }, [activeKey, models, filteredValue]); - // 搜索和操作区组件 const SearchAndActions = useMemo(() => (
@@ -462,7 +432,6 @@ const ModelPricing = () => { } placeholder={t('模糊搜索模型名称')} - className="!rounded-lg" onCompositionStart={handleCompositionStart} onCompositionEnd={handleCompositionEnd} onChange={handleChange} @@ -476,7 +445,7 @@ const ModelPricing = () => { icon={} onClick={() => copyText(selectedRowKeys)} disabled={selectedRowKeys.length === 0} - className="!rounded-lg !bg-blue-500 hover:!bg-blue-600 text-white" + className="!bg-blue-500 hover:!bg-blue-600 text-white" size="large" > {t('复制选中模型')} @@ -485,7 +454,6 @@ const ModelPricing = () => { ), [selectedRowKeys, t]); - // 表格组件 const ModelTable = useMemo(() => (
{
-
+
{/* 主卡片容器 */} - + {/* 顶部状态卡片 */} { }} bodyStyle={{ padding: 0 }} > - {/* 装饰性背景元素 */} -
-
-
-
-
-
@@ -565,7 +526,7 @@ const ModelPricing = () => {
- {t('未登录,使用默认分组倍率')}: {groupRatio['default']} + {t('未登录,使用默认分组倍率:')}{groupRatio['default']}
)} diff --git a/web/src/components/table/RedemptionsTable.js b/web/src/components/table/RedemptionsTable.js index e11a4657..9bdb603f 100644 --- a/web/src/components/table/RedemptionsTable.js +++ b/web/src/components/table/RedemptionsTable.js @@ -8,14 +8,7 @@ import { renderQuota } from '../../helpers'; -import { - CheckCircle, - XCircle, - Minus, - HelpCircle, - Coins, - Ticket -} from 'lucide-react'; +import { Ticket } from 'lucide-react'; import { ITEMS_PER_PAGE } from '../../constants'; import { @@ -37,18 +30,12 @@ import { IllustrationNoResultDark } from '@douyinfe/semi-illustrations'; import { - IconPlus, - IconCopy, IconSearch, - IconEyeOpened, - IconEdit, - IconDelete, - IconStop, - IconPlay, - IconMore + IconMore, } from '@douyinfe/semi-icons'; import EditRedemption from '../../pages/Redemption/EditRedemption'; import { useTranslation } from 'react-i18next'; +import { useTableCompactMode } from '../../hooks/useTableCompactMode'; const { Text } = Typography; @@ -66,31 +53,31 @@ const RedemptionsTable = () => { const renderStatus = (status, record) => { if (isExpired(record)) { return ( - }>{t('已过期')} + {t('已过期')} ); } switch (status) { case 1: return ( - }> + {t('未使用')} ); case 2: return ( - }> + {t('已禁用')} ); case 3: return ( - }> + {t('已使用')} ); default: return ( - }> + {t('未知状态')} ); @@ -120,7 +107,7 @@ const RedemptionsTable = () => { render: (text, record, index) => { return (
- }> + {renderQuota(parseInt(text))}
@@ -158,7 +145,6 @@ const RedemptionsTable = () => { { node: 'item', name: t('删除'), - icon: , type: 'danger', onClick: () => { Modal.confirm({ @@ -178,7 +164,6 @@ const RedemptionsTable = () => { moreMenuItems.push({ node: 'item', name: t('禁用'), - icon: , type: 'warning', onClick: () => { manageRedemption(record.id, 'disable', record); @@ -188,7 +173,6 @@ const RedemptionsTable = () => { moreMenuItems.push({ node: 'item', name: t('启用'), - icon: , type: 'secondary', onClick: () => { manageRedemption(record.id, 'enable', record); @@ -201,21 +185,17 @@ const RedemptionsTable = () => {
@@ -479,8 +442,7 @@ const RedemptionsTable = () => {
@@ -583,7 +542,7 @@ const RedemptionsTable = () => { }, 100); } }} - className="!rounded-full flex-1 md:flex-initial md:w-auto" + className="flex-1 md:flex-initial md:w-auto" > {t('重置')} @@ -610,9 +569,9 @@ const RedemptionsTable = () => { bordered={false} >
rest) : columns} dataSource={pageData} - scroll={{ x: 'max-content' }} + scroll={compactMode ? undefined : { x: 'max-content' }} pagination={{ currentPage: activePage, pageSize: pageSize, diff --git a/web/src/components/table/TaskLogsTable.js b/web/src/components/table/TaskLogsTable.js index 449b3d55..de41478e 100644 --- a/web/src/components/table/TaskLogsTable.js +++ b/web/src/components/table/TaskLogsTable.js @@ -11,7 +11,9 @@ import { XCircle, Loader, List, - Hash + Hash, + Video, + Sparkles } from 'lucide-react'; import { API, @@ -45,8 +47,9 @@ import { ITEMS_PER_PAGE } from '../../constants'; import { IconEyeOpened, IconSearch, - IconSetting } from '@douyinfe/semi-icons'; +import { useTableCompactMode } from '../../hooks/useTableCompactMode'; +import { TASK_ACTION_GENERATE, TASK_ACTION_TEXT_GENERATE } from '../../constants/common.constant'; const { Text } = Typography; @@ -80,6 +83,7 @@ const COLUMN_KEYS = { TASK_STATUS: 'task_status', PROGRESS: 'progress', FAIL_REASON: 'fail_reason', + RESULT_URL: 'result_url', }; const renderTimestamp = (timestampInSeconds) => { @@ -96,20 +100,8 @@ const renderTimestamp = (timestampInSeconds) => { }; function renderDuration(submit_time, finishTime) { - // 确保startTime和finishTime都是有效的时间戳 if (!submit_time || !finishTime) return 'N/A'; - - // 将时间戳转换为Date对象 - const start = new Date(submit_time); - const finish = new Date(finishTime); - - // 计算时间差(毫秒) - const durationMs = finish - start; - - // 将时间差转换为秒,并保留一位小数 - const durationSec = (durationMs / 1000).toFixed(1); - - // 设置颜色:大于60秒则为红色,小于等于60秒则为绿色 + const durationSec = finishTime - submit_time; const color = durationSec > 60 ? 'red' : 'green'; // 返回带有样式的颜色标签 @@ -162,6 +154,7 @@ const LogsTable = () => { [COLUMN_KEYS.TASK_STATUS]: true, [COLUMN_KEYS.PROGRESS]: true, [COLUMN_KEYS.FAIL_REASON]: true, + [COLUMN_KEYS.RESULT_URL]: true, }; }; @@ -215,6 +208,18 @@ const LogsTable = () => { {t('生成歌词')} ); + case TASK_ACTION_GENERATE: + return ( + }> + {t('图生视频')} + + ); + case TASK_ACTION_TEXT_GENERATE: + return ( + }> + {t('文生视频')} + + ); default: return ( }> @@ -224,14 +229,26 @@ const LogsTable = () => { } }; - const renderPlatform = (type) => { - switch (type) { + const renderPlatform = (platform) => { + switch (platform) { case 'suno': return ( }> Suno ); + case 'kling': + return ( + }> + Kling + + ); + case 'jimeng': + return ( + }> + Jimeng + + ); default: return ( }> @@ -423,10 +440,21 @@ const LogsTable = () => { }, { key: COLUMN_KEYS.FAIL_REASON, - title: t('失败原因'), + title: t('详情'), dataIndex: 'fail_reason', fixed: 'right', render: (text, record, index) => { + // 仅当为视频生成任务且成功,且 fail_reason 是 URL 时显示可点击链接 + const isVideoTask = record.action === TASK_ACTION_GENERATE || record.action === TASK_ACTION_TEXT_GENERATE; + const isSuccess = record.status === 'SUCCESS'; + const isUrl = typeof text === 'string' && /^https?:\/\//.test(text); + if (isSuccess && isVideoTask && isUrl) { + return ( + + {t('点击预览视频')} + + ); + } if (!text) { return t('无'); } @@ -456,6 +484,8 @@ const LogsTable = () => { const [logs, setLogs] = useState([]); const [loading, setLoading] = useState(false); + const [compactMode, setCompactMode] = useTableCompactMode('taskLogs'); + useEffect(() => { const localPageSize = parseInt(localStorage.getItem('task-page-size')) || ITEMS_PER_PAGE; setPageSize(localPageSize); @@ -569,21 +599,18 @@ const LogsTable = () => { @@ -635,7 +662,7 @@ const LogsTable = () => { className="!rounded-2xl mb-4" title={
-
+
{loading ? ( @@ -650,6 +677,14 @@ const LogsTable = () => { {t('任务记录')} )}
+
@@ -684,7 +719,6 @@ const LogsTable = () => { field='task_id' prefix={} placeholder={t('任务 ID')} - className="!rounded-full" showClear pure /> @@ -695,7 +729,6 @@ const LogsTable = () => { field='channel_id' prefix={} placeholder={t('渠道 ID')} - className="!rounded-full" showClear pure /> @@ -710,7 +743,6 @@ const LogsTable = () => { type='primary' htmlType='submit' loading={loading} - className="!rounded-full" > {t('查询')} @@ -725,16 +757,13 @@ const LogsTable = () => { }, 100); } }} - className="!rounded-full" > {t('重置')} @@ -748,11 +777,11 @@ const LogsTable = () => { bordered={false} >
rest) : getVisibleColumns()} dataSource={logs} rowKey='key' loading={loading} - scroll={{ x: 'max-content' }} + scroll={compactMode ? undefined : { x: 'max-content' }} className="rounded-xl overflow-hidden" size="middle" empty={ diff --git a/web/src/components/table/TokensTable.js b/web/src/components/table/TokensTable.js index bc6c7607..f91f7b82 100644 --- a/web/src/components/table/TokensTable.js +++ b/web/src/components/table/TokensTable.js @@ -9,7 +9,6 @@ import { renderQuota, getQuotaPerUnit } from '../../helpers'; - import { ITEMS_PER_PAGE } from '../../constants'; import { Button, @@ -29,33 +28,15 @@ import { IllustrationNoResult, IllustrationNoResultDark } from '@douyinfe/semi-illustrations'; - import { - CheckCircle, - Shield, - XCircle, - Clock, - Gauge, - HelpCircle, - Infinity, - Coins, - Key -} from 'lucide-react'; - -import { - IconPlus, - IconCopy, IconSearch, IconTreeTriangleDown, - IconEyeOpened, - IconEdit, - IconDelete, - IconStop, - IconPlay, - IconMore + IconMore, } from '@douyinfe/semi-icons'; +import { Key } from 'lucide-react'; import EditToken from '../../pages/Token/EditToken'; import { useTranslation } from 'react-i18next'; +import { useTableCompactMode } from '../../hooks/useTableCompactMode'; const { Text } = Typography; @@ -71,38 +52,38 @@ const TokensTable = () => { case 1: if (model_limits_enabled) { return ( - }> + {t('已启用:限制模型')} ); } else { return ( - }> + {t('已启用')} ); } case 2: return ( - }> + {t('已禁用')} ); case 3: return ( - }> + {t('已过期')} ); case 4: return ( - }> + {t('已耗尽')} ); default: return ( - }> + {t('未知状态')} ); @@ -135,7 +116,7 @@ const TokensTable = () => { render: (text, record, index) => { return (
- }> + {renderQuota(parseInt(text))}
@@ -162,7 +143,7 @@ const TokensTable = () => { return (
{record.unlimited_quota ? ( - }> + {t('无限制')} ) : ( @@ -170,7 +151,6 @@ const TokensTable = () => { size={'large'} color={getQuotaColor(parseInt(text))} shape='circle' - prefixIcon={} > {renderQuota(parseInt(text))} @@ -236,7 +216,6 @@ const TokensTable = () => { { node: 'item', name: t('查看'), - icon: , onClick: () => { Modal.info({ title: t('令牌详情'), @@ -248,7 +227,6 @@ const TokensTable = () => { { node: 'item', name: t('删除'), - icon: , type: 'danger', onClick: () => { Modal.confirm({ @@ -269,7 +247,6 @@ const TokensTable = () => { moreMenuItems.push({ node: 'item', name: t('禁用'), - icon: , type: 'warning', onClick: () => { manageToken(record.id, 'disable', record); @@ -279,7 +256,6 @@ const TokensTable = () => { moreMenuItems.push({ node: 'item', name: t('启用'), - icon: , type: 'secondary', onClick: () => { manageToken(record.id, 'enable', record); @@ -290,7 +266,7 @@ const TokensTable = () => { return (
-
+
+ + + ), + }); }} > - {t('复制所选令牌到剪贴板')} + {t('复制所选令牌')} + +
@@ -649,7 +706,6 @@ const TokensTable = () => { field="searchKeyword" prefix={} placeholder={t('搜索关键字')} - className="!rounded-full" showClear pure /> @@ -659,7 +715,6 @@ const TokensTable = () => { field="searchToken" prefix={} placeholder={t('密钥')} - className="!rounded-full" showClear pure /> @@ -669,7 +724,7 @@ const TokensTable = () => { type="primary" htmlType="submit" loading={loading || searching} - className="!rounded-full flex-1 md:flex-initial md:w-auto" + className="flex-1 md:flex-initial md:w-auto" > {t('查询')} @@ -684,7 +739,7 @@ const TokensTable = () => { }, 100); } }} - className="!rounded-full flex-1 md:flex-initial md:w-auto" + className="flex-1 md:flex-initial md:w-auto" > {t('重置')} @@ -711,9 +766,15 @@ const TokensTable = () => { bordered={false} >
{ + if (col.dataIndex === 'operate') { + const { fixed, ...rest } = col; + return rest; + } + return col; + }) : columns} dataSource={tokens} - scroll={{ x: 'max-content' }} + scroll={compactMode ? undefined : { x: 'max-content' }} pagination={{ currentPage: activePage, pageSize: pageSize, diff --git a/web/src/components/table/UsersTable.js b/web/src/components/table/UsersTable.js index a027af59..02a19b80 100644 --- a/web/src/components/table/UsersTable.js +++ b/web/src/components/table/UsersTable.js @@ -13,7 +13,7 @@ import { Activity, Users, DollarSign, - UserPlus + UserPlus, } from 'lucide-react'; import { Button, @@ -26,6 +26,7 @@ import { Space, Table, Tag, + Tooltip, Typography } from '@douyinfe/semi-ui'; import { @@ -33,26 +34,21 @@ import { IllustrationNoResultDark } from '@douyinfe/semi-illustrations'; import { - IconPlus, IconSearch, - IconEdit, - IconDelete, - IconStop, - IconPlay, - IconMore, IconUserAdd, - IconArrowUp, - IconArrowDown + IconMore, } from '@douyinfe/semi-icons'; import { ITEMS_PER_PAGE } from '../../constants'; import AddUser from '../../pages/User/AddUser'; import EditUser from '../../pages/User/EditUser'; import { useTranslation } from 'react-i18next'; +import { useTableCompactMode } from '../../hooks/useTableCompactMode'; const { Text } = Typography; const UsersTable = () => { const { t } = useTranslation(); + const [compactMode, setCompactMode] = useTableCompactMode('users'); function renderRole(role) { switch (role) { @@ -110,6 +106,27 @@ const UsersTable = () => { { title: t('用户名'), dataIndex: 'username', + render: (text, record) => { + const remark = record.remark; + if (!remark) { + return {text}; + } + const maxLen = 10; + const displayRemark = remark.length > maxLen ? remark.slice(0, maxLen) + '…' : remark; + return ( + + {text} + + +
+
+ {displayRemark} +
+ + + + ); + }, }, { title: t('分组'), @@ -196,7 +213,6 @@ const UsersTable = () => { { node: 'item', name: t('提升'), - icon: , type: 'warning', onClick: () => { Modal.confirm({ @@ -211,7 +227,6 @@ const UsersTable = () => { { node: 'item', name: t('降级'), - icon: , type: 'secondary', onClick: () => { Modal.confirm({ @@ -226,7 +241,6 @@ const UsersTable = () => { { node: 'item', name: t('注销'), - icon: , type: 'danger', onClick: () => { Modal.confirm({ @@ -247,7 +261,6 @@ const UsersTable = () => { moreMenuItems.splice(-1, 0, { node: 'item', name: t('禁用'), - icon: , type: 'warning', onClick: () => { manageUser(record.id, 'disable', record); @@ -257,7 +270,6 @@ const UsersTable = () => { moreMenuItems.splice(-1, 0, { node: 'item', name: t('启用'), - icon: , type: 'secondary', onClick: () => { manageUser(record.id, 'enable', record); @@ -269,11 +281,9 @@ const UsersTable = () => { return (
@@ -518,8 +537,7 @@ const UsersTable = () => { @@ -591,7 +608,7 @@ const UsersTable = () => { }, 100); } }} - className="!rounded-full flex-1 md:flex-initial md:w-auto" + className="flex-1 md:flex-initial md:w-auto" > {t('重置')} @@ -623,9 +640,9 @@ const UsersTable = () => { bordered={false} >
rest) : columns} dataSource={users} - scroll={{ x: 'max-content' }} + scroll={compactMode ? undefined : { x: 'max-content' }} pagination={{ formatPageText: (page) => t('第 {{start}} - {{end}} 条,共 {{total}} 条', { @@ -653,7 +670,7 @@ const UsersTable = () => { style={{ padding: 30 }} /> } - className="rounded-xl overflow-hidden" + className="overflow-hidden" size="middle" /> diff --git a/web/src/constants/channel.constants.js b/web/src/constants/channel.constants.js index 20fed5b7..b145ea11 100644 --- a/web/src/constants/channel.constants.js +++ b/web/src/constants/channel.constants.js @@ -125,4 +125,16 @@ export const CHANNEL_OPTIONS = [ color: 'blue', label: 'Coze', }, + { + value: 50, + color: 'green', + label: '可灵', + }, + { + value: 51, + color: 'blue', + label: '即梦', + }, ]; + +export const MODEL_TABLE_PAGE_SIZE = 10; diff --git a/web/src/constants/common.constant.js b/web/src/constants/common.constant.js index 1a37d5f6..6556ffef 100644 --- a/web/src/constants/common.constant.js +++ b/web/src/constants/common.constant.js @@ -1 +1,23 @@ export const ITEMS_PER_PAGE = 10; // this value must keep same as the one defined in backend! + +export const DEFAULT_ENDPOINT = '/api/ratio_config'; + +export const TABLE_COMPACT_MODES_KEY = 'table_compact_modes'; + +export const API_ENDPOINTS = [ + '/v1/chat/completions', + '/v1/responses', + '/v1/messages', + '/v1beta/models', + '/v1/embeddings', + '/v1/rerank', + '/v1/images/generations', + '/v1/images/edits', + '/v1/images/variations', + '/v1/audio/speech', + '/v1/audio/transcriptions', + '/v1/audio/translations' +]; + +export const TASK_ACTION_GENERATE = 'generate'; +export const TASK_ACTION_TEXT_GENERATE = 'textGenerate'; \ No newline at end of file diff --git a/web/src/helpers/api.js b/web/src/helpers/api.js index 5b9c03e1..a2f664ee 100644 --- a/web/src/helpers/api.js +++ b/web/src/helpers/api.js @@ -84,6 +84,7 @@ export const buildApiPayload = (messages, systemPrompt, inputs, parameterEnabled model: inputs.model, group: inputs.group, messages: processedMessages, + group: inputs.group, stream: inputs.stream, }; diff --git a/web/src/helpers/data.js b/web/src/helpers/data.js index bc1d28aa..afc29384 100644 --- a/web/src/helpers/data.js +++ b/web/src/helpers/data.js @@ -9,6 +9,7 @@ export function setStatusData(data) { localStorage.setItem('enable_task', data.enable_task); localStorage.setItem('enable_data_export', data.enable_data_export); localStorage.setItem('chats', JSON.stringify(data.chats)); + localStorage.setItem('pay_methods', JSON.stringify(data.pay_methods)); localStorage.setItem( 'data_export_default_time', data.data_export_default_time, diff --git a/web/src/helpers/render.js b/web/src/helpers/render.js index c9508203..6f00b914 100644 --- a/web/src/helpers/render.js +++ b/web/src/helpers/render.js @@ -883,7 +883,7 @@ function getEffectiveRatio(groupRatio, user_group_ratio) { ? i18next.t('专属倍率') : i18next.t('分组倍率'); const effectiveRatio = useUserGroupRatio ? user_group_ratio : groupRatio; - + return { ratio: effectiveRatio, label: ratioLabel, @@ -1074,25 +1074,25 @@ export function renderModelPrice( const extraServices = [ webSearch && webSearchCallCount > 0 ? i18next.t( - ' + Web搜索 {{count}}次 / 1K 次 * ${{price}} * {{ratioType}} {{ratio}}', - { - count: webSearchCallCount, - price: webSearchPrice, - ratio: groupRatio, - ratioType: ratioLabel, - }, - ) + ' + Web搜索 {{count}}次 / 1K 次 * ${{price}} * {{ratioType}} {{ratio}}', + { + count: webSearchCallCount, + price: webSearchPrice, + ratio: groupRatio, + ratioType: ratioLabel, + }, + ) : '', fileSearch && fileSearchCallCount > 0 ? i18next.t( - ' + 文件搜索 {{count}}次 / 1K 次 * ${{price}} * {{ratioType}} {{ratio}}', - { - count: fileSearchCallCount, - price: fileSearchPrice, - ratio: groupRatio, - ratioType: ratioLabel, - }, - ) + ' + 文件搜索 {{count}}次 / 1K 次 * ${{price}} * {{ratioType}} {{ratio}}', + { + count: fileSearchCallCount, + price: fileSearchPrice, + ratio: groupRatio, + ratioType: ratioLabel, + }, + ) : '', ].join(''); @@ -1281,10 +1281,10 @@ export function renderAudioModelPrice( let audioPrice = (audioInputTokens / 1000000) * inputRatioPrice * audioRatio * groupRatio + (audioCompletionTokens / 1000000) * - inputRatioPrice * - audioRatio * - audioCompletionRatio * - groupRatio; + inputRatioPrice * + audioRatio * + audioCompletionRatio * + groupRatio; let price = textPrice + audioPrice; return ( <> @@ -1340,27 +1340,27 @@ export function renderAudioModelPrice(

{cacheTokens > 0 ? i18next.t( - '文字提示 {{nonCacheInput}} tokens / 1M tokens * ${{price}} + 缓存 {{cacheInput}} tokens / 1M tokens * ${{cachePrice}} + 文字补全 {{completion}} tokens / 1M tokens * ${{compPrice}} = ${{total}}', - { - nonCacheInput: inputTokens - cacheTokens, - cacheInput: cacheTokens, - cachePrice: inputRatioPrice * cacheRatio, - price: inputRatioPrice, - completion: completionTokens, - compPrice: completionRatioPrice, - total: textPrice.toFixed(6), - }, - ) + '文字提示 {{nonCacheInput}} tokens / 1M tokens * ${{price}} + 缓存 {{cacheInput}} tokens / 1M tokens * ${{cachePrice}} + 文字补全 {{completion}} tokens / 1M tokens * ${{compPrice}} = ${{total}}', + { + nonCacheInput: inputTokens - cacheTokens, + cacheInput: cacheTokens, + cachePrice: inputRatioPrice * cacheRatio, + price: inputRatioPrice, + completion: completionTokens, + compPrice: completionRatioPrice, + total: textPrice.toFixed(6), + }, + ) : i18next.t( - '文字提示 {{input}} tokens / 1M tokens * ${{price}} + 文字补全 {{completion}} tokens / 1M tokens * ${{compPrice}} = ${{total}}', - { - input: inputTokens, - price: inputRatioPrice, - completion: completionTokens, - compPrice: completionRatioPrice, - total: textPrice.toFixed(6), - }, - )} + '文字提示 {{input}} tokens / 1M tokens * ${{price}} + 文字补全 {{completion}} tokens / 1M tokens * ${{compPrice}} = ${{total}}', + { + input: inputTokens, + price: inputRatioPrice, + completion: completionTokens, + compPrice: completionRatioPrice, + total: textPrice.toFixed(6), + }, + )}

{i18next.t( @@ -1397,7 +1397,7 @@ export function renderQuotaWithPrompt(quota, digits) { displayInCurrency = displayInCurrency === 'true'; if (displayInCurrency) { return ( - ' | ' + i18next.t('等价金额') + ': ' + renderQuota(quota, digits) + '' + i18next.t('等价金额:') + renderQuota(quota, digits) ); } return ''; @@ -1499,35 +1499,35 @@ export function renderClaudeModelPrice(

{cacheTokens > 0 || cacheCreationTokens > 0 ? i18next.t( - '提示 {{nonCacheInput}} tokens / 1M tokens * ${{price}} + 缓存 {{cacheInput}} tokens / 1M tokens * ${{cachePrice}} + 缓存创建 {{cacheCreationInput}} tokens / 1M tokens * ${{cacheCreationPrice}} + 补全 {{completion}} tokens / 1M tokens * ${{compPrice}} * {{ratioType}} {{ratio}} = ${{total}}', - { - nonCacheInput: nonCachedTokens, - cacheInput: cacheTokens, - cacheRatio: cacheRatio, - cacheCreationInput: cacheCreationTokens, - cacheCreationRatio: cacheCreationRatio, - cachePrice: cacheRatioPrice, - cacheCreationPrice: cacheCreationRatioPrice, - price: inputRatioPrice, - completion: completionTokens, - compPrice: completionRatioPrice, - ratio: groupRatio, - ratioType: ratioLabel, - total: price.toFixed(6), - }, - ) + '提示 {{nonCacheInput}} tokens / 1M tokens * ${{price}} + 缓存 {{cacheInput}} tokens / 1M tokens * ${{cachePrice}} + 缓存创建 {{cacheCreationInput}} tokens / 1M tokens * ${{cacheCreationPrice}} + 补全 {{completion}} tokens / 1M tokens * ${{compPrice}} * {{ratioType}} {{ratio}} = ${{total}}', + { + nonCacheInput: nonCachedTokens, + cacheInput: cacheTokens, + cacheRatio: cacheRatio, + cacheCreationInput: cacheCreationTokens, + cacheCreationRatio: cacheCreationRatio, + cachePrice: cacheRatioPrice, + cacheCreationPrice: cacheCreationRatioPrice, + price: inputRatioPrice, + completion: completionTokens, + compPrice: completionRatioPrice, + ratio: groupRatio, + ratioType: ratioLabel, + total: price.toFixed(6), + }, + ) : i18next.t( - '提示 {{input}} tokens / 1M tokens * ${{price}} + 补全 {{completion}} tokens / 1M tokens * ${{compPrice}} * {{ratioType}} {{ratio}} = ${{total}}', - { - input: inputTokens, - price: inputRatioPrice, - completion: completionTokens, - compPrice: completionRatioPrice, - ratio: groupRatio, - ratioType: ratioLabel, - total: price.toFixed(6), - }, - )} + '提示 {{input}} tokens / 1M tokens * ${{price}} + 补全 {{completion}} tokens / 1M tokens * ${{compPrice}} * {{ratioType}} {{ratio}} = ${{total}}', + { + input: inputTokens, + price: inputRatioPrice, + completion: completionTokens, + compPrice: completionRatioPrice, + ratio: groupRatio, + ratioType: ratioLabel, + total: price.toFixed(6), + }, + )}

{i18next.t('仅供参考,以实际扣费为准')}

diff --git a/web/src/helpers/utils.js b/web/src/helpers/utils.js index 56e1104d..68a05846 100644 --- a/web/src/helpers/utils.js +++ b/web/src/helpers/utils.js @@ -3,6 +3,7 @@ import { toastConstants } from '../constants'; import React from 'react'; import { toast } from 'react-toastify'; import { THINK_TAG_REGEX, MESSAGE_ROLES } from '../constants/playground.constants'; +import { TABLE_COMPACT_MODES_KEY } from '../constants'; const HTMLToastContent = ({ htmlContent }) => { return
; @@ -509,3 +510,31 @@ export const formatDateTimeString = (date) => { const minutes = String(date.getMinutes()).padStart(2, '0'); return `${year}-${month}-${day} ${hours}:${minutes}`; }; + +function readTableCompactModes() { + try { + const json = localStorage.getItem(TABLE_COMPACT_MODES_KEY); + return json ? JSON.parse(json) : {}; + } catch { + return {}; + } +} + +function writeTableCompactModes(modes) { + try { + localStorage.setItem(TABLE_COMPACT_MODES_KEY, JSON.stringify(modes)); + } catch { + // ignore + } +} + +export function getTableCompactMode(tableKey = 'global') { + const modes = readTableCompactModes(); + return !!modes[tableKey]; +} + +export function setTableCompactMode(compact, tableKey = 'global') { + const modes = readTableCompactModes(); + modes[tableKey] = compact; + writeTableCompactModes(modes); +} diff --git a/web/src/hooks/useTableCompactMode.js b/web/src/hooks/useTableCompactMode.js new file mode 100644 index 00000000..f943bda7 --- /dev/null +++ b/web/src/hooks/useTableCompactMode.js @@ -0,0 +1,34 @@ +import { useState, useEffect, useCallback } from 'react'; +import { getTableCompactMode, setTableCompactMode } from '../helpers'; +import { TABLE_COMPACT_MODES_KEY } from '../constants'; + +/** + * 自定义 Hook:管理表格紧凑/自适应模式 + * 返回 [compactMode, setCompactMode]。 + * 内部使用 localStorage 保存状态,并监听 storage 事件保持多标签页同步。 + */ +export function useTableCompactMode(tableKey = 'global') { + const [compactMode, setCompactModeState] = useState(() => getTableCompactMode(tableKey)); + + const setCompactMode = useCallback((value) => { + setCompactModeState(value); + setTableCompactMode(value, tableKey); + }, [tableKey]); + + useEffect(() => { + const handleStorage = (e) => { + if (e.key === TABLE_COMPACT_MODES_KEY) { + try { + const modes = JSON.parse(e.newValue || '{}'); + setCompactModeState(!!modes[tableKey]); + } catch { + // ignore parse error + } + } + }; + window.addEventListener('storage', handleStorage); + return () => window.removeEventListener('storage', handleStorage); + }, [tableKey]); + + return [compactMode, setCompactMode]; +} \ No newline at end of file diff --git a/web/src/i18n/locales/en.json b/web/src/i18n/locales/en.json index ba23ca5c..80f7f3cd 100644 --- a/web/src/i18n/locales/en.json +++ b/web/src/i18n/locales/en.json @@ -139,7 +139,7 @@ "已成功开始测试所有已启用通道,请刷新页面查看结果。": "Successfully started testing all enabled channels. Please refresh page to view results.", "通道 ${name} 余额更新成功!": "Channel ${name} quota updated successfully!", "已更新完毕所有已启用通道余额!": "Updated quota for all enabled channels!", - "搜索渠道的 ID,名称,密钥和API地址 ...": "Search channel ID, name, key and Base URL...", + "渠道ID,名称,密钥,API地址": "Channel ID, name, key, Base URL", "名称": "Name", "分组": "Group", "类型": "Type", @@ -397,7 +397,7 @@ "删除用户": "Delete User", "添加新的用户": "Add New User", "自定义": "Custom", - "等价金额": "Equivalent Amount", + "等价金额:": "Equivalent Amount: ", "未登录或登录已过期,请重新登录": "Not logged in or login has expired, please log in again", "请求次数过多,请稍后再试": "Too many requests, please try again later", "服务器内部错误,请联系管理员": "Server internal error, please contact the administrator", @@ -428,6 +428,7 @@ "填入基础模型": "Fill in the basic model", "填入所有模型": "Fill in all models", "清除所有模型": "Clear all models", + "复制所有模型": "Copy all models", "密钥": "Key", "请输入密钥": "Please enter the key", "批量创建": "Batch Create", @@ -455,8 +456,8 @@ "创建新的令牌": "Create New Token", "令牌分组,默认为用户的分组": "Token group, default is the your's group", "IP白名单": "IP whitelist", - "注意,令牌的额度仅用于限制令牌本身的最大额度使用量,实际的使用受到账户的剩余额度限制。": "Note that the quota of the token is only used to limit the maximum quota usage of the token itself, and the actual usage is limited by the remaining quota of the account.", - "设为无限额度": "Set to unlimited quota", + "令牌的额度仅用于限制令牌本身的最大额度使用量,实际的使用受到账户的剩余额度限制": "The quota of the token is only used to limit the maximum quota usage of the token itself, and the actual usage is limited by the remaining quota of the account", + "无限额度": "Unlimited quota", "更新令牌信息": "Update Token Information", "请输入充值码!": "Please enter the recharge code!", "请输入名称": "Please enter a name", @@ -470,10 +471,11 @@ "请输入新的密码": "Please enter a new password", "显示名称": "Display Name", "请输入新的显示名称": "Please enter a new display name", - "已绑定的 GitHub 账户": "GitHub Account Bound", - "此项只读,要用户通过个人设置页面的相关绑��按钮进��绑���,不可直接修改": "This item is read-only. Users need to bind through the relevant binding button on the personal settings page, and cannot be modified directly", - "已绑定的微信账户": "WeChat Account Bound", - "已绑定的邮箱账户": "Email Account Bound", + "已绑定的 GITHUB 账户": "Bound GitHub Account", + "已绑定的 WECHAT 账户": "Bound WeChat Account", + "已绑定的 EMAIL 账户": "Bound Email Account", + "已绑定的 TELEGRAM 账户": "Bound Telegram Account", + "此项只读,要用户通过个人设置页面的相关绑定按钮进行绑定,不可直接修改": "This item is read-only. Users need to bind through the relevant binding button on the personal settings page, and cannot be modified directly", "用户信息更新成功!": "User information updated successfully!", "使用明细(总消耗额度:{renderQuota(stat.quota)})": "Usage Details (Total Consumption Quota: {renderQuota(stat.quota)})", "用户名称": "User Name", @@ -515,7 +517,6 @@ "注意,系统请求的时模型名称中的点会被剔除,例如:gpt-4.1会请求为gpt-41,所以在Azure部署的时候,部署模型名称需要手动改为gpt-41": "Note that the dot in the model name requested by the system will be removed, for example: gpt-4.1 will be requested as gpt-41, so when deploying on Azure, the deployment model name needs to be manually changed to gpt-41", "2025年5月10日后添加的渠道,不需要再在部署的时候移除模型名称中的\".\"": "After May 10, 2025, channels added do not need to remove the dot in the model name during deployment", "模型映射必须是合法的 JSON 格式!": "Model mapping must be in valid JSON format!", - "取消无限额度": "Cancel unlimited quota", "取消": "Cancel", "重置": "Reset", "请输入新的剩余额度": "Please enter the new remaining quota", @@ -800,6 +801,7 @@ "获取无水印": "Get no watermark", "生成图片": "Generate pictures", "可灵": "Kling", + "即梦": "Jimeng", "正在提交": "Submitting", "执行中": "processing", "平台": "platform", @@ -813,7 +815,16 @@ "复制所选令牌": "Copy selected token", "请至少选择一个令牌!": "Please select at least one token!", "管理员未设置查询页链接": "The administrator has not set the query page link", - "复制所选令牌到剪贴板": "Copy selected token to clipboard", + "批量删除令牌": "Batch delete token", + "确定要删除所选的 {{count}} 个令牌吗?": "Are you sure you want to delete the selected {{count}} tokens?", + "删除所选令牌": "Delete selected token", + "请先选择要删除的令牌!": "Please select the token to be deleted!", + "已删除 {{count}} 个令牌!": "Deleted {{count}} tokens!", + "删除失败": "Delete failed", + "复制令牌": "Copy token", + "请选择你的复制方式": "Please select your copy method", + "名称+密钥": "Name + key", + "仅密钥": "Only key", "查看API地址": "View API address", "打开查询页": "Open query page", "时间(仅显示近3天)": "Time (only displays the last 3 days)", @@ -865,7 +876,7 @@ "加载token失败": "Failed to load token", "配置聊天": "Configure chat", "模型消耗分布": "Model consumption distribution", - "模型调用次数占比": "Proportion of model calls", + "模型调用次数占比": "Model call ratio", "用户消耗分布": "User consumption distribution", "时间粒度": "Time granularity", "天": "day", @@ -1108,6 +1119,10 @@ "平均TPM": "Average TPM", "消耗分布": "Consumption distribution", "调用次数分布": "Models call distribution", + "消耗趋势": "Consumption trend", + "模型消耗趋势": "Model consumption trend", + "调用次数排行": "Models call ranking", + "模型调用次数排行": "Model call ranking", "添加渠道": "Add channel", "测试所有通道": "Test all channels", "删除禁用通道": "Delete disabled channels", @@ -1132,8 +1147,8 @@ "默认测试模型": "Default Test Model", "不填则为模型列表第一个": "First model in list if empty", "是否自动禁用(仅当自动禁用开启时有效),关闭后不会自动禁用该渠道": "Auto-disable (only effective when auto-disable is enabled). When turned off, this channel will not be automatically disabled", - "状态码复写(仅影响本地判断,不修改返回到上游的状态码)": "Status Code Override (only affects local judgment, does not modify status code returned upstream)", - "此项可选,用于复写返回的状态码,比如将claude渠道的400错误复写为500(用于重试),请勿滥用该功能,例如:": "Optional, used to override returned status codes, e.g. rewriting Claude channel's 400 error to 500 (for retry). Do not abuse this feature. Example:", + "状态码复写": "Status Code Override", + "此项可选,用于复写返回的状态码,仅影响本地判断,不修改返回到上游的状态码,比如将claude渠道的400错误复写为500(用于重试),请勿滥用该功能,例如:": "Optional, used to override returned status codes, only affects local judgment, does not modify status code returned upstream, e.g. rewriting Claude channel's 400 error to 500 (for retry). Do not abuse this feature. Example:", "渠道标签": "Channel Tag", "渠道优先级": "Channel Priority", "渠道权重": "Channel Weight", @@ -1188,7 +1203,7 @@ "添加用户": "Add user", "角色": "Role", "已绑定的 Telegram 账户": "Bound Telegram account", - "新额度": "New quota", + "新额度:": "New quota: ", "需要添加的额度(支持负数)": "Need to add quota (supports negative numbers)", "此项只读,需要用户通过个人设置页面的相关绑定按钮进行绑定,不可直接修改": "Read-only, user's personal settings, and cannot be modified directly", "请输入新的密码,最短 8 位": "Please enter a new password, at least 8 characterss", @@ -1206,7 +1221,7 @@ "默认折叠侧边栏": "Default collapse sidebar", "聊天链接功能已经弃用,请使用下方聊天设置功能": "Chat link function has been deprecated, please use the chat settings below", "你似乎并没有修改什么": "You seem to have not modified anything", - "令牌聊天设置": "Chat settings", + "聊天设置": "Chat settings", "必须将上方聊天链接全部设置为空,才能使用下方聊天设置功能": "Must set all chat links above to empty to use the chat settings below", "链接中的{key}将自动替换为sk-xxxx,{address}将自动替换为系统设置的服务器地址,末尾不带/和/v1": "The {key} in the link will be automatically replaced with sk-xxxx, the {address} will be automatically replaced with the server address in system settings, and the end will not have / and /v1", "聊天配置": "Chat configuration", @@ -1263,7 +1278,7 @@ " 吗?": "?", "修改子渠道优先级": "Modify sub-channel priority", "确定要修改所有子渠道优先级为 ": "Confirm to modify all sub-channel priorities to ", - "分组设置": "Group settings", + "分组倍率设置": "Group ratio settings", "用户可选分组": "User selectable groups", "保存分组倍率设置": "Save group ratio settings", "模型倍率设置": "Model ratio settings", @@ -1373,6 +1388,12 @@ "示例": "Example", "缺省 MaxTokens": "Default MaxTokens", "启用Claude思考适配(-thinking后缀)": "Enable Claude thinking adaptation (-thinking suffix)", + "和Claude不同,默认情况下Gemini的思考模型会自动决定要不要思考,就算不开启适配模型也可以正常使用,": "Unlike Claude, Gemini's thinking model automatically decides whether to think by default, and can be used normally even without enabling the adaptation model.", + "如果您需要计费,推荐设置无后缀模型价格按思考价格设置。": "If you need billing, it is recommended to set the no-suffix model price according to the thinking price.", + "支持使用 gemini-2.5-pro-preview-06-05-thinking-128 格式来精确传递思考预算。": "Supports using gemini-2.5-pro-preview-06-05-thinking-128 format to precisely pass thinking budget.", + "启用Gemini思考后缀适配": "Enable Gemini thinking suffix adaptation", + "适配-thinking、-thinking-预算数字和-nothinking后缀": "Adapt -thinking, -thinking-budgetNumber, and -nothinking suffixes", + "思考预算占比": "Thinking budget ratio", "Claude思考适配 BudgetTokens = MaxTokens * BudgetTokens 百分比": "Claude thinking adaptation BudgetTokens = MaxTokens * BudgetTokens percentage", "思考适配 BudgetTokens 百分比": "Thinking adaptation BudgetTokens percentage", "0.1-1之间的小数": "Decimal between 0.1 and 1", @@ -1406,8 +1427,8 @@ "初始化系统": "Initialize system", "支持众多的大模型供应商": "Supporting various LLM providers", "统一的大模型接口网关": "The Unified LLMs API Gateway", - "更好的价格,更好的稳定性,无需订阅": "Better price, better stability, no subscription required", - "开始使用": "Get Started", + "更好的价格,更好的稳定性,只需要将模型基址替换为:": "Better price, better stability, no subscription required, just replace the model BASE URL with: ", + "获取密钥": "Get Key", "关于我们": "About Us", "关于项目": "About Project", "联系我们": "Contact Us", @@ -1443,7 +1464,8 @@ "访问限制": "Access Restrictions", "设置令牌的访问限制": "Set token access restrictions", "请勿过度信任此功能,IP可能被伪造": "Do not over-trust this feature, IP can be spoofed", - "勾选启用模型限制后可选择": "Select after checking to enable model restrictions", + "模型限制列表": "Model restrictions list", + "请选择该令牌支持的模型,留空支持所有模型": "Select models supported by the token, leave blank to support all models", "非必要,不建议启用模型限制": "Not necessary, model restrictions are not recommended", "分组信息": "Group Information", "设置令牌的分组": "Set token grouping", @@ -1582,7 +1604,7 @@ "性能指标": "Performance Indicators", "模型数据分析": "Model Data Analysis", "搜索无结果": "No results found", - "仪表盘配置": "Dashboard Configuration", + "仪表盘设置": "Dashboard Settings", "API信息管理,可以配置多个API地址用于状态展示和负载均衡(最多50个)": "API information management, you can configure multiple API addresses for status display and load balancing (maximum 50)", "线路描述": "Route description", "颜色": "Color", @@ -1609,6 +1631,7 @@ "编辑公告": "Edit Notice", "公告内容": "Notice Content", "请输入公告内容": "Please enter the notice content", + "请输入公告内容(支持 Markdown/HTML)": "Please enter the notice content (supports Markdown/HTML)", "发布日期": "Publish Date", "请选择发布日期": "Please select the publish date", "发布时间": "Publish Time", @@ -1624,6 +1647,7 @@ "请输入问题标题": "Please enter the question title", "回答内容": "Answer Content", "请输入回答内容": "Please enter the answer content", + "请输入回答内容(支持 Markdown/HTML)": "Please enter the answer content (supports Markdown/HTML)", "确定要删除此问答吗?": "Are you sure you want to delete this FAQ?", "系统公告管理,可以发布系统通知和重要消息(最多100个,前端显示最新20条)": "System notice management, you can publish system notices and important messages (maximum 100, display latest 20 on the front end)", "常见问答管理,为用户提供常见问题的答案(最多50个,前端显示最新20条)": "FAQ management, providing answers to common questions for users (maximum 50, display latest 20 on the front end)", @@ -1658,5 +1682,79 @@ "清除失效兑换码": "Clear invalid redemption codes", "确定清除所有失效兑换码?": "Are you sure you want to clear all invalid redemption codes?", "将删除已使用、已禁用及过期的兑换码,此操作不可撤销。": "This will delete all used, disabled, and expired redemption codes, this operation cannot be undone.", - "选择过期时间(可选,留空为永久)": "Select expiration time (optional, leave blank for permanent)" + "选择过期时间(可选,留空为永久)": "Select expiration time (optional, leave blank for permanent)", + "请输入备注(仅管理员可见)": "Please enter a remark (only visible to administrators)", + "上游倍率同步": "Upstream ratio synchronization", + "获取渠道失败:": "Failed to get channels: ", + "请至少选择一个渠道": "Please select at least one channel", + "获取倍率失败:": "Failed to get ratios: ", + "后端请求失败": "Backend request failed", + "部分渠道测试失败:": "Some channels failed to test: ", + "未找到差异化倍率,无需同步": "No differential ratio found, no synchronization is required", + "请求后端接口失败:": "Failed to request the backend interface: ", + "同步成功": "Synchronization successful", + "部分保存失败": "Some settings failed to save", + "保存失败": "Save failed", + "选择同步渠道": "Select synchronization channel", + "应用同步": "Apply synchronization", + "倍率类型": "Ratio type", + "当前值": "Current value", + "上游值": "Upstream value", + "差异": "Difference", + "搜索渠道名称或地址": "Search channel name or address", + "缓存倍率": "Cache ratio", + "暂无差异化倍率显示": "No differential ratio display", + "请先选择同步渠道": "Please select the synchronization channel first", + "与本地相同": "Same as local", + "未找到匹配的模型": "No matching model found", + "暴露倍率接口": "Expose ratio API", + "支付设置": "Payment Settings", + "(当前仅支持易支付接口,默认使用上方服务器地址作为回调地址!)": "(Currently only supports Epay interface, the default callback address is the server address above!)", + "支付地址": "Payment address", + "易支付商户ID": "Epay merchant ID", + "易支付商户密钥": "Epay merchant key", + "回调地址": "Callback address", + "充值价格(x元/美金)": "Recharge price (x yuan/dollar)", + "最低充值美元数量": "Minimum recharge dollar amount", + "充值分组倍率": "Recharge group ratio", + "充值方式设置": "Recharge method settings", + "更新支付设置": "Update payment settings", + "通知": "Notice", + "源地址": "Source address", + "同步接口": "Synchronization interface", + "置信度": "Confidence", + "谨慎": "Cautious", + "该数据可能不可信,请谨慎使用": "This data may not be reliable, please use with caution", + "可信": "Reliable", + "所有上游数据均可信": "All upstream data is reliable", + "以下上游数据可能不可信:": "The following upstream data may not be reliable: ", + "按倍率类型筛选": "Filter by ratio type", + "内容": "Content", + "放大编辑": "Expand editor", + "编辑公告内容": "Edit announcement content", + "自适应列表": "Adaptive list", + "紧凑列表": "Compact list", + "仅显示矛盾倍率": "Only show conflicting ratios", + "矛盾": "Conflict", + "确认冲突项修改": "Confirm conflict item modification", + "该模型存在固定价格与倍率计费方式冲突,请确认选择": "The model has a fixed price and ratio billing method conflict, please confirm the selection", + "当前计费": "Current billing", + "修改为": "Modify to", + "状态筛选": "Status filter", + "没有模型可以复制": "No models to copy", + "模型列表已复制到剪贴板": "Model list copied to clipboard", + "复制失败": "Copy failed", + "复制已选": "Copy selected", + "选择成功": "Selection successful", + "暂无成功模型": "No successful models", + "请先选择模型!": "Please select a model first!", + "已复制 ${count} 个模型": "Copied ${count} models", + "复制失败,请手动复制": "Copy failed, please copy manually", + "过期时间快捷设置": "Expiration time quick settings", + "批量创建时会在名称后自动添加随机后缀": "When creating in batches, a random suffix will be automatically added to the name", + "额度必须大于0": "Quota must be greater than 0", + "生成数量必须大于0": "Generation quantity must be greater than 0", + "创建后可在编辑渠道时获取上游模型列表": "After creation, you can get the upstream model list when editing the channel", + "可用端点类型": "Supported endpoint types", + "未登录,使用默认分组倍率:": "Not logged in, using default group ratio: " } \ No newline at end of file diff --git a/web/src/index.css b/web/src/index.css index c1254fcc..8e71536a 100644 --- a/web/src/index.css +++ b/web/src/index.css @@ -43,6 +43,7 @@ code { /* ==================== 导航和侧边栏样式 ==================== */ /* 导航项样式 */ +.semi-input-textarea-wrapper, .semi-navigation-sub-title, .semi-chat-inputBox-sendButton, .semi-page-item, @@ -53,7 +54,7 @@ code { .semi-select, .semi-button, .semi-datepicker-range-input { - border-radius: 9999px !important; + border-radius: 10px !important; } .semi-navigation-item { @@ -375,6 +376,7 @@ code { } /* 隐藏卡片内容区域的滚动条 */ +.model-test-scroll, .card-content-scroll, .model-settings-scroll, .thinking-content-scroll, @@ -385,6 +387,7 @@ code { scrollbar-width: none; } +.model-test-scroll::-webkit-scrollbar, .card-content-scroll::-webkit-scrollbar, .model-settings-scroll::-webkit-scrollbar, .thinking-content-scroll::-webkit-scrollbar, @@ -432,4 +435,162 @@ code { .semi-table-tbody>.semi-table-row { border-bottom: 1px solid rgba(0, 0, 0, 0.1); } +} + +/* ==================== 同步倍率 - 渠道选择器 ==================== */ + +.components-transfer-source-item, +.components-transfer-selected-item { + display: flex; + align-items: center; + padding: 8px; +} + +.semi-transfer-left-list, +.semi-transfer-right-list { + -ms-overflow-style: none; + scrollbar-width: none; +} + +.semi-transfer-left-list::-webkit-scrollbar, +.semi-transfer-right-list::-webkit-scrollbar { + display: none; +} + +.components-transfer-source-item .semi-checkbox, +.components-transfer-selected-item .semi-checkbox { + display: flex; + align-items: center; + width: 100%; +} + +.components-transfer-source-item .semi-avatar, +.components-transfer-selected-item .semi-avatar { + margin-right: 12px; + flex-shrink: 0; +} + +.components-transfer-source-item .info, +.components-transfer-selected-item .info { + flex: 1; + overflow: hidden; + display: flex; + flex-direction: column; + justify-content: center; +} + +.components-transfer-source-item .name, +.components-transfer-selected-item .name { + font-weight: 500; + white-space: nowrap; + overflow: hidden; + text-overflow: ellipsis; +} + +.components-transfer-source-item .email, +.components-transfer-selected-item .email { + font-size: 12px; + color: var(--semi-color-text-2); + display: flex; + align-items: center; +} + +.components-transfer-selected-item .semi-icon-close { + margin-left: 8px; + cursor: pointer; + color: var(--semi-color-text-2); +} + +.components-transfer-selected-item .semi-icon-close:hover { + color: var(--semi-color-text-0); +} + +/* ==================== 未读通知闪光效果 ==================== */ +@keyframes sweep-shine { + 0% { + background-position: 200% 0; + } + + 100% { + background-position: -200% 0; + } +} + +.shine-text { + background: linear-gradient(90deg, currentColor 0%, currentColor 40%, rgba(255, 255, 255, 0.9) 50%, currentColor 60%, currentColor 100%); + background-size: 200% 100%; + -webkit-background-clip: text; + background-clip: text; + -webkit-text-fill-color: transparent; + animation: sweep-shine 4s linear infinite; +} + +.dark .shine-text { + background: linear-gradient(90deg, currentColor 0%, currentColor 40%, #facc15 50%, currentColor 60%, currentColor 100%); + background-size: 200% 100%; + -webkit-background-clip: text; + background-clip: text; + -webkit-text-fill-color: transparent; +} + +/* ==================== ScrollList 定制样式 ==================== */ +.semi-scrolllist, +.semi-scrolllist * { + -ms-overflow-style: none; + /* IE, Edge */ + scrollbar-width: none; + /* Firefox */ + background: transparent !important; +} + +.semi-scrolllist::-webkit-scrollbar, +.semi-scrolllist *::-webkit-scrollbar { + width: 0 !important; + height: 0 !important; + display: none !important; +} + +.semi-scrolllist-body { + padding: 1px !important; +} + +.semi-scrolllist-list-outer { + padding-right: 0 !important; +} + +/* ==================== Banner 背景模糊球 ==================== */ +.blur-ball { + position: absolute; + width: 360px; + height: 360px; + border-radius: 50%; + filter: blur(120px); + pointer-events: none; + z-index: -1; +} + +.blur-ball-indigo { + background: #6366f1; + /* indigo-500 */ + top: 40px; + left: 50%; + transform: translateX(-50%); + opacity: 0.5; +} + +.blur-ball-teal { + background: #14b8a6; + /* teal-400 */ + top: 200px; + left: 30%; + opacity: 0.4; +} + +/* 浅色主题下让模糊球更柔和 */ +html:not(.dark) .blur-ball-indigo { + opacity: 0.25; +} + +html:not(.dark) .blur-ball-teal { + opacity: 0.2; } \ No newline at end of file diff --git a/web/src/index.js b/web/src/index.js index ef8a3a07..ef299ea2 100644 --- a/web/src/index.js +++ b/web/src/index.js @@ -5,7 +5,6 @@ import '@douyinfe/semi-ui/dist/css/semi.css'; import { UserProvider } from './context/User'; import 'react-toastify/dist/ReactToastify.css'; import { StatusProvider } from './context/Status'; -import { Layout } from '@douyinfe/semi-ui'; import { ThemeProvider } from './context/Theme'; import { StyleProvider } from './context/Style/index.js'; import PageLayout from './components/layout/PageLayout.js'; @@ -15,7 +14,6 @@ import './index.css'; // initialization const root = ReactDOM.createRoot(document.getElementById('root')); -const { Sider, Content, Header, Footer } = Layout; root.render( diff --git a/web/src/pages/About/index.js b/web/src/pages/About/index.js index 3259449e..032562ca 100644 --- a/web/src/pages/About/index.js +++ b/web/src/pages/About/index.js @@ -105,7 +105,7 @@ const About = () => { ); return ( - <> +
{aboutLoaded && about === '' ? (
{ )} )} - +
); }; diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index 32d2ce49..cfed54e4 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -1,4 +1,4 @@ -import React, { useEffect, useState } from 'react'; +import React, { useEffect, useState, useRef } from 'react'; import { useNavigate } from 'react-router-dom'; import { useTranslation } from 'react-i18next'; import { @@ -15,19 +15,19 @@ import { Space, Spin, Button, - Input, Typography, - Select, - TextArea, Checkbox, Banner, Modal, ImagePreview, Card, Tag, - Upload, + Avatar, + Form, + Row, + Col, } from '@douyinfe/semi-ui'; -import { getChannelModels } from '../../helpers'; +import { getChannelModels, copy } from '../../helpers'; import { IconSave, IconClose, @@ -35,7 +35,6 @@ import { IconSetting, IconCode, IconGlobe, - IconBolt, } from '@douyinfe/semi-icons'; const { Text, Title } = Typography; @@ -66,6 +65,10 @@ function type2secretPrompt(type) { return '按照如下格式输入:AppId|SecretId|SecretKey'; case 33: return '按照如下格式输入:Ak|Sk|Region'; + case 50: + return '按照如下格式输入: AccessKey|SecretKey'; + case 51: + return '按照如下格式输入: Access Key ID|Secret Access Key'; default: return '请输入渠道对应的鉴权密钥'; } @@ -99,9 +102,8 @@ const EditChannel = (props) => { tag: '', }; const [batch, setBatch] = useState(false); - const [mergeToSingle, setMergeToSingle] = useState(false); const [autoBan, setAutoBan] = useState(true); - const [jsonFiles, setJsonFiles] = useState([]); + // const [autoBan, setAutoBan] = useState(true); const [inputs, setInputs] = useState(originInputs); const [originModelOptions, setOriginModelOptions] = useState([]); const [modelOptions, setModelOptions] = useState([]); @@ -111,7 +113,16 @@ const EditChannel = (props) => { const [customModel, setCustomModel] = useState(''); const [modalImageUrl, setModalImageUrl] = useState(''); const [isModalOpenurl, setIsModalOpenurl] = useState(false); + const formApiRef = useRef(null); + const getInitValues = () => ({ ...originInputs }); const handleInputChange = (name, value) => { + if (formApiRef.current) { + formApiRef.current.setValue(name, value); + } + if (name === 'models' && Array.isArray(value)) { + value = Array.from(new Set(value.map((m) => (m || '').trim()))); + } + if (name === 'base_url' && value.endsWith('/v1')) { Modal.confirm({ title: '警告', @@ -142,6 +153,8 @@ const EditChannel = (props) => { localModels = [ 'swap_face', 'mj_imagine', + 'mj_video', + 'mj_edits', 'mj_variation', 'mj_reroll', 'mj_blend', @@ -199,6 +212,9 @@ const EditChannel = (props) => { ); } setInputs(data); + if (formApiRef.current) { + formApiRef.current.setValues(data); + } if (data.auto_ban === 0) { setAutoBan(false); } else { @@ -266,10 +282,14 @@ const EditChannel = (props) => { const fetchModels = async () => { try { let res = await API.get(`/api/channel/models`); - let localModelOptions = res.data.data.map((model) => ({ - label: model.id, - value: model.id, - })); + const localModelOptions = res.data.data.map((model) => { + const id = (model.id || '').trim(); + return { + key: id, + label: id, + value: id, + }; + }); setOriginModelOptions(localModelOptions); setFullModels(res.data.data.map((model) => model.id)); setBasicModels( @@ -302,56 +322,77 @@ const EditChannel = (props) => { }; useEffect(() => { - let localModelOptions = [...originModelOptions]; - inputs.models.forEach((model) => { - if (!localModelOptions.find((option) => option.label === model)) { - localModelOptions.push({ - label: model, - value: model, + const modelMap = new Map(); + + originModelOptions.forEach(option => { + const v = (option.value || '').trim(); + if (!modelMap.has(v)) { + modelMap.set(v, option); + } + }); + + inputs.models.forEach(model => { + const v = (model || '').trim(); + if (!modelMap.has(v)) { + modelMap.set(v, { + key: v, + label: v, + value: v, }); } }); - setModelOptions(localModelOptions); + + setModelOptions(Array.from(modelMap.values())); }, [originModelOptions, inputs.models]); useEffect(() => { fetchModels().then(); fetchGroups().then(); - if (isEdit) { - loadChannel().then(() => { }); - } else { + if (!isEdit) { setInputs(originInputs); + if (formApiRef.current) { + formApiRef.current.setValues(originInputs); + } let localModels = getChannelModels(inputs.type); setBasicModels(localModels); setInputs((inputs) => ({ ...inputs, models: localModels })); } }, [props.editingChannel.id]); - const submit = async () => { - if (!isEdit) { - if (inputs.name === '') { - showInfo(t('请填写渠道名称!')); - return; - } - if (inputs.type === 41 && batch) { - if (jsonFiles.length === 0) { - showInfo(t('请至少选择一个 JSON 凭证文件!')); - return; - } - } else if (inputs.key === '') { - showInfo(t('请填写渠道密钥!')); - return; - } + useEffect(() => { + if (formApiRef.current) { + formApiRef.current.setValues(inputs); } - if (inputs.models.length === 0) { + }, [inputs]); + + useEffect(() => { + if (props.visible) { + if (isEdit) { + loadChannel(); + } else { + formApiRef.current?.setValues(getInitValues()); + } + } else { + formApiRef.current?.reset(); + } + }, [props.visible, channelId]); + + const submit = async () => { + const formValues = formApiRef.current ? formApiRef.current.getValues() : {}; + let localInputs = { ...formValues }; + + if (!isEdit && (!localInputs.name || !localInputs.key)) { + showInfo(t('请填写渠道名称和渠道密钥!')); + return; + } + if (!Array.isArray(localInputs.models) || localInputs.models.length === 0) { showInfo(t('请至少选择一个模型!')); return; } - if (inputs.model_mapping !== '' && !verifyJSON(inputs.model_mapping)) { + if (localInputs.model_mapping && localInputs.model_mapping !== '' && !verifyJSON(localInputs.model_mapping)) { showInfo(t('模型映射必须是合法的 JSON 格式!')); return; } - let localInputs = { ...inputs }; if (localInputs.base_url && localInputs.base_url.endsWith('/')) { localInputs.base_url = localInputs.base_url.slice( 0, @@ -362,40 +403,16 @@ const EditChannel = (props) => { localInputs.other = 'v2.1'; } let res; - if (!Array.isArray(localInputs.models)) { - showError(t('提交失败,请勿重复提交!')); - handleCancel(); - return; - } - localInputs.auto_ban = autoBan ? 1 : 0; + localInputs.auto_ban = localInputs.auto_ban ? 1 : 0; localInputs.models = localInputs.models.join(','); - localInputs.group = localInputs.groups.join(','); - - if (inputs.type === 41 && batch) { - const keyObj = {}; - jsonFiles.forEach((content, idx) => { - keyObj[content] = idx; - }); - localInputs.key = JSON.stringify(keyObj); - } - - let mode = 'single'; - if (batch) { - mode = mergeToSingle ? 'multi_to_single' : 'batch'; - } - - const payload = { - mode, - channel: localInputs, - }; - + localInputs.group = (localInputs.groups || []).join(','); if (isEdit) { res = await API.put(`/api/channel/`, { ...localInputs, id: parseInt(channelId), }); } else { - res = await API.post(`/api/channel/`, payload); + res = await API.post(`/api/channel/`, localInputs); } const { success, message } = res.data; if (success) { @@ -425,7 +442,7 @@ const EditChannel = (props) => { localModels.push(model); localModelOptions.push({ key: model, - text: model, + label: model, value: model, }); addedModels.push(model); @@ -448,17 +465,10 @@ const EditChannel = (props) => { } }; - const handleJsonFileUpload = (file) => { - return new Promise((resolve) => { - const reader = new FileReader(); - reader.onload = (e) => { - const content = e.target.result; - setJsonFiles((prev) => [...prev, content]); - resolve({ shouldUpload: false, status: 'success' }); - }; - reader.readAsText(file); - }); - }; + const batchAllowed = !isEdit && inputs.type !== 41; + const batchExtra = batchAllowed ? ( + setBatch(!batch)}>{t('批量创建')} + ) : null; return ( <> @@ -472,14 +482,7 @@ const EditChannel = (props) => { } - headerStyle={{ - borderBottom: '1px solid var(--semi-color-border)', - padding: '24px' - }} - bodyStyle={{ - backgroundColor: 'var(--semi-color-bg-0)', - padding: '0' - }} + bodyStyle={{ padding: '0' }} visible={props.visible} width={isMobile() ? '100%' : 600} footer={ @@ -487,17 +490,13 @@ const EditChannel = (props) => {