Merge remote-tracking branch 'origin/alpha' into fix/openrouter-custom-ratio-billing

This commit is contained in:
yyhhyyyyyy
2025-08-15 14:58:42 +08:00
140 changed files with 3384 additions and 2647 deletions

5
common/quota.go Normal file
View File

@@ -0,0 +1,5 @@
package common
func GetTrustQuota() int {
return int(10 * QuotaPerUnit)
}

View File

@@ -99,12 +99,75 @@ func GetJsonString(data any) string {
return string(b) return string(b)
} }
// MaskSensitiveInfo masks sensitive information like URLs, IPs in a string // MaskEmail masks a user email to prevent PII leakage in logs
// Returns "***masked***" if email is empty, otherwise shows only the domain part
func MaskEmail(email string) string {
if email == "" {
return "***masked***"
}
// Find the @ symbol
atIndex := strings.Index(email, "@")
if atIndex == -1 {
// No @ symbol found, return masked
return "***masked***"
}
// Return only the domain part with @ symbol
return "***@" + email[atIndex+1:]
}
// maskHostTail returns the tail parts of a domain/host that should be preserved.
// It keeps 2 parts for likely country-code TLDs (e.g., co.uk, com.cn), otherwise keeps only the TLD.
func maskHostTail(parts []string) []string {
if len(parts) < 2 {
return parts
}
lastPart := parts[len(parts)-1]
secondLastPart := parts[len(parts)-2]
if len(lastPart) == 2 && len(secondLastPart) <= 3 {
// Likely country code TLD like co.uk, com.cn
return []string{secondLastPart, lastPart}
}
return []string{lastPart}
}
// maskHostForURL collapses subdomains and keeps only masked prefix + preserved tail.
// Example: api.openai.com -> ***.com, sub.domain.co.uk -> ***.co.uk
func maskHostForURL(host string) string {
parts := strings.Split(host, ".")
if len(parts) < 2 {
return "***"
}
tail := maskHostTail(parts)
return "***." + strings.Join(tail, ".")
}
// maskHostForPlainDomain masks a plain domain and reflects subdomain depth with multiple ***.
// Example: openai.com -> ***.com, api.openai.com -> ***.***.com, sub.domain.co.uk -> ***.***.co.uk
func maskHostForPlainDomain(domain string) string {
parts := strings.Split(domain, ".")
if len(parts) < 2 {
return domain
}
tail := maskHostTail(parts)
numStars := len(parts) - len(tail)
if numStars < 1 {
numStars = 1
}
stars := strings.TrimSuffix(strings.Repeat("***.", numStars), ".")
return stars + "." + strings.Join(tail, ".")
}
// MaskSensitiveInfo masks sensitive information like URLs, IPs, and domain names in a string
// Example: // Example:
// http://example.com -> http://***.com // http://example.com -> http://***.com
// https://api.test.org/v1/users/123?key=secret -> https://***.org/***/***/?key=*** // https://api.test.org/v1/users/123?key=secret -> https://***.org/***/***/?key=***
// https://sub.domain.co.uk/path/to/resource -> https://***.co.uk/***/*** // https://sub.domain.co.uk/path/to/resource -> https://***.co.uk/***/***
// 192.168.1.1 -> ***.***.***.*** // 192.168.1.1 -> ***.***.***.***
// openai.com -> ***.com
// www.openai.com -> ***.***.com
// api.openai.com -> ***.***.com
func MaskSensitiveInfo(str string) string { func MaskSensitiveInfo(str string) string {
// Mask URLs // Mask URLs
urlPattern := regexp.MustCompile(`(http|https)://[^\s/$.?#].[^\s]*`) urlPattern := regexp.MustCompile(`(http|https)://[^\s/$.?#].[^\s]*`)
@@ -119,32 +182,8 @@ func MaskSensitiveInfo(str string) string {
return urlStr return urlStr
} }
// Split host by dots // Mask host with unified logic
parts := strings.Split(host, ".") maskedHost := maskHostForURL(host)
if len(parts) < 2 {
// If less than 2 parts, just mask the whole host
return u.Scheme + "://***" + u.Path
}
// Keep the TLD (Top Level Domain) and mask the rest
var maskedHost string
if len(parts) == 2 {
// example.com -> ***.com
maskedHost = "***." + parts[len(parts)-1]
} else {
// Handle cases like sub.domain.co.uk or api.example.com
// Keep last 2 parts if they look like country code TLD (co.uk, com.cn, etc.)
lastPart := parts[len(parts)-1]
secondLastPart := parts[len(parts)-2]
if len(lastPart) == 2 && len(secondLastPart) <= 3 {
// Likely country code TLD like co.uk, com.cn
maskedHost = "***." + secondLastPart + "." + lastPart
} else {
// Regular TLD like .com, .org
maskedHost = "***." + lastPart
}
}
result := u.Scheme + "://" + maskedHost result := u.Scheme + "://" + maskedHost
@@ -184,6 +223,12 @@ func MaskSensitiveInfo(str string) string {
return result return result
}) })
// Mask domain names without protocol (like openai.com, www.openai.com)
domainPattern := regexp.MustCompile(`\b(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,}\b`)
str = domainPattern.ReplaceAllStringFunc(str, func(domain string) string {
return maskHostForPlainDomain(domain)
})
// Mask IP addresses // Mask IP addresses
ipPattern := regexp.MustCompile(`\b(?:\d{1,3}\.){3}\d{1,3}\b`) ipPattern := regexp.MustCompile(`\b(?:\d{1,3}\.){3}\d{1,3}\b`)
str = ipPattern.ReplaceAllString(str, "***.***.***.***") str = ipPattern.ReplaceAllString(str, "***.***.***.***")

24
common/sys_log.go Normal file
View File

@@ -0,0 +1,24 @@
package common
import (
"fmt"
"github.com/gin-gonic/gin"
"os"
"time"
)
func SysLog(s string) {
t := time.Now()
_, _ = fmt.Fprintf(gin.DefaultWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
}
func SysError(s string) {
t := time.Now()
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
}
func FatalLog(v ...any) {
t := time.Now()
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v)
os.Exit(1)
}

View File

@@ -3,6 +3,8 @@ package constant
type ContextKey string type ContextKey string
const ( const (
ContextKeyPromptTokens ContextKey = "prompt_tokens"
ContextKeyOriginalModel ContextKey = "original_model" ContextKeyOriginalModel ContextKey = "original_model"
ContextKeyRequestStartTime ContextKey = "request_start_time" ContextKeyRequestStartTime ContextKey = "request_start_time"

View File

@@ -132,10 +132,27 @@ func testChannel(channel *model.Channel, testModel string) testResult {
newAPIError: newAPIError, newAPIError: newAPIError,
} }
} }
request := buildTestRequest(testModel)
info := relaycommon.GenRelayInfo(c) // Determine relay format based on request path
relayFormat := types.RelayFormatOpenAI
if c.Request.URL.Path == "/v1/embeddings" {
relayFormat = types.RelayFormatEmbedding
}
err = helper.ModelMappedHelper(c, info, nil) info, err := relaycommon.GenRelayInfo(c, relayFormat, request, nil)
if err != nil {
return testResult{
context: c,
localErr: err,
newAPIError: types.NewError(err, types.ErrorCodeGenRelayInfoFailed),
}
}
info.InitChannelMeta(c)
err = helper.ModelMappedHelper(c, info, request)
if err != nil { if err != nil {
return testResult{ return testResult{
context: c, context: c,
@@ -143,7 +160,9 @@ func testChannel(channel *model.Channel, testModel string) testResult {
newAPIError: types.NewError(err, types.ErrorCodeChannelModelMappedError), newAPIError: types.NewError(err, types.ErrorCodeChannelModelMappedError),
} }
} }
testModel = info.UpstreamModelName testModel = info.UpstreamModelName
request.Model = testModel
apiType, _ := common.ChannelType2APIType(channel.Type) apiType, _ := common.ChannelType2APIType(channel.Type)
adaptor := relay.GetAdaptor(apiType) adaptor := relay.GetAdaptor(apiType)
@@ -155,13 +174,12 @@ func testChannel(channel *model.Channel, testModel string) testResult {
} }
} }
request := buildTestRequest(testModel) //// 创建一个用于日志的 info 副本,移除 ApiKey
// 创建一个用于日志的 info 副本,移除 ApiKey //logInfo := info
logInfo := *info //logInfo.ApiKey = ""
logInfo.ApiKey = "" common.SysLog(fmt.Sprintf("testing channel %d with model %s , info %+v ", channel.Id, testModel, info.ToString()))
common.SysLog(fmt.Sprintf("testing channel %d with model %s , info %+v ", channel.Id, testModel, logInfo))
priceData, err := helper.ModelPriceHelper(c, info, 0, int(request.GetMaxTokens())) priceData, err := helper.ModelPriceHelper(c, info, 0, request.GetTokenCountMeta())
if err != nil { if err != nil {
return testResult{ return testResult{
context: c, context: c,

View File

@@ -3,101 +3,102 @@
package controller package controller
import ( import (
"encoding/json" "encoding/json"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/model" "one-api/model"
"github.com/gin-gonic/gin"
"github.com/gin-gonic/gin"
) )
// MigrateConsoleSetting 迁移旧的控制台相关配置到 console_setting.* // MigrateConsoleSetting 迁移旧的控制台相关配置到 console_setting.*
func MigrateConsoleSetting(c *gin.Context) { func MigrateConsoleSetting(c *gin.Context) {
// 读取全部 option // 读取全部 option
opts, err := model.AllOption() opts, err := model.AllOption()
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": err.Error()})
return return
} }
// 建立 map // 建立 map
valMap := map[string]string{} valMap := map[string]string{}
for _, o := range opts { for _, o := range opts {
valMap[o.Key] = o.Value valMap[o.Key] = o.Value
} }
// 处理 APIInfo // 处理 APIInfo
if v := valMap["ApiInfo"]; v != "" { if v := valMap["ApiInfo"]; v != "" {
var arr []map[string]interface{} var arr []map[string]interface{}
if err := json.Unmarshal([]byte(v), &arr); err == nil { if err := json.Unmarshal([]byte(v), &arr); err == nil {
if len(arr) > 50 { if len(arr) > 50 {
arr = arr[:50] arr = arr[:50]
} }
bytes, _ := json.Marshal(arr) bytes, _ := json.Marshal(arr)
model.UpdateOption("console_setting.api_info", string(bytes)) model.UpdateOption("console_setting.api_info", string(bytes))
} }
model.UpdateOption("ApiInfo", "") model.UpdateOption("ApiInfo", "")
} }
// Announcements 直接搬 // Announcements 直接搬
if v := valMap["Announcements"]; v != "" { if v := valMap["Announcements"]; v != "" {
model.UpdateOption("console_setting.announcements", v) model.UpdateOption("console_setting.announcements", v)
model.UpdateOption("Announcements", "") model.UpdateOption("Announcements", "")
} }
// FAQ 转换 // FAQ 转换
if v := valMap["FAQ"]; v != "" { if v := valMap["FAQ"]; v != "" {
var arr []map[string]interface{} var arr []map[string]interface{}
if err := json.Unmarshal([]byte(v), &arr); err == nil { if err := json.Unmarshal([]byte(v), &arr); err == nil {
out := []map[string]interface{}{} out := []map[string]interface{}{}
for _, item := range arr { for _, item := range arr {
q, _ := item["question"].(string) q, _ := item["question"].(string)
if q == "" { if q == "" {
q, _ = item["title"].(string) q, _ = item["title"].(string)
} }
a, _ := item["answer"].(string) a, _ := item["answer"].(string)
if a == "" { if a == "" {
a, _ = item["content"].(string) a, _ = item["content"].(string)
} }
if q != "" && a != "" { if q != "" && a != "" {
out = append(out, map[string]interface{}{"question": q, "answer": a}) out = append(out, map[string]interface{}{"question": q, "answer": a})
} }
} }
if len(out) > 50 { if len(out) > 50 {
out = out[:50] out = out[:50]
} }
bytes, _ := json.Marshal(out) bytes, _ := json.Marshal(out)
model.UpdateOption("console_setting.faq", string(bytes)) model.UpdateOption("console_setting.faq", string(bytes))
} }
model.UpdateOption("FAQ", "") model.UpdateOption("FAQ", "")
} }
// Uptime Kuma 迁移到新的 groups 结构console_setting.uptime_kuma_groups // Uptime Kuma 迁移到新的 groups 结构console_setting.uptime_kuma_groups
url := valMap["UptimeKumaUrl"] url := valMap["UptimeKumaUrl"]
slug := valMap["UptimeKumaSlug"] slug := valMap["UptimeKumaSlug"]
if url != "" && slug != "" { if url != "" && slug != "" {
// 仅当同时存在 URL 与 Slug 时才进行迁移 // 仅当同时存在 URL 与 Slug 时才进行迁移
groups := []map[string]interface{}{ groups := []map[string]interface{}{
{ {
"id": 1, "id": 1,
"categoryName": "old", "categoryName": "old",
"url": url, "url": url,
"slug": slug, "slug": slug,
"description": "", "description": "",
}, },
} }
bytes, _ := json.Marshal(groups) bytes, _ := json.Marshal(groups)
model.UpdateOption("console_setting.uptime_kuma_groups", string(bytes)) model.UpdateOption("console_setting.uptime_kuma_groups", string(bytes))
} }
// 清空旧键内容 // 清空旧键内容
if url != "" { if url != "" {
model.UpdateOption("UptimeKumaUrl", "") model.UpdateOption("UptimeKumaUrl", "")
} }
if slug != "" { if slug != "" {
model.UpdateOption("UptimeKumaSlug", "") model.UpdateOption("UptimeKumaSlug", "")
} }
// 删除旧键记录 // 删除旧键记录
oldKeys := []string{"ApiInfo", "Announcements", "FAQ", "UptimeKumaUrl", "UptimeKumaSlug"} oldKeys := []string{"ApiInfo", "Announcements", "FAQ", "UptimeKumaUrl", "UptimeKumaSlug"}
model.DB.Where("key IN ?", oldKeys).Delete(&model.Option{}) model.DB.Where("key IN ?", oldKeys).Delete(&model.Option{})
// 重新加载 OptionMap // 重新加载 OptionMap
model.InitOptionMap() model.InitOptionMap()
common.SysLog("console setting migrated") common.SysLog("console setting migrated")
c.JSON(http.StatusOK, gin.H{"success": true, "message": "migrated"}) c.JSON(http.StatusOK, gin.H{"success": true, "message": "migrated"})
} }

View File

@@ -9,6 +9,7 @@ import (
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/dto" "one-api/dto"
"one-api/logger"
"one-api/model" "one-api/model"
"one-api/service" "one-api/service"
"one-api/setting" "one-api/setting"
@@ -28,7 +29,7 @@ func UpdateMidjourneyTaskBulk() {
continue continue
} }
common.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks))) logger.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks)))
taskChannelM := make(map[int][]string) taskChannelM := make(map[int][]string)
taskM := make(map[string]*model.Midjourney) taskM := make(map[string]*model.Midjourney)
nullTaskIds := make([]int, 0) nullTaskIds := make([]int, 0)
@@ -47,9 +48,9 @@ func UpdateMidjourneyTaskBulk() {
"progress": "100%", "progress": "100%",
}) })
if err != nil { if err != nil {
common.LogError(ctx, fmt.Sprintf("Fix null mj_id task error: %v", err)) logger.LogError(ctx, fmt.Sprintf("Fix null mj_id task error: %v", err))
} else { } else {
common.LogInfo(ctx, fmt.Sprintf("Fix null mj_id task success: %v", nullTaskIds)) logger.LogInfo(ctx, fmt.Sprintf("Fix null mj_id task success: %v", nullTaskIds))
} }
} }
if len(taskChannelM) == 0 { if len(taskChannelM) == 0 {
@@ -57,20 +58,20 @@ func UpdateMidjourneyTaskBulk() {
} }
for channelId, taskIds := range taskChannelM { for channelId, taskIds := range taskChannelM {
common.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds))) logger.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
if len(taskIds) == 0 { if len(taskIds) == 0 {
continue continue
} }
midjourneyChannel, err := model.CacheGetChannel(channelId) midjourneyChannel, err := model.CacheGetChannel(channelId)
if err != nil { if err != nil {
common.LogError(ctx, fmt.Sprintf("CacheGetChannel: %v", err)) logger.LogError(ctx, fmt.Sprintf("CacheGetChannel: %v", err))
err := model.MjBulkUpdate(taskIds, map[string]any{ err := model.MjBulkUpdate(taskIds, map[string]any{
"fail_reason": fmt.Sprintf("获取渠道信息失败请联系管理员渠道ID%d", channelId), "fail_reason": fmt.Sprintf("获取渠道信息失败请联系管理员渠道ID%d", channelId),
"status": "FAILURE", "status": "FAILURE",
"progress": "100%", "progress": "100%",
}) })
if err != nil { if err != nil {
common.LogInfo(ctx, fmt.Sprintf("UpdateMidjourneyTask error: %v", err)) logger.LogInfo(ctx, fmt.Sprintf("UpdateMidjourneyTask error: %v", err))
} }
continue continue
} }
@@ -81,7 +82,7 @@ func UpdateMidjourneyTaskBulk() {
}) })
req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(body)) req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(body))
if err != nil { if err != nil {
common.LogError(ctx, fmt.Sprintf("Get Task error: %v", err)) logger.LogError(ctx, fmt.Sprintf("Get Task error: %v", err))
continue continue
} }
// 设置超时时间 // 设置超时时间
@@ -93,22 +94,22 @@ func UpdateMidjourneyTaskBulk() {
req.Header.Set("mj-api-secret", midjourneyChannel.Key) req.Header.Set("mj-api-secret", midjourneyChannel.Key)
resp, err := service.GetHttpClient().Do(req) resp, err := service.GetHttpClient().Do(req)
if err != nil { if err != nil {
common.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err)) logger.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err))
continue continue
} }
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
common.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode)) logger.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
continue continue
} }
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
common.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err)) logger.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err))
continue continue
} }
var responseItems []dto.MidjourneyDto var responseItems []dto.MidjourneyDto
err = json.Unmarshal(responseBody, &responseItems) err = json.Unmarshal(responseBody, &responseItems)
if err != nil { if err != nil {
common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody))) logger.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
continue continue
} }
resp.Body.Close() resp.Body.Close()
@@ -152,7 +153,7 @@ func UpdateMidjourneyTaskBulk() {
if responseItem.VideoUrls != nil && len(responseItem.VideoUrls) > 0 { if responseItem.VideoUrls != nil && len(responseItem.VideoUrls) > 0 {
videoUrlsStr, err := json.Marshal(responseItem.VideoUrls) videoUrlsStr, err := json.Marshal(responseItem.VideoUrls)
if err != nil { if err != nil {
common.LogError(ctx, fmt.Sprintf("序列化 VideoUrls 失败: %v", err)) logger.LogError(ctx, fmt.Sprintf("序列化 VideoUrls 失败: %v", err))
task.VideoUrls = "[]" // 失败时设置为空数组 task.VideoUrls = "[]" // 失败时设置为空数组
} else { } else {
task.VideoUrls = string(videoUrlsStr) task.VideoUrls = string(videoUrlsStr)
@@ -163,7 +164,7 @@ func UpdateMidjourneyTaskBulk() {
shouldReturnQuota := false shouldReturnQuota := false
if (task.Progress != "100%" && responseItem.FailReason != "") || (task.Progress == "100%" && task.Status == "FAILURE") { if (task.Progress != "100%" && responseItem.FailReason != "") || (task.Progress == "100%" && task.Status == "FAILURE") {
common.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason) logger.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason)
task.Progress = "100%" task.Progress = "100%"
if task.Quota != 0 { if task.Quota != 0 {
shouldReturnQuota = true shouldReturnQuota = true
@@ -171,14 +172,14 @@ func UpdateMidjourneyTaskBulk() {
} }
err = task.Update() err = task.Update()
if err != nil { if err != nil {
common.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error()) logger.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error())
} else { } else {
if shouldReturnQuota { if shouldReturnQuota {
err = model.IncreaseUserQuota(task.UserId, task.Quota, false) err = model.IncreaseUserQuota(task.UserId, task.Quota, false)
if err != nil { if err != nil {
common.LogError(ctx, "fail to increase user quota: "+err.Error()) logger.LogError(ctx, "fail to increase user quota: "+err.Error())
} }
logContent := fmt.Sprintf("构图失败 %s补偿 %s", task.MjId, common.LogQuota(task.Quota)) logContent := fmt.Sprintf("构图失败 %s补偿 %s", task.MjId, logger.LogQuota(task.Quota))
model.RecordLog(task.UserId, model.LogTypeSystem, logContent) model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
} }
} }

View File

@@ -93,7 +93,9 @@ func init() {
if !success || apiType == constant.APITypeAIProxyLibrary { if !success || apiType == constant.APITypeAIProxyLibrary {
continue continue
} }
meta := &relaycommon.RelayInfo{ChannelType: i} meta := &relaycommon.RelayInfo{ChannelMeta: &relaycommon.ChannelMeta{
ChannelType: i,
}}
adaptor := relay.GetAdaptor(apiType) adaptor := relay.GetAdaptor(apiType)
adaptor.Init(meta) adaptor.Init(meta)
channelId2Models[i] = adaptor.GetModelList() channelId2Models[i] = adaptor.GetModelList()

View File

@@ -69,7 +69,7 @@ func getOidcUserInfoByCode(code string) (*OidcUser, error) {
} }
if oidcResponse.AccessToken == "" { if oidcResponse.AccessToken == "" {
common.SysError("OIDC 获取 Token 失败,请检查设置!") common.SysLog("OIDC 获取 Token 失败,请检查设置!")
return nil, errors.New("OIDC 获取 Token 失败,请检查设置!") return nil, errors.New("OIDC 获取 Token 失败,请检查设置!")
} }
@@ -85,7 +85,7 @@ func getOidcUserInfoByCode(code string) (*OidcUser, error) {
} }
defer res2.Body.Close() defer res2.Body.Close()
if res2.StatusCode != http.StatusOK { if res2.StatusCode != http.StatusOK {
common.SysError("OIDC 获取用户信息失败!请检查设置!") common.SysLog("OIDC 获取用户信息失败!请检查设置!")
return nil, errors.New("OIDC 获取用户信息失败!请检查设置!") return nil, errors.New("OIDC 获取用户信息失败!请检查设置!")
} }
@@ -95,7 +95,7 @@ func getOidcUserInfoByCode(code string) (*OidcUser, error) {
return nil, err return nil, err
} }
if oidcUser.OpenID == "" || oidcUser.Email == "" { if oidcUser.OpenID == "" || oidcUser.Email == "" {
common.SysError("OIDC 获取用户信息为空!请检查设置!") common.SysLog("OIDC 获取用户信息为空!请检查设置!")
return nil, errors.New("OIDC 获取用户信息为空!请检查设置!") return nil, errors.New("OIDC 获取用户信息为空!请检查设置!")
} }
return &oidcUser, nil return &oidcUser, nil

View File

@@ -56,5 +56,5 @@ func Playground(c *gin.Context) {
//middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model) //middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model)
common.SetContextKey(c, constant.ContextKeyRequestStartTime, time.Now()) common.SetContextKey(c, constant.ContextKeyRequestStartTime, time.Now())
Relay(c) Relay(c, types.RelayFormatOpenAI)
} }

View File

@@ -1,474 +1,474 @@
package controller package controller
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
"strings" "one-api/logger"
"sync" "strings"
"time" "sync"
"time"
"one-api/common" "one-api/dto"
"one-api/dto" "one-api/model"
"one-api/model" "one-api/setting/ratio_setting"
"one-api/setting/ratio_setting"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
const ( const (
defaultTimeoutSeconds = 10 defaultTimeoutSeconds = 10
defaultEndpoint = "/api/ratio_config" defaultEndpoint = "/api/ratio_config"
maxConcurrentFetches = 8 maxConcurrentFetches = 8
) )
var ratioTypes = []string{"model_ratio", "completion_ratio", "cache_ratio", "model_price"} var ratioTypes = []string{"model_ratio", "completion_ratio", "cache_ratio", "model_price"}
type upstreamResult struct { type upstreamResult struct {
Name string `json:"name"` Name string `json:"name"`
Data map[string]any `json:"data,omitempty"` Data map[string]any `json:"data,omitempty"`
Err string `json:"err,omitempty"` Err string `json:"err,omitempty"`
} }
func FetchUpstreamRatios(c *gin.Context) { func FetchUpstreamRatios(c *gin.Context) {
var req dto.UpstreamRequest var req dto.UpstreamRequest
if err := c.ShouldBindJSON(&req); err != nil { if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": err.Error()}) c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": err.Error()})
return return
} }
if req.Timeout <= 0 { if req.Timeout <= 0 {
req.Timeout = defaultTimeoutSeconds req.Timeout = defaultTimeoutSeconds
} }
var upstreams []dto.UpstreamDTO var upstreams []dto.UpstreamDTO
if len(req.Upstreams) > 0 { if len(req.Upstreams) > 0 {
for _, u := range req.Upstreams { for _, u := range req.Upstreams {
if strings.HasPrefix(u.BaseURL, "http") { if strings.HasPrefix(u.BaseURL, "http") {
if u.Endpoint == "" { if u.Endpoint == "" {
u.Endpoint = defaultEndpoint u.Endpoint = defaultEndpoint
} }
u.BaseURL = strings.TrimRight(u.BaseURL, "/") u.BaseURL = strings.TrimRight(u.BaseURL, "/")
upstreams = append(upstreams, u) upstreams = append(upstreams, u)
} }
} }
} else if len(req.ChannelIDs) > 0 { } else if len(req.ChannelIDs) > 0 {
intIds := make([]int, 0, len(req.ChannelIDs)) intIds := make([]int, 0, len(req.ChannelIDs))
for _, id64 := range req.ChannelIDs { for _, id64 := range req.ChannelIDs {
intIds = append(intIds, int(id64)) intIds = append(intIds, int(id64))
} }
dbChannels, err := model.GetChannelsByIds(intIds) dbChannels, err := model.GetChannelsByIds(intIds)
if err != nil { if err != nil {
common.LogError(c.Request.Context(), "failed to query channels: "+err.Error()) logger.LogError(c.Request.Context(), "failed to query channels: "+err.Error())
c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": "查询渠道失败"}) c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": "查询渠道失败"})
return return
} }
for _, ch := range dbChannels { for _, ch := range dbChannels {
if base := ch.GetBaseURL(); strings.HasPrefix(base, "http") { if base := ch.GetBaseURL(); strings.HasPrefix(base, "http") {
upstreams = append(upstreams, dto.UpstreamDTO{ upstreams = append(upstreams, dto.UpstreamDTO{
ID: ch.Id, ID: ch.Id,
Name: ch.Name, Name: ch.Name,
BaseURL: strings.TrimRight(base, "/"), BaseURL: strings.TrimRight(base, "/"),
Endpoint: "", Endpoint: "",
}) })
} }
} }
} }
if len(upstreams) == 0 { if len(upstreams) == 0 {
c.JSON(http.StatusOK, gin.H{"success": false, "message": "无有效上游渠道"}) c.JSON(http.StatusOK, gin.H{"success": false, "message": "无有效上游渠道"})
return return
} }
var wg sync.WaitGroup var wg sync.WaitGroup
ch := make(chan upstreamResult, len(upstreams)) ch := make(chan upstreamResult, len(upstreams))
sem := make(chan struct{}, maxConcurrentFetches) sem := make(chan struct{}, maxConcurrentFetches)
client := &http.Client{Transport: &http.Transport{MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second}} client := &http.Client{Transport: &http.Transport{MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second}}
for _, chn := range upstreams { for _, chn := range upstreams {
wg.Add(1) wg.Add(1)
go func(chItem dto.UpstreamDTO) { go func(chItem dto.UpstreamDTO) {
defer wg.Done() defer wg.Done()
sem <- struct{}{} sem <- struct{}{}
defer func() { <-sem }() defer func() { <-sem }()
endpoint := chItem.Endpoint endpoint := chItem.Endpoint
if endpoint == "" { if endpoint == "" {
endpoint = defaultEndpoint endpoint = defaultEndpoint
} else if !strings.HasPrefix(endpoint, "/") { } else if !strings.HasPrefix(endpoint, "/") {
endpoint = "/" + endpoint endpoint = "/" + endpoint
} }
fullURL := chItem.BaseURL + endpoint fullURL := chItem.BaseURL + endpoint
uniqueName := chItem.Name uniqueName := chItem.Name
if chItem.ID != 0 { if chItem.ID != 0 {
uniqueName = fmt.Sprintf("%s(%d)", chItem.Name, chItem.ID) uniqueName = fmt.Sprintf("%s(%d)", chItem.Name, chItem.ID)
} }
ctx, cancel := context.WithTimeout(c.Request.Context(), time.Duration(req.Timeout)*time.Second) ctx, cancel := context.WithTimeout(c.Request.Context(), time.Duration(req.Timeout)*time.Second)
defer cancel() defer cancel()
httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil) httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil)
if err != nil { if err != nil {
common.LogWarn(c.Request.Context(), "build request failed: "+err.Error()) logger.LogWarn(c.Request.Context(), "build request failed: "+err.Error())
ch <- upstreamResult{Name: uniqueName, Err: err.Error()} ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
return return
} }
resp, err := client.Do(httpReq) resp, err := client.Do(httpReq)
if err != nil { if err != nil {
common.LogWarn(c.Request.Context(), "http error on "+chItem.Name+": "+err.Error()) logger.LogWarn(c.Request.Context(), "http error on "+chItem.Name+": "+err.Error())
ch <- upstreamResult{Name: uniqueName, Err: err.Error()} ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
return return
} }
defer resp.Body.Close() defer resp.Body.Close()
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
common.LogWarn(c.Request.Context(), "non-200 from "+chItem.Name+": "+resp.Status) logger.LogWarn(c.Request.Context(), "non-200 from "+chItem.Name+": "+resp.Status)
ch <- upstreamResult{Name: uniqueName, Err: resp.Status} ch <- upstreamResult{Name: uniqueName, Err: resp.Status}
return return
} }
// 兼容两种上游接口格式: // 兼容两种上游接口格式:
// type1: /api/ratio_config -> data 为 map[string]any包含 model_ratio/completion_ratio/cache_ratio/model_price // type1: /api/ratio_config -> data 为 map[string]any包含 model_ratio/completion_ratio/cache_ratio/model_price
// type2: /api/pricing -> data 为 []Pricing 列表,需要转换为与 type1 相同的 map 格式 // type2: /api/pricing -> data 为 []Pricing 列表,需要转换为与 type1 相同的 map 格式
var body struct { var body struct {
Success bool `json:"success"` Success bool `json:"success"`
Data json.RawMessage `json:"data"` Data json.RawMessage `json:"data"`
Message string `json:"message"` Message string `json:"message"`
} }
if err := json.NewDecoder(resp.Body).Decode(&body); err != nil { if err := json.NewDecoder(resp.Body).Decode(&body); err != nil {
common.LogWarn(c.Request.Context(), "json decode failed from "+chItem.Name+": "+err.Error()) logger.LogWarn(c.Request.Context(), "json decode failed from "+chItem.Name+": "+err.Error())
ch <- upstreamResult{Name: uniqueName, Err: err.Error()} ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
return return
} }
if !body.Success { if !body.Success {
ch <- upstreamResult{Name: uniqueName, Err: body.Message} ch <- upstreamResult{Name: uniqueName, Err: body.Message}
return return
} }
// 尝试按 type1 解析 // 尝试按 type1 解析
var type1Data map[string]any var type1Data map[string]any
if err := json.Unmarshal(body.Data, &type1Data); err == nil { if err := json.Unmarshal(body.Data, &type1Data); err == nil {
// 如果包含至少一个 ratioTypes 字段,则认为是 type1 // 如果包含至少一个 ratioTypes 字段,则认为是 type1
isType1 := false isType1 := false
for _, rt := range ratioTypes { for _, rt := range ratioTypes {
if _, ok := type1Data[rt]; ok { if _, ok := type1Data[rt]; ok {
isType1 = true isType1 = true
break break
} }
} }
if isType1 { if isType1 {
ch <- upstreamResult{Name: uniqueName, Data: type1Data} ch <- upstreamResult{Name: uniqueName, Data: type1Data}
return return
} }
} }
// 如果不是 type1则尝试按 type2 (/api/pricing) 解析 // 如果不是 type1则尝试按 type2 (/api/pricing) 解析
var pricingItems []struct { var pricingItems []struct {
ModelName string `json:"model_name"` ModelName string `json:"model_name"`
QuotaType int `json:"quota_type"` QuotaType int `json:"quota_type"`
ModelRatio float64 `json:"model_ratio"` ModelRatio float64 `json:"model_ratio"`
ModelPrice float64 `json:"model_price"` ModelPrice float64 `json:"model_price"`
CompletionRatio float64 `json:"completion_ratio"` CompletionRatio float64 `json:"completion_ratio"`
} }
if err := json.Unmarshal(body.Data, &pricingItems); err != nil { if err := json.Unmarshal(body.Data, &pricingItems); err != nil {
common.LogWarn(c.Request.Context(), "unrecognized data format from "+chItem.Name+": "+err.Error()) logger.LogWarn(c.Request.Context(), "unrecognized data format from "+chItem.Name+": "+err.Error())
ch <- upstreamResult{Name: uniqueName, Err: "无法解析上游返回数据"} ch <- upstreamResult{Name: uniqueName, Err: "无法解析上游返回数据"}
return return
} }
modelRatioMap := make(map[string]float64) modelRatioMap := make(map[string]float64)
completionRatioMap := make(map[string]float64) completionRatioMap := make(map[string]float64)
modelPriceMap := make(map[string]float64) modelPriceMap := make(map[string]float64)
for _, item := range pricingItems { for _, item := range pricingItems {
if item.QuotaType == 1 { if item.QuotaType == 1 {
modelPriceMap[item.ModelName] = item.ModelPrice modelPriceMap[item.ModelName] = item.ModelPrice
} else { } else {
modelRatioMap[item.ModelName] = item.ModelRatio modelRatioMap[item.ModelName] = item.ModelRatio
// completionRatio 可能为 0此时也直接赋值保持与上游一致 // completionRatio 可能为 0此时也直接赋值保持与上游一致
completionRatioMap[item.ModelName] = item.CompletionRatio completionRatioMap[item.ModelName] = item.CompletionRatio
} }
} }
converted := make(map[string]any) converted := make(map[string]any)
if len(modelRatioMap) > 0 { if len(modelRatioMap) > 0 {
ratioAny := make(map[string]any, len(modelRatioMap)) ratioAny := make(map[string]any, len(modelRatioMap))
for k, v := range modelRatioMap { for k, v := range modelRatioMap {
ratioAny[k] = v ratioAny[k] = v
} }
converted["model_ratio"] = ratioAny converted["model_ratio"] = ratioAny
} }
if len(completionRatioMap) > 0 { if len(completionRatioMap) > 0 {
compAny := make(map[string]any, len(completionRatioMap)) compAny := make(map[string]any, len(completionRatioMap))
for k, v := range completionRatioMap { for k, v := range completionRatioMap {
compAny[k] = v compAny[k] = v
} }
converted["completion_ratio"] = compAny converted["completion_ratio"] = compAny
} }
if len(modelPriceMap) > 0 { if len(modelPriceMap) > 0 {
priceAny := make(map[string]any, len(modelPriceMap)) priceAny := make(map[string]any, len(modelPriceMap))
for k, v := range modelPriceMap { for k, v := range modelPriceMap {
priceAny[k] = v priceAny[k] = v
} }
converted["model_price"] = priceAny converted["model_price"] = priceAny
} }
ch <- upstreamResult{Name: uniqueName, Data: converted} ch <- upstreamResult{Name: uniqueName, Data: converted}
}(chn) }(chn)
} }
wg.Wait() wg.Wait()
close(ch) close(ch)
localData := ratio_setting.GetExposedData() localData := ratio_setting.GetExposedData()
var testResults []dto.TestResult var testResults []dto.TestResult
var successfulChannels []struct { var successfulChannels []struct {
name string name string
data map[string]any data map[string]any
} }
for r := range ch { for r := range ch {
if r.Err != "" { if r.Err != "" {
testResults = append(testResults, dto.TestResult{ testResults = append(testResults, dto.TestResult{
Name: r.Name, Name: r.Name,
Status: "error", Status: "error",
Error: r.Err, Error: r.Err,
}) })
} else { } else {
testResults = append(testResults, dto.TestResult{ testResults = append(testResults, dto.TestResult{
Name: r.Name, Name: r.Name,
Status: "success", Status: "success",
}) })
successfulChannels = append(successfulChannels, struct { successfulChannels = append(successfulChannels, struct {
name string name string
data map[string]any data map[string]any
}{name: r.Name, data: r.Data}) }{name: r.Name, data: r.Data})
} }
} }
differences := buildDifferences(localData, successfulChannels) differences := buildDifferences(localData, successfulChannels)
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": true, "success": true,
"data": gin.H{ "data": gin.H{
"differences": differences, "differences": differences,
"test_results": testResults, "test_results": testResults,
}, },
}) })
} }
func buildDifferences(localData map[string]any, successfulChannels []struct { func buildDifferences(localData map[string]any, successfulChannels []struct {
name string name string
data map[string]any data map[string]any
}) map[string]map[string]dto.DifferenceItem { }) map[string]map[string]dto.DifferenceItem {
differences := make(map[string]map[string]dto.DifferenceItem) differences := make(map[string]map[string]dto.DifferenceItem)
allModels := make(map[string]struct{}) allModels := make(map[string]struct{})
for _, ratioType := range ratioTypes { for _, ratioType := range ratioTypes {
if localRatioAny, ok := localData[ratioType]; ok { if localRatioAny, ok := localData[ratioType]; ok {
if localRatio, ok := localRatioAny.(map[string]float64); ok { if localRatio, ok := localRatioAny.(map[string]float64); ok {
for modelName := range localRatio { for modelName := range localRatio {
allModels[modelName] = struct{}{} allModels[modelName] = struct{}{}
} }
} }
} }
} }
for _, channel := range successfulChannels { for _, channel := range successfulChannels {
for _, ratioType := range ratioTypes { for _, ratioType := range ratioTypes {
if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok { if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
for modelName := range upstreamRatio { for modelName := range upstreamRatio {
allModels[modelName] = struct{}{} allModels[modelName] = struct{}{}
} }
} }
} }
} }
confidenceMap := make(map[string]map[string]bool) confidenceMap := make(map[string]map[string]bool)
// 预处理阶段检查pricing接口的可信度 // 预处理阶段检查pricing接口的可信度
for _, channel := range successfulChannels { for _, channel := range successfulChannels {
confidenceMap[channel.name] = make(map[string]bool) confidenceMap[channel.name] = make(map[string]bool)
modelRatios, hasModelRatio := channel.data["model_ratio"].(map[string]any) modelRatios, hasModelRatio := channel.data["model_ratio"].(map[string]any)
completionRatios, hasCompletionRatio := channel.data["completion_ratio"].(map[string]any) completionRatios, hasCompletionRatio := channel.data["completion_ratio"].(map[string]any)
if hasModelRatio && hasCompletionRatio { if hasModelRatio && hasCompletionRatio {
// 遍历所有模型,检查是否满足不可信条件 // 遍历所有模型,检查是否满足不可信条件
for modelName := range allModels { for modelName := range allModels {
// 默认为可信 // 默认为可信
confidenceMap[channel.name][modelName] = true confidenceMap[channel.name][modelName] = true
// 检查是否满足不可信条件model_ratio为37.5且completion_ratio为1 // 检查是否满足不可信条件model_ratio为37.5且completion_ratio为1
if modelRatioVal, ok := modelRatios[modelName]; ok { if modelRatioVal, ok := modelRatios[modelName]; ok {
if completionRatioVal, ok := completionRatios[modelName]; ok { if completionRatioVal, ok := completionRatios[modelName]; ok {
// 转换为float64进行比较 // 转换为float64进行比较
if modelRatioFloat, ok := modelRatioVal.(float64); ok { if modelRatioFloat, ok := modelRatioVal.(float64); ok {
if completionRatioFloat, ok := completionRatioVal.(float64); ok { if completionRatioFloat, ok := completionRatioVal.(float64); ok {
if modelRatioFloat == 37.5 && completionRatioFloat == 1.0 { if modelRatioFloat == 37.5 && completionRatioFloat == 1.0 {
confidenceMap[channel.name][modelName] = false confidenceMap[channel.name][modelName] = false
} }
} }
} }
} }
} }
} }
} else { } else {
// 如果不是从pricing接口获取的数据则全部标记为可信 // 如果不是从pricing接口获取的数据则全部标记为可信
for modelName := range allModels { for modelName := range allModels {
confidenceMap[channel.name][modelName] = true confidenceMap[channel.name][modelName] = true
} }
} }
} }
for modelName := range allModels { for modelName := range allModels {
for _, ratioType := range ratioTypes { for _, ratioType := range ratioTypes {
var localValue interface{} = nil var localValue interface{} = nil
if localRatioAny, ok := localData[ratioType]; ok { if localRatioAny, ok := localData[ratioType]; ok {
if localRatio, ok := localRatioAny.(map[string]float64); ok { if localRatio, ok := localRatioAny.(map[string]float64); ok {
if val, exists := localRatio[modelName]; exists { if val, exists := localRatio[modelName]; exists {
localValue = val localValue = val
} }
} }
} }
upstreamValues := make(map[string]interface{}) upstreamValues := make(map[string]interface{})
confidenceValues := make(map[string]bool) confidenceValues := make(map[string]bool)
hasUpstreamValue := false hasUpstreamValue := false
hasDifference := false hasDifference := false
for _, channel := range successfulChannels { for _, channel := range successfulChannels {
var upstreamValue interface{} = nil var upstreamValue interface{} = nil
if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok { if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
if val, exists := upstreamRatio[modelName]; exists { if val, exists := upstreamRatio[modelName]; exists {
upstreamValue = val upstreamValue = val
hasUpstreamValue = true hasUpstreamValue = true
if localValue != nil && localValue != val { if localValue != nil && localValue != val {
hasDifference = true hasDifference = true
} else if localValue == val { } else if localValue == val {
upstreamValue = "same" upstreamValue = "same"
} }
} }
} }
if upstreamValue == nil && localValue == nil { if upstreamValue == nil && localValue == nil {
upstreamValue = "same" upstreamValue = "same"
} }
if localValue == nil && upstreamValue != nil && upstreamValue != "same" { if localValue == nil && upstreamValue != nil && upstreamValue != "same" {
hasDifference = true hasDifference = true
} }
upstreamValues[channel.name] = upstreamValue upstreamValues[channel.name] = upstreamValue
confidenceValues[channel.name] = confidenceMap[channel.name][modelName] confidenceValues[channel.name] = confidenceMap[channel.name][modelName]
} }
shouldInclude := false shouldInclude := false
if localValue != nil { if localValue != nil {
if hasDifference { if hasDifference {
shouldInclude = true shouldInclude = true
} }
} else { } else {
if hasUpstreamValue { if hasUpstreamValue {
shouldInclude = true shouldInclude = true
} }
} }
if shouldInclude { if shouldInclude {
if differences[modelName] == nil { if differences[modelName] == nil {
differences[modelName] = make(map[string]dto.DifferenceItem) differences[modelName] = make(map[string]dto.DifferenceItem)
} }
differences[modelName][ratioType] = dto.DifferenceItem{ differences[modelName][ratioType] = dto.DifferenceItem{
Current: localValue, Current: localValue,
Upstreams: upstreamValues, Upstreams: upstreamValues,
Confidence: confidenceValues, Confidence: confidenceValues,
} }
} }
} }
} }
channelHasDiff := make(map[string]bool) channelHasDiff := make(map[string]bool)
for _, ratioMap := range differences { for _, ratioMap := range differences {
for _, item := range ratioMap { for _, item := range ratioMap {
for chName, val := range item.Upstreams { for chName, val := range item.Upstreams {
if val != nil && val != "same" { if val != nil && val != "same" {
channelHasDiff[chName] = true channelHasDiff[chName] = true
} }
} }
} }
} }
for modelName, ratioMap := range differences { for modelName, ratioMap := range differences {
for ratioType, item := range ratioMap { for ratioType, item := range ratioMap {
for chName := range item.Upstreams { for chName := range item.Upstreams {
if !channelHasDiff[chName] { if !channelHasDiff[chName] {
delete(item.Upstreams, chName) delete(item.Upstreams, chName)
delete(item.Confidence, chName) delete(item.Confidence, chName)
} }
} }
allSame := true allSame := true
for _, v := range item.Upstreams { for _, v := range item.Upstreams {
if v != "same" { if v != "same" {
allSame = false allSame = false
break break
} }
} }
if len(item.Upstreams) == 0 || allSame { if len(item.Upstreams) == 0 || allSame {
delete(ratioMap, ratioType) delete(ratioMap, ratioType)
} else { } else {
differences[modelName][ratioType] = item differences[modelName][ratioType] = item
} }
} }
if len(ratioMap) == 0 { if len(ratioMap) == 0 {
delete(differences, modelName) delete(differences, modelName)
} }
} }
return differences return differences
} }
func GetSyncableChannels(c *gin.Context) { func GetSyncableChannels(c *gin.Context) {
channels, err := model.GetAllChannels(0, 0, true, false) channels, err := model.GetAllChannels(0, 0, true, false)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": err.Error(), "message": err.Error(),
}) })
return return
} }
var syncableChannels []dto.SyncableChannel var syncableChannels []dto.SyncableChannel
for _, channel := range channels { for _, channel := range channels {
if channel.GetBaseURL() != "" { if channel.GetBaseURL() != "" {
syncableChannels = append(syncableChannels, dto.SyncableChannel{ syncableChannels = append(syncableChannels, dto.SyncableChannel{
ID: channel.Id, ID: channel.Id,
Name: channel.Name, Name: channel.Name,
BaseURL: channel.GetBaseURL(), BaseURL: channel.GetBaseURL(),
Status: channel.Status, Status: channel.Status,
}) })
} }
} }
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": true, "success": true,
"message": "", "message": "",
"data": syncableChannels, "data": syncableChannels,
}) })
} }

View File

@@ -2,21 +2,22 @@ package controller
import ( import (
"bytes" "bytes"
"errors"
"fmt" "fmt"
"io" "io"
"log" "log"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/constant" "one-api/constant"
constant2 "one-api/constant"
"one-api/dto" "one-api/dto"
"one-api/logger"
"one-api/middleware" "one-api/middleware"
"one-api/model" "one-api/model"
"one-api/relay" "one-api/relay"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant" relayconstant "one-api/relay/constant"
"one-api/relay/helper" "one-api/relay/helper"
"one-api/service" "one-api/service"
"one-api/setting"
"one-api/types" "one-api/types"
"strings" "strings"
@@ -24,81 +25,177 @@ import (
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
) )
func relayHandler(c *gin.Context, relayMode int) *types.NewAPIError { func relayHandler(c *gin.Context, info *relaycommon.RelayInfo) *types.NewAPIError {
var err *types.NewAPIError var err *types.NewAPIError
switch relayMode { switch info.RelayMode {
case relayconstant.RelayModeImagesGenerations, relayconstant.RelayModeImagesEdits: case relayconstant.RelayModeImagesGenerations, relayconstant.RelayModeImagesEdits:
err = relay.ImageHelper(c) err = relay.ImageHelper(c, info)
case relayconstant.RelayModeAudioSpeech: case relayconstant.RelayModeAudioSpeech:
fallthrough fallthrough
case relayconstant.RelayModeAudioTranslation: case relayconstant.RelayModeAudioTranslation:
fallthrough fallthrough
case relayconstant.RelayModeAudioTranscription: case relayconstant.RelayModeAudioTranscription:
err = relay.AudioHelper(c) err = relay.AudioHelper(c, info)
case relayconstant.RelayModeRerank: case relayconstant.RelayModeRerank:
err = relay.RerankHelper(c, relayMode) err = relay.RerankHelper(c, info)
case relayconstant.RelayModeEmbeddings: case relayconstant.RelayModeEmbeddings:
err = relay.EmbeddingHelper(c) err = relay.EmbeddingHelper(c, info)
case relayconstant.RelayModeResponses: case relayconstant.RelayModeResponses:
err = relay.ResponsesHelper(c) err = relay.ResponsesHelper(c, info)
case relayconstant.RelayModeGemini:
if strings.Contains(c.Request.URL.Path, "embed") {
err = relay.GeminiEmbeddingHandler(c)
} else {
err = relay.GeminiHelper(c)
}
default: default:
err = relay.TextHelper(c) err = relay.TextHelper(c, info)
} }
if constant2.ErrorLogEnabled && err != nil && types.IsRecordErrorLog(err) {
// 保存错误日志到mysql中
userId := c.GetInt("id")
tokenName := c.GetString("token_name")
modelName := c.GetString("original_model")
tokenId := c.GetInt("token_id")
userGroup := c.GetString("group")
channelId := c.GetInt("channel_id")
other := make(map[string]interface{})
other["error_type"] = err.GetErrorType()
other["error_code"] = err.GetErrorCode()
other["status_code"] = err.StatusCode
other["channel_id"] = channelId
other["channel_name"] = c.GetString("channel_name")
other["channel_type"] = c.GetInt("channel_type")
adminInfo := make(map[string]interface{})
adminInfo["use_channel"] = c.GetStringSlice("use_channel")
isMultiKey := common.GetContextKeyBool(c, constant.ContextKeyChannelIsMultiKey)
if isMultiKey {
adminInfo["is_multi_key"] = true
adminInfo["multi_key_index"] = common.GetContextKeyInt(c, constant.ContextKeyChannelMultiKeyIndex)
}
other["admin_info"] = adminInfo
model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.MaskSensitiveError(), tokenId, 0, false, userGroup, other)
}
return err return err
} }
func Relay(c *gin.Context) { func geminiRelayHandler(c *gin.Context, info *relaycommon.RelayInfo) *types.NewAPIError {
relayMode := relayconstant.Path2RelayMode(c.Request.URL.Path) var err *types.NewAPIError
if strings.Contains(c.Request.URL.Path, "embed") {
err = relay.GeminiEmbeddingHandler(c, info)
} else {
err = relay.GeminiHelper(c, info)
}
return err
}
func Relay(c *gin.Context, relayFormat types.RelayFormat) {
requestId := c.GetString(common.RequestIdKey) requestId := c.GetString(common.RequestIdKey)
group := c.GetString("group") group := c.GetString("group")
originalModel := c.GetString("original_model") originalModel := c.GetString("original_model")
var newAPIError *types.NewAPIError
var (
newAPIError *types.NewAPIError
ws *websocket.Conn
)
if relayFormat == types.RelayFormatOpenAIRealtime {
var err error
ws, err = upgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil {
helper.WssError(c, ws, types.NewError(err, types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry()).ToOpenAIError())
return
}
defer ws.Close()
}
defer func() {
if newAPIError != nil {
newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId))
switch relayFormat {
case types.RelayFormatOpenAIRealtime:
helper.WssError(c, ws, newAPIError.ToOpenAIError())
case types.RelayFormatClaude:
c.JSON(newAPIError.StatusCode, gin.H{
"type": "error",
"error": newAPIError.ToClaudeError(),
})
default:
c.JSON(newAPIError.StatusCode, gin.H{
"error": newAPIError.ToOpenAIError(),
})
}
}
}()
request, err := helper.GetAndValidateRequest(c, relayFormat)
if err != nil {
newAPIError = types.NewError(err, types.ErrorCodeInvalidRequest)
return
}
relayInfo, err := relaycommon.GenRelayInfo(c, relayFormat, request, ws)
if err != nil {
newAPIError = types.NewError(err, types.ErrorCodeGenRelayInfoFailed)
return
}
meta := request.GetTokenCountMeta()
if setting.ShouldCheckPromptSensitive() {
contains, words := service.CheckSensitiveText(meta.CombineText)
if contains {
logger.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ", ")))
newAPIError = types.NewError(err, types.ErrorCodeSensitiveWordsDetected)
return
}
}
tokens, err := service.CountRequestToken(c, meta, relayInfo)
if err != nil {
newAPIError = types.NewError(err, types.ErrorCodeCountTokenFailed)
return
}
priceData, err := helper.ModelPriceHelper(c, relayInfo, tokens, meta)
if err != nil {
newAPIError = types.NewError(err, types.ErrorCodeModelPriceError)
return
}
preConsumedQuota, newAPIError := service.PreConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if newAPIError != nil {
return
}
defer func() {
// Only return quota if downstream failed and quota was actually pre-consumed
if newAPIError != nil && preConsumedQuota != 0 {
service.ReturnPreConsumedQuota(c, relayInfo, preConsumedQuota)
}
}()
for i := 0; i <= common.RetryTimes; i++ { for i := 0; i <= common.RetryTimes; i++ {
channel, err := getChannel(c, group, originalModel, i) channel, err := getChannel(c, group, originalModel, i)
if err != nil { if err != nil {
common.LogError(c, err.Error()) logger.LogError(c, err.Error())
newAPIError = err newAPIError = err
break break
} }
newAPIError = relayRequest(c, relayMode, channel) addUsedChannel(c, channel.Id)
requestBody, _ := common.GetRequestBody(c)
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
switch relayFormat {
case types.RelayFormatOpenAIRealtime:
newAPIError = relay.WssHelper(c, relayInfo)
case types.RelayFormatClaude:
newAPIError = relay.ClaudeHelper(c, relayInfo)
case types.RelayFormatGemini:
newAPIError = geminiRelayHandler(c, relayInfo)
default:
newAPIError = relayHandler(c, relayInfo)
}
if newAPIError == nil { if newAPIError == nil {
return // 成功处理请求,直接返回 return
} else {
if constant.ErrorLogEnabled && types.IsRecordErrorLog(newAPIError) {
// 保存错误日志到mysql中
userId := c.GetInt("id")
tokenName := c.GetString("token_name")
modelName := c.GetString("original_model")
tokenId := c.GetInt("token_id")
userGroup := c.GetString("group")
channelId := c.GetInt("channel_id")
other := make(map[string]interface{})
other["error_type"] = newAPIError.GetErrorType()
other["error_code"] = newAPIError.GetErrorCode()
other["status_code"] = newAPIError.StatusCode
other["channel_id"] = channelId
other["channel_name"] = c.GetString("channel_name")
other["channel_type"] = c.GetInt("channel_type")
adminInfo := make(map[string]interface{})
adminInfo["use_channel"] = c.GetStringSlice("use_channel")
isMultiKey := common.GetContextKeyBool(c, constant.ContextKeyChannelIsMultiKey)
if isMultiKey {
adminInfo["is_multi_key"] = true
adminInfo["multi_key_index"] = common.GetContextKeyInt(c, constant.ContextKeyChannelMultiKeyIndex)
}
other["admin_info"] = adminInfo
model.RecordErrorLog(c, userId, channelId, modelName, tokenName, newAPIError.MaskSensitiveError(), tokenId, 0, false, userGroup, other)
}
} }
go processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError) go processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
@@ -107,21 +204,11 @@ func Relay(c *gin.Context) {
break break
} }
} }
useChannel := c.GetStringSlice("use_channel") useChannel := c.GetStringSlice("use_channel")
if len(useChannel) > 1 { if len(useChannel) > 1 {
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]")) retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
common.LogInfo(c, retryLogStr) logger.LogInfo(c, retryLogStr)
}
if newAPIError != nil {
//if newAPIError.StatusCode == http.StatusTooManyRequests {
// common.LogError(c, fmt.Sprintf("origin 429 error: %s", newAPIError.Error()))
// newAPIError.SetMessage("当前分组上游负载已饱和,请稍后再试")
//}
newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId))
c.JSON(newAPIError.StatusCode, gin.H{
"error": newAPIError.ToOpenAIError(),
})
} }
} }
@@ -132,122 +219,6 @@ var upgrader = websocket.Upgrader{
}, },
} }
func WssRelay(c *gin.Context) {
// 将 HTTP 连接升级为 WebSocket 连接
ws, err := upgrader.Upgrade(c.Writer, c.Request, nil)
defer ws.Close()
if err != nil {
helper.WssError(c, ws, types.NewError(err, types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry()).ToOpenAIError())
return
}
relayMode := relayconstant.Path2RelayMode(c.Request.URL.Path)
requestId := c.GetString(common.RequestIdKey)
group := c.GetString("group")
//wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01
originalModel := c.GetString("original_model")
var newAPIError *types.NewAPIError
for i := 0; i <= common.RetryTimes; i++ {
channel, err := getChannel(c, group, originalModel, i)
if err != nil {
common.LogError(c, err.Error())
newAPIError = err
break
}
newAPIError = wssRequest(c, ws, relayMode, channel)
if newAPIError == nil {
return // 成功处理请求,直接返回
}
go processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
if !shouldRetry(c, newAPIError, common.RetryTimes-i) {
break
}
}
useChannel := c.GetStringSlice("use_channel")
if len(useChannel) > 1 {
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
common.LogInfo(c, retryLogStr)
}
if newAPIError != nil {
//if newAPIError.StatusCode == http.StatusTooManyRequests {
// newAPIError.SetMessage("当前分组上游负载已饱和,请稍后再试")
//}
newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId))
helper.WssError(c, ws, newAPIError.ToOpenAIError())
}
}
func RelayClaude(c *gin.Context) {
//relayMode := constant.Path2RelayMode(c.Request.URL.Path)
requestId := c.GetString(common.RequestIdKey)
group := c.GetString("group")
originalModel := c.GetString("original_model")
var newAPIError *types.NewAPIError
for i := 0; i <= common.RetryTimes; i++ {
channel, err := getChannel(c, group, originalModel, i)
if err != nil {
common.LogError(c, err.Error())
newAPIError = err
break
}
newAPIError = claudeRequest(c, channel)
if newAPIError == nil {
return // 成功处理请求,直接返回
}
go processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
if !shouldRetry(c, newAPIError, common.RetryTimes-i) {
break
}
}
useChannel := c.GetStringSlice("use_channel")
if len(useChannel) > 1 {
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
common.LogInfo(c, retryLogStr)
}
if newAPIError != nil {
newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId))
c.JSON(newAPIError.StatusCode, gin.H{
"type": "error",
"error": newAPIError.ToClaudeError(),
})
}
}
func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *types.NewAPIError {
addUsedChannel(c, channel.Id)
requestBody, _ := common.GetRequestBody(c)
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
return relayHandler(c, relayMode)
}
func wssRequest(c *gin.Context, ws *websocket.Conn, relayMode int, channel *model.Channel) *types.NewAPIError {
addUsedChannel(c, channel.Id)
requestBody, _ := common.GetRequestBody(c)
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
return relay.WssHelper(c, ws)
}
func claudeRequest(c *gin.Context, channel *model.Channel) *types.NewAPIError {
addUsedChannel(c, channel.Id)
requestBody, _ := common.GetRequestBody(c)
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
return relay.ClaudeHelper(c)
}
func addUsedChannel(c *gin.Context, channelId int) { func addUsedChannel(c *gin.Context, channelId int) {
useChannel := c.GetStringSlice("use_channel") useChannel := c.GetStringSlice("use_channel")
useChannel = append(useChannel, fmt.Sprintf("%d", channelId)) useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
@@ -270,10 +241,10 @@ func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*m
} }
channel, selectGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount) channel, selectGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount)
if err != nil { if err != nil {
return nil, types.NewError(errors.New(fmt.Sprintf("获取分组 %s 下模型 %s 的可用渠道失败retry: %s", selectGroup, originalModel, err.Error())), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry()) return nil, types.NewError(fmt.Errorf("获取分组 %s 下模型 %s 的可用渠道失败retry: %s", selectGroup, originalModel, err.Error()), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
} }
if channel == nil { if channel == nil {
return nil, types.NewError(errors.New(fmt.Sprintf("分组 %s 下模型 %s 的可用渠道不存在数据库一致性已被破坏retry", selectGroup, originalModel)), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry()) return nil, types.NewError(fmt.Errorf("分组 %s 下模型 %s 的可用渠道不存在数据库一致性已被破坏retry", selectGroup, originalModel), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
} }
newAPIError := middleware.SetupContextForSelectedChannel(c, channel, originalModel) newAPIError := middleware.SetupContextForSelectedChannel(c, channel, originalModel)
if newAPIError != nil { if newAPIError != nil {
@@ -327,42 +298,52 @@ func shouldRetry(c *gin.Context, openaiErr *types.NewAPIError, retryTimes int) b
func processChannelError(c *gin.Context, channelError types.ChannelError, err *types.NewAPIError) { func processChannelError(c *gin.Context, channelError types.ChannelError, err *types.NewAPIError) {
// 不要使用context获取渠道信息异步处理时可能会出现渠道信息不一致的情况 // 不要使用context获取渠道信息异步处理时可能会出现渠道信息不一致的情况
// do not use context to get channel info, there may be inconsistent channel info when processing asynchronously // do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelError.ChannelId, err.StatusCode, err.Error())) logger.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelError.ChannelId, err.StatusCode, err.Error()))
if service.ShouldDisableChannel(channelError.ChannelId, err) && channelError.AutoBan { if service.ShouldDisableChannel(channelError.ChannelId, err) && channelError.AutoBan {
service.DisableChannel(channelError, err.Error()) service.DisableChannel(channelError, err.Error())
} }
} }
func RelayMidjourney(c *gin.Context) { func RelayMidjourney(c *gin.Context) {
relayMode := c.GetInt("relay_mode") relayInfo, err := relaycommon.GenRelayInfo(c, types.RelayFormatMjProxy, nil, nil)
var err *dto.MidjourneyResponse
switch relayMode { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"description": fmt.Sprintf("failed to generate relay info: %s", err.Error()),
"type": "upstream_error",
"code": 4,
})
return
}
var mjErr *dto.MidjourneyResponse
switch relayInfo.RelayMode {
case relayconstant.RelayModeMidjourneyNotify: case relayconstant.RelayModeMidjourneyNotify:
err = relay.RelayMidjourneyNotify(c) mjErr = relay.RelayMidjourneyNotify(c)
case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition: case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition:
err = relay.RelayMidjourneyTask(c, relayMode) mjErr = relay.RelayMidjourneyTask(c, relayInfo.RelayMode)
case relayconstant.RelayModeMidjourneyTaskImageSeed: case relayconstant.RelayModeMidjourneyTaskImageSeed:
err = relay.RelayMidjourneyTaskImageSeed(c) mjErr = relay.RelayMidjourneyTaskImageSeed(c)
case relayconstant.RelayModeSwapFace: case relayconstant.RelayModeSwapFace:
err = relay.RelaySwapFace(c) mjErr = relay.RelaySwapFace(c, relayInfo)
default: default:
err = relay.RelayMidjourneySubmit(c, relayMode) mjErr = relay.RelayMidjourneySubmit(c, relayInfo)
} }
//err = relayMidjourneySubmit(c, relayMode) //err = relayMidjourneySubmit(c, relayMode)
log.Println(err) log.Println(mjErr)
if err != nil { if mjErr != nil {
statusCode := http.StatusBadRequest statusCode := http.StatusBadRequest
if err.Code == 30 { if mjErr.Code == 30 {
err.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。" mjErr.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。"
statusCode = http.StatusTooManyRequests statusCode = http.StatusTooManyRequests
} }
c.JSON(statusCode, gin.H{ c.JSON(statusCode, gin.H{
"description": fmt.Sprintf("%s %s", err.Description, err.Result), "description": fmt.Sprintf("%s %s", mjErr.Description, mjErr.Result),
"type": "upstream_error", "type": "upstream_error",
"code": err.Code, "code": mjErr.Code,
}) })
channelId := c.GetInt("channel_id") channelId := c.GetInt("channel_id")
common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code %d): %s", channelId, statusCode, fmt.Sprintf("%s %s", err.Description, err.Result))) logger.LogError(c, fmt.Sprintf("relay error (channel #%d, status code %d): %s", channelId, statusCode, fmt.Sprintf("%s %s", mjErr.Description, mjErr.Result)))
} }
} }
@@ -404,7 +385,7 @@ func RelayTask(c *gin.Context) {
for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ { for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
channel, newAPIError := getChannel(c, group, originalModel, i) channel, newAPIError := getChannel(c, group, originalModel, i)
if newAPIError != nil { if newAPIError != nil {
common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", newAPIError.Error())) logger.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", newAPIError.Error()))
taskErr = service.TaskErrorWrapperLocal(newAPIError.Err, "get_channel_failed", http.StatusInternalServerError) taskErr = service.TaskErrorWrapperLocal(newAPIError.Err, "get_channel_failed", http.StatusInternalServerError)
break break
} }
@@ -412,7 +393,7 @@ func RelayTask(c *gin.Context) {
useChannel := c.GetStringSlice("use_channel") useChannel := c.GetStringSlice("use_channel")
useChannel = append(useChannel, fmt.Sprintf("%d", channelId)) useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
c.Set("use_channel", useChannel) c.Set("use_channel", useChannel)
common.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i)) logger.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
//middleware.SetupContextForSelectedChannel(c, channel, originalModel) //middleware.SetupContextForSelectedChannel(c, channel, originalModel)
requestBody, _ := common.GetRequestBody(c) requestBody, _ := common.GetRequestBody(c)
@@ -422,7 +403,7 @@ func RelayTask(c *gin.Context) {
useChannel := c.GetStringSlice("use_channel") useChannel := c.GetStringSlice("use_channel")
if len(useChannel) > 1 { if len(useChannel) > 1 {
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]")) retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
common.LogInfo(c, retryLogStr) logger.LogInfo(c, retryLogStr)
} }
if taskErr != nil { if taskErr != nil {
if taskErr.StatusCode == http.StatusTooManyRequests { if taskErr.StatusCode == http.StatusTooManyRequests {

View File

@@ -10,6 +10,7 @@ import (
"one-api/common" "one-api/common"
"one-api/constant" "one-api/constant"
"one-api/dto" "one-api/dto"
"one-api/logger"
"one-api/model" "one-api/model"
"one-api/relay" "one-api/relay"
"sort" "sort"
@@ -54,9 +55,9 @@ func UpdateTaskBulk() {
"progress": "100%", "progress": "100%",
}) })
if err != nil { if err != nil {
common.LogError(ctx, fmt.Sprintf("Fix null task_id task error: %v", err)) logger.LogError(ctx, fmt.Sprintf("Fix null task_id task error: %v", err))
} else { } else {
common.LogInfo(ctx, fmt.Sprintf("Fix null task_id task success: %v", nullTaskIds)) logger.LogInfo(ctx, fmt.Sprintf("Fix null task_id task success: %v", nullTaskIds))
} }
} }
if len(taskChannelM) == 0 { if len(taskChannelM) == 0 {
@@ -86,14 +87,14 @@ func UpdateSunoTaskAll(ctx context.Context, taskChannelM map[int][]string, taskM
for channelId, taskIds := range taskChannelM { for channelId, taskIds := range taskChannelM {
err := updateSunoTaskAll(ctx, channelId, taskIds, taskM) err := updateSunoTaskAll(ctx, channelId, taskIds, taskM)
if err != nil { if err != nil {
common.LogError(ctx, fmt.Sprintf("渠道 #%d 更新异步任务失败: %d", channelId, err.Error())) logger.LogError(ctx, fmt.Sprintf("渠道 #%d 更新异步任务失败: %d", channelId, err.Error()))
} }
} }
return nil return nil
} }
func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error { func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error {
common.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds))) logger.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
if len(taskIds) == 0 { if len(taskIds) == 0 {
return nil return nil
} }
@@ -106,7 +107,7 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas
"progress": "100%", "progress": "100%",
}) })
if err != nil { if err != nil {
common.SysError(fmt.Sprintf("UpdateMidjourneyTask error2: %v", err)) common.SysLog(fmt.Sprintf("UpdateMidjourneyTask error2: %v", err))
} }
return err return err
} }
@@ -118,23 +119,23 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas
"ids": taskIds, "ids": taskIds,
}) })
if err != nil { if err != nil {
common.SysError(fmt.Sprintf("Get Task Do req error: %v", err)) common.SysLog(fmt.Sprintf("Get Task Do req error: %v", err))
return err return err
} }
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
common.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode)) logger.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
return errors.New(fmt.Sprintf("Get Task status code: %d", resp.StatusCode)) return errors.New(fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
} }
defer resp.Body.Close() defer resp.Body.Close()
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
common.SysError(fmt.Sprintf("Get Task parse body error: %v", err)) common.SysLog(fmt.Sprintf("Get Task parse body error: %v", err))
return err return err
} }
var responseItems dto.TaskResponse[[]dto.SunoDataResponse] var responseItems dto.TaskResponse[[]dto.SunoDataResponse]
err = json.Unmarshal(responseBody, &responseItems) err = json.Unmarshal(responseBody, &responseItems)
if err != nil { if err != nil {
common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody))) logger.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
return err return err
} }
if !responseItems.IsSuccess() { if !responseItems.IsSuccess() {
@@ -154,19 +155,19 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas
task.StartTime = lo.If(responseItem.StartTime != 0, responseItem.StartTime).Else(task.StartTime) task.StartTime = lo.If(responseItem.StartTime != 0, responseItem.StartTime).Else(task.StartTime)
task.FinishTime = lo.If(responseItem.FinishTime != 0, responseItem.FinishTime).Else(task.FinishTime) task.FinishTime = lo.If(responseItem.FinishTime != 0, responseItem.FinishTime).Else(task.FinishTime)
if responseItem.FailReason != "" || task.Status == model.TaskStatusFailure { if responseItem.FailReason != "" || task.Status == model.TaskStatusFailure {
common.LogInfo(ctx, task.TaskID+" 构建失败,"+task.FailReason) logger.LogInfo(ctx, task.TaskID+" 构建失败,"+task.FailReason)
task.Progress = "100%" task.Progress = "100%"
//err = model.CacheUpdateUserQuota(task.UserId) ? //err = model.CacheUpdateUserQuota(task.UserId) ?
if err != nil { if err != nil {
common.LogError(ctx, "error update user quota cache: "+err.Error()) logger.LogError(ctx, "error update user quota cache: "+err.Error())
} else { } else {
quota := task.Quota quota := task.Quota
if quota != 0 { if quota != 0 {
err = model.IncreaseUserQuota(task.UserId, quota, false) err = model.IncreaseUserQuota(task.UserId, quota, false)
if err != nil { if err != nil {
common.LogError(ctx, "fail to increase user quota: "+err.Error()) logger.LogError(ctx, "fail to increase user quota: "+err.Error())
} }
logContent := fmt.Sprintf("异步任务执行失败 %s补偿 %s", task.TaskID, common.LogQuota(quota)) logContent := fmt.Sprintf("异步任务执行失败 %s补偿 %s", task.TaskID, logger.LogQuota(quota))
model.RecordLog(task.UserId, model.LogTypeSystem, logContent) model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
} }
} }
@@ -178,7 +179,7 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas
err = task.Update() err = task.Update()
if err != nil { if err != nil {
common.SysError("UpdateMidjourneyTask task error: " + err.Error()) common.SysLog("UpdateMidjourneyTask task error: " + err.Error())
} }
} }
return nil return nil

View File

@@ -8,6 +8,7 @@ import (
"one-api/common" "one-api/common"
"one-api/constant" "one-api/constant"
"one-api/dto" "one-api/dto"
"one-api/logger"
"one-api/model" "one-api/model"
"one-api/relay" "one-api/relay"
"one-api/relay/channel" "one-api/relay/channel"
@@ -18,14 +19,14 @@ import (
func UpdateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) error { func UpdateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
for channelId, taskIds := range taskChannelM { for channelId, taskIds := range taskChannelM {
if err := updateVideoTaskAll(ctx, platform, channelId, taskIds, taskM); err != nil { if err := updateVideoTaskAll(ctx, platform, channelId, taskIds, taskM); err != nil {
common.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error())) logger.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error()))
} }
} }
return nil return nil
} }
func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, channelId int, taskIds []string, taskM map[string]*model.Task) error { func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, channelId int, taskIds []string, taskM map[string]*model.Task) error {
common.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds))) logger.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds)))
if len(taskIds) == 0 { if len(taskIds) == 0 {
return nil return nil
} }
@@ -37,7 +38,7 @@ func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, cha
"progress": "100%", "progress": "100%",
}) })
if errUpdate != nil { if errUpdate != nil {
common.SysError(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate)) common.SysLog(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate))
} }
return fmt.Errorf("CacheGetChannel failed: %w", err) return fmt.Errorf("CacheGetChannel failed: %w", err)
} }
@@ -47,7 +48,7 @@ func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, cha
} }
for _, taskId := range taskIds { for _, taskId := range taskIds {
if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil { if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil {
common.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error())) logger.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error()))
} }
} }
return nil return nil
@@ -61,7 +62,7 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
task := taskM[taskId] task := taskM[taskId]
if task == nil { if task == nil {
common.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId)) logger.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId))
return fmt.Errorf("task %s not found", taskId) return fmt.Errorf("task %s not found", taskId)
} }
resp, err := adaptor.FetchTask(baseURL, channel.Key, map[string]any{ resp, err := adaptor.FetchTask(baseURL, channel.Key, map[string]any{
@@ -112,7 +113,7 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
task.StartTime = now task.StartTime = now
} }
case model.TaskStatusSuccess: case model.TaskStatusSuccess:
task.Progress = "100%" task.Progress = "100%"
if task.FinishTime == 0 { if task.FinishTime == 0 {
task.FinishTime = now task.FinishTime = now
} }
@@ -124,13 +125,13 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
task.FinishTime = now task.FinishTime = now
} }
task.FailReason = taskResult.Reason task.FailReason = taskResult.Reason
common.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason)) logger.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason))
quota := task.Quota quota := task.Quota
if quota != 0 { if quota != 0 {
if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil { if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil {
common.LogError(ctx, "Failed to increase user quota: "+err.Error()) logger.LogError(ctx, "Failed to increase user quota: "+err.Error())
} }
logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, common.LogQuota(quota)) logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, logger.LogQuota(quota))
model.RecordLog(task.UserId, model.LogTypeSystem, logContent) model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
} }
default: default:
@@ -140,7 +141,7 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
task.Progress = taskResult.Progress task.Progress = taskResult.Progress
} }
if err := task.Update(); err != nil { if err := task.Update(); err != nil {
common.SysError("UpdateVideoTask task error: " + err.Error()) common.SysLog("UpdateVideoTask task error: " + err.Error())
} }
return nil return nil

View File

@@ -102,7 +102,7 @@ func AddToken(c *gin.Context) {
"success": false, "success": false,
"message": "生成令牌失败", "message": "生成令牌失败",
}) })
common.SysError("failed to generate token key: " + err.Error()) common.SysLog("failed to generate token key: " + err.Error())
return return
} }
cleanToken := model.Token{ cleanToken := model.Token{

View File

@@ -5,6 +5,7 @@ import (
"log" "log"
"net/url" "net/url"
"one-api/common" "one-api/common"
"one-api/logger"
"one-api/model" "one-api/model"
"one-api/service" "one-api/service"
"one-api/setting" "one-api/setting"
@@ -231,7 +232,7 @@ func EpayNotify(c *gin.Context) {
return return
} }
log.Printf("易支付回调更新用户成功 %v", topUp) log.Printf("易支付回调更新用户成功 %v", topUp)
model.RecordLog(topUp.UserId, model.LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v支付金额%f", common.LogQuota(quotaToAdd), topUp.Money)) model.RecordLog(topUp.UserId, model.LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v支付金额%f", logger.LogQuota(quotaToAdd), topUp.Money))
} }
} else { } else {
log.Printf("易支付异常回调: %v", verifyInfo) log.Printf("易支付异常回调: %v", verifyInfo)

View File

@@ -70,7 +70,7 @@ func Setup2FA(c *gin.Context) {
"success": false, "success": false,
"message": "生成2FA密钥失败", "message": "生成2FA密钥失败",
}) })
common.SysError("生成TOTP密钥失败: " + err.Error()) common.SysLog("生成TOTP密钥失败: " + err.Error())
return return
} }
@@ -81,7 +81,7 @@ func Setup2FA(c *gin.Context) {
"success": false, "success": false,
"message": "生成备用码失败", "message": "生成备用码失败",
}) })
common.SysError("生成备用码失败: " + err.Error()) common.SysLog("生成备用码失败: " + err.Error())
return return
} }
@@ -115,7 +115,7 @@ func Setup2FA(c *gin.Context) {
"success": false, "success": false,
"message": "保存备用码失败", "message": "保存备用码失败",
}) })
common.SysError("保存备用码失败: " + err.Error()) common.SysLog("保存备用码失败: " + err.Error())
return return
} }
@@ -294,7 +294,7 @@ func Get2FAStatus(c *gin.Context) {
// 获取剩余备用码数量 // 获取剩余备用码数量
backupCount, err := model.GetUnusedBackupCodeCount(userId) backupCount, err := model.GetUnusedBackupCodeCount(userId)
if err != nil { if err != nil {
common.SysError("获取备用码数量失败: " + err.Error()) common.SysLog("获取备用码数量失败: " + err.Error())
} else { } else {
status["backup_codes_remaining"] = backupCount status["backup_codes_remaining"] = backupCount
} }
@@ -368,7 +368,7 @@ func RegenerateBackupCodes(c *gin.Context) {
"success": false, "success": false,
"message": "生成备用码失败", "message": "生成备用码失败",
}) })
common.SysError("生成备用码失败: " + err.Error()) common.SysLog("生成备用码失败: " + err.Error())
return return
} }
@@ -378,7 +378,7 @@ func RegenerateBackupCodes(c *gin.Context) {
"success": false, "success": false,
"message": "保存备用码失败", "message": "保存备用码失败",
}) })
common.SysError("保存备用码失败: " + err.Error()) common.SysLog("保存备用码失败: " + err.Error())
return return
} }

View File

@@ -7,6 +7,7 @@ import (
"net/url" "net/url"
"one-api/common" "one-api/common"
"one-api/dto" "one-api/dto"
"one-api/logger"
"one-api/model" "one-api/model"
"one-api/setting" "one-api/setting"
"strconv" "strconv"
@@ -192,7 +193,7 @@ func Register(c *gin.Context) {
"success": false, "success": false,
"message": "数据库错误,请稍后重试", "message": "数据库错误,请稍后重试",
}) })
common.SysError(fmt.Sprintf("CheckUserExistOrDeleted error: %v", err)) common.SysLog(fmt.Sprintf("CheckUserExistOrDeleted error: %v", err))
return return
} }
if exist { if exist {
@@ -235,7 +236,7 @@ func Register(c *gin.Context) {
"success": false, "success": false,
"message": "生成默认令牌失败", "message": "生成默认令牌失败",
}) })
common.SysError("failed to generate token key: " + err.Error()) common.SysLog("failed to generate token key: " + err.Error())
return return
} }
// 生成默认令牌 // 生成默认令牌
@@ -342,7 +343,7 @@ func GenerateAccessToken(c *gin.Context) {
"success": false, "success": false,
"message": "生成失败", "message": "生成失败",
}) })
common.SysError("failed to generate key: " + err.Error()) common.SysLog("failed to generate key: " + err.Error())
return return
} }
user.SetAccessToken(key) user.SetAccessToken(key)
@@ -517,7 +518,7 @@ func UpdateUser(c *gin.Context) {
return return
} }
if originUser.Quota != updatedUser.Quota { if originUser.Quota != updatedUser.Quota {
model.RecordLog(originUser.Id, model.LogTypeManage, fmt.Sprintf("管理员将用户额度从 %s修改为 %s", common.LogQuota(originUser.Quota), common.LogQuota(updatedUser.Quota))) model.RecordLog(originUser.Id, model.LogTypeManage, fmt.Sprintf("管理员将用户额度从 %s修改为 %s", logger.LogQuota(originUser.Quota), logger.LogQuota(updatedUser.Quota)))
} }
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": true, "success": true,

View File

@@ -1,5 +1,11 @@
package dto package dto
import (
"one-api/types"
"github.com/gin-gonic/gin"
)
type AudioRequest struct { type AudioRequest struct {
Model string `json:"model"` Model string `json:"model"`
Input string `json:"input"` Input string `json:"input"`
@@ -8,6 +14,18 @@ type AudioRequest struct {
ResponseFormat string `json:"response_format,omitempty"` ResponseFormat string `json:"response_format,omitempty"`
} }
func (r *AudioRequest) GetTokenCountMeta() *types.TokenCountMeta {
meta := &types.TokenCountMeta{
CombineText: r.Input,
TokenType: types.TokenTypeTextNumber,
}
return meta
}
func (r *AudioRequest) IsStream(c *gin.Context) bool {
return false
}
type AudioResponse struct { type AudioResponse struct {
Text string `json:"text"` Text string `json:"text"`
} }

View File

@@ -5,6 +5,9 @@ import (
"fmt" "fmt"
"one-api/common" "one-api/common"
"one-api/types" "one-api/types"
"strings"
"github.com/gin-gonic/gin"
) )
type ClaudeMetadata struct { type ClaudeMetadata struct {
@@ -81,7 +84,7 @@ func (c *ClaudeMediaMessage) GetStringContent() string {
} }
func (c *ClaudeMediaMessage) GetJsonRowString() string { func (c *ClaudeMediaMessage) GetJsonRowString() string {
jsonContent, _ := json.Marshal(c) jsonContent, _ := common.Marshal(c)
return string(jsonContent) return string(jsonContent)
} }
@@ -199,6 +202,129 @@ type ClaudeRequest struct {
Thinking *Thinking `json:"thinking,omitempty"` Thinking *Thinking `json:"thinking,omitempty"`
} }
func (c *ClaudeRequest) GetTokenCountMeta() *types.TokenCountMeta {
var tokenCountMeta = types.TokenCountMeta{
TokenType: types.TokenTypeTokenizer,
MaxTokens: int(c.MaxTokens),
}
var texts = make([]string, 0)
var fileMeta = make([]*types.FileMeta, 0)
// system
if c.System != nil {
if c.IsStringSystem() {
sys := c.GetStringSystem()
if sys != "" {
texts = append(texts, sys)
}
} else {
systemMedia := c.ParseSystem()
for _, media := range systemMedia {
switch media.Type {
case "text":
texts = append(texts, media.GetText())
case "image":
if media.Source != nil {
data := media.Source.Url
if data == "" {
data = common.Interface2String(media.Source.Data)
}
if data != "" {
fileMeta = append(fileMeta, &types.FileMeta{FileType: types.FileTypeImage, Data: data})
}
}
}
}
}
}
// messages
for _, message := range c.Messages {
tokenCountMeta.MessagesCount++
texts = append(texts, message.Role)
if message.IsStringContent() {
content := message.GetStringContent()
if content != "" {
texts = append(texts, content)
}
continue
}
content, _ := message.ParseContent()
for _, media := range content {
switch media.Type {
case "text":
texts = append(texts, media.GetText())
case "image":
if media.Source != nil {
data := media.Source.Url
if data == "" {
data = common.Interface2String(media.Source.Data)
}
if data != "" {
fileMeta = append(fileMeta, &types.FileMeta{FileType: types.FileTypeImage, Data: data})
}
}
case "tool_use":
if media.Name != "" {
texts = append(texts, media.Name)
}
if media.Input != nil {
b, _ := common.Marshal(media.Input)
texts = append(texts, string(b))
}
case "tool_result":
if media.Content != nil {
b, _ := common.Marshal(media.Content)
texts = append(texts, string(b))
}
}
}
}
// tools
if c.Tools != nil {
tools := c.GetTools()
normalTools, webSearchTools := ProcessTools(tools)
if normalTools != nil {
for _, t := range normalTools {
tokenCountMeta.ToolsCount++
if t.Name != "" {
texts = append(texts, t.Name)
}
if t.Description != "" {
texts = append(texts, t.Description)
}
if t.InputSchema != nil {
b, _ := common.Marshal(t.InputSchema)
texts = append(texts, string(b))
}
}
}
if webSearchTools != nil {
for _, t := range webSearchTools {
tokenCountMeta.ToolsCount++
if t.Name != "" {
texts = append(texts, t.Name)
}
if t.UserLocation != nil {
b, _ := common.Marshal(t.UserLocation)
texts = append(texts, string(b))
}
}
}
}
tokenCountMeta.CombineText = strings.Join(texts, "\n")
tokenCountMeta.Files = fileMeta
return &tokenCountMeta
}
func (claudeRequest *ClaudeRequest) IsStream(c *gin.Context) bool {
return claudeRequest.Stream
}
func (c *ClaudeRequest) SearchToolNameByToolCallId(toolCallId string) string { func (c *ClaudeRequest) SearchToolNameByToolCallId(toolCallId string) string {
for _, message := range c.Messages { for _, message := range c.Messages {
content, _ := message.ParseContent() content, _ := message.ParseContent()

View File

@@ -1,32 +0,0 @@
package dto
import "encoding/json"
type ImageRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt" binding:"required"`
N int `json:"n,omitempty"`
Size string `json:"size,omitempty"`
Quality string `json:"quality,omitempty"`
ResponseFormat string `json:"response_format,omitempty"`
Style json.RawMessage `json:"style,omitempty"`
User json.RawMessage `json:"user,omitempty"`
ExtraFields json.RawMessage `json:"extra_fields,omitempty"`
Background json.RawMessage `json:"background,omitempty"`
Moderation json.RawMessage `json:"moderation,omitempty"`
OutputFormat json.RawMessage `json:"output_format,omitempty"`
OutputCompression json.RawMessage `json:"output_compression,omitempty"`
PartialImages json.RawMessage `json:"partial_images,omitempty"`
// Stream bool `json:"stream,omitempty"`
Watermark *bool `json:"watermark,omitempty"`
}
type ImageResponse struct {
Data []ImageData `json:"data"`
Created int64 `json:"created"`
}
type ImageData struct {
Url string `json:"url"`
B64Json string `json:"b64_json"`
RevisedPrompt string `json:"revised_prompt"`
}

View File

@@ -1,5 +1,12 @@
package dto package dto
import (
"one-api/types"
"strings"
"github.com/gin-gonic/gin"
)
type EmbeddingOptions struct { type EmbeddingOptions struct {
Seed int `json:"seed,omitempty"` Seed int `json:"seed,omitempty"`
Temperature *float64 `json:"temperature,omitempty"` Temperature *float64 `json:"temperature,omitempty"`
@@ -24,9 +31,26 @@ type EmbeddingRequest struct {
PresencePenalty float64 `json:"presence_penalty,omitempty"` PresencePenalty float64 `json:"presence_penalty,omitempty"`
} }
func (r EmbeddingRequest) ParseInput() []string { func (r *EmbeddingRequest) GetTokenCountMeta() *types.TokenCountMeta {
var texts = make([]string, 0)
inputs := r.ParseInput()
for _, input := range inputs {
texts = append(texts, input)
}
return &types.TokenCountMeta{
CombineText: strings.Join(texts, "\n"),
}
}
func (r *EmbeddingRequest) IsStream(c *gin.Context) bool {
return false
}
func (r *EmbeddingRequest) ParseInput() []string {
if r.Input == nil { if r.Input == nil {
return nil return make([]string, 0)
} }
var input []string var input []string
switch r.Input.(type) { switch r.Input.(type) {

View File

@@ -2,7 +2,10 @@ package dto
import ( import (
"encoding/json" "encoding/json"
"github.com/gin-gonic/gin"
"one-api/common" "one-api/common"
"one-api/logger"
"one-api/types"
"strings" "strings"
) )
@@ -14,19 +17,75 @@ type GeminiChatRequest struct {
SystemInstructions *GeminiChatContent `json:"systemInstruction,omitempty"` SystemInstructions *GeminiChatContent `json:"systemInstruction,omitempty"`
} }
func (r *GeminiChatRequest) GetTokenCountMeta() *types.TokenCountMeta {
var files []*types.FileMeta = make([]*types.FileMeta, 0)
var maxTokens int
if r.GenerationConfig.MaxOutputTokens > 0 {
maxTokens = int(r.GenerationConfig.MaxOutputTokens)
}
var inputTexts []string
for _, content := range r.Contents {
for _, part := range content.Parts {
if part.Text != "" {
inputTexts = append(inputTexts, part.Text)
}
if part.InlineData != nil && part.InlineData.Data != "" {
if strings.HasPrefix(part.InlineData.MimeType, "image/") {
files = append(files, &types.FileMeta{
FileType: types.FileTypeImage,
Data: part.InlineData.Data,
})
} else if strings.HasPrefix(part.InlineData.MimeType, "audio/") {
files = append(files, &types.FileMeta{
FileType: types.FileTypeAudio,
Data: part.InlineData.Data,
})
} else if strings.HasPrefix(part.InlineData.MimeType, "video/") {
files = append(files, &types.FileMeta{
FileType: types.FileTypeVideo,
Data: part.InlineData.Data,
})
} else {
files = append(files, &types.FileMeta{
FileType: types.FileTypeFile,
Data: part.InlineData.Data,
})
}
}
}
}
inputText := strings.Join(inputTexts, "\n")
return &types.TokenCountMeta{
CombineText: inputText,
Files: files,
MaxTokens: maxTokens,
}
}
func (r *GeminiChatRequest) IsStream(c *gin.Context) bool {
if c.Query("alt") == "sse" {
return true
}
return false
}
func (r *GeminiChatRequest) GetTools() []GeminiChatTool { func (r *GeminiChatRequest) GetTools() []GeminiChatTool {
var tools []GeminiChatTool var tools []GeminiChatTool
if strings.HasSuffix(string(r.Tools), "[") { if strings.HasSuffix(string(r.Tools), "[") {
// is array // is array
if err := common.Unmarshal(r.Tools, &tools); err != nil { if err := common.Unmarshal(r.Tools, &tools); err != nil {
common.LogError(nil, "error_unmarshalling_tools: "+err.Error()) logger.LogError(nil, "error_unmarshalling_tools: "+err.Error())
return nil return nil
} }
} else if strings.HasPrefix(string(r.Tools), "{") { } else if strings.HasPrefix(string(r.Tools), "{") {
// is object // is object
singleTool := GeminiChatTool{} singleTool := GeminiChatTool{}
if err := common.Unmarshal(r.Tools, &singleTool); err != nil { if err := common.Unmarshal(r.Tools, &singleTool); err != nil {
common.LogError(nil, "error_unmarshalling_single_tool: "+err.Error()) logger.LogError(nil, "error_unmarshalling_single_tool: "+err.Error())
return nil return nil
} }
tools = []GeminiChatTool{singleTool} tools = []GeminiChatTool{singleTool}
@@ -43,7 +102,7 @@ func (r *GeminiChatRequest) SetTools(tools []GeminiChatTool) {
// Marshal the tools to JSON // Marshal the tools to JSON
data, err := common.Marshal(tools) data, err := common.Marshal(tools)
if err != nil { if err != nil {
common.LogError(nil, "error_marshalling_tools: "+err.Error()) logger.LogError(nil, "error_marshalling_tools: "+err.Error())
return return
} }
r.Tools = data r.Tools = data

74
dto/openai_image.go Normal file
View File

@@ -0,0 +1,74 @@
package dto
import (
"encoding/json"
"one-api/types"
"strings"
"github.com/gin-gonic/gin"
)
type ImageRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt" binding:"required"`
N uint `json:"n,omitempty"`
Size string `json:"size,omitempty"`
Quality string `json:"quality,omitempty"`
ResponseFormat string `json:"response_format,omitempty"`
Style json.RawMessage `json:"style,omitempty"`
User json.RawMessage `json:"user,omitempty"`
ExtraFields json.RawMessage `json:"extra_fields,omitempty"`
Background json.RawMessage `json:"background,omitempty"`
Moderation json.RawMessage `json:"moderation,omitempty"`
OutputFormat json.RawMessage `json:"output_format,omitempty"`
OutputCompression json.RawMessage `json:"output_compression,omitempty"`
PartialImages json.RawMessage `json:"partial_images,omitempty"`
// Stream bool `json:"stream,omitempty"`
Watermark *bool `json:"watermark,omitempty"`
}
func (i *ImageRequest) GetTokenCountMeta() *types.TokenCountMeta {
var sizeRatio = 1.0
var qualityRatio = 1.0
if strings.HasPrefix(i.Model, "dall-e") {
// Size
if i.Size == "256x256" {
sizeRatio = 0.4
} else if i.Size == "512x512" {
sizeRatio = 0.45
} else if i.Size == "1024x1024" {
sizeRatio = 1
} else if i.Size == "1024x1792" || i.Size == "1792x1024" {
sizeRatio = 2
}
if i.Model == "dall-e-3" && i.Quality == "hd" {
qualityRatio = 2.0
if i.Size == "1024x1792" || i.Size == "1792x1024" {
qualityRatio = 1.5
}
}
}
// not support token count for dalle
return &types.TokenCountMeta{
CombineText: i.Prompt,
MaxTokens: 1584,
ImagePriceRatio: sizeRatio * qualityRatio * float64(i.N),
}
}
func (i *ImageRequest) IsStream(c *gin.Context) bool {
return false
}
type ImageResponse struct {
Data []ImageData `json:"data"`
Created int64 `json:"created"`
}
type ImageData struct {
Url string `json:"url"`
B64Json string `json:"b64_json"`
RevisedPrompt string `json:"revised_prompt"`
}

View File

@@ -2,8 +2,12 @@ package dto
import ( import (
"encoding/json" "encoding/json"
"fmt"
"one-api/common" "one-api/common"
"one-api/types"
"strings" "strings"
"github.com/gin-gonic/gin"
) )
type ResponseFormat struct { type ResponseFormat struct {
@@ -67,6 +71,116 @@ type GeneralOpenAIRequest struct {
Extra map[string]json.RawMessage `json:"-"` Extra map[string]json.RawMessage `json:"-"`
} }
func (r *GeneralOpenAIRequest) GetTokenCountMeta() *types.TokenCountMeta {
var tokenCountMeta types.TokenCountMeta
var texts = make([]string, 0)
var fileMeta = make([]*types.FileMeta, 0)
if r.Prompt != nil {
switch v := r.Prompt.(type) {
case string:
texts = append(texts, v)
case []any:
for _, item := range v {
if str, ok := item.(string); ok {
texts = append(texts, str)
}
}
default:
texts = append(texts, fmt.Sprintf("%v", r.Prompt))
}
}
if r.Input != nil {
inputs := r.ParseInput()
texts = append(texts, inputs...)
}
if r.MaxCompletionTokens > r.MaxTokens {
tokenCountMeta.MaxTokens = int(r.MaxCompletionTokens)
} else {
tokenCountMeta.MaxTokens = int(r.MaxTokens)
}
for _, message := range r.Messages {
tokenCountMeta.MessagesCount++
texts = append(texts, message.Role)
if message.Content != nil {
if message.Name != nil {
tokenCountMeta.NameCount++
texts = append(texts, *message.Name)
}
arrayContent := message.ParseContent()
for _, m := range arrayContent {
if m.Type == ContentTypeImageURL {
imageUrl := m.GetImageMedia()
if imageUrl != nil {
meta := &types.FileMeta{
FileType: types.FileTypeImage,
}
meta.Data = imageUrl.Url
meta.Detail = imageUrl.Detail
fileMeta = append(fileMeta, meta)
}
} else if m.Type == ContentTypeInputAudio {
inputAudio := m.GetInputAudio()
if inputAudio != nil {
meta := &types.FileMeta{
FileType: types.FileTypeAudio,
}
meta.Data = inputAudio.Data
fileMeta = append(fileMeta, meta)
}
} else if m.Type == ContentTypeFile {
file := m.GetFile()
if file != nil {
meta := &types.FileMeta{
FileType: types.FileTypeFile,
}
meta.Data = file.FileData
fileMeta = append(fileMeta, meta)
}
} else if m.Type == ContentTypeVideoUrl {
videoUrl := m.GetVideoUrl()
if videoUrl != nil {
meta := &types.FileMeta{
FileType: types.FileTypeVideo,
}
meta.Data = videoUrl.Url
fileMeta = append(fileMeta, meta)
}
} else {
texts = append(texts, m.Text)
}
}
}
}
if r.Tools != nil {
openaiTools := r.Tools
for _, tool := range openaiTools {
tokenCountMeta.ToolsCount++
texts = append(texts, tool.Function.Name)
if tool.Function.Description != "" {
texts = append(texts, tool.Function.Description)
}
if tool.Function.Parameters != nil {
texts = append(texts, fmt.Sprintf("%v", tool.Function.Parameters))
}
}
//toolTokens := CountTokenInput(countStr, request.Model)
//tkm += 8
//tkm += toolTokens
}
tokenCountMeta.CombineText = strings.Join(texts, "\n")
tokenCountMeta.Files = fileMeta
return &tokenCountMeta
}
func (r *GeneralOpenAIRequest) IsStream(c *gin.Context) bool {
return r.Stream
}
func (r *GeneralOpenAIRequest) ToMap() map[string]any { func (r *GeneralOpenAIRequest) ToMap() map[string]any {
result := make(map[string]any) result := make(map[string]any)
data, _ := common.Marshal(r) data, _ := common.Marshal(r)
@@ -202,6 +316,21 @@ func (m *MediaContent) GetFile() *MessageFile {
return nil return nil
} }
func (m *MediaContent) GetVideoUrl() *MessageVideoUrl {
if m.VideoUrl != nil {
if _, ok := m.VideoUrl.(*MessageVideoUrl); ok {
return m.VideoUrl.(*MessageVideoUrl)
}
if itemMap, ok := m.VideoUrl.(map[string]any); ok {
out := &MessageVideoUrl{
Url: common.Interface2String(itemMap["url"]),
}
return out
}
}
return nil
}
type MessageImageUrl struct { type MessageImageUrl struct {
Url string `json:"url"` Url string `json:"url"`
Detail string `json:"detail"` Detail string `json:"detail"`
@@ -233,6 +362,7 @@ const (
ContentTypeInputAudio = "input_audio" ContentTypeInputAudio = "input_audio"
ContentTypeFile = "file" ContentTypeFile = "file"
ContentTypeVideoUrl = "video_url" // 阿里百炼视频识别 ContentTypeVideoUrl = "video_url" // 阿里百炼视频识别
//ContentTypeAudioUrl = "audio_url"
) )
func (m *Message) GetPrefix() bool { func (m *Message) GetPrefix() bool {
@@ -623,7 +753,7 @@ type WebSearchOptions struct {
// https://platform.openai.com/docs/api-reference/responses/create // https://platform.openai.com/docs/api-reference/responses/create
type OpenAIResponsesRequest struct { type OpenAIResponsesRequest struct {
Model string `json:"model"` Model string `json:"model"`
Input json.RawMessage `json:"input,omitempty"` Input any `json:"input,omitempty"`
Include json.RawMessage `json:"include,omitempty"` Include json.RawMessage `json:"include,omitempty"`
Instructions json.RawMessage `json:"instructions,omitempty"` Instructions json.RawMessage `json:"instructions,omitempty"`
MaxOutputTokens uint `json:"max_output_tokens,omitempty"` MaxOutputTokens uint `json:"max_output_tokens,omitempty"`
@@ -645,28 +775,145 @@ type OpenAIResponsesRequest struct {
Prompt json.RawMessage `json:"prompt,omitempty"` Prompt json.RawMessage `json:"prompt,omitempty"`
} }
func (r *OpenAIResponsesRequest) GetTokenCountMeta() *types.TokenCountMeta {
var fileMeta = make([]*types.FileMeta, 0)
var texts = make([]string, 0)
if r.Input != nil {
inputs := r.ParseInput()
for _, input := range inputs {
if input.Type == "input_image" {
fileMeta = append(fileMeta, &types.FileMeta{
FileType: types.FileTypeImage,
Data: input.ImageUrl,
Detail: input.Detail,
})
} else if input.Type == "input_file" {
fileMeta = append(fileMeta, &types.FileMeta{
FileType: types.FileTypeFile,
Data: input.FileUrl,
})
} else {
texts = append(texts, input.Text)
}
}
}
if len(r.Instructions) > 0 {
texts = append(texts, string(r.Instructions))
}
if len(r.Metadata) > 0 {
texts = append(texts, string(r.Metadata))
}
if len(r.Text) > 0 {
texts = append(texts, string(r.Text))
}
if len(r.ToolChoice) > 0 {
texts = append(texts, string(r.ToolChoice))
}
if len(r.Prompt) > 0 {
texts = append(texts, string(r.Prompt))
}
if len(r.Tools) > 0 {
toolStr, _ := common.Marshal(r.Tools)
texts = append(texts, string(toolStr))
}
return &types.TokenCountMeta{
CombineText: strings.Join(texts, "\n"),
Files: fileMeta,
MaxTokens: int(r.MaxOutputTokens),
}
}
func (r *OpenAIResponsesRequest) IsStream(c *gin.Context) bool {
return r.Stream
}
type Reasoning struct { type Reasoning struct {
Effort string `json:"effort,omitempty"` Effort string `json:"effort,omitempty"`
Summary string `json:"summary,omitempty"` Summary string `json:"summary,omitempty"`
} }
//type ResponsesToolsCall struct { type MediaInput struct {
// Type string `json:"type"` Type string `json:"type"`
// // Web Search Text string `json:"text,omitempty"`
// UserLocation json.RawMessage `json:"user_location,omitempty"` FileUrl string `json:"file_url,omitempty"`
// SearchContextSize string `json:"search_context_size,omitempty"` ImageUrl string `json:"image_url,omitempty"`
// // File Search Detail string `json:"detail,omitempty"` // 仅 input_image 有效
// VectorStoreIds []string `json:"vector_store_ids,omitempty"` }
// MaxNumResults uint `json:"max_num_results,omitempty"`
// Filters json.RawMessage `json:"filters,omitempty"` // ParseInput parses the Responses API `input` field into a normalized slice of MediaInput.
// // Computer Use // Reference implementation mirrors Message.ParseContent:
// DisplayWidth uint `json:"display_width,omitempty"` // - input can be a string, treated as an input_text item
// DisplayHeight uint `json:"display_height,omitempty"` // - input can be an array of objects with a `type` field
// Environment string `json:"environment,omitempty"` // supported types: input_text, input_image, input_file
// // Function func (r *OpenAIResponsesRequest) ParseInput() []MediaInput {
// Name string `json:"name,omitempty"` if r.Input == nil {
// Description string `json:"description,omitempty"` return nil
// Parameters json.RawMessage `json:"parameters,omitempty"` }
// Function json.RawMessage `json:"function,omitempty"`
// Container json.RawMessage `json:"container,omitempty"` var inputs []MediaInput
//}
// Try string first
if str, ok := r.Input.(string); ok {
inputs = append(inputs, MediaInput{Type: "input_text", Text: str})
return inputs
}
// Try array of parts
if array, ok := r.Input.([]any); ok {
for _, itemAny := range array {
// Already parsed MediaInput
if media, ok := itemAny.(MediaInput); ok {
inputs = append(inputs, media)
continue
}
// Generic map
item, ok := itemAny.(map[string]any)
if !ok {
continue
}
typeVal, ok := item["type"].(string)
if !ok {
continue
}
switch typeVal {
case "input_text":
text, _ := item["text"].(string)
inputs = append(inputs, MediaInput{Type: "input_text", Text: text})
case "input_image":
// image_url may be string or object with url field
var imageUrl string
switch v := item["image_url"].(type) {
case string:
imageUrl = v
case map[string]any:
if url, ok := v["url"].(string); ok {
imageUrl = url
}
}
inputs = append(inputs, MediaInput{Type: "input_image", ImageUrl: imageUrl})
case "input_file":
// file_url may be string or object with url field
var fileUrl string
switch v := item["file_url"].(type) {
case string:
fileUrl = v
case map[string]any:
if url, ok := v["url"].(string); ok {
fileUrl = url
}
}
inputs = append(inputs, MediaInput{Type: "input_file", FileUrl: fileUrl})
}
}
}
return inputs
}

24
dto/request_common.go Normal file
View File

@@ -0,0 +1,24 @@
package dto
import (
"github.com/gin-gonic/gin"
"one-api/types"
)
type Request interface {
GetTokenCountMeta() *types.TokenCountMeta
IsStream(c *gin.Context) bool
}
type BaseRequest struct {
}
func (b *BaseRequest) GetTokenCountMeta() *types.TokenCountMeta {
return &types.TokenCountMeta{
TokenType: types.TokenTypeTokenizer,
}
}
func (b *BaseRequest) IsStream(c *gin.Context) bool {
return false
}

View File

@@ -1,5 +1,12 @@
package dto package dto
import (
"fmt"
"github.com/gin-gonic/gin"
"one-api/types"
"strings"
)
type RerankRequest struct { type RerankRequest struct {
Documents []any `json:"documents"` Documents []any `json:"documents"`
Query string `json:"query"` Query string `json:"query"`
@@ -10,6 +17,26 @@ type RerankRequest struct {
OverLapTokens int `json:"overlap_tokens,omitempty"` OverLapTokens int `json:"overlap_tokens,omitempty"`
} }
func (r *RerankRequest) IsStream(c *gin.Context) bool {
return false
}
func (r *RerankRequest) GetTokenCountMeta() *types.TokenCountMeta {
var texts = make([]string, 0)
for _, document := range r.Documents {
texts = append(texts, fmt.Sprintf("%v", document))
}
if r.Query != "" {
texts = append(texts, r.Query)
}
return &types.TokenCountMeta{
CombineText: strings.Join(texts, "\n"),
}
}
func (r *RerankRequest) GetReturnDocuments() bool { func (r *RerankRequest) GetReturnDocuments() bool {
if r.ReturnDocuments == nil { if r.ReturnDocuments == nil {
return false return false

View File

@@ -1,23 +1,26 @@
package common package logger
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/bytedance/gopkg/util/gopool"
"github.com/gin-gonic/gin"
"io" "io"
"log" "log"
"one-api/common"
"os" "os"
"path/filepath" "path/filepath"
"sync" "sync"
"time" "time"
"github.com/bytedance/gopkg/util/gopool"
"github.com/gin-gonic/gin"
) )
const ( const (
loggerINFO = "INFO" loggerINFO = "INFO"
loggerWarn = "WARN" loggerWarn = "WARN"
loggerError = "ERR" loggerError = "ERR"
loggerDebug = "DEBUG"
) )
const maxLogCount = 1000000 const maxLogCount = 1000000
@@ -27,7 +30,10 @@ var setupLogLock sync.Mutex
var setupLogWorking bool var setupLogWorking bool
func SetupLogger() { func SetupLogger() {
if *LogDir != "" { defer func() {
setupLogWorking = false
}()
if *common.LogDir != "" {
ok := setupLogLock.TryLock() ok := setupLogLock.TryLock()
if !ok { if !ok {
log.Println("setup log is already working") log.Println("setup log is already working")
@@ -35,9 +41,8 @@ func SetupLogger() {
} }
defer func() { defer func() {
setupLogLock.Unlock() setupLogLock.Unlock()
setupLogWorking = false
}() }()
logPath := filepath.Join(*LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102150405"))) logPath := filepath.Join(*common.LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102150405")))
fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil { if err != nil {
log.Fatal("failed to open log file") log.Fatal("failed to open log file")
@@ -47,16 +52,6 @@ func SetupLogger() {
} }
} }
func SysLog(s string) {
t := time.Now()
_, _ = fmt.Fprintf(gin.DefaultWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
}
func SysError(s string) {
t := time.Now()
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
}
func LogInfo(ctx context.Context, msg string) { func LogInfo(ctx context.Context, msg string) {
logHelper(ctx, loggerINFO, msg) logHelper(ctx, loggerINFO, msg)
} }
@@ -69,12 +64,18 @@ func LogError(ctx context.Context, msg string) {
logHelper(ctx, loggerError, msg) logHelper(ctx, loggerError, msg)
} }
func LogDebug(ctx context.Context, msg string) {
if common.DebugEnabled {
logHelper(ctx, loggerDebug, msg)
}
}
func logHelper(ctx context.Context, level string, msg string) { func logHelper(ctx context.Context, level string, msg string) {
writer := gin.DefaultErrorWriter writer := gin.DefaultErrorWriter
if level == loggerINFO { if level == loggerINFO {
writer = gin.DefaultWriter writer = gin.DefaultWriter
} }
id := ctx.Value(RequestIdKey) id := ctx.Value(common.RequestIdKey)
if id == nil { if id == nil {
id = "SYSTEM" id = "SYSTEM"
} }
@@ -90,23 +91,17 @@ func logHelper(ctx context.Context, level string, msg string) {
} }
} }
func FatalLog(v ...any) {
t := time.Now()
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v)
os.Exit(1)
}
func LogQuota(quota int) string { func LogQuota(quota int) string {
if DisplayInCurrencyEnabled { if common.DisplayInCurrencyEnabled {
return fmt.Sprintf("%.6f 额度", float64(quota)/QuotaPerUnit) return fmt.Sprintf("%.6f 额度", float64(quota)/common.QuotaPerUnit)
} else { } else {
return fmt.Sprintf("%d 点额度", quota) return fmt.Sprintf("%d 点额度", quota)
} }
} }
func FormatQuota(quota int) string { func FormatQuota(quota int) string {
if DisplayInCurrencyEnabled { if common.DisplayInCurrencyEnabled {
return fmt.Sprintf("%.6f", float64(quota)/QuotaPerUnit) return fmt.Sprintf("%.6f", float64(quota)/common.QuotaPerUnit)
} else { } else {
return fmt.Sprintf("%d", quota) return fmt.Sprintf("%d", quota)
} }

View File

@@ -8,6 +8,7 @@ import (
"one-api/common" "one-api/common"
"one-api/constant" "one-api/constant"
"one-api/controller" "one-api/controller"
"one-api/logger"
"one-api/middleware" "one-api/middleware"
"one-api/model" "one-api/model"
"one-api/router" "one-api/router"
@@ -60,13 +61,13 @@ func main() {
} }
if common.MemoryCacheEnabled { if common.MemoryCacheEnabled {
common.SysLog("memory cache enabled") common.SysLog("memory cache enabled")
common.SysError(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency)) common.SysLog(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency))
// Add panic recovery and retry for InitChannelCache // Add panic recovery and retry for InitChannelCache
func() { func() {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
common.SysError(fmt.Sprintf("InitChannelCache panic: %v, retrying once", r)) common.SysLog(fmt.Sprintf("InitChannelCache panic: %v, retrying once", r))
// Retry once // Retry once
_, _, fixErr := model.FixAbility() _, _, fixErr := model.FixAbility()
if fixErr != nil { if fixErr != nil {
@@ -125,7 +126,7 @@ func main() {
// Initialize HTTP server // Initialize HTTP server
server := gin.New() server := gin.New()
server.Use(gin.CustomRecovery(func(c *gin.Context, err any) { server.Use(gin.CustomRecovery(func(c *gin.Context, err any) {
common.SysError(fmt.Sprintf("panic detected: %v", err)) common.SysLog(fmt.Sprintf("panic detected: %v", err))
c.JSON(http.StatusInternalServerError, gin.H{ c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{ "error": gin.H{
"message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/Calcium-Ion/new-api", err), "message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/Calcium-Ion/new-api", err),
@@ -171,7 +172,7 @@ func InitResources() error {
// 加载环境变量 // 加载环境变量
common.InitEnv() common.InitEnv()
common.SetupLogger() logger.SetupLogger()
// Initialize model settings // Initialize model settings
ratio_setting.InitRatioSettings() ratio_setting.InitRatioSettings()

View File

@@ -107,11 +107,11 @@ func Distribute() func(c *gin.Context) {
// common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id)) // common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
// message = "数据库一致性已被破坏,请联系管理员" // message = "数据库一致性已被破坏,请联系管理员"
//} //}
abortWithOpenAiMessage(c, http.StatusServiceUnavailable, message) abortWithOpenAiMessage(c, http.StatusServiceUnavailable, message, string(types.ErrorCodeModelNotFound))
return return
} }
if channel == nil { if channel == nil {
abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("分组 %s 下模型 %s 无可用渠道distributor", userGroup, modelRequest.Model)) abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("分组 %s 下模型 %s 无可用渠道distributor", userGroup, modelRequest.Model), string(types.ErrorCodeModelNotFound))
return return
} }
} }

View File

@@ -12,8 +12,8 @@ func RelayPanicRecover() gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
defer func() { defer func() {
if err := recover(); err != nil { if err := recover(); err != nil {
common.SysError(fmt.Sprintf("panic detected: %v", err)) common.SysLog(fmt.Sprintf("panic detected: %v", err))
common.SysError(fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack()))) common.SysLog(fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack())))
c.JSON(http.StatusInternalServerError, gin.H{ c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{ "error": gin.H{
"message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/Calcium-Ion/new-api", err), "message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/Calcium-Ion/new-api", err),

View File

@@ -37,7 +37,7 @@ func TurnstileCheck() gin.HandlerFunc {
"remoteip": {c.ClientIP()}, "remoteip": {c.ClientIP()},
}) })
if err != nil { if err != nil {
common.SysError(err.Error()) common.SysLog(err.Error())
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": err.Error(), "message": err.Error(),
@@ -49,7 +49,7 @@ func TurnstileCheck() gin.HandlerFunc {
var res turnstileCheckResponse var res turnstileCheckResponse
err = json.NewDecoder(rawRes.Body).Decode(&res) err = json.NewDecoder(rawRes.Body).Decode(&res)
if err != nil { if err != nil {
common.SysError(err.Error()) common.SysLog(err.Error())
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": err.Error(), "message": err.Error(),

View File

@@ -4,18 +4,24 @@ import (
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"one-api/common" "one-api/common"
"one-api/logger"
) )
func abortWithOpenAiMessage(c *gin.Context, statusCode int, message string) { func abortWithOpenAiMessage(c *gin.Context, statusCode int, message string, code ...string) {
codeStr := ""
if len(code) > 0 {
codeStr = code[0]
}
userId := c.GetInt("id") userId := c.GetInt("id")
c.JSON(statusCode, gin.H{ c.JSON(statusCode, gin.H{
"error": gin.H{ "error": gin.H{
"message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)), "message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)),
"type": "new_api_error", "type": "new_api_error",
"code": codeStr,
}, },
}) })
c.Abort() c.Abort()
common.LogError(c.Request.Context(), fmt.Sprintf("user %d | %s", userId, message)) logger.LogError(c.Request.Context(), fmt.Sprintf("user %d | %s", userId, message))
} }
func abortWithMidjourneyMessage(c *gin.Context, statusCode int, code int, description string) { func abortWithMidjourneyMessage(c *gin.Context, statusCode int, code int, description string) {
@@ -25,5 +31,5 @@ func abortWithMidjourneyMessage(c *gin.Context, statusCode int, code int, descri
"code": code, "code": code,
}) })
c.Abort() c.Abort()
common.LogError(c.Request.Context(), description) logger.LogError(c.Request.Context(), description)
} }

View File

@@ -294,13 +294,13 @@ func FixAbility() (int, int, error) {
if common.UsingSQLite { if common.UsingSQLite {
err := DB.Exec("DELETE FROM abilities").Error err := DB.Exec("DELETE FROM abilities").Error
if err != nil { if err != nil {
common.SysError(fmt.Sprintf("Delete abilities failed: %s", err.Error())) common.SysLog(fmt.Sprintf("Delete abilities failed: %s", err.Error()))
return 0, 0, err return 0, 0, err
} }
} else { } else {
err := DB.Exec("TRUNCATE TABLE abilities").Error err := DB.Exec("TRUNCATE TABLE abilities").Error
if err != nil { if err != nil {
common.SysError(fmt.Sprintf("Truncate abilities failed: %s", err.Error())) common.SysLog(fmt.Sprintf("Truncate abilities failed: %s", err.Error()))
return 0, 0, err return 0, 0, err
} }
} }
@@ -320,7 +320,7 @@ func FixAbility() (int, int, error) {
// Delete all abilities of this channel // Delete all abilities of this channel
err = DB.Where("channel_id IN ?", ids).Delete(&Ability{}).Error err = DB.Where("channel_id IN ?", ids).Delete(&Ability{}).Error
if err != nil { if err != nil {
common.SysError(fmt.Sprintf("Delete abilities failed: %s", err.Error())) common.SysLog(fmt.Sprintf("Delete abilities failed: %s", err.Error()))
failCount += len(chunk) failCount += len(chunk)
continue continue
} }
@@ -328,7 +328,7 @@ func FixAbility() (int, int, error) {
for _, channel := range chunk { for _, channel := range chunk {
err = channel.AddAbilities(nil) err = channel.AddAbilities(nil)
if err != nil { if err != nil {
common.SysError(fmt.Sprintf("Add abilities for channel %d failed: %s", channel.Id, err.Error())) common.SysLog(fmt.Sprintf("Add abilities for channel %d failed: %s", channel.Id, err.Error()))
failCount++ failCount++
} else { } else {
successCount++ successCount++

View File

@@ -209,7 +209,7 @@ func (channel *Channel) GetOtherInfo() map[string]interface{} {
if channel.OtherInfo != "" { if channel.OtherInfo != "" {
err := common.Unmarshal([]byte(channel.OtherInfo), &otherInfo) err := common.Unmarshal([]byte(channel.OtherInfo), &otherInfo)
if err != nil { if err != nil {
common.SysError("failed to unmarshal other info: " + err.Error()) common.SysLog(fmt.Sprintf("failed to unmarshal other info: channel_id=%d, tag=%s, name=%s, error=%v", channel.Id, channel.GetTag(), channel.Name, err))
} }
} }
return otherInfo return otherInfo
@@ -218,7 +218,7 @@ func (channel *Channel) GetOtherInfo() map[string]interface{} {
func (channel *Channel) SetOtherInfo(otherInfo map[string]interface{}) { func (channel *Channel) SetOtherInfo(otherInfo map[string]interface{}) {
otherInfoBytes, err := json.Marshal(otherInfo) otherInfoBytes, err := json.Marshal(otherInfo)
if err != nil { if err != nil {
common.SysError("failed to marshal other info: " + err.Error()) common.SysLog(fmt.Sprintf("failed to marshal other info: channel_id=%d, tag=%s, name=%s, error=%v", channel.Id, channel.GetTag(), channel.Name, err))
return return
} }
channel.OtherInfo = string(otherInfoBytes) channel.OtherInfo = string(otherInfoBytes)
@@ -406,7 +406,11 @@ func (channel *Channel) GetBaseURL() string {
if channel.BaseURL == nil { if channel.BaseURL == nil {
return "" return ""
} }
return *channel.BaseURL url := *channel.BaseURL
if url == "" {
url = constant.ChannelBaseURLs[channel.Type]
}
return url
} }
func (channel *Channel) GetModelMapping() string { func (channel *Channel) GetModelMapping() string {
@@ -488,7 +492,7 @@ func (channel *Channel) UpdateResponseTime(responseTime int64) {
ResponseTime: int(responseTime), ResponseTime: int(responseTime),
}).Error }).Error
if err != nil { if err != nil {
common.SysError("failed to update response time: " + err.Error()) common.SysLog(fmt.Sprintf("failed to update response time: channel_id=%d, error=%v", channel.Id, err))
} }
} }
@@ -498,7 +502,7 @@ func (channel *Channel) UpdateBalance(balance float64) {
Balance: balance, Balance: balance,
}).Error }).Error
if err != nil { if err != nil {
common.SysError("failed to update balance: " + err.Error()) common.SysLog(fmt.Sprintf("failed to update balance: channel_id=%d, error=%v", channel.Id, err))
} }
} }
@@ -614,7 +618,7 @@ func UpdateChannelStatus(channelId int, usingKey string, status int, reason stri
if shouldUpdateAbilities { if shouldUpdateAbilities {
err := UpdateAbilityStatus(channelId, status == common.ChannelStatusEnabled) err := UpdateAbilityStatus(channelId, status == common.ChannelStatusEnabled)
if err != nil { if err != nil {
common.SysError("failed to update ability status: " + err.Error()) common.SysLog(fmt.Sprintf("failed to update ability status: channel_id=%d, error=%v", channelId, err))
} }
} }
}() }()
@@ -642,7 +646,7 @@ func UpdateChannelStatus(channelId int, usingKey string, status int, reason stri
} }
err = channel.Save() err = channel.Save()
if err != nil { if err != nil {
common.SysError("failed to update channel status: " + err.Error()) common.SysLog(fmt.Sprintf("failed to update channel status: channel_id=%d, status=%d, error=%v", channel.Id, status, err))
return false return false
} }
} }
@@ -704,7 +708,7 @@ func EditChannelByTag(tag string, newTag *string, modelMapping *string, models *
for _, channel := range channels { for _, channel := range channels {
err = channel.UpdateAbilities(nil) err = channel.UpdateAbilities(nil)
if err != nil { if err != nil {
common.SysError("failed to update abilities: " + err.Error()) common.SysLog(fmt.Sprintf("failed to update abilities: channel_id=%d, tag=%s, error=%v", channel.Id, channel.GetTag(), err))
} }
} }
} }
@@ -728,7 +732,7 @@ func UpdateChannelUsedQuota(id int, quota int) {
func updateChannelUsedQuota(id int, quota int) { func updateChannelUsedQuota(id int, quota int) {
err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error
if err != nil { if err != nil {
common.SysError("failed to update channel used quota: " + err.Error()) common.SysLog(fmt.Sprintf("failed to update channel used quota: channel_id=%d, delta_quota=%d, error=%v", id, quota, err))
} }
} }
@@ -821,7 +825,7 @@ func (channel *Channel) GetSetting() dto.ChannelSettings {
if channel.Setting != nil && *channel.Setting != "" { if channel.Setting != nil && *channel.Setting != "" {
err := common.Unmarshal([]byte(*channel.Setting), &setting) err := common.Unmarshal([]byte(*channel.Setting), &setting)
if err != nil { if err != nil {
common.SysError("failed to unmarshal setting: " + err.Error()) common.SysLog(fmt.Sprintf("failed to unmarshal setting: channel_id=%d, error=%v", channel.Id, err))
channel.Setting = nil // 清空设置以避免后续错误 channel.Setting = nil // 清空设置以避免后续错误
_ = channel.Save() // 保存修改 _ = channel.Save() // 保存修改
} }
@@ -832,7 +836,7 @@ func (channel *Channel) GetSetting() dto.ChannelSettings {
func (channel *Channel) SetSetting(setting dto.ChannelSettings) { func (channel *Channel) SetSetting(setting dto.ChannelSettings) {
settingBytes, err := common.Marshal(setting) settingBytes, err := common.Marshal(setting)
if err != nil { if err != nil {
common.SysError("failed to marshal setting: " + err.Error()) common.SysLog(fmt.Sprintf("failed to marshal setting: channel_id=%d, error=%v", channel.Id, err))
return return
} }
channel.Setting = common.GetPointer[string](string(settingBytes)) channel.Setting = common.GetPointer[string](string(settingBytes))
@@ -843,7 +847,7 @@ func (channel *Channel) GetOtherSettings() dto.ChannelOtherSettings {
if channel.OtherSettings != "" { if channel.OtherSettings != "" {
err := common.UnmarshalJsonStr(channel.OtherSettings, &setting) err := common.UnmarshalJsonStr(channel.OtherSettings, &setting)
if err != nil { if err != nil {
common.SysError("failed to unmarshal setting: " + err.Error()) common.SysLog(fmt.Sprintf("failed to unmarshal setting: channel_id=%d, error=%v", channel.Id, err))
channel.OtherSettings = "{}" // 清空设置以避免后续错误 channel.OtherSettings = "{}" // 清空设置以避免后续错误
_ = channel.Save() // 保存修改 _ = channel.Save() // 保存修改
} }
@@ -854,7 +858,7 @@ func (channel *Channel) GetOtherSettings() dto.ChannelOtherSettings {
func (channel *Channel) SetOtherSettings(setting dto.ChannelOtherSettings) { func (channel *Channel) SetOtherSettings(setting dto.ChannelOtherSettings) {
settingBytes, err := common.Marshal(setting) settingBytes, err := common.Marshal(setting)
if err != nil { if err != nil {
common.SysError("failed to marshal setting: " + err.Error()) common.SysLog(fmt.Sprintf("failed to marshal setting: channel_id=%d, error=%v", channel.Id, err))
return return
} }
channel.OtherSettings = string(settingBytes) channel.OtherSettings = string(settingBytes)
@@ -865,7 +869,7 @@ func (channel *Channel) GetParamOverride() map[string]interface{} {
if channel.ParamOverride != nil && *channel.ParamOverride != "" { if channel.ParamOverride != nil && *channel.ParamOverride != "" {
err := common.Unmarshal([]byte(*channel.ParamOverride), &paramOverride) err := common.Unmarshal([]byte(*channel.ParamOverride), &paramOverride)
if err != nil { if err != nil {
common.SysError("failed to unmarshal param override: " + err.Error()) common.SysLog(fmt.Sprintf("failed to unmarshal param override: channel_id=%d, error=%v", channel.Id, err))
} }
} }
return paramOverride return paramOverride

View File

@@ -4,6 +4,7 @@ import (
"context" "context"
"fmt" "fmt"
"one-api/common" "one-api/common"
"one-api/logger"
"os" "os"
"strings" "strings"
"time" "time"
@@ -87,13 +88,13 @@ func RecordLog(userId int, logType int, content string) {
} }
err := LOG_DB.Create(log).Error err := LOG_DB.Create(log).Error
if err != nil { if err != nil {
common.SysError("failed to record log: " + err.Error()) common.SysLog("failed to record log: " + err.Error())
} }
} }
func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string, tokenName string, content string, tokenId int, useTimeSeconds int, func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string, tokenName string, content string, tokenId int, useTimeSeconds int,
isStream bool, group string, other map[string]interface{}) { isStream bool, group string, other map[string]interface{}) {
common.LogInfo(c, fmt.Sprintf("record error log: userId=%d, channelId=%d, modelName=%s, tokenName=%s, content=%s", userId, channelId, modelName, tokenName, content)) logger.LogInfo(c, fmt.Sprintf("record error log: userId=%d, channelId=%d, modelName=%s, tokenName=%s, content=%s", userId, channelId, modelName, tokenName, content))
username := c.GetString("username") username := c.GetString("username")
otherStr := common.MapToJsonStr(other) otherStr := common.MapToJsonStr(other)
// 判断是否需要记录 IP // 判断是否需要记录 IP
@@ -129,7 +130,7 @@ func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string,
} }
err := LOG_DB.Create(log).Error err := LOG_DB.Create(log).Error
if err != nil { if err != nil {
common.LogError(c, "failed to record log: "+err.Error()) logger.LogError(c, "failed to record log: "+err.Error())
} }
} }
@@ -142,7 +143,6 @@ type RecordConsumeLogParams struct {
Quota int `json:"quota"` Quota int `json:"quota"`
Content string `json:"content"` Content string `json:"content"`
TokenId int `json:"token_id"` TokenId int `json:"token_id"`
UserQuota int `json:"user_quota"`
UseTimeSeconds int `json:"use_time_seconds"` UseTimeSeconds int `json:"use_time_seconds"`
IsStream bool `json:"is_stream"` IsStream bool `json:"is_stream"`
Group string `json:"group"` Group string `json:"group"`
@@ -150,7 +150,7 @@ type RecordConsumeLogParams struct {
} }
func RecordConsumeLog(c *gin.Context, userId int, params RecordConsumeLogParams) { func RecordConsumeLog(c *gin.Context, userId int, params RecordConsumeLogParams) {
common.LogInfo(c, fmt.Sprintf("record consume log: userId=%d, params=%s", userId, common.GetJsonString(params))) logger.LogInfo(c, fmt.Sprintf("record consume log: userId=%d, params=%s", userId, common.GetJsonString(params)))
if !common.LogConsumeEnabled { if !common.LogConsumeEnabled {
return return
} }
@@ -189,7 +189,7 @@ func RecordConsumeLog(c *gin.Context, userId int, params RecordConsumeLogParams)
} }
err := LOG_DB.Create(log).Error err := LOG_DB.Create(log).Error
if err != nil { if err != nil {
common.LogError(c, "failed to record log: "+err.Error()) logger.LogError(c, "failed to record log: "+err.Error())
} }
if common.DataExportEnabled { if common.DataExportEnabled {
gopool.Go(func() { gopool.Go(func() {

View File

@@ -150,7 +150,7 @@ func loadOptionsFromDatabase() {
for _, option := range options { for _, option := range options {
err := updateOptionMap(option.Key, option.Value) err := updateOptionMap(option.Key, option.Value)
if err != nil { if err != nil {
common.SysError("failed to update option map: " + err.Error()) common.SysLog("failed to update option map: " + err.Error())
} }
} }
} }

View File

@@ -92,7 +92,7 @@ func updatePricing() {
//modelRatios := common.GetModelRatios() //modelRatios := common.GetModelRatios()
enableAbilities, err := GetAllEnableAbilityWithChannels() enableAbilities, err := GetAllEnableAbilityWithChannels()
if err != nil { if err != nil {
common.SysError(fmt.Sprintf("GetAllEnableAbilityWithChannels error: %v", err)) common.SysLog(fmt.Sprintf("GetAllEnableAbilityWithChannels error: %v", err))
return return
} }
// 预加载模型元数据与供应商一次,避免循环查询 // 预加载模型元数据与供应商一次,避免循环查询

View File

@@ -4,6 +4,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"one-api/common" "one-api/common"
"one-api/logger"
"strconv" "strconv"
"gorm.io/gorm" "gorm.io/gorm"
@@ -148,7 +149,7 @@ func Redeem(key string, userId int) (quota int, err error) {
if err != nil { if err != nil {
return 0, errors.New("兑换失败," + err.Error()) return 0, errors.New("兑换失败," + err.Error())
} }
RecordLog(userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s兑换码ID %d", common.LogQuota(redemption.Quota), redemption.Id)) RecordLog(userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s兑换码ID %d", logger.LogQuota(redemption.Quota), redemption.Id))
return redemption.Quota, nil return redemption.Quota, nil
} }

View File

@@ -91,7 +91,7 @@ func ValidateUserToken(key string) (token *Token, err error) {
token.Status = common.TokenStatusExpired token.Status = common.TokenStatusExpired
err := token.SelectUpdate() err := token.SelectUpdate()
if err != nil { if err != nil {
common.SysError("failed to update token status" + err.Error()) common.SysLog("failed to update token status" + err.Error())
} }
} }
return token, errors.New("该令牌已过期") return token, errors.New("该令牌已过期")
@@ -102,7 +102,7 @@ func ValidateUserToken(key string) (token *Token, err error) {
token.Status = common.TokenStatusExhausted token.Status = common.TokenStatusExhausted
err := token.SelectUpdate() err := token.SelectUpdate()
if err != nil { if err != nil {
common.SysError("failed to update token status" + err.Error()) common.SysLog("failed to update token status" + err.Error())
} }
} }
keyPrefix := key[:3] keyPrefix := key[:3]
@@ -134,7 +134,7 @@ func GetTokenById(id int) (*Token, error) {
if shouldUpdateRedis(true, err) { if shouldUpdateRedis(true, err) {
gopool.Go(func() { gopool.Go(func() {
if err := cacheSetToken(token); err != nil { if err := cacheSetToken(token); err != nil {
common.SysError("failed to update user status cache: " + err.Error()) common.SysLog("failed to update user status cache: " + err.Error())
} }
}) })
} }
@@ -147,7 +147,7 @@ func GetTokenByKey(key string, fromDB bool) (token *Token, err error) {
if shouldUpdateRedis(fromDB, err) && token != nil { if shouldUpdateRedis(fromDB, err) && token != nil {
gopool.Go(func() { gopool.Go(func() {
if err := cacheSetToken(*token); err != nil { if err := cacheSetToken(*token); err != nil {
common.SysError("failed to update user status cache: " + err.Error()) common.SysLog("failed to update user status cache: " + err.Error())
} }
}) })
} }
@@ -178,7 +178,7 @@ func (token *Token) Update() (err error) {
gopool.Go(func() { gopool.Go(func() {
err := cacheSetToken(*token) err := cacheSetToken(*token)
if err != nil { if err != nil {
common.SysError("failed to update token cache: " + err.Error()) common.SysLog("failed to update token cache: " + err.Error())
} }
}) })
} }
@@ -194,7 +194,7 @@ func (token *Token) SelectUpdate() (err error) {
gopool.Go(func() { gopool.Go(func() {
err := cacheSetToken(*token) err := cacheSetToken(*token)
if err != nil { if err != nil {
common.SysError("failed to update token cache: " + err.Error()) common.SysLog("failed to update token cache: " + err.Error())
} }
}) })
} }
@@ -209,7 +209,7 @@ func (token *Token) Delete() (err error) {
gopool.Go(func() { gopool.Go(func() {
err := cacheDeleteToken(token.Key) err := cacheDeleteToken(token.Key)
if err != nil { if err != nil {
common.SysError("failed to delete token cache: " + err.Error()) common.SysLog("failed to delete token cache: " + err.Error())
} }
}) })
} }
@@ -269,7 +269,7 @@ func IncreaseTokenQuota(id int, key string, quota int) (err error) {
gopool.Go(func() { gopool.Go(func() {
err := cacheIncrTokenQuota(key, int64(quota)) err := cacheIncrTokenQuota(key, int64(quota))
if err != nil { if err != nil {
common.SysError("failed to increase token quota: " + err.Error()) common.SysLog("failed to increase token quota: " + err.Error())
} }
}) })
} }
@@ -299,7 +299,7 @@ func DecreaseTokenQuota(id int, key string, quota int) (err error) {
gopool.Go(func() { gopool.Go(func() {
err := cacheDecrTokenQuota(key, int64(quota)) err := cacheDecrTokenQuota(key, int64(quota))
if err != nil { if err != nil {
common.SysError("failed to decrease token quota: " + err.Error()) common.SysLog("failed to decrease token quota: " + err.Error())
} }
}) })
} }

View File

@@ -4,6 +4,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"one-api/common" "one-api/common"
"one-api/logger"
"gorm.io/gorm" "gorm.io/gorm"
) )
@@ -94,7 +95,7 @@ func Recharge(referenceId string, customerId string) (err error) {
return errors.New("充值失败," + err.Error()) return errors.New("充值失败," + err.Error())
} }
RecordLog(topUp.UserId, LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v支付金额%d", common.FormatQuota(int(quota)), topUp.Amount)) RecordLog(topUp.UserId, LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v支付金额%d", logger.FormatQuota(int(quota)), topUp.Amount))
return nil return nil
} }

View File

@@ -243,7 +243,7 @@ func (t *TwoFA) ValidateTOTPAndUpdateUsage(code string) (bool, error) {
if !common.ValidateTOTPCode(t.Secret, code) { if !common.ValidateTOTPCode(t.Secret, code) {
// 增加失败次数 // 增加失败次数
if err := t.IncrementFailedAttempts(); err != nil { if err := t.IncrementFailedAttempts(); err != nil {
common.SysError("更新2FA失败次数失败: " + err.Error()) common.SysLog("更新2FA失败次数失败: " + err.Error())
} }
return false, nil return false, nil
} }
@@ -255,7 +255,7 @@ func (t *TwoFA) ValidateTOTPAndUpdateUsage(code string) (bool, error) {
t.LastUsedAt = &now t.LastUsedAt = &now
if err := t.Update(); err != nil { if err := t.Update(); err != nil {
common.SysError("更新2FA使用记录失败: " + err.Error()) common.SysLog("更新2FA使用记录失败: " + err.Error())
} }
return true, nil return true, nil
@@ -277,7 +277,7 @@ func (t *TwoFA) ValidateBackupCodeAndUpdateUsage(code string) (bool, error) {
if !valid { if !valid {
// 增加失败次数 // 增加失败次数
if err := t.IncrementFailedAttempts(); err != nil { if err := t.IncrementFailedAttempts(); err != nil {
common.SysError("更新2FA失败次数失败: " + err.Error()) common.SysLog("更新2FA失败次数失败: " + err.Error())
} }
return false, nil return false, nil
} }
@@ -289,7 +289,7 @@ func (t *TwoFA) ValidateBackupCodeAndUpdateUsage(code string) (bool, error) {
t.LastUsedAt = &now t.LastUsedAt = &now
if err := t.Update(); err != nil { if err := t.Update(); err != nil {
common.SysError("更新2FA使用记录失败: " + err.Error()) common.SysLog("更新2FA使用记录失败: " + err.Error())
} }
return true, nil return true, nil

View File

@@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"one-api/common" "one-api/common"
"one-api/dto" "one-api/dto"
"one-api/logger"
"strconv" "strconv"
"strings" "strings"
@@ -75,7 +76,7 @@ func (user *User) GetSetting() dto.UserSetting {
if user.Setting != "" { if user.Setting != "" {
err := json.Unmarshal([]byte(user.Setting), &setting) err := json.Unmarshal([]byte(user.Setting), &setting)
if err != nil { if err != nil {
common.SysError("failed to unmarshal setting: " + err.Error()) common.SysLog("failed to unmarshal setting: " + err.Error())
} }
} }
return setting return setting
@@ -84,7 +85,7 @@ func (user *User) GetSetting() dto.UserSetting {
func (user *User) SetSetting(setting dto.UserSetting) { func (user *User) SetSetting(setting dto.UserSetting) {
settingBytes, err := json.Marshal(setting) settingBytes, err := json.Marshal(setting)
if err != nil { if err != nil {
common.SysError("failed to marshal setting: " + err.Error()) common.SysLog("failed to marshal setting: " + err.Error())
return return
} }
user.Setting = string(settingBytes) user.Setting = string(settingBytes)
@@ -274,7 +275,7 @@ func inviteUser(inviterId int) (err error) {
func (user *User) TransferAffQuotaToQuota(quota int) error { func (user *User) TransferAffQuotaToQuota(quota int) error {
// 检查quota是否小于最小额度 // 检查quota是否小于最小额度
if float64(quota) < common.QuotaPerUnit { if float64(quota) < common.QuotaPerUnit {
return fmt.Errorf("转移额度最小为%s", common.LogQuota(int(common.QuotaPerUnit))) return fmt.Errorf("转移额度最小为%s", logger.LogQuota(int(common.QuotaPerUnit)))
} }
// 开始数据库事务 // 开始数据库事务
@@ -324,16 +325,16 @@ func (user *User) Insert(inviterId int) error {
return result.Error return result.Error
} }
if common.QuotaForNewUser > 0 { if common.QuotaForNewUser > 0 {
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", common.LogQuota(common.QuotaForNewUser))) RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", logger.LogQuota(common.QuotaForNewUser)))
} }
if inviterId != 0 { if inviterId != 0 {
if common.QuotaForInvitee > 0 { if common.QuotaForInvitee > 0 {
_ = IncreaseUserQuota(user.Id, common.QuotaForInvitee, true) _ = IncreaseUserQuota(user.Id, common.QuotaForInvitee, true)
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(common.QuotaForInvitee))) RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", logger.LogQuota(common.QuotaForInvitee)))
} }
if common.QuotaForInviter > 0 { if common.QuotaForInviter > 0 {
//_ = IncreaseUserQuota(inviterId, common.QuotaForInviter) //_ = IncreaseUserQuota(inviterId, common.QuotaForInviter)
RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", common.LogQuota(common.QuotaForInviter))) RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", logger.LogQuota(common.QuotaForInviter)))
_ = inviteUser(inviterId) _ = inviteUser(inviterId)
} }
} }
@@ -517,7 +518,7 @@ func IsAdmin(userId int) bool {
var user User var user User
err := DB.Where("id = ?", userId).Select("role").Find(&user).Error err := DB.Where("id = ?", userId).Select("role").Find(&user).Error
if err != nil { if err != nil {
common.SysError("no such user " + err.Error()) common.SysLog("no such user " + err.Error())
return false return false
} }
return user.Role >= common.RoleAdminUser return user.Role >= common.RoleAdminUser
@@ -572,7 +573,7 @@ func GetUserQuota(id int, fromDB bool) (quota int, err error) {
if shouldUpdateRedis(fromDB, err) { if shouldUpdateRedis(fromDB, err) {
gopool.Go(func() { gopool.Go(func() {
if err := updateUserQuotaCache(id, quota); err != nil { if err := updateUserQuotaCache(id, quota); err != nil {
common.SysError("failed to update user quota cache: " + err.Error()) common.SysLog("failed to update user quota cache: " + err.Error())
} }
}) })
} }
@@ -610,7 +611,7 @@ func GetUserGroup(id int, fromDB bool) (group string, err error) {
if shouldUpdateRedis(fromDB, err) { if shouldUpdateRedis(fromDB, err) {
gopool.Go(func() { gopool.Go(func() {
if err := updateUserGroupCache(id, group); err != nil { if err := updateUserGroupCache(id, group); err != nil {
common.SysError("failed to update user group cache: " + err.Error()) common.SysLog("failed to update user group cache: " + err.Error())
} }
}) })
} }
@@ -639,7 +640,7 @@ func GetUserSetting(id int, fromDB bool) (settingMap dto.UserSetting, err error)
if shouldUpdateRedis(fromDB, err) { if shouldUpdateRedis(fromDB, err) {
gopool.Go(func() { gopool.Go(func() {
if err := updateUserSettingCache(id, setting); err != nil { if err := updateUserSettingCache(id, setting); err != nil {
common.SysError("failed to update user setting cache: " + err.Error()) common.SysLog("failed to update user setting cache: " + err.Error())
} }
}) })
} }
@@ -669,7 +670,7 @@ func IncreaseUserQuota(id int, quota int, db bool) (err error) {
gopool.Go(func() { gopool.Go(func() {
err := cacheIncrUserQuota(id, int64(quota)) err := cacheIncrUserQuota(id, int64(quota))
if err != nil { if err != nil {
common.SysError("failed to increase user quota: " + err.Error()) common.SysLog("failed to increase user quota: " + err.Error())
} }
}) })
if !db && common.BatchUpdateEnabled { if !db && common.BatchUpdateEnabled {
@@ -694,7 +695,7 @@ func DecreaseUserQuota(id int, quota int) (err error) {
gopool.Go(func() { gopool.Go(func() {
err := cacheDecrUserQuota(id, int64(quota)) err := cacheDecrUserQuota(id, int64(quota))
if err != nil { if err != nil {
common.SysError("failed to decrease user quota: " + err.Error()) common.SysLog("failed to decrease user quota: " + err.Error())
} }
}) })
if common.BatchUpdateEnabled { if common.BatchUpdateEnabled {
@@ -750,7 +751,7 @@ func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) {
}, },
).Error ).Error
if err != nil { if err != nil {
common.SysError("failed to update user used quota and request count: " + err.Error()) common.SysLog("failed to update user used quota and request count: " + err.Error())
return return
} }
@@ -767,14 +768,14 @@ func updateUserUsedQuota(id int, quota int) {
}, },
).Error ).Error
if err != nil { if err != nil {
common.SysError("failed to update user used quota: " + err.Error()) common.SysLog("failed to update user used quota: " + err.Error())
} }
} }
func updateUserRequestCount(id int, count int) { func updateUserRequestCount(id int, count int) {
err := DB.Model(&User{}).Where("id = ?", id).Update("request_count", gorm.Expr("request_count + ?", count)).Error err := DB.Model(&User{}).Where("id = ?", id).Update("request_count", gorm.Expr("request_count + ?", count)).Error
if err != nil { if err != nil {
common.SysError("failed to update user request count: " + err.Error()) common.SysLog("failed to update user request count: " + err.Error())
} }
} }
@@ -785,7 +786,7 @@ func GetUsernameById(id int, fromDB bool) (username string, err error) {
if shouldUpdateRedis(fromDB, err) { if shouldUpdateRedis(fromDB, err) {
gopool.Go(func() { gopool.Go(func() {
if err := updateUserNameCache(id, username); err != nil { if err := updateUserNameCache(id, username); err != nil {
common.SysError("failed to update user name cache: " + err.Error()) common.SysLog("failed to update user name cache: " + err.Error())
} }
}) })
} }

View File

@@ -37,7 +37,7 @@ func (user *UserBase) GetSetting() dto.UserSetting {
if user.Setting != "" { if user.Setting != "" {
err := common.Unmarshal([]byte(user.Setting), &setting) err := common.Unmarshal([]byte(user.Setting), &setting)
if err != nil { if err != nil {
common.SysError("failed to unmarshal setting: " + err.Error()) common.SysLog("failed to unmarshal setting: " + err.Error())
} }
} }
return setting return setting
@@ -78,7 +78,7 @@ func GetUserCache(userId int) (userCache *UserBase, err error) {
if shouldUpdateRedis(fromDB, err) && user != nil { if shouldUpdateRedis(fromDB, err) && user != nil {
gopool.Go(func() { gopool.Go(func() {
if err := updateUserCache(*user); err != nil { if err := updateUserCache(*user); err != nil {
common.SysError("failed to update user status cache: " + err.Error()) common.SysLog("failed to update user status cache: " + err.Error())
} }
}) })
} }

View File

@@ -77,12 +77,12 @@ func batchUpdate() {
case BatchUpdateTypeUserQuota: case BatchUpdateTypeUserQuota:
err := increaseUserQuota(key, value) err := increaseUserQuota(key, value)
if err != nil { if err != nil {
common.SysError("failed to batch update user quota: " + err.Error()) common.SysLog("failed to batch update user quota: " + err.Error())
} }
case BatchUpdateTypeTokenQuota: case BatchUpdateTypeTokenQuota:
err := increaseTokenQuota(key, value) err := increaseTokenQuota(key, value)
if err != nil { if err != nil {
common.SysError("failed to batch update token quota: " + err.Error()) common.SysLog("failed to batch update token quota: " + err.Error())
} }
case BatchUpdateTypeUsedQuota: case BatchUpdateTypeUsedQuota:
updateUserUsedQuota(key, value) updateUserUsedQuota(key, value)

View File

@@ -4,107 +4,40 @@ import (
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
"one-api/common"
"one-api/dto" "one-api/dto"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/relay/helper" "one-api/relay/helper"
"one-api/service" "one-api/service"
"one-api/setting"
"one-api/types" "one-api/types"
"strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.AudioRequest, error) { func AudioHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
audioRequest := &dto.AudioRequest{} info.InitChannelMeta(c)
err := common.UnmarshalBodyReusable(c, audioRequest)
if err != nil {
return nil, err
}
switch info.RelayMode {
case relayconstant.RelayModeAudioSpeech:
if audioRequest.Model == "" {
return nil, errors.New("model is required")
}
if setting.ShouldCheckPromptSensitive() {
words, err := service.CheckSensitiveInput(audioRequest.Input)
if err != nil {
common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ",")))
return nil, err
}
}
default:
err = c.Request.ParseForm()
if err != nil {
return nil, err
}
formData := c.Request.PostForm
if audioRequest.Model == "" {
audioRequest.Model = formData.Get("model")
}
if audioRequest.Model == "" { audioRequest, ok := info.Request.(*dto.AudioRequest)
return nil, errors.New("model is required") if !ok {
} return types.NewError(errors.New("invalid request type"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
audioRequest.ResponseFormat = formData.Get("response_format")
if audioRequest.ResponseFormat == "" {
audioRequest.ResponseFormat = "json"
}
}
return audioRequest, nil
}
func AudioHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
relayInfo := relaycommon.GenRelayInfoOpenAIAudio(c)
audioRequest, err := getAndValidAudioRequest(c, relayInfo)
if err != nil {
common.LogError(c, fmt.Sprintf("getAndValidAudioRequest failed: %s", err.Error()))
return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
} }
promptTokens := 0 err := helper.ModelMappedHelper(c, info, audioRequest)
preConsumedTokens := common.PreConsumedQuota
if relayInfo.RelayMode == relayconstant.RelayModeAudioSpeech {
promptTokens = service.CountTTSToken(audioRequest.Input, audioRequest.Model)
preConsumedTokens = promptTokens
relayInfo.PromptTokens = promptTokens
}
priceData, err := helper.ModelPriceHelper(c, relayInfo, preConsumedTokens, 0)
if err != nil {
return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
}
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if openaiErr != nil {
return openaiErr
}
defer func() {
if openaiErr != nil {
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
}
}()
err = helper.ModelMappedHelper(c, relayInfo, audioRequest)
if err != nil { if err != nil {
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
} }
adaptor := GetAdaptor(relayInfo.ApiType) adaptor := GetAdaptor(info.ApiType)
if adaptor == nil { if adaptor == nil {
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
} }
adaptor.Init(relayInfo) adaptor.Init(info)
ioReader, err := adaptor.ConvertAudioRequest(c, relayInfo, *audioRequest) ioReader, err := adaptor.ConvertAudioRequest(c, info, *audioRequest)
if err != nil { if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
} }
resp, err := adaptor.DoRequest(c, relayInfo, ioReader) resp, err := adaptor.DoRequest(c, info, ioReader)
if err != nil { if err != nil {
return types.NewError(err, types.ErrorCodeDoRequestFailed) return types.NewError(err, types.ErrorCodeDoRequestFailed)
} }
@@ -121,14 +54,14 @@ func AudioHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
} }
} }
usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo) usage, newAPIError := adaptor.DoResponse(c, httpResp, info)
if newAPIError != nil { if newAPIError != nil {
// reset status code 重置状态码 // reset status code 重置状态码
service.ResetStatusCode(newAPIError, statusCodeMappingStr) service.ResetStatusCode(newAPIError, statusCodeMappingStr)
return newAPIError return newAPIError
} }
postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "") postConsumeQuota(c, info, usage.(*dto.Usage), "")
return nil return nil
} }

View File

@@ -34,20 +34,20 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
var fullRequestURL string var fullRequestURL string
switch info.RelayFormat { switch info.RelayFormat {
case relaycommon.RelayFormatClaude: case types.RelayFormatClaude:
fullRequestURL = fmt.Sprintf("%s/api/v2/apps/claude-code-proxy/v1/messages", info.BaseUrl) fullRequestURL = fmt.Sprintf("%s/api/v2/apps/claude-code-proxy/v1/messages", info.ChannelBaseUrl)
default: default:
switch info.RelayMode { switch info.RelayMode {
case constant.RelayModeEmbeddings: case constant.RelayModeEmbeddings:
fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/embeddings", info.BaseUrl) fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/embeddings", info.ChannelBaseUrl)
case constant.RelayModeRerank: case constant.RelayModeRerank:
fullRequestURL = fmt.Sprintf("%s/api/v1/services/rerank/text-rerank/text-rerank", info.BaseUrl) fullRequestURL = fmt.Sprintf("%s/api/v1/services/rerank/text-rerank/text-rerank", info.ChannelBaseUrl)
case constant.RelayModeImagesGenerations: case constant.RelayModeImagesGenerations:
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.BaseUrl) fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.ChannelBaseUrl)
case constant.RelayModeCompletions: case constant.RelayModeCompletions:
fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/completions", info.BaseUrl) fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/completions", info.ChannelBaseUrl)
default: default:
fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/chat/completions", info.BaseUrl) fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/chat/completions", info.ChannelBaseUrl)
} }
} }
@@ -118,7 +118,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
switch info.RelayFormat { switch info.RelayFormat {
case relaycommon.RelayFormatClaude: case types.RelayFormatClaude:
if info.IsStream { if info.IsStream {
err, usage = claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage) err, usage = claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
} else { } else {

View File

@@ -8,6 +8,7 @@ import (
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/dto" "one-api/dto"
"one-api/logger"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/service" "one-api/service"
"one-api/types" "one-api/types"
@@ -22,14 +23,14 @@ func oaiImage2Ali(request dto.ImageRequest) *AliImageRequest {
imageRequest.Input.Prompt = request.Prompt imageRequest.Input.Prompt = request.Prompt
imageRequest.Model = request.Model imageRequest.Model = request.Model
imageRequest.Parameters.Size = strings.Replace(request.Size, "x", "*", -1) imageRequest.Parameters.Size = strings.Replace(request.Size, "x", "*", -1)
imageRequest.Parameters.N = request.N imageRequest.Parameters.N = int(request.N)
imageRequest.ResponseFormat = request.ResponseFormat imageRequest.ResponseFormat = request.ResponseFormat
return &imageRequest return &imageRequest
} }
func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error, []byte) { func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error, []byte) {
url := fmt.Sprintf("%s/api/v1/tasks/%s", info.BaseUrl, taskID) url := fmt.Sprintf("%s/api/v1/tasks/%s", info.ChannelBaseUrl, taskID)
var aliResponse AliResponse var aliResponse AliResponse
@@ -43,7 +44,7 @@ func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error
client := &http.Client{} client := &http.Client{}
resp, err := client.Do(req) resp, err := client.Do(req)
if err != nil { if err != nil {
common.SysError("updateTask client.Do err: " + err.Error()) common.SysLog("updateTask client.Do err: " + err.Error())
return &aliResponse, err, nil return &aliResponse, err, nil
} }
defer resp.Body.Close() defer resp.Body.Close()
@@ -53,7 +54,7 @@ func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error
var response AliResponse var response AliResponse
err = json.Unmarshal(responseBody, &response) err = json.Unmarshal(responseBody, &response)
if err != nil { if err != nil {
common.SysError("updateTask NewDecoder err: " + err.Error()) common.SysLog("updateTask NewDecoder err: " + err.Error())
return &aliResponse, err, nil return &aliResponse, err, nil
} }
@@ -109,7 +110,7 @@ func responseAli2OpenAIImage(c *gin.Context, response *AliResponse, info *relayc
if responseFormat == "b64_json" { if responseFormat == "b64_json" {
_, b64, err := service.GetImageFromUrl(data.Url) _, b64, err := service.GetImageFromUrl(data.Url)
if err != nil { if err != nil {
common.LogError(c, "get_image_data_failed: "+err.Error()) logger.LogError(c, "get_image_data_failed: "+err.Error())
continue continue
} }
b64Json = b64 b64Json = b64
@@ -134,14 +135,14 @@ func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
if err != nil { if err != nil {
return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
} }
common.CloseResponseBodyGracefully(resp) service.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &aliTaskResponse) err = json.Unmarshal(responseBody, &aliTaskResponse)
if err != nil { if err != nil {
return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
} }
if aliTaskResponse.Message != "" { if aliTaskResponse.Message != "" {
common.LogError(c, "ali_async_task_failed: "+aliTaskResponse.Message) logger.LogError(c, "ali_async_task_failed: "+aliTaskResponse.Message)
return types.NewError(errors.New(aliTaskResponse.Message), types.ErrorCodeBadResponse), nil return types.NewError(errors.New(aliTaskResponse.Message), types.ErrorCodeBadResponse), nil
} }

View File

@@ -4,9 +4,9 @@ import (
"encoding/json" "encoding/json"
"io" "io"
"net/http" "net/http"
"one-api/common"
"one-api/dto" "one-api/dto"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/service"
"one-api/types" "one-api/types"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -36,7 +36,7 @@ func RerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
if err != nil { if err != nil {
return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
} }
common.CloseResponseBodyGracefully(resp) service.CloseResponseBodyGracefully(resp)
var aliResponse AliRerankResponse var aliResponse AliRerankResponse
err = json.Unmarshal(responseBody, &aliResponse) err = json.Unmarshal(responseBody, &aliResponse)

View File

@@ -8,6 +8,7 @@ import (
"one-api/common" "one-api/common"
"one-api/dto" "one-api/dto"
"one-api/relay/helper" "one-api/relay/helper"
"one-api/service"
"strings" "strings"
"one-api/types" "one-api/types"
@@ -46,7 +47,7 @@ func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*types.NewAPIErro
return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
} }
common.CloseResponseBodyGracefully(resp) service.CloseResponseBodyGracefully(resp)
model := c.GetString("model") model := c.GetString("model")
if model == "" { if model == "" {
@@ -148,7 +149,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError,
var aliResponse AliResponse var aliResponse AliResponse
err := json.Unmarshal([]byte(data), &aliResponse) err := json.Unmarshal([]byte(data), &aliResponse)
if err != nil { if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error()) common.SysLog("error unmarshalling stream response: " + err.Error())
return true return true
} }
if aliResponse.Usage.OutputTokens != 0 { if aliResponse.Usage.OutputTokens != 0 {
@@ -161,7 +162,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError,
lastResponseText = aliResponse.Output.Text lastResponseText = aliResponse.Output.Text
jsonResponse, err := json.Marshal(response) jsonResponse, err := json.Marshal(response)
if err != nil { if err != nil {
common.SysError("error marshalling stream response: " + err.Error()) common.SysLog("error marshalling stream response: " + err.Error())
return true return true
} }
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
@@ -171,7 +172,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError,
return false return false
} }
}) })
common.CloseResponseBodyGracefully(resp) service.CloseResponseBodyGracefully(resp)
return nil, &usage return nil, &usage
} }
@@ -181,7 +182,7 @@ func aliHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, *dto.U
if err != nil { if err != nil {
return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
} }
common.CloseResponseBodyGracefully(resp) service.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &aliResponse) err = json.Unmarshal(responseBody, &aliResponse)
if err != nil { if err != nil {
return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil

View File

@@ -7,6 +7,7 @@ import (
"io" "io"
"net/http" "net/http"
common2 "one-api/common" common2 "one-api/common"
"one-api/logger"
"one-api/relay/common" "one-api/relay/common"
"one-api/relay/constant" "one-api/relay/constant"
"one-api/relay/helper" "one-api/relay/helper"
@@ -181,7 +182,7 @@ func sendPingData(c *gin.Context, mutex *sync.Mutex) error {
err := helper.PingData(c) err := helper.PingData(c)
if err != nil { if err != nil {
common2.LogError(c, "SSE ping error: "+err.Error()) logger.LogError(c, "SSE ping error: "+err.Error())
done <- err done <- err
return return
} }

View File

@@ -101,7 +101,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
default: default:
suffix += strings.ToLower(info.UpstreamModelName) suffix += strings.ToLower(info.UpstreamModelName)
} }
fullRequestURL := fmt.Sprintf("%s/rpc/2.0/ai_custom/v1/wenxinworkshop/%s", info.BaseUrl, suffix) fullRequestURL := fmt.Sprintf("%s/rpc/2.0/ai_custom/v1/wenxinworkshop/%s", info.ChannelBaseUrl, suffix)
var accessToken string var accessToken string
var err error var err error
if accessToken, err = getBaiduAccessToken(info.ApiKey); err != nil { if accessToken, err = getBaiduAccessToken(info.ApiKey); err != nil {

View File

@@ -118,7 +118,7 @@ func baiduStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.
var baiduResponse BaiduChatStreamResponse var baiduResponse BaiduChatStreamResponse
err := common.Unmarshal([]byte(data), &baiduResponse) err := common.Unmarshal([]byte(data), &baiduResponse)
if err != nil { if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error()) common.SysLog("error unmarshalling stream response: " + err.Error())
return true return true
} }
if baiduResponse.Usage.TotalTokens != 0 { if baiduResponse.Usage.TotalTokens != 0 {
@@ -129,11 +129,11 @@ func baiduStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.
response := streamResponseBaidu2OpenAI(&baiduResponse) response := streamResponseBaidu2OpenAI(&baiduResponse)
err = helper.ObjectData(c, response) err = helper.ObjectData(c, response)
if err != nil { if err != nil {
common.SysError("error sending stream response: " + err.Error()) common.SysLog("error sending stream response: " + err.Error())
} }
return true return true
}) })
common.CloseResponseBodyGracefully(resp) service.CloseResponseBodyGracefully(resp)
return nil, usage return nil, usage
} }
@@ -143,7 +143,7 @@ func baiduHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respon
if err != nil { if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil return types.NewError(err, types.ErrorCodeBadResponseBody), nil
} }
common.CloseResponseBodyGracefully(resp) service.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &baiduResponse) err = json.Unmarshal(responseBody, &baiduResponse)
if err != nil { if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil return types.NewError(err, types.ErrorCodeBadResponseBody), nil
@@ -168,7 +168,7 @@ func baiduEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *ht
if err != nil { if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil return types.NewError(err, types.ErrorCodeBadResponseBody), nil
} }
common.CloseResponseBodyGracefully(resp) service.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &baiduResponse) err = json.Unmarshal(responseBody, &baiduResponse)
if err != nil { if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil return types.NewError(err, types.ErrorCodeBadResponseBody), nil

View File

@@ -45,15 +45,15 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
switch info.RelayMode { switch info.RelayMode {
case constant.RelayModeChatCompletions: case constant.RelayModeChatCompletions:
return fmt.Sprintf("%s/v2/chat/completions", info.BaseUrl), nil return fmt.Sprintf("%s/v2/chat/completions", info.ChannelBaseUrl), nil
case constant.RelayModeEmbeddings: case constant.RelayModeEmbeddings:
return fmt.Sprintf("%s/v2/embeddings", info.BaseUrl), nil return fmt.Sprintf("%s/v2/embeddings", info.ChannelBaseUrl), nil
case constant.RelayModeImagesGenerations: case constant.RelayModeImagesGenerations:
return fmt.Sprintf("%s/v2/images/generations", info.BaseUrl), nil return fmt.Sprintf("%s/v2/images/generations", info.ChannelBaseUrl), nil
case constant.RelayModeImagesEdits: case constant.RelayModeImagesEdits:
return fmt.Sprintf("%s/v2/images/edits", info.BaseUrl), nil return fmt.Sprintf("%s/v2/images/edits", info.ChannelBaseUrl), nil
case constant.RelayModeRerank: case constant.RelayModeRerank:
return fmt.Sprintf("%s/v2/rerank", info.BaseUrl), nil return fmt.Sprintf("%s/v2/rerank", info.ChannelBaseUrl), nil
default: default:
} }
return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode) return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode)

View File

@@ -53,9 +53,9 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if a.RequestMode == RequestModeMessage { if a.RequestMode == RequestModeMessage {
return fmt.Sprintf("%s/v1/messages", info.BaseUrl), nil return fmt.Sprintf("%s/v1/messages", info.ChannelBaseUrl), nil
} else { } else {
return fmt.Sprintf("%s/v1/complete", info.BaseUrl), nil return fmt.Sprintf("%s/v1/complete", info.ChannelBaseUrl), nil
} }
} }

View File

@@ -7,6 +7,7 @@ import (
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/dto" "one-api/dto"
"one-api/logger"
"one-api/relay/channel/openrouter" "one-api/relay/channel/openrouter"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/relay/helper" "one-api/relay/helper"
@@ -375,7 +376,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla
for _, toolCall := range message.ParseToolCalls() { for _, toolCall := range message.ParseToolCalls() {
inputObj := make(map[string]any) inputObj := make(map[string]any)
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &inputObj); err != nil { if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &inputObj); err != nil {
common.SysError("tool call function arguments is not a map[string]any: " + fmt.Sprintf("%v", toolCall.Function.Arguments)) common.SysLog("tool call function arguments is not a map[string]any: " + fmt.Sprintf("%v", toolCall.Function.Arguments))
continue continue
} }
claudeMediaMessages = append(claudeMediaMessages, dto.ClaudeMediaMessage{ claudeMediaMessages = append(claudeMediaMessages, dto.ClaudeMediaMessage{
@@ -609,13 +610,13 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
var claudeResponse dto.ClaudeResponse var claudeResponse dto.ClaudeResponse
err := common.UnmarshalJsonStr(data, &claudeResponse) err := common.UnmarshalJsonStr(data, &claudeResponse)
if err != nil { if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error()) common.SysLog("error unmarshalling stream response: " + err.Error())
return types.NewError(err, types.ErrorCodeBadResponseBody) return types.NewError(err, types.ErrorCodeBadResponseBody)
} }
if claudeError := claudeResponse.GetClaudeError(); claudeError != nil && claudeError.Type != "" { if claudeError := claudeResponse.GetClaudeError(); claudeError != nil && claudeError.Type != "" {
return types.WithClaudeError(*claudeError, http.StatusInternalServerError) return types.WithClaudeError(*claudeError, http.StatusInternalServerError)
} }
if info.RelayFormat == relaycommon.RelayFormatClaude { if info.RelayFormat == types.RelayFormatClaude {
FormatClaudeResponseInfo(requestMode, &claudeResponse, nil, claudeInfo) FormatClaudeResponseInfo(requestMode, &claudeResponse, nil, claudeInfo)
if requestMode == RequestModeCompletion { if requestMode == RequestModeCompletion {
@@ -628,7 +629,7 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
} }
} }
helper.ClaudeChunkData(c, claudeResponse, data) helper.ClaudeChunkData(c, claudeResponse, data)
} else if info.RelayFormat == relaycommon.RelayFormatOpenAI { } else if info.RelayFormat == types.RelayFormatOpenAI {
response := StreamResponseClaude2OpenAI(requestMode, &claudeResponse) response := StreamResponseClaude2OpenAI(requestMode, &claudeResponse)
if !FormatClaudeResponseInfo(requestMode, &claudeResponse, response, claudeInfo) { if !FormatClaudeResponseInfo(requestMode, &claudeResponse, response, claudeInfo) {
@@ -637,7 +638,7 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
err = helper.ObjectData(c, response) err = helper.ObjectData(c, response)
if err != nil { if err != nil {
common.LogError(c, "send_stream_response_failed: "+err.Error()) logger.LogError(c, "send_stream_response_failed: "+err.Error())
} }
} }
return nil return nil
@@ -653,21 +654,20 @@ func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, clau
} }
if claudeInfo.Usage.CompletionTokens == 0 || !claudeInfo.Done { if claudeInfo.Usage.CompletionTokens == 0 || !claudeInfo.Done {
if common.DebugEnabled { if common.DebugEnabled {
common.SysError("claude response usage is not complete, maybe upstream error") common.SysLog("claude response usage is not complete, maybe upstream error")
} }
claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens) claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
} }
} }
if info.RelayFormat == relaycommon.RelayFormatClaude { if info.RelayFormat == types.RelayFormatClaude {
// //
} else if info.RelayFormat == relaycommon.RelayFormatOpenAI { } else if info.RelayFormat == types.RelayFormatOpenAI {
if info.ShouldIncludeUsage { if info.ShouldIncludeUsage {
response := helper.GenerateFinalUsageResponse(claudeInfo.ResponseId, claudeInfo.Created, info.UpstreamModelName, *claudeInfo.Usage) response := helper.GenerateFinalUsageResponse(claudeInfo.ResponseId, claudeInfo.Created, info.UpstreamModelName, *claudeInfo.Usage)
err := helper.ObjectData(c, response) err := helper.ObjectData(c, response)
if err != nil { if err != nil {
common.SysError("send final response failed: " + err.Error()) common.SysLog("send final response failed: " + err.Error())
} }
} }
helper.Done(c) helper.Done(c)
@@ -721,14 +721,14 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
} }
var responseData []byte var responseData []byte
switch info.RelayFormat { switch info.RelayFormat {
case relaycommon.RelayFormatOpenAI: case types.RelayFormatOpenAI:
openaiResponse := ResponseClaude2OpenAI(requestMode, &claudeResponse) openaiResponse := ResponseClaude2OpenAI(requestMode, &claudeResponse)
openaiResponse.Usage = *claudeInfo.Usage openaiResponse.Usage = *claudeInfo.Usage
responseData, err = json.Marshal(openaiResponse) responseData, err = json.Marshal(openaiResponse)
if err != nil { if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody) return types.NewError(err, types.ErrorCodeBadResponseBody)
} }
case relaycommon.RelayFormatClaude: case types.RelayFormatClaude:
responseData = data responseData = data
} }
@@ -736,12 +736,12 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
c.Set("claude_web_search_requests", claudeResponse.Usage.ServerToolUse.WebSearchRequests) c.Set("claude_web_search_requests", claudeResponse.Usage.ServerToolUse.WebSearchRequests)
} }
common.IOCopyBytesGracefully(c, nil, responseData) service.IOCopyBytesGracefully(c, nil, responseData)
return nil return nil
} }
func ClaudeHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*types.NewAPIError, *dto.Usage) { func ClaudeHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*types.NewAPIError, *dto.Usage) {
defer common.CloseResponseBodyGracefully(resp) defer service.CloseResponseBodyGracefully(resp)
claudeInfo := &ClaudeResponseInfo{ claudeInfo := &ClaudeResponseInfo{
ResponseId: helper.GetResponseID(c), ResponseId: helper.GetResponseID(c),

View File

@@ -36,13 +36,13 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
switch info.RelayMode { switch info.RelayMode {
case constant.RelayModeChatCompletions: case constant.RelayModeChatCompletions:
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/chat/completions", info.BaseUrl, info.ApiVersion), nil return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/chat/completions", info.ChannelBaseUrl, info.ApiVersion), nil
case constant.RelayModeEmbeddings: case constant.RelayModeEmbeddings:
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/embeddings", info.BaseUrl, info.ApiVersion), nil return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/embeddings", info.ChannelBaseUrl, info.ApiVersion), nil
case constant.RelayModeResponses: case constant.RelayModeResponses:
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/responses", info.BaseUrl, info.ApiVersion), nil return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/responses", info.ChannelBaseUrl, info.ApiVersion), nil
default: default:
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/run/%s", info.BaseUrl, info.ApiVersion, info.UpstreamModelName), nil return fmt.Sprintf("%s/client/v4/accounts/%s/ai/run/%s", info.ChannelBaseUrl, info.ApiVersion, info.UpstreamModelName), nil
} }
} }

View File

@@ -5,8 +5,8 @@ import (
"encoding/json" "encoding/json"
"io" "io"
"net/http" "net/http"
"one-api/common"
"one-api/dto" "one-api/dto"
"one-api/logger"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/relay/helper" "one-api/relay/helper"
"one-api/service" "one-api/service"
@@ -51,7 +51,7 @@ func cfStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Res
var response dto.ChatCompletionsStreamResponse var response dto.ChatCompletionsStreamResponse
err := json.Unmarshal([]byte(data), &response) err := json.Unmarshal([]byte(data), &response)
if err != nil { if err != nil {
common.LogError(c, "error_unmarshalling_stream_response: "+err.Error()) logger.LogError(c, "error_unmarshalling_stream_response: "+err.Error())
continue continue
} }
for _, choice := range response.Choices { for _, choice := range response.Choices {
@@ -66,24 +66,24 @@ func cfStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Res
info.FirstResponseTime = time.Now() info.FirstResponseTime = time.Now()
} }
if err != nil { if err != nil {
common.LogError(c, "error_rendering_stream_response: "+err.Error()) logger.LogError(c, "error_rendering_stream_response: "+err.Error())
} }
} }
if err := scanner.Err(); err != nil { if err := scanner.Err(); err != nil {
common.LogError(c, "error_scanning_stream_response: "+err.Error()) logger.LogError(c, "error_scanning_stream_response: "+err.Error())
} }
usage := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) usage := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
if info.ShouldIncludeUsage { if info.ShouldIncludeUsage {
response := helper.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage) response := helper.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage)
err := helper.ObjectData(c, response) err := helper.ObjectData(c, response)
if err != nil { if err != nil {
common.LogError(c, "error_rendering_final_usage_response: "+err.Error()) logger.LogError(c, "error_rendering_final_usage_response: "+err.Error())
} }
} }
helper.Done(c) helper.Done(c)
common.CloseResponseBodyGracefully(resp) service.CloseResponseBodyGracefully(resp)
return nil, usage return nil, usage
} }
@@ -93,7 +93,7 @@ func cfHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response)
if err != nil { if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil return types.NewError(err, types.ErrorCodeBadResponseBody), nil
} }
common.CloseResponseBodyGracefully(resp) service.CloseResponseBodyGracefully(resp)
var response dto.TextResponse var response dto.TextResponse
err = json.Unmarshal(responseBody, &response) err = json.Unmarshal(responseBody, &response)
if err != nil { if err != nil {
@@ -123,7 +123,7 @@ func cfSTTHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respon
if err != nil { if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil return types.NewError(err, types.ErrorCodeBadResponseBody), nil
} }
common.CloseResponseBodyGracefully(resp) service.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &cfResp) err = json.Unmarshal(responseBody, &cfResp)
if err != nil { if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil return types.NewError(err, types.ErrorCodeBadResponseBody), nil

View File

@@ -43,9 +43,9 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if info.RelayMode == constant.RelayModeRerank { if info.RelayMode == constant.RelayModeRerank {
return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil return fmt.Sprintf("%s/v1/rerank", info.ChannelBaseUrl), nil
} else { } else {
return fmt.Sprintf("%s/v1/chat", info.BaseUrl), nil return fmt.Sprintf("%s/v1/chat", info.ChannelBaseUrl), nil
} }
} }

View File

@@ -118,7 +118,7 @@ func cohereStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
var cohereResp CohereResponse var cohereResp CohereResponse
err := json.Unmarshal([]byte(data), &cohereResp) err := json.Unmarshal([]byte(data), &cohereResp)
if err != nil { if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error()) common.SysLog("error unmarshalling stream response: " + err.Error())
return true return true
} }
var openaiResp dto.ChatCompletionsStreamResponse var openaiResp dto.ChatCompletionsStreamResponse
@@ -153,7 +153,7 @@ func cohereStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
} }
jsonStr, err := json.Marshal(openaiResp) jsonStr, err := json.Marshal(openaiResp)
if err != nil { if err != nil {
common.SysError("error marshalling stream response: " + err.Error()) common.SysLog("error marshalling stream response: " + err.Error())
return true return true
} }
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)}) c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
@@ -175,7 +175,7 @@ func cohereHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
if err != nil { if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody) return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
} }
common.CloseResponseBodyGracefully(resp) service.CloseResponseBodyGracefully(resp)
var cohereResp CohereResponseResult var cohereResp CohereResponseResult
err = json.Unmarshal(responseBody, &cohereResp) err = json.Unmarshal(responseBody, &cohereResp)
if err != nil { if err != nil {
@@ -216,7 +216,7 @@ func cohereRerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.
if err != nil { if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody) return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
} }
common.CloseResponseBodyGracefully(resp) service.CloseResponseBodyGracefully(resp)
var cohereResp CohereRerankResponseResult var cohereResp CohereRerankResponseResult
err = json.Unmarshal(responseBody, &cohereResp) err = json.Unmarshal(responseBody, &cohereResp)
if err != nil { if err != nil {

View File

@@ -122,7 +122,7 @@ func (a *Adaptor) GetModelList() []string {
// GetRequestURL implements channel.Adaptor. // GetRequestURL implements channel.Adaptor.
func (a *Adaptor) GetRequestURL(info *common.RelayInfo) (string, error) { func (a *Adaptor) GetRequestURL(info *common.RelayInfo) (string, error) {
return fmt.Sprintf("%s/v3/chat", info.BaseUrl), nil return fmt.Sprintf("%s/v3/chat", info.ChannelBaseUrl), nil
} }
// Init implements channel.Adaptor. // Init implements channel.Adaptor.

View File

@@ -49,7 +49,7 @@ func cozeChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Res
if err != nil { if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody) return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
} }
common.CloseResponseBodyGracefully(resp) service.CloseResponseBodyGracefully(resp)
// convert coze response to openai response // convert coze response to openai response
var response dto.TextResponse var response dto.TextResponse
var cozeResponse CozeChatDetailResponse var cozeResponse CozeChatDetailResponse
@@ -154,7 +154,7 @@ func handleCozeEvent(c *gin.Context, event string, data string, responseText *st
var chatData CozeChatResponseData var chatData CozeChatResponseData
err := json.Unmarshal([]byte(data), &chatData) err := json.Unmarshal([]byte(data), &chatData)
if err != nil { if err != nil {
common.SysError("error_unmarshalling_stream_response: " + err.Error()) common.SysLog("error_unmarshalling_stream_response: " + err.Error())
return return
} }
@@ -171,14 +171,14 @@ func handleCozeEvent(c *gin.Context, event string, data string, responseText *st
var messageData CozeChatV3MessageDetail var messageData CozeChatV3MessageDetail
err := json.Unmarshal([]byte(data), &messageData) err := json.Unmarshal([]byte(data), &messageData)
if err != nil { if err != nil {
common.SysError("error_unmarshalling_stream_response: " + err.Error()) common.SysLog("error_unmarshalling_stream_response: " + err.Error())
return return
} }
var content string var content string
err = json.Unmarshal(messageData.Content, &content) err = json.Unmarshal(messageData.Content, &content)
if err != nil { if err != nil {
common.SysError("error_unmarshalling_stream_response: " + err.Error()) common.SysLog("error_unmarshalling_stream_response: " + err.Error())
return return
} }
@@ -203,16 +203,16 @@ func handleCozeEvent(c *gin.Context, event string, data string, responseText *st
var errorData CozeError var errorData CozeError
err := json.Unmarshal([]byte(data), &errorData) err := json.Unmarshal([]byte(data), &errorData)
if err != nil { if err != nil {
common.SysError("error_unmarshalling_stream_response: " + err.Error()) common.SysLog("error_unmarshalling_stream_response: " + err.Error())
return return
} }
common.SysError(fmt.Sprintf("stream event error: ", errorData.Code, errorData.Message)) common.SysLog(fmt.Sprintf("stream event error: ", errorData.Code, errorData.Message))
} }
} }
func checkIfChatComplete(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (error, bool) { func checkIfChatComplete(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (error, bool) {
requestURL := fmt.Sprintf("%s/v3/chat/retrieve", info.BaseUrl) requestURL := fmt.Sprintf("%s/v3/chat/retrieve", info.ChannelBaseUrl)
requestURL = requestURL + "?conversation_id=" + c.GetString("coze_conversation_id") + "&chat_id=" + c.GetString("coze_chat_id") requestURL = requestURL + "?conversation_id=" + c.GetString("coze_conversation_id") + "&chat_id=" + c.GetString("coze_chat_id")
// 将 conversationId和chatId作为参数发送get请求 // 将 conversationId和chatId作为参数发送get请求
@@ -258,7 +258,7 @@ func checkIfChatComplete(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo
} }
func getChatDetail(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (*http.Response, error) { func getChatDetail(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (*http.Response, error) {
requestURL := fmt.Sprintf("%s/v3/chat/message/list", info.BaseUrl) requestURL := fmt.Sprintf("%s/v3/chat/message/list", info.ChannelBaseUrl)
requestURL = requestURL + "?conversation_id=" + c.GetString("coze_conversation_id") + "&chat_id=" + c.GetString("coze_chat_id") requestURL = requestURL + "?conversation_id=" + c.GetString("coze_conversation_id") + "&chat_id=" + c.GetString("coze_chat_id")
req, err := http.NewRequest("GET", requestURL, nil) req, err := http.NewRequest("GET", requestURL, nil)

View File

@@ -43,15 +43,15 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
} }
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
fimBaseUrl := info.BaseUrl fimBaseUrl := info.ChannelBaseUrl
if !strings.HasSuffix(info.BaseUrl, "/beta") { if !strings.HasSuffix(info.ChannelBaseUrl, "/beta") {
fimBaseUrl += "/beta" fimBaseUrl += "/beta"
} }
switch info.RelayMode { switch info.RelayMode {
case constant.RelayModeCompletions: case constant.RelayModeCompletions:
return fmt.Sprintf("%s/completions", fimBaseUrl), nil return fmt.Sprintf("%s/completions", fimBaseUrl), nil
default: default:
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil
} }
} }

View File

@@ -61,13 +61,13 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
switch a.BotType { switch a.BotType {
case BotTypeWorkFlow: case BotTypeWorkFlow:
return fmt.Sprintf("%s/v1/workflows/run", info.BaseUrl), nil return fmt.Sprintf("%s/v1/workflows/run", info.ChannelBaseUrl), nil
case BotTypeCompletion: case BotTypeCompletion:
return fmt.Sprintf("%s/v1/completion-messages", info.BaseUrl), nil return fmt.Sprintf("%s/v1/completion-messages", info.ChannelBaseUrl), nil
case BotTypeAgent: case BotTypeAgent:
fallthrough fallthrough
default: default:
return fmt.Sprintf("%s/v1/chat-messages", info.BaseUrl), nil return fmt.Sprintf("%s/v1/chat-messages", info.ChannelBaseUrl), nil
} }
} }

View File

@@ -22,7 +22,7 @@ import (
) )
func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, media dto.MediaContent) *DifyFile { func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, media dto.MediaContent) *DifyFile {
uploadUrl := fmt.Sprintf("%s/v1/files/upload", info.BaseUrl) uploadUrl := fmt.Sprintf("%s/v1/files/upload", info.ChannelBaseUrl)
switch media.Type { switch media.Type {
case dto.ContentTypeImageURL: case dto.ContentTypeImageURL:
// Decode base64 data // Decode base64 data
@@ -36,14 +36,14 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me
// Decode base64 string // Decode base64 string
decodedData, err := base64.StdEncoding.DecodeString(base64Data) decodedData, err := base64.StdEncoding.DecodeString(base64Data)
if err != nil { if err != nil {
common.SysError("failed to decode base64: " + err.Error()) common.SysLog("failed to decode base64: " + err.Error())
return nil return nil
} }
// Create temporary file // Create temporary file
tempFile, err := os.CreateTemp("", "dify-upload-*") tempFile, err := os.CreateTemp("", "dify-upload-*")
if err != nil { if err != nil {
common.SysError("failed to create temp file: " + err.Error()) common.SysLog("failed to create temp file: " + err.Error())
return nil return nil
} }
defer tempFile.Close() defer tempFile.Close()
@@ -51,7 +51,7 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me
// Write decoded data to temp file // Write decoded data to temp file
if _, err := tempFile.Write(decodedData); err != nil { if _, err := tempFile.Write(decodedData); err != nil {
common.SysError("failed to write to temp file: " + err.Error()) common.SysLog("failed to write to temp file: " + err.Error())
return nil return nil
} }
@@ -61,7 +61,7 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me
// Add user field // Add user field
if err := writer.WriteField("user", user); err != nil { if err := writer.WriteField("user", user); err != nil {
common.SysError("failed to add user field: " + err.Error()) common.SysLog("failed to add user field: " + err.Error())
return nil return nil
} }
@@ -74,13 +74,13 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me
// Create form file // Create form file
part, err := writer.CreateFormFile("file", fmt.Sprintf("image.%s", strings.TrimPrefix(mimeType, "image/"))) part, err := writer.CreateFormFile("file", fmt.Sprintf("image.%s", strings.TrimPrefix(mimeType, "image/")))
if err != nil { if err != nil {
common.SysError("failed to create form file: " + err.Error()) common.SysLog("failed to create form file: " + err.Error())
return nil return nil
} }
// Copy file content to form // Copy file content to form
if _, err = io.Copy(part, bytes.NewReader(decodedData)); err != nil { if _, err = io.Copy(part, bytes.NewReader(decodedData)); err != nil {
common.SysError("failed to copy file content: " + err.Error()) common.SysLog("failed to copy file content: " + err.Error())
return nil return nil
} }
writer.Close() writer.Close()
@@ -88,7 +88,7 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me
// Create HTTP request // Create HTTP request
req, err := http.NewRequest("POST", uploadUrl, body) req, err := http.NewRequest("POST", uploadUrl, body)
if err != nil { if err != nil {
common.SysError("failed to create request: " + err.Error()) common.SysLog("failed to create request: " + err.Error())
return nil return nil
} }
@@ -99,7 +99,7 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me
client := service.GetHttpClient() client := service.GetHttpClient()
resp, err := client.Do(req) resp, err := client.Do(req)
if err != nil { if err != nil {
common.SysError("failed to send request: " + err.Error()) common.SysLog("failed to send request: " + err.Error())
return nil return nil
} }
defer resp.Body.Close() defer resp.Body.Close()
@@ -109,7 +109,7 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me
Id string `json:"id"` Id string `json:"id"`
} }
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
common.SysError("failed to decode response: " + err.Error()) common.SysLog("failed to decode response: " + err.Error())
return nil return nil
} }
@@ -219,7 +219,7 @@ func difyStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
var difyResponse DifyChunkChatCompletionResponse var difyResponse DifyChunkChatCompletionResponse
err := json.Unmarshal([]byte(data), &difyResponse) err := json.Unmarshal([]byte(data), &difyResponse)
if err != nil { if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error()) common.SysLog("error unmarshalling stream response: " + err.Error())
return true return true
} }
var openaiResponse dto.ChatCompletionsStreamResponse var openaiResponse dto.ChatCompletionsStreamResponse
@@ -239,7 +239,7 @@ func difyStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
} }
err = helper.ObjectData(c, openaiResponse) err = helper.ObjectData(c, openaiResponse)
if err != nil { if err != nil {
common.SysError(err.Error()) common.SysLog(err.Error())
} }
return true return true
}) })
@@ -258,7 +258,7 @@ func difyHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respons
if err != nil { if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody) return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
} }
common.CloseResponseBodyGracefully(resp) service.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &difyResponse) err = json.Unmarshal(responseBody, &difyResponse)
if err != nil { if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody) return nil, types.NewError(err, types.ErrorCodeBadResponseBody)

View File

@@ -78,7 +78,7 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
}, },
}, },
Parameters: dto.GeminiImageParameters{ Parameters: dto.GeminiImageParameters{
SampleCount: request.N, SampleCount: int(request.N),
AspectRatio: aspectRatio, AspectRatio: aspectRatio,
PersonGeneration: "allow_adult", // default allow adult PersonGeneration: "allow_adult", // default allow adult
}, },
@@ -108,7 +108,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
version := model_setting.GetGeminiVersionSetting(info.UpstreamModelName) version := model_setting.GetGeminiVersionSetting(info.UpstreamModelName)
if strings.HasPrefix(info.UpstreamModelName, "imagen") { if strings.HasPrefix(info.UpstreamModelName, "imagen") {
return fmt.Sprintf("%s/%s/models/%s:predict", info.BaseUrl, version, info.UpstreamModelName), nil return fmt.Sprintf("%s/%s/models/%s:predict", info.ChannelBaseUrl, version, info.UpstreamModelName), nil
} }
if strings.HasPrefix(info.UpstreamModelName, "text-embedding") || if strings.HasPrefix(info.UpstreamModelName, "text-embedding") ||
@@ -118,7 +118,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if info.IsGeminiBatchEmbedding { if info.IsGeminiBatchEmbedding {
action = "batchEmbedContents" action = "batchEmbedContents"
} }
return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil return fmt.Sprintf("%s/%s/models/%s:%s", info.ChannelBaseUrl, version, info.UpstreamModelName, action), nil
} }
action := "generateContent" action := "generateContent"
@@ -128,7 +128,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
info.DisablePing = true info.DisablePing = true
} }
} }
return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil return fmt.Sprintf("%s/%s/models/%s:%s", info.ChannelBaseUrl, version, info.UpstreamModelName, action), nil
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {

View File

@@ -5,6 +5,7 @@ import (
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/dto" "one-api/dto"
"one-api/logger"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/relay/helper" "one-api/relay/helper"
"one-api/service" "one-api/service"
@@ -17,7 +18,7 @@ import (
) )
func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
defer common.CloseResponseBodyGracefully(resp) defer service.CloseResponseBodyGracefully(resp)
// 读取响应体 // 读取响应体
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
@@ -53,13 +54,13 @@ func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, re
} }
} }
common.IOCopyBytesGracefully(c, resp, responseBody) service.IOCopyBytesGracefully(c, resp, responseBody)
return &usage, nil return &usage, nil
} }
func NativeGeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *types.NewAPIError) { func NativeGeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *types.NewAPIError) {
defer common.CloseResponseBodyGracefully(resp) defer service.CloseResponseBodyGracefully(resp)
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
@@ -89,7 +90,7 @@ func NativeGeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *rel
} }
} }
common.IOCopyBytesGracefully(c, resp, responseBody) service.IOCopyBytesGracefully(c, resp, responseBody)
return usage, nil return usage, nil
} }
@@ -106,7 +107,7 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayIn
var geminiResponse dto.GeminiChatResponse var geminiResponse dto.GeminiChatResponse
err := common.UnmarshalJsonStr(data, &geminiResponse) err := common.UnmarshalJsonStr(data, &geminiResponse)
if err != nil { if err != nil {
common.LogError(c, "error unmarshalling stream response: "+err.Error()) logger.LogError(c, "error unmarshalling stream response: "+err.Error())
return false return false
} }
@@ -140,7 +141,7 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayIn
// 直接发送 GeminiChatResponse 响应 // 直接发送 GeminiChatResponse 响应
err = helper.StringData(c, data) err = helper.StringData(c, data)
if err != nil { if err != nil {
common.LogError(c, err.Error()) logger.LogError(c, err.Error())
} }
info.SendResponseCount++ info.SendResponseCount++
return true return true

View File

@@ -9,6 +9,7 @@ import (
"one-api/common" "one-api/common"
"one-api/constant" "one-api/constant"
"one-api/dto" "one-api/dto"
"one-api/logger"
"one-api/relay/channel/openai" "one-api/relay/channel/openai"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/relay/helper" "one-api/relay/helper"
@@ -901,7 +902,7 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
var geminiResponse dto.GeminiChatResponse var geminiResponse dto.GeminiChatResponse
err := common.UnmarshalJsonStr(data, &geminiResponse) err := common.UnmarshalJsonStr(data, &geminiResponse)
if err != nil { if err != nil {
common.LogError(c, "error unmarshalling stream response: "+err.Error()) logger.LogError(c, "error unmarshalling stream response: "+err.Error())
return false return false
} }
@@ -945,7 +946,7 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
finishReason = constant.FinishReasonToolCalls finishReason = constant.FinishReasonToolCalls
err = handleStream(c, info, emptyResponse) err = handleStream(c, info, emptyResponse)
if err != nil { if err != nil {
common.LogError(c, err.Error()) logger.LogError(c, err.Error())
} }
response.ClearToolCalls() response.ClearToolCalls()
@@ -957,7 +958,7 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
err = handleStream(c, info, response) err = handleStream(c, info, response)
if err != nil { if err != nil {
common.LogError(c, err.Error()) logger.LogError(c, err.Error())
} }
if isStop { if isStop {
_ = handleStream(c, info, helper.GenerateStopResponse(id, createAt, info.UpstreamModelName, finishReason)) _ = handleStream(c, info, helper.GenerateStopResponse(id, createAt, info.UpstreamModelName, finishReason))
@@ -993,7 +994,7 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
response := helper.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage) response := helper.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
err := handleFinalStream(c, info, response) err := handleFinalStream(c, info, response)
if err != nil { if err != nil {
common.SysError("send final response failed: " + err.Error()) common.SysLog("send final response failed: " + err.Error())
} }
//if info.RelayFormat == relaycommon.RelayFormatOpenAI { //if info.RelayFormat == relaycommon.RelayFormatOpenAI {
// helper.Done(c) // helper.Done(c)
@@ -1007,7 +1008,7 @@ func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
if err != nil { if err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
} }
common.CloseResponseBodyGracefully(resp) service.CloseResponseBodyGracefully(resp)
if common.DebugEnabled { if common.DebugEnabled {
println(string(responseBody)) println(string(responseBody))
} }
@@ -1041,29 +1042,29 @@ func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
fullTextResponse.Usage = usage fullTextResponse.Usage = usage
switch info.RelayFormat { switch info.RelayFormat {
case relaycommon.RelayFormatOpenAI: case types.RelayFormatOpenAI:
responseBody, err = common.Marshal(fullTextResponse) responseBody, err = common.Marshal(fullTextResponse)
if err != nil { if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody) return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
} }
case relaycommon.RelayFormatClaude: case types.RelayFormatClaude:
claudeResp := service.ResponseOpenAI2Claude(fullTextResponse, info) claudeResp := service.ResponseOpenAI2Claude(fullTextResponse, info)
claudeRespStr, err := common.Marshal(claudeResp) claudeRespStr, err := common.Marshal(claudeResp)
if err != nil { if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody) return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
} }
responseBody = claudeRespStr responseBody = claudeRespStr
case relaycommon.RelayFormatGemini: case types.RelayFormatGemini:
break break
} }
common.IOCopyBytesGracefully(c, resp, responseBody) service.IOCopyBytesGracefully(c, resp, responseBody)
return &usage, nil return &usage, nil
} }
func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
defer common.CloseResponseBodyGracefully(resp) defer service.CloseResponseBodyGracefully(resp)
responseBody, readErr := io.ReadAll(resp.Body) responseBody, readErr := io.ReadAll(resp.Body)
if readErr != nil { if readErr != nil {
@@ -1107,7 +1108,7 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h
return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
} }
common.IOCopyBytesGracefully(c, resp, jsonResponse) service.IOCopyBytesGracefully(c, resp, jsonResponse)
return usage, nil return usage, nil
} }

View File

@@ -32,7 +32,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
} }
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/?Action=CVProcess&Version=2022-08-31", info.BaseUrl), nil return fmt.Sprintf("%s/?Action=CVProcess&Version=2022-08-31", info.ChannelBaseUrl), nil
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info *relaycommon.RelayInfo) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info *relaycommon.RelayInfo) error {

View File

@@ -5,9 +5,9 @@ import (
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"one-api/common"
"one-api/dto" "one-api/dto"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/service"
"one-api/types" "one-api/types"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -54,7 +54,7 @@ func jimengImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.R
if err != nil { if err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
} }
common.CloseResponseBodyGracefully(resp) service.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &jimengResponse) err = json.Unmarshal(responseBody, &jimengResponse)
if err != nil { if err != nil {

View File

@@ -12,7 +12,7 @@ import (
"io" "io"
"net/http" "net/http"
"net/url" "net/url"
"one-api/common" "one-api/logger"
"sort" "sort"
"strings" "strings"
"time" "time"
@@ -44,7 +44,7 @@ func SetPayloadHash(c *gin.Context, req any) error {
if err != nil { if err != nil {
return err return err
} }
common.LogInfo(c, fmt.Sprintf("SetPayloadHash body: %s", body)) logger.LogInfo(c, fmt.Sprintf("SetPayloadHash body: %s", body))
payloadHash := sha256.Sum256(body) payloadHash := sha256.Sum256(body)
hexPayloadHash := hex.EncodeToString(payloadHash[:]) hexPayloadHash := hex.EncodeToString(payloadHash[:])
c.Set(HexPayloadHashKey, hexPayloadHash) c.Set(HexPayloadHashKey, hexPayloadHash)

View File

@@ -45,9 +45,9 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if info.RelayMode == constant.RelayModeRerank { if info.RelayMode == constant.RelayModeRerank {
return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil return fmt.Sprintf("%s/v1/rerank", info.ChannelBaseUrl), nil
} else if info.RelayMode == constant.RelayModeEmbeddings { } else if info.RelayMode == constant.RelayModeEmbeddings {
return fmt.Sprintf("%s/v1/embeddings", info.BaseUrl), nil return fmt.Sprintf("%s/v1/embeddings", info.ChannelBaseUrl), nil
} }
return "", errors.New("invalid relay mode") return "", errors.New("invalid relay mode")
} }

View File

@@ -6,5 +6,5 @@ import (
) )
func GetRequestURL(info *relaycommon.RelayInfo) (string, error) { func GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/v1/text/chatcompletion_v2", info.BaseUrl), nil return fmt.Sprintf("%s/v1/text/chatcompletion_v2", info.ChannelBaseUrl), nil
} }

View File

@@ -41,7 +41,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
} }
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, info.RequestURLPath, info.ChannelType), nil
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {

View File

@@ -54,7 +54,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if strings.HasPrefix(info.UpstreamModelName, "m3e") { if strings.HasPrefix(info.UpstreamModelName, "m3e") {
suffix = "embeddings" suffix = "embeddings"
} }
fullRequestURL := fmt.Sprintf("%s/%s", info.BaseUrl, suffix) fullRequestURL := fmt.Sprintf("%s/%s", info.ChannelBaseUrl, suffix)
return fullRequestURL, nil return fullRequestURL, nil
} }

View File

@@ -7,6 +7,7 @@ import (
"one-api/common" "one-api/common"
"one-api/dto" "one-api/dto"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/service"
"one-api/types" "one-api/types"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -56,7 +57,7 @@ func mokaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *htt
if err != nil { if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody) return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
} }
common.CloseResponseBodyGracefully(resp) service.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &baiduResponse) err = json.Unmarshal(responseBody, &baiduResponse)
if err != nil { if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody) return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
@@ -77,6 +78,6 @@ func mokaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *htt
} }
c.Writer.Header().Set("Content-Type", "application/json") c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode) c.Writer.WriteHeader(resp.StatusCode)
common.IOCopyBytesGracefully(c, resp, jsonResponse) service.IOCopyBytesGracefully(c, resp, jsonResponse)
return &fullTextResponse.Usage, nil return &fullTextResponse.Usage, nil
} }

View File

@@ -44,19 +44,19 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
switch info.RelayFormat { switch info.RelayFormat {
case relaycommon.RelayFormatClaude: case types.RelayFormatClaude:
return fmt.Sprintf("%s/anthropic/v1/messages", info.BaseUrl), nil return fmt.Sprintf("%s/anthropic/v1/messages", info.ChannelBaseUrl), nil
default: default:
if info.RelayMode == constant.RelayModeRerank { if info.RelayMode == constant.RelayModeRerank {
return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil return fmt.Sprintf("%s/v1/rerank", info.ChannelBaseUrl), nil
} else if info.RelayMode == constant.RelayModeEmbeddings { } else if info.RelayMode == constant.RelayModeEmbeddings {
return fmt.Sprintf("%s/v1/embeddings", info.BaseUrl), nil return fmt.Sprintf("%s/v1/embeddings", info.ChannelBaseUrl), nil
} else if info.RelayMode == constant.RelayModeChatCompletions { } else if info.RelayMode == constant.RelayModeChatCompletions {
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil
} else if info.RelayMode == constant.RelayModeCompletions { } else if info.RelayMode == constant.RelayModeCompletions {
return fmt.Sprintf("%s/v1/completions", info.BaseUrl), nil return fmt.Sprintf("%s/v1/completions", info.ChannelBaseUrl), nil
} }
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil
} }
} }
@@ -89,10 +89,10 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
switch info.RelayFormat { switch info.RelayFormat {
case relaycommon.RelayFormatOpenAI: case types.RelayFormatOpenAI:
adaptor := openai.Adaptor{} adaptor := openai.Adaptor{}
return adaptor.DoResponse(c, resp, info) return adaptor.DoResponse(c, resp, info)
case relaycommon.RelayFormatClaude: case types.RelayFormatClaude:
if info.IsStream { if info.IsStream {
err, usage = claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage) err, usage = claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
} else { } else {

View File

@@ -48,14 +48,14 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
} }
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if info.RelayFormat == relaycommon.RelayFormatClaude { if info.RelayFormat == types.RelayFormatClaude {
return info.BaseUrl + "/v1/chat/completions", nil return info.ChannelBaseUrl + "/v1/chat/completions", nil
} }
switch info.RelayMode { switch info.RelayMode {
case relayconstant.RelayModeEmbeddings: case relayconstant.RelayModeEmbeddings:
return info.BaseUrl + "/api/embed", nil return info.ChannelBaseUrl + "/api/embed", nil
default: default:
return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, info.RequestURLPath, info.ChannelType), nil
} }
} }

View File

@@ -94,7 +94,7 @@ func ollamaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h
if err != nil { if err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
} }
common.CloseResponseBodyGracefully(resp) service.CloseResponseBodyGracefully(resp)
err = common.Unmarshal(responseBody, &ollamaEmbeddingResponse) err = common.Unmarshal(responseBody, &ollamaEmbeddingResponse)
if err != nil { if err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
@@ -123,7 +123,7 @@ func ollamaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h
if err != nil { if err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
} }
common.IOCopyBytesGracefully(c, resp, doResponseBody) service.IOCopyBytesGracefully(c, resp, doResponseBody)
return usage, nil return usage, nil
} }

View File

@@ -105,14 +105,14 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if info.RelayMode == relayconstant.RelayModeRealtime { if info.RelayMode == relayconstant.RelayModeRealtime {
if strings.HasPrefix(info.BaseUrl, "https://") { if strings.HasPrefix(info.ChannelBaseUrl, "https://") {
baseUrl := strings.TrimPrefix(info.BaseUrl, "https://") baseUrl := strings.TrimPrefix(info.ChannelBaseUrl, "https://")
baseUrl = "wss://" + baseUrl baseUrl = "wss://" + baseUrl
info.BaseUrl = baseUrl info.ChannelBaseUrl = baseUrl
} else if strings.HasPrefix(info.BaseUrl, "http://") { } else if strings.HasPrefix(info.ChannelBaseUrl, "http://") {
baseUrl := strings.TrimPrefix(info.BaseUrl, "http://") baseUrl := strings.TrimPrefix(info.ChannelBaseUrl, "http://")
baseUrl = "ws://" + baseUrl baseUrl = "ws://" + baseUrl
info.BaseUrl = baseUrl info.ChannelBaseUrl = baseUrl
} }
} }
switch info.ChannelType { switch info.ChannelType {
@@ -126,7 +126,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion) requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion)
task := strings.TrimPrefix(requestURL, "/v1/") task := strings.TrimPrefix(requestURL, "/v1/")
if info.RelayFormat == relaycommon.RelayFormatClaude { if info.RelayFormat == types.RelayFormatClaude {
task = strings.TrimPrefix(task, "messages") task = strings.TrimPrefix(task, "messages")
task = "chat/completions" + task task = "chat/completions" + task
} }
@@ -136,7 +136,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
responsesApiVersion := "preview" responsesApiVersion := "preview"
subUrl := "/openai/v1/responses" subUrl := "/openai/v1/responses"
if strings.Contains(info.BaseUrl, "cognitiveservices.azure.com") { if strings.Contains(info.ChannelBaseUrl, "cognitiveservices.azure.com") {
subUrl = "/openai/responses" subUrl = "/openai/responses"
responsesApiVersion = apiVersion responsesApiVersion = apiVersion
} }
@@ -146,7 +146,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
} }
requestURL = fmt.Sprintf("%s?api-version=%s", subUrl, responsesApiVersion) requestURL = fmt.Sprintf("%s?api-version=%s", subUrl, responsesApiVersion)
return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, requestURL, info.ChannelType), nil
} }
model_ := info.UpstreamModelName model_ := info.UpstreamModelName
@@ -159,18 +159,18 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if info.RelayMode == relayconstant.RelayModeRealtime { if info.RelayMode == relayconstant.RelayModeRealtime {
requestURL = fmt.Sprintf("/openai/realtime?deployment=%s&api-version=%s", model_, apiVersion) requestURL = fmt.Sprintf("/openai/realtime?deployment=%s&api-version=%s", model_, apiVersion)
} }
return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, requestURL, info.ChannelType), nil
case constant.ChannelTypeMiniMax: case constant.ChannelTypeMiniMax:
return minimax.GetRequestURL(info) return minimax.GetRequestURL(info)
case constant.ChannelTypeCustom: case constant.ChannelTypeCustom:
url := info.BaseUrl url := info.ChannelBaseUrl
url = strings.Replace(url, "{model}", info.UpstreamModelName, -1) url = strings.Replace(url, "{model}", info.UpstreamModelName, -1)
return url, nil return url, nil
default: default:
if info.RelayFormat == relaycommon.RelayFormatClaude || info.RelayFormat == relaycommon.RelayFormatGemini { if info.RelayFormat == types.RelayFormatClaude || info.RelayFormat == types.RelayFormatGemini {
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil
} }
return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, info.RequestURLPath, info.ChannelType), nil
} }
} }

View File

@@ -7,10 +7,12 @@ import (
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/dto" "one-api/dto"
"one-api/logger"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant" relayconstant "one-api/relay/constant"
"one-api/relay/helper" "one-api/relay/helper"
"one-api/service" "one-api/service"
"one-api/types"
"strings" "strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -21,11 +23,11 @@ func HandleStreamFormat(c *gin.Context, info *relaycommon.RelayInfo, data string
info.SendResponseCount++ info.SendResponseCount++
switch info.RelayFormat { switch info.RelayFormat {
case relaycommon.RelayFormatOpenAI: case types.RelayFormatOpenAI:
return sendStreamData(c, info, data, forceFormat, thinkToContent) return sendStreamData(c, info, data, forceFormat, thinkToContent)
case relaycommon.RelayFormatClaude: case types.RelayFormatClaude:
return handleClaudeFormat(c, data, info) return handleClaudeFormat(c, data, info)
case relaycommon.RelayFormatGemini: case types.RelayFormatGemini:
return handleGeminiFormat(c, data, info) return handleGeminiFormat(c, data, info)
} }
return nil return nil
@@ -50,7 +52,7 @@ func handleClaudeFormat(c *gin.Context, data string, info *relaycommon.RelayInfo
func handleGeminiFormat(c *gin.Context, data string, info *relaycommon.RelayInfo) error { func handleGeminiFormat(c *gin.Context, data string, info *relaycommon.RelayInfo) error {
var streamResponse dto.ChatCompletionsStreamResponse var streamResponse dto.ChatCompletionsStreamResponse
if err := common.Unmarshal(common.StringToByteSlice(data), &streamResponse); err != nil { if err := common.Unmarshal(common.StringToByteSlice(data), &streamResponse); err != nil {
common.LogError(c, "failed to unmarshal stream response: "+err.Error()) logger.LogError(c, "failed to unmarshal stream response: "+err.Error())
return err return err
} }
@@ -63,7 +65,7 @@ func handleGeminiFormat(c *gin.Context, data string, info *relaycommon.RelayInfo
geminiResponseStr, err := common.Marshal(geminiResponse) geminiResponseStr, err := common.Marshal(geminiResponse)
if err != nil { if err != nil {
common.LogError(c, "failed to marshal gemini response: "+err.Error()) logger.LogError(c, "failed to marshal gemini response: "+err.Error())
return err return err
} }
@@ -110,14 +112,14 @@ func processChatCompletions(streamResp string, streamItems []string, responseTex
var streamResponses []dto.ChatCompletionsStreamResponse var streamResponses []dto.ChatCompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil { if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil {
// 一次性解析失败,逐个解析 // 一次性解析失败,逐个解析
common.SysError("error unmarshalling stream response: " + err.Error()) common.SysLog("error unmarshalling stream response: " + err.Error())
for _, item := range streamItems { for _, item := range streamItems {
var streamResponse dto.ChatCompletionsStreamResponse var streamResponse dto.ChatCompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil { if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
return err return err
} }
if err := ProcessStreamResponse(streamResponse, responseTextBuilder, toolCount); err != nil { if err := ProcessStreamResponse(streamResponse, responseTextBuilder, toolCount); err != nil {
common.SysError("error processing stream response: " + err.Error()) common.SysLog("error processing stream response: " + err.Error())
} }
} }
return nil return nil
@@ -146,7 +148,7 @@ func processCompletions(streamResp string, streamItems []string, responseTextBui
var streamResponses []dto.CompletionsStreamResponse var streamResponses []dto.CompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil { if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil {
// 一次性解析失败,逐个解析 // 一次性解析失败,逐个解析
common.SysError("error unmarshalling stream response: " + err.Error()) common.SysLog("error unmarshalling stream response: " + err.Error())
for _, item := range streamItems { for _, item := range streamItems {
var streamResponse dto.CompletionsStreamResponse var streamResponse dto.CompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil { if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
@@ -201,7 +203,7 @@ func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream
usage *dto.Usage, containStreamUsage bool) { usage *dto.Usage, containStreamUsage bool) {
switch info.RelayFormat { switch info.RelayFormat {
case relaycommon.RelayFormatOpenAI: case types.RelayFormatOpenAI:
if info.ShouldIncludeUsage && !containStreamUsage { if info.ShouldIncludeUsage && !containStreamUsage {
response := helper.GenerateFinalUsageResponse(responseId, createAt, model, *usage) response := helper.GenerateFinalUsageResponse(responseId, createAt, model, *usage)
response.SetSystemFingerprint(systemFingerprint) response.SetSystemFingerprint(systemFingerprint)
@@ -209,11 +211,11 @@ func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream
} }
helper.Done(c) helper.Done(c)
case relaycommon.RelayFormatClaude: case types.RelayFormatClaude:
info.ClaudeConvertInfo.Done = true info.ClaudeConvertInfo.Done = true
var streamResponse dto.ChatCompletionsStreamResponse var streamResponse dto.ChatCompletionsStreamResponse
if err := common.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil { if err := common.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil {
common.SysError("error unmarshalling stream response: " + err.Error()) common.SysLog("error unmarshalling stream response: " + err.Error())
return return
} }
@@ -224,10 +226,10 @@ func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream
_ = helper.ClaudeData(c, *resp) _ = helper.ClaudeData(c, *resp)
} }
case relaycommon.RelayFormatGemini: case types.RelayFormatGemini:
var streamResponse dto.ChatCompletionsStreamResponse var streamResponse dto.ChatCompletionsStreamResponse
if err := common.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil { if err := common.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil {
common.SysError("error unmarshalling stream response: " + err.Error()) common.SysLog("error unmarshalling stream response: " + err.Error())
return return
} }
@@ -245,7 +247,7 @@ func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream
geminiResponseStr, err := common.Marshal(geminiResponse) geminiResponseStr, err := common.Marshal(geminiResponse)
if err != nil { if err != nil {
common.SysError("error marshalling gemini response: " + err.Error()) common.SysLog("error marshalling gemini response: " + err.Error())
return return
} }

View File

@@ -10,6 +10,7 @@ import (
"one-api/common" "one-api/common"
"one-api/constant" "one-api/constant"
"one-api/dto" "one-api/dto"
"one-api/logger"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/relay/helper" "one-api/relay/helper"
"one-api/service" "one-api/service"
@@ -108,11 +109,11 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo
func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
if resp == nil || resp.Body == nil { if resp == nil || resp.Body == nil {
common.LogError(c, "invalid response or response body") logger.LogError(c, "invalid response or response body")
return nil, types.NewOpenAIError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse, http.StatusInternalServerError) return nil, types.NewOpenAIError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse, http.StatusInternalServerError)
} }
defer common.CloseResponseBodyGracefully(resp) defer service.CloseResponseBodyGracefully(resp)
model := info.UpstreamModelName model := info.UpstreamModelName
var responseId string var responseId string
@@ -129,7 +130,7 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
if lastStreamData != "" { if lastStreamData != "" {
err := HandleStreamFormat(c, info, lastStreamData, info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent) err := HandleStreamFormat(c, info, lastStreamData, info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent)
if err != nil { if err != nil {
common.SysError("error handling stream format: " + err.Error()) common.SysLog("error handling stream format: " + err.Error())
} }
} }
if len(data) > 0 { if len(data) > 0 {
@@ -143,10 +144,10 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
shouldSendLastResp := true shouldSendLastResp := true
if err := handleLastResponse(lastStreamData, &responseId, &createAt, &systemFingerprint, &model, &usage, if err := handleLastResponse(lastStreamData, &responseId, &createAt, &systemFingerprint, &model, &usage,
&containStreamUsage, info, &shouldSendLastResp); err != nil { &containStreamUsage, info, &shouldSendLastResp); err != nil {
common.LogError(c, fmt.Sprintf("error handling last response: %s, lastStreamData: [%s]", err.Error(), lastStreamData)) logger.LogError(c, fmt.Sprintf("error handling last response: %s, lastStreamData: [%s]", err.Error(), lastStreamData))
} }
if info.RelayFormat == relaycommon.RelayFormatOpenAI { if info.RelayFormat == types.RelayFormatOpenAI {
if shouldSendLastResp { if shouldSendLastResp {
_ = sendStreamData(c, info, lastStreamData, info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent) _ = sendStreamData(c, info, lastStreamData, info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent)
} }
@@ -154,7 +155,7 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
// 处理token计算 // 处理token计算
if err := processTokens(info.RelayMode, streamItems, &responseTextBuilder, &toolCount); err != nil { if err := processTokens(info.RelayMode, streamItems, &responseTextBuilder, &toolCount); err != nil {
common.LogError(c, "error processing tokens: "+err.Error()) logger.LogError(c, "error processing tokens: "+err.Error())
} }
if !containStreamUsage { if !containStreamUsage {
@@ -173,7 +174,7 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
} }
func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
defer common.CloseResponseBodyGracefully(resp) defer service.CloseResponseBodyGracefully(resp)
var simpleResponse dto.OpenAITextResponse var simpleResponse dto.OpenAITextResponse
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
@@ -210,7 +211,7 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
} }
switch info.RelayFormat { switch info.RelayFormat {
case relaycommon.RelayFormatOpenAI: case types.RelayFormatOpenAI:
if forceFormat { if forceFormat {
responseBody, err = common.Marshal(simpleResponse) responseBody, err = common.Marshal(simpleResponse)
if err != nil { if err != nil {
@@ -219,14 +220,14 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
} else { } else {
break break
} }
case relaycommon.RelayFormatClaude: case types.RelayFormatClaude:
claudeResp := service.ResponseOpenAI2Claude(&simpleResponse, info) claudeResp := service.ResponseOpenAI2Claude(&simpleResponse, info)
claudeRespStr, err := common.Marshal(claudeResp) claudeRespStr, err := common.Marshal(claudeResp)
if err != nil { if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody) return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
} }
responseBody = claudeRespStr responseBody = claudeRespStr
case relaycommon.RelayFormatGemini: case types.RelayFormatGemini:
geminiResp := service.ResponseOpenAI2Gemini(&simpleResponse, info) geminiResp := service.ResponseOpenAI2Gemini(&simpleResponse, info)
geminiRespStr, err := common.Marshal(geminiResp) geminiRespStr, err := common.Marshal(geminiResp)
if err != nil { if err != nil {
@@ -235,7 +236,7 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
responseBody = geminiRespStr responseBody = geminiRespStr
} }
common.IOCopyBytesGracefully(c, resp, responseBody) service.IOCopyBytesGracefully(c, resp, responseBody)
return &simpleResponse.Usage, nil return &simpleResponse.Usage, nil
} }
@@ -247,7 +248,7 @@ func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
// if the upstream returns a specific status code, once the upstream has already written the header, // if the upstream returns a specific status code, once the upstream has already written the header,
// the subsequent failure of the response body should be regarded as a non-recoverable error, // the subsequent failure of the response body should be regarded as a non-recoverable error,
// and can be terminated directly. // and can be terminated directly.
defer common.CloseResponseBodyGracefully(resp) defer service.CloseResponseBodyGracefully(resp)
usage := &dto.Usage{} usage := &dto.Usage{}
usage.PromptTokens = info.PromptTokens usage.PromptTokens = info.PromptTokens
usage.TotalTokens = info.PromptTokens usage.TotalTokens = info.PromptTokens
@@ -258,13 +259,13 @@ func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
c.Writer.WriteHeaderNow() c.Writer.WriteHeaderNow()
_, err := io.Copy(c.Writer, resp.Body) _, err := io.Copy(c.Writer, resp.Body)
if err != nil { if err != nil {
common.LogError(c, err.Error()) logger.LogError(c, err.Error())
} }
return usage return usage
} }
func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*types.NewAPIError, *dto.Usage) { func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*types.NewAPIError, *dto.Usage) {
defer common.CloseResponseBodyGracefully(resp) defer service.CloseResponseBodyGracefully(resp)
// count tokens by audio file duration // count tokens by audio file duration
audioTokens, err := countAudioTokens(c) audioTokens, err := countAudioTokens(c)
@@ -276,7 +277,7 @@ func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
} }
// 写入新的 response body // 写入新的 response body
common.IOCopyBytesGracefully(c, resp, responseBody) service.IOCopyBytesGracefully(c, resp, responseBody)
usage := &dto.Usage{} usage := &dto.Usage{}
usage.PromptTokens = audioTokens usage.PromptTokens = audioTokens
@@ -386,7 +387,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types.
errChan <- fmt.Errorf("error counting text token: %v", err) errChan <- fmt.Errorf("error counting text token: %v", err)
return return
} }
common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken)) logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
localUsage.TotalTokens += textToken + audioToken localUsage.TotalTokens += textToken + audioToken
localUsage.InputTokens += textToken + audioToken localUsage.InputTokens += textToken + audioToken
localUsage.InputTokenDetails.TextTokens += textToken localUsage.InputTokenDetails.TextTokens += textToken
@@ -459,7 +460,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types.
errChan <- fmt.Errorf("error counting text token: %v", err) errChan <- fmt.Errorf("error counting text token: %v", err)
return return
} }
common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken)) logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
localUsage.TotalTokens += textToken + audioToken localUsage.TotalTokens += textToken + audioToken
info.IsFirstRequest = false info.IsFirstRequest = false
localUsage.InputTokens += textToken + audioToken localUsage.InputTokens += textToken + audioToken
@@ -474,9 +475,9 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types.
localUsage = &dto.RealtimeUsage{} localUsage = &dto.RealtimeUsage{}
// print now usage // print now usage
} }
common.LogInfo(c, fmt.Sprintf("realtime streaming sumUsage: %v", sumUsage)) logger.LogInfo(c, fmt.Sprintf("realtime streaming sumUsage: %v", sumUsage))
common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage)) logger.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage)) logger.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
} else if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdated || realtimeEvent.Type == dto.RealtimeEventTypeSessionCreated { } else if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdated || realtimeEvent.Type == dto.RealtimeEventTypeSessionCreated {
realtimeSession := realtimeEvent.Session realtimeSession := realtimeEvent.Session
@@ -491,7 +492,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types.
errChan <- fmt.Errorf("error counting text token: %v", err) errChan <- fmt.Errorf("error counting text token: %v", err)
return return
} }
common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken)) logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
localUsage.TotalTokens += textToken + audioToken localUsage.TotalTokens += textToken + audioToken
localUsage.OutputTokens += textToken + audioToken localUsage.OutputTokens += textToken + audioToken
localUsage.OutputTokenDetails.TextTokens += textToken localUsage.OutputTokenDetails.TextTokens += textToken
@@ -517,7 +518,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types.
case <-targetClosed: case <-targetClosed:
case err := <-errChan: case err := <-errChan:
//return service.OpenAIErrorWrapper(err, "realtime_error", http.StatusInternalServerError), nil //return service.OpenAIErrorWrapper(err, "realtime_error", http.StatusInternalServerError), nil
common.LogError(c, "realtime error: "+err.Error()) logger.LogError(c, "realtime error: "+err.Error())
case <-c.Done(): case <-c.Done():
} }
@@ -553,7 +554,7 @@ func preConsumeUsage(ctx *gin.Context, info *relaycommon.RelayInfo, usage *dto.R
} }
func OpenaiHandlerWithUsage(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { func OpenaiHandlerWithUsage(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
defer common.CloseResponseBodyGracefully(resp) defer service.CloseResponseBodyGracefully(resp)
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
@@ -567,7 +568,7 @@ func OpenaiHandlerWithUsage(c *gin.Context, info *relaycommon.RelayInfo, resp *h
} }
// 写入新的 response body // 写入新的 response body
common.IOCopyBytesGracefully(c, resp, responseBody) service.IOCopyBytesGracefully(c, resp, responseBody)
// Once we've written to the client, we should not return errors anymore // Once we've written to the client, we should not return errors anymore
// because the upstream has already consumed resources and returned content // because the upstream has already consumed resources and returned content

View File

@@ -6,6 +6,7 @@ import (
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/dto" "one-api/dto"
"one-api/logger"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/relay/helper" "one-api/relay/helper"
"one-api/service" "one-api/service"
@@ -16,7 +17,7 @@ import (
) )
func OaiResponsesHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { func OaiResponsesHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
defer common.CloseResponseBodyGracefully(resp) defer service.CloseResponseBodyGracefully(resp)
// read response body // read response body
var responsesResponse dto.OpenAIResponsesResponse var responsesResponse dto.OpenAIResponsesResponse
@@ -33,7 +34,7 @@ func OaiResponsesHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
} }
// 写入新的 response body // 写入新的 response body
common.IOCopyBytesGracefully(c, resp, responseBody) service.IOCopyBytesGracefully(c, resp, responseBody)
// compute usage // compute usage
usage := dto.Usage{} usage := dto.Usage{}
@@ -54,7 +55,7 @@ func OaiResponsesHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
func OaiResponsesStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { func OaiResponsesStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
if resp == nil || resp.Body == nil { if resp == nil || resp.Body == nil {
common.LogError(c, "invalid response or response body") logger.LogError(c, "invalid response or response body")
return nil, types.NewError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse) return nil, types.NewError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse)
} }

View File

@@ -42,7 +42,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
} }
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", info.BaseUrl), nil return fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", info.ChannelBaseUrl), nil
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {

View File

@@ -58,15 +58,15 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError,
go func() { go func() {
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
common.SysError("error reading stream response: " + err.Error()) common.SysLog("error reading stream response: " + err.Error())
stopChan <- true stopChan <- true
return return
} }
common.CloseResponseBodyGracefully(resp) service.CloseResponseBodyGracefully(resp)
var palmResponse PaLMChatResponse var palmResponse PaLMChatResponse
err = json.Unmarshal(responseBody, &palmResponse) err = json.Unmarshal(responseBody, &palmResponse)
if err != nil { if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error()) common.SysLog("error unmarshalling stream response: " + err.Error())
stopChan <- true stopChan <- true
return return
} }
@@ -78,7 +78,7 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError,
} }
jsonResponse, err := json.Marshal(fullTextResponse) jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil { if err != nil {
common.SysError("error marshalling stream response: " + err.Error()) common.SysLog("error marshalling stream response: " + err.Error())
stopChan <- true stopChan <- true
return return
} }
@@ -96,7 +96,7 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError,
return false return false
} }
}) })
common.CloseResponseBodyGracefully(resp) service.CloseResponseBodyGracefully(resp)
return nil, responseText return nil, responseText
} }
@@ -105,7 +105,7 @@ func palmHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respons
if err != nil { if err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
} }
common.CloseResponseBodyGracefully(resp) service.CloseResponseBodyGracefully(resp)
var palmResponse PaLMChatResponse var palmResponse PaLMChatResponse
err = json.Unmarshal(responseBody, &palmResponse) err = json.Unmarshal(responseBody, &palmResponse)
if err != nil { if err != nil {
@@ -133,6 +133,6 @@ func palmHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respons
} }
c.Writer.Header().Set("Content-Type", "application/json") c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode) c.Writer.WriteHeader(resp.StatusCode)
common.IOCopyBytesGracefully(c, resp, jsonResponse) service.IOCopyBytesGracefully(c, resp, jsonResponse)
return &usage, nil return &usage, nil
} }

View File

@@ -42,7 +42,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
} }
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/chat/completions", info.BaseUrl), nil return fmt.Sprintf("%s/chat/completions", info.ChannelBaseUrl), nil
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {

View File

@@ -43,15 +43,15 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if info.RelayMode == constant.RelayModeRerank { if info.RelayMode == constant.RelayModeRerank {
return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil return fmt.Sprintf("%s/v1/rerank", info.ChannelBaseUrl), nil
} else if info.RelayMode == constant.RelayModeEmbeddings { } else if info.RelayMode == constant.RelayModeEmbeddings {
return fmt.Sprintf("%s/v1/embeddings", info.BaseUrl), nil return fmt.Sprintf("%s/v1/embeddings", info.ChannelBaseUrl), nil
} else if info.RelayMode == constant.RelayModeChatCompletions { } else if info.RelayMode == constant.RelayModeChatCompletions {
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil
} else if info.RelayMode == constant.RelayModeCompletions { } else if info.RelayMode == constant.RelayModeCompletions {
return fmt.Sprintf("%s/v1/completions", info.BaseUrl), nil return fmt.Sprintf("%s/v1/completions", info.ChannelBaseUrl), nil
} }
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {

View File

@@ -4,9 +4,9 @@ import (
"encoding/json" "encoding/json"
"io" "io"
"net/http" "net/http"
"one-api/common"
"one-api/dto" "one-api/dto"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/service"
"one-api/types" "one-api/types"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -17,7 +17,7 @@ func siliconflowRerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp
if err != nil { if err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
} }
common.CloseResponseBodyGracefully(resp) service.CloseResponseBodyGracefully(resp)
var siliconflowResp SFRerankResponse var siliconflowResp SFRerankResponse
err = json.Unmarshal(responseBody, &siliconflowResp) err = json.Unmarshal(responseBody, &siliconflowResp)
if err != nil { if err != nil {
@@ -39,6 +39,6 @@ func siliconflowRerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp
} }
c.Writer.Header().Set("Content-Type", "application/json") c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode) c.Writer.WriteHeader(resp.StatusCode)
common.IOCopyBytesGracefully(c, resp, jsonResponse) service.IOCopyBytesGracefully(c, resp, jsonResponse)
return usage, nil return usage, nil
} }

View File

@@ -76,7 +76,7 @@ type TaskAdaptor struct {
func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) { func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
a.ChannelType = info.ChannelType a.ChannelType = info.ChannelType
a.baseURL = info.BaseUrl a.baseURL = info.ChannelBaseUrl
// apiKey format: "access_key|secret_key" // apiKey format: "access_key|secret_key"
keyParts := strings.Split(info.ApiKey, "|") keyParts := strings.Split(info.ApiKey, "|")

View File

@@ -81,7 +81,7 @@ type TaskAdaptor struct {
func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) { func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
a.ChannelType = info.ChannelType a.ChannelType = info.ChannelType
a.baseURL = info.BaseUrl a.baseURL = info.ChannelBaseUrl
a.apiKey = info.ApiKey a.apiKey = info.ApiKey
// apiKey format: "access_key|secret_key" // apiKey format: "access_key|secret_key"

View File

@@ -59,7 +59,7 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom
} }
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) { func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) {
baseURL := info.BaseUrl baseURL := info.ChannelBaseUrl
fullRequestURL := fmt.Sprintf("%s%s", baseURL, "/suno/submit/"+info.Action) fullRequestURL := fmt.Sprintf("%s%s", baseURL, "/suno/submit/"+info.Action)
return fullRequestURL, nil return fullRequestURL, nil
} }
@@ -139,7 +139,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(byteBody)) req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(byteBody))
if err != nil { if err != nil {
common.SysError(fmt.Sprintf("Get Task error: %v", err)) common.SysLog(fmt.Sprintf("Get Task error: %v", err))
return nil, err return nil, err
} }
defer req.Body.Close() defer req.Body.Close()

View File

@@ -86,7 +86,7 @@ type TaskAdaptor struct {
func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) { func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
a.ChannelType = info.ChannelType a.ChannelType = info.ChannelType
a.baseURL = info.BaseUrl a.baseURL = info.ChannelBaseUrl
} }
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) *dto.TaskError { func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) *dto.TaskError {

View File

@@ -53,7 +53,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
} }
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/", info.BaseUrl), nil return fmt.Sprintf("%s/", info.ChannelBaseUrl), nil
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {

View File

@@ -106,7 +106,7 @@ func tencentStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *htt
var tencentResponse TencentChatResponse var tencentResponse TencentChatResponse
err := json.Unmarshal([]byte(data), &tencentResponse) err := json.Unmarshal([]byte(data), &tencentResponse)
if err != nil { if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error()) common.SysLog("error unmarshalling stream response: " + err.Error())
continue continue
} }
@@ -117,17 +117,17 @@ func tencentStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *htt
err = helper.ObjectData(c, response) err = helper.ObjectData(c, response)
if err != nil { if err != nil {
common.SysError(err.Error()) common.SysLog(err.Error())
} }
} }
if err := scanner.Err(); err != nil { if err := scanner.Err(); err != nil {
common.SysError("error reading stream: " + err.Error()) common.SysLog("error reading stream: " + err.Error())
} }
helper.Done(c) helper.Done(c)
common.CloseResponseBodyGracefully(resp) service.CloseResponseBodyGracefully(resp)
return service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens), nil return service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens), nil
} }
@@ -138,7 +138,7 @@ func tencentHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Resp
if err != nil { if err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
} }
common.CloseResponseBodyGracefully(resp) service.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &tencentSb) err = json.Unmarshal(responseBody, &tencentSb)
if err != nil { if err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
@@ -156,7 +156,7 @@ func tencentHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Resp
} }
c.Writer.Header().Set("Content-Type", "application/json") c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode) c.Writer.WriteHeader(resp.StatusCode)
common.IOCopyBytesGracefully(c, resp, jsonResponse) service.IOCopyBytesGracefully(c, resp, jsonResponse)
return &fullTextResponse.Usage, nil return &fullTextResponse.Usage, nil
} }

View File

@@ -188,17 +188,17 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
switch info.RelayMode { switch info.RelayMode {
case constant.RelayModeChatCompletions: case constant.RelayModeChatCompletions:
if strings.HasPrefix(info.UpstreamModelName, "bot") { if strings.HasPrefix(info.UpstreamModelName, "bot") {
return fmt.Sprintf("%s/api/v3/bots/chat/completions", info.BaseUrl), nil return fmt.Sprintf("%s/api/v3/bots/chat/completions", info.ChannelBaseUrl), nil
} }
return fmt.Sprintf("%s/api/v3/chat/completions", info.BaseUrl), nil return fmt.Sprintf("%s/api/v3/chat/completions", info.ChannelBaseUrl), nil
case constant.RelayModeEmbeddings: case constant.RelayModeEmbeddings:
return fmt.Sprintf("%s/api/v3/embeddings", info.BaseUrl), nil return fmt.Sprintf("%s/api/v3/embeddings", info.ChannelBaseUrl), nil
case constant.RelayModeImagesGenerations: case constant.RelayModeImagesGenerations:
return fmt.Sprintf("%s/api/v3/images/generations", info.BaseUrl), nil return fmt.Sprintf("%s/api/v3/images/generations", info.ChannelBaseUrl), nil
case constant.RelayModeImagesEdits: case constant.RelayModeImagesEdits:
return fmt.Sprintf("%s/api/v3/images/edits", info.BaseUrl), nil return fmt.Sprintf("%s/api/v3/images/edits", info.ChannelBaseUrl), nil
case constant.RelayModeRerank: case constant.RelayModeRerank:
return fmt.Sprintf("%s/api/v3/rerank", info.BaseUrl), nil return fmt.Sprintf("%s/api/v3/rerank", info.ChannelBaseUrl), nil
default: default:
} }
return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode) return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode)

View File

@@ -39,7 +39,7 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
xaiRequest := ImageRequest{ xaiRequest := ImageRequest{
Model: request.Model, Model: request.Model,
Prompt: request.Prompt, Prompt: request.Prompt,
N: request.N, N: int(request.N),
ResponseFormat: request.ResponseFormat, ResponseFormat: request.ResponseFormat,
} }
return xaiRequest, nil return xaiRequest, nil
@@ -49,7 +49,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
} }
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, info.RequestURLPath, info.ChannelType), nil
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {

View File

@@ -47,7 +47,7 @@ func xAIStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
var xAIResp *dto.ChatCompletionsStreamResponse var xAIResp *dto.ChatCompletionsStreamResponse
err := json.Unmarshal([]byte(data), &xAIResp) err := json.Unmarshal([]byte(data), &xAIResp)
if err != nil { if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error()) common.SysLog("error unmarshalling stream response: " + err.Error())
return true return true
} }
@@ -63,7 +63,7 @@ func xAIStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
_ = openai.ProcessStreamResponse(*openaiResponse, &responseTextBuilder, &toolCount) _ = openai.ProcessStreamResponse(*openaiResponse, &responseTextBuilder, &toolCount)
err = helper.ObjectData(c, openaiResponse) err = helper.ObjectData(c, openaiResponse)
if err != nil { if err != nil {
common.SysError(err.Error()) common.SysLog(err.Error())
} }
return true return true
}) })
@@ -74,12 +74,12 @@ func xAIStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
} }
helper.Done(c) helper.Done(c)
common.CloseResponseBodyGracefully(resp) service.CloseResponseBodyGracefully(resp)
return usage, nil return usage, nil
} }
func xAIHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { func xAIHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
defer common.CloseResponseBodyGracefully(resp) defer service.CloseResponseBodyGracefully(resp)
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
@@ -101,7 +101,7 @@ func xAIHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response
return nil, types.NewError(err, types.ErrorCodeBadResponseBody) return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
} }
common.IOCopyBytesGracefully(c, resp, encodeJson) service.IOCopyBytesGracefully(c, resp, encodeJson)
return xaiResponse.Usage, nil return xaiResponse.Usage, nil
} }

View File

@@ -143,7 +143,7 @@ func xunfeiStreamHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, a
response := streamResponseXunfei2OpenAI(&xunfeiResponse) response := streamResponseXunfei2OpenAI(&xunfeiResponse)
jsonResponse, err := json.Marshal(response) jsonResponse, err := json.Marshal(response)
if err != nil { if err != nil {
common.SysError("error marshalling stream response: " + err.Error()) common.SysLog("error marshalling stream response: " + err.Error())
return true return true
} }
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
@@ -206,6 +206,11 @@ func xunfeiMakeRequest(textRequest dto.GeneralOpenAIRequest, domain, authUrl, ap
if err != nil || resp.StatusCode != 101 { if err != nil || resp.StatusCode != 101 {
return nil, nil, err return nil, nil, err
} }
defer func() {
conn.Close()
}()
data := requestOpenAI2Xunfei(textRequest, appId, domain) data := requestOpenAI2Xunfei(textRequest, appId, domain)
err = conn.WriteJSON(data) err = conn.WriteJSON(data)
if err != nil { if err != nil {
@@ -218,20 +223,19 @@ func xunfeiMakeRequest(textRequest dto.GeneralOpenAIRequest, domain, authUrl, ap
for { for {
_, msg, err := conn.ReadMessage() _, msg, err := conn.ReadMessage()
if err != nil { if err != nil {
common.SysError("error reading stream response: " + err.Error()) common.SysLog("error reading stream response: " + err.Error())
break break
} }
var response XunfeiChatResponse var response XunfeiChatResponse
err = json.Unmarshal(msg, &response) err = json.Unmarshal(msg, &response)
if err != nil { if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error()) common.SysLog("error unmarshalling stream response: " + err.Error())
break break
} }
dataChan <- response dataChan <- response
if response.Payload.Choices.Status == 2 { if response.Payload.Choices.Status == 2 {
err := conn.Close()
if err != nil { if err != nil {
common.SysError("error closing websocket connection: " + err.Error()) common.SysLog("error closing websocket connection: " + err.Error())
} }
break break
} }

View File

@@ -45,7 +45,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if info.IsStream { if info.IsStream {
method = "sse-invoke" method = "sse-invoke"
} }
return fmt.Sprintf("%s/api/paas/v3/model-api/%s/%s", info.BaseUrl, info.UpstreamModelName, method), nil return fmt.Sprintf("%s/api/paas/v3/model-api/%s/%s", info.ChannelBaseUrl, info.UpstreamModelName, method), nil
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {

View File

@@ -10,6 +10,7 @@ import (
"one-api/dto" "one-api/dto"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/relay/helper" "one-api/relay/helper"
"one-api/service"
"one-api/types" "one-api/types"
"strings" "strings"
"sync" "sync"
@@ -38,7 +39,7 @@ func getZhipuToken(apikey string) string {
split := strings.Split(apikey, ".") split := strings.Split(apikey, ".")
if len(split) != 2 { if len(split) != 2 {
common.SysError("invalid zhipu key: " + apikey) common.SysLog("invalid zhipu key: " + apikey)
return "" return ""
} }
@@ -186,7 +187,7 @@ func zhipuStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.
response := streamResponseZhipu2OpenAI(data) response := streamResponseZhipu2OpenAI(data)
jsonResponse, err := json.Marshal(response) jsonResponse, err := json.Marshal(response)
if err != nil { if err != nil {
common.SysError("error marshalling stream response: " + err.Error()) common.SysLog("error marshalling stream response: " + err.Error())
return true return true
} }
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
@@ -195,13 +196,13 @@ func zhipuStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.
var zhipuResponse ZhipuStreamMetaResponse var zhipuResponse ZhipuStreamMetaResponse
err := json.Unmarshal([]byte(data), &zhipuResponse) err := json.Unmarshal([]byte(data), &zhipuResponse)
if err != nil { if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error()) common.SysLog("error unmarshalling stream response: " + err.Error())
return true return true
} }
response, zhipuUsage := streamMetaResponseZhipu2OpenAI(&zhipuResponse) response, zhipuUsage := streamMetaResponseZhipu2OpenAI(&zhipuResponse)
jsonResponse, err := json.Marshal(response) jsonResponse, err := json.Marshal(response)
if err != nil { if err != nil {
common.SysError("error marshalling stream response: " + err.Error()) common.SysLog("error marshalling stream response: " + err.Error())
return true return true
} }
usage = zhipuUsage usage = zhipuUsage
@@ -212,7 +213,7 @@ func zhipuStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.
return false return false
} }
}) })
common.CloseResponseBodyGracefully(resp) service.CloseResponseBodyGracefully(resp)
return usage, nil return usage, nil
} }
@@ -222,7 +223,7 @@ func zhipuHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respon
if err != nil { if err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError) return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
} }
common.CloseResponseBodyGracefully(resp) service.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &zhipuResponse) err = json.Unmarshal(responseBody, &zhipuResponse)
if err != nil { if err != nil {
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError) return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)

Some files were not shown because too many files have changed in this diff Show More