refactor: 调整代码,符合项目现有规范

This commit is contained in:
tbphp
2025-05-05 19:32:22 +08:00
parent 1e1d24d1b0
commit 1513ed7847
4 changed files with 65 additions and 66 deletions

View File

@@ -6,6 +6,7 @@ import (
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/limiter" "one-api/common/limiter"
"one-api/constant"
"one-api/setting" "one-api/setting"
"strconv" "strconv"
"time" "time"
@@ -19,20 +20,25 @@ const (
ModelRequestRateLimitSuccessCountMark = "MRRLS" ModelRequestRateLimitSuccessCountMark = "MRRLS"
) )
// 检查Redis中的请求限制
func checkRedisRateLimit(ctx context.Context, rdb *redis.Client, key string, maxCount int, duration int64) (bool, error) { func checkRedisRateLimit(ctx context.Context, rdb *redis.Client, key string, maxCount int, duration int64) (bool, error) {
// 如果maxCount为0表示不限制
if maxCount == 0 { if maxCount == 0 {
return true, nil return true, nil
} }
// 获取当前计数
length, err := rdb.LLen(ctx, key).Result() length, err := rdb.LLen(ctx, key).Result()
if err != nil { if err != nil {
return false, err return false, err
} }
// 如果未达到限制,允许请求
if length < int64(maxCount) { if length < int64(maxCount) {
return true, nil return true, nil
} }
// 检查时间窗口
oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result() oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result()
oldTime, err := time.Parse(timeFormat, oldTimeStr) oldTime, err := time.Parse(timeFormat, oldTimeStr)
if err != nil { if err != nil {
@@ -44,6 +50,7 @@ func checkRedisRateLimit(ctx context.Context, rdb *redis.Client, key string, max
if err != nil { if err != nil {
return false, err return false, err
} }
// 如果在时间窗口内已达到限制,拒绝请求
subTime := nowTime.Sub(oldTime).Seconds() subTime := nowTime.Sub(oldTime).Seconds()
if int64(subTime) < duration { if int64(subTime) < duration {
rdb.Expire(ctx, key, time.Duration(setting.ModelRequestRateLimitDurationMinutes)*time.Minute) rdb.Expire(ctx, key, time.Duration(setting.ModelRequestRateLimitDurationMinutes)*time.Minute)
@@ -53,7 +60,9 @@ func checkRedisRateLimit(ctx context.Context, rdb *redis.Client, key string, max
return true, nil return true, nil
} }
// 记录Redis请求
func recordRedisRequest(ctx context.Context, rdb *redis.Client, key string, maxCount int) { func recordRedisRequest(ctx context.Context, rdb *redis.Client, key string, maxCount int) {
// 如果maxCount为0不记录请求
if maxCount == 0 { if maxCount == 0 {
return return
} }
@@ -64,12 +73,14 @@ func recordRedisRequest(ctx context.Context, rdb *redis.Client, key string, maxC
rdb.Expire(ctx, key, time.Duration(setting.ModelRequestRateLimitDurationMinutes)*time.Minute) rdb.Expire(ctx, key, time.Duration(setting.ModelRequestRateLimitDurationMinutes)*time.Minute)
} }
// Redis限流处理器
func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc { func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
userId := strconv.Itoa(c.GetInt("id")) userId := strconv.Itoa(c.GetInt("id"))
ctx := context.Background() ctx := context.Background()
rdb := common.RDB rdb := common.RDB
// 1. 检查成功请求数限制
successKey := fmt.Sprintf("rateLimit:%s:%s", ModelRequestRateLimitSuccessCountMark, userId) successKey := fmt.Sprintf("rateLimit:%s:%s", ModelRequestRateLimitSuccessCountMark, userId)
allowed, err := checkRedisRateLimit(ctx, rdb, successKey, successMaxCount, duration) allowed, err := checkRedisRateLimit(ctx, rdb, successKey, successMaxCount, duration)
if err != nil { if err != nil {
@@ -82,7 +93,9 @@ func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) g
return return
} }
//2.检查总请求数限制并记录总请求当totalMaxCount为0时会自动跳过使用令牌桶限流器
totalKey := fmt.Sprintf("rateLimit:%s", userId) totalKey := fmt.Sprintf("rateLimit:%s", userId)
// 初始化
tb := limiter.New(ctx, rdb) tb := limiter.New(ctx, rdb)
allowed, err = tb.Allow( allowed, err = tb.Allow(
ctx, ctx,
@@ -102,14 +115,17 @@ func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) g
abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到总请求数限制:%d分钟内最多请求%d次包括失败次数请检查您的请求是否正确", setting.ModelRequestRateLimitDurationMinutes, totalMaxCount)) abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到总请求数限制:%d分钟内最多请求%d次包括失败次数请检查您的请求是否正确", setting.ModelRequestRateLimitDurationMinutes, totalMaxCount))
} }
// 4. 处理请求
c.Next() c.Next()
// 5. 如果请求成功,记录成功请求
if c.Writer.Status() < 400 { if c.Writer.Status() < 400 {
recordRedisRequest(ctx, rdb, successKey, successMaxCount) recordRedisRequest(ctx, rdb, successKey, successMaxCount)
} }
} }
} }
// 内存限流处理器
func memoryRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc { func memoryRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc {
inMemoryRateLimiter.Init(time.Duration(setting.ModelRequestRateLimitDurationMinutes) * time.Minute) inMemoryRateLimiter.Init(time.Duration(setting.ModelRequestRateLimitDurationMinutes) * time.Minute)
@@ -118,12 +134,15 @@ func memoryRateLimitHandler(duration int64, totalMaxCount, successMaxCount int)
totalKey := ModelRequestRateLimitCountMark + userId totalKey := ModelRequestRateLimitCountMark + userId
successKey := ModelRequestRateLimitSuccessCountMark + userId successKey := ModelRequestRateLimitSuccessCountMark + userId
// 1. 检查总请求数限制当totalMaxCount为0时跳过
if totalMaxCount > 0 && !inMemoryRateLimiter.Request(totalKey, totalMaxCount, duration) { if totalMaxCount > 0 && !inMemoryRateLimiter.Request(totalKey, totalMaxCount, duration) {
c.Status(http.StatusTooManyRequests) c.Status(http.StatusTooManyRequests)
c.Abort() c.Abort()
return return
} }
// 2. 检查成功请求数限制
// 使用一个临时key来检查限制这样可以避免实际记录
checkKey := successKey + "_check" checkKey := successKey + "_check"
if !inMemoryRateLimiter.Request(checkKey, successMaxCount, duration) { if !inMemoryRateLimiter.Request(checkKey, successMaxCount, duration) {
c.Status(http.StatusTooManyRequests) c.Status(http.StatusTooManyRequests)
@@ -131,51 +150,48 @@ func memoryRateLimitHandler(duration int64, totalMaxCount, successMaxCount int)
return return
} }
// 3. 处理请求
c.Next() c.Next()
// 4. 如果请求成功,记录到实际的成功请求计数中
if c.Writer.Status() < 400 { if c.Writer.Status() < 400 {
inMemoryRateLimiter.Request(successKey, successMaxCount, duration) inMemoryRateLimiter.Request(successKey, successMaxCount, duration)
} }
} }
} }
// ModelRequestRateLimit 模型请求限流中间件
func ModelRequestRateLimit() func(c *gin.Context) { func ModelRequestRateLimit() func(c *gin.Context) {
return func(c *gin.Context) { return func(c *gin.Context) {
// 在每个请求时检查是否启用限流
if !setting.ModelRequestRateLimitEnabled { if !setting.ModelRequestRateLimitEnabled {
c.Next() c.Next()
return return
} }
// 计算限流参数
duration := int64(setting.ModelRequestRateLimitDurationMinutes * 60) duration := int64(setting.ModelRequestRateLimitDurationMinutes * 60)
totalMaxCount := setting.ModelRequestRateLimitCount
successMaxCount := setting.ModelRequestRateLimitSuccessCount
// 获取分组
group := c.GetString("token_group") group := c.GetString("token_group")
if group == "" { if group == "" {
group = c.GetString("group") group = c.GetString(constant.ContextKeyUserGroup)
}
if group == "" {
group = "default"
} }
finalTotalCount := setting.ModelRequestRateLimitCount //获取分组的限流配置
finalSuccessCount := setting.ModelRequestRateLimitSuccessCount
foundGroupLimit := false
groupTotalCount, groupSuccessCount, found := setting.GetGroupRateLimit(group) groupTotalCount, groupSuccessCount, found := setting.GetGroupRateLimit(group)
if found { if found {
finalTotalCount = groupTotalCount totalMaxCount = groupTotalCount
finalSuccessCount = groupSuccessCount successMaxCount = groupSuccessCount
foundGroupLimit = true
common.LogWarn(c.Request.Context(), fmt.Sprintf("Using rate limit for group '%s': total=%d, success=%d", group, finalTotalCount, finalSuccessCount))
}
if !foundGroupLimit {
common.LogInfo(c.Request.Context(), fmt.Sprintf("No specific rate limit found for group '%s', using global limits: total=%d, success=%d", group, finalTotalCount, finalSuccessCount))
} }
// 根据存储类型选择并执行限流处理器
if common.RedisEnabled { if common.RedisEnabled {
redisRateLimitHandler(duration, finalTotalCount, finalSuccessCount)(c) redisRateLimitHandler(duration, totalMaxCount, successMaxCount)(c)
} else { } else {
memoryRateLimitHandler(duration, finalTotalCount, finalSuccessCount)(c) memoryRateLimitHandler(duration, totalMaxCount, successMaxCount)(c)
} }
} }
} }

View File

@@ -1,8 +1,6 @@
package model package model
import ( import (
"encoding/json"
"fmt"
"one-api/common" "one-api/common"
"one-api/setting" "one-api/setting"
"one-api/setting/config" "one-api/setting/config"
@@ -94,8 +92,7 @@ func InitOptionMap() {
common.OptionMap["ModelRequestRateLimitCount"] = strconv.Itoa(setting.ModelRequestRateLimitCount) common.OptionMap["ModelRequestRateLimitCount"] = strconv.Itoa(setting.ModelRequestRateLimitCount)
common.OptionMap["ModelRequestRateLimitDurationMinutes"] = strconv.Itoa(setting.ModelRequestRateLimitDurationMinutes) common.OptionMap["ModelRequestRateLimitDurationMinutes"] = strconv.Itoa(setting.ModelRequestRateLimitDurationMinutes)
common.OptionMap["ModelRequestRateLimitSuccessCount"] = strconv.Itoa(setting.ModelRequestRateLimitSuccessCount) common.OptionMap["ModelRequestRateLimitSuccessCount"] = strconv.Itoa(setting.ModelRequestRateLimitSuccessCount)
jsonBytes, _ := json.Marshal(map[string][2]int{}) common.OptionMap["ModelRequestRateLimitGroup"] = setting.ModelRequestRateLimitGroup2JSONString()
common.OptionMap["ModelRequestRateLimitGroup"] = string(jsonBytes)
common.OptionMap["ModelRatio"] = operation_setting.ModelRatio2JSONString() common.OptionMap["ModelRatio"] = operation_setting.ModelRatio2JSONString()
common.OptionMap["ModelPrice"] = operation_setting.ModelPrice2JSONString() common.OptionMap["ModelPrice"] = operation_setting.ModelPrice2JSONString()
common.OptionMap["CacheRatio"] = operation_setting.CacheRatio2JSONString() common.OptionMap["CacheRatio"] = operation_setting.CacheRatio2JSONString()
@@ -154,31 +151,18 @@ func SyncOptions(frequency int) {
} }
func UpdateOption(key string, value string) error { func UpdateOption(key string, value string) error {
originalValue := value // Save to database first
if key == "ModelRequestRateLimitGroup" {
var cfg map[string][2]int
err := json.Unmarshal([]byte(originalValue), &cfg)
if err != nil {
return fmt.Errorf("无效的 JSON 格式 for %s: %w", key, err)
}
formattedValueBytes, marshalErr := json.MarshalIndent(cfg, "", " ")
if marshalErr != nil {
return fmt.Errorf("failed to marshal validated %s config: %w", key, marshalErr)
}
value = string(formattedValueBytes)
}
option := Option{ option := Option{
Key: key, Key: key,
} }
// https://gorm.io/docs/update.html#Save-All-Fields
DB.FirstOrCreate(&option, Option{Key: key}) DB.FirstOrCreate(&option, Option{Key: key})
option.Value = value option.Value = value
if err := DB.Save(&option).Error; err != nil { // Save is a combination function.
return fmt.Errorf("保存选项 %s 到数据库失败: %w", key, err) // If save value does not contain primary key, it will execute Create,
} // otherwise it will execute Update (with all fields).
DB.Save(&option)
// Update OptionMap
return updateOptionMap(key, value) return updateOptionMap(key, value)
} }
@@ -356,7 +340,7 @@ func updateOptionMap(key string, value string) (err error) {
case "ModelRequestRateLimitSuccessCount": case "ModelRequestRateLimitSuccessCount":
setting.ModelRequestRateLimitSuccessCount, _ = strconv.Atoi(value) setting.ModelRequestRateLimitSuccessCount, _ = strconv.Atoi(value)
case "ModelRequestRateLimitGroup": case "ModelRequestRateLimitGroup":
err = setting.UpdateModelRequestRateLimitGroup(value) err = setting.UpdateModelRequestRateLimitGroupByJSONString(value)
case "RetryTimes": case "RetryTimes":
common.RetryTimes, _ = strconv.Atoi(value) common.RetryTimes, _ = strconv.Atoi(value)
case "DataExportInterval": case "DataExportInterval":

View File

@@ -2,7 +2,6 @@ package setting
import ( import (
"encoding/json" "encoding/json"
"fmt"
"one-api/common" "one-api/common"
"sync" "sync"
) )
@@ -11,33 +10,31 @@ var ModelRequestRateLimitEnabled = false
var ModelRequestRateLimitDurationMinutes = 1 var ModelRequestRateLimitDurationMinutes = 1
var ModelRequestRateLimitCount = 0 var ModelRequestRateLimitCount = 0
var ModelRequestRateLimitSuccessCount = 1000 var ModelRequestRateLimitSuccessCount = 1000
var ModelRequestRateLimitGroup map[string][2]int var ModelRequestRateLimitGroup = map[string][2]int{}
var ModelRequestRateLimitMutex sync.RWMutex
var ModelRequestRateLimitGroupMutex sync.RWMutex func ModelRequestRateLimitGroup2JSONString() string {
ModelRequestRateLimitMutex.RLock()
defer ModelRequestRateLimitMutex.RUnlock()
func UpdateModelRequestRateLimitGroup(jsonStr string) error { jsonBytes, err := json.Marshal(ModelRequestRateLimitGroup)
ModelRequestRateLimitGroupMutex.Lock()
defer ModelRequestRateLimitGroupMutex.Unlock()
var newConfig map[string][2]int
if jsonStr == "" || jsonStr == "{}" {
ModelRequestRateLimitGroup = make(map[string][2]int)
common.SysLog("Model request rate limit group config cleared")
return nil
}
err := json.Unmarshal([]byte(jsonStr), &newConfig)
if err != nil { if err != nil {
return fmt.Errorf("failed to unmarshal ModelRequestRateLimitGroup config: %w", err) common.SysError("error marshalling model ratio: " + err.Error())
} }
return string(jsonBytes)
}
ModelRequestRateLimitGroup = newConfig func UpdateModelRequestRateLimitGroupByJSONString(jsonStr string) error {
return nil ModelRequestRateLimitMutex.RLock()
defer ModelRequestRateLimitMutex.RUnlock()
ModelRequestRateLimitGroup = make(map[string][2]int)
return json.Unmarshal([]byte(jsonStr), &ModelRequestRateLimitGroup)
} }
func GetGroupRateLimit(group string) (totalCount, successCount int, found bool) { func GetGroupRateLimit(group string) (totalCount, successCount int, found bool) {
ModelRequestRateLimitGroupMutex.RLock() ModelRequestRateLimitMutex.RLock()
defer ModelRequestRateLimitGroupMutex.RUnlock() defer ModelRequestRateLimitMutex.RUnlock()
if ModelRequestRateLimitGroup == nil { if ModelRequestRateLimitGroup == nil {
return 0, 0, false return 0, 0, false

View File

@@ -24,7 +24,6 @@ const RateLimitSetting = () => {
if (success) { if (success) {
let newInputs = {}; let newInputs = {};
data.forEach((item) => { data.forEach((item) => {
// 检查 key 是否在初始 inputs 中定义
if (Object.prototype.hasOwnProperty.call(inputs, item.key)) { if (Object.prototype.hasOwnProperty.call(inputs, item.key)) {
if (item.key.endsWith('Enabled')) { if (item.key.endsWith('Enabled')) {
newInputs[item.key] = item.value === 'true'; newInputs[item.key] = item.value === 'true';
@@ -33,6 +32,7 @@ const RateLimitSetting = () => {
} }
} }
}); });
setInputs(newInputs); setInputs(newInputs);
} else { } else {
showError(message); showError(message);
@@ -42,6 +42,7 @@ const RateLimitSetting = () => {
try { try {
setLoading(true); setLoading(true);
await getOptions(); await getOptions();
// showSuccess('刷新成功');
} catch (error) { } catch (error) {
showError('刷新失败'); showError('刷新失败');
} finally { } finally {
@@ -56,6 +57,7 @@ const RateLimitSetting = () => {
return ( return (
<> <>
<Spin spinning={loading} size='large'> <Spin spinning={loading} size='large'>
{/* AI请求速率限制 */}
<Card style={{ marginTop: '10px' }}> <Card style={{ marginTop: '10px' }}>
<RequestRateLimit options={inputs} refresh={onRefresh} /> <RequestRateLimit options={inputs} refresh={onRefresh} />
</Card> </Card>