Merge pull request #770 from Calcium-Ion/refactor_notify
feat: Add user notification settings and multiple notification methods
This commit is contained in:
@@ -89,6 +89,8 @@ You can add custom models gpt-4-gizmo-* in channels. These are third-party model
|
|||||||
- `MAX_FILE_DOWNLOAD_MB`: Maximum file download size in MB, default `20`
|
- `MAX_FILE_DOWNLOAD_MB`: Maximum file download size in MB, default `20`
|
||||||
- `CRYPTO_SECRET`: Encryption key for encrypting database content
|
- `CRYPTO_SECRET`: Encryption key for encrypting database content
|
||||||
- `AZURE_DEFAULT_API_VERSION`: Azure channel default API version, if not specified in channel settings, use this version, default `2024-12-01-preview`
|
- `AZURE_DEFAULT_API_VERSION`: Azure channel default API version, if not specified in channel settings, use this version, default `2024-12-01-preview`
|
||||||
|
- `NOTIFICATION_LIMIT_DURATION_MINUTE`: Duration of notification limit in minutes, default `10`
|
||||||
|
- `NOTIFY_LIMIT_COUNT`: Maximum number of user notifications in the specified duration, default `2`
|
||||||
|
|
||||||
## Deployment
|
## Deployment
|
||||||
|
|
||||||
|
|||||||
@@ -95,6 +95,9 @@
|
|||||||
- `MAX_FILE_DOWNLOAD_MB`: 最大文件下载大小,单位 MB,默认为 `20`。
|
- `MAX_FILE_DOWNLOAD_MB`: 最大文件下载大小,单位 MB,默认为 `20`。
|
||||||
- `CRYPTO_SECRET`:加密密钥,用于加密数据库内容。
|
- `CRYPTO_SECRET`:加密密钥,用于加密数据库内容。
|
||||||
- `AZURE_DEFAULT_API_VERSION`:Azure渠道默认API版本,如果渠道设置中未指定API版本,则使用此版本,默认为 `2024-12-01-preview`
|
- `AZURE_DEFAULT_API_VERSION`:Azure渠道默认API版本,如果渠道设置中未指定API版本,则使用此版本,默认为 `2024-12-01-preview`
|
||||||
|
- `NOTIFICATION_LIMIT_DURATION_MINUTE`:通知限制的持续时间(分钟),默认为 `10`。
|
||||||
|
- `NOTIFY_LIMIT_COUNT`:用户通知在指定持续时间内的最大数量,默认为 `2`。
|
||||||
|
|
||||||
## 部署
|
## 部署
|
||||||
|
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
|
|||||||
@@ -101,7 +101,7 @@ var PreConsumedQuota = 500
|
|||||||
|
|
||||||
var RetryTimes = 0
|
var RetryTimes = 0
|
||||||
|
|
||||||
var RootUserEmail = ""
|
//var RootUserEmail = ""
|
||||||
|
|
||||||
var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
|
var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
|
||||||
|
|
||||||
|
|||||||
@@ -100,6 +100,14 @@ func LogQuota(quota int) string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func FormatQuota(quota int) string {
|
||||||
|
if DisplayInCurrencyEnabled {
|
||||||
|
return fmt.Sprintf("$%.6f", float64(quota)/QuotaPerUnit)
|
||||||
|
} else {
|
||||||
|
return fmt.Sprintf("%d", quota)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// LogJson 仅供测试使用 only for test
|
// LogJson 仅供测试使用 only for test
|
||||||
func LogJson(ctx context.Context, msg string, obj any) {
|
func LogJson(ctx context.Context, msg string, obj any) {
|
||||||
jsonStr, err := json.Marshal(obj)
|
jsonStr, err := json.Marshal(obj)
|
||||||
|
|||||||
@@ -233,7 +233,11 @@ var (
|
|||||||
modelRatioMapMutex = sync.RWMutex{}
|
modelRatioMapMutex = sync.RWMutex{}
|
||||||
)
|
)
|
||||||
|
|
||||||
var CompletionRatio map[string]float64 = nil
|
var (
|
||||||
|
CompletionRatio map[string]float64 = nil
|
||||||
|
CompletionRatioMutex = sync.RWMutex{}
|
||||||
|
)
|
||||||
|
|
||||||
var defaultCompletionRatio = map[string]float64{
|
var defaultCompletionRatio = map[string]float64{
|
||||||
"gpt-4-gizmo-*": 2,
|
"gpt-4-gizmo-*": 2,
|
||||||
"gpt-4o-gizmo-*": 3,
|
"gpt-4o-gizmo-*": 3,
|
||||||
@@ -334,10 +338,17 @@ func GetDefaultModelRatioMap() map[string]float64 {
|
|||||||
return defaultModelRatio
|
return defaultModelRatio
|
||||||
}
|
}
|
||||||
|
|
||||||
func CompletionRatio2JSONString() string {
|
func GetCompletionRatioMap() map[string]float64 {
|
||||||
|
CompletionRatioMutex.Lock()
|
||||||
|
defer CompletionRatioMutex.Unlock()
|
||||||
if CompletionRatio == nil {
|
if CompletionRatio == nil {
|
||||||
CompletionRatio = defaultCompletionRatio
|
CompletionRatio = defaultCompletionRatio
|
||||||
}
|
}
|
||||||
|
return CompletionRatio
|
||||||
|
}
|
||||||
|
|
||||||
|
func CompletionRatio2JSONString() string {
|
||||||
|
GetCompletionRatioMap()
|
||||||
jsonBytes, err := json.Marshal(CompletionRatio)
|
jsonBytes, err := json.Marshal(CompletionRatio)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
SysError("error marshalling completion ratio: " + err.Error())
|
SysError("error marshalling completion ratio: " + err.Error())
|
||||||
@@ -346,11 +357,15 @@ func CompletionRatio2JSONString() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func UpdateCompletionRatioByJSONString(jsonStr string) error {
|
func UpdateCompletionRatioByJSONString(jsonStr string) error {
|
||||||
|
CompletionRatioMutex.Lock()
|
||||||
|
defer CompletionRatioMutex.Unlock()
|
||||||
CompletionRatio = make(map[string]float64)
|
CompletionRatio = make(map[string]float64)
|
||||||
return json.Unmarshal([]byte(jsonStr), &CompletionRatio)
|
return json.Unmarshal([]byte(jsonStr), &CompletionRatio)
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetCompletionRatio(name string) float64 {
|
func GetCompletionRatio(name string) float64 {
|
||||||
|
GetCompletionRatioMap()
|
||||||
|
|
||||||
if strings.Contains(name, "/") {
|
if strings.Contains(name, "/") {
|
||||||
if ratio, ok := CompletionRatio[name]; ok {
|
if ratio, ok := CompletionRatio[name]; ok {
|
||||||
return ratio
|
return ratio
|
||||||
@@ -476,24 +491,3 @@ func GetAudioCompletionRatio(name string) float64 {
|
|||||||
}
|
}
|
||||||
return 2
|
return 2
|
||||||
}
|
}
|
||||||
|
|
||||||
//func GetAudioPricePerMinute(name string) float64 {
|
|
||||||
// if strings.HasPrefix(name, "gpt-4o-realtime") {
|
|
||||||
// return 0.06
|
|
||||||
// }
|
|
||||||
// return 0.06
|
|
||||||
//}
|
|
||||||
//
|
|
||||||
//func GetAudioCompletionPricePerMinute(name string) float64 {
|
|
||||||
// if strings.HasPrefix(name, "gpt-4o-realtime") {
|
|
||||||
// return 0.24
|
|
||||||
// }
|
|
||||||
// return 0.24
|
|
||||||
//}
|
|
||||||
|
|
||||||
func GetCompletionRatioMap() map[string]float64 {
|
|
||||||
if CompletionRatio == nil {
|
|
||||||
CompletionRatio = defaultCompletionRatio
|
|
||||||
}
|
|
||||||
return CompletionRatio
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -29,6 +29,9 @@ var GeminiModelMap = map[string]string{
|
|||||||
|
|
||||||
var GeminiVisionMaxImageNum = common.GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16)
|
var GeminiVisionMaxImageNum = common.GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16)
|
||||||
|
|
||||||
|
var NotifyLimitCount = common.GetEnvOrDefault("NOTIFY_LIMIT_COUNT", 2)
|
||||||
|
var NotificationLimitDurationMinute = common.GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10)
|
||||||
|
|
||||||
func InitEnv() {
|
func InitEnv() {
|
||||||
modelVersionMapStr := strings.TrimSpace(os.Getenv("GEMINI_MODEL_MAP"))
|
modelVersionMapStr := strings.TrimSpace(os.Getenv("GEMINI_MODEL_MAP"))
|
||||||
if modelVersionMapStr == "" {
|
if modelVersionMapStr == "" {
|
||||||
@@ -44,5 +47,5 @@ func InitEnv() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 是否生成初始令牌,默认关闭。
|
// GenerateDefaultToken 是否生成初始令牌,默认关闭。
|
||||||
var GenerateDefaultToken = common.GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false)
|
var GenerateDefaultToken = common.GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false)
|
||||||
|
|||||||
14
constant/user_setting.go
Normal file
14
constant/user_setting.go
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
package constant
|
||||||
|
|
||||||
|
var (
|
||||||
|
UserSettingNotifyType = "notify_type" // QuotaWarningType 额度预警类型
|
||||||
|
UserSettingQuotaWarningThreshold = "quota_warning_threshold" // QuotaWarningThreshold 额度预警阈值
|
||||||
|
UserSettingWebhookUrl = "webhook_url" // WebhookUrl webhook地址
|
||||||
|
UserSettingWebhookSecret = "webhook_secret" // WebhookSecret webhook密钥
|
||||||
|
UserSettingNotificationEmail = "notification_email" // NotificationEmail 通知邮箱地址
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
NotifyTypeEmail = "email" // Email 邮件
|
||||||
|
NotifyTypeWebhook = "webhook" // Webhook
|
||||||
|
)
|
||||||
@@ -238,9 +238,7 @@ var testAllChannelsLock sync.Mutex
|
|||||||
var testAllChannelsRunning bool = false
|
var testAllChannelsRunning bool = false
|
||||||
|
|
||||||
func testAllChannels(notify bool) error {
|
func testAllChannels(notify bool) error {
|
||||||
if common.RootUserEmail == "" {
|
|
||||||
common.RootUserEmail = model.GetRootUserEmail()
|
|
||||||
}
|
|
||||||
testAllChannelsLock.Lock()
|
testAllChannelsLock.Lock()
|
||||||
if testAllChannelsRunning {
|
if testAllChannelsRunning {
|
||||||
testAllChannelsLock.Unlock()
|
testAllChannelsLock.Unlock()
|
||||||
@@ -295,10 +293,7 @@ func testAllChannels(notify bool) error {
|
|||||||
testAllChannelsRunning = false
|
testAllChannelsRunning = false
|
||||||
testAllChannelsLock.Unlock()
|
testAllChannelsLock.Unlock()
|
||||||
if notify {
|
if notify {
|
||||||
err := common.SendEmail("通道测试完成", common.RootUserEmail, "通道测试完成,如果没有收到禁用通知,说明所有通道都正常")
|
service.NotifyRootUser(dto.NotifyTypeChannelTest, "通道测试完成", "所有通道测试已完成")
|
||||||
if err != nil {
|
|
||||||
common.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ func GetPricing(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
var group string
|
var group string
|
||||||
if exists {
|
if exists {
|
||||||
user, err := model.GetUserById(userId.(int), false)
|
user, err := model.GetUserCache(userId.(int))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
group = user.Group
|
group = user.Group
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/url"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/setting"
|
"one-api/setting"
|
||||||
@@ -471,7 +472,7 @@ func GetUserModels(c *gin.Context) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
id = c.GetInt("id")
|
id = c.GetInt("id")
|
||||||
}
|
}
|
||||||
user, err := model.GetUserById(id, true)
|
user, err := model.GetUserCache(id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
@@ -869,9 +870,6 @@ func EmailBind(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if user.Role == common.RoleRootUser {
|
|
||||||
common.RootUserEmail = email
|
|
||||||
}
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
@@ -913,3 +911,115 @@ func TopUp(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type UpdateUserSettingRequest struct {
|
||||||
|
QuotaWarningType string `json:"notify_type"`
|
||||||
|
QuotaWarningThreshold int `json:"quota_warning_threshold"`
|
||||||
|
WebhookUrl string `json:"webhook_url,omitempty"`
|
||||||
|
WebhookSecret string `json:"webhook_secret,omitempty"`
|
||||||
|
NotificationEmail string `json:"notification_email,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func UpdateUserSetting(c *gin.Context) {
|
||||||
|
var req UpdateUserSettingRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "无效的参数",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证预警类型
|
||||||
|
if req.QuotaWarningType != constant.NotifyTypeEmail && req.QuotaWarningType != constant.NotifyTypeWebhook {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "无效的预警类型",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证预警阈值
|
||||||
|
if req.QuotaWarningThreshold <= 0 {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "预警阈值必须大于0",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果是webhook类型,验证webhook地址
|
||||||
|
if req.QuotaWarningType == constant.NotifyTypeWebhook {
|
||||||
|
if req.WebhookUrl == "" {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "Webhook地址不能为空",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 验证URL格式
|
||||||
|
if _, err := url.ParseRequestURI(req.WebhookUrl); err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "无效的Webhook地址",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果是邮件类型,验证邮箱地址
|
||||||
|
if req.QuotaWarningType == constant.NotifyTypeEmail && req.NotificationEmail != "" {
|
||||||
|
// 验证邮箱格式
|
||||||
|
if !strings.Contains(req.NotificationEmail, "@") {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "无效的邮箱地址",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
userId := c.GetInt("id")
|
||||||
|
user, err := model.GetUserById(userId, true)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 构建设置
|
||||||
|
settings := map[string]interface{}{
|
||||||
|
constant.UserSettingNotifyType: req.QuotaWarningType,
|
||||||
|
constant.UserSettingQuotaWarningThreshold: req.QuotaWarningThreshold,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果是webhook类型,添加webhook相关设置
|
||||||
|
if req.QuotaWarningType == constant.NotifyTypeWebhook {
|
||||||
|
settings[constant.UserSettingWebhookUrl] = req.WebhookUrl
|
||||||
|
if req.WebhookSecret != "" {
|
||||||
|
settings[constant.UserSettingWebhookSecret] = req.WebhookSecret
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果提供了通知邮箱,添加到设置中
|
||||||
|
if req.QuotaWarningType == constant.NotifyTypeEmail && req.NotificationEmail != "" {
|
||||||
|
settings[constant.UserSettingNotificationEmail] = req.NotificationEmail
|
||||||
|
}
|
||||||
|
|
||||||
|
// 更新用户设置
|
||||||
|
user.SetSetting(settings)
|
||||||
|
if err := user.Update(false); err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "更新设置失败: " + err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"message": "设置已更新",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ services:
|
|||||||
- redis
|
- redis
|
||||||
- mysql
|
- mysql
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: [ "CMD-SHELL", "wget -q -O - http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk -F: '{print $2}'" ]
|
test: ["CMD-SHELL", "wget -q -O - http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk -F: '{print $$2}'"]
|
||||||
interval: 30s
|
interval: 30s
|
||||||
timeout: 10s
|
timeout: 10s
|
||||||
retries: 3
|
retries: 3
|
||||||
|
|||||||
25
dto/notify.go
Normal file
25
dto/notify.go
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
package dto
|
||||||
|
|
||||||
|
type Notify struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Title string `json:"title"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
Values []interface{} `json:"values"`
|
||||||
|
}
|
||||||
|
|
||||||
|
const ContentValueParam = "{{value}}"
|
||||||
|
|
||||||
|
const (
|
||||||
|
NotifyTypeQuotaExceed = "quota_exceed"
|
||||||
|
NotifyTypeChannelUpdate = "channel_update"
|
||||||
|
NotifyTypeChannelTest = "channel_test"
|
||||||
|
)
|
||||||
|
|
||||||
|
func NewNotify(t string, title string, content string, values []interface{}) Notify {
|
||||||
|
return Notify{
|
||||||
|
Type: t,
|
||||||
|
Title: title,
|
||||||
|
Content: content,
|
||||||
|
Values: values,
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -18,6 +18,7 @@ type GeneralOpenAIRequest struct {
|
|||||||
Model string `json:"model,omitempty"`
|
Model string `json:"model,omitempty"`
|
||||||
Messages []Message `json:"messages,omitempty"`
|
Messages []Message `json:"messages,omitempty"`
|
||||||
Prompt any `json:"prompt,omitempty"`
|
Prompt any `json:"prompt,omitempty"`
|
||||||
|
Prefix any `json:"prefix,omitempty"`
|
||||||
Suffix any `json:"suffix,omitempty"`
|
Suffix any `json:"suffix,omitempty"`
|
||||||
Stream bool `json:"stream,omitempty"`
|
Stream bool `json:"stream,omitempty"`
|
||||||
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
|
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
|
||||||
|
|||||||
@@ -3,13 +3,11 @@ package model
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"one-api/common"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/bytedance/gopkg/util/gopool"
|
"github.com/bytedance/gopkg/util/gopool"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"one-api/common"
|
|
||||||
relaycommon "one-api/relay/common"
|
|
||||||
"one-api/setting"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Token struct {
|
type Token struct {
|
||||||
@@ -322,80 +320,3 @@ func decreaseTokenQuota(id int, quota int) (err error) {
|
|||||||
).Error
|
).Error
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func PreConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, quota int) error {
|
|
||||||
if quota < 0 {
|
|
||||||
return errors.New("quota 不能为负数!")
|
|
||||||
}
|
|
||||||
if relayInfo.IsPlayground {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
//if relayInfo.TokenUnlimited {
|
|
||||||
// return nil
|
|
||||||
//}
|
|
||||||
token, err := GetTokenById(relayInfo.TokenId)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if !relayInfo.TokenUnlimited && token.RemainQuota < quota {
|
|
||||||
return errors.New("令牌额度不足")
|
|
||||||
}
|
|
||||||
err = DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func PostConsumeQuota(relayInfo *relaycommon.RelayInfo, userQuota int, quota int, preConsumedQuota int, sendEmail bool) (err error) {
|
|
||||||
|
|
||||||
if quota > 0 {
|
|
||||||
err = DecreaseUserQuota(relayInfo.UserId, quota)
|
|
||||||
} else {
|
|
||||||
err = IncreaseUserQuota(relayInfo.UserId, -quota)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if !relayInfo.IsPlayground {
|
|
||||||
if quota > 0 {
|
|
||||||
err = DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota)
|
|
||||||
} else {
|
|
||||||
err = IncreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, -quota)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if sendEmail {
|
|
||||||
if (quota + preConsumedQuota) != 0 {
|
|
||||||
quotaTooLow := userQuota >= common.QuotaRemindThreshold && userQuota-(quota+preConsumedQuota) < common.QuotaRemindThreshold
|
|
||||||
noMoreQuota := userQuota-(quota+preConsumedQuota) <= 0
|
|
||||||
if quotaTooLow || noMoreQuota {
|
|
||||||
go func() {
|
|
||||||
email, err := GetUserEmail(relayInfo.UserId)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("failed to fetch user email: " + err.Error())
|
|
||||||
}
|
|
||||||
prompt := "您的额度即将用尽"
|
|
||||||
if noMoreQuota {
|
|
||||||
prompt = "您的额度已用尽"
|
|
||||||
}
|
|
||||||
if email != "" {
|
|
||||||
topUpLink := fmt.Sprintf("%s/topup", setting.ServerAddress)
|
|
||||||
err = common.SendEmail(prompt, email,
|
|
||||||
fmt.Sprintf("%s,当前剩余额度为 %d,为了不影响您的使用,请及时充值。<br/>充值链接:<a href='%s'>%s</a>", prompt, userQuota, topUpLink, topUpLink))
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("failed to send email" + err.Error())
|
|
||||||
}
|
|
||||||
common.SysLog("user quota is low, consumed quota: " + strconv.Itoa(quota) + ", user quota: " + strconv.Itoa(userQuota))
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -52,7 +52,7 @@ func cacheSetTokenField(key string, field string, value string) error {
|
|||||||
func cacheGetTokenByKey(key string) (*Token, error) {
|
func cacheGetTokenByKey(key string) (*Token, error) {
|
||||||
hmacKey := common.GenerateHMAC(key)
|
hmacKey := common.GenerateHMAC(key)
|
||||||
if !common.RedisEnabled {
|
if !common.RedisEnabled {
|
||||||
return nil, nil
|
return nil, fmt.Errorf("redis is not enabled")
|
||||||
}
|
}
|
||||||
var token Token
|
var token Token
|
||||||
err := common.RedisHGetObj(fmt.Sprintf("token:%s", hmacKey), &token)
|
err := common.RedisHGetObj(fmt.Sprintf("token:%s", hmacKey), &token)
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package model
|
package model
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
@@ -38,6 +39,20 @@ type User struct {
|
|||||||
InviterId int `json:"inviter_id" gorm:"type:int;column:inviter_id;index"`
|
InviterId int `json:"inviter_id" gorm:"type:int;column:inviter_id;index"`
|
||||||
DeletedAt gorm.DeletedAt `gorm:"index"`
|
DeletedAt gorm.DeletedAt `gorm:"index"`
|
||||||
LinuxDOId string `json:"linux_do_id" gorm:"column:linux_do_id;index"`
|
LinuxDOId string `json:"linux_do_id" gorm:"column:linux_do_id;index"`
|
||||||
|
Setting string `json:"setting" gorm:"type:text;column:setting"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (user *User) ToBaseUser() *UserBase {
|
||||||
|
cache := &UserBase{
|
||||||
|
Id: user.Id,
|
||||||
|
Group: user.Group,
|
||||||
|
Quota: user.Quota,
|
||||||
|
Status: user.Status,
|
||||||
|
Username: user.Username,
|
||||||
|
Setting: user.Setting,
|
||||||
|
Email: user.Email,
|
||||||
|
}
|
||||||
|
return cache
|
||||||
}
|
}
|
||||||
|
|
||||||
func (user *User) GetAccessToken() string {
|
func (user *User) GetAccessToken() string {
|
||||||
@@ -51,6 +66,22 @@ func (user *User) SetAccessToken(token string) {
|
|||||||
user.AccessToken = &token
|
user.AccessToken = &token
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (user *User) GetSetting() map[string]interface{} {
|
||||||
|
if user.Setting == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return common.StrToMap(user.Setting)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (user *User) SetSetting(setting map[string]interface{}) {
|
||||||
|
settingBytes, err := json.Marshal(setting)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("failed to marshal setting: " + err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
user.Setting = string(settingBytes)
|
||||||
|
}
|
||||||
|
|
||||||
// CheckUserExistOrDeleted check if user exist or deleted, if not exist, return false, nil, if deleted or exist, return true, nil
|
// CheckUserExistOrDeleted check if user exist or deleted, if not exist, return false, nil, if deleted or exist, return true, nil
|
||||||
func CheckUserExistOrDeleted(username string, email string) (bool, error) {
|
func CheckUserExistOrDeleted(username string, email string) (bool, error) {
|
||||||
var user User
|
var user User
|
||||||
@@ -315,8 +346,8 @@ func (user *User) Update(updatePassword bool) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 更新缓存
|
// Update cache
|
||||||
return updateUserCache(user.Id, user.Username, user.Group, user.Quota, user.Status)
|
return updateUserCache(*user)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (user *User) Edit(updatePassword bool) error {
|
func (user *User) Edit(updatePassword bool) error {
|
||||||
@@ -344,8 +375,8 @@ func (user *User) Edit(updatePassword bool) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 更新缓存
|
// Update cache
|
||||||
return updateUserCache(user.Id, user.Username, user.Group, user.Quota, user.Status)
|
return updateUserCache(*user)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (user *User) Delete() error {
|
func (user *User) Delete() error {
|
||||||
@@ -371,8 +402,8 @@ func (user *User) HardDelete() error {
|
|||||||
// ValidateAndFill check password & user status
|
// ValidateAndFill check password & user status
|
||||||
func (user *User) ValidateAndFill() (err error) {
|
func (user *User) ValidateAndFill() (err error) {
|
||||||
// When querying with struct, GORM will only query with non-zero fields,
|
// When querying with struct, GORM will only query with non-zero fields,
|
||||||
// that means if your field’s value is 0, '', false or other zero values,
|
// that means if your field's value is 0, '', false or other zero values,
|
||||||
// it won’t be used to build query conditions
|
// it won't be used to build query conditions
|
||||||
password := user.Password
|
password := user.Password
|
||||||
username := strings.TrimSpace(user.Username)
|
username := strings.TrimSpace(user.Username)
|
||||||
if username == "" || password == "" {
|
if username == "" || password == "" {
|
||||||
@@ -531,7 +562,6 @@ func GetUserQuota(id int, fromDB bool) (quota int, err error) {
|
|||||||
return quota, nil
|
return quota, nil
|
||||||
}
|
}
|
||||||
// Don't return error - fall through to DB
|
// Don't return error - fall through to DB
|
||||||
//common.SysError("failed to get user quota from cache: " + err.Error())
|
|
||||||
}
|
}
|
||||||
fromDB = true
|
fromDB = true
|
||||||
err = DB.Model(&User{}).Where("id = ?", id).Select("quota").Find("a).Error
|
err = DB.Model(&User{}).Where("id = ?", id).Select("quota").Find("a).Error
|
||||||
@@ -580,6 +610,35 @@ func GetUserGroup(id int, fromDB bool) (group string, err error) {
|
|||||||
return group, nil
|
return group, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetUserSetting gets setting from Redis first, falls back to DB if needed
|
||||||
|
func GetUserSetting(id int, fromDB bool) (settingMap map[string]interface{}, err error) {
|
||||||
|
var setting string
|
||||||
|
defer func() {
|
||||||
|
// Update Redis cache asynchronously on successful DB read
|
||||||
|
if shouldUpdateRedis(fromDB, err) {
|
||||||
|
gopool.Go(func() {
|
||||||
|
if err := updateUserSettingCache(id, setting); err != nil {
|
||||||
|
common.SysError("failed to update user setting cache: " + err.Error())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
if !fromDB && common.RedisEnabled {
|
||||||
|
setting, err := getUserSettingCache(id)
|
||||||
|
if err == nil {
|
||||||
|
return setting, nil
|
||||||
|
}
|
||||||
|
// Don't return error - fall through to DB
|
||||||
|
}
|
||||||
|
fromDB = true
|
||||||
|
err = DB.Model(&User{}).Where("id = ?", id).Select("setting").Find(&setting).Error
|
||||||
|
if err != nil {
|
||||||
|
return map[string]interface{}{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return common.StrToMap(setting), nil
|
||||||
|
}
|
||||||
|
|
||||||
func IncreaseUserQuota(id int, quota int) (err error) {
|
func IncreaseUserQuota(id int, quota int) (err error) {
|
||||||
if quota < 0 {
|
if quota < 0 {
|
||||||
return errors.New("quota 不能为负数!")
|
return errors.New("quota 不能为负数!")
|
||||||
@@ -641,9 +700,14 @@ func DeltaUpdateUserQuota(id int, delta int) (err error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetRootUserEmail() (email string) {
|
//func GetRootUserEmail() (email string) {
|
||||||
DB.Model(&User{}).Where("role = ?", common.RoleRootUser).Select("email").Find(&email)
|
// DB.Model(&User{}).Where("role = ?", common.RoleRootUser).Select("email").Find(&email)
|
||||||
return email
|
// return email
|
||||||
|
//}
|
||||||
|
|
||||||
|
func GetRootUser() (user *User) {
|
||||||
|
DB.Where("role = ?", common.RoleRootUser).First(&user)
|
||||||
|
return user
|
||||||
}
|
}
|
||||||
|
|
||||||
func UpdateUserUsedQuotaAndRequestCount(id int, quota int) {
|
func UpdateUserUsedQuotaAndRequestCount(id int, quota int) {
|
||||||
@@ -725,10 +789,10 @@ func IsLinuxDOIdAlreadyTaken(linuxDOId string) bool {
|
|||||||
return !errors.Is(err, gorm.ErrRecordNotFound)
|
return !errors.Is(err, gorm.ErrRecordNotFound)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *User) FillUserByLinuxDOId() error {
|
func (user *User) FillUserByLinuxDOId() error {
|
||||||
if u.LinuxDOId == "" {
|
if user.LinuxDOId == "" {
|
||||||
return errors.New("linux do id is empty")
|
return errors.New("linux do id is empty")
|
||||||
}
|
}
|
||||||
err := DB.Where("linux_do_id = ?", u.LinuxDOId).First(u).Error
|
err := DB.Where("linux_do_id = ?", user.LinuxDOId).First(user).Error
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,206 +1,213 @@
|
|||||||
package model
|
package model
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
"strconv"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/bytedance/gopkg/util/gopool"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Change UserCache struct to userCache
|
// UserBase struct remains the same as it represents the cached data structure
|
||||||
type userCache struct {
|
type UserBase struct {
|
||||||
Id int `json:"id"`
|
Id int `json:"id"`
|
||||||
Group string `json:"group"`
|
Group string `json:"group"`
|
||||||
|
Email string `json:"email"`
|
||||||
Quota int `json:"quota"`
|
Quota int `json:"quota"`
|
||||||
Status int `json:"status"`
|
Status int `json:"status"`
|
||||||
Role int `json:"role"`
|
|
||||||
Username string `json:"username"`
|
Username string `json:"username"`
|
||||||
|
Setting string `json:"setting"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Rename all exported functions to private ones
|
func (user *UserBase) GetSetting() map[string]interface{} {
|
||||||
// invalidateUserCache clears all user related cache
|
if user.Setting == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return common.StrToMap(user.Setting)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (user *UserBase) SetSetting(setting map[string]interface{}) {
|
||||||
|
settingBytes, err := json.Marshal(setting)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("failed to marshal setting: " + err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
user.Setting = string(settingBytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// getUserCacheKey returns the key for user cache
|
||||||
|
func getUserCacheKey(userId int) string {
|
||||||
|
return fmt.Sprintf("user:%d", userId)
|
||||||
|
}
|
||||||
|
|
||||||
|
// invalidateUserCache clears user cache
|
||||||
func invalidateUserCache(userId int) error {
|
func invalidateUserCache(userId int) error {
|
||||||
if !common.RedisEnabled {
|
if !common.RedisEnabled {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
return common.RedisHDelObj(getUserCacheKey(userId))
|
||||||
|
}
|
||||||
|
|
||||||
keys := []string{
|
// updateUserCache updates all user cache fields using hash
|
||||||
fmt.Sprintf(constant.UserGroupKeyFmt, userId),
|
func updateUserCache(user User) error {
|
||||||
fmt.Sprintf(constant.UserQuotaKeyFmt, userId),
|
if !common.RedisEnabled {
|
||||||
fmt.Sprintf(constant.UserEnabledKeyFmt, userId),
|
return nil
|
||||||
fmt.Sprintf(constant.UserUsernameKeyFmt, userId),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, key := range keys {
|
return common.RedisHSetObj(
|
||||||
if err := common.RedisDel(key); err != nil {
|
getUserCacheKey(user.Id),
|
||||||
return fmt.Errorf("failed to delete cache key %s: %w", key, err)
|
user.ToBaseUser(),
|
||||||
|
time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUserCache gets complete user cache from hash
|
||||||
|
func GetUserCache(userId int) (userCache *UserBase, err error) {
|
||||||
|
var user *User
|
||||||
|
var fromDB bool
|
||||||
|
defer func() {
|
||||||
|
// Update Redis cache asynchronously on successful DB read
|
||||||
|
if shouldUpdateRedis(fromDB, err) && user != nil {
|
||||||
|
gopool.Go(func() {
|
||||||
|
if err := updateUserCache(*user); err != nil {
|
||||||
|
common.SysError("failed to update user status cache: " + err.Error())
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}()
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// updateUserGroupCache updates user group cache
|
// Try getting from Redis first
|
||||||
func updateUserGroupCache(userId int, group string) error {
|
userCache, err = cacheGetUserBase(userId)
|
||||||
if !common.RedisEnabled {
|
if err == nil {
|
||||||
return nil
|
return userCache, nil
|
||||||
}
|
|
||||||
return common.RedisSet(
|
|
||||||
fmt.Sprintf(constant.UserGroupKeyFmt, userId),
|
|
||||||
group,
|
|
||||||
time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
// updateUserQuotaCache updates user quota cache
|
|
||||||
func updateUserQuotaCache(userId int, quota int) error {
|
|
||||||
if !common.RedisEnabled {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return common.RedisSet(
|
|
||||||
fmt.Sprintf(constant.UserQuotaKeyFmt, userId),
|
|
||||||
fmt.Sprintf("%d", quota),
|
|
||||||
time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
// updateUserStatusCache updates user status cache
|
|
||||||
func updateUserStatusCache(userId int, userEnabled bool) error {
|
|
||||||
if !common.RedisEnabled {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
enabled := "0"
|
|
||||||
if userEnabled {
|
|
||||||
enabled = "1"
|
|
||||||
}
|
|
||||||
return common.RedisSet(
|
|
||||||
fmt.Sprintf(constant.UserEnabledKeyFmt, userId),
|
|
||||||
enabled,
|
|
||||||
time.Duration(constant.UserId2StatusCacheSeconds)*time.Second,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
// updateUserNameCache updates username cache
|
|
||||||
func updateUserNameCache(userId int, username string) error {
|
|
||||||
if !common.RedisEnabled {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return common.RedisSet(
|
|
||||||
fmt.Sprintf(constant.UserUsernameKeyFmt, userId),
|
|
||||||
username,
|
|
||||||
time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
// updateUserCache updates all user cache fields
|
|
||||||
func updateUserCache(userId int, username string, userGroup string, quota int, status int) error {
|
|
||||||
if !common.RedisEnabled {
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := updateUserGroupCache(userId, userGroup); err != nil {
|
// If Redis fails, get from DB
|
||||||
return fmt.Errorf("update group cache: %w", err)
|
fromDB = true
|
||||||
}
|
user, err = GetUserById(userId, false)
|
||||||
|
|
||||||
if err := updateUserQuotaCache(userId, quota); err != nil {
|
|
||||||
return fmt.Errorf("update quota cache: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := updateUserStatusCache(userId, status == common.UserStatusEnabled); err != nil {
|
|
||||||
return fmt.Errorf("update status cache: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := updateUserNameCache(userId, username); err != nil {
|
|
||||||
return fmt.Errorf("update username cache: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// getUserGroupCache gets user group from cache
|
|
||||||
func getUserGroupCache(userId int) (string, error) {
|
|
||||||
if !common.RedisEnabled {
|
|
||||||
return "", nil
|
|
||||||
}
|
|
||||||
return common.RedisGet(fmt.Sprintf(constant.UserGroupKeyFmt, userId))
|
|
||||||
}
|
|
||||||
|
|
||||||
// getUserQuotaCache gets user quota from cache
|
|
||||||
func getUserQuotaCache(userId int) (int, error) {
|
|
||||||
if !common.RedisEnabled {
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
quotaStr, err := common.RedisGet(fmt.Sprintf(constant.UserQuotaKeyFmt, userId))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return nil, err // Return nil and error if DB lookup fails
|
||||||
}
|
}
|
||||||
return strconv.Atoi(quotaStr)
|
|
||||||
|
// Create cache object from user data
|
||||||
|
userCache = &UserBase{
|
||||||
|
Id: user.Id,
|
||||||
|
Group: user.Group,
|
||||||
|
Quota: user.Quota,
|
||||||
|
Status: user.Status,
|
||||||
|
Username: user.Username,
|
||||||
|
Setting: user.Setting,
|
||||||
|
Email: user.Email,
|
||||||
|
}
|
||||||
|
|
||||||
|
return userCache, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// getUserStatusCache gets user status from cache
|
func cacheGetUserBase(userId int) (*UserBase, error) {
|
||||||
func getUserStatusCache(userId int) (int, error) {
|
|
||||||
if !common.RedisEnabled {
|
if !common.RedisEnabled {
|
||||||
return 0, nil
|
return nil, fmt.Errorf("redis is not enabled")
|
||||||
}
|
}
|
||||||
statusStr, err := common.RedisGet(fmt.Sprintf(constant.UserEnabledKeyFmt, userId))
|
var userCache UserBase
|
||||||
|
// Try getting from Redis first
|
||||||
|
err := common.RedisHGetObj(getUserCacheKey(userId), &userCache)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return strconv.Atoi(statusStr)
|
return &userCache, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// getUserNameCache gets username from cache
|
// Add atomic quota operations using hash fields
|
||||||
func getUserNameCache(userId int) (string, error) {
|
|
||||||
if !common.RedisEnabled {
|
|
||||||
return "", nil
|
|
||||||
}
|
|
||||||
return common.RedisGet(fmt.Sprintf(constant.UserUsernameKeyFmt, userId))
|
|
||||||
}
|
|
||||||
|
|
||||||
// getUserCache gets complete user cache
|
|
||||||
func getUserCache(userId int) (*userCache, error) {
|
|
||||||
if !common.RedisEnabled {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
group, err := getUserGroupCache(userId)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("get group cache: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
quota, err := getUserQuotaCache(userId)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("get quota cache: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
status, err := getUserStatusCache(userId)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("get status cache: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
username, err := getUserNameCache(userId)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("get username cache: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &userCache{
|
|
||||||
Id: userId,
|
|
||||||
Group: group,
|
|
||||||
Quota: quota,
|
|
||||||
Status: status,
|
|
||||||
Username: username,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add atomic quota operations
|
|
||||||
func cacheIncrUserQuota(userId int, delta int64) error {
|
func cacheIncrUserQuota(userId int, delta int64) error {
|
||||||
if !common.RedisEnabled {
|
if !common.RedisEnabled {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
key := fmt.Sprintf(constant.UserQuotaKeyFmt, userId)
|
return common.RedisHIncrBy(getUserCacheKey(userId), "Quota", delta)
|
||||||
return common.RedisIncr(key, delta)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func cacheDecrUserQuota(userId int, delta int64) error {
|
func cacheDecrUserQuota(userId int, delta int64) error {
|
||||||
return cacheIncrUserQuota(userId, -delta)
|
return cacheIncrUserQuota(userId, -delta)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Helper functions to get individual fields if needed
|
||||||
|
func getUserGroupCache(userId int) (string, error) {
|
||||||
|
cache, err := GetUserCache(userId)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return cache.Group, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getUserQuotaCache(userId int) (int, error) {
|
||||||
|
cache, err := GetUserCache(userId)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return cache.Quota, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getUserStatusCache(userId int) (int, error) {
|
||||||
|
cache, err := GetUserCache(userId)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return cache.Status, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getUserNameCache(userId int) (string, error) {
|
||||||
|
cache, err := GetUserCache(userId)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return cache.Username, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getUserSettingCache(userId int) (map[string]interface{}, error) {
|
||||||
|
setting := make(map[string]interface{})
|
||||||
|
cache, err := GetUserCache(userId)
|
||||||
|
if err != nil {
|
||||||
|
return setting, err
|
||||||
|
}
|
||||||
|
return cache.GetSetting(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// New functions for individual field updates
|
||||||
|
func updateUserStatusCache(userId int, status bool) error {
|
||||||
|
if !common.RedisEnabled {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
statusInt := common.UserStatusEnabled
|
||||||
|
if !status {
|
||||||
|
statusInt = common.UserStatusDisabled
|
||||||
|
}
|
||||||
|
return common.RedisHSetField(getUserCacheKey(userId), "Status", fmt.Sprintf("%d", statusInt))
|
||||||
|
}
|
||||||
|
|
||||||
|
func updateUserQuotaCache(userId int, quota int) error {
|
||||||
|
if !common.RedisEnabled {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return common.RedisHSetField(getUserCacheKey(userId), "Quota", fmt.Sprintf("%d", quota))
|
||||||
|
}
|
||||||
|
|
||||||
|
func updateUserGroupCache(userId int, group string) error {
|
||||||
|
if !common.RedisEnabled {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return common.RedisHSetField(getUserCacheKey(userId), "Group", group)
|
||||||
|
}
|
||||||
|
|
||||||
|
func updateUserNameCache(userId int, username string) error {
|
||||||
|
if !common.RedisEnabled {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return common.RedisHSetField(getUserCacheKey(userId), "Username", username)
|
||||||
|
}
|
||||||
|
|
||||||
|
func updateUserSettingCache(userId int, setting string) error {
|
||||||
|
if !common.RedisEnabled {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return common.RedisHSetField(getUserCacheKey(userId), "Setting", setting)
|
||||||
|
}
|
||||||
|
|||||||
@@ -4,13 +4,14 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/relay/channel"
|
"one-api/relay/channel"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/relay/constant"
|
"one-api/relay/constant"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Adaptor struct {
|
type Adaptor struct {
|
||||||
|
|||||||
@@ -194,7 +194,7 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
|
|||||||
}
|
}
|
||||||
defer func(ctx context.Context) {
|
defer func(ctx context.Context) {
|
||||||
if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 {
|
if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 {
|
||||||
err := model.PostConsumeQuota(relayInfo, userQuota, quota, 0, true)
|
err := service.PostConsumeQuota(relayInfo, quota, 0, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("error consuming token remain quota: " + err.Error())
|
common.SysError("error consuming token remain quota: " + err.Error())
|
||||||
}
|
}
|
||||||
@@ -500,7 +500,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
|||||||
|
|
||||||
defer func(ctx context.Context) {
|
defer func(ctx context.Context) {
|
||||||
if consumeQuota && midjResponseWithStatus.StatusCode == 200 {
|
if consumeQuota && midjResponseWithStatus.StatusCode == 200 {
|
||||||
err := model.PostConsumeQuota(relayInfo, userQuota, quota, 0, true)
|
err := service.PostConsumeQuota(relayInfo, quota, 0, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("error consuming token remain quota: " + err.Error())
|
common.SysError("error consuming token remain quota: " + err.Error())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -272,7 +272,7 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
|
|||||||
return 0, 0, service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
return 0, 0, service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
||||||
}
|
}
|
||||||
if userQuota-preConsumedQuota < 0 {
|
if userQuota-preConsumedQuota < 0 {
|
||||||
return 0, 0, service.OpenAIErrorWrapperLocal(fmt.Errorf("chat pre-consumed quota failed, user quota: %d, need quota: %d", userQuota, preConsumedQuota), "insufficient_user_quota", http.StatusBadRequest)
|
return 0, 0, service.OpenAIErrorWrapperLocal(fmt.Errorf("chat pre-consumed quota failed, user quota: %s, need quota: %d", common.FormatQuota(userQuota), preConsumedQuota), "insufficient_user_quota", http.StatusBadRequest)
|
||||||
}
|
}
|
||||||
if userQuota > 100*preConsumedQuota {
|
if userQuota > 100*preConsumedQuota {
|
||||||
// 用户额度充足,判断令牌额度是否充足
|
// 用户额度充足,判断令牌额度是否充足
|
||||||
@@ -282,18 +282,18 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
|
|||||||
if tokenQuota > 100*preConsumedQuota {
|
if tokenQuota > 100*preConsumedQuota {
|
||||||
// 令牌额度充足,信任令牌
|
// 令牌额度充足,信任令牌
|
||||||
preConsumedQuota = 0
|
preConsumedQuota = 0
|
||||||
common.LogInfo(c, fmt.Sprintf("user %d quota %d and token %d quota %d are enough, trusted and no need to pre-consume", relayInfo.UserId, userQuota, relayInfo.TokenId, tokenQuota))
|
common.LogInfo(c, fmt.Sprintf("user %d quota %s and token %d quota %d are enough, trusted and no need to pre-consume", relayInfo.UserId, common.FormatQuota(userQuota), relayInfo.TokenId, tokenQuota))
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// in this case, we do not pre-consume quota
|
// in this case, we do not pre-consume quota
|
||||||
// because the user has enough quota
|
// because the user has enough quota
|
||||||
preConsumedQuota = 0
|
preConsumedQuota = 0
|
||||||
common.LogInfo(c, fmt.Sprintf("user %d with unlimited token has enough quota %d, trusted and no need to pre-consume", relayInfo.UserId, userQuota))
|
common.LogInfo(c, fmt.Sprintf("user %d with unlimited token has enough quota %s, trusted and no need to pre-consume", relayInfo.UserId, common.FormatQuota(userQuota)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if preConsumedQuota > 0 {
|
if preConsumedQuota > 0 {
|
||||||
err = model.PreConsumeTokenQuota(relayInfo, preConsumedQuota)
|
err = service.PreConsumeTokenQuota(relayInfo, preConsumedQuota)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, 0, service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden)
|
return 0, 0, service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden)
|
||||||
}
|
}
|
||||||
@@ -310,7 +310,7 @@ func returnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, us
|
|||||||
go func() {
|
go func() {
|
||||||
relayInfoCopy := *relayInfo
|
relayInfoCopy := *relayInfo
|
||||||
|
|
||||||
err := model.PostConsumeQuota(&relayInfoCopy, userQuota, -preConsumedQuota, 0, false)
|
err := service.PostConsumeQuota(&relayInfoCopy, -preConsumedQuota, 0, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("error return pre-consumed quota: " + err.Error())
|
common.SysError("error return pre-consumed quota: " + err.Error())
|
||||||
}
|
}
|
||||||
@@ -368,7 +368,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelN
|
|||||||
//}
|
//}
|
||||||
quotaDelta := quota - preConsumedQuota
|
quotaDelta := quota - preConsumedQuota
|
||||||
if quotaDelta != 0 {
|
if quotaDelta != 0 {
|
||||||
err := model.PostConsumeQuota(relayInfo, userQuota, quotaDelta, preConsumedQuota, true)
|
err := service.PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
|
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -113,7 +113,7 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
|
|||||||
// release quota
|
// release quota
|
||||||
if relayInfo.ConsumeQuota && taskErr == nil {
|
if relayInfo.ConsumeQuota && taskErr == nil {
|
||||||
|
|
||||||
err := model.PostConsumeQuota(relayInfo.ToRelayInfo(), userQuota, quota, 0, true)
|
err := service.PostConsumeQuota(relayInfo.ToRelayInfo(), quota, 0, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("error consuming token remain quota: " + err.Error())
|
common.SysError("error consuming token remain quota: " + err.Error())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -56,6 +56,7 @@ func SetApiRouter(router *gin.Engine) {
|
|||||||
selfRoute.POST("/pay", controller.RequestEpay)
|
selfRoute.POST("/pay", controller.RequestEpay)
|
||||||
selfRoute.POST("/amount", controller.RequestAmount)
|
selfRoute.POST("/amount", controller.RequestAmount)
|
||||||
selfRoute.POST("/aff_transfer", controller.TransferAffQuota)
|
selfRoute.POST("/aff_transfer", controller.TransferAffQuota)
|
||||||
|
selfRoute.PUT("/setting", controller.UpdateUserSetting)
|
||||||
}
|
}
|
||||||
|
|
||||||
adminRoute := userRoute.Group("/")
|
adminRoute := userRoute.Group("/")
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package service
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
@@ -9,19 +10,46 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// WorkerRequest Worker请求的数据结构
|
||||||
|
type WorkerRequest struct {
|
||||||
|
URL string `json:"url"`
|
||||||
|
Key string `json:"key"`
|
||||||
|
Method string `json:"method,omitempty"`
|
||||||
|
Headers map[string]string `json:"headers,omitempty"`
|
||||||
|
Body json.RawMessage `json:"body,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// DoWorkerRequest 通过Worker发送请求
|
||||||
|
func DoWorkerRequest(req *WorkerRequest) (*http.Response, error) {
|
||||||
|
if !setting.EnableWorker() {
|
||||||
|
return nil, fmt.Errorf("worker not enabled")
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(req.URL, "https") {
|
||||||
|
return nil, fmt.Errorf("only support https url")
|
||||||
|
}
|
||||||
|
|
||||||
|
workerUrl := setting.WorkerUrl
|
||||||
|
if !strings.HasSuffix(workerUrl, "/") {
|
||||||
|
workerUrl += "/"
|
||||||
|
}
|
||||||
|
|
||||||
|
// 序列化worker请求数据
|
||||||
|
workerPayload, err := json.Marshal(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to marshal worker payload: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return http.Post(workerUrl, "application/json", bytes.NewBuffer(workerPayload))
|
||||||
|
}
|
||||||
|
|
||||||
func DoDownloadRequest(originUrl string) (resp *http.Response, err error) {
|
func DoDownloadRequest(originUrl string) (resp *http.Response, err error) {
|
||||||
if setting.EnableWorker() {
|
if setting.EnableWorker() {
|
||||||
common.SysLog(fmt.Sprintf("downloading file from worker: %s", originUrl))
|
common.SysLog(fmt.Sprintf("downloading file from worker: %s", originUrl))
|
||||||
if !strings.HasPrefix(originUrl, "https") {
|
req := &WorkerRequest{
|
||||||
return nil, fmt.Errorf("only support https url")
|
URL: originUrl,
|
||||||
|
Key: setting.WorkerValidKey,
|
||||||
}
|
}
|
||||||
workerUrl := setting.WorkerUrl
|
return DoWorkerRequest(req)
|
||||||
if !strings.HasSuffix(workerUrl, "/") {
|
|
||||||
workerUrl += "/"
|
|
||||||
}
|
|
||||||
// post request to worker
|
|
||||||
data := []byte(`{"url":"` + originUrl + `","key":"` + setting.WorkerValidKey + `"}`)
|
|
||||||
return http.Post(setting.WorkerUrl, "application/json", bytes.NewBuffer(data))
|
|
||||||
} else {
|
} else {
|
||||||
common.SysLog(fmt.Sprintf("downloading from origin: %s", originUrl))
|
common.SysLog(fmt.Sprintf("downloading from origin: %s", originUrl))
|
||||||
return http.Get(originUrl)
|
return http.Get(originUrl)
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
relaymodel "one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/setting"
|
"one-api/setting"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -15,17 +15,17 @@ func DisableChannel(channelId int, channelName string, reason string) {
|
|||||||
model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled, reason)
|
model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled, reason)
|
||||||
subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId)
|
subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId)
|
||||||
content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason)
|
content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason)
|
||||||
notifyRootUser(subject, content)
|
NotifyRootUser(subject, content, dto.NotifyTypeChannelUpdate)
|
||||||
}
|
}
|
||||||
|
|
||||||
func EnableChannel(channelId int, channelName string) {
|
func EnableChannel(channelId int, channelName string) {
|
||||||
model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled, "")
|
model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled, "")
|
||||||
subject := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId)
|
subject := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId)
|
||||||
content := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId)
|
content := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId)
|
||||||
notifyRootUser(subject, content)
|
NotifyRootUser(subject, content, dto.NotifyTypeChannelUpdate)
|
||||||
}
|
}
|
||||||
|
|
||||||
func ShouldDisableChannel(channelType int, err *relaymodel.OpenAIErrorWithStatusCode) bool {
|
func ShouldDisableChannel(channelType int, err *dto.OpenAIErrorWithStatusCode) bool {
|
||||||
if !common.AutomaticDisableChannelEnabled {
|
if !common.AutomaticDisableChannelEnabled {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@@ -75,7 +75,7 @@ func ShouldDisableChannel(channelType int, err *relaymodel.OpenAIErrorWithStatus
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func ShouldEnableChannel(err error, openaiWithStatusErr *relaymodel.OpenAIErrorWithStatusCode, status int) bool {
|
func ShouldEnableChannel(err error, openaiWithStatusErr *dto.OpenAIErrorWithStatusCode, status int) bool {
|
||||||
if !common.AutomaticEnableChannelEnabled {
|
if !common.AutomaticEnableChannelEnabled {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|||||||
116
service/notify-limit.go
Normal file
116
service/notify-limit.go
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/constant"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// notifyLimitStore is used for in-memory rate limiting when Redis is disabled
|
||||||
|
var (
|
||||||
|
notifyLimitStore sync.Map
|
||||||
|
cleanupOnce sync.Once
|
||||||
|
)
|
||||||
|
|
||||||
|
type limitCount struct {
|
||||||
|
Count int
|
||||||
|
Timestamp time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func getDuration() time.Duration {
|
||||||
|
minute := constant.NotificationLimitDurationMinute
|
||||||
|
return time.Duration(minute) * time.Minute
|
||||||
|
}
|
||||||
|
|
||||||
|
// startCleanupTask starts a background task to clean up expired entries
|
||||||
|
func startCleanupTask() {
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
time.Sleep(time.Hour)
|
||||||
|
now := time.Now()
|
||||||
|
notifyLimitStore.Range(func(key, value interface{}) bool {
|
||||||
|
if limit, ok := value.(limitCount); ok {
|
||||||
|
if now.Sub(limit.Timestamp) >= getDuration() {
|
||||||
|
notifyLimitStore.Delete(key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// CheckNotificationLimit checks if the user has exceeded their notification limit
|
||||||
|
// Returns true if the user can send notification, false if limit exceeded
|
||||||
|
func CheckNotificationLimit(userId int, notifyType string) (bool, error) {
|
||||||
|
if common.RedisEnabled {
|
||||||
|
return checkRedisLimit(userId, notifyType)
|
||||||
|
}
|
||||||
|
return checkMemoryLimit(userId, notifyType)
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkRedisLimit(userId int, notifyType string) (bool, error) {
|
||||||
|
key := fmt.Sprintf("notify_limit:%d:%s:%s", userId, notifyType, time.Now().Format("2006010215"))
|
||||||
|
|
||||||
|
// Get current count
|
||||||
|
count, err := common.RedisGet(key)
|
||||||
|
if err != nil && err.Error() != "redis: nil" {
|
||||||
|
return false, fmt.Errorf("failed to get notification count: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If key doesn't exist, initialize it
|
||||||
|
if count == "" {
|
||||||
|
err = common.RedisSet(key, "1", getDuration())
|
||||||
|
return true, err
|
||||||
|
}
|
||||||
|
|
||||||
|
currentCount, _ := strconv.Atoi(count)
|
||||||
|
limit := constant.NotifyLimitCount
|
||||||
|
|
||||||
|
// Check if limit is already reached
|
||||||
|
if currentCount >= limit {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only increment if under limit
|
||||||
|
err = common.RedisIncr(key, 1)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("failed to increment notification count: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkMemoryLimit(userId int, notifyType string) (bool, error) {
|
||||||
|
// Ensure cleanup task is started
|
||||||
|
cleanupOnce.Do(startCleanupTask)
|
||||||
|
|
||||||
|
key := fmt.Sprintf("%d:%s:%s", userId, notifyType, time.Now().Format("2006010215"))
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
// Get current limit count or initialize new one
|
||||||
|
var currentLimit limitCount
|
||||||
|
if value, ok := notifyLimitStore.Load(key); ok {
|
||||||
|
currentLimit = value.(limitCount)
|
||||||
|
// Check if the entry has expired
|
||||||
|
if now.Sub(currentLimit.Timestamp) >= getDuration() {
|
||||||
|
currentLimit = limitCount{Count: 0, Timestamp: now}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
currentLimit = limitCount{Count: 0, Timestamp: now}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Increment count
|
||||||
|
currentLimit.Count++
|
||||||
|
|
||||||
|
// Check against limits
|
||||||
|
limit := constant.NotifyLimitCount
|
||||||
|
|
||||||
|
// Store updated count
|
||||||
|
notifyLimitStore.Store(key, currentLimit)
|
||||||
|
|
||||||
|
return currentLimit.Count <= limit, nil
|
||||||
|
}
|
||||||
@@ -3,8 +3,10 @@ package service
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/bytedance/gopkg/util/gopool"
|
||||||
"math"
|
"math"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
constant2 "one-api/constant"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
@@ -99,7 +101,7 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
|
|||||||
return errors.New(fmt.Sprintf("令牌额度不足,剩余额度为 %d", token.RemainQuota))
|
return errors.New(fmt.Sprintf("令牌额度不足,剩余额度为 %d", token.RemainQuota))
|
||||||
}
|
}
|
||||||
|
|
||||||
err = model.PostConsumeQuota(relayInfo, 0, quota, 0, false)
|
err = PostConsumeQuota(relayInfo, quota, 0, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -222,7 +224,7 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
|||||||
} else {
|
} else {
|
||||||
quotaDelta := quota - preConsumedQuota
|
quotaDelta := quota - preConsumedQuota
|
||||||
if quotaDelta != 0 {
|
if quotaDelta != 0 {
|
||||||
err := model.PostConsumeQuota(relayInfo, userQuota, quotaDelta, preConsumedQuota, true)
|
err := PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
|
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
|
||||||
}
|
}
|
||||||
@@ -239,3 +241,88 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
|||||||
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.PromptTokens, usage.CompletionTokens, logModel,
|
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.PromptTokens, usage.CompletionTokens, logModel,
|
||||||
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
|
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func PreConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, quota int) error {
|
||||||
|
if quota < 0 {
|
||||||
|
return errors.New("quota 不能为负数!")
|
||||||
|
}
|
||||||
|
if relayInfo.IsPlayground {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
//if relayInfo.TokenUnlimited {
|
||||||
|
// return nil
|
||||||
|
//}
|
||||||
|
token, err := model.GetTokenByKey(relayInfo.TokenKey, false)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !relayInfo.TokenUnlimited && token.RemainQuota < quota {
|
||||||
|
return errors.New("令牌额度不足")
|
||||||
|
}
|
||||||
|
err = model.DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func PostConsumeQuota(relayInfo *relaycommon.RelayInfo, quota int, preConsumedQuota int, sendEmail bool) (err error) {
|
||||||
|
|
||||||
|
if quota > 0 {
|
||||||
|
err = model.DecreaseUserQuota(relayInfo.UserId, quota)
|
||||||
|
} else {
|
||||||
|
err = model.IncreaseUserQuota(relayInfo.UserId, -quota)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !relayInfo.IsPlayground {
|
||||||
|
if quota > 0 {
|
||||||
|
err = model.DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota)
|
||||||
|
} else {
|
||||||
|
err = model.IncreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, -quota)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if sendEmail {
|
||||||
|
if (quota + preConsumedQuota) != 0 {
|
||||||
|
checkAndSendQuotaNotify(relayInfo.UserId, quota, preConsumedQuota)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkAndSendQuotaNotify(userId int, quota int, preConsumedQuota int) {
|
||||||
|
gopool.Go(func() {
|
||||||
|
userCache, err := model.GetUserCache(userId)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("failed to get user cache: " + err.Error())
|
||||||
|
}
|
||||||
|
userSetting := userCache.GetSetting()
|
||||||
|
threshold := common.QuotaRemindThreshold
|
||||||
|
if userCustomThreshold, ok := userSetting[constant2.UserSettingQuotaWarningThreshold]; ok {
|
||||||
|
threshold = int(userCustomThreshold.(float64))
|
||||||
|
}
|
||||||
|
|
||||||
|
//noMoreQuota := userCache.Quota-(quota+preConsumedQuota) <= 0
|
||||||
|
quotaTooLow := false
|
||||||
|
consumeQuota := quota + preConsumedQuota
|
||||||
|
if userCache.Quota-consumeQuota < threshold {
|
||||||
|
quotaTooLow = true
|
||||||
|
}
|
||||||
|
if quotaTooLow {
|
||||||
|
prompt := "您的额度即将用尽"
|
||||||
|
topUpLink := fmt.Sprintf("%s/topup", setting.ServerAddress)
|
||||||
|
content := "{{value}},当前剩余额度为 {{value}},为了不影响您的使用,请及时充值。<br/>充值链接:<a href='{{value}}'>{{value}}</a>"
|
||||||
|
err = NotifyUser(userCache, dto.NewNotify(dto.NotifyTypeQuotaExceed, prompt, content, []interface{}{prompt, common.FormatQuota(userCache.Quota), topUpLink, topUpLink}))
|
||||||
|
if err != nil {
|
||||||
|
common.SysError(fmt.Sprintf("failed to send quota notify to user %d: %s", userId, err.Error()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -3,15 +3,75 @@ package service
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/constant"
|
||||||
|
"one-api/dto"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
func notifyRootUser(subject string, content string) {
|
func NotifyRootUser(t string, subject string, content string) {
|
||||||
if common.RootUserEmail == "" {
|
user := model.GetRootUser().ToBaseUser()
|
||||||
common.RootUserEmail = model.GetRootUserEmail()
|
_ = NotifyUser(user, dto.NewNotify(t, subject, content, nil))
|
||||||
}
|
}
|
||||||
err := common.SendEmail(subject, common.RootUserEmail, content)
|
|
||||||
if err != nil {
|
func NotifyUser(user *model.UserBase, data dto.Notify) error {
|
||||||
common.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
|
userSetting := user.GetSetting()
|
||||||
}
|
notifyType, ok := userSetting[constant.UserSettingNotifyType]
|
||||||
|
if !ok {
|
||||||
|
notifyType = constant.NotifyTypeEmail
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check notification limit
|
||||||
|
canSend, err := CheckNotificationLimit(user.Id, data.Type)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError(fmt.Sprintf("failed to check notification limit: %s", err.Error()))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !canSend {
|
||||||
|
return fmt.Errorf("notification limit exceeded for user %d with type %s", user.Id, notifyType)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch notifyType {
|
||||||
|
case constant.NotifyTypeEmail:
|
||||||
|
userEmail := user.Email
|
||||||
|
// check setting email
|
||||||
|
if settingEmail, ok := userSetting[constant.UserSettingNotificationEmail]; ok {
|
||||||
|
userEmail = settingEmail.(string)
|
||||||
|
}
|
||||||
|
if userEmail == "" {
|
||||||
|
common.SysLog(fmt.Sprintf("user %d has no email, skip sending email", user.Id))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return sendEmailNotify(userEmail, data)
|
||||||
|
case constant.NotifyTypeWebhook:
|
||||||
|
webhookURL, ok := userSetting[constant.UserSettingWebhookUrl]
|
||||||
|
if !ok {
|
||||||
|
common.SysError(fmt.Sprintf("user %d has no webhook url, skip sending webhook", user.Id))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
webhookURLStr, ok := webhookURL.(string)
|
||||||
|
if !ok {
|
||||||
|
common.SysError(fmt.Sprintf("user %d webhook url is not string type", user.Id))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取 webhook secret
|
||||||
|
var webhookSecret string
|
||||||
|
if secret, ok := userSetting[constant.UserSettingWebhookSecret]; ok {
|
||||||
|
webhookSecret, _ = secret.(string)
|
||||||
|
}
|
||||||
|
|
||||||
|
return SendWebhookNotify(webhookURLStr, webhookSecret, data)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func sendEmailNotify(userEmail string, data dto.Notify) error {
|
||||||
|
// make email content
|
||||||
|
content := data.Content
|
||||||
|
// 处理占位符
|
||||||
|
for _, value := range data.Values {
|
||||||
|
content = strings.Replace(content, dto.ContentValueParam, fmt.Sprintf("%v", value), 1)
|
||||||
|
}
|
||||||
|
return common.SendEmail(data.Title, userEmail, content)
|
||||||
}
|
}
|
||||||
|
|||||||
118
service/webhook.go
Normal file
118
service/webhook.go
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/hmac"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"one-api/dto"
|
||||||
|
"one-api/setting"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// WebhookPayload webhook 通知的负载数据
|
||||||
|
type WebhookPayload struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Title string `json:"title"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
Values []interface{} `json:"values,omitempty"`
|
||||||
|
Timestamp int64 `json:"timestamp"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateSignature 生成 webhook 签名
|
||||||
|
func generateSignature(secret string, payload []byte) string {
|
||||||
|
h := hmac.New(sha256.New, []byte(secret))
|
||||||
|
h.Write(payload)
|
||||||
|
return hex.EncodeToString(h.Sum(nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendWebhookNotify 发送 webhook 通知
|
||||||
|
func SendWebhookNotify(webhookURL string, secret string, data dto.Notify) error {
|
||||||
|
// 处理占位符
|
||||||
|
content := data.Content
|
||||||
|
for _, value := range data.Values {
|
||||||
|
content = fmt.Sprintf(content, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 构建 webhook 负载
|
||||||
|
payload := WebhookPayload{
|
||||||
|
Type: data.Type,
|
||||||
|
Title: data.Title,
|
||||||
|
Content: content,
|
||||||
|
Values: data.Values,
|
||||||
|
Timestamp: time.Now().Unix(),
|
||||||
|
}
|
||||||
|
|
||||||
|
// 序列化负载
|
||||||
|
payloadBytes, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to marshal webhook payload: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建 HTTP 请求
|
||||||
|
var req *http.Request
|
||||||
|
var resp *http.Response
|
||||||
|
|
||||||
|
if setting.EnableWorker() {
|
||||||
|
// 构建worker请求数据
|
||||||
|
workerReq := &WorkerRequest{
|
||||||
|
URL: webhookURL,
|
||||||
|
Key: setting.WorkerValidKey,
|
||||||
|
Method: http.MethodPost,
|
||||||
|
Headers: map[string]string{
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
Body: payloadBytes,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果有secret,添加签名到headers
|
||||||
|
if secret != "" {
|
||||||
|
signature := generateSignature(secret, payloadBytes)
|
||||||
|
workerReq.Headers["X-Webhook-Signature"] = signature
|
||||||
|
workerReq.Headers["Authorization"] = "Bearer " + secret
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err = DoWorkerRequest(workerReq)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to send webhook request through worker: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
// 检查响应状态
|
||||||
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||||
|
return fmt.Errorf("webhook request failed with status code: %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
req, err = http.NewRequest(http.MethodPost, webhookURL, bytes.NewBuffer(payloadBytes))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create webhook request: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 设置请求头
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
// 如果有 secret,生成签名
|
||||||
|
if secret != "" {
|
||||||
|
signature := generateSignature(secret, payloadBytes)
|
||||||
|
req.Header.Set("X-Webhook-Signature", signature)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 发送请求
|
||||||
|
client := GetImpatientHttpClient()
|
||||||
|
resp, err = client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to send webhook request: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
// 检查响应状态
|
||||||
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||||
|
return fmt.Errorf("webhook request failed with status code: %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -26,6 +26,10 @@ import {
|
|||||||
Tag,
|
Tag,
|
||||||
Typography,
|
Typography,
|
||||||
Collapsible,
|
Collapsible,
|
||||||
|
Select,
|
||||||
|
Radio,
|
||||||
|
RadioGroup,
|
||||||
|
AutoComplete,
|
||||||
} from '@douyinfe/semi-ui';
|
} from '@douyinfe/semi-ui';
|
||||||
import {
|
import {
|
||||||
getQuotaPerUnit,
|
getQuotaPerUnit,
|
||||||
@@ -67,14 +71,16 @@ const PersonalSetting = () => {
|
|||||||
const [transferAmount, setTransferAmount] = useState(0);
|
const [transferAmount, setTransferAmount] = useState(0);
|
||||||
const [isModelsExpanded, setIsModelsExpanded] = useState(false);
|
const [isModelsExpanded, setIsModelsExpanded] = useState(false);
|
||||||
const MODELS_DISPLAY_COUNT = 10; // 默认显示的模型数量
|
const MODELS_DISPLAY_COUNT = 10; // 默认显示的模型数量
|
||||||
|
const [notificationSettings, setNotificationSettings] = useState({
|
||||||
|
warningType: 'email',
|
||||||
|
warningThreshold: 100000,
|
||||||
|
webhookUrl: '',
|
||||||
|
webhookSecret: '',
|
||||||
|
notificationEmail: ''
|
||||||
|
});
|
||||||
|
const [showWebhookDocs, setShowWebhookDocs] = useState(false);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
// let user = localStorage.getItem('user');
|
|
||||||
// if (user) {
|
|
||||||
// userDispatch({ type: 'login', payload: user });
|
|
||||||
// }
|
|
||||||
// console.log(localStorage.getItem('user'))
|
|
||||||
|
|
||||||
let status = localStorage.getItem('status');
|
let status = localStorage.getItem('status');
|
||||||
if (status) {
|
if (status) {
|
||||||
status = JSON.parse(status);
|
status = JSON.parse(status);
|
||||||
@@ -105,6 +111,19 @@ const PersonalSetting = () => {
|
|||||||
return () => clearInterval(countdownInterval); // Clean up on unmount
|
return () => clearInterval(countdownInterval); // Clean up on unmount
|
||||||
}, [disableButton, countdown]);
|
}, [disableButton, countdown]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (userState?.user?.setting) {
|
||||||
|
const settings = JSON.parse(userState.user.setting);
|
||||||
|
setNotificationSettings({
|
||||||
|
warningType: settings.notify_type || 'email',
|
||||||
|
warningThreshold: settings.quota_warning_threshold || 500000,
|
||||||
|
webhookUrl: settings.webhook_url || '',
|
||||||
|
webhookSecret: settings.webhook_secret || '',
|
||||||
|
notificationEmail: settings.notification_email || ''
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}, [userState?.user?.setting]);
|
||||||
|
|
||||||
const handleInputChange = (name, value) => {
|
const handleInputChange = (name, value) => {
|
||||||
setInputs((inputs) => ({...inputs, [name]: value}));
|
setInputs((inputs) => ({...inputs, [name]: value}));
|
||||||
};
|
};
|
||||||
@@ -300,7 +319,36 @@ const PersonalSetting = () => {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const handleNotificationSettingChange = (type, value) => {
|
||||||
|
setNotificationSettings(prev => ({
|
||||||
|
...prev,
|
||||||
|
[type]: value.target ? value.target.value : value // 处理 Radio 事件对象
|
||||||
|
}));
|
||||||
|
};
|
||||||
|
|
||||||
|
const saveNotificationSettings = async () => {
|
||||||
|
try {
|
||||||
|
const res = await API.put('/api/user/setting', {
|
||||||
|
notify_type: notificationSettings.warningType,
|
||||||
|
quota_warning_threshold: notificationSettings.warningThreshold,
|
||||||
|
webhook_url: notificationSettings.webhookUrl,
|
||||||
|
webhook_secret: notificationSettings.webhookSecret,
|
||||||
|
notification_email: notificationSettings.notificationEmail
|
||||||
|
});
|
||||||
|
|
||||||
|
if (res.data.success) {
|
||||||
|
showSuccess(t('通知设置已更新'));
|
||||||
|
await getUserData();
|
||||||
|
} else {
|
||||||
|
showError(res.data.message);
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
showError(t('更新通知设置失败'));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
|
||||||
<div>
|
<div>
|
||||||
<Layout>
|
<Layout>
|
||||||
<Layout.Content>
|
<Layout.Content>
|
||||||
@@ -526,9 +574,7 @@ const PersonalSetting = () => {
|
|||||||
</div>
|
</div>
|
||||||
<div style={{marginTop: 10}}>
|
<div style={{marginTop: 10}}>
|
||||||
<Typography.Text strong>{t('微信')}</Typography.Text>
|
<Typography.Text strong>{t('微信')}</Typography.Text>
|
||||||
<div
|
<div style={{display: 'flex', justifyContent: 'space-between'}}>
|
||||||
style={{display: 'flex', justifyContent: 'space-between'}}
|
|
||||||
>
|
|
||||||
<div>
|
<div>
|
||||||
<Input
|
<Input
|
||||||
value={
|
value={
|
||||||
@@ -541,12 +587,16 @@ const PersonalSetting = () => {
|
|||||||
</div>
|
</div>
|
||||||
<div>
|
<div>
|
||||||
<Button
|
<Button
|
||||||
disabled={
|
disabled={!status.wechat_login}
|
||||||
(userState.user && userState.user.wechat_id !== '') ||
|
onClick={() => {
|
||||||
!status.wechat_login
|
setShowWeChatBindModal(true);
|
||||||
}
|
}}
|
||||||
>
|
>
|
||||||
{status.wechat_login ? t('绑定') : t('未启用')}
|
{userState.user && userState.user.wechat_id !== ''
|
||||||
|
? t('修改绑定')
|
||||||
|
: status.wechat_login
|
||||||
|
? t('绑定')
|
||||||
|
: t('未启用')}
|
||||||
</Button>
|
</Button>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
@@ -672,18 +722,8 @@ const PersonalSetting = () => {
|
|||||||
style={{marginTop: '10px'}}
|
style={{marginTop: '10px'}}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
{status.wechat_login && (
|
|
||||||
<Button
|
|
||||||
onClick={() => {
|
|
||||||
setShowWeChatBindModal(true);
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
{t('绑定微信账号')}
|
|
||||||
</Button>
|
|
||||||
)}
|
|
||||||
<Modal
|
<Modal
|
||||||
onCancel={() => setShowWeChatBindModal(false)}
|
onCancel={() => setShowWeChatBindModal(false)}
|
||||||
// onOpen={() => setShowWeChatBindModal(true)}
|
|
||||||
visible={showWeChatBindModal}
|
visible={showWeChatBindModal}
|
||||||
size={'small'}
|
size={'small'}
|
||||||
>
|
>
|
||||||
@@ -707,9 +747,121 @@ const PersonalSetting = () => {
|
|||||||
</Modal>
|
</Modal>
|
||||||
</div>
|
</div>
|
||||||
</Card>
|
</Card>
|
||||||
|
<Card style={{marginTop: 10}}>
|
||||||
|
<Typography.Title heading={6}>{t('通知设置')}</Typography.Title>
|
||||||
|
<div style={{marginTop: 20}}>
|
||||||
|
<Typography.Text strong>{t('通知方式')}</Typography.Text>
|
||||||
|
<div style={{marginTop: 10}}>
|
||||||
|
<RadioGroup
|
||||||
|
value={notificationSettings.warningType}
|
||||||
|
onChange={value => handleNotificationSettingChange('warningType', value)}
|
||||||
|
>
|
||||||
|
<Radio value="email">{t('邮件通知')}</Radio>
|
||||||
|
<Radio value="webhook">{t('Webhook通知')}</Radio>
|
||||||
|
</RadioGroup>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
{notificationSettings.warningType === 'webhook' && (
|
||||||
|
<>
|
||||||
|
<div style={{marginTop: 20}}>
|
||||||
|
<Typography.Text strong>{t('Webhook地址')}</Typography.Text>
|
||||||
|
<div style={{marginTop: 10}}>
|
||||||
|
<Input
|
||||||
|
value={notificationSettings.webhookUrl}
|
||||||
|
onChange={val => handleNotificationSettingChange('webhookUrl', val)}
|
||||||
|
placeholder={t('请输入Webhook地址,例如: https://example.com/webhook')}
|
||||||
|
/>
|
||||||
|
<Typography.Text type="secondary" style={{marginTop: 8, display: 'block'}}>
|
||||||
|
{t('只支持https,系统将以 POST 方式发送通知,请确保地址可以接收 POST 请求')}
|
||||||
|
</Typography.Text>
|
||||||
|
<Typography.Text type="secondary" style={{marginTop: 8, display: 'block'}}>
|
||||||
|
<div style={{cursor: 'pointer'}} onClick={() => setShowWebhookDocs(!showWebhookDocs)}>
|
||||||
|
{t('Webhook请求结构')} {showWebhookDocs ? '▼' : '▶'}
|
||||||
|
</div>
|
||||||
|
<Collapsible isOpen={showWebhookDocs}>
|
||||||
|
<pre style={{marginTop: 4, background: 'var(--semi-color-fill-0)', padding: 8, borderRadius: 4}}>
|
||||||
|
{`{
|
||||||
|
"type": "quota_exceed", // 通知类型
|
||||||
|
"title": "标题", // 通知标题
|
||||||
|
"content": "通知内容", // 通知内容,支持 {{value}} 变量占位符
|
||||||
|
"values": ["值1", "值2"], // 按顺序替换content中的 {{value}} 占位符
|
||||||
|
"timestamp": 1739950503 // 时间戳
|
||||||
|
}
|
||||||
|
|
||||||
|
示例:
|
||||||
|
{
|
||||||
|
"type": "quota_exceed",
|
||||||
|
"title": "额度预警通知",
|
||||||
|
"content": "您的额度即将用尽,当前剩余额度为 {{value}}",
|
||||||
|
"values": ["$0.99"],
|
||||||
|
"timestamp": 1739950503
|
||||||
|
}`}
|
||||||
|
</pre>
|
||||||
|
</Collapsible>
|
||||||
|
</Typography.Text>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div style={{marginTop: 20}}>
|
||||||
|
<Typography.Text strong>{t('接口凭证(可选)')}</Typography.Text>
|
||||||
|
<div style={{marginTop: 10}}>
|
||||||
|
<Input
|
||||||
|
value={notificationSettings.webhookSecret}
|
||||||
|
onChange={val => handleNotificationSettingChange('webhookSecret', val)}
|
||||||
|
placeholder={t('请输入密钥')}
|
||||||
|
/>
|
||||||
|
<Typography.Text type="secondary" style={{marginTop: 8, display: 'block'}}>
|
||||||
|
{t('密钥将以 Bearer 方式添加到请求头中,用于验证webhook请求的合法性')}
|
||||||
|
</Typography.Text>
|
||||||
|
<Typography.Text type="secondary" style={{marginTop: 4, display: 'block'}}>
|
||||||
|
{t('Authorization: Bearer your-secret-key')}
|
||||||
|
</Typography.Text>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
{notificationSettings.warningType === 'email' && (
|
||||||
|
<div style={{marginTop: 20}}>
|
||||||
|
<Typography.Text strong>{t('通知邮箱')}</Typography.Text>
|
||||||
|
<div style={{marginTop: 10}}>
|
||||||
|
<Input
|
||||||
|
value={notificationSettings.notificationEmail}
|
||||||
|
onChange={val => handleNotificationSettingChange('notificationEmail', val)}
|
||||||
|
placeholder={t('留空则使用账号绑定的邮箱')}
|
||||||
|
/>
|
||||||
|
<Typography.Text type="secondary" style={{marginTop: 8, display: 'block'}}>
|
||||||
|
{t('设置用于接收额度预警的邮箱地址,不填则使用账号绑定的邮箱')}
|
||||||
|
</Typography.Text>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
<div style={{marginTop: 20}}>
|
||||||
|
<Typography.Text strong>{t('额度预警阈值')} {renderQuotaWithPrompt(notificationSettings.warningThreshold)}</Typography.Text>
|
||||||
|
<div style={{marginTop: 10}}>
|
||||||
|
<AutoComplete
|
||||||
|
value={notificationSettings.warningThreshold}
|
||||||
|
onChange={val => handleNotificationSettingChange('warningThreshold', val)}
|
||||||
|
style={{width: 200}}
|
||||||
|
placeholder={t('请输入预警额度')}
|
||||||
|
data={[
|
||||||
|
{ value: 100000, label: '0.2$' },
|
||||||
|
{ value: 500000, label: '1$' },
|
||||||
|
{ value: 1000000, label: '5$' },
|
||||||
|
{ value: 5000000, label: '10$' }
|
||||||
|
]}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
<Typography.Text type="secondary" style={{marginTop: 10, display: 'block'}}>
|
||||||
|
{t('当剩余额度低于此数值时,系统将通过选择的方式发送通知')}
|
||||||
|
</Typography.Text>
|
||||||
|
</div>
|
||||||
|
<div style={{marginTop: 20}}>
|
||||||
|
<Button type="primary" onClick={saveNotificationSettings}>
|
||||||
|
{t('保存设置')}
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
</Card>
|
||||||
<Modal
|
<Modal
|
||||||
onCancel={() => setShowEmailBindModal(false)}
|
onCancel={() => setShowEmailBindModal(false)}
|
||||||
// onOpen={() => setShowEmailBindModal(true)}
|
|
||||||
onOk={bindEmail}
|
onOk={bindEmail}
|
||||||
visible={showEmailBindModal}
|
visible={showEmailBindModal}
|
||||||
size={'small'}
|
size={'small'}
|
||||||
|
|||||||
@@ -368,6 +368,17 @@ const SystemSetting = () => {
|
|||||||
</a>
|
</a>
|
||||||
)
|
)
|
||||||
</Header>
|
</Header>
|
||||||
|
<Message info>
|
||||||
|
注意:代理功能仅对图片请求和 Webhook 请求生效,不会影响其他 API 请求。如需配置 API 请求代理,请参考
|
||||||
|
<a
|
||||||
|
href='https://github.com/Calcium-Ion/new-api/blob/main/docs/channel/other_setting.md'
|
||||||
|
target='_blank'
|
||||||
|
rel='noreferrer'
|
||||||
|
>
|
||||||
|
{' '}API 代理设置文档
|
||||||
|
</a>
|
||||||
|
。
|
||||||
|
</Message>
|
||||||
<Form.Group widths='equal'>
|
<Form.Group widths='equal'>
|
||||||
<Form.Input
|
<Form.Input
|
||||||
label='Worker地址,不填写则不启用代理'
|
label='Worker地址,不填写则不启用代理'
|
||||||
|
|||||||
@@ -386,7 +386,7 @@ export function renderQuotaWithPrompt(quota, digits) {
|
|||||||
let displayInCurrency = localStorage.getItem('display_in_currency');
|
let displayInCurrency = localStorage.getItem('display_in_currency');
|
||||||
displayInCurrency = displayInCurrency === 'true';
|
displayInCurrency = displayInCurrency === 'true';
|
||||||
if (displayInCurrency) {
|
if (displayInCurrency) {
|
||||||
return '|' + i18next.t('等价金额') + ': ' + renderQuota(quota, digits) + '';
|
return ' | ' + i18next.t('等价金额') + ': ' + renderQuota(quota, digits) + '';
|
||||||
}
|
}
|
||||||
return '';
|
return '';
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user