diff --git a/common/constants.go b/common/constants.go index bee00506..ac803148 100644 --- a/common/constants.go +++ b/common/constants.go @@ -241,6 +241,7 @@ const ( ChannelTypeXinference = 47 ChannelTypeXai = 48 ChannelTypeCoze = 49 + ChannelTypeKling = 50 ChannelTypeDummy // this one is only for count, do not add any channel after this ) @@ -296,4 +297,5 @@ var ChannelBaseURLs = []string{ "", //47 "https://api.x.ai", //48 "https://api.coze.cn", //49 + "https://api.klingai.com", //50 } diff --git a/common/utils.go b/common/utils.go index d9db67d0..17aecd95 100644 --- a/common/utils.go +++ b/common/utils.go @@ -13,6 +13,7 @@ import ( "math/big" "math/rand" "net" + "net/url" "os" "os/exec" "runtime" @@ -284,3 +285,20 @@ func GetAudioDuration(ctx context.Context, filename string, ext string) (float64 } return strconv.ParseFloat(durationStr, 64) } + +// BuildURL concatenates base and endpoint, returns the complete url string +func BuildURL(base string, endpoint string) string { + u, err := url.Parse(base) + if err != nil { + return base + endpoint + } + end := endpoint + if end == "" { + end = "/" + } + ref, err := url.Parse(end) + if err != nil { + return base + endpoint + } + return u.ResolveReference(ref).String() +} diff --git a/constant/task.go b/constant/task.go index 1a68b812..d466fc8a 100644 --- a/constant/task.go +++ b/constant/task.go @@ -5,6 +5,7 @@ type TaskPlatform string const ( TaskPlatformSuno TaskPlatform = "suno" TaskPlatformMidjourney = "mj" + TaskPlatformKling TaskPlatform = "kling" ) const ( diff --git a/controller/channel-test.go b/controller/channel-test.go index d162d8cf..d54ccf0d 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -40,6 +40,9 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr if channel.Type == common.ChannelTypeSunoAPI { return errors.New("suno channel test is not supported"), nil } + if channel.Type == common.ChannelTypeKling { + return errors.New("kling channel test is not supported"), nil + } w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) @@ -90,7 +93,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr info := relaycommon.GenRelayInfo(c) - err = helper.ModelMappedHelper(c, info) + err = helper.ModelMappedHelper(c, info, nil) if err != nil { return err, nil } diff --git a/controller/model.go b/controller/model.go index 134217a3..78bd32d6 100644 --- a/controller/model.go +++ b/controller/model.go @@ -2,6 +2,7 @@ package controller import ( "fmt" + "github.com/samber/lo" "net/http" "one-api/common" "one-api/constant" @@ -136,6 +137,9 @@ func init() { adaptor.Init(meta) channelId2Models[i] = adaptor.GetModelList() } + openAIModels = lo.UniqBy(openAIModels, func(m dto.OpenAIModels) string { + return m.Id + }) } func ListModels(c *gin.Context) { diff --git a/controller/ratio_config.go b/controller/ratio_config.go new file mode 100644 index 00000000..6ddc3d9e --- /dev/null +++ b/controller/ratio_config.go @@ -0,0 +1,24 @@ +package controller + +import ( + "net/http" + "one-api/setting/ratio_setting" + + "github.com/gin-gonic/gin" +) + +func GetRatioConfig(c *gin.Context) { + if !ratio_setting.IsExposeRatioEnabled() { + c.JSON(http.StatusForbidden, gin.H{ + "success": false, + "message": "倍率配置接口未启用", + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": ratio_setting.GetExposedData(), + }) +} \ No newline at end of file diff --git a/controller/ratio_sync.go b/controller/ratio_sync.go new file mode 100644 index 00000000..f749f384 --- /dev/null +++ b/controller/ratio_sync.go @@ -0,0 +1,322 @@ +package controller + +import ( + "context" + "encoding/json" + "net/http" + "strings" + "sync" + "time" + + "one-api/common" + "one-api/dto" + "one-api/model" + "one-api/setting/ratio_setting" + + "github.com/gin-gonic/gin" +) + +const ( + defaultTimeoutSeconds = 10 + defaultEndpoint = "/api/ratio_config" + maxConcurrentFetches = 8 +) + +var ratioTypes = []string{"model_ratio", "completion_ratio", "cache_ratio", "model_price"} + +type upstreamResult struct { + Name string `json:"name"` + Data map[string]any `json:"data,omitempty"` + Err string `json:"err,omitempty"` +} + +func FetchUpstreamRatios(c *gin.Context) { + var req dto.UpstreamRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": err.Error()}) + return + } + + if req.Timeout <= 0 { + req.Timeout = defaultTimeoutSeconds + } + + var upstreams []dto.UpstreamDTO + + if len(req.ChannelIDs) > 0 { + intIds := make([]int, 0, len(req.ChannelIDs)) + for _, id64 := range req.ChannelIDs { + intIds = append(intIds, int(id64)) + } + dbChannels, err := model.GetChannelsByIds(intIds) + if err != nil { + common.LogError(c.Request.Context(), "failed to query channels: "+err.Error()) + c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": "查询渠道失败"}) + return + } + for _, ch := range dbChannels { + if base := ch.GetBaseURL(); strings.HasPrefix(base, "http") { + upstreams = append(upstreams, dto.UpstreamDTO{ + Name: ch.Name, + BaseURL: strings.TrimRight(base, "/"), + Endpoint: "", + }) + } + } + } + + if len(upstreams) == 0 { + c.JSON(http.StatusOK, gin.H{"success": false, "message": "无有效上游渠道"}) + return + } + + var wg sync.WaitGroup + ch := make(chan upstreamResult, len(upstreams)) + + sem := make(chan struct{}, maxConcurrentFetches) + + client := &http.Client{Transport: &http.Transport{MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second}} + + for _, chn := range upstreams { + wg.Add(1) + go func(chItem dto.UpstreamDTO) { + defer wg.Done() + + sem <- struct{}{} + defer func() { <-sem }() + + endpoint := chItem.Endpoint + if endpoint == "" { + endpoint = defaultEndpoint + } else if !strings.HasPrefix(endpoint, "/") { + endpoint = "/" + endpoint + } + fullURL := chItem.BaseURL + endpoint + + ctx, cancel := context.WithTimeout(c.Request.Context(), time.Duration(req.Timeout)*time.Second) + defer cancel() + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil) + if err != nil { + common.LogWarn(c.Request.Context(), "build request failed: "+err.Error()) + ch <- upstreamResult{Name: chItem.Name, Err: err.Error()} + return + } + + resp, err := client.Do(httpReq) + if err != nil { + common.LogWarn(c.Request.Context(), "http error on "+chItem.Name+": "+err.Error()) + ch <- upstreamResult{Name: chItem.Name, Err: err.Error()} + return + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + common.LogWarn(c.Request.Context(), "non-200 from "+chItem.Name+": "+resp.Status) + ch <- upstreamResult{Name: chItem.Name, Err: resp.Status} + return + } + var body struct { + Success bool `json:"success"` + Data map[string]any `json:"data"` + Message string `json:"message"` + } + if err := json.NewDecoder(resp.Body).Decode(&body); err != nil { + common.LogWarn(c.Request.Context(), "json decode failed from "+chItem.Name+": "+err.Error()) + ch <- upstreamResult{Name: chItem.Name, Err: err.Error()} + return + } + if !body.Success { + ch <- upstreamResult{Name: chItem.Name, Err: body.Message} + return + } + ch <- upstreamResult{Name: chItem.Name, Data: body.Data} + }(chn) + } + + wg.Wait() + close(ch) + + localData := ratio_setting.GetExposedData() + + var testResults []dto.TestResult + var successfulChannels []struct { + name string + data map[string]any + } + + for r := range ch { + if r.Err != "" { + testResults = append(testResults, dto.TestResult{ + Name: r.Name, + Status: "error", + Error: r.Err, + }) + } else { + testResults = append(testResults, dto.TestResult{ + Name: r.Name, + Status: "success", + }) + successfulChannels = append(successfulChannels, struct { + name string + data map[string]any + }{name: r.Name, data: r.Data}) + } + } + + differences := buildDifferences(localData, successfulChannels) + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "data": gin.H{ + "differences": differences, + "test_results": testResults, + }, + }) +} + +func buildDifferences(localData map[string]any, successfulChannels []struct { + name string + data map[string]any +}) map[string]map[string]dto.DifferenceItem { + differences := make(map[string]map[string]dto.DifferenceItem) + + allModels := make(map[string]struct{}) + + for _, ratioType := range ratioTypes { + if localRatioAny, ok := localData[ratioType]; ok { + if localRatio, ok := localRatioAny.(map[string]float64); ok { + for modelName := range localRatio { + allModels[modelName] = struct{}{} + } + } + } + } + + for _, channel := range successfulChannels { + for _, ratioType := range ratioTypes { + if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok { + for modelName := range upstreamRatio { + allModels[modelName] = struct{}{} + } + } + } + } + + for modelName := range allModels { + for _, ratioType := range ratioTypes { + var localValue interface{} = nil + if localRatioAny, ok := localData[ratioType]; ok { + if localRatio, ok := localRatioAny.(map[string]float64); ok { + if val, exists := localRatio[modelName]; exists { + localValue = val + } + } + } + + upstreamValues := make(map[string]interface{}) + hasUpstreamValue := false + hasDifference := false + + for _, channel := range successfulChannels { + var upstreamValue interface{} = nil + + if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok { + if val, exists := upstreamRatio[modelName]; exists { + upstreamValue = val + hasUpstreamValue = true + + if localValue != nil && localValue != val { + hasDifference = true + } else if localValue == val { + upstreamValue = "same" + } + } + } + if upstreamValue == nil && localValue == nil { + upstreamValue = "same" + } + + if localValue == nil && upstreamValue != nil && upstreamValue != "same" { + hasDifference = true + } + + upstreamValues[channel.name] = upstreamValue + } + + shouldInclude := false + + if localValue != nil { + if hasDifference { + shouldInclude = true + } + } else { + if hasUpstreamValue { + shouldInclude = true + } + } + + if shouldInclude { + if differences[modelName] == nil { + differences[modelName] = make(map[string]dto.DifferenceItem) + } + differences[modelName][ratioType] = dto.DifferenceItem{ + Current: localValue, + Upstreams: upstreamValues, + } + } + } + } + + channelHasDiff := make(map[string]bool) + for _, ratioMap := range differences { + for _, item := range ratioMap { + for chName, val := range item.Upstreams { + if val != nil && val != "same" { + channelHasDiff[chName] = true + } + } + } + } + + for modelName, ratioMap := range differences { + for ratioType, item := range ratioMap { + for chName := range item.Upstreams { + if !channelHasDiff[chName] { + delete(item.Upstreams, chName) + } + } + differences[modelName][ratioType] = item + } + } + + return differences +} + +func GetSyncableChannels(c *gin.Context) { + channels, err := model.GetAllChannels(0, 0, true, false) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + + var syncableChannels []dto.SyncableChannel + for _, channel := range channels { + if channel.GetBaseURL() != "" { + syncableChannels = append(syncableChannels, dto.SyncableChannel{ + ID: channel.Id, + Name: channel.Name, + BaseURL: channel.GetBaseURL(), + Status: channel.Status, + }) + } + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": syncableChannels, + }) +} \ No newline at end of file diff --git a/controller/relay.go b/controller/relay.go index c1c45114..4da4262b 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -420,7 +420,7 @@ func RelayTask(c *gin.Context) { func taskRelayHandler(c *gin.Context, relayMode int) *dto.TaskError { var err *dto.TaskError switch relayMode { - case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID: + case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID, relayconstant.RelayModeKlingFetchByID: err = relay.RelayTaskFetch(c, relayMode) default: err = relay.RelayTaskSubmit(c, relayMode) diff --git a/controller/task.go b/controller/task.go index 34e14f3f..f7523e87 100644 --- a/controller/task.go +++ b/controller/task.go @@ -74,6 +74,8 @@ func UpdateTaskByPlatform(platform constant.TaskPlatform, taskChannelM map[int][ //_ = UpdateMidjourneyTaskAll(context.Background(), tasks) case constant.TaskPlatformSuno: _ = UpdateSunoTaskAll(context.Background(), taskChannelM, taskM) + case constant.TaskPlatformKling: + _ = UpdateVideoTaskAll(context.Background(), taskChannelM, taskM) default: common.SysLog("未知平台") } diff --git a/controller/task_video.go b/controller/task_video.go new file mode 100644 index 00000000..a2c2431d --- /dev/null +++ b/controller/task_video.go @@ -0,0 +1,140 @@ +package controller + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "one-api/common" + "one-api/constant" + "one-api/model" + "one-api/relay" + "one-api/relay/channel" +) + +func UpdateVideoTaskAll(ctx context.Context, taskChannelM map[int][]string, taskM map[string]*model.Task) error { + for channelId, taskIds := range taskChannelM { + if err := updateVideoTaskAll(ctx, channelId, taskIds, taskM); err != nil { + common.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error())) + } + } + return nil +} + +func updateVideoTaskAll(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error { + common.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds))) + if len(taskIds) == 0 { + return nil + } + cacheGetChannel, err := model.CacheGetChannel(channelId) + if err != nil { + errUpdate := model.TaskBulkUpdate(taskIds, map[string]any{ + "fail_reason": fmt.Sprintf("Failed to get channel info, channel ID: %d", channelId), + "status": "FAILURE", + "progress": "100%", + }) + if errUpdate != nil { + common.SysError(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate)) + } + return fmt.Errorf("CacheGetChannel failed: %w", err) + } + adaptor := relay.GetTaskAdaptor(constant.TaskPlatformKling) + if adaptor == nil { + return fmt.Errorf("video adaptor not found") + } + for _, taskId := range taskIds { + if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil { + common.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error())) + } + } + return nil +} + +func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, channel *model.Channel, taskId string, taskM map[string]*model.Task) error { + baseURL := common.ChannelBaseURLs[channel.Type] + if channel.GetBaseURL() != "" { + baseURL = channel.GetBaseURL() + } + resp, err := adaptor.FetchTask(baseURL, channel.Key, map[string]any{ + "task_id": taskId, + }) + if err != nil { + return fmt.Errorf("FetchTask failed for task %s: %w", taskId, err) + } + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("Get Video Task status code: %d", resp.StatusCode) + } + defer resp.Body.Close() + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("ReadAll failed for task %s: %w", taskId, err) + } + + var responseItem map[string]interface{} + err = json.Unmarshal(responseBody, &responseItem) + if err != nil { + common.LogError(ctx, fmt.Sprintf("Failed to parse video task response body: %v, body: %s", err, string(responseBody))) + return fmt.Errorf("Unmarshal failed for task %s: %w", taskId, err) + } + + code, _ := responseItem["code"].(float64) + if code != 0 { + return fmt.Errorf("video task fetch failed for task %s", taskId) + } + + data, ok := responseItem["data"].(map[string]interface{}) + if !ok { + common.LogError(ctx, fmt.Sprintf("Video task data format error: %s", string(responseBody))) + return fmt.Errorf("video task data format error for task %s", taskId) + } + + task := taskM[taskId] + if task == nil { + common.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId)) + return fmt.Errorf("task %s not found", taskId) + } + + if status, ok := data["task_status"].(string); ok { + switch status { + case "submitted", "queued": + task.Status = model.TaskStatusSubmitted + case "processing": + task.Status = model.TaskStatusInProgress + case "succeed": + task.Status = model.TaskStatusSuccess + task.Progress = "100%" + if url, err := adaptor.ParseResultUrl(responseItem); err == nil { + task.FailReason = url + } else { + common.LogWarn(ctx, fmt.Sprintf("Failed to get url from body for task %s: %s", task.TaskID, err.Error())) + } + case "failed": + task.Status = model.TaskStatusFailure + task.Progress = "100%" + if reason, ok := data["fail_reason"].(string); ok { + task.FailReason = reason + } + } + } + + // If task failed, refund quota + if task.Status == model.TaskStatusFailure { + common.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason)) + quota := task.Quota + if quota != 0 { + if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil { + common.LogError(ctx, "Failed to increase user quota: "+err.Error()) + } + logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, common.LogQuota(quota)) + model.RecordLog(task.UserId, model.LogTypeSystem, logContent) + } + } + + task.Data = responseBody + if err := task.Update(); err != nil { + common.SysError("UpdateVideoTask task error: " + err.Error()) + } + + return nil +} diff --git a/controller/topup.go b/controller/topup.go index 951b2cf2..827dda39 100644 --- a/controller/topup.go +++ b/controller/topup.go @@ -97,14 +97,12 @@ func RequestEpay(c *gin.Context) { c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"}) return } - payType := "wxpay" - if req.PaymentMethod == "zfb" { - payType = "alipay" - } - if req.PaymentMethod == "wx" { - req.PaymentMethod = "wxpay" - payType = "wxpay" + + if !setting.ContainsPayMethod(req.PaymentMethod) { + c.JSON(200, gin.H{"message": "error", "data": "支付方式不存在"}) + return } + callBackAddress := service.GetCallbackAddress() returnUrl, _ := url.Parse(setting.ServerAddress + "/console/log") notifyUrl, _ := url.Parse(callBackAddress + "/api/user/epay/notify") @@ -116,7 +114,7 @@ func RequestEpay(c *gin.Context) { return } uri, params, err := client.Purchase(&epay.PurchaseArgs{ - Type: payType, + Type: req.PaymentMethod, ServiceTradeNo: tradeNo, Name: fmt.Sprintf("TUC%d", req.Amount), Money: strconv.FormatFloat(payMoney, 'f', 2, 64), diff --git a/dto/ratio_sync.go b/dto/ratio_sync.go new file mode 100644 index 00000000..55a89025 --- /dev/null +++ b/dto/ratio_sync.go @@ -0,0 +1,49 @@ +package dto + +// UpstreamDTO 提交到后端同步倍率的上游渠道信息 +// Endpoint 可以为空,后端会默认使用 /api/ratio_config +// BaseURL 必须以 http/https 开头,不要以 / 结尾 +// 例如: https://api.example.com +// Endpoint: /api/ratio_config +// 提交示例: +// { +// "name": "openai", +// "base_url": "https://api.openai.com", +// "endpoint": "/ratio_config" +// } + +type UpstreamDTO struct { + Name string `json:"name" binding:"required"` + BaseURL string `json:"base_url" binding:"required"` + Endpoint string `json:"endpoint"` +} + +type UpstreamRequest struct { + ChannelIDs []int64 `json:"channel_ids"` + Timeout int `json:"timeout"` +} + +// TestResult 上游测试连通性结果 +type TestResult struct { + Name string `json:"name"` + Status string `json:"status"` + Error string `json:"error,omitempty"` +} + +// DifferenceItem 差异项 +// Current 为本地值,可能为 nil +// Upstreams 为各渠道的上游值,具体数值 / "same" / nil + +type DifferenceItem struct { + Current interface{} `json:"current"` + Upstreams map[string]interface{} `json:"upstreams"` +} + +// SyncableChannel 可同步的渠道信息(base_url 不为空) + +type SyncableChannel struct { + ID int `json:"id"` + Name string `json:"name"` + BaseURL string `json:"base_url"` + Status int `json:"status"` +} \ No newline at end of file diff --git a/dto/video.go b/dto/video.go new file mode 100644 index 00000000..5b48146a --- /dev/null +++ b/dto/video.go @@ -0,0 +1,47 @@ +package dto + +type VideoRequest struct { + Model string `json:"model,omitempty" example:"kling-v1"` // Model/style ID + Prompt string `json:"prompt,omitempty" example:"宇航员站起身走了"` // Text prompt + Image string `json:"image,omitempty" example:"https://h2.inkwai.com/bs2/upload-ylab-stunt/se/ai_portal_queue_mmu_image_upscale_aiweb/3214b798-e1b4-4b00-b7af-72b5b0417420_raw_image_0.jpg"` // Image input (URL/Base64) + Duration float64 `json:"duration" example:"5.0"` // Video duration (seconds) + Width int `json:"width" example:"512"` // Video width + Height int `json:"height" example:"512"` // Video height + Fps int `json:"fps,omitempty" example:"30"` // Video frame rate + Seed int `json:"seed,omitempty" example:"20231234"` // Random seed + N int `json:"n,omitempty" example:"1"` // Number of videos to generate + ResponseFormat string `json:"response_format,omitempty" example:"url"` // Response format + User string `json:"user,omitempty" example:"user-1234"` // User identifier + Metadata map[string]any `json:"metadata,omitempty"` // Vendor-specific/custom params (e.g. negative_prompt, style, quality_level, etc.) +} + +// VideoResponse 视频生成提交任务后的响应 +type VideoResponse struct { + TaskId string `json:"task_id"` + Status string `json:"status"` +} + +// VideoTaskResponse 查询视频生成任务状态的响应 +type VideoTaskResponse struct { + TaskId string `json:"task_id" example:"abcd1234efgh"` // 任务ID + Status string `json:"status" example:"succeeded"` // 任务状态 + Url string `json:"url,omitempty"` // 视频资源URL(成功时) + Format string `json:"format,omitempty" example:"mp4"` // 视频格式 + Metadata *VideoTaskMetadata `json:"metadata,omitempty"` // 结果元数据 + Error *VideoTaskError `json:"error,omitempty"` // 错误信息(失败时) +} + +// VideoTaskMetadata 视频任务元数据 +type VideoTaskMetadata struct { + Duration float64 `json:"duration" example:"5.0"` // 实际生成的视频时长 + Fps int `json:"fps" example:"30"` // 实际帧率 + Width int `json:"width" example:"512"` // 实际宽度 + Height int `json:"height" example:"512"` // 实际高度 + Seed int `json:"seed" example:"20231234"` // 使用的随机种子 +} + +// VideoTaskError 视频任务错误信息 +type VideoTaskError struct { + Code int `json:"code"` + Message string `json:"message"` +} diff --git a/middleware/distributor.go b/middleware/distributor.go index 84eb182e..9d074ce8 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -170,6 +170,15 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) { } c.Set("platform", string(constant.TaskPlatformSuno)) c.Set("relay_mode", relayMode) + } else if strings.Contains(c.Request.URL.Path, "/v1/video/generations") { + relayMode := relayconstant.Path2RelayKling(c.Request.Method, c.Request.URL.Path) + if relayMode == relayconstant.RelayModeKlingFetchByID { + shouldSelectChannel = false + } else { + err = common.UnmarshalBodyReusable(c, &modelRequest) + } + c.Set("platform", string(constant.TaskPlatformKling)) + c.Set("relay_mode", relayMode) } else if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") { // Gemini API 路径处理: /v1beta/models/gemini-2.0-flash:generateContent relayMode := relayconstant.RelayModeGemini diff --git a/model/option.go b/model/option.go index ec386b29..ea72e5ee 100644 --- a/model/option.go +++ b/model/option.go @@ -127,6 +127,7 @@ func InitOptionMap() { common.OptionMap["SensitiveWords"] = setting.SensitiveWordsToString() common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(setting.StreamCacheQueueLength) common.OptionMap["AutomaticDisableKeywords"] = operation_setting.AutomaticDisableKeywordsToString() + common.OptionMap["ExposeRatioEnabled"] = strconv.FormatBool(ratio_setting.IsExposeRatioEnabled()) // 自动添加所有注册的模型配置 modelConfigs := config.GlobalConfig.ExportAllConfigs() @@ -267,6 +268,8 @@ func updateOptionMap(key string, value string) (err error) { setting.WorkerAllowHttpImageRequestEnabled = boolValue case "DefaultUseAutoGroup": setting.DefaultUseAutoGroup = boolValue + case "ExposeRatioEnabled": + ratio_setting.SetExposeRatioEnabled(boolValue) } } switch key { diff --git a/relay/relay-audio.go b/relay/audio_handler.go similarity index 91% rename from relay/relay-audio.go rename to relay/audio_handler.go index deb45c58..c1ce1a02 100644 --- a/relay/relay-audio.go +++ b/relay/audio_handler.go @@ -55,7 +55,7 @@ func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto. } func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { - relayInfo := relaycommon.GenRelayInfo(c) + relayInfo := relaycommon.GenRelayInfoOpenAIAudio(c) audioRequest, err := getAndValidAudioRequest(c, relayInfo) if err != nil { @@ -66,10 +66,7 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { promptTokens := 0 preConsumedTokens := common.PreConsumedQuota if relayInfo.RelayMode == relayconstant.RelayModeAudioSpeech { - promptTokens, err = service.CountTTSToken(audioRequest.Input, audioRequest.Model) - if err != nil { - return service.OpenAIErrorWrapper(err, "count_audio_token_failed", http.StatusInternalServerError) - } + promptTokens = service.CountTTSToken(audioRequest.Input, audioRequest.Model) preConsumedTokens = promptTokens relayInfo.PromptTokens = promptTokens } @@ -89,13 +86,11 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { } }() - err = helper.ModelMappedHelper(c, relayInfo) + err = helper.ModelMappedHelper(c, relayInfo, audioRequest) if err != nil { return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError) } - audioRequest.Model = relayInfo.UpstreamModelName - adaptor := GetAdaptor(relayInfo.ApiType) if adaptor == nil { return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest) diff --git a/relay/channel/adapter.go b/relay/channel/adapter.go index 50255d0a..873997f6 100644 --- a/relay/channel/adapter.go +++ b/relay/channel/adapter.go @@ -44,4 +44,6 @@ type TaskAdaptor interface { // FetchTask FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) + + ParseResultUrl(resp map[string]any) (string, error) } diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index ba20adea..406ebc8a 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -549,7 +549,7 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, requestMode int) { if requestMode == RequestModeCompletion { - claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens) + claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens) } else { if claudeInfo.Usage.PromptTokens == 0 { //上游出错 @@ -558,7 +558,7 @@ func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, clau if common.DebugEnabled { common.SysError("claude response usage is not complete, maybe upstream error") } - claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens) + claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens) } } @@ -618,10 +618,7 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud } } if requestMode == RequestModeCompletion { - completionTokens, err := service.CountTextToken(claudeResponse.Completion, info.OriginModelName) - if err != nil { - return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError) - } + completionTokens := service.CountTextToken(claudeResponse.Completion, info.OriginModelName) claudeInfo.Usage.PromptTokens = info.PromptTokens claudeInfo.Usage.CompletionTokens = completionTokens claudeInfo.Usage.TotalTokens = info.PromptTokens + completionTokens diff --git a/relay/channel/cloudflare/relay_cloudflare.go b/relay/channel/cloudflare/relay_cloudflare.go index a487429c..50d4928a 100644 --- a/relay/channel/cloudflare/relay_cloudflare.go +++ b/relay/channel/cloudflare/relay_cloudflare.go @@ -71,7 +71,7 @@ func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela if err := scanner.Err(); err != nil { common.LogError(c, "error_scanning_stream_response: "+err.Error()) } - usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) + usage := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) if info.ShouldIncludeUsage { response := helper.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage) err := helper.ObjectData(c, response) @@ -108,7 +108,7 @@ func cfHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) for _, choice := range response.Choices { responseText += choice.Message.StringContent() } - usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) + usage := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) response.Usage = *usage response.Id = helper.GetResponseID(c) jsonResponse, err := json.Marshal(response) @@ -150,7 +150,7 @@ func cfSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayIn usage := &dto.Usage{} usage.PromptTokens = info.PromptTokens - usage.CompletionTokens, _ = service.CountTextToken(cfResp.Result.Text, info.UpstreamModelName) + usage.CompletionTokens = service.CountTextToken(cfResp.Result.Text, info.UpstreamModelName) usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens return nil, usage diff --git a/relay/channel/cohere/relay-cohere.go b/relay/channel/cohere/relay-cohere.go index 8a044bf2..29064242 100644 --- a/relay/channel/cohere/relay-cohere.go +++ b/relay/channel/cohere/relay-cohere.go @@ -162,7 +162,7 @@ func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. } }) if usage.PromptTokens == 0 { - usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) + usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) } return nil, usage } diff --git a/relay/channel/coze/relay-coze.go b/relay/channel/coze/relay-coze.go index 6db40213..ac76476f 100644 --- a/relay/channel/coze/relay-coze.go +++ b/relay/channel/coze/relay-coze.go @@ -106,7 +106,7 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo var currentEvent string var currentData string - var usage dto.Usage + var usage = &dto.Usage{} for scanner.Scan() { line := scanner.Text() @@ -114,7 +114,7 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo if line == "" { if currentEvent != "" && currentData != "" { // handle last event - handleCozeEvent(c, currentEvent, currentData, &responseText, &usage, id, info) + handleCozeEvent(c, currentEvent, currentData, &responseText, usage, id, info) currentEvent = "" currentData = "" } @@ -134,7 +134,7 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo // Last event if currentEvent != "" && currentData != "" { - handleCozeEvent(c, currentEvent, currentData, &responseText, &usage, id, info) + handleCozeEvent(c, currentEvent, currentData, &responseText, usage, id, info) } if err := scanner.Err(); err != nil { @@ -143,12 +143,10 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo helper.Done(c) if usage.TotalTokens == 0 { - usage.PromptTokens = info.PromptTokens - usage.CompletionTokens, _ = service.CountTextToken("gpt-3.5-turbo", responseText) - usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens + usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, c.GetInt("coze_input_count")) } - return nil, &usage + return nil, usage } func handleCozeEvent(c *gin.Context, event string, data string, responseText *string, usage *dto.Usage, id string, info *relaycommon.RelayInfo) { diff --git a/relay/channel/dify/relay-dify.go b/relay/channel/dify/relay-dify.go index 93e3e8d6..115aed1b 100644 --- a/relay/channel/dify/relay-dify.go +++ b/relay/channel/dify/relay-dify.go @@ -243,15 +243,8 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re return true }) helper.Done(c) - err := resp.Body.Close() - if err != nil { - // return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - common.SysError("close_response_body_failed: " + err.Error()) - } if usage.TotalTokens == 0 { - usage.PromptTokens = info.PromptTokens - usage.CompletionTokens, _ = service.CountTextToken("gpt-3.5-turbo", responseText) - usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens + usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) } usage.CompletionTokens += nodeToken return nil, usage diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go index a81eb3a9..968d9c9b 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/channel/gemini/adaptor.go @@ -73,12 +73,12 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { if model_setting.GetGeminiSettings().ThinkingAdapterEnabled { // 新增逻辑:处理 -thinking- 格式 - if strings.Contains(info.OriginModelName, "-thinking-") { + if strings.Contains(info.UpstreamModelName, "-thinking-") { parts := strings.Split(info.UpstreamModelName, "-thinking-") info.UpstreamModelName = parts[0] - } else if strings.HasSuffix(info.OriginModelName, "-thinking") { // 旧的适配 + } else if strings.HasSuffix(info.UpstreamModelName, "-thinking") { // 旧的适配 info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking") - } else if strings.HasSuffix(info.OriginModelName, "-nothinking") { + } else if strings.HasSuffix(info.UpstreamModelName, "-nothinking") { info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-nothinking") } } diff --git a/relay/channel/gemini/relay-gemini-native.go b/relay/channel/gemini/relay-gemini-native.go index bf87eafd..39757cef 100644 --- a/relay/channel/gemini/relay-gemini-native.go +++ b/relay/channel/gemini/relay-gemini-native.go @@ -9,6 +9,7 @@ import ( relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" + "strings" "github.com/gin-gonic/gin" ) @@ -75,6 +76,8 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info helper.SetEventStreamHeaders(c) + responseText := strings.Builder{} + helper.StreamScannerHandler(c, resp, info, func(data string) bool { var geminiResponse GeminiChatResponse err := common.DecodeJsonStr(data, &geminiResponse) @@ -89,6 +92,9 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info if part.InlineData != nil && part.InlineData.MimeType != "" { imageCount++ } + if part.Text != "" { + responseText.WriteString(part.Text) + } } } @@ -108,7 +114,7 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info } // 直接发送 GeminiChatResponse 响应 - err = helper.ObjectData(c, geminiResponse) + err = helper.StringData(c, data) if err != nil { common.LogError(c, err.Error()) } @@ -122,8 +128,16 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info } } - // 计算最终使用量 - // usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens + // 如果usage.CompletionTokens为0,则使用本地统计的completion tokens + if usage.CompletionTokens == 0 { + str := responseText.String() + if len(str) > 0 { + usage = service.ResponseText2Usage(responseText.String(), info.UpstreamModelName, info.PromptTokens) + } else { + // 空补全,不需要使用量 + usage = &dto.Usage{} + } + } // 移除流式响应结尾的[Done],因为Gemini API没有发送Done的行为 //helper.Done(c) diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index d4b7c209..ef2c35be 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -99,7 +99,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon } if model_setting.GetGeminiSettings().ThinkingAdapterEnabled { - modelName := info.OriginModelName + modelName := info.UpstreamModelName isNew25Pro := strings.HasPrefix(modelName, "gemini-2.5-pro") && !strings.HasPrefix(modelName, "gemini-2.5-pro-preview-05-06") && !strings.HasPrefix(modelName, "gemini-2.5-pro-preview-03-25") diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index 4dc0fc60..71590cd6 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -8,7 +8,6 @@ import ( "math" "mime/multipart" "net/http" - "path/filepath" "one-api/common" "one-api/constant" "one-api/dto" @@ -16,6 +15,7 @@ import ( "one-api/relay/helper" "one-api/service" "os" + "path/filepath" "strings" "github.com/bytedance/gopkg/util/gopool" @@ -181,7 +181,7 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel } if !containStreamUsage { - usage, _ = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens) + usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens) usage.CompletionTokens += toolCount * 7 } else { if info.ChannelType == common.ChannelTypeDeepSeek { @@ -216,7 +216,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI StatusCode: resp.StatusCode, }, nil } - + forceFormat := false if forceFmt, ok := info.ChannelSetting[constant.ForceFormat].(bool); ok { forceFormat = forceFmt @@ -225,7 +225,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) { completionTokens := 0 for _, choice := range simpleResponse.Choices { - ctkm, _ := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName) + ctkm := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName) completionTokens += ctkm } simpleResponse.Usage = dto.Usage{ @@ -276,9 +276,9 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { // the status code has been judged before, if there is a body reading failure, // it should be regarded as a non-recoverable error, so it should not return err for external retry. - // Analogous to nginx's load balancing, it will only retry if it can't be requested or - // if the upstream returns a specific status code, once the upstream has already written the header, - // the subsequent failure of the response body should be regarded as a non-recoverable error, + // Analogous to nginx's load balancing, it will only retry if it can't be requested or + // if the upstream returns a specific status code, once the upstream has already written the header, + // the subsequent failure of the response body should be regarded as a non-recoverable error, // and can be terminated directly. defer resp.Body.Close() usage := &dto.Usage{} @@ -346,12 +346,12 @@ func countAudioTokens(c *gin.Context) (int, error) { if err = c.ShouldBind(&reqBody); err != nil { return 0, errors.WithStack(err) } - ext := filepath.Ext(reqBody.File.Filename) // 获取文件扩展名 + ext := filepath.Ext(reqBody.File.Filename) // 获取文件扩展名 reqFp, err := reqBody.File.Open() if err != nil { return 0, errors.WithStack(err) } - defer reqFp.Close() + defer reqFp.Close() tmpFp, err := os.CreateTemp("", "audio-*"+ext) if err != nil { diff --git a/relay/channel/openai/relay_responses.go b/relay/channel/openai/relay_responses.go index 1d1e060e..da9382c3 100644 --- a/relay/channel/openai/relay_responses.go +++ b/relay/channel/openai/relay_responses.go @@ -110,7 +110,7 @@ func OaiResponsesStreamHandler(c *gin.Context, resp *http.Response, info *relayc tempStr := responseTextBuilder.String() if len(tempStr) > 0 { // 非正常结束,使用输出文本的 token 数量 - completionTokens, _ := service.CountTextToken(tempStr, info.UpstreamModelName) + completionTokens := service.CountTextToken(tempStr, info.UpstreamModelName) usage.CompletionTokens = completionTokens } } diff --git a/relay/channel/palm/adaptor.go b/relay/channel/palm/adaptor.go index 3a06e7ee..aee4a307 100644 --- a/relay/channel/palm/adaptor.go +++ b/relay/channel/palm/adaptor.go @@ -74,7 +74,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom if info.IsStream { var responseText string err, responseText = palmStreamHandler(c, resp) - usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) + usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) } else { err, usage = palmHandler(c, resp, info.PromptTokens, info.UpstreamModelName) } diff --git a/relay/channel/palm/relay-palm.go b/relay/channel/palm/relay-palm.go index 0c6f8641..9d3dbd67 100644 --- a/relay/channel/palm/relay-palm.go +++ b/relay/channel/palm/relay-palm.go @@ -155,7 +155,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st }, nil } fullTextResponse := responsePaLM2OpenAI(&palmResponse) - completionTokens, _ := service.CountTextToken(palmResponse.Candidates[0].Content, model) + completionTokens := service.CountTextToken(palmResponse.Candidates[0].Content, model) usage := dto.Usage{ PromptTokens: promptTokens, CompletionTokens: completionTokens, diff --git a/relay/channel/task/kling/adaptor.go b/relay/channel/task/kling/adaptor.go new file mode 100644 index 00000000..9ea58728 --- /dev/null +++ b/relay/channel/task/kling/adaptor.go @@ -0,0 +1,312 @@ +package kling + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/gin-gonic/gin" + "github.com/golang-jwt/jwt" + "github.com/pkg/errors" + + "one-api/common" + "one-api/dto" + "one-api/relay/channel" + relaycommon "one-api/relay/common" + "one-api/service" +) + +// ============================ +// Request / Response structures +// ============================ + +type SubmitReq struct { + Prompt string `json:"prompt"` + Model string `json:"model,omitempty"` + Mode string `json:"mode,omitempty"` + Image string `json:"image,omitempty"` + Size string `json:"size,omitempty"` + Duration int `json:"duration,omitempty"` + Metadata map[string]interface{} `json:"metadata,omitempty"` +} + +type requestPayload struct { + Prompt string `json:"prompt,omitempty"` + Image string `json:"image,omitempty"` + Mode string `json:"mode,omitempty"` + Duration string `json:"duration,omitempty"` + AspectRatio string `json:"aspect_ratio,omitempty"` + Model string `json:"model,omitempty"` + ModelName string `json:"model_name,omitempty"` + CfgScale float64 `json:"cfg_scale,omitempty"` +} + +type responsePayload struct { + Code int `json:"code"` + Message string `json:"message"` + Data struct { + TaskID string `json:"task_id"` + } `json:"data"` +} + +// ============================ +// Adaptor implementation +// ============================ + +type TaskAdaptor struct { + ChannelType int + accessKey string + secretKey string + baseURL string +} + +func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) { + a.ChannelType = info.ChannelType + a.baseURL = info.BaseUrl + + // apiKey format: "access_key,secret_key" + keyParts := strings.Split(info.ApiKey, ",") + if len(keyParts) == 2 { + a.accessKey = strings.TrimSpace(keyParts[0]) + a.secretKey = strings.TrimSpace(keyParts[1]) + } +} + +// ValidateRequestAndSetAction parses body, validates fields and sets default action. +func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) (taskErr *dto.TaskError) { + // Accept only POST /v1/video/generations as "generate" action. + action := "generate" + info.Action = action + + var req SubmitReq + if err := common.UnmarshalBodyReusable(c, &req); err != nil { + taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest) + return + } + if strings.TrimSpace(req.Prompt) == "" { + taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest) + return + } + + // Store into context for later usage + c.Set("kling_request", req) + return nil +} + +// BuildRequestURL constructs the upstream URL. +func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) { + return fmt.Sprintf("%s/v1/videos/image2video", a.baseURL), nil +} + +// BuildRequestHeader sets required headers. +func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error { + token, err := a.createJWTToken() + if err != nil { + return fmt.Errorf("failed to create JWT token: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + req.Header.Set("Authorization", "Bearer "+token) + req.Header.Set("User-Agent", "kling-sdk/1.0") + return nil +} + +// BuildRequestBody converts request into Kling specific format. +func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error) { + v, exists := c.Get("kling_request") + if !exists { + return nil, fmt.Errorf("request not found in context") + } + req := v.(SubmitReq) + + body := a.convertToRequestPayload(&req) + data, err := json.Marshal(body) + if err != nil { + return nil, err + } + return bytes.NewReader(data), nil +} + +// DoRequest delegates to common helper. +func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) { + return channel.DoTaskApiRequest(a, c, info, requestBody) +} + +// DoResponse handles upstream response, returns taskID etc. +func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) + return + } + + // Attempt Kling response parse first. + var kResp responsePayload + if err := json.Unmarshal(responseBody, &kResp); err == nil && kResp.Code == 0 { + c.JSON(http.StatusOK, gin.H{"task_id": kResp.Data.TaskID}) + return kResp.Data.TaskID, responseBody, nil + } + + // Fallback generic task response. + var generic dto.TaskResponse[string] + if err := json.Unmarshal(responseBody, &generic); err != nil { + taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError) + return + } + + if !generic.IsSuccess() { + taskErr = service.TaskErrorWrapper(fmt.Errorf(generic.Message), generic.Code, http.StatusInternalServerError) + return + } + + c.JSON(http.StatusOK, gin.H{"task_id": generic.Data}) + return generic.Data, responseBody, nil +} + +// FetchTask fetch task status +func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) { + taskID, ok := body["task_id"].(string) + if !ok { + return nil, fmt.Errorf("invalid task_id") + } + url := fmt.Sprintf("%s/v1/videos/image2video/%s", baseUrl, taskID) + + req, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + return nil, err + } + + token, err := a.createJWTTokenWithKey(key) + if err != nil { + token = key + } + + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + req = req.WithContext(ctx) + req.Header.Set("Accept", "application/json") + req.Header.Set("Authorization", "Bearer "+token) + req.Header.Set("User-Agent", "kling-sdk/1.0") + + return service.GetHttpClient().Do(req) +} + +func (a *TaskAdaptor) GetModelList() []string { + return []string{"kling-v1", "kling-v1-6", "kling-v2-master"} +} + +func (a *TaskAdaptor) GetChannelName() string { + return "kling" +} + +// ============================ +// helpers +// ============================ + +func (a *TaskAdaptor) convertToRequestPayload(req *SubmitReq) *requestPayload { + r := &requestPayload{ + Prompt: req.Prompt, + Image: req.Image, + Mode: defaultString(req.Mode, "std"), + Duration: fmt.Sprintf("%d", defaultInt(req.Duration, 5)), + AspectRatio: a.getAspectRatio(req.Size), + Model: req.Model, + ModelName: req.Model, + CfgScale: 0.5, + } + if r.Model == "" { + r.Model = "kling-v1" + r.ModelName = "kling-v1" + } + return r +} + +func (a *TaskAdaptor) getAspectRatio(size string) string { + switch size { + case "1024x1024", "512x512": + return "1:1" + case "1280x720", "1920x1080": + return "16:9" + case "720x1280", "1080x1920": + return "9:16" + default: + return "1:1" + } +} + +func defaultString(s, def string) string { + if strings.TrimSpace(s) == "" { + return def + } + return s +} + +func defaultInt(v int, def int) int { + if v == 0 { + return def + } + return v +} + +// ============================ +// JWT helpers +// ============================ + +func (a *TaskAdaptor) createJWTToken() (string, error) { + return a.createJWTTokenWithKeys(a.accessKey, a.secretKey) +} + +func (a *TaskAdaptor) createJWTTokenWithKey(apiKey string) (string, error) { + parts := strings.Split(apiKey, ",") + if len(parts) != 2 { + return "", fmt.Errorf("invalid API key format, expected 'access_key,secret_key'") + } + return a.createJWTTokenWithKeys(strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1])) +} + +func (a *TaskAdaptor) createJWTTokenWithKeys(accessKey, secretKey string) (string, error) { + if accessKey == "" || secretKey == "" { + return "", fmt.Errorf("access key and secret key are required") + } + now := time.Now().Unix() + claims := jwt.MapClaims{ + "iss": accessKey, + "exp": now + 1800, // 30 minutes + "nbf": now - 5, + } + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + token.Header["typ"] = "JWT" + return token.SignedString([]byte(secretKey)) +} + +// ParseResultUrl 提取视频任务结果的 url +func (a *TaskAdaptor) ParseResultUrl(resp map[string]any) (string, error) { + data, ok := resp["data"].(map[string]any) + if !ok { + return "", fmt.Errorf("data field not found or invalid") + } + taskResult, ok := data["task_result"].(map[string]any) + if !ok { + return "", fmt.Errorf("task_result field not found or invalid") + } + videos, ok := taskResult["videos"].([]interface{}) + if !ok || len(videos) == 0 { + return "", fmt.Errorf("videos field not found or empty") + } + video, ok := videos[0].(map[string]interface{}) + if !ok { + return "", fmt.Errorf("video item invalid") + } + url, ok := video["url"].(string) + if !ok || url == "" { + return "", fmt.Errorf("url field not found or invalid") + } + return url, nil +} diff --git a/relay/channel/task/suno/adaptor.go b/relay/channel/task/suno/adaptor.go index 03d60516..f7042348 100644 --- a/relay/channel/task/suno/adaptor.go +++ b/relay/channel/task/suno/adaptor.go @@ -22,6 +22,10 @@ type TaskAdaptor struct { ChannelType int } +func (a *TaskAdaptor) ParseResultUrl(resp map[string]any) (string, error) { + return "", nil // todo implement this method if needed +} + func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) { a.ChannelType = info.ChannelType } diff --git a/relay/channel/tencent/adaptor.go b/relay/channel/tencent/adaptor.go index 44718a25..7ea3aae7 100644 --- a/relay/channel/tencent/adaptor.go +++ b/relay/channel/tencent/adaptor.go @@ -98,7 +98,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom if info.IsStream { var responseText string err, responseText = tencentStreamHandler(c, resp) - usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) + usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) } else { err, usage = tencentHandler(c, resp) } diff --git a/relay/channel/vertex/adaptor.go b/relay/channel/vertex/adaptor.go index 424234a8..e568f651 100644 --- a/relay/channel/vertex/adaptor.go +++ b/relay/channel/vertex/adaptor.go @@ -83,10 +83,13 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { suffix := "" if a.RequestMode == RequestModeGemini { if model_setting.GetGeminiSettings().ThinkingAdapterEnabled { - // suffix -thinking and -nothinking - if strings.HasSuffix(info.OriginModelName, "-thinking") { + // 新增逻辑:处理 -thinking- 格式 + if strings.Contains(info.UpstreamModelName, "-thinking-") { + parts := strings.Split(info.UpstreamModelName, "-thinking-") + info.UpstreamModelName = parts[0] + } else if strings.HasSuffix(info.UpstreamModelName, "-thinking") { // 旧的适配 info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking") - } else if strings.HasSuffix(info.OriginModelName, "-nothinking") { + } else if strings.HasSuffix(info.UpstreamModelName, "-nothinking") { info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-nothinking") } } diff --git a/relay/channel/xai/text.go b/relay/channel/xai/text.go index e019c2dc..408160fb 100644 --- a/relay/channel/xai/text.go +++ b/relay/channel/xai/text.go @@ -68,7 +68,7 @@ func xAIStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel }) if !containStreamUsage { - usage, _ = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens) + usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens) usage.CompletionTokens += toolCount * 7 } diff --git a/relay/claude_handler.go b/relay/claude_handler.go index e8805255..42139ddf 100644 --- a/relay/claude_handler.go +++ b/relay/claude_handler.go @@ -46,13 +46,11 @@ func ClaudeHelper(c *gin.Context) (claudeError *dto.ClaudeErrorWithStatusCode) { relayInfo.IsStream = true } - err = helper.ModelMappedHelper(c, relayInfo) + err = helper.ModelMappedHelper(c, relayInfo, textRequest) if err != nil { return service.ClaudeErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError) } - textRequest.Model = relayInfo.UpstreamModelName - promptTokens, err := getClaudePromptTokens(textRequest, relayInfo) // count messages token error 计算promptTokens错误 if err != nil { @@ -126,7 +124,7 @@ func ClaudeHelper(c *gin.Context) (claudeError *dto.ClaudeErrorWithStatusCode) { var httpResp *http.Response resp, err := adaptor.DoRequest(c, relayInfo, requestBody) if err != nil { - return service.ClaudeErrorWrapperLocal(err, "do_request_failed", http.StatusInternalServerError) + return service.ClaudeErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) } if resp != nil { diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index a842a58d..3759c363 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -34,9 +34,14 @@ type ClaudeConvertInfo struct { } const ( - RelayFormatOpenAI = "openai" - RelayFormatClaude = "claude" - RelayFormatGemini = "gemini" + RelayFormatOpenAI = "openai" + RelayFormatClaude = "claude" + RelayFormatGemini = "gemini" + RelayFormatOpenAIResponses = "openai_responses" + RelayFormatOpenAIAudio = "openai_audio" + RelayFormatOpenAIImage = "openai_image" + RelayFormatRerank = "rerank" + RelayFormatEmbedding = "embedding" ) type RerankerInfo struct { @@ -143,6 +148,7 @@ func GenRelayInfoClaude(c *gin.Context) *RelayInfo { func GenRelayInfoRerank(c *gin.Context, req *dto.RerankRequest) *RelayInfo { info := GenRelayInfo(c) info.RelayMode = relayconstant.RelayModeRerank + info.RelayFormat = RelayFormatRerank info.RerankerInfo = &RerankerInfo{ Documents: req.Documents, ReturnDocuments: req.GetReturnDocuments(), @@ -150,9 +156,25 @@ func GenRelayInfoRerank(c *gin.Context, req *dto.RerankRequest) *RelayInfo { return info } +func GenRelayInfoOpenAIAudio(c *gin.Context) *RelayInfo { + info := GenRelayInfo(c) + info.RelayFormat = RelayFormatOpenAIAudio + return info +} + +func GenRelayInfoEmbedding(c *gin.Context) *RelayInfo { + info := GenRelayInfo(c) + info.RelayFormat = RelayFormatEmbedding + return info +} + func GenRelayInfoResponses(c *gin.Context, req *dto.OpenAIResponsesRequest) *RelayInfo { info := GenRelayInfo(c) info.RelayMode = relayconstant.RelayModeResponses + info.RelayFormat = RelayFormatOpenAIResponses + + info.SupportStreamOptions = false + info.ResponsesUsageInfo = &ResponsesUsageInfo{ BuiltInTools: make(map[string]*BuildInToolInfo), } @@ -175,6 +197,19 @@ func GenRelayInfoResponses(c *gin.Context, req *dto.OpenAIResponsesRequest) *Rel return info } +func GenRelayInfoGemini(c *gin.Context) *RelayInfo { + info := GenRelayInfo(c) + info.RelayFormat = RelayFormatGemini + info.ShouldIncludeUsage = false + return info +} + +func GenRelayInfoImage(c *gin.Context) *RelayInfo { + info := GenRelayInfo(c) + info.RelayFormat = RelayFormatOpenAIImage + return info +} + func GenRelayInfo(c *gin.Context) *RelayInfo { channelType := c.GetInt("channel_type") channelId := c.GetInt("channel_id") @@ -243,10 +278,6 @@ func GenRelayInfo(c *gin.Context) *RelayInfo { if streamSupportedChannels[info.ChannelType] { info.SupportStreamOptions = true } - // responses 模式不支持 StreamOptions - if relayconstant.RelayModeResponses == info.RelayMode { - info.SupportStreamOptions = false - } return info } diff --git a/relay/constant/relay_mode.go b/relay/constant/relay_mode.go index f22a20bd..02a286e2 100644 --- a/relay/constant/relay_mode.go +++ b/relay/constant/relay_mode.go @@ -38,6 +38,9 @@ const ( RelayModeSunoFetchByID RelayModeSunoSubmit + RelayModeKlingFetchByID + RelayModeKlingSubmit + RelayModeRerank RelayModeResponses @@ -133,3 +136,13 @@ func Path2RelaySuno(method, path string) int { } return relayMode } + +func Path2RelayKling(method, path string) int { + relayMode := RelayModeUnknown + if method == http.MethodPost && strings.HasSuffix(path, "/video/generations") { + relayMode = RelayModeKlingSubmit + } else if method == http.MethodGet && strings.Contains(path, "/video/generations/") { + relayMode = RelayModeKlingFetchByID + } + return relayMode +} diff --git a/relay/relay_embedding.go b/relay/embedding_handler.go similarity index 94% rename from relay/relay_embedding.go rename to relay/embedding_handler.go index b4909849..849c70da 100644 --- a/relay/relay_embedding.go +++ b/relay/embedding_handler.go @@ -15,7 +15,7 @@ import ( ) func getEmbeddingPromptToken(embeddingRequest dto.EmbeddingRequest) int { - token, _ := service.CountTokenInput(embeddingRequest.Input, embeddingRequest.Model) + token := service.CountTokenInput(embeddingRequest.Input, embeddingRequest.Model) return token } @@ -33,7 +33,7 @@ func validateEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, embed } func EmbeddingHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { - relayInfo := relaycommon.GenRelayInfo(c) + relayInfo := relaycommon.GenRelayInfoEmbedding(c) var embeddingRequest *dto.EmbeddingRequest err := common.UnmarshalBodyReusable(c, &embeddingRequest) @@ -47,13 +47,11 @@ func EmbeddingHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) return service.OpenAIErrorWrapperLocal(err, "invalid_embedding_request", http.StatusBadRequest) } - err = helper.ModelMappedHelper(c, relayInfo) + err = helper.ModelMappedHelper(c, relayInfo, embeddingRequest) if err != nil { return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError) } - embeddingRequest.Model = relayInfo.UpstreamModelName - promptToken := getEmbeddingPromptToken(*embeddingRequest) relayInfo.PromptTokens = promptToken diff --git a/relay/relay-gemini.go b/relay/gemini_handler.go similarity index 92% rename from relay/relay-gemini.go rename to relay/gemini_handler.go index 80e5a694..14d58cc5 100644 --- a/relay/relay-gemini.go +++ b/relay/gemini_handler.go @@ -59,7 +59,7 @@ func checkGeminiInputSensitive(textRequest *gemini.GeminiChatRequest) ([]string, return sensitiveWords, err } -func getGeminiInputTokens(req *gemini.GeminiChatRequest, info *relaycommon.RelayInfo) (int, error) { +func getGeminiInputTokens(req *gemini.GeminiChatRequest, info *relaycommon.RelayInfo) int { // 计算输入 token 数量 var inputTexts []string for _, content := range req.Contents { @@ -71,9 +71,9 @@ func getGeminiInputTokens(req *gemini.GeminiChatRequest, info *relaycommon.Relay } inputText := strings.Join(inputTexts, "\n") - inputTokens, err := service.CountTokenInput(inputText, info.UpstreamModelName) + inputTokens := service.CountTokenInput(inputText, info.UpstreamModelName) info.PromptTokens = inputTokens - return inputTokens, err + return inputTokens } func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { @@ -83,7 +83,7 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { return service.OpenAIErrorWrapperLocal(err, "invalid_gemini_request", http.StatusBadRequest) } - relayInfo := relaycommon.GenRelayInfo(c) + relayInfo := relaycommon.GenRelayInfoGemini(c) // 检查 Gemini 流式模式 checkGeminiStreamMode(c, relayInfo) @@ -97,7 +97,7 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { } // model mapped 模型映射 - err = helper.ModelMappedHelper(c, relayInfo) + err = helper.ModelMappedHelper(c, relayInfo, req) if err != nil { return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusBadRequest) } @@ -106,7 +106,7 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { promptTokens := value.(int) relayInfo.SetPromptTokens(promptTokens) } else { - promptTokens, err := getGeminiInputTokens(req, relayInfo) + promptTokens := getGeminiInputTokens(req, relayInfo) if err != nil { return service.OpenAIErrorWrapperLocal(err, "count_input_tokens_error", http.StatusBadRequest) } @@ -162,7 +162,7 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { resp, err := adaptor.DoRequest(c, relayInfo, bytes.NewReader(requestBody)) if err != nil { common.LogError(c, "Do gemini request failed: "+err.Error()) - return service.OpenAIErrorWrapperLocal(err, "do_request_failed", http.StatusInternalServerError) + return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) } statusCodeMappingStr := c.GetString("status_code_mapping") diff --git a/relay/helper/model_mapped.go b/relay/helper/model_mapped.go index 9bf67c03..c1735149 100644 --- a/relay/helper/model_mapped.go +++ b/relay/helper/model_mapped.go @@ -4,12 +4,14 @@ import ( "encoding/json" "errors" "fmt" + common2 "one-api/common" + "one-api/dto" "one-api/relay/common" "github.com/gin-gonic/gin" ) -func ModelMappedHelper(c *gin.Context, info *common.RelayInfo) error { +func ModelMappedHelper(c *gin.Context, info *common.RelayInfo, request any) error { // map model name modelMapping := c.GetString("model_mapping") if modelMapping != "" && modelMapping != "{}" { @@ -50,5 +52,41 @@ func ModelMappedHelper(c *gin.Context, info *common.RelayInfo) error { info.UpstreamModelName = currentModel } } + if request != nil { + switch info.RelayFormat { + case common.RelayFormatGemini: + // Gemini 模型映射 + case common.RelayFormatClaude: + if claudeRequest, ok := request.(*dto.ClaudeRequest); ok { + claudeRequest.Model = info.UpstreamModelName + } + case common.RelayFormatOpenAIResponses: + if openAIResponsesRequest, ok := request.(*dto.OpenAIResponsesRequest); ok { + openAIResponsesRequest.Model = info.UpstreamModelName + } + case common.RelayFormatOpenAIAudio: + if openAIAudioRequest, ok := request.(*dto.AudioRequest); ok { + openAIAudioRequest.Model = info.UpstreamModelName + } + case common.RelayFormatOpenAIImage: + if imageRequest, ok := request.(*dto.ImageRequest); ok { + imageRequest.Model = info.UpstreamModelName + } + case common.RelayFormatRerank: + if rerankRequest, ok := request.(*dto.RerankRequest); ok { + rerankRequest.Model = info.UpstreamModelName + } + case common.RelayFormatEmbedding: + if embeddingRequest, ok := request.(*dto.EmbeddingRequest); ok { + embeddingRequest.Model = info.UpstreamModelName + } + default: + if openAIRequest, ok := request.(*dto.GeneralOpenAIRequest); ok { + openAIRequest.Model = info.UpstreamModelName + } else { + common2.LogWarn(c, fmt.Sprintf("model mapped but request type %T not supported", request)) + } + } + } return nil } diff --git a/relay/relay-image.go b/relay/image_handler.go similarity index 98% rename from relay/relay-image.go rename to relay/image_handler.go index 197a8af6..57917025 100644 --- a/relay/relay-image.go +++ b/relay/image_handler.go @@ -102,7 +102,7 @@ func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto. } func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { - relayInfo := relaycommon.GenRelayInfo(c) + relayInfo := relaycommon.GenRelayInfoImage(c) imageRequest, err := getAndValidImageRequest(c, relayInfo) if err != nil { @@ -110,13 +110,11 @@ func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { return service.OpenAIErrorWrapper(err, "invalid_image_request", http.StatusBadRequest) } - err = helper.ModelMappedHelper(c, relayInfo) + err = helper.ModelMappedHelper(c, relayInfo, imageRequest) if err != nil { return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError) } - imageRequest.Model = relayInfo.UpstreamModelName - priceData, err := helper.ModelPriceHelper(c, relayInfo, len(imageRequest.Prompt), 0) if err != nil { return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError) diff --git a/relay/relay-text.go b/relay/relay-text.go index 24fb8155..db8d0d3b 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -108,13 +108,11 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { } } - err = helper.ModelMappedHelper(c, relayInfo) + err = helper.ModelMappedHelper(c, relayInfo, textRequest) if err != nil { return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError) } - textRequest.Model = relayInfo.UpstreamModelName - // 获取 promptTokens,如果上下文中已经存在,则直接使用 var promptTokens int if value, exists := c.Get("prompt_tokens"); exists { @@ -253,11 +251,11 @@ func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.Re case relayconstant.RelayModeChatCompletions: promptTokens, err = service.CountTokenChatRequest(info, *textRequest) case relayconstant.RelayModeCompletions: - promptTokens, err = service.CountTokenInput(textRequest.Prompt, textRequest.Model) + promptTokens = service.CountTokenInput(textRequest.Prompt, textRequest.Model) case relayconstant.RelayModeModerations: - promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model) + promptTokens = service.CountTokenInput(textRequest.Input, textRequest.Model) case relayconstant.RelayModeEmbeddings: - promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model) + promptTokens = service.CountTokenInput(textRequest.Input, textRequest.Model) default: err = errors.New("unknown relay mode") promptTokens = 0 diff --git a/relay/relay_adaptor.go b/relay/relay_adaptor.go index 7bf0da9f..626bb7e4 100644 --- a/relay/relay_adaptor.go +++ b/relay/relay_adaptor.go @@ -22,6 +22,7 @@ import ( "one-api/relay/channel/palm" "one-api/relay/channel/perplexity" "one-api/relay/channel/siliconflow" + "one-api/relay/channel/task/kling" "one-api/relay/channel/task/suno" "one-api/relay/channel/tencent" "one-api/relay/channel/vertex" @@ -101,6 +102,8 @@ func GetTaskAdaptor(platform commonconstant.TaskPlatform) channel.TaskAdaptor { // return &aiproxy.Adaptor{} case commonconstant.TaskPlatformSuno: return &suno.TaskAdaptor{} + case commonconstant.TaskPlatformKling: + return &kling.TaskAdaptor{} } return nil } diff --git a/relay/relay_task.go b/relay/relay_task.go index 3da9a20f..245fd681 100644 --- a/relay/relay_task.go +++ b/relay/relay_task.go @@ -37,6 +37,9 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) { } modelName := service.CoverTaskActionToModelName(platform, relayInfo.Action) + if platform == constant.TaskPlatformKling { + modelName = relayInfo.OriginModelName + } modelPrice, success := ratio_setting.GetModelPrice(modelName, true) if !success { defaultPrice, ok := ratio_setting.GetDefaultModelRatioMap()[modelName] @@ -136,10 +139,11 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) { } relayInfo.ConsumeQuota = true // insert task - task := model.InitTask(constant.TaskPlatformSuno, relayInfo) + task := model.InitTask(platform, relayInfo) task.TaskID = taskID task.Quota = quota task.Data = taskData + task.Action = relayInfo.Action err = task.Insert() if err != nil { taskErr = service.TaskErrorWrapper(err, "insert_task_failed", http.StatusInternalServerError) @@ -149,8 +153,9 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) { } var fetchRespBuilders = map[int]func(c *gin.Context) (respBody []byte, taskResp *dto.TaskError){ - relayconstant.RelayModeSunoFetchByID: sunoFetchByIDRespBodyBuilder, - relayconstant.RelayModeSunoFetch: sunoFetchRespBodyBuilder, + relayconstant.RelayModeSunoFetchByID: sunoFetchByIDRespBodyBuilder, + relayconstant.RelayModeSunoFetch: sunoFetchRespBodyBuilder, + relayconstant.RelayModeKlingFetchByID: videoFetchByIDRespBodyBuilder, } func RelayTaskFetch(c *gin.Context, relayMode int) (taskResp *dto.TaskError) { @@ -225,6 +230,27 @@ func sunoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dt return } +func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) { + taskId := c.Param("id") + userId := c.GetInt("id") + + originTask, exist, err := model.GetByTaskId(userId, taskId) + if err != nil { + taskResp = service.TaskErrorWrapper(err, "get_task_failed", http.StatusInternalServerError) + return + } + if !exist { + taskResp = service.TaskErrorWrapperLocal(errors.New("task_not_exist"), "task_not_exist", http.StatusBadRequest) + return + } + + respBody, err = json.Marshal(dto.TaskResponse[any]{ + Code: "success", + Data: TaskModel2Dto(originTask), + }) + return +} + func TaskModel2Dto(task *model.Task) *dto.TaskDto { return &dto.TaskDto{ TaskID: task.TaskID, diff --git a/relay/relay_rerank.go b/relay/rerank_handler.go similarity index 92% rename from relay/relay_rerank.go rename to relay/rerank_handler.go index 6ca98de7..319811b8 100644 --- a/relay/relay_rerank.go +++ b/relay/rerank_handler.go @@ -14,12 +14,10 @@ import ( ) func getRerankPromptToken(rerankRequest dto.RerankRequest) int { - token, _ := service.CountTokenInput(rerankRequest.Query, rerankRequest.Model) + token := service.CountTokenInput(rerankRequest.Query, rerankRequest.Model) for _, document := range rerankRequest.Documents { - tkm, err := service.CountTokenInput(document, rerankRequest.Model) - if err == nil { - token += tkm - } + tkm := service.CountTokenInput(document, rerankRequest.Model) + token += tkm } return token } @@ -42,13 +40,11 @@ func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWith return service.OpenAIErrorWrapperLocal(fmt.Errorf("documents is empty"), "invalid_documents", http.StatusBadRequest) } - err = helper.ModelMappedHelper(c, relayInfo) + err = helper.ModelMappedHelper(c, relayInfo, rerankRequest) if err != nil { return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError) } - rerankRequest.Model = relayInfo.UpstreamModelName - promptToken := getRerankPromptToken(*rerankRequest) relayInfo.PromptTokens = promptToken diff --git a/relay/relay-responses.go b/relay/responses_handler.go similarity index 93% rename from relay/relay-responses.go rename to relay/responses_handler.go index fd3ddb5a..e744e354 100644 --- a/relay/relay-responses.go +++ b/relay/responses_handler.go @@ -40,10 +40,10 @@ func checkInputSensitive(textRequest *dto.OpenAIResponsesRequest, info *relaycom return sensitiveWords, err } -func getInputTokens(req *dto.OpenAIResponsesRequest, info *relaycommon.RelayInfo) (int, error) { - inputTokens, err := service.CountTokenInput(req.Input, req.Model) +func getInputTokens(req *dto.OpenAIResponsesRequest, info *relaycommon.RelayInfo) int { + inputTokens := service.CountTokenInput(req.Input, req.Model) info.PromptTokens = inputTokens - return inputTokens, err + return inputTokens } func ResponsesHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { @@ -63,19 +63,16 @@ func ResponsesHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) } } - err = helper.ModelMappedHelper(c, relayInfo) + err = helper.ModelMappedHelper(c, relayInfo, req) if err != nil { return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusBadRequest) } - req.Model = relayInfo.UpstreamModelName + if value, exists := c.Get("prompt_tokens"); exists { promptTokens := value.(int) relayInfo.SetPromptTokens(promptTokens) } else { - promptTokens, err := getInputTokens(req, relayInfo) - if err != nil { - return service.OpenAIErrorWrapperLocal(err, "count_input_tokens_error", http.StatusBadRequest) - } + promptTokens := getInputTokens(req, relayInfo) c.Set("prompt_tokens", promptTokens) } diff --git a/router/api-router.go b/router/api-router.go index 45930246..badfa7bf 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -36,6 +36,7 @@ func SetApiRouter(router *gin.Engine) { apiRouter.GET("/oauth/email/bind", middleware.CriticalRateLimit(), controller.EmailBind) apiRouter.GET("/oauth/telegram/login", middleware.CriticalRateLimit(), controller.TelegramLogin) apiRouter.GET("/oauth/telegram/bind", middleware.CriticalRateLimit(), controller.TelegramBind) + apiRouter.GET("/ratio_config", middleware.CriticalRateLimit(), controller.GetRatioConfig) userRoute := apiRouter.Group("/user") { @@ -83,6 +84,12 @@ func SetApiRouter(router *gin.Engine) { optionRoute.POST("/rest_model_ratio", controller.ResetModelRatio) optionRoute.POST("/migrate_console_setting", controller.MigrateConsoleSetting) // 用于迁移检测的旧键,下个版本会删除 } + ratioSyncRoute := apiRouter.Group("/ratio_sync") + ratioSyncRoute.Use(middleware.RootAuth()) + { + ratioSyncRoute.GET("/channels", controller.GetSyncableChannels) + ratioSyncRoute.POST("/fetch", controller.FetchUpstreamRatios) + } channelRoute := apiRouter.Group("/channel") channelRoute.Use(middleware.AdminAuth()) { diff --git a/router/main.go b/router/main.go index b8ac4055..0d2bfdce 100644 --- a/router/main.go +++ b/router/main.go @@ -14,6 +14,7 @@ func SetRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) { SetApiRouter(router) SetDashboardRouter(router) SetRelayRouter(router) + SetVideoRouter(router) frontendBaseUrl := os.Getenv("FRONTEND_BASE_URL") if common.IsMasterNode && frontendBaseUrl != "" { frontendBaseUrl = "" diff --git a/router/video-router.go b/router/video-router.go new file mode 100644 index 00000000..7201c34a --- /dev/null +++ b/router/video-router.go @@ -0,0 +1,17 @@ +package router + +import ( + "one-api/controller" + "one-api/middleware" + + "github.com/gin-gonic/gin" +) + +func SetVideoRouter(router *gin.Engine) { + videoV1Router := router.Group("/v1") + videoV1Router.Use(middleware.TokenAuth(), middleware.Distribute()) + { + videoV1Router.POST("/video/generations", controller.RelayTask) + videoV1Router.GET("/video/generations/:task_id", controller.RelayTask) + } +} diff --git a/service/token_counter.go b/service/token_counter.go index 82de0a05..53c6c2fa 100644 --- a/service/token_counter.go +++ b/service/token_counter.go @@ -171,7 +171,7 @@ func CountTokenChatRequest(info *relaycommon.RelayInfo, request dto.GeneralOpenA countStr += fmt.Sprintf("%v", tool.Function.Parameters) } } - toolTokens, err := CountTokenInput(countStr, request.Model) + toolTokens := CountTokenInput(countStr, request.Model) if err != nil { return 0, err } @@ -194,7 +194,7 @@ func CountTokenClaudeRequest(request dto.ClaudeRequest, model string) (int, erro // Count tokens in system message if request.System != "" { - systemTokens, err := CountTokenInput(request.System, model) + systemTokens := CountTokenInput(request.System, model) if err != nil { return 0, err } @@ -296,10 +296,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent, switch request.Type { case dto.RealtimeEventTypeSessionUpdate: if request.Session != nil { - msgTokens, err := CountTextToken(request.Session.Instructions, model) - if err != nil { - return 0, 0, err - } + msgTokens := CountTextToken(request.Session.Instructions, model) textToken += msgTokens } case dto.RealtimeEventResponseAudioDelta: @@ -311,10 +308,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent, audioToken += atk case dto.RealtimeEventResponseAudioTranscriptionDelta, dto.RealtimeEventResponseFunctionCallArgumentsDelta: // count text token - tkm, err := CountTextToken(request.Delta, model) - if err != nil { - return 0, 0, fmt.Errorf("error counting text token: %v", err) - } + tkm := CountTextToken(request.Delta, model) textToken += tkm case dto.RealtimeEventInputAudioBufferAppend: // count audio token @@ -329,10 +323,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent, case "message": for _, content := range request.Item.Content { if content.Type == "input_text" { - tokens, err := CountTextToken(content.Text, model) - if err != nil { - return 0, 0, err - } + tokens := CountTextToken(content.Text, model) textToken += tokens } } @@ -343,10 +334,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent, if !info.IsFirstRequest { if info.RealtimeTools != nil && len(info.RealtimeTools) > 0 { for _, tool := range info.RealtimeTools { - toolTokens, err := CountTokenInput(tool, model) - if err != nil { - return 0, 0, err - } + toolTokens := CountTokenInput(tool, model) textToken += 8 textToken += toolTokens } @@ -409,7 +397,7 @@ func CountTokenMessages(info *relaycommon.RelayInfo, messages []dto.Message, mod return tokenNum, nil } -func CountTokenInput(input any, model string) (int, error) { +func CountTokenInput(input any, model string) int { switch v := input.(type) { case string: return CountTextToken(v, model) @@ -432,13 +420,13 @@ func CountTokenInput(input any, model string) (int, error) { func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice, model string) int { tokens := 0 for _, message := range messages { - tkm, _ := CountTokenInput(message.Delta.GetContentString(), model) + tkm := CountTokenInput(message.Delta.GetContentString(), model) tokens += tkm if message.Delta.ToolCalls != nil { for _, tool := range message.Delta.ToolCalls { - tkm, _ := CountTokenInput(tool.Function.Name, model) + tkm := CountTokenInput(tool.Function.Name, model) tokens += tkm - tkm, _ = CountTokenInput(tool.Function.Arguments, model) + tkm = CountTokenInput(tool.Function.Arguments, model) tokens += tkm } } @@ -446,9 +434,9 @@ func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice, return tokens } -func CountTTSToken(text string, model string) (int, error) { +func CountTTSToken(text string, model string) int { if strings.HasPrefix(model, "tts") { - return utf8.RuneCountInString(text), nil + return utf8.RuneCountInString(text) } else { return CountTextToken(text, model) } @@ -483,8 +471,10 @@ func CountAudioTokenOutput(audioBase64 string, audioFormat string) (int, error) //} // CountTextToken 统计文本的token数量,仅当文本包含敏感词,返回错误,同时返回token数量 -func CountTextToken(text string, model string) (int, error) { - var err error +func CountTextToken(text string, model string) int { + if text == "" { + return 0 + } tokenEncoder := getTokenEncoder(model) - return getTokenNum(tokenEncoder, text), err + return getTokenNum(tokenEncoder, text) } diff --git a/service/usage_helpr.go b/service/usage_helpr.go index c52e1e15..ca9c0830 100644 --- a/service/usage_helpr.go +++ b/service/usage_helpr.go @@ -16,13 +16,13 @@ import ( // return 0, errors.New("unknown relay mode") //} -func ResponseText2Usage(responseText string, modeName string, promptTokens int) (*dto.Usage, error) { +func ResponseText2Usage(responseText string, modeName string, promptTokens int) *dto.Usage { usage := &dto.Usage{} usage.PromptTokens = promptTokens - ctkm, err := CountTextToken(responseText, modeName) + ctkm := CountTextToken(responseText, modeName) usage.CompletionTokens = ctkm usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens - return usage, err + return usage } func ValidUsage(usage *dto.Usage) bool { diff --git a/setting/payment.go b/setting/payment.go index 4ffa4381..3fc0f14a 100644 --- a/setting/payment.go +++ b/setting/payment.go @@ -13,12 +13,12 @@ var PayMethods = []map[string]string{ { "name": "支付宝", "color": "rgba(var(--semi-blue-5), 1)", - "type": "zfb", + "type": "alipay", }, { "name": "微信", "color": "rgba(var(--semi-green-5), 1)", - "type": "wx", + "type": "wxpay", }, } diff --git a/setting/ratio_setting/cache_ratio.go b/setting/ratio_setting/cache_ratio.go index aa934b22..51d473a8 100644 --- a/setting/ratio_setting/cache_ratio.go +++ b/setting/ratio_setting/cache_ratio.go @@ -85,7 +85,11 @@ func UpdateCacheRatioByJSONString(jsonStr string) error { cacheRatioMapMutex.Lock() defer cacheRatioMapMutex.Unlock() cacheRatioMap = make(map[string]float64) - return json.Unmarshal([]byte(jsonStr), &cacheRatioMap) + err := json.Unmarshal([]byte(jsonStr), &cacheRatioMap) + if err == nil { + InvalidateExposedDataCache() + } + return err } // GetCacheRatio returns the cache ratio for a model @@ -106,3 +110,13 @@ func GetCreateCacheRatio(name string) (float64, bool) { } return ratio, true } + +func GetCacheRatioCopy() map[string]float64 { + cacheRatioMapMutex.RLock() + defer cacheRatioMapMutex.RUnlock() + copyMap := make(map[string]float64, len(cacheRatioMap)) + for k, v := range cacheRatioMap { + copyMap[k] = v + } + return copyMap +} diff --git a/setting/ratio_setting/expose_ratio.go b/setting/ratio_setting/expose_ratio.go new file mode 100644 index 00000000..8fca0bcb --- /dev/null +++ b/setting/ratio_setting/expose_ratio.go @@ -0,0 +1,17 @@ +package ratio_setting + +import "sync/atomic" + +var exposeRatioEnabled atomic.Bool + +func init() { + exposeRatioEnabled.Store(false) +} + +func SetExposeRatioEnabled(enabled bool) { + exposeRatioEnabled.Store(enabled) +} + +func IsExposeRatioEnabled() bool { + return exposeRatioEnabled.Load() +} \ No newline at end of file diff --git a/setting/ratio_setting/exposed_cache.go b/setting/ratio_setting/exposed_cache.go new file mode 100644 index 00000000..9e5b6c30 --- /dev/null +++ b/setting/ratio_setting/exposed_cache.go @@ -0,0 +1,55 @@ +package ratio_setting + +import ( + "sync" + "sync/atomic" + "time" + + "github.com/gin-gonic/gin" +) + +const exposedDataTTL = 30 * time.Second + +type exposedCache struct { + data gin.H + expiresAt time.Time +} + +var ( + exposedData atomic.Value + rebuildMu sync.Mutex +) + +func InvalidateExposedDataCache() { + exposedData.Store((*exposedCache)(nil)) +} + +func cloneGinH(src gin.H) gin.H { + dst := make(gin.H, len(src)) + for k, v := range src { + dst[k] = v + } + return dst +} + +func GetExposedData() gin.H { + if c, ok := exposedData.Load().(*exposedCache); ok && c != nil && time.Now().Before(c.expiresAt) { + return cloneGinH(c.data) + } + rebuildMu.Lock() + defer rebuildMu.Unlock() + if c, ok := exposedData.Load().(*exposedCache); ok && c != nil && time.Now().Before(c.expiresAt) { + return cloneGinH(c.data) + } + newData := gin.H{ + "model_ratio": GetModelRatioCopy(), + "completion_ratio": GetCompletionRatioCopy(), + "cache_ratio": GetCacheRatioCopy(), + "model_price": GetModelPriceCopy(), + } + exposedData.Store(&exposedCache{ + data: newData, + expiresAt: time.Now().Add(exposedDataTTL), + }) + return cloneGinH(newData) +} \ No newline at end of file diff --git a/setting/ratio_setting/model_ratio.go b/setting/ratio_setting/model_ratio.go index 3102dfe9..1eaf25b1 100644 --- a/setting/ratio_setting/model_ratio.go +++ b/setting/ratio_setting/model_ratio.go @@ -317,7 +317,11 @@ func UpdateModelPriceByJSONString(jsonStr string) error { modelPriceMapMutex.Lock() defer modelPriceMapMutex.Unlock() modelPriceMap = make(map[string]float64) - return json.Unmarshal([]byte(jsonStr), &modelPriceMap) + err := json.Unmarshal([]byte(jsonStr), &modelPriceMap) + if err == nil { + InvalidateExposedDataCache() + } + return err } // GetModelPrice 返回模型的价格,如果模型不存在则返回-1,false @@ -345,7 +349,11 @@ func UpdateModelRatioByJSONString(jsonStr string) error { modelRatioMapMutex.Lock() defer modelRatioMapMutex.Unlock() modelRatioMap = make(map[string]float64) - return json.Unmarshal([]byte(jsonStr), &modelRatioMap) + err := json.Unmarshal([]byte(jsonStr), &modelRatioMap) + if err == nil { + InvalidateExposedDataCache() + } + return err } // 处理带有思考预算的模型名称,方便统一定价 @@ -405,7 +413,11 @@ func UpdateCompletionRatioByJSONString(jsonStr string) error { CompletionRatioMutex.Lock() defer CompletionRatioMutex.Unlock() CompletionRatio = make(map[string]float64) - return json.Unmarshal([]byte(jsonStr), &CompletionRatio) + err := json.Unmarshal([]byte(jsonStr), &CompletionRatio) + if err == nil { + InvalidateExposedDataCache() + } + return err } func GetCompletionRatio(name string) float64 { @@ -609,3 +621,33 @@ func GetImageRatio(name string) (float64, bool) { } return ratio, true } + +func GetModelRatioCopy() map[string]float64 { + modelRatioMapMutex.RLock() + defer modelRatioMapMutex.RUnlock() + copyMap := make(map[string]float64, len(modelRatioMap)) + for k, v := range modelRatioMap { + copyMap[k] = v + } + return copyMap +} + +func GetModelPriceCopy() map[string]float64 { + modelPriceMapMutex.RLock() + defer modelPriceMapMutex.RUnlock() + copyMap := make(map[string]float64, len(modelPriceMap)) + for k, v := range modelPriceMap { + copyMap[k] = v + } + return copyMap +} + +func GetCompletionRatioCopy() map[string]float64 { + CompletionRatioMutex.RLock() + defer CompletionRatioMutex.RUnlock() + copyMap := make(map[string]float64, len(CompletionRatio)) + for k, v := range CompletionRatio { + copyMap[k] = v + } + return copyMap +} diff --git a/web/src/components/settings/ChannelSelectorModal.js b/web/src/components/settings/ChannelSelectorModal.js new file mode 100644 index 00000000..573329b3 --- /dev/null +++ b/web/src/components/settings/ChannelSelectorModal.js @@ -0,0 +1,143 @@ +import React, { useState } from 'react'; +import { + Modal, + Transfer, + Input, + Space, + Checkbox, + Avatar, + Highlight, +} from '@douyinfe/semi-ui'; +import { IconClose } from '@douyinfe/semi-icons'; + +const CHANNEL_STATUS_CONFIG = { + 1: { color: 'green', text: '启用' }, + 2: { color: 'red', text: '禁用' }, + 3: { color: 'amber', text: '自禁' }, + default: { color: 'grey', text: '未知' } +}; + +const getChannelStatusConfig = (status) => { + return CHANNEL_STATUS_CONFIG[status] || CHANNEL_STATUS_CONFIG.default; +}; + +export default function ChannelSelectorModal({ + t, + visible, + onCancel, + onOk, + allChannels = [], + selectedChannelIds = [], + setSelectedChannelIds, + channelEndpoints, + updateChannelEndpoint, +}) { + const [searchText, setSearchText] = useState(''); + + const ChannelInfo = ({ item, showEndpoint = false, isSelected = false }) => { + const channelId = item.key || item.value; + const currentEndpoint = channelEndpoints[channelId]; + const baseUrl = item._originalData?.base_url || ''; + const status = item._originalData?.status || 0; + const statusConfig = getChannelStatusConfig(status); + + return ( + <> + + {statusConfig.text} + +
+
+ {isSelected ? ( + item.label + ) : ( + + )} +
+
+ + {isSelected ? ( + baseUrl + ) : ( + + )} + + {showEndpoint && ( + updateChannelEndpoint(channelId, value)} + placeholder="/api/ratio_config" + className="flex-1 text-xs" + style={{ fontSize: '12px' }} + /> + )} + {isSelected && !showEndpoint && ( + + {currentEndpoint} + + )} +
+
+ + ); + }; + + const renderSourceItem = (item) => { + return ( +
+ + + +
+ ); + }; + + const renderSelectedItem = (item) => { + return ( +
+ + +
+ ); + }; + + const channelFilter = (input, item) => { + const searchLower = input.toLowerCase(); + return item.label.toLowerCase().includes(searchLower) || + (item._originalData?.base_url || '').toLowerCase().includes(searchLower); + }; + + return ( + {t('选择同步渠道')}} + width={1000} + > + + + + + ); +} \ No newline at end of file diff --git a/web/src/components/settings/RatioSetting.js b/web/src/components/settings/RatioSetting.js index bf97282c..1d87c6de 100644 --- a/web/src/components/settings/RatioSetting.js +++ b/web/src/components/settings/RatioSetting.js @@ -6,6 +6,7 @@ import GroupRatioSettings from '../../pages/Setting/Ratio/GroupRatioSettings.js' import ModelRatioSettings from '../../pages/Setting/Ratio/ModelRatioSettings.js'; import ModelSettingsVisualEditor from '../../pages/Setting/Ratio/ModelSettingsVisualEditor.js'; import ModelRatioNotSetEditor from '../../pages/Setting/Ratio/ModelRationNotSetEditor.js'; +import UpstreamRatioSync from '../../pages/Setting/Ratio/UpstreamRatioSync.js'; import { API, showError } from '../../helpers'; @@ -21,6 +22,7 @@ const RatioSetting = () => { GroupGroupRatio: '', AutoGroups: '', DefaultUseAutoGroup: false, + ExposeRatioEnabled: false, UserUsableGroups: '', }); @@ -48,7 +50,7 @@ const RatioSetting = () => { // 如果后端返回的不是合法 JSON,直接展示 } } - if (['DefaultUseAutoGroup'].includes(item.key)) { + if (['DefaultUseAutoGroup', 'ExposeRatioEnabled'].includes(item.key)) { newInputs[item.key] = item.value === 'true' ? true : false; } else { newInputs[item.key] = item.value; @@ -78,10 +80,6 @@ const RatioSetting = () => { return ( - {/* 分组倍率设置 */} - - - {/* 模型倍率设置以及可视化编辑器 */} @@ -100,8 +98,18 @@ const RatioSetting = () => { refresh={onRefresh} /> + + + + {/* 分组倍率设置 */} + + + ); }; diff --git a/web/src/components/table/TaskLogsTable.js b/web/src/components/table/TaskLogsTable.js index b3d0ab7b..37bdde57 100644 --- a/web/src/components/table/TaskLogsTable.js +++ b/web/src/components/table/TaskLogsTable.js @@ -11,7 +11,9 @@ import { XCircle, Loader, List, - Hash + Hash, + Video, + Sparkles } from 'lucide-react'; import { API, @@ -80,6 +82,7 @@ const COLUMN_KEYS = { TASK_STATUS: 'task_status', PROGRESS: 'progress', FAIL_REASON: 'fail_reason', + RESULT_URL: 'result_url', }; const renderTimestamp = (timestampInSeconds) => { @@ -150,6 +153,7 @@ const LogsTable = () => { [COLUMN_KEYS.TASK_STATUS]: true, [COLUMN_KEYS.PROGRESS]: true, [COLUMN_KEYS.FAIL_REASON]: true, + [COLUMN_KEYS.RESULT_URL]: true, }; }; @@ -203,6 +207,12 @@ const LogsTable = () => { {t('生成歌词')} ); + case 'generate': + return ( + }> + {t('生成视频')} + + ); default: return ( }> @@ -220,6 +230,12 @@ const LogsTable = () => { Suno ); + case 'kling': + return ( + }> + Kling + + ); default: return ( }> @@ -411,10 +427,21 @@ const LogsTable = () => { }, { key: COLUMN_KEYS.FAIL_REASON, - title: t('失败原因'), + title: t('详情'), dataIndex: 'fail_reason', fixed: 'right', render: (text, record, index) => { + // 仅当为视频生成任务且成功,且 fail_reason 是 URL 时显示可点击链接 + const isVideoTask = record.action === 'generate'; + const isSuccess = record.status === 'SUCCESS'; + const isUrl = typeof text === 'string' && /^https?:\/\//.test(text); + if (isSuccess && isVideoTask && isUrl) { + return ( + + {t('点击预览视频')} + + ); + } if (!text) { return t('无'); } diff --git a/web/src/constants/channel.constants.js b/web/src/constants/channel.constants.js index 20fed5b7..c4220bd4 100644 --- a/web/src/constants/channel.constants.js +++ b/web/src/constants/channel.constants.js @@ -125,4 +125,9 @@ export const CHANNEL_OPTIONS = [ color: 'blue', label: 'Coze', }, + { + value: 50, + color: 'green', + label: '可灵', + }, ]; diff --git a/web/src/constants/common.constant.js b/web/src/constants/common.constant.js index 1a37d5f6..9ce83432 100644 --- a/web/src/constants/common.constant.js +++ b/web/src/constants/common.constant.js @@ -1 +1,3 @@ export const ITEMS_PER_PAGE = 10; // this value must keep same as the one defined in backend! + +export const DEFAULT_ENDPOINT = '/api/ratio_config'; \ No newline at end of file diff --git a/web/src/i18n/locales/en.json b/web/src/i18n/locales/en.json index fc80f9c1..ab793364 100644 --- a/web/src/i18n/locales/en.json +++ b/web/src/i18n/locales/en.json @@ -1665,5 +1665,28 @@ "确定清除所有失效兑换码?": "Are you sure you want to clear all invalid redemption codes?", "将删除已使用、已禁用及过期的兑换码,此操作不可撤销。": "This will delete all used, disabled, and expired redemption codes, this operation cannot be undone.", "选择过期时间(可选,留空为永久)": "Select expiration time (optional, leave blank for permanent)", - "请输入备注(仅管理员可见)": "Please enter a remark (only visible to administrators)" + "请输入备注(仅管理员可见)": "Please enter a remark (only visible to administrators)", + "上游倍率同步": "Upstream ratio synchronization", + "获取渠道失败:": "Failed to get channels: ", + "请至少选择一个渠道": "Please select at least one channel", + "获取倍率失败:": "Failed to get ratios: ", + "后端请求失败": "Backend request failed", + "部分渠道测试失败:": "Some channels failed to test: ", + "已与上游倍率完全一致,无需同步": "The upstream ratio is completely consistent, no synchronization is required", + "请求后端接口失败:": "Failed to request the backend interface: ", + "同步成功": "Synchronization successful", + "部分保存失败": "Some settings failed to save", + "保存失败": "Save failed", + "选择同步渠道": "Select synchronization channel", + "应用同步": "Apply synchronization", + "倍率类型": "Ratio type", + "当前值": "Current value", + "上游值": "Upstream value", + "差异": "Difference", + "搜索渠道名称或地址": "Search channel name or address", + "缓存倍率": "Cache ratio", + "暂无差异化倍率显示": "No differential ratio display", + "请先选择同步渠道": "Please select the synchronization channel first", + "与本地相同": "Same as local", + "未找到匹配的模型": "No matching model found" } \ No newline at end of file diff --git a/web/src/index.css b/web/src/index.css index c1254fcc..ff7294ad 100644 --- a/web/src/index.css +++ b/web/src/index.css @@ -432,4 +432,72 @@ code { .semi-table-tbody>.semi-table-row { border-bottom: 1px solid rgba(0, 0, 0, 0.1); } +} + +/* ==================== 同步倍率 - 渠道选择器 ==================== */ + +.components-transfer-source-item, +.components-transfer-selected-item { + display: flex; + align-items: center; + padding: 8px; +} + +.semi-transfer-left-list, +.semi-transfer-right-list { + -ms-overflow-style: none; + scrollbar-width: none; +} + +.semi-transfer-left-list::-webkit-scrollbar, +.semi-transfer-right-list::-webkit-scrollbar { + display: none; +} + +.components-transfer-source-item .semi-checkbox, +.components-transfer-selected-item .semi-checkbox { + display: flex; + align-items: center; + width: 100%; +} + +.components-transfer-source-item .semi-avatar, +.components-transfer-selected-item .semi-avatar { + margin-right: 12px; + flex-shrink: 0; +} + +.components-transfer-source-item .info, +.components-transfer-selected-item .info { + flex: 1; + overflow: hidden; + display: flex; + flex-direction: column; + justify-content: center; +} + +.components-transfer-source-item .name, +.components-transfer-selected-item .name { + font-weight: 500; + white-space: nowrap; + overflow: hidden; + text-overflow: ellipsis; +} + +.components-transfer-source-item .email, +.components-transfer-selected-item .email { + font-size: 12px; + color: var(--semi-color-text-2); + display: flex; + align-items: center; +} + +.components-transfer-selected-item .semi-icon-close { + margin-left: 8px; + cursor: pointer; + color: var(--semi-color-text-2); +} + +.components-transfer-selected-item .semi-icon-close:hover { + color: var(--semi-color-text-0); } \ No newline at end of file diff --git a/web/src/pages/Detail/index.js b/web/src/pages/Detail/index.js index 15c02abf..0fd18d16 100644 --- a/web/src/pages/Detail/index.js +++ b/web/src/pages/Detail/index.js @@ -1112,7 +1112,6 @@ const Detail = (props) => { @@ -1389,7 +1388,6 @@ const Detail = (props) => { ) : ( + + + + setInputs({ ...inputs, ExposeRatioEnabled: value }) + } + /> + + diff --git a/web/src/pages/Setting/Ratio/UpstreamRatioSync.js b/web/src/pages/Setting/Ratio/UpstreamRatioSync.js new file mode 100644 index 00000000..f83e0cdc --- /dev/null +++ b/web/src/pages/Setting/Ratio/UpstreamRatioSync.js @@ -0,0 +1,503 @@ +import React, { useState, useCallback, useMemo } from 'react'; +import { + Button, + Table, + Tag, + Empty, + Checkbox, + Form, + Input, +} from '@douyinfe/semi-ui'; +import { IconSearch } from '@douyinfe/semi-icons'; +import { + RefreshCcw, + CheckSquare, +} from 'lucide-react'; +import { API, showError, showSuccess, showWarning, stringToColor } from '../../../helpers'; +import { DEFAULT_ENDPOINT } from '../../../constants'; +import { useTranslation } from 'react-i18next'; +import { + IllustrationNoResult, + IllustrationNoResultDark +} from '@douyinfe/semi-illustrations'; +import ChannelSelectorModal from '../../../components/settings/ChannelSelectorModal'; + +export default function UpstreamRatioSync(props) { + const { t } = useTranslation(); + const [modalVisible, setModalVisible] = useState(false); + const [loading, setLoading] = useState(false); + const [syncLoading, setSyncLoading] = useState(false); + + // 渠道选择相关 + const [allChannels, setAllChannels] = useState([]); + const [selectedChannelIds, setSelectedChannelIds] = useState([]); + + // 渠道端点配置 + const [channelEndpoints, setChannelEndpoints] = useState({}); // { channelId: endpoint } + + // 差异数据和测试结果 + const [differences, setDifferences] = useState({}); + const [resolutions, setResolutions] = useState({}); + + // 是否已经执行过同步 + const [hasSynced, setHasSynced] = useState(false); + + // 分页相关状态 + const [currentPage, setCurrentPage] = useState(1); + const [pageSize, setPageSize] = useState(10); + + // 搜索相关状态 + const [searchKeyword, setSearchKeyword] = useState(''); + + const fetchAllChannels = async () => { + setLoading(true); + try { + const res = await API.get('/api/ratio_sync/channels'); + + if (res.data.success) { + const channels = res.data.data || []; + + const transferData = channels.map(channel => ({ + key: channel.id, + label: channel.name, + value: channel.id, + disabled: false, + _originalData: channel, + })); + + setAllChannels(transferData); + + const initialEndpoints = {}; + transferData.forEach(channel => { + initialEndpoints[channel.key] = DEFAULT_ENDPOINT; + }); + setChannelEndpoints(initialEndpoints); + } else { + showError(res.data.message); + } + } catch (error) { + showError(t('获取渠道失败:') + error.message); + } finally { + setLoading(false); + } + }; + + const confirmChannelSelection = () => { + const selected = allChannels + .filter(ch => selectedChannelIds.includes(ch.value)) + .map(ch => ch._originalData); + + if (selected.length === 0) { + showWarning(t('请至少选择一个渠道')); + return; + } + + setModalVisible(false); + fetchRatiosFromChannels(selected); + }; + + const fetchRatiosFromChannels = async (channelList) => { + setSyncLoading(true); + + const payload = { + channel_ids: channelList.map(ch => parseInt(ch.id)), + timeout: 10, + }; + + try { + const res = await API.post('/api/ratio_sync/fetch', payload); + + if (!res.data.success) { + showError(res.data.message || t('后端请求失败')); + setSyncLoading(false); + return; + } + + const { differences = {}, test_results = [] } = res.data.data; + + const errorResults = test_results.filter(r => r.status === 'error'); + if (errorResults.length > 0) { + showWarning(t('部分渠道测试失败:') + errorResults.map(r => `${r.name}: ${r.error}`).join(', ')); + } + + setDifferences(differences); + setResolutions({}); + setHasSynced(true); + + if (Object.keys(differences).length === 0) { + showSuccess(t('已与上游倍率完全一致,无需同步')); + } + } catch (e) { + showError(t('请求后端接口失败:') + e.message); + } finally { + setSyncLoading(false); + } + }; + + const selectValue = (model, ratioType, value) => { + setResolutions(prev => ({ + ...prev, + [model]: { + ...prev[model], + [ratioType]: value, + }, + })); + }; + + const applySync = async () => { + const currentRatios = { + ModelRatio: JSON.parse(props.options.ModelRatio || '{}'), + CompletionRatio: JSON.parse(props.options.CompletionRatio || '{}'), + CacheRatio: JSON.parse(props.options.CacheRatio || '{}'), + ModelPrice: JSON.parse(props.options.ModelPrice || '{}'), + }; + + Object.entries(resolutions).forEach(([model, ratios]) => { + Object.entries(ratios).forEach(([ratioType, value]) => { + const optionKey = ratioType + .split('_') + .map(word => word.charAt(0).toUpperCase() + word.slice(1)) + .join(''); + currentRatios[optionKey][model] = parseFloat(value); + }); + }); + + setLoading(true); + try { + const updates = Object.entries(currentRatios).map(([key, value]) => + API.put('/api/option/', { + key, + value: JSON.stringify(value, null, 2), + }) + ); + + const results = await Promise.all(updates); + + if (results.every(res => res.data.success)) { + showSuccess(t('同步成功')); + props.refresh(); + + setDifferences(prevDifferences => { + const newDifferences = { ...prevDifferences }; + + Object.entries(resolutions).forEach(([model, ratios]) => { + Object.keys(ratios).forEach(ratioType => { + if (newDifferences[model] && newDifferences[model][ratioType]) { + delete newDifferences[model][ratioType]; + + if (Object.keys(newDifferences[model]).length === 0) { + delete newDifferences[model]; + } + } + }); + }); + + return newDifferences; + }); + + setResolutions({}); + } else { + showError(t('部分保存失败')); + } + } catch (error) { + showError(t('保存失败')); + } finally { + setLoading(false); + } + }; + + const getCurrentPageData = (dataSource) => { + const startIndex = (currentPage - 1) * pageSize; + const endIndex = startIndex + pageSize; + return dataSource.slice(startIndex, endIndex); + }; + + const renderHeader = () => ( +
+
+
+ + + {(() => { + const hasSelections = Object.keys(resolutions).length > 0; + + return ( + + ); + })()} + + } + placeholder={t('搜索模型名称')} + value={searchKeyword} + onChange={setSearchKeyword} + className="!rounded-full w-full md:w-64 mt-2" + showClear + /> +
+
+
+ ); + + const renderDifferenceTable = () => { + const dataSource = useMemo(() => { + const tmp = []; + + Object.entries(differences).forEach(([model, ratioTypes]) => { + Object.entries(ratioTypes).forEach(([ratioType, diff]) => { + tmp.push({ + key: `${model}_${ratioType}`, + model, + ratioType, + current: diff.current, + upstreams: diff.upstreams, + }); + }); + }); + + return tmp; + }, [differences]); + + const filteredDataSource = useMemo(() => { + if (!searchKeyword.trim()) { + return dataSource; + } + + const keyword = searchKeyword.toLowerCase().trim(); + return dataSource.filter(item => + item.model.toLowerCase().includes(keyword) + ); + }, [dataSource, searchKeyword]); + + const upstreamNames = useMemo(() => { + const set = new Set(); + filteredDataSource.forEach((row) => { + Object.keys(row.upstreams || {}).forEach((name) => set.add(name)); + }); + return Array.from(set); + }, [filteredDataSource]); + + if (filteredDataSource.length === 0) { + return ( + } + darkModeImage={} + description={ + searchKeyword.trim() + ? t('未找到匹配的模型') + : (Object.keys(differences).length === 0 ? + (hasSynced ? t('暂无差异化倍率显示') : t('请先选择同步渠道')) + : t('请先选择同步渠道')) + } + style={{ padding: 30 }} + /> + ); + } + + const columns = [ + { + title: t('模型'), + dataIndex: 'model', + fixed: 'left', + }, + { + title: t('倍率类型'), + dataIndex: 'ratioType', + render: (text) => { + const typeMap = { + model_ratio: t('模型倍率'), + completion_ratio: t('补全倍率'), + cache_ratio: t('缓存倍率'), + model_price: t('固定价格'), + }; + return {typeMap[text] || text}; + }, + }, + { + title: t('当前值'), + dataIndex: 'current', + render: (text) => ( + + {text !== null && text !== undefined ? text : t('未设置')} + + ), + }, + ...upstreamNames.map((upName) => { + const channelStats = (() => { + let selectableCount = 0; + let selectedCount = 0; + + filteredDataSource.forEach((row) => { + const upstreamVal = row.upstreams?.[upName]; + if (upstreamVal !== null && upstreamVal !== undefined && upstreamVal !== 'same') { + selectableCount++; + const isSelected = resolutions[row.model]?.[row.ratioType] === upstreamVal; + if (isSelected) { + selectedCount++; + } + } + }); + + return { + selectableCount, + selectedCount, + allSelected: selectableCount > 0 && selectedCount === selectableCount, + partiallySelected: selectedCount > 0 && selectedCount < selectableCount, + hasSelectableItems: selectableCount > 0 + }; + })(); + + const handleBulkSelect = (checked) => { + setResolutions((prev) => { + const newRes = { ...prev }; + + filteredDataSource.forEach((row) => { + const upstreamVal = row.upstreams?.[upName]; + if (upstreamVal !== null && upstreamVal !== undefined && upstreamVal !== 'same') { + if (checked) { + if (!newRes[row.model]) newRes[row.model] = {}; + newRes[row.model][row.ratioType] = upstreamVal; + } else { + if (newRes[row.model]) { + delete newRes[row.model][row.ratioType]; + if (Object.keys(newRes[row.model]).length === 0) { + delete newRes[row.model]; + } + } + } + } + }); + + return newRes; + }); + }; + + return { + title: channelStats.hasSelectableItems ? ( + handleBulkSelect(e.target.checked)} + > + {upName} + + ) : ( + {upName} + ), + dataIndex: upName, + render: (_, record) => { + const upstreamVal = record.upstreams?.[upName]; + + if (upstreamVal === null || upstreamVal === undefined) { + return {t('未设置')}; + } + + if (upstreamVal === 'same') { + return {t('与本地相同')}; + } + + const isSelected = resolutions[record.model]?.[record.ratioType] === upstreamVal; + + return ( + { + const isChecked = e.target.checked; + if (isChecked) { + selectValue(record.model, record.ratioType, upstreamVal); + } else { + setResolutions((prev) => { + const newRes = { ...prev }; + if (newRes[record.model]) { + delete newRes[record.model][record.ratioType]; + if (Object.keys(newRes[record.model]).length === 0) { + delete newRes[record.model]; + } + } + return newRes; + }); + } + }} + > + {upstreamVal} + + ); + }, + }; + }), + ]; + + return ( + t('第 {{start}} - {{end}} 条,共 {{total}} 条', { + start: page.currentStart, + end: page.currentEnd, + total: filteredDataSource.length, + }), + pageSizeOptions: ['5', '10', '20', '50'], + onChange: (page, size) => { + setCurrentPage(page); + setPageSize(size); + }, + onShowSizeChange: (current, size) => { + setCurrentPage(1); + setPageSize(size); + } + }} + scroll={{ x: 'max-content' }} + size='middle' + loading={loading || syncLoading} + className="rounded-xl overflow-hidden" + /> + ); + }; + + const updateChannelEndpoint = useCallback((channelId, endpoint) => { + setChannelEndpoints(prev => ({ ...prev, [channelId]: endpoint })); + }, []); + + return ( + <> + + {renderDifferenceTable()} + + + setModalVisible(false)} + onOk={confirmChannelSelection} + allChannels={allChannels} + selectedChannelIds={selectedChannelIds} + setSelectedChannelIds={setSelectedChannelIds} + channelEndpoints={channelEndpoints} + updateChannelEndpoint={updateChannelEndpoint} + /> + + ); +} \ No newline at end of file