diff --git a/constant/context_key.go b/constant/context_key.go index 32dd9617..26ff1738 100644 --- a/constant/context_key.go +++ b/constant/context_key.go @@ -40,4 +40,6 @@ const ( ContextKeyUserGroup ContextKey = "user_group" ContextKeyUsingGroup ContextKey = "group" ContextKeyUserName ContextKey = "username" + + ContextKeySystemPromptOverride ContextKey = "system_prompt_override" ) diff --git a/controller/midjourney.go b/controller/midjourney.go index 02ad708f..30a5a09a 100644 --- a/controller/midjourney.go +++ b/controller/midjourney.go @@ -145,6 +145,22 @@ func UpdateMidjourneyTaskBulk() { buttonStr, _ := json.Marshal(responseItem.Buttons) task.Buttons = string(buttonStr) } + // 映射 VideoUrl + task.VideoUrl = responseItem.VideoUrl + + // 映射 VideoUrls - 将数组序列化为 JSON 字符串 + if responseItem.VideoUrls != nil && len(responseItem.VideoUrls) > 0 { + videoUrlsStr, err := json.Marshal(responseItem.VideoUrls) + if err != nil { + common.LogError(ctx, fmt.Sprintf("序列化 VideoUrls 失败: %v", err)) + task.VideoUrls = "[]" // 失败时设置为空数组 + } else { + task.VideoUrls = string(videoUrlsStr) + } + } else { + task.VideoUrls = "" // 空值时清空字段 + } + shouldReturnQuota := false if (task.Progress != "100%" && responseItem.FailReason != "") || (task.Progress == "100%" && task.Status == "FAILURE") { common.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason) @@ -208,6 +224,20 @@ func checkMjTaskNeedUpdate(oldTask *model.Midjourney, newTask dto.MidjourneyDto) if oldTask.Progress != "100%" && newTask.FailReason != "" { return true } + // 检查 VideoUrl 是否需要更新 + if oldTask.VideoUrl != newTask.VideoUrl { + return true + } + // 检查 VideoUrls 是否需要更新 + if newTask.VideoUrls != nil && len(newTask.VideoUrls) > 0 { + newVideoUrlsStr, _ := json.Marshal(newTask.VideoUrls) + if oldTask.VideoUrls != string(newVideoUrlsStr) { + return true + } + } else if oldTask.VideoUrls != "" { + // 如果新数据没有 VideoUrls 但旧数据有,需要更新(清空) + return true + } return false } diff --git a/dto/channel_settings.go b/dto/channel_settings.go index 47f8bf95..1c697048 100644 --- a/dto/channel_settings.go +++ b/dto/channel_settings.go @@ -6,4 +6,5 @@ type ChannelSettings struct { Proxy string `json:"proxy"` PassThroughBodyEnabled bool `json:"pass_through_body_enabled,omitempty"` SystemPrompt string `json:"system_prompt,omitempty"` + SystemPromptOverride bool `json:"system_prompt_override,omitempty"` } diff --git a/dto/gemini.go b/dto/gemini.go index f7acd355..1bd1fe4c 100644 --- a/dto/gemini.go +++ b/dto/gemini.go @@ -216,10 +216,14 @@ type GeminiEmbeddingRequest struct { OutputDimensionality int `json:"outputDimensionality,omitempty"` } -type GeminiEmbeddingResponse struct { - Embedding ContentEmbedding `json:"embedding"` +type GeminiBatchEmbeddingRequest struct { + Requests []*GeminiEmbeddingRequest `json:"requests"` } -type ContentEmbedding struct { +type GeminiEmbedding struct { Values []float64 `json:"values"` } + +type GeminiBatchEmbeddingResponse struct { + Embeddings []*GeminiEmbedding `json:"embeddings"` +} diff --git a/dto/openai_request.go b/dto/openai_request.go index fcd47d07..f33b2c1e 100644 --- a/dto/openai_request.go +++ b/dto/openai_request.go @@ -78,6 +78,8 @@ func (r *GeneralOpenAIRequest) GetSystemRoleName() string { if !strings.HasPrefix(r.Model, "o1-mini") && !strings.HasPrefix(r.Model, "o1-preview") { return "developer" } + } else if strings.HasPrefix(r.Model, "gpt-5") { + return "developer" } return "system" } diff --git a/middleware/distributor.go b/middleware/distributor.go index e8abcbe9..dea30abf 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -267,6 +267,8 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode common.SetContextKey(c, constant.ContextKeyChannelKey, key) common.SetContextKey(c, constant.ContextKeyChannelBaseUrl, channel.GetBaseURL()) + common.SetContextKey(c, constant.ContextKeySystemPromptOverride, false) + // TODO: api_version统一 switch channel.Type { case constant.ChannelTypeAzure: diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go index 01dfea2c..e5b4146a 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/channel/gemini/adaptor.go @@ -114,7 +114,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { if strings.HasPrefix(info.UpstreamModelName, "text-embedding") || strings.HasPrefix(info.UpstreamModelName, "embedding") || strings.HasPrefix(info.UpstreamModelName, "gemini-embedding") { - return fmt.Sprintf("%s/%s/models/%s:embedContent", info.BaseUrl, version, info.UpstreamModelName), nil + return fmt.Sprintf("%s/%s/models/%s:batchEmbedContents", info.BaseUrl, version, info.UpstreamModelName), nil } action := "generateContent" @@ -159,29 +159,35 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela if len(inputs) == 0 { return nil, errors.New("input is empty") } - - // only process the first input - geminiRequest := dto.GeminiEmbeddingRequest{ - Content: dto.GeminiChatContent{ - Parts: []dto.GeminiPart{ - { - Text: inputs[0], + // process all inputs + geminiRequests := make([]map[string]interface{}, 0, len(inputs)) + for _, input := range inputs { + geminiRequest := map[string]interface{}{ + "model": fmt.Sprintf("models/%s", info.UpstreamModelName), + "content": dto.GeminiChatContent{ + Parts: []dto.GeminiPart{ + { + Text: input, + }, }, }, - }, - } - - // set specific parameters for different models - // https://ai.google.dev/api/embeddings?hl=zh-cn#method:-models.embedcontent - switch info.UpstreamModelName { - case "text-embedding-004": - // except embedding-001 supports setting `OutputDimensionality` - if request.Dimensions > 0 { - geminiRequest.OutputDimensionality = request.Dimensions } + + // set specific parameters for different models + // https://ai.google.dev/api/embeddings?hl=zh-cn#method:-models.embedcontent + switch info.UpstreamModelName { + case "text-embedding-004","gemini-embedding-exp-03-07","gemini-embedding-001": + // Only newer models introduced after 2024 support OutputDimensionality + if request.Dimensions > 0 { + geminiRequest["outputDimensionality"] = request.Dimensions + } + } + geminiRequests = append(geminiRequests, geminiRequest) } - return geminiRequest, nil + return map[string]interface{}{ + "requests": geminiRequests, + }, nil } func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index 25a2c412..24b42942 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -1071,7 +1071,7 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h return nil, types.NewOpenAIError(readErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } - var geminiResponse dto.GeminiEmbeddingResponse + var geminiResponse dto.GeminiBatchEmbeddingResponse if jsonErr := common.Unmarshal(responseBody, &geminiResponse); jsonErr != nil { return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) } @@ -1079,14 +1079,16 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h // convert to openai format response openAIResponse := dto.OpenAIEmbeddingResponse{ Object: "list", - Data: []dto.OpenAIEmbeddingResponseItem{ - { - Object: "embedding", - Embedding: geminiResponse.Embedding.Values, - Index: 0, - }, - }, - Model: info.UpstreamModelName, + Data: make([]dto.OpenAIEmbeddingResponseItem, 0, len(geminiResponse.Embeddings)), + Model: info.UpstreamModelName, + } + + for i, embedding := range geminiResponse.Embeddings { + openAIResponse.Data = append(openAIResponse.Data, dto.OpenAIEmbeddingResponseItem{ + Object: "embedding", + Embedding: embedding.Values, + Index: i, + }) } // calculate usage diff --git a/relay/channel/zhipu_4v/adaptor.go b/relay/channel/zhipu_4v/adaptor.go index 83070fe5..a83e30e6 100644 --- a/relay/channel/zhipu_4v/adaptor.go +++ b/relay/channel/zhipu_4v/adaptor.go @@ -54,8 +54,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { channel.SetupApiRequestHeader(info, c, req) - token := getZhipuToken(info.ApiKey) - req.Set("Authorization", token) + req.Set("Authorization", "Bearer "+info.ApiKey) return nil } diff --git a/relay/channel/zhipu_4v/relay-zhipu_v4.go b/relay/channel/zhipu_4v/relay-zhipu_v4.go index 98a852f5..cb8adfe4 100644 --- a/relay/channel/zhipu_4v/relay-zhipu_v4.go +++ b/relay/channel/zhipu_4v/relay-zhipu_v4.go @@ -1,69 +1,10 @@ package zhipu_4v import ( - "github.com/golang-jwt/jwt" - "one-api/common" "one-api/dto" "strings" - "sync" - "time" ) -// https://open.bigmodel.cn/doc/api#chatglm_std -// chatglm_std, chatglm_lite -// https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/invoke -// https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/sse-invoke - -var zhipuTokens sync.Map -var expSeconds int64 = 24 * 3600 - -func getZhipuToken(apikey string) string { - data, ok := zhipuTokens.Load(apikey) - if ok { - tokenData := data.(tokenData) - if time.Now().Before(tokenData.ExpiryTime) { - return tokenData.Token - } - } - - split := strings.Split(apikey, ".") - if len(split) != 2 { - common.SysError("invalid zhipu key: " + apikey) - return "" - } - - id := split[0] - secret := split[1] - - expMillis := time.Now().Add(time.Duration(expSeconds)*time.Second).UnixNano() / 1e6 - expiryTime := time.Now().Add(time.Duration(expSeconds) * time.Second) - - timestamp := time.Now().UnixNano() / 1e6 - - payload := jwt.MapClaims{ - "api_key": id, - "exp": expMillis, - "timestamp": timestamp, - } - - token := jwt.NewWithClaims(jwt.SigningMethodHS256, payload) - - token.Header["alg"] = "HS256" - token.Header["sign_type"] = "SIGN" - - tokenString, err := token.SignedString([]byte(secret)) - if err != nil { - return "" - } - - zhipuTokens.Store(apikey, tokenData{ - Token: tokenString, - ExpiryTime: expiryTime, - }) - - return tokenString -} - func requestOpenAI2Zhipu(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIRequest { messages := make([]dto.Message, 0, len(request.Messages)) for _, message := range request.Messages { diff --git a/relay/relay-text.go b/relay/relay-text.go index f175dbfb..1e014615 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -201,6 +201,26 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) { Content: relayInfo.ChannelSetting.SystemPrompt, } request.Messages = append([]dto.Message{systemMessage}, request.Messages...) + } else if relayInfo.ChannelSetting.SystemPromptOverride { + common.SetContextKey(c, constant.ContextKeySystemPromptOverride, true) + // 如果有系统提示,且允许覆盖,则拼接到前面 + for i, message := range request.Messages { + if message.Role == request.GetSystemRoleName() { + if message.IsStringContent() { + request.Messages[i].SetStringContent(relayInfo.ChannelSetting.SystemPrompt + "\n" + message.StringContent()) + } else { + contents := message.ParseContent() + contents = append([]dto.MediaContent{ + { + Type: dto.ContentTypeText, + Text: relayInfo.ChannelSetting.SystemPrompt, + }, + }, contents...) + request.Messages[i].Content = contents + } + break + } + } } } diff --git a/service/log_info_generate.go b/service/log_info_generate.go index 020a2ba9..0dae9a03 100644 --- a/service/log_info_generate.go +++ b/service/log_info_generate.go @@ -28,6 +28,12 @@ func GenerateTextOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, m other["is_model_mapped"] = true other["upstream_model_name"] = relayInfo.UpstreamModelName } + + isSystemPromptOverwritten := common.GetContextKeyBool(ctx, constant.ContextKeySystemPromptOverride) + if isSystemPromptOverwritten { + other["is_system_prompt_overwritten"] = true + } + adminInfo := make(map[string]interface{}) adminInfo["use_channel"] = ctx.GetStringSlice("use_channel") isMultiKey := common.GetContextKeyBool(ctx, constant.ContextKeyChannelIsMultiKey) diff --git a/web/src/components/table/channels/modals/EditChannelModal.jsx b/web/src/components/table/channels/modals/EditChannelModal.jsx index 40aedcbf..b86aade5 100644 --- a/web/src/components/table/channels/modals/EditChannelModal.jsx +++ b/web/src/components/table/channels/modals/EditChannelModal.jsx @@ -131,6 +131,7 @@ const EditChannelModal = (props) => { proxy: '', pass_through_body_enabled: false, system_prompt: '', + system_prompt_override: false, }; const [batch, setBatch] = useState(false); const [multiToSingle, setMultiToSingle] = useState(false); @@ -340,12 +341,15 @@ const EditChannelModal = (props) => { data.proxy = parsedSettings.proxy || ''; data.pass_through_body_enabled = parsedSettings.pass_through_body_enabled || false; data.system_prompt = parsedSettings.system_prompt || ''; + data.system_prompt_override = parsedSettings.system_prompt_override || false; } catch (error) { console.error('解析渠道设置失败:', error); data.force_format = false; data.thinking_to_content = false; data.proxy = ''; data.pass_through_body_enabled = false; + data.system_prompt = ''; + data.system_prompt_override = false; } } else { data.force_format = false; @@ -353,6 +357,7 @@ const EditChannelModal = (props) => { data.proxy = ''; data.pass_through_body_enabled = false; data.system_prompt = ''; + data.system_prompt_override = false; } setInputs(data); @@ -372,6 +377,7 @@ const EditChannelModal = (props) => { proxy: data.proxy, pass_through_body_enabled: data.pass_through_body_enabled, system_prompt: data.system_prompt, + system_prompt_override: data.system_prompt_override || false, }); // console.log(data); } else { @@ -573,6 +579,7 @@ const EditChannelModal = (props) => { proxy: '', pass_through_body_enabled: false, system_prompt: '', + system_prompt_override: false, }); // 重置密钥模式状态 setKeyMode('append'); @@ -721,6 +728,7 @@ const EditChannelModal = (props) => { proxy: localInputs.proxy || '', pass_through_body_enabled: localInputs.pass_through_body_enabled || false, system_prompt: localInputs.system_prompt || '', + system_prompt_override: localInputs.system_prompt_override || false, }; localInputs.setting = JSON.stringify(channelExtraSettings); @@ -730,6 +738,7 @@ const EditChannelModal = (props) => { delete localInputs.proxy; delete localInputs.pass_through_body_enabled; delete localInputs.system_prompt; + delete localInputs.system_prompt_override; let res; localInputs.auto_ban = localInputs.auto_ban ? 1 : 0; @@ -1722,6 +1731,14 @@ const EditChannelModal = (props) => { showClear extraText={t('用户优先:如果用户在请求中指定了系统提示词,将优先使用用户的设置')} /> +
{modalContent}
+ {isVideo ? ( + + ) : ( +{modalContent}
+ )}