diff --git a/dto/channel_settings.go b/dto/channel_settings.go index 8791f516..d6d6e084 100644 --- a/dto/channel_settings.go +++ b/dto/channel_settings.go @@ -19,4 +19,12 @@ const ( type ChannelOtherSettings struct { AzureResponsesVersion string `json:"azure_responses_version,omitempty"` VertexKeyType VertexKeyType `json:"vertex_key_type,omitempty"` // "json" or "api_key" + OpenRouterEnterprise *bool `json:"openrouter_enterprise,omitempty"` +} + +func (s *ChannelOtherSettings) IsOpenRouterEnterprise() bool { + if s == nil || s.OpenRouterEnterprise == nil { + return false + } + return *s.OpenRouterEnterprise } diff --git a/relay/channel/api_request.go b/relay/channel/api_request.go index a065caff..79a0f706 100644 --- a/relay/channel/api_request.go +++ b/relay/channel/api_request.go @@ -265,6 +265,7 @@ func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http resp, err := client.Do(req) if err != nil { + logger.LogError(c, "do request failed: "+err.Error()) return nil, types.NewError(err, types.ErrorCodeDoRequestFailed, types.ErrOptionWithHideErrMsg("upstream error: do request failed")) } if resp == nil { diff --git a/relay/channel/ollama/adaptor.go b/relay/channel/ollama/adaptor.go index d6b5b697..bafe73b9 100644 --- a/relay/channel/ollama/adaptor.go +++ b/relay/channel/ollama/adaptor.go @@ -10,6 +10,7 @@ import ( relaycommon "one-api/relay/common" relayconstant "one-api/relay/constant" "one-api/types" + "strings" "github.com/gin-gonic/gin" ) @@ -17,10 +18,7 @@ import ( type Adaptor struct { } -func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) { - //TODO implement me - return nil, errors.New("not implemented") -} +func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) { return nil, errors.New("not implemented") } func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) { openaiAdaptor := openai.Adaptor{} @@ -31,32 +29,21 @@ func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayIn openaiRequest.(*dto.GeneralOpenAIRequest).StreamOptions = &dto.StreamOptions{ IncludeUsage: true, } - return requestOpenAI2Ollama(c, openaiRequest.(*dto.GeneralOpenAIRequest)) + // map to ollama chat request (Claude -> OpenAI -> Ollama chat) + return openAIChatToOllamaChat(c, openaiRequest.(*dto.GeneralOpenAIRequest)) } -func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { - //TODO implement me - return nil, errors.New("not implemented") -} +func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { return nil, errors.New("not implemented") } -func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { - //TODO implement me - return nil, errors.New("not implemented") -} +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { return nil, errors.New("not implemented") } func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { - if info.RelayFormat == types.RelayFormatClaude { - return info.ChannelBaseUrl + "/v1/chat/completions", nil - } - switch info.RelayMode { - case relayconstant.RelayModeEmbeddings: - return info.ChannelBaseUrl + "/api/embed", nil - default: - return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, info.RequestURLPath, info.ChannelType), nil - } + if info.RelayMode == relayconstant.RelayModeEmbeddings { return info.ChannelBaseUrl + "/api/embed", nil } + if strings.Contains(info.RequestURLPath, "/v1/completions") || info.RelayMode == relayconstant.RelayModeCompletions { return info.ChannelBaseUrl + "/api/generate", nil } + return info.ChannelBaseUrl + "/api/chat", nil } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { @@ -66,10 +53,12 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel } func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { - if request == nil { - return nil, errors.New("request is nil") + if request == nil { return nil, errors.New("request is nil") } + // decide generate or chat + if strings.Contains(info.RequestURLPath, "/v1/completions") || info.RelayMode == relayconstant.RelayModeCompletions { + return openAIToGenerate(c, request) } - return requestOpenAI2Ollama(c, request) + return openAIChatToOllamaChat(c, request) } func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { @@ -80,10 +69,7 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela return requestOpenAI2Embeddings(request), nil } -func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { - // TODO implement me - return nil, errors.New("not implemented") -} +func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { return nil, errors.New("not implemented") } func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) @@ -92,15 +78,13 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { switch info.RelayMode { case relayconstant.RelayModeEmbeddings: - usage, err = ollamaEmbeddingHandler(c, info, resp) + return ollamaEmbeddingHandler(c, info, resp) default: if info.IsStream { - usage, err = openai.OaiStreamHandler(c, info, resp) - } else { - usage, err = openai.OpenaiHandler(c, info, resp) + return ollamaStreamHandler(c, info, resp) } + return ollamaChatHandler(c, info, resp) } - return } func (a *Adaptor) GetModelList() []string { diff --git a/relay/channel/ollama/dto.go b/relay/channel/ollama/dto.go index 317c2a4a..45e49ab4 100644 --- a/relay/channel/ollama/dto.go +++ b/relay/channel/ollama/dto.go @@ -2,48 +2,69 @@ package ollama import ( "encoding/json" - "one-api/dto" ) -type OllamaRequest struct { - Model string `json:"model,omitempty"` - Messages []dto.Message `json:"messages,omitempty"` - Stream bool `json:"stream,omitempty"` - Temperature *float64 `json:"temperature,omitempty"` - Seed float64 `json:"seed,omitempty"` - Topp float64 `json:"top_p,omitempty"` - TopK int `json:"top_k,omitempty"` - Stop any `json:"stop,omitempty"` - MaxTokens uint `json:"max_tokens,omitempty"` - Tools []dto.ToolCallRequest `json:"tools,omitempty"` - ResponseFormat any `json:"response_format,omitempty"` - FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` - PresencePenalty float64 `json:"presence_penalty,omitempty"` - Suffix any `json:"suffix,omitempty"` - StreamOptions *dto.StreamOptions `json:"stream_options,omitempty"` - Prompt any `json:"prompt,omitempty"` - Think json.RawMessage `json:"think,omitempty"` +type OllamaChatMessage struct { + Role string `json:"role"` + Content string `json:"content,omitempty"` + Images []string `json:"images,omitempty"` + ToolCalls []OllamaToolCall `json:"tool_calls,omitempty"` + ToolName string `json:"tool_name,omitempty"` + Thinking json.RawMessage `json:"thinking,omitempty"` } -type Options struct { - Seed int `json:"seed,omitempty"` - Temperature *float64 `json:"temperature,omitempty"` - TopK int `json:"top_k,omitempty"` - TopP float64 `json:"top_p,omitempty"` - FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` - PresencePenalty float64 `json:"presence_penalty,omitempty"` - NumPredict int `json:"num_predict,omitempty"` - NumCtx int `json:"num_ctx,omitempty"` +type OllamaToolFunction struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + Parameters interface{} `json:"parameters,omitempty"` +} + +type OllamaTool struct { + Type string `json:"type"` + Function OllamaToolFunction `json:"function"` +} + +type OllamaToolCall struct { + Function struct { + Name string `json:"name"` + Arguments interface{} `json:"arguments"` + } `json:"function"` +} + +type OllamaChatRequest struct { + Model string `json:"model"` + Messages []OllamaChatMessage `json:"messages"` + Tools interface{} `json:"tools,omitempty"` + Format interface{} `json:"format,omitempty"` + Stream bool `json:"stream,omitempty"` + Options map[string]any `json:"options,omitempty"` + KeepAlive interface{} `json:"keep_alive,omitempty"` + Think json.RawMessage `json:"think,omitempty"` +} + +type OllamaGenerateRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt,omitempty"` + Suffix string `json:"suffix,omitempty"` + Images []string `json:"images,omitempty"` + Format interface{} `json:"format,omitempty"` + Stream bool `json:"stream,omitempty"` + Options map[string]any `json:"options,omitempty"` + KeepAlive interface{} `json:"keep_alive,omitempty"` + Think json.RawMessage `json:"think,omitempty"` } type OllamaEmbeddingRequest struct { - Model string `json:"model,omitempty"` - Input []string `json:"input"` - Options *Options `json:"options,omitempty"` + Model string `json:"model"` + Input interface{} `json:"input"` + Options map[string]any `json:"options,omitempty"` + Dimensions int `json:"dimensions,omitempty"` } type OllamaEmbeddingResponse struct { - Error string `json:"error,omitempty"` - Model string `json:"model"` - Embedding [][]float64 `json:"embeddings,omitempty"` + Error string `json:"error,omitempty"` + Model string `json:"model"` + Embeddings [][]float64 `json:"embeddings"` + PromptEvalCount int `json:"prompt_eval_count,omitempty"` } + diff --git a/relay/channel/ollama/relay-ollama.go b/relay/channel/ollama/relay-ollama.go index 27c67b4e..3b67f952 100644 --- a/relay/channel/ollama/relay-ollama.go +++ b/relay/channel/ollama/relay-ollama.go @@ -1,6 +1,7 @@ package ollama import ( + "encoding/json" "fmt" "io" "net/http" @@ -14,121 +15,176 @@ import ( "github.com/gin-gonic/gin" ) -func requestOpenAI2Ollama(c *gin.Context, request *dto.GeneralOpenAIRequest) (*OllamaRequest, error) { - messages := make([]dto.Message, 0, len(request.Messages)) - for _, message := range request.Messages { - if !message.IsStringContent() { - mediaMessages := message.ParseContent() - for j, mediaMessage := range mediaMessages { - if mediaMessage.Type == dto.ContentTypeImageURL { - imageUrl := mediaMessage.GetImageMedia() - // check if not base64 - if strings.HasPrefix(imageUrl.Url, "http") { - fileData, err := service.GetFileBase64FromUrl(c, imageUrl.Url, "formatting image for Ollama") - if err != nil { - return nil, err +func openAIChatToOllamaChat(c *gin.Context, r *dto.GeneralOpenAIRequest) (*OllamaChatRequest, error) { + chatReq := &OllamaChatRequest{ + Model: r.Model, + Stream: r.Stream, + Options: map[string]any{}, + Think: r.Think, + } + if r.ResponseFormat != nil { + if r.ResponseFormat.Type == "json" { + chatReq.Format = "json" + } else if r.ResponseFormat.Type == "json_schema" { + if len(r.ResponseFormat.JsonSchema) > 0 { + var schema any + _ = json.Unmarshal(r.ResponseFormat.JsonSchema, &schema) + chatReq.Format = schema + } + } + } + + // options mapping + if r.Temperature != nil { chatReq.Options["temperature"] = r.Temperature } + if r.TopP != 0 { chatReq.Options["top_p"] = r.TopP } + if r.TopK != 0 { chatReq.Options["top_k"] = r.TopK } + if r.FrequencyPenalty != 0 { chatReq.Options["frequency_penalty"] = r.FrequencyPenalty } + if r.PresencePenalty != 0 { chatReq.Options["presence_penalty"] = r.PresencePenalty } + if r.Seed != 0 { chatReq.Options["seed"] = int(r.Seed) } + if mt := r.GetMaxTokens(); mt != 0 { chatReq.Options["num_predict"] = int(mt) } + + if r.Stop != nil { + switch v := r.Stop.(type) { + case string: + chatReq.Options["stop"] = []string{v} + case []string: + chatReq.Options["stop"] = v + case []any: + arr := make([]string,0,len(v)) + for _, i := range v { if s,ok:=i.(string); ok { arr = append(arr,s) } } + if len(arr)>0 { chatReq.Options["stop"] = arr } + } + } + + if len(r.Tools) > 0 { + tools := make([]OllamaTool,0,len(r.Tools)) + for _, t := range r.Tools { + tools = append(tools, OllamaTool{Type: "function", Function: OllamaToolFunction{Name: t.Function.Name, Description: t.Function.Description, Parameters: t.Function.Parameters}}) + } + chatReq.Tools = tools + } + + chatReq.Messages = make([]OllamaChatMessage,0,len(r.Messages)) + for _, m := range r.Messages { + var textBuilder strings.Builder + var images []string + if m.IsStringContent() { + textBuilder.WriteString(m.StringContent()) + } else { + parts := m.ParseContent() + for _, part := range parts { + if part.Type == dto.ContentTypeImageURL { + img := part.GetImageMedia() + if img != nil && img.Url != "" { + var base64Data string + if strings.HasPrefix(img.Url, "http") { + fileData, err := service.GetFileBase64FromUrl(c, img.Url, "fetch image for ollama chat") + if err != nil { return nil, err } + base64Data = fileData.Base64Data + } else if strings.HasPrefix(img.Url, "data:") { + if idx := strings.Index(img.Url, ","); idx != -1 && idx+1 < len(img.Url) { base64Data = img.Url[idx+1:] } + } else { + base64Data = img.Url } - imageUrl.Url = fmt.Sprintf("data:%s;base64,%s", fileData.MimeType, fileData.Base64Data) + if base64Data != "" { images = append(images, base64Data) } } - mediaMessage.ImageUrl = imageUrl - mediaMessages[j] = mediaMessage + } else if part.Type == dto.ContentTypeText { + textBuilder.WriteString(part.Text) } } - message.SetMediaContent(mediaMessages) } - messages = append(messages, dto.Message{ - Role: message.Role, - Content: message.Content, - ToolCalls: message.ToolCalls, - ToolCallId: message.ToolCallId, - }) + cm := OllamaChatMessage{Role: m.Role, Content: textBuilder.String()} + if len(images)>0 { cm.Images = images } + if m.Role == "tool" && m.Name != nil { cm.ToolName = *m.Name } + if m.ToolCalls != nil && len(m.ToolCalls) > 0 { + parsed := m.ParseToolCalls() + if len(parsed) > 0 { + calls := make([]OllamaToolCall,0,len(parsed)) + for _, tc := range parsed { + var args interface{} + if tc.Function.Arguments != "" { _ = json.Unmarshal([]byte(tc.Function.Arguments), &args) } + if args==nil { args = map[string]any{} } + oc := OllamaToolCall{} + oc.Function.Name = tc.Function.Name + oc.Function.Arguments = args + calls = append(calls, oc) + } + cm.ToolCalls = calls + } + } + chatReq.Messages = append(chatReq.Messages, cm) } - str, ok := request.Stop.(string) - var Stop []string - if ok { - Stop = []string{str} - } else { - Stop, _ = request.Stop.([]string) - } - ollamaRequest := &OllamaRequest{ - Model: request.Model, - Messages: messages, - Stream: request.Stream, - Temperature: request.Temperature, - Seed: request.Seed, - Topp: request.TopP, - TopK: request.TopK, - Stop: Stop, - Tools: request.Tools, - MaxTokens: request.GetMaxTokens(), - ResponseFormat: request.ResponseFormat, - FrequencyPenalty: request.FrequencyPenalty, - PresencePenalty: request.PresencePenalty, - Prompt: request.Prompt, - StreamOptions: request.StreamOptions, - Suffix: request.Suffix, - } - ollamaRequest.Think = request.Think - return ollamaRequest, nil + return chatReq, nil } -func requestOpenAI2Embeddings(request dto.EmbeddingRequest) *OllamaEmbeddingRequest { - return &OllamaEmbeddingRequest{ - Model: request.Model, - Input: request.ParseInput(), - Options: &Options{ - Seed: int(request.Seed), - Temperature: request.Temperature, - TopP: request.TopP, - FrequencyPenalty: request.FrequencyPenalty, - PresencePenalty: request.PresencePenalty, - }, +// openAIToGenerate converts OpenAI completions request to Ollama generate +func openAIToGenerate(c *gin.Context, r *dto.GeneralOpenAIRequest) (*OllamaGenerateRequest, error) { + gen := &OllamaGenerateRequest{ + Model: r.Model, + Stream: r.Stream, + Options: map[string]any{}, + Think: r.Think, } + // Prompt may be in r.Prompt (string or []any) + if r.Prompt != nil { + switch v := r.Prompt.(type) { + case string: + gen.Prompt = v + case []any: + var sb strings.Builder + for _, it := range v { if s,ok:=it.(string); ok { sb.WriteString(s) } } + gen.Prompt = sb.String() + default: + gen.Prompt = fmt.Sprintf("%v", r.Prompt) + } + } + if r.Suffix != nil { if s,ok:=r.Suffix.(string); ok { gen.Suffix = s } } + if r.ResponseFormat != nil { + if r.ResponseFormat.Type == "json" { gen.Format = "json" } else if r.ResponseFormat.Type == "json_schema" { var schema any; _ = json.Unmarshal(r.ResponseFormat.JsonSchema,&schema); gen.Format=schema } + } + if r.Temperature != nil { gen.Options["temperature"] = r.Temperature } + if r.TopP != 0 { gen.Options["top_p"] = r.TopP } + if r.TopK != 0 { gen.Options["top_k"] = r.TopK } + if r.FrequencyPenalty != 0 { gen.Options["frequency_penalty"] = r.FrequencyPenalty } + if r.PresencePenalty != 0 { gen.Options["presence_penalty"] = r.PresencePenalty } + if r.Seed != 0 { gen.Options["seed"] = int(r.Seed) } + if mt := r.GetMaxTokens(); mt != 0 { gen.Options["num_predict"] = int(mt) } + if r.Stop != nil { + switch v := r.Stop.(type) { + case string: gen.Options["stop"] = []string{v} + case []string: gen.Options["stop"] = v + case []any: arr:=make([]string,0,len(v)); for _,i:= range v { if s,ok:=i.(string); ok { arr=append(arr,s) } }; if len(arr)>0 { gen.Options["stop"]=arr } + } + } + return gen, nil +} + +func requestOpenAI2Embeddings(r dto.EmbeddingRequest) *OllamaEmbeddingRequest { + opts := map[string]any{} + if r.Temperature != nil { opts["temperature"] = r.Temperature } + if r.TopP != 0 { opts["top_p"] = r.TopP } + if r.FrequencyPenalty != 0 { opts["frequency_penalty"] = r.FrequencyPenalty } + if r.PresencePenalty != 0 { opts["presence_penalty"] = r.PresencePenalty } + if r.Seed != 0 { opts["seed"] = int(r.Seed) } + if r.Dimensions != 0 { opts["dimensions"] = r.Dimensions } + input := r.ParseInput() + if len(input)==1 { return &OllamaEmbeddingRequest{Model:r.Model, Input: input[0], Options: opts, Dimensions:r.Dimensions} } + return &OllamaEmbeddingRequest{Model:r.Model, Input: input, Options: opts, Dimensions:r.Dimensions} } func ollamaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { - var ollamaEmbeddingResponse OllamaEmbeddingResponse - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) - } + var oResp OllamaEmbeddingResponse + body, err := io.ReadAll(resp.Body) + if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } service.CloseResponseBodyGracefully(resp) - err = common.Unmarshal(responseBody, &ollamaEmbeddingResponse) - if err != nil { - return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) - } - if ollamaEmbeddingResponse.Error != "" { - return nil, types.NewOpenAIError(fmt.Errorf("ollama error: %s", ollamaEmbeddingResponse.Error), types.ErrorCodeBadResponseBody, http.StatusInternalServerError) - } - flattenedEmbeddings := flattenEmbeddings(ollamaEmbeddingResponse.Embedding) - data := make([]dto.OpenAIEmbeddingResponseItem, 0, 1) - data = append(data, dto.OpenAIEmbeddingResponseItem{ - Embedding: flattenedEmbeddings, - Object: "embedding", - }) - usage := &dto.Usage{ - TotalTokens: info.PromptTokens, - CompletionTokens: 0, - PromptTokens: info.PromptTokens, - } - embeddingResponse := &dto.OpenAIEmbeddingResponse{ - Object: "list", - Data: data, - Model: info.UpstreamModelName, - Usage: *usage, - } - doResponseBody, err := common.Marshal(embeddingResponse) - if err != nil { - return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) - } - service.IOCopyBytesGracefully(c, resp, doResponseBody) + if err = common.Unmarshal(body, &oResp); err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } + if oResp.Error != "" { return nil, types.NewOpenAIError(fmt.Errorf("ollama error: %s", oResp.Error), types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } + data := make([]dto.OpenAIEmbeddingResponseItem,0,len(oResp.Embeddings)) + for i, emb := range oResp.Embeddings { data = append(data, dto.OpenAIEmbeddingResponseItem{Index:i,Object:"embedding",Embedding:emb}) } + usage := &dto.Usage{PromptTokens: oResp.PromptEvalCount, CompletionTokens:0, TotalTokens: oResp.PromptEvalCount} + embResp := &dto.OpenAIEmbeddingResponse{Object:"list", Data:data, Model: info.UpstreamModelName, Usage:*usage} + out, _ := common.Marshal(embResp) + service.IOCopyBytesGracefully(c, resp, out) return usage, nil } -func flattenEmbeddings(embeddings [][]float64) []float64 { - flattened := []float64{} - for _, row := range embeddings { - flattened = append(flattened, row...) - } - return flattened -} diff --git a/relay/channel/ollama/stream.go b/relay/channel/ollama/stream.go new file mode 100644 index 00000000..964f11d9 --- /dev/null +++ b/relay/channel/ollama/stream.go @@ -0,0 +1,210 @@ +package ollama + +import ( + "bufio" + "encoding/json" + "fmt" + "io" + "net/http" + "one-api/common" + "one-api/dto" + "one-api/logger" + relaycommon "one-api/relay/common" + "one-api/relay/helper" + "one-api/service" + "one-api/types" + "strings" + "time" + + "github.com/gin-gonic/gin" +) + +type ollamaChatStreamChunk struct { + Model string `json:"model"` + CreatedAt string `json:"created_at"` + // chat + Message *struct { + Role string `json:"role"` + Content string `json:"content"` + Thinking json.RawMessage `json:"thinking"` + ToolCalls []struct { + Function struct { + Name string `json:"name"` + Arguments interface{} `json:"arguments"` + } `json:"function"` + } `json:"tool_calls"` + } `json:"message"` + // generate + Response string `json:"response"` + Done bool `json:"done"` + DoneReason string `json:"done_reason"` + TotalDuration int64 `json:"total_duration"` + LoadDuration int64 `json:"load_duration"` + PromptEvalCount int `json:"prompt_eval_count"` + EvalCount int `json:"eval_count"` + PromptEvalDuration int64 `json:"prompt_eval_duration"` + EvalDuration int64 `json:"eval_duration"` +} + +func toUnix(ts string) int64 { + if ts == "" { return time.Now().Unix() } + // try time.RFC3339 or with nanoseconds + t, err := time.Parse(time.RFC3339Nano, ts) + if err != nil { t2, err2 := time.Parse(time.RFC3339, ts); if err2==nil { return t2.Unix() }; return time.Now().Unix() } + return t.Unix() +} + +func ollamaStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { + if resp == nil || resp.Body == nil { return nil, types.NewOpenAIError(fmt.Errorf("empty response"), types.ErrorCodeBadResponse, http.StatusBadRequest) } + defer service.CloseResponseBodyGracefully(resp) + + helper.SetEventStreamHeaders(c) + scanner := bufio.NewScanner(resp.Body) + usage := &dto.Usage{} + var model = info.UpstreamModelName + var responseId = common.GetUUID() + var created = time.Now().Unix() + var toolCallIndex int + start := helper.GenerateStartEmptyResponse(responseId, created, model, nil) + if data, err := common.Marshal(start); err == nil { _ = helper.StringData(c, string(data)) } + + for scanner.Scan() { + line := scanner.Text() + line = strings.TrimSpace(line) + if line == "" { continue } + var chunk ollamaChatStreamChunk + if err := json.Unmarshal([]byte(line), &chunk); err != nil { + logger.LogError(c, "ollama stream json decode error: "+err.Error()+" line="+line) + return usage, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) + } + if chunk.Model != "" { model = chunk.Model } + created = toUnix(chunk.CreatedAt) + + if !chunk.Done { + // delta content + var content string + if chunk.Message != nil { content = chunk.Message.Content } else { content = chunk.Response } + delta := dto.ChatCompletionsStreamResponse{ + Id: responseId, + Object: "chat.completion.chunk", + Created: created, + Model: model, + Choices: []dto.ChatCompletionsStreamResponseChoice{ { + Index: 0, + Delta: dto.ChatCompletionsStreamResponseChoiceDelta{ Role: "assistant" }, + } }, + } + if content != "" { delta.Choices[0].Delta.SetContentString(content) } + if chunk.Message != nil && len(chunk.Message.Thinking) > 0 { + raw := strings.TrimSpace(string(chunk.Message.Thinking)) + if raw != "" && raw != "null" { delta.Choices[0].Delta.SetReasoningContent(raw) } + } + // tool calls + if chunk.Message != nil && len(chunk.Message.ToolCalls) > 0 { + delta.Choices[0].Delta.ToolCalls = make([]dto.ToolCallResponse,0,len(chunk.Message.ToolCalls)) + for _, tc := range chunk.Message.ToolCalls { + // arguments -> string + argBytes, _ := json.Marshal(tc.Function.Arguments) + toolId := fmt.Sprintf("call_%d", toolCallIndex) + tr := dto.ToolCallResponse{ID:toolId, Type:"function", Function: dto.FunctionResponse{Name: tc.Function.Name, Arguments: string(argBytes)}} + tr.SetIndex(toolCallIndex) + toolCallIndex++ + delta.Choices[0].Delta.ToolCalls = append(delta.Choices[0].Delta.ToolCalls, tr) + } + } + if data, err := common.Marshal(delta); err == nil { _ = helper.StringData(c, string(data)) } + continue + } + // done frame + // finalize once and break loop + usage.PromptTokens = chunk.PromptEvalCount + usage.CompletionTokens = chunk.EvalCount + usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens + finishReason := chunk.DoneReason + if finishReason == "" { finishReason = "stop" } + // emit stop delta + if stop := helper.GenerateStopResponse(responseId, created, model, finishReason); stop != nil { + if data, err := common.Marshal(stop); err == nil { _ = helper.StringData(c, string(data)) } + } + // emit usage frame + if final := helper.GenerateFinalUsageResponse(responseId, created, model, *usage); final != nil { + if data, err := common.Marshal(final); err == nil { _ = helper.StringData(c, string(data)) } + } + // send [DONE] + helper.Done(c) + break + } + if err := scanner.Err(); err != nil && err != io.EOF { logger.LogError(c, "ollama stream scan error: "+err.Error()) } + return usage, nil +} + +// non-stream handler for chat/generate +func ollamaChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { + body, err := io.ReadAll(resp.Body) + if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) } + service.CloseResponseBodyGracefully(resp) + raw := string(body) + if common.DebugEnabled { println("ollama non-stream raw resp:", raw) } + + lines := strings.Split(raw, "\n") + var ( + aggContent strings.Builder + reasoningBuilder strings.Builder + lastChunk ollamaChatStreamChunk + parsedAny bool + ) + for _, ln := range lines { + ln = strings.TrimSpace(ln) + if ln == "" { continue } + var ck ollamaChatStreamChunk + if err := json.Unmarshal([]byte(ln), &ck); err != nil { + if len(lines) == 1 { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } + continue + } + parsedAny = true + lastChunk = ck + if ck.Message != nil && len(ck.Message.Thinking) > 0 { + raw := strings.TrimSpace(string(ck.Message.Thinking)) + if raw != "" && raw != "null" { reasoningBuilder.WriteString(raw) } + } + if ck.Message != nil && ck.Message.Content != "" { aggContent.WriteString(ck.Message.Content) } else if ck.Response != "" { aggContent.WriteString(ck.Response) } + } + + if !parsedAny { + var single ollamaChatStreamChunk + if err := json.Unmarshal(body, &single); err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } + lastChunk = single + if single.Message != nil { + if len(single.Message.Thinking) > 0 { raw := strings.TrimSpace(string(single.Message.Thinking)); if raw != "" && raw != "null" { reasoningBuilder.WriteString(raw) } } + aggContent.WriteString(single.Message.Content) + } else { aggContent.WriteString(single.Response) } + } + + model := lastChunk.Model + if model == "" { model = info.UpstreamModelName } + created := toUnix(lastChunk.CreatedAt) + usage := &dto.Usage{PromptTokens: lastChunk.PromptEvalCount, CompletionTokens: lastChunk.EvalCount, TotalTokens: lastChunk.PromptEvalCount + lastChunk.EvalCount} + content := aggContent.String() + finishReason := lastChunk.DoneReason + if finishReason == "" { finishReason = "stop" } + + msg := dto.Message{Role: "assistant", Content: contentPtr(content)} + if rc := reasoningBuilder.String(); rc != "" { msg.ReasoningContent = rc } + full := dto.OpenAITextResponse{ + Id: common.GetUUID(), + Model: model, + Object: "chat.completion", + Created: created, + Choices: []dto.OpenAITextResponseChoice{ { + Index: 0, + Message: msg, + FinishReason: finishReason, + } }, + Usage: *usage, + } + out, _ := common.Marshal(full) + service.IOCopyBytesGracefully(c, resp, out) + return usage, nil +} + +func contentPtr(s string) *string { if s=="" { return nil }; return &s } diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index 4b13a7df..a88b6850 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -12,6 +12,7 @@ import ( "one-api/constant" "one-api/dto" "one-api/logger" + "one-api/relay/channel/openrouter" relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" @@ -185,10 +186,27 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo if common.DebugEnabled { println("upstream response body:", string(responseBody)) } + // Unmarshal to simpleResponse + if info.ChannelType == constant.ChannelTypeOpenRouter && info.ChannelOtherSettings.IsOpenRouterEnterprise() { + // 尝试解析为 openrouter enterprise + var enterpriseResponse openrouter.OpenRouterEnterpriseResponse + err = common.Unmarshal(responseBody, &enterpriseResponse) + if err != nil { + return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) + } + if enterpriseResponse.Success { + responseBody = enterpriseResponse.Data + } else { + logger.LogError(c, fmt.Sprintf("openrouter enterprise response success=false, data: %s", enterpriseResponse.Data)) + return nil, types.NewOpenAIError(fmt.Errorf("openrouter response success=false"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError) + } + } + err = common.Unmarshal(responseBody, &simpleResponse) if err != nil { return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } + if oaiError := simpleResponse.GetOpenAIError(); oaiError != nil && oaiError.Type != "" { return nil, types.WithOpenAIError(*oaiError, resp.StatusCode) } diff --git a/relay/channel/openrouter/dto.go b/relay/channel/openrouter/dto.go index 607f495b..a3249985 100644 --- a/relay/channel/openrouter/dto.go +++ b/relay/channel/openrouter/dto.go @@ -1,5 +1,7 @@ package openrouter +import "encoding/json" + type RequestReasoning struct { // One of the following (not both): Effort string `json:"effort,omitempty"` // Can be "high", "medium", or "low" (OpenAI-style) @@ -7,3 +9,8 @@ type RequestReasoning struct { // Optional: Default is false. All models support this. Exclude bool `json:"exclude,omitempty"` // Set to true to exclude reasoning tokens from response } + +type OpenRouterEnterpriseResponse struct { + Data json.RawMessage `json:"data"` + Success bool `json:"success"` +} diff --git a/relay/channel/volcengine/adaptor.go b/relay/channel/volcengine/adaptor.go index eb88412a..21d6e170 100644 --- a/relay/channel/volcengine/adaptor.go +++ b/relay/channel/volcengine/adaptor.go @@ -9,6 +9,7 @@ import ( "mime/multipart" "net/http" "net/textproto" + channelconstant "one-api/constant" "one-api/dto" "one-api/relay/channel" "one-api/relay/channel/openai" @@ -188,20 +189,26 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { + // 支持自定义域名,如果未设置则使用默认域名 + baseUrl := info.ChannelBaseUrl + if baseUrl == "" { + baseUrl = channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeVolcEngine] + } + switch info.RelayMode { case constant.RelayModeChatCompletions: if strings.HasPrefix(info.UpstreamModelName, "bot") { - return fmt.Sprintf("%s/api/v3/bots/chat/completions", info.ChannelBaseUrl), nil + return fmt.Sprintf("%s/api/v3/bots/chat/completions", baseUrl), nil } - return fmt.Sprintf("%s/api/v3/chat/completions", info.ChannelBaseUrl), nil + return fmt.Sprintf("%s/api/v3/chat/completions", baseUrl), nil case constant.RelayModeEmbeddings: - return fmt.Sprintf("%s/api/v3/embeddings", info.ChannelBaseUrl), nil + return fmt.Sprintf("%s/api/v3/embeddings", baseUrl), nil case constant.RelayModeImagesGenerations: - return fmt.Sprintf("%s/api/v3/images/generations", info.ChannelBaseUrl), nil + return fmt.Sprintf("%s/api/v3/images/generations", baseUrl), nil case constant.RelayModeImagesEdits: - return fmt.Sprintf("%s/api/v3/images/edits", info.ChannelBaseUrl), nil + return fmt.Sprintf("%s/api/v3/images/edits", baseUrl), nil case constant.RelayModeRerank: - return fmt.Sprintf("%s/api/v3/rerank", info.ChannelBaseUrl), nil + return fmt.Sprintf("%s/api/v3/rerank", baseUrl), nil default: } return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode) diff --git a/relay/channel/volcengine/constants.go b/relay/channel/volcengine/constants.go index fca10e7c..87a12b27 100644 --- a/relay/channel/volcengine/constants.go +++ b/relay/channel/volcengine/constants.go @@ -9,6 +9,11 @@ var ModelList = []string{ "Doubao-lite-4k", "Doubao-embedding", "doubao-seedream-4-0-250828", + "seedream-4-0-250828", + "doubao-seedance-1-0-pro-250528", + "seedance-1-0-pro-250528", + "doubao-seed-1-6-thinking-250715", + "seed-1-6-thinking-250715", } var ChannelName = "volcengine" diff --git a/relay/channel/xunfei/relay-xunfei.go b/relay/channel/xunfei/relay-xunfei.go index 9d5c190f..9503d5d3 100644 --- a/relay/channel/xunfei/relay-xunfei.go +++ b/relay/channel/xunfei/relay-xunfei.go @@ -207,10 +207,6 @@ func xunfeiMakeRequest(textRequest dto.GeneralOpenAIRequest, domain, authUrl, ap return nil, nil, err } - defer func() { - conn.Close() - }() - data := requestOpenAI2Xunfei(textRequest, appId, domain) err = conn.WriteJSON(data) if err != nil { @@ -220,6 +216,9 @@ func xunfeiMakeRequest(textRequest dto.GeneralOpenAIRequest, domain, authUrl, ap dataChan := make(chan XunfeiChatResponse) stopChan := make(chan bool) go func() { + defer func() { + conn.Close() + }() for { _, msg, err := conn.ReadMessage() if err != nil { diff --git a/web/src/components/table/channels/modals/EditChannelModal.jsx b/web/src/components/table/channels/modals/EditChannelModal.jsx index 967bf88a..25ef68c6 100644 --- a/web/src/components/table/channels/modals/EditChannelModal.jsx +++ b/web/src/components/table/channels/modals/EditChannelModal.jsx @@ -164,6 +164,8 @@ const EditChannelModal = (props) => { settings: '', // 仅 Vertex: 密钥格式(存入 settings.vertex_key_type) vertex_key_type: 'json', + // 企业账户设置 + is_enterprise_account: false, }; const [batch, setBatch] = useState(false); const [multiToSingle, setMultiToSingle] = useState(false); @@ -189,6 +191,7 @@ const EditChannelModal = (props) => { const [channelSearchValue, setChannelSearchValue] = useState(''); const [useManualInput, setUseManualInput] = useState(false); // 是否使用手动输入模式 const [keyMode, setKeyMode] = useState('append'); // 密钥模式:replace(覆盖)或 append(追加) + const [isEnterpriseAccount, setIsEnterpriseAccount] = useState(false); // 是否为企业账户 // 2FA验证查看密钥相关状态 const [twoFAState, setTwoFAState] = useState({ @@ -235,7 +238,7 @@ const EditChannelModal = (props) => { pass_through_body_enabled: false, system_prompt: '', }); - const showApiConfigCard = inputs.type !== 45; // 控制是否显示 API 配置卡片(仅当渠道类型不是 豆包 时显示) + const showApiConfigCard = true; // 控制是否显示 API 配置卡片 const getInitValues = () => ({ ...originInputs }); // 处理渠道额外设置的更新 @@ -342,6 +345,10 @@ const EditChannelModal = (props) => { case 36: localModels = ['suno_music', 'suno_lyrics']; break; + case 45: + localModels = getChannelModels(value); + setInputs((prevInputs) => ({ ...prevInputs, base_url: 'https://ark.cn-beijing.volces.com' })); + break; default: localModels = getChannelModels(value); break; @@ -433,15 +440,19 @@ const EditChannelModal = (props) => { parsedSettings.azure_responses_version || ''; // 读取 Vertex 密钥格式 data.vertex_key_type = parsedSettings.vertex_key_type || 'json'; + // 读取企业账户设置 + data.is_enterprise_account = parsedSettings.openrouter_enterprise === true; } catch (error) { console.error('解析其他设置失败:', error); data.azure_responses_version = ''; data.region = ''; data.vertex_key_type = 'json'; + data.is_enterprise_account = false; } } else { // 兼容历史数据:老渠道没有 settings 时,默认按 json 展示 data.vertex_key_type = 'json'; + data.is_enterprise_account = false; } setInputs(data); @@ -453,6 +464,8 @@ const EditChannelModal = (props) => { } else { setAutoBan(true); } + // 同步企业账户状态 + setIsEnterpriseAccount(data.is_enterprise_account || false); setBasicModels(getChannelModels(data.type)); // 同步更新channelSettings状态显示 setChannelSettings({ @@ -712,6 +725,8 @@ const EditChannelModal = (props) => { }); // 重置密钥模式状态 setKeyMode('append'); + // 重置企业账户状态 + setIsEnterpriseAccount(false); // 清空表单中的key_mode字段 if (formApiRef.current) { formApiRef.current.setValue('key_mode', undefined); @@ -842,6 +857,10 @@ const EditChannelModal = (props) => { showInfo(t('请至少选择一个模型!')); return; } + if (localInputs.type === 45 && (!localInputs.base_url || localInputs.base_url.trim() === '')) { + showInfo(t('请输入API地址!')); + return; + } if ( localInputs.model_mapping && localInputs.model_mapping !== '' && @@ -871,6 +890,21 @@ const EditChannelModal = (props) => { }; localInputs.setting = JSON.stringify(channelExtraSettings); + // 处理type === 20的企业账户设置 + if (localInputs.type === 20) { + let settings = {}; + if (localInputs.settings) { + try { + settings = JSON.parse(localInputs.settings); + } catch (error) { + console.error('解析settings失败:', error); + } + } + // 设置企业账户标识,无论是true还是false都要传到后端 + settings.openrouter_enterprise = localInputs.is_enterprise_account === true; + localInputs.settings = JSON.stringify(settings); + } + // 清理不需要发送到后端的字段 delete localInputs.force_format; delete localInputs.thinking_to_content; @@ -878,6 +912,7 @@ const EditChannelModal = (props) => { delete localInputs.pass_through_body_enabled; delete localInputs.system_prompt; delete localInputs.system_prompt_override; + delete localInputs.is_enterprise_account; // 顶层的 vertex_key_type 不应发送给后端 delete localInputs.vertex_key_type; @@ -1195,6 +1230,21 @@ const EditChannelModal = (props) => { onChange={(value) => handleInputChange('type', value)} /> + {inputs.type === 20 && ( + { + setIsEnterpriseAccount(value); + handleInputChange('is_enterprise_account', value); + }} + extraText={t('企业账户为特殊返回格式,需要特殊处理,如果非企业账户,请勿勾选')} + initValue={inputs.is_enterprise_account} + /> + )} + { /> )} + + {inputs.type === 45 && ( +
+ + handleInputChange('base_url', value) + } + optionList={[ + { + value: 'https://ark.cn-beijing.volces.com', + label: 'https://ark.cn-beijing.volces.com' + }, + { + value: 'https://ark.ap-southeast.bytepluses.com', + label: 'https://ark.ap-southeast.bytepluses.com' + } + ]} + defaultValue='https://ark.cn-beijing.volces.com' + /> +
+ )} )} diff --git a/web/src/hooks/channels/useChannelsData.jsx b/web/src/hooks/channels/useChannelsData.jsx index 65460a06..7d09d4df 100644 --- a/web/src/hooks/channels/useChannelsData.jsx +++ b/web/src/hooks/channels/useChannelsData.jsx @@ -25,13 +25,9 @@ import { showInfo, showSuccess, loadChannelModels, - copy, + copy } from '../../helpers'; -import { - CHANNEL_OPTIONS, - ITEMS_PER_PAGE, - MODEL_TABLE_PAGE_SIZE, -} from '../../constants'; +import { CHANNEL_OPTIONS, ITEMS_PER_PAGE, MODEL_TABLE_PAGE_SIZE } from '../../constants'; import { useIsMobile } from '../common/useIsMobile'; import { useTableCompactMode } from '../common/useTableCompactMode'; import { Modal } from '@douyinfe/semi-ui'; @@ -68,7 +64,7 @@ export const useChannelsData = () => { // Status filter const [statusFilter, setStatusFilter] = useState( - localStorage.getItem('channel-status-filter') || 'all', + localStorage.getItem('channel-status-filter') || 'all' ); // Type tabs states @@ -83,9 +79,10 @@ export const useChannelsData = () => { const [testingModels, setTestingModels] = useState(new Set()); const [selectedModelKeys, setSelectedModelKeys] = useState([]); const [isBatchTesting, setIsBatchTesting] = useState(false); - const [testQueue, setTestQueue] = useState([]); - const [isProcessingQueue, setIsProcessingQueue] = useState(false); const [modelTablePage, setModelTablePage] = useState(1); + + // 使用 ref 来避免闭包问题,类似旧版实现 + const shouldStopBatchTestingRef = useRef(false); // Multi-key management states const [showMultiKeyManageModal, setShowMultiKeyManageModal] = useState(false); @@ -119,12 +116,9 @@ export const useChannelsData = () => { // Initialize from localStorage useEffect(() => { const localIdSort = localStorage.getItem('id-sort') === 'true'; - const localPageSize = - parseInt(localStorage.getItem('page-size')) || ITEMS_PER_PAGE; - const localEnableTagMode = - localStorage.getItem('enable-tag-mode') === 'true'; - const localEnableBatchDelete = - localStorage.getItem('enable-batch-delete') === 'true'; + const localPageSize = parseInt(localStorage.getItem('page-size')) || ITEMS_PER_PAGE; + const localEnableTagMode = localStorage.getItem('enable-tag-mode') === 'true'; + const localEnableBatchDelete = localStorage.getItem('enable-batch-delete') === 'true'; setIdSort(localIdSort); setPageSize(localPageSize); @@ -182,10 +176,7 @@ export const useChannelsData = () => { // Save column preferences useEffect(() => { if (Object.keys(visibleColumns).length > 0) { - localStorage.setItem( - 'channels-table-columns', - JSON.stringify(visibleColumns), - ); + localStorage.setItem('channels-table-columns', JSON.stringify(visibleColumns)); } }, [visibleColumns]); @@ -299,21 +290,14 @@ export const useChannelsData = () => { const { searchKeyword, searchGroup, searchModel } = getFormValues(); if (searchKeyword !== '' || searchGroup !== '' || searchModel !== '') { setLoading(true); - await searchChannels( - enableTagMode, - typeKey, - statusF, - page, - pageSize, - idSort, - ); + await searchChannels(enableTagMode, typeKey, statusF, page, pageSize, idSort); setLoading(false); return; } const reqId = ++requestCounter.current; setLoading(true); - const typeParam = typeKey !== 'all' ? `&type=${typeKey}` : ''; + const typeParam = (typeKey !== 'all') ? `&type=${typeKey}` : ''; const statusParam = statusF !== 'all' ? `&status=${statusF}` : ''; const res = await API.get( `/api/channel/?p=${page}&page_size=${pageSize}&id_sort=${idSort}&tag_mode=${enableTagMode}${typeParam}${statusParam}`, @@ -327,10 +311,7 @@ export const useChannelsData = () => { if (success) { const { items, total, type_counts } = data; if (type_counts) { - const sumAll = Object.values(type_counts).reduce( - (acc, v) => acc + v, - 0, - ); + const sumAll = Object.values(type_counts).reduce((acc, v) => acc + v, 0); setTypeCounts({ ...type_counts, all: sumAll }); } setChannelFormat(items, enableTagMode); @@ -354,18 +335,11 @@ export const useChannelsData = () => { setSearching(true); try { if (searchKeyword === '' && searchGroup === '' && searchModel === '') { - await loadChannels( - page, - pageSz, - sortFlag, - enableTagMode, - typeKey, - statusF, - ); + await loadChannels(page, pageSz, sortFlag, enableTagMode, typeKey, statusF); return; } - const typeParam = typeKey !== 'all' ? `&type=${typeKey}` : ''; + const typeParam = (typeKey !== 'all') ? `&type=${typeKey}` : ''; const statusParam = statusF !== 'all' ? `&status=${statusF}` : ''; const res = await API.get( `/api/channel/search?keyword=${searchKeyword}&group=${searchGroup}&model=${searchModel}&id_sort=${sortFlag}&tag_mode=${enableTagMode}&p=${page}&page_size=${pageSz}${typeParam}${statusParam}`, @@ -373,10 +347,7 @@ export const useChannelsData = () => { const { success, message, data } = res.data; if (success) { const { items = [], total = 0, type_counts = {} } = data; - const sumAll = Object.values(type_counts).reduce( - (acc, v) => acc + v, - 0, - ); + const sumAll = Object.values(type_counts).reduce((acc, v) => acc + v, 0); setTypeCounts({ ...type_counts, all: sumAll }); setChannelFormat(items, enableTagMode); setChannelCount(total); @@ -395,14 +366,7 @@ export const useChannelsData = () => { if (searchKeyword === '' && searchGroup === '' && searchModel === '') { await loadChannels(page, pageSize, idSort, enableTagMode); } else { - await searchChannels( - enableTagMode, - activeTypeKey, - statusFilter, - page, - pageSize, - idSort, - ); + await searchChannels(enableTagMode, activeTypeKey, statusFilter, page, pageSize, idSort); } }; @@ -488,16 +452,9 @@ export const useChannelsData = () => { const { searchKeyword, searchGroup, searchModel } = getFormValues(); setActivePage(page); if (searchKeyword === '' && searchGroup === '' && searchModel === '') { - loadChannels(page, pageSize, idSort, enableTagMode).then(() => {}); + loadChannels(page, pageSize, idSort, enableTagMode).then(() => { }); } else { - searchChannels( - enableTagMode, - activeTypeKey, - statusFilter, - page, - pageSize, - idSort, - ); + searchChannels(enableTagMode, activeTypeKey, statusFilter, page, pageSize, idSort); } }; @@ -513,14 +470,7 @@ export const useChannelsData = () => { showError(reason); }); } else { - searchChannels( - enableTagMode, - activeTypeKey, - statusFilter, - 1, - size, - idSort, - ); + searchChannels(enableTagMode, activeTypeKey, statusFilter, 1, size, idSort); } }; @@ -551,10 +501,7 @@ export const useChannelsData = () => { showError(res?.data?.message || t('渠道复制失败')); } } catch (error) { - showError( - t('渠道复制失败: ') + - (error?.response?.data?.message || error?.message || error), - ); + showError(t('渠道复制失败: ') + (error?.response?.data?.message || error?.message || error)); } }; @@ -593,11 +540,7 @@ export const useChannelsData = () => { data.priority = parseInt(data.priority); break; case 'weight': - if ( - data.weight === undefined || - data.weight < 0 || - data.weight === '' - ) { + if (data.weight === undefined || data.weight < 0 || data.weight === '') { showInfo('权重必须是非负整数!'); return; } @@ -740,136 +683,226 @@ export const useChannelsData = () => { const res = await API.post(`/api/channel/fix`); const { success, message, data } = res.data; if (success) { - showSuccess( - t('已修复 ${success} 个通道,失败 ${fails} 个通道。') - .replace('${success}', data.success) - .replace('${fails}', data.fails), - ); + showSuccess(t('已修复 ${success} 个通道,失败 ${fails} 个通道。').replace('${success}', data.success).replace('${fails}', data.fails)); await refresh(); } else { showError(message); } }; - // Test channel + // Test channel - 单个模型测试,参考旧版实现 const testChannel = async (record, model) => { - setTestQueue((prev) => [...prev, { channel: record, model }]); - if (!isProcessingQueue) { - setIsProcessingQueue(true); + const testKey = `${record.id}-${model}`; + + // 检查是否应该停止批量测试 + if (shouldStopBatchTestingRef.current && isBatchTesting) { + return Promise.resolve(); } - }; - // Process test queue - const processTestQueue = async () => { - if (!isProcessingQueue || testQueue.length === 0) return; - - const { channel, model, indexInFiltered } = testQueue[0]; - - if (currentTestChannel && currentTestChannel.id === channel.id) { - let pageNo; - if (indexInFiltered !== undefined) { - pageNo = Math.floor(indexInFiltered / MODEL_TABLE_PAGE_SIZE) + 1; - } else { - const filteredModelsList = currentTestChannel.models - .split(',') - .filter((m) => - m.toLowerCase().includes(modelSearchKeyword.toLowerCase()), - ); - const modelIdx = filteredModelsList.indexOf(model); - pageNo = - modelIdx !== -1 - ? Math.floor(modelIdx / MODEL_TABLE_PAGE_SIZE) + 1 - : 1; - } - setModelTablePage(pageNo); - } + // 添加到正在测试的模型集合 + setTestingModels(prev => new Set([...prev, model])); try { - setTestingModels((prev) => new Set([...prev, model])); - const res = await API.get( - `/api/channel/test/${channel.id}?model=${model}`, - ); + const res = await API.get(`/api/channel/test/${record.id}?model=${model}`); + + // 检查是否在请求期间被停止 + if (shouldStopBatchTestingRef.current && isBatchTesting) { + return Promise.resolve(); + } + const { success, message, time } = res.data; - setModelTestResults((prev) => ({ + // 更新测试结果 + setModelTestResults(prev => ({ ...prev, - [`${channel.id}-${model}`]: { success, time }, + [testKey]: { + success, + message, + time: time || 0, + timestamp: Date.now() + } })); if (success) { - updateChannelProperty(channel.id, (ch) => { - ch.response_time = time * 1000; - ch.test_time = Date.now() / 1000; + // 更新渠道响应时间 + updateChannelProperty(record.id, (channel) => { + channel.response_time = time * 1000; + channel.test_time = Date.now() / 1000; }); - if (!model) { + + if (!model || model === '') { showInfo( t('通道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。') - .replace('${name}', channel.name) + .replace('${name}', record.name) + .replace('${time.toFixed(2)}', time.toFixed(2)), + ); + } else { + showInfo( + t('通道 ${name} 测试成功,模型 ${model} 耗时 ${time.toFixed(2)} 秒。') + .replace('${name}', record.name) + .replace('${model}', model) .replace('${time.toFixed(2)}', time.toFixed(2)), ); } } else { - showError(message); + showError(`${t('模型')} ${model}: ${message}`); } } catch (error) { - showError(error.message); + // 处理网络错误 + const testKey = `${record.id}-${model}`; + setModelTestResults(prev => ({ + ...prev, + [testKey]: { + success: false, + message: error.message || t('网络错误'), + time: 0, + timestamp: Date.now() + } + })); + showError(`${t('模型')} ${model}: ${error.message || t('测试失败')}`); } finally { - setTestingModels((prev) => { + // 从正在测试的模型集合中移除 + setTestingModels(prev => { const newSet = new Set(prev); newSet.delete(model); return newSet; }); } - - setTestQueue((prev) => prev.slice(1)); }; - // Monitor queue changes - useEffect(() => { - if (testQueue.length > 0 && isProcessingQueue) { - processTestQueue(); - } else if (testQueue.length === 0 && isProcessingQueue) { - setIsProcessingQueue(false); - setIsBatchTesting(false); - } - }, [testQueue, isProcessingQueue]); - - // Batch test models + // 批量测试单个渠道的所有模型,参考旧版实现 const batchTestModels = async () => { - if (!currentTestChannel) return; + if (!currentTestChannel || !currentTestChannel.models) { + showError(t('渠道模型信息不完整')); + return; + } + + const models = currentTestChannel.models.split(',').filter(model => + model.toLowerCase().includes(modelSearchKeyword.toLowerCase()) + ); + + if (models.length === 0) { + showError(t('没有找到匹配的模型')); + return; + } setIsBatchTesting(true); - setModelTablePage(1); + shouldStopBatchTestingRef.current = false; // 重置停止标志 - const filteredModels = currentTestChannel.models - .split(',') - .filter((model) => - model.toLowerCase().includes(modelSearchKeyword.toLowerCase()), - ); + // 清空该渠道之前的测试结果 + setModelTestResults(prev => { + const newResults = { ...prev }; + models.forEach(model => { + const testKey = `${currentTestChannel.id}-${model}`; + delete newResults[testKey]; + }); + return newResults; + }); - setTestQueue( - filteredModels.map((model, idx) => ({ - channel: currentTestChannel, - model, - indexInFiltered: idx, - })), - ); - setIsProcessingQueue(true); + try { + showInfo(t('开始批量测试 ${count} 个模型,已清空上次结果...').replace('${count}', models.length)); + + // 提高并发数量以加快测试速度,参考旧版的并发限制 + const concurrencyLimit = 5; + const results = []; + + for (let i = 0; i < models.length; i += concurrencyLimit) { + // 检查是否应该停止 + if (shouldStopBatchTestingRef.current) { + showInfo(t('批量测试已停止')); + break; + } + + const batch = models.slice(i, i + concurrencyLimit); + showInfo(t('正在测试第 ${current} - ${end} 个模型 (共 ${total} 个)') + .replace('${current}', i + 1) + .replace('${end}', Math.min(i + concurrencyLimit, models.length)) + .replace('${total}', models.length) + ); + + const batchPromises = batch.map(model => testChannel(currentTestChannel, model)); + const batchResults = await Promise.allSettled(batchPromises); + results.push(...batchResults); + + // 再次检查是否应该停止 + if (shouldStopBatchTestingRef.current) { + showInfo(t('批量测试已停止')); + break; + } + + // 短暂延迟避免过于频繁的请求 + if (i + concurrencyLimit < models.length) { + await new Promise(resolve => setTimeout(resolve, 100)); + } + } + + if (!shouldStopBatchTestingRef.current) { + // 等待一小段时间确保所有结果都已更新 + await new Promise(resolve => setTimeout(resolve, 300)); + + // 使用当前状态重新计算结果统计 + setModelTestResults(currentResults => { + let successCount = 0; + let failCount = 0; + + models.forEach(model => { + const testKey = `${currentTestChannel.id}-${model}`; + const result = currentResults[testKey]; + if (result && result.success) { + successCount++; + } else { + failCount++; + } + }); + + // 显示完成消息 + setTimeout(() => { + showSuccess(t('批量测试完成!成功: ${success}, 失败: ${fail}, 总计: ${total}') + .replace('${success}', successCount) + .replace('${fail}', failCount) + .replace('${total}', models.length) + ); + }, 100); + + return currentResults; // 不修改状态,只是为了获取最新值 + }); + } + } catch (error) { + showError(t('批量测试过程中发生错误: ') + error.message); + } finally { + setIsBatchTesting(false); + } + }; + + // 停止批量测试 + const stopBatchTesting = () => { + shouldStopBatchTestingRef.current = true; + setIsBatchTesting(false); + setTestingModels(new Set()); + showInfo(t('已停止批量测试')); + }; + + // 清空测试结果 + const clearTestResults = () => { + setModelTestResults({}); + showInfo(t('已清空测试结果')); }; // Handle close modal const handleCloseModal = () => { + // 如果正在批量测试,先停止测试 if (isBatchTesting) { - setTestQueue([]); - setIsProcessingQueue(false); - setIsBatchTesting(false); - showSuccess(t('已停止测试')); - } else { - setShowModelTestModal(false); - setModelSearchKeyword(''); - setSelectedModelKeys([]); - setModelTablePage(1); + shouldStopBatchTestingRef.current = true; + showInfo(t('关闭弹窗,已停止批量测试')); } + + setShowModelTestModal(false); + setModelSearchKeyword(''); + setIsBatchTesting(false); + setTestingModels(new Set()); + setSelectedModelKeys([]); + setModelTablePage(1); + // 可选择性保留测试结果,这里不清空以便用户查看 }; // Type counts @@ -1012,4 +1045,4 @@ export const useChannelsData = () => { setCompactMode, setActivePage, }; -}; +}; \ No newline at end of file