diff --git a/common/constants.go b/common/constants.go index ed50fae6..dd4f3b04 100644 --- a/common/constants.go +++ b/common/constants.go @@ -62,6 +62,10 @@ var EmailDomainWhitelist = []string{ "yahoo.com", "foxmail.com", } +var EmailLoginAuthServerList = []string{ + "smtp.sendcloud.net", + "smtp.azurecomm.net", +} var DebugEnabled bool var MemoryCacheEnabled bool diff --git a/common/email.go b/common/email.go index 8eb575f3..18e6dbf7 100644 --- a/common/email.go +++ b/common/email.go @@ -5,6 +5,7 @@ import ( "encoding/base64" "fmt" "net/smtp" + "slices" "strings" "time" ) @@ -79,7 +80,7 @@ func SendEmail(subject string, receiver string, content string) error { if err != nil { return err } - } else if isOutlookServer(SMTPAccount) || SMTPServer == "smtp.azurecomm.net" { + } else if isOutlookServer(SMTPAccount) || slices.Contains(EmailLoginAuthServerList, SMTPServer) { auth = LoginAuth(SMTPAccount, SMTPToken) err = smtp.SendMail(addr, auth, SMTPFrom, to, mail) } else { diff --git a/common/limiter/limiter.go b/common/limiter/limiter.go new file mode 100644 index 00000000..ef5d1935 --- /dev/null +++ b/common/limiter/limiter.go @@ -0,0 +1,89 @@ +package limiter + +import ( + "context" + _ "embed" + "fmt" + "github.com/go-redis/redis/v8" + "one-api/common" + "sync" +) + +//go:embed lua/rate_limit.lua +var rateLimitScript string + +type RedisLimiter struct { + client *redis.Client + limitScriptSHA string +} + +var ( + instance *RedisLimiter + once sync.Once +) + +func New(ctx context.Context, r *redis.Client) *RedisLimiter { + once.Do(func() { + // 预加载脚本 + limitSHA, err := r.ScriptLoad(ctx, rateLimitScript).Result() + if err != nil { + common.SysLog(fmt.Sprintf("Failed to load rate limit script: %v", err)) + } + instance = &RedisLimiter{ + client: r, + limitScriptSHA: limitSHA, + } + }) + + return instance +} + +func (rl *RedisLimiter) Allow(ctx context.Context, key string, opts ...Option) (bool, error) { + // 默认配置 + config := &Config{ + Capacity: 10, + Rate: 1, + Requested: 1, + } + + // 应用选项模式 + for _, opt := range opts { + opt(config) + } + + // 执行限流 + result, err := rl.client.EvalSha( + ctx, + rl.limitScriptSHA, + []string{key}, + config.Requested, + config.Rate, + config.Capacity, + ).Int() + + if err != nil { + return false, fmt.Errorf("rate limit failed: %w", err) + } + return result == 1, nil +} + +// Config 配置选项模式 +type Config struct { + Capacity int64 + Rate int64 + Requested int64 +} + +type Option func(*Config) + +func WithCapacity(c int64) Option { + return func(cfg *Config) { cfg.Capacity = c } +} + +func WithRate(r int64) Option { + return func(cfg *Config) { cfg.Rate = r } +} + +func WithRequested(n int64) Option { + return func(cfg *Config) { cfg.Requested = n } +} diff --git a/common/limiter/lua/rate_limit.lua b/common/limiter/lua/rate_limit.lua new file mode 100644 index 00000000..c07fd3a8 --- /dev/null +++ b/common/limiter/lua/rate_limit.lua @@ -0,0 +1,44 @@ +-- 令牌桶限流器 +-- KEYS[1]: 限流器唯一标识 +-- ARGV[1]: 请求令牌数 (通常为1) +-- ARGV[2]: 令牌生成速率 (每秒) +-- ARGV[3]: 桶容量 + +local key = KEYS[1] +local requested = tonumber(ARGV[1]) +local rate = tonumber(ARGV[2]) +local capacity = tonumber(ARGV[3]) + +-- 获取当前时间(Redis服务器时间) +local now = redis.call('TIME') +local nowInSeconds = tonumber(now[1]) + +-- 获取桶状态 +local bucket = redis.call('HMGET', key, 'tokens', 'last_time') +local tokens = tonumber(bucket[1]) +local last_time = tonumber(bucket[2]) + +-- 初始化桶(首次请求或过期) +if not tokens or not last_time then + tokens = capacity + last_time = nowInSeconds +else + -- 计算新增令牌 + local elapsed = nowInSeconds - last_time + local add_tokens = elapsed * rate + tokens = math.min(capacity, tokens + add_tokens) + last_time = nowInSeconds +end + +-- 判断是否允许请求 +local allowed = false +if tokens >= requested then + tokens = tokens - requested + allowed = true +end + +---- 更新桶状态并设置过期时间 +redis.call('HMSET', key, 'tokens', tokens, 'last_time', last_time) +--redis.call('EXPIRE', key, math.ceil(capacity / rate) + 60) -- 适当延长过期时间 + +return allowed and 1 or 0 \ No newline at end of file diff --git a/common/utils.go b/common/utils.go index e57801e3..587de537 100644 --- a/common/utils.go +++ b/common/utils.go @@ -7,7 +7,6 @@ import ( "encoding/base64" "encoding/json" "fmt" - "github.com/pkg/errors" "html/template" "io" "log" @@ -22,6 +21,7 @@ import ( "time" "github.com/google/uuid" + "github.com/pkg/errors" ) func OpenBrowser(url string) { diff --git a/controller/channel-test.go b/controller/channel-test.go index 99ba04b0..d1cb4093 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -103,7 +103,10 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr } request := buildTestRequest(testModel) - common.SysLog(fmt.Sprintf("testing channel %d with model %s , info %v ", channel.Id, testModel, info)) + // 创建一个用于日志的 info 副本,移除 ApiKey + logInfo := *info + logInfo.ApiKey = "" + common.SysLog(fmt.Sprintf("testing channel %d with model %s , info %+v ", channel.Id, testModel, logInfo)) priceData, err := helper.ModelPriceHelper(c, info, 0, int(request.MaxTokens)) if err != nil { diff --git a/middleware/model-rate-limit.go b/middleware/model-rate-limit.go index bd5f9d25..581dc451 100644 --- a/middleware/model-rate-limit.go +++ b/middleware/model-rate-limit.go @@ -5,6 +5,7 @@ import ( "fmt" "net/http" "one-api/common" + "one-api/common/limiter" "one-api/setting" "strconv" "time" @@ -78,21 +79,9 @@ func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) g ctx := context.Background() rdb := common.RDB - // 1. 检查总请求数限制(当totalMaxCount为0时会自动跳过) - totalKey := fmt.Sprintf("rateLimit:%s:%s", ModelRequestRateLimitCountMark, userId) - allowed, err := checkRedisRateLimit(ctx, rdb, totalKey, totalMaxCount, duration) - if err != nil { - fmt.Println("检查总请求数限制失败:", err.Error()) - abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed") - return - } - if !allowed { - abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到总请求数限制:%d分钟内最多请求%d次,包括失败次数,请检查您的请求是否正确", setting.ModelRequestRateLimitDurationMinutes, totalMaxCount)) - } - - // 2. 检查成功请求数限制 + // 1. 检查成功请求数限制 successKey := fmt.Sprintf("rateLimit:%s:%s", ModelRequestRateLimitSuccessCountMark, userId) - allowed, err = checkRedisRateLimit(ctx, rdb, successKey, successMaxCount, duration) + allowed, err := checkRedisRateLimit(ctx, rdb, successKey, successMaxCount, duration) if err != nil { fmt.Println("检查成功请求数限制失败:", err.Error()) abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed") @@ -103,8 +92,27 @@ func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) g return } - // 3. 记录总请求(当totalMaxCount为0时会自动跳过) - recordRedisRequest(ctx, rdb, totalKey, totalMaxCount) + //2.检查总请求数限制并记录总请求(当totalMaxCount为0时会自动跳过,使用令牌桶限流器 + totalKey := fmt.Sprintf("rateLimit:%s", userId) + // 初始化 + tb := limiter.New(ctx, rdb) + allowed, err = tb.Allow( + ctx, + totalKey, + limiter.WithCapacity(int64(totalMaxCount)*duration), + limiter.WithRate(int64(totalMaxCount)), + limiter.WithRequested(duration), + ) + + if err != nil { + fmt.Println("检查总请求数限制失败:", err.Error()) + abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed") + return + } + + if !allowed { + abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到总请求数限制:%d分钟内最多请求%d次,包括失败次数,请检查您的请求是否正确", setting.ModelRequestRateLimitDurationMinutes, totalMaxCount)) + } // 4. 处理请求 c.Next() diff --git a/model/user.go b/model/user.go index c15e5370..0aea2ff5 100644 --- a/model/user.go +++ b/model/user.go @@ -108,7 +108,7 @@ func CheckUserExistOrDeleted(username string, email string) (bool, error) { func GetMaxUserId() int { var user User - DB.Last(&user) + DB.Unscoped().Last(&user) return user.Id } diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index 0404ce85..95e7c4be 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -246,17 +246,23 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla } else { imageUrl := mediaMessage.GetImageMedia() claudeMediaMessage.Type = "image" - claudeMediaMessage.Source = &dto.ClaudeMessageSource{} + claudeMediaMessage.Source = &dto.ClaudeMessageSource{ + Type: "base64", + } // 判断是否是url if strings.HasPrefix(imageUrl.Url, "http") { - claudeMediaMessage.Source.Type = "url" - claudeMediaMessage.Source.Url = imageUrl.Url + // 是url,获取图片的类型和base64编码的数据 + fileData, err := service.GetFileBase64FromUrl(imageUrl.Url) + if err != nil { + return nil, fmt.Errorf("get file base64 from url failed: %s", err.Error()) + } + claudeMediaMessage.Source.MediaType = fileData.MimeType + claudeMediaMessage.Source.Data = fileData.Base64Data } else { _, format, base64String, err := service.DecodeBase64ImageData(imageUrl.Url) if err != nil { return nil, err } - claudeMediaMessage.Source.Type = "base64" claudeMediaMessage.Source.MediaType = "image/" + format claudeMediaMessage.Source.Data = base64String } diff --git a/relay/channel/deepseek/adaptor.go b/relay/channel/deepseek/adaptor.go index 57accc8f..f6e910e8 100644 --- a/relay/channel/deepseek/adaptor.go +++ b/relay/channel/deepseek/adaptor.go @@ -11,6 +11,7 @@ import ( "one-api/relay/channel/openai" relaycommon "one-api/relay/common" "one-api/relay/constant" + "strings" ) type Adaptor struct { @@ -36,9 +37,13 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { + fimBaseUrl := info.BaseUrl + if !strings.HasSuffix(info.BaseUrl, "/beta") { + fimBaseUrl += "/beta" + } switch info.RelayMode { case constant.RelayModeCompletions: - return fmt.Sprintf("%s/beta/completions", info.BaseUrl), nil + return fmt.Sprintf("%s/completions", fimBaseUrl), nil default: return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil } diff --git a/relay/relay-image.go b/relay/relay-image.go index 15763298..70219cc1 100644 --- a/relay/relay-image.go +++ b/relay/relay-image.go @@ -5,7 +5,6 @@ import ( "encoding/json" "errors" "fmt" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/common" @@ -17,6 +16,8 @@ import ( "one-api/service" "one-api/setting" "strings" + + "github.com/gin-gonic/gin" ) func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.ImageRequest, error) { @@ -81,6 +82,50 @@ func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto. imageRequest.Size = "1024x1024" } + err := common.UnmarshalBodyReusable(c, imageRequest) + if err != nil { + return nil, err + } + if imageRequest.Prompt == "" { + return nil, errors.New("prompt is required") + } + if strings.Contains(imageRequest.Size, "×") { + return nil, errors.New("size an unexpected error occurred in the parameter, please use 'x' instead of the multiplication sign '×'") + } + if imageRequest.N == 0 { + imageRequest.N = 1 + } + if imageRequest.Size == "" { + imageRequest.Size = "1024x1024" + } + if imageRequest.Model == "" { + imageRequest.Model = "dall-e-2" + } + // x.ai grok-2-image not support size, quality or style + if imageRequest.Size == "empty" { + imageRequest.Size = "" + } + + // Not "256x256", "512x512", or "1024x1024" + if imageRequest.Model == "dall-e-2" || imageRequest.Model == "dall-e" { + if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" { + return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024") + } + } else if imageRequest.Model == "dall-e-3" { + if imageRequest.Size != "" && imageRequest.Size != "1024x1024" && imageRequest.Size != "1024x1792" && imageRequest.Size != "1792x1024" { + return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024") + } + if imageRequest.Quality == "" { + imageRequest.Quality = "standard" + } + //if imageRequest.N != 1 { + // return nil, errors.New("n must be 1") + //} + } + // N should between 1 and 10 + //if imageRequest.N != 0 && (imageRequest.N < 1 || imageRequest.N > 10) { + // return service.OpenAIErrorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest) + //} if setting.ShouldCheckPromptSensitive() { words, err := service.CheckSensitiveInput(imageRequest.Prompt) if err != nil { diff --git a/web/src/components/ModelPricing.js b/web/src/components/ModelPricing.js index f8390c2c..16eb08f1 100644 --- a/web/src/components/ModelPricing.js +++ b/web/src/components/ModelPricing.js @@ -81,7 +81,7 @@ const ModelPricing = () => { } function renderAvailable(available) { - return ( + return available ? ( {t('您的分组可以使用该模型')} @@ -98,7 +98,7 @@ const ModelPricing = () => { > - ); + ) : null; } const columns = [ @@ -109,7 +109,12 @@ const ModelPricing = () => { // if record.enable_groups contains selectedGroup, then available is true return renderAvailable(record.enable_groups.includes(selectedGroup)); }, - sorter: (a, b) => a.available - b.available, + sorter: (a, b) => { + const aAvailable = a.enable_groups.includes(selectedGroup); + const bAvailable = b.enable_groups.includes(selectedGroup); + return Number(aAvailable) - Number(bAvailable); + }, + defaultSortOrder: 'descend', }, { title: t('模型名称'),