Files
new-api-hunter/controller/ratio_sync.go
2026-01-15 14:43:53 +08:00

545 lines
15 KiB
Go
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package controller
import (
"context"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"strings"
"sync"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/setting/ratio_setting"
"github.com/gin-gonic/gin"
)
const (
defaultTimeoutSeconds = 10
defaultEndpoint = "/api/ratio_config"
maxConcurrentFetches = 8
maxRatioConfigBytes = 10 << 20 // 10MB
floatEpsilon = 1e-9
)
func nearlyEqual(a, b float64) bool {
if a > b {
return a-b < floatEpsilon
}
return b-a < floatEpsilon
}
func valuesEqual(a, b interface{}) bool {
af, aok := a.(float64)
bf, bok := b.(float64)
if aok && bok {
return nearlyEqual(af, bf)
}
return a == b
}
var ratioTypes = []string{"model_ratio", "completion_ratio", "cache_ratio", "model_price"}
type upstreamResult struct {
Name string `json:"name"`
Data map[string]any `json:"data,omitempty"`
Err string `json:"err,omitempty"`
}
func FetchUpstreamRatios(c *gin.Context) {
var req dto.UpstreamRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": err.Error()})
return
}
if req.Timeout <= 0 {
req.Timeout = defaultTimeoutSeconds
}
var upstreams []dto.UpstreamDTO
if len(req.Upstreams) > 0 {
for _, u := range req.Upstreams {
if strings.HasPrefix(u.BaseURL, "http") {
if u.Endpoint == "" {
u.Endpoint = defaultEndpoint
}
u.BaseURL = strings.TrimRight(u.BaseURL, "/")
upstreams = append(upstreams, u)
}
}
} else if len(req.ChannelIDs) > 0 {
intIds := make([]int, 0, len(req.ChannelIDs))
for _, id64 := range req.ChannelIDs {
intIds = append(intIds, int(id64))
}
dbChannels, err := model.GetChannelsByIds(intIds)
if err != nil {
logger.LogError(c.Request.Context(), "failed to query channels: "+err.Error())
c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": "查询渠é<C2A0>“失败"})
return
}
for _, ch := range dbChannels {
if base := ch.GetBaseURL(); strings.HasPrefix(base, "http") {
upstreams = append(upstreams, dto.UpstreamDTO{
ID: ch.Id,
Name: ch.Name,
BaseURL: strings.TrimRight(base, "/"),
Endpoint: "",
})
}
}
}
if len(upstreams) == 0 {
c.JSON(http.StatusOK, gin.H{"success": false, "message": "无有效上游渠é<C2A0>“"})
return
}
var wg sync.WaitGroup
ch := make(chan upstreamResult, len(upstreams))
sem := make(chan struct{}, maxConcurrentFetches)
dialer := &net.Dialer{Timeout: 10 * time.Second}
transport := &http.Transport{MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, ResponseHeaderTimeout: 10 * time.Second}
if common.TLSInsecureSkipVerify {
transport.TLSClientConfig = common.InsecureTLSConfig
}
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
host, _, err := net.SplitHostPort(addr)
if err != nil {
host = addr
}
// 对 github.io 优先å°<C3A5>试 IPv4,失败则åžé€€ IPv6
if strings.HasSuffix(host, "github.io") {
if conn, err := dialer.DialContext(ctx, "tcp4", addr); err == nil {
return conn, nil
}
return dialer.DialContext(ctx, "tcp6", addr)
}
return dialer.DialContext(ctx, network, addr)
}
client := &http.Client{Transport: transport}
for _, chn := range upstreams {
wg.Add(1)
go func(chItem dto.UpstreamDTO) {
defer wg.Done()
sem <- struct{}{}
defer func() { <-sem }()
endpoint := chItem.Endpoint
var fullURL string
if strings.HasPrefix(endpoint, "http://") || strings.HasPrefix(endpoint, "https://") {
fullURL = endpoint
} else {
if endpoint == "" {
endpoint = defaultEndpoint
} else if !strings.HasPrefix(endpoint, "/") {
endpoint = "/" + endpoint
}
fullURL = chItem.BaseURL + endpoint
}
uniqueName := chItem.Name
if chItem.ID != 0 {
uniqueName = fmt.Sprintf("%s(%d)", chItem.Name, chItem.ID)
}
ctx, cancel := context.WithTimeout(c.Request.Context(), time.Duration(req.Timeout)*time.Second)
defer cancel()
httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil)
if err != nil {
logger.LogWarn(c.Request.Context(), "build request failed: "+err.Error())
ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
return
}
// 简å<E282AC>•é‡<C3A9>试:最多 3 次,指数退é<E282AC>¿
var resp *http.Response
var lastErr error
for attempt := 0; attempt < 3; attempt++ {
resp, lastErr = client.Do(httpReq)
if lastErr == nil {
break
}
time.Sleep(time.Duration(200*(1<<attempt)) * time.Millisecond)
}
if lastErr != nil {
logger.LogWarn(c.Request.Context(), "http error on "+chItem.Name+": "+lastErr.Error())
ch <- upstreamResult{Name: uniqueName, Err: lastErr.Error()}
return
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
logger.LogWarn(c.Request.Context(), "non-200 from "+chItem.Name+": "+resp.Status)
ch <- upstreamResult{Name: uniqueName, Err: resp.Status}
return
}
// Content-Type åŒå“<C3A5>应体大å°<C3A5>校验
if ct := resp.Header.Get("Content-Type"); ct != "" && !strings.Contains(strings.ToLower(ct), "application/json") {
logger.LogWarn(c.Request.Context(), "unexpected content-type from "+chItem.Name+": "+ct)
}
limited := io.LimitReader(resp.Body, maxRatioConfigBytes)
// 兼容两ç§<C3A7>上游接å<C2A5>£æ ¼å¼<C3A5>:
// type1: /api/ratio_config -> data 为 map[string]any,包å<E280A6>« model_ratio/completion_ratio/cache_ratio/model_price
// type2: /api/pricing -> data 为 []Pricing 列表,需è¦<C3A8>转æ<C2AC>¢ä¸ºä¸Ž type1 ç¸å<C2B8>Œçš„ map æ ¼å¼<C3A5>
var body struct {
Success bool `json:"success"`
Data json.RawMessage `json:"data"`
Message string `json:"message"`
}
if err := json.NewDecoder(limited).Decode(&body); err != nil {
logger.LogWarn(c.Request.Context(), "json decode failed from "+chItem.Name+": "+err.Error())
ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
return
}
if !body.Success {
ch <- upstreamResult{Name: uniqueName, Err: body.Message}
return
}
// è‹¥ Data 为空,将继续按 type1 å°<C3A5>试解æž<C3A6>(与多数é<C2B0>™æ€<C3A6> ratio_config 兼容)
// å°<C3A5>试按 type1 è§£æž<C3A6>
var type1Data map[string]any
if err := json.Unmarshal(body.Data, &type1Data); err == nil {
// 妿žœåŒ…å<E280A6>«è‡³å°ä¸€ä¸ª ratioTypes 字段,则认为是 type1
isType1 := false
for _, rt := range ratioTypes {
if _, ok := type1Data[rt]; ok {
isType1 = true
break
}
}
if isType1 {
ch <- upstreamResult{Name: uniqueName, Data: type1Data}
return
}
}
// 妿žœä¸<C3A4>是 type1,则å°<C3A5>试按 type2 (/api/pricing) è§£æž<C3A6>
var pricingItems []struct {
ModelName string `json:"model_name"`
QuotaType int `json:"quota_type"`
ModelRatio float64 `json:"model_ratio"`
ModelPrice float64 `json:"model_price"`
CompletionRatio float64 `json:"completion_ratio"`
}
if err := json.Unmarshal(body.Data, &pricingItems); err != nil {
logger.LogWarn(c.Request.Context(), "unrecognized data format from "+chItem.Name+": "+err.Error())
ch <- upstreamResult{Name: uniqueName, Err: "无法解æž<C3A6>ä¸Šæ¸¸è¿”åžæ•°æ<C2B0>®"}
return
}
modelRatioMap := make(map[string]float64)
completionRatioMap := make(map[string]float64)
modelPriceMap := make(map[string]float64)
for _, item := range pricingItems {
if item.QuotaType == 1 {
modelPriceMap[item.ModelName] = item.ModelPrice
} else {
modelRatioMap[item.ModelName] = item.ModelRatio
// completionRatio å<>¯èƒ½ä¸º 0ï¼Œæ­¤æ—¶ä¹Ÿç´æŽ¥èµå€¼ï¼Œä¿<C3A4>æŒ<C3A6>与上游一致
completionRatioMap[item.ModelName] = item.CompletionRatio
}
}
converted := make(map[string]any)
if len(modelRatioMap) > 0 {
ratioAny := make(map[string]any, len(modelRatioMap))
for k, v := range modelRatioMap {
ratioAny[k] = v
}
converted["model_ratio"] = ratioAny
}
if len(completionRatioMap) > 0 {
compAny := make(map[string]any, len(completionRatioMap))
for k, v := range completionRatioMap {
compAny[k] = v
}
converted["completion_ratio"] = compAny
}
if len(modelPriceMap) > 0 {
priceAny := make(map[string]any, len(modelPriceMap))
for k, v := range modelPriceMap {
priceAny[k] = v
}
converted["model_price"] = priceAny
}
ch <- upstreamResult{Name: uniqueName, Data: converted}
}(chn)
}
wg.Wait()
close(ch)
localData := ratio_setting.GetExposedData()
var testResults []dto.TestResult
var successfulChannels []struct {
name string
data map[string]any
}
for r := range ch {
if r.Err != "" {
testResults = append(testResults, dto.TestResult{
Name: r.Name,
Status: "error",
Error: r.Err,
})
} else {
testResults = append(testResults, dto.TestResult{
Name: r.Name,
Status: "success",
})
successfulChannels = append(successfulChannels, struct {
name string
data map[string]any
}{name: r.Name, data: r.Data})
}
}
differences := buildDifferences(localData, successfulChannels)
c.JSON(http.StatusOK, gin.H{
"success": true,
"data": gin.H{
"differences": differences,
"test_results": testResults,
},
})
}
func buildDifferences(localData map[string]any, successfulChannels []struct {
name string
data map[string]any
}) map[string]map[string]dto.DifferenceItem {
differences := make(map[string]map[string]dto.DifferenceItem)
allModels := make(map[string]struct{})
for _, ratioType := range ratioTypes {
if localRatioAny, ok := localData[ratioType]; ok {
if localRatio, ok := localRatioAny.(map[string]float64); ok {
for modelName := range localRatio {
allModels[modelName] = struct{}{}
}
}
}
}
for _, channel := range successfulChannels {
for _, ratioType := range ratioTypes {
if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
for modelName := range upstreamRatio {
allModels[modelName] = struct{}{}
}
}
}
}
confidenceMap := make(map[string]map[string]bool)
// 预处ç<E2809E>†é˜¶æ®µï¼šæ£€æŸ¥pricing接å<C2A5>£çš„å<E2809E>¯ä¿¡åº¦
for _, channel := range successfulChannels {
confidenceMap[channel.name] = make(map[string]bool)
modelRatios, hasModelRatio := channel.data["model_ratio"].(map[string]any)
completionRatios, hasCompletionRatio := channel.data["completion_ratio"].(map[string]any)
if hasModelRatio && hasCompletionRatio {
// é<><C3A9>历所有模åžï¼Œæ£€æŸ¥æ˜¯å<C2AF>¦æ»¡è¶³ä¸<C3A4>å<EFBFBD>¯ä¿¡æ<C2A1>¡ä»¶
for modelName := range allModels {
// 默认为å<C2BA>¯ä¿¡
confidenceMap[channel.name][modelName] = true
// 检查是å<C2AF>¦æ»¡è¶³ä¸<C3A4>å<EFBFBD>¯ä¿¡æ<C2A1>¡ä»¶ï¼šmodel_ratio为37.5且completion_ratio为1
if modelRatioVal, ok := modelRatios[modelName]; ok {
if completionRatioVal, ok := completionRatios[modelName]; ok {
// 转æ<C2AC>¢ä¸ºfloat64è¿è¡Œæ¯”较
if modelRatioFloat, ok := modelRatioVal.(float64); ok {
if completionRatioFloat, ok := completionRatioVal.(float64); ok {
if modelRatioFloat == 37.5 && completionRatioFloat == 1.0 {
confidenceMap[channel.name][modelName] = false
}
}
}
}
}
}
} else {
// 妿žœä¸<C3A4>是从pricing接å<C2A5>£èŽ·å<C2B7>的数æ<C2B0>®ï¼Œåˆ™å…¨éƒ¨æ ‡è®°ä¸ºå<C2BA>¯ä¿¡
for modelName := range allModels {
confidenceMap[channel.name][modelName] = true
}
}
}
for modelName := range allModels {
for _, ratioType := range ratioTypes {
var localValue interface{} = nil
if localRatioAny, ok := localData[ratioType]; ok {
if localRatio, ok := localRatioAny.(map[string]float64); ok {
if val, exists := localRatio[modelName]; exists {
localValue = val
}
}
}
upstreamValues := make(map[string]interface{})
confidenceValues := make(map[string]bool)
hasUpstreamValue := false
hasDifference := false
for _, channel := range successfulChannels {
var upstreamValue interface{} = nil
if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
if val, exists := upstreamRatio[modelName]; exists {
upstreamValue = val
hasUpstreamValue = true
if localValue != nil && !valuesEqual(localValue, val) {
hasDifference = true
} else if valuesEqual(localValue, val) {
upstreamValue = "same"
}
}
}
if upstreamValue == nil && localValue == nil {
upstreamValue = "same"
}
if localValue == nil && upstreamValue != nil && upstreamValue != "same" {
hasDifference = true
}
upstreamValues[channel.name] = upstreamValue
confidenceValues[channel.name] = confidenceMap[channel.name][modelName]
}
shouldInclude := false
if localValue != nil {
if hasDifference {
shouldInclude = true
}
} else {
if hasUpstreamValue {
shouldInclude = true
}
}
if shouldInclude {
if differences[modelName] == nil {
differences[modelName] = make(map[string]dto.DifferenceItem)
}
differences[modelName][ratioType] = dto.DifferenceItem{
Current: localValue,
Upstreams: upstreamValues,
Confidence: confidenceValues,
}
}
}
}
channelHasDiff := make(map[string]bool)
for _, ratioMap := range differences {
for _, item := range ratioMap {
for chName, val := range item.Upstreams {
if val != nil && val != "same" {
channelHasDiff[chName] = true
}
}
}
}
for modelName, ratioMap := range differences {
for ratioType, item := range ratioMap {
for chName := range item.Upstreams {
if !channelHasDiff[chName] {
delete(item.Upstreams, chName)
delete(item.Confidence, chName)
}
}
allSame := true
for _, v := range item.Upstreams {
if v != "same" {
allSame = false
break
}
}
if len(item.Upstreams) == 0 || allSame {
delete(ratioMap, ratioType)
} else {
differences[modelName][ratioType] = item
}
}
if len(ratioMap) == 0 {
delete(differences, modelName)
}
}
return differences
}
func GetSyncableChannels(c *gin.Context) {
channels, err := model.GetAllChannels(0, 0, true, false)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
var syncableChannels []dto.SyncableChannel
for _, channel := range channels {
if channel.GetBaseURL() != "" {
syncableChannels = append(syncableChannels, dto.SyncableChannel{
ID: channel.Id,
Name: channel.Name,
BaseURL: channel.GetBaseURL(),
Status: channel.Status,
})
}
}
syncableChannels = append(syncableChannels, dto.SyncableChannel{
ID: -100,
Name: "官æ¹å€<C3A5>率预设",
BaseURL: "https://basellm.github.io",
Status: 1,
})
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": syncableChannels,
})
}