diff --git a/controller/ratio_sync.go b/controller/ratio_sync.go index 368f92dd..f749f384 100644 --- a/controller/ratio_sync.go +++ b/controller/ratio_sync.go @@ -1,41 +1,35 @@ package controller import ( + "context" "encoding/json" "net/http" - "one-api/model" - "one-api/setting/ratio_setting" - "one-api/dto" + "strings" "sync" "time" + "one-api/common" + "one-api/dto" + "one-api/model" + "one-api/setting/ratio_setting" + "github.com/gin-gonic/gin" ) +const ( + defaultTimeoutSeconds = 10 + defaultEndpoint = "/api/ratio_config" + maxConcurrentFetches = 8 +) + +var ratioTypes = []string{"model_ratio", "completion_ratio", "cache_ratio", "model_price"} + type upstreamResult struct { Name string `json:"name"` Data map[string]any `json:"data,omitempty"` Err string `json:"err,omitempty"` } -type TestResult struct { - Name string `json:"name"` - Status string `json:"status"` - Error string `json:"error,omitempty"` -} - -type DifferenceItem struct { - Current interface{} `json:"current"` // 当前本地值,可能为null - Upstreams map[string]interface{} `json:"upstreams"` // 上游值:具体值/"same"/null -} - -type SyncableChannel struct { - ID int `json:"id"` - Name string `json:"name"` - BaseURL string `json:"base_url"` - Status int `json:"status"` -} - func FetchUpstreamRatios(c *gin.Context) { var req dto.UpstreamRequest if err := c.ShouldBindJSON(&req); err != nil { @@ -44,45 +38,80 @@ func FetchUpstreamRatios(c *gin.Context) { } if req.Timeout <= 0 { - req.Timeout = 10 + req.Timeout = defaultTimeoutSeconds } var upstreams []dto.UpstreamDTO + if len(req.ChannelIDs) > 0 { intIds := make([]int, 0, len(req.ChannelIDs)) for _, id64 := range req.ChannelIDs { intIds = append(intIds, int(id64)) } - dbChannels, _ := model.GetChannelsByIds(intIds) - for _, ch := range dbChannels { - upstreams = append(upstreams, dto.UpstreamDTO{ - Name: ch.Name, - BaseURL: ch.GetBaseURL(), - Endpoint: "", - }) + dbChannels, err := model.GetChannelsByIds(intIds) + if err != nil { + common.LogError(c.Request.Context(), "failed to query channels: "+err.Error()) + c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": "查询渠道失败"}) + return } + for _, ch := range dbChannels { + if base := ch.GetBaseURL(); strings.HasPrefix(base, "http") { + upstreams = append(upstreams, dto.UpstreamDTO{ + Name: ch.Name, + BaseURL: strings.TrimRight(base, "/"), + Endpoint: "", + }) + } + } + } + + if len(upstreams) == 0 { + c.JSON(http.StatusOK, gin.H{"success": false, "message": "无有效上游渠道"}) + return } var wg sync.WaitGroup ch := make(chan upstreamResult, len(upstreams)) + sem := make(chan struct{}, maxConcurrentFetches) + + client := &http.Client{Transport: &http.Transport{MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second}} + for _, chn := range upstreams { wg.Add(1) go func(chItem dto.UpstreamDTO) { defer wg.Done() + + sem <- struct{}{} + defer func() { <-sem }() + endpoint := chItem.Endpoint if endpoint == "" { - endpoint = "/api/ratio_config" + endpoint = defaultEndpoint + } else if !strings.HasPrefix(endpoint, "/") { + endpoint = "/" + endpoint } - url := chItem.BaseURL + endpoint - client := http.Client{Timeout: time.Duration(req.Timeout) * time.Second} - resp, err := client.Get(url) + fullURL := chItem.BaseURL + endpoint + + ctx, cancel := context.WithTimeout(c.Request.Context(), time.Duration(req.Timeout)*time.Second) + defer cancel() + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil) if err != nil { + common.LogWarn(c.Request.Context(), "build request failed: "+err.Error()) + ch <- upstreamResult{Name: chItem.Name, Err: err.Error()} + return + } + + resp, err := client.Do(httpReq) + if err != nil { + common.LogWarn(c.Request.Context(), "http error on "+chItem.Name+": "+err.Error()) ch <- upstreamResult{Name: chItem.Name, Err: err.Error()} return } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { + common.LogWarn(c.Request.Context(), "non-200 from "+chItem.Name+": "+resp.Status) ch <- upstreamResult{Name: chItem.Name, Err: resp.Status} return } @@ -92,6 +121,7 @@ func FetchUpstreamRatios(c *gin.Context) { Message string `json:"message"` } if err := json.NewDecoder(resp.Body).Decode(&body); err != nil { + common.LogWarn(c.Request.Context(), "json decode failed from "+chItem.Name+": "+err.Error()) ch <- upstreamResult{Name: chItem.Name, Err: err.Error()} return } @@ -149,7 +179,6 @@ func buildDifferences(localData map[string]any, successfulChannels []struct { data map[string]any }) map[string]map[string]dto.DifferenceItem { differences := make(map[string]map[string]dto.DifferenceItem) - ratioTypes := []string{"model_ratio", "completion_ratio", "cache_ratio", "model_price"} allModels := make(map[string]struct{}) diff --git a/dto/ratio_sync.go b/dto/ratio_sync.go index 4f2fe06d..55a89025 100644 --- a/dto/ratio_sync.go +++ b/dto/ratio_sync.go @@ -19,9 +19,8 @@ type UpstreamDTO struct { } type UpstreamRequest struct { - ChannelIDs []int64 `json:"channel_ids"` - CustomChannels []UpstreamDTO `json:"custom_channels"` - Timeout int `json:"timeout"` + ChannelIDs []int64 `json:"channel_ids"` + Timeout int `json:"timeout"` } // TestResult 上游测试连通性结果