chore: 更新依赖、配置和代码生成

主要更新:
- 更新 go.mod/go.sum 依赖
- 重新生成 Ent ORM 代码
- 更新 Wire 依赖注入配置
- 添加 docker-compose.override.yml 到 .gitignore
- 更新 README 文档(Simple Mode 说明和已知问题)
- 清理调试日志
- 其他代码优化和格式修复
This commit is contained in:
ianshaw
2026-01-03 06:37:08 -08:00
parent b1702de522
commit 112a2d0866
121 changed files with 3058 additions and 2948 deletions

View File

@@ -29,6 +29,9 @@ type Account struct {
RateLimitResetAt *time.Time
OverloadUntil *time.Time
TempUnschedulableUntil *time.Time
TempUnschedulableReason string
SessionWindowStart *time.Time
SessionWindowEnd *time.Time
SessionWindowStatus string
@@ -39,6 +42,13 @@ type Account struct {
Groups []*Group
}
type TempUnschedulableRule struct {
ErrorCode int `json:"error_code"`
Keywords []string `json:"keywords"`
DurationMinutes int `json:"duration_minutes"`
Description string `json:"description"`
}
func (a *Account) IsActive() bool {
return a.Status == StatusActive
}
@@ -54,6 +64,9 @@ func (a *Account) IsSchedulable() bool {
if a.RateLimitResetAt != nil && now.Before(*a.RateLimitResetAt) {
return false
}
if a.TempUnschedulableUntil != nil && now.Before(*a.TempUnschedulableUntil) {
return false
}
return true
}
@@ -163,6 +176,114 @@ func (a *Account) GetCredentialAsTime(key string) *time.Time {
return nil
}
func (a *Account) IsTempUnschedulableEnabled() bool {
if a.Credentials == nil {
return false
}
raw, ok := a.Credentials["temp_unschedulable_enabled"]
if !ok || raw == nil {
return false
}
enabled, ok := raw.(bool)
return ok && enabled
}
func (a *Account) GetTempUnschedulableRules() []TempUnschedulableRule {
if a.Credentials == nil {
return nil
}
raw, ok := a.Credentials["temp_unschedulable_rules"]
if !ok || raw == nil {
return nil
}
arr, ok := raw.([]any)
if !ok {
return nil
}
rules := make([]TempUnschedulableRule, 0, len(arr))
for _, item := range arr {
entry, ok := item.(map[string]any)
if !ok || entry == nil {
continue
}
rule := TempUnschedulableRule{
ErrorCode: parseTempUnschedInt(entry["error_code"]),
Keywords: parseTempUnschedStrings(entry["keywords"]),
DurationMinutes: parseTempUnschedInt(entry["duration_minutes"]),
Description: parseTempUnschedString(entry["description"]),
}
if rule.ErrorCode <= 0 || rule.DurationMinutes <= 0 || len(rule.Keywords) == 0 {
continue
}
rules = append(rules, rule)
}
return rules
}
func parseTempUnschedString(value any) string {
s, ok := value.(string)
if !ok {
return ""
}
return strings.TrimSpace(s)
}
func parseTempUnschedStrings(value any) []string {
if value == nil {
return nil
}
var raw []string
switch v := value.(type) {
case []string:
raw = v
case []any:
raw = make([]string, 0, len(v))
for _, item := range v {
if s, ok := item.(string); ok {
raw = append(raw, s)
}
}
default:
return nil
}
out := make([]string, 0, len(raw))
for _, item := range raw {
s := strings.TrimSpace(item)
if s != "" {
out = append(out, s)
}
}
return out
}
func parseTempUnschedInt(value any) int {
switch v := value.(type) {
case int:
return v
case int64:
return int(v)
case float64:
return int(v)
case json.Number:
if i, err := v.Int64(); err == nil {
return int(i)
}
case string:
if i, err := strconv.Atoi(strings.TrimSpace(v)); err == nil {
return i
}
}
return 0
}
func (a *Account) GetModelMapping() map[string]string {
if a.Credentials == nil {
return nil
@@ -206,7 +327,7 @@ func (a *Account) GetMappedModel(requestedModel string) string {
}
func (a *Account) GetBaseURL() string {
if a.Type != AccountTypeAPIKey {
if a.Type != AccountTypeApiKey {
return ""
}
baseURL := a.GetCredential("base_url")
@@ -229,7 +350,7 @@ func (a *Account) GetExtraString(key string) string {
}
func (a *Account) IsCustomErrorCodesEnabled() bool {
if a.Type != AccountTypeAPIKey || a.Credentials == nil {
if a.Type != AccountTypeApiKey || a.Credentials == nil {
return false
}
if v, ok := a.Credentials["custom_error_codes_enabled"]; ok {
@@ -300,15 +421,15 @@ func (a *Account) IsOpenAIOAuth() bool {
return a.IsOpenAI() && a.Type == AccountTypeOAuth
}
func (a *Account) IsOpenAIAPIKey() bool {
return a.IsOpenAI() && a.Type == AccountTypeAPIKey
func (a *Account) IsOpenAIApiKey() bool {
return a.IsOpenAI() && a.Type == AccountTypeApiKey
}
func (a *Account) GetOpenAIBaseURL() string {
if !a.IsOpenAI() {
return ""
}
if a.Type == AccountTypeAPIKey {
if a.Type == AccountTypeApiKey {
baseURL := a.GetCredential("base_url")
if baseURL != "" {
return baseURL
@@ -338,8 +459,8 @@ func (a *Account) GetOpenAIIDToken() string {
return a.GetCredential("id_token")
}
func (a *Account) GetOpenAIAPIKey() string {
if !a.IsOpenAIAPIKey() {
func (a *Account) GetOpenAIApiKey() string {
if !a.IsOpenAIApiKey() {
return ""
}
return a.GetCredential("api_key")

View File

@@ -1,5 +1,3 @@
// Package service 提供业务逻辑层服务,封装领域模型的业务规则和操作流程。
// 服务层协调 repository 层的数据访问,实现跨实体的业务逻辑,并为上层 API 提供统一的业务接口。
package service
import (
@@ -51,6 +49,8 @@ type AccountRepository interface {
SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error
SetOverloaded(ctx context.Context, id int64, until time.Time) error
SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error
ClearTempUnschedulable(ctx context.Context, id int64) error
ClearRateLimit(ctx context.Context, id int64) error
UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error
UpdateExtra(ctx context.Context, id int64, updates map[string]any) error

View File

@@ -139,6 +139,14 @@ func (s *accountRepoStub) SetOverloaded(ctx context.Context, id int64, until tim
panic("unexpected SetOverloaded call")
}
func (s *accountRepoStub) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
panic("unexpected SetTempUnschedulable call")
}
func (s *accountRepoStub) ClearTempUnschedulable(ctx context.Context, id int64) error {
panic("unexpected ClearTempUnschedulable call")
}
func (s *accountRepoStub) ClearRateLimit(ctx context.Context, id int64) error {
panic("unexpected ClearRateLimit call")
}

View File

@@ -324,7 +324,7 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
chatgptAccountID = account.GetChatGPTAccountID()
} else if account.Type == "apikey" {
// API Key - use Platform API
authToken = account.GetOpenAIAPIKey()
authToken = account.GetOpenAIApiKey()
if authToken == "" {
return s.sendErrorAndEnd(c, "No API key available")
}
@@ -402,7 +402,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
}
// For API Key accounts with model mapping, map the model
if account.Type == AccountTypeAPIKey {
if account.Type == AccountTypeApiKey {
mapping := account.GetModelMapping()
if len(mapping) > 0 {
if mappedModel, exists := mapping[testModelID]; exists {
@@ -426,7 +426,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
var err error
switch account.Type {
case AccountTypeAPIKey:
case AccountTypeApiKey:
req, err = s.buildGeminiAPIKeyRequest(ctx, account, testModelID, payload)
case AccountTypeOAuth:
req, err = s.buildGeminiOAuthRequest(ctx, account, testModelID, payload)

View File

@@ -2,7 +2,7 @@ package service
import "time"
type APIKey struct {
type ApiKey struct {
ID int64
UserID int64
Key string
@@ -15,6 +15,6 @@ type APIKey struct {
Group *Group
}
func (k *APIKey) IsActive() bool {
func (k *ApiKey) IsActive() bool {
return k.Status == StatusActive
}

View File

@@ -14,39 +14,39 @@ import (
)
var (
ErrAPIKeyNotFound = infraerrors.NotFound("API_KEY_NOT_FOUND", "api key not found")
ErrApiKeyNotFound = infraerrors.NotFound("API_KEY_NOT_FOUND", "api key not found")
ErrGroupNotAllowed = infraerrors.Forbidden("GROUP_NOT_ALLOWED", "user is not allowed to bind this group")
ErrAPIKeyExists = infraerrors.Conflict("API_KEY_EXISTS", "api key already exists")
ErrAPIKeyTooShort = infraerrors.BadRequest("API_KEY_TOO_SHORT", "api key must be at least 16 characters")
ErrAPIKeyInvalidChars = infraerrors.BadRequest("API_KEY_INVALID_CHARS", "api key can only contain letters, numbers, underscores, and hyphens")
ErrAPIKeyRateLimited = infraerrors.TooManyRequests("API_KEY_RATE_LIMITED", "too many failed attempts, please try again later")
ErrApiKeyExists = infraerrors.Conflict("API_KEY_EXISTS", "api key already exists")
ErrApiKeyTooShort = infraerrors.BadRequest("API_KEY_TOO_SHORT", "api key must be at least 16 characters")
ErrApiKeyInvalidChars = infraerrors.BadRequest("API_KEY_INVALID_CHARS", "api key can only contain letters, numbers, underscores, and hyphens")
ErrApiKeyRateLimited = infraerrors.TooManyRequests("API_KEY_RATE_LIMITED", "too many failed attempts, please try again later")
)
const (
apiKeyMaxErrorsPerHour = 20
)
type APIKeyRepository interface {
Create(ctx context.Context, key *APIKey) error
GetByID(ctx context.Context, id int64) (*APIKey, error)
type ApiKeyRepository interface {
Create(ctx context.Context, key *ApiKey) error
GetByID(ctx context.Context, id int64) (*ApiKey, error)
// GetOwnerID 仅获取 API Key 的所有者 ID用于删除前的轻量级权限验证
GetOwnerID(ctx context.Context, id int64) (int64, error)
GetByKey(ctx context.Context, key string) (*APIKey, error)
Update(ctx context.Context, key *APIKey) error
GetByKey(ctx context.Context, key string) (*ApiKey, error)
Update(ctx context.Context, key *ApiKey) error
Delete(ctx context.Context, id int64) error
ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error)
ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error)
VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error)
CountByUserID(ctx context.Context, userID int64) (int64, error)
ExistsByKey(ctx context.Context, key string) (bool, error)
ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error)
SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error)
ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error)
SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]ApiKey, error)
ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error)
CountByGroupID(ctx context.Context, groupID int64) (int64, error)
}
// APIKeyCache defines cache operations for API key service
type APIKeyCache interface {
// ApiKeyCache defines cache operations for API key service
type ApiKeyCache interface {
GetCreateAttemptCount(ctx context.Context, userID int64) (int, error)
IncrementCreateAttemptCount(ctx context.Context, userID int64) error
DeleteCreateAttemptCount(ctx context.Context, userID int64) error
@@ -55,40 +55,40 @@ type APIKeyCache interface {
SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error
}
// CreateAPIKeyRequest 创建API Key请求
type CreateAPIKeyRequest struct {
// CreateApiKeyRequest 创建API Key请求
type CreateApiKeyRequest struct {
Name string `json:"name"`
GroupID *int64 `json:"group_id"`
CustomKey *string `json:"custom_key"` // 可选的自定义key
}
// UpdateAPIKeyRequest 更新API Key请求
type UpdateAPIKeyRequest struct {
// UpdateApiKeyRequest 更新API Key请求
type UpdateApiKeyRequest struct {
Name *string `json:"name"`
GroupID *int64 `json:"group_id"`
Status *string `json:"status"`
}
// APIKeyService API Key服务
type APIKeyService struct {
apiKeyRepo APIKeyRepository
// ApiKeyService API Key服务
type ApiKeyService struct {
apiKeyRepo ApiKeyRepository
userRepo UserRepository
groupRepo GroupRepository
userSubRepo UserSubscriptionRepository
cache APIKeyCache
cache ApiKeyCache
cfg *config.Config
}
// NewAPIKeyService 创建API Key服务实例
func NewAPIKeyService(
apiKeyRepo APIKeyRepository,
// NewApiKeyService 创建API Key服务实例
func NewApiKeyService(
apiKeyRepo ApiKeyRepository,
userRepo UserRepository,
groupRepo GroupRepository,
userSubRepo UserSubscriptionRepository,
cache APIKeyCache,
cache ApiKeyCache,
cfg *config.Config,
) *APIKeyService {
return &APIKeyService{
) *ApiKeyService {
return &ApiKeyService{
apiKeyRepo: apiKeyRepo,
userRepo: userRepo,
groupRepo: groupRepo,
@@ -99,7 +99,7 @@ func NewAPIKeyService(
}
// GenerateKey 生成随机API Key
func (s *APIKeyService) GenerateKey() (string, error) {
func (s *ApiKeyService) GenerateKey() (string, error) {
// 生成32字节随机数据
bytes := make([]byte, 32)
if _, err := rand.Read(bytes); err != nil {
@@ -107,7 +107,7 @@ func (s *APIKeyService) GenerateKey() (string, error) {
}
// 转换为十六进制字符串并添加前缀
prefix := s.cfg.Default.APIKeyPrefix
prefix := s.cfg.Default.ApiKeyPrefix
if prefix == "" {
prefix = "sk-"
}
@@ -117,10 +117,10 @@ func (s *APIKeyService) GenerateKey() (string, error) {
}
// ValidateCustomKey 验证自定义API Key格式
func (s *APIKeyService) ValidateCustomKey(key string) error {
func (s *ApiKeyService) ValidateCustomKey(key string) error {
// 检查长度
if len(key) < 16 {
return ErrAPIKeyTooShort
return ErrApiKeyTooShort
}
// 检查字符:只允许字母、数字、下划线、连字符
@@ -131,14 +131,14 @@ func (s *APIKeyService) ValidateCustomKey(key string) error {
c == '_' || c == '-' {
continue
}
return ErrAPIKeyInvalidChars
return ErrApiKeyInvalidChars
}
return nil
}
// checkAPIKeyRateLimit 检查用户创建自定义Key的错误次数是否超限
func (s *APIKeyService) checkAPIKeyRateLimit(ctx context.Context, userID int64) error {
// checkApiKeyRateLimit 检查用户创建自定义Key的错误次数是否超限
func (s *ApiKeyService) checkApiKeyRateLimit(ctx context.Context, userID int64) error {
if s.cache == nil {
return nil
}
@@ -150,14 +150,14 @@ func (s *APIKeyService) checkAPIKeyRateLimit(ctx context.Context, userID int64)
}
if count >= apiKeyMaxErrorsPerHour {
return ErrAPIKeyRateLimited
return ErrApiKeyRateLimited
}
return nil
}
// incrementAPIKeyErrorCount 增加用户创建自定义Key的错误计数
func (s *APIKeyService) incrementAPIKeyErrorCount(ctx context.Context, userID int64) {
// incrementApiKeyErrorCount 增加用户创建自定义Key的错误计数
func (s *ApiKeyService) incrementApiKeyErrorCount(ctx context.Context, userID int64) {
if s.cache == nil {
return
}
@@ -168,7 +168,7 @@ func (s *APIKeyService) incrementAPIKeyErrorCount(ctx context.Context, userID in
// canUserBindGroup 检查用户是否可以绑定指定分组
// 对于订阅类型分组:检查用户是否有有效订阅
// 对于标准类型分组:使用原有的 AllowedGroups 和 IsExclusive 逻辑
func (s *APIKeyService) canUserBindGroup(ctx context.Context, user *User, group *Group) bool {
func (s *ApiKeyService) canUserBindGroup(ctx context.Context, user *User, group *Group) bool {
// 订阅类型分组:需要有效订阅
if group.IsSubscriptionType() {
_, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, user.ID, group.ID)
@@ -179,7 +179,7 @@ func (s *APIKeyService) canUserBindGroup(ctx context.Context, user *User, group
}
// Create 创建API Key
func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIKeyRequest) (*APIKey, error) {
func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiKeyRequest) (*ApiKey, error) {
// 验证用户存在
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
@@ -204,7 +204,7 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK
// 判断是否使用自定义Key
if req.CustomKey != nil && *req.CustomKey != "" {
// 检查限流仅对自定义key进行限流
if err := s.checkAPIKeyRateLimit(ctx, userID); err != nil {
if err := s.checkApiKeyRateLimit(ctx, userID); err != nil {
return nil, err
}
@@ -219,9 +219,9 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK
return nil, fmt.Errorf("check key exists: %w", err)
}
if exists {
// Key已存在,增加错误计数
s.incrementAPIKeyErrorCount(ctx, userID)
return nil, ErrAPIKeyExists
// Key已存在增加错误计数
s.incrementApiKeyErrorCount(ctx, userID)
return nil, ErrApiKeyExists
}
key = *req.CustomKey
@@ -235,7 +235,7 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK
}
// 创建API Key记录
apiKey := &APIKey{
apiKey := &ApiKey{
UserID: userID,
Key: key,
Name: req.Name,
@@ -251,7 +251,7 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK
}
// List 获取用户的API Key列表
func (s *APIKeyService) List(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
func (s *ApiKeyService) List(ctx context.Context, userID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error) {
keys, pagination, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
if err != nil {
return nil, nil, fmt.Errorf("list api keys: %w", err)
@@ -259,7 +259,7 @@ func (s *APIKeyService) List(ctx context.Context, userID int64, params paginatio
return keys, pagination, nil
}
func (s *APIKeyService) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
func (s *ApiKeyService) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
if len(apiKeyIDs) == 0 {
return []int64{}, nil
}
@@ -272,7 +272,7 @@ func (s *APIKeyService) VerifyOwnership(ctx context.Context, userID int64, apiKe
}
// GetByID 根据ID获取API Key
func (s *APIKeyService) GetByID(ctx context.Context, id int64) (*APIKey, error) {
func (s *ApiKeyService) GetByID(ctx context.Context, id int64) (*ApiKey, error) {
apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
if err != nil {
return nil, fmt.Errorf("get api key: %w", err)
@@ -281,7 +281,7 @@ func (s *APIKeyService) GetByID(ctx context.Context, id int64) (*APIKey, error)
}
// GetByKey 根据Key字符串获取API Key用于认证
func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, error) {
func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*ApiKey, error) {
// 尝试从Redis缓存获取
cacheKey := fmt.Sprintf("apikey:%s", key)
@@ -301,7 +301,7 @@ func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, erro
}
// Update 更新API Key
func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req UpdateAPIKeyRequest) (*APIKey, error) {
func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req UpdateApiKeyRequest) (*ApiKey, error) {
apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
if err != nil {
return nil, fmt.Errorf("get api key: %w", err)
@@ -353,8 +353,8 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req
// Delete 删除API Key
// 优化:使用 GetOwnerID 替代 GetByID 进行权限验证,
// 避免加载完整 APIKey 对象及其关联数据User、Group提升删除操作的性能
func (s *APIKeyService) Delete(ctx context.Context, id int64, userID int64) error {
// 避免加载完整 ApiKey 对象及其关联数据User、Group提升删除操作的性能
func (s *ApiKeyService) Delete(ctx context.Context, id int64, userID int64) error {
// 仅获取所有者 ID 用于权限验证,而非加载完整对象
ownerID, err := s.apiKeyRepo.GetOwnerID(ctx, id)
if err != nil {
@@ -379,7 +379,7 @@ func (s *APIKeyService) Delete(ctx context.Context, id int64, userID int64) erro
}
// ValidateKey 验证API Key是否有效用于认证中间件
func (s *APIKeyService) ValidateKey(ctx context.Context, key string) (*APIKey, *User, error) {
func (s *ApiKeyService) ValidateKey(ctx context.Context, key string) (*ApiKey, *User, error) {
// 获取API Key
apiKey, err := s.GetByKey(ctx, key)
if err != nil {
@@ -406,7 +406,7 @@ func (s *APIKeyService) ValidateKey(ctx context.Context, key string) (*APIKey, *
}
// IncrementUsage 增加API Key使用次数可选用于统计
func (s *APIKeyService) IncrementUsage(ctx context.Context, keyID int64) error {
func (s *ApiKeyService) IncrementUsage(ctx context.Context, keyID int64) error {
// 使用Redis计数器
if s.cache != nil {
cacheKey := fmt.Sprintf("apikey:usage:%d:%s", keyID, timezone.Now().Format("2006-01-02"))
@@ -423,7 +423,7 @@ func (s *APIKeyService) IncrementUsage(ctx context.Context, keyID int64) error {
// 返回用户可以选择的分组:
// - 标准类型分组:公开的(非专属)或用户被明确允许的
// - 订阅类型分组:用户有有效订阅的
func (s *APIKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([]Group, error) {
func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([]Group, error) {
// 获取用户信息
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
@@ -460,7 +460,7 @@ func (s *APIKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([
}
// canUserBindGroupInternal 内部方法,检查用户是否可以绑定分组(使用预加载的订阅数据)
func (s *APIKeyService) canUserBindGroupInternal(user *User, group *Group, subscribedGroupIDs map[int64]bool) bool {
func (s *ApiKeyService) canUserBindGroupInternal(user *User, group *Group, subscribedGroupIDs map[int64]bool) bool {
// 订阅类型分组:需要有效订阅
if group.IsSubscriptionType() {
return subscribedGroupIDs[group.ID]
@@ -469,8 +469,8 @@ func (s *APIKeyService) canUserBindGroupInternal(user *User, group *Group, subsc
return user.CanBindGroup(group.ID, group.IsExclusive)
}
func (s *APIKeyService) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error) {
keys, err := s.apiKeyRepo.SearchAPIKeys(ctx, userID, keyword, limit)
func (s *ApiKeyService) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]ApiKey, error) {
keys, err := s.apiKeyRepo.SearchApiKeys(ctx, userID, keyword, limit)
if err != nil {
return nil, fmt.Errorf("search api keys: %w", err)
}

View File

@@ -1,7 +1,7 @@
//go:build unit
// API Key 服务删除方法的单元测试
// 测试 APIKeyService.Delete 方法在各种场景下的行为,
// 测试 ApiKeyService.Delete 方法在各种场景下的行为,
// 包括权限验证、缓存清理和错误处理
package service
@@ -16,12 +16,12 @@ import (
"github.com/stretchr/testify/require"
)
// apiKeyRepoStub 是 APIKeyRepository 接口的测试桩实现。
// 用于隔离测试 APIKeyService.Delete 方法,避免依赖真实数据库。
// apiKeyRepoStub 是 ApiKeyRepository 接口的测试桩实现。
// 用于隔离测试 ApiKeyService.Delete 方法,避免依赖真实数据库。
//
// 设计说明:
// - ownerID: 模拟 GetOwnerID 返回的所有者 ID
// - ownerErr: 模拟 GetOwnerID 返回的错误(如 ErrAPIKeyNotFound
// - ownerErr: 模拟 GetOwnerID 返回的错误(如 ErrApiKeyNotFound
// - deleteErr: 模拟 Delete 返回的错误
// - deletedIDs: 记录被调用删除的 API Key ID用于断言验证
type apiKeyRepoStub struct {
@@ -33,11 +33,11 @@ type apiKeyRepoStub struct {
// 以下方法在本测试中不应被调用,使用 panic 确保测试失败时能快速定位问题
func (s *apiKeyRepoStub) Create(ctx context.Context, key *APIKey) error {
func (s *apiKeyRepoStub) Create(ctx context.Context, key *ApiKey) error {
panic("unexpected Create call")
}
func (s *apiKeyRepoStub) GetByID(ctx context.Context, id int64) (*APIKey, error) {
func (s *apiKeyRepoStub) GetByID(ctx context.Context, id int64) (*ApiKey, error) {
panic("unexpected GetByID call")
}
@@ -47,11 +47,11 @@ func (s *apiKeyRepoStub) GetOwnerID(ctx context.Context, id int64) (int64, error
return s.ownerID, s.ownerErr
}
func (s *apiKeyRepoStub) GetByKey(ctx context.Context, key string) (*APIKey, error) {
func (s *apiKeyRepoStub) GetByKey(ctx context.Context, key string) (*ApiKey, error) {
panic("unexpected GetByKey call")
}
func (s *apiKeyRepoStub) Update(ctx context.Context, key *APIKey) error {
func (s *apiKeyRepoStub) Update(ctx context.Context, key *ApiKey) error {
panic("unexpected Update call")
}
@@ -64,7 +64,7 @@ func (s *apiKeyRepoStub) Delete(ctx context.Context, id int64) error {
// 以下是接口要求实现但本测试不关心的方法
func (s *apiKeyRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
func (s *apiKeyRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error) {
panic("unexpected ListByUserID call")
}
@@ -80,12 +80,12 @@ func (s *apiKeyRepoStub) ExistsByKey(ctx context.Context, key string) (bool, err
panic("unexpected ExistsByKey call")
}
func (s *apiKeyRepoStub) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
func (s *apiKeyRepoStub) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error) {
panic("unexpected ListByGroupID call")
}
func (s *apiKeyRepoStub) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error) {
panic("unexpected SearchAPIKeys call")
func (s *apiKeyRepoStub) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]ApiKey, error) {
panic("unexpected SearchApiKeys call")
}
func (s *apiKeyRepoStub) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
@@ -96,7 +96,7 @@ func (s *apiKeyRepoStub) CountByGroupID(ctx context.Context, groupID int64) (int
panic("unexpected CountByGroupID call")
}
// apiKeyCacheStub 是 APIKeyCache 接口的测试桩实现。
// apiKeyCacheStub 是 ApiKeyCache 接口的测试桩实现。
// 用于验证删除操作时缓存清理逻辑是否被正确调用。
//
// 设计说明:
@@ -132,17 +132,17 @@ func (s *apiKeyCacheStub) SetDailyUsageExpiry(ctx context.Context, apiKey string
return nil
}
// TestAPIKeyService_Delete_OwnerMismatch 测试非所有者尝试删除时返回权限错误。
// TestApiKeyService_Delete_OwnerMismatch 测试非所有者尝试删除时返回权限错误。
// 预期行为:
// - GetOwnerID 返回所有者 ID 为 1
// - 调用者 userID 为 2不匹配
// - 返回 ErrInsufficientPerms 错误
// - Delete 方法不被调用
// - 缓存不被清除
func TestAPIKeyService_Delete_OwnerMismatch(t *testing.T) {
func TestApiKeyService_Delete_OwnerMismatch(t *testing.T) {
repo := &apiKeyRepoStub{ownerID: 1}
cache := &apiKeyCacheStub{}
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
svc := &ApiKeyService{apiKeyRepo: repo, cache: cache}
err := svc.Delete(context.Background(), 10, 2) // API Key ID=10, 调用者 userID=2
require.ErrorIs(t, err, ErrInsufficientPerms)
@@ -150,17 +150,17 @@ func TestAPIKeyService_Delete_OwnerMismatch(t *testing.T) {
require.Empty(t, cache.invalidated) // 验证缓存未被清除
}
// TestAPIKeyService_Delete_Success 测试所有者成功删除 API Key 的场景。
// TestApiKeyService_Delete_Success 测试所有者成功删除 API Key 的场景。
// 预期行为:
// - GetOwnerID 返回所有者 ID 为 7
// - 调用者 userID 为 7匹配
// - Delete 成功执行
// - 缓存被正确清除(使用 ownerID
// - 返回 nil 错误
func TestAPIKeyService_Delete_Success(t *testing.T) {
func TestApiKeyService_Delete_Success(t *testing.T) {
repo := &apiKeyRepoStub{ownerID: 7}
cache := &apiKeyCacheStub{}
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
svc := &ApiKeyService{apiKeyRepo: repo, cache: cache}
err := svc.Delete(context.Background(), 42, 7) // API Key ID=42, 调用者 userID=7
require.NoError(t, err)
@@ -168,37 +168,37 @@ func TestAPIKeyService_Delete_Success(t *testing.T) {
require.Equal(t, []int64{7}, cache.invalidated) // 验证所有者的缓存被清除
}
// TestAPIKeyService_Delete_NotFound 测试删除不存在的 API Key 时返回正确的错误。
// TestApiKeyService_Delete_NotFound 测试删除不存在的 API Key 时返回正确的错误。
// 预期行为:
// - GetOwnerID 返回 ErrAPIKeyNotFound 错误
// - 返回 ErrAPIKeyNotFound 错误(被 fmt.Errorf 包装)
// - GetOwnerID 返回 ErrApiKeyNotFound 错误
// - 返回 ErrApiKeyNotFound 错误(被 fmt.Errorf 包装)
// - Delete 方法不被调用
// - 缓存不被清除
func TestAPIKeyService_Delete_NotFound(t *testing.T) {
repo := &apiKeyRepoStub{ownerErr: ErrAPIKeyNotFound}
func TestApiKeyService_Delete_NotFound(t *testing.T) {
repo := &apiKeyRepoStub{ownerErr: ErrApiKeyNotFound}
cache := &apiKeyCacheStub{}
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
svc := &ApiKeyService{apiKeyRepo: repo, cache: cache}
err := svc.Delete(context.Background(), 99, 1)
require.ErrorIs(t, err, ErrAPIKeyNotFound)
require.ErrorIs(t, err, ErrApiKeyNotFound)
require.Empty(t, repo.deletedIDs)
require.Empty(t, cache.invalidated)
}
// TestAPIKeyService_Delete_DeleteFails 测试删除操作失败时的错误处理。
// TestApiKeyService_Delete_DeleteFails 测试删除操作失败时的错误处理。
// 预期行为:
// - GetOwnerID 返回正确的所有者 ID
// - 所有权验证通过
// - 缓存被清除(在删除之前)
// - Delete 被调用但返回错误
// - 返回包含 "delete api key" 的错误信息
func TestAPIKeyService_Delete_DeleteFails(t *testing.T) {
func TestApiKeyService_Delete_DeleteFails(t *testing.T) {
repo := &apiKeyRepoStub{
ownerID: 3,
deleteErr: errors.New("delete failed"),
}
cache := &apiKeyCacheStub{}
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
svc := &ApiKeyService{apiKeyRepo: repo, cache: cache}
err := svc.Delete(context.Background(), 3, 3) // API Key ID=3, 调用者 userID=3
require.Error(t, err)

View File

@@ -445,7 +445,7 @@ func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID
// CheckBillingEligibility 检查用户是否有资格发起请求
// 余额模式:检查缓存余额 > 0
// 订阅模式检查缓存用量未超过限额Group限额从参数传入
func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user *User, apiKey *APIKey, group *Group, subscription *UserSubscription) error {
func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user *User, apiKey *ApiKey, group *Group, subscription *UserSubscription) error {
// 简易模式:跳过所有计费检查
if s.cfg.RunMode == config.RunModeSimple {
return nil

View File

@@ -82,7 +82,7 @@ type crsExportResponse struct {
OpenAIOAuthAccounts []crsOpenAIOAuthAccount `json:"openaiOAuthAccounts"`
OpenAIResponsesAccounts []crsOpenAIResponsesAccount `json:"openaiResponsesAccounts"`
GeminiOAuthAccounts []crsGeminiOAuthAccount `json:"geminiOAuthAccounts"`
GeminiAPIKeyAccounts []crsGeminiAPIKeyAccount `json:"geminiAPIKeyAccounts"`
GeminiAPIKeyAccounts []crsGeminiAPIKeyAccount `json:"geminiApiKeyAccounts"`
} `json:"data"`
}
@@ -430,7 +430,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
account := &Account{
Name: defaultName(src.Name, src.ID),
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Type: AccountTypeApiKey,
Credentials: credentials,
Extra: extra,
ProxyID: proxyID,
@@ -455,7 +455,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
existing.Extra = mergeMap(existing.Extra, extra)
existing.Name = defaultName(src.Name, src.ID)
existing.Platform = PlatformAnthropic
existing.Type = AccountTypeAPIKey
existing.Type = AccountTypeApiKey
existing.Credentials = mergeMap(existing.Credentials, credentials)
if proxyID != nil {
existing.ProxyID = proxyID
@@ -674,7 +674,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
account := &Account{
Name: defaultName(src.Name, src.ID),
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Type: AccountTypeApiKey,
Credentials: credentials,
Extra: extra,
ProxyID: proxyID,
@@ -699,7 +699,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
existing.Extra = mergeMap(existing.Extra, extra)
existing.Name = defaultName(src.Name, src.ID)
existing.Platform = PlatformOpenAI
existing.Type = AccountTypeAPIKey
existing.Type = AccountTypeApiKey
existing.Credentials = mergeMap(existing.Credentials, credentials)
if proxyID != nil {
existing.ProxyID = proxyID
@@ -893,7 +893,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
account := &Account{
Name: defaultName(src.Name, src.ID),
Platform: PlatformGemini,
Type: AccountTypeAPIKey,
Type: AccountTypeApiKey,
Credentials: credentials,
Extra: extra,
ProxyID: proxyID,
@@ -918,7 +918,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
existing.Extra = mergeMap(existing.Extra, extra)
existing.Name = defaultName(src.Name, src.ID)
existing.Platform = PlatformGemini
existing.Type = AccountTypeAPIKey
existing.Type = AccountTypeApiKey
existing.Credentials = mergeMap(existing.Credentials, credentials)
if proxyID != nil {
existing.ProxyID = proxyID

View File

@@ -43,8 +43,8 @@ func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTi
return stats, nil
}
func (s *DashboardService) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) {
trend, err := s.usageRepo.GetAPIKeyUsageTrend(ctx, startTime, endTime, granularity, limit)
func (s *DashboardService) GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.ApiKeyUsageTrendPoint, error) {
trend, err := s.usageRepo.GetApiKeyUsageTrend(ctx, startTime, endTime, granularity, limit)
if err != nil {
return nil, fmt.Errorf("get api key usage trend: %w", err)
}
@@ -67,8 +67,8 @@ func (s *DashboardService) GetBatchUserUsageStats(ctx context.Context, userIDs [
return stats, nil
}
func (s *DashboardService) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
stats, err := s.usageRepo.GetBatchAPIKeyUsageStats(ctx, apiKeyIDs)
func (s *DashboardService) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchApiKeyUsageStats, error) {
stats, err := s.usageRepo.GetBatchApiKeyUsageStats(ctx, apiKeyIDs)
if err != nil {
return nil, fmt.Errorf("get batch api key usage stats: %w", err)
}

View File

@@ -28,7 +28,7 @@ const (
const (
AccountTypeOAuth = "oauth" // OAuth类型账号full scope: profile + inference
AccountTypeSetupToken = "setup-token" // Setup Token类型账号inference only scope
AccountTypeAPIKey = "apikey" // API Key类型账号
AccountTypeApiKey = "apikey" // API Key类型账号
)
// Redeem type constants
@@ -64,13 +64,13 @@ const (
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
// 邮件服务设置
SettingKeySMTPHost = "smtp_host" // SMTP服务器地址
SettingKeySMTPPort = "smtp_port" // SMTP端口
SettingKeySMTPUsername = "smtp_username" // SMTP用户名
SettingKeySMTPPassword = "smtp_password" // SMTP密码加密存储
SettingKeySMTPFrom = "smtp_from" // 发件人地址
SettingKeySMTPFromName = "smtp_from_name" // 发件人名称
SettingKeySMTPUseTLS = "smtp_use_tls" // 是否使用TLS
SettingKeySmtpHost = "smtp_host" // SMTP服务器地址
SettingKeySmtpPort = "smtp_port" // SMTP端口
SettingKeySmtpUsername = "smtp_username" // SMTP用户名
SettingKeySmtpPassword = "smtp_password" // SMTP密码加密存储
SettingKeySmtpFrom = "smtp_from" // 发件人地址
SettingKeySmtpFromName = "smtp_from_name" // 发件人名称
SettingKeySmtpUseTLS = "smtp_use_tls" // 是否使用TLS
// Cloudflare Turnstile 设置
SettingKeyTurnstileEnabled = "turnstile_enabled" // 是否启用 Turnstile 验证
@@ -81,20 +81,27 @@ const (
SettingKeySiteName = "site_name" // 网站名称
SettingKeySiteLogo = "site_logo" // 网站Logo (base64)
SettingKeySiteSubtitle = "site_subtitle" // 网站副标题
SettingKeyAPIBaseURL = "api_base_url" // API端点地址用于客户端配置和导入
SettingKeyApiBaseUrl = "api_base_url" // API端点地址用于客户端配置和导入
SettingKeyContactInfo = "contact_info" // 客服联系方式
SettingKeyDocURL = "doc_url" // 文档链接
SettingKeyDocUrl = "doc_url" // 文档链接
// 默认配置
SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量
SettingKeyDefaultBalance = "default_balance" // 新用户默认余额
// 管理员 API Key
SettingKeyAdminAPIKey = "admin_api_key" // 全局管理员 API Key用于外部系统集成
SettingKeyAdminApiKey = "admin_api_key" // 全局管理员 API Key用于外部系统集成
// Gemini 配额策略JSON
SettingKeyGeminiQuotaPolicy = "gemini_quota_policy"
// Model fallback settings
SettingKeyEnableModelFallback = "enable_model_fallback"
SettingKeyFallbackModelAnthropic = "fallback_model_anthropic"
SettingKeyFallbackModelOpenAI = "fallback_model_openai"
SettingKeyFallbackModelGemini = "fallback_model_gemini"
SettingKeyFallbackModelAntigravity = "fallback_model_antigravity"
)
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys)
const AdminAPIKeyPrefix = "admin-"
// Admin API Key prefix (distinct from user "sk-" keys)
const AdminApiKeyPrefix = "admin-"

View File

@@ -40,8 +40,8 @@ const (
maxVerifyCodeAttempts = 5
)
// SMTPConfig SMTP配置
type SMTPConfig struct {
// SmtpConfig SMTP配置
type SmtpConfig struct {
Host string
Port int
Username string
@@ -65,16 +65,16 @@ func NewEmailService(settingRepo SettingRepository, cache EmailCache) *EmailServ
}
}
// GetSMTPConfig 从数据库获取SMTP配置
func (s *EmailService) GetSMTPConfig(ctx context.Context) (*SMTPConfig, error) {
// GetSmtpConfig 从数据库获取SMTP配置
func (s *EmailService) GetSmtpConfig(ctx context.Context) (*SmtpConfig, error) {
keys := []string{
SettingKeySMTPHost,
SettingKeySMTPPort,
SettingKeySMTPUsername,
SettingKeySMTPPassword,
SettingKeySMTPFrom,
SettingKeySMTPFromName,
SettingKeySMTPUseTLS,
SettingKeySmtpHost,
SettingKeySmtpPort,
SettingKeySmtpUsername,
SettingKeySmtpPassword,
SettingKeySmtpFrom,
SettingKeySmtpFromName,
SettingKeySmtpUseTLS,
}
settings, err := s.settingRepo.GetMultiple(ctx, keys)
@@ -82,34 +82,34 @@ func (s *EmailService) GetSMTPConfig(ctx context.Context) (*SMTPConfig, error) {
return nil, fmt.Errorf("get smtp settings: %w", err)
}
host := settings[SettingKeySMTPHost]
host := settings[SettingKeySmtpHost]
if host == "" {
return nil, ErrEmailNotConfigured
}
port := 587 // 默认端口
if portStr := settings[SettingKeySMTPPort]; portStr != "" {
if portStr := settings[SettingKeySmtpPort]; portStr != "" {
if p, err := strconv.Atoi(portStr); err == nil {
port = p
}
}
useTLS := settings[SettingKeySMTPUseTLS] == "true"
useTLS := settings[SettingKeySmtpUseTLS] == "true"
return &SMTPConfig{
return &SmtpConfig{
Host: host,
Port: port,
Username: settings[SettingKeySMTPUsername],
Password: settings[SettingKeySMTPPassword],
From: settings[SettingKeySMTPFrom],
FromName: settings[SettingKeySMTPFromName],
Username: settings[SettingKeySmtpUsername],
Password: settings[SettingKeySmtpPassword],
From: settings[SettingKeySmtpFrom],
FromName: settings[SettingKeySmtpFromName],
UseTLS: useTLS,
}, nil
}
// SendEmail 发送邮件(使用数据库中保存的配置)
func (s *EmailService) SendEmail(ctx context.Context, to, subject, body string) error {
config, err := s.GetSMTPConfig(ctx)
config, err := s.GetSmtpConfig(ctx)
if err != nil {
return err
}
@@ -117,7 +117,7 @@ func (s *EmailService) SendEmail(ctx context.Context, to, subject, body string)
}
// SendEmailWithConfig 使用指定配置发送邮件
func (s *EmailService) SendEmailWithConfig(config *SMTPConfig, to, subject, body string) error {
func (s *EmailService) SendEmailWithConfig(config *SmtpConfig, to, subject, body string) error {
from := config.From
if config.FromName != "" {
from = fmt.Sprintf("%s <%s>", config.FromName, config.From)
@@ -306,8 +306,8 @@ func (s *EmailService) buildVerifyCodeEmailBody(code, siteName string) string {
`, siteName, code)
}
// TestSMTPConnectionWithConfig 使用指定配置测试SMTP连接
func (s *EmailService) TestSMTPConnectionWithConfig(config *SMTPConfig) error {
// TestSmtpConnectionWithConfig 使用指定配置测试SMTP连接
func (s *EmailService) TestSmtpConnectionWithConfig(config *SmtpConfig) error {
addr := fmt.Sprintf("%s:%d", config.Host, config.Port)
if config.UseTLS {

View File

@@ -487,8 +487,8 @@ func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *Acco
return "", "", errors.New("access_token not found in credentials")
}
return accessToken, "oauth", nil
case AccountTypeAPIKey:
apiKey := account.GetOpenAIAPIKey()
case AccountTypeApiKey:
apiKey := account.GetOpenAIApiKey()
if apiKey == "" {
return "", "", errors.New("api_key not found in credentials")
}
@@ -627,7 +627,7 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
case AccountTypeOAuth:
// OAuth accounts use ChatGPT internal API
targetURL = chatgptCodexURL
case AccountTypeAPIKey:
case AccountTypeApiKey:
// API Key accounts use Platform API or custom base URL
baseURL := account.GetOpenAIBaseURL()
if baseURL != "" {
@@ -703,7 +703,13 @@ func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *ht
}
// Handle upstream error (mark account status)
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
shouldDisable := false
if s.rateLimitService != nil {
shouldDisable = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
}
if shouldDisable {
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
// Return appropriate error response
var errType, errMsg string
@@ -940,7 +946,7 @@ func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel
// OpenAIRecordUsageInput input for recording usage
type OpenAIRecordUsageInput struct {
Result *OpenAIForwardResult
APIKey *APIKey
ApiKey *ApiKey
User *User
Account *Account
Subscription *UserSubscription
@@ -949,7 +955,7 @@ type OpenAIRecordUsageInput struct {
// RecordUsage records usage and deducts balance
func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRecordUsageInput) error {
result := input.Result
apiKey := input.APIKey
apiKey := input.ApiKey
user := input.User
account := input.Account
subscription := input.Subscription
@@ -991,7 +997,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
durationMs := int(result.Duration.Milliseconds())
usageLog := &UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
ApiKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: result.RequestID,
Model: result.Model,

View File

@@ -61,9 +61,9 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SettingKeySiteName,
SettingKeySiteLogo,
SettingKeySiteSubtitle,
SettingKeyAPIBaseURL,
SettingKeyApiBaseUrl,
SettingKeyContactInfo,
SettingKeyDocURL,
SettingKeyDocUrl,
}
settings, err := s.settingRepo.GetMultiple(ctx, keys)
@@ -79,9 +79,9 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"),
SiteLogo: settings[SettingKeySiteLogo],
SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"),
APIBaseURL: settings[SettingKeyAPIBaseURL],
ApiBaseUrl: settings[SettingKeyApiBaseUrl],
ContactInfo: settings[SettingKeyContactInfo],
DocURL: settings[SettingKeyDocURL],
DocUrl: settings[SettingKeyDocUrl],
}, nil
}
@@ -94,15 +94,15 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates[SettingKeyEmailVerifyEnabled] = strconv.FormatBool(settings.EmailVerifyEnabled)
// 邮件服务设置(只有非空才更新密码)
updates[SettingKeySMTPHost] = settings.SMTPHost
updates[SettingKeySMTPPort] = strconv.Itoa(settings.SMTPPort)
updates[SettingKeySMTPUsername] = settings.SMTPUsername
if settings.SMTPPassword != "" {
updates[SettingKeySMTPPassword] = settings.SMTPPassword
updates[SettingKeySmtpHost] = settings.SmtpHost
updates[SettingKeySmtpPort] = strconv.Itoa(settings.SmtpPort)
updates[SettingKeySmtpUsername] = settings.SmtpUsername
if settings.SmtpPassword != "" {
updates[SettingKeySmtpPassword] = settings.SmtpPassword
}
updates[SettingKeySMTPFrom] = settings.SMTPFrom
updates[SettingKeySMTPFromName] = settings.SMTPFromName
updates[SettingKeySMTPUseTLS] = strconv.FormatBool(settings.SMTPUseTLS)
updates[SettingKeySmtpFrom] = settings.SmtpFrom
updates[SettingKeySmtpFromName] = settings.SmtpFromName
updates[SettingKeySmtpUseTLS] = strconv.FormatBool(settings.SmtpUseTLS)
// Cloudflare Turnstile 设置(只有非空才更新密钥)
updates[SettingKeyTurnstileEnabled] = strconv.FormatBool(settings.TurnstileEnabled)
@@ -115,14 +115,21 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates[SettingKeySiteName] = settings.SiteName
updates[SettingKeySiteLogo] = settings.SiteLogo
updates[SettingKeySiteSubtitle] = settings.SiteSubtitle
updates[SettingKeyAPIBaseURL] = settings.APIBaseURL
updates[SettingKeyApiBaseUrl] = settings.ApiBaseUrl
updates[SettingKeyContactInfo] = settings.ContactInfo
updates[SettingKeyDocURL] = settings.DocURL
updates[SettingKeyDocUrl] = settings.DocUrl
// 默认配置
updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency)
updates[SettingKeyDefaultBalance] = strconv.FormatFloat(settings.DefaultBalance, 'f', 8, 64)
// Model fallback configuration
updates[SettingKeyEnableModelFallback] = strconv.FormatBool(settings.EnableModelFallback)
updates[SettingKeyFallbackModelAnthropic] = settings.FallbackModelAnthropic
updates[SettingKeyFallbackModelOpenAI] = settings.FallbackModelOpenAI
updates[SettingKeyFallbackModelGemini] = settings.FallbackModelGemini
updates[SettingKeyFallbackModelAntigravity] = settings.FallbackModelAntigravity
return s.settingRepo.SetMultiple(ctx, updates)
}
@@ -198,8 +205,14 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
SettingKeySiteLogo: "",
SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
SettingKeySMTPPort: "587",
SettingKeySMTPUseTLS: "false",
SettingKeySmtpPort: "587",
SettingKeySmtpUseTLS: "false",
// Model fallback defaults
SettingKeyEnableModelFallback: "false",
SettingKeyFallbackModelAnthropic: "claude-3-5-sonnet-20241022",
SettingKeyFallbackModelOpenAI: "gpt-4o",
SettingKeyFallbackModelGemini: "gemini-2.5-pro",
SettingKeyFallbackModelAntigravity: "gemini-2.5-pro",
}
return s.settingRepo.SetMultiple(ctx, defaults)
@@ -210,26 +223,26 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
result := &SystemSettings{
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
EmailVerifyEnabled: settings[SettingKeyEmailVerifyEnabled] == "true",
SMTPHost: settings[SettingKeySMTPHost],
SMTPUsername: settings[SettingKeySMTPUsername],
SMTPFrom: settings[SettingKeySMTPFrom],
SMTPFromName: settings[SettingKeySMTPFromName],
SMTPUseTLS: settings[SettingKeySMTPUseTLS] == "true",
SmtpHost: settings[SettingKeySmtpHost],
SmtpUsername: settings[SettingKeySmtpUsername],
SmtpFrom: settings[SettingKeySmtpFrom],
SmtpFromName: settings[SettingKeySmtpFromName],
SmtpUseTLS: settings[SettingKeySmtpUseTLS] == "true",
TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true",
TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey],
SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"),
SiteLogo: settings[SettingKeySiteLogo],
SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"),
APIBaseURL: settings[SettingKeyAPIBaseURL],
ApiBaseUrl: settings[SettingKeyApiBaseUrl],
ContactInfo: settings[SettingKeyContactInfo],
DocURL: settings[SettingKeyDocURL],
DocUrl: settings[SettingKeyDocUrl],
}
// 解析整数类型
if port, err := strconv.Atoi(settings[SettingKeySMTPPort]); err == nil {
result.SMTPPort = port
if port, err := strconv.Atoi(settings[SettingKeySmtpPort]); err == nil {
result.SmtpPort = port
} else {
result.SMTPPort = 587
result.SmtpPort = 587
}
if concurrency, err := strconv.Atoi(settings[SettingKeyDefaultConcurrency]); err == nil {
@@ -245,10 +258,17 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
result.DefaultBalance = s.cfg.Default.UserBalance
}
// 敏感信息直接返回,方便测试连接时使用
result.SMTPPassword = settings[SettingKeySMTPPassword]
// 敏感信息直接返回方便测试连接时使用
result.SmtpPassword = settings[SettingKeySmtpPassword]
result.TurnstileSecretKey = settings[SettingKeyTurnstileSecretKey]
// Model fallback settings
result.EnableModelFallback = settings[SettingKeyEnableModelFallback] == "true"
result.FallbackModelAnthropic = s.getStringOrDefault(settings, SettingKeyFallbackModelAnthropic, "claude-3-5-sonnet-20241022")
result.FallbackModelOpenAI = s.getStringOrDefault(settings, SettingKeyFallbackModelOpenAI, "gpt-4o")
result.FallbackModelGemini = s.getStringOrDefault(settings, SettingKeyFallbackModelGemini, "gemini-2.5-pro")
result.FallbackModelAntigravity = s.getStringOrDefault(settings, SettingKeyFallbackModelAntigravity, "gemini-2.5-pro")
return result
}
@@ -278,28 +298,28 @@ func (s *SettingService) GetTurnstileSecretKey(ctx context.Context) string {
return value
}
// GenerateAdminAPIKey 生成新的管理员 API Key
func (s *SettingService) GenerateAdminAPIKey(ctx context.Context) (string, error) {
// GenerateAdminApiKey 生成新的管理员 API Key
func (s *SettingService) GenerateAdminApiKey(ctx context.Context) (string, error) {
// 生成 32 字节随机数 = 64 位十六进制字符
bytes := make([]byte, 32)
if _, err := rand.Read(bytes); err != nil {
return "", fmt.Errorf("generate random bytes: %w", err)
}
key := AdminAPIKeyPrefix + hex.EncodeToString(bytes)
key := AdminApiKeyPrefix + hex.EncodeToString(bytes)
// 存储到 settings 表
if err := s.settingRepo.Set(ctx, SettingKeyAdminAPIKey, key); err != nil {
if err := s.settingRepo.Set(ctx, SettingKeyAdminApiKey, key); err != nil {
return "", fmt.Errorf("save admin api key: %w", err)
}
return key, nil
}
// GetAdminAPIKeyStatus 获取管理员 API Key 状态
// GetAdminApiKeyStatus 获取管理员 API Key 状态
// 返回脱敏的 key、是否存在、错误
func (s *SettingService) GetAdminAPIKeyStatus(ctx context.Context) (maskedKey string, exists bool, err error) {
key, err := s.settingRepo.GetValue(ctx, SettingKeyAdminAPIKey)
func (s *SettingService) GetAdminApiKeyStatus(ctx context.Context) (maskedKey string, exists bool, err error) {
key, err := s.settingRepo.GetValue(ctx, SettingKeyAdminApiKey)
if err != nil {
if errors.Is(err, ErrSettingNotFound) {
return "", false, nil
@@ -320,10 +340,10 @@ func (s *SettingService) GetAdminAPIKeyStatus(ctx context.Context) (maskedKey st
return maskedKey, true, nil
}
// GetAdminAPIKey 获取完整的管理员 API Key仅供内部验证使用
// GetAdminApiKey 获取完整的管理员 API Key仅供内部验证使用
// 如果未配置返回空字符串和 nil 错误,只有数据库错误时才返回 error
func (s *SettingService) GetAdminAPIKey(ctx context.Context) (string, error) {
key, err := s.settingRepo.GetValue(ctx, SettingKeyAdminAPIKey)
func (s *SettingService) GetAdminApiKey(ctx context.Context) (string, error) {
key, err := s.settingRepo.GetValue(ctx, SettingKeyAdminApiKey)
if err != nil {
if errors.Is(err, ErrSettingNotFound) {
return "", nil // 未配置,返回空字符串
@@ -333,7 +353,45 @@ func (s *SettingService) GetAdminAPIKey(ctx context.Context) (string, error) {
return key, nil
}
// DeleteAdminAPIKey 删除管理员 API Key
func (s *SettingService) DeleteAdminAPIKey(ctx context.Context) error {
return s.settingRepo.Delete(ctx, SettingKeyAdminAPIKey)
// DeleteAdminApiKey 删除管理员 API Key
func (s *SettingService) DeleteAdminApiKey(ctx context.Context) error {
return s.settingRepo.Delete(ctx, SettingKeyAdminApiKey)
}
// IsModelFallbackEnabled 检查是否启用模型兜底机制
func (s *SettingService) IsModelFallbackEnabled(ctx context.Context) bool {
value, err := s.settingRepo.GetValue(ctx, SettingKeyEnableModelFallback)
if err != nil {
return false // Default: disabled
}
return value == "true"
}
// GetFallbackModel 获取指定平台的兜底模型
func (s *SettingService) GetFallbackModel(ctx context.Context, platform string) string {
var key string
var defaultModel string
switch platform {
case PlatformAnthropic:
key = SettingKeyFallbackModelAnthropic
defaultModel = "claude-3-5-sonnet-20241022"
case PlatformOpenAI:
key = SettingKeyFallbackModelOpenAI
defaultModel = "gpt-4o"
case PlatformGemini:
key = SettingKeyFallbackModelGemini
defaultModel = "gemini-2.5-pro"
case PlatformAntigravity:
key = SettingKeyFallbackModelAntigravity
defaultModel = "gemini-2.5-pro"
default:
return ""
}
value, err := s.settingRepo.GetValue(ctx, key)
if err != nil || value == "" {
return defaultModel
}
return value
}

View File

@@ -4,13 +4,13 @@ type SystemSettings struct {
RegistrationEnabled bool
EmailVerifyEnabled bool
SMTPHost string
SMTPPort int
SMTPUsername string
SMTPPassword string
SMTPFrom string
SMTPFromName string
SMTPUseTLS bool
SmtpHost string
SmtpPort int
SmtpUsername string
SmtpPassword string
SmtpFrom string
SmtpFromName string
SmtpUseTLS bool
TurnstileEnabled bool
TurnstileSiteKey string
@@ -19,12 +19,19 @@ type SystemSettings struct {
SiteName string
SiteLogo string
SiteSubtitle string
APIBaseURL string
ApiBaseUrl string
ContactInfo string
DocURL string
DocUrl string
DefaultConcurrency int
DefaultBalance float64
// Model fallback configuration
EnableModelFallback bool `json:"enable_model_fallback"`
FallbackModelAnthropic string `json:"fallback_model_anthropic"`
FallbackModelOpenAI string `json:"fallback_model_openai"`
FallbackModelGemini string `json:"fallback_model_gemini"`
FallbackModelAntigravity string `json:"fallback_model_antigravity"`
}
type PublicSettings struct {
@@ -35,8 +42,8 @@ type PublicSettings struct {
SiteName string
SiteLogo string
SiteSubtitle string
APIBaseURL string
ApiBaseUrl string
ContactInfo string
DocURL string
DocUrl string
Version string
}

View File

@@ -79,7 +79,7 @@ type ReleaseInfo struct {
Name string `json:"name"`
Body string `json:"body"`
PublishedAt string `json:"published_at"`
HTMLURL string `json:"html_url"`
HtmlURL string `json:"html_url"`
Assets []Asset `json:"assets,omitempty"`
}
@@ -96,13 +96,13 @@ type GitHubRelease struct {
Name string `json:"name"`
Body string `json:"body"`
PublishedAt string `json:"published_at"`
HTMLURL string `json:"html_url"`
HtmlUrl string `json:"html_url"`
Assets []GitHubAsset `json:"assets"`
}
type GitHubAsset struct {
Name string `json:"name"`
BrowserDownloadURL string `json:"browser_download_url"`
BrowserDownloadUrl string `json:"browser_download_url"`
Size int64 `json:"size"`
}
@@ -285,7 +285,7 @@ func (s *UpdateService) fetchLatestRelease(ctx context.Context) (*UpdateInfo, er
for i, a := range release.Assets {
assets[i] = Asset{
Name: a.Name,
DownloadURL: a.BrowserDownloadURL,
DownloadURL: a.BrowserDownloadUrl,
Size: a.Size,
}
}
@@ -298,7 +298,7 @@ func (s *UpdateService) fetchLatestRelease(ctx context.Context) (*UpdateInfo, er
Name: release.Name,
Body: release.Body,
PublishedAt: release.PublishedAt,
HTMLURL: release.HTMLURL,
HtmlURL: release.HtmlUrl,
Assets: assets,
},
Cached: false,

View File

@@ -0,0 +1,35 @@
package service
import "time"
// clampInt 将整数限制在指定范围内
func clampInt(value, min, max int) int {
if value < min {
return min
}
if value > max {
return max
}
return value
}
// clampFloat64 将浮点数限制在指定范围内
func clampFloat64(value, min, max float64) float64 {
if value < min {
return min
}
if value > max {
return max
}
return value
}
// remainingSecondsUntil 计算到指定时间的剩余秒数,保证非负
func remainingSecondsUntil(t time.Time) int {
seconds := int(time.Until(t).Seconds())
if seconds < 0 {
return 0
}
return seconds
}

View File

@@ -10,7 +10,7 @@ const (
type UsageLog struct {
ID int64
UserID int64
APIKeyID int64
ApiKeyID int64
AccountID int64
RequestID string
Model string
@@ -42,7 +42,7 @@ type UsageLog struct {
CreatedAt time.Time
User *User
APIKey *APIKey
ApiKey *ApiKey
Account *Account
Group *Group
Subscription *UserSubscription

View File

@@ -17,7 +17,7 @@ var (
// CreateUsageLogRequest 创建使用日志请求
type CreateUsageLogRequest struct {
UserID int64 `json:"user_id"`
APIKeyID int64 `json:"api_key_id"`
ApiKeyID int64 `json:"api_key_id"`
AccountID int64 `json:"account_id"`
RequestID string `json:"request_id"`
Model string `json:"model"`
@@ -75,7 +75,7 @@ func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (*
// 创建使用日志
usageLog := &UsageLog{
UserID: req.UserID,
APIKeyID: req.APIKeyID,
ApiKeyID: req.ApiKeyID,
AccountID: req.AccountID,
RequestID: req.RequestID,
Model: req.Model,
@@ -128,9 +128,9 @@ func (s *UsageService) ListByUser(ctx context.Context, userID int64, params pagi
return logs, pagination, nil
}
// ListByAPIKey 获取API Key的使用日志列表
func (s *UsageService) ListByAPIKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error) {
logs, pagination, err := s.usageRepo.ListByAPIKey(ctx, apiKeyID, params)
// ListByApiKey 获取API Key的使用日志列表
func (s *UsageService) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error) {
logs, pagination, err := s.usageRepo.ListByApiKey(ctx, apiKeyID, params)
if err != nil {
return nil, nil, fmt.Errorf("list usage logs: %w", err)
}
@@ -165,9 +165,9 @@ func (s *UsageService) GetStatsByUser(ctx context.Context, userID int64, startTi
}, nil
}
// GetStatsByAPIKey 获取API Key的使用统计
func (s *UsageService) GetStatsByAPIKey(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*UsageStats, error) {
stats, err := s.usageRepo.GetAPIKeyStatsAggregated(ctx, apiKeyID, startTime, endTime)
// GetStatsByApiKey 获取API Key的使用统计
func (s *UsageService) GetStatsByApiKey(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*UsageStats, error) {
stats, err := s.usageRepo.GetApiKeyStatsAggregated(ctx, apiKeyID, startTime, endTime)
if err != nil {
return nil, fmt.Errorf("get api key stats: %w", err)
}
@@ -270,9 +270,9 @@ func (s *UsageService) GetUserModelStats(ctx context.Context, userID int64, star
return stats, nil
}
// GetBatchAPIKeyUsageStats returns today/total actual_cost for given api keys.
func (s *UsageService) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
stats, err := s.usageRepo.GetBatchAPIKeyUsageStats(ctx, apiKeyIDs)
// GetBatchApiKeyUsageStats returns today/total actual_cost for given api keys.
func (s *UsageService) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchApiKeyUsageStats, error) {
stats, err := s.usageRepo.GetBatchApiKeyUsageStats(ctx, apiKeyIDs)
if err != nil {
return nil, fmt.Errorf("get batch api key usage stats: %w", err)
}

View File

@@ -21,7 +21,7 @@ type User struct {
CreatedAt time.Time
UpdatedAt time.Time
APIKeys []APIKey
ApiKeys []ApiKey
Subscriptions []UserSubscription
}

View File

@@ -56,6 +56,10 @@ func (s *UserAttributeService) CreateDefinition(ctx context.Context, input Creat
Enabled: input.Enabled,
}
if err := validateDefinitionPattern(def); err != nil {
return nil, err
}
if err := s.defRepo.Create(ctx, def); err != nil {
return nil, fmt.Errorf("create definition: %w", err)
}
@@ -108,6 +112,10 @@ func (s *UserAttributeService) UpdateDefinition(ctx context.Context, id int64, i
def.Enabled = *input.Enabled
}
if err := validateDefinitionPattern(def); err != nil {
return nil, err
}
if err := s.defRepo.Update(ctx, def); err != nil {
return nil, fmt.Errorf("update definition: %w", err)
}
@@ -231,7 +239,10 @@ func (s *UserAttributeService) validateValue(def *UserAttributeDefinition, value
// Pattern validation
if v.Pattern != nil && *v.Pattern != "" && value != "" {
re, err := regexp.Compile(*v.Pattern)
if err == nil && !re.MatchString(value) {
if err != nil {
return validationError(def.Name + " has an invalid pattern")
}
if !re.MatchString(value) {
msg := def.Name + " format is invalid"
if v.Message != nil && *v.Message != "" {
msg = *v.Message
@@ -293,3 +304,20 @@ func isValidAttributeType(t UserAttributeType) bool {
}
return false
}
func validateDefinitionPattern(def *UserAttributeDefinition) error {
if def == nil {
return nil
}
if def.Validation.Pattern == nil {
return nil
}
pattern := strings.TrimSpace(*def.Validation.Pattern)
if pattern == "" {
return nil
}
if _, err := regexp.Compile(pattern); err != nil {
return infraerrors.BadRequest("INVALID_ATTRIBUTE_PATTERN", fmt.Sprintf("invalid pattern for %s: %v", def.Name, err))
}
return nil
}

View File

@@ -54,18 +54,6 @@ func ProvideTimingWheelService() *TimingWheelService {
return svc
}
// ProvideAntigravityQuotaRefresher creates and starts AntigravityQuotaRefresher
func ProvideAntigravityQuotaRefresher(
accountRepo AccountRepository,
proxyRepo ProxyRepository,
oauthSvc *AntigravityOAuthService,
cfg *config.Config,
) *AntigravityQuotaRefresher {
svc := NewAntigravityQuotaRefresher(accountRepo, proxyRepo, oauthSvc, cfg)
svc.Start()
return svc
}
// ProvideDeferredService creates and starts DeferredService
func ProvideDeferredService(accountRepo AccountRepository, timingWheel *TimingWheelService) *DeferredService {
svc := NewDeferredService(accountRepo, timingWheel, 10*time.Second)
@@ -73,20 +61,6 @@ func ProvideDeferredService(accountRepo AccountRepository, timingWheel *TimingWh
return svc
}
// ProvideOpsMetricsCollector creates and starts OpsMetricsCollector.
func ProvideOpsMetricsCollector(opsService *OpsService, concurrencyService *ConcurrencyService) *OpsMetricsCollector {
svc := NewOpsMetricsCollector(opsService, concurrencyService)
svc.Start()
return svc
}
// ProvideOpsAlertService creates and starts OpsAlertService.
func ProvideOpsAlertService(opsService *OpsService, userService *UserService, emailService *EmailService) *OpsAlertService {
svc := NewOpsAlertService(opsService, userService, emailService)
svc.Start()
return svc
}
// ProvideConcurrencyService creates ConcurrencyService and starts slot cleanup worker.
func ProvideConcurrencyService(cache ConcurrencyCache, accountRepo AccountRepository, cfg *config.Config) *ConcurrencyService {
svc := NewConcurrencyService(cache)
@@ -101,14 +75,13 @@ var ProviderSet = wire.NewSet(
// Core services
NewAuthService,
NewUserService,
NewAPIKeyService,
NewApiKeyService,
NewGroupService,
NewAccountService,
NewProxyService,
NewRedeemService,
NewUsageService,
NewDashboardService,
NewOpsService,
ProvidePricingService,
NewBillingService,
NewBillingCacheService,
@@ -139,8 +112,7 @@ var ProviderSet = wire.NewSet(
ProvideTokenRefreshService,
ProvideTimingWheelService,
ProvideDeferredService,
ProvideAntigravityQuotaRefresher,
ProvideOpsMetricsCollector,
ProvideOpsAlertService,
NewAntigravityQuotaFetcher,
NewUserAttributeService,
NewUsageCache,
)