diff --git a/common/limiter/limiter.go b/common/limiter/limiter.go index fcfcb0c3..ef5d1935 100644 --- a/common/limiter/limiter.go +++ b/common/limiter/limiter.go @@ -5,7 +5,7 @@ import ( _ "embed" "fmt" "github.com/go-redis/redis/v8" - "one-api/logger" + "one-api/common" "sync" ) @@ -27,7 +27,7 @@ func New(ctx context.Context, r *redis.Client) *RedisLimiter { // 预加载脚本 limitSHA, err := r.ScriptLoad(ctx, rateLimitScript).Result() if err != nil { - logger.SysLog(fmt.Sprintf("Failed to load rate limit script: %v", err)) + common.SysLog(fmt.Sprintf("Failed to load rate limit script: %v", err)) } instance = &RedisLimiter{ client: r, diff --git a/controller/channel-billing.go b/controller/channel-billing.go index bbf0f97a..5152e060 100644 --- a/controller/channel-billing.go +++ b/controller/channel-billing.go @@ -8,7 +8,6 @@ import ( "net/http" "one-api/common" "one-api/constant" - "one-api/logger" "one-api/model" "one-api/service" "one-api/setting" @@ -486,8 +485,8 @@ func UpdateAllChannelsBalance(c *gin.Context) { func AutomaticallyUpdateChannels(frequency int) { for { time.Sleep(time.Duration(frequency) * time.Minute) - logger.SysLog("updating all channels") + common.SysLog("updating all channels") _ = updateAllChannelsBalance() - logger.SysLog("channels update done") + common.SysLog("channels update done") } } diff --git a/controller/channel-test.go b/controller/channel-test.go index ec2e6226..32486a8b 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -13,7 +13,6 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" - "one-api/logger" "one-api/middleware" "one-api/model" "one-api/relay" @@ -133,8 +132,17 @@ func testChannel(channel *model.Channel, testModel string) testResult { newAPIError: newAPIError, } } + request := buildTestRequest(testModel) - info := relaycommon.GenRelayInfo(c) + info, err := relaycommon.GenRelayInfo(c, types.RelayFormatOpenAI, request, nil) + + if err != nil { + return testResult{ + context: c, + localErr: err, + newAPIError: types.NewError(err, types.ErrorCodeGenRelayInfoFailed), + } + } err = helper.ModelMappedHelper(c, info, nil) if err != nil { @@ -144,7 +152,9 @@ func testChannel(channel *model.Channel, testModel string) testResult { newAPIError: types.NewError(err, types.ErrorCodeChannelModelMappedError), } } + testModel = info.UpstreamModelName + request.Model = testModel apiType, _ := common.ChannelType2APIType(channel.Type) adaptor := relay.GetAdaptor(apiType) @@ -156,13 +166,12 @@ func testChannel(channel *model.Channel, testModel string) testResult { } } - request := buildTestRequest(testModel) // 创建一个用于日志的 info 副本,移除 ApiKey logInfo := *info logInfo.ApiKey = "" - logger.SysLog(fmt.Sprintf("testing channel %d with model %s , info %+v ", channel.Id, testModel, logInfo)) + common.SysLog(fmt.Sprintf("testing channel %d with model %s , info %+v ", channel.Id, testModel, logInfo)) - priceData, err := helper.ModelPriceHelper(c, info, 0, int(request.GetMaxTokens())) + priceData, err := helper.ModelPriceHelper(c, info, 0, request.GetTokenCountMeta()) if err != nil { return testResult{ context: c, @@ -280,7 +289,7 @@ func testChannel(channel *model.Channel, testModel string) testResult { Group: info.UsingGroup, Other: other, }) - logger.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody))) + common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody))) return testResult{ context: c, localErr: nil, @@ -462,13 +471,13 @@ func TestAllChannels(c *gin.Context) { func AutomaticallyTestChannels(frequency int) { if frequency <= 0 { - logger.SysLog("CHANNEL_TEST_FREQUENCY is not set or invalid, skipping automatic channel test") + common.SysLog("CHANNEL_TEST_FREQUENCY is not set or invalid, skipping automatic channel test") return } for { time.Sleep(time.Duration(frequency) * time.Minute) - logger.SysLog("testing all channels") + common.SysLog("testing all channels") _ = testAllChannels(false) - logger.SysLog("channel test finished") + common.SysLog("channel test finished") } } diff --git a/controller/console_migrate.go b/controller/console_migrate.go index d21f5e21..f0812c3d 100644 --- a/controller/console_migrate.go +++ b/controller/console_migrate.go @@ -4,10 +4,11 @@ package controller import ( "encoding/json" - "github.com/gin-gonic/gin" "net/http" - "one-api/logger" + "one-api/common" "one-api/model" + + "github.com/gin-gonic/gin" ) // MigrateConsoleSetting 迁移旧的控制台相关配置到 console_setting.* @@ -98,6 +99,6 @@ func MigrateConsoleSetting(c *gin.Context) { // 重新加载 OptionMap model.InitOptionMap() - logger.SysLog("console setting migrated") + common.SysLog("console setting migrated") c.JSON(http.StatusOK, gin.H{"success": true, "message": "migrated"}) } diff --git a/controller/github.go b/controller/github.go index 0715a8fe..881d6dc1 100644 --- a/controller/github.go +++ b/controller/github.go @@ -7,7 +7,6 @@ import ( "fmt" "net/http" "one-api/common" - "one-api/logger" "one-api/model" "strconv" "time" @@ -48,7 +47,7 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) { } res, err := client.Do(req) if err != nil { - logger.SysLog(err.Error()) + common.SysLog(err.Error()) return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!") } defer res.Body.Close() @@ -64,7 +63,7 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) { req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken)) res2, err := client.Do(req) if err != nil { - logger.SysLog(err.Error()) + common.SysLog(err.Error()) return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!") } defer res2.Body.Close() diff --git a/controller/model.go b/controller/model.go index d03fdeb2..398503e8 100644 --- a/controller/model.go +++ b/controller/model.go @@ -93,7 +93,9 @@ func init() { if !success || apiType == constant.APITypeAIProxyLibrary { continue } - meta := &relaycommon.RelayInfo{ChannelType: i} + meta := &relaycommon.RelayInfo{ChannelMeta: &relaycommon.ChannelMeta{ + ChannelType: i, + }} adaptor := relay.GetAdaptor(apiType) adaptor.Init(meta) channelId2Models[i] = adaptor.GetModelList() diff --git a/controller/oidc.go b/controller/oidc.go index 1e3435a8..f3def0e3 100644 --- a/controller/oidc.go +++ b/controller/oidc.go @@ -7,7 +7,6 @@ import ( "net/http" "net/url" "one-api/common" - "one-api/logger" "one-api/model" "one-api/setting" "one-api/setting/system_setting" @@ -59,7 +58,7 @@ func getOidcUserInfoByCode(code string) (*OidcUser, error) { } res, err := client.Do(req) if err != nil { - logger.SysLog(err.Error()) + common.SysLog(err.Error()) return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!") } defer res.Body.Close() @@ -70,7 +69,7 @@ func getOidcUserInfoByCode(code string) (*OidcUser, error) { } if oidcResponse.AccessToken == "" { - logger.SysError("OIDC 获取 Token 失败,请检查设置!") + common.SysLog("OIDC 获取 Token 失败,请检查设置!") return nil, errors.New("OIDC 获取 Token 失败,请检查设置!") } @@ -81,12 +80,12 @@ func getOidcUserInfoByCode(code string) (*OidcUser, error) { req.Header.Set("Authorization", "Bearer "+oidcResponse.AccessToken) res2, err := client.Do(req) if err != nil { - logger.SysLog(err.Error()) + common.SysLog(err.Error()) return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!") } defer res2.Body.Close() if res2.StatusCode != http.StatusOK { - logger.SysError("OIDC 获取用户信息失败!请检查设置!") + common.SysLog("OIDC 获取用户信息失败!请检查设置!") return nil, errors.New("OIDC 获取用户信息失败!请检查设置!") } @@ -96,7 +95,7 @@ func getOidcUserInfoByCode(code string) (*OidcUser, error) { return nil, err } if oidcUser.OpenID == "" || oidcUser.Email == "" { - logger.SysError("OIDC 获取用户信息为空!请检查设置!") + common.SysLog("OIDC 获取用户信息为空!请检查设置!") return nil, errors.New("OIDC 获取用户信息为空!请检查设置!") } return &oidcUser, nil diff --git a/controller/playground.go b/controller/playground.go index dd930802..8a1cb2b6 100644 --- a/controller/playground.go +++ b/controller/playground.go @@ -56,5 +56,5 @@ func Playground(c *gin.Context) { //middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model) common.SetContextKey(c, constant.ContextKeyRequestStartTime, time.Now()) - Relay(c) + Relay(c, types.RelayFormatOpenAI) } diff --git a/controller/relay.go b/controller/relay.go index 583ac036..8b67fd89 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -104,26 +104,6 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) { return } - //includeUsage := true - //// 判断用户是否需要返回使用情况 - //if textRequest.StreamOptions != nil { - // includeUsage = textRequest.StreamOptions.IncludeUsage - //} - // - //// 如果不支持StreamOptions,将StreamOptions设置为nil - //if !relayInfo.SupportStreamOptions || !textRequest.Stream { - // textRequest.StreamOptions = nil - //} else { - // // 如果支持StreamOptions,且请求中没有设置StreamOptions,根据配置文件设置StreamOptions - // if constant.ForceStreamOption { - // textRequest.StreamOptions = &dto.StreamOptions{ - // IncludeUsage: true, - // } - // } - //} - // - //relayInfo.ShouldIncludeUsage = includeUsage - relayInfo, err := relaycommon.GenRelayInfo(c, relayFormat, request, ws) if err != nil { newAPIError = types.NewError(err, types.ErrorCodeGenRelayInfoFailed) @@ -178,7 +158,7 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) { switch relayFormat { case types.RelayFormatOpenAIRealtime: - newAPIError = relay.WssHelper(c, ws) + newAPIError = relay.WssHelper(c, relayInfo) case types.RelayFormatClaude: newAPIError = relay.ClaudeHelper(c, relayInfo) case types.RelayFormatGemini: @@ -324,35 +304,45 @@ func processChannelError(c *gin.Context, channelError types.ChannelError, err *t } func RelayMidjourney(c *gin.Context) { - relayMode := c.GetInt("relay_mode") - var err *dto.MidjourneyResponse - switch relayMode { + relayInfo, err := relaycommon.GenRelayInfo(c, types.RelayFormatMjProxy, nil, nil) + + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "description": fmt.Sprintf("failed to generate relay info: %s", err.Error()), + "type": "upstream_error", + "code": 4, + }) + return + } + + var mjErr *dto.MidjourneyResponse + switch relayInfo.RelayMode { case relayconstant.RelayModeMidjourneyNotify: - err = relay.RelayMidjourneyNotify(c) + mjErr = relay.RelayMidjourneyNotify(c) case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition: - err = relay.RelayMidjourneyTask(c, relayMode) + mjErr = relay.RelayMidjourneyTask(c, relayInfo.RelayMode) case relayconstant.RelayModeMidjourneyTaskImageSeed: - err = relay.RelayMidjourneyTaskImageSeed(c) + mjErr = relay.RelayMidjourneyTaskImageSeed(c) case relayconstant.RelayModeSwapFace: - err = relay.RelaySwapFace(c) + mjErr = relay.RelaySwapFace(c, relayInfo) default: - err = relay.RelayMidjourneySubmit(c, relayMode) + mjErr = relay.RelayMidjourneySubmit(c, relayInfo) } //err = relayMidjourneySubmit(c, relayMode) - log.Println(err) - if err != nil { + log.Println(mjErr) + if mjErr != nil { statusCode := http.StatusBadRequest - if err.Code == 30 { - err.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。" + if mjErr.Code == 30 { + mjErr.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。" statusCode = http.StatusTooManyRequests } c.JSON(statusCode, gin.H{ - "description": fmt.Sprintf("%s %s", err.Description, err.Result), + "description": fmt.Sprintf("%s %s", mjErr.Description, mjErr.Result), "type": "upstream_error", - "code": err.Code, + "code": mjErr.Code, }) channelId := c.GetInt("channel_id") - logger.LogError(c, fmt.Sprintf("relay error (channel #%d, status code %d): %s", channelId, statusCode, fmt.Sprintf("%s %s", err.Description, err.Result))) + logger.LogError(c, fmt.Sprintf("relay error (channel #%d, status code %d): %s", channelId, statusCode, fmt.Sprintf("%s %s", mjErr.Description, mjErr.Result))) } } diff --git a/controller/task.go b/controller/task.go index a5b28ae2..1082d7a1 100644 --- a/controller/task.go +++ b/controller/task.go @@ -26,7 +26,7 @@ func UpdateTaskBulk() { //imageModel := "midjourney" for { time.Sleep(time.Duration(15) * time.Second) - logger.SysLog("任务进度轮询开始") + common.SysLog("任务进度轮询开始") ctx := context.TODO() allTasks := model.GetAllUnFinishSyncTasks(500) platformTask := make(map[constant.TaskPlatform][]*model.Task) @@ -66,7 +66,7 @@ func UpdateTaskBulk() { UpdateTaskByPlatform(platform, taskChannelM, taskM) } - logger.SysLog("任务进度轮询完成") + common.SysLog("任务进度轮询完成") } } @@ -78,7 +78,7 @@ func UpdateTaskByPlatform(platform constant.TaskPlatform, taskChannelM map[int][ _ = UpdateSunoTaskAll(context.Background(), taskChannelM, taskM) default: if err := UpdateVideoTaskAll(context.Background(), platform, taskChannelM, taskM); err != nil { - logger.SysLog(fmt.Sprintf("UpdateVideoTaskAll fail: %s", err)) + common.SysLog(fmt.Sprintf("UpdateVideoTaskAll fail: %s", err)) } } } @@ -100,14 +100,14 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas } channel, err := model.CacheGetChannel(channelId) if err != nil { - logger.SysLog(fmt.Sprintf("CacheGetChannel: %v", err)) + common.SysLog(fmt.Sprintf("CacheGetChannel: %v", err)) err = model.TaskBulkUpdate(taskIds, map[string]any{ "fail_reason": fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", channelId), "status": "FAILURE", "progress": "100%", }) if err != nil { - logger.SysError(fmt.Sprintf("UpdateMidjourneyTask error2: %v", err)) + common.SysLog(fmt.Sprintf("UpdateMidjourneyTask error2: %v", err)) } return err } @@ -119,7 +119,7 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas "ids": taskIds, }) if err != nil { - logger.SysError(fmt.Sprintf("Get Task Do req error: %v", err)) + common.SysLog(fmt.Sprintf("Get Task Do req error: %v", err)) return err } if resp.StatusCode != http.StatusOK { @@ -129,7 +129,7 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas defer resp.Body.Close() responseBody, err := io.ReadAll(resp.Body) if err != nil { - logger.SysError(fmt.Sprintf("Get Task parse body error: %v", err)) + common.SysLog(fmt.Sprintf("Get Task parse body error: %v", err)) return err } var responseItems dto.TaskResponse[[]dto.SunoDataResponse] @@ -139,7 +139,7 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas return err } if !responseItems.IsSuccess() { - logger.SysLog(fmt.Sprintf("渠道 #%d 未完成的任务有: %d, 成功获取到任务数: %d", channelId, len(taskIds), string(responseBody))) + common.SysLog(fmt.Sprintf("渠道 #%d 未完成的任务有: %d, 成功获取到任务数: %d", channelId, len(taskIds), string(responseBody))) return err } @@ -179,7 +179,7 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas err = task.Update() if err != nil { - logger.SysError("UpdateMidjourneyTask task error: " + err.Error()) + common.SysLog("UpdateMidjourneyTask task error: " + err.Error()) } } return nil diff --git a/controller/task_video.go b/controller/task_video.go index dca42955..ffb6728b 100644 --- a/controller/task_video.go +++ b/controller/task_video.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "io" + "one-api/common" "one-api/constant" "one-api/dto" "one-api/logger" @@ -37,7 +38,7 @@ func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, cha "progress": "100%", }) if errUpdate != nil { - logger.SysError(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate)) + common.SysLog(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate)) } return fmt.Errorf("CacheGetChannel failed: %w", err) } @@ -112,7 +113,7 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha task.StartTime = now } case model.TaskStatusSuccess: - task.Progress = "100%" + task.Progress = "100%" if task.FinishTime == 0 { task.FinishTime = now } @@ -140,7 +141,7 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha task.Progress = taskResult.Progress } if err := task.Update(); err != nil { - logger.SysError("UpdateVideoTask task error: " + err.Error()) + common.SysLog("UpdateVideoTask task error: " + err.Error()) } return nil diff --git a/controller/token.go b/controller/token.go index db575fec..399ccb4f 100644 --- a/controller/token.go +++ b/controller/token.go @@ -3,7 +3,6 @@ package controller import ( "net/http" "one-api/common" - "one-api/logger" "one-api/model" "strconv" @@ -103,7 +102,7 @@ func AddToken(c *gin.Context) { "success": false, "message": "生成令牌失败", }) - logger.SysError("failed to generate token key: " + err.Error()) + common.SysLog("failed to generate token key: " + err.Error()) return } cleanToken := model.Token{ diff --git a/controller/twofa.go b/controller/twofa.go index 0ab66029..1859a128 100644 --- a/controller/twofa.go +++ b/controller/twofa.go @@ -5,7 +5,6 @@ import ( "fmt" "net/http" "one-api/common" - "one-api/logger" "one-api/model" "strconv" @@ -71,7 +70,7 @@ func Setup2FA(c *gin.Context) { "success": false, "message": "生成2FA密钥失败", }) - logger.SysError("生成TOTP密钥失败: " + err.Error()) + common.SysLog("生成TOTP密钥失败: " + err.Error()) return } @@ -82,7 +81,7 @@ func Setup2FA(c *gin.Context) { "success": false, "message": "生成备用码失败", }) - logger.SysError("生成备用码失败: " + err.Error()) + common.SysLog("生成备用码失败: " + err.Error()) return } @@ -116,7 +115,7 @@ func Setup2FA(c *gin.Context) { "success": false, "message": "保存备用码失败", }) - logger.SysError("保存备用码失败: " + err.Error()) + common.SysLog("保存备用码失败: " + err.Error()) return } @@ -295,7 +294,7 @@ func Get2FAStatus(c *gin.Context) { // 获取剩余备用码数量 backupCount, err := model.GetUnusedBackupCodeCount(userId) if err != nil { - logger.SysError("获取备用码数量失败: " + err.Error()) + common.SysLog("获取备用码数量失败: " + err.Error()) } else { status["backup_codes_remaining"] = backupCount } @@ -369,7 +368,7 @@ func RegenerateBackupCodes(c *gin.Context) { "success": false, "message": "生成备用码失败", }) - logger.SysError("生成备用码失败: " + err.Error()) + common.SysLog("生成备用码失败: " + err.Error()) return } @@ -379,7 +378,7 @@ func RegenerateBackupCodes(c *gin.Context) { "success": false, "message": "保存备用码失败", }) - logger.SysError("保存备用码失败: " + err.Error()) + common.SysLog("保存备用码失败: " + err.Error()) return } diff --git a/controller/user.go b/controller/user.go index 8ce44fa6..a7d59f17 100644 --- a/controller/user.go +++ b/controller/user.go @@ -193,7 +193,7 @@ func Register(c *gin.Context) { "success": false, "message": "数据库错误,请稍后重试", }) - logger.SysError(fmt.Sprintf("CheckUserExistOrDeleted error: %v", err)) + common.SysLog(fmt.Sprintf("CheckUserExistOrDeleted error: %v", err)) return } if exist { @@ -236,7 +236,7 @@ func Register(c *gin.Context) { "success": false, "message": "生成默认令牌失败", }) - logger.SysError("failed to generate token key: " + err.Error()) + common.SysLog("failed to generate token key: " + err.Error()) return } // 生成默认令牌 @@ -343,7 +343,7 @@ func GenerateAccessToken(c *gin.Context) { "success": false, "message": "生成失败", }) - logger.SysError("failed to generate key: " + err.Error()) + common.SysLog("failed to generate key: " + err.Error()) return } user.SetAccessToken(key) diff --git a/dto/openai_request.go b/dto/openai_request.go index 0c01c503..12aa54f4 100644 --- a/dto/openai_request.go +++ b/dto/openai_request.go @@ -332,9 +332,9 @@ func (m *MediaContent) GetVideoUrl() *MessageVideoUrl { } type MessageImageUrl struct { - Url string `json:"url"` - Detail string `json:"detail"` - //MimeType string + Url string `json:"url"` + Detail string `json:"detail"` + MimeType string } func (m *MessageImageUrl) IsRemoteImage() bool { diff --git a/dto/request_common.go b/dto/request_common.go index e5dde8b5..8bd25785 100644 --- a/dto/request_common.go +++ b/dto/request_common.go @@ -9,3 +9,16 @@ type Request interface { GetTokenCountMeta() *types.TokenCountMeta IsStream(c *gin.Context) bool } + +type BaseRequest struct { +} + +func (b *BaseRequest) GetTokenCountMeta() *types.TokenCountMeta { + return &types.TokenCountMeta{ + TokenType: types.TokenTypeTokenizer, + } +} + +func (b *BaseRequest) IsStream(c *gin.Context) bool { + return false +} diff --git a/main.go b/main.go index 9a5bd652..2dfddacc 100644 --- a/main.go +++ b/main.go @@ -36,22 +36,22 @@ func main() { err := InitResources() if err != nil { - logger.FatalLog("failed to initialize resources: " + err.Error()) + common.FatalLog("failed to initialize resources: " + err.Error()) return } - logger.SysLog("New API " + common.Version + " started") + common.SysLog("New API " + common.Version + " started") if os.Getenv("GIN_MODE") != "debug" { gin.SetMode(gin.ReleaseMode) } if common.DebugEnabled { - logger.SysLog("running in debug mode") + common.SysLog("running in debug mode") } defer func() { err := model.CloseDB() if err != nil { - logger.FatalLog("failed to close database: " + err.Error()) + common.FatalLog("failed to close database: " + err.Error()) } }() @@ -60,18 +60,18 @@ func main() { common.MemoryCacheEnabled = true } if common.MemoryCacheEnabled { - logger.SysLog("memory cache enabled") - logger.SysError(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency)) + common.SysLog("memory cache enabled") + common.SysLog(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency)) // Add panic recovery and retry for InitChannelCache func() { defer func() { if r := recover(); r != nil { - logger.SysError(fmt.Sprintf("InitChannelCache panic: %v, retrying once", r)) + common.SysLog(fmt.Sprintf("InitChannelCache panic: %v, retrying once", r)) // Retry once _, _, fixErr := model.FixAbility() if fixErr != nil { - logger.FatalLog(fmt.Sprintf("InitChannelCache failed: %s", fixErr.Error())) + common.FatalLog(fmt.Sprintf("InitChannelCache failed: %s", fixErr.Error())) } } }() @@ -90,14 +90,14 @@ func main() { if os.Getenv("CHANNEL_UPDATE_FREQUENCY") != "" { frequency, err := strconv.Atoi(os.Getenv("CHANNEL_UPDATE_FREQUENCY")) if err != nil { - logger.FatalLog("failed to parse CHANNEL_UPDATE_FREQUENCY: " + err.Error()) + common.FatalLog("failed to parse CHANNEL_UPDATE_FREQUENCY: " + err.Error()) } go controller.AutomaticallyUpdateChannels(frequency) } if os.Getenv("CHANNEL_TEST_FREQUENCY") != "" { frequency, err := strconv.Atoi(os.Getenv("CHANNEL_TEST_FREQUENCY")) if err != nil { - logger.FatalLog("failed to parse CHANNEL_TEST_FREQUENCY: " + err.Error()) + common.FatalLog("failed to parse CHANNEL_TEST_FREQUENCY: " + err.Error()) } go controller.AutomaticallyTestChannels(frequency) } @@ -111,7 +111,7 @@ func main() { } if os.Getenv("BATCH_UPDATE_ENABLED") == "true" { common.BatchUpdateEnabled = true - logger.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s") + common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s") model.InitBatchUpdater() } @@ -120,13 +120,13 @@ func main() { log.Println(http.ListenAndServe("0.0.0.0:8005", nil)) }) go common.Monitor() - logger.SysLog("pprof enabled") + common.SysLog("pprof enabled") } // Initialize HTTP server server := gin.New() server.Use(gin.CustomRecovery(func(c *gin.Context, err any) { - logger.SysError(fmt.Sprintf("panic detected: %v", err)) + common.SysLog(fmt.Sprintf("panic detected: %v", err)) c.JSON(http.StatusInternalServerError, gin.H{ "error": gin.H{ "message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/Calcium-Ion/new-api", err), @@ -156,7 +156,7 @@ func main() { } err = server.Run(":" + port) if err != nil { - logger.FatalLog("failed to start HTTP server: " + err.Error()) + common.FatalLog("failed to start HTTP server: " + err.Error()) } } @@ -165,8 +165,8 @@ func InitResources() error { // This is a placeholder function for future resource initialization err := godotenv.Load(".env") if err != nil { - logger.SysLog("未找到 .env 文件,使用默认环境变量,如果需要,请创建 .env 文件并设置相关变量") - logger.SysLog("No .env file found, using default environment variables. If needed, please create a .env file and set the relevant variables.") + common.SysLog("未找到 .env 文件,使用默认环境变量,如果需要,请创建 .env 文件并设置相关变量") + common.SysLog("No .env file found, using default environment variables. If needed, please create a .env file and set the relevant variables.") } // 加载环境变量 @@ -184,7 +184,7 @@ func InitResources() error { // Initialize SQL Database err = model.InitDB() if err != nil { - logger.FatalLog("failed to initialize database: " + err.Error()) + common.FatalLog("failed to initialize database: " + err.Error()) return err } diff --git a/middleware/recover.go b/middleware/recover.go index 6c9c7ef6..d78c8137 100644 --- a/middleware/recover.go +++ b/middleware/recover.go @@ -4,7 +4,7 @@ import ( "fmt" "github.com/gin-gonic/gin" "net/http" - "one-api/logger" + "one-api/common" "runtime/debug" ) @@ -12,8 +12,8 @@ func RelayPanicRecover() gin.HandlerFunc { return func(c *gin.Context) { defer func() { if err := recover(); err != nil { - logger.SysError(fmt.Sprintf("panic detected: %v", err)) - logger.SysError(fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack()))) + common.SysLog(fmt.Sprintf("panic detected: %v", err)) + common.SysLog(fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack()))) c.JSON(http.StatusInternalServerError, gin.H{ "error": gin.H{ "message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/Calcium-Ion/new-api", err), diff --git a/middleware/turnstile-check.go b/middleware/turnstile-check.go index a136a900..106a7278 100644 --- a/middleware/turnstile-check.go +++ b/middleware/turnstile-check.go @@ -7,7 +7,6 @@ import ( "net/http" "net/url" "one-api/common" - "one-api/logger" ) type turnstileCheckResponse struct { @@ -38,7 +37,7 @@ func TurnstileCheck() gin.HandlerFunc { "remoteip": {c.ClientIP()}, }) if err != nil { - logger.SysError(err.Error()) + common.SysLog(err.Error()) c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), @@ -50,7 +49,7 @@ func TurnstileCheck() gin.HandlerFunc { var res turnstileCheckResponse err = json.NewDecoder(rawRes.Body).Decode(&res) if err != nil { - logger.SysError(err.Error()) + common.SysLog(err.Error()) c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), diff --git a/model/ability.go b/model/ability.go index ac5530d8..123fc7be 100644 --- a/model/ability.go +++ b/model/ability.go @@ -4,7 +4,6 @@ import ( "errors" "fmt" "one-api/common" - "one-api/logger" "strings" "sync" @@ -295,13 +294,13 @@ func FixAbility() (int, int, error) { if common.UsingSQLite { err := DB.Exec("DELETE FROM abilities").Error if err != nil { - logger.SysError(fmt.Sprintf("Delete abilities failed: %s", err.Error())) + common.SysLog(fmt.Sprintf("Delete abilities failed: %s", err.Error())) return 0, 0, err } } else { err := DB.Exec("TRUNCATE TABLE abilities").Error if err != nil { - logger.SysError(fmt.Sprintf("Truncate abilities failed: %s", err.Error())) + common.SysLog(fmt.Sprintf("Truncate abilities failed: %s", err.Error())) return 0, 0, err } } @@ -321,7 +320,7 @@ func FixAbility() (int, int, error) { // Delete all abilities of this channel err = DB.Where("channel_id IN ?", ids).Delete(&Ability{}).Error if err != nil { - logger.SysError(fmt.Sprintf("Delete abilities failed: %s", err.Error())) + common.SysLog(fmt.Sprintf("Delete abilities failed: %s", err.Error())) failCount += len(chunk) continue } @@ -329,7 +328,7 @@ func FixAbility() (int, int, error) { for _, channel := range chunk { err = channel.AddAbilities(nil) if err != nil { - logger.SysError(fmt.Sprintf("Add abilities for channel %d failed: %s", channel.Id, err.Error())) + common.SysLog(fmt.Sprintf("Add abilities for channel %d failed: %s", channel.Id, err.Error())) failCount++ } else { successCount++ diff --git a/model/channel.go b/model/channel.go index c0d253fc..af769f63 100644 --- a/model/channel.go +++ b/model/channel.go @@ -9,7 +9,6 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" - "one-api/logger" "one-api/types" "strings" "sync" @@ -210,7 +209,7 @@ func (channel *Channel) GetOtherInfo() map[string]interface{} { if channel.OtherInfo != "" { err := common.Unmarshal([]byte(channel.OtherInfo), &otherInfo) if err != nil { - logger.SysError("failed to unmarshal other info: " + err.Error()) + common.SysLog("failed to unmarshal other info: " + err.Error()) } } return otherInfo @@ -219,7 +218,7 @@ func (channel *Channel) GetOtherInfo() map[string]interface{} { func (channel *Channel) SetOtherInfo(otherInfo map[string]interface{}) { otherInfoBytes, err := json.Marshal(otherInfo) if err != nil { - logger.SysError("failed to marshal other info: " + err.Error()) + common.SysLog("failed to marshal other info: " + err.Error()) return } channel.OtherInfo = string(otherInfoBytes) @@ -489,7 +488,7 @@ func (channel *Channel) UpdateResponseTime(responseTime int64) { ResponseTime: int(responseTime), }).Error if err != nil { - logger.SysError("failed to update response time: " + err.Error()) + common.SysLog("failed to update response time: " + err.Error()) } } @@ -499,7 +498,7 @@ func (channel *Channel) UpdateBalance(balance float64) { Balance: balance, }).Error if err != nil { - logger.SysError("failed to update balance: " + err.Error()) + common.SysLog("failed to update balance: " + err.Error()) } } @@ -615,7 +614,7 @@ func UpdateChannelStatus(channelId int, usingKey string, status int, reason stri if shouldUpdateAbilities { err := UpdateAbilityStatus(channelId, status == common.ChannelStatusEnabled) if err != nil { - logger.SysError("failed to update ability status: " + err.Error()) + common.SysLog("failed to update ability status: " + err.Error()) } } }() @@ -643,7 +642,7 @@ func UpdateChannelStatus(channelId int, usingKey string, status int, reason stri } err = channel.Save() if err != nil { - logger.SysError("failed to update channel status: " + err.Error()) + common.SysLog("failed to update channel status: " + err.Error()) return false } } @@ -705,7 +704,7 @@ func EditChannelByTag(tag string, newTag *string, modelMapping *string, models * for _, channel := range channels { err = channel.UpdateAbilities(nil) if err != nil { - logger.SysError("failed to update abilities: " + err.Error()) + common.SysLog("failed to update abilities: " + err.Error()) } } } @@ -729,7 +728,7 @@ func UpdateChannelUsedQuota(id int, quota int) { func updateChannelUsedQuota(id int, quota int) { err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error if err != nil { - logger.SysError("failed to update channel used quota: " + err.Error()) + common.SysLog("failed to update channel used quota: " + err.Error()) } } @@ -822,7 +821,7 @@ func (channel *Channel) GetSetting() dto.ChannelSettings { if channel.Setting != nil && *channel.Setting != "" { err := common.Unmarshal([]byte(*channel.Setting), &setting) if err != nil { - logger.SysError("failed to unmarshal setting: " + err.Error()) + common.SysLog("failed to unmarshal setting: " + err.Error()) channel.Setting = nil // 清空设置以避免后续错误 _ = channel.Save() // 保存修改 } @@ -833,7 +832,7 @@ func (channel *Channel) GetSetting() dto.ChannelSettings { func (channel *Channel) SetSetting(setting dto.ChannelSettings) { settingBytes, err := common.Marshal(setting) if err != nil { - logger.SysError("failed to marshal setting: " + err.Error()) + common.SysLog("failed to marshal setting: " + err.Error()) return } channel.Setting = common.GetPointer[string](string(settingBytes)) @@ -844,7 +843,7 @@ func (channel *Channel) GetOtherSettings() dto.ChannelOtherSettings { if channel.OtherSettings != "" { err := common.UnmarshalJsonStr(channel.OtherSettings, &setting) if err != nil { - logger.SysError("failed to unmarshal setting: " + err.Error()) + common.SysLog("failed to unmarshal setting: " + err.Error()) channel.OtherSettings = "{}" // 清空设置以避免后续错误 _ = channel.Save() // 保存修改 } @@ -855,7 +854,7 @@ func (channel *Channel) GetOtherSettings() dto.ChannelOtherSettings { func (channel *Channel) SetOtherSettings(setting dto.ChannelOtherSettings) { settingBytes, err := common.Marshal(setting) if err != nil { - logger.SysError("failed to marshal setting: " + err.Error()) + common.SysLog("failed to marshal setting: " + err.Error()) return } channel.OtherSettings = string(settingBytes) @@ -866,7 +865,7 @@ func (channel *Channel) GetParamOverride() map[string]interface{} { if channel.ParamOverride != nil && *channel.ParamOverride != "" { err := common.Unmarshal([]byte(*channel.ParamOverride), ¶mOverride) if err != nil { - logger.SysError("failed to unmarshal param override: " + err.Error()) + common.SysLog("failed to unmarshal param override: " + err.Error()) } } return paramOverride diff --git a/model/channel_cache.go b/model/channel_cache.go index 22216027..86866e40 100644 --- a/model/channel_cache.go +++ b/model/channel_cache.go @@ -6,7 +6,6 @@ import ( "math/rand" "one-api/common" "one-api/constant" - "one-api/logger" "one-api/setting" "one-api/setting/ratio_setting" "sort" @@ -85,13 +84,13 @@ func InitChannelCache() { } channelsIDM = newChannelId2channel channelSyncLock.Unlock() - logger.SysLog("channels synced from database") + common.SysLog("channels synced from database") } func SyncChannelCache(frequency int) { for { time.Sleep(time.Duration(frequency) * time.Second) - logger.SysLog("syncing channels from database") + common.SysLog("syncing channels from database") InitChannelCache() } } diff --git a/model/log.go b/model/log.go index d9495968..e443516d 100644 --- a/model/log.go +++ b/model/log.go @@ -88,7 +88,7 @@ func RecordLog(userId int, logType int, content string) { } err := LOG_DB.Create(log).Error if err != nil { - logger.SysError("failed to record log: " + err.Error()) + common.SysLog("failed to record log: " + err.Error()) } } diff --git a/model/main.go b/model/main.go index 1e582e1a..dbf27152 100644 --- a/model/main.go +++ b/model/main.go @@ -5,7 +5,6 @@ import ( "log" "one-api/common" "one-api/constant" - "one-api/logger" "os" "strings" "sync" @@ -85,7 +84,7 @@ func createRootAccountIfNeed() error { var user User //if user.Status != common.UserStatusEnabled { if err := DB.First(&user).Error; err != nil { - logger.SysLog("no user exists, create a root user for you: username is root, password is 123456") + common.SysLog("no user exists, create a root user for you: username is root, password is 123456") hashedPassword, err := common.Password2Hash("123456") if err != nil { return err @@ -109,7 +108,7 @@ func CheckSetup() { if setup == nil { // No setup record exists, check if we have a root user if RootUserExists() { - logger.SysLog("system is not initialized, but root user exists") + common.SysLog("system is not initialized, but root user exists") // Create setup record newSetup := Setup{ Version: common.Version, @@ -117,16 +116,16 @@ func CheckSetup() { } err := DB.Create(&newSetup).Error if err != nil { - logger.SysLog("failed to create setup record: " + err.Error()) + common.SysLog("failed to create setup record: " + err.Error()) } constant.Setup = true } else { - logger.SysLog("system is not initialized and no root user exists") + common.SysLog("system is not initialized and no root user exists") constant.Setup = false } } else { // Setup record exists, system is initialized - logger.SysLog("system is already initialized at: " + time.Unix(setup.InitializedAt, 0).String()) + common.SysLog("system is already initialized at: " + time.Unix(setup.InitializedAt, 0).String()) constant.Setup = true } } @@ -139,7 +138,7 @@ func chooseDB(envName string, isLog bool) (*gorm.DB, error) { if dsn != "" { if strings.HasPrefix(dsn, "postgres://") || strings.HasPrefix(dsn, "postgresql://") { // Use PostgreSQL - logger.SysLog("using PostgreSQL as database") + common.SysLog("using PostgreSQL as database") if !isLog { common.UsingPostgreSQL = true } else { @@ -153,7 +152,7 @@ func chooseDB(envName string, isLog bool) (*gorm.DB, error) { }) } if strings.HasPrefix(dsn, "local") { - logger.SysLog("SQL_DSN not set, using SQLite as database") + common.SysLog("SQL_DSN not set, using SQLite as database") if !isLog { common.UsingSQLite = true } else { @@ -164,7 +163,7 @@ func chooseDB(envName string, isLog bool) (*gorm.DB, error) { }) } // Use MySQL - logger.SysLog("using MySQL as database") + common.SysLog("using MySQL as database") // check parseTime if !strings.Contains(dsn, "parseTime") { if strings.Contains(dsn, "?") { @@ -183,7 +182,7 @@ func chooseDB(envName string, isLog bool) (*gorm.DB, error) { }) } // Use SQLite - logger.SysLog("SQL_DSN not set, using SQLite as database") + common.SysLog("SQL_DSN not set, using SQLite as database") common.UsingSQLite = true return gorm.Open(sqlite.Open(common.SQLitePath), &gorm.Config{ PrepareStmt: true, // precompile SQL @@ -217,11 +216,11 @@ func InitDB() (err error) { if common.UsingMySQL { //_, _ = sqlDB.Exec("ALTER TABLE channels MODIFY model_mapping TEXT;") // TODO: delete this line when most users have upgraded } - logger.SysLog("database migration started") + common.SysLog("database migration started") err = migrateDB() return err } else { - logger.FatalLog(err) + common.FatalLog(err) } return err } @@ -254,11 +253,11 @@ func InitLogDB() (err error) { if !common.IsMasterNode { return nil } - logger.SysLog("database migration started") + common.SysLog("database migration started") err = migrateLOGDB() return err } else { - logger.FatalLog(err) + common.FatalLog(err) } return err } @@ -355,7 +354,7 @@ func migrateDBFast() error { return err } } - logger.SysLog("database migrated") + common.SysLog("database migrated") return nil } @@ -504,6 +503,6 @@ func PingDB() error { } lastPingTime = time.Now() - logger.SysLog("Database pinged successfully") + common.SysLog("Database pinged successfully") return nil } diff --git a/model/option.go b/model/option.go index 8fcd13a8..2121710c 100644 --- a/model/option.go +++ b/model/option.go @@ -2,7 +2,6 @@ package model import ( "one-api/common" - "one-api/logger" "one-api/setting" "one-api/setting/config" "one-api/setting/operation_setting" @@ -151,7 +150,7 @@ func loadOptionsFromDatabase() { for _, option := range options { err := updateOptionMap(option.Key, option.Value) if err != nil { - logger.SysError("failed to update option map: " + err.Error()) + common.SysLog("failed to update option map: " + err.Error()) } } } @@ -159,7 +158,7 @@ func loadOptionsFromDatabase() { func SyncOptions(frequency int) { for { time.Sleep(time.Duration(frequency) * time.Second) - logger.SysLog("syncing options from database") + common.SysLog("syncing options from database") loadOptionsFromDatabase() } } diff --git a/model/pricing.go b/model/pricing.go index 31aa5cdf..3c9349de 100644 --- a/model/pricing.go +++ b/model/pricing.go @@ -3,7 +3,6 @@ package model import ( "encoding/json" "fmt" - "one-api/logger" "strings" "one-api/common" @@ -93,7 +92,7 @@ func updatePricing() { //modelRatios := common.GetModelRatios() enableAbilities, err := GetAllEnableAbilityWithChannels() if err != nil { - logger.SysError(fmt.Sprintf("GetAllEnableAbilityWithChannels error: %v", err)) + common.SysLog(fmt.Sprintf("GetAllEnableAbilityWithChannels error: %v", err)) return } // 预加载模型元数据与供应商一次,避免循环查询 diff --git a/model/token.go b/model/token.go index 63c17e2d..320b5cf0 100644 --- a/model/token.go +++ b/model/token.go @@ -4,7 +4,6 @@ import ( "errors" "fmt" "one-api/common" - "one-api/logger" "strings" "github.com/bytedance/gopkg/util/gopool" @@ -92,7 +91,7 @@ func ValidateUserToken(key string) (token *Token, err error) { token.Status = common.TokenStatusExpired err := token.SelectUpdate() if err != nil { - logger.SysError("failed to update token status" + err.Error()) + common.SysLog("failed to update token status" + err.Error()) } } return token, errors.New("该令牌已过期") @@ -103,7 +102,7 @@ func ValidateUserToken(key string) (token *Token, err error) { token.Status = common.TokenStatusExhausted err := token.SelectUpdate() if err != nil { - logger.SysError("failed to update token status" + err.Error()) + common.SysLog("failed to update token status" + err.Error()) } } keyPrefix := key[:3] @@ -135,7 +134,7 @@ func GetTokenById(id int) (*Token, error) { if shouldUpdateRedis(true, err) { gopool.Go(func() { if err := cacheSetToken(token); err != nil { - logger.SysError("failed to update user status cache: " + err.Error()) + common.SysLog("failed to update user status cache: " + err.Error()) } }) } @@ -148,7 +147,7 @@ func GetTokenByKey(key string, fromDB bool) (token *Token, err error) { if shouldUpdateRedis(fromDB, err) && token != nil { gopool.Go(func() { if err := cacheSetToken(*token); err != nil { - logger.SysError("failed to update user status cache: " + err.Error()) + common.SysLog("failed to update user status cache: " + err.Error()) } }) } @@ -179,7 +178,7 @@ func (token *Token) Update() (err error) { gopool.Go(func() { err := cacheSetToken(*token) if err != nil { - logger.SysError("failed to update token cache: " + err.Error()) + common.SysLog("failed to update token cache: " + err.Error()) } }) } @@ -195,7 +194,7 @@ func (token *Token) SelectUpdate() (err error) { gopool.Go(func() { err := cacheSetToken(*token) if err != nil { - logger.SysError("failed to update token cache: " + err.Error()) + common.SysLog("failed to update token cache: " + err.Error()) } }) } @@ -210,7 +209,7 @@ func (token *Token) Delete() (err error) { gopool.Go(func() { err := cacheDeleteToken(token.Key) if err != nil { - logger.SysError("failed to delete token cache: " + err.Error()) + common.SysLog("failed to delete token cache: " + err.Error()) } }) } @@ -270,7 +269,7 @@ func IncreaseTokenQuota(id int, key string, quota int) (err error) { gopool.Go(func() { err := cacheIncrTokenQuota(key, int64(quota)) if err != nil { - logger.SysError("failed to increase token quota: " + err.Error()) + common.SysLog("failed to increase token quota: " + err.Error()) } }) } @@ -300,7 +299,7 @@ func DecreaseTokenQuota(id int, key string, quota int) (err error) { gopool.Go(func() { err := cacheDecrTokenQuota(key, int64(quota)) if err != nil { - logger.SysError("failed to decrease token quota: " + err.Error()) + common.SysLog("failed to decrease token quota: " + err.Error()) } }) } diff --git a/model/twofa.go b/model/twofa.go index b2ea54e0..8e97289f 100644 --- a/model/twofa.go +++ b/model/twofa.go @@ -4,7 +4,6 @@ import ( "errors" "fmt" "one-api/common" - "one-api/logger" "time" "gorm.io/gorm" @@ -244,7 +243,7 @@ func (t *TwoFA) ValidateTOTPAndUpdateUsage(code string) (bool, error) { if !common.ValidateTOTPCode(t.Secret, code) { // 增加失败次数 if err := t.IncrementFailedAttempts(); err != nil { - logger.SysError("更新2FA失败次数失败: " + err.Error()) + common.SysLog("更新2FA失败次数失败: " + err.Error()) } return false, nil } @@ -256,7 +255,7 @@ func (t *TwoFA) ValidateTOTPAndUpdateUsage(code string) (bool, error) { t.LastUsedAt = &now if err := t.Update(); err != nil { - logger.SysError("更新2FA使用记录失败: " + err.Error()) + common.SysLog("更新2FA使用记录失败: " + err.Error()) } return true, nil @@ -278,7 +277,7 @@ func (t *TwoFA) ValidateBackupCodeAndUpdateUsage(code string) (bool, error) { if !valid { // 增加失败次数 if err := t.IncrementFailedAttempts(); err != nil { - logger.SysError("更新2FA失败次数失败: " + err.Error()) + common.SysLog("更新2FA失败次数失败: " + err.Error()) } return false, nil } @@ -290,7 +289,7 @@ func (t *TwoFA) ValidateBackupCodeAndUpdateUsage(code string) (bool, error) { t.LastUsedAt = &now if err := t.Update(); err != nil { - logger.SysError("更新2FA使用记录失败: " + err.Error()) + common.SysLog("更新2FA使用记录失败: " + err.Error()) } return true, nil diff --git a/model/usedata.go b/model/usedata.go index f0027a8d..1255b0be 100644 --- a/model/usedata.go +++ b/model/usedata.go @@ -4,7 +4,6 @@ import ( "fmt" "gorm.io/gorm" "one-api/common" - "one-api/logger" "sync" "time" ) @@ -25,12 +24,12 @@ func UpdateQuotaData() { // recover defer func() { if r := recover(); r != nil { - logger.SysLog(fmt.Sprintf("UpdateQuotaData panic: %s", r)) + common.SysLog(fmt.Sprintf("UpdateQuotaData panic: %s", r)) } }() for { if common.DataExportEnabled { - logger.SysLog("正在更新数据看板数据...") + common.SysLog("正在更新数据看板数据...") SaveQuotaDataCache() } time.Sleep(time.Duration(common.DataExportInterval) * time.Minute) @@ -92,7 +91,7 @@ func SaveQuotaDataCache() { } } CacheQuotaData = make(map[string]*QuotaData) - logger.SysLog(fmt.Sprintf("保存数据看板数据成功,共保存%d条数据", size)) + common.SysLog(fmt.Sprintf("保存数据看板数据成功,共保存%d条数据", size)) } func increaseQuotaData(userId int, username string, modelName string, count int, quota int, createdAt int64, tokenUsed int) { @@ -103,7 +102,7 @@ func increaseQuotaData(userId int, username string, modelName string, count int, "token_used": gorm.Expr("token_used + ?", tokenUsed), }).Error if err != nil { - logger.SysLog(fmt.Sprintf("increaseQuotaData error: %s", err)) + common.SysLog(fmt.Sprintf("increaseQuotaData error: %s", err)) } } diff --git a/model/user.go b/model/user.go index 244380ad..29d7a446 100644 --- a/model/user.go +++ b/model/user.go @@ -76,7 +76,7 @@ func (user *User) GetSetting() dto.UserSetting { if user.Setting != "" { err := json.Unmarshal([]byte(user.Setting), &setting) if err != nil { - logger.SysError("failed to unmarshal setting: " + err.Error()) + common.SysLog("failed to unmarshal setting: " + err.Error()) } } return setting @@ -85,7 +85,7 @@ func (user *User) GetSetting() dto.UserSetting { func (user *User) SetSetting(setting dto.UserSetting) { settingBytes, err := json.Marshal(setting) if err != nil { - logger.SysError("failed to marshal setting: " + err.Error()) + common.SysLog("failed to marshal setting: " + err.Error()) return } user.Setting = string(settingBytes) @@ -518,7 +518,7 @@ func IsAdmin(userId int) bool { var user User err := DB.Where("id = ?", userId).Select("role").Find(&user).Error if err != nil { - logger.SysError("no such user " + err.Error()) + common.SysLog("no such user " + err.Error()) return false } return user.Role >= common.RoleAdminUser @@ -573,7 +573,7 @@ func GetUserQuota(id int, fromDB bool) (quota int, err error) { if shouldUpdateRedis(fromDB, err) { gopool.Go(func() { if err := updateUserQuotaCache(id, quota); err != nil { - logger.SysError("failed to update user quota cache: " + err.Error()) + common.SysLog("failed to update user quota cache: " + err.Error()) } }) } @@ -611,7 +611,7 @@ func GetUserGroup(id int, fromDB bool) (group string, err error) { if shouldUpdateRedis(fromDB, err) { gopool.Go(func() { if err := updateUserGroupCache(id, group); err != nil { - logger.SysError("failed to update user group cache: " + err.Error()) + common.SysLog("failed to update user group cache: " + err.Error()) } }) } @@ -640,7 +640,7 @@ func GetUserSetting(id int, fromDB bool) (settingMap dto.UserSetting, err error) if shouldUpdateRedis(fromDB, err) { gopool.Go(func() { if err := updateUserSettingCache(id, setting); err != nil { - logger.SysError("failed to update user setting cache: " + err.Error()) + common.SysLog("failed to update user setting cache: " + err.Error()) } }) } @@ -670,7 +670,7 @@ func IncreaseUserQuota(id int, quota int, db bool) (err error) { gopool.Go(func() { err := cacheIncrUserQuota(id, int64(quota)) if err != nil { - logger.SysError("failed to increase user quota: " + err.Error()) + common.SysLog("failed to increase user quota: " + err.Error()) } }) if !db && common.BatchUpdateEnabled { @@ -695,7 +695,7 @@ func DecreaseUserQuota(id int, quota int) (err error) { gopool.Go(func() { err := cacheDecrUserQuota(id, int64(quota)) if err != nil { - logger.SysError("failed to decrease user quota: " + err.Error()) + common.SysLog("failed to decrease user quota: " + err.Error()) } }) if common.BatchUpdateEnabled { @@ -751,7 +751,7 @@ func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) { }, ).Error if err != nil { - logger.SysError("failed to update user used quota and request count: " + err.Error()) + common.SysLog("failed to update user used quota and request count: " + err.Error()) return } @@ -768,14 +768,14 @@ func updateUserUsedQuota(id int, quota int) { }, ).Error if err != nil { - logger.SysError("failed to update user used quota: " + err.Error()) + common.SysLog("failed to update user used quota: " + err.Error()) } } func updateUserRequestCount(id int, count int) { err := DB.Model(&User{}).Where("id = ?", id).Update("request_count", gorm.Expr("request_count + ?", count)).Error if err != nil { - logger.SysError("failed to update user request count: " + err.Error()) + common.SysLog("failed to update user request count: " + err.Error()) } } @@ -786,7 +786,7 @@ func GetUsernameById(id int, fromDB bool) (username string, err error) { if shouldUpdateRedis(fromDB, err) { gopool.Go(func() { if err := updateUserNameCache(id, username); err != nil { - logger.SysError("failed to update user name cache: " + err.Error()) + common.SysLog("failed to update user name cache: " + err.Error()) } }) } diff --git a/model/user_cache.go b/model/user_cache.go index dec7597b..936e1a43 100644 --- a/model/user_cache.go +++ b/model/user_cache.go @@ -5,7 +5,6 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" - "one-api/logger" "time" "github.com/gin-gonic/gin" @@ -38,7 +37,7 @@ func (user *UserBase) GetSetting() dto.UserSetting { if user.Setting != "" { err := common.Unmarshal([]byte(user.Setting), &setting) if err != nil { - logger.SysError("failed to unmarshal setting: " + err.Error()) + common.SysLog("failed to unmarshal setting: " + err.Error()) } } return setting @@ -79,7 +78,7 @@ func GetUserCache(userId int) (userCache *UserBase, err error) { if shouldUpdateRedis(fromDB, err) && user != nil { gopool.Go(func() { if err := updateUserCache(*user); err != nil { - logger.SysError("failed to update user status cache: " + err.Error()) + common.SysLog("failed to update user status cache: " + err.Error()) } }) } diff --git a/model/utils.go b/model/utils.go index abd96b79..dced2bc6 100644 --- a/model/utils.go +++ b/model/utils.go @@ -3,7 +3,6 @@ package model import ( "errors" "one-api/common" - "one-api/logger" "sync" "time" @@ -66,7 +65,7 @@ func batchUpdate() { return } - logger.SysLog("batch update started") + common.SysLog("batch update started") for i := 0; i < BatchUpdateTypeCount; i++ { batchUpdateLocks[i].Lock() store := batchUpdateStores[i] @@ -78,12 +77,12 @@ func batchUpdate() { case BatchUpdateTypeUserQuota: err := increaseUserQuota(key, value) if err != nil { - logger.SysError("failed to batch update user quota: " + err.Error()) + common.SysLog("failed to batch update user quota: " + err.Error()) } case BatchUpdateTypeTokenQuota: err := increaseTokenQuota(key, value) if err != nil { - logger.SysError("failed to batch update token quota: " + err.Error()) + common.SysLog("failed to batch update token quota: " + err.Error()) } case BatchUpdateTypeUsedQuota: updateUserUsedQuota(key, value) @@ -94,7 +93,7 @@ func batchUpdate() { } } } - logger.SysLog("batch update finished") + common.SysLog("batch update finished") } func RecordExist(err error) (bool, error) { diff --git a/relay/channel/ali/adaptor.go b/relay/channel/ali/adaptor.go index bfb94008..0ae8a8d1 100644 --- a/relay/channel/ali/adaptor.go +++ b/relay/channel/ali/adaptor.go @@ -34,20 +34,20 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { var fullRequestURL string switch info.RelayFormat { - case relaycommon.RelayFormatClaude: - fullRequestURL = fmt.Sprintf("%s/api/v2/apps/claude-code-proxy/v1/messages", info.BaseUrl) + case types.RelayFormatClaude: + fullRequestURL = fmt.Sprintf("%s/api/v2/apps/claude-code-proxy/v1/messages", info.ChannelBaseUrl) default: switch info.RelayMode { case constant.RelayModeEmbeddings: - fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/embeddings", info.BaseUrl) + fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/embeddings", info.ChannelBaseUrl) case constant.RelayModeRerank: - fullRequestURL = fmt.Sprintf("%s/api/v1/services/rerank/text-rerank/text-rerank", info.BaseUrl) + fullRequestURL = fmt.Sprintf("%s/api/v1/services/rerank/text-rerank/text-rerank", info.ChannelBaseUrl) case constant.RelayModeImagesGenerations: - fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.BaseUrl) + fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.ChannelBaseUrl) case constant.RelayModeCompletions: - fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/completions", info.BaseUrl) + fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/completions", info.ChannelBaseUrl) default: - fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/chat/completions", info.BaseUrl) + fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/chat/completions", info.ChannelBaseUrl) } } @@ -118,7 +118,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { switch info.RelayFormat { - case relaycommon.RelayFormatClaude: + case types.RelayFormatClaude: if info.IsStream { err, usage = claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage) } else { diff --git a/relay/channel/ali/image.go b/relay/channel/ali/image.go index 841896cf..645882bc 100644 --- a/relay/channel/ali/image.go +++ b/relay/channel/ali/image.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "net/http" + "one-api/common" "one-api/dto" "one-api/logger" relaycommon "one-api/relay/common" @@ -22,14 +23,14 @@ func oaiImage2Ali(request dto.ImageRequest) *AliImageRequest { imageRequest.Input.Prompt = request.Prompt imageRequest.Model = request.Model imageRequest.Parameters.Size = strings.Replace(request.Size, "x", "*", -1) - imageRequest.Parameters.N = request.N + imageRequest.Parameters.N = int(request.N) imageRequest.ResponseFormat = request.ResponseFormat return &imageRequest } func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error, []byte) { - url := fmt.Sprintf("%s/api/v1/tasks/%s", info.BaseUrl, taskID) + url := fmt.Sprintf("%s/api/v1/tasks/%s", info.ChannelBaseUrl, taskID) var aliResponse AliResponse @@ -43,7 +44,7 @@ func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error client := &http.Client{} resp, err := client.Do(req) if err != nil { - logger.SysError("updateTask client.Do err: " + err.Error()) + common.SysLog("updateTask client.Do err: " + err.Error()) return &aliResponse, err, nil } defer resp.Body.Close() @@ -53,7 +54,7 @@ func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error var response AliResponse err = json.Unmarshal(responseBody, &response) if err != nil { - logger.SysError("updateTask NewDecoder err: " + err.Error()) + common.SysLog("updateTask NewDecoder err: " + err.Error()) return &aliResponse, err, nil } diff --git a/relay/channel/ali/text.go b/relay/channel/ali/text.go index 17fcef2a..67b63286 100644 --- a/relay/channel/ali/text.go +++ b/relay/channel/ali/text.go @@ -7,7 +7,6 @@ import ( "net/http" "one-api/common" "one-api/dto" - "one-api/logger" "one-api/relay/helper" "one-api/service" "strings" @@ -150,7 +149,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, var aliResponse AliResponse err := json.Unmarshal([]byte(data), &aliResponse) if err != nil { - logger.SysError("error unmarshalling stream response: " + err.Error()) + common.SysLog("error unmarshalling stream response: " + err.Error()) return true } if aliResponse.Usage.OutputTokens != 0 { @@ -163,7 +162,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, lastResponseText = aliResponse.Output.Text jsonResponse, err := json.Marshal(response) if err != nil { - logger.SysError("error marshalling stream response: " + err.Error()) + common.SysLog("error marshalling stream response: " + err.Error()) return true } c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) diff --git a/relay/channel/baidu/adaptor.go b/relay/channel/baidu/adaptor.go index 8396a844..32e301ee 100644 --- a/relay/channel/baidu/adaptor.go +++ b/relay/channel/baidu/adaptor.go @@ -101,7 +101,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { default: suffix += strings.ToLower(info.UpstreamModelName) } - fullRequestURL := fmt.Sprintf("%s/rpc/2.0/ai_custom/v1/wenxinworkshop/%s", info.BaseUrl, suffix) + fullRequestURL := fmt.Sprintf("%s/rpc/2.0/ai_custom/v1/wenxinworkshop/%s", info.ChannelBaseUrl, suffix) var accessToken string var err error if accessToken, err = getBaiduAccessToken(info.ApiKey); err != nil { diff --git a/relay/channel/baidu/relay-baidu.go b/relay/channel/baidu/relay-baidu.go index 696c2496..31e8319e 100644 --- a/relay/channel/baidu/relay-baidu.go +++ b/relay/channel/baidu/relay-baidu.go @@ -9,7 +9,6 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" - "one-api/logger" relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" @@ -119,7 +118,7 @@ func baiduStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http. var baiduResponse BaiduChatStreamResponse err := common.Unmarshal([]byte(data), &baiduResponse) if err != nil { - logger.SysError("error unmarshalling stream response: " + err.Error()) + common.SysLog("error unmarshalling stream response: " + err.Error()) return true } if baiduResponse.Usage.TotalTokens != 0 { @@ -130,7 +129,7 @@ func baiduStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http. response := streamResponseBaidu2OpenAI(&baiduResponse) err = helper.ObjectData(c, response) if err != nil { - logger.SysError("error sending stream response: " + err.Error()) + common.SysLog("error sending stream response: " + err.Error()) } return true }) diff --git a/relay/channel/baidu_v2/adaptor.go b/relay/channel/baidu_v2/adaptor.go index ba59e307..6744f8ba 100644 --- a/relay/channel/baidu_v2/adaptor.go +++ b/relay/channel/baidu_v2/adaptor.go @@ -45,15 +45,15 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { switch info.RelayMode { case constant.RelayModeChatCompletions: - return fmt.Sprintf("%s/v2/chat/completions", info.BaseUrl), nil + return fmt.Sprintf("%s/v2/chat/completions", info.ChannelBaseUrl), nil case constant.RelayModeEmbeddings: - return fmt.Sprintf("%s/v2/embeddings", info.BaseUrl), nil + return fmt.Sprintf("%s/v2/embeddings", info.ChannelBaseUrl), nil case constant.RelayModeImagesGenerations: - return fmt.Sprintf("%s/v2/images/generations", info.BaseUrl), nil + return fmt.Sprintf("%s/v2/images/generations", info.ChannelBaseUrl), nil case constant.RelayModeImagesEdits: - return fmt.Sprintf("%s/v2/images/edits", info.BaseUrl), nil + return fmt.Sprintf("%s/v2/images/edits", info.ChannelBaseUrl), nil case constant.RelayModeRerank: - return fmt.Sprintf("%s/v2/rerank", info.BaseUrl), nil + return fmt.Sprintf("%s/v2/rerank", info.ChannelBaseUrl), nil default: } return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode) diff --git a/relay/channel/claude/adaptor.go b/relay/channel/claude/adaptor.go index 39b8ce2f..41583d30 100644 --- a/relay/channel/claude/adaptor.go +++ b/relay/channel/claude/adaptor.go @@ -53,9 +53,9 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { if a.RequestMode == RequestModeMessage { - return fmt.Sprintf("%s/v1/messages", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/messages", info.ChannelBaseUrl), nil } else { - return fmt.Sprintf("%s/v1/complete", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/complete", info.ChannelBaseUrl), nil } } diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index 5d839908..57670bcf 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -376,7 +376,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla for _, toolCall := range message.ParseToolCalls() { inputObj := make(map[string]any) if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &inputObj); err != nil { - logger.SysError("tool call function arguments is not a map[string]any: " + fmt.Sprintf("%v", toolCall.Function.Arguments)) + common.SysLog("tool call function arguments is not a map[string]any: " + fmt.Sprintf("%v", toolCall.Function.Arguments)) continue } claudeMediaMessages = append(claudeMediaMessages, dto.ClaudeMediaMessage{ @@ -610,13 +610,13 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud var claudeResponse dto.ClaudeResponse err := common.UnmarshalJsonStr(data, &claudeResponse) if err != nil { - logger.SysError("error unmarshalling stream response: " + err.Error()) + common.SysLog("error unmarshalling stream response: " + err.Error()) return types.NewError(err, types.ErrorCodeBadResponseBody) } if claudeError := claudeResponse.GetClaudeError(); claudeError != nil && claudeError.Type != "" { return types.WithClaudeError(*claudeError, http.StatusInternalServerError) } - if info.RelayFormat == relaycommon.RelayFormatClaude { + if info.RelayFormat == types.RelayFormatClaude { FormatClaudeResponseInfo(requestMode, &claudeResponse, nil, claudeInfo) if requestMode == RequestModeCompletion { @@ -629,7 +629,7 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud } } helper.ClaudeChunkData(c, claudeResponse, data) - } else if info.RelayFormat == relaycommon.RelayFormatOpenAI { + } else if info.RelayFormat == types.RelayFormatOpenAI { response := StreamResponseClaude2OpenAI(requestMode, &claudeResponse) if !FormatClaudeResponseInfo(requestMode, &claudeResponse, response, claudeInfo) { @@ -654,21 +654,20 @@ func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, clau } if claudeInfo.Usage.CompletionTokens == 0 || !claudeInfo.Done { if common.DebugEnabled { - logger.SysError("claude response usage is not complete, maybe upstream error") + common.SysLog("claude response usage is not complete, maybe upstream error") } claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens) } } - if info.RelayFormat == relaycommon.RelayFormatClaude { + if info.RelayFormat == types.RelayFormatClaude { // - } else if info.RelayFormat == relaycommon.RelayFormatOpenAI { - + } else if info.RelayFormat == types.RelayFormatOpenAI { if info.ShouldIncludeUsage { response := helper.GenerateFinalUsageResponse(claudeInfo.ResponseId, claudeInfo.Created, info.UpstreamModelName, *claudeInfo.Usage) err := helper.ObjectData(c, response) if err != nil { - logger.SysError("send final response failed: " + err.Error()) + common.SysLog("send final response failed: " + err.Error()) } } helper.Done(c) @@ -722,14 +721,14 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud } var responseData []byte switch info.RelayFormat { - case relaycommon.RelayFormatOpenAI: + case types.RelayFormatOpenAI: openaiResponse := ResponseClaude2OpenAI(requestMode, &claudeResponse) openaiResponse.Usage = *claudeInfo.Usage responseData, err = json.Marshal(openaiResponse) if err != nil { return types.NewError(err, types.ErrorCodeBadResponseBody) } - case relaycommon.RelayFormatClaude: + case types.RelayFormatClaude: responseData = data } diff --git a/relay/channel/cloudflare/adaptor.go b/relay/channel/cloudflare/adaptor.go index 4b9f5028..bdea72f0 100644 --- a/relay/channel/cloudflare/adaptor.go +++ b/relay/channel/cloudflare/adaptor.go @@ -36,13 +36,13 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { switch info.RelayMode { case constant.RelayModeChatCompletions: - return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/chat/completions", info.BaseUrl, info.ApiVersion), nil + return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/chat/completions", info.ChannelBaseUrl, info.ApiVersion), nil case constant.RelayModeEmbeddings: - return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/embeddings", info.BaseUrl, info.ApiVersion), nil + return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/embeddings", info.ChannelBaseUrl, info.ApiVersion), nil case constant.RelayModeResponses: - return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/responses", info.BaseUrl, info.ApiVersion), nil + return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/responses", info.ChannelBaseUrl, info.ApiVersion), nil default: - return fmt.Sprintf("%s/client/v4/accounts/%s/ai/run/%s", info.BaseUrl, info.ApiVersion, info.UpstreamModelName), nil + return fmt.Sprintf("%s/client/v4/accounts/%s/ai/run/%s", info.ChannelBaseUrl, info.ApiVersion, info.UpstreamModelName), nil } } diff --git a/relay/channel/cohere/adaptor.go b/relay/channel/cohere/adaptor.go index 887f9efd..c8a38d46 100644 --- a/relay/channel/cohere/adaptor.go +++ b/relay/channel/cohere/adaptor.go @@ -43,9 +43,9 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { if info.RelayMode == constant.RelayModeRerank { - return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/rerank", info.ChannelBaseUrl), nil } else { - return fmt.Sprintf("%s/v1/chat", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/chat", info.ChannelBaseUrl), nil } } diff --git a/relay/channel/cohere/relay-cohere.go b/relay/channel/cohere/relay-cohere.go index ccef9b23..af357348 100644 --- a/relay/channel/cohere/relay-cohere.go +++ b/relay/channel/cohere/relay-cohere.go @@ -7,7 +7,6 @@ import ( "net/http" "one-api/common" "one-api/dto" - "one-api/logger" relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" @@ -119,7 +118,7 @@ func cohereStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http var cohereResp CohereResponse err := json.Unmarshal([]byte(data), &cohereResp) if err != nil { - logger.SysError("error unmarshalling stream response: " + err.Error()) + common.SysLog("error unmarshalling stream response: " + err.Error()) return true } var openaiResp dto.ChatCompletionsStreamResponse @@ -154,7 +153,7 @@ func cohereStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http } jsonStr, err := json.Marshal(openaiResp) if err != nil { - logger.SysError("error marshalling stream response: " + err.Error()) + common.SysLog("error marshalling stream response: " + err.Error()) return true } c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)}) diff --git a/relay/channel/coze/adaptor.go b/relay/channel/coze/adaptor.go index 658c6193..0f2a6fd3 100644 --- a/relay/channel/coze/adaptor.go +++ b/relay/channel/coze/adaptor.go @@ -122,7 +122,7 @@ func (a *Adaptor) GetModelList() []string { // GetRequestURL implements channel.Adaptor. func (a *Adaptor) GetRequestURL(info *common.RelayInfo) (string, error) { - return fmt.Sprintf("%s/v3/chat", info.BaseUrl), nil + return fmt.Sprintf("%s/v3/chat", info.ChannelBaseUrl), nil } // Init implements channel.Adaptor. diff --git a/relay/channel/coze/relay-coze.go b/relay/channel/coze/relay-coze.go index 18ed46af..c480045f 100644 --- a/relay/channel/coze/relay-coze.go +++ b/relay/channel/coze/relay-coze.go @@ -9,7 +9,6 @@ import ( "net/http" "one-api/common" "one-api/dto" - "one-api/logger" relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" @@ -155,7 +154,7 @@ func handleCozeEvent(c *gin.Context, event string, data string, responseText *st var chatData CozeChatResponseData err := json.Unmarshal([]byte(data), &chatData) if err != nil { - logger.SysError("error_unmarshalling_stream_response: " + err.Error()) + common.SysLog("error_unmarshalling_stream_response: " + err.Error()) return } @@ -172,14 +171,14 @@ func handleCozeEvent(c *gin.Context, event string, data string, responseText *st var messageData CozeChatV3MessageDetail err := json.Unmarshal([]byte(data), &messageData) if err != nil { - logger.SysError("error_unmarshalling_stream_response: " + err.Error()) + common.SysLog("error_unmarshalling_stream_response: " + err.Error()) return } var content string err = json.Unmarshal(messageData.Content, &content) if err != nil { - logger.SysError("error_unmarshalling_stream_response: " + err.Error()) + common.SysLog("error_unmarshalling_stream_response: " + err.Error()) return } @@ -204,16 +203,16 @@ func handleCozeEvent(c *gin.Context, event string, data string, responseText *st var errorData CozeError err := json.Unmarshal([]byte(data), &errorData) if err != nil { - logger.SysError("error_unmarshalling_stream_response: " + err.Error()) + common.SysLog("error_unmarshalling_stream_response: " + err.Error()) return } - logger.SysError(fmt.Sprintf("stream event error: ", errorData.Code, errorData.Message)) + common.SysLog(fmt.Sprintf("stream event error: ", errorData.Code, errorData.Message)) } } func checkIfChatComplete(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (error, bool) { - requestURL := fmt.Sprintf("%s/v3/chat/retrieve", info.BaseUrl) + requestURL := fmt.Sprintf("%s/v3/chat/retrieve", info.ChannelBaseUrl) requestURL = requestURL + "?conversation_id=" + c.GetString("coze_conversation_id") + "&chat_id=" + c.GetString("coze_chat_id") // 将 conversationId和chatId作为参数发送get请求 @@ -259,7 +258,7 @@ func checkIfChatComplete(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo } func getChatDetail(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (*http.Response, error) { - requestURL := fmt.Sprintf("%s/v3/chat/message/list", info.BaseUrl) + requestURL := fmt.Sprintf("%s/v3/chat/message/list", info.ChannelBaseUrl) requestURL = requestURL + "?conversation_id=" + c.GetString("coze_conversation_id") + "&chat_id=" + c.GetString("coze_chat_id") req, err := http.NewRequest("GET", requestURL, nil) diff --git a/relay/channel/deepseek/adaptor.go b/relay/channel/deepseek/adaptor.go index be8de0c8..17d732ab 100644 --- a/relay/channel/deepseek/adaptor.go +++ b/relay/channel/deepseek/adaptor.go @@ -43,15 +43,15 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { - fimBaseUrl := info.BaseUrl - if !strings.HasSuffix(info.BaseUrl, "/beta") { + fimBaseUrl := info.ChannelBaseUrl + if !strings.HasSuffix(info.ChannelBaseUrl, "/beta") { fimBaseUrl += "/beta" } switch info.RelayMode { case constant.RelayModeCompletions: return fmt.Sprintf("%s/completions", fimBaseUrl), nil default: - return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil } } diff --git a/relay/channel/dify/adaptor.go b/relay/channel/dify/adaptor.go index 8c7898c9..0a08d035 100644 --- a/relay/channel/dify/adaptor.go +++ b/relay/channel/dify/adaptor.go @@ -61,13 +61,13 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { switch a.BotType { case BotTypeWorkFlow: - return fmt.Sprintf("%s/v1/workflows/run", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/workflows/run", info.ChannelBaseUrl), nil case BotTypeCompletion: - return fmt.Sprintf("%s/v1/completion-messages", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/completion-messages", info.ChannelBaseUrl), nil case BotTypeAgent: fallthrough default: - return fmt.Sprintf("%s/v1/chat-messages", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/chat-messages", info.ChannelBaseUrl), nil } } diff --git a/relay/channel/dify/relay-dify.go b/relay/channel/dify/relay-dify.go index f03d61a4..2336fd4c 100644 --- a/relay/channel/dify/relay-dify.go +++ b/relay/channel/dify/relay-dify.go @@ -11,7 +11,6 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" - "one-api/logger" relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" @@ -23,7 +22,7 @@ import ( ) func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, media dto.MediaContent) *DifyFile { - uploadUrl := fmt.Sprintf("%s/v1/files/upload", info.BaseUrl) + uploadUrl := fmt.Sprintf("%s/v1/files/upload", info.ChannelBaseUrl) switch media.Type { case dto.ContentTypeImageURL: // Decode base64 data @@ -37,14 +36,14 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me // Decode base64 string decodedData, err := base64.StdEncoding.DecodeString(base64Data) if err != nil { - logger.SysError("failed to decode base64: " + err.Error()) + common.SysLog("failed to decode base64: " + err.Error()) return nil } // Create temporary file tempFile, err := os.CreateTemp("", "dify-upload-*") if err != nil { - logger.SysError("failed to create temp file: " + err.Error()) + common.SysLog("failed to create temp file: " + err.Error()) return nil } defer tempFile.Close() @@ -52,7 +51,7 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me // Write decoded data to temp file if _, err := tempFile.Write(decodedData); err != nil { - logger.SysError("failed to write to temp file: " + err.Error()) + common.SysLog("failed to write to temp file: " + err.Error()) return nil } @@ -62,7 +61,7 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me // Add user field if err := writer.WriteField("user", user); err != nil { - logger.SysError("failed to add user field: " + err.Error()) + common.SysLog("failed to add user field: " + err.Error()) return nil } @@ -75,13 +74,13 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me // Create form file part, err := writer.CreateFormFile("file", fmt.Sprintf("image.%s", strings.TrimPrefix(mimeType, "image/"))) if err != nil { - logger.SysError("failed to create form file: " + err.Error()) + common.SysLog("failed to create form file: " + err.Error()) return nil } // Copy file content to form if _, err = io.Copy(part, bytes.NewReader(decodedData)); err != nil { - logger.SysError("failed to copy file content: " + err.Error()) + common.SysLog("failed to copy file content: " + err.Error()) return nil } writer.Close() @@ -89,7 +88,7 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me // Create HTTP request req, err := http.NewRequest("POST", uploadUrl, body) if err != nil { - logger.SysError("failed to create request: " + err.Error()) + common.SysLog("failed to create request: " + err.Error()) return nil } @@ -100,7 +99,7 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me client := service.GetHttpClient() resp, err := client.Do(req) if err != nil { - logger.SysError("failed to send request: " + err.Error()) + common.SysLog("failed to send request: " + err.Error()) return nil } defer resp.Body.Close() @@ -110,7 +109,7 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me Id string `json:"id"` } if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { - logger.SysError("failed to decode response: " + err.Error()) + common.SysLog("failed to decode response: " + err.Error()) return nil } @@ -220,7 +219,7 @@ func difyStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R var difyResponse DifyChunkChatCompletionResponse err := json.Unmarshal([]byte(data), &difyResponse) if err != nil { - logger.SysError("error unmarshalling stream response: " + err.Error()) + common.SysLog("error unmarshalling stream response: " + err.Error()) return true } var openaiResponse dto.ChatCompletionsStreamResponse @@ -240,7 +239,7 @@ func difyStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R } err = helper.ObjectData(c, openaiResponse) if err != nil { - logger.SysError(err.Error()) + common.SysLog(err.Error()) } return true }) diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go index 05d974f6..99b6645e 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/channel/gemini/adaptor.go @@ -108,7 +108,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { version := model_setting.GetGeminiVersionSetting(info.UpstreamModelName) if strings.HasPrefix(info.UpstreamModelName, "imagen") { - return fmt.Sprintf("%s/%s/models/%s:predict", info.BaseUrl, version, info.UpstreamModelName), nil + return fmt.Sprintf("%s/%s/models/%s:predict", info.ChannelBaseUrl, version, info.UpstreamModelName), nil } if strings.HasPrefix(info.UpstreamModelName, "text-embedding") || @@ -118,7 +118,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { if info.IsGeminiBatchEmbedding { action = "batchEmbedContents" } - return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil + return fmt.Sprintf("%s/%s/models/%s:%s", info.ChannelBaseUrl, version, info.UpstreamModelName, action), nil } action := "generateContent" @@ -128,7 +128,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { info.DisablePing = true } } - return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil + return fmt.Sprintf("%s/%s/models/%s:%s", info.ChannelBaseUrl, version, info.UpstreamModelName, action), nil } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index 82a2d8de..af5e8233 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -994,7 +994,7 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp * response := helper.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage) err := handleFinalStream(c, info, response) if err != nil { - logger.SysError("send final response failed: " + err.Error()) + common.SysLog("send final response failed: " + err.Error()) } //if info.RelayFormat == relaycommon.RelayFormatOpenAI { // helper.Done(c) @@ -1042,19 +1042,19 @@ func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R fullTextResponse.Usage = usage switch info.RelayFormat { - case relaycommon.RelayFormatOpenAI: + case types.RelayFormatOpenAI: responseBody, err = common.Marshal(fullTextResponse) if err != nil { return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } - case relaycommon.RelayFormatClaude: + case types.RelayFormatClaude: claudeResp := service.ResponseOpenAI2Claude(fullTextResponse, info) claudeRespStr, err := common.Marshal(claudeResp) if err != nil { return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } responseBody = claudeRespStr - case relaycommon.RelayFormatGemini: + case types.RelayFormatGemini: break } diff --git a/relay/channel/jimeng/adaptor.go b/relay/channel/jimeng/adaptor.go index ff9ac678..885a1427 100644 --- a/relay/channel/jimeng/adaptor.go +++ b/relay/channel/jimeng/adaptor.go @@ -32,7 +32,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { - return fmt.Sprintf("%s/?Action=CVProcess&Version=2022-08-31", info.BaseUrl), nil + return fmt.Sprintf("%s/?Action=CVProcess&Version=2022-08-31", info.ChannelBaseUrl), nil } func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info *relaycommon.RelayInfo) error { diff --git a/relay/channel/jina/adaptor.go b/relay/channel/jina/adaptor.go index bf318aa7..a383728f 100644 --- a/relay/channel/jina/adaptor.go +++ b/relay/channel/jina/adaptor.go @@ -45,9 +45,9 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { if info.RelayMode == constant.RelayModeRerank { - return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/rerank", info.ChannelBaseUrl), nil } else if info.RelayMode == constant.RelayModeEmbeddings { - return fmt.Sprintf("%s/v1/embeddings", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/embeddings", info.ChannelBaseUrl), nil } return "", errors.New("invalid relay mode") } diff --git a/relay/channel/minimax/relay-minimax.go b/relay/channel/minimax/relay-minimax.go index d0a15b0d..ff9b72ea 100644 --- a/relay/channel/minimax/relay-minimax.go +++ b/relay/channel/minimax/relay-minimax.go @@ -6,5 +6,5 @@ import ( ) func GetRequestURL(info *relaycommon.RelayInfo) (string, error) { - return fmt.Sprintf("%s/v1/text/chatcompletion_v2", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/text/chatcompletion_v2", info.ChannelBaseUrl), nil } diff --git a/relay/channel/mistral/adaptor.go b/relay/channel/mistral/adaptor.go index 45cb3290..f98ff869 100644 --- a/relay/channel/mistral/adaptor.go +++ b/relay/channel/mistral/adaptor.go @@ -41,7 +41,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { - return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil + return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, info.RequestURLPath, info.ChannelType), nil } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { diff --git a/relay/channel/mokaai/adaptor.go b/relay/channel/mokaai/adaptor.go index 37db2aec..f9da685f 100644 --- a/relay/channel/mokaai/adaptor.go +++ b/relay/channel/mokaai/adaptor.go @@ -54,7 +54,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { if strings.HasPrefix(info.UpstreamModelName, "m3e") { suffix = "embeddings" } - fullRequestURL := fmt.Sprintf("%s/%s", info.BaseUrl, suffix) + fullRequestURL := fmt.Sprintf("%s/%s", info.ChannelBaseUrl, suffix) return fullRequestURL, nil } diff --git a/relay/channel/moonshot/adaptor.go b/relay/channel/moonshot/adaptor.go index d540388d..29004d0c 100644 --- a/relay/channel/moonshot/adaptor.go +++ b/relay/channel/moonshot/adaptor.go @@ -44,19 +44,19 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { switch info.RelayFormat { - case relaycommon.RelayFormatClaude: - return fmt.Sprintf("%s/anthropic/v1/messages", info.BaseUrl), nil + case types.RelayFormatClaude: + return fmt.Sprintf("%s/anthropic/v1/messages", info.ChannelBaseUrl), nil default: if info.RelayMode == constant.RelayModeRerank { - return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/rerank", info.ChannelBaseUrl), nil } else if info.RelayMode == constant.RelayModeEmbeddings { - return fmt.Sprintf("%s/v1/embeddings", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/embeddings", info.ChannelBaseUrl), nil } else if info.RelayMode == constant.RelayModeChatCompletions { - return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil } else if info.RelayMode == constant.RelayModeCompletions { - return fmt.Sprintf("%s/v1/completions", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/completions", info.ChannelBaseUrl), nil } - return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil } } @@ -89,10 +89,10 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { switch info.RelayFormat { - case relaycommon.RelayFormatOpenAI: + case types.RelayFormatOpenAI: adaptor := openai.Adaptor{} return adaptor.DoResponse(c, resp, info) - case relaycommon.RelayFormatClaude: + case types.RelayFormatClaude: if info.IsStream { err, usage = claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage) } else { diff --git a/relay/channel/ollama/adaptor.go b/relay/channel/ollama/adaptor.go index 1f3fda8d..1a0caf75 100644 --- a/relay/channel/ollama/adaptor.go +++ b/relay/channel/ollama/adaptor.go @@ -48,14 +48,14 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { - if info.RelayFormat == relaycommon.RelayFormatClaude { - return info.BaseUrl + "/v1/chat/completions", nil + if info.RelayFormat == types.RelayFormatClaude { + return info.ChannelBaseUrl + "/v1/chat/completions", nil } switch info.RelayMode { case relayconstant.RelayModeEmbeddings: - return info.BaseUrl + "/api/embed", nil + return info.ChannelBaseUrl + "/api/embed", nil default: - return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil + return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, info.RequestURLPath, info.ChannelType), nil } } diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index fc1749a0..d783b9d8 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -105,14 +105,14 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { if info.RelayMode == relayconstant.RelayModeRealtime { - if strings.HasPrefix(info.BaseUrl, "https://") { - baseUrl := strings.TrimPrefix(info.BaseUrl, "https://") + if strings.HasPrefix(info.ChannelBaseUrl, "https://") { + baseUrl := strings.TrimPrefix(info.ChannelBaseUrl, "https://") baseUrl = "wss://" + baseUrl - info.BaseUrl = baseUrl - } else if strings.HasPrefix(info.BaseUrl, "http://") { - baseUrl := strings.TrimPrefix(info.BaseUrl, "http://") + info.ChannelBaseUrl = baseUrl + } else if strings.HasPrefix(info.ChannelBaseUrl, "http://") { + baseUrl := strings.TrimPrefix(info.ChannelBaseUrl, "http://") baseUrl = "ws://" + baseUrl - info.BaseUrl = baseUrl + info.ChannelBaseUrl = baseUrl } } switch info.ChannelType { @@ -126,7 +126,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion) task := strings.TrimPrefix(requestURL, "/v1/") - if info.RelayFormat == relaycommon.RelayFormatClaude { + if info.RelayFormat == types.RelayFormatClaude { task = strings.TrimPrefix(task, "messages") task = "chat/completions" + task } @@ -136,7 +136,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { responsesApiVersion := "preview" subUrl := "/openai/v1/responses" - if strings.Contains(info.BaseUrl, "cognitiveservices.azure.com") { + if strings.Contains(info.ChannelBaseUrl, "cognitiveservices.azure.com") { subUrl = "/openai/responses" responsesApiVersion = apiVersion } @@ -146,7 +146,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { } requestURL = fmt.Sprintf("%s?api-version=%s", subUrl, responsesApiVersion) - return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil + return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, requestURL, info.ChannelType), nil } model_ := info.UpstreamModelName @@ -159,18 +159,18 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { if info.RelayMode == relayconstant.RelayModeRealtime { requestURL = fmt.Sprintf("/openai/realtime?deployment=%s&api-version=%s", model_, apiVersion) } - return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil + return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, requestURL, info.ChannelType), nil case constant.ChannelTypeMiniMax: return minimax.GetRequestURL(info) case constant.ChannelTypeCustom: - url := info.BaseUrl + url := info.ChannelBaseUrl url = strings.Replace(url, "{model}", info.UpstreamModelName, -1) return url, nil default: - if info.RelayFormat == relaycommon.RelayFormatClaude || info.RelayFormat == relaycommon.RelayFormatGemini { - return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil + if info.RelayFormat == types.RelayFormatClaude || info.RelayFormat == types.RelayFormatGemini { + return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil } - return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil + return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, info.RequestURLPath, info.ChannelType), nil } } diff --git a/relay/channel/openai/helper.go b/relay/channel/openai/helper.go index 80973aa1..2a4b4938 100644 --- a/relay/channel/openai/helper.go +++ b/relay/channel/openai/helper.go @@ -12,6 +12,7 @@ import ( relayconstant "one-api/relay/constant" "one-api/relay/helper" "one-api/service" + "one-api/types" "strings" "github.com/gin-gonic/gin" @@ -22,11 +23,11 @@ func HandleStreamFormat(c *gin.Context, info *relaycommon.RelayInfo, data string info.SendResponseCount++ switch info.RelayFormat { - case relaycommon.RelayFormatOpenAI: + case types.RelayFormatOpenAI: return sendStreamData(c, info, data, forceFormat, thinkToContent) - case relaycommon.RelayFormatClaude: + case types.RelayFormatClaude: return handleClaudeFormat(c, data, info) - case relaycommon.RelayFormatGemini: + case types.RelayFormatGemini: return handleGeminiFormat(c, data, info) } return nil @@ -111,14 +112,14 @@ func processChatCompletions(streamResp string, streamItems []string, responseTex var streamResponses []dto.ChatCompletionsStreamResponse if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil { // 一次性解析失败,逐个解析 - logger.SysError("error unmarshalling stream response: " + err.Error()) + common.SysLog("error unmarshalling stream response: " + err.Error()) for _, item := range streamItems { var streamResponse dto.ChatCompletionsStreamResponse if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil { return err } if err := ProcessStreamResponse(streamResponse, responseTextBuilder, toolCount); err != nil { - logger.SysError("error processing stream response: " + err.Error()) + common.SysLog("error processing stream response: " + err.Error()) } } return nil @@ -147,7 +148,7 @@ func processCompletions(streamResp string, streamItems []string, responseTextBui var streamResponses []dto.CompletionsStreamResponse if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil { // 一次性解析失败,逐个解析 - logger.SysError("error unmarshalling stream response: " + err.Error()) + common.SysLog("error unmarshalling stream response: " + err.Error()) for _, item := range streamItems { var streamResponse dto.CompletionsStreamResponse if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil { @@ -202,7 +203,7 @@ func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream usage *dto.Usage, containStreamUsage bool) { switch info.RelayFormat { - case relaycommon.RelayFormatOpenAI: + case types.RelayFormatOpenAI: if info.ShouldIncludeUsage && !containStreamUsage { response := helper.GenerateFinalUsageResponse(responseId, createAt, model, *usage) response.SetSystemFingerprint(systemFingerprint) @@ -210,11 +211,11 @@ func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream } helper.Done(c) - case relaycommon.RelayFormatClaude: + case types.RelayFormatClaude: info.ClaudeConvertInfo.Done = true var streamResponse dto.ChatCompletionsStreamResponse if err := common.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil { - logger.SysError("error unmarshalling stream response: " + err.Error()) + common.SysLog("error unmarshalling stream response: " + err.Error()) return } @@ -225,10 +226,10 @@ func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream _ = helper.ClaudeData(c, *resp) } - case relaycommon.RelayFormatGemini: + case types.RelayFormatGemini: var streamResponse dto.ChatCompletionsStreamResponse if err := common.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil { - logger.SysError("error unmarshalling stream response: " + err.Error()) + common.SysLog("error unmarshalling stream response: " + err.Error()) return } @@ -246,7 +247,7 @@ func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream geminiResponseStr, err := common.Marshal(geminiResponse) if err != nil { - logger.SysError("error marshalling gemini response: " + err.Error()) + common.SysLog("error marshalling gemini response: " + err.Error()) return } diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index 447e0f31..00dde46d 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -130,7 +130,7 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re if lastStreamData != "" { err := HandleStreamFormat(c, info, lastStreamData, info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent) if err != nil { - logger.SysError("error handling stream format: " + err.Error()) + common.SysLog("error handling stream format: " + err.Error()) } } if len(data) > 0 { @@ -147,7 +147,7 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re logger.LogError(c, fmt.Sprintf("error handling last response: %s, lastStreamData: [%s]", err.Error(), lastStreamData)) } - if info.RelayFormat == relaycommon.RelayFormatOpenAI { + if info.RelayFormat == types.RelayFormatOpenAI { if shouldSendLastResp { _ = sendStreamData(c, info, lastStreamData, info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent) } @@ -211,7 +211,7 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo } switch info.RelayFormat { - case relaycommon.RelayFormatOpenAI: + case types.RelayFormatOpenAI: if forceFormat { responseBody, err = common.Marshal(simpleResponse) if err != nil { @@ -220,14 +220,14 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo } else { break } - case relaycommon.RelayFormatClaude: + case types.RelayFormatClaude: claudeResp := service.ResponseOpenAI2Claude(&simpleResponse, info) claudeRespStr, err := common.Marshal(claudeResp) if err != nil { return nil, types.NewError(err, types.ErrorCodeBadResponseBody) } responseBody = claudeRespStr - case relaycommon.RelayFormatGemini: + case types.RelayFormatGemini: geminiResp := service.ResponseOpenAI2Gemini(&simpleResponse, info) geminiRespStr, err := common.Marshal(geminiResp) if err != nil { diff --git a/relay/channel/palm/adaptor.go b/relay/channel/palm/adaptor.go index 4d1ab783..2a022a1b 100644 --- a/relay/channel/palm/adaptor.go +++ b/relay/channel/palm/adaptor.go @@ -42,7 +42,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { - return fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", info.BaseUrl), nil + return fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", info.ChannelBaseUrl), nil } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { diff --git a/relay/channel/palm/relay-palm.go b/relay/channel/palm/relay-palm.go index 1264b2b4..3a6ec2f4 100644 --- a/relay/channel/palm/relay-palm.go +++ b/relay/channel/palm/relay-palm.go @@ -7,7 +7,6 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" - "one-api/logger" relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" @@ -59,7 +58,7 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, go func() { responseBody, err := io.ReadAll(resp.Body) if err != nil { - logger.SysError("error reading stream response: " + err.Error()) + common.SysLog("error reading stream response: " + err.Error()) stopChan <- true return } @@ -67,7 +66,7 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, var palmResponse PaLMChatResponse err = json.Unmarshal(responseBody, &palmResponse) if err != nil { - logger.SysError("error unmarshalling stream response: " + err.Error()) + common.SysLog("error unmarshalling stream response: " + err.Error()) stopChan <- true return } @@ -79,7 +78,7 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, } jsonResponse, err := json.Marshal(fullTextResponse) if err != nil { - logger.SysError("error marshalling stream response: " + err.Error()) + common.SysLog("error marshalling stream response: " + err.Error()) stopChan <- true return } diff --git a/relay/channel/perplexity/adaptor.go b/relay/channel/perplexity/adaptor.go index 92cb08a2..8ab9c854 100644 --- a/relay/channel/perplexity/adaptor.go +++ b/relay/channel/perplexity/adaptor.go @@ -42,7 +42,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { - return fmt.Sprintf("%s/chat/completions", info.BaseUrl), nil + return fmt.Sprintf("%s/chat/completions", info.ChannelBaseUrl), nil } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { diff --git a/relay/channel/siliconflow/adaptor.go b/relay/channel/siliconflow/adaptor.go index 05e6d453..4c176c08 100644 --- a/relay/channel/siliconflow/adaptor.go +++ b/relay/channel/siliconflow/adaptor.go @@ -43,15 +43,15 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { if info.RelayMode == constant.RelayModeRerank { - return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/rerank", info.ChannelBaseUrl), nil } else if info.RelayMode == constant.RelayModeEmbeddings { - return fmt.Sprintf("%s/v1/embeddings", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/embeddings", info.ChannelBaseUrl), nil } else if info.RelayMode == constant.RelayModeChatCompletions { - return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil } else if info.RelayMode == constant.RelayModeCompletions { - return fmt.Sprintf("%s/v1/completions", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/completions", info.ChannelBaseUrl), nil } - return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil + return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { diff --git a/relay/channel/task/jimeng/adaptor.go b/relay/channel/task/jimeng/adaptor.go index 8d057513..a5ada137 100644 --- a/relay/channel/task/jimeng/adaptor.go +++ b/relay/channel/task/jimeng/adaptor.go @@ -76,7 +76,7 @@ type TaskAdaptor struct { func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) { a.ChannelType = info.ChannelType - a.baseURL = info.BaseUrl + a.baseURL = info.ChannelBaseUrl // apiKey format: "access_key|secret_key" keyParts := strings.Split(info.ApiKey, "|") diff --git a/relay/channel/task/kling/adaptor.go b/relay/channel/task/kling/adaptor.go index b7b9a5ff..1fecda08 100644 --- a/relay/channel/task/kling/adaptor.go +++ b/relay/channel/task/kling/adaptor.go @@ -81,7 +81,7 @@ type TaskAdaptor struct { func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) { a.ChannelType = info.ChannelType - a.baseURL = info.BaseUrl + a.baseURL = info.ChannelBaseUrl a.apiKey = info.ApiKey // apiKey format: "access_key|secret_key" diff --git a/relay/channel/task/suno/adaptor.go b/relay/channel/task/suno/adaptor.go index 1deb33fd..df2bb99e 100644 --- a/relay/channel/task/suno/adaptor.go +++ b/relay/channel/task/suno/adaptor.go @@ -11,7 +11,6 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" - "one-api/logger" "one-api/relay/channel" relaycommon "one-api/relay/common" "one-api/service" @@ -60,7 +59,7 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom } func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) { - baseURL := info.BaseUrl + baseURL := info.ChannelBaseUrl fullRequestURL := fmt.Sprintf("%s%s", baseURL, "/suno/submit/"+info.Action) return fullRequestURL, nil } @@ -140,7 +139,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(byteBody)) if err != nil { - logger.SysError(fmt.Sprintf("Get Task error: %v", err)) + common.SysLog(fmt.Sprintf("Get Task error: %v", err)) return nil, err } defer req.Body.Close() diff --git a/relay/channel/task/vidu/adaptor.go b/relay/channel/task/vidu/adaptor.go index f40b480c..b0cc0bdc 100644 --- a/relay/channel/task/vidu/adaptor.go +++ b/relay/channel/task/vidu/adaptor.go @@ -86,7 +86,7 @@ type TaskAdaptor struct { func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) { a.ChannelType = info.ChannelType - a.baseURL = info.BaseUrl + a.baseURL = info.ChannelBaseUrl } func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) *dto.TaskError { diff --git a/relay/channel/tencent/adaptor.go b/relay/channel/tencent/adaptor.go index b86d8a16..ab96ecaa 100644 --- a/relay/channel/tencent/adaptor.go +++ b/relay/channel/tencent/adaptor.go @@ -53,7 +53,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { - return fmt.Sprintf("%s/", info.BaseUrl), nil + return fmt.Sprintf("%s/", info.ChannelBaseUrl), nil } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { diff --git a/relay/channel/tencent/relay-tencent.go b/relay/channel/tencent/relay-tencent.go index d3aeab3f..f33a275c 100644 --- a/relay/channel/tencent/relay-tencent.go +++ b/relay/channel/tencent/relay-tencent.go @@ -13,7 +13,6 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" - "one-api/logger" relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" @@ -107,7 +106,7 @@ func tencentStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *htt var tencentResponse TencentChatResponse err := json.Unmarshal([]byte(data), &tencentResponse) if err != nil { - logger.SysError("error unmarshalling stream response: " + err.Error()) + common.SysLog("error unmarshalling stream response: " + err.Error()) continue } @@ -118,12 +117,12 @@ func tencentStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *htt err = helper.ObjectData(c, response) if err != nil { - logger.SysError(err.Error()) + common.SysLog(err.Error()) } } if err := scanner.Err(); err != nil { - logger.SysError("error reading stream: " + err.Error()) + common.SysLog("error reading stream: " + err.Error()) } helper.Done(c) diff --git a/relay/channel/volcengine/adaptor.go b/relay/channel/volcengine/adaptor.go index 2cc4f663..b46cb952 100644 --- a/relay/channel/volcengine/adaptor.go +++ b/relay/channel/volcengine/adaptor.go @@ -188,17 +188,17 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { switch info.RelayMode { case constant.RelayModeChatCompletions: if strings.HasPrefix(info.UpstreamModelName, "bot") { - return fmt.Sprintf("%s/api/v3/bots/chat/completions", info.BaseUrl), nil + return fmt.Sprintf("%s/api/v3/bots/chat/completions", info.ChannelBaseUrl), nil } - return fmt.Sprintf("%s/api/v3/chat/completions", info.BaseUrl), nil + return fmt.Sprintf("%s/api/v3/chat/completions", info.ChannelBaseUrl), nil case constant.RelayModeEmbeddings: - return fmt.Sprintf("%s/api/v3/embeddings", info.BaseUrl), nil + return fmt.Sprintf("%s/api/v3/embeddings", info.ChannelBaseUrl), nil case constant.RelayModeImagesGenerations: - return fmt.Sprintf("%s/api/v3/images/generations", info.BaseUrl), nil + return fmt.Sprintf("%s/api/v3/images/generations", info.ChannelBaseUrl), nil case constant.RelayModeImagesEdits: - return fmt.Sprintf("%s/api/v3/images/edits", info.BaseUrl), nil + return fmt.Sprintf("%s/api/v3/images/edits", info.ChannelBaseUrl), nil case constant.RelayModeRerank: - return fmt.Sprintf("%s/api/v3/rerank", info.BaseUrl), nil + return fmt.Sprintf("%s/api/v3/rerank", info.ChannelBaseUrl), nil default: } return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode) diff --git a/relay/channel/xai/adaptor.go b/relay/channel/xai/adaptor.go index 6a3a5370..d5671ab2 100644 --- a/relay/channel/xai/adaptor.go +++ b/relay/channel/xai/adaptor.go @@ -39,7 +39,7 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf xaiRequest := ImageRequest{ Model: request.Model, Prompt: request.Prompt, - N: request.N, + N: int(request.N), ResponseFormat: request.ResponseFormat, } return xaiRequest, nil @@ -49,7 +49,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { - return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil + return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, info.RequestURLPath, info.ChannelType), nil } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { diff --git a/relay/channel/xai/text.go b/relay/channel/xai/text.go index 4d4e7b92..5cae9c0a 100644 --- a/relay/channel/xai/text.go +++ b/relay/channel/xai/text.go @@ -6,7 +6,6 @@ import ( "net/http" "one-api/common" "one-api/dto" - "one-api/logger" "one-api/relay/channel/openai" relaycommon "one-api/relay/common" "one-api/relay/helper" @@ -48,7 +47,7 @@ func xAIStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re var xAIResp *dto.ChatCompletionsStreamResponse err := json.Unmarshal([]byte(data), &xAIResp) if err != nil { - logger.SysError("error unmarshalling stream response: " + err.Error()) + common.SysLog("error unmarshalling stream response: " + err.Error()) return true } @@ -64,7 +63,7 @@ func xAIStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re _ = openai.ProcessStreamResponse(*openaiResponse, &responseTextBuilder, &toolCount) err = helper.ObjectData(c, openaiResponse) if err != nil { - logger.SysError(err.Error()) + common.SysLog(err.Error()) } return true }) diff --git a/relay/channel/xunfei/relay-xunfei.go b/relay/channel/xunfei/relay-xunfei.go index 398bb08d..54ed476f 100644 --- a/relay/channel/xunfei/relay-xunfei.go +++ b/relay/channel/xunfei/relay-xunfei.go @@ -11,7 +11,6 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" - "one-api/logger" "one-api/relay/helper" "one-api/types" "strings" @@ -144,7 +143,7 @@ func xunfeiStreamHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, a response := streamResponseXunfei2OpenAI(&xunfeiResponse) jsonResponse, err := json.Marshal(response) if err != nil { - logger.SysError("error marshalling stream response: " + err.Error()) + common.SysLog("error marshalling stream response: " + err.Error()) return true } c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) @@ -219,20 +218,20 @@ func xunfeiMakeRequest(textRequest dto.GeneralOpenAIRequest, domain, authUrl, ap for { _, msg, err := conn.ReadMessage() if err != nil { - logger.SysError("error reading stream response: " + err.Error()) + common.SysLog("error reading stream response: " + err.Error()) break } var response XunfeiChatResponse err = json.Unmarshal(msg, &response) if err != nil { - logger.SysError("error unmarshalling stream response: " + err.Error()) + common.SysLog("error unmarshalling stream response: " + err.Error()) break } dataChan <- response if response.Payload.Choices.Status == 2 { err := conn.Close() if err != nil { - logger.SysError("error closing websocket connection: " + err.Error()) + common.SysLog("error closing websocket connection: " + err.Error()) } break } @@ -283,6 +282,6 @@ func getAPIVersion(c *gin.Context, modelName string) string { return apiVersion } apiVersion = "v1.1" - logger.SysLog("api_version not found, using default: " + apiVersion) + common.SysLog("api_version not found, using default: " + apiVersion) return apiVersion } diff --git a/relay/channel/zhipu/adaptor.go b/relay/channel/zhipu/adaptor.go index e3be0e8e..bd27c90b 100644 --- a/relay/channel/zhipu/adaptor.go +++ b/relay/channel/zhipu/adaptor.go @@ -45,7 +45,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { if info.IsStream { method = "sse-invoke" } - return fmt.Sprintf("%s/api/paas/v3/model-api/%s/%s", info.BaseUrl, info.UpstreamModelName, method), nil + return fmt.Sprintf("%s/api/paas/v3/model-api/%s/%s", info.ChannelBaseUrl, info.UpstreamModelName, method), nil } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { diff --git a/relay/channel/zhipu/relay-zhipu.go b/relay/channel/zhipu/relay-zhipu.go index 65b662b6..8eb0dcc1 100644 --- a/relay/channel/zhipu/relay-zhipu.go +++ b/relay/channel/zhipu/relay-zhipu.go @@ -8,7 +8,6 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" - "one-api/logger" relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" @@ -40,7 +39,7 @@ func getZhipuToken(apikey string) string { split := strings.Split(apikey, ".") if len(split) != 2 { - logger.SysError("invalid zhipu key: " + apikey) + common.SysLog("invalid zhipu key: " + apikey) return "" } @@ -188,7 +187,7 @@ func zhipuStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http. response := streamResponseZhipu2OpenAI(data) jsonResponse, err := json.Marshal(response) if err != nil { - logger.SysError("error marshalling stream response: " + err.Error()) + common.SysLog("error marshalling stream response: " + err.Error()) return true } c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) @@ -197,13 +196,13 @@ func zhipuStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http. var zhipuResponse ZhipuStreamMetaResponse err := json.Unmarshal([]byte(data), &zhipuResponse) if err != nil { - logger.SysError("error unmarshalling stream response: " + err.Error()) + common.SysLog("error unmarshalling stream response: " + err.Error()) return true } response, zhipuUsage := streamMetaResponseZhipu2OpenAI(&zhipuResponse) jsonResponse, err := json.Marshal(response) if err != nil { - logger.SysError("error marshalling stream response: " + err.Error()) + common.SysLog("error marshalling stream response: " + err.Error()) return true } usage = zhipuUsage diff --git a/relay/channel/zhipu_4v/adaptor.go b/relay/channel/zhipu_4v/adaptor.go index a83e30e6..0fae3767 100644 --- a/relay/channel/zhipu_4v/adaptor.go +++ b/relay/channel/zhipu_4v/adaptor.go @@ -43,7 +43,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { - baseUrl := fmt.Sprintf("%s/api/paas/v4", info.BaseUrl) + baseUrl := fmt.Sprintf("%s/api/paas/v4", info.ChannelBaseUrl) switch info.RelayMode { case relayconstant.RelayModeEmbeddings: return fmt.Sprintf("%s/embeddings", baseUrl), nil diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 59be0011..742cd61c 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -66,6 +66,7 @@ type ChannelMeta struct { ChannelOtherSettings dto.ChannelOtherSettings UpstreamModelName string IsModelMapped bool + SupportStreamOptions bool // 是否支持流式选项 } type RelayInfo struct { @@ -86,9 +87,9 @@ type RelayInfo struct { RelayMode int OriginModelName string //RecodeModelName string - RequestURLPath string - PromptTokens int - SupportStreamOptions bool + RequestURLPath string + PromptTokens int + //SupportStreamOptions bool ShouldIncludeUsage bool DisablePing bool // 是否禁止向下游发送自定义 Ping ClientWs *websocket.Conn @@ -135,6 +136,7 @@ func (info *RelayInfo) InitChannelMeta(c *gin.Context) { ParamOverride: paramOverride, UpstreamModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel), IsModelMapped: false, + SupportStreamOptions: false, } channelSetting, ok := common.GetContextKeyType[dto.ChannelSettings](c, constant.ContextKeyChannelSetting) @@ -146,6 +148,10 @@ func (info *RelayInfo) InitChannelMeta(c *gin.Context) { if ok { channelMeta.ChannelOtherSettings = channelOtherSettings } + + if streamSupportedChannels[channelMeta.ChannelType] { + channelMeta.SupportStreamOptions = true + } info.ChannelMeta = channelMeta } @@ -268,6 +274,12 @@ func genBaseRelayInfo(c *gin.Context, request dto.Request) *RelayInfo { startTime = time.Now() } + isStream := false + + if request != nil { + isStream = request.IsStream(c) + } + // firstResponseTime = time.Now() - 1 second info := &RelayInfo{ @@ -289,7 +301,7 @@ func genBaseRelayInfo(c *gin.Context, request dto.Request) *RelayInfo { isFirstResponse: true, RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path), RequestURLPath: c.Request.URL.String(), - IsStream: request.IsStream(c), + IsStream: isStream, StartTime: startTime, FirstResponseTime: startTime.Add(-time.Second), @@ -339,6 +351,10 @@ func GenRelayInfo(c *gin.Context, relayFormat types.RelayFormat, request dto.Req return GenRelayInfoResponses(c, request), nil } return nil, errors.New("request is not a OpenAIResponsesRequest") + case types.RelayFormatTask: + return genBaseRelayInfo(c, nil), nil + case types.RelayFormatMjProxy: + return genBaseRelayInfo(c, nil), nil default: return nil, errors.New("invalid relay format") } @@ -367,11 +383,15 @@ type TaskRelayInfo struct { ConsumeQuota bool } -func GenTaskRelayInfo(c *gin.Context) *TaskRelayInfo { - info := &TaskRelayInfo{ - RelayInfo: GenRelayInfo(c), +func GenTaskRelayInfo(c *gin.Context) (*TaskRelayInfo, error) { + relayInfo, err := GenRelayInfo(c, types.RelayFormatTask, nil, nil) + if err != nil { + return nil, err } - return info + info := &TaskRelayInfo{ + RelayInfo: relayInfo, + } + return info, nil } type TaskSubmitReq struct { diff --git a/relay/helper/price.go b/relay/helper/price.go index 89fc3b66..fdc5b66d 100644 --- a/relay/helper/price.go +++ b/relay/helper/price.go @@ -53,9 +53,9 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens var imageRatio float64 var cacheCreationRatio float64 if !usePrice { - preConsumedTokens := common.PreConsumedQuota + preConsumedTokens := common.Max(promptTokens, common.PreConsumedQuota) if meta.MaxTokens != 0 { - preConsumedTokens = promptTokens + meta.MaxTokens + preConsumedTokens += meta.MaxTokens } var success bool var matchName string @@ -102,27 +102,27 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens } // ModelPriceHelperPerCall 按次计费的 PriceHelper (MJ、Task) -//func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) types.PerCallPriceData { -// groupRatioInfo := HandleGroupRatio(c, info) -// -// modelPrice, success := ratio_setting.GetModelPrice(info.OriginModelName, true) -// // 如果没有配置价格,则使用默认价格 -// if !success { -// defaultPrice, ok := ratio_setting.GetDefaultModelRatioMap()[info.OriginModelName] -// if !ok { -// modelPrice = 0.1 -// } else { -// modelPrice = defaultPrice -// } -// } -// quota := int(modelPrice * common.QuotaPerUnit * groupRatioInfo.GroupRatio) -// priceData := types.PerCallPriceData{ -// ModelPrice: modelPrice, -// Quota: quota, -// GroupRatioInfo: groupRatioInfo, -// } -// return priceData -//} +func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) types.PerCallPriceData { + groupRatioInfo := HandleGroupRatio(c, info) + + modelPrice, success := ratio_setting.GetModelPrice(info.OriginModelName, true) + // 如果没有配置价格,则使用默认价格 + if !success { + defaultPrice, ok := ratio_setting.GetDefaultModelRatioMap()[info.OriginModelName] + if !ok { + modelPrice = 0.1 + } else { + modelPrice = defaultPrice + } + } + quota := int(modelPrice * common.QuotaPerUnit * groupRatioInfo.GroupRatio) + priceData := types.PerCallPriceData{ + ModelPrice: modelPrice, + Quota: quota, + GroupRatioInfo: groupRatioInfo, + } + return priceData +} func ContainPriceOrRatio(modelName string) bool { _, ok := ratio_setting.GetModelPrice(modelName, false) diff --git a/relay/helper/valid_request.go b/relay/helper/valid_request.go index 0bc51774..1d556a33 100644 --- a/relay/helper/valid_request.go +++ b/relay/helper/valid_request.go @@ -36,7 +36,7 @@ func GetAndValidateRequest(c *gin.Context, format types.RelayFormat) (request dt case types.RelayFormatOpenAIAudio: request, err = GetAndValidAudioRequest(c, relayMode) case types.RelayFormatOpenAIRealtime: - // nothing to do, no request body + request = &dto.BaseRequest{} default: return nil, fmt.Errorf("unsupported relay format: %s", format) } diff --git a/relay/chat_handler.go b/relay/mjproxy_handler.go similarity index 87% rename from relay/chat_handler.go rename to relay/mjproxy_handler.go index 30bce55c..756ad450 100644 --- a/relay/chat_handler.go +++ b/relay/mjproxy_handler.go @@ -10,7 +10,6 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" - "one-api/logger" "one-api/model" relaycommon "one-api/relay/common" relayconstant "one-api/relay/constant" @@ -171,13 +170,7 @@ func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjo return } -func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse { - startTime := time.Now().UnixNano() / int64(time.Millisecond) - tokenId := c.GetInt("token_id") - userId := c.GetInt("id") - //group := c.GetString("group") - channelId := c.GetInt("channel_id") - relayInfo := relaycommon.GenRelayInfo(c) +func RelaySwapFace(c *gin.Context, info *relaycommon.RelayInfo) *dto.MidjourneyResponse { var swapFaceRequest dto.SwapFaceRequest err := common.UnmarshalBodyReusable(c, &swapFaceRequest) if err != nil { @@ -188,9 +181,9 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse { } modelName := service.CoverActionToModelName(constant.MjActionSwapFace) - priceData := helper.ModelPriceHelperPerCall(c, relayInfo) + priceData := helper.ModelPriceHelperPerCall(c, info) - userQuota, err := model.GetUserQuota(userId, false) + userQuota, err := model.GetUserQuota(info.UserId, false) if err != nil { return &dto.MidjourneyResponse{ Code: 4, @@ -213,32 +206,31 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse { } defer func() { if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 { - err := service.PostConsumeQuota(relayInfo, priceData.Quota, 0, true) + err := service.PostConsumeQuota(info, priceData.Quota, 0, true) if err != nil { - logger.SysError("error consuming token remain quota: " + err.Error()) + common.SysLog("error consuming token remain quota: " + err.Error()) } tokenName := c.GetString("token_name") logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", priceData.ModelPrice, priceData.GroupRatioInfo.GroupRatio, constant.MjActionSwapFace) other := service.GenerateMjOtherInfo(priceData) - model.RecordConsumeLog(c, relayInfo.UserId, model.RecordConsumeLogParams{ - ChannelId: channelId, + model.RecordConsumeLog(c, info.UserId, model.RecordConsumeLogParams{ + ChannelId: info.ChannelId, ModelName: modelName, TokenName: tokenName, Quota: priceData.Quota, Content: logContent, - TokenId: tokenId, - UserQuota: userQuota, - Group: relayInfo.UsingGroup, + TokenId: info.TokenId, + Group: info.UsingGroup, Other: other, }) - model.UpdateUserUsedQuotaAndRequestCount(userId, priceData.Quota) - model.UpdateChannelUsedQuota(channelId, priceData.Quota) + model.UpdateUserUsedQuotaAndRequestCount(info.UserId, priceData.Quota) + model.UpdateChannelUsedQuota(info.ChannelId, priceData.Quota) } }() midjResponse := &mjResp.Response midjourneyTask := &model.Midjourney{ - UserId: userId, + UserId: info.UserId, Code: midjResponse.Code, Action: constant.MjActionSwapFace, MjId: midjResponse.Result, @@ -246,7 +238,7 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse { PromptEn: "", Description: midjResponse.Description, State: "", - SubmitTime: startTime, + SubmitTime: info.StartTime.UnixNano() / int64(time.Millisecond), StartTime: time.Now().UnixNano() / int64(time.Millisecond), FinishTime: 0, ImageUrl: "", @@ -370,14 +362,7 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse return nil } -func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyResponse { - - //tokenId := c.GetInt("token_id") - //channelType := c.GetInt("channel") - userId := c.GetInt("id") - group := c.GetString("group") - channelId := c.GetInt("channel_id") - relayInfo := relaycommon.GenRelayInfo(c) +func RelayMidjourneySubmit(c *gin.Context, relayInfo *relaycommon.RelayInfo) *dto.MidjourneyResponse { consumeQuota := true var midjRequest dto.MidjourneyRequest err := common.UnmarshalBodyReusable(c, &midjRequest) @@ -385,35 +370,35 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons return service.MidjourneyErrorWrapper(constant.MjRequestError, "bind_request_body_failed") } - if relayMode == relayconstant.RelayModeMidjourneyAction { // midjourney plus,需要从customId中获取任务信息 + if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyAction { // midjourney plus,需要从customId中获取任务信息 mjErr := service.CoverPlusActionToNormalAction(&midjRequest) if mjErr != nil { return mjErr } - relayMode = relayconstant.RelayModeMidjourneyChange + relayInfo.RelayMode = relayconstant.RelayModeMidjourneyChange } - if relayMode == relayconstant.RelayModeMidjourneyVideo { + if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyVideo { midjRequest.Action = constant.MjActionVideo } - if relayMode == relayconstant.RelayModeMidjourneyImagine { //绘画任务,此类任务可重复 + if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyImagine { //绘画任务,此类任务可重复 if midjRequest.Prompt == "" { return service.MidjourneyErrorWrapper(constant.MjRequestError, "prompt_is_required") } midjRequest.Action = constant.MjActionImagine - } else if relayMode == relayconstant.RelayModeMidjourneyDescribe { //按图生文任务,此类任务可重复 + } else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyDescribe { //按图生文任务,此类任务可重复 midjRequest.Action = constant.MjActionDescribe - } else if relayMode == relayconstant.RelayModeMidjourneyEdits { //编辑任务,此类任务可重复 + } else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyEdits { //编辑任务,此类任务可重复 midjRequest.Action = constant.MjActionEdits - } else if relayMode == relayconstant.RelayModeMidjourneyShorten { //缩短任务,此类任务可重复,plus only + } else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyShorten { //缩短任务,此类任务可重复,plus only midjRequest.Action = constant.MjActionShorten - } else if relayMode == relayconstant.RelayModeMidjourneyBlend { //绘画任务,此类任务可重复 + } else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyBlend { //绘画任务,此类任务可重复 midjRequest.Action = constant.MjActionBlend - } else if relayMode == relayconstant.RelayModeMidjourneyUpload { //绘画任务,此类任务可重复 + } else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyUpload { //绘画任务,此类任务可重复 midjRequest.Action = constant.MjActionUpload } else if midjRequest.TaskId != "" { //放大、变换任务,此类任务,如果重复且已有结果,远端api会直接返回最终结果 mjId := "" - if relayMode == relayconstant.RelayModeMidjourneyChange { + if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyChange { if midjRequest.TaskId == "" { return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_id_is_required") } else if midjRequest.Action == "" { @@ -423,7 +408,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons } //action = midjRequest.Action mjId = midjRequest.TaskId - } else if relayMode == relayconstant.RelayModeMidjourneySimpleChange { + } else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneySimpleChange { if midjRequest.Content == "" { return service.MidjourneyErrorWrapper(constant.MjRequestError, "content_is_required") } @@ -433,13 +418,13 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons } mjId = params.TaskId midjRequest.Action = params.Action - } else if relayMode == relayconstant.RelayModeMidjourneyModal { + } else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyModal { //if midjRequest.MaskBase64 == "" { // return service.MidjourneyErrorWrapper(constant.MjRequestError, "mask_base64_is_required") //} mjId = midjRequest.TaskId midjRequest.Action = constant.MjActionModal - } else if relayMode == relayconstant.RelayModeMidjourneyVideo { + } else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyVideo { midjRequest.Action = constant.MjActionVideo if midjRequest.TaskId == "" { return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_id_is_required") @@ -449,12 +434,12 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons mjId = midjRequest.TaskId } - originTask := model.GetByMJId(userId, mjId) + originTask := model.GetByMJId(relayInfo.UserId, mjId) if originTask == nil { return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_not_found") } else { //原任务的Status=SUCCESS,则可以做放大UPSCALE、变换VARIATION等动作,此时必须使用原来的请求地址才能正确处理 if setting.MjActionCheckSuccessEnabled { - if originTask.Status != "SUCCESS" && relayMode != relayconstant.RelayModeMidjourneyModal { + if originTask.Status != "SUCCESS" && relayInfo.RelayMode != relayconstant.RelayModeMidjourneyModal { return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_status_not_success") } } @@ -497,7 +482,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons priceData := helper.ModelPriceHelperPerCall(c, relayInfo) - userQuota, err := model.GetUserQuota(userId, false) + userQuota, err := model.GetUserQuota(relayInfo.UserId, false) if err != nil { return &dto.MidjourneyResponse{ Code: 4, @@ -522,24 +507,23 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons if consumeQuota && midjResponseWithStatus.StatusCode == 200 { err := service.PostConsumeQuota(relayInfo, priceData.Quota, 0, true) if err != nil { - logger.SysError("error consuming token remain quota: " + err.Error()) + common.SysLog("error consuming token remain quota: " + err.Error()) } tokenName := c.GetString("token_name") logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s,ID %s", priceData.ModelPrice, priceData.GroupRatioInfo.GroupRatio, midjRequest.Action, midjResponse.Result) other := service.GenerateMjOtherInfo(priceData) model.RecordConsumeLog(c, relayInfo.UserId, model.RecordConsumeLogParams{ - ChannelId: channelId, + ChannelId: relayInfo.ChannelId, ModelName: modelName, TokenName: tokenName, Quota: priceData.Quota, Content: logContent, TokenId: relayInfo.TokenId, - UserQuota: userQuota, - Group: group, + Group: relayInfo.UsingGroup, Other: other, }) - model.UpdateUserUsedQuotaAndRequestCount(userId, priceData.Quota) - model.UpdateChannelUsedQuota(channelId, priceData.Quota) + model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, priceData.Quota) + model.UpdateChannelUsedQuota(relayInfo.ChannelId, priceData.Quota) } }() @@ -551,7 +535,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons // 24-prompt包含敏感词 {"code":24,"description":"可能包含敏感词","properties":{"promptEn":"nude body","bannedWord":"nude"}} // other: 提交错误,description为错误描述 midjourneyTask := &model.Midjourney{ - UserId: userId, + UserId: relayInfo.UserId, Code: midjResponse.Code, Action: midjRequest.Action, MjId: midjResponse.Result, @@ -573,7 +557,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons //无实例账号自动禁用渠道(No available account instance) channel, err := model.GetChannelById(midjourneyTask.ChannelId, true) if err != nil { - logger.SysError("get_channel_null: " + err.Error()) + common.SysLog("get_channel_null: " + err.Error()) } if channel.GetAutoBan() && common.AutomaticDisableChannelEnabled { model.UpdateChannelStatus(midjourneyTask.ChannelId, "", 2, "No available account instance") diff --git a/relay/relay-text.go b/relay/relay-text.go index de750e76..5c07c718 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -44,6 +44,26 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) } + includeUsage := true + // 判断用户是否需要返回使用情况 + if textRequest.StreamOptions != nil { + includeUsage = textRequest.StreamOptions.IncludeUsage + } + + // 如果不支持StreamOptions,将StreamOptions设置为nil + if !info.SupportStreamOptions || !textRequest.Stream { + textRequest.StreamOptions = nil + } else { + // 如果支持StreamOptions,且请求中没有设置StreamOptions,根据配置文件设置StreamOptions + if constant.ForceStreamOption { + textRequest.StreamOptions = &dto.StreamOptions{ + IncludeUsage: true, + } + } + } + + info.ShouldIncludeUsage = includeUsage + adaptor := GetAdaptor(info.ApiType) if adaptor == nil { return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) diff --git a/relay/relay_task.go b/relay/relay_task.go index ae002d73..95b8083b 100644 --- a/relay/relay_task.go +++ b/relay/relay_task.go @@ -10,7 +10,6 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" - "one-api/logger" "one-api/model" relaycommon "one-api/relay/common" relayconstant "one-api/relay/constant" @@ -28,7 +27,11 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) { if platform == "" { platform = GetTaskPlatform(c) } - relayInfo := relaycommon.GenTaskRelayInfo(c) + + relayInfo, err := relaycommon.GenTaskRelayInfo(c) + if err != nil { + return service.TaskErrorWrapper(err, "gen_relay_info_failed", http.StatusInternalServerError) + } adaptor := GetTaskAdaptor(platform) if adaptor == nil { @@ -98,7 +101,7 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) { c.Set("channel_id", originTask.ChannelId) c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) - relayInfo.BaseUrl = channel.GetBaseURL() + relayInfo.ChannelBaseUrl = channel.GetBaseURL() relayInfo.ChannelId = originTask.ChannelId } } @@ -128,7 +131,7 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) { err := service.PostConsumeQuota(relayInfo.RelayInfo, quota, 0, true) if err != nil { - logger.SysError("error consuming token remain quota: " + err.Error()) + common.SysLog("error consuming token remain quota: " + err.Error()) } if quota != 0 { tokenName := c.GetString("token_name") @@ -150,7 +153,6 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) { Quota: quota, Content: logContent, TokenId: relayInfo.TokenId, - UserQuota: userQuota, Group: relayInfo.UsingGroup, Other: other, }) diff --git a/relay/websocket.go b/relay/websocket.go index 22b681f1..2d313154 100644 --- a/relay/websocket.go +++ b/relay/websocket.go @@ -4,7 +4,6 @@ import ( "fmt" "one-api/dto" relaycommon "one-api/relay/common" - "one-api/relay/helper" "one-api/service" "one-api/types" @@ -12,58 +11,35 @@ import ( "github.com/gorilla/websocket" ) -func WssHelper(c *gin.Context, ws *websocket.Conn) (newAPIError *types.NewAPIError) { - relayInfo := relaycommon.GenRelayInfoWs(c, ws) +func WssHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) { + info.InitChannelMeta(c) - err := helper.ModelMappedHelper(c, relayInfo, nil) - if err != nil { - return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) - } - - priceData, err := helper.ModelPriceHelper(c, relayInfo, 0, 0) - if err != nil { - return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry()) - } - - // pre-consume quota 预消耗配额 - preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) - if newAPIError != nil { - return newAPIError - } - - defer func() { - if newAPIError != nil { - returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) - } - }() - - adaptor := GetAdaptor(relayInfo.ApiType) + adaptor := GetAdaptor(info.ApiType) if adaptor == nil { - return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) + return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) } - adaptor.Init(relayInfo) + adaptor.Init(info) //var requestBody io.Reader //firstWssRequest, _ := c.Get("first_wss_request") //requestBody = bytes.NewBuffer(firstWssRequest.([]byte)) statusCodeMappingStr := c.GetString("status_code_mapping") - resp, err := adaptor.DoRequest(c, relayInfo, nil) + resp, err := adaptor.DoRequest(c, info, nil) if err != nil { return types.NewError(err, types.ErrorCodeDoRequestFailed) } if resp != nil { - relayInfo.TargetWs = resp.(*websocket.Conn) - defer relayInfo.TargetWs.Close() + info.TargetWs = resp.(*websocket.Conn) + defer info.TargetWs.Close() } - usage, newAPIError := adaptor.DoResponse(c, nil, relayInfo) + usage, newAPIError := adaptor.DoResponse(c, nil, info) if newAPIError != nil { // reset status code 重置状态码 service.ResetStatusCode(newAPIError, statusCodeMappingStr) return newAPIError } - service.PostWssConsumeQuota(c, relayInfo, relayInfo.UpstreamModelName, usage.(*dto.RealtimeUsage), preConsumedQuota, - userQuota, priceData, "") + service.PostWssConsumeQuota(c, info, info.UpstreamModelName, usage.(*dto.RealtimeUsage), "") return nil } diff --git a/router/main.go b/router/main.go index 7653f3a5..23576427 100644 --- a/router/main.go +++ b/router/main.go @@ -3,12 +3,12 @@ package router import ( "embed" "fmt" - "github.com/gin-gonic/gin" "net/http" "one-api/common" - "one-api/logger" "os" "strings" + + "github.com/gin-gonic/gin" ) func SetRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) { @@ -19,7 +19,7 @@ func SetRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) { frontendBaseUrl := os.Getenv("FRONTEND_BASE_URL") if common.IsMasterNode && frontendBaseUrl != "" { frontendBaseUrl = "" - logger.SysLog("FRONTEND_BASE_URL is ignored on master node") + common.SysLog("FRONTEND_BASE_URL is ignored on master node") } if frontendBaseUrl == "" { SetWebRouter(router, buildFS, indexPage) diff --git a/service/cf_worker.go b/service/cf_worker.go index 65f7f133..ae6e1ffe 100644 --- a/service/cf_worker.go +++ b/service/cf_worker.go @@ -5,7 +5,7 @@ import ( "encoding/json" "fmt" "net/http" - "one-api/logger" + "one-api/common" "one-api/setting" "strings" ) @@ -44,14 +44,14 @@ func DoWorkerRequest(req *WorkerRequest) (*http.Response, error) { func DoDownloadRequest(originUrl string) (resp *http.Response, err error) { if setting.EnableWorker() { - logger.SysLog(fmt.Sprintf("downloading file from worker: %s", originUrl)) + common.SysLog(fmt.Sprintf("downloading file from worker: %s", originUrl)) req := &WorkerRequest{ URL: originUrl, Key: setting.WorkerValidKey, } return DoWorkerRequest(req) } else { - logger.SysLog(fmt.Sprintf("downloading from origin: %s", originUrl)) + common.SysLog(fmt.Sprintf("downloading from origin: %s", originUrl)) return http.Get(originUrl) } } diff --git a/service/error.go b/service/error.go index 668731b0..ef5cbbde 100644 --- a/service/error.go +++ b/service/error.go @@ -7,7 +7,6 @@ import ( "net/http" "one-api/common" "one-api/dto" - "one-api/logger" "one-api/types" "strconv" "strings" @@ -59,7 +58,7 @@ func ClaudeErrorWrapper(err error, code string, statusCode int) *dto.ClaudeError lowerText := strings.ToLower(text) if !strings.HasPrefix(lowerText, "get file base64 from url") { if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") { - logger.SysLog(fmt.Sprintf("error: %s", text)) + common.SysLog(fmt.Sprintf("error: %s", text)) text = "请求上游地址失败" } } @@ -139,7 +138,7 @@ func TaskErrorWrapper(err error, code string, statusCode int) *dto.TaskError { text := err.Error() lowerText := strings.ToLower(text) if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") { - logger.SysLog(fmt.Sprintf("error: %s", text)) + common.SysLog(fmt.Sprintf("error: %s", text)) text = "请求上游地址失败" } //避免暴露内部错误 diff --git a/service/image.go b/service/image.go index 957ca041..252093f1 100644 --- a/service/image.go +++ b/service/image.go @@ -8,8 +8,8 @@ import ( "image" "io" "net/http" + "one-api/common" "one-api/constant" - "one-api/logger" "strings" "golang.org/x/image/webp" @@ -113,7 +113,7 @@ func GetImageFromUrl(url string) (mimeType string, data string, err error) { func DecodeUrlImageData(imageUrl string) (image.Config, string, error) { response, err := DoDownloadRequest(imageUrl) if err != nil { - logger.SysLog(fmt.Sprintf("fail to get image from url: %s", err.Error())) + common.SysLog(fmt.Sprintf("fail to get image from url: %s", err.Error())) return image.Config{}, "", err } defer response.Body.Close() @@ -131,7 +131,7 @@ func DecodeUrlImageData(imageUrl string) (image.Config, string, error) { var readData []byte for _, limit := range []int64{1024 * 8, 1024 * 24, 1024 * 64} { - logger.SysLog(fmt.Sprintf("try to decode image config with limit: %d", limit)) + common.SysLog(fmt.Sprintf("try to decode image config with limit: %d", limit)) // 从response.Body读取更多的数据直到达到当前的限制 additionalData := make([]byte, limit-int64(len(readData))) @@ -157,11 +157,11 @@ func getImageConfig(reader io.Reader) (image.Config, string, error) { config, format, err := image.DecodeConfig(reader) if err != nil { err = errors.New(fmt.Sprintf("fail to decode image config(gif, jpg, png): %s", err.Error())) - logger.SysLog(err.Error()) + common.SysLog(err.Error()) config, err = webp.DecodeConfig(reader) if err != nil { err = errors.New(fmt.Sprintf("fail to decode image config(webp): %s", err.Error())) - logger.SysLog(err.Error()) + common.SysLog(err.Error()) } format = "webp" } diff --git a/service/log_info_generate.go b/service/log_info_generate.go index 0dae9a03..7a609c9f 100644 --- a/service/log_info_generate.go +++ b/service/log_info_generate.go @@ -5,7 +5,7 @@ import ( "one-api/constant" "one-api/dto" relaycommon "one-api/relay/common" - "one-api/relay/helper" + "one-api/types" "github.com/gin-gonic/gin" ) @@ -78,7 +78,7 @@ func GenerateClaudeOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, return info } -func GenerateMjOtherInfo(priceData helper.PerCallPriceData) map[string]interface{} { +func GenerateMjOtherInfo(priceData types.PerCallPriceData) map[string]interface{} { other := make(map[string]interface{}) other["model_price"] = priceData.ModelPrice other["group_ratio"] = priceData.GroupRatioInfo.GroupRatio diff --git a/service/midjourney.go b/service/midjourney.go index 1d232739..916d02d0 100644 --- a/service/midjourney.go +++ b/service/midjourney.go @@ -9,7 +9,6 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" - "one-api/logger" relayconstant "one-api/relay/constant" "one-api/setting" "strconv" @@ -213,7 +212,7 @@ func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestU defer cancel() resp, err := GetHttpClient().Do(req) if err != nil { - logger.SysError("do request failed: " + err.Error()) + common.SysLog("do request failed: " + err.Error()) return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "do_request_failed", http.StatusInternalServerError), nullBytes, err } statusCode := resp.StatusCode diff --git a/service/pre_consume_quota.go b/service/pre_consume_quota.go index 3c4d0e7e..3902ef92 100644 --- a/service/pre_consume_quota.go +++ b/service/pre_consume_quota.go @@ -6,6 +6,7 @@ import ( "github.com/bytedance/gopkg/util/gopool" "github.com/gin-gonic/gin" "net/http" + "one-api/common" "one-api/logger" "one-api/model" relaycommon "one-api/relay/common" @@ -19,7 +20,7 @@ func ReturnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, pr err := PostConsumeQuota(&relayInfoCopy, -preConsumedQuota, 0, false) if err != nil { - logger.SysError("error return pre-consumed quota: " + err.Error()) + common.SysLog("error return pre-consumed quota: " + err.Error()) } }) } diff --git a/service/token_counter.go b/service/token_counter.go index ec817182..43a508c1 100644 --- a/service/token_counter.go +++ b/service/token_counter.go @@ -10,7 +10,6 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" - "one-api/logger" relaycommon "one-api/relay/common" "one-api/types" "strings" @@ -32,9 +31,9 @@ var tokenEncoderMap = make(map[string]tokenizer.Codec) var tokenEncoderMutex sync.RWMutex func InitTokenEncoders() { - logger.SysLog("initializing token encoders") + common.SysLog("initializing token encoders") defaultTokenEncoder = codec.NewCl100kBase() - logger.SysLog("token encoders initialized") + common.SysLog("token encoders initialized") } func getTokenEncoder(model string) tokenizer.Codec { @@ -158,7 +157,7 @@ func getImageToken(fileMeta *types.FileMeta, model string, stream bool) (int, er if strings.HasPrefix(fileMeta.Data, "http") { config, format, err = DecodeUrlImageData(fileMeta.Data) } else { - logger.SysLog(fmt.Sprintf("decoding image")) + common.SysLog(fmt.Sprintf("decoding image")) config, format, b64str, err = DecodeBase64ImageData(fileMeta.Data) } if err != nil { @@ -248,6 +247,11 @@ func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relayco if meta == nil { return 0, errors.New("token count meta is nil") } + + if info.RelayFormat == types.RelayFormatOpenAIRealtime { + return 0, nil + } + model := common.GetContextKeyString(c, constant.ContextKeyOriginalModel) tkm := CountTextToken(meta.CombineText, model) diff --git a/service/user_notify.go b/service/user_notify.go index 1fcc62d3..7c864a1b 100644 --- a/service/user_notify.go +++ b/service/user_notify.go @@ -4,7 +4,6 @@ import ( "fmt" "one-api/common" "one-api/dto" - "one-api/logger" "one-api/model" "strings" ) @@ -13,7 +12,7 @@ func NotifyRootUser(t string, subject string, content string) { user := model.GetRootUser().ToBaseUser() err := NotifyUser(user.Id, user.Email, user.GetSetting(), dto.NewNotify(t, subject, content, nil)) if err != nil { - logger.SysError(fmt.Sprintf("failed to notify root user: %s", err.Error())) + common.SysLog(fmt.Sprintf("failed to notify root user: %s", err.Error())) } } @@ -26,7 +25,7 @@ func NotifyUser(userId int, userEmail string, userSetting dto.UserSetting, data // Check notification limit canSend, err := CheckNotificationLimit(userId, data.Type) if err != nil { - logger.SysError(fmt.Sprintf("failed to check notification limit: %s", err.Error())) + common.SysLog(fmt.Sprintf("failed to check notification limit: %s", err.Error())) return err } if !canSend { @@ -38,14 +37,14 @@ func NotifyUser(userId int, userEmail string, userSetting dto.UserSetting, data // check setting email userEmail = userSetting.NotificationEmail if userEmail == "" { - logger.SysLog(fmt.Sprintf("user %d has no email, skip sending email", userId)) + common.SysLog(fmt.Sprintf("user %d has no email, skip sending email", userId)) return nil } return sendEmailNotify(userEmail, data) case dto.NotifyTypeWebhook: webhookURLStr := userSetting.WebhookUrl if webhookURLStr == "" { - logger.SysError(fmt.Sprintf("user %d has no webhook url, skip sending webhook", userId)) + common.SysLog(fmt.Sprintf("user %d has no webhook url, skip sending webhook", userId)) return nil } diff --git a/setting/chat.go b/setting/chat.go index b417af28..bd1e26e3 100644 --- a/setting/chat.go +++ b/setting/chat.go @@ -2,7 +2,7 @@ package setting import ( "encoding/json" - "one-api/logger" + "one-api/common" ) var Chats = []map[string]string{ @@ -37,7 +37,7 @@ func UpdateChatsByJsonString(jsonString string) error { func Chats2JsonString() string { jsonBytes, err := json.Marshal(Chats) if err != nil { - logger.SysError("error marshalling chats: " + err.Error()) + common.SysLog("error marshalling chats: " + err.Error()) return "[]" } return string(jsonBytes) diff --git a/setting/config/config.go b/setting/config/config.go index 2e43e0a7..3af51b14 100644 --- a/setting/config/config.go +++ b/setting/config/config.go @@ -2,7 +2,7 @@ package config import ( "encoding/json" - "one-api/logger" + "one-api/common" "reflect" "strconv" "strings" @@ -57,7 +57,7 @@ func (cm *ConfigManager) LoadFromDB(options map[string]string) error { // 如果找到配置项,则更新配置 if len(configMap) > 0 { if err := updateConfigFromMap(config, configMap); err != nil { - logger.SysError("failed to update config " + name + ": " + err.Error()) + common.SysError("failed to update config " + name + ": " + err.Error()) continue } } diff --git a/setting/rate_limit.go b/setting/rate_limit.go index dcb9fae5..141463e1 100644 --- a/setting/rate_limit.go +++ b/setting/rate_limit.go @@ -4,7 +4,7 @@ import ( "encoding/json" "fmt" "math" - "one-api/logger" + "one-api/common" "sync" ) @@ -21,7 +21,7 @@ func ModelRequestRateLimitGroup2JSONString() string { jsonBytes, err := json.Marshal(ModelRequestRateLimitGroup) if err != nil { - logger.SysError("error marshalling model ratio: " + err.Error()) + common.SysLog("error marshalling model ratio: " + err.Error()) } return string(jsonBytes) } diff --git a/setting/ratio_setting/cache_ratio.go b/setting/ratio_setting/cache_ratio.go index 47079850..5993cdee 100644 --- a/setting/ratio_setting/cache_ratio.go +++ b/setting/ratio_setting/cache_ratio.go @@ -2,7 +2,7 @@ package ratio_setting import ( "encoding/json" - "one-api/logger" + "one-api/common" "sync" ) @@ -89,7 +89,7 @@ func CacheRatio2JSONString() string { defer cacheRatioMapMutex.RUnlock() jsonBytes, err := json.Marshal(cacheRatioMap) if err != nil { - logger.SysError("error marshalling cache ratio: " + err.Error()) + common.SysLog("error marshalling cache ratio: " + err.Error()) } return string(jsonBytes) } diff --git a/setting/ratio_setting/group_ratio.go b/setting/ratio_setting/group_ratio.go index c1a666e9..c42553da 100644 --- a/setting/ratio_setting/group_ratio.go +++ b/setting/ratio_setting/group_ratio.go @@ -3,7 +3,7 @@ package ratio_setting import ( "encoding/json" "errors" - "one-api/logger" + "one-api/common" "sync" ) @@ -48,7 +48,7 @@ func GroupRatio2JSONString() string { jsonBytes, err := json.Marshal(groupRatio) if err != nil { - logger.SysError("error marshalling model ratio: " + err.Error()) + common.SysLog("error marshalling model ratio: " + err.Error()) } return string(jsonBytes) } @@ -67,7 +67,7 @@ func GetGroupRatio(name string) float64 { ratio, ok := groupRatio[name] if !ok { - logger.SysError("group ratio not found: " + name) + common.SysLog("group ratio not found: " + name) return 1 } return ratio @@ -94,7 +94,7 @@ func GroupGroupRatio2JSONString() string { jsonBytes, err := json.Marshal(GroupGroupRatio) if err != nil { - logger.SysError("error marshalling group-group ratio: " + err.Error()) + common.SysLog("error marshalling group-group ratio: " + err.Error()) } return string(jsonBytes) } diff --git a/setting/user_usable_group.go b/setting/user_usable_group.go index bcbe712c..57e4beec 100644 --- a/setting/user_usable_group.go +++ b/setting/user_usable_group.go @@ -2,7 +2,7 @@ package setting import ( "encoding/json" - "one-api/logger" + "one-api/common" "sync" ) @@ -29,7 +29,7 @@ func UserUsableGroups2JSONString() string { jsonBytes, err := json.Marshal(userUsableGroups) if err != nil { - logger.SysError("error marshalling user groups: " + err.Error()) + common.SysLog("error marshalling user groups: " + err.Error()) } return string(jsonBytes) } diff --git a/types/relay_format.go b/types/relay_format.go index 4c29d649..6d94a70b 100644 --- a/types/relay_format.go +++ b/types/relay_format.go @@ -12,4 +12,7 @@ const ( RelayFormatOpenAIRealtime = "openai_realtime" RelayFormatRerank = "rerank" RelayFormatEmbedding = "embedding" + + RelayFormatTask = "task" + RelayFormatMjProxy = "mj_proxy" ) diff --git a/types/relay_request.go b/types/relay_request.go deleted file mode 100644 index b9d092f0..00000000 --- a/types/relay_request.go +++ /dev/null @@ -1,27 +0,0 @@ -package types - -type RelayRequest struct { - OriginRequest any - Format RelayFormat - PromptTokenCount int -} - -func (r *RelayRequest) CopyOriginRequest() any { - if r.OriginRequest == nil { - return nil - } - switch v := r.OriginRequest.(type) { - case *GeneralOpenAIRequest: - return v.Copy() - case *GeneralClaudeRequest: - return v.Copy() - case *GeneralGeminiRequest: - return v.Copy() - case *GeneralRerankRequest: - return v.Copy() - case *GeneralEmbeddingRequest: - return v.Copy() - default: - return nil - } -}