Merge branch 'alpha'
This commit is contained in:
@@ -241,6 +241,7 @@ const (
|
|||||||
ChannelTypeXinference = 47
|
ChannelTypeXinference = 47
|
||||||
ChannelTypeXai = 48
|
ChannelTypeXai = 48
|
||||||
ChannelTypeCoze = 49
|
ChannelTypeCoze = 49
|
||||||
|
ChannelTypeKling = 50
|
||||||
ChannelTypeDummy // this one is only for count, do not add any channel after this
|
ChannelTypeDummy // this one is only for count, do not add any channel after this
|
||||||
|
|
||||||
)
|
)
|
||||||
@@ -296,4 +297,5 @@ var ChannelBaseURLs = []string{
|
|||||||
"", //47
|
"", //47
|
||||||
"https://api.x.ai", //48
|
"https://api.x.ai", //48
|
||||||
"https://api.coze.cn", //49
|
"https://api.coze.cn", //49
|
||||||
|
"https://api.klingai.com", //50
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import (
|
|||||||
"math/big"
|
"math/big"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"runtime"
|
"runtime"
|
||||||
@@ -284,3 +285,20 @@ func GetAudioDuration(ctx context.Context, filename string, ext string) (float64
|
|||||||
}
|
}
|
||||||
return strconv.ParseFloat(durationStr, 64)
|
return strconv.ParseFloat(durationStr, 64)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BuildURL concatenates base and endpoint, returns the complete url string
|
||||||
|
func BuildURL(base string, endpoint string) string {
|
||||||
|
u, err := url.Parse(base)
|
||||||
|
if err != nil {
|
||||||
|
return base + endpoint
|
||||||
|
}
|
||||||
|
end := endpoint
|
||||||
|
if end == "" {
|
||||||
|
end = "/"
|
||||||
|
}
|
||||||
|
ref, err := url.Parse(end)
|
||||||
|
if err != nil {
|
||||||
|
return base + endpoint
|
||||||
|
}
|
||||||
|
return u.ResolveReference(ref).String()
|
||||||
|
}
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ type TaskPlatform string
|
|||||||
const (
|
const (
|
||||||
TaskPlatformSuno TaskPlatform = "suno"
|
TaskPlatformSuno TaskPlatform = "suno"
|
||||||
TaskPlatformMidjourney = "mj"
|
TaskPlatformMidjourney = "mj"
|
||||||
|
TaskPlatformKling TaskPlatform = "kling"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
|||||||
@@ -40,6 +40,9 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
|||||||
if channel.Type == common.ChannelTypeSunoAPI {
|
if channel.Type == common.ChannelTypeSunoAPI {
|
||||||
return errors.New("suno channel test is not supported"), nil
|
return errors.New("suno channel test is not supported"), nil
|
||||||
}
|
}
|
||||||
|
if channel.Type == common.ChannelTypeKling {
|
||||||
|
return errors.New("kling channel test is not supported"), nil
|
||||||
|
}
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
c, _ := gin.CreateTestContext(w)
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
|
||||||
@@ -90,7 +93,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
|||||||
|
|
||||||
info := relaycommon.GenRelayInfo(c)
|
info := relaycommon.GenRelayInfo(c)
|
||||||
|
|
||||||
err = helper.ModelMappedHelper(c, info)
|
err = helper.ModelMappedHelper(c, info, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err, nil
|
return err, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package controller
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/samber/lo"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
@@ -136,6 +137,9 @@ func init() {
|
|||||||
adaptor.Init(meta)
|
adaptor.Init(meta)
|
||||||
channelId2Models[i] = adaptor.GetModelList()
|
channelId2Models[i] = adaptor.GetModelList()
|
||||||
}
|
}
|
||||||
|
openAIModels = lo.UniqBy(openAIModels, func(m dto.OpenAIModels) string {
|
||||||
|
return m.Id
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func ListModels(c *gin.Context) {
|
func ListModels(c *gin.Context) {
|
||||||
|
|||||||
24
controller/ratio_config.go
Normal file
24
controller/ratio_config.go
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"one-api/setting/ratio_setting"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
func GetRatioConfig(c *gin.Context) {
|
||||||
|
if !ratio_setting.IsExposeRatioEnabled() {
|
||||||
|
c.JSON(http.StatusForbidden, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "倍率配置接口未启用",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"message": "",
|
||||||
|
"data": ratio_setting.GetExposedData(),
|
||||||
|
})
|
||||||
|
}
|
||||||
322
controller/ratio_sync.go
Normal file
322
controller/ratio_sync.go
Normal file
@@ -0,0 +1,322 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/dto"
|
||||||
|
"one-api/model"
|
||||||
|
"one-api/setting/ratio_setting"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultTimeoutSeconds = 10
|
||||||
|
defaultEndpoint = "/api/ratio_config"
|
||||||
|
maxConcurrentFetches = 8
|
||||||
|
)
|
||||||
|
|
||||||
|
var ratioTypes = []string{"model_ratio", "completion_ratio", "cache_ratio", "model_price"}
|
||||||
|
|
||||||
|
type upstreamResult struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Data map[string]any `json:"data,omitempty"`
|
||||||
|
Err string `json:"err,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func FetchUpstreamRatios(c *gin.Context) {
|
||||||
|
var req dto.UpstreamRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Timeout <= 0 {
|
||||||
|
req.Timeout = defaultTimeoutSeconds
|
||||||
|
}
|
||||||
|
|
||||||
|
var upstreams []dto.UpstreamDTO
|
||||||
|
|
||||||
|
if len(req.ChannelIDs) > 0 {
|
||||||
|
intIds := make([]int, 0, len(req.ChannelIDs))
|
||||||
|
for _, id64 := range req.ChannelIDs {
|
||||||
|
intIds = append(intIds, int(id64))
|
||||||
|
}
|
||||||
|
dbChannels, err := model.GetChannelsByIds(intIds)
|
||||||
|
if err != nil {
|
||||||
|
common.LogError(c.Request.Context(), "failed to query channels: "+err.Error())
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": "查询渠道失败"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, ch := range dbChannels {
|
||||||
|
if base := ch.GetBaseURL(); strings.HasPrefix(base, "http") {
|
||||||
|
upstreams = append(upstreams, dto.UpstreamDTO{
|
||||||
|
Name: ch.Name,
|
||||||
|
BaseURL: strings.TrimRight(base, "/"),
|
||||||
|
Endpoint: "",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(upstreams) == 0 {
|
||||||
|
c.JSON(http.StatusOK, gin.H{"success": false, "message": "无有效上游渠道"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
ch := make(chan upstreamResult, len(upstreams))
|
||||||
|
|
||||||
|
sem := make(chan struct{}, maxConcurrentFetches)
|
||||||
|
|
||||||
|
client := &http.Client{Transport: &http.Transport{MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second}}
|
||||||
|
|
||||||
|
for _, chn := range upstreams {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(chItem dto.UpstreamDTO) {
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
|
sem <- struct{}{}
|
||||||
|
defer func() { <-sem }()
|
||||||
|
|
||||||
|
endpoint := chItem.Endpoint
|
||||||
|
if endpoint == "" {
|
||||||
|
endpoint = defaultEndpoint
|
||||||
|
} else if !strings.HasPrefix(endpoint, "/") {
|
||||||
|
endpoint = "/" + endpoint
|
||||||
|
}
|
||||||
|
fullURL := chItem.BaseURL + endpoint
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(c.Request.Context(), time.Duration(req.Timeout)*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil)
|
||||||
|
if err != nil {
|
||||||
|
common.LogWarn(c.Request.Context(), "build request failed: "+err.Error())
|
||||||
|
ch <- upstreamResult{Name: chItem.Name, Err: err.Error()}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := client.Do(httpReq)
|
||||||
|
if err != nil {
|
||||||
|
common.LogWarn(c.Request.Context(), "http error on "+chItem.Name+": "+err.Error())
|
||||||
|
ch <- upstreamResult{Name: chItem.Name, Err: err.Error()}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
common.LogWarn(c.Request.Context(), "non-200 from "+chItem.Name+": "+resp.Status)
|
||||||
|
ch <- upstreamResult{Name: chItem.Name, Err: resp.Status}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var body struct {
|
||||||
|
Success bool `json:"success"`
|
||||||
|
Data map[string]any `json:"data"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&body); err != nil {
|
||||||
|
common.LogWarn(c.Request.Context(), "json decode failed from "+chItem.Name+": "+err.Error())
|
||||||
|
ch <- upstreamResult{Name: chItem.Name, Err: err.Error()}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !body.Success {
|
||||||
|
ch <- upstreamResult{Name: chItem.Name, Err: body.Message}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ch <- upstreamResult{Name: chItem.Name, Data: body.Data}
|
||||||
|
}(chn)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
close(ch)
|
||||||
|
|
||||||
|
localData := ratio_setting.GetExposedData()
|
||||||
|
|
||||||
|
var testResults []dto.TestResult
|
||||||
|
var successfulChannels []struct {
|
||||||
|
name string
|
||||||
|
data map[string]any
|
||||||
|
}
|
||||||
|
|
||||||
|
for r := range ch {
|
||||||
|
if r.Err != "" {
|
||||||
|
testResults = append(testResults, dto.TestResult{
|
||||||
|
Name: r.Name,
|
||||||
|
Status: "error",
|
||||||
|
Error: r.Err,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
testResults = append(testResults, dto.TestResult{
|
||||||
|
Name: r.Name,
|
||||||
|
Status: "success",
|
||||||
|
})
|
||||||
|
successfulChannels = append(successfulChannels, struct {
|
||||||
|
name string
|
||||||
|
data map[string]any
|
||||||
|
}{name: r.Name, data: r.Data})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
differences := buildDifferences(localData, successfulChannels)
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"data": gin.H{
|
||||||
|
"differences": differences,
|
||||||
|
"test_results": testResults,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildDifferences(localData map[string]any, successfulChannels []struct {
|
||||||
|
name string
|
||||||
|
data map[string]any
|
||||||
|
}) map[string]map[string]dto.DifferenceItem {
|
||||||
|
differences := make(map[string]map[string]dto.DifferenceItem)
|
||||||
|
|
||||||
|
allModels := make(map[string]struct{})
|
||||||
|
|
||||||
|
for _, ratioType := range ratioTypes {
|
||||||
|
if localRatioAny, ok := localData[ratioType]; ok {
|
||||||
|
if localRatio, ok := localRatioAny.(map[string]float64); ok {
|
||||||
|
for modelName := range localRatio {
|
||||||
|
allModels[modelName] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, channel := range successfulChannels {
|
||||||
|
for _, ratioType := range ratioTypes {
|
||||||
|
if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
|
||||||
|
for modelName := range upstreamRatio {
|
||||||
|
allModels[modelName] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for modelName := range allModels {
|
||||||
|
for _, ratioType := range ratioTypes {
|
||||||
|
var localValue interface{} = nil
|
||||||
|
if localRatioAny, ok := localData[ratioType]; ok {
|
||||||
|
if localRatio, ok := localRatioAny.(map[string]float64); ok {
|
||||||
|
if val, exists := localRatio[modelName]; exists {
|
||||||
|
localValue = val
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
upstreamValues := make(map[string]interface{})
|
||||||
|
hasUpstreamValue := false
|
||||||
|
hasDifference := false
|
||||||
|
|
||||||
|
for _, channel := range successfulChannels {
|
||||||
|
var upstreamValue interface{} = nil
|
||||||
|
|
||||||
|
if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
|
||||||
|
if val, exists := upstreamRatio[modelName]; exists {
|
||||||
|
upstreamValue = val
|
||||||
|
hasUpstreamValue = true
|
||||||
|
|
||||||
|
if localValue != nil && localValue != val {
|
||||||
|
hasDifference = true
|
||||||
|
} else if localValue == val {
|
||||||
|
upstreamValue = "same"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if upstreamValue == nil && localValue == nil {
|
||||||
|
upstreamValue = "same"
|
||||||
|
}
|
||||||
|
|
||||||
|
if localValue == nil && upstreamValue != nil && upstreamValue != "same" {
|
||||||
|
hasDifference = true
|
||||||
|
}
|
||||||
|
|
||||||
|
upstreamValues[channel.name] = upstreamValue
|
||||||
|
}
|
||||||
|
|
||||||
|
shouldInclude := false
|
||||||
|
|
||||||
|
if localValue != nil {
|
||||||
|
if hasDifference {
|
||||||
|
shouldInclude = true
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if hasUpstreamValue {
|
||||||
|
shouldInclude = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if shouldInclude {
|
||||||
|
if differences[modelName] == nil {
|
||||||
|
differences[modelName] = make(map[string]dto.DifferenceItem)
|
||||||
|
}
|
||||||
|
differences[modelName][ratioType] = dto.DifferenceItem{
|
||||||
|
Current: localValue,
|
||||||
|
Upstreams: upstreamValues,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
channelHasDiff := make(map[string]bool)
|
||||||
|
for _, ratioMap := range differences {
|
||||||
|
for _, item := range ratioMap {
|
||||||
|
for chName, val := range item.Upstreams {
|
||||||
|
if val != nil && val != "same" {
|
||||||
|
channelHasDiff[chName] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for modelName, ratioMap := range differences {
|
||||||
|
for ratioType, item := range ratioMap {
|
||||||
|
for chName := range item.Upstreams {
|
||||||
|
if !channelHasDiff[chName] {
|
||||||
|
delete(item.Upstreams, chName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
differences[modelName][ratioType] = item
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return differences
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetSyncableChannels(c *gin.Context) {
|
||||||
|
channels, err := model.GetAllChannels(0, 0, true, false)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var syncableChannels []dto.SyncableChannel
|
||||||
|
for _, channel := range channels {
|
||||||
|
if channel.GetBaseURL() != "" {
|
||||||
|
syncableChannels = append(syncableChannels, dto.SyncableChannel{
|
||||||
|
ID: channel.Id,
|
||||||
|
Name: channel.Name,
|
||||||
|
BaseURL: channel.GetBaseURL(),
|
||||||
|
Status: channel.Status,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"message": "",
|
||||||
|
"data": syncableChannels,
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -420,7 +420,7 @@ func RelayTask(c *gin.Context) {
|
|||||||
func taskRelayHandler(c *gin.Context, relayMode int) *dto.TaskError {
|
func taskRelayHandler(c *gin.Context, relayMode int) *dto.TaskError {
|
||||||
var err *dto.TaskError
|
var err *dto.TaskError
|
||||||
switch relayMode {
|
switch relayMode {
|
||||||
case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID:
|
case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID, relayconstant.RelayModeKlingFetchByID:
|
||||||
err = relay.RelayTaskFetch(c, relayMode)
|
err = relay.RelayTaskFetch(c, relayMode)
|
||||||
default:
|
default:
|
||||||
err = relay.RelayTaskSubmit(c, relayMode)
|
err = relay.RelayTaskSubmit(c, relayMode)
|
||||||
|
|||||||
@@ -74,6 +74,8 @@ func UpdateTaskByPlatform(platform constant.TaskPlatform, taskChannelM map[int][
|
|||||||
//_ = UpdateMidjourneyTaskAll(context.Background(), tasks)
|
//_ = UpdateMidjourneyTaskAll(context.Background(), tasks)
|
||||||
case constant.TaskPlatformSuno:
|
case constant.TaskPlatformSuno:
|
||||||
_ = UpdateSunoTaskAll(context.Background(), taskChannelM, taskM)
|
_ = UpdateSunoTaskAll(context.Background(), taskChannelM, taskM)
|
||||||
|
case constant.TaskPlatformKling:
|
||||||
|
_ = UpdateVideoTaskAll(context.Background(), taskChannelM, taskM)
|
||||||
default:
|
default:
|
||||||
common.SysLog("未知平台")
|
common.SysLog("未知平台")
|
||||||
}
|
}
|
||||||
|
|||||||
140
controller/task_video.go
Normal file
140
controller/task_video.go
Normal file
@@ -0,0 +1,140 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/constant"
|
||||||
|
"one-api/model"
|
||||||
|
"one-api/relay"
|
||||||
|
"one-api/relay/channel"
|
||||||
|
)
|
||||||
|
|
||||||
|
func UpdateVideoTaskAll(ctx context.Context, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
|
||||||
|
for channelId, taskIds := range taskChannelM {
|
||||||
|
if err := updateVideoTaskAll(ctx, channelId, taskIds, taskM); err != nil {
|
||||||
|
common.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func updateVideoTaskAll(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error {
|
||||||
|
common.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds)))
|
||||||
|
if len(taskIds) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
cacheGetChannel, err := model.CacheGetChannel(channelId)
|
||||||
|
if err != nil {
|
||||||
|
errUpdate := model.TaskBulkUpdate(taskIds, map[string]any{
|
||||||
|
"fail_reason": fmt.Sprintf("Failed to get channel info, channel ID: %d", channelId),
|
||||||
|
"status": "FAILURE",
|
||||||
|
"progress": "100%",
|
||||||
|
})
|
||||||
|
if errUpdate != nil {
|
||||||
|
common.SysError(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate))
|
||||||
|
}
|
||||||
|
return fmt.Errorf("CacheGetChannel failed: %w", err)
|
||||||
|
}
|
||||||
|
adaptor := relay.GetTaskAdaptor(constant.TaskPlatformKling)
|
||||||
|
if adaptor == nil {
|
||||||
|
return fmt.Errorf("video adaptor not found")
|
||||||
|
}
|
||||||
|
for _, taskId := range taskIds {
|
||||||
|
if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil {
|
||||||
|
common.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, channel *model.Channel, taskId string, taskM map[string]*model.Task) error {
|
||||||
|
baseURL := common.ChannelBaseURLs[channel.Type]
|
||||||
|
if channel.GetBaseURL() != "" {
|
||||||
|
baseURL = channel.GetBaseURL()
|
||||||
|
}
|
||||||
|
resp, err := adaptor.FetchTask(baseURL, channel.Key, map[string]any{
|
||||||
|
"task_id": taskId,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("FetchTask failed for task %s: %w", taskId, err)
|
||||||
|
}
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return fmt.Errorf("Get Video Task status code: %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("ReadAll failed for task %s: %w", taskId, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var responseItem map[string]interface{}
|
||||||
|
err = json.Unmarshal(responseBody, &responseItem)
|
||||||
|
if err != nil {
|
||||||
|
common.LogError(ctx, fmt.Sprintf("Failed to parse video task response body: %v, body: %s", err, string(responseBody)))
|
||||||
|
return fmt.Errorf("Unmarshal failed for task %s: %w", taskId, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
code, _ := responseItem["code"].(float64)
|
||||||
|
if code != 0 {
|
||||||
|
return fmt.Errorf("video task fetch failed for task %s", taskId)
|
||||||
|
}
|
||||||
|
|
||||||
|
data, ok := responseItem["data"].(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
common.LogError(ctx, fmt.Sprintf("Video task data format error: %s", string(responseBody)))
|
||||||
|
return fmt.Errorf("video task data format error for task %s", taskId)
|
||||||
|
}
|
||||||
|
|
||||||
|
task := taskM[taskId]
|
||||||
|
if task == nil {
|
||||||
|
common.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId))
|
||||||
|
return fmt.Errorf("task %s not found", taskId)
|
||||||
|
}
|
||||||
|
|
||||||
|
if status, ok := data["task_status"].(string); ok {
|
||||||
|
switch status {
|
||||||
|
case "submitted", "queued":
|
||||||
|
task.Status = model.TaskStatusSubmitted
|
||||||
|
case "processing":
|
||||||
|
task.Status = model.TaskStatusInProgress
|
||||||
|
case "succeed":
|
||||||
|
task.Status = model.TaskStatusSuccess
|
||||||
|
task.Progress = "100%"
|
||||||
|
if url, err := adaptor.ParseResultUrl(responseItem); err == nil {
|
||||||
|
task.FailReason = url
|
||||||
|
} else {
|
||||||
|
common.LogWarn(ctx, fmt.Sprintf("Failed to get url from body for task %s: %s", task.TaskID, err.Error()))
|
||||||
|
}
|
||||||
|
case "failed":
|
||||||
|
task.Status = model.TaskStatusFailure
|
||||||
|
task.Progress = "100%"
|
||||||
|
if reason, ok := data["fail_reason"].(string); ok {
|
||||||
|
task.FailReason = reason
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If task failed, refund quota
|
||||||
|
if task.Status == model.TaskStatusFailure {
|
||||||
|
common.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason))
|
||||||
|
quota := task.Quota
|
||||||
|
if quota != 0 {
|
||||||
|
if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil {
|
||||||
|
common.LogError(ctx, "Failed to increase user quota: "+err.Error())
|
||||||
|
}
|
||||||
|
logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, common.LogQuota(quota))
|
||||||
|
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
task.Data = responseBody
|
||||||
|
if err := task.Update(); err != nil {
|
||||||
|
common.SysError("UpdateVideoTask task error: " + err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -97,14 +97,12 @@ func RequestEpay(c *gin.Context) {
|
|||||||
c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"})
|
c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
payType := "wxpay"
|
|
||||||
if req.PaymentMethod == "zfb" {
|
if !setting.ContainsPayMethod(req.PaymentMethod) {
|
||||||
payType = "alipay"
|
c.JSON(200, gin.H{"message": "error", "data": "支付方式不存在"})
|
||||||
}
|
return
|
||||||
if req.PaymentMethod == "wx" {
|
|
||||||
req.PaymentMethod = "wxpay"
|
|
||||||
payType = "wxpay"
|
|
||||||
}
|
}
|
||||||
|
|
||||||
callBackAddress := service.GetCallbackAddress()
|
callBackAddress := service.GetCallbackAddress()
|
||||||
returnUrl, _ := url.Parse(setting.ServerAddress + "/console/log")
|
returnUrl, _ := url.Parse(setting.ServerAddress + "/console/log")
|
||||||
notifyUrl, _ := url.Parse(callBackAddress + "/api/user/epay/notify")
|
notifyUrl, _ := url.Parse(callBackAddress + "/api/user/epay/notify")
|
||||||
@@ -116,7 +114,7 @@ func RequestEpay(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
uri, params, err := client.Purchase(&epay.PurchaseArgs{
|
uri, params, err := client.Purchase(&epay.PurchaseArgs{
|
||||||
Type: payType,
|
Type: req.PaymentMethod,
|
||||||
ServiceTradeNo: tradeNo,
|
ServiceTradeNo: tradeNo,
|
||||||
Name: fmt.Sprintf("TUC%d", req.Amount),
|
Name: fmt.Sprintf("TUC%d", req.Amount),
|
||||||
Money: strconv.FormatFloat(payMoney, 'f', 2, 64),
|
Money: strconv.FormatFloat(payMoney, 'f', 2, 64),
|
||||||
|
|||||||
49
dto/ratio_sync.go
Normal file
49
dto/ratio_sync.go
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
package dto
|
||||||
|
|
||||||
|
// UpstreamDTO 提交到后端同步倍率的上游渠道信息
|
||||||
|
// Endpoint 可以为空,后端会默认使用 /api/ratio_config
|
||||||
|
// BaseURL 必须以 http/https 开头,不要以 / 结尾
|
||||||
|
// 例如: https://api.example.com
|
||||||
|
// Endpoint: /api/ratio_config
|
||||||
|
// 提交示例:
|
||||||
|
// {
|
||||||
|
// "name": "openai",
|
||||||
|
// "base_url": "https://api.openai.com",
|
||||||
|
// "endpoint": "/ratio_config"
|
||||||
|
// }
|
||||||
|
|
||||||
|
type UpstreamDTO struct {
|
||||||
|
Name string `json:"name" binding:"required"`
|
||||||
|
BaseURL string `json:"base_url" binding:"required"`
|
||||||
|
Endpoint string `json:"endpoint"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type UpstreamRequest struct {
|
||||||
|
ChannelIDs []int64 `json:"channel_ids"`
|
||||||
|
Timeout int `json:"timeout"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestResult 上游测试连通性结果
|
||||||
|
type TestResult struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
Error string `json:"error,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// DifferenceItem 差异项
|
||||||
|
// Current 为本地值,可能为 nil
|
||||||
|
// Upstreams 为各渠道的上游值,具体数值 / "same" / nil
|
||||||
|
|
||||||
|
type DifferenceItem struct {
|
||||||
|
Current interface{} `json:"current"`
|
||||||
|
Upstreams map[string]interface{} `json:"upstreams"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SyncableChannel 可同步的渠道信息(base_url 不为空)
|
||||||
|
|
||||||
|
type SyncableChannel struct {
|
||||||
|
ID int `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
BaseURL string `json:"base_url"`
|
||||||
|
Status int `json:"status"`
|
||||||
|
}
|
||||||
47
dto/video.go
Normal file
47
dto/video.go
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
package dto
|
||||||
|
|
||||||
|
type VideoRequest struct {
|
||||||
|
Model string `json:"model,omitempty" example:"kling-v1"` // Model/style ID
|
||||||
|
Prompt string `json:"prompt,omitempty" example:"宇航员站起身走了"` // Text prompt
|
||||||
|
Image string `json:"image,omitempty" example:"https://h2.inkwai.com/bs2/upload-ylab-stunt/se/ai_portal_queue_mmu_image_upscale_aiweb/3214b798-e1b4-4b00-b7af-72b5b0417420_raw_image_0.jpg"` // Image input (URL/Base64)
|
||||||
|
Duration float64 `json:"duration" example:"5.0"` // Video duration (seconds)
|
||||||
|
Width int `json:"width" example:"512"` // Video width
|
||||||
|
Height int `json:"height" example:"512"` // Video height
|
||||||
|
Fps int `json:"fps,omitempty" example:"30"` // Video frame rate
|
||||||
|
Seed int `json:"seed,omitempty" example:"20231234"` // Random seed
|
||||||
|
N int `json:"n,omitempty" example:"1"` // Number of videos to generate
|
||||||
|
ResponseFormat string `json:"response_format,omitempty" example:"url"` // Response format
|
||||||
|
User string `json:"user,omitempty" example:"user-1234"` // User identifier
|
||||||
|
Metadata map[string]any `json:"metadata,omitempty"` // Vendor-specific/custom params (e.g. negative_prompt, style, quality_level, etc.)
|
||||||
|
}
|
||||||
|
|
||||||
|
// VideoResponse 视频生成提交任务后的响应
|
||||||
|
type VideoResponse struct {
|
||||||
|
TaskId string `json:"task_id"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// VideoTaskResponse 查询视频生成任务状态的响应
|
||||||
|
type VideoTaskResponse struct {
|
||||||
|
TaskId string `json:"task_id" example:"abcd1234efgh"` // 任务ID
|
||||||
|
Status string `json:"status" example:"succeeded"` // 任务状态
|
||||||
|
Url string `json:"url,omitempty"` // 视频资源URL(成功时)
|
||||||
|
Format string `json:"format,omitempty" example:"mp4"` // 视频格式
|
||||||
|
Metadata *VideoTaskMetadata `json:"metadata,omitempty"` // 结果元数据
|
||||||
|
Error *VideoTaskError `json:"error,omitempty"` // 错误信息(失败时)
|
||||||
|
}
|
||||||
|
|
||||||
|
// VideoTaskMetadata 视频任务元数据
|
||||||
|
type VideoTaskMetadata struct {
|
||||||
|
Duration float64 `json:"duration" example:"5.0"` // 实际生成的视频时长
|
||||||
|
Fps int `json:"fps" example:"30"` // 实际帧率
|
||||||
|
Width int `json:"width" example:"512"` // 实际宽度
|
||||||
|
Height int `json:"height" example:"512"` // 实际高度
|
||||||
|
Seed int `json:"seed" example:"20231234"` // 使用的随机种子
|
||||||
|
}
|
||||||
|
|
||||||
|
// VideoTaskError 视频任务错误信息
|
||||||
|
type VideoTaskError struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
@@ -170,6 +170,15 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
|
|||||||
}
|
}
|
||||||
c.Set("platform", string(constant.TaskPlatformSuno))
|
c.Set("platform", string(constant.TaskPlatformSuno))
|
||||||
c.Set("relay_mode", relayMode)
|
c.Set("relay_mode", relayMode)
|
||||||
|
} else if strings.Contains(c.Request.URL.Path, "/v1/video/generations") {
|
||||||
|
relayMode := relayconstant.Path2RelayKling(c.Request.Method, c.Request.URL.Path)
|
||||||
|
if relayMode == relayconstant.RelayModeKlingFetchByID {
|
||||||
|
shouldSelectChannel = false
|
||||||
|
} else {
|
||||||
|
err = common.UnmarshalBodyReusable(c, &modelRequest)
|
||||||
|
}
|
||||||
|
c.Set("platform", string(constant.TaskPlatformKling))
|
||||||
|
c.Set("relay_mode", relayMode)
|
||||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") {
|
} else if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") {
|
||||||
// Gemini API 路径处理: /v1beta/models/gemini-2.0-flash:generateContent
|
// Gemini API 路径处理: /v1beta/models/gemini-2.0-flash:generateContent
|
||||||
relayMode := relayconstant.RelayModeGemini
|
relayMode := relayconstant.RelayModeGemini
|
||||||
|
|||||||
@@ -127,6 +127,7 @@ func InitOptionMap() {
|
|||||||
common.OptionMap["SensitiveWords"] = setting.SensitiveWordsToString()
|
common.OptionMap["SensitiveWords"] = setting.SensitiveWordsToString()
|
||||||
common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(setting.StreamCacheQueueLength)
|
common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(setting.StreamCacheQueueLength)
|
||||||
common.OptionMap["AutomaticDisableKeywords"] = operation_setting.AutomaticDisableKeywordsToString()
|
common.OptionMap["AutomaticDisableKeywords"] = operation_setting.AutomaticDisableKeywordsToString()
|
||||||
|
common.OptionMap["ExposeRatioEnabled"] = strconv.FormatBool(ratio_setting.IsExposeRatioEnabled())
|
||||||
|
|
||||||
// 自动添加所有注册的模型配置
|
// 自动添加所有注册的模型配置
|
||||||
modelConfigs := config.GlobalConfig.ExportAllConfigs()
|
modelConfigs := config.GlobalConfig.ExportAllConfigs()
|
||||||
@@ -267,6 +268,8 @@ func updateOptionMap(key string, value string) (err error) {
|
|||||||
setting.WorkerAllowHttpImageRequestEnabled = boolValue
|
setting.WorkerAllowHttpImageRequestEnabled = boolValue
|
||||||
case "DefaultUseAutoGroup":
|
case "DefaultUseAutoGroup":
|
||||||
setting.DefaultUseAutoGroup = boolValue
|
setting.DefaultUseAutoGroup = boolValue
|
||||||
|
case "ExposeRatioEnabled":
|
||||||
|
ratio_setting.SetExposeRatioEnabled(boolValue)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
switch key {
|
switch key {
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
|
|||||||
}
|
}
|
||||||
|
|
||||||
func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||||
relayInfo := relaycommon.GenRelayInfo(c)
|
relayInfo := relaycommon.GenRelayInfoOpenAIAudio(c)
|
||||||
audioRequest, err := getAndValidAudioRequest(c, relayInfo)
|
audioRequest, err := getAndValidAudioRequest(c, relayInfo)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -66,10 +66,7 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
|||||||
promptTokens := 0
|
promptTokens := 0
|
||||||
preConsumedTokens := common.PreConsumedQuota
|
preConsumedTokens := common.PreConsumedQuota
|
||||||
if relayInfo.RelayMode == relayconstant.RelayModeAudioSpeech {
|
if relayInfo.RelayMode == relayconstant.RelayModeAudioSpeech {
|
||||||
promptTokens, err = service.CountTTSToken(audioRequest.Input, audioRequest.Model)
|
promptTokens = service.CountTTSToken(audioRequest.Input, audioRequest.Model)
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapper(err, "count_audio_token_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
preConsumedTokens = promptTokens
|
preConsumedTokens = promptTokens
|
||||||
relayInfo.PromptTokens = promptTokens
|
relayInfo.PromptTokens = promptTokens
|
||||||
}
|
}
|
||||||
@@ -89,13 +86,11 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
err = helper.ModelMappedHelper(c, relayInfo)
|
err = helper.ModelMappedHelper(c, relayInfo, audioRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
|
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
audioRequest.Model = relayInfo.UpstreamModelName
|
|
||||||
|
|
||||||
adaptor := GetAdaptor(relayInfo.ApiType)
|
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||||
if adaptor == nil {
|
if adaptor == nil {
|
||||||
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
|
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
|
||||||
@@ -44,4 +44,6 @@ type TaskAdaptor interface {
|
|||||||
|
|
||||||
// FetchTask
|
// FetchTask
|
||||||
FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error)
|
FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error)
|
||||||
|
|
||||||
|
ParseResultUrl(resp map[string]any) (string, error)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -549,7 +549,7 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
|
|||||||
func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, requestMode int) {
|
func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, requestMode int) {
|
||||||
|
|
||||||
if requestMode == RequestModeCompletion {
|
if requestMode == RequestModeCompletion {
|
||||||
claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
|
claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
|
||||||
} else {
|
} else {
|
||||||
if claudeInfo.Usage.PromptTokens == 0 {
|
if claudeInfo.Usage.PromptTokens == 0 {
|
||||||
//上游出错
|
//上游出错
|
||||||
@@ -558,7 +558,7 @@ func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, clau
|
|||||||
if common.DebugEnabled {
|
if common.DebugEnabled {
|
||||||
common.SysError("claude response usage is not complete, maybe upstream error")
|
common.SysError("claude response usage is not complete, maybe upstream error")
|
||||||
}
|
}
|
||||||
claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
|
claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -618,10 +618,7 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if requestMode == RequestModeCompletion {
|
if requestMode == RequestModeCompletion {
|
||||||
completionTokens, err := service.CountTextToken(claudeResponse.Completion, info.OriginModelName)
|
completionTokens := service.CountTextToken(claudeResponse.Completion, info.OriginModelName)
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
claudeInfo.Usage.PromptTokens = info.PromptTokens
|
claudeInfo.Usage.PromptTokens = info.PromptTokens
|
||||||
claudeInfo.Usage.CompletionTokens = completionTokens
|
claudeInfo.Usage.CompletionTokens = completionTokens
|
||||||
claudeInfo.Usage.TotalTokens = info.PromptTokens + completionTokens
|
claudeInfo.Usage.TotalTokens = info.PromptTokens + completionTokens
|
||||||
|
|||||||
@@ -71,7 +71,7 @@ func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
|
|||||||
if err := scanner.Err(); err != nil {
|
if err := scanner.Err(); err != nil {
|
||||||
common.LogError(c, "error_scanning_stream_response: "+err.Error())
|
common.LogError(c, "error_scanning_stream_response: "+err.Error())
|
||||||
}
|
}
|
||||||
usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
usage := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||||
if info.ShouldIncludeUsage {
|
if info.ShouldIncludeUsage {
|
||||||
response := helper.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage)
|
response := helper.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage)
|
||||||
err := helper.ObjectData(c, response)
|
err := helper.ObjectData(c, response)
|
||||||
@@ -108,7 +108,7 @@ func cfHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo)
|
|||||||
for _, choice := range response.Choices {
|
for _, choice := range response.Choices {
|
||||||
responseText += choice.Message.StringContent()
|
responseText += choice.Message.StringContent()
|
||||||
}
|
}
|
||||||
usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
usage := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||||
response.Usage = *usage
|
response.Usage = *usage
|
||||||
response.Id = helper.GetResponseID(c)
|
response.Id = helper.GetResponseID(c)
|
||||||
jsonResponse, err := json.Marshal(response)
|
jsonResponse, err := json.Marshal(response)
|
||||||
@@ -150,7 +150,7 @@ func cfSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayIn
|
|||||||
|
|
||||||
usage := &dto.Usage{}
|
usage := &dto.Usage{}
|
||||||
usage.PromptTokens = info.PromptTokens
|
usage.PromptTokens = info.PromptTokens
|
||||||
usage.CompletionTokens, _ = service.CountTextToken(cfResp.Result.Text, info.UpstreamModelName)
|
usage.CompletionTokens = service.CountTextToken(cfResp.Result.Text, info.UpstreamModelName)
|
||||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||||
|
|
||||||
return nil, usage
|
return nil, usage
|
||||||
|
|||||||
@@ -162,7 +162,7 @@ func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
if usage.PromptTokens == 0 {
|
if usage.PromptTokens == 0 {
|
||||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||||
}
|
}
|
||||||
return nil, usage
|
return nil, usage
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -106,7 +106,7 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo
|
|||||||
|
|
||||||
var currentEvent string
|
var currentEvent string
|
||||||
var currentData string
|
var currentData string
|
||||||
var usage dto.Usage
|
var usage = &dto.Usage{}
|
||||||
|
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
line := scanner.Text()
|
line := scanner.Text()
|
||||||
@@ -114,7 +114,7 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo
|
|||||||
if line == "" {
|
if line == "" {
|
||||||
if currentEvent != "" && currentData != "" {
|
if currentEvent != "" && currentData != "" {
|
||||||
// handle last event
|
// handle last event
|
||||||
handleCozeEvent(c, currentEvent, currentData, &responseText, &usage, id, info)
|
handleCozeEvent(c, currentEvent, currentData, &responseText, usage, id, info)
|
||||||
currentEvent = ""
|
currentEvent = ""
|
||||||
currentData = ""
|
currentData = ""
|
||||||
}
|
}
|
||||||
@@ -134,7 +134,7 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo
|
|||||||
|
|
||||||
// Last event
|
// Last event
|
||||||
if currentEvent != "" && currentData != "" {
|
if currentEvent != "" && currentData != "" {
|
||||||
handleCozeEvent(c, currentEvent, currentData, &responseText, &usage, id, info)
|
handleCozeEvent(c, currentEvent, currentData, &responseText, usage, id, info)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := scanner.Err(); err != nil {
|
if err := scanner.Err(); err != nil {
|
||||||
@@ -143,12 +143,10 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo
|
|||||||
helper.Done(c)
|
helper.Done(c)
|
||||||
|
|
||||||
if usage.TotalTokens == 0 {
|
if usage.TotalTokens == 0 {
|
||||||
usage.PromptTokens = info.PromptTokens
|
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, c.GetInt("coze_input_count"))
|
||||||
usage.CompletionTokens, _ = service.CountTextToken("gpt-3.5-turbo", responseText)
|
|
||||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, &usage
|
return nil, usage
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleCozeEvent(c *gin.Context, event string, data string, responseText *string, usage *dto.Usage, id string, info *relaycommon.RelayInfo) {
|
func handleCozeEvent(c *gin.Context, event string, data string, responseText *string, usage *dto.Usage, id string, info *relaycommon.RelayInfo) {
|
||||||
|
|||||||
@@ -243,15 +243,8 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
|
|||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
helper.Done(c)
|
helper.Done(c)
|
||||||
err := resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
// return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
common.SysError("close_response_body_failed: " + err.Error())
|
|
||||||
}
|
|
||||||
if usage.TotalTokens == 0 {
|
if usage.TotalTokens == 0 {
|
||||||
usage.PromptTokens = info.PromptTokens
|
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||||
usage.CompletionTokens, _ = service.CountTextToken("gpt-3.5-turbo", responseText)
|
|
||||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
|
||||||
}
|
}
|
||||||
usage.CompletionTokens += nodeToken
|
usage.CompletionTokens += nodeToken
|
||||||
return nil, usage
|
return nil, usage
|
||||||
|
|||||||
@@ -73,12 +73,12 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|||||||
|
|
||||||
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
|
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
|
||||||
// 新增逻辑:处理 -thinking-<budget> 格式
|
// 新增逻辑:处理 -thinking-<budget> 格式
|
||||||
if strings.Contains(info.OriginModelName, "-thinking-") {
|
if strings.Contains(info.UpstreamModelName, "-thinking-") {
|
||||||
parts := strings.Split(info.UpstreamModelName, "-thinking-")
|
parts := strings.Split(info.UpstreamModelName, "-thinking-")
|
||||||
info.UpstreamModelName = parts[0]
|
info.UpstreamModelName = parts[0]
|
||||||
} else if strings.HasSuffix(info.OriginModelName, "-thinking") { // 旧的适配
|
} else if strings.HasSuffix(info.UpstreamModelName, "-thinking") { // 旧的适配
|
||||||
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking")
|
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking")
|
||||||
} else if strings.HasSuffix(info.OriginModelName, "-nothinking") {
|
} else if strings.HasSuffix(info.UpstreamModelName, "-nothinking") {
|
||||||
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-nothinking")
|
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-nothinking")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/relay/helper"
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
@@ -75,6 +76,8 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info
|
|||||||
|
|
||||||
helper.SetEventStreamHeaders(c)
|
helper.SetEventStreamHeaders(c)
|
||||||
|
|
||||||
|
responseText := strings.Builder{}
|
||||||
|
|
||||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||||
var geminiResponse GeminiChatResponse
|
var geminiResponse GeminiChatResponse
|
||||||
err := common.DecodeJsonStr(data, &geminiResponse)
|
err := common.DecodeJsonStr(data, &geminiResponse)
|
||||||
@@ -89,6 +92,9 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info
|
|||||||
if part.InlineData != nil && part.InlineData.MimeType != "" {
|
if part.InlineData != nil && part.InlineData.MimeType != "" {
|
||||||
imageCount++
|
imageCount++
|
||||||
}
|
}
|
||||||
|
if part.Text != "" {
|
||||||
|
responseText.WriteString(part.Text)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -108,7 +114,7 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 直接发送 GeminiChatResponse 响应
|
// 直接发送 GeminiChatResponse 响应
|
||||||
err = helper.ObjectData(c, geminiResponse)
|
err = helper.StringData(c, data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(c, err.Error())
|
common.LogError(c, err.Error())
|
||||||
}
|
}
|
||||||
@@ -122,8 +128,16 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 计算最终使用量
|
// 如果usage.CompletionTokens为0,则使用本地统计的completion tokens
|
||||||
// usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
|
if usage.CompletionTokens == 0 {
|
||||||
|
str := responseText.String()
|
||||||
|
if len(str) > 0 {
|
||||||
|
usage = service.ResponseText2Usage(responseText.String(), info.UpstreamModelName, info.PromptTokens)
|
||||||
|
} else {
|
||||||
|
// 空补全,不需要使用量
|
||||||
|
usage = &dto.Usage{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 移除流式响应结尾的[Done],因为Gemini API没有发送Done的行为
|
// 移除流式响应结尾的[Done],因为Gemini API没有发送Done的行为
|
||||||
//helper.Done(c)
|
//helper.Done(c)
|
||||||
|
|||||||
@@ -99,7 +99,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
|||||||
}
|
}
|
||||||
|
|
||||||
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
|
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
|
||||||
modelName := info.OriginModelName
|
modelName := info.UpstreamModelName
|
||||||
isNew25Pro := strings.HasPrefix(modelName, "gemini-2.5-pro") &&
|
isNew25Pro := strings.HasPrefix(modelName, "gemini-2.5-pro") &&
|
||||||
!strings.HasPrefix(modelName, "gemini-2.5-pro-preview-05-06") &&
|
!strings.HasPrefix(modelName, "gemini-2.5-pro-preview-05-06") &&
|
||||||
!strings.HasPrefix(modelName, "gemini-2.5-pro-preview-03-25")
|
!strings.HasPrefix(modelName, "gemini-2.5-pro-preview-03-25")
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ import (
|
|||||||
"math"
|
"math"
|
||||||
"mime/multipart"
|
"mime/multipart"
|
||||||
"net/http"
|
"net/http"
|
||||||
"path/filepath"
|
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
@@ -16,6 +15,7 @@ import (
|
|||||||
"one-api/relay/helper"
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"os"
|
"os"
|
||||||
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/bytedance/gopkg/util/gopool"
|
"github.com/bytedance/gopkg/util/gopool"
|
||||||
@@ -181,7 +181,7 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !containStreamUsage {
|
if !containStreamUsage {
|
||||||
usage, _ = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
|
usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
|
||||||
usage.CompletionTokens += toolCount * 7
|
usage.CompletionTokens += toolCount * 7
|
||||||
} else {
|
} else {
|
||||||
if info.ChannelType == common.ChannelTypeDeepSeek {
|
if info.ChannelType == common.ChannelTypeDeepSeek {
|
||||||
@@ -216,7 +216,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
|
|||||||
StatusCode: resp.StatusCode,
|
StatusCode: resp.StatusCode,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
forceFormat := false
|
forceFormat := false
|
||||||
if forceFmt, ok := info.ChannelSetting[constant.ForceFormat].(bool); ok {
|
if forceFmt, ok := info.ChannelSetting[constant.ForceFormat].(bool); ok {
|
||||||
forceFormat = forceFmt
|
forceFormat = forceFmt
|
||||||
@@ -225,7 +225,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
|
|||||||
if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
|
if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
|
||||||
completionTokens := 0
|
completionTokens := 0
|
||||||
for _, choice := range simpleResponse.Choices {
|
for _, choice := range simpleResponse.Choices {
|
||||||
ctkm, _ := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName)
|
ctkm := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName)
|
||||||
completionTokens += ctkm
|
completionTokens += ctkm
|
||||||
}
|
}
|
||||||
simpleResponse.Usage = dto.Usage{
|
simpleResponse.Usage = dto.Usage{
|
||||||
@@ -276,9 +276,9 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
|
|||||||
func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
// the status code has been judged before, if there is a body reading failure,
|
// the status code has been judged before, if there is a body reading failure,
|
||||||
// it should be regarded as a non-recoverable error, so it should not return err for external retry.
|
// it should be regarded as a non-recoverable error, so it should not return err for external retry.
|
||||||
// Analogous to nginx's load balancing, it will only retry if it can't be requested or
|
// Analogous to nginx's load balancing, it will only retry if it can't be requested or
|
||||||
// if the upstream returns a specific status code, once the upstream has already written the header,
|
// if the upstream returns a specific status code, once the upstream has already written the header,
|
||||||
// the subsequent failure of the response body should be regarded as a non-recoverable error,
|
// the subsequent failure of the response body should be regarded as a non-recoverable error,
|
||||||
// and can be terminated directly.
|
// and can be terminated directly.
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
usage := &dto.Usage{}
|
usage := &dto.Usage{}
|
||||||
@@ -346,12 +346,12 @@ func countAudioTokens(c *gin.Context) (int, error) {
|
|||||||
if err = c.ShouldBind(&reqBody); err != nil {
|
if err = c.ShouldBind(&reqBody); err != nil {
|
||||||
return 0, errors.WithStack(err)
|
return 0, errors.WithStack(err)
|
||||||
}
|
}
|
||||||
ext := filepath.Ext(reqBody.File.Filename) // 获取文件扩展名
|
ext := filepath.Ext(reqBody.File.Filename) // 获取文件扩展名
|
||||||
reqFp, err := reqBody.File.Open()
|
reqFp, err := reqBody.File.Open()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, errors.WithStack(err)
|
return 0, errors.WithStack(err)
|
||||||
}
|
}
|
||||||
defer reqFp.Close()
|
defer reqFp.Close()
|
||||||
|
|
||||||
tmpFp, err := os.CreateTemp("", "audio-*"+ext)
|
tmpFp, err := os.CreateTemp("", "audio-*"+ext)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -110,7 +110,7 @@ func OaiResponsesStreamHandler(c *gin.Context, resp *http.Response, info *relayc
|
|||||||
tempStr := responseTextBuilder.String()
|
tempStr := responseTextBuilder.String()
|
||||||
if len(tempStr) > 0 {
|
if len(tempStr) > 0 {
|
||||||
// 非正常结束,使用输出文本的 token 数量
|
// 非正常结束,使用输出文本的 token 数量
|
||||||
completionTokens, _ := service.CountTextToken(tempStr, info.UpstreamModelName)
|
completionTokens := service.CountTextToken(tempStr, info.UpstreamModelName)
|
||||||
usage.CompletionTokens = completionTokens
|
usage.CompletionTokens = completionTokens
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -74,7 +74,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
|||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
var responseText string
|
var responseText string
|
||||||
err, responseText = palmStreamHandler(c, resp)
|
err, responseText = palmStreamHandler(c, resp)
|
||||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||||
} else {
|
} else {
|
||||||
err, usage = palmHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
err, usage = palmHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -155,7 +155,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
fullTextResponse := responsePaLM2OpenAI(&palmResponse)
|
fullTextResponse := responsePaLM2OpenAI(&palmResponse)
|
||||||
completionTokens, _ := service.CountTextToken(palmResponse.Candidates[0].Content, model)
|
completionTokens := service.CountTextToken(palmResponse.Candidates[0].Content, model)
|
||||||
usage := dto.Usage{
|
usage := dto.Usage{
|
||||||
PromptTokens: promptTokens,
|
PromptTokens: promptTokens,
|
||||||
CompletionTokens: completionTokens,
|
CompletionTokens: completionTokens,
|
||||||
|
|||||||
312
relay/channel/task/kling/adaptor.go
Normal file
312
relay/channel/task/kling/adaptor.go
Normal file
@@ -0,0 +1,312 @@
|
|||||||
|
package kling
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/golang-jwt/jwt"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/dto"
|
||||||
|
"one-api/relay/channel"
|
||||||
|
relaycommon "one-api/relay/common"
|
||||||
|
"one-api/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ============================
|
||||||
|
// Request / Response structures
|
||||||
|
// ============================
|
||||||
|
|
||||||
|
type SubmitReq struct {
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
|
Model string `json:"model,omitempty"`
|
||||||
|
Mode string `json:"mode,omitempty"`
|
||||||
|
Image string `json:"image,omitempty"`
|
||||||
|
Size string `json:"size,omitempty"`
|
||||||
|
Duration int `json:"duration,omitempty"`
|
||||||
|
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type requestPayload struct {
|
||||||
|
Prompt string `json:"prompt,omitempty"`
|
||||||
|
Image string `json:"image,omitempty"`
|
||||||
|
Mode string `json:"mode,omitempty"`
|
||||||
|
Duration string `json:"duration,omitempty"`
|
||||||
|
AspectRatio string `json:"aspect_ratio,omitempty"`
|
||||||
|
Model string `json:"model,omitempty"`
|
||||||
|
ModelName string `json:"model_name,omitempty"`
|
||||||
|
CfgScale float64 `json:"cfg_scale,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type responsePayload struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
Data struct {
|
||||||
|
TaskID string `json:"task_id"`
|
||||||
|
} `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================
|
||||||
|
// Adaptor implementation
|
||||||
|
// ============================
|
||||||
|
|
||||||
|
type TaskAdaptor struct {
|
||||||
|
ChannelType int
|
||||||
|
accessKey string
|
||||||
|
secretKey string
|
||||||
|
baseURL string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
|
||||||
|
a.ChannelType = info.ChannelType
|
||||||
|
a.baseURL = info.BaseUrl
|
||||||
|
|
||||||
|
// apiKey format: "access_key,secret_key"
|
||||||
|
keyParts := strings.Split(info.ApiKey, ",")
|
||||||
|
if len(keyParts) == 2 {
|
||||||
|
a.accessKey = strings.TrimSpace(keyParts[0])
|
||||||
|
a.secretKey = strings.TrimSpace(keyParts[1])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateRequestAndSetAction parses body, validates fields and sets default action.
|
||||||
|
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) (taskErr *dto.TaskError) {
|
||||||
|
// Accept only POST /v1/video/generations as "generate" action.
|
||||||
|
action := "generate"
|
||||||
|
info.Action = action
|
||||||
|
|
||||||
|
var req SubmitReq
|
||||||
|
if err := common.UnmarshalBodyReusable(c, &req); err != nil {
|
||||||
|
taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(req.Prompt) == "" {
|
||||||
|
taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store into context for later usage
|
||||||
|
c.Set("kling_request", req)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildRequestURL constructs the upstream URL.
|
||||||
|
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) {
|
||||||
|
return fmt.Sprintf("%s/v1/videos/image2video", a.baseURL), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildRequestHeader sets required headers.
|
||||||
|
func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error {
|
||||||
|
token, err := a.createJWTToken()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create JWT token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
req.Header.Set("Authorization", "Bearer "+token)
|
||||||
|
req.Header.Set("User-Agent", "kling-sdk/1.0")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildRequestBody converts request into Kling specific format.
|
||||||
|
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error) {
|
||||||
|
v, exists := c.Get("kling_request")
|
||||||
|
if !exists {
|
||||||
|
return nil, fmt.Errorf("request not found in context")
|
||||||
|
}
|
||||||
|
req := v.(SubmitReq)
|
||||||
|
|
||||||
|
body := a.convertToRequestPayload(&req)
|
||||||
|
data, err := json.Marshal(body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return bytes.NewReader(data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DoRequest delegates to common helper.
|
||||||
|
func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||||
|
return channel.DoTaskApiRequest(a, c, info, requestBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DoResponse handles upstream response, returns taskID etc.
|
||||||
|
func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
|
||||||
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Attempt Kling response parse first.
|
||||||
|
var kResp responsePayload
|
||||||
|
if err := json.Unmarshal(responseBody, &kResp); err == nil && kResp.Code == 0 {
|
||||||
|
c.JSON(http.StatusOK, gin.H{"task_id": kResp.Data.TaskID})
|
||||||
|
return kResp.Data.TaskID, responseBody, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback generic task response.
|
||||||
|
var generic dto.TaskResponse[string]
|
||||||
|
if err := json.Unmarshal(responseBody, &generic); err != nil {
|
||||||
|
taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !generic.IsSuccess() {
|
||||||
|
taskErr = service.TaskErrorWrapper(fmt.Errorf(generic.Message), generic.Code, http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{"task_id": generic.Data})
|
||||||
|
return generic.Data, responseBody, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// FetchTask fetch task status
|
||||||
|
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
|
||||||
|
taskID, ok := body["task_id"].(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("invalid task_id")
|
||||||
|
}
|
||||||
|
url := fmt.Sprintf("%s/v1/videos/image2video/%s", baseUrl, taskID)
|
||||||
|
|
||||||
|
req, err := http.NewRequest(http.MethodGet, url, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := a.createJWTTokenWithKey(key)
|
||||||
|
if err != nil {
|
||||||
|
token = key
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
req = req.WithContext(ctx)
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
req.Header.Set("Authorization", "Bearer "+token)
|
||||||
|
req.Header.Set("User-Agent", "kling-sdk/1.0")
|
||||||
|
|
||||||
|
return service.GetHttpClient().Do(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *TaskAdaptor) GetModelList() []string {
|
||||||
|
return []string{"kling-v1", "kling-v1-6", "kling-v2-master"}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *TaskAdaptor) GetChannelName() string {
|
||||||
|
return "kling"
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================
|
||||||
|
// helpers
|
||||||
|
// ============================
|
||||||
|
|
||||||
|
func (a *TaskAdaptor) convertToRequestPayload(req *SubmitReq) *requestPayload {
|
||||||
|
r := &requestPayload{
|
||||||
|
Prompt: req.Prompt,
|
||||||
|
Image: req.Image,
|
||||||
|
Mode: defaultString(req.Mode, "std"),
|
||||||
|
Duration: fmt.Sprintf("%d", defaultInt(req.Duration, 5)),
|
||||||
|
AspectRatio: a.getAspectRatio(req.Size),
|
||||||
|
Model: req.Model,
|
||||||
|
ModelName: req.Model,
|
||||||
|
CfgScale: 0.5,
|
||||||
|
}
|
||||||
|
if r.Model == "" {
|
||||||
|
r.Model = "kling-v1"
|
||||||
|
r.ModelName = "kling-v1"
|
||||||
|
}
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *TaskAdaptor) getAspectRatio(size string) string {
|
||||||
|
switch size {
|
||||||
|
case "1024x1024", "512x512":
|
||||||
|
return "1:1"
|
||||||
|
case "1280x720", "1920x1080":
|
||||||
|
return "16:9"
|
||||||
|
case "720x1280", "1080x1920":
|
||||||
|
return "9:16"
|
||||||
|
default:
|
||||||
|
return "1:1"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func defaultString(s, def string) string {
|
||||||
|
if strings.TrimSpace(s) == "" {
|
||||||
|
return def
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func defaultInt(v int, def int) int {
|
||||||
|
if v == 0 {
|
||||||
|
return def
|
||||||
|
}
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================
|
||||||
|
// JWT helpers
|
||||||
|
// ============================
|
||||||
|
|
||||||
|
func (a *TaskAdaptor) createJWTToken() (string, error) {
|
||||||
|
return a.createJWTTokenWithKeys(a.accessKey, a.secretKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *TaskAdaptor) createJWTTokenWithKey(apiKey string) (string, error) {
|
||||||
|
parts := strings.Split(apiKey, ",")
|
||||||
|
if len(parts) != 2 {
|
||||||
|
return "", fmt.Errorf("invalid API key format, expected 'access_key,secret_key'")
|
||||||
|
}
|
||||||
|
return a.createJWTTokenWithKeys(strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1]))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *TaskAdaptor) createJWTTokenWithKeys(accessKey, secretKey string) (string, error) {
|
||||||
|
if accessKey == "" || secretKey == "" {
|
||||||
|
return "", fmt.Errorf("access key and secret key are required")
|
||||||
|
}
|
||||||
|
now := time.Now().Unix()
|
||||||
|
claims := jwt.MapClaims{
|
||||||
|
"iss": accessKey,
|
||||||
|
"exp": now + 1800, // 30 minutes
|
||||||
|
"nbf": now - 5,
|
||||||
|
}
|
||||||
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||||
|
token.Header["typ"] = "JWT"
|
||||||
|
return token.SignedString([]byte(secretKey))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseResultUrl 提取视频任务结果的 url
|
||||||
|
func (a *TaskAdaptor) ParseResultUrl(resp map[string]any) (string, error) {
|
||||||
|
data, ok := resp["data"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("data field not found or invalid")
|
||||||
|
}
|
||||||
|
taskResult, ok := data["task_result"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("task_result field not found or invalid")
|
||||||
|
}
|
||||||
|
videos, ok := taskResult["videos"].([]interface{})
|
||||||
|
if !ok || len(videos) == 0 {
|
||||||
|
return "", fmt.Errorf("videos field not found or empty")
|
||||||
|
}
|
||||||
|
video, ok := videos[0].(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("video item invalid")
|
||||||
|
}
|
||||||
|
url, ok := video["url"].(string)
|
||||||
|
if !ok || url == "" {
|
||||||
|
return "", fmt.Errorf("url field not found or invalid")
|
||||||
|
}
|
||||||
|
return url, nil
|
||||||
|
}
|
||||||
@@ -22,6 +22,10 @@ type TaskAdaptor struct {
|
|||||||
ChannelType int
|
ChannelType int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *TaskAdaptor) ParseResultUrl(resp map[string]any) (string, error) {
|
||||||
|
return "", nil // todo implement this method if needed
|
||||||
|
}
|
||||||
|
|
||||||
func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
|
func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
|
||||||
a.ChannelType = info.ChannelType
|
a.ChannelType = info.ChannelType
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -98,7 +98,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
|||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
var responseText string
|
var responseText string
|
||||||
err, responseText = tencentStreamHandler(c, resp)
|
err, responseText = tencentStreamHandler(c, resp)
|
||||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||||
} else {
|
} else {
|
||||||
err, usage = tencentHandler(c, resp)
|
err, usage = tencentHandler(c, resp)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -83,10 +83,13 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|||||||
suffix := ""
|
suffix := ""
|
||||||
if a.RequestMode == RequestModeGemini {
|
if a.RequestMode == RequestModeGemini {
|
||||||
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
|
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
|
||||||
// suffix -thinking and -nothinking
|
// 新增逻辑:处理 -thinking-<budget> 格式
|
||||||
if strings.HasSuffix(info.OriginModelName, "-thinking") {
|
if strings.Contains(info.UpstreamModelName, "-thinking-") {
|
||||||
|
parts := strings.Split(info.UpstreamModelName, "-thinking-")
|
||||||
|
info.UpstreamModelName = parts[0]
|
||||||
|
} else if strings.HasSuffix(info.UpstreamModelName, "-thinking") { // 旧的适配
|
||||||
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking")
|
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking")
|
||||||
} else if strings.HasSuffix(info.OriginModelName, "-nothinking") {
|
} else if strings.HasSuffix(info.UpstreamModelName, "-nothinking") {
|
||||||
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-nothinking")
|
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-nothinking")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -68,7 +68,7 @@ func xAIStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
|||||||
})
|
})
|
||||||
|
|
||||||
if !containStreamUsage {
|
if !containStreamUsage {
|
||||||
usage, _ = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
|
usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
|
||||||
usage.CompletionTokens += toolCount * 7
|
usage.CompletionTokens += toolCount * 7
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -46,13 +46,11 @@ func ClaudeHelper(c *gin.Context) (claudeError *dto.ClaudeErrorWithStatusCode) {
|
|||||||
relayInfo.IsStream = true
|
relayInfo.IsStream = true
|
||||||
}
|
}
|
||||||
|
|
||||||
err = helper.ModelMappedHelper(c, relayInfo)
|
err = helper.ModelMappedHelper(c, relayInfo, textRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.ClaudeErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
|
return service.ClaudeErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
textRequest.Model = relayInfo.UpstreamModelName
|
|
||||||
|
|
||||||
promptTokens, err := getClaudePromptTokens(textRequest, relayInfo)
|
promptTokens, err := getClaudePromptTokens(textRequest, relayInfo)
|
||||||
// count messages token error 计算promptTokens错误
|
// count messages token error 计算promptTokens错误
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -126,7 +124,7 @@ func ClaudeHelper(c *gin.Context) (claudeError *dto.ClaudeErrorWithStatusCode) {
|
|||||||
var httpResp *http.Response
|
var httpResp *http.Response
|
||||||
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.ClaudeErrorWrapperLocal(err, "do_request_failed", http.StatusInternalServerError)
|
return service.ClaudeErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp != nil {
|
if resp != nil {
|
||||||
|
|||||||
@@ -34,9 +34,14 @@ type ClaudeConvertInfo struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
RelayFormatOpenAI = "openai"
|
RelayFormatOpenAI = "openai"
|
||||||
RelayFormatClaude = "claude"
|
RelayFormatClaude = "claude"
|
||||||
RelayFormatGemini = "gemini"
|
RelayFormatGemini = "gemini"
|
||||||
|
RelayFormatOpenAIResponses = "openai_responses"
|
||||||
|
RelayFormatOpenAIAudio = "openai_audio"
|
||||||
|
RelayFormatOpenAIImage = "openai_image"
|
||||||
|
RelayFormatRerank = "rerank"
|
||||||
|
RelayFormatEmbedding = "embedding"
|
||||||
)
|
)
|
||||||
|
|
||||||
type RerankerInfo struct {
|
type RerankerInfo struct {
|
||||||
@@ -143,6 +148,7 @@ func GenRelayInfoClaude(c *gin.Context) *RelayInfo {
|
|||||||
func GenRelayInfoRerank(c *gin.Context, req *dto.RerankRequest) *RelayInfo {
|
func GenRelayInfoRerank(c *gin.Context, req *dto.RerankRequest) *RelayInfo {
|
||||||
info := GenRelayInfo(c)
|
info := GenRelayInfo(c)
|
||||||
info.RelayMode = relayconstant.RelayModeRerank
|
info.RelayMode = relayconstant.RelayModeRerank
|
||||||
|
info.RelayFormat = RelayFormatRerank
|
||||||
info.RerankerInfo = &RerankerInfo{
|
info.RerankerInfo = &RerankerInfo{
|
||||||
Documents: req.Documents,
|
Documents: req.Documents,
|
||||||
ReturnDocuments: req.GetReturnDocuments(),
|
ReturnDocuments: req.GetReturnDocuments(),
|
||||||
@@ -150,9 +156,25 @@ func GenRelayInfoRerank(c *gin.Context, req *dto.RerankRequest) *RelayInfo {
|
|||||||
return info
|
return info
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GenRelayInfoOpenAIAudio(c *gin.Context) *RelayInfo {
|
||||||
|
info := GenRelayInfo(c)
|
||||||
|
info.RelayFormat = RelayFormatOpenAIAudio
|
||||||
|
return info
|
||||||
|
}
|
||||||
|
|
||||||
|
func GenRelayInfoEmbedding(c *gin.Context) *RelayInfo {
|
||||||
|
info := GenRelayInfo(c)
|
||||||
|
info.RelayFormat = RelayFormatEmbedding
|
||||||
|
return info
|
||||||
|
}
|
||||||
|
|
||||||
func GenRelayInfoResponses(c *gin.Context, req *dto.OpenAIResponsesRequest) *RelayInfo {
|
func GenRelayInfoResponses(c *gin.Context, req *dto.OpenAIResponsesRequest) *RelayInfo {
|
||||||
info := GenRelayInfo(c)
|
info := GenRelayInfo(c)
|
||||||
info.RelayMode = relayconstant.RelayModeResponses
|
info.RelayMode = relayconstant.RelayModeResponses
|
||||||
|
info.RelayFormat = RelayFormatOpenAIResponses
|
||||||
|
|
||||||
|
info.SupportStreamOptions = false
|
||||||
|
|
||||||
info.ResponsesUsageInfo = &ResponsesUsageInfo{
|
info.ResponsesUsageInfo = &ResponsesUsageInfo{
|
||||||
BuiltInTools: make(map[string]*BuildInToolInfo),
|
BuiltInTools: make(map[string]*BuildInToolInfo),
|
||||||
}
|
}
|
||||||
@@ -175,6 +197,19 @@ func GenRelayInfoResponses(c *gin.Context, req *dto.OpenAIResponsesRequest) *Rel
|
|||||||
return info
|
return info
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GenRelayInfoGemini(c *gin.Context) *RelayInfo {
|
||||||
|
info := GenRelayInfo(c)
|
||||||
|
info.RelayFormat = RelayFormatGemini
|
||||||
|
info.ShouldIncludeUsage = false
|
||||||
|
return info
|
||||||
|
}
|
||||||
|
|
||||||
|
func GenRelayInfoImage(c *gin.Context) *RelayInfo {
|
||||||
|
info := GenRelayInfo(c)
|
||||||
|
info.RelayFormat = RelayFormatOpenAIImage
|
||||||
|
return info
|
||||||
|
}
|
||||||
|
|
||||||
func GenRelayInfo(c *gin.Context) *RelayInfo {
|
func GenRelayInfo(c *gin.Context) *RelayInfo {
|
||||||
channelType := c.GetInt("channel_type")
|
channelType := c.GetInt("channel_type")
|
||||||
channelId := c.GetInt("channel_id")
|
channelId := c.GetInt("channel_id")
|
||||||
@@ -243,10 +278,6 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
|
|||||||
if streamSupportedChannels[info.ChannelType] {
|
if streamSupportedChannels[info.ChannelType] {
|
||||||
info.SupportStreamOptions = true
|
info.SupportStreamOptions = true
|
||||||
}
|
}
|
||||||
// responses 模式不支持 StreamOptions
|
|
||||||
if relayconstant.RelayModeResponses == info.RelayMode {
|
|
||||||
info.SupportStreamOptions = false
|
|
||||||
}
|
|
||||||
return info
|
return info
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -38,6 +38,9 @@ const (
|
|||||||
RelayModeSunoFetchByID
|
RelayModeSunoFetchByID
|
||||||
RelayModeSunoSubmit
|
RelayModeSunoSubmit
|
||||||
|
|
||||||
|
RelayModeKlingFetchByID
|
||||||
|
RelayModeKlingSubmit
|
||||||
|
|
||||||
RelayModeRerank
|
RelayModeRerank
|
||||||
|
|
||||||
RelayModeResponses
|
RelayModeResponses
|
||||||
@@ -133,3 +136,13 @@ func Path2RelaySuno(method, path string) int {
|
|||||||
}
|
}
|
||||||
return relayMode
|
return relayMode
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Path2RelayKling(method, path string) int {
|
||||||
|
relayMode := RelayModeUnknown
|
||||||
|
if method == http.MethodPost && strings.HasSuffix(path, "/video/generations") {
|
||||||
|
relayMode = RelayModeKlingSubmit
|
||||||
|
} else if method == http.MethodGet && strings.Contains(path, "/video/generations/") {
|
||||||
|
relayMode = RelayModeKlingFetchByID
|
||||||
|
}
|
||||||
|
return relayMode
|
||||||
|
}
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func getEmbeddingPromptToken(embeddingRequest dto.EmbeddingRequest) int {
|
func getEmbeddingPromptToken(embeddingRequest dto.EmbeddingRequest) int {
|
||||||
token, _ := service.CountTokenInput(embeddingRequest.Input, embeddingRequest.Model)
|
token := service.CountTokenInput(embeddingRequest.Input, embeddingRequest.Model)
|
||||||
return token
|
return token
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -33,7 +33,7 @@ func validateEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, embed
|
|||||||
}
|
}
|
||||||
|
|
||||||
func EmbeddingHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
func EmbeddingHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||||
relayInfo := relaycommon.GenRelayInfo(c)
|
relayInfo := relaycommon.GenRelayInfoEmbedding(c)
|
||||||
|
|
||||||
var embeddingRequest *dto.EmbeddingRequest
|
var embeddingRequest *dto.EmbeddingRequest
|
||||||
err := common.UnmarshalBodyReusable(c, &embeddingRequest)
|
err := common.UnmarshalBodyReusable(c, &embeddingRequest)
|
||||||
@@ -47,13 +47,11 @@ func EmbeddingHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode)
|
|||||||
return service.OpenAIErrorWrapperLocal(err, "invalid_embedding_request", http.StatusBadRequest)
|
return service.OpenAIErrorWrapperLocal(err, "invalid_embedding_request", http.StatusBadRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = helper.ModelMappedHelper(c, relayInfo)
|
err = helper.ModelMappedHelper(c, relayInfo, embeddingRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
|
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
embeddingRequest.Model = relayInfo.UpstreamModelName
|
|
||||||
|
|
||||||
promptToken := getEmbeddingPromptToken(*embeddingRequest)
|
promptToken := getEmbeddingPromptToken(*embeddingRequest)
|
||||||
relayInfo.PromptTokens = promptToken
|
relayInfo.PromptTokens = promptToken
|
||||||
|
|
||||||
@@ -59,7 +59,7 @@ func checkGeminiInputSensitive(textRequest *gemini.GeminiChatRequest) ([]string,
|
|||||||
return sensitiveWords, err
|
return sensitiveWords, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func getGeminiInputTokens(req *gemini.GeminiChatRequest, info *relaycommon.RelayInfo) (int, error) {
|
func getGeminiInputTokens(req *gemini.GeminiChatRequest, info *relaycommon.RelayInfo) int {
|
||||||
// 计算输入 token 数量
|
// 计算输入 token 数量
|
||||||
var inputTexts []string
|
var inputTexts []string
|
||||||
for _, content := range req.Contents {
|
for _, content := range req.Contents {
|
||||||
@@ -71,9 +71,9 @@ func getGeminiInputTokens(req *gemini.GeminiChatRequest, info *relaycommon.Relay
|
|||||||
}
|
}
|
||||||
|
|
||||||
inputText := strings.Join(inputTexts, "\n")
|
inputText := strings.Join(inputTexts, "\n")
|
||||||
inputTokens, err := service.CountTokenInput(inputText, info.UpstreamModelName)
|
inputTokens := service.CountTokenInput(inputText, info.UpstreamModelName)
|
||||||
info.PromptTokens = inputTokens
|
info.PromptTokens = inputTokens
|
||||||
return inputTokens, err
|
return inputTokens
|
||||||
}
|
}
|
||||||
|
|
||||||
func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||||
@@ -83,7 +83,7 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
|||||||
return service.OpenAIErrorWrapperLocal(err, "invalid_gemini_request", http.StatusBadRequest)
|
return service.OpenAIErrorWrapperLocal(err, "invalid_gemini_request", http.StatusBadRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
relayInfo := relaycommon.GenRelayInfo(c)
|
relayInfo := relaycommon.GenRelayInfoGemini(c)
|
||||||
|
|
||||||
// 检查 Gemini 流式模式
|
// 检查 Gemini 流式模式
|
||||||
checkGeminiStreamMode(c, relayInfo)
|
checkGeminiStreamMode(c, relayInfo)
|
||||||
@@ -97,7 +97,7 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// model mapped 模型映射
|
// model mapped 模型映射
|
||||||
err = helper.ModelMappedHelper(c, relayInfo)
|
err = helper.ModelMappedHelper(c, relayInfo, req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusBadRequest)
|
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusBadRequest)
|
||||||
}
|
}
|
||||||
@@ -106,7 +106,7 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
|||||||
promptTokens := value.(int)
|
promptTokens := value.(int)
|
||||||
relayInfo.SetPromptTokens(promptTokens)
|
relayInfo.SetPromptTokens(promptTokens)
|
||||||
} else {
|
} else {
|
||||||
promptTokens, err := getGeminiInputTokens(req, relayInfo)
|
promptTokens := getGeminiInputTokens(req, relayInfo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapperLocal(err, "count_input_tokens_error", http.StatusBadRequest)
|
return service.OpenAIErrorWrapperLocal(err, "count_input_tokens_error", http.StatusBadRequest)
|
||||||
}
|
}
|
||||||
@@ -162,7 +162,7 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
|||||||
resp, err := adaptor.DoRequest(c, relayInfo, bytes.NewReader(requestBody))
|
resp, err := adaptor.DoRequest(c, relayInfo, bytes.NewReader(requestBody))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(c, "Do gemini request failed: "+err.Error())
|
common.LogError(c, "Do gemini request failed: "+err.Error())
|
||||||
return service.OpenAIErrorWrapperLocal(err, "do_request_failed", http.StatusInternalServerError)
|
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||||
@@ -4,12 +4,14 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
common2 "one-api/common"
|
||||||
|
"one-api/dto"
|
||||||
"one-api/relay/common"
|
"one-api/relay/common"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func ModelMappedHelper(c *gin.Context, info *common.RelayInfo) error {
|
func ModelMappedHelper(c *gin.Context, info *common.RelayInfo, request any) error {
|
||||||
// map model name
|
// map model name
|
||||||
modelMapping := c.GetString("model_mapping")
|
modelMapping := c.GetString("model_mapping")
|
||||||
if modelMapping != "" && modelMapping != "{}" {
|
if modelMapping != "" && modelMapping != "{}" {
|
||||||
@@ -50,5 +52,41 @@ func ModelMappedHelper(c *gin.Context, info *common.RelayInfo) error {
|
|||||||
info.UpstreamModelName = currentModel
|
info.UpstreamModelName = currentModel
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if request != nil {
|
||||||
|
switch info.RelayFormat {
|
||||||
|
case common.RelayFormatGemini:
|
||||||
|
// Gemini 模型映射
|
||||||
|
case common.RelayFormatClaude:
|
||||||
|
if claudeRequest, ok := request.(*dto.ClaudeRequest); ok {
|
||||||
|
claudeRequest.Model = info.UpstreamModelName
|
||||||
|
}
|
||||||
|
case common.RelayFormatOpenAIResponses:
|
||||||
|
if openAIResponsesRequest, ok := request.(*dto.OpenAIResponsesRequest); ok {
|
||||||
|
openAIResponsesRequest.Model = info.UpstreamModelName
|
||||||
|
}
|
||||||
|
case common.RelayFormatOpenAIAudio:
|
||||||
|
if openAIAudioRequest, ok := request.(*dto.AudioRequest); ok {
|
||||||
|
openAIAudioRequest.Model = info.UpstreamModelName
|
||||||
|
}
|
||||||
|
case common.RelayFormatOpenAIImage:
|
||||||
|
if imageRequest, ok := request.(*dto.ImageRequest); ok {
|
||||||
|
imageRequest.Model = info.UpstreamModelName
|
||||||
|
}
|
||||||
|
case common.RelayFormatRerank:
|
||||||
|
if rerankRequest, ok := request.(*dto.RerankRequest); ok {
|
||||||
|
rerankRequest.Model = info.UpstreamModelName
|
||||||
|
}
|
||||||
|
case common.RelayFormatEmbedding:
|
||||||
|
if embeddingRequest, ok := request.(*dto.EmbeddingRequest); ok {
|
||||||
|
embeddingRequest.Model = info.UpstreamModelName
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if openAIRequest, ok := request.(*dto.GeneralOpenAIRequest); ok {
|
||||||
|
openAIRequest.Model = info.UpstreamModelName
|
||||||
|
} else {
|
||||||
|
common2.LogWarn(c, fmt.Sprintf("model mapped but request type %T not supported", request))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -102,7 +102,7 @@ func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
|
|||||||
}
|
}
|
||||||
|
|
||||||
func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
||||||
relayInfo := relaycommon.GenRelayInfo(c)
|
relayInfo := relaycommon.GenRelayInfoImage(c)
|
||||||
|
|
||||||
imageRequest, err := getAndValidImageRequest(c, relayInfo)
|
imageRequest, err := getAndValidImageRequest(c, relayInfo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -110,13 +110,11 @@ func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
|||||||
return service.OpenAIErrorWrapper(err, "invalid_image_request", http.StatusBadRequest)
|
return service.OpenAIErrorWrapper(err, "invalid_image_request", http.StatusBadRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = helper.ModelMappedHelper(c, relayInfo)
|
err = helper.ModelMappedHelper(c, relayInfo, imageRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
|
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
imageRequest.Model = relayInfo.UpstreamModelName
|
|
||||||
|
|
||||||
priceData, err := helper.ModelPriceHelper(c, relayInfo, len(imageRequest.Prompt), 0)
|
priceData, err := helper.ModelPriceHelper(c, relayInfo, len(imageRequest.Prompt), 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
|
return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
|
||||||
@@ -108,13 +108,11 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = helper.ModelMappedHelper(c, relayInfo)
|
err = helper.ModelMappedHelper(c, relayInfo, textRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
|
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
textRequest.Model = relayInfo.UpstreamModelName
|
|
||||||
|
|
||||||
// 获取 promptTokens,如果上下文中已经存在,则直接使用
|
// 获取 promptTokens,如果上下文中已经存在,则直接使用
|
||||||
var promptTokens int
|
var promptTokens int
|
||||||
if value, exists := c.Get("prompt_tokens"); exists {
|
if value, exists := c.Get("prompt_tokens"); exists {
|
||||||
@@ -253,11 +251,11 @@ func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.Re
|
|||||||
case relayconstant.RelayModeChatCompletions:
|
case relayconstant.RelayModeChatCompletions:
|
||||||
promptTokens, err = service.CountTokenChatRequest(info, *textRequest)
|
promptTokens, err = service.CountTokenChatRequest(info, *textRequest)
|
||||||
case relayconstant.RelayModeCompletions:
|
case relayconstant.RelayModeCompletions:
|
||||||
promptTokens, err = service.CountTokenInput(textRequest.Prompt, textRequest.Model)
|
promptTokens = service.CountTokenInput(textRequest.Prompt, textRequest.Model)
|
||||||
case relayconstant.RelayModeModerations:
|
case relayconstant.RelayModeModerations:
|
||||||
promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model)
|
promptTokens = service.CountTokenInput(textRequest.Input, textRequest.Model)
|
||||||
case relayconstant.RelayModeEmbeddings:
|
case relayconstant.RelayModeEmbeddings:
|
||||||
promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model)
|
promptTokens = service.CountTokenInput(textRequest.Input, textRequest.Model)
|
||||||
default:
|
default:
|
||||||
err = errors.New("unknown relay mode")
|
err = errors.New("unknown relay mode")
|
||||||
promptTokens = 0
|
promptTokens = 0
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ import (
|
|||||||
"one-api/relay/channel/palm"
|
"one-api/relay/channel/palm"
|
||||||
"one-api/relay/channel/perplexity"
|
"one-api/relay/channel/perplexity"
|
||||||
"one-api/relay/channel/siliconflow"
|
"one-api/relay/channel/siliconflow"
|
||||||
|
"one-api/relay/channel/task/kling"
|
||||||
"one-api/relay/channel/task/suno"
|
"one-api/relay/channel/task/suno"
|
||||||
"one-api/relay/channel/tencent"
|
"one-api/relay/channel/tencent"
|
||||||
"one-api/relay/channel/vertex"
|
"one-api/relay/channel/vertex"
|
||||||
@@ -101,6 +102,8 @@ func GetTaskAdaptor(platform commonconstant.TaskPlatform) channel.TaskAdaptor {
|
|||||||
// return &aiproxy.Adaptor{}
|
// return &aiproxy.Adaptor{}
|
||||||
case commonconstant.TaskPlatformSuno:
|
case commonconstant.TaskPlatformSuno:
|
||||||
return &suno.TaskAdaptor{}
|
return &suno.TaskAdaptor{}
|
||||||
|
case commonconstant.TaskPlatformKling:
|
||||||
|
return &kling.TaskAdaptor{}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -37,6 +37,9 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
modelName := service.CoverTaskActionToModelName(platform, relayInfo.Action)
|
modelName := service.CoverTaskActionToModelName(platform, relayInfo.Action)
|
||||||
|
if platform == constant.TaskPlatformKling {
|
||||||
|
modelName = relayInfo.OriginModelName
|
||||||
|
}
|
||||||
modelPrice, success := ratio_setting.GetModelPrice(modelName, true)
|
modelPrice, success := ratio_setting.GetModelPrice(modelName, true)
|
||||||
if !success {
|
if !success {
|
||||||
defaultPrice, ok := ratio_setting.GetDefaultModelRatioMap()[modelName]
|
defaultPrice, ok := ratio_setting.GetDefaultModelRatioMap()[modelName]
|
||||||
@@ -136,10 +139,11 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
|
|||||||
}
|
}
|
||||||
relayInfo.ConsumeQuota = true
|
relayInfo.ConsumeQuota = true
|
||||||
// insert task
|
// insert task
|
||||||
task := model.InitTask(constant.TaskPlatformSuno, relayInfo)
|
task := model.InitTask(platform, relayInfo)
|
||||||
task.TaskID = taskID
|
task.TaskID = taskID
|
||||||
task.Quota = quota
|
task.Quota = quota
|
||||||
task.Data = taskData
|
task.Data = taskData
|
||||||
|
task.Action = relayInfo.Action
|
||||||
err = task.Insert()
|
err = task.Insert()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
taskErr = service.TaskErrorWrapper(err, "insert_task_failed", http.StatusInternalServerError)
|
taskErr = service.TaskErrorWrapper(err, "insert_task_failed", http.StatusInternalServerError)
|
||||||
@@ -149,8 +153,9 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var fetchRespBuilders = map[int]func(c *gin.Context) (respBody []byte, taskResp *dto.TaskError){
|
var fetchRespBuilders = map[int]func(c *gin.Context) (respBody []byte, taskResp *dto.TaskError){
|
||||||
relayconstant.RelayModeSunoFetchByID: sunoFetchByIDRespBodyBuilder,
|
relayconstant.RelayModeSunoFetchByID: sunoFetchByIDRespBodyBuilder,
|
||||||
relayconstant.RelayModeSunoFetch: sunoFetchRespBodyBuilder,
|
relayconstant.RelayModeSunoFetch: sunoFetchRespBodyBuilder,
|
||||||
|
relayconstant.RelayModeKlingFetchByID: videoFetchByIDRespBodyBuilder,
|
||||||
}
|
}
|
||||||
|
|
||||||
func RelayTaskFetch(c *gin.Context, relayMode int) (taskResp *dto.TaskError) {
|
func RelayTaskFetch(c *gin.Context, relayMode int) (taskResp *dto.TaskError) {
|
||||||
@@ -225,6 +230,27 @@ func sunoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dt
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
|
||||||
|
taskId := c.Param("id")
|
||||||
|
userId := c.GetInt("id")
|
||||||
|
|
||||||
|
originTask, exist, err := model.GetByTaskId(userId, taskId)
|
||||||
|
if err != nil {
|
||||||
|
taskResp = service.TaskErrorWrapper(err, "get_task_failed", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !exist {
|
||||||
|
taskResp = service.TaskErrorWrapperLocal(errors.New("task_not_exist"), "task_not_exist", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
respBody, err = json.Marshal(dto.TaskResponse[any]{
|
||||||
|
Code: "success",
|
||||||
|
Data: TaskModel2Dto(originTask),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
func TaskModel2Dto(task *model.Task) *dto.TaskDto {
|
func TaskModel2Dto(task *model.Task) *dto.TaskDto {
|
||||||
return &dto.TaskDto{
|
return &dto.TaskDto{
|
||||||
TaskID: task.TaskID,
|
TaskID: task.TaskID,
|
||||||
|
|||||||
@@ -14,12 +14,10 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func getRerankPromptToken(rerankRequest dto.RerankRequest) int {
|
func getRerankPromptToken(rerankRequest dto.RerankRequest) int {
|
||||||
token, _ := service.CountTokenInput(rerankRequest.Query, rerankRequest.Model)
|
token := service.CountTokenInput(rerankRequest.Query, rerankRequest.Model)
|
||||||
for _, document := range rerankRequest.Documents {
|
for _, document := range rerankRequest.Documents {
|
||||||
tkm, err := service.CountTokenInput(document, rerankRequest.Model)
|
tkm := service.CountTokenInput(document, rerankRequest.Model)
|
||||||
if err == nil {
|
token += tkm
|
||||||
token += tkm
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return token
|
return token
|
||||||
}
|
}
|
||||||
@@ -42,13 +40,11 @@ func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWith
|
|||||||
return service.OpenAIErrorWrapperLocal(fmt.Errorf("documents is empty"), "invalid_documents", http.StatusBadRequest)
|
return service.OpenAIErrorWrapperLocal(fmt.Errorf("documents is empty"), "invalid_documents", http.StatusBadRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = helper.ModelMappedHelper(c, relayInfo)
|
err = helper.ModelMappedHelper(c, relayInfo, rerankRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
|
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
rerankRequest.Model = relayInfo.UpstreamModelName
|
|
||||||
|
|
||||||
promptToken := getRerankPromptToken(*rerankRequest)
|
promptToken := getRerankPromptToken(*rerankRequest)
|
||||||
relayInfo.PromptTokens = promptToken
|
relayInfo.PromptTokens = promptToken
|
||||||
|
|
||||||
@@ -40,10 +40,10 @@ func checkInputSensitive(textRequest *dto.OpenAIResponsesRequest, info *relaycom
|
|||||||
return sensitiveWords, err
|
return sensitiveWords, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func getInputTokens(req *dto.OpenAIResponsesRequest, info *relaycommon.RelayInfo) (int, error) {
|
func getInputTokens(req *dto.OpenAIResponsesRequest, info *relaycommon.RelayInfo) int {
|
||||||
inputTokens, err := service.CountTokenInput(req.Input, req.Model)
|
inputTokens := service.CountTokenInput(req.Input, req.Model)
|
||||||
info.PromptTokens = inputTokens
|
info.PromptTokens = inputTokens
|
||||||
return inputTokens, err
|
return inputTokens
|
||||||
}
|
}
|
||||||
|
|
||||||
func ResponsesHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
func ResponsesHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||||
@@ -63,19 +63,16 @@ func ResponsesHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = helper.ModelMappedHelper(c, relayInfo)
|
err = helper.ModelMappedHelper(c, relayInfo, req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusBadRequest)
|
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusBadRequest)
|
||||||
}
|
}
|
||||||
req.Model = relayInfo.UpstreamModelName
|
|
||||||
if value, exists := c.Get("prompt_tokens"); exists {
|
if value, exists := c.Get("prompt_tokens"); exists {
|
||||||
promptTokens := value.(int)
|
promptTokens := value.(int)
|
||||||
relayInfo.SetPromptTokens(promptTokens)
|
relayInfo.SetPromptTokens(promptTokens)
|
||||||
} else {
|
} else {
|
||||||
promptTokens, err := getInputTokens(req, relayInfo)
|
promptTokens := getInputTokens(req, relayInfo)
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapperLocal(err, "count_input_tokens_error", http.StatusBadRequest)
|
|
||||||
}
|
|
||||||
c.Set("prompt_tokens", promptTokens)
|
c.Set("prompt_tokens", promptTokens)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -36,6 +36,7 @@ func SetApiRouter(router *gin.Engine) {
|
|||||||
apiRouter.GET("/oauth/email/bind", middleware.CriticalRateLimit(), controller.EmailBind)
|
apiRouter.GET("/oauth/email/bind", middleware.CriticalRateLimit(), controller.EmailBind)
|
||||||
apiRouter.GET("/oauth/telegram/login", middleware.CriticalRateLimit(), controller.TelegramLogin)
|
apiRouter.GET("/oauth/telegram/login", middleware.CriticalRateLimit(), controller.TelegramLogin)
|
||||||
apiRouter.GET("/oauth/telegram/bind", middleware.CriticalRateLimit(), controller.TelegramBind)
|
apiRouter.GET("/oauth/telegram/bind", middleware.CriticalRateLimit(), controller.TelegramBind)
|
||||||
|
apiRouter.GET("/ratio_config", middleware.CriticalRateLimit(), controller.GetRatioConfig)
|
||||||
|
|
||||||
userRoute := apiRouter.Group("/user")
|
userRoute := apiRouter.Group("/user")
|
||||||
{
|
{
|
||||||
@@ -83,6 +84,12 @@ func SetApiRouter(router *gin.Engine) {
|
|||||||
optionRoute.POST("/rest_model_ratio", controller.ResetModelRatio)
|
optionRoute.POST("/rest_model_ratio", controller.ResetModelRatio)
|
||||||
optionRoute.POST("/migrate_console_setting", controller.MigrateConsoleSetting) // 用于迁移检测的旧键,下个版本会删除
|
optionRoute.POST("/migrate_console_setting", controller.MigrateConsoleSetting) // 用于迁移检测的旧键,下个版本会删除
|
||||||
}
|
}
|
||||||
|
ratioSyncRoute := apiRouter.Group("/ratio_sync")
|
||||||
|
ratioSyncRoute.Use(middleware.RootAuth())
|
||||||
|
{
|
||||||
|
ratioSyncRoute.GET("/channels", controller.GetSyncableChannels)
|
||||||
|
ratioSyncRoute.POST("/fetch", controller.FetchUpstreamRatios)
|
||||||
|
}
|
||||||
channelRoute := apiRouter.Group("/channel")
|
channelRoute := apiRouter.Group("/channel")
|
||||||
channelRoute.Use(middleware.AdminAuth())
|
channelRoute.Use(middleware.AdminAuth())
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ func SetRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) {
|
|||||||
SetApiRouter(router)
|
SetApiRouter(router)
|
||||||
SetDashboardRouter(router)
|
SetDashboardRouter(router)
|
||||||
SetRelayRouter(router)
|
SetRelayRouter(router)
|
||||||
|
SetVideoRouter(router)
|
||||||
frontendBaseUrl := os.Getenv("FRONTEND_BASE_URL")
|
frontendBaseUrl := os.Getenv("FRONTEND_BASE_URL")
|
||||||
if common.IsMasterNode && frontendBaseUrl != "" {
|
if common.IsMasterNode && frontendBaseUrl != "" {
|
||||||
frontendBaseUrl = ""
|
frontendBaseUrl = ""
|
||||||
|
|||||||
17
router/video-router.go
Normal file
17
router/video-router.go
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
package router
|
||||||
|
|
||||||
|
import (
|
||||||
|
"one-api/controller"
|
||||||
|
"one-api/middleware"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
func SetVideoRouter(router *gin.Engine) {
|
||||||
|
videoV1Router := router.Group("/v1")
|
||||||
|
videoV1Router.Use(middleware.TokenAuth(), middleware.Distribute())
|
||||||
|
{
|
||||||
|
videoV1Router.POST("/video/generations", controller.RelayTask)
|
||||||
|
videoV1Router.GET("/video/generations/:task_id", controller.RelayTask)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -171,7 +171,7 @@ func CountTokenChatRequest(info *relaycommon.RelayInfo, request dto.GeneralOpenA
|
|||||||
countStr += fmt.Sprintf("%v", tool.Function.Parameters)
|
countStr += fmt.Sprintf("%v", tool.Function.Parameters)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
toolTokens, err := CountTokenInput(countStr, request.Model)
|
toolTokens := CountTokenInput(countStr, request.Model)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
@@ -194,7 +194,7 @@ func CountTokenClaudeRequest(request dto.ClaudeRequest, model string) (int, erro
|
|||||||
|
|
||||||
// Count tokens in system message
|
// Count tokens in system message
|
||||||
if request.System != "" {
|
if request.System != "" {
|
||||||
systemTokens, err := CountTokenInput(request.System, model)
|
systemTokens := CountTokenInput(request.System, model)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
@@ -296,10 +296,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent,
|
|||||||
switch request.Type {
|
switch request.Type {
|
||||||
case dto.RealtimeEventTypeSessionUpdate:
|
case dto.RealtimeEventTypeSessionUpdate:
|
||||||
if request.Session != nil {
|
if request.Session != nil {
|
||||||
msgTokens, err := CountTextToken(request.Session.Instructions, model)
|
msgTokens := CountTextToken(request.Session.Instructions, model)
|
||||||
if err != nil {
|
|
||||||
return 0, 0, err
|
|
||||||
}
|
|
||||||
textToken += msgTokens
|
textToken += msgTokens
|
||||||
}
|
}
|
||||||
case dto.RealtimeEventResponseAudioDelta:
|
case dto.RealtimeEventResponseAudioDelta:
|
||||||
@@ -311,10 +308,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent,
|
|||||||
audioToken += atk
|
audioToken += atk
|
||||||
case dto.RealtimeEventResponseAudioTranscriptionDelta, dto.RealtimeEventResponseFunctionCallArgumentsDelta:
|
case dto.RealtimeEventResponseAudioTranscriptionDelta, dto.RealtimeEventResponseFunctionCallArgumentsDelta:
|
||||||
// count text token
|
// count text token
|
||||||
tkm, err := CountTextToken(request.Delta, model)
|
tkm := CountTextToken(request.Delta, model)
|
||||||
if err != nil {
|
|
||||||
return 0, 0, fmt.Errorf("error counting text token: %v", err)
|
|
||||||
}
|
|
||||||
textToken += tkm
|
textToken += tkm
|
||||||
case dto.RealtimeEventInputAudioBufferAppend:
|
case dto.RealtimeEventInputAudioBufferAppend:
|
||||||
// count audio token
|
// count audio token
|
||||||
@@ -329,10 +323,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent,
|
|||||||
case "message":
|
case "message":
|
||||||
for _, content := range request.Item.Content {
|
for _, content := range request.Item.Content {
|
||||||
if content.Type == "input_text" {
|
if content.Type == "input_text" {
|
||||||
tokens, err := CountTextToken(content.Text, model)
|
tokens := CountTextToken(content.Text, model)
|
||||||
if err != nil {
|
|
||||||
return 0, 0, err
|
|
||||||
}
|
|
||||||
textToken += tokens
|
textToken += tokens
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -343,10 +334,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent,
|
|||||||
if !info.IsFirstRequest {
|
if !info.IsFirstRequest {
|
||||||
if info.RealtimeTools != nil && len(info.RealtimeTools) > 0 {
|
if info.RealtimeTools != nil && len(info.RealtimeTools) > 0 {
|
||||||
for _, tool := range info.RealtimeTools {
|
for _, tool := range info.RealtimeTools {
|
||||||
toolTokens, err := CountTokenInput(tool, model)
|
toolTokens := CountTokenInput(tool, model)
|
||||||
if err != nil {
|
|
||||||
return 0, 0, err
|
|
||||||
}
|
|
||||||
textToken += 8
|
textToken += 8
|
||||||
textToken += toolTokens
|
textToken += toolTokens
|
||||||
}
|
}
|
||||||
@@ -409,7 +397,7 @@ func CountTokenMessages(info *relaycommon.RelayInfo, messages []dto.Message, mod
|
|||||||
return tokenNum, nil
|
return tokenNum, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func CountTokenInput(input any, model string) (int, error) {
|
func CountTokenInput(input any, model string) int {
|
||||||
switch v := input.(type) {
|
switch v := input.(type) {
|
||||||
case string:
|
case string:
|
||||||
return CountTextToken(v, model)
|
return CountTextToken(v, model)
|
||||||
@@ -432,13 +420,13 @@ func CountTokenInput(input any, model string) (int, error) {
|
|||||||
func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice, model string) int {
|
func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice, model string) int {
|
||||||
tokens := 0
|
tokens := 0
|
||||||
for _, message := range messages {
|
for _, message := range messages {
|
||||||
tkm, _ := CountTokenInput(message.Delta.GetContentString(), model)
|
tkm := CountTokenInput(message.Delta.GetContentString(), model)
|
||||||
tokens += tkm
|
tokens += tkm
|
||||||
if message.Delta.ToolCalls != nil {
|
if message.Delta.ToolCalls != nil {
|
||||||
for _, tool := range message.Delta.ToolCalls {
|
for _, tool := range message.Delta.ToolCalls {
|
||||||
tkm, _ := CountTokenInput(tool.Function.Name, model)
|
tkm := CountTokenInput(tool.Function.Name, model)
|
||||||
tokens += tkm
|
tokens += tkm
|
||||||
tkm, _ = CountTokenInput(tool.Function.Arguments, model)
|
tkm = CountTokenInput(tool.Function.Arguments, model)
|
||||||
tokens += tkm
|
tokens += tkm
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -446,9 +434,9 @@ func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice,
|
|||||||
return tokens
|
return tokens
|
||||||
}
|
}
|
||||||
|
|
||||||
func CountTTSToken(text string, model string) (int, error) {
|
func CountTTSToken(text string, model string) int {
|
||||||
if strings.HasPrefix(model, "tts") {
|
if strings.HasPrefix(model, "tts") {
|
||||||
return utf8.RuneCountInString(text), nil
|
return utf8.RuneCountInString(text)
|
||||||
} else {
|
} else {
|
||||||
return CountTextToken(text, model)
|
return CountTextToken(text, model)
|
||||||
}
|
}
|
||||||
@@ -483,8 +471,10 @@ func CountAudioTokenOutput(audioBase64 string, audioFormat string) (int, error)
|
|||||||
//}
|
//}
|
||||||
|
|
||||||
// CountTextToken 统计文本的token数量,仅当文本包含敏感词,返回错误,同时返回token数量
|
// CountTextToken 统计文本的token数量,仅当文本包含敏感词,返回错误,同时返回token数量
|
||||||
func CountTextToken(text string, model string) (int, error) {
|
func CountTextToken(text string, model string) int {
|
||||||
var err error
|
if text == "" {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
tokenEncoder := getTokenEncoder(model)
|
tokenEncoder := getTokenEncoder(model)
|
||||||
return getTokenNum(tokenEncoder, text), err
|
return getTokenNum(tokenEncoder, text)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,13 +16,13 @@ import (
|
|||||||
// return 0, errors.New("unknown relay mode")
|
// return 0, errors.New("unknown relay mode")
|
||||||
//}
|
//}
|
||||||
|
|
||||||
func ResponseText2Usage(responseText string, modeName string, promptTokens int) (*dto.Usage, error) {
|
func ResponseText2Usage(responseText string, modeName string, promptTokens int) *dto.Usage {
|
||||||
usage := &dto.Usage{}
|
usage := &dto.Usage{}
|
||||||
usage.PromptTokens = promptTokens
|
usage.PromptTokens = promptTokens
|
||||||
ctkm, err := CountTextToken(responseText, modeName)
|
ctkm := CountTextToken(responseText, modeName)
|
||||||
usage.CompletionTokens = ctkm
|
usage.CompletionTokens = ctkm
|
||||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||||
return usage, err
|
return usage
|
||||||
}
|
}
|
||||||
|
|
||||||
func ValidUsage(usage *dto.Usage) bool {
|
func ValidUsage(usage *dto.Usage) bool {
|
||||||
|
|||||||
@@ -13,12 +13,12 @@ var PayMethods = []map[string]string{
|
|||||||
{
|
{
|
||||||
"name": "支付宝",
|
"name": "支付宝",
|
||||||
"color": "rgba(var(--semi-blue-5), 1)",
|
"color": "rgba(var(--semi-blue-5), 1)",
|
||||||
"type": "zfb",
|
"type": "alipay",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "微信",
|
"name": "微信",
|
||||||
"color": "rgba(var(--semi-green-5), 1)",
|
"color": "rgba(var(--semi-green-5), 1)",
|
||||||
"type": "wx",
|
"type": "wxpay",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -85,7 +85,11 @@ func UpdateCacheRatioByJSONString(jsonStr string) error {
|
|||||||
cacheRatioMapMutex.Lock()
|
cacheRatioMapMutex.Lock()
|
||||||
defer cacheRatioMapMutex.Unlock()
|
defer cacheRatioMapMutex.Unlock()
|
||||||
cacheRatioMap = make(map[string]float64)
|
cacheRatioMap = make(map[string]float64)
|
||||||
return json.Unmarshal([]byte(jsonStr), &cacheRatioMap)
|
err := json.Unmarshal([]byte(jsonStr), &cacheRatioMap)
|
||||||
|
if err == nil {
|
||||||
|
InvalidateExposedDataCache()
|
||||||
|
}
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetCacheRatio returns the cache ratio for a model
|
// GetCacheRatio returns the cache ratio for a model
|
||||||
@@ -106,3 +110,13 @@ func GetCreateCacheRatio(name string) (float64, bool) {
|
|||||||
}
|
}
|
||||||
return ratio, true
|
return ratio, true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetCacheRatioCopy() map[string]float64 {
|
||||||
|
cacheRatioMapMutex.RLock()
|
||||||
|
defer cacheRatioMapMutex.RUnlock()
|
||||||
|
copyMap := make(map[string]float64, len(cacheRatioMap))
|
||||||
|
for k, v := range cacheRatioMap {
|
||||||
|
copyMap[k] = v
|
||||||
|
}
|
||||||
|
return copyMap
|
||||||
|
}
|
||||||
|
|||||||
17
setting/ratio_setting/expose_ratio.go
Normal file
17
setting/ratio_setting/expose_ratio.go
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
package ratio_setting
|
||||||
|
|
||||||
|
import "sync/atomic"
|
||||||
|
|
||||||
|
var exposeRatioEnabled atomic.Bool
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
exposeRatioEnabled.Store(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
func SetExposeRatioEnabled(enabled bool) {
|
||||||
|
exposeRatioEnabled.Store(enabled)
|
||||||
|
}
|
||||||
|
|
||||||
|
func IsExposeRatioEnabled() bool {
|
||||||
|
return exposeRatioEnabled.Load()
|
||||||
|
}
|
||||||
55
setting/ratio_setting/exposed_cache.go
Normal file
55
setting/ratio_setting/exposed_cache.go
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
package ratio_setting
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
const exposedDataTTL = 30 * time.Second
|
||||||
|
|
||||||
|
type exposedCache struct {
|
||||||
|
data gin.H
|
||||||
|
expiresAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
exposedData atomic.Value
|
||||||
|
rebuildMu sync.Mutex
|
||||||
|
)
|
||||||
|
|
||||||
|
func InvalidateExposedDataCache() {
|
||||||
|
exposedData.Store((*exposedCache)(nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
func cloneGinH(src gin.H) gin.H {
|
||||||
|
dst := make(gin.H, len(src))
|
||||||
|
for k, v := range src {
|
||||||
|
dst[k] = v
|
||||||
|
}
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetExposedData() gin.H {
|
||||||
|
if c, ok := exposedData.Load().(*exposedCache); ok && c != nil && time.Now().Before(c.expiresAt) {
|
||||||
|
return cloneGinH(c.data)
|
||||||
|
}
|
||||||
|
rebuildMu.Lock()
|
||||||
|
defer rebuildMu.Unlock()
|
||||||
|
if c, ok := exposedData.Load().(*exposedCache); ok && c != nil && time.Now().Before(c.expiresAt) {
|
||||||
|
return cloneGinH(c.data)
|
||||||
|
}
|
||||||
|
newData := gin.H{
|
||||||
|
"model_ratio": GetModelRatioCopy(),
|
||||||
|
"completion_ratio": GetCompletionRatioCopy(),
|
||||||
|
"cache_ratio": GetCacheRatioCopy(),
|
||||||
|
"model_price": GetModelPriceCopy(),
|
||||||
|
}
|
||||||
|
exposedData.Store(&exposedCache{
|
||||||
|
data: newData,
|
||||||
|
expiresAt: time.Now().Add(exposedDataTTL),
|
||||||
|
})
|
||||||
|
return cloneGinH(newData)
|
||||||
|
}
|
||||||
@@ -317,7 +317,11 @@ func UpdateModelPriceByJSONString(jsonStr string) error {
|
|||||||
modelPriceMapMutex.Lock()
|
modelPriceMapMutex.Lock()
|
||||||
defer modelPriceMapMutex.Unlock()
|
defer modelPriceMapMutex.Unlock()
|
||||||
modelPriceMap = make(map[string]float64)
|
modelPriceMap = make(map[string]float64)
|
||||||
return json.Unmarshal([]byte(jsonStr), &modelPriceMap)
|
err := json.Unmarshal([]byte(jsonStr), &modelPriceMap)
|
||||||
|
if err == nil {
|
||||||
|
InvalidateExposedDataCache()
|
||||||
|
}
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetModelPrice 返回模型的价格,如果模型不存在则返回-1,false
|
// GetModelPrice 返回模型的价格,如果模型不存在则返回-1,false
|
||||||
@@ -345,7 +349,11 @@ func UpdateModelRatioByJSONString(jsonStr string) error {
|
|||||||
modelRatioMapMutex.Lock()
|
modelRatioMapMutex.Lock()
|
||||||
defer modelRatioMapMutex.Unlock()
|
defer modelRatioMapMutex.Unlock()
|
||||||
modelRatioMap = make(map[string]float64)
|
modelRatioMap = make(map[string]float64)
|
||||||
return json.Unmarshal([]byte(jsonStr), &modelRatioMap)
|
err := json.Unmarshal([]byte(jsonStr), &modelRatioMap)
|
||||||
|
if err == nil {
|
||||||
|
InvalidateExposedDataCache()
|
||||||
|
}
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 处理带有思考预算的模型名称,方便统一定价
|
// 处理带有思考预算的模型名称,方便统一定价
|
||||||
@@ -405,7 +413,11 @@ func UpdateCompletionRatioByJSONString(jsonStr string) error {
|
|||||||
CompletionRatioMutex.Lock()
|
CompletionRatioMutex.Lock()
|
||||||
defer CompletionRatioMutex.Unlock()
|
defer CompletionRatioMutex.Unlock()
|
||||||
CompletionRatio = make(map[string]float64)
|
CompletionRatio = make(map[string]float64)
|
||||||
return json.Unmarshal([]byte(jsonStr), &CompletionRatio)
|
err := json.Unmarshal([]byte(jsonStr), &CompletionRatio)
|
||||||
|
if err == nil {
|
||||||
|
InvalidateExposedDataCache()
|
||||||
|
}
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetCompletionRatio(name string) float64 {
|
func GetCompletionRatio(name string) float64 {
|
||||||
@@ -609,3 +621,33 @@ func GetImageRatio(name string) (float64, bool) {
|
|||||||
}
|
}
|
||||||
return ratio, true
|
return ratio, true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetModelRatioCopy() map[string]float64 {
|
||||||
|
modelRatioMapMutex.RLock()
|
||||||
|
defer modelRatioMapMutex.RUnlock()
|
||||||
|
copyMap := make(map[string]float64, len(modelRatioMap))
|
||||||
|
for k, v := range modelRatioMap {
|
||||||
|
copyMap[k] = v
|
||||||
|
}
|
||||||
|
return copyMap
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetModelPriceCopy() map[string]float64 {
|
||||||
|
modelPriceMapMutex.RLock()
|
||||||
|
defer modelPriceMapMutex.RUnlock()
|
||||||
|
copyMap := make(map[string]float64, len(modelPriceMap))
|
||||||
|
for k, v := range modelPriceMap {
|
||||||
|
copyMap[k] = v
|
||||||
|
}
|
||||||
|
return copyMap
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetCompletionRatioCopy() map[string]float64 {
|
||||||
|
CompletionRatioMutex.RLock()
|
||||||
|
defer CompletionRatioMutex.RUnlock()
|
||||||
|
copyMap := make(map[string]float64, len(CompletionRatio))
|
||||||
|
for k, v := range CompletionRatio {
|
||||||
|
copyMap[k] = v
|
||||||
|
}
|
||||||
|
return copyMap
|
||||||
|
}
|
||||||
|
|||||||
143
web/src/components/settings/ChannelSelectorModal.js
Normal file
143
web/src/components/settings/ChannelSelectorModal.js
Normal file
@@ -0,0 +1,143 @@
|
|||||||
|
import React, { useState } from 'react';
|
||||||
|
import {
|
||||||
|
Modal,
|
||||||
|
Transfer,
|
||||||
|
Input,
|
||||||
|
Space,
|
||||||
|
Checkbox,
|
||||||
|
Avatar,
|
||||||
|
Highlight,
|
||||||
|
} from '@douyinfe/semi-ui';
|
||||||
|
import { IconClose } from '@douyinfe/semi-icons';
|
||||||
|
|
||||||
|
const CHANNEL_STATUS_CONFIG = {
|
||||||
|
1: { color: 'green', text: '启用' },
|
||||||
|
2: { color: 'red', text: '禁用' },
|
||||||
|
3: { color: 'amber', text: '自禁' },
|
||||||
|
default: { color: 'grey', text: '未知' }
|
||||||
|
};
|
||||||
|
|
||||||
|
const getChannelStatusConfig = (status) => {
|
||||||
|
return CHANNEL_STATUS_CONFIG[status] || CHANNEL_STATUS_CONFIG.default;
|
||||||
|
};
|
||||||
|
|
||||||
|
export default function ChannelSelectorModal({
|
||||||
|
t,
|
||||||
|
visible,
|
||||||
|
onCancel,
|
||||||
|
onOk,
|
||||||
|
allChannels = [],
|
||||||
|
selectedChannelIds = [],
|
||||||
|
setSelectedChannelIds,
|
||||||
|
channelEndpoints,
|
||||||
|
updateChannelEndpoint,
|
||||||
|
}) {
|
||||||
|
const [searchText, setSearchText] = useState('');
|
||||||
|
|
||||||
|
const ChannelInfo = ({ item, showEndpoint = false, isSelected = false }) => {
|
||||||
|
const channelId = item.key || item.value;
|
||||||
|
const currentEndpoint = channelEndpoints[channelId];
|
||||||
|
const baseUrl = item._originalData?.base_url || '';
|
||||||
|
const status = item._originalData?.status || 0;
|
||||||
|
const statusConfig = getChannelStatusConfig(status);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
<Avatar color={statusConfig.color} size="small">
|
||||||
|
{statusConfig.text}
|
||||||
|
</Avatar>
|
||||||
|
<div className="info">
|
||||||
|
<div className="name">
|
||||||
|
{isSelected ? (
|
||||||
|
item.label
|
||||||
|
) : (
|
||||||
|
<Highlight sourceString={item.label} searchWords={[searchText]} />
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
<div className="email" style={showEndpoint ? { display: 'flex', alignItems: 'center', gap: '4px' } : {}}>
|
||||||
|
<span className="text-xs text-gray-500 truncate max-w-[200px]" title={baseUrl}>
|
||||||
|
{isSelected ? (
|
||||||
|
baseUrl
|
||||||
|
) : (
|
||||||
|
<Highlight sourceString={baseUrl} searchWords={[searchText]} />
|
||||||
|
)}
|
||||||
|
</span>
|
||||||
|
{showEndpoint && (
|
||||||
|
<Input
|
||||||
|
size="small"
|
||||||
|
value={currentEndpoint}
|
||||||
|
onChange={(value) => updateChannelEndpoint(channelId, value)}
|
||||||
|
placeholder="/api/ratio_config"
|
||||||
|
className="flex-1 text-xs"
|
||||||
|
style={{ fontSize: '12px' }}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
{isSelected && !showEndpoint && (
|
||||||
|
<span className="text-xs text-gray-700 font-mono bg-gray-100 px-2 py-1 rounded ml-2">
|
||||||
|
{currentEndpoint}
|
||||||
|
</span>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
const renderSourceItem = (item) => {
|
||||||
|
return (
|
||||||
|
<div className="components-transfer-source-item" key={item.key}>
|
||||||
|
<Checkbox
|
||||||
|
onChange={item.onChange}
|
||||||
|
checked={item.checked}
|
||||||
|
style={{ height: 52, alignItems: 'center' }}
|
||||||
|
>
|
||||||
|
<ChannelInfo item={item} showEndpoint={true} />
|
||||||
|
</Checkbox>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
const renderSelectedItem = (item) => {
|
||||||
|
return (
|
||||||
|
<div className="components-transfer-selected-item" key={item.key}>
|
||||||
|
<ChannelInfo item={item} isSelected={true} />
|
||||||
|
<IconClose style={{ cursor: 'pointer' }} onClick={item.onRemove} />
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
const channelFilter = (input, item) => {
|
||||||
|
const searchLower = input.toLowerCase();
|
||||||
|
return item.label.toLowerCase().includes(searchLower) ||
|
||||||
|
(item._originalData?.base_url || '').toLowerCase().includes(searchLower);
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Modal
|
||||||
|
visible={visible}
|
||||||
|
onCancel={onCancel}
|
||||||
|
onOk={onOk}
|
||||||
|
title={<span className="text-lg font-semibold">{t('选择同步渠道')}</span>}
|
||||||
|
width={1000}
|
||||||
|
>
|
||||||
|
<Space vertical style={{ width: '100%' }}>
|
||||||
|
<Transfer
|
||||||
|
style={{ width: '100%' }}
|
||||||
|
dataSource={allChannels}
|
||||||
|
value={selectedChannelIds}
|
||||||
|
onChange={setSelectedChannelIds}
|
||||||
|
renderSourceItem={renderSourceItem}
|
||||||
|
renderSelectedItem={renderSelectedItem}
|
||||||
|
filter={channelFilter}
|
||||||
|
inputProps={{ placeholder: t('搜索渠道名称或地址') }}
|
||||||
|
onSearch={setSearchText}
|
||||||
|
emptyContent={{
|
||||||
|
left: t('暂无渠道'),
|
||||||
|
right: t('暂无选择'),
|
||||||
|
search: t('无搜索结果'),
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
</Space>
|
||||||
|
</Modal>
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -6,6 +6,7 @@ import GroupRatioSettings from '../../pages/Setting/Ratio/GroupRatioSettings.js'
|
|||||||
import ModelRatioSettings from '../../pages/Setting/Ratio/ModelRatioSettings.js';
|
import ModelRatioSettings from '../../pages/Setting/Ratio/ModelRatioSettings.js';
|
||||||
import ModelSettingsVisualEditor from '../../pages/Setting/Ratio/ModelSettingsVisualEditor.js';
|
import ModelSettingsVisualEditor from '../../pages/Setting/Ratio/ModelSettingsVisualEditor.js';
|
||||||
import ModelRatioNotSetEditor from '../../pages/Setting/Ratio/ModelRationNotSetEditor.js';
|
import ModelRatioNotSetEditor from '../../pages/Setting/Ratio/ModelRationNotSetEditor.js';
|
||||||
|
import UpstreamRatioSync from '../../pages/Setting/Ratio/UpstreamRatioSync.js';
|
||||||
|
|
||||||
import { API, showError } from '../../helpers';
|
import { API, showError } from '../../helpers';
|
||||||
|
|
||||||
@@ -21,6 +22,7 @@ const RatioSetting = () => {
|
|||||||
GroupGroupRatio: '',
|
GroupGroupRatio: '',
|
||||||
AutoGroups: '',
|
AutoGroups: '',
|
||||||
DefaultUseAutoGroup: false,
|
DefaultUseAutoGroup: false,
|
||||||
|
ExposeRatioEnabled: false,
|
||||||
UserUsableGroups: '',
|
UserUsableGroups: '',
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -48,7 +50,7 @@ const RatioSetting = () => {
|
|||||||
// 如果后端返回的不是合法 JSON,直接展示
|
// 如果后端返回的不是合法 JSON,直接展示
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (['DefaultUseAutoGroup'].includes(item.key)) {
|
if (['DefaultUseAutoGroup', 'ExposeRatioEnabled'].includes(item.key)) {
|
||||||
newInputs[item.key] = item.value === 'true' ? true : false;
|
newInputs[item.key] = item.value === 'true' ? true : false;
|
||||||
} else {
|
} else {
|
||||||
newInputs[item.key] = item.value;
|
newInputs[item.key] = item.value;
|
||||||
@@ -78,10 +80,6 @@ const RatioSetting = () => {
|
|||||||
|
|
||||||
return (
|
return (
|
||||||
<Spin spinning={loading} size='large'>
|
<Spin spinning={loading} size='large'>
|
||||||
{/* 分组倍率设置 */}
|
|
||||||
<Card style={{ marginTop: '10px' }}>
|
|
||||||
<GroupRatioSettings options={inputs} refresh={onRefresh} />
|
|
||||||
</Card>
|
|
||||||
{/* 模型倍率设置以及可视化编辑器 */}
|
{/* 模型倍率设置以及可视化编辑器 */}
|
||||||
<Card style={{ marginTop: '10px' }}>
|
<Card style={{ marginTop: '10px' }}>
|
||||||
<Tabs type='line'>
|
<Tabs type='line'>
|
||||||
@@ -100,8 +98,18 @@ const RatioSetting = () => {
|
|||||||
refresh={onRefresh}
|
refresh={onRefresh}
|
||||||
/>
|
/>
|
||||||
</Tabs.TabPane>
|
</Tabs.TabPane>
|
||||||
|
<Tabs.TabPane tab={t('上游倍率同步')} itemKey='upstream_sync'>
|
||||||
|
<UpstreamRatioSync
|
||||||
|
options={inputs}
|
||||||
|
refresh={onRefresh}
|
||||||
|
/>
|
||||||
|
</Tabs.TabPane>
|
||||||
</Tabs>
|
</Tabs>
|
||||||
</Card>
|
</Card>
|
||||||
|
{/* 分组倍率设置 */}
|
||||||
|
<Card style={{ marginTop: '10px' }}>
|
||||||
|
<GroupRatioSettings options={inputs} refresh={onRefresh} />
|
||||||
|
</Card>
|
||||||
</Spin>
|
</Spin>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -11,7 +11,9 @@ import {
|
|||||||
XCircle,
|
XCircle,
|
||||||
Loader,
|
Loader,
|
||||||
List,
|
List,
|
||||||
Hash
|
Hash,
|
||||||
|
Video,
|
||||||
|
Sparkles
|
||||||
} from 'lucide-react';
|
} from 'lucide-react';
|
||||||
import {
|
import {
|
||||||
API,
|
API,
|
||||||
@@ -80,6 +82,7 @@ const COLUMN_KEYS = {
|
|||||||
TASK_STATUS: 'task_status',
|
TASK_STATUS: 'task_status',
|
||||||
PROGRESS: 'progress',
|
PROGRESS: 'progress',
|
||||||
FAIL_REASON: 'fail_reason',
|
FAIL_REASON: 'fail_reason',
|
||||||
|
RESULT_URL: 'result_url',
|
||||||
};
|
};
|
||||||
|
|
||||||
const renderTimestamp = (timestampInSeconds) => {
|
const renderTimestamp = (timestampInSeconds) => {
|
||||||
@@ -150,6 +153,7 @@ const LogsTable = () => {
|
|||||||
[COLUMN_KEYS.TASK_STATUS]: true,
|
[COLUMN_KEYS.TASK_STATUS]: true,
|
||||||
[COLUMN_KEYS.PROGRESS]: true,
|
[COLUMN_KEYS.PROGRESS]: true,
|
||||||
[COLUMN_KEYS.FAIL_REASON]: true,
|
[COLUMN_KEYS.FAIL_REASON]: true,
|
||||||
|
[COLUMN_KEYS.RESULT_URL]: true,
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -203,6 +207,12 @@ const LogsTable = () => {
|
|||||||
{t('生成歌词')}
|
{t('生成歌词')}
|
||||||
</Tag>
|
</Tag>
|
||||||
);
|
);
|
||||||
|
case 'generate':
|
||||||
|
return (
|
||||||
|
<Tag color='blue' size='large' shape='circle' prefixIcon={<Sparkles size={14} />}>
|
||||||
|
{t('生成视频')}
|
||||||
|
</Tag>
|
||||||
|
);
|
||||||
default:
|
default:
|
||||||
return (
|
return (
|
||||||
<Tag color='white' size='large' shape='circle' prefixIcon={<HelpCircle size={14} />}>
|
<Tag color='white' size='large' shape='circle' prefixIcon={<HelpCircle size={14} />}>
|
||||||
@@ -220,6 +230,12 @@ const LogsTable = () => {
|
|||||||
Suno
|
Suno
|
||||||
</Tag>
|
</Tag>
|
||||||
);
|
);
|
||||||
|
case 'kling':
|
||||||
|
return (
|
||||||
|
<Tag color='blue' size='large' shape='circle' prefixIcon={<Video size={14} />}>
|
||||||
|
Kling
|
||||||
|
</Tag>
|
||||||
|
);
|
||||||
default:
|
default:
|
||||||
return (
|
return (
|
||||||
<Tag color='white' size='large' shape='circle' prefixIcon={<HelpCircle size={14} />}>
|
<Tag color='white' size='large' shape='circle' prefixIcon={<HelpCircle size={14} />}>
|
||||||
@@ -411,10 +427,21 @@ const LogsTable = () => {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
key: COLUMN_KEYS.FAIL_REASON,
|
key: COLUMN_KEYS.FAIL_REASON,
|
||||||
title: t('失败原因'),
|
title: t('详情'),
|
||||||
dataIndex: 'fail_reason',
|
dataIndex: 'fail_reason',
|
||||||
fixed: 'right',
|
fixed: 'right',
|
||||||
render: (text, record, index) => {
|
render: (text, record, index) => {
|
||||||
|
// 仅当为视频生成任务且成功,且 fail_reason 是 URL 时显示可点击链接
|
||||||
|
const isVideoTask = record.action === 'generate';
|
||||||
|
const isSuccess = record.status === 'SUCCESS';
|
||||||
|
const isUrl = typeof text === 'string' && /^https?:\/\//.test(text);
|
||||||
|
if (isSuccess && isVideoTask && isUrl) {
|
||||||
|
return (
|
||||||
|
<a href={text} target="_blank" rel="noopener noreferrer">
|
||||||
|
{t('点击预览视频')}
|
||||||
|
</a>
|
||||||
|
);
|
||||||
|
}
|
||||||
if (!text) {
|
if (!text) {
|
||||||
return t('无');
|
return t('无');
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -125,4 +125,9 @@ export const CHANNEL_OPTIONS = [
|
|||||||
color: 'blue',
|
color: 'blue',
|
||||||
label: 'Coze',
|
label: 'Coze',
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
value: 50,
|
||||||
|
color: 'green',
|
||||||
|
label: '可灵',
|
||||||
|
},
|
||||||
];
|
];
|
||||||
|
|||||||
@@ -1 +1,3 @@
|
|||||||
export const ITEMS_PER_PAGE = 10; // this value must keep same as the one defined in backend!
|
export const ITEMS_PER_PAGE = 10; // this value must keep same as the one defined in backend!
|
||||||
|
|
||||||
|
export const DEFAULT_ENDPOINT = '/api/ratio_config';
|
||||||
@@ -1665,5 +1665,28 @@
|
|||||||
"确定清除所有失效兑换码?": "Are you sure you want to clear all invalid redemption codes?",
|
"确定清除所有失效兑换码?": "Are you sure you want to clear all invalid redemption codes?",
|
||||||
"将删除已使用、已禁用及过期的兑换码,此操作不可撤销。": "This will delete all used, disabled, and expired redemption codes, this operation cannot be undone.",
|
"将删除已使用、已禁用及过期的兑换码,此操作不可撤销。": "This will delete all used, disabled, and expired redemption codes, this operation cannot be undone.",
|
||||||
"选择过期时间(可选,留空为永久)": "Select expiration time (optional, leave blank for permanent)",
|
"选择过期时间(可选,留空为永久)": "Select expiration time (optional, leave blank for permanent)",
|
||||||
"请输入备注(仅管理员可见)": "Please enter a remark (only visible to administrators)"
|
"请输入备注(仅管理员可见)": "Please enter a remark (only visible to administrators)",
|
||||||
|
"上游倍率同步": "Upstream ratio synchronization",
|
||||||
|
"获取渠道失败:": "Failed to get channels: ",
|
||||||
|
"请至少选择一个渠道": "Please select at least one channel",
|
||||||
|
"获取倍率失败:": "Failed to get ratios: ",
|
||||||
|
"后端请求失败": "Backend request failed",
|
||||||
|
"部分渠道测试失败:": "Some channels failed to test: ",
|
||||||
|
"已与上游倍率完全一致,无需同步": "The upstream ratio is completely consistent, no synchronization is required",
|
||||||
|
"请求后端接口失败:": "Failed to request the backend interface: ",
|
||||||
|
"同步成功": "Synchronization successful",
|
||||||
|
"部分保存失败": "Some settings failed to save",
|
||||||
|
"保存失败": "Save failed",
|
||||||
|
"选择同步渠道": "Select synchronization channel",
|
||||||
|
"应用同步": "Apply synchronization",
|
||||||
|
"倍率类型": "Ratio type",
|
||||||
|
"当前值": "Current value",
|
||||||
|
"上游值": "Upstream value",
|
||||||
|
"差异": "Difference",
|
||||||
|
"搜索渠道名称或地址": "Search channel name or address",
|
||||||
|
"缓存倍率": "Cache ratio",
|
||||||
|
"暂无差异化倍率显示": "No differential ratio display",
|
||||||
|
"请先选择同步渠道": "Please select the synchronization channel first",
|
||||||
|
"与本地相同": "Same as local",
|
||||||
|
"未找到匹配的模型": "No matching model found"
|
||||||
}
|
}
|
||||||
@@ -432,4 +432,72 @@ code {
|
|||||||
.semi-table-tbody>.semi-table-row {
|
.semi-table-tbody>.semi-table-row {
|
||||||
border-bottom: 1px solid rgba(0, 0, 0, 0.1);
|
border-bottom: 1px solid rgba(0, 0, 0, 0.1);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ==================== 同步倍率 - 渠道选择器 ==================== */
|
||||||
|
|
||||||
|
.components-transfer-source-item,
|
||||||
|
.components-transfer-selected-item {
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
padding: 8px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.semi-transfer-left-list,
|
||||||
|
.semi-transfer-right-list {
|
||||||
|
-ms-overflow-style: none;
|
||||||
|
scrollbar-width: none;
|
||||||
|
}
|
||||||
|
|
||||||
|
.semi-transfer-left-list::-webkit-scrollbar,
|
||||||
|
.semi-transfer-right-list::-webkit-scrollbar {
|
||||||
|
display: none;
|
||||||
|
}
|
||||||
|
|
||||||
|
.components-transfer-source-item .semi-checkbox,
|
||||||
|
.components-transfer-selected-item .semi-checkbox {
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
width: 100%;
|
||||||
|
}
|
||||||
|
|
||||||
|
.components-transfer-source-item .semi-avatar,
|
||||||
|
.components-transfer-selected-item .semi-avatar {
|
||||||
|
margin-right: 12px;
|
||||||
|
flex-shrink: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.components-transfer-source-item .info,
|
||||||
|
.components-transfer-selected-item .info {
|
||||||
|
flex: 1;
|
||||||
|
overflow: hidden;
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
justify-content: center;
|
||||||
|
}
|
||||||
|
|
||||||
|
.components-transfer-source-item .name,
|
||||||
|
.components-transfer-selected-item .name {
|
||||||
|
font-weight: 500;
|
||||||
|
white-space: nowrap;
|
||||||
|
overflow: hidden;
|
||||||
|
text-overflow: ellipsis;
|
||||||
|
}
|
||||||
|
|
||||||
|
.components-transfer-source-item .email,
|
||||||
|
.components-transfer-selected-item .email {
|
||||||
|
font-size: 12px;
|
||||||
|
color: var(--semi-color-text-2);
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
}
|
||||||
|
|
||||||
|
.components-transfer-selected-item .semi-icon-close {
|
||||||
|
margin-left: 8px;
|
||||||
|
cursor: pointer;
|
||||||
|
color: var(--semi-color-text-2);
|
||||||
|
}
|
||||||
|
|
||||||
|
.components-transfer-selected-item .semi-icon-close:hover {
|
||||||
|
color: var(--semi-color-text-0);
|
||||||
}
|
}
|
||||||
@@ -1112,7 +1112,6 @@ const Detail = (props) => {
|
|||||||
</div>
|
</div>
|
||||||
<Tabs
|
<Tabs
|
||||||
type="button"
|
type="button"
|
||||||
preventScroll={true}
|
|
||||||
activeKey={activeChartTab}
|
activeKey={activeChartTab}
|
||||||
onChange={setActiveChartTab}
|
onChange={setActiveChartTab}
|
||||||
>
|
>
|
||||||
@@ -1389,7 +1388,6 @@ const Detail = (props) => {
|
|||||||
) : (
|
) : (
|
||||||
<Tabs
|
<Tabs
|
||||||
type="card"
|
type="card"
|
||||||
preventScroll={true}
|
|
||||||
collapsible
|
collapsible
|
||||||
activeKey={activeUptimeTab}
|
activeKey={activeUptimeTab}
|
||||||
onChange={setActiveUptimeTab}
|
onChange={setActiveUptimeTab}
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ export default function ModelRatioSettings(props) {
|
|||||||
ModelRatio: '',
|
ModelRatio: '',
|
||||||
CacheRatio: '',
|
CacheRatio: '',
|
||||||
CompletionRatio: '',
|
CompletionRatio: '',
|
||||||
|
ExposeRatioEnabled: false,
|
||||||
});
|
});
|
||||||
const refForm = useRef();
|
const refForm = useRef();
|
||||||
const [inputsRow, setInputsRow] = useState(inputs);
|
const [inputsRow, setInputsRow] = useState(inputs);
|
||||||
@@ -206,6 +207,17 @@ export default function ModelRatioSettings(props) {
|
|||||||
/>
|
/>
|
||||||
</Col>
|
</Col>
|
||||||
</Row>
|
</Row>
|
||||||
|
<Row gutter={16}>
|
||||||
|
<Col span={16}>
|
||||||
|
<Form.Switch
|
||||||
|
label={t('暴露倍率接口')}
|
||||||
|
field={'ExposeRatioEnabled'}
|
||||||
|
onChange={(value) =>
|
||||||
|
setInputs({ ...inputs, ExposeRatioEnabled: value })
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
</Col>
|
||||||
|
</Row>
|
||||||
</Form.Section>
|
</Form.Section>
|
||||||
</Form>
|
</Form>
|
||||||
<Space>
|
<Space>
|
||||||
|
|||||||
503
web/src/pages/Setting/Ratio/UpstreamRatioSync.js
Normal file
503
web/src/pages/Setting/Ratio/UpstreamRatioSync.js
Normal file
@@ -0,0 +1,503 @@
|
|||||||
|
import React, { useState, useCallback, useMemo } from 'react';
|
||||||
|
import {
|
||||||
|
Button,
|
||||||
|
Table,
|
||||||
|
Tag,
|
||||||
|
Empty,
|
||||||
|
Checkbox,
|
||||||
|
Form,
|
||||||
|
Input,
|
||||||
|
} from '@douyinfe/semi-ui';
|
||||||
|
import { IconSearch } from '@douyinfe/semi-icons';
|
||||||
|
import {
|
||||||
|
RefreshCcw,
|
||||||
|
CheckSquare,
|
||||||
|
} from 'lucide-react';
|
||||||
|
import { API, showError, showSuccess, showWarning, stringToColor } from '../../../helpers';
|
||||||
|
import { DEFAULT_ENDPOINT } from '../../../constants';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import {
|
||||||
|
IllustrationNoResult,
|
||||||
|
IllustrationNoResultDark
|
||||||
|
} from '@douyinfe/semi-illustrations';
|
||||||
|
import ChannelSelectorModal from '../../../components/settings/ChannelSelectorModal';
|
||||||
|
|
||||||
|
export default function UpstreamRatioSync(props) {
|
||||||
|
const { t } = useTranslation();
|
||||||
|
const [modalVisible, setModalVisible] = useState(false);
|
||||||
|
const [loading, setLoading] = useState(false);
|
||||||
|
const [syncLoading, setSyncLoading] = useState(false);
|
||||||
|
|
||||||
|
// 渠道选择相关
|
||||||
|
const [allChannels, setAllChannels] = useState([]);
|
||||||
|
const [selectedChannelIds, setSelectedChannelIds] = useState([]);
|
||||||
|
|
||||||
|
// 渠道端点配置
|
||||||
|
const [channelEndpoints, setChannelEndpoints] = useState({}); // { channelId: endpoint }
|
||||||
|
|
||||||
|
// 差异数据和测试结果
|
||||||
|
const [differences, setDifferences] = useState({});
|
||||||
|
const [resolutions, setResolutions] = useState({});
|
||||||
|
|
||||||
|
// 是否已经执行过同步
|
||||||
|
const [hasSynced, setHasSynced] = useState(false);
|
||||||
|
|
||||||
|
// 分页相关状态
|
||||||
|
const [currentPage, setCurrentPage] = useState(1);
|
||||||
|
const [pageSize, setPageSize] = useState(10);
|
||||||
|
|
||||||
|
// 搜索相关状态
|
||||||
|
const [searchKeyword, setSearchKeyword] = useState('');
|
||||||
|
|
||||||
|
const fetchAllChannels = async () => {
|
||||||
|
setLoading(true);
|
||||||
|
try {
|
||||||
|
const res = await API.get('/api/ratio_sync/channels');
|
||||||
|
|
||||||
|
if (res.data.success) {
|
||||||
|
const channels = res.data.data || [];
|
||||||
|
|
||||||
|
const transferData = channels.map(channel => ({
|
||||||
|
key: channel.id,
|
||||||
|
label: channel.name,
|
||||||
|
value: channel.id,
|
||||||
|
disabled: false,
|
||||||
|
_originalData: channel,
|
||||||
|
}));
|
||||||
|
|
||||||
|
setAllChannels(transferData);
|
||||||
|
|
||||||
|
const initialEndpoints = {};
|
||||||
|
transferData.forEach(channel => {
|
||||||
|
initialEndpoints[channel.key] = DEFAULT_ENDPOINT;
|
||||||
|
});
|
||||||
|
setChannelEndpoints(initialEndpoints);
|
||||||
|
} else {
|
||||||
|
showError(res.data.message);
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
showError(t('获取渠道失败:') + error.message);
|
||||||
|
} finally {
|
||||||
|
setLoading(false);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const confirmChannelSelection = () => {
|
||||||
|
const selected = allChannels
|
||||||
|
.filter(ch => selectedChannelIds.includes(ch.value))
|
||||||
|
.map(ch => ch._originalData);
|
||||||
|
|
||||||
|
if (selected.length === 0) {
|
||||||
|
showWarning(t('请至少选择一个渠道'));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
setModalVisible(false);
|
||||||
|
fetchRatiosFromChannels(selected);
|
||||||
|
};
|
||||||
|
|
||||||
|
const fetchRatiosFromChannels = async (channelList) => {
|
||||||
|
setSyncLoading(true);
|
||||||
|
|
||||||
|
const payload = {
|
||||||
|
channel_ids: channelList.map(ch => parseInt(ch.id)),
|
||||||
|
timeout: 10,
|
||||||
|
};
|
||||||
|
|
||||||
|
try {
|
||||||
|
const res = await API.post('/api/ratio_sync/fetch', payload);
|
||||||
|
|
||||||
|
if (!res.data.success) {
|
||||||
|
showError(res.data.message || t('后端请求失败'));
|
||||||
|
setSyncLoading(false);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const { differences = {}, test_results = [] } = res.data.data;
|
||||||
|
|
||||||
|
const errorResults = test_results.filter(r => r.status === 'error');
|
||||||
|
if (errorResults.length > 0) {
|
||||||
|
showWarning(t('部分渠道测试失败:') + errorResults.map(r => `${r.name}: ${r.error}`).join(', '));
|
||||||
|
}
|
||||||
|
|
||||||
|
setDifferences(differences);
|
||||||
|
setResolutions({});
|
||||||
|
setHasSynced(true);
|
||||||
|
|
||||||
|
if (Object.keys(differences).length === 0) {
|
||||||
|
showSuccess(t('已与上游倍率完全一致,无需同步'));
|
||||||
|
}
|
||||||
|
} catch (e) {
|
||||||
|
showError(t('请求后端接口失败:') + e.message);
|
||||||
|
} finally {
|
||||||
|
setSyncLoading(false);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const selectValue = (model, ratioType, value) => {
|
||||||
|
setResolutions(prev => ({
|
||||||
|
...prev,
|
||||||
|
[model]: {
|
||||||
|
...prev[model],
|
||||||
|
[ratioType]: value,
|
||||||
|
},
|
||||||
|
}));
|
||||||
|
};
|
||||||
|
|
||||||
|
const applySync = async () => {
|
||||||
|
const currentRatios = {
|
||||||
|
ModelRatio: JSON.parse(props.options.ModelRatio || '{}'),
|
||||||
|
CompletionRatio: JSON.parse(props.options.CompletionRatio || '{}'),
|
||||||
|
CacheRatio: JSON.parse(props.options.CacheRatio || '{}'),
|
||||||
|
ModelPrice: JSON.parse(props.options.ModelPrice || '{}'),
|
||||||
|
};
|
||||||
|
|
||||||
|
Object.entries(resolutions).forEach(([model, ratios]) => {
|
||||||
|
Object.entries(ratios).forEach(([ratioType, value]) => {
|
||||||
|
const optionKey = ratioType
|
||||||
|
.split('_')
|
||||||
|
.map(word => word.charAt(0).toUpperCase() + word.slice(1))
|
||||||
|
.join('');
|
||||||
|
currentRatios[optionKey][model] = parseFloat(value);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
setLoading(true);
|
||||||
|
try {
|
||||||
|
const updates = Object.entries(currentRatios).map(([key, value]) =>
|
||||||
|
API.put('/api/option/', {
|
||||||
|
key,
|
||||||
|
value: JSON.stringify(value, null, 2),
|
||||||
|
})
|
||||||
|
);
|
||||||
|
|
||||||
|
const results = await Promise.all(updates);
|
||||||
|
|
||||||
|
if (results.every(res => res.data.success)) {
|
||||||
|
showSuccess(t('同步成功'));
|
||||||
|
props.refresh();
|
||||||
|
|
||||||
|
setDifferences(prevDifferences => {
|
||||||
|
const newDifferences = { ...prevDifferences };
|
||||||
|
|
||||||
|
Object.entries(resolutions).forEach(([model, ratios]) => {
|
||||||
|
Object.keys(ratios).forEach(ratioType => {
|
||||||
|
if (newDifferences[model] && newDifferences[model][ratioType]) {
|
||||||
|
delete newDifferences[model][ratioType];
|
||||||
|
|
||||||
|
if (Object.keys(newDifferences[model]).length === 0) {
|
||||||
|
delete newDifferences[model];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
return newDifferences;
|
||||||
|
});
|
||||||
|
|
||||||
|
setResolutions({});
|
||||||
|
} else {
|
||||||
|
showError(t('部分保存失败'));
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
showError(t('保存失败'));
|
||||||
|
} finally {
|
||||||
|
setLoading(false);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const getCurrentPageData = (dataSource) => {
|
||||||
|
const startIndex = (currentPage - 1) * pageSize;
|
||||||
|
const endIndex = startIndex + pageSize;
|
||||||
|
return dataSource.slice(startIndex, endIndex);
|
||||||
|
};
|
||||||
|
|
||||||
|
const renderHeader = () => (
|
||||||
|
<div className="flex flex-col w-full">
|
||||||
|
<div className="flex flex-col md:flex-row justify-between items-center gap-4 w-full">
|
||||||
|
<div className="flex gap-2 w-full md:w-auto order-2 md:order-1">
|
||||||
|
<Button
|
||||||
|
icon={<RefreshCcw size={14} />}
|
||||||
|
className="!rounded-full w-full md:w-auto mt-2"
|
||||||
|
onClick={() => {
|
||||||
|
setModalVisible(true);
|
||||||
|
fetchAllChannels();
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{t('选择同步渠道')}
|
||||||
|
</Button>
|
||||||
|
|
||||||
|
{(() => {
|
||||||
|
const hasSelections = Object.keys(resolutions).length > 0;
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Button
|
||||||
|
icon={<CheckSquare size={14} />}
|
||||||
|
type='secondary'
|
||||||
|
onClick={applySync}
|
||||||
|
disabled={!hasSelections}
|
||||||
|
className="!rounded-full w-full md:w-auto mt-2"
|
||||||
|
>
|
||||||
|
{t('应用同步')}
|
||||||
|
</Button>
|
||||||
|
);
|
||||||
|
})()}
|
||||||
|
|
||||||
|
<Input
|
||||||
|
prefix={<IconSearch size={14} />}
|
||||||
|
placeholder={t('搜索模型名称')}
|
||||||
|
value={searchKeyword}
|
||||||
|
onChange={setSearchKeyword}
|
||||||
|
className="!rounded-full w-full md:w-64 mt-2"
|
||||||
|
showClear
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
|
||||||
|
const renderDifferenceTable = () => {
|
||||||
|
const dataSource = useMemo(() => {
|
||||||
|
const tmp = [];
|
||||||
|
|
||||||
|
Object.entries(differences).forEach(([model, ratioTypes]) => {
|
||||||
|
Object.entries(ratioTypes).forEach(([ratioType, diff]) => {
|
||||||
|
tmp.push({
|
||||||
|
key: `${model}_${ratioType}`,
|
||||||
|
model,
|
||||||
|
ratioType,
|
||||||
|
current: diff.current,
|
||||||
|
upstreams: diff.upstreams,
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
return tmp;
|
||||||
|
}, [differences]);
|
||||||
|
|
||||||
|
const filteredDataSource = useMemo(() => {
|
||||||
|
if (!searchKeyword.trim()) {
|
||||||
|
return dataSource;
|
||||||
|
}
|
||||||
|
|
||||||
|
const keyword = searchKeyword.toLowerCase().trim();
|
||||||
|
return dataSource.filter(item =>
|
||||||
|
item.model.toLowerCase().includes(keyword)
|
||||||
|
);
|
||||||
|
}, [dataSource, searchKeyword]);
|
||||||
|
|
||||||
|
const upstreamNames = useMemo(() => {
|
||||||
|
const set = new Set();
|
||||||
|
filteredDataSource.forEach((row) => {
|
||||||
|
Object.keys(row.upstreams || {}).forEach((name) => set.add(name));
|
||||||
|
});
|
||||||
|
return Array.from(set);
|
||||||
|
}, [filteredDataSource]);
|
||||||
|
|
||||||
|
if (filteredDataSource.length === 0) {
|
||||||
|
return (
|
||||||
|
<Empty
|
||||||
|
image={<IllustrationNoResult style={{ width: 150, height: 150 }} />}
|
||||||
|
darkModeImage={<IllustrationNoResultDark style={{ width: 150, height: 150 }} />}
|
||||||
|
description={
|
||||||
|
searchKeyword.trim()
|
||||||
|
? t('未找到匹配的模型')
|
||||||
|
: (Object.keys(differences).length === 0 ?
|
||||||
|
(hasSynced ? t('暂无差异化倍率显示') : t('请先选择同步渠道'))
|
||||||
|
: t('请先选择同步渠道'))
|
||||||
|
}
|
||||||
|
style={{ padding: 30 }}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
const columns = [
|
||||||
|
{
|
||||||
|
title: t('模型'),
|
||||||
|
dataIndex: 'model',
|
||||||
|
fixed: 'left',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
title: t('倍率类型'),
|
||||||
|
dataIndex: 'ratioType',
|
||||||
|
render: (text) => {
|
||||||
|
const typeMap = {
|
||||||
|
model_ratio: t('模型倍率'),
|
||||||
|
completion_ratio: t('补全倍率'),
|
||||||
|
cache_ratio: t('缓存倍率'),
|
||||||
|
model_price: t('固定价格'),
|
||||||
|
};
|
||||||
|
return <Tag color={stringToColor(text)} shape="circle">{typeMap[text] || text}</Tag>;
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
title: t('当前值'),
|
||||||
|
dataIndex: 'current',
|
||||||
|
render: (text) => (
|
||||||
|
<Tag color={text !== null && text !== undefined ? 'blue' : 'default'} shape="circle">
|
||||||
|
{text !== null && text !== undefined ? text : t('未设置')}
|
||||||
|
</Tag>
|
||||||
|
),
|
||||||
|
},
|
||||||
|
...upstreamNames.map((upName) => {
|
||||||
|
const channelStats = (() => {
|
||||||
|
let selectableCount = 0;
|
||||||
|
let selectedCount = 0;
|
||||||
|
|
||||||
|
filteredDataSource.forEach((row) => {
|
||||||
|
const upstreamVal = row.upstreams?.[upName];
|
||||||
|
if (upstreamVal !== null && upstreamVal !== undefined && upstreamVal !== 'same') {
|
||||||
|
selectableCount++;
|
||||||
|
const isSelected = resolutions[row.model]?.[row.ratioType] === upstreamVal;
|
||||||
|
if (isSelected) {
|
||||||
|
selectedCount++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
return {
|
||||||
|
selectableCount,
|
||||||
|
selectedCount,
|
||||||
|
allSelected: selectableCount > 0 && selectedCount === selectableCount,
|
||||||
|
partiallySelected: selectedCount > 0 && selectedCount < selectableCount,
|
||||||
|
hasSelectableItems: selectableCount > 0
|
||||||
|
};
|
||||||
|
})();
|
||||||
|
|
||||||
|
const handleBulkSelect = (checked) => {
|
||||||
|
setResolutions((prev) => {
|
||||||
|
const newRes = { ...prev };
|
||||||
|
|
||||||
|
filteredDataSource.forEach((row) => {
|
||||||
|
const upstreamVal = row.upstreams?.[upName];
|
||||||
|
if (upstreamVal !== null && upstreamVal !== undefined && upstreamVal !== 'same') {
|
||||||
|
if (checked) {
|
||||||
|
if (!newRes[row.model]) newRes[row.model] = {};
|
||||||
|
newRes[row.model][row.ratioType] = upstreamVal;
|
||||||
|
} else {
|
||||||
|
if (newRes[row.model]) {
|
||||||
|
delete newRes[row.model][row.ratioType];
|
||||||
|
if (Object.keys(newRes[row.model]).length === 0) {
|
||||||
|
delete newRes[row.model];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
return newRes;
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
return {
|
||||||
|
title: channelStats.hasSelectableItems ? (
|
||||||
|
<Checkbox
|
||||||
|
checked={channelStats.allSelected}
|
||||||
|
indeterminate={channelStats.partiallySelected}
|
||||||
|
onChange={(e) => handleBulkSelect(e.target.checked)}
|
||||||
|
>
|
||||||
|
{upName}
|
||||||
|
</Checkbox>
|
||||||
|
) : (
|
||||||
|
<span>{upName}</span>
|
||||||
|
),
|
||||||
|
dataIndex: upName,
|
||||||
|
render: (_, record) => {
|
||||||
|
const upstreamVal = record.upstreams?.[upName];
|
||||||
|
|
||||||
|
if (upstreamVal === null || upstreamVal === undefined) {
|
||||||
|
return <Tag color="default" shape="circle">{t('未设置')}</Tag>;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (upstreamVal === 'same') {
|
||||||
|
return <Tag color="blue" shape="circle">{t('与本地相同')}</Tag>;
|
||||||
|
}
|
||||||
|
|
||||||
|
const isSelected = resolutions[record.model]?.[record.ratioType] === upstreamVal;
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Checkbox
|
||||||
|
checked={isSelected}
|
||||||
|
onChange={(e) => {
|
||||||
|
const isChecked = e.target.checked;
|
||||||
|
if (isChecked) {
|
||||||
|
selectValue(record.model, record.ratioType, upstreamVal);
|
||||||
|
} else {
|
||||||
|
setResolutions((prev) => {
|
||||||
|
const newRes = { ...prev };
|
||||||
|
if (newRes[record.model]) {
|
||||||
|
delete newRes[record.model][record.ratioType];
|
||||||
|
if (Object.keys(newRes[record.model]).length === 0) {
|
||||||
|
delete newRes[record.model];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return newRes;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{upstreamVal}
|
||||||
|
</Checkbox>
|
||||||
|
);
|
||||||
|
},
|
||||||
|
};
|
||||||
|
}),
|
||||||
|
];
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Table
|
||||||
|
columns={columns}
|
||||||
|
dataSource={getCurrentPageData(filteredDataSource)}
|
||||||
|
pagination={{
|
||||||
|
currentPage: currentPage,
|
||||||
|
pageSize: pageSize,
|
||||||
|
total: filteredDataSource.length,
|
||||||
|
showSizeChanger: true,
|
||||||
|
showQuickJumper: true,
|
||||||
|
formatPageText: (page) => t('第 {{start}} - {{end}} 条,共 {{total}} 条', {
|
||||||
|
start: page.currentStart,
|
||||||
|
end: page.currentEnd,
|
||||||
|
total: filteredDataSource.length,
|
||||||
|
}),
|
||||||
|
pageSizeOptions: ['5', '10', '20', '50'],
|
||||||
|
onChange: (page, size) => {
|
||||||
|
setCurrentPage(page);
|
||||||
|
setPageSize(size);
|
||||||
|
},
|
||||||
|
onShowSizeChange: (current, size) => {
|
||||||
|
setCurrentPage(1);
|
||||||
|
setPageSize(size);
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
scroll={{ x: 'max-content' }}
|
||||||
|
size='middle'
|
||||||
|
loading={loading || syncLoading}
|
||||||
|
className="rounded-xl overflow-hidden"
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
const updateChannelEndpoint = useCallback((channelId, endpoint) => {
|
||||||
|
setChannelEndpoints(prev => ({ ...prev, [channelId]: endpoint }));
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
<Form.Section text={renderHeader()}>
|
||||||
|
{renderDifferenceTable()}
|
||||||
|
</Form.Section>
|
||||||
|
|
||||||
|
<ChannelSelectorModal
|
||||||
|
t={t}
|
||||||
|
visible={modalVisible}
|
||||||
|
onCancel={() => setModalVisible(false)}
|
||||||
|
onOk={confirmChannelSelection}
|
||||||
|
allChannels={allChannels}
|
||||||
|
selectedChannelIds={selectedChannelIds}
|
||||||
|
setSelectedChannelIds={setSelectedChannelIds}
|
||||||
|
channelEndpoints={channelEndpoints}
|
||||||
|
updateChannelEndpoint={updateChannelEndpoint}
|
||||||
|
/>
|
||||||
|
</>
|
||||||
|
);
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user