Merge branch 'alpha'

This commit is contained in:
CaIon
2025-06-21 01:26:57 +08:00
66 changed files with 2135 additions and 167 deletions

View File

@@ -241,6 +241,7 @@ const (
ChannelTypeXinference = 47 ChannelTypeXinference = 47
ChannelTypeXai = 48 ChannelTypeXai = 48
ChannelTypeCoze = 49 ChannelTypeCoze = 49
ChannelTypeKling = 50
ChannelTypeDummy // this one is only for count, do not add any channel after this ChannelTypeDummy // this one is only for count, do not add any channel after this
) )
@@ -296,4 +297,5 @@ var ChannelBaseURLs = []string{
"", //47 "", //47
"https://api.x.ai", //48 "https://api.x.ai", //48
"https://api.coze.cn", //49 "https://api.coze.cn", //49
"https://api.klingai.com", //50
} }

View File

@@ -13,6 +13,7 @@ import (
"math/big" "math/big"
"math/rand" "math/rand"
"net" "net"
"net/url"
"os" "os"
"os/exec" "os/exec"
"runtime" "runtime"
@@ -284,3 +285,20 @@ func GetAudioDuration(ctx context.Context, filename string, ext string) (float64
} }
return strconv.ParseFloat(durationStr, 64) return strconv.ParseFloat(durationStr, 64)
} }
// BuildURL concatenates base and endpoint, returns the complete url string
func BuildURL(base string, endpoint string) string {
u, err := url.Parse(base)
if err != nil {
return base + endpoint
}
end := endpoint
if end == "" {
end = "/"
}
ref, err := url.Parse(end)
if err != nil {
return base + endpoint
}
return u.ResolveReference(ref).String()
}

View File

@@ -5,6 +5,7 @@ type TaskPlatform string
const ( const (
TaskPlatformSuno TaskPlatform = "suno" TaskPlatformSuno TaskPlatform = "suno"
TaskPlatformMidjourney = "mj" TaskPlatformMidjourney = "mj"
TaskPlatformKling TaskPlatform = "kling"
) )
const ( const (

View File

@@ -40,6 +40,9 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
if channel.Type == common.ChannelTypeSunoAPI { if channel.Type == common.ChannelTypeSunoAPI {
return errors.New("suno channel test is not supported"), nil return errors.New("suno channel test is not supported"), nil
} }
if channel.Type == common.ChannelTypeKling {
return errors.New("kling channel test is not supported"), nil
}
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
@@ -90,7 +93,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
info := relaycommon.GenRelayInfo(c) info := relaycommon.GenRelayInfo(c)
err = helper.ModelMappedHelper(c, info) err = helper.ModelMappedHelper(c, info, nil)
if err != nil { if err != nil {
return err, nil return err, nil
} }

View File

@@ -2,6 +2,7 @@ package controller
import ( import (
"fmt" "fmt"
"github.com/samber/lo"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/constant" "one-api/constant"
@@ -136,6 +137,9 @@ func init() {
adaptor.Init(meta) adaptor.Init(meta)
channelId2Models[i] = adaptor.GetModelList() channelId2Models[i] = adaptor.GetModelList()
} }
openAIModels = lo.UniqBy(openAIModels, func(m dto.OpenAIModels) string {
return m.Id
})
} }
func ListModels(c *gin.Context) { func ListModels(c *gin.Context) {

View File

@@ -0,0 +1,24 @@
package controller
import (
"net/http"
"one-api/setting/ratio_setting"
"github.com/gin-gonic/gin"
)
func GetRatioConfig(c *gin.Context) {
if !ratio_setting.IsExposeRatioEnabled() {
c.JSON(http.StatusForbidden, gin.H{
"success": false,
"message": "倍率配置接口未启用",
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": ratio_setting.GetExposedData(),
})
}

322
controller/ratio_sync.go Normal file
View File

@@ -0,0 +1,322 @@
package controller
import (
"context"
"encoding/json"
"net/http"
"strings"
"sync"
"time"
"one-api/common"
"one-api/dto"
"one-api/model"
"one-api/setting/ratio_setting"
"github.com/gin-gonic/gin"
)
const (
defaultTimeoutSeconds = 10
defaultEndpoint = "/api/ratio_config"
maxConcurrentFetches = 8
)
var ratioTypes = []string{"model_ratio", "completion_ratio", "cache_ratio", "model_price"}
type upstreamResult struct {
Name string `json:"name"`
Data map[string]any `json:"data,omitempty"`
Err string `json:"err,omitempty"`
}
func FetchUpstreamRatios(c *gin.Context) {
var req dto.UpstreamRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": err.Error()})
return
}
if req.Timeout <= 0 {
req.Timeout = defaultTimeoutSeconds
}
var upstreams []dto.UpstreamDTO
if len(req.ChannelIDs) > 0 {
intIds := make([]int, 0, len(req.ChannelIDs))
for _, id64 := range req.ChannelIDs {
intIds = append(intIds, int(id64))
}
dbChannels, err := model.GetChannelsByIds(intIds)
if err != nil {
common.LogError(c.Request.Context(), "failed to query channels: "+err.Error())
c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": "查询渠道失败"})
return
}
for _, ch := range dbChannels {
if base := ch.GetBaseURL(); strings.HasPrefix(base, "http") {
upstreams = append(upstreams, dto.UpstreamDTO{
Name: ch.Name,
BaseURL: strings.TrimRight(base, "/"),
Endpoint: "",
})
}
}
}
if len(upstreams) == 0 {
c.JSON(http.StatusOK, gin.H{"success": false, "message": "无有效上游渠道"})
return
}
var wg sync.WaitGroup
ch := make(chan upstreamResult, len(upstreams))
sem := make(chan struct{}, maxConcurrentFetches)
client := &http.Client{Transport: &http.Transport{MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second}}
for _, chn := range upstreams {
wg.Add(1)
go func(chItem dto.UpstreamDTO) {
defer wg.Done()
sem <- struct{}{}
defer func() { <-sem }()
endpoint := chItem.Endpoint
if endpoint == "" {
endpoint = defaultEndpoint
} else if !strings.HasPrefix(endpoint, "/") {
endpoint = "/" + endpoint
}
fullURL := chItem.BaseURL + endpoint
ctx, cancel := context.WithTimeout(c.Request.Context(), time.Duration(req.Timeout)*time.Second)
defer cancel()
httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil)
if err != nil {
common.LogWarn(c.Request.Context(), "build request failed: "+err.Error())
ch <- upstreamResult{Name: chItem.Name, Err: err.Error()}
return
}
resp, err := client.Do(httpReq)
if err != nil {
common.LogWarn(c.Request.Context(), "http error on "+chItem.Name+": "+err.Error())
ch <- upstreamResult{Name: chItem.Name, Err: err.Error()}
return
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
common.LogWarn(c.Request.Context(), "non-200 from "+chItem.Name+": "+resp.Status)
ch <- upstreamResult{Name: chItem.Name, Err: resp.Status}
return
}
var body struct {
Success bool `json:"success"`
Data map[string]any `json:"data"`
Message string `json:"message"`
}
if err := json.NewDecoder(resp.Body).Decode(&body); err != nil {
common.LogWarn(c.Request.Context(), "json decode failed from "+chItem.Name+": "+err.Error())
ch <- upstreamResult{Name: chItem.Name, Err: err.Error()}
return
}
if !body.Success {
ch <- upstreamResult{Name: chItem.Name, Err: body.Message}
return
}
ch <- upstreamResult{Name: chItem.Name, Data: body.Data}
}(chn)
}
wg.Wait()
close(ch)
localData := ratio_setting.GetExposedData()
var testResults []dto.TestResult
var successfulChannels []struct {
name string
data map[string]any
}
for r := range ch {
if r.Err != "" {
testResults = append(testResults, dto.TestResult{
Name: r.Name,
Status: "error",
Error: r.Err,
})
} else {
testResults = append(testResults, dto.TestResult{
Name: r.Name,
Status: "success",
})
successfulChannels = append(successfulChannels, struct {
name string
data map[string]any
}{name: r.Name, data: r.Data})
}
}
differences := buildDifferences(localData, successfulChannels)
c.JSON(http.StatusOK, gin.H{
"success": true,
"data": gin.H{
"differences": differences,
"test_results": testResults,
},
})
}
func buildDifferences(localData map[string]any, successfulChannels []struct {
name string
data map[string]any
}) map[string]map[string]dto.DifferenceItem {
differences := make(map[string]map[string]dto.DifferenceItem)
allModels := make(map[string]struct{})
for _, ratioType := range ratioTypes {
if localRatioAny, ok := localData[ratioType]; ok {
if localRatio, ok := localRatioAny.(map[string]float64); ok {
for modelName := range localRatio {
allModels[modelName] = struct{}{}
}
}
}
}
for _, channel := range successfulChannels {
for _, ratioType := range ratioTypes {
if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
for modelName := range upstreamRatio {
allModels[modelName] = struct{}{}
}
}
}
}
for modelName := range allModels {
for _, ratioType := range ratioTypes {
var localValue interface{} = nil
if localRatioAny, ok := localData[ratioType]; ok {
if localRatio, ok := localRatioAny.(map[string]float64); ok {
if val, exists := localRatio[modelName]; exists {
localValue = val
}
}
}
upstreamValues := make(map[string]interface{})
hasUpstreamValue := false
hasDifference := false
for _, channel := range successfulChannels {
var upstreamValue interface{} = nil
if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
if val, exists := upstreamRatio[modelName]; exists {
upstreamValue = val
hasUpstreamValue = true
if localValue != nil && localValue != val {
hasDifference = true
} else if localValue == val {
upstreamValue = "same"
}
}
}
if upstreamValue == nil && localValue == nil {
upstreamValue = "same"
}
if localValue == nil && upstreamValue != nil && upstreamValue != "same" {
hasDifference = true
}
upstreamValues[channel.name] = upstreamValue
}
shouldInclude := false
if localValue != nil {
if hasDifference {
shouldInclude = true
}
} else {
if hasUpstreamValue {
shouldInclude = true
}
}
if shouldInclude {
if differences[modelName] == nil {
differences[modelName] = make(map[string]dto.DifferenceItem)
}
differences[modelName][ratioType] = dto.DifferenceItem{
Current: localValue,
Upstreams: upstreamValues,
}
}
}
}
channelHasDiff := make(map[string]bool)
for _, ratioMap := range differences {
for _, item := range ratioMap {
for chName, val := range item.Upstreams {
if val != nil && val != "same" {
channelHasDiff[chName] = true
}
}
}
}
for modelName, ratioMap := range differences {
for ratioType, item := range ratioMap {
for chName := range item.Upstreams {
if !channelHasDiff[chName] {
delete(item.Upstreams, chName)
}
}
differences[modelName][ratioType] = item
}
}
return differences
}
func GetSyncableChannels(c *gin.Context) {
channels, err := model.GetAllChannels(0, 0, true, false)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
var syncableChannels []dto.SyncableChannel
for _, channel := range channels {
if channel.GetBaseURL() != "" {
syncableChannels = append(syncableChannels, dto.SyncableChannel{
ID: channel.Id,
Name: channel.Name,
BaseURL: channel.GetBaseURL(),
Status: channel.Status,
})
}
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": syncableChannels,
})
}

View File

@@ -420,7 +420,7 @@ func RelayTask(c *gin.Context) {
func taskRelayHandler(c *gin.Context, relayMode int) *dto.TaskError { func taskRelayHandler(c *gin.Context, relayMode int) *dto.TaskError {
var err *dto.TaskError var err *dto.TaskError
switch relayMode { switch relayMode {
case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID: case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID, relayconstant.RelayModeKlingFetchByID:
err = relay.RelayTaskFetch(c, relayMode) err = relay.RelayTaskFetch(c, relayMode)
default: default:
err = relay.RelayTaskSubmit(c, relayMode) err = relay.RelayTaskSubmit(c, relayMode)

View File

@@ -74,6 +74,8 @@ func UpdateTaskByPlatform(platform constant.TaskPlatform, taskChannelM map[int][
//_ = UpdateMidjourneyTaskAll(context.Background(), tasks) //_ = UpdateMidjourneyTaskAll(context.Background(), tasks)
case constant.TaskPlatformSuno: case constant.TaskPlatformSuno:
_ = UpdateSunoTaskAll(context.Background(), taskChannelM, taskM) _ = UpdateSunoTaskAll(context.Background(), taskChannelM, taskM)
case constant.TaskPlatformKling:
_ = UpdateVideoTaskAll(context.Background(), taskChannelM, taskM)
default: default:
common.SysLog("未知平台") common.SysLog("未知平台")
} }

140
controller/task_video.go Normal file
View File

@@ -0,0 +1,140 @@
package controller
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/model"
"one-api/relay"
"one-api/relay/channel"
)
func UpdateVideoTaskAll(ctx context.Context, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
for channelId, taskIds := range taskChannelM {
if err := updateVideoTaskAll(ctx, channelId, taskIds, taskM); err != nil {
common.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error()))
}
}
return nil
}
func updateVideoTaskAll(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error {
common.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds)))
if len(taskIds) == 0 {
return nil
}
cacheGetChannel, err := model.CacheGetChannel(channelId)
if err != nil {
errUpdate := model.TaskBulkUpdate(taskIds, map[string]any{
"fail_reason": fmt.Sprintf("Failed to get channel info, channel ID: %d", channelId),
"status": "FAILURE",
"progress": "100%",
})
if errUpdate != nil {
common.SysError(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate))
}
return fmt.Errorf("CacheGetChannel failed: %w", err)
}
adaptor := relay.GetTaskAdaptor(constant.TaskPlatformKling)
if adaptor == nil {
return fmt.Errorf("video adaptor not found")
}
for _, taskId := range taskIds {
if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil {
common.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error()))
}
}
return nil
}
func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, channel *model.Channel, taskId string, taskM map[string]*model.Task) error {
baseURL := common.ChannelBaseURLs[channel.Type]
if channel.GetBaseURL() != "" {
baseURL = channel.GetBaseURL()
}
resp, err := adaptor.FetchTask(baseURL, channel.Key, map[string]any{
"task_id": taskId,
})
if err != nil {
return fmt.Errorf("FetchTask failed for task %s: %w", taskId, err)
}
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("Get Video Task status code: %d", resp.StatusCode)
}
defer resp.Body.Close()
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("ReadAll failed for task %s: %w", taskId, err)
}
var responseItem map[string]interface{}
err = json.Unmarshal(responseBody, &responseItem)
if err != nil {
common.LogError(ctx, fmt.Sprintf("Failed to parse video task response body: %v, body: %s", err, string(responseBody)))
return fmt.Errorf("Unmarshal failed for task %s: %w", taskId, err)
}
code, _ := responseItem["code"].(float64)
if code != 0 {
return fmt.Errorf("video task fetch failed for task %s", taskId)
}
data, ok := responseItem["data"].(map[string]interface{})
if !ok {
common.LogError(ctx, fmt.Sprintf("Video task data format error: %s", string(responseBody)))
return fmt.Errorf("video task data format error for task %s", taskId)
}
task := taskM[taskId]
if task == nil {
common.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId))
return fmt.Errorf("task %s not found", taskId)
}
if status, ok := data["task_status"].(string); ok {
switch status {
case "submitted", "queued":
task.Status = model.TaskStatusSubmitted
case "processing":
task.Status = model.TaskStatusInProgress
case "succeed":
task.Status = model.TaskStatusSuccess
task.Progress = "100%"
if url, err := adaptor.ParseResultUrl(responseItem); err == nil {
task.FailReason = url
} else {
common.LogWarn(ctx, fmt.Sprintf("Failed to get url from body for task %s: %s", task.TaskID, err.Error()))
}
case "failed":
task.Status = model.TaskStatusFailure
task.Progress = "100%"
if reason, ok := data["fail_reason"].(string); ok {
task.FailReason = reason
}
}
}
// If task failed, refund quota
if task.Status == model.TaskStatusFailure {
common.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason))
quota := task.Quota
if quota != 0 {
if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil {
common.LogError(ctx, "Failed to increase user quota: "+err.Error())
}
logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, common.LogQuota(quota))
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
}
}
task.Data = responseBody
if err := task.Update(); err != nil {
common.SysError("UpdateVideoTask task error: " + err.Error())
}
return nil
}

View File

@@ -97,14 +97,12 @@ func RequestEpay(c *gin.Context) {
c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"}) c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"})
return return
} }
payType := "wxpay"
if req.PaymentMethod == "zfb" { if !setting.ContainsPayMethod(req.PaymentMethod) {
payType = "alipay" c.JSON(200, gin.H{"message": "error", "data": "支付方式不存在"})
} return
if req.PaymentMethod == "wx" {
req.PaymentMethod = "wxpay"
payType = "wxpay"
} }
callBackAddress := service.GetCallbackAddress() callBackAddress := service.GetCallbackAddress()
returnUrl, _ := url.Parse(setting.ServerAddress + "/console/log") returnUrl, _ := url.Parse(setting.ServerAddress + "/console/log")
notifyUrl, _ := url.Parse(callBackAddress + "/api/user/epay/notify") notifyUrl, _ := url.Parse(callBackAddress + "/api/user/epay/notify")
@@ -116,7 +114,7 @@ func RequestEpay(c *gin.Context) {
return return
} }
uri, params, err := client.Purchase(&epay.PurchaseArgs{ uri, params, err := client.Purchase(&epay.PurchaseArgs{
Type: payType, Type: req.PaymentMethod,
ServiceTradeNo: tradeNo, ServiceTradeNo: tradeNo,
Name: fmt.Sprintf("TUC%d", req.Amount), Name: fmt.Sprintf("TUC%d", req.Amount),
Money: strconv.FormatFloat(payMoney, 'f', 2, 64), Money: strconv.FormatFloat(payMoney, 'f', 2, 64),

49
dto/ratio_sync.go Normal file
View File

@@ -0,0 +1,49 @@
package dto
// UpstreamDTO 提交到后端同步倍率的上游渠道信息
// Endpoint 可以为空,后端会默认使用 /api/ratio_config
// BaseURL 必须以 http/https 开头,不要以 / 结尾
// 例如: https://api.example.com
// Endpoint: /api/ratio_config
// 提交示例:
// {
// "name": "openai",
// "base_url": "https://api.openai.com",
// "endpoint": "/ratio_config"
// }
type UpstreamDTO struct {
Name string `json:"name" binding:"required"`
BaseURL string `json:"base_url" binding:"required"`
Endpoint string `json:"endpoint"`
}
type UpstreamRequest struct {
ChannelIDs []int64 `json:"channel_ids"`
Timeout int `json:"timeout"`
}
// TestResult 上游测试连通性结果
type TestResult struct {
Name string `json:"name"`
Status string `json:"status"`
Error string `json:"error,omitempty"`
}
// DifferenceItem 差异项
// Current 为本地值,可能为 nil
// Upstreams 为各渠道的上游值,具体数值 / "same" / nil
type DifferenceItem struct {
Current interface{} `json:"current"`
Upstreams map[string]interface{} `json:"upstreams"`
}
// SyncableChannel 可同步的渠道信息base_url 不为空)
type SyncableChannel struct {
ID int `json:"id"`
Name string `json:"name"`
BaseURL string `json:"base_url"`
Status int `json:"status"`
}

47
dto/video.go Normal file
View File

@@ -0,0 +1,47 @@
package dto
type VideoRequest struct {
Model string `json:"model,omitempty" example:"kling-v1"` // Model/style ID
Prompt string `json:"prompt,omitempty" example:"宇航员站起身走了"` // Text prompt
Image string `json:"image,omitempty" example:"https://h2.inkwai.com/bs2/upload-ylab-stunt/se/ai_portal_queue_mmu_image_upscale_aiweb/3214b798-e1b4-4b00-b7af-72b5b0417420_raw_image_0.jpg"` // Image input (URL/Base64)
Duration float64 `json:"duration" example:"5.0"` // Video duration (seconds)
Width int `json:"width" example:"512"` // Video width
Height int `json:"height" example:"512"` // Video height
Fps int `json:"fps,omitempty" example:"30"` // Video frame rate
Seed int `json:"seed,omitempty" example:"20231234"` // Random seed
N int `json:"n,omitempty" example:"1"` // Number of videos to generate
ResponseFormat string `json:"response_format,omitempty" example:"url"` // Response format
User string `json:"user,omitempty" example:"user-1234"` // User identifier
Metadata map[string]any `json:"metadata,omitempty"` // Vendor-specific/custom params (e.g. negative_prompt, style, quality_level, etc.)
}
// VideoResponse 视频生成提交任务后的响应
type VideoResponse struct {
TaskId string `json:"task_id"`
Status string `json:"status"`
}
// VideoTaskResponse 查询视频生成任务状态的响应
type VideoTaskResponse struct {
TaskId string `json:"task_id" example:"abcd1234efgh"` // 任务ID
Status string `json:"status" example:"succeeded"` // 任务状态
Url string `json:"url,omitempty"` // 视频资源URL成功时
Format string `json:"format,omitempty" example:"mp4"` // 视频格式
Metadata *VideoTaskMetadata `json:"metadata,omitempty"` // 结果元数据
Error *VideoTaskError `json:"error,omitempty"` // 错误信息(失败时)
}
// VideoTaskMetadata 视频任务元数据
type VideoTaskMetadata struct {
Duration float64 `json:"duration" example:"5.0"` // 实际生成的视频时长
Fps int `json:"fps" example:"30"` // 实际帧率
Width int `json:"width" example:"512"` // 实际宽度
Height int `json:"height" example:"512"` // 实际高度
Seed int `json:"seed" example:"20231234"` // 使用的随机种子
}
// VideoTaskError 视频任务错误信息
type VideoTaskError struct {
Code int `json:"code"`
Message string `json:"message"`
}

View File

@@ -170,6 +170,15 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
} }
c.Set("platform", string(constant.TaskPlatformSuno)) c.Set("platform", string(constant.TaskPlatformSuno))
c.Set("relay_mode", relayMode) c.Set("relay_mode", relayMode)
} else if strings.Contains(c.Request.URL.Path, "/v1/video/generations") {
relayMode := relayconstant.Path2RelayKling(c.Request.Method, c.Request.URL.Path)
if relayMode == relayconstant.RelayModeKlingFetchByID {
shouldSelectChannel = false
} else {
err = common.UnmarshalBodyReusable(c, &modelRequest)
}
c.Set("platform", string(constant.TaskPlatformKling))
c.Set("relay_mode", relayMode)
} else if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") { } else if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") {
// Gemini API 路径处理: /v1beta/models/gemini-2.0-flash:generateContent // Gemini API 路径处理: /v1beta/models/gemini-2.0-flash:generateContent
relayMode := relayconstant.RelayModeGemini relayMode := relayconstant.RelayModeGemini

View File

@@ -127,6 +127,7 @@ func InitOptionMap() {
common.OptionMap["SensitiveWords"] = setting.SensitiveWordsToString() common.OptionMap["SensitiveWords"] = setting.SensitiveWordsToString()
common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(setting.StreamCacheQueueLength) common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(setting.StreamCacheQueueLength)
common.OptionMap["AutomaticDisableKeywords"] = operation_setting.AutomaticDisableKeywordsToString() common.OptionMap["AutomaticDisableKeywords"] = operation_setting.AutomaticDisableKeywordsToString()
common.OptionMap["ExposeRatioEnabled"] = strconv.FormatBool(ratio_setting.IsExposeRatioEnabled())
// 自动添加所有注册的模型配置 // 自动添加所有注册的模型配置
modelConfigs := config.GlobalConfig.ExportAllConfigs() modelConfigs := config.GlobalConfig.ExportAllConfigs()
@@ -267,6 +268,8 @@ func updateOptionMap(key string, value string) (err error) {
setting.WorkerAllowHttpImageRequestEnabled = boolValue setting.WorkerAllowHttpImageRequestEnabled = boolValue
case "DefaultUseAutoGroup": case "DefaultUseAutoGroup":
setting.DefaultUseAutoGroup = boolValue setting.DefaultUseAutoGroup = boolValue
case "ExposeRatioEnabled":
ratio_setting.SetExposeRatioEnabled(boolValue)
} }
} }
switch key { switch key {

View File

@@ -55,7 +55,7 @@ func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
} }
func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
relayInfo := relaycommon.GenRelayInfo(c) relayInfo := relaycommon.GenRelayInfoOpenAIAudio(c)
audioRequest, err := getAndValidAudioRequest(c, relayInfo) audioRequest, err := getAndValidAudioRequest(c, relayInfo)
if err != nil { if err != nil {
@@ -66,10 +66,7 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
promptTokens := 0 promptTokens := 0
preConsumedTokens := common.PreConsumedQuota preConsumedTokens := common.PreConsumedQuota
if relayInfo.RelayMode == relayconstant.RelayModeAudioSpeech { if relayInfo.RelayMode == relayconstant.RelayModeAudioSpeech {
promptTokens, err = service.CountTTSToken(audioRequest.Input, audioRequest.Model) promptTokens = service.CountTTSToken(audioRequest.Input, audioRequest.Model)
if err != nil {
return service.OpenAIErrorWrapper(err, "count_audio_token_failed", http.StatusInternalServerError)
}
preConsumedTokens = promptTokens preConsumedTokens = promptTokens
relayInfo.PromptTokens = promptTokens relayInfo.PromptTokens = promptTokens
} }
@@ -89,13 +86,11 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
} }
}() }()
err = helper.ModelMappedHelper(c, relayInfo) err = helper.ModelMappedHelper(c, relayInfo, audioRequest)
if err != nil { if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError) return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
} }
audioRequest.Model = relayInfo.UpstreamModelName
adaptor := GetAdaptor(relayInfo.ApiType) adaptor := GetAdaptor(relayInfo.ApiType)
if adaptor == nil { if adaptor == nil {
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest) return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)

View File

@@ -44,4 +44,6 @@ type TaskAdaptor interface {
// FetchTask // FetchTask
FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error)
ParseResultUrl(resp map[string]any) (string, error)
} }

View File

@@ -549,7 +549,7 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, requestMode int) { func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, requestMode int) {
if requestMode == RequestModeCompletion { if requestMode == RequestModeCompletion {
claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens) claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
} else { } else {
if claudeInfo.Usage.PromptTokens == 0 { if claudeInfo.Usage.PromptTokens == 0 {
//上游出错 //上游出错
@@ -558,7 +558,7 @@ func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, clau
if common.DebugEnabled { if common.DebugEnabled {
common.SysError("claude response usage is not complete, maybe upstream error") common.SysError("claude response usage is not complete, maybe upstream error")
} }
claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens) claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
} }
} }
@@ -618,10 +618,7 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
} }
} }
if requestMode == RequestModeCompletion { if requestMode == RequestModeCompletion {
completionTokens, err := service.CountTextToken(claudeResponse.Completion, info.OriginModelName) completionTokens := service.CountTextToken(claudeResponse.Completion, info.OriginModelName)
if err != nil {
return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError)
}
claudeInfo.Usage.PromptTokens = info.PromptTokens claudeInfo.Usage.PromptTokens = info.PromptTokens
claudeInfo.Usage.CompletionTokens = completionTokens claudeInfo.Usage.CompletionTokens = completionTokens
claudeInfo.Usage.TotalTokens = info.PromptTokens + completionTokens claudeInfo.Usage.TotalTokens = info.PromptTokens + completionTokens

View File

@@ -71,7 +71,7 @@ func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
if err := scanner.Err(); err != nil { if err := scanner.Err(); err != nil {
common.LogError(c, "error_scanning_stream_response: "+err.Error()) common.LogError(c, "error_scanning_stream_response: "+err.Error())
} }
usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) usage := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
if info.ShouldIncludeUsage { if info.ShouldIncludeUsage {
response := helper.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage) response := helper.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage)
err := helper.ObjectData(c, response) err := helper.ObjectData(c, response)
@@ -108,7 +108,7 @@ func cfHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo)
for _, choice := range response.Choices { for _, choice := range response.Choices {
responseText += choice.Message.StringContent() responseText += choice.Message.StringContent()
} }
usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) usage := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
response.Usage = *usage response.Usage = *usage
response.Id = helper.GetResponseID(c) response.Id = helper.GetResponseID(c)
jsonResponse, err := json.Marshal(response) jsonResponse, err := json.Marshal(response)
@@ -150,7 +150,7 @@ func cfSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayIn
usage := &dto.Usage{} usage := &dto.Usage{}
usage.PromptTokens = info.PromptTokens usage.PromptTokens = info.PromptTokens
usage.CompletionTokens, _ = service.CountTextToken(cfResp.Result.Text, info.UpstreamModelName) usage.CompletionTokens = service.CountTextToken(cfResp.Result.Text, info.UpstreamModelName)
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
return nil, usage return nil, usage

View File

@@ -162,7 +162,7 @@ func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
} }
}) })
if usage.PromptTokens == 0 { if usage.PromptTokens == 0 {
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} }
return nil, usage return nil, usage
} }

View File

@@ -106,7 +106,7 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo
var currentEvent string var currentEvent string
var currentData string var currentData string
var usage dto.Usage var usage = &dto.Usage{}
for scanner.Scan() { for scanner.Scan() {
line := scanner.Text() line := scanner.Text()
@@ -114,7 +114,7 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo
if line == "" { if line == "" {
if currentEvent != "" && currentData != "" { if currentEvent != "" && currentData != "" {
// handle last event // handle last event
handleCozeEvent(c, currentEvent, currentData, &responseText, &usage, id, info) handleCozeEvent(c, currentEvent, currentData, &responseText, usage, id, info)
currentEvent = "" currentEvent = ""
currentData = "" currentData = ""
} }
@@ -134,7 +134,7 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo
// Last event // Last event
if currentEvent != "" && currentData != "" { if currentEvent != "" && currentData != "" {
handleCozeEvent(c, currentEvent, currentData, &responseText, &usage, id, info) handleCozeEvent(c, currentEvent, currentData, &responseText, usage, id, info)
} }
if err := scanner.Err(); err != nil { if err := scanner.Err(); err != nil {
@@ -143,12 +143,10 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo
helper.Done(c) helper.Done(c)
if usage.TotalTokens == 0 { if usage.TotalTokens == 0 {
usage.PromptTokens = info.PromptTokens usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, c.GetInt("coze_input_count"))
usage.CompletionTokens, _ = service.CountTextToken("gpt-3.5-turbo", responseText)
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
} }
return nil, &usage return nil, usage
} }
func handleCozeEvent(c *gin.Context, event string, data string, responseText *string, usage *dto.Usage, id string, info *relaycommon.RelayInfo) { func handleCozeEvent(c *gin.Context, event string, data string, responseText *string, usage *dto.Usage, id string, info *relaycommon.RelayInfo) {

View File

@@ -243,15 +243,8 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
return true return true
}) })
helper.Done(c) helper.Done(c)
err := resp.Body.Close()
if err != nil {
// return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
common.SysError("close_response_body_failed: " + err.Error())
}
if usage.TotalTokens == 0 { if usage.TotalTokens == 0 {
usage.PromptTokens = info.PromptTokens usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
usage.CompletionTokens, _ = service.CountTextToken("gpt-3.5-turbo", responseText)
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
} }
usage.CompletionTokens += nodeToken usage.CompletionTokens += nodeToken
return nil, usage return nil, usage

View File

@@ -73,12 +73,12 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled { if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
// 新增逻辑:处理 -thinking-<budget> 格式 // 新增逻辑:处理 -thinking-<budget> 格式
if strings.Contains(info.OriginModelName, "-thinking-") { if strings.Contains(info.UpstreamModelName, "-thinking-") {
parts := strings.Split(info.UpstreamModelName, "-thinking-") parts := strings.Split(info.UpstreamModelName, "-thinking-")
info.UpstreamModelName = parts[0] info.UpstreamModelName = parts[0]
} else if strings.HasSuffix(info.OriginModelName, "-thinking") { // 旧的适配 } else if strings.HasSuffix(info.UpstreamModelName, "-thinking") { // 旧的适配
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking") info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking")
} else if strings.HasSuffix(info.OriginModelName, "-nothinking") { } else if strings.HasSuffix(info.UpstreamModelName, "-nothinking") {
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-nothinking") info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-nothinking")
} }
} }

View File

@@ -9,6 +9,7 @@ import (
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/relay/helper" "one-api/relay/helper"
"one-api/service" "one-api/service"
"strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@@ -75,6 +76,8 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info
helper.SetEventStreamHeaders(c) helper.SetEventStreamHeaders(c)
responseText := strings.Builder{}
helper.StreamScannerHandler(c, resp, info, func(data string) bool { helper.StreamScannerHandler(c, resp, info, func(data string) bool {
var geminiResponse GeminiChatResponse var geminiResponse GeminiChatResponse
err := common.DecodeJsonStr(data, &geminiResponse) err := common.DecodeJsonStr(data, &geminiResponse)
@@ -89,6 +92,9 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info
if part.InlineData != nil && part.InlineData.MimeType != "" { if part.InlineData != nil && part.InlineData.MimeType != "" {
imageCount++ imageCount++
} }
if part.Text != "" {
responseText.WriteString(part.Text)
}
} }
} }
@@ -108,7 +114,7 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info
} }
// 直接发送 GeminiChatResponse 响应 // 直接发送 GeminiChatResponse 响应
err = helper.ObjectData(c, geminiResponse) err = helper.StringData(c, data)
if err != nil { if err != nil {
common.LogError(c, err.Error()) common.LogError(c, err.Error())
} }
@@ -122,8 +128,16 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info
} }
} }
// 计算最终使用量 // 如果usage.CompletionTokens为0则使用本地统计的completion tokens
// usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens if usage.CompletionTokens == 0 {
str := responseText.String()
if len(str) > 0 {
usage = service.ResponseText2Usage(responseText.String(), info.UpstreamModelName, info.PromptTokens)
} else {
// 空补全,不需要使用量
usage = &dto.Usage{}
}
}
// 移除流式响应结尾的[Done]因为Gemini API没有发送Done的行为 // 移除流式响应结尾的[Done]因为Gemini API没有发送Done的行为
//helper.Done(c) //helper.Done(c)

View File

@@ -99,7 +99,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
} }
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled { if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
modelName := info.OriginModelName modelName := info.UpstreamModelName
isNew25Pro := strings.HasPrefix(modelName, "gemini-2.5-pro") && isNew25Pro := strings.HasPrefix(modelName, "gemini-2.5-pro") &&
!strings.HasPrefix(modelName, "gemini-2.5-pro-preview-05-06") && !strings.HasPrefix(modelName, "gemini-2.5-pro-preview-05-06") &&
!strings.HasPrefix(modelName, "gemini-2.5-pro-preview-03-25") !strings.HasPrefix(modelName, "gemini-2.5-pro-preview-03-25")

View File

@@ -8,7 +8,6 @@ import (
"math" "math"
"mime/multipart" "mime/multipart"
"net/http" "net/http"
"path/filepath"
"one-api/common" "one-api/common"
"one-api/constant" "one-api/constant"
"one-api/dto" "one-api/dto"
@@ -16,6 +15,7 @@ import (
"one-api/relay/helper" "one-api/relay/helper"
"one-api/service" "one-api/service"
"os" "os"
"path/filepath"
"strings" "strings"
"github.com/bytedance/gopkg/util/gopool" "github.com/bytedance/gopkg/util/gopool"
@@ -181,7 +181,7 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
} }
if !containStreamUsage { if !containStreamUsage {
usage, _ = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens) usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
usage.CompletionTokens += toolCount * 7 usage.CompletionTokens += toolCount * 7
} else { } else {
if info.ChannelType == common.ChannelTypeDeepSeek { if info.ChannelType == common.ChannelTypeDeepSeek {
@@ -216,7 +216,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
StatusCode: resp.StatusCode, StatusCode: resp.StatusCode,
}, nil }, nil
} }
forceFormat := false forceFormat := false
if forceFmt, ok := info.ChannelSetting[constant.ForceFormat].(bool); ok { if forceFmt, ok := info.ChannelSetting[constant.ForceFormat].(bool); ok {
forceFormat = forceFmt forceFormat = forceFmt
@@ -225,7 +225,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) { if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
completionTokens := 0 completionTokens := 0
for _, choice := range simpleResponse.Choices { for _, choice := range simpleResponse.Choices {
ctkm, _ := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName) ctkm := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName)
completionTokens += ctkm completionTokens += ctkm
} }
simpleResponse.Usage = dto.Usage{ simpleResponse.Usage = dto.Usage{
@@ -276,9 +276,9 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
// the status code has been judged before, if there is a body reading failure, // the status code has been judged before, if there is a body reading failure,
// it should be regarded as a non-recoverable error, so it should not return err for external retry. // it should be regarded as a non-recoverable error, so it should not return err for external retry.
// Analogous to nginx's load balancing, it will only retry if it can't be requested or // Analogous to nginx's load balancing, it will only retry if it can't be requested or
// if the upstream returns a specific status code, once the upstream has already written the header, // if the upstream returns a specific status code, once the upstream has already written the header,
// the subsequent failure of the response body should be regarded as a non-recoverable error, // the subsequent failure of the response body should be regarded as a non-recoverable error,
// and can be terminated directly. // and can be terminated directly.
defer resp.Body.Close() defer resp.Body.Close()
usage := &dto.Usage{} usage := &dto.Usage{}
@@ -346,12 +346,12 @@ func countAudioTokens(c *gin.Context) (int, error) {
if err = c.ShouldBind(&reqBody); err != nil { if err = c.ShouldBind(&reqBody); err != nil {
return 0, errors.WithStack(err) return 0, errors.WithStack(err)
} }
ext := filepath.Ext(reqBody.File.Filename) // 获取文件扩展名 ext := filepath.Ext(reqBody.File.Filename) // 获取文件扩展名
reqFp, err := reqBody.File.Open() reqFp, err := reqBody.File.Open()
if err != nil { if err != nil {
return 0, errors.WithStack(err) return 0, errors.WithStack(err)
} }
defer reqFp.Close() defer reqFp.Close()
tmpFp, err := os.CreateTemp("", "audio-*"+ext) tmpFp, err := os.CreateTemp("", "audio-*"+ext)
if err != nil { if err != nil {

View File

@@ -110,7 +110,7 @@ func OaiResponsesStreamHandler(c *gin.Context, resp *http.Response, info *relayc
tempStr := responseTextBuilder.String() tempStr := responseTextBuilder.String()
if len(tempStr) > 0 { if len(tempStr) > 0 {
// 非正常结束,使用输出文本的 token 数量 // 非正常结束,使用输出文本的 token 数量
completionTokens, _ := service.CountTextToken(tempStr, info.UpstreamModelName) completionTokens := service.CountTextToken(tempStr, info.UpstreamModelName)
usage.CompletionTokens = completionTokens usage.CompletionTokens = completionTokens
} }
} }

View File

@@ -74,7 +74,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
if info.IsStream { if info.IsStream {
var responseText string var responseText string
err, responseText = palmStreamHandler(c, resp) err, responseText = palmStreamHandler(c, resp)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else { } else {
err, usage = palmHandler(c, resp, info.PromptTokens, info.UpstreamModelName) err, usage = palmHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
} }

View File

@@ -155,7 +155,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st
}, nil }, nil
} }
fullTextResponse := responsePaLM2OpenAI(&palmResponse) fullTextResponse := responsePaLM2OpenAI(&palmResponse)
completionTokens, _ := service.CountTextToken(palmResponse.Candidates[0].Content, model) completionTokens := service.CountTextToken(palmResponse.Candidates[0].Content, model)
usage := dto.Usage{ usage := dto.Usage{
PromptTokens: promptTokens, PromptTokens: promptTokens,
CompletionTokens: completionTokens, CompletionTokens: completionTokens,

View File

@@ -0,0 +1,312 @@
package kling
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt"
"github.com/pkg/errors"
"one-api/common"
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/service"
)
// ============================
// Request / Response structures
// ============================
type SubmitReq struct {
Prompt string `json:"prompt"`
Model string `json:"model,omitempty"`
Mode string `json:"mode,omitempty"`
Image string `json:"image,omitempty"`
Size string `json:"size,omitempty"`
Duration int `json:"duration,omitempty"`
Metadata map[string]interface{} `json:"metadata,omitempty"`
}
type requestPayload struct {
Prompt string `json:"prompt,omitempty"`
Image string `json:"image,omitempty"`
Mode string `json:"mode,omitempty"`
Duration string `json:"duration,omitempty"`
AspectRatio string `json:"aspect_ratio,omitempty"`
Model string `json:"model,omitempty"`
ModelName string `json:"model_name,omitempty"`
CfgScale float64 `json:"cfg_scale,omitempty"`
}
type responsePayload struct {
Code int `json:"code"`
Message string `json:"message"`
Data struct {
TaskID string `json:"task_id"`
} `json:"data"`
}
// ============================
// Adaptor implementation
// ============================
type TaskAdaptor struct {
ChannelType int
accessKey string
secretKey string
baseURL string
}
func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
a.ChannelType = info.ChannelType
a.baseURL = info.BaseUrl
// apiKey format: "access_key,secret_key"
keyParts := strings.Split(info.ApiKey, ",")
if len(keyParts) == 2 {
a.accessKey = strings.TrimSpace(keyParts[0])
a.secretKey = strings.TrimSpace(keyParts[1])
}
}
// ValidateRequestAndSetAction parses body, validates fields and sets default action.
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) (taskErr *dto.TaskError) {
// Accept only POST /v1/video/generations as "generate" action.
action := "generate"
info.Action = action
var req SubmitReq
if err := common.UnmarshalBodyReusable(c, &req); err != nil {
taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
return
}
if strings.TrimSpace(req.Prompt) == "" {
taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest)
return
}
// Store into context for later usage
c.Set("kling_request", req)
return nil
}
// BuildRequestURL constructs the upstream URL.
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) {
return fmt.Sprintf("%s/v1/videos/image2video", a.baseURL), nil
}
// BuildRequestHeader sets required headers.
func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error {
token, err := a.createJWTToken()
if err != nil {
return fmt.Errorf("failed to create JWT token: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
req.Header.Set("Authorization", "Bearer "+token)
req.Header.Set("User-Agent", "kling-sdk/1.0")
return nil
}
// BuildRequestBody converts request into Kling specific format.
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error) {
v, exists := c.Get("kling_request")
if !exists {
return nil, fmt.Errorf("request not found in context")
}
req := v.(SubmitReq)
body := a.convertToRequestPayload(&req)
data, err := json.Marshal(body)
if err != nil {
return nil, err
}
return bytes.NewReader(data), nil
}
// DoRequest delegates to common helper.
func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoTaskApiRequest(a, c, info, requestBody)
}
// DoResponse handles upstream response, returns taskID etc.
func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
return
}
// Attempt Kling response parse first.
var kResp responsePayload
if err := json.Unmarshal(responseBody, &kResp); err == nil && kResp.Code == 0 {
c.JSON(http.StatusOK, gin.H{"task_id": kResp.Data.TaskID})
return kResp.Data.TaskID, responseBody, nil
}
// Fallback generic task response.
var generic dto.TaskResponse[string]
if err := json.Unmarshal(responseBody, &generic); err != nil {
taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError)
return
}
if !generic.IsSuccess() {
taskErr = service.TaskErrorWrapper(fmt.Errorf(generic.Message), generic.Code, http.StatusInternalServerError)
return
}
c.JSON(http.StatusOK, gin.H{"task_id": generic.Data})
return generic.Data, responseBody, nil
}
// FetchTask fetch task status
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
taskID, ok := body["task_id"].(string)
if !ok {
return nil, fmt.Errorf("invalid task_id")
}
url := fmt.Sprintf("%s/v1/videos/image2video/%s", baseUrl, taskID)
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return nil, err
}
token, err := a.createJWTTokenWithKey(key)
if err != nil {
token = key
}
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
req = req.WithContext(ctx)
req.Header.Set("Accept", "application/json")
req.Header.Set("Authorization", "Bearer "+token)
req.Header.Set("User-Agent", "kling-sdk/1.0")
return service.GetHttpClient().Do(req)
}
func (a *TaskAdaptor) GetModelList() []string {
return []string{"kling-v1", "kling-v1-6", "kling-v2-master"}
}
func (a *TaskAdaptor) GetChannelName() string {
return "kling"
}
// ============================
// helpers
// ============================
func (a *TaskAdaptor) convertToRequestPayload(req *SubmitReq) *requestPayload {
r := &requestPayload{
Prompt: req.Prompt,
Image: req.Image,
Mode: defaultString(req.Mode, "std"),
Duration: fmt.Sprintf("%d", defaultInt(req.Duration, 5)),
AspectRatio: a.getAspectRatio(req.Size),
Model: req.Model,
ModelName: req.Model,
CfgScale: 0.5,
}
if r.Model == "" {
r.Model = "kling-v1"
r.ModelName = "kling-v1"
}
return r
}
func (a *TaskAdaptor) getAspectRatio(size string) string {
switch size {
case "1024x1024", "512x512":
return "1:1"
case "1280x720", "1920x1080":
return "16:9"
case "720x1280", "1080x1920":
return "9:16"
default:
return "1:1"
}
}
func defaultString(s, def string) string {
if strings.TrimSpace(s) == "" {
return def
}
return s
}
func defaultInt(v int, def int) int {
if v == 0 {
return def
}
return v
}
// ============================
// JWT helpers
// ============================
func (a *TaskAdaptor) createJWTToken() (string, error) {
return a.createJWTTokenWithKeys(a.accessKey, a.secretKey)
}
func (a *TaskAdaptor) createJWTTokenWithKey(apiKey string) (string, error) {
parts := strings.Split(apiKey, ",")
if len(parts) != 2 {
return "", fmt.Errorf("invalid API key format, expected 'access_key,secret_key'")
}
return a.createJWTTokenWithKeys(strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1]))
}
func (a *TaskAdaptor) createJWTTokenWithKeys(accessKey, secretKey string) (string, error) {
if accessKey == "" || secretKey == "" {
return "", fmt.Errorf("access key and secret key are required")
}
now := time.Now().Unix()
claims := jwt.MapClaims{
"iss": accessKey,
"exp": now + 1800, // 30 minutes
"nbf": now - 5,
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
token.Header["typ"] = "JWT"
return token.SignedString([]byte(secretKey))
}
// ParseResultUrl 提取视频任务结果的 url
func (a *TaskAdaptor) ParseResultUrl(resp map[string]any) (string, error) {
data, ok := resp["data"].(map[string]any)
if !ok {
return "", fmt.Errorf("data field not found or invalid")
}
taskResult, ok := data["task_result"].(map[string]any)
if !ok {
return "", fmt.Errorf("task_result field not found or invalid")
}
videos, ok := taskResult["videos"].([]interface{})
if !ok || len(videos) == 0 {
return "", fmt.Errorf("videos field not found or empty")
}
video, ok := videos[0].(map[string]interface{})
if !ok {
return "", fmt.Errorf("video item invalid")
}
url, ok := video["url"].(string)
if !ok || url == "" {
return "", fmt.Errorf("url field not found or invalid")
}
return url, nil
}

View File

@@ -22,6 +22,10 @@ type TaskAdaptor struct {
ChannelType int ChannelType int
} }
func (a *TaskAdaptor) ParseResultUrl(resp map[string]any) (string, error) {
return "", nil // todo implement this method if needed
}
func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) { func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
a.ChannelType = info.ChannelType a.ChannelType = info.ChannelType
} }

View File

@@ -98,7 +98,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
if info.IsStream { if info.IsStream {
var responseText string var responseText string
err, responseText = tencentStreamHandler(c, resp) err, responseText = tencentStreamHandler(c, resp)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else { } else {
err, usage = tencentHandler(c, resp) err, usage = tencentHandler(c, resp)
} }

View File

@@ -83,10 +83,13 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
suffix := "" suffix := ""
if a.RequestMode == RequestModeGemini { if a.RequestMode == RequestModeGemini {
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled { if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
// suffix -thinking and -nothinking // 新增逻辑:处理 -thinking-<budget> 格式
if strings.HasSuffix(info.OriginModelName, "-thinking") { if strings.Contains(info.UpstreamModelName, "-thinking-") {
parts := strings.Split(info.UpstreamModelName, "-thinking-")
info.UpstreamModelName = parts[0]
} else if strings.HasSuffix(info.UpstreamModelName, "-thinking") { // 旧的适配
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking") info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking")
} else if strings.HasSuffix(info.OriginModelName, "-nothinking") { } else if strings.HasSuffix(info.UpstreamModelName, "-nothinking") {
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-nothinking") info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-nothinking")
} }
} }

View File

@@ -68,7 +68,7 @@ func xAIStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
}) })
if !containStreamUsage { if !containStreamUsage {
usage, _ = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens) usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
usage.CompletionTokens += toolCount * 7 usage.CompletionTokens += toolCount * 7
} }

View File

@@ -46,13 +46,11 @@ func ClaudeHelper(c *gin.Context) (claudeError *dto.ClaudeErrorWithStatusCode) {
relayInfo.IsStream = true relayInfo.IsStream = true
} }
err = helper.ModelMappedHelper(c, relayInfo) err = helper.ModelMappedHelper(c, relayInfo, textRequest)
if err != nil { if err != nil {
return service.ClaudeErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError) return service.ClaudeErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
} }
textRequest.Model = relayInfo.UpstreamModelName
promptTokens, err := getClaudePromptTokens(textRequest, relayInfo) promptTokens, err := getClaudePromptTokens(textRequest, relayInfo)
// count messages token error 计算promptTokens错误 // count messages token error 计算promptTokens错误
if err != nil { if err != nil {
@@ -126,7 +124,7 @@ func ClaudeHelper(c *gin.Context) (claudeError *dto.ClaudeErrorWithStatusCode) {
var httpResp *http.Response var httpResp *http.Response
resp, err := adaptor.DoRequest(c, relayInfo, requestBody) resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
if err != nil { if err != nil {
return service.ClaudeErrorWrapperLocal(err, "do_request_failed", http.StatusInternalServerError) return service.ClaudeErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
} }
if resp != nil { if resp != nil {

View File

@@ -34,9 +34,14 @@ type ClaudeConvertInfo struct {
} }
const ( const (
RelayFormatOpenAI = "openai" RelayFormatOpenAI = "openai"
RelayFormatClaude = "claude" RelayFormatClaude = "claude"
RelayFormatGemini = "gemini" RelayFormatGemini = "gemini"
RelayFormatOpenAIResponses = "openai_responses"
RelayFormatOpenAIAudio = "openai_audio"
RelayFormatOpenAIImage = "openai_image"
RelayFormatRerank = "rerank"
RelayFormatEmbedding = "embedding"
) )
type RerankerInfo struct { type RerankerInfo struct {
@@ -143,6 +148,7 @@ func GenRelayInfoClaude(c *gin.Context) *RelayInfo {
func GenRelayInfoRerank(c *gin.Context, req *dto.RerankRequest) *RelayInfo { func GenRelayInfoRerank(c *gin.Context, req *dto.RerankRequest) *RelayInfo {
info := GenRelayInfo(c) info := GenRelayInfo(c)
info.RelayMode = relayconstant.RelayModeRerank info.RelayMode = relayconstant.RelayModeRerank
info.RelayFormat = RelayFormatRerank
info.RerankerInfo = &RerankerInfo{ info.RerankerInfo = &RerankerInfo{
Documents: req.Documents, Documents: req.Documents,
ReturnDocuments: req.GetReturnDocuments(), ReturnDocuments: req.GetReturnDocuments(),
@@ -150,9 +156,25 @@ func GenRelayInfoRerank(c *gin.Context, req *dto.RerankRequest) *RelayInfo {
return info return info
} }
func GenRelayInfoOpenAIAudio(c *gin.Context) *RelayInfo {
info := GenRelayInfo(c)
info.RelayFormat = RelayFormatOpenAIAudio
return info
}
func GenRelayInfoEmbedding(c *gin.Context) *RelayInfo {
info := GenRelayInfo(c)
info.RelayFormat = RelayFormatEmbedding
return info
}
func GenRelayInfoResponses(c *gin.Context, req *dto.OpenAIResponsesRequest) *RelayInfo { func GenRelayInfoResponses(c *gin.Context, req *dto.OpenAIResponsesRequest) *RelayInfo {
info := GenRelayInfo(c) info := GenRelayInfo(c)
info.RelayMode = relayconstant.RelayModeResponses info.RelayMode = relayconstant.RelayModeResponses
info.RelayFormat = RelayFormatOpenAIResponses
info.SupportStreamOptions = false
info.ResponsesUsageInfo = &ResponsesUsageInfo{ info.ResponsesUsageInfo = &ResponsesUsageInfo{
BuiltInTools: make(map[string]*BuildInToolInfo), BuiltInTools: make(map[string]*BuildInToolInfo),
} }
@@ -175,6 +197,19 @@ func GenRelayInfoResponses(c *gin.Context, req *dto.OpenAIResponsesRequest) *Rel
return info return info
} }
func GenRelayInfoGemini(c *gin.Context) *RelayInfo {
info := GenRelayInfo(c)
info.RelayFormat = RelayFormatGemini
info.ShouldIncludeUsage = false
return info
}
func GenRelayInfoImage(c *gin.Context) *RelayInfo {
info := GenRelayInfo(c)
info.RelayFormat = RelayFormatOpenAIImage
return info
}
func GenRelayInfo(c *gin.Context) *RelayInfo { func GenRelayInfo(c *gin.Context) *RelayInfo {
channelType := c.GetInt("channel_type") channelType := c.GetInt("channel_type")
channelId := c.GetInt("channel_id") channelId := c.GetInt("channel_id")
@@ -243,10 +278,6 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
if streamSupportedChannels[info.ChannelType] { if streamSupportedChannels[info.ChannelType] {
info.SupportStreamOptions = true info.SupportStreamOptions = true
} }
// responses 模式不支持 StreamOptions
if relayconstant.RelayModeResponses == info.RelayMode {
info.SupportStreamOptions = false
}
return info return info
} }

View File

@@ -38,6 +38,9 @@ const (
RelayModeSunoFetchByID RelayModeSunoFetchByID
RelayModeSunoSubmit RelayModeSunoSubmit
RelayModeKlingFetchByID
RelayModeKlingSubmit
RelayModeRerank RelayModeRerank
RelayModeResponses RelayModeResponses
@@ -133,3 +136,13 @@ func Path2RelaySuno(method, path string) int {
} }
return relayMode return relayMode
} }
func Path2RelayKling(method, path string) int {
relayMode := RelayModeUnknown
if method == http.MethodPost && strings.HasSuffix(path, "/video/generations") {
relayMode = RelayModeKlingSubmit
} else if method == http.MethodGet && strings.Contains(path, "/video/generations/") {
relayMode = RelayModeKlingFetchByID
}
return relayMode
}

View File

@@ -15,7 +15,7 @@ import (
) )
func getEmbeddingPromptToken(embeddingRequest dto.EmbeddingRequest) int { func getEmbeddingPromptToken(embeddingRequest dto.EmbeddingRequest) int {
token, _ := service.CountTokenInput(embeddingRequest.Input, embeddingRequest.Model) token := service.CountTokenInput(embeddingRequest.Input, embeddingRequest.Model)
return token return token
} }
@@ -33,7 +33,7 @@ func validateEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, embed
} }
func EmbeddingHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { func EmbeddingHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
relayInfo := relaycommon.GenRelayInfo(c) relayInfo := relaycommon.GenRelayInfoEmbedding(c)
var embeddingRequest *dto.EmbeddingRequest var embeddingRequest *dto.EmbeddingRequest
err := common.UnmarshalBodyReusable(c, &embeddingRequest) err := common.UnmarshalBodyReusable(c, &embeddingRequest)
@@ -47,13 +47,11 @@ func EmbeddingHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode)
return service.OpenAIErrorWrapperLocal(err, "invalid_embedding_request", http.StatusBadRequest) return service.OpenAIErrorWrapperLocal(err, "invalid_embedding_request", http.StatusBadRequest)
} }
err = helper.ModelMappedHelper(c, relayInfo) err = helper.ModelMappedHelper(c, relayInfo, embeddingRequest)
if err != nil { if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError) return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
} }
embeddingRequest.Model = relayInfo.UpstreamModelName
promptToken := getEmbeddingPromptToken(*embeddingRequest) promptToken := getEmbeddingPromptToken(*embeddingRequest)
relayInfo.PromptTokens = promptToken relayInfo.PromptTokens = promptToken

View File

@@ -59,7 +59,7 @@ func checkGeminiInputSensitive(textRequest *gemini.GeminiChatRequest) ([]string,
return sensitiveWords, err return sensitiveWords, err
} }
func getGeminiInputTokens(req *gemini.GeminiChatRequest, info *relaycommon.RelayInfo) (int, error) { func getGeminiInputTokens(req *gemini.GeminiChatRequest, info *relaycommon.RelayInfo) int {
// 计算输入 token 数量 // 计算输入 token 数量
var inputTexts []string var inputTexts []string
for _, content := range req.Contents { for _, content := range req.Contents {
@@ -71,9 +71,9 @@ func getGeminiInputTokens(req *gemini.GeminiChatRequest, info *relaycommon.Relay
} }
inputText := strings.Join(inputTexts, "\n") inputText := strings.Join(inputTexts, "\n")
inputTokens, err := service.CountTokenInput(inputText, info.UpstreamModelName) inputTokens := service.CountTokenInput(inputText, info.UpstreamModelName)
info.PromptTokens = inputTokens info.PromptTokens = inputTokens
return inputTokens, err return inputTokens
} }
func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
@@ -83,7 +83,7 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
return service.OpenAIErrorWrapperLocal(err, "invalid_gemini_request", http.StatusBadRequest) return service.OpenAIErrorWrapperLocal(err, "invalid_gemini_request", http.StatusBadRequest)
} }
relayInfo := relaycommon.GenRelayInfo(c) relayInfo := relaycommon.GenRelayInfoGemini(c)
// 检查 Gemini 流式模式 // 检查 Gemini 流式模式
checkGeminiStreamMode(c, relayInfo) checkGeminiStreamMode(c, relayInfo)
@@ -97,7 +97,7 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
} }
// model mapped 模型映射 // model mapped 模型映射
err = helper.ModelMappedHelper(c, relayInfo) err = helper.ModelMappedHelper(c, relayInfo, req)
if err != nil { if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusBadRequest) return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusBadRequest)
} }
@@ -106,7 +106,7 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
promptTokens := value.(int) promptTokens := value.(int)
relayInfo.SetPromptTokens(promptTokens) relayInfo.SetPromptTokens(promptTokens)
} else { } else {
promptTokens, err := getGeminiInputTokens(req, relayInfo) promptTokens := getGeminiInputTokens(req, relayInfo)
if err != nil { if err != nil {
return service.OpenAIErrorWrapperLocal(err, "count_input_tokens_error", http.StatusBadRequest) return service.OpenAIErrorWrapperLocal(err, "count_input_tokens_error", http.StatusBadRequest)
} }
@@ -162,7 +162,7 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
resp, err := adaptor.DoRequest(c, relayInfo, bytes.NewReader(requestBody)) resp, err := adaptor.DoRequest(c, relayInfo, bytes.NewReader(requestBody))
if err != nil { if err != nil {
common.LogError(c, "Do gemini request failed: "+err.Error()) common.LogError(c, "Do gemini request failed: "+err.Error())
return service.OpenAIErrorWrapperLocal(err, "do_request_failed", http.StatusInternalServerError) return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
} }
statusCodeMappingStr := c.GetString("status_code_mapping") statusCodeMappingStr := c.GetString("status_code_mapping")

View File

@@ -4,12 +4,14 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
common2 "one-api/common"
"one-api/dto"
"one-api/relay/common" "one-api/relay/common"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
func ModelMappedHelper(c *gin.Context, info *common.RelayInfo) error { func ModelMappedHelper(c *gin.Context, info *common.RelayInfo, request any) error {
// map model name // map model name
modelMapping := c.GetString("model_mapping") modelMapping := c.GetString("model_mapping")
if modelMapping != "" && modelMapping != "{}" { if modelMapping != "" && modelMapping != "{}" {
@@ -50,5 +52,41 @@ func ModelMappedHelper(c *gin.Context, info *common.RelayInfo) error {
info.UpstreamModelName = currentModel info.UpstreamModelName = currentModel
} }
} }
if request != nil {
switch info.RelayFormat {
case common.RelayFormatGemini:
// Gemini 模型映射
case common.RelayFormatClaude:
if claudeRequest, ok := request.(*dto.ClaudeRequest); ok {
claudeRequest.Model = info.UpstreamModelName
}
case common.RelayFormatOpenAIResponses:
if openAIResponsesRequest, ok := request.(*dto.OpenAIResponsesRequest); ok {
openAIResponsesRequest.Model = info.UpstreamModelName
}
case common.RelayFormatOpenAIAudio:
if openAIAudioRequest, ok := request.(*dto.AudioRequest); ok {
openAIAudioRequest.Model = info.UpstreamModelName
}
case common.RelayFormatOpenAIImage:
if imageRequest, ok := request.(*dto.ImageRequest); ok {
imageRequest.Model = info.UpstreamModelName
}
case common.RelayFormatRerank:
if rerankRequest, ok := request.(*dto.RerankRequest); ok {
rerankRequest.Model = info.UpstreamModelName
}
case common.RelayFormatEmbedding:
if embeddingRequest, ok := request.(*dto.EmbeddingRequest); ok {
embeddingRequest.Model = info.UpstreamModelName
}
default:
if openAIRequest, ok := request.(*dto.GeneralOpenAIRequest); ok {
openAIRequest.Model = info.UpstreamModelName
} else {
common2.LogWarn(c, fmt.Sprintf("model mapped but request type %T not supported", request))
}
}
}
return nil return nil
} }

View File

@@ -102,7 +102,7 @@ func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
} }
func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
relayInfo := relaycommon.GenRelayInfo(c) relayInfo := relaycommon.GenRelayInfoImage(c)
imageRequest, err := getAndValidImageRequest(c, relayInfo) imageRequest, err := getAndValidImageRequest(c, relayInfo)
if err != nil { if err != nil {
@@ -110,13 +110,11 @@ func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
return service.OpenAIErrorWrapper(err, "invalid_image_request", http.StatusBadRequest) return service.OpenAIErrorWrapper(err, "invalid_image_request", http.StatusBadRequest)
} }
err = helper.ModelMappedHelper(c, relayInfo) err = helper.ModelMappedHelper(c, relayInfo, imageRequest)
if err != nil { if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError) return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
} }
imageRequest.Model = relayInfo.UpstreamModelName
priceData, err := helper.ModelPriceHelper(c, relayInfo, len(imageRequest.Prompt), 0) priceData, err := helper.ModelPriceHelper(c, relayInfo, len(imageRequest.Prompt), 0)
if err != nil { if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError) return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)

View File

@@ -108,13 +108,11 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
} }
} }
err = helper.ModelMappedHelper(c, relayInfo) err = helper.ModelMappedHelper(c, relayInfo, textRequest)
if err != nil { if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError) return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
} }
textRequest.Model = relayInfo.UpstreamModelName
// 获取 promptTokens如果上下文中已经存在则直接使用 // 获取 promptTokens如果上下文中已经存在则直接使用
var promptTokens int var promptTokens int
if value, exists := c.Get("prompt_tokens"); exists { if value, exists := c.Get("prompt_tokens"); exists {
@@ -253,11 +251,11 @@ func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.Re
case relayconstant.RelayModeChatCompletions: case relayconstant.RelayModeChatCompletions:
promptTokens, err = service.CountTokenChatRequest(info, *textRequest) promptTokens, err = service.CountTokenChatRequest(info, *textRequest)
case relayconstant.RelayModeCompletions: case relayconstant.RelayModeCompletions:
promptTokens, err = service.CountTokenInput(textRequest.Prompt, textRequest.Model) promptTokens = service.CountTokenInput(textRequest.Prompt, textRequest.Model)
case relayconstant.RelayModeModerations: case relayconstant.RelayModeModerations:
promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model) promptTokens = service.CountTokenInput(textRequest.Input, textRequest.Model)
case relayconstant.RelayModeEmbeddings: case relayconstant.RelayModeEmbeddings:
promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model) promptTokens = service.CountTokenInput(textRequest.Input, textRequest.Model)
default: default:
err = errors.New("unknown relay mode") err = errors.New("unknown relay mode")
promptTokens = 0 promptTokens = 0

View File

@@ -22,6 +22,7 @@ import (
"one-api/relay/channel/palm" "one-api/relay/channel/palm"
"one-api/relay/channel/perplexity" "one-api/relay/channel/perplexity"
"one-api/relay/channel/siliconflow" "one-api/relay/channel/siliconflow"
"one-api/relay/channel/task/kling"
"one-api/relay/channel/task/suno" "one-api/relay/channel/task/suno"
"one-api/relay/channel/tencent" "one-api/relay/channel/tencent"
"one-api/relay/channel/vertex" "one-api/relay/channel/vertex"
@@ -101,6 +102,8 @@ func GetTaskAdaptor(platform commonconstant.TaskPlatform) channel.TaskAdaptor {
// return &aiproxy.Adaptor{} // return &aiproxy.Adaptor{}
case commonconstant.TaskPlatformSuno: case commonconstant.TaskPlatformSuno:
return &suno.TaskAdaptor{} return &suno.TaskAdaptor{}
case commonconstant.TaskPlatformKling:
return &kling.TaskAdaptor{}
} }
return nil return nil
} }

View File

@@ -37,6 +37,9 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
} }
modelName := service.CoverTaskActionToModelName(platform, relayInfo.Action) modelName := service.CoverTaskActionToModelName(platform, relayInfo.Action)
if platform == constant.TaskPlatformKling {
modelName = relayInfo.OriginModelName
}
modelPrice, success := ratio_setting.GetModelPrice(modelName, true) modelPrice, success := ratio_setting.GetModelPrice(modelName, true)
if !success { if !success {
defaultPrice, ok := ratio_setting.GetDefaultModelRatioMap()[modelName] defaultPrice, ok := ratio_setting.GetDefaultModelRatioMap()[modelName]
@@ -136,10 +139,11 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
} }
relayInfo.ConsumeQuota = true relayInfo.ConsumeQuota = true
// insert task // insert task
task := model.InitTask(constant.TaskPlatformSuno, relayInfo) task := model.InitTask(platform, relayInfo)
task.TaskID = taskID task.TaskID = taskID
task.Quota = quota task.Quota = quota
task.Data = taskData task.Data = taskData
task.Action = relayInfo.Action
err = task.Insert() err = task.Insert()
if err != nil { if err != nil {
taskErr = service.TaskErrorWrapper(err, "insert_task_failed", http.StatusInternalServerError) taskErr = service.TaskErrorWrapper(err, "insert_task_failed", http.StatusInternalServerError)
@@ -149,8 +153,9 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
} }
var fetchRespBuilders = map[int]func(c *gin.Context) (respBody []byte, taskResp *dto.TaskError){ var fetchRespBuilders = map[int]func(c *gin.Context) (respBody []byte, taskResp *dto.TaskError){
relayconstant.RelayModeSunoFetchByID: sunoFetchByIDRespBodyBuilder, relayconstant.RelayModeSunoFetchByID: sunoFetchByIDRespBodyBuilder,
relayconstant.RelayModeSunoFetch: sunoFetchRespBodyBuilder, relayconstant.RelayModeSunoFetch: sunoFetchRespBodyBuilder,
relayconstant.RelayModeKlingFetchByID: videoFetchByIDRespBodyBuilder,
} }
func RelayTaskFetch(c *gin.Context, relayMode int) (taskResp *dto.TaskError) { func RelayTaskFetch(c *gin.Context, relayMode int) (taskResp *dto.TaskError) {
@@ -225,6 +230,27 @@ func sunoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dt
return return
} }
func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
taskId := c.Param("id")
userId := c.GetInt("id")
originTask, exist, err := model.GetByTaskId(userId, taskId)
if err != nil {
taskResp = service.TaskErrorWrapper(err, "get_task_failed", http.StatusInternalServerError)
return
}
if !exist {
taskResp = service.TaskErrorWrapperLocal(errors.New("task_not_exist"), "task_not_exist", http.StatusBadRequest)
return
}
respBody, err = json.Marshal(dto.TaskResponse[any]{
Code: "success",
Data: TaskModel2Dto(originTask),
})
return
}
func TaskModel2Dto(task *model.Task) *dto.TaskDto { func TaskModel2Dto(task *model.Task) *dto.TaskDto {
return &dto.TaskDto{ return &dto.TaskDto{
TaskID: task.TaskID, TaskID: task.TaskID,

View File

@@ -14,12 +14,10 @@ import (
) )
func getRerankPromptToken(rerankRequest dto.RerankRequest) int { func getRerankPromptToken(rerankRequest dto.RerankRequest) int {
token, _ := service.CountTokenInput(rerankRequest.Query, rerankRequest.Model) token := service.CountTokenInput(rerankRequest.Query, rerankRequest.Model)
for _, document := range rerankRequest.Documents { for _, document := range rerankRequest.Documents {
tkm, err := service.CountTokenInput(document, rerankRequest.Model) tkm := service.CountTokenInput(document, rerankRequest.Model)
if err == nil { token += tkm
token += tkm
}
} }
return token return token
} }
@@ -42,13 +40,11 @@ func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWith
return service.OpenAIErrorWrapperLocal(fmt.Errorf("documents is empty"), "invalid_documents", http.StatusBadRequest) return service.OpenAIErrorWrapperLocal(fmt.Errorf("documents is empty"), "invalid_documents", http.StatusBadRequest)
} }
err = helper.ModelMappedHelper(c, relayInfo) err = helper.ModelMappedHelper(c, relayInfo, rerankRequest)
if err != nil { if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError) return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
} }
rerankRequest.Model = relayInfo.UpstreamModelName
promptToken := getRerankPromptToken(*rerankRequest) promptToken := getRerankPromptToken(*rerankRequest)
relayInfo.PromptTokens = promptToken relayInfo.PromptTokens = promptToken

View File

@@ -40,10 +40,10 @@ func checkInputSensitive(textRequest *dto.OpenAIResponsesRequest, info *relaycom
return sensitiveWords, err return sensitiveWords, err
} }
func getInputTokens(req *dto.OpenAIResponsesRequest, info *relaycommon.RelayInfo) (int, error) { func getInputTokens(req *dto.OpenAIResponsesRequest, info *relaycommon.RelayInfo) int {
inputTokens, err := service.CountTokenInput(req.Input, req.Model) inputTokens := service.CountTokenInput(req.Input, req.Model)
info.PromptTokens = inputTokens info.PromptTokens = inputTokens
return inputTokens, err return inputTokens
} }
func ResponsesHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { func ResponsesHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
@@ -63,19 +63,16 @@ func ResponsesHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode)
} }
} }
err = helper.ModelMappedHelper(c, relayInfo) err = helper.ModelMappedHelper(c, relayInfo, req)
if err != nil { if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusBadRequest) return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusBadRequest)
} }
req.Model = relayInfo.UpstreamModelName
if value, exists := c.Get("prompt_tokens"); exists { if value, exists := c.Get("prompt_tokens"); exists {
promptTokens := value.(int) promptTokens := value.(int)
relayInfo.SetPromptTokens(promptTokens) relayInfo.SetPromptTokens(promptTokens)
} else { } else {
promptTokens, err := getInputTokens(req, relayInfo) promptTokens := getInputTokens(req, relayInfo)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "count_input_tokens_error", http.StatusBadRequest)
}
c.Set("prompt_tokens", promptTokens) c.Set("prompt_tokens", promptTokens)
} }

View File

@@ -36,6 +36,7 @@ func SetApiRouter(router *gin.Engine) {
apiRouter.GET("/oauth/email/bind", middleware.CriticalRateLimit(), controller.EmailBind) apiRouter.GET("/oauth/email/bind", middleware.CriticalRateLimit(), controller.EmailBind)
apiRouter.GET("/oauth/telegram/login", middleware.CriticalRateLimit(), controller.TelegramLogin) apiRouter.GET("/oauth/telegram/login", middleware.CriticalRateLimit(), controller.TelegramLogin)
apiRouter.GET("/oauth/telegram/bind", middleware.CriticalRateLimit(), controller.TelegramBind) apiRouter.GET("/oauth/telegram/bind", middleware.CriticalRateLimit(), controller.TelegramBind)
apiRouter.GET("/ratio_config", middleware.CriticalRateLimit(), controller.GetRatioConfig)
userRoute := apiRouter.Group("/user") userRoute := apiRouter.Group("/user")
{ {
@@ -83,6 +84,12 @@ func SetApiRouter(router *gin.Engine) {
optionRoute.POST("/rest_model_ratio", controller.ResetModelRatio) optionRoute.POST("/rest_model_ratio", controller.ResetModelRatio)
optionRoute.POST("/migrate_console_setting", controller.MigrateConsoleSetting) // 用于迁移检测的旧键,下个版本会删除 optionRoute.POST("/migrate_console_setting", controller.MigrateConsoleSetting) // 用于迁移检测的旧键,下个版本会删除
} }
ratioSyncRoute := apiRouter.Group("/ratio_sync")
ratioSyncRoute.Use(middleware.RootAuth())
{
ratioSyncRoute.GET("/channels", controller.GetSyncableChannels)
ratioSyncRoute.POST("/fetch", controller.FetchUpstreamRatios)
}
channelRoute := apiRouter.Group("/channel") channelRoute := apiRouter.Group("/channel")
channelRoute.Use(middleware.AdminAuth()) channelRoute.Use(middleware.AdminAuth())
{ {

View File

@@ -14,6 +14,7 @@ func SetRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) {
SetApiRouter(router) SetApiRouter(router)
SetDashboardRouter(router) SetDashboardRouter(router)
SetRelayRouter(router) SetRelayRouter(router)
SetVideoRouter(router)
frontendBaseUrl := os.Getenv("FRONTEND_BASE_URL") frontendBaseUrl := os.Getenv("FRONTEND_BASE_URL")
if common.IsMasterNode && frontendBaseUrl != "" { if common.IsMasterNode && frontendBaseUrl != "" {
frontendBaseUrl = "" frontendBaseUrl = ""

17
router/video-router.go Normal file
View File

@@ -0,0 +1,17 @@
package router
import (
"one-api/controller"
"one-api/middleware"
"github.com/gin-gonic/gin"
)
func SetVideoRouter(router *gin.Engine) {
videoV1Router := router.Group("/v1")
videoV1Router.Use(middleware.TokenAuth(), middleware.Distribute())
{
videoV1Router.POST("/video/generations", controller.RelayTask)
videoV1Router.GET("/video/generations/:task_id", controller.RelayTask)
}
}

View File

@@ -171,7 +171,7 @@ func CountTokenChatRequest(info *relaycommon.RelayInfo, request dto.GeneralOpenA
countStr += fmt.Sprintf("%v", tool.Function.Parameters) countStr += fmt.Sprintf("%v", tool.Function.Parameters)
} }
} }
toolTokens, err := CountTokenInput(countStr, request.Model) toolTokens := CountTokenInput(countStr, request.Model)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@@ -194,7 +194,7 @@ func CountTokenClaudeRequest(request dto.ClaudeRequest, model string) (int, erro
// Count tokens in system message // Count tokens in system message
if request.System != "" { if request.System != "" {
systemTokens, err := CountTokenInput(request.System, model) systemTokens := CountTokenInput(request.System, model)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@@ -296,10 +296,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent,
switch request.Type { switch request.Type {
case dto.RealtimeEventTypeSessionUpdate: case dto.RealtimeEventTypeSessionUpdate:
if request.Session != nil { if request.Session != nil {
msgTokens, err := CountTextToken(request.Session.Instructions, model) msgTokens := CountTextToken(request.Session.Instructions, model)
if err != nil {
return 0, 0, err
}
textToken += msgTokens textToken += msgTokens
} }
case dto.RealtimeEventResponseAudioDelta: case dto.RealtimeEventResponseAudioDelta:
@@ -311,10 +308,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent,
audioToken += atk audioToken += atk
case dto.RealtimeEventResponseAudioTranscriptionDelta, dto.RealtimeEventResponseFunctionCallArgumentsDelta: case dto.RealtimeEventResponseAudioTranscriptionDelta, dto.RealtimeEventResponseFunctionCallArgumentsDelta:
// count text token // count text token
tkm, err := CountTextToken(request.Delta, model) tkm := CountTextToken(request.Delta, model)
if err != nil {
return 0, 0, fmt.Errorf("error counting text token: %v", err)
}
textToken += tkm textToken += tkm
case dto.RealtimeEventInputAudioBufferAppend: case dto.RealtimeEventInputAudioBufferAppend:
// count audio token // count audio token
@@ -329,10 +323,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent,
case "message": case "message":
for _, content := range request.Item.Content { for _, content := range request.Item.Content {
if content.Type == "input_text" { if content.Type == "input_text" {
tokens, err := CountTextToken(content.Text, model) tokens := CountTextToken(content.Text, model)
if err != nil {
return 0, 0, err
}
textToken += tokens textToken += tokens
} }
} }
@@ -343,10 +334,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent,
if !info.IsFirstRequest { if !info.IsFirstRequest {
if info.RealtimeTools != nil && len(info.RealtimeTools) > 0 { if info.RealtimeTools != nil && len(info.RealtimeTools) > 0 {
for _, tool := range info.RealtimeTools { for _, tool := range info.RealtimeTools {
toolTokens, err := CountTokenInput(tool, model) toolTokens := CountTokenInput(tool, model)
if err != nil {
return 0, 0, err
}
textToken += 8 textToken += 8
textToken += toolTokens textToken += toolTokens
} }
@@ -409,7 +397,7 @@ func CountTokenMessages(info *relaycommon.RelayInfo, messages []dto.Message, mod
return tokenNum, nil return tokenNum, nil
} }
func CountTokenInput(input any, model string) (int, error) { func CountTokenInput(input any, model string) int {
switch v := input.(type) { switch v := input.(type) {
case string: case string:
return CountTextToken(v, model) return CountTextToken(v, model)
@@ -432,13 +420,13 @@ func CountTokenInput(input any, model string) (int, error) {
func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice, model string) int { func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice, model string) int {
tokens := 0 tokens := 0
for _, message := range messages { for _, message := range messages {
tkm, _ := CountTokenInput(message.Delta.GetContentString(), model) tkm := CountTokenInput(message.Delta.GetContentString(), model)
tokens += tkm tokens += tkm
if message.Delta.ToolCalls != nil { if message.Delta.ToolCalls != nil {
for _, tool := range message.Delta.ToolCalls { for _, tool := range message.Delta.ToolCalls {
tkm, _ := CountTokenInput(tool.Function.Name, model) tkm := CountTokenInput(tool.Function.Name, model)
tokens += tkm tokens += tkm
tkm, _ = CountTokenInput(tool.Function.Arguments, model) tkm = CountTokenInput(tool.Function.Arguments, model)
tokens += tkm tokens += tkm
} }
} }
@@ -446,9 +434,9 @@ func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice,
return tokens return tokens
} }
func CountTTSToken(text string, model string) (int, error) { func CountTTSToken(text string, model string) int {
if strings.HasPrefix(model, "tts") { if strings.HasPrefix(model, "tts") {
return utf8.RuneCountInString(text), nil return utf8.RuneCountInString(text)
} else { } else {
return CountTextToken(text, model) return CountTextToken(text, model)
} }
@@ -483,8 +471,10 @@ func CountAudioTokenOutput(audioBase64 string, audioFormat string) (int, error)
//} //}
// CountTextToken 统计文本的token数量仅当文本包含敏感词返回错误同时返回token数量 // CountTextToken 统计文本的token数量仅当文本包含敏感词返回错误同时返回token数量
func CountTextToken(text string, model string) (int, error) { func CountTextToken(text string, model string) int {
var err error if text == "" {
return 0
}
tokenEncoder := getTokenEncoder(model) tokenEncoder := getTokenEncoder(model)
return getTokenNum(tokenEncoder, text), err return getTokenNum(tokenEncoder, text)
} }

View File

@@ -16,13 +16,13 @@ import (
// return 0, errors.New("unknown relay mode") // return 0, errors.New("unknown relay mode")
//} //}
func ResponseText2Usage(responseText string, modeName string, promptTokens int) (*dto.Usage, error) { func ResponseText2Usage(responseText string, modeName string, promptTokens int) *dto.Usage {
usage := &dto.Usage{} usage := &dto.Usage{}
usage.PromptTokens = promptTokens usage.PromptTokens = promptTokens
ctkm, err := CountTextToken(responseText, modeName) ctkm := CountTextToken(responseText, modeName)
usage.CompletionTokens = ctkm usage.CompletionTokens = ctkm
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
return usage, err return usage
} }
func ValidUsage(usage *dto.Usage) bool { func ValidUsage(usage *dto.Usage) bool {

View File

@@ -13,12 +13,12 @@ var PayMethods = []map[string]string{
{ {
"name": "支付宝", "name": "支付宝",
"color": "rgba(var(--semi-blue-5), 1)", "color": "rgba(var(--semi-blue-5), 1)",
"type": "zfb", "type": "alipay",
}, },
{ {
"name": "微信", "name": "微信",
"color": "rgba(var(--semi-green-5), 1)", "color": "rgba(var(--semi-green-5), 1)",
"type": "wx", "type": "wxpay",
}, },
} }

View File

@@ -85,7 +85,11 @@ func UpdateCacheRatioByJSONString(jsonStr string) error {
cacheRatioMapMutex.Lock() cacheRatioMapMutex.Lock()
defer cacheRatioMapMutex.Unlock() defer cacheRatioMapMutex.Unlock()
cacheRatioMap = make(map[string]float64) cacheRatioMap = make(map[string]float64)
return json.Unmarshal([]byte(jsonStr), &cacheRatioMap) err := json.Unmarshal([]byte(jsonStr), &cacheRatioMap)
if err == nil {
InvalidateExposedDataCache()
}
return err
} }
// GetCacheRatio returns the cache ratio for a model // GetCacheRatio returns the cache ratio for a model
@@ -106,3 +110,13 @@ func GetCreateCacheRatio(name string) (float64, bool) {
} }
return ratio, true return ratio, true
} }
func GetCacheRatioCopy() map[string]float64 {
cacheRatioMapMutex.RLock()
defer cacheRatioMapMutex.RUnlock()
copyMap := make(map[string]float64, len(cacheRatioMap))
for k, v := range cacheRatioMap {
copyMap[k] = v
}
return copyMap
}

View File

@@ -0,0 +1,17 @@
package ratio_setting
import "sync/atomic"
var exposeRatioEnabled atomic.Bool
func init() {
exposeRatioEnabled.Store(false)
}
func SetExposeRatioEnabled(enabled bool) {
exposeRatioEnabled.Store(enabled)
}
func IsExposeRatioEnabled() bool {
return exposeRatioEnabled.Load()
}

View File

@@ -0,0 +1,55 @@
package ratio_setting
import (
"sync"
"sync/atomic"
"time"
"github.com/gin-gonic/gin"
)
const exposedDataTTL = 30 * time.Second
type exposedCache struct {
data gin.H
expiresAt time.Time
}
var (
exposedData atomic.Value
rebuildMu sync.Mutex
)
func InvalidateExposedDataCache() {
exposedData.Store((*exposedCache)(nil))
}
func cloneGinH(src gin.H) gin.H {
dst := make(gin.H, len(src))
for k, v := range src {
dst[k] = v
}
return dst
}
func GetExposedData() gin.H {
if c, ok := exposedData.Load().(*exposedCache); ok && c != nil && time.Now().Before(c.expiresAt) {
return cloneGinH(c.data)
}
rebuildMu.Lock()
defer rebuildMu.Unlock()
if c, ok := exposedData.Load().(*exposedCache); ok && c != nil && time.Now().Before(c.expiresAt) {
return cloneGinH(c.data)
}
newData := gin.H{
"model_ratio": GetModelRatioCopy(),
"completion_ratio": GetCompletionRatioCopy(),
"cache_ratio": GetCacheRatioCopy(),
"model_price": GetModelPriceCopy(),
}
exposedData.Store(&exposedCache{
data: newData,
expiresAt: time.Now().Add(exposedDataTTL),
})
return cloneGinH(newData)
}

View File

@@ -317,7 +317,11 @@ func UpdateModelPriceByJSONString(jsonStr string) error {
modelPriceMapMutex.Lock() modelPriceMapMutex.Lock()
defer modelPriceMapMutex.Unlock() defer modelPriceMapMutex.Unlock()
modelPriceMap = make(map[string]float64) modelPriceMap = make(map[string]float64)
return json.Unmarshal([]byte(jsonStr), &modelPriceMap) err := json.Unmarshal([]byte(jsonStr), &modelPriceMap)
if err == nil {
InvalidateExposedDataCache()
}
return err
} }
// GetModelPrice 返回模型的价格,如果模型不存在则返回-1false // GetModelPrice 返回模型的价格,如果模型不存在则返回-1false
@@ -345,7 +349,11 @@ func UpdateModelRatioByJSONString(jsonStr string) error {
modelRatioMapMutex.Lock() modelRatioMapMutex.Lock()
defer modelRatioMapMutex.Unlock() defer modelRatioMapMutex.Unlock()
modelRatioMap = make(map[string]float64) modelRatioMap = make(map[string]float64)
return json.Unmarshal([]byte(jsonStr), &modelRatioMap) err := json.Unmarshal([]byte(jsonStr), &modelRatioMap)
if err == nil {
InvalidateExposedDataCache()
}
return err
} }
// 处理带有思考预算的模型名称,方便统一定价 // 处理带有思考预算的模型名称,方便统一定价
@@ -405,7 +413,11 @@ func UpdateCompletionRatioByJSONString(jsonStr string) error {
CompletionRatioMutex.Lock() CompletionRatioMutex.Lock()
defer CompletionRatioMutex.Unlock() defer CompletionRatioMutex.Unlock()
CompletionRatio = make(map[string]float64) CompletionRatio = make(map[string]float64)
return json.Unmarshal([]byte(jsonStr), &CompletionRatio) err := json.Unmarshal([]byte(jsonStr), &CompletionRatio)
if err == nil {
InvalidateExposedDataCache()
}
return err
} }
func GetCompletionRatio(name string) float64 { func GetCompletionRatio(name string) float64 {
@@ -609,3 +621,33 @@ func GetImageRatio(name string) (float64, bool) {
} }
return ratio, true return ratio, true
} }
func GetModelRatioCopy() map[string]float64 {
modelRatioMapMutex.RLock()
defer modelRatioMapMutex.RUnlock()
copyMap := make(map[string]float64, len(modelRatioMap))
for k, v := range modelRatioMap {
copyMap[k] = v
}
return copyMap
}
func GetModelPriceCopy() map[string]float64 {
modelPriceMapMutex.RLock()
defer modelPriceMapMutex.RUnlock()
copyMap := make(map[string]float64, len(modelPriceMap))
for k, v := range modelPriceMap {
copyMap[k] = v
}
return copyMap
}
func GetCompletionRatioCopy() map[string]float64 {
CompletionRatioMutex.RLock()
defer CompletionRatioMutex.RUnlock()
copyMap := make(map[string]float64, len(CompletionRatio))
for k, v := range CompletionRatio {
copyMap[k] = v
}
return copyMap
}

View File

@@ -0,0 +1,143 @@
import React, { useState } from 'react';
import {
Modal,
Transfer,
Input,
Space,
Checkbox,
Avatar,
Highlight,
} from '@douyinfe/semi-ui';
import { IconClose } from '@douyinfe/semi-icons';
const CHANNEL_STATUS_CONFIG = {
1: { color: 'green', text: '启用' },
2: { color: 'red', text: '禁用' },
3: { color: 'amber', text: '自禁' },
default: { color: 'grey', text: '未知' }
};
const getChannelStatusConfig = (status) => {
return CHANNEL_STATUS_CONFIG[status] || CHANNEL_STATUS_CONFIG.default;
};
export default function ChannelSelectorModal({
t,
visible,
onCancel,
onOk,
allChannels = [],
selectedChannelIds = [],
setSelectedChannelIds,
channelEndpoints,
updateChannelEndpoint,
}) {
const [searchText, setSearchText] = useState('');
const ChannelInfo = ({ item, showEndpoint = false, isSelected = false }) => {
const channelId = item.key || item.value;
const currentEndpoint = channelEndpoints[channelId];
const baseUrl = item._originalData?.base_url || '';
const status = item._originalData?.status || 0;
const statusConfig = getChannelStatusConfig(status);
return (
<>
<Avatar color={statusConfig.color} size="small">
{statusConfig.text}
</Avatar>
<div className="info">
<div className="name">
{isSelected ? (
item.label
) : (
<Highlight sourceString={item.label} searchWords={[searchText]} />
)}
</div>
<div className="email" style={showEndpoint ? { display: 'flex', alignItems: 'center', gap: '4px' } : {}}>
<span className="text-xs text-gray-500 truncate max-w-[200px]" title={baseUrl}>
{isSelected ? (
baseUrl
) : (
<Highlight sourceString={baseUrl} searchWords={[searchText]} />
)}
</span>
{showEndpoint && (
<Input
size="small"
value={currentEndpoint}
onChange={(value) => updateChannelEndpoint(channelId, value)}
placeholder="/api/ratio_config"
className="flex-1 text-xs"
style={{ fontSize: '12px' }}
/>
)}
{isSelected && !showEndpoint && (
<span className="text-xs text-gray-700 font-mono bg-gray-100 px-2 py-1 rounded ml-2">
{currentEndpoint}
</span>
)}
</div>
</div>
</>
);
};
const renderSourceItem = (item) => {
return (
<div className="components-transfer-source-item" key={item.key}>
<Checkbox
onChange={item.onChange}
checked={item.checked}
style={{ height: 52, alignItems: 'center' }}
>
<ChannelInfo item={item} showEndpoint={true} />
</Checkbox>
</div>
);
};
const renderSelectedItem = (item) => {
return (
<div className="components-transfer-selected-item" key={item.key}>
<ChannelInfo item={item} isSelected={true} />
<IconClose style={{ cursor: 'pointer' }} onClick={item.onRemove} />
</div>
);
};
const channelFilter = (input, item) => {
const searchLower = input.toLowerCase();
return item.label.toLowerCase().includes(searchLower) ||
(item._originalData?.base_url || '').toLowerCase().includes(searchLower);
};
return (
<Modal
visible={visible}
onCancel={onCancel}
onOk={onOk}
title={<span className="text-lg font-semibold">{t('选择同步渠道')}</span>}
width={1000}
>
<Space vertical style={{ width: '100%' }}>
<Transfer
style={{ width: '100%' }}
dataSource={allChannels}
value={selectedChannelIds}
onChange={setSelectedChannelIds}
renderSourceItem={renderSourceItem}
renderSelectedItem={renderSelectedItem}
filter={channelFilter}
inputProps={{ placeholder: t('搜索渠道名称或地址') }}
onSearch={setSearchText}
emptyContent={{
left: t('暂无渠道'),
right: t('暂无选择'),
search: t('无搜索结果'),
}}
/>
</Space>
</Modal>
);
}

View File

@@ -6,6 +6,7 @@ import GroupRatioSettings from '../../pages/Setting/Ratio/GroupRatioSettings.js'
import ModelRatioSettings from '../../pages/Setting/Ratio/ModelRatioSettings.js'; import ModelRatioSettings from '../../pages/Setting/Ratio/ModelRatioSettings.js';
import ModelSettingsVisualEditor from '../../pages/Setting/Ratio/ModelSettingsVisualEditor.js'; import ModelSettingsVisualEditor from '../../pages/Setting/Ratio/ModelSettingsVisualEditor.js';
import ModelRatioNotSetEditor from '../../pages/Setting/Ratio/ModelRationNotSetEditor.js'; import ModelRatioNotSetEditor from '../../pages/Setting/Ratio/ModelRationNotSetEditor.js';
import UpstreamRatioSync from '../../pages/Setting/Ratio/UpstreamRatioSync.js';
import { API, showError } from '../../helpers'; import { API, showError } from '../../helpers';
@@ -21,6 +22,7 @@ const RatioSetting = () => {
GroupGroupRatio: '', GroupGroupRatio: '',
AutoGroups: '', AutoGroups: '',
DefaultUseAutoGroup: false, DefaultUseAutoGroup: false,
ExposeRatioEnabled: false,
UserUsableGroups: '', UserUsableGroups: '',
}); });
@@ -48,7 +50,7 @@ const RatioSetting = () => {
// 如果后端返回的不是合法 JSON直接展示 // 如果后端返回的不是合法 JSON直接展示
} }
} }
if (['DefaultUseAutoGroup'].includes(item.key)) { if (['DefaultUseAutoGroup', 'ExposeRatioEnabled'].includes(item.key)) {
newInputs[item.key] = item.value === 'true' ? true : false; newInputs[item.key] = item.value === 'true' ? true : false;
} else { } else {
newInputs[item.key] = item.value; newInputs[item.key] = item.value;
@@ -78,10 +80,6 @@ const RatioSetting = () => {
return ( return (
<Spin spinning={loading} size='large'> <Spin spinning={loading} size='large'>
{/* 分组倍率设置 */}
<Card style={{ marginTop: '10px' }}>
<GroupRatioSettings options={inputs} refresh={onRefresh} />
</Card>
{/* 模型倍率设置以及可视化编辑器 */} {/* 模型倍率设置以及可视化编辑器 */}
<Card style={{ marginTop: '10px' }}> <Card style={{ marginTop: '10px' }}>
<Tabs type='line'> <Tabs type='line'>
@@ -100,8 +98,18 @@ const RatioSetting = () => {
refresh={onRefresh} refresh={onRefresh}
/> />
</Tabs.TabPane> </Tabs.TabPane>
<Tabs.TabPane tab={t('上游倍率同步')} itemKey='upstream_sync'>
<UpstreamRatioSync
options={inputs}
refresh={onRefresh}
/>
</Tabs.TabPane>
</Tabs> </Tabs>
</Card> </Card>
{/* 分组倍率设置 */}
<Card style={{ marginTop: '10px' }}>
<GroupRatioSettings options={inputs} refresh={onRefresh} />
</Card>
</Spin> </Spin>
); );
}; };

View File

@@ -11,7 +11,9 @@ import {
XCircle, XCircle,
Loader, Loader,
List, List,
Hash Hash,
Video,
Sparkles
} from 'lucide-react'; } from 'lucide-react';
import { import {
API, API,
@@ -80,6 +82,7 @@ const COLUMN_KEYS = {
TASK_STATUS: 'task_status', TASK_STATUS: 'task_status',
PROGRESS: 'progress', PROGRESS: 'progress',
FAIL_REASON: 'fail_reason', FAIL_REASON: 'fail_reason',
RESULT_URL: 'result_url',
}; };
const renderTimestamp = (timestampInSeconds) => { const renderTimestamp = (timestampInSeconds) => {
@@ -150,6 +153,7 @@ const LogsTable = () => {
[COLUMN_KEYS.TASK_STATUS]: true, [COLUMN_KEYS.TASK_STATUS]: true,
[COLUMN_KEYS.PROGRESS]: true, [COLUMN_KEYS.PROGRESS]: true,
[COLUMN_KEYS.FAIL_REASON]: true, [COLUMN_KEYS.FAIL_REASON]: true,
[COLUMN_KEYS.RESULT_URL]: true,
}; };
}; };
@@ -203,6 +207,12 @@ const LogsTable = () => {
{t('生成歌词')} {t('生成歌词')}
</Tag> </Tag>
); );
case 'generate':
return (
<Tag color='blue' size='large' shape='circle' prefixIcon={<Sparkles size={14} />}>
{t('生成视频')}
</Tag>
);
default: default:
return ( return (
<Tag color='white' size='large' shape='circle' prefixIcon={<HelpCircle size={14} />}> <Tag color='white' size='large' shape='circle' prefixIcon={<HelpCircle size={14} />}>
@@ -220,6 +230,12 @@ const LogsTable = () => {
Suno Suno
</Tag> </Tag>
); );
case 'kling':
return (
<Tag color='blue' size='large' shape='circle' prefixIcon={<Video size={14} />}>
Kling
</Tag>
);
default: default:
return ( return (
<Tag color='white' size='large' shape='circle' prefixIcon={<HelpCircle size={14} />}> <Tag color='white' size='large' shape='circle' prefixIcon={<HelpCircle size={14} />}>
@@ -411,10 +427,21 @@ const LogsTable = () => {
}, },
{ {
key: COLUMN_KEYS.FAIL_REASON, key: COLUMN_KEYS.FAIL_REASON,
title: t('失败原因'), title: t('详情'),
dataIndex: 'fail_reason', dataIndex: 'fail_reason',
fixed: 'right', fixed: 'right',
render: (text, record, index) => { render: (text, record, index) => {
// 仅当为视频生成任务且成功,且 fail_reason 是 URL 时显示可点击链接
const isVideoTask = record.action === 'generate';
const isSuccess = record.status === 'SUCCESS';
const isUrl = typeof text === 'string' && /^https?:\/\//.test(text);
if (isSuccess && isVideoTask && isUrl) {
return (
<a href={text} target="_blank" rel="noopener noreferrer">
{t('点击预览视频')}
</a>
);
}
if (!text) { if (!text) {
return t('无'); return t('无');
} }

View File

@@ -125,4 +125,9 @@ export const CHANNEL_OPTIONS = [
color: 'blue', color: 'blue',
label: 'Coze', label: 'Coze',
}, },
{
value: 50,
color: 'green',
label: '可灵',
},
]; ];

View File

@@ -1 +1,3 @@
export const ITEMS_PER_PAGE = 10; // this value must keep same as the one defined in backend! export const ITEMS_PER_PAGE = 10; // this value must keep same as the one defined in backend!
export const DEFAULT_ENDPOINT = '/api/ratio_config';

View File

@@ -1665,5 +1665,28 @@
"确定清除所有失效兑换码?": "Are you sure you want to clear all invalid redemption codes?", "确定清除所有失效兑换码?": "Are you sure you want to clear all invalid redemption codes?",
"将删除已使用、已禁用及过期的兑换码,此操作不可撤销。": "This will delete all used, disabled, and expired redemption codes, this operation cannot be undone.", "将删除已使用、已禁用及过期的兑换码,此操作不可撤销。": "This will delete all used, disabled, and expired redemption codes, this operation cannot be undone.",
"选择过期时间(可选,留空为永久)": "Select expiration time (optional, leave blank for permanent)", "选择过期时间(可选,留空为永久)": "Select expiration time (optional, leave blank for permanent)",
"请输入备注(仅管理员可见)": "Please enter a remark (only visible to administrators)" "请输入备注(仅管理员可见)": "Please enter a remark (only visible to administrators)",
"上游倍率同步": "Upstream ratio synchronization",
"获取渠道失败:": "Failed to get channels: ",
"请至少选择一个渠道": "Please select at least one channel",
"获取倍率失败:": "Failed to get ratios: ",
"后端请求失败": "Backend request failed",
"部分渠道测试失败:": "Some channels failed to test: ",
"已与上游倍率完全一致,无需同步": "The upstream ratio is completely consistent, no synchronization is required",
"请求后端接口失败:": "Failed to request the backend interface: ",
"同步成功": "Synchronization successful",
"部分保存失败": "Some settings failed to save",
"保存失败": "Save failed",
"选择同步渠道": "Select synchronization channel",
"应用同步": "Apply synchronization",
"倍率类型": "Ratio type",
"当前值": "Current value",
"上游值": "Upstream value",
"差异": "Difference",
"搜索渠道名称或地址": "Search channel name or address",
"缓存倍率": "Cache ratio",
"暂无差异化倍率显示": "No differential ratio display",
"请先选择同步渠道": "Please select the synchronization channel first",
"与本地相同": "Same as local",
"未找到匹配的模型": "No matching model found"
} }

View File

@@ -432,4 +432,72 @@ code {
.semi-table-tbody>.semi-table-row { .semi-table-tbody>.semi-table-row {
border-bottom: 1px solid rgba(0, 0, 0, 0.1); border-bottom: 1px solid rgba(0, 0, 0, 0.1);
} }
}
/* ==================== 同步倍率 - 渠道选择器 ==================== */
.components-transfer-source-item,
.components-transfer-selected-item {
display: flex;
align-items: center;
padding: 8px;
}
.semi-transfer-left-list,
.semi-transfer-right-list {
-ms-overflow-style: none;
scrollbar-width: none;
}
.semi-transfer-left-list::-webkit-scrollbar,
.semi-transfer-right-list::-webkit-scrollbar {
display: none;
}
.components-transfer-source-item .semi-checkbox,
.components-transfer-selected-item .semi-checkbox {
display: flex;
align-items: center;
width: 100%;
}
.components-transfer-source-item .semi-avatar,
.components-transfer-selected-item .semi-avatar {
margin-right: 12px;
flex-shrink: 0;
}
.components-transfer-source-item .info,
.components-transfer-selected-item .info {
flex: 1;
overflow: hidden;
display: flex;
flex-direction: column;
justify-content: center;
}
.components-transfer-source-item .name,
.components-transfer-selected-item .name {
font-weight: 500;
white-space: nowrap;
overflow: hidden;
text-overflow: ellipsis;
}
.components-transfer-source-item .email,
.components-transfer-selected-item .email {
font-size: 12px;
color: var(--semi-color-text-2);
display: flex;
align-items: center;
}
.components-transfer-selected-item .semi-icon-close {
margin-left: 8px;
cursor: pointer;
color: var(--semi-color-text-2);
}
.components-transfer-selected-item .semi-icon-close:hover {
color: var(--semi-color-text-0);
} }

View File

@@ -1112,7 +1112,6 @@ const Detail = (props) => {
</div> </div>
<Tabs <Tabs
type="button" type="button"
preventScroll={true}
activeKey={activeChartTab} activeKey={activeChartTab}
onChange={setActiveChartTab} onChange={setActiveChartTab}
> >
@@ -1389,7 +1388,6 @@ const Detail = (props) => {
) : ( ) : (
<Tabs <Tabs
type="card" type="card"
preventScroll={true}
collapsible collapsible
activeKey={activeUptimeTab} activeKey={activeUptimeTab}
onChange={setActiveUptimeTab} onChange={setActiveUptimeTab}

View File

@@ -25,6 +25,7 @@ export default function ModelRatioSettings(props) {
ModelRatio: '', ModelRatio: '',
CacheRatio: '', CacheRatio: '',
CompletionRatio: '', CompletionRatio: '',
ExposeRatioEnabled: false,
}); });
const refForm = useRef(); const refForm = useRef();
const [inputsRow, setInputsRow] = useState(inputs); const [inputsRow, setInputsRow] = useState(inputs);
@@ -206,6 +207,17 @@ export default function ModelRatioSettings(props) {
/> />
</Col> </Col>
</Row> </Row>
<Row gutter={16}>
<Col span={16}>
<Form.Switch
label={t('暴露倍率接口')}
field={'ExposeRatioEnabled'}
onChange={(value) =>
setInputs({ ...inputs, ExposeRatioEnabled: value })
}
/>
</Col>
</Row>
</Form.Section> </Form.Section>
</Form> </Form>
<Space> <Space>

View File

@@ -0,0 +1,503 @@
import React, { useState, useCallback, useMemo } from 'react';
import {
Button,
Table,
Tag,
Empty,
Checkbox,
Form,
Input,
} from '@douyinfe/semi-ui';
import { IconSearch } from '@douyinfe/semi-icons';
import {
RefreshCcw,
CheckSquare,
} from 'lucide-react';
import { API, showError, showSuccess, showWarning, stringToColor } from '../../../helpers';
import { DEFAULT_ENDPOINT } from '../../../constants';
import { useTranslation } from 'react-i18next';
import {
IllustrationNoResult,
IllustrationNoResultDark
} from '@douyinfe/semi-illustrations';
import ChannelSelectorModal from '../../../components/settings/ChannelSelectorModal';
export default function UpstreamRatioSync(props) {
const { t } = useTranslation();
const [modalVisible, setModalVisible] = useState(false);
const [loading, setLoading] = useState(false);
const [syncLoading, setSyncLoading] = useState(false);
// 渠道选择相关
const [allChannels, setAllChannels] = useState([]);
const [selectedChannelIds, setSelectedChannelIds] = useState([]);
// 渠道端点配置
const [channelEndpoints, setChannelEndpoints] = useState({}); // { channelId: endpoint }
// 差异数据和测试结果
const [differences, setDifferences] = useState({});
const [resolutions, setResolutions] = useState({});
// 是否已经执行过同步
const [hasSynced, setHasSynced] = useState(false);
// 分页相关状态
const [currentPage, setCurrentPage] = useState(1);
const [pageSize, setPageSize] = useState(10);
// 搜索相关状态
const [searchKeyword, setSearchKeyword] = useState('');
const fetchAllChannels = async () => {
setLoading(true);
try {
const res = await API.get('/api/ratio_sync/channels');
if (res.data.success) {
const channels = res.data.data || [];
const transferData = channels.map(channel => ({
key: channel.id,
label: channel.name,
value: channel.id,
disabled: false,
_originalData: channel,
}));
setAllChannels(transferData);
const initialEndpoints = {};
transferData.forEach(channel => {
initialEndpoints[channel.key] = DEFAULT_ENDPOINT;
});
setChannelEndpoints(initialEndpoints);
} else {
showError(res.data.message);
}
} catch (error) {
showError(t('获取渠道失败:') + error.message);
} finally {
setLoading(false);
}
};
const confirmChannelSelection = () => {
const selected = allChannels
.filter(ch => selectedChannelIds.includes(ch.value))
.map(ch => ch._originalData);
if (selected.length === 0) {
showWarning(t('请至少选择一个渠道'));
return;
}
setModalVisible(false);
fetchRatiosFromChannels(selected);
};
const fetchRatiosFromChannels = async (channelList) => {
setSyncLoading(true);
const payload = {
channel_ids: channelList.map(ch => parseInt(ch.id)),
timeout: 10,
};
try {
const res = await API.post('/api/ratio_sync/fetch', payload);
if (!res.data.success) {
showError(res.data.message || t('后端请求失败'));
setSyncLoading(false);
return;
}
const { differences = {}, test_results = [] } = res.data.data;
const errorResults = test_results.filter(r => r.status === 'error');
if (errorResults.length > 0) {
showWarning(t('部分渠道测试失败:') + errorResults.map(r => `${r.name}: ${r.error}`).join(', '));
}
setDifferences(differences);
setResolutions({});
setHasSynced(true);
if (Object.keys(differences).length === 0) {
showSuccess(t('已与上游倍率完全一致,无需同步'));
}
} catch (e) {
showError(t('请求后端接口失败:') + e.message);
} finally {
setSyncLoading(false);
}
};
const selectValue = (model, ratioType, value) => {
setResolutions(prev => ({
...prev,
[model]: {
...prev[model],
[ratioType]: value,
},
}));
};
const applySync = async () => {
const currentRatios = {
ModelRatio: JSON.parse(props.options.ModelRatio || '{}'),
CompletionRatio: JSON.parse(props.options.CompletionRatio || '{}'),
CacheRatio: JSON.parse(props.options.CacheRatio || '{}'),
ModelPrice: JSON.parse(props.options.ModelPrice || '{}'),
};
Object.entries(resolutions).forEach(([model, ratios]) => {
Object.entries(ratios).forEach(([ratioType, value]) => {
const optionKey = ratioType
.split('_')
.map(word => word.charAt(0).toUpperCase() + word.slice(1))
.join('');
currentRatios[optionKey][model] = parseFloat(value);
});
});
setLoading(true);
try {
const updates = Object.entries(currentRatios).map(([key, value]) =>
API.put('/api/option/', {
key,
value: JSON.stringify(value, null, 2),
})
);
const results = await Promise.all(updates);
if (results.every(res => res.data.success)) {
showSuccess(t('同步成功'));
props.refresh();
setDifferences(prevDifferences => {
const newDifferences = { ...prevDifferences };
Object.entries(resolutions).forEach(([model, ratios]) => {
Object.keys(ratios).forEach(ratioType => {
if (newDifferences[model] && newDifferences[model][ratioType]) {
delete newDifferences[model][ratioType];
if (Object.keys(newDifferences[model]).length === 0) {
delete newDifferences[model];
}
}
});
});
return newDifferences;
});
setResolutions({});
} else {
showError(t('部分保存失败'));
}
} catch (error) {
showError(t('保存失败'));
} finally {
setLoading(false);
}
};
const getCurrentPageData = (dataSource) => {
const startIndex = (currentPage - 1) * pageSize;
const endIndex = startIndex + pageSize;
return dataSource.slice(startIndex, endIndex);
};
const renderHeader = () => (
<div className="flex flex-col w-full">
<div className="flex flex-col md:flex-row justify-between items-center gap-4 w-full">
<div className="flex gap-2 w-full md:w-auto order-2 md:order-1">
<Button
icon={<RefreshCcw size={14} />}
className="!rounded-full w-full md:w-auto mt-2"
onClick={() => {
setModalVisible(true);
fetchAllChannels();
}}
>
{t('选择同步渠道')}
</Button>
{(() => {
const hasSelections = Object.keys(resolutions).length > 0;
return (
<Button
icon={<CheckSquare size={14} />}
type='secondary'
onClick={applySync}
disabled={!hasSelections}
className="!rounded-full w-full md:w-auto mt-2"
>
{t('应用同步')}
</Button>
);
})()}
<Input
prefix={<IconSearch size={14} />}
placeholder={t('搜索模型名称')}
value={searchKeyword}
onChange={setSearchKeyword}
className="!rounded-full w-full md:w-64 mt-2"
showClear
/>
</div>
</div>
</div>
);
const renderDifferenceTable = () => {
const dataSource = useMemo(() => {
const tmp = [];
Object.entries(differences).forEach(([model, ratioTypes]) => {
Object.entries(ratioTypes).forEach(([ratioType, diff]) => {
tmp.push({
key: `${model}_${ratioType}`,
model,
ratioType,
current: diff.current,
upstreams: diff.upstreams,
});
});
});
return tmp;
}, [differences]);
const filteredDataSource = useMemo(() => {
if (!searchKeyword.trim()) {
return dataSource;
}
const keyword = searchKeyword.toLowerCase().trim();
return dataSource.filter(item =>
item.model.toLowerCase().includes(keyword)
);
}, [dataSource, searchKeyword]);
const upstreamNames = useMemo(() => {
const set = new Set();
filteredDataSource.forEach((row) => {
Object.keys(row.upstreams || {}).forEach((name) => set.add(name));
});
return Array.from(set);
}, [filteredDataSource]);
if (filteredDataSource.length === 0) {
return (
<Empty
image={<IllustrationNoResult style={{ width: 150, height: 150 }} />}
darkModeImage={<IllustrationNoResultDark style={{ width: 150, height: 150 }} />}
description={
searchKeyword.trim()
? t('未找到匹配的模型')
: (Object.keys(differences).length === 0 ?
(hasSynced ? t('暂无差异化倍率显示') : t('请先选择同步渠道'))
: t('请先选择同步渠道'))
}
style={{ padding: 30 }}
/>
);
}
const columns = [
{
title: t('模型'),
dataIndex: 'model',
fixed: 'left',
},
{
title: t('倍率类型'),
dataIndex: 'ratioType',
render: (text) => {
const typeMap = {
model_ratio: t('模型倍率'),
completion_ratio: t('补全倍率'),
cache_ratio: t('缓存倍率'),
model_price: t('固定价格'),
};
return <Tag color={stringToColor(text)} shape="circle">{typeMap[text] || text}</Tag>;
},
},
{
title: t('当前值'),
dataIndex: 'current',
render: (text) => (
<Tag color={text !== null && text !== undefined ? 'blue' : 'default'} shape="circle">
{text !== null && text !== undefined ? text : t('未设置')}
</Tag>
),
},
...upstreamNames.map((upName) => {
const channelStats = (() => {
let selectableCount = 0;
let selectedCount = 0;
filteredDataSource.forEach((row) => {
const upstreamVal = row.upstreams?.[upName];
if (upstreamVal !== null && upstreamVal !== undefined && upstreamVal !== 'same') {
selectableCount++;
const isSelected = resolutions[row.model]?.[row.ratioType] === upstreamVal;
if (isSelected) {
selectedCount++;
}
}
});
return {
selectableCount,
selectedCount,
allSelected: selectableCount > 0 && selectedCount === selectableCount,
partiallySelected: selectedCount > 0 && selectedCount < selectableCount,
hasSelectableItems: selectableCount > 0
};
})();
const handleBulkSelect = (checked) => {
setResolutions((prev) => {
const newRes = { ...prev };
filteredDataSource.forEach((row) => {
const upstreamVal = row.upstreams?.[upName];
if (upstreamVal !== null && upstreamVal !== undefined && upstreamVal !== 'same') {
if (checked) {
if (!newRes[row.model]) newRes[row.model] = {};
newRes[row.model][row.ratioType] = upstreamVal;
} else {
if (newRes[row.model]) {
delete newRes[row.model][row.ratioType];
if (Object.keys(newRes[row.model]).length === 0) {
delete newRes[row.model];
}
}
}
}
});
return newRes;
});
};
return {
title: channelStats.hasSelectableItems ? (
<Checkbox
checked={channelStats.allSelected}
indeterminate={channelStats.partiallySelected}
onChange={(e) => handleBulkSelect(e.target.checked)}
>
{upName}
</Checkbox>
) : (
<span>{upName}</span>
),
dataIndex: upName,
render: (_, record) => {
const upstreamVal = record.upstreams?.[upName];
if (upstreamVal === null || upstreamVal === undefined) {
return <Tag color="default" shape="circle">{t('未设置')}</Tag>;
}
if (upstreamVal === 'same') {
return <Tag color="blue" shape="circle">{t('与本地相同')}</Tag>;
}
const isSelected = resolutions[record.model]?.[record.ratioType] === upstreamVal;
return (
<Checkbox
checked={isSelected}
onChange={(e) => {
const isChecked = e.target.checked;
if (isChecked) {
selectValue(record.model, record.ratioType, upstreamVal);
} else {
setResolutions((prev) => {
const newRes = { ...prev };
if (newRes[record.model]) {
delete newRes[record.model][record.ratioType];
if (Object.keys(newRes[record.model]).length === 0) {
delete newRes[record.model];
}
}
return newRes;
});
}
}}
>
{upstreamVal}
</Checkbox>
);
},
};
}),
];
return (
<Table
columns={columns}
dataSource={getCurrentPageData(filteredDataSource)}
pagination={{
currentPage: currentPage,
pageSize: pageSize,
total: filteredDataSource.length,
showSizeChanger: true,
showQuickJumper: true,
formatPageText: (page) => t('第 {{start}} - {{end}} 条,共 {{total}} 条', {
start: page.currentStart,
end: page.currentEnd,
total: filteredDataSource.length,
}),
pageSizeOptions: ['5', '10', '20', '50'],
onChange: (page, size) => {
setCurrentPage(page);
setPageSize(size);
},
onShowSizeChange: (current, size) => {
setCurrentPage(1);
setPageSize(size);
}
}}
scroll={{ x: 'max-content' }}
size='middle'
loading={loading || syncLoading}
className="rounded-xl overflow-hidden"
/>
);
};
const updateChannelEndpoint = useCallback((channelId, endpoint) => {
setChannelEndpoints(prev => ({ ...prev, [channelId]: endpoint }));
}, []);
return (
<>
<Form.Section text={renderHeader()}>
{renderDifferenceTable()}
</Form.Section>
<ChannelSelectorModal
t={t}
visible={modalVisible}
onCancel={() => setModalVisible(false)}
onOk={confirmChannelSelection}
allChannels={allChannels}
selectedChannelIds={selectedChannelIds}
setSelectedChannelIds={setSelectedChannelIds}
channelEndpoints={channelEndpoints}
updateChannelEndpoint={updateChannelEndpoint}
/>
</>
);
}