diff --git a/README.en.md b/README.en.md
index feb4b0bb..446c88f6 100644
--- a/README.en.md
+++ b/README.en.md
@@ -89,6 +89,8 @@ You can add custom models gpt-4-gizmo-* in channels. These are third-party model
- `MAX_FILE_DOWNLOAD_MB`: Maximum file download size in MB, default `20`
- `CRYPTO_SECRET`: Encryption key for encrypting database content
- `AZURE_DEFAULT_API_VERSION`: Azure channel default API version, if not specified in channel settings, use this version, default `2024-12-01-preview`
+- `NOTIFICATION_LIMIT_DURATION_MINUTE`: Duration of notification limit in minutes, default `10`
+- `NOTIFY_LIMIT_COUNT`: Maximum number of user notifications in the specified duration, default `2`
## Deployment
diff --git a/README.md b/README.md
index cecefca6..e678832d 100644
--- a/README.md
+++ b/README.md
@@ -95,6 +95,9 @@
- `MAX_FILE_DOWNLOAD_MB`: 最大文件下载大小,单位 MB,默认为 `20`。
- `CRYPTO_SECRET`:加密密钥,用于加密数据库内容。
- `AZURE_DEFAULT_API_VERSION`:Azure渠道默认API版本,如果渠道设置中未指定API版本,则使用此版本,默认为 `2024-12-01-preview`
+- `NOTIFICATION_LIMIT_DURATION_MINUTE`:通知限制的持续时间(分钟),默认为 `10`。
+- `NOTIFY_LIMIT_COUNT`:用户通知在指定持续时间内的最大数量,默认为 `2`。
+
## 部署
> [!TIP]
diff --git a/common/constants.go b/common/constants.go
index f967d066..04fb1b9a 100644
--- a/common/constants.go
+++ b/common/constants.go
@@ -101,7 +101,7 @@ var PreConsumedQuota = 500
var RetryTimes = 0
-var RootUserEmail = ""
+//var RootUserEmail = ""
var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
diff --git a/common/logger.go b/common/logger.go
index 93d557d8..e72a73af 100644
--- a/common/logger.go
+++ b/common/logger.go
@@ -100,6 +100,14 @@ func LogQuota(quota int) string {
}
}
+func FormatQuota(quota int) string {
+ if DisplayInCurrencyEnabled {
+ return fmt.Sprintf("$%.6f", float64(quota)/QuotaPerUnit)
+ } else {
+ return fmt.Sprintf("%d", quota)
+ }
+}
+
// LogJson 仅供测试使用 only for test
func LogJson(ctx context.Context, msg string, obj any) {
jsonStr, err := json.Marshal(obj)
diff --git a/common/model-ratio.go b/common/model-ratio.go
index bb94ad36..542cd93c 100644
--- a/common/model-ratio.go
+++ b/common/model-ratio.go
@@ -233,7 +233,11 @@ var (
modelRatioMapMutex = sync.RWMutex{}
)
-var CompletionRatio map[string]float64 = nil
+var (
+ CompletionRatio map[string]float64 = nil
+ CompletionRatioMutex = sync.RWMutex{}
+)
+
var defaultCompletionRatio = map[string]float64{
"gpt-4-gizmo-*": 2,
"gpt-4o-gizmo-*": 3,
@@ -334,10 +338,17 @@ func GetDefaultModelRatioMap() map[string]float64 {
return defaultModelRatio
}
-func CompletionRatio2JSONString() string {
+func GetCompletionRatioMap() map[string]float64 {
+ CompletionRatioMutex.Lock()
+ defer CompletionRatioMutex.Unlock()
if CompletionRatio == nil {
CompletionRatio = defaultCompletionRatio
}
+ return CompletionRatio
+}
+
+func CompletionRatio2JSONString() string {
+ GetCompletionRatioMap()
jsonBytes, err := json.Marshal(CompletionRatio)
if err != nil {
SysError("error marshalling completion ratio: " + err.Error())
@@ -346,11 +357,15 @@ func CompletionRatio2JSONString() string {
}
func UpdateCompletionRatioByJSONString(jsonStr string) error {
+ CompletionRatioMutex.Lock()
+ defer CompletionRatioMutex.Unlock()
CompletionRatio = make(map[string]float64)
return json.Unmarshal([]byte(jsonStr), &CompletionRatio)
}
func GetCompletionRatio(name string) float64 {
+ GetCompletionRatioMap()
+
if strings.Contains(name, "/") {
if ratio, ok := CompletionRatio[name]; ok {
return ratio
@@ -476,24 +491,3 @@ func GetAudioCompletionRatio(name string) float64 {
}
return 2
}
-
-//func GetAudioPricePerMinute(name string) float64 {
-// if strings.HasPrefix(name, "gpt-4o-realtime") {
-// return 0.06
-// }
-// return 0.06
-//}
-//
-//func GetAudioCompletionPricePerMinute(name string) float64 {
-// if strings.HasPrefix(name, "gpt-4o-realtime") {
-// return 0.24
-// }
-// return 0.24
-//}
-
-func GetCompletionRatioMap() map[string]float64 {
- if CompletionRatio == nil {
- CompletionRatio = defaultCompletionRatio
- }
- return CompletionRatio
-}
diff --git a/constant/env.go b/constant/env.go
index 4135e8c7..bffbfeea 100644
--- a/constant/env.go
+++ b/constant/env.go
@@ -29,6 +29,9 @@ var GeminiModelMap = map[string]string{
var GeminiVisionMaxImageNum = common.GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16)
+var NotifyLimitCount = common.GetEnvOrDefault("NOTIFY_LIMIT_COUNT", 2)
+var NotificationLimitDurationMinute = common.GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10)
+
func InitEnv() {
modelVersionMapStr := strings.TrimSpace(os.Getenv("GEMINI_MODEL_MAP"))
if modelVersionMapStr == "" {
@@ -44,5 +47,5 @@ func InitEnv() {
}
}
-// 是否生成初始令牌,默认关闭。
+// GenerateDefaultToken 是否生成初始令牌,默认关闭。
var GenerateDefaultToken = common.GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false)
diff --git a/constant/user_setting.go b/constant/user_setting.go
new file mode 100644
index 00000000..a5b921b2
--- /dev/null
+++ b/constant/user_setting.go
@@ -0,0 +1,14 @@
+package constant
+
+var (
+ UserSettingNotifyType = "notify_type" // QuotaWarningType 额度预警类型
+ UserSettingQuotaWarningThreshold = "quota_warning_threshold" // QuotaWarningThreshold 额度预警阈值
+ UserSettingWebhookUrl = "webhook_url" // WebhookUrl webhook地址
+ UserSettingWebhookSecret = "webhook_secret" // WebhookSecret webhook密钥
+ UserSettingNotificationEmail = "notification_email" // NotificationEmail 通知邮箱地址
+)
+
+var (
+ NotifyTypeEmail = "email" // Email 邮件
+ NotifyTypeWebhook = "webhook" // Webhook
+)
diff --git a/controller/channel-test.go b/controller/channel-test.go
index 7e74bec2..4b0cc169 100644
--- a/controller/channel-test.go
+++ b/controller/channel-test.go
@@ -238,9 +238,7 @@ var testAllChannelsLock sync.Mutex
var testAllChannelsRunning bool = false
func testAllChannels(notify bool) error {
- if common.RootUserEmail == "" {
- common.RootUserEmail = model.GetRootUserEmail()
- }
+
testAllChannelsLock.Lock()
if testAllChannelsRunning {
testAllChannelsLock.Unlock()
@@ -295,10 +293,7 @@ func testAllChannels(notify bool) error {
testAllChannelsRunning = false
testAllChannelsLock.Unlock()
if notify {
- err := common.SendEmail("通道测试完成", common.RootUserEmail, "通道测试完成,如果没有收到禁用通知,说明所有通道都正常")
- if err != nil {
- common.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
- }
+ service.NotifyRootUser(dto.NotifyTypeChannelTest, "通道测试完成", "所有通道测试已完成")
}
})
return nil
diff --git a/controller/pricing.go b/controller/pricing.go
index 36caff9d..d7af5a4c 100644
--- a/controller/pricing.go
+++ b/controller/pricing.go
@@ -17,7 +17,7 @@ func GetPricing(c *gin.Context) {
}
var group string
if exists {
- user, err := model.GetUserById(userId.(int), false)
+ user, err := model.GetUserCache(userId.(int))
if err == nil {
group = user.Group
}
diff --git a/controller/user.go b/controller/user.go
index 7146f00e..3002a613 100644
--- a/controller/user.go
+++ b/controller/user.go
@@ -4,6 +4,7 @@ import (
"encoding/json"
"fmt"
"net/http"
+ "net/url"
"one-api/common"
"one-api/model"
"one-api/setting"
@@ -471,7 +472,7 @@ func GetUserModels(c *gin.Context) {
if err != nil {
id = c.GetInt("id")
}
- user, err := model.GetUserById(id, true)
+ user, err := model.GetUserCache(id)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -869,9 +870,6 @@ func EmailBind(c *gin.Context) {
})
return
}
- if user.Role == common.RoleRootUser {
- common.RootUserEmail = email
- }
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
@@ -913,3 +911,115 @@ func TopUp(c *gin.Context) {
})
return
}
+
+type UpdateUserSettingRequest struct {
+ QuotaWarningType string `json:"notify_type"`
+ QuotaWarningThreshold int `json:"quota_warning_threshold"`
+ WebhookUrl string `json:"webhook_url,omitempty"`
+ WebhookSecret string `json:"webhook_secret,omitempty"`
+ NotificationEmail string `json:"notification_email,omitempty"`
+}
+
+func UpdateUserSetting(c *gin.Context) {
+ var req UpdateUserSettingRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "无效的参数",
+ })
+ return
+ }
+
+ // 验证预警类型
+ if req.QuotaWarningType != constant.NotifyTypeEmail && req.QuotaWarningType != constant.NotifyTypeWebhook {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "无效的预警类型",
+ })
+ return
+ }
+
+ // 验证预警阈值
+ if req.QuotaWarningThreshold <= 0 {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "预警阈值必须大于0",
+ })
+ return
+ }
+
+ // 如果是webhook类型,验证webhook地址
+ if req.QuotaWarningType == constant.NotifyTypeWebhook {
+ if req.WebhookUrl == "" {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "Webhook地址不能为空",
+ })
+ return
+ }
+ // 验证URL格式
+ if _, err := url.ParseRequestURI(req.WebhookUrl); err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "无效的Webhook地址",
+ })
+ return
+ }
+ }
+
+ // 如果是邮件类型,验证邮箱地址
+ if req.QuotaWarningType == constant.NotifyTypeEmail && req.NotificationEmail != "" {
+ // 验证邮箱格式
+ if !strings.Contains(req.NotificationEmail, "@") {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "无效的邮箱地址",
+ })
+ return
+ }
+ }
+
+ userId := c.GetInt("id")
+ user, err := model.GetUserById(userId, true)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+
+ // 构建设置
+ settings := map[string]interface{}{
+ constant.UserSettingNotifyType: req.QuotaWarningType,
+ constant.UserSettingQuotaWarningThreshold: req.QuotaWarningThreshold,
+ }
+
+ // 如果是webhook类型,添加webhook相关设置
+ if req.QuotaWarningType == constant.NotifyTypeWebhook {
+ settings[constant.UserSettingWebhookUrl] = req.WebhookUrl
+ if req.WebhookSecret != "" {
+ settings[constant.UserSettingWebhookSecret] = req.WebhookSecret
+ }
+ }
+
+ // 如果提供了通知邮箱,添加到设置中
+ if req.QuotaWarningType == constant.NotifyTypeEmail && req.NotificationEmail != "" {
+ settings[constant.UserSettingNotificationEmail] = req.NotificationEmail
+ }
+
+ // 更新用户设置
+ user.SetSetting(settings)
+ if err := user.Update(false); err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "更新设置失败: " + err.Error(),
+ })
+ return
+ }
+
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "设置已更新",
+ })
+}
diff --git a/docker-compose.yml b/docker-compose.yml
index 640cf074..0f23cea2 100644
--- a/docker-compose.yml
+++ b/docker-compose.yml
@@ -24,7 +24,7 @@ services:
- redis
- mysql
healthcheck:
- test: [ "CMD-SHELL", "wget -q -O - http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk -F: '{print $2}'" ]
+ test: ["CMD-SHELL", "wget -q -O - http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk -F: '{print $$2}'"]
interval: 30s
timeout: 10s
retries: 3
diff --git a/dto/notify.go b/dto/notify.go
new file mode 100644
index 00000000..b75cec70
--- /dev/null
+++ b/dto/notify.go
@@ -0,0 +1,25 @@
+package dto
+
+type Notify struct {
+ Type string `json:"type"`
+ Title string `json:"title"`
+ Content string `json:"content"`
+ Values []interface{} `json:"values"`
+}
+
+const ContentValueParam = "{{value}}"
+
+const (
+ NotifyTypeQuotaExceed = "quota_exceed"
+ NotifyTypeChannelUpdate = "channel_update"
+ NotifyTypeChannelTest = "channel_test"
+)
+
+func NewNotify(t string, title string, content string, values []interface{}) Notify {
+ return Notify{
+ Type: t,
+ Title: title,
+ Content: content,
+ Values: values,
+ }
+}
diff --git a/dto/openai_request.go b/dto/openai_request.go
index 58a4ce73..a142b437 100644
--- a/dto/openai_request.go
+++ b/dto/openai_request.go
@@ -18,6 +18,7 @@ type GeneralOpenAIRequest struct {
Model string `json:"model,omitempty"`
Messages []Message `json:"messages,omitempty"`
Prompt any `json:"prompt,omitempty"`
+ Prefix any `json:"prefix,omitempty"`
Suffix any `json:"suffix,omitempty"`
Stream bool `json:"stream,omitempty"`
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
diff --git a/model/token.go b/model/token.go
index 3abd22cf..8587ea62 100644
--- a/model/token.go
+++ b/model/token.go
@@ -3,13 +3,11 @@ package model
import (
"errors"
"fmt"
+ "one-api/common"
+ "strings"
+
"github.com/bytedance/gopkg/util/gopool"
"gorm.io/gorm"
- "one-api/common"
- relaycommon "one-api/relay/common"
- "one-api/setting"
- "strconv"
- "strings"
)
type Token struct {
@@ -322,80 +320,3 @@ func decreaseTokenQuota(id int, quota int) (err error) {
).Error
return err
}
-
-func PreConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, quota int) error {
- if quota < 0 {
- return errors.New("quota 不能为负数!")
- }
- if relayInfo.IsPlayground {
- return nil
- }
- //if relayInfo.TokenUnlimited {
- // return nil
- //}
- token, err := GetTokenById(relayInfo.TokenId)
- if err != nil {
- return err
- }
- if !relayInfo.TokenUnlimited && token.RemainQuota < quota {
- return errors.New("令牌额度不足")
- }
- err = DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota)
- if err != nil {
- return err
- }
- return nil
-}
-
-func PostConsumeQuota(relayInfo *relaycommon.RelayInfo, userQuota int, quota int, preConsumedQuota int, sendEmail bool) (err error) {
-
- if quota > 0 {
- err = DecreaseUserQuota(relayInfo.UserId, quota)
- } else {
- err = IncreaseUserQuota(relayInfo.UserId, -quota)
- }
- if err != nil {
- return err
- }
-
- if !relayInfo.IsPlayground {
- if quota > 0 {
- err = DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota)
- } else {
- err = IncreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, -quota)
- }
- if err != nil {
- return err
- }
- }
-
- if sendEmail {
- if (quota + preConsumedQuota) != 0 {
- quotaTooLow := userQuota >= common.QuotaRemindThreshold && userQuota-(quota+preConsumedQuota) < common.QuotaRemindThreshold
- noMoreQuota := userQuota-(quota+preConsumedQuota) <= 0
- if quotaTooLow || noMoreQuota {
- go func() {
- email, err := GetUserEmail(relayInfo.UserId)
- if err != nil {
- common.SysError("failed to fetch user email: " + err.Error())
- }
- prompt := "您的额度即将用尽"
- if noMoreQuota {
- prompt = "您的额度已用尽"
- }
- if email != "" {
- topUpLink := fmt.Sprintf("%s/topup", setting.ServerAddress)
- err = common.SendEmail(prompt, email,
- fmt.Sprintf("%s,当前剩余额度为 %d,为了不影响您的使用,请及时充值。
充值链接:%s", prompt, userQuota, topUpLink, topUpLink))
- if err != nil {
- common.SysError("failed to send email" + err.Error())
- }
- common.SysLog("user quota is low, consumed quota: " + strconv.Itoa(quota) + ", user quota: " + strconv.Itoa(userQuota))
- }
- }()
- }
- }
- }
-
- return nil
-}
diff --git a/model/token_cache.go b/model/token_cache.go
index 99b762f5..0fe02fea 100644
--- a/model/token_cache.go
+++ b/model/token_cache.go
@@ -52,7 +52,7 @@ func cacheSetTokenField(key string, field string, value string) error {
func cacheGetTokenByKey(key string) (*Token, error) {
hmacKey := common.GenerateHMAC(key)
if !common.RedisEnabled {
- return nil, nil
+ return nil, fmt.Errorf("redis is not enabled")
}
var token Token
err := common.RedisHGetObj(fmt.Sprintf("token:%s", hmacKey), &token)
diff --git a/model/user.go b/model/user.go
index 95123c21..427b0625 100644
--- a/model/user.go
+++ b/model/user.go
@@ -1,6 +1,7 @@
package model
import (
+ "encoding/json"
"errors"
"fmt"
"one-api/common"
@@ -38,6 +39,20 @@ type User struct {
InviterId int `json:"inviter_id" gorm:"type:int;column:inviter_id;index"`
DeletedAt gorm.DeletedAt `gorm:"index"`
LinuxDOId string `json:"linux_do_id" gorm:"column:linux_do_id;index"`
+ Setting string `json:"setting" gorm:"type:text;column:setting"`
+}
+
+func (user *User) ToBaseUser() *UserBase {
+ cache := &UserBase{
+ Id: user.Id,
+ Group: user.Group,
+ Quota: user.Quota,
+ Status: user.Status,
+ Username: user.Username,
+ Setting: user.Setting,
+ Email: user.Email,
+ }
+ return cache
}
func (user *User) GetAccessToken() string {
@@ -51,6 +66,22 @@ func (user *User) SetAccessToken(token string) {
user.AccessToken = &token
}
+func (user *User) GetSetting() map[string]interface{} {
+ if user.Setting == "" {
+ return nil
+ }
+ return common.StrToMap(user.Setting)
+}
+
+func (user *User) SetSetting(setting map[string]interface{}) {
+ settingBytes, err := json.Marshal(setting)
+ if err != nil {
+ common.SysError("failed to marshal setting: " + err.Error())
+ return
+ }
+ user.Setting = string(settingBytes)
+}
+
// CheckUserExistOrDeleted check if user exist or deleted, if not exist, return false, nil, if deleted or exist, return true, nil
func CheckUserExistOrDeleted(username string, email string) (bool, error) {
var user User
@@ -315,8 +346,8 @@ func (user *User) Update(updatePassword bool) error {
return err
}
- // 更新缓存
- return updateUserCache(user.Id, user.Username, user.Group, user.Quota, user.Status)
+ // Update cache
+ return updateUserCache(*user)
}
func (user *User) Edit(updatePassword bool) error {
@@ -344,8 +375,8 @@ func (user *User) Edit(updatePassword bool) error {
return err
}
- // 更新缓存
- return updateUserCache(user.Id, user.Username, user.Group, user.Quota, user.Status)
+ // Update cache
+ return updateUserCache(*user)
}
func (user *User) Delete() error {
@@ -371,8 +402,8 @@ func (user *User) HardDelete() error {
// ValidateAndFill check password & user status
func (user *User) ValidateAndFill() (err error) {
// When querying with struct, GORM will only query with non-zero fields,
- // that means if your field’s value is 0, '', false or other zero values,
- // it won’t be used to build query conditions
+ // that means if your field's value is 0, '', false or other zero values,
+ // it won't be used to build query conditions
password := user.Password
username := strings.TrimSpace(user.Username)
if username == "" || password == "" {
@@ -531,7 +562,6 @@ func GetUserQuota(id int, fromDB bool) (quota int, err error) {
return quota, nil
}
// Don't return error - fall through to DB
- //common.SysError("failed to get user quota from cache: " + err.Error())
}
fromDB = true
err = DB.Model(&User{}).Where("id = ?", id).Select("quota").Find("a).Error
@@ -580,6 +610,35 @@ func GetUserGroup(id int, fromDB bool) (group string, err error) {
return group, nil
}
+// GetUserSetting gets setting from Redis first, falls back to DB if needed
+func GetUserSetting(id int, fromDB bool) (settingMap map[string]interface{}, err error) {
+ var setting string
+ defer func() {
+ // Update Redis cache asynchronously on successful DB read
+ if shouldUpdateRedis(fromDB, err) {
+ gopool.Go(func() {
+ if err := updateUserSettingCache(id, setting); err != nil {
+ common.SysError("failed to update user setting cache: " + err.Error())
+ }
+ })
+ }
+ }()
+ if !fromDB && common.RedisEnabled {
+ setting, err := getUserSettingCache(id)
+ if err == nil {
+ return setting, nil
+ }
+ // Don't return error - fall through to DB
+ }
+ fromDB = true
+ err = DB.Model(&User{}).Where("id = ?", id).Select("setting").Find(&setting).Error
+ if err != nil {
+ return map[string]interface{}{}, err
+ }
+
+ return common.StrToMap(setting), nil
+}
+
func IncreaseUserQuota(id int, quota int) (err error) {
if quota < 0 {
return errors.New("quota 不能为负数!")
@@ -641,9 +700,14 @@ func DeltaUpdateUserQuota(id int, delta int) (err error) {
}
}
-func GetRootUserEmail() (email string) {
- DB.Model(&User{}).Where("role = ?", common.RoleRootUser).Select("email").Find(&email)
- return email
+//func GetRootUserEmail() (email string) {
+// DB.Model(&User{}).Where("role = ?", common.RoleRootUser).Select("email").Find(&email)
+// return email
+//}
+
+func GetRootUser() (user *User) {
+ DB.Where("role = ?", common.RoleRootUser).First(&user)
+ return user
}
func UpdateUserUsedQuotaAndRequestCount(id int, quota int) {
@@ -725,10 +789,10 @@ func IsLinuxDOIdAlreadyTaken(linuxDOId string) bool {
return !errors.Is(err, gorm.ErrRecordNotFound)
}
-func (u *User) FillUserByLinuxDOId() error {
- if u.LinuxDOId == "" {
+func (user *User) FillUserByLinuxDOId() error {
+ if user.LinuxDOId == "" {
return errors.New("linux do id is empty")
}
- err := DB.Where("linux_do_id = ?", u.LinuxDOId).First(u).Error
+ err := DB.Where("linux_do_id = ?", user.LinuxDOId).First(user).Error
return err
}
diff --git a/model/user_cache.go b/model/user_cache.go
index 9dc7e899..cc08288d 100644
--- a/model/user_cache.go
+++ b/model/user_cache.go
@@ -1,206 +1,213 @@
package model
import (
+ "encoding/json"
"fmt"
"one-api/common"
"one-api/constant"
- "strconv"
"time"
+
+ "github.com/bytedance/gopkg/util/gopool"
)
-// Change UserCache struct to userCache
-type userCache struct {
+// UserBase struct remains the same as it represents the cached data structure
+type UserBase struct {
Id int `json:"id"`
Group string `json:"group"`
+ Email string `json:"email"`
Quota int `json:"quota"`
Status int `json:"status"`
- Role int `json:"role"`
Username string `json:"username"`
+ Setting string `json:"setting"`
}
-// Rename all exported functions to private ones
-// invalidateUserCache clears all user related cache
+func (user *UserBase) GetSetting() map[string]interface{} {
+ if user.Setting == "" {
+ return nil
+ }
+ return common.StrToMap(user.Setting)
+}
+
+func (user *UserBase) SetSetting(setting map[string]interface{}) {
+ settingBytes, err := json.Marshal(setting)
+ if err != nil {
+ common.SysError("failed to marshal setting: " + err.Error())
+ return
+ }
+ user.Setting = string(settingBytes)
+}
+
+// getUserCacheKey returns the key for user cache
+func getUserCacheKey(userId int) string {
+ return fmt.Sprintf("user:%d", userId)
+}
+
+// invalidateUserCache clears user cache
func invalidateUserCache(userId int) error {
if !common.RedisEnabled {
return nil
}
+ return common.RedisHDelObj(getUserCacheKey(userId))
+}
- keys := []string{
- fmt.Sprintf(constant.UserGroupKeyFmt, userId),
- fmt.Sprintf(constant.UserQuotaKeyFmt, userId),
- fmt.Sprintf(constant.UserEnabledKeyFmt, userId),
- fmt.Sprintf(constant.UserUsernameKeyFmt, userId),
+// updateUserCache updates all user cache fields using hash
+func updateUserCache(user User) error {
+ if !common.RedisEnabled {
+ return nil
}
- for _, key := range keys {
- if err := common.RedisDel(key); err != nil {
- return fmt.Errorf("failed to delete cache key %s: %w", key, err)
+ return common.RedisHSetObj(
+ getUserCacheKey(user.Id),
+ user.ToBaseUser(),
+ time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second,
+ )
+}
+
+// GetUserCache gets complete user cache from hash
+func GetUserCache(userId int) (userCache *UserBase, err error) {
+ var user *User
+ var fromDB bool
+ defer func() {
+ // Update Redis cache asynchronously on successful DB read
+ if shouldUpdateRedis(fromDB, err) && user != nil {
+ gopool.Go(func() {
+ if err := updateUserCache(*user); err != nil {
+ common.SysError("failed to update user status cache: " + err.Error())
+ }
+ })
}
- }
- return nil
-}
+ }()
-// updateUserGroupCache updates user group cache
-func updateUserGroupCache(userId int, group string) error {
- if !common.RedisEnabled {
- return nil
- }
- return common.RedisSet(
- fmt.Sprintf(constant.UserGroupKeyFmt, userId),
- group,
- time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second,
- )
-}
-
-// updateUserQuotaCache updates user quota cache
-func updateUserQuotaCache(userId int, quota int) error {
- if !common.RedisEnabled {
- return nil
- }
- return common.RedisSet(
- fmt.Sprintf(constant.UserQuotaKeyFmt, userId),
- fmt.Sprintf("%d", quota),
- time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second,
- )
-}
-
-// updateUserStatusCache updates user status cache
-func updateUserStatusCache(userId int, userEnabled bool) error {
- if !common.RedisEnabled {
- return nil
- }
- enabled := "0"
- if userEnabled {
- enabled = "1"
- }
- return common.RedisSet(
- fmt.Sprintf(constant.UserEnabledKeyFmt, userId),
- enabled,
- time.Duration(constant.UserId2StatusCacheSeconds)*time.Second,
- )
-}
-
-// updateUserNameCache updates username cache
-func updateUserNameCache(userId int, username string) error {
- if !common.RedisEnabled {
- return nil
- }
- return common.RedisSet(
- fmt.Sprintf(constant.UserUsernameKeyFmt, userId),
- username,
- time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second,
- )
-}
-
-// updateUserCache updates all user cache fields
-func updateUserCache(userId int, username string, userGroup string, quota int, status int) error {
- if !common.RedisEnabled {
- return nil
+ // Try getting from Redis first
+ userCache, err = cacheGetUserBase(userId)
+ if err == nil {
+ return userCache, nil
}
- if err := updateUserGroupCache(userId, userGroup); err != nil {
- return fmt.Errorf("update group cache: %w", err)
- }
-
- if err := updateUserQuotaCache(userId, quota); err != nil {
- return fmt.Errorf("update quota cache: %w", err)
- }
-
- if err := updateUserStatusCache(userId, status == common.UserStatusEnabled); err != nil {
- return fmt.Errorf("update status cache: %w", err)
- }
-
- if err := updateUserNameCache(userId, username); err != nil {
- return fmt.Errorf("update username cache: %w", err)
- }
-
- return nil
-}
-
-// getUserGroupCache gets user group from cache
-func getUserGroupCache(userId int) (string, error) {
- if !common.RedisEnabled {
- return "", nil
- }
- return common.RedisGet(fmt.Sprintf(constant.UserGroupKeyFmt, userId))
-}
-
-// getUserQuotaCache gets user quota from cache
-func getUserQuotaCache(userId int) (int, error) {
- if !common.RedisEnabled {
- return 0, nil
- }
- quotaStr, err := common.RedisGet(fmt.Sprintf(constant.UserQuotaKeyFmt, userId))
+ // If Redis fails, get from DB
+ fromDB = true
+ user, err = GetUserById(userId, false)
if err != nil {
- return 0, err
+ return nil, err // Return nil and error if DB lookup fails
}
- return strconv.Atoi(quotaStr)
+
+ // Create cache object from user data
+ userCache = &UserBase{
+ Id: user.Id,
+ Group: user.Group,
+ Quota: user.Quota,
+ Status: user.Status,
+ Username: user.Username,
+ Setting: user.Setting,
+ Email: user.Email,
+ }
+
+ return userCache, nil
}
-// getUserStatusCache gets user status from cache
-func getUserStatusCache(userId int) (int, error) {
+func cacheGetUserBase(userId int) (*UserBase, error) {
if !common.RedisEnabled {
- return 0, nil
+ return nil, fmt.Errorf("redis is not enabled")
}
- statusStr, err := common.RedisGet(fmt.Sprintf(constant.UserEnabledKeyFmt, userId))
+ var userCache UserBase
+ // Try getting from Redis first
+ err := common.RedisHGetObj(getUserCacheKey(userId), &userCache)
if err != nil {
- return 0, err
+ return nil, err
}
- return strconv.Atoi(statusStr)
+ return &userCache, nil
}
-// getUserNameCache gets username from cache
-func getUserNameCache(userId int) (string, error) {
- if !common.RedisEnabled {
- return "", nil
- }
- return common.RedisGet(fmt.Sprintf(constant.UserUsernameKeyFmt, userId))
-}
-
-// getUserCache gets complete user cache
-func getUserCache(userId int) (*userCache, error) {
- if !common.RedisEnabled {
- return nil, nil
- }
-
- group, err := getUserGroupCache(userId)
- if err != nil {
- return nil, fmt.Errorf("get group cache: %w", err)
- }
-
- quota, err := getUserQuotaCache(userId)
- if err != nil {
- return nil, fmt.Errorf("get quota cache: %w", err)
- }
-
- status, err := getUserStatusCache(userId)
- if err != nil {
- return nil, fmt.Errorf("get status cache: %w", err)
- }
-
- username, err := getUserNameCache(userId)
- if err != nil {
- return nil, fmt.Errorf("get username cache: %w", err)
- }
-
- return &userCache{
- Id: userId,
- Group: group,
- Quota: quota,
- Status: status,
- Username: username,
- }, nil
-}
-
-// Add atomic quota operations
+// Add atomic quota operations using hash fields
func cacheIncrUserQuota(userId int, delta int64) error {
if !common.RedisEnabled {
return nil
}
- key := fmt.Sprintf(constant.UserQuotaKeyFmt, userId)
- return common.RedisIncr(key, delta)
+ return common.RedisHIncrBy(getUserCacheKey(userId), "Quota", delta)
}
func cacheDecrUserQuota(userId int, delta int64) error {
return cacheIncrUserQuota(userId, -delta)
}
+
+// Helper functions to get individual fields if needed
+func getUserGroupCache(userId int) (string, error) {
+ cache, err := GetUserCache(userId)
+ if err != nil {
+ return "", err
+ }
+ return cache.Group, nil
+}
+
+func getUserQuotaCache(userId int) (int, error) {
+ cache, err := GetUserCache(userId)
+ if err != nil {
+ return 0, err
+ }
+ return cache.Quota, nil
+}
+
+func getUserStatusCache(userId int) (int, error) {
+ cache, err := GetUserCache(userId)
+ if err != nil {
+ return 0, err
+ }
+ return cache.Status, nil
+}
+
+func getUserNameCache(userId int) (string, error) {
+ cache, err := GetUserCache(userId)
+ if err != nil {
+ return "", err
+ }
+ return cache.Username, nil
+}
+
+func getUserSettingCache(userId int) (map[string]interface{}, error) {
+ setting := make(map[string]interface{})
+ cache, err := GetUserCache(userId)
+ if err != nil {
+ return setting, err
+ }
+ return cache.GetSetting(), nil
+}
+
+// New functions for individual field updates
+func updateUserStatusCache(userId int, status bool) error {
+ if !common.RedisEnabled {
+ return nil
+ }
+ statusInt := common.UserStatusEnabled
+ if !status {
+ statusInt = common.UserStatusDisabled
+ }
+ return common.RedisHSetField(getUserCacheKey(userId), "Status", fmt.Sprintf("%d", statusInt))
+}
+
+func updateUserQuotaCache(userId int, quota int) error {
+ if !common.RedisEnabled {
+ return nil
+ }
+ return common.RedisHSetField(getUserCacheKey(userId), "Quota", fmt.Sprintf("%d", quota))
+}
+
+func updateUserGroupCache(userId int, group string) error {
+ if !common.RedisEnabled {
+ return nil
+ }
+ return common.RedisHSetField(getUserCacheKey(userId), "Group", group)
+}
+
+func updateUserNameCache(userId int, username string) error {
+ if !common.RedisEnabled {
+ return nil
+ }
+ return common.RedisHSetField(getUserCacheKey(userId), "Username", username)
+}
+
+func updateUserSettingCache(userId int, setting string) error {
+ if !common.RedisEnabled {
+ return nil
+ }
+ return common.RedisHSetField(getUserCacheKey(userId), "Setting", setting)
+}
diff --git a/relay/channel/cloudflare/adaptor.go b/relay/channel/cloudflare/adaptor.go
index 75400098..5c2eadc2 100644
--- a/relay/channel/cloudflare/adaptor.go
+++ b/relay/channel/cloudflare/adaptor.go
@@ -4,13 +4,14 @@ import (
"bytes"
"errors"
"fmt"
- "github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
+
+ "github.com/gin-gonic/gin"
)
type Adaptor struct {
diff --git a/relay/relay-mj.go b/relay/relay-mj.go
index 0facecab..766064cb 100644
--- a/relay/relay-mj.go
+++ b/relay/relay-mj.go
@@ -194,7 +194,7 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
}
defer func(ctx context.Context) {
if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 {
- err := model.PostConsumeQuota(relayInfo, userQuota, quota, 0, true)
+ err := service.PostConsumeQuota(relayInfo, quota, 0, true)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
}
@@ -500,7 +500,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
defer func(ctx context.Context) {
if consumeQuota && midjResponseWithStatus.StatusCode == 200 {
- err := model.PostConsumeQuota(relayInfo, userQuota, quota, 0, true)
+ err := service.PostConsumeQuota(relayInfo, quota, 0, true)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
}
diff --git a/relay/relay-text.go b/relay/relay-text.go
index f303ff6a..f9d1bd03 100644
--- a/relay/relay-text.go
+++ b/relay/relay-text.go
@@ -272,7 +272,7 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
return 0, 0, service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
}
if userQuota-preConsumedQuota < 0 {
- return 0, 0, service.OpenAIErrorWrapperLocal(fmt.Errorf("chat pre-consumed quota failed, user quota: %d, need quota: %d", userQuota, preConsumedQuota), "insufficient_user_quota", http.StatusBadRequest)
+ return 0, 0, service.OpenAIErrorWrapperLocal(fmt.Errorf("chat pre-consumed quota failed, user quota: %s, need quota: %d", common.FormatQuota(userQuota), preConsumedQuota), "insufficient_user_quota", http.StatusBadRequest)
}
if userQuota > 100*preConsumedQuota {
// 用户额度充足,判断令牌额度是否充足
@@ -282,18 +282,18 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
if tokenQuota > 100*preConsumedQuota {
// 令牌额度充足,信任令牌
preConsumedQuota = 0
- common.LogInfo(c, fmt.Sprintf("user %d quota %d and token %d quota %d are enough, trusted and no need to pre-consume", relayInfo.UserId, userQuota, relayInfo.TokenId, tokenQuota))
+ common.LogInfo(c, fmt.Sprintf("user %d quota %s and token %d quota %d are enough, trusted and no need to pre-consume", relayInfo.UserId, common.FormatQuota(userQuota), relayInfo.TokenId, tokenQuota))
}
} else {
// in this case, we do not pre-consume quota
// because the user has enough quota
preConsumedQuota = 0
- common.LogInfo(c, fmt.Sprintf("user %d with unlimited token has enough quota %d, trusted and no need to pre-consume", relayInfo.UserId, userQuota))
+ common.LogInfo(c, fmt.Sprintf("user %d with unlimited token has enough quota %s, trusted and no need to pre-consume", relayInfo.UserId, common.FormatQuota(userQuota)))
}
}
if preConsumedQuota > 0 {
- err = model.PreConsumeTokenQuota(relayInfo, preConsumedQuota)
+ err = service.PreConsumeTokenQuota(relayInfo, preConsumedQuota)
if err != nil {
return 0, 0, service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden)
}
@@ -310,7 +310,7 @@ func returnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, us
go func() {
relayInfoCopy := *relayInfo
- err := model.PostConsumeQuota(&relayInfoCopy, userQuota, -preConsumedQuota, 0, false)
+ err := service.PostConsumeQuota(&relayInfoCopy, -preConsumedQuota, 0, false)
if err != nil {
common.SysError("error return pre-consumed quota: " + err.Error())
}
@@ -368,7 +368,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelN
//}
quotaDelta := quota - preConsumedQuota
if quotaDelta != 0 {
- err := model.PostConsumeQuota(relayInfo, userQuota, quotaDelta, preConsumedQuota, true)
+ err := service.PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true)
if err != nil {
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
}
diff --git a/relay/relay_task.go b/relay/relay_task.go
index 61577faf..f03fcb2d 100644
--- a/relay/relay_task.go
+++ b/relay/relay_task.go
@@ -113,7 +113,7 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
// release quota
if relayInfo.ConsumeQuota && taskErr == nil {
- err := model.PostConsumeQuota(relayInfo.ToRelayInfo(), userQuota, quota, 0, true)
+ err := service.PostConsumeQuota(relayInfo.ToRelayInfo(), quota, 0, true)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
}
diff --git a/router/api-router.go b/router/api-router.go
index b00595af..bf88449a 100644
--- a/router/api-router.go
+++ b/router/api-router.go
@@ -56,6 +56,7 @@ func SetApiRouter(router *gin.Engine) {
selfRoute.POST("/pay", controller.RequestEpay)
selfRoute.POST("/amount", controller.RequestAmount)
selfRoute.POST("/aff_transfer", controller.TransferAffQuota)
+ selfRoute.PUT("/setting", controller.UpdateUserSetting)
}
adminRoute := userRoute.Group("/")
diff --git a/service/cf_worker.go b/service/cf_worker.go
index afe65411..40a1e294 100644
--- a/service/cf_worker.go
+++ b/service/cf_worker.go
@@ -2,6 +2,7 @@ package service
import (
"bytes"
+ "encoding/json"
"fmt"
"net/http"
"one-api/common"
@@ -9,19 +10,46 @@ import (
"strings"
)
+// WorkerRequest Worker请求的数据结构
+type WorkerRequest struct {
+ URL string `json:"url"`
+ Key string `json:"key"`
+ Method string `json:"method,omitempty"`
+ Headers map[string]string `json:"headers,omitempty"`
+ Body json.RawMessage `json:"body,omitempty"`
+}
+
+// DoWorkerRequest 通过Worker发送请求
+func DoWorkerRequest(req *WorkerRequest) (*http.Response, error) {
+ if !setting.EnableWorker() {
+ return nil, fmt.Errorf("worker not enabled")
+ }
+ if !strings.HasPrefix(req.URL, "https") {
+ return nil, fmt.Errorf("only support https url")
+ }
+
+ workerUrl := setting.WorkerUrl
+ if !strings.HasSuffix(workerUrl, "/") {
+ workerUrl += "/"
+ }
+
+ // 序列化worker请求数据
+ workerPayload, err := json.Marshal(req)
+ if err != nil {
+ return nil, fmt.Errorf("failed to marshal worker payload: %v", err)
+ }
+
+ return http.Post(workerUrl, "application/json", bytes.NewBuffer(workerPayload))
+}
+
func DoDownloadRequest(originUrl string) (resp *http.Response, err error) {
if setting.EnableWorker() {
common.SysLog(fmt.Sprintf("downloading file from worker: %s", originUrl))
- if !strings.HasPrefix(originUrl, "https") {
- return nil, fmt.Errorf("only support https url")
+ req := &WorkerRequest{
+ URL: originUrl,
+ Key: setting.WorkerValidKey,
}
- workerUrl := setting.WorkerUrl
- if !strings.HasSuffix(workerUrl, "/") {
- workerUrl += "/"
- }
- // post request to worker
- data := []byte(`{"url":"` + originUrl + `","key":"` + setting.WorkerValidKey + `"}`)
- return http.Post(setting.WorkerUrl, "application/json", bytes.NewBuffer(data))
+ return DoWorkerRequest(req)
} else {
common.SysLog(fmt.Sprintf("downloading from origin: %s", originUrl))
return http.Get(originUrl)
diff --git a/service/channel.go b/service/channel.go
index 73545b1e..76bcacf1 100644
--- a/service/channel.go
+++ b/service/channel.go
@@ -4,7 +4,7 @@ import (
"fmt"
"net/http"
"one-api/common"
- relaymodel "one-api/dto"
+ "one-api/dto"
"one-api/model"
"one-api/setting"
"strings"
@@ -15,17 +15,17 @@ func DisableChannel(channelId int, channelName string, reason string) {
model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled, reason)
subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId)
content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason)
- notifyRootUser(subject, content)
+ NotifyRootUser(subject, content, dto.NotifyTypeChannelUpdate)
}
func EnableChannel(channelId int, channelName string) {
model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled, "")
subject := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId)
content := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId)
- notifyRootUser(subject, content)
+ NotifyRootUser(subject, content, dto.NotifyTypeChannelUpdate)
}
-func ShouldDisableChannel(channelType int, err *relaymodel.OpenAIErrorWithStatusCode) bool {
+func ShouldDisableChannel(channelType int, err *dto.OpenAIErrorWithStatusCode) bool {
if !common.AutomaticDisableChannelEnabled {
return false
}
@@ -75,7 +75,7 @@ func ShouldDisableChannel(channelType int, err *relaymodel.OpenAIErrorWithStatus
return false
}
-func ShouldEnableChannel(err error, openaiWithStatusErr *relaymodel.OpenAIErrorWithStatusCode, status int) bool {
+func ShouldEnableChannel(err error, openaiWithStatusErr *dto.OpenAIErrorWithStatusCode, status int) bool {
if !common.AutomaticEnableChannelEnabled {
return false
}
diff --git a/service/notify-limit.go b/service/notify-limit.go
new file mode 100644
index 00000000..7bb62f62
--- /dev/null
+++ b/service/notify-limit.go
@@ -0,0 +1,116 @@
+package service
+
+import (
+ "fmt"
+ "one-api/common"
+ "one-api/constant"
+ "strconv"
+ "sync"
+ "time"
+)
+
+// notifyLimitStore is used for in-memory rate limiting when Redis is disabled
+var (
+ notifyLimitStore sync.Map
+ cleanupOnce sync.Once
+)
+
+type limitCount struct {
+ Count int
+ Timestamp time.Time
+}
+
+func getDuration() time.Duration {
+ minute := constant.NotificationLimitDurationMinute
+ return time.Duration(minute) * time.Minute
+}
+
+// startCleanupTask starts a background task to clean up expired entries
+func startCleanupTask() {
+ go func() {
+ for {
+ time.Sleep(time.Hour)
+ now := time.Now()
+ notifyLimitStore.Range(func(key, value interface{}) bool {
+ if limit, ok := value.(limitCount); ok {
+ if now.Sub(limit.Timestamp) >= getDuration() {
+ notifyLimitStore.Delete(key)
+ }
+ }
+ return true
+ })
+ }
+ }()
+}
+
+// CheckNotificationLimit checks if the user has exceeded their notification limit
+// Returns true if the user can send notification, false if limit exceeded
+func CheckNotificationLimit(userId int, notifyType string) (bool, error) {
+ if common.RedisEnabled {
+ return checkRedisLimit(userId, notifyType)
+ }
+ return checkMemoryLimit(userId, notifyType)
+}
+
+func checkRedisLimit(userId int, notifyType string) (bool, error) {
+ key := fmt.Sprintf("notify_limit:%d:%s:%s", userId, notifyType, time.Now().Format("2006010215"))
+
+ // Get current count
+ count, err := common.RedisGet(key)
+ if err != nil && err.Error() != "redis: nil" {
+ return false, fmt.Errorf("failed to get notification count: %w", err)
+ }
+
+ // If key doesn't exist, initialize it
+ if count == "" {
+ err = common.RedisSet(key, "1", getDuration())
+ return true, err
+ }
+
+ currentCount, _ := strconv.Atoi(count)
+ limit := constant.NotifyLimitCount
+
+ // Check if limit is already reached
+ if currentCount >= limit {
+ return false, nil
+ }
+
+ // Only increment if under limit
+ err = common.RedisIncr(key, 1)
+ if err != nil {
+ return false, fmt.Errorf("failed to increment notification count: %w", err)
+ }
+
+ return true, nil
+}
+
+func checkMemoryLimit(userId int, notifyType string) (bool, error) {
+ // Ensure cleanup task is started
+ cleanupOnce.Do(startCleanupTask)
+
+ key := fmt.Sprintf("%d:%s:%s", userId, notifyType, time.Now().Format("2006010215"))
+ now := time.Now()
+
+ // Get current limit count or initialize new one
+ var currentLimit limitCount
+ if value, ok := notifyLimitStore.Load(key); ok {
+ currentLimit = value.(limitCount)
+ // Check if the entry has expired
+ if now.Sub(currentLimit.Timestamp) >= getDuration() {
+ currentLimit = limitCount{Count: 0, Timestamp: now}
+ }
+ } else {
+ currentLimit = limitCount{Count: 0, Timestamp: now}
+ }
+
+ // Increment count
+ currentLimit.Count++
+
+ // Check against limits
+ limit := constant.NotifyLimitCount
+
+ // Store updated count
+ notifyLimitStore.Store(key, currentLimit)
+
+ return currentLimit.Count <= limit, nil
+}
diff --git a/service/quota.go b/service/quota.go
index ab048008..2ec04fe0 100644
--- a/service/quota.go
+++ b/service/quota.go
@@ -3,8 +3,10 @@ package service
import (
"errors"
"fmt"
+ "github.com/bytedance/gopkg/util/gopool"
"math"
"one-api/common"
+ constant2 "one-api/constant"
"one-api/dto"
"one-api/model"
relaycommon "one-api/relay/common"
@@ -99,7 +101,7 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
return errors.New(fmt.Sprintf("令牌额度不足,剩余额度为 %d", token.RemainQuota))
}
- err = model.PostConsumeQuota(relayInfo, 0, quota, 0, false)
+ err = PostConsumeQuota(relayInfo, quota, 0, false)
if err != nil {
return err
}
@@ -222,7 +224,7 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
} else {
quotaDelta := quota - preConsumedQuota
if quotaDelta != 0 {
- err := model.PostConsumeQuota(relayInfo, userQuota, quotaDelta, preConsumedQuota, true)
+ err := PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true)
if err != nil {
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
}
@@ -239,3 +241,88 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.PromptTokens, usage.CompletionTokens, logModel,
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
}
+
+func PreConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, quota int) error {
+ if quota < 0 {
+ return errors.New("quota 不能为负数!")
+ }
+ if relayInfo.IsPlayground {
+ return nil
+ }
+ //if relayInfo.TokenUnlimited {
+ // return nil
+ //}
+ token, err := model.GetTokenByKey(relayInfo.TokenKey, false)
+ if err != nil {
+ return err
+ }
+ if !relayInfo.TokenUnlimited && token.RemainQuota < quota {
+ return errors.New("令牌额度不足")
+ }
+ err = model.DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota)
+ if err != nil {
+ return err
+ }
+ return nil
+}
+
+func PostConsumeQuota(relayInfo *relaycommon.RelayInfo, quota int, preConsumedQuota int, sendEmail bool) (err error) {
+
+ if quota > 0 {
+ err = model.DecreaseUserQuota(relayInfo.UserId, quota)
+ } else {
+ err = model.IncreaseUserQuota(relayInfo.UserId, -quota)
+ }
+ if err != nil {
+ return err
+ }
+
+ if !relayInfo.IsPlayground {
+ if quota > 0 {
+ err = model.DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota)
+ } else {
+ err = model.IncreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, -quota)
+ }
+ if err != nil {
+ return err
+ }
+ }
+
+ if sendEmail {
+ if (quota + preConsumedQuota) != 0 {
+ checkAndSendQuotaNotify(relayInfo.UserId, quota, preConsumedQuota)
+ }
+ }
+
+ return nil
+}
+
+func checkAndSendQuotaNotify(userId int, quota int, preConsumedQuota int) {
+ gopool.Go(func() {
+ userCache, err := model.GetUserCache(userId)
+ if err != nil {
+ common.SysError("failed to get user cache: " + err.Error())
+ }
+ userSetting := userCache.GetSetting()
+ threshold := common.QuotaRemindThreshold
+ if userCustomThreshold, ok := userSetting[constant2.UserSettingQuotaWarningThreshold]; ok {
+ threshold = int(userCustomThreshold.(float64))
+ }
+
+ //noMoreQuota := userCache.Quota-(quota+preConsumedQuota) <= 0
+ quotaTooLow := false
+ consumeQuota := quota + preConsumedQuota
+ if userCache.Quota-consumeQuota < threshold {
+ quotaTooLow = true
+ }
+ if quotaTooLow {
+ prompt := "您的额度即将用尽"
+ topUpLink := fmt.Sprintf("%s/topup", setting.ServerAddress)
+ content := "{{value}},当前剩余额度为 {{value}},为了不影响您的使用,请及时充值。
充值链接:{{value}}"
+ err = NotifyUser(userCache, dto.NewNotify(dto.NotifyTypeQuotaExceed, prompt, content, []interface{}{prompt, common.FormatQuota(userCache.Quota), topUpLink, topUpLink}))
+ if err != nil {
+ common.SysError(fmt.Sprintf("failed to send quota notify to user %d: %s", userId, err.Error()))
+ }
+ }
+ })
+}
diff --git a/service/user_notify.go b/service/user_notify.go
index 7ae9062b..e01b7aa9 100644
--- a/service/user_notify.go
+++ b/service/user_notify.go
@@ -3,15 +3,75 @@ package service
import (
"fmt"
"one-api/common"
+ "one-api/constant"
+ "one-api/dto"
"one-api/model"
+ "strings"
)
-func notifyRootUser(subject string, content string) {
- if common.RootUserEmail == "" {
- common.RootUserEmail = model.GetRootUserEmail()
- }
- err := common.SendEmail(subject, common.RootUserEmail, content)
- if err != nil {
- common.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
- }
+func NotifyRootUser(t string, subject string, content string) {
+ user := model.GetRootUser().ToBaseUser()
+ _ = NotifyUser(user, dto.NewNotify(t, subject, content, nil))
+}
+
+func NotifyUser(user *model.UserBase, data dto.Notify) error {
+ userSetting := user.GetSetting()
+ notifyType, ok := userSetting[constant.UserSettingNotifyType]
+ if !ok {
+ notifyType = constant.NotifyTypeEmail
+ }
+
+ // Check notification limit
+ canSend, err := CheckNotificationLimit(user.Id, data.Type)
+ if err != nil {
+ common.SysError(fmt.Sprintf("failed to check notification limit: %s", err.Error()))
+ return err
+ }
+ if !canSend {
+ return fmt.Errorf("notification limit exceeded for user %d with type %s", user.Id, notifyType)
+ }
+
+ switch notifyType {
+ case constant.NotifyTypeEmail:
+ userEmail := user.Email
+ // check setting email
+ if settingEmail, ok := userSetting[constant.UserSettingNotificationEmail]; ok {
+ userEmail = settingEmail.(string)
+ }
+ if userEmail == "" {
+ common.SysLog(fmt.Sprintf("user %d has no email, skip sending email", user.Id))
+ return nil
+ }
+ return sendEmailNotify(userEmail, data)
+ case constant.NotifyTypeWebhook:
+ webhookURL, ok := userSetting[constant.UserSettingWebhookUrl]
+ if !ok {
+ common.SysError(fmt.Sprintf("user %d has no webhook url, skip sending webhook", user.Id))
+ return nil
+ }
+ webhookURLStr, ok := webhookURL.(string)
+ if !ok {
+ common.SysError(fmt.Sprintf("user %d webhook url is not string type", user.Id))
+ return nil
+ }
+
+ // 获取 webhook secret
+ var webhookSecret string
+ if secret, ok := userSetting[constant.UserSettingWebhookSecret]; ok {
+ webhookSecret, _ = secret.(string)
+ }
+
+ return SendWebhookNotify(webhookURLStr, webhookSecret, data)
+ }
+ return nil
+}
+
+func sendEmailNotify(userEmail string, data dto.Notify) error {
+ // make email content
+ content := data.Content
+ // 处理占位符
+ for _, value := range data.Values {
+ content = strings.Replace(content, dto.ContentValueParam, fmt.Sprintf("%v", value), 1)
+ }
+ return common.SendEmail(data.Title, userEmail, content)
}
diff --git a/service/webhook.go b/service/webhook.go
new file mode 100644
index 00000000..ad2967eb
--- /dev/null
+++ b/service/webhook.go
@@ -0,0 +1,118 @@
+package service
+
+import (
+ "bytes"
+ "crypto/hmac"
+ "crypto/sha256"
+ "encoding/hex"
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "one-api/dto"
+ "one-api/setting"
+ "time"
+)
+
+// WebhookPayload webhook 通知的负载数据
+type WebhookPayload struct {
+ Type string `json:"type"`
+ Title string `json:"title"`
+ Content string `json:"content"`
+ Values []interface{} `json:"values,omitempty"`
+ Timestamp int64 `json:"timestamp"`
+}
+
+// generateSignature 生成 webhook 签名
+func generateSignature(secret string, payload []byte) string {
+ h := hmac.New(sha256.New, []byte(secret))
+ h.Write(payload)
+ return hex.EncodeToString(h.Sum(nil))
+}
+
+// SendWebhookNotify 发送 webhook 通知
+func SendWebhookNotify(webhookURL string, secret string, data dto.Notify) error {
+ // 处理占位符
+ content := data.Content
+ for _, value := range data.Values {
+ content = fmt.Sprintf(content, value)
+ }
+
+ // 构建 webhook 负载
+ payload := WebhookPayload{
+ Type: data.Type,
+ Title: data.Title,
+ Content: content,
+ Values: data.Values,
+ Timestamp: time.Now().Unix(),
+ }
+
+ // 序列化负载
+ payloadBytes, err := json.Marshal(payload)
+ if err != nil {
+ return fmt.Errorf("failed to marshal webhook payload: %v", err)
+ }
+
+ // 创建 HTTP 请求
+ var req *http.Request
+ var resp *http.Response
+
+ if setting.EnableWorker() {
+ // 构建worker请求数据
+ workerReq := &WorkerRequest{
+ URL: webhookURL,
+ Key: setting.WorkerValidKey,
+ Method: http.MethodPost,
+ Headers: map[string]string{
+ "Content-Type": "application/json",
+ },
+ Body: payloadBytes,
+ }
+
+ // 如果有secret,添加签名到headers
+ if secret != "" {
+ signature := generateSignature(secret, payloadBytes)
+ workerReq.Headers["X-Webhook-Signature"] = signature
+ workerReq.Headers["Authorization"] = "Bearer " + secret
+ }
+
+ resp, err = DoWorkerRequest(workerReq)
+ if err != nil {
+ return fmt.Errorf("failed to send webhook request through worker: %v", err)
+ }
+ defer resp.Body.Close()
+
+ // 检查响应状态
+ if resp.StatusCode < 200 || resp.StatusCode >= 300 {
+ return fmt.Errorf("webhook request failed with status code: %d", resp.StatusCode)
+ }
+ } else {
+ req, err = http.NewRequest(http.MethodPost, webhookURL, bytes.NewBuffer(payloadBytes))
+ if err != nil {
+ return fmt.Errorf("failed to create webhook request: %v", err)
+ }
+
+ // 设置请求头
+ req.Header.Set("Content-Type", "application/json")
+
+ // 如果有 secret,生成签名
+ if secret != "" {
+ signature := generateSignature(secret, payloadBytes)
+ req.Header.Set("X-Webhook-Signature", signature)
+ }
+
+ // 发送请求
+ client := GetImpatientHttpClient()
+ resp, err = client.Do(req)
+ if err != nil {
+ return fmt.Errorf("failed to send webhook request: %v", err)
+ }
+ defer resp.Body.Close()
+
+ // 检查响应状态
+ if resp.StatusCode < 200 || resp.StatusCode >= 300 {
+ return fmt.Errorf("webhook request failed with status code: %d", resp.StatusCode)
+ }
+ }
+
+ return nil
+}
diff --git a/setting/system-setting.go b/setting/system_setting.go
similarity index 100%
rename from setting/system-setting.go
rename to setting/system_setting.go
diff --git a/web/src/components/PersonalSetting.js b/web/src/components/PersonalSetting.js
index 2f112c37..777cf042 100644
--- a/web/src/components/PersonalSetting.js
+++ b/web/src/components/PersonalSetting.js
@@ -26,6 +26,10 @@ import {
Tag,
Typography,
Collapsible,
+ Select,
+ Radio,
+ RadioGroup,
+ AutoComplete,
} from '@douyinfe/semi-ui';
import {
getQuotaPerUnit,
@@ -67,14 +71,16 @@ const PersonalSetting = () => {
const [transferAmount, setTransferAmount] = useState(0);
const [isModelsExpanded, setIsModelsExpanded] = useState(false);
const MODELS_DISPLAY_COUNT = 10; // 默认显示的模型数量
+ const [notificationSettings, setNotificationSettings] = useState({
+ warningType: 'email',
+ warningThreshold: 100000,
+ webhookUrl: '',
+ webhookSecret: '',
+ notificationEmail: ''
+ });
+ const [showWebhookDocs, setShowWebhookDocs] = useState(false);
useEffect(() => {
- // let user = localStorage.getItem('user');
- // if (user) {
- // userDispatch({ type: 'login', payload: user });
- // }
- // console.log(localStorage.getItem('user'))
-
let status = localStorage.getItem('status');
if (status) {
status = JSON.parse(status);
@@ -105,6 +111,19 @@ const PersonalSetting = () => {
return () => clearInterval(countdownInterval); // Clean up on unmount
}, [disableButton, countdown]);
+ useEffect(() => {
+ if (userState?.user?.setting) {
+ const settings = JSON.parse(userState.user.setting);
+ setNotificationSettings({
+ warningType: settings.notify_type || 'email',
+ warningThreshold: settings.quota_warning_threshold || 500000,
+ webhookUrl: settings.webhook_url || '',
+ webhookSecret: settings.webhook_secret || '',
+ notificationEmail: settings.notification_email || ''
+ });
+ }
+ }, [userState?.user?.setting]);
+
const handleInputChange = (name, value) => {
setInputs((inputs) => ({...inputs, [name]: value}));
};
@@ -300,7 +319,36 @@ const PersonalSetting = () => {
}
};
+ const handleNotificationSettingChange = (type, value) => {
+ setNotificationSettings(prev => ({
+ ...prev,
+ [type]: value.target ? value.target.value : value // 处理 Radio 事件对象
+ }));
+ };
+
+ const saveNotificationSettings = async () => {
+ try {
+ const res = await API.put('/api/user/setting', {
+ notify_type: notificationSettings.warningType,
+ quota_warning_threshold: notificationSettings.warningThreshold,
+ webhook_url: notificationSettings.webhookUrl,
+ webhook_secret: notificationSettings.webhookSecret,
+ notification_email: notificationSettings.notificationEmail
+ });
+
+ if (res.data.success) {
+ showSuccess(t('通知设置已更新'));
+ await getUserData();
+ } else {
+ showError(res.data.message);
+ }
+ } catch (error) {
+ showError(t('更新通知设置失败'));
+ }
+ };
+
return (
+
+{`{
+ "type": "quota_exceed", // 通知类型
+ "title": "标题", // 通知标题
+ "content": "通知内容", // 通知内容,支持 {{value}} 变量占位符
+ "values": ["值1", "值2"], // 按顺序替换content中的 {{value}} 占位符
+ "timestamp": 1739950503 // 时间戳
+}
+
+示例:
+{
+ "type": "quota_exceed",
+ "title": "额度预警通知",
+ "content": "您的额度即将用尽,当前剩余额度为 {{value}}",
+ "values": ["$0.99"],
+ "timestamp": 1739950503
+}`}
+
+