chore: 更新依赖、配置和代码生成

主要更新:
- 更新 go.mod/go.sum 依赖
- 重新生成 Ent ORM 代码
- 更新 Wire 依赖注入配置
- 添加 docker-compose.override.yml 到 .gitignore
- 更新 README 文档(Simple Mode 说明和已知问题)
- 清理调试日志
- 其他代码优化和格式修复
This commit is contained in:
ianshaw
2026-01-03 06:37:08 -08:00
parent b1702de522
commit 112a2d0866
121 changed files with 3058 additions and 2948 deletions

View File

@@ -1,4 +1,3 @@
// Package config provides application configuration management.
package config
import (
@@ -140,7 +139,7 @@ type GatewayConfig struct {
LogUpstreamErrorBodyMaxBytes int `mapstructure:"log_upstream_error_body_max_bytes"`
// API-key 账号在客户端未提供 anthropic-beta 时,是否按需自动补齐(默认关闭以保持兼容)
InjectBetaForAPIKey bool `mapstructure:"inject_beta_for_apikey"`
InjectBetaForApiKey bool `mapstructure:"inject_beta_for_apikey"`
// 是否允许对部分 400 错误触发 failover默认关闭以避免改变语义
FailoverOn400 bool `mapstructure:"failover_on_400"`
@@ -242,7 +241,7 @@ type DefaultConfig struct {
AdminPassword string `mapstructure:"admin_password"`
UserConcurrency int `mapstructure:"user_concurrency"`
UserBalance float64 `mapstructure:"user_balance"`
APIKeyPrefix string `mapstructure:"api_key_prefix"`
ApiKeyPrefix string `mapstructure:"api_key_prefix"`
RateMultiplier float64 `mapstructure:"rate_multiplier"`
}

View File

@@ -1,4 +1,3 @@
// Package config provides application configuration management.
package config
import "github.com/google/wire"

View File

@@ -1,5 +1,3 @@
// Package admin provides HTTP handlers for administrative operations including
// dashboard statistics, user management, API key management, and account management.
package admin
import (
@@ -77,8 +75,8 @@ func (h *DashboardHandler) GetStats(c *gin.Context) {
"active_users": stats.ActiveUsers,
// API Key 统计
"total_api_keys": stats.TotalAPIKeys,
"active_api_keys": stats.ActiveAPIKeys,
"total_api_keys": stats.TotalApiKeys,
"active_api_keys": stats.ActiveApiKeys,
// 账户统计
"total_accounts": stats.TotalAccounts,
@@ -195,10 +193,10 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
})
}
// GetAPIKeyUsageTrend handles getting API key usage trend data
// GetApiKeyUsageTrend handles getting API key usage trend data
// GET /api/v1/admin/dashboard/api-keys-trend
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), limit (default 5)
func (h *DashboardHandler) GetAPIKeyUsageTrend(c *gin.Context) {
func (h *DashboardHandler) GetApiKeyUsageTrend(c *gin.Context) {
startTime, endTime := parseTimeRange(c)
granularity := c.DefaultQuery("granularity", "day")
limitStr := c.DefaultQuery("limit", "5")
@@ -207,7 +205,7 @@ func (h *DashboardHandler) GetAPIKeyUsageTrend(c *gin.Context) {
limit = 5
}
trend, err := h.dashboardService.GetAPIKeyUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit)
trend, err := h.dashboardService.GetApiKeyUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit)
if err != nil {
response.Error(c, 500, "Failed to get API key usage trend")
return
@@ -275,26 +273,26 @@ func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) {
response.Success(c, gin.H{"stats": stats})
}
// BatchAPIKeysUsageRequest represents the request body for batch api key usage stats
type BatchAPIKeysUsageRequest struct {
APIKeyIDs []int64 `json:"api_key_ids" binding:"required"`
// BatchApiKeysUsageRequest represents the request body for batch api key usage stats
type BatchApiKeysUsageRequest struct {
ApiKeyIDs []int64 `json:"api_key_ids" binding:"required"`
}
// GetBatchAPIKeysUsage handles getting usage stats for multiple API keys
// GetBatchApiKeysUsage handles getting usage stats for multiple API keys
// POST /api/v1/admin/dashboard/api-keys-usage
func (h *DashboardHandler) GetBatchAPIKeysUsage(c *gin.Context) {
var req BatchAPIKeysUsageRequest
func (h *DashboardHandler) GetBatchApiKeysUsage(c *gin.Context) {
var req BatchApiKeysUsageRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if len(req.APIKeyIDs) == 0 {
if len(req.ApiKeyIDs) == 0 {
response.Success(c, gin.H{"stats": map[string]any{}})
return
}
stats, err := h.dashboardService.GetBatchAPIKeyUsageStats(c.Request.Context(), req.APIKeyIDs)
stats, err := h.dashboardService.GetBatchApiKeyUsageStats(c.Request.Context(), req.ApiKeyIDs)
if err != nil {
response.Error(c, 500, "Failed to get API key usage stats")
return

View File

@@ -237,9 +237,9 @@ func (h *GroupHandler) GetGroupAPIKeys(c *gin.Context) {
return
}
outKeys := make([]dto.APIKey, 0, len(keys))
outKeys := make([]dto.ApiKey, 0, len(keys))
for i := range keys {
outKeys = append(outKeys, *dto.APIKeyFromService(&keys[i]))
outKeys = append(outKeys, *dto.ApiKeyFromService(&keys[i]))
}
response.Paginated(c, outKeys, total, page, pageSize)
}

View File

@@ -36,24 +36,29 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
response.Success(c, dto.SystemSettings{
RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled,
SMTPHost: settings.SMTPHost,
SMTPPort: settings.SMTPPort,
SMTPUsername: settings.SMTPUsername,
SMTPPassword: settings.SMTPPassword,
SMTPFrom: settings.SMTPFrom,
SMTPFromName: settings.SMTPFromName,
SMTPUseTLS: settings.SMTPUseTLS,
SmtpHost: settings.SmtpHost,
SmtpPort: settings.SmtpPort,
SmtpUsername: settings.SmtpUsername,
SmtpPassword: settings.SmtpPassword,
SmtpFrom: settings.SmtpFrom,
SmtpFromName: settings.SmtpFromName,
SmtpUseTLS: settings.SmtpUseTLS,
TurnstileEnabled: settings.TurnstileEnabled,
TurnstileSiteKey: settings.TurnstileSiteKey,
TurnstileSecretKey: settings.TurnstileSecretKey,
SiteName: settings.SiteName,
SiteLogo: settings.SiteLogo,
SiteSubtitle: settings.SiteSubtitle,
APIBaseURL: settings.APIBaseURL,
ApiBaseUrl: settings.ApiBaseUrl,
ContactInfo: settings.ContactInfo,
DocURL: settings.DocURL,
DocUrl: settings.DocUrl,
DefaultConcurrency: settings.DefaultConcurrency,
DefaultBalance: settings.DefaultBalance,
EnableModelFallback: settings.EnableModelFallback,
FallbackModelAnthropic: settings.FallbackModelAnthropic,
FallbackModelOpenAI: settings.FallbackModelOpenAI,
FallbackModelGemini: settings.FallbackModelGemini,
FallbackModelAntigravity: settings.FallbackModelAntigravity,
})
}
@@ -64,13 +69,13 @@ type UpdateSettingsRequest struct {
EmailVerifyEnabled bool `json:"email_verify_enabled"`
// 邮件服务设置
SMTPHost string `json:"smtp_host"`
SMTPPort int `json:"smtp_port"`
SMTPUsername string `json:"smtp_username"`
SMTPPassword string `json:"smtp_password"`
SMTPFrom string `json:"smtp_from_email"`
SMTPFromName string `json:"smtp_from_name"`
SMTPUseTLS bool `json:"smtp_use_tls"`
SmtpHost string `json:"smtp_host"`
SmtpPort int `json:"smtp_port"`
SmtpUsername string `json:"smtp_username"`
SmtpPassword string `json:"smtp_password"`
SmtpFrom string `json:"smtp_from_email"`
SmtpFromName string `json:"smtp_from_name"`
SmtpUseTLS bool `json:"smtp_use_tls"`
// Cloudflare Turnstile 设置
TurnstileEnabled bool `json:"turnstile_enabled"`
@@ -81,13 +86,20 @@ type UpdateSettingsRequest struct {
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
APIBaseURL string `json:"api_base_url"`
ApiBaseUrl string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocURL string `json:"doc_url"`
DocUrl string `json:"doc_url"`
// 默认配置
DefaultConcurrency int `json:"default_concurrency"`
DefaultBalance float64 `json:"default_balance"`
// Model fallback configuration
EnableModelFallback bool `json:"enable_model_fallback"`
FallbackModelAnthropic string `json:"fallback_model_anthropic"`
FallbackModelOpenAI string `json:"fallback_model_openai"`
FallbackModelGemini string `json:"fallback_model_gemini"`
FallbackModelAntigravity string `json:"fallback_model_antigravity"`
}
// UpdateSettings 更新系统设置
@@ -106,8 +118,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
if req.DefaultBalance < 0 {
req.DefaultBalance = 0
}
if req.SMTPPort <= 0 {
req.SMTPPort = 587
if req.SmtpPort <= 0 {
req.SmtpPort = 587
}
// Turnstile 参数验证
@@ -143,24 +155,29 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
settings := &service.SystemSettings{
RegistrationEnabled: req.RegistrationEnabled,
EmailVerifyEnabled: req.EmailVerifyEnabled,
SMTPHost: req.SMTPHost,
SMTPPort: req.SMTPPort,
SMTPUsername: req.SMTPUsername,
SMTPPassword: req.SMTPPassword,
SMTPFrom: req.SMTPFrom,
SMTPFromName: req.SMTPFromName,
SMTPUseTLS: req.SMTPUseTLS,
SmtpHost: req.SmtpHost,
SmtpPort: req.SmtpPort,
SmtpUsername: req.SmtpUsername,
SmtpPassword: req.SmtpPassword,
SmtpFrom: req.SmtpFrom,
SmtpFromName: req.SmtpFromName,
SmtpUseTLS: req.SmtpUseTLS,
TurnstileEnabled: req.TurnstileEnabled,
TurnstileSiteKey: req.TurnstileSiteKey,
TurnstileSecretKey: req.TurnstileSecretKey,
SiteName: req.SiteName,
SiteLogo: req.SiteLogo,
SiteSubtitle: req.SiteSubtitle,
APIBaseURL: req.APIBaseURL,
ApiBaseUrl: req.ApiBaseUrl,
ContactInfo: req.ContactInfo,
DocURL: req.DocURL,
DocUrl: req.DocUrl,
DefaultConcurrency: req.DefaultConcurrency,
DefaultBalance: req.DefaultBalance,
EnableModelFallback: req.EnableModelFallback,
FallbackModelAnthropic: req.FallbackModelAnthropic,
FallbackModelOpenAI: req.FallbackModelOpenAI,
FallbackModelGemini: req.FallbackModelGemini,
FallbackModelAntigravity: req.FallbackModelAntigravity,
}
if err := h.settingService.UpdateSettings(c.Request.Context(), settings); err != nil {
@@ -178,67 +195,72 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
response.Success(c, dto.SystemSettings{
RegistrationEnabled: updatedSettings.RegistrationEnabled,
EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
SMTPHost: updatedSettings.SMTPHost,
SMTPPort: updatedSettings.SMTPPort,
SMTPUsername: updatedSettings.SMTPUsername,
SMTPPassword: updatedSettings.SMTPPassword,
SMTPFrom: updatedSettings.SMTPFrom,
SMTPFromName: updatedSettings.SMTPFromName,
SMTPUseTLS: updatedSettings.SMTPUseTLS,
SmtpHost: updatedSettings.SmtpHost,
SmtpPort: updatedSettings.SmtpPort,
SmtpUsername: updatedSettings.SmtpUsername,
SmtpPassword: updatedSettings.SmtpPassword,
SmtpFrom: updatedSettings.SmtpFrom,
SmtpFromName: updatedSettings.SmtpFromName,
SmtpUseTLS: updatedSettings.SmtpUseTLS,
TurnstileEnabled: updatedSettings.TurnstileEnabled,
TurnstileSiteKey: updatedSettings.TurnstileSiteKey,
TurnstileSecretKey: updatedSettings.TurnstileSecretKey,
SiteName: updatedSettings.SiteName,
SiteLogo: updatedSettings.SiteLogo,
SiteSubtitle: updatedSettings.SiteSubtitle,
APIBaseURL: updatedSettings.APIBaseURL,
ApiBaseUrl: updatedSettings.ApiBaseUrl,
ContactInfo: updatedSettings.ContactInfo,
DocURL: updatedSettings.DocURL,
DocUrl: updatedSettings.DocUrl,
DefaultConcurrency: updatedSettings.DefaultConcurrency,
DefaultBalance: updatedSettings.DefaultBalance,
EnableModelFallback: updatedSettings.EnableModelFallback,
FallbackModelAnthropic: updatedSettings.FallbackModelAnthropic,
FallbackModelOpenAI: updatedSettings.FallbackModelOpenAI,
FallbackModelGemini: updatedSettings.FallbackModelGemini,
FallbackModelAntigravity: updatedSettings.FallbackModelAntigravity,
})
}
// TestSMTPRequest 测试SMTP连接请求
type TestSMTPRequest struct {
SMTPHost string `json:"smtp_host" binding:"required"`
SMTPPort int `json:"smtp_port"`
SMTPUsername string `json:"smtp_username"`
SMTPPassword string `json:"smtp_password"`
SMTPUseTLS bool `json:"smtp_use_tls"`
// TestSmtpRequest 测试SMTP连接请求
type TestSmtpRequest struct {
SmtpHost string `json:"smtp_host" binding:"required"`
SmtpPort int `json:"smtp_port"`
SmtpUsername string `json:"smtp_username"`
SmtpPassword string `json:"smtp_password"`
SmtpUseTLS bool `json:"smtp_use_tls"`
}
// TestSMTPConnection 测试SMTP连接
// TestSmtpConnection 测试SMTP连接
// POST /api/v1/admin/settings/test-smtp
func (h *SettingHandler) TestSMTPConnection(c *gin.Context) {
var req TestSMTPRequest
func (h *SettingHandler) TestSmtpConnection(c *gin.Context) {
var req TestSmtpRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if req.SMTPPort <= 0 {
req.SMTPPort = 587
if req.SmtpPort <= 0 {
req.SmtpPort = 587
}
// 如果未提供密码,从数据库获取已保存的密码
password := req.SMTPPassword
password := req.SmtpPassword
if password == "" {
savedConfig, err := h.emailService.GetSMTPConfig(c.Request.Context())
savedConfig, err := h.emailService.GetSmtpConfig(c.Request.Context())
if err == nil && savedConfig != nil {
password = savedConfig.Password
}
}
config := &service.SMTPConfig{
Host: req.SMTPHost,
Port: req.SMTPPort,
Username: req.SMTPUsername,
config := &service.SmtpConfig{
Host: req.SmtpHost,
Port: req.SmtpPort,
Username: req.SmtpUsername,
Password: password,
UseTLS: req.SMTPUseTLS,
UseTLS: req.SmtpUseTLS,
}
err := h.emailService.TestSMTPConnectionWithConfig(config)
err := h.emailService.TestSmtpConnectionWithConfig(config)
if err != nil {
response.ErrorFrom(c, err)
return
@@ -250,13 +272,13 @@ func (h *SettingHandler) TestSMTPConnection(c *gin.Context) {
// SendTestEmailRequest 发送测试邮件请求
type SendTestEmailRequest struct {
Email string `json:"email" binding:"required,email"`
SMTPHost string `json:"smtp_host" binding:"required"`
SMTPPort int `json:"smtp_port"`
SMTPUsername string `json:"smtp_username"`
SMTPPassword string `json:"smtp_password"`
SMTPFrom string `json:"smtp_from_email"`
SMTPFromName string `json:"smtp_from_name"`
SMTPUseTLS bool `json:"smtp_use_tls"`
SmtpHost string `json:"smtp_host" binding:"required"`
SmtpPort int `json:"smtp_port"`
SmtpUsername string `json:"smtp_username"`
SmtpPassword string `json:"smtp_password"`
SmtpFrom string `json:"smtp_from_email"`
SmtpFromName string `json:"smtp_from_name"`
SmtpUseTLS bool `json:"smtp_use_tls"`
}
// SendTestEmail 发送测试邮件
@@ -268,27 +290,27 @@ func (h *SettingHandler) SendTestEmail(c *gin.Context) {
return
}
if req.SMTPPort <= 0 {
req.SMTPPort = 587
if req.SmtpPort <= 0 {
req.SmtpPort = 587
}
// 如果未提供密码,从数据库获取已保存的密码
password := req.SMTPPassword
password := req.SmtpPassword
if password == "" {
savedConfig, err := h.emailService.GetSMTPConfig(c.Request.Context())
savedConfig, err := h.emailService.GetSmtpConfig(c.Request.Context())
if err == nil && savedConfig != nil {
password = savedConfig.Password
}
}
config := &service.SMTPConfig{
Host: req.SMTPHost,
Port: req.SMTPPort,
Username: req.SMTPUsername,
config := &service.SmtpConfig{
Host: req.SmtpHost,
Port: req.SmtpPort,
Username: req.SmtpUsername,
Password: password,
From: req.SMTPFrom,
FromName: req.SMTPFromName,
UseTLS: req.SMTPUseTLS,
From: req.SmtpFrom,
FromName: req.SmtpFromName,
UseTLS: req.SmtpUseTLS,
}
siteName := h.settingService.GetSiteName(c.Request.Context())
@@ -333,10 +355,10 @@ func (h *SettingHandler) SendTestEmail(c *gin.Context) {
response.Success(c, gin.H{"message": "Test email sent successfully"})
}
// GetAdminAPIKey 获取管理员 API Key 状态
// GetAdminApiKey 获取管理员 API Key 状态
// GET /api/v1/admin/settings/admin-api-key
func (h *SettingHandler) GetAdminAPIKey(c *gin.Context) {
maskedKey, exists, err := h.settingService.GetAdminAPIKeyStatus(c.Request.Context())
func (h *SettingHandler) GetAdminApiKey(c *gin.Context) {
maskedKey, exists, err := h.settingService.GetAdminApiKeyStatus(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
@@ -348,10 +370,10 @@ func (h *SettingHandler) GetAdminAPIKey(c *gin.Context) {
})
}
// RegenerateAdminAPIKey 生成/重新生成管理员 API Key
// RegenerateAdminApiKey 生成/重新生成管理员 API Key
// POST /api/v1/admin/settings/admin-api-key/regenerate
func (h *SettingHandler) RegenerateAdminAPIKey(c *gin.Context) {
key, err := h.settingService.GenerateAdminAPIKey(c.Request.Context())
func (h *SettingHandler) RegenerateAdminApiKey(c *gin.Context) {
key, err := h.settingService.GenerateAdminApiKey(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
@@ -362,10 +384,10 @@ func (h *SettingHandler) RegenerateAdminAPIKey(c *gin.Context) {
})
}
// DeleteAdminAPIKey 删除管理员 API Key
// DeleteAdminApiKey 删除管理员 API Key
// DELETE /api/v1/admin/settings/admin-api-key
func (h *SettingHandler) DeleteAdminAPIKey(c *gin.Context) {
if err := h.settingService.DeleteAdminAPIKey(c.Request.Context()); err != nil {
func (h *SettingHandler) DeleteAdminApiKey(c *gin.Context) {
if err := h.settingService.DeleteAdminApiKey(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}

View File

@@ -17,14 +17,14 @@ import (
// UsageHandler handles admin usage-related requests
type UsageHandler struct {
usageService *service.UsageService
apiKeyService *service.APIKeyService
apiKeyService *service.ApiKeyService
adminService service.AdminService
}
// NewUsageHandler creates a new admin usage handler
func NewUsageHandler(
usageService *service.UsageService,
apiKeyService *service.APIKeyService,
apiKeyService *service.ApiKeyService,
adminService service.AdminService,
) *UsageHandler {
return &UsageHandler{
@@ -125,7 +125,7 @@ func (h *UsageHandler) List(c *gin.Context) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
filters := usagestats.UsageLogFilters{
UserID: userID,
APIKeyID: apiKeyID,
ApiKeyID: apiKeyID,
AccountID: accountID,
GroupID: groupID,
Model: model,
@@ -207,7 +207,7 @@ func (h *UsageHandler) Stats(c *gin.Context) {
}
if apiKeyID > 0 {
stats, err := h.usageService.GetStatsByAPIKey(c.Request.Context(), apiKeyID, startTime, endTime)
stats, err := h.usageService.GetStatsByApiKey(c.Request.Context(), apiKeyID, startTime, endTime)
if err != nil {
response.ErrorFrom(c, err)
return
@@ -269,9 +269,9 @@ func (h *UsageHandler) SearchUsers(c *gin.Context) {
response.Success(c, result)
}
// SearchAPIKeys handles searching API keys by user
// SearchApiKeys handles searching API keys by user
// GET /api/v1/admin/usage/search-api-keys
func (h *UsageHandler) SearchAPIKeys(c *gin.Context) {
func (h *UsageHandler) SearchApiKeys(c *gin.Context) {
userIDStr := c.Query("user_id")
keyword := c.Query("q")
@@ -285,22 +285,22 @@ func (h *UsageHandler) SearchAPIKeys(c *gin.Context) {
userID = id
}
keys, err := h.apiKeyService.SearchAPIKeys(c.Request.Context(), userID, keyword, 30)
keys, err := h.apiKeyService.SearchApiKeys(c.Request.Context(), userID, keyword, 30)
if err != nil {
response.ErrorFrom(c, err)
return
}
// Return simplified API key list (only id and name)
type SimpleAPIKey struct {
type SimpleApiKey struct {
ID int64 `json:"id"`
Name string `json:"name"`
UserID int64 `json:"user_id"`
}
result := make([]SimpleAPIKey, len(keys))
result := make([]SimpleApiKey, len(keys))
for i, k := range keys {
result[i] = SimpleAPIKey{
result[i] = SimpleApiKey{
ID: k.ID,
Name: k.Name,
UserID: k.UserID,

View File

@@ -243,9 +243,9 @@ func (h *UserHandler) GetUserAPIKeys(c *gin.Context) {
return
}
out := make([]dto.APIKey, 0, len(keys))
out := make([]dto.ApiKey, 0, len(keys))
for i := range keys {
out = append(out, *dto.APIKeyFromService(&keys[i]))
out = append(out, *dto.ApiKeyFromService(&keys[i]))
}
response.Paginated(c, out, total, page, pageSize)
}

View File

@@ -14,11 +14,11 @@ import (
// APIKeyHandler handles API key-related requests
type APIKeyHandler struct {
apiKeyService *service.APIKeyService
apiKeyService *service.ApiKeyService
}
// NewAPIKeyHandler creates a new APIKeyHandler
func NewAPIKeyHandler(apiKeyService *service.APIKeyService) *APIKeyHandler {
func NewAPIKeyHandler(apiKeyService *service.ApiKeyService) *APIKeyHandler {
return &APIKeyHandler{
apiKeyService: apiKeyService,
}
@@ -56,9 +56,9 @@ func (h *APIKeyHandler) List(c *gin.Context) {
return
}
out := make([]dto.APIKey, 0, len(keys))
out := make([]dto.ApiKey, 0, len(keys))
for i := range keys {
out = append(out, *dto.APIKeyFromService(&keys[i]))
out = append(out, *dto.ApiKeyFromService(&keys[i]))
}
response.Paginated(c, out, result.Total, page, pageSize)
}
@@ -90,7 +90,7 @@ func (h *APIKeyHandler) GetByID(c *gin.Context) {
return
}
response.Success(c, dto.APIKeyFromService(key))
response.Success(c, dto.ApiKeyFromService(key))
}
// Create handles creating a new API key
@@ -108,7 +108,7 @@ func (h *APIKeyHandler) Create(c *gin.Context) {
return
}
svcReq := service.CreateAPIKeyRequest{
svcReq := service.CreateApiKeyRequest{
Name: req.Name,
GroupID: req.GroupID,
CustomKey: req.CustomKey,
@@ -119,7 +119,7 @@ func (h *APIKeyHandler) Create(c *gin.Context) {
return
}
response.Success(c, dto.APIKeyFromService(key))
response.Success(c, dto.ApiKeyFromService(key))
}
// Update handles updating an API key
@@ -143,7 +143,7 @@ func (h *APIKeyHandler) Update(c *gin.Context) {
return
}
svcReq := service.UpdateAPIKeyRequest{}
svcReq := service.UpdateApiKeyRequest{}
if req.Name != "" {
svcReq.Name = &req.Name
}
@@ -158,7 +158,7 @@ func (h *APIKeyHandler) Update(c *gin.Context) {
return
}
response.Success(c, dto.APIKeyFromService(key))
response.Success(c, dto.ApiKeyFromService(key))
}
// Delete handles deleting an API key

View File

@@ -5,13 +5,13 @@ type SystemSettings struct {
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
SMTPHost string `json:"smtp_host"`
SMTPPort int `json:"smtp_port"`
SMTPUsername string `json:"smtp_username"`
SMTPPassword string `json:"smtp_password,omitempty"`
SMTPFrom string `json:"smtp_from_email"`
SMTPFromName string `json:"smtp_from_name"`
SMTPUseTLS bool `json:"smtp_use_tls"`
SmtpHost string `json:"smtp_host"`
SmtpPort int `json:"smtp_port"`
SmtpUsername string `json:"smtp_username"`
SmtpPassword string `json:"smtp_password,omitempty"`
SmtpFrom string `json:"smtp_from_email"`
SmtpFromName string `json:"smtp_from_name"`
SmtpUseTLS bool `json:"smtp_use_tls"`
TurnstileEnabled bool `json:"turnstile_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key"`
@@ -20,12 +20,19 @@ type SystemSettings struct {
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
APIBaseURL string `json:"api_base_url"`
ApiBaseUrl string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocURL string `json:"doc_url"`
DocUrl string `json:"doc_url"`
DefaultConcurrency int `json:"default_concurrency"`
DefaultBalance float64 `json:"default_balance"`
// Model fallback configuration
EnableModelFallback bool `json:"enable_model_fallback"`
FallbackModelAnthropic string `json:"fallback_model_anthropic"`
FallbackModelOpenAI string `json:"fallback_model_openai"`
FallbackModelGemini string `json:"fallback_model_gemini"`
FallbackModelAntigravity string `json:"fallback_model_antigravity"`
}
type PublicSettings struct {
@@ -36,8 +43,8 @@ type PublicSettings struct {
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
APIBaseURL string `json:"api_base_url"`
ApiBaseUrl string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocURL string `json:"doc_url"`
DocUrl string `json:"doc_url"`
Version string `json:"version"`
}

View File

@@ -7,7 +7,6 @@ import (
// AdminHandlers contains all admin-related HTTP handlers
type AdminHandlers struct {
Dashboard *admin.DashboardHandler
Ops *admin.OpsHandler
User *admin.UserHandler
Group *admin.GroupHandler
Account *admin.AccountHandler

View File

@@ -22,7 +22,6 @@ type OpenAIGatewayHandler struct {
gatewayService *service.OpenAIGatewayService
billingCacheService *service.BillingCacheService
concurrencyHelper *ConcurrencyHelper
opsService *service.OpsService
}
// NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler
@@ -30,21 +29,19 @@ func NewOpenAIGatewayHandler(
gatewayService *service.OpenAIGatewayService,
concurrencyService *service.ConcurrencyService,
billingCacheService *service.BillingCacheService,
opsService *service.OpsService,
) *OpenAIGatewayHandler {
return &OpenAIGatewayHandler{
gatewayService: gatewayService,
billingCacheService: billingCacheService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatNone),
opsService: opsService,
}
}
// Responses handles OpenAI Responses API endpoint
// POST /openai/v1/responses
func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
// Get apiKey and user from context (set by APIKeyAuth middleware)
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
// Get apiKey and user from context (set by ApiKeyAuth middleware)
apiKey, ok := middleware2.GetApiKeyFromContext(c)
if !ok {
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
return
@@ -82,7 +79,6 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
// Extract model and stream
reqModel, _ := reqBody["model"].(string)
reqStream, _ := reqBody["stream"].(bool)
setOpsRequestContext(c, reqModel, reqStream)
// 验证 model 必填
if reqModel == "" {
@@ -239,7 +235,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
Result: result,
APIKey: apiKey,
ApiKey: apiKey,
User: apiKey.User,
Account: usedAccount,
Subscription: subscription,
@@ -282,7 +278,6 @@ func (h *OpenAIGatewayHandler) mapUpstreamError(statusCode int) (int, string, st
// handleStreamingAwareError handles errors that may occur after streaming has started
func (h *OpenAIGatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
if streamStarted {
recordOpsError(c, h.opsService, status, errType, message, service.PlatformOpenAI)
// Stream already started, send error as SSE event then close
flusher, ok := c.Writer.(http.Flusher)
if ok {
@@ -302,7 +297,6 @@ func (h *OpenAIGatewayHandler) handleStreamingAwareError(c *gin.Context, status
// errorResponse returns OpenAI API format error response
func (h *OpenAIGatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) {
recordOpsError(c, h.opsService, status, errType, message, service.PlatformOpenAI)
c.JSON(status, gin.H{
"error": gin.H{
"type": errType,

View File

@@ -39,9 +39,9 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
SiteName: settings.SiteName,
SiteLogo: settings.SiteLogo,
SiteSubtitle: settings.SiteSubtitle,
APIBaseURL: settings.APIBaseURL,
ApiBaseUrl: settings.ApiBaseUrl,
ContactInfo: settings.ContactInfo,
DocURL: settings.DocURL,
DocUrl: settings.DocUrl,
Version: h.version,
})
}

View File

@@ -18,11 +18,11 @@ import (
// UsageHandler handles usage-related requests
type UsageHandler struct {
usageService *service.UsageService
apiKeyService *service.APIKeyService
apiKeyService *service.ApiKeyService
}
// NewUsageHandler creates a new UsageHandler
func NewUsageHandler(usageService *service.UsageService, apiKeyService *service.APIKeyService) *UsageHandler {
func NewUsageHandler(usageService *service.UsageService, apiKeyService *service.ApiKeyService) *UsageHandler {
return &UsageHandler{
usageService: usageService,
apiKeyService: apiKeyService,
@@ -111,7 +111,7 @@ func (h *UsageHandler) List(c *gin.Context) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
filters := usagestats.UsageLogFilters{
UserID: subject.UserID, // Always filter by current user for security
APIKeyID: apiKeyID,
ApiKeyID: apiKeyID,
Model: model,
Stream: stream,
BillingType: billingType,
@@ -235,7 +235,7 @@ func (h *UsageHandler) Stats(c *gin.Context) {
var stats *service.UsageStats
var err error
if apiKeyID > 0 {
stats, err = h.usageService.GetStatsByAPIKey(c.Request.Context(), apiKeyID, startTime, endTime)
stats, err = h.usageService.GetStatsByApiKey(c.Request.Context(), apiKeyID, startTime, endTime)
} else {
stats, err = h.usageService.GetStatsByUser(c.Request.Context(), subject.UserID, startTime, endTime)
}
@@ -346,49 +346,49 @@ func (h *UsageHandler) DashboardModels(c *gin.Context) {
})
}
// BatchAPIKeysUsageRequest represents the request for batch API keys usage
type BatchAPIKeysUsageRequest struct {
APIKeyIDs []int64 `json:"api_key_ids" binding:"required"`
// BatchApiKeysUsageRequest represents the request for batch API keys usage
type BatchApiKeysUsageRequest struct {
ApiKeyIDs []int64 `json:"api_key_ids" binding:"required"`
}
// DashboardAPIKeysUsage handles getting usage stats for user's own API keys
// DashboardApiKeysUsage handles getting usage stats for user's own API keys
// POST /api/v1/usage/dashboard/api-keys-usage
func (h *UsageHandler) DashboardAPIKeysUsage(c *gin.Context) {
func (h *UsageHandler) DashboardApiKeysUsage(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
var req BatchAPIKeysUsageRequest
var req BatchApiKeysUsageRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if len(req.APIKeyIDs) == 0 {
if len(req.ApiKeyIDs) == 0 {
response.Success(c, gin.H{"stats": map[string]any{}})
return
}
// Limit the number of API key IDs to prevent SQL parameter overflow
if len(req.APIKeyIDs) > 100 {
if len(req.ApiKeyIDs) > 100 {
response.BadRequest(c, "Too many API key IDs (maximum 100 allowed)")
return
}
validAPIKeyIDs, err := h.apiKeyService.VerifyOwnership(c.Request.Context(), subject.UserID, req.APIKeyIDs)
validApiKeyIDs, err := h.apiKeyService.VerifyOwnership(c.Request.Context(), subject.UserID, req.ApiKeyIDs)
if err != nil {
response.ErrorFrom(c, err)
return
}
if len(validAPIKeyIDs) == 0 {
if len(validApiKeyIDs) == 0 {
response.Success(c, gin.H{"stats": map[string]any{}})
return
}
stats, err := h.usageService.GetBatchAPIKeyUsageStats(c.Request.Context(), validAPIKeyIDs)
stats, err := h.usageService.GetBatchApiKeyUsageStats(c.Request.Context(), validApiKeyIDs)
if err != nil {
response.ErrorFrom(c, err)
return

View File

@@ -10,7 +10,6 @@ import (
// ProvideAdminHandlers creates the AdminHandlers struct
func ProvideAdminHandlers(
dashboardHandler *admin.DashboardHandler,
opsHandler *admin.OpsHandler,
userHandler *admin.UserHandler,
groupHandler *admin.GroupHandler,
accountHandler *admin.AccountHandler,
@@ -28,7 +27,6 @@ func ProvideAdminHandlers(
) *AdminHandlers {
return &AdminHandlers{
Dashboard: dashboardHandler,
Ops: opsHandler,
User: userHandler,
Group: groupHandler,
Account: accountHandler,
@@ -98,7 +96,6 @@ var ProviderSet = wire.NewSet(
// Admin handlers
admin.NewDashboardHandler,
admin.NewOpsHandler,
admin.NewUserHandler,
admin.NewGroupHandler,
admin.NewAccountHandler,

View File

@@ -1,4 +1,3 @@
// Package claude provides Claude API client constants and utilities.
package claude
// Claude Code 客户端相关常量
@@ -17,13 +16,13 @@ const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleav
// HaikuBetaHeader Haiku 模型使用的 anthropic-beta header不需要 claude-code beta
const HaikuBetaHeader = BetaOAuth + "," + BetaInterleavedThinking
// APIKeyBetaHeader API-key 账号建议使用的 anthropic-beta header不包含 oauth
const APIKeyBetaHeader = BetaClaudeCode + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
// ApiKeyBetaHeader API-key 账号建议使用的 anthropic-beta header不包含 oauth
const ApiKeyBetaHeader = BetaClaudeCode + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
// APIKeyHaikuBetaHeader Haiku 模型在 API-key 账号下使用的 anthropic-beta header不包含 oauth / claude-code
const APIKeyHaikuBetaHeader = BetaInterleavedThinking
// ApiKeyHaikuBetaHeader Haiku 模型在 API-key 账号下使用的 anthropic-beta header不包含 oauth / claude-code
const ApiKeyHaikuBetaHeader = BetaInterleavedThinking
// DefaultHeaders are the default request headers for Claude Code client.
// Claude Code 客户端默认请求头
var DefaultHeaders = map[string]string{
"User-Agent": "claude-cli/2.0.62 (external, cli)",
"X-Stainless-Lang": "js",

View File

@@ -1,4 +1,3 @@
// Package errors provides custom error types and error handling utilities.
// nolint:mnd
package errors

View File

@@ -1,7 +1,7 @@
// Package gemini provides minimal fallback model metadata for Gemini native endpoints.
package gemini
// This package is used when upstream model listing is unavailable (e.g. OAuth token missing AI Studio scopes).
// This package provides minimal fallback model metadata for Gemini native endpoints.
// It is used when upstream model listing is unavailable (e.g. OAuth token missing AI Studio scopes).
type Model struct {
Name string `json:"name"`

View File

@@ -1,4 +1,3 @@
// Package googleapi provides utilities for Google API interactions.
package googleapi
import "net/http"

View File

@@ -1,4 +1,3 @@
// Package oauth provides OAuth 2.0 utilities including PKCE flow, session management, and token exchange.
package oauth
import (

View File

@@ -1,4 +1,3 @@
// Package openai provides OpenAI API models and configuration.
package openai
import _ "embed"

View File

@@ -327,7 +327,7 @@ func ParseIDToken(idToken string) (*IDTokenClaims, error) {
return &claims, nil
}
// UserInfo extracts user information from ID Token claims
// ExtractUserInfo extracts user information from ID Token claims
type UserInfo struct {
Email string
ChatGPTAccountID string

View File

@@ -1,4 +1,3 @@
// Package pagination provides utilities for handling paginated queries and results.
package pagination
// PaginationParams 分页参数

View File

@@ -1,4 +1,3 @@
// Package response provides HTTP response utilities for standardized API responses and error handling.
package response
import (

View File

@@ -1,4 +1,3 @@
// Package sysutil provides system-level utilities for service management.
package sysutil
import (

View File

@@ -1,4 +1,3 @@
// Package usagestats defines types for tracking and reporting API usage statistics.
package usagestats
import "time"
@@ -11,8 +10,8 @@ type DashboardStats struct {
ActiveUsers int64 `json:"active_users"` // 今日有请求的用户数
// API Key 统计
TotalAPIKeys int64 `json:"total_api_keys"`
ActiveAPIKeys int64 `json:"active_api_keys"` // 状态为 active 的 API Key 数
TotalApiKeys int64 `json:"total_api_keys"`
ActiveApiKeys int64 `json:"active_api_keys"` // 状态为 active 的 API Key 数
// 账户统计
TotalAccounts int64 `json:"total_accounts"`
@@ -83,10 +82,10 @@ type UserUsageTrendPoint struct {
ActualCost float64 `json:"actual_cost"` // 实际扣除
}
// APIKeyUsageTrendPoint represents API key usage trend data point
type APIKeyUsageTrendPoint struct {
// ApiKeyUsageTrendPoint represents API key usage trend data point
type ApiKeyUsageTrendPoint struct {
Date string `json:"date"`
APIKeyID int64 `json:"api_key_id"`
ApiKeyID int64 `json:"api_key_id"`
KeyName string `json:"key_name"`
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
@@ -95,8 +94,8 @@ type APIKeyUsageTrendPoint struct {
// UserDashboardStats 用户仪表盘统计
type UserDashboardStats struct {
// API Key 统计
TotalAPIKeys int64 `json:"total_api_keys"`
ActiveAPIKeys int64 `json:"active_api_keys"`
TotalApiKeys int64 `json:"total_api_keys"`
ActiveApiKeys int64 `json:"active_api_keys"`
// 累计 Token 使用统计
TotalRequests int64 `json:"total_requests"`
@@ -129,7 +128,7 @@ type UserDashboardStats struct {
// UsageLogFilters represents filters for usage log queries
type UsageLogFilters struct {
UserID int64
APIKeyID int64
ApiKeyID int64
AccountID int64
GroupID int64
Model string
@@ -158,9 +157,9 @@ type BatchUserUsageStats struct {
TotalActualCost float64 `json:"total_actual_cost"`
}
// BatchAPIKeyUsageStats represents usage stats for a single API key
type BatchAPIKeyUsageStats struct {
APIKeyID int64 `json:"api_key_id"`
// BatchApiKeyUsageStats represents usage stats for a single API key
type BatchApiKeyUsageStats struct {
ApiKeyID int64 `json:"api_key_id"`
TodayActualCost float64 `json:"today_actual_cost"`
TotalActualCost float64 `json:"total_actual_cost"`
}

View File

@@ -135,12 +135,12 @@ func (s *AccountRepoSuite) TestListWithFilters() {
name: "filter_by_type",
setup: func(client *dbent.Client) {
mustCreateAccount(s.T(), client, &service.Account{Name: "t1", Type: service.AccountTypeOAuth})
mustCreateAccount(s.T(), client, &service.Account{Name: "t2", Type: service.AccountTypeAPIKey})
mustCreateAccount(s.T(), client, &service.Account{Name: "t2", Type: service.AccountTypeApiKey})
},
accType: service.AccountTypeAPIKey,
accType: service.AccountTypeApiKey,
wantCount: 1,
validate: func(accounts []service.Account) {
s.Require().Equal(service.AccountTypeAPIKey, accounts[0].Type)
s.Require().Equal(service.AccountTypeApiKey, accounts[0].Type)
},
},
{

View File

@@ -80,7 +80,7 @@ func TestUserRepository_RemoveGroupFromAllowedGroups_RemovesAllOccurrences(t *te
require.NotContains(t, u2After.AllowedGroups, targetGroup.ID)
}
func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsAPIKeys(t *testing.T) {
func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsApiKeys(t *testing.T) {
ctx := context.Background()
tx := testEntTx(t)
entClient := tx.Client()
@@ -98,7 +98,7 @@ func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsAPIKeys(t *t
userRepo := newUserRepositoryWithSQL(entClient, tx)
groupRepo := newGroupRepositoryWithSQL(entClient, tx)
apiKeyRepo := NewAPIKeyRepository(entClient)
apiKeyRepo := NewApiKeyRepository(entClient)
u := &service.User{
Email: uniqueTestValue(t, "cascade-user") + "@example.com",
@@ -110,7 +110,7 @@ func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsAPIKeys(t *t
}
require.NoError(t, userRepo.Create(ctx, u))
key := &service.APIKey{
key := &service.ApiKey{
UserID: u.ID,
Key: uniqueTestValue(t, "sk-test-delete-cascade"),
Name: "test key",

View File

@@ -24,7 +24,7 @@ type apiKeyCache struct {
rdb *redis.Client
}
func NewAPIKeyCache(rdb *redis.Client) service.APIKeyCache {
func NewApiKeyCache(rdb *redis.Client) service.ApiKeyCache {
return &apiKeyCache{rdb: rdb}
}

View File

@@ -13,11 +13,11 @@ import (
"github.com/stretchr/testify/suite"
)
type APIKeyCacheSuite struct {
type ApiKeyCacheSuite struct {
IntegrationRedisSuite
}
func (s *APIKeyCacheSuite) TestCreateAttemptCount() {
func (s *ApiKeyCacheSuite) TestCreateAttemptCount() {
tests := []struct {
name string
fn func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache)
@@ -78,7 +78,7 @@ func (s *APIKeyCacheSuite) TestCreateAttemptCount() {
}
}
func (s *APIKeyCacheSuite) TestDailyUsage() {
func (s *ApiKeyCacheSuite) TestDailyUsage() {
tests := []struct {
name string
fn func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache)
@@ -122,6 +122,6 @@ func (s *APIKeyCacheSuite) TestDailyUsage() {
}
}
func TestAPIKeyCacheSuite(t *testing.T) {
suite.Run(t, new(APIKeyCacheSuite))
func TestApiKeyCacheSuite(t *testing.T) {
suite.Run(t, new(ApiKeyCacheSuite))
}

View File

@@ -9,7 +9,7 @@ import (
"github.com/stretchr/testify/require"
)
func TestAPIKeyRateLimitKey(t *testing.T) {
func TestApiKeyRateLimitKey(t *testing.T) {
tests := []struct {
name string
userID int64

View File

@@ -16,17 +16,17 @@ type apiKeyRepository struct {
client *dbent.Client
}
func NewAPIKeyRepository(client *dbent.Client) service.APIKeyRepository {
func NewApiKeyRepository(client *dbent.Client) service.ApiKeyRepository {
return &apiKeyRepository{client: client}
}
func (r *apiKeyRepository) activeQuery() *dbent.APIKeyQuery {
func (r *apiKeyRepository) activeQuery() *dbent.ApiKeyQuery {
// 默认过滤已软删除记录,避免删除后仍被查询到。
return r.client.APIKey.Query().Where(apikey.DeletedAtIsNil())
return r.client.ApiKey.Query().Where(apikey.DeletedAtIsNil())
}
func (r *apiKeyRepository) Create(ctx context.Context, key *service.APIKey) error {
created, err := r.client.APIKey.Create().
func (r *apiKeyRepository) Create(ctx context.Context, key *service.ApiKey) error {
created, err := r.client.ApiKey.Create().
SetUserID(key.UserID).
SetKey(key.Key).
SetName(key.Name).
@@ -38,10 +38,10 @@ func (r *apiKeyRepository) Create(ctx context.Context, key *service.APIKey) erro
key.CreatedAt = created.CreatedAt
key.UpdatedAt = created.UpdatedAt
}
return translatePersistenceError(err, nil, service.ErrAPIKeyExists)
return translatePersistenceError(err, nil, service.ErrApiKeyExists)
}
func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.APIKey, error) {
func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) {
m, err := r.activeQuery().
Where(apikey.IDEQ(id)).
WithUser().
@@ -49,7 +49,7 @@ func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.APIK
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return nil, service.ErrAPIKeyNotFound
return nil, service.ErrApiKeyNotFound
}
return nil, err
}
@@ -59,7 +59,7 @@ func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.APIK
// GetOwnerID 根据 API Key ID 获取其所有者(用户)的 ID。
// 相比 GetByID此方法性能更优因为
// - 使用 Select() 只查询 user_id 字段,减少数据传输量
// - 不加载完整的 APIKey 实体及其关联数据User、Group 等)
// - 不加载完整的 ApiKey 实体及其关联数据User、Group 等)
// - 适用于权限验证等只需用户 ID 的场景(如删除前的所有权检查)
func (r *apiKeyRepository) GetOwnerID(ctx context.Context, id int64) (int64, error) {
m, err := r.activeQuery().
@@ -68,14 +68,14 @@ func (r *apiKeyRepository) GetOwnerID(ctx context.Context, id int64) (int64, err
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return 0, service.ErrAPIKeyNotFound
return 0, service.ErrApiKeyNotFound
}
return 0, err
}
return m.UserID, nil
}
func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) {
m, err := r.activeQuery().
Where(apikey.KeyEQ(key)).
WithUser().
@@ -83,21 +83,21 @@ func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.A
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return nil, service.ErrAPIKeyNotFound
return nil, service.ErrApiKeyNotFound
}
return nil, err
}
return apiKeyEntityToService(m), nil
}
func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) error {
func (r *apiKeyRepository) Update(ctx context.Context, key *service.ApiKey) error {
// 使用原子操作:将软删除检查与更新合并到同一语句,避免竞态条件。
// 之前的实现先检查 Exist 再 UpdateOneID若在两步之间发生软删除
// 则会更新已删除的记录。
// 这里选择 Update().Where(),确保只有未软删除记录能被更新。
// 同时显式设置 updated_at避免二次查询带来的并发可见性问题。
now := time.Now()
builder := r.client.APIKey.Update().
builder := r.client.ApiKey.Update().
Where(apikey.IDEQ(key.ID), apikey.DeletedAtIsNil()).
SetName(key.Name).
SetStatus(key.Status).
@@ -114,7 +114,7 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) erro
}
if affected == 0 {
// 更新影响行数为 0说明记录不存在或已被软删除。
return service.ErrAPIKeyNotFound
return service.ErrApiKeyNotFound
}
// 使用同一时间戳回填,避免并发删除导致二次查询失败。
@@ -124,18 +124,18 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) erro
func (r *apiKeyRepository) Delete(ctx context.Context, id int64) error {
// 显式软删除:避免依赖 Hook 行为,确保 deleted_at 一定被设置。
affected, err := r.client.APIKey.Update().
affected, err := r.client.ApiKey.Update().
Where(apikey.IDEQ(id), apikey.DeletedAtIsNil()).
SetDeletedAt(time.Now()).
Save(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return service.ErrAPIKeyNotFound
return service.ErrApiKeyNotFound
}
return err
}
if affected == 0 {
exists, err := r.client.APIKey.Query().
exists, err := r.client.ApiKey.Query().
Where(apikey.IDEQ(id)).
Exist(mixins.SkipSoftDelete(ctx))
if err != nil {
@@ -144,12 +144,12 @@ func (r *apiKeyRepository) Delete(ctx context.Context, id int64) error {
if exists {
return nil
}
return service.ErrAPIKeyNotFound
return service.ErrApiKeyNotFound
}
return nil
}
func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
q := r.activeQuery().Where(apikey.UserIDEQ(userID))
total, err := q.Count(ctx)
@@ -167,7 +167,7 @@ func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, param
return nil, nil, err
}
outKeys := make([]service.APIKey, 0, len(keys))
outKeys := make([]service.ApiKey, 0, len(keys))
for i := range keys {
outKeys = append(outKeys, *apiKeyEntityToService(keys[i]))
}
@@ -180,7 +180,7 @@ func (r *apiKeyRepository) VerifyOwnership(ctx context.Context, userID int64, ap
return []int64{}, nil
}
ids, err := r.client.APIKey.Query().
ids, err := r.client.ApiKey.Query().
Where(apikey.UserIDEQ(userID), apikey.IDIn(apiKeyIDs...), apikey.DeletedAtIsNil()).
IDs(ctx)
if err != nil {
@@ -199,7 +199,7 @@ func (r *apiKeyRepository) ExistsByKey(ctx context.Context, key string) (bool, e
return count > 0, err
}
func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
q := r.activeQuery().Where(apikey.GroupIDEQ(groupID))
total, err := q.Count(ctx)
@@ -217,7 +217,7 @@ func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, par
return nil, nil, err
}
outKeys := make([]service.APIKey, 0, len(keys))
outKeys := make([]service.ApiKey, 0, len(keys))
for i := range keys {
outKeys = append(outKeys, *apiKeyEntityToService(keys[i]))
}
@@ -225,8 +225,8 @@ func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, par
return outKeys, paginationResultFromTotal(int64(total), params), nil
}
// SearchAPIKeys searches API keys by user ID and/or keyword (name)
func (r *apiKeyRepository) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.APIKey, error) {
// SearchApiKeys searches API keys by user ID and/or keyword (name)
func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) {
q := r.activeQuery()
if userID > 0 {
q = q.Where(apikey.UserIDEQ(userID))
@@ -241,7 +241,7 @@ func (r *apiKeyRepository) SearchAPIKeys(ctx context.Context, userID int64, keyw
return nil, err
}
outKeys := make([]service.APIKey, 0, len(keys))
outKeys := make([]service.ApiKey, 0, len(keys))
for i := range keys {
outKeys = append(outKeys, *apiKeyEntityToService(keys[i]))
}
@@ -250,7 +250,7 @@ func (r *apiKeyRepository) SearchAPIKeys(ctx context.Context, userID int64, keyw
// ClearGroupIDByGroupID 将指定分组的所有 API Key 的 group_id 设为 nil
func (r *apiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
n, err := r.client.APIKey.Update().
n, err := r.client.ApiKey.Update().
Where(apikey.GroupIDEQ(groupID), apikey.DeletedAtIsNil()).
ClearGroupID().
Save(ctx)
@@ -263,11 +263,11 @@ func (r *apiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (i
return int64(count), err
}
func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey {
func apiKeyEntityToService(m *dbent.ApiKey) *service.ApiKey {
if m == nil {
return nil
}
out := &service.APIKey{
out := &service.ApiKey{
ID: m.ID,
UserID: m.UserID,
Key: m.Key,

View File

@@ -12,30 +12,30 @@ import (
"github.com/stretchr/testify/suite"
)
type APIKeyRepoSuite struct {
type ApiKeyRepoSuite struct {
suite.Suite
ctx context.Context
client *dbent.Client
repo *apiKeyRepository
}
func (s *APIKeyRepoSuite) SetupTest() {
func (s *ApiKeyRepoSuite) SetupTest() {
s.ctx = context.Background()
tx := testEntTx(s.T())
s.client = tx.Client()
s.repo = NewAPIKeyRepository(s.client).(*apiKeyRepository)
s.repo = NewApiKeyRepository(s.client).(*apiKeyRepository)
}
func TestAPIKeyRepoSuite(t *testing.T) {
suite.Run(t, new(APIKeyRepoSuite))
func TestApiKeyRepoSuite(t *testing.T) {
suite.Run(t, new(ApiKeyRepoSuite))
}
// --- Create / GetByID / GetByKey ---
func (s *APIKeyRepoSuite) TestCreate() {
func (s *ApiKeyRepoSuite) TestCreate() {
user := s.mustCreateUser("create@test.com")
key := &service.APIKey{
key := &service.ApiKey{
UserID: user.ID,
Key: "sk-create-test",
Name: "Test Key",
@@ -51,16 +51,16 @@ func (s *APIKeyRepoSuite) TestCreate() {
s.Require().Equal("sk-create-test", got.Key)
}
func (s *APIKeyRepoSuite) TestGetByID_NotFound() {
func (s *ApiKeyRepoSuite) TestGetByID_NotFound() {
_, err := s.repo.GetByID(s.ctx, 999999)
s.Require().Error(err, "expected error for non-existent ID")
}
func (s *APIKeyRepoSuite) TestGetByKey() {
func (s *ApiKeyRepoSuite) TestGetByKey() {
user := s.mustCreateUser("getbykey@test.com")
group := s.mustCreateGroup("g-key")
key := &service.APIKey{
key := &service.ApiKey{
UserID: user.ID,
Key: "sk-getbykey",
Name: "My Key",
@@ -78,16 +78,16 @@ func (s *APIKeyRepoSuite) TestGetByKey() {
s.Require().Equal(group.ID, got.Group.ID)
}
func (s *APIKeyRepoSuite) TestGetByKey_NotFound() {
func (s *ApiKeyRepoSuite) TestGetByKey_NotFound() {
_, err := s.repo.GetByKey(s.ctx, "non-existent-key")
s.Require().Error(err, "expected error for non-existent key")
}
// --- Update ---
func (s *APIKeyRepoSuite) TestUpdate() {
func (s *ApiKeyRepoSuite) TestUpdate() {
user := s.mustCreateUser("update@test.com")
key := &service.APIKey{
key := &service.ApiKey{
UserID: user.ID,
Key: "sk-update",
Name: "Original",
@@ -108,10 +108,10 @@ func (s *APIKeyRepoSuite) TestUpdate() {
s.Require().Equal(service.StatusDisabled, got.Status)
}
func (s *APIKeyRepoSuite) TestUpdate_ClearGroupID() {
func (s *ApiKeyRepoSuite) TestUpdate_ClearGroupID() {
user := s.mustCreateUser("cleargroup@test.com")
group := s.mustCreateGroup("g-clear")
key := &service.APIKey{
key := &service.ApiKey{
UserID: user.ID,
Key: "sk-clear-group",
Name: "Group Key",
@@ -131,9 +131,9 @@ func (s *APIKeyRepoSuite) TestUpdate_ClearGroupID() {
// --- Delete ---
func (s *APIKeyRepoSuite) TestDelete() {
func (s *ApiKeyRepoSuite) TestDelete() {
user := s.mustCreateUser("delete@test.com")
key := &service.APIKey{
key := &service.ApiKey{
UserID: user.ID,
Key: "sk-delete",
Name: "Delete Me",
@@ -150,10 +150,10 @@ func (s *APIKeyRepoSuite) TestDelete() {
// --- ListByUserID / CountByUserID ---
func (s *APIKeyRepoSuite) TestListByUserID() {
func (s *ApiKeyRepoSuite) TestListByUserID() {
user := s.mustCreateUser("listbyuser@test.com")
s.mustCreateAPIKey(user.ID, "sk-list-1", "Key 1", nil)
s.mustCreateAPIKey(user.ID, "sk-list-2", "Key 2", nil)
s.mustCreateApiKey(user.ID, "sk-list-1", "Key 1", nil)
s.mustCreateApiKey(user.ID, "sk-list-2", "Key 2", nil)
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "ListByUserID")
@@ -161,10 +161,10 @@ func (s *APIKeyRepoSuite) TestListByUserID() {
s.Require().Equal(int64(2), page.Total)
}
func (s *APIKeyRepoSuite) TestListByUserID_Pagination() {
func (s *ApiKeyRepoSuite) TestListByUserID_Pagination() {
user := s.mustCreateUser("paging@test.com")
for i := 0; i < 5; i++ {
s.mustCreateAPIKey(user.ID, "sk-page-"+string(rune('a'+i)), "Key", nil)
s.mustCreateApiKey(user.ID, "sk-page-"+string(rune('a'+i)), "Key", nil)
}
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 2})
@@ -174,10 +174,10 @@ func (s *APIKeyRepoSuite) TestListByUserID_Pagination() {
s.Require().Equal(3, page.Pages)
}
func (s *APIKeyRepoSuite) TestCountByUserID() {
func (s *ApiKeyRepoSuite) TestCountByUserID() {
user := s.mustCreateUser("count@test.com")
s.mustCreateAPIKey(user.ID, "sk-count-1", "K1", nil)
s.mustCreateAPIKey(user.ID, "sk-count-2", "K2", nil)
s.mustCreateApiKey(user.ID, "sk-count-1", "K1", nil)
s.mustCreateApiKey(user.ID, "sk-count-2", "K2", nil)
count, err := s.repo.CountByUserID(s.ctx, user.ID)
s.Require().NoError(err, "CountByUserID")
@@ -186,13 +186,13 @@ func (s *APIKeyRepoSuite) TestCountByUserID() {
// --- ListByGroupID / CountByGroupID ---
func (s *APIKeyRepoSuite) TestListByGroupID() {
func (s *ApiKeyRepoSuite) TestListByGroupID() {
user := s.mustCreateUser("listbygroup@test.com")
group := s.mustCreateGroup("g-list")
s.mustCreateAPIKey(user.ID, "sk-grp-1", "K1", &group.ID)
s.mustCreateAPIKey(user.ID, "sk-grp-2", "K2", &group.ID)
s.mustCreateAPIKey(user.ID, "sk-grp-3", "K3", nil) // no group
s.mustCreateApiKey(user.ID, "sk-grp-1", "K1", &group.ID)
s.mustCreateApiKey(user.ID, "sk-grp-2", "K2", &group.ID)
s.mustCreateApiKey(user.ID, "sk-grp-3", "K3", nil) // no group
keys, page, err := s.repo.ListByGroupID(s.ctx, group.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "ListByGroupID")
@@ -202,10 +202,10 @@ func (s *APIKeyRepoSuite) TestListByGroupID() {
s.Require().NotNil(keys[0].User)
}
func (s *APIKeyRepoSuite) TestCountByGroupID() {
func (s *ApiKeyRepoSuite) TestCountByGroupID() {
user := s.mustCreateUser("countgroup@test.com")
group := s.mustCreateGroup("g-count")
s.mustCreateAPIKey(user.ID, "sk-gc-1", "K1", &group.ID)
s.mustCreateApiKey(user.ID, "sk-gc-1", "K1", &group.ID)
count, err := s.repo.CountByGroupID(s.ctx, group.ID)
s.Require().NoError(err, "CountByGroupID")
@@ -214,9 +214,9 @@ func (s *APIKeyRepoSuite) TestCountByGroupID() {
// --- ExistsByKey ---
func (s *APIKeyRepoSuite) TestExistsByKey() {
func (s *ApiKeyRepoSuite) TestExistsByKey() {
user := s.mustCreateUser("exists@test.com")
s.mustCreateAPIKey(user.ID, "sk-exists", "K", nil)
s.mustCreateApiKey(user.ID, "sk-exists", "K", nil)
exists, err := s.repo.ExistsByKey(s.ctx, "sk-exists")
s.Require().NoError(err, "ExistsByKey")
@@ -227,47 +227,47 @@ func (s *APIKeyRepoSuite) TestExistsByKey() {
s.Require().False(notExists)
}
// --- SearchAPIKeys ---
// --- SearchApiKeys ---
func (s *APIKeyRepoSuite) TestSearchAPIKeys() {
func (s *ApiKeyRepoSuite) TestSearchApiKeys() {
user := s.mustCreateUser("search@test.com")
s.mustCreateAPIKey(user.ID, "sk-search-1", "Production Key", nil)
s.mustCreateAPIKey(user.ID, "sk-search-2", "Development Key", nil)
s.mustCreateApiKey(user.ID, "sk-search-1", "Production Key", nil)
s.mustCreateApiKey(user.ID, "sk-search-2", "Development Key", nil)
found, err := s.repo.SearchAPIKeys(s.ctx, user.ID, "prod", 10)
s.Require().NoError(err, "SearchAPIKeys")
found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "prod", 10)
s.Require().NoError(err, "SearchApiKeys")
s.Require().Len(found, 1)
s.Require().Contains(found[0].Name, "Production")
}
func (s *APIKeyRepoSuite) TestSearchAPIKeys_NoKeyword() {
func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoKeyword() {
user := s.mustCreateUser("searchnokw@test.com")
s.mustCreateAPIKey(user.ID, "sk-nk-1", "K1", nil)
s.mustCreateAPIKey(user.ID, "sk-nk-2", "K2", nil)
s.mustCreateApiKey(user.ID, "sk-nk-1", "K1", nil)
s.mustCreateApiKey(user.ID, "sk-nk-2", "K2", nil)
found, err := s.repo.SearchAPIKeys(s.ctx, user.ID, "", 10)
found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "", 10)
s.Require().NoError(err)
s.Require().Len(found, 2)
}
func (s *APIKeyRepoSuite) TestSearchAPIKeys_NoUserID() {
func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoUserID() {
user := s.mustCreateUser("searchnouid@test.com")
s.mustCreateAPIKey(user.ID, "sk-nu-1", "TestKey", nil)
s.mustCreateApiKey(user.ID, "sk-nu-1", "TestKey", nil)
found, err := s.repo.SearchAPIKeys(s.ctx, 0, "testkey", 10)
found, err := s.repo.SearchApiKeys(s.ctx, 0, "testkey", 10)
s.Require().NoError(err)
s.Require().Len(found, 1)
}
// --- ClearGroupIDByGroupID ---
func (s *APIKeyRepoSuite) TestClearGroupIDByGroupID() {
func (s *ApiKeyRepoSuite) TestClearGroupIDByGroupID() {
user := s.mustCreateUser("cleargrp@test.com")
group := s.mustCreateGroup("g-clear-bulk")
k1 := s.mustCreateAPIKey(user.ID, "sk-clr-1", "K1", &group.ID)
k2 := s.mustCreateAPIKey(user.ID, "sk-clr-2", "K2", &group.ID)
s.mustCreateAPIKey(user.ID, "sk-clr-3", "K3", nil) // no group
k1 := s.mustCreateApiKey(user.ID, "sk-clr-1", "K1", &group.ID)
k2 := s.mustCreateApiKey(user.ID, "sk-clr-2", "K2", &group.ID)
s.mustCreateApiKey(user.ID, "sk-clr-3", "K3", nil) // no group
affected, err := s.repo.ClearGroupIDByGroupID(s.ctx, group.ID)
s.Require().NoError(err, "ClearGroupIDByGroupID")
@@ -284,10 +284,10 @@ func (s *APIKeyRepoSuite) TestClearGroupIDByGroupID() {
// --- Combined CRUD/Search/ClearGroupID (original test preserved as integration) ---
func (s *APIKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
user := s.mustCreateUser("k@example.com")
group := s.mustCreateGroup("g-k")
key := s.mustCreateAPIKey(user.ID, "sk-test-1", "My Key", &group.ID)
key := s.mustCreateApiKey(user.ID, "sk-test-1", "My Key", &group.ID)
key.GroupID = &group.ID
got, err := s.repo.GetByKey(s.ctx, key.Key)
@@ -320,13 +320,13 @@ func (s *APIKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
s.Require().NoError(err, "ExistsByKey")
s.Require().True(exists, "expected key to exist")
found, err := s.repo.SearchAPIKeys(s.ctx, user.ID, "renam", 10)
s.Require().NoError(err, "SearchAPIKeys")
found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "renam", 10)
s.Require().NoError(err, "SearchApiKeys")
s.Require().Len(found, 1)
s.Require().Equal(key.ID, found[0].ID)
// ClearGroupIDByGroupID
k2 := s.mustCreateAPIKey(user.ID, "sk-test-2", "Group Key", &group.ID)
k2 := s.mustCreateApiKey(user.ID, "sk-test-2", "Group Key", &group.ID)
k2.GroupID = &group.ID
countBefore, err := s.repo.CountByGroupID(s.ctx, group.ID)
@@ -346,7 +346,7 @@ func (s *APIKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
s.Require().Equal(int64(0), countAfter, "expected 0 keys in group after clear")
}
func (s *APIKeyRepoSuite) mustCreateUser(email string) *service.User {
func (s *ApiKeyRepoSuite) mustCreateUser(email string) *service.User {
s.T().Helper()
u, err := s.client.User.Create().
@@ -359,7 +359,7 @@ func (s *APIKeyRepoSuite) mustCreateUser(email string) *service.User {
return userEntityToService(u)
}
func (s *APIKeyRepoSuite) mustCreateGroup(name string) *service.Group {
func (s *ApiKeyRepoSuite) mustCreateGroup(name string) *service.Group {
s.T().Helper()
g, err := s.client.Group.Create().
@@ -370,10 +370,10 @@ func (s *APIKeyRepoSuite) mustCreateGroup(name string) *service.Group {
return groupEntityToService(g)
}
func (s *APIKeyRepoSuite) mustCreateAPIKey(userID int64, key, name string, groupID *int64) *service.APIKey {
func (s *ApiKeyRepoSuite) mustCreateApiKey(userID int64, key, name string, groupID *int64) *service.ApiKey {
s.T().Helper()
k := &service.APIKey{
k := &service.ApiKey{
UserID: userID,
Key: key,
Name: name,

View File

@@ -1,4 +1,4 @@
// Package repository 提供应用程序的基础设施层组件。
// Package infrastructure 提供应用程序的基础设施层组件。
// 包括数据库连接初始化、ORM 客户端管理、Redis 连接、数据库迁移等核心功能。
package repository

View File

@@ -243,7 +243,7 @@ func mustCreateAccount(t *testing.T, client *dbent.Client, a *service.Account) *
return a
}
func mustCreateAPIKey(t *testing.T, client *dbent.Client, k *service.APIKey) *service.APIKey {
func mustCreateApiKey(t *testing.T, client *dbent.Client, k *service.ApiKey) *service.ApiKey {
t.Helper()
ctx := context.Background()
@@ -257,7 +257,7 @@ func mustCreateAPIKey(t *testing.T, client *dbent.Client, k *service.APIKey) *se
k.Name = "default"
}
create := client.APIKey.Create().
create := client.ApiKey.Create().
SetUserID(k.UserID).
SetKey(k.Key).
SetName(k.Name).

View File

@@ -293,8 +293,8 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
// 2. Clear group_id for api keys bound to this group.
// 仅更新未软删除的记录,避免修改已删除数据,保证审计与历史回溯一致性。
// 与 APIKeyRepository 的软删除语义保持一致,减少跨模块行为差异。
if _, err := txClient.APIKey.Update().
// 与 ApiKeyRepository 的软删除语义保持一致,减少跨模块行为差异。
if _, err := txClient.ApiKey.Update().
Where(apikey.GroupIDEQ(id), apikey.DeletedAtIsNil()).
ClearGroupID().
Save(ctx); err != nil {

View File

@@ -34,15 +34,15 @@ func createEntUser(t *testing.T, ctx context.Context, client *dbent.Client, emai
return u
}
func TestEntSoftDelete_APIKey_DefaultFilterAndSkip(t *testing.T) {
func TestEntSoftDelete_ApiKey_DefaultFilterAndSkip(t *testing.T) {
ctx := context.Background()
// 使用全局 ent client确保软删除验证在实际持久化数据上进行。
client := testEntClient(t)
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user")+"@example.com")
repo := NewAPIKeyRepository(client)
key := &service.APIKey{
repo := NewApiKeyRepository(client)
key := &service.ApiKey{
UserID: u.ID,
Key: uniqueSoftDeleteValue(t, "sk-soft-delete"),
Name: "soft-delete",
@@ -53,28 +53,28 @@ func TestEntSoftDelete_APIKey_DefaultFilterAndSkip(t *testing.T) {
require.NoError(t, repo.Delete(ctx, key.ID), "soft delete api key")
_, err := repo.GetByID(ctx, key.ID)
require.ErrorIs(t, err, service.ErrAPIKeyNotFound, "deleted rows should be hidden by default")
require.ErrorIs(t, err, service.ErrApiKeyNotFound, "deleted rows should be hidden by default")
_, err = client.APIKey.Query().Where(apikey.IDEQ(key.ID)).Only(ctx)
_, err = client.ApiKey.Query().Where(apikey.IDEQ(key.ID)).Only(ctx)
require.Error(t, err, "default ent query should not see soft-deleted rows")
require.True(t, dbent.IsNotFound(err), "expected ent not-found after default soft delete filter")
got, err := client.APIKey.Query().
got, err := client.ApiKey.Query().
Where(apikey.IDEQ(key.ID)).
Only(mixins.SkipSoftDelete(ctx))
require.NoError(t, err, "SkipSoftDelete should include soft-deleted rows")
require.NotNil(t, got.DeletedAt, "deleted_at should be set after soft delete")
}
func TestEntSoftDelete_APIKey_DeleteIdempotent(t *testing.T) {
func TestEntSoftDelete_ApiKey_DeleteIdempotent(t *testing.T) {
ctx := context.Background()
// 使用全局 ent client避免事务回滚影响幂等性验证。
client := testEntClient(t)
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user2")+"@example.com")
repo := NewAPIKeyRepository(client)
key := &service.APIKey{
repo := NewApiKeyRepository(client)
key := &service.ApiKey{
UserID: u.ID,
Key: uniqueSoftDeleteValue(t, "sk-soft-delete2"),
Name: "soft-delete2",
@@ -86,15 +86,15 @@ func TestEntSoftDelete_APIKey_DeleteIdempotent(t *testing.T) {
require.NoError(t, repo.Delete(ctx, key.ID), "second delete should be idempotent")
}
func TestEntSoftDelete_APIKey_HardDeleteViaSkipSoftDelete(t *testing.T) {
func TestEntSoftDelete_ApiKey_HardDeleteViaSkipSoftDelete(t *testing.T) {
ctx := context.Background()
// 使用全局 ent client确保 SkipSoftDelete 的硬删除语义可验证。
client := testEntClient(t)
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user3")+"@example.com")
repo := NewAPIKeyRepository(client)
key := &service.APIKey{
repo := NewApiKeyRepository(client)
key := &service.ApiKey{
UserID: u.ID,
Key: uniqueSoftDeleteValue(t, "sk-soft-delete3"),
Name: "soft-delete3",
@@ -105,10 +105,10 @@ func TestEntSoftDelete_APIKey_HardDeleteViaSkipSoftDelete(t *testing.T) {
require.NoError(t, repo.Delete(ctx, key.ID), "soft delete api key")
// Hard delete using SkipSoftDelete so the hook doesn't convert it to update-deleted_at.
_, err := client.APIKey.Delete().Where(apikey.IDEQ(key.ID)).Exec(mixins.SkipSoftDelete(ctx))
_, err := client.ApiKey.Delete().Where(apikey.IDEQ(key.ID)).Exec(mixins.SkipSoftDelete(ctx))
require.NoError(t, err, "hard delete")
_, err = client.APIKey.Query().
_, err = client.ApiKey.Query().
Where(apikey.IDEQ(key.ID)).
Only(mixins.SkipSoftDelete(ctx))
require.True(t, dbent.IsNotFound(err), "expected row to be hard deleted")

View File

@@ -4,12 +4,13 @@ import (
"context"
"database/sql"
"errors"
"fmt"
"sort"
"strings"
dbent "github.com/Wei-Shaw/sub2api/ent"
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -17,14 +18,15 @@ import (
type userRepository struct {
client *dbent.Client
sql sqlExecutor
}
func NewUserRepository(client *dbent.Client, sqlDB *sql.DB) service.UserRepository {
return newUserRepositoryWithSQL(client, sqlDB)
}
func newUserRepositoryWithSQL(client *dbent.Client, _ sqlExecutor) *userRepository {
return &userRepository{client: client}
func newUserRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *userRepository {
return &userRepository{client: client, sql: sqlq}
}
func (r *userRepository) Create(ctx context.Context, userIn *service.User) error {
@@ -194,7 +196,11 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
// If attribute filters are specified, we need to filter by user IDs first
var allowedUserIDs []int64
if len(filters.Attributes) > 0 {
allowedUserIDs = r.filterUsersByAttributes(ctx, filters.Attributes)
var attrErr error
allowedUserIDs, attrErr = r.filterUsersByAttributes(ctx, filters.Attributes)
if attrErr != nil {
return nil, nil, attrErr
}
if len(allowedUserIDs) == 0 {
// No users match the attribute filters
return []service.User{}, paginationResultFromTotal(0, params), nil
@@ -262,56 +268,53 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
}
// filterUsersByAttributes returns user IDs that match ALL the given attribute filters
func (r *userRepository) filterUsersByAttributes(ctx context.Context, attrs map[int64]string) []int64 {
func (r *userRepository) filterUsersByAttributes(ctx context.Context, attrs map[int64]string) ([]int64, error) {
if len(attrs) == 0 {
return nil
return nil, nil
}
// For each attribute filter, get the set of matching user IDs
// Then intersect all sets to get users matching ALL filters
var resultSet map[int64]struct{}
first := true
if r.sql == nil {
return nil, fmt.Errorf("sql executor is not configured")
}
clauses := make([]string, 0, len(attrs))
args := make([]any, 0, len(attrs)*2+1)
argIndex := 1
for attrID, value := range attrs {
// Query user_attribute_values for this attribute
values, err := r.client.UserAttributeValue.Query().
Where(
userattributevalue.AttributeIDEQ(attrID),
userattributevalue.ValueContainsFold(value),
).
All(ctx)
if err != nil {
continue
}
currentSet := make(map[int64]struct{}, len(values))
for _, v := range values {
currentSet[v.UserID] = struct{}{}
}
if first {
resultSet = currentSet
first = false
} else {
// Intersect with previous results
for userID := range resultSet {
if _, ok := currentSet[userID]; !ok {
delete(resultSet, userID)
}
}
}
// Early exit if no users match
if len(resultSet) == 0 {
return nil
}
clauses = append(clauses, fmt.Sprintf("(attribute_id = $%d AND value ILIKE $%d)", argIndex, argIndex+1))
args = append(args, attrID, "%"+value+"%")
argIndex += 2
}
result := make([]int64, 0, len(resultSet))
for userID := range resultSet {
query := fmt.Sprintf(
`SELECT user_id
FROM user_attribute_values
WHERE %s
GROUP BY user_id
HAVING COUNT(DISTINCT attribute_id) = $%d`,
strings.Join(clauses, " OR "),
argIndex,
)
args = append(args, len(attrs))
rows, err := r.sql.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
result := make([]int64, 0)
for rows.Next() {
var userID int64
if scanErr := rows.Scan(&userID); scanErr != nil {
return nil, scanErr
}
result = append(result, userID)
}
return result
if err := rows.Err(); err != nil {
return nil, err
}
return result, nil
}
func (r *userRepository) UpdateBalance(ctx context.Context, id int64, amount float64) error {

View File

@@ -28,13 +28,12 @@ func ProvideConcurrencyCache(rdb *redis.Client, cfg *config.Config) service.Conc
// ProviderSet is the Wire provider set for all repositories
var ProviderSet = wire.NewSet(
NewUserRepository,
NewAPIKeyRepository,
NewApiKeyRepository,
NewGroupRepository,
NewAccountRepository,
NewProxyRepository,
NewRedeemCodeRepository,
NewUsageLogRepository,
NewOpsRepository,
NewSettingRepository,
NewUserSubscriptionRepository,
NewUserAttributeDefinitionRepository,
@@ -43,7 +42,8 @@ var ProviderSet = wire.NewSet(
// Cache implementations
NewGatewayCache,
NewBillingCache,
NewAPIKeyCache,
NewApiKeyCache,
NewTempUnschedCache,
ProvideConcurrencyCache,
NewEmailCache,
NewIdentityCache,

View File

@@ -1,4 +1,3 @@
// Package server provides HTTP server setup and routing configuration.
package server
import (
@@ -26,8 +25,8 @@ func ProvideRouter(
handlers *handler.Handlers,
jwtAuth middleware2.JWTAuthMiddleware,
adminAuth middleware2.AdminAuthMiddleware,
apiKeyAuth middleware2.APIKeyAuthMiddleware,
apiKeyService *service.APIKeyService,
apiKeyAuth middleware2.ApiKeyAuthMiddleware,
apiKeyService *service.ApiKeyService,
subscriptionService *service.SubscriptionService,
) *gin.Engine {
if cfg.Server.Mode == "release" {

View File

@@ -32,7 +32,7 @@ func adminAuth(
// 检查 x-api-key headerAdmin API Key 认证)
apiKey := c.GetHeader("x-api-key")
if apiKey != "" {
if !validateAdminAPIKey(c, apiKey, settingService, userService) {
if !validateAdminApiKey(c, apiKey, settingService, userService) {
return
}
c.Next()
@@ -52,48 +52,19 @@ func adminAuth(
}
}
// WebSocket 请求无法设置自定义 header允许在 query 中携带凭证
if isWebSocketRequest(c) {
if token := strings.TrimSpace(c.Query("token")); token != "" {
if !validateJWTForAdmin(c, token, authService, userService) {
return
}
c.Next()
return
}
if apiKey := strings.TrimSpace(c.Query("api_key")); apiKey != "" {
if !validateAdminAPIKey(c, apiKey, settingService, userService) {
return
}
c.Next()
return
}
}
// 无有效认证信息
AbortWithError(c, 401, "UNAUTHORIZED", "Authorization required")
}
}
func isWebSocketRequest(c *gin.Context) bool {
if c == nil || c.Request == nil {
return false
}
if strings.EqualFold(c.GetHeader("Upgrade"), "websocket") {
return true
}
conn := strings.ToLower(c.GetHeader("Connection"))
return strings.Contains(conn, "upgrade") && strings.EqualFold(c.GetHeader("Upgrade"), "websocket")
}
// validateAdminAPIKey 验证管理员 API Key
func validateAdminAPIKey(
// validateAdminApiKey 验证管理员 API Key
func validateAdminApiKey(
c *gin.Context,
key string,
settingService *service.SettingService,
userService *service.UserService,
) bool {
storedKey, err := settingService.GetAdminAPIKey(c.Request.Context())
storedKey, err := settingService.GetAdminApiKey(c.Request.Context())
if err != nil {
AbortWithError(c, 500, "INTERNAL_ERROR", "Internal server error")
return false

View File

@@ -11,13 +11,13 @@ import (
"github.com/gin-gonic/gin"
)
// NewAPIKeyAuthMiddleware 创建 API Key 认证中间件
func NewAPIKeyAuthMiddleware(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config, opsService *service.OpsService) APIKeyAuthMiddleware {
return APIKeyAuthMiddleware(apiKeyAuthWithSubscription(apiKeyService, subscriptionService, cfg, opsService))
// NewApiKeyAuthMiddleware 创建 API Key 认证中间件
func NewApiKeyAuthMiddleware(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) ApiKeyAuthMiddleware {
return ApiKeyAuthMiddleware(apiKeyAuthWithSubscription(apiKeyService, subscriptionService, cfg))
}
// apiKeyAuthWithSubscription API Key认证中间件支持订阅验证
func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config, opsService *service.OpsService) gin.HandlerFunc {
func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc {
return func(c *gin.Context) {
// 尝试从Authorization header中提取API key (Bearer scheme)
authHeader := c.GetHeader("Authorization")
@@ -53,7 +53,6 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
// 如果所有header都没有API key
if apiKeyString == "" {
recordOpsAuthError(c, opsService, nil, 401, "API key is required in Authorization header (Bearer scheme), x-api-key header, x-goog-api-key header, or key/api_key query parameter")
AbortWithError(c, 401, "API_KEY_REQUIRED", "API key is required in Authorization header (Bearer scheme), x-api-key header, x-goog-api-key header, or key/api_key query parameter")
return
}
@@ -61,40 +60,35 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
// 从数据库验证API key
apiKey, err := apiKeyService.GetByKey(c.Request.Context(), apiKeyString)
if err != nil {
if errors.Is(err, service.ErrAPIKeyNotFound) {
recordOpsAuthError(c, opsService, nil, 401, "Invalid API key")
if errors.Is(err, service.ErrApiKeyNotFound) {
AbortWithError(c, 401, "INVALID_API_KEY", "Invalid API key")
return
}
recordOpsAuthError(c, opsService, nil, 500, "Failed to validate API key")
AbortWithError(c, 500, "INTERNAL_ERROR", "Failed to validate API key")
return
}
// 检查API key是否激活
if !apiKey.IsActive() {
recordOpsAuthError(c, opsService, apiKey, 401, "API key is disabled")
AbortWithError(c, 401, "API_KEY_DISABLED", "API key is disabled")
return
}
// 检查关联的用户
if apiKey.User == nil {
recordOpsAuthError(c, opsService, apiKey, 401, "User associated with API key not found")
AbortWithError(c, 401, "USER_NOT_FOUND", "User associated with API key not found")
return
}
// 检查用户状态
if !apiKey.User.IsActive() {
recordOpsAuthError(c, opsService, apiKey, 401, "User account is not active")
AbortWithError(c, 401, "USER_INACTIVE", "User account is not active")
return
}
if cfg.RunMode == config.RunModeSimple {
// 简易模式:跳过余额和订阅检查,但仍需设置必要的上下文
c.Set(string(ContextKeyAPIKey), apiKey)
c.Set(string(ContextKeyApiKey), apiKey)
c.Set(string(ContextKeyUser), AuthSubject{
UserID: apiKey.User.ID,
Concurrency: apiKey.User.Concurrency,
@@ -115,14 +109,12 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
apiKey.Group.ID,
)
if err != nil {
recordOpsAuthError(c, opsService, apiKey, 403, "No active subscription found for this group")
AbortWithError(c, 403, "SUBSCRIPTION_NOT_FOUND", "No active subscription found for this group")
return
}
// 验证订阅状态(是否过期、暂停等)
if err := subscriptionService.ValidateSubscription(c.Request.Context(), subscription); err != nil {
recordOpsAuthError(c, opsService, apiKey, 403, err.Error())
AbortWithError(c, 403, "SUBSCRIPTION_INVALID", err.Error())
return
}
@@ -139,7 +131,6 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
// 预检查用量限制使用0作为额外费用进行预检查
if err := subscriptionService.CheckUsageLimits(c.Request.Context(), subscription, apiKey.Group, 0); err != nil {
recordOpsAuthError(c, opsService, apiKey, 429, err.Error())
AbortWithError(c, 429, "USAGE_LIMIT_EXCEEDED", err.Error())
return
}
@@ -149,14 +140,13 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
} else {
// 余额模式:检查用户余额
if apiKey.User.Balance <= 0 {
recordOpsAuthError(c, opsService, apiKey, 403, "Insufficient account balance")
AbortWithError(c, 403, "INSUFFICIENT_BALANCE", "Insufficient account balance")
return
}
}
// 将API key和用户信息存入上下文
c.Set(string(ContextKeyAPIKey), apiKey)
c.Set(string(ContextKeyApiKey), apiKey)
c.Set(string(ContextKeyUser), AuthSubject{
UserID: apiKey.User.ID,
Concurrency: apiKey.User.Concurrency,
@@ -167,66 +157,13 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
}
}
func recordOpsAuthError(c *gin.Context, opsService *service.OpsService, apiKey *service.APIKey, status int, message string) {
if opsService == nil || c == nil {
return
}
errType := "authentication_error"
phase := "auth"
severity := "P3"
switch status {
case 403:
errType = "billing_error"
phase = "billing"
case 429:
errType = "rate_limit_error"
phase = "billing"
severity = "P2"
case 500:
errType = "api_error"
phase = "internal"
severity = "P1"
}
logEntry := &service.OpsErrorLog{
Phase: phase,
Type: errType,
Severity: severity,
StatusCode: status,
Message: message,
ClientIP: c.ClientIP(),
RequestPath: func() string {
if c.Request != nil && c.Request.URL != nil {
return c.Request.URL.Path
}
return ""
}(),
}
if apiKey != nil {
logEntry.APIKeyID = &apiKey.ID
if apiKey.User != nil {
logEntry.UserID = &apiKey.User.ID
}
if apiKey.GroupID != nil {
logEntry.GroupID = apiKey.GroupID
}
if apiKey.Group != nil {
logEntry.Platform = apiKey.Group.Platform
}
}
enqueueOpsAuthErrorLog(opsService, logEntry)
}
// GetAPIKeyFromContext 从上下文中获取API key
func GetAPIKeyFromContext(c *gin.Context) (*service.APIKey, bool) {
value, exists := c.Get(string(ContextKeyAPIKey))
// GetApiKeyFromContext 从上下文中获取API key
func GetApiKeyFromContext(c *gin.Context) (*service.ApiKey, bool) {
value, exists := c.Get(string(ContextKeyApiKey))
if !exists {
return nil, false
}
apiKey, ok := value.(*service.APIKey)
apiKey, ok := value.(*service.ApiKey)
return apiKey, ok
}

View File

@@ -11,16 +11,16 @@ import (
"github.com/gin-gonic/gin"
)
// APIKeyAuthGoogle is a Google-style error wrapper for API key auth.
func APIKeyAuthGoogle(apiKeyService *service.APIKeyService, cfg *config.Config) gin.HandlerFunc {
return APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, cfg)
// ApiKeyAuthGoogle is a Google-style error wrapper for API key auth.
func ApiKeyAuthGoogle(apiKeyService *service.ApiKeyService, cfg *config.Config) gin.HandlerFunc {
return ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, cfg)
}
// APIKeyAuthWithSubscriptionGoogle behaves like APIKeyAuthWithSubscription but returns Google-style errors:
// ApiKeyAuthWithSubscriptionGoogle behaves like ApiKeyAuthWithSubscription but returns Google-style errors:
// {"error":{"code":401,"message":"...","status":"UNAUTHENTICATED"}}
//
// It is intended for Gemini native endpoints (/v1beta) to match Gemini SDK expectations.
func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc {
func ApiKeyAuthWithSubscriptionGoogle(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc {
return func(c *gin.Context) {
apiKeyString := extractAPIKeyFromRequest(c)
if apiKeyString == "" {
@@ -30,7 +30,7 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs
apiKey, err := apiKeyService.GetByKey(c.Request.Context(), apiKeyString)
if err != nil {
if errors.Is(err, service.ErrAPIKeyNotFound) {
if errors.Is(err, service.ErrApiKeyNotFound) {
abortWithGoogleError(c, 401, "Invalid API key")
return
}
@@ -53,7 +53,7 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs
// 简易模式:跳过余额和订阅检查
if cfg.RunMode == config.RunModeSimple {
c.Set(string(ContextKeyAPIKey), apiKey)
c.Set(string(ContextKeyApiKey), apiKey)
c.Set(string(ContextKeyUser), AuthSubject{
UserID: apiKey.User.ID,
Concurrency: apiKey.User.Concurrency,
@@ -92,7 +92,7 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs
}
}
c.Set(string(ContextKeyAPIKey), apiKey)
c.Set(string(ContextKeyApiKey), apiKey)
c.Set(string(ContextKeyUser), AuthSubject{
UserID: apiKey.User.ID,
Concurrency: apiKey.User.Concurrency,

View File

@@ -16,53 +16,53 @@ import (
"github.com/stretchr/testify/require"
)
type fakeAPIKeyRepo struct {
getByKey func(ctx context.Context, key string) (*service.APIKey, error)
type fakeApiKeyRepo struct {
getByKey func(ctx context.Context, key string) (*service.ApiKey, error)
}
func (f fakeAPIKeyRepo) Create(ctx context.Context, key *service.APIKey) error {
func (f fakeApiKeyRepo) Create(ctx context.Context, key *service.ApiKey) error {
return errors.New("not implemented")
}
func (f fakeAPIKeyRepo) GetByID(ctx context.Context, id int64) (*service.APIKey, error) {
func (f fakeApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) {
return nil, errors.New("not implemented")
}
func (f fakeAPIKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) {
func (f fakeApiKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) {
return 0, errors.New("not implemented")
}
func (f fakeAPIKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
func (f fakeApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) {
if f.getByKey == nil {
return nil, errors.New("unexpected call")
}
return f.getByKey(ctx, key)
}
func (f fakeAPIKeyRepo) Update(ctx context.Context, key *service.APIKey) error {
func (f fakeApiKeyRepo) Update(ctx context.Context, key *service.ApiKey) error {
return errors.New("not implemented")
}
func (f fakeAPIKeyRepo) Delete(ctx context.Context, id int64) error {
func (f fakeApiKeyRepo) Delete(ctx context.Context, id int64) error {
return errors.New("not implemented")
}
func (f fakeAPIKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
func (f fakeApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (f fakeAPIKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
func (f fakeApiKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
return nil, errors.New("not implemented")
}
func (f fakeAPIKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64, error) {
func (f fakeApiKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64, error) {
return 0, errors.New("not implemented")
}
func (f fakeAPIKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, error) {
func (f fakeApiKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, error) {
return false, errors.New("not implemented")
}
func (f fakeAPIKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
func (f fakeApiKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (f fakeAPIKeyRepo) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.APIKey, error) {
func (f fakeApiKeyRepo) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) {
return nil, errors.New("not implemented")
}
func (f fakeAPIKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
func (f fakeApiKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, errors.New("not implemented")
}
func (f fakeAPIKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
func (f fakeApiKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, errors.New("not implemented")
}
@@ -74,8 +74,8 @@ type googleErrorResponse struct {
} `json:"error"`
}
func newTestAPIKeyService(repo service.APIKeyRepository) *service.APIKeyService {
return service.NewAPIKeyService(
func newTestApiKeyService(repo service.ApiKeyRepository) *service.ApiKeyService {
return service.NewApiKeyService(
repo,
nil, // userRepo (unused in GetByKey)
nil, // groupRepo
@@ -85,16 +85,16 @@ func newTestAPIKeyService(repo service.APIKeyRepository) *service.APIKeyService
)
}
func TestAPIKeyAuthWithSubscriptionGoogle_MissingKey(t *testing.T) {
func TestApiKeyAuthWithSubscriptionGoogle_MissingKey(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
return nil, errors.New("should not be called")
},
})
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
@@ -109,16 +109,16 @@ func TestAPIKeyAuthWithSubscriptionGoogle_MissingKey(t *testing.T) {
require.Equal(t, "UNAUTHENTICATED", resp.Error.Status)
}
func TestAPIKeyAuthWithSubscriptionGoogle_InvalidKey(t *testing.T) {
func TestApiKeyAuthWithSubscriptionGoogle_InvalidKey(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
return nil, service.ErrAPIKeyNotFound
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
return nil, service.ErrApiKeyNotFound
},
})
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
@@ -134,16 +134,16 @@ func TestAPIKeyAuthWithSubscriptionGoogle_InvalidKey(t *testing.T) {
require.Equal(t, "UNAUTHENTICATED", resp.Error.Status)
}
func TestAPIKeyAuthWithSubscriptionGoogle_RepoError(t *testing.T) {
func TestApiKeyAuthWithSubscriptionGoogle_RepoError(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
return nil, errors.New("db down")
},
})
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
@@ -159,13 +159,13 @@ func TestAPIKeyAuthWithSubscriptionGoogle_RepoError(t *testing.T) {
require.Equal(t, "INTERNAL", resp.Error.Status)
}
func TestAPIKeyAuthWithSubscriptionGoogle_DisabledKey(t *testing.T) {
func TestApiKeyAuthWithSubscriptionGoogle_DisabledKey(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
return &service.APIKey{
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
return &service.ApiKey{
ID: 1,
Key: key,
Status: service.StatusDisabled,
@@ -176,7 +176,7 @@ func TestAPIKeyAuthWithSubscriptionGoogle_DisabledKey(t *testing.T) {
}, nil
},
})
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
@@ -192,13 +192,13 @@ func TestAPIKeyAuthWithSubscriptionGoogle_DisabledKey(t *testing.T) {
require.Equal(t, "UNAUTHENTICATED", resp.Error.Status)
}
func TestAPIKeyAuthWithSubscriptionGoogle_InsufficientBalance(t *testing.T) {
func TestApiKeyAuthWithSubscriptionGoogle_InsufficientBalance(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
return &service.APIKey{
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
return &service.ApiKey{
ID: 1,
Key: key,
Status: service.StatusActive,
@@ -210,7 +210,7 @@ func TestAPIKeyAuthWithSubscriptionGoogle_InsufficientBalance(t *testing.T) {
}, nil
},
})
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)

View File

@@ -35,7 +35,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
Balance: 10,
Concurrency: 3,
}
apiKey := &service.APIKey{
apiKey := &service.ApiKey{
ID: 100,
UserID: user.ID,
Key: "test-key",
@@ -45,10 +45,10 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
}
apiKey.GroupID = &group.ID
apiKeyRepo := &stubAPIKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
apiKeyRepo := &stubApiKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
if key != apiKey.Key {
return nil, service.ErrAPIKeyNotFound
return nil, service.ErrApiKeyNotFound
}
clone := *apiKey
return &clone, nil
@@ -57,7 +57,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
t.Run("simple_mode_bypasses_quota_check", func(t *testing.T) {
cfg := &config.Config{RunMode: config.RunModeSimple}
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
apiKeyService := service.NewApiKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
subscriptionService := service.NewSubscriptionService(nil, &stubUserSubscriptionRepo{}, nil)
router := newAuthTestRouter(apiKeyService, subscriptionService, cfg)
@@ -71,7 +71,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
t.Run("standard_mode_enforces_quota_check", func(t *testing.T) {
cfg := &config.Config{RunMode: config.RunModeStandard}
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
apiKeyService := service.NewApiKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
now := time.Now()
sub := &service.UserSubscription{
@@ -110,75 +110,75 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
})
}
func newAuthTestRouter(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) *gin.Engine {
func newAuthTestRouter(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) *gin.Engine {
router := gin.New()
router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, cfg, nil)))
router.Use(gin.HandlerFunc(NewApiKeyAuthMiddleware(apiKeyService, subscriptionService, cfg)))
router.GET("/t", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"ok": true})
})
return router
}
type stubAPIKeyRepo struct {
getByKey func(ctx context.Context, key string) (*service.APIKey, error)
type stubApiKeyRepo struct {
getByKey func(ctx context.Context, key string) (*service.ApiKey, error)
}
func (r *stubAPIKeyRepo) Create(ctx context.Context, key *service.APIKey) error {
func (r *stubApiKeyRepo) Create(ctx context.Context, key *service.ApiKey) error {
return errors.New("not implemented")
}
func (r *stubAPIKeyRepo) GetByID(ctx context.Context, id int64) (*service.APIKey, error) {
func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) {
return nil, errors.New("not implemented")
}
func (r *stubAPIKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) {
func (r *stubApiKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) {
return 0, errors.New("not implemented")
}
func (r *stubAPIKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) {
if r.getByKey != nil {
return r.getByKey(ctx, key)
}
return nil, errors.New("not implemented")
}
func (r *stubAPIKeyRepo) Update(ctx context.Context, key *service.APIKey) error {
func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.ApiKey) error {
return errors.New("not implemented")
}
func (r *stubAPIKeyRepo) Delete(ctx context.Context, id int64) error {
func (r *stubApiKeyRepo) Delete(ctx context.Context, id int64) error {
return errors.New("not implemented")
}
func (r *stubAPIKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (r *stubAPIKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
func (r *stubApiKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
return nil, errors.New("not implemented")
}
func (r *stubAPIKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64, error) {
func (r *stubApiKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64, error) {
return 0, errors.New("not implemented")
}
func (r *stubAPIKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, error) {
func (r *stubApiKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, error) {
return false, errors.New("not implemented")
}
func (r *stubAPIKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
func (r *stubApiKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (r *stubAPIKeyRepo) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.APIKey, error) {
func (r *stubApiKeyRepo) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) {
return nil, errors.New("not implemented")
}
func (r *stubAPIKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
func (r *stubApiKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, errors.New("not implemented")
}
func (r *stubAPIKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
func (r *stubApiKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, errors.New("not implemented")
}

View File

@@ -2,14 +2,11 @@ package middleware
import (
"log"
"regexp"
"time"
"github.com/gin-gonic/gin"
)
var sensitiveQueryParamRE = regexp.MustCompile(`(?i)([?&](?:token|api_key)=)[^&#]*`)
// Logger 请求日志中间件
func Logger() gin.HandlerFunc {
return func(c *gin.Context) {
@@ -29,7 +26,7 @@ func Logger() gin.HandlerFunc {
method := c.Request.Method
// 请求路径
path := sensitiveQueryParamRE.ReplaceAllString(c.Request.URL.RequestURI(), "${1}***")
path := c.Request.URL.Path
// 状态码
statusCode := c.Writer.Status()

View File

@@ -1,5 +1,3 @@
// Package middleware provides HTTP middleware components for authentication,
// authorization, logging, error recovery, and request processing.
package middleware
import (
@@ -17,8 +15,8 @@ const (
ContextKeyUser ContextKey = "user"
// ContextKeyUserRole 当前用户角色string
ContextKeyUserRole ContextKey = "user_role"
// ContextKeyAPIKey API密钥上下文键
ContextKeyAPIKey ContextKey = "api_key"
// ContextKeyApiKey API密钥上下文键
ContextKeyApiKey ContextKey = "api_key"
// ContextKeySubscription 订阅上下文键
ContextKeySubscription ContextKey = "subscription"
// ContextKeyForcePlatform 强制平台(用于 /antigravity 路由)

View File

@@ -11,12 +11,12 @@ type JWTAuthMiddleware gin.HandlerFunc
// AdminAuthMiddleware 管理员认证中间件类型
type AdminAuthMiddleware gin.HandlerFunc
// APIKeyAuthMiddleware API Key 认证中间件类型
type APIKeyAuthMiddleware gin.HandlerFunc
// ApiKeyAuthMiddleware API Key 认证中间件类型
type ApiKeyAuthMiddleware gin.HandlerFunc
// ProviderSet 中间件层的依赖注入
var ProviderSet = wire.NewSet(
NewJWTAuthMiddleware,
NewAdminAuthMiddleware,
NewAPIKeyAuthMiddleware,
NewApiKeyAuthMiddleware,
)

View File

@@ -17,8 +17,8 @@ func SetupRouter(
handlers *handler.Handlers,
jwtAuth middleware2.JWTAuthMiddleware,
adminAuth middleware2.AdminAuthMiddleware,
apiKeyAuth middleware2.APIKeyAuthMiddleware,
apiKeyService *service.APIKeyService,
apiKeyAuth middleware2.ApiKeyAuthMiddleware,
apiKeyService *service.ApiKeyService,
subscriptionService *service.SubscriptionService,
cfg *config.Config,
) *gin.Engine {
@@ -43,8 +43,8 @@ func registerRoutes(
h *handler.Handlers,
jwtAuth middleware2.JWTAuthMiddleware,
adminAuth middleware2.AdminAuthMiddleware,
apiKeyAuth middleware2.APIKeyAuthMiddleware,
apiKeyService *service.APIKeyService,
apiKeyAuth middleware2.ApiKeyAuthMiddleware,
apiKeyService *service.ApiKeyService,
subscriptionService *service.SubscriptionService,
cfg *config.Config,
) {

View File

@@ -1,4 +1,3 @@
// Package routes 提供 HTTP 路由注册和处理函数
package routes
import (

View File

@@ -50,7 +50,7 @@ func RegisterUserRoutes(
usage.GET("/dashboard/stats", h.Usage.DashboardStats)
usage.GET("/dashboard/trend", h.Usage.DashboardTrend)
usage.GET("/dashboard/models", h.Usage.DashboardModels)
usage.POST("/dashboard/api-keys-usage", h.Usage.DashboardAPIKeysUsage)
usage.POST("/dashboard/api-keys-usage", h.Usage.DashboardApiKeysUsage)
}
// 卡密兑换

View File

@@ -29,6 +29,9 @@ type Account struct {
RateLimitResetAt *time.Time
OverloadUntil *time.Time
TempUnschedulableUntil *time.Time
TempUnschedulableReason string
SessionWindowStart *time.Time
SessionWindowEnd *time.Time
SessionWindowStatus string
@@ -39,6 +42,13 @@ type Account struct {
Groups []*Group
}
type TempUnschedulableRule struct {
ErrorCode int `json:"error_code"`
Keywords []string `json:"keywords"`
DurationMinutes int `json:"duration_minutes"`
Description string `json:"description"`
}
func (a *Account) IsActive() bool {
return a.Status == StatusActive
}
@@ -54,6 +64,9 @@ func (a *Account) IsSchedulable() bool {
if a.RateLimitResetAt != nil && now.Before(*a.RateLimitResetAt) {
return false
}
if a.TempUnschedulableUntil != nil && now.Before(*a.TempUnschedulableUntil) {
return false
}
return true
}
@@ -163,6 +176,114 @@ func (a *Account) GetCredentialAsTime(key string) *time.Time {
return nil
}
func (a *Account) IsTempUnschedulableEnabled() bool {
if a.Credentials == nil {
return false
}
raw, ok := a.Credentials["temp_unschedulable_enabled"]
if !ok || raw == nil {
return false
}
enabled, ok := raw.(bool)
return ok && enabled
}
func (a *Account) GetTempUnschedulableRules() []TempUnschedulableRule {
if a.Credentials == nil {
return nil
}
raw, ok := a.Credentials["temp_unschedulable_rules"]
if !ok || raw == nil {
return nil
}
arr, ok := raw.([]any)
if !ok {
return nil
}
rules := make([]TempUnschedulableRule, 0, len(arr))
for _, item := range arr {
entry, ok := item.(map[string]any)
if !ok || entry == nil {
continue
}
rule := TempUnschedulableRule{
ErrorCode: parseTempUnschedInt(entry["error_code"]),
Keywords: parseTempUnschedStrings(entry["keywords"]),
DurationMinutes: parseTempUnschedInt(entry["duration_minutes"]),
Description: parseTempUnschedString(entry["description"]),
}
if rule.ErrorCode <= 0 || rule.DurationMinutes <= 0 || len(rule.Keywords) == 0 {
continue
}
rules = append(rules, rule)
}
return rules
}
func parseTempUnschedString(value any) string {
s, ok := value.(string)
if !ok {
return ""
}
return strings.TrimSpace(s)
}
func parseTempUnschedStrings(value any) []string {
if value == nil {
return nil
}
var raw []string
switch v := value.(type) {
case []string:
raw = v
case []any:
raw = make([]string, 0, len(v))
for _, item := range v {
if s, ok := item.(string); ok {
raw = append(raw, s)
}
}
default:
return nil
}
out := make([]string, 0, len(raw))
for _, item := range raw {
s := strings.TrimSpace(item)
if s != "" {
out = append(out, s)
}
}
return out
}
func parseTempUnschedInt(value any) int {
switch v := value.(type) {
case int:
return v
case int64:
return int(v)
case float64:
return int(v)
case json.Number:
if i, err := v.Int64(); err == nil {
return int(i)
}
case string:
if i, err := strconv.Atoi(strings.TrimSpace(v)); err == nil {
return i
}
}
return 0
}
func (a *Account) GetModelMapping() map[string]string {
if a.Credentials == nil {
return nil
@@ -206,7 +327,7 @@ func (a *Account) GetMappedModel(requestedModel string) string {
}
func (a *Account) GetBaseURL() string {
if a.Type != AccountTypeAPIKey {
if a.Type != AccountTypeApiKey {
return ""
}
baseURL := a.GetCredential("base_url")
@@ -229,7 +350,7 @@ func (a *Account) GetExtraString(key string) string {
}
func (a *Account) IsCustomErrorCodesEnabled() bool {
if a.Type != AccountTypeAPIKey || a.Credentials == nil {
if a.Type != AccountTypeApiKey || a.Credentials == nil {
return false
}
if v, ok := a.Credentials["custom_error_codes_enabled"]; ok {
@@ -300,15 +421,15 @@ func (a *Account) IsOpenAIOAuth() bool {
return a.IsOpenAI() && a.Type == AccountTypeOAuth
}
func (a *Account) IsOpenAIAPIKey() bool {
return a.IsOpenAI() && a.Type == AccountTypeAPIKey
func (a *Account) IsOpenAIApiKey() bool {
return a.IsOpenAI() && a.Type == AccountTypeApiKey
}
func (a *Account) GetOpenAIBaseURL() string {
if !a.IsOpenAI() {
return ""
}
if a.Type == AccountTypeAPIKey {
if a.Type == AccountTypeApiKey {
baseURL := a.GetCredential("base_url")
if baseURL != "" {
return baseURL
@@ -338,8 +459,8 @@ func (a *Account) GetOpenAIIDToken() string {
return a.GetCredential("id_token")
}
func (a *Account) GetOpenAIAPIKey() string {
if !a.IsOpenAIAPIKey() {
func (a *Account) GetOpenAIApiKey() string {
if !a.IsOpenAIApiKey() {
return ""
}
return a.GetCredential("api_key")

View File

@@ -1,5 +1,3 @@
// Package service 提供业务逻辑层服务,封装领域模型的业务规则和操作流程。
// 服务层协调 repository 层的数据访问,实现跨实体的业务逻辑,并为上层 API 提供统一的业务接口。
package service
import (
@@ -51,6 +49,8 @@ type AccountRepository interface {
SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error
SetOverloaded(ctx context.Context, id int64, until time.Time) error
SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error
ClearTempUnschedulable(ctx context.Context, id int64) error
ClearRateLimit(ctx context.Context, id int64) error
UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error
UpdateExtra(ctx context.Context, id int64, updates map[string]any) error

View File

@@ -139,6 +139,14 @@ func (s *accountRepoStub) SetOverloaded(ctx context.Context, id int64, until tim
panic("unexpected SetOverloaded call")
}
func (s *accountRepoStub) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
panic("unexpected SetTempUnschedulable call")
}
func (s *accountRepoStub) ClearTempUnschedulable(ctx context.Context, id int64) error {
panic("unexpected ClearTempUnschedulable call")
}
func (s *accountRepoStub) ClearRateLimit(ctx context.Context, id int64) error {
panic("unexpected ClearRateLimit call")
}

View File

@@ -324,7 +324,7 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
chatgptAccountID = account.GetChatGPTAccountID()
} else if account.Type == "apikey" {
// API Key - use Platform API
authToken = account.GetOpenAIAPIKey()
authToken = account.GetOpenAIApiKey()
if authToken == "" {
return s.sendErrorAndEnd(c, "No API key available")
}
@@ -402,7 +402,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
}
// For API Key accounts with model mapping, map the model
if account.Type == AccountTypeAPIKey {
if account.Type == AccountTypeApiKey {
mapping := account.GetModelMapping()
if len(mapping) > 0 {
if mappedModel, exists := mapping[testModelID]; exists {
@@ -426,7 +426,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
var err error
switch account.Type {
case AccountTypeAPIKey:
case AccountTypeApiKey:
req, err = s.buildGeminiAPIKeyRequest(ctx, account, testModelID, payload)
case AccountTypeOAuth:
req, err = s.buildGeminiOAuthRequest(ctx, account, testModelID, payload)

View File

@@ -2,7 +2,7 @@ package service
import "time"
type APIKey struct {
type ApiKey struct {
ID int64
UserID int64
Key string
@@ -15,6 +15,6 @@ type APIKey struct {
Group *Group
}
func (k *APIKey) IsActive() bool {
func (k *ApiKey) IsActive() bool {
return k.Status == StatusActive
}

View File

@@ -14,39 +14,39 @@ import (
)
var (
ErrAPIKeyNotFound = infraerrors.NotFound("API_KEY_NOT_FOUND", "api key not found")
ErrApiKeyNotFound = infraerrors.NotFound("API_KEY_NOT_FOUND", "api key not found")
ErrGroupNotAllowed = infraerrors.Forbidden("GROUP_NOT_ALLOWED", "user is not allowed to bind this group")
ErrAPIKeyExists = infraerrors.Conflict("API_KEY_EXISTS", "api key already exists")
ErrAPIKeyTooShort = infraerrors.BadRequest("API_KEY_TOO_SHORT", "api key must be at least 16 characters")
ErrAPIKeyInvalidChars = infraerrors.BadRequest("API_KEY_INVALID_CHARS", "api key can only contain letters, numbers, underscores, and hyphens")
ErrAPIKeyRateLimited = infraerrors.TooManyRequests("API_KEY_RATE_LIMITED", "too many failed attempts, please try again later")
ErrApiKeyExists = infraerrors.Conflict("API_KEY_EXISTS", "api key already exists")
ErrApiKeyTooShort = infraerrors.BadRequest("API_KEY_TOO_SHORT", "api key must be at least 16 characters")
ErrApiKeyInvalidChars = infraerrors.BadRequest("API_KEY_INVALID_CHARS", "api key can only contain letters, numbers, underscores, and hyphens")
ErrApiKeyRateLimited = infraerrors.TooManyRequests("API_KEY_RATE_LIMITED", "too many failed attempts, please try again later")
)
const (
apiKeyMaxErrorsPerHour = 20
)
type APIKeyRepository interface {
Create(ctx context.Context, key *APIKey) error
GetByID(ctx context.Context, id int64) (*APIKey, error)
type ApiKeyRepository interface {
Create(ctx context.Context, key *ApiKey) error
GetByID(ctx context.Context, id int64) (*ApiKey, error)
// GetOwnerID 仅获取 API Key 的所有者 ID用于删除前的轻量级权限验证
GetOwnerID(ctx context.Context, id int64) (int64, error)
GetByKey(ctx context.Context, key string) (*APIKey, error)
Update(ctx context.Context, key *APIKey) error
GetByKey(ctx context.Context, key string) (*ApiKey, error)
Update(ctx context.Context, key *ApiKey) error
Delete(ctx context.Context, id int64) error
ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error)
ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error)
VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error)
CountByUserID(ctx context.Context, userID int64) (int64, error)
ExistsByKey(ctx context.Context, key string) (bool, error)
ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error)
SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error)
ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error)
SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]ApiKey, error)
ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error)
CountByGroupID(ctx context.Context, groupID int64) (int64, error)
}
// APIKeyCache defines cache operations for API key service
type APIKeyCache interface {
// ApiKeyCache defines cache operations for API key service
type ApiKeyCache interface {
GetCreateAttemptCount(ctx context.Context, userID int64) (int, error)
IncrementCreateAttemptCount(ctx context.Context, userID int64) error
DeleteCreateAttemptCount(ctx context.Context, userID int64) error
@@ -55,40 +55,40 @@ type APIKeyCache interface {
SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error
}
// CreateAPIKeyRequest 创建API Key请求
type CreateAPIKeyRequest struct {
// CreateApiKeyRequest 创建API Key请求
type CreateApiKeyRequest struct {
Name string `json:"name"`
GroupID *int64 `json:"group_id"`
CustomKey *string `json:"custom_key"` // 可选的自定义key
}
// UpdateAPIKeyRequest 更新API Key请求
type UpdateAPIKeyRequest struct {
// UpdateApiKeyRequest 更新API Key请求
type UpdateApiKeyRequest struct {
Name *string `json:"name"`
GroupID *int64 `json:"group_id"`
Status *string `json:"status"`
}
// APIKeyService API Key服务
type APIKeyService struct {
apiKeyRepo APIKeyRepository
// ApiKeyService API Key服务
type ApiKeyService struct {
apiKeyRepo ApiKeyRepository
userRepo UserRepository
groupRepo GroupRepository
userSubRepo UserSubscriptionRepository
cache APIKeyCache
cache ApiKeyCache
cfg *config.Config
}
// NewAPIKeyService 创建API Key服务实例
func NewAPIKeyService(
apiKeyRepo APIKeyRepository,
// NewApiKeyService 创建API Key服务实例
func NewApiKeyService(
apiKeyRepo ApiKeyRepository,
userRepo UserRepository,
groupRepo GroupRepository,
userSubRepo UserSubscriptionRepository,
cache APIKeyCache,
cache ApiKeyCache,
cfg *config.Config,
) *APIKeyService {
return &APIKeyService{
) *ApiKeyService {
return &ApiKeyService{
apiKeyRepo: apiKeyRepo,
userRepo: userRepo,
groupRepo: groupRepo,
@@ -99,7 +99,7 @@ func NewAPIKeyService(
}
// GenerateKey 生成随机API Key
func (s *APIKeyService) GenerateKey() (string, error) {
func (s *ApiKeyService) GenerateKey() (string, error) {
// 生成32字节随机数据
bytes := make([]byte, 32)
if _, err := rand.Read(bytes); err != nil {
@@ -107,7 +107,7 @@ func (s *APIKeyService) GenerateKey() (string, error) {
}
// 转换为十六进制字符串并添加前缀
prefix := s.cfg.Default.APIKeyPrefix
prefix := s.cfg.Default.ApiKeyPrefix
if prefix == "" {
prefix = "sk-"
}
@@ -117,10 +117,10 @@ func (s *APIKeyService) GenerateKey() (string, error) {
}
// ValidateCustomKey 验证自定义API Key格式
func (s *APIKeyService) ValidateCustomKey(key string) error {
func (s *ApiKeyService) ValidateCustomKey(key string) error {
// 检查长度
if len(key) < 16 {
return ErrAPIKeyTooShort
return ErrApiKeyTooShort
}
// 检查字符:只允许字母、数字、下划线、连字符
@@ -131,14 +131,14 @@ func (s *APIKeyService) ValidateCustomKey(key string) error {
c == '_' || c == '-' {
continue
}
return ErrAPIKeyInvalidChars
return ErrApiKeyInvalidChars
}
return nil
}
// checkAPIKeyRateLimit 检查用户创建自定义Key的错误次数是否超限
func (s *APIKeyService) checkAPIKeyRateLimit(ctx context.Context, userID int64) error {
// checkApiKeyRateLimit 检查用户创建自定义Key的错误次数是否超限
func (s *ApiKeyService) checkApiKeyRateLimit(ctx context.Context, userID int64) error {
if s.cache == nil {
return nil
}
@@ -150,14 +150,14 @@ func (s *APIKeyService) checkAPIKeyRateLimit(ctx context.Context, userID int64)
}
if count >= apiKeyMaxErrorsPerHour {
return ErrAPIKeyRateLimited
return ErrApiKeyRateLimited
}
return nil
}
// incrementAPIKeyErrorCount 增加用户创建自定义Key的错误计数
func (s *APIKeyService) incrementAPIKeyErrorCount(ctx context.Context, userID int64) {
// incrementApiKeyErrorCount 增加用户创建自定义Key的错误计数
func (s *ApiKeyService) incrementApiKeyErrorCount(ctx context.Context, userID int64) {
if s.cache == nil {
return
}
@@ -168,7 +168,7 @@ func (s *APIKeyService) incrementAPIKeyErrorCount(ctx context.Context, userID in
// canUserBindGroup 检查用户是否可以绑定指定分组
// 对于订阅类型分组:检查用户是否有有效订阅
// 对于标准类型分组:使用原有的 AllowedGroups 和 IsExclusive 逻辑
func (s *APIKeyService) canUserBindGroup(ctx context.Context, user *User, group *Group) bool {
func (s *ApiKeyService) canUserBindGroup(ctx context.Context, user *User, group *Group) bool {
// 订阅类型分组:需要有效订阅
if group.IsSubscriptionType() {
_, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, user.ID, group.ID)
@@ -179,7 +179,7 @@ func (s *APIKeyService) canUserBindGroup(ctx context.Context, user *User, group
}
// Create 创建API Key
func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIKeyRequest) (*APIKey, error) {
func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiKeyRequest) (*ApiKey, error) {
// 验证用户存在
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
@@ -204,7 +204,7 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK
// 判断是否使用自定义Key
if req.CustomKey != nil && *req.CustomKey != "" {
// 检查限流仅对自定义key进行限流
if err := s.checkAPIKeyRateLimit(ctx, userID); err != nil {
if err := s.checkApiKeyRateLimit(ctx, userID); err != nil {
return nil, err
}
@@ -219,9 +219,9 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK
return nil, fmt.Errorf("check key exists: %w", err)
}
if exists {
// Key已存在,增加错误计数
s.incrementAPIKeyErrorCount(ctx, userID)
return nil, ErrAPIKeyExists
// Key已存在增加错误计数
s.incrementApiKeyErrorCount(ctx, userID)
return nil, ErrApiKeyExists
}
key = *req.CustomKey
@@ -235,7 +235,7 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK
}
// 创建API Key记录
apiKey := &APIKey{
apiKey := &ApiKey{
UserID: userID,
Key: key,
Name: req.Name,
@@ -251,7 +251,7 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK
}
// List 获取用户的API Key列表
func (s *APIKeyService) List(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
func (s *ApiKeyService) List(ctx context.Context, userID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error) {
keys, pagination, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
if err != nil {
return nil, nil, fmt.Errorf("list api keys: %w", err)
@@ -259,7 +259,7 @@ func (s *APIKeyService) List(ctx context.Context, userID int64, params paginatio
return keys, pagination, nil
}
func (s *APIKeyService) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
func (s *ApiKeyService) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
if len(apiKeyIDs) == 0 {
return []int64{}, nil
}
@@ -272,7 +272,7 @@ func (s *APIKeyService) VerifyOwnership(ctx context.Context, userID int64, apiKe
}
// GetByID 根据ID获取API Key
func (s *APIKeyService) GetByID(ctx context.Context, id int64) (*APIKey, error) {
func (s *ApiKeyService) GetByID(ctx context.Context, id int64) (*ApiKey, error) {
apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
if err != nil {
return nil, fmt.Errorf("get api key: %w", err)
@@ -281,7 +281,7 @@ func (s *APIKeyService) GetByID(ctx context.Context, id int64) (*APIKey, error)
}
// GetByKey 根据Key字符串获取API Key用于认证
func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, error) {
func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*ApiKey, error) {
// 尝试从Redis缓存获取
cacheKey := fmt.Sprintf("apikey:%s", key)
@@ -301,7 +301,7 @@ func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, erro
}
// Update 更新API Key
func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req UpdateAPIKeyRequest) (*APIKey, error) {
func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req UpdateApiKeyRequest) (*ApiKey, error) {
apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
if err != nil {
return nil, fmt.Errorf("get api key: %w", err)
@@ -353,8 +353,8 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req
// Delete 删除API Key
// 优化:使用 GetOwnerID 替代 GetByID 进行权限验证,
// 避免加载完整 APIKey 对象及其关联数据User、Group提升删除操作的性能
func (s *APIKeyService) Delete(ctx context.Context, id int64, userID int64) error {
// 避免加载完整 ApiKey 对象及其关联数据User、Group提升删除操作的性能
func (s *ApiKeyService) Delete(ctx context.Context, id int64, userID int64) error {
// 仅获取所有者 ID 用于权限验证,而非加载完整对象
ownerID, err := s.apiKeyRepo.GetOwnerID(ctx, id)
if err != nil {
@@ -379,7 +379,7 @@ func (s *APIKeyService) Delete(ctx context.Context, id int64, userID int64) erro
}
// ValidateKey 验证API Key是否有效用于认证中间件
func (s *APIKeyService) ValidateKey(ctx context.Context, key string) (*APIKey, *User, error) {
func (s *ApiKeyService) ValidateKey(ctx context.Context, key string) (*ApiKey, *User, error) {
// 获取API Key
apiKey, err := s.GetByKey(ctx, key)
if err != nil {
@@ -406,7 +406,7 @@ func (s *APIKeyService) ValidateKey(ctx context.Context, key string) (*APIKey, *
}
// IncrementUsage 增加API Key使用次数可选用于统计
func (s *APIKeyService) IncrementUsage(ctx context.Context, keyID int64) error {
func (s *ApiKeyService) IncrementUsage(ctx context.Context, keyID int64) error {
// 使用Redis计数器
if s.cache != nil {
cacheKey := fmt.Sprintf("apikey:usage:%d:%s", keyID, timezone.Now().Format("2006-01-02"))
@@ -423,7 +423,7 @@ func (s *APIKeyService) IncrementUsage(ctx context.Context, keyID int64) error {
// 返回用户可以选择的分组:
// - 标准类型分组:公开的(非专属)或用户被明确允许的
// - 订阅类型分组:用户有有效订阅的
func (s *APIKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([]Group, error) {
func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([]Group, error) {
// 获取用户信息
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
@@ -460,7 +460,7 @@ func (s *APIKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([
}
// canUserBindGroupInternal 内部方法,检查用户是否可以绑定分组(使用预加载的订阅数据)
func (s *APIKeyService) canUserBindGroupInternal(user *User, group *Group, subscribedGroupIDs map[int64]bool) bool {
func (s *ApiKeyService) canUserBindGroupInternal(user *User, group *Group, subscribedGroupIDs map[int64]bool) bool {
// 订阅类型分组:需要有效订阅
if group.IsSubscriptionType() {
return subscribedGroupIDs[group.ID]
@@ -469,8 +469,8 @@ func (s *APIKeyService) canUserBindGroupInternal(user *User, group *Group, subsc
return user.CanBindGroup(group.ID, group.IsExclusive)
}
func (s *APIKeyService) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error) {
keys, err := s.apiKeyRepo.SearchAPIKeys(ctx, userID, keyword, limit)
func (s *ApiKeyService) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]ApiKey, error) {
keys, err := s.apiKeyRepo.SearchApiKeys(ctx, userID, keyword, limit)
if err != nil {
return nil, fmt.Errorf("search api keys: %w", err)
}

View File

@@ -1,7 +1,7 @@
//go:build unit
// API Key 服务删除方法的单元测试
// 测试 APIKeyService.Delete 方法在各种场景下的行为,
// 测试 ApiKeyService.Delete 方法在各种场景下的行为,
// 包括权限验证、缓存清理和错误处理
package service
@@ -16,12 +16,12 @@ import (
"github.com/stretchr/testify/require"
)
// apiKeyRepoStub 是 APIKeyRepository 接口的测试桩实现。
// 用于隔离测试 APIKeyService.Delete 方法,避免依赖真实数据库。
// apiKeyRepoStub 是 ApiKeyRepository 接口的测试桩实现。
// 用于隔离测试 ApiKeyService.Delete 方法,避免依赖真实数据库。
//
// 设计说明:
// - ownerID: 模拟 GetOwnerID 返回的所有者 ID
// - ownerErr: 模拟 GetOwnerID 返回的错误(如 ErrAPIKeyNotFound
// - ownerErr: 模拟 GetOwnerID 返回的错误(如 ErrApiKeyNotFound
// - deleteErr: 模拟 Delete 返回的错误
// - deletedIDs: 记录被调用删除的 API Key ID用于断言验证
type apiKeyRepoStub struct {
@@ -33,11 +33,11 @@ type apiKeyRepoStub struct {
// 以下方法在本测试中不应被调用,使用 panic 确保测试失败时能快速定位问题
func (s *apiKeyRepoStub) Create(ctx context.Context, key *APIKey) error {
func (s *apiKeyRepoStub) Create(ctx context.Context, key *ApiKey) error {
panic("unexpected Create call")
}
func (s *apiKeyRepoStub) GetByID(ctx context.Context, id int64) (*APIKey, error) {
func (s *apiKeyRepoStub) GetByID(ctx context.Context, id int64) (*ApiKey, error) {
panic("unexpected GetByID call")
}
@@ -47,11 +47,11 @@ func (s *apiKeyRepoStub) GetOwnerID(ctx context.Context, id int64) (int64, error
return s.ownerID, s.ownerErr
}
func (s *apiKeyRepoStub) GetByKey(ctx context.Context, key string) (*APIKey, error) {
func (s *apiKeyRepoStub) GetByKey(ctx context.Context, key string) (*ApiKey, error) {
panic("unexpected GetByKey call")
}
func (s *apiKeyRepoStub) Update(ctx context.Context, key *APIKey) error {
func (s *apiKeyRepoStub) Update(ctx context.Context, key *ApiKey) error {
panic("unexpected Update call")
}
@@ -64,7 +64,7 @@ func (s *apiKeyRepoStub) Delete(ctx context.Context, id int64) error {
// 以下是接口要求实现但本测试不关心的方法
func (s *apiKeyRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
func (s *apiKeyRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error) {
panic("unexpected ListByUserID call")
}
@@ -80,12 +80,12 @@ func (s *apiKeyRepoStub) ExistsByKey(ctx context.Context, key string) (bool, err
panic("unexpected ExistsByKey call")
}
func (s *apiKeyRepoStub) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
func (s *apiKeyRepoStub) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error) {
panic("unexpected ListByGroupID call")
}
func (s *apiKeyRepoStub) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error) {
panic("unexpected SearchAPIKeys call")
func (s *apiKeyRepoStub) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]ApiKey, error) {
panic("unexpected SearchApiKeys call")
}
func (s *apiKeyRepoStub) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
@@ -96,7 +96,7 @@ func (s *apiKeyRepoStub) CountByGroupID(ctx context.Context, groupID int64) (int
panic("unexpected CountByGroupID call")
}
// apiKeyCacheStub 是 APIKeyCache 接口的测试桩实现。
// apiKeyCacheStub 是 ApiKeyCache 接口的测试桩实现。
// 用于验证删除操作时缓存清理逻辑是否被正确调用。
//
// 设计说明:
@@ -132,17 +132,17 @@ func (s *apiKeyCacheStub) SetDailyUsageExpiry(ctx context.Context, apiKey string
return nil
}
// TestAPIKeyService_Delete_OwnerMismatch 测试非所有者尝试删除时返回权限错误。
// TestApiKeyService_Delete_OwnerMismatch 测试非所有者尝试删除时返回权限错误。
// 预期行为:
// - GetOwnerID 返回所有者 ID 为 1
// - 调用者 userID 为 2不匹配
// - 返回 ErrInsufficientPerms 错误
// - Delete 方法不被调用
// - 缓存不被清除
func TestAPIKeyService_Delete_OwnerMismatch(t *testing.T) {
func TestApiKeyService_Delete_OwnerMismatch(t *testing.T) {
repo := &apiKeyRepoStub{ownerID: 1}
cache := &apiKeyCacheStub{}
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
svc := &ApiKeyService{apiKeyRepo: repo, cache: cache}
err := svc.Delete(context.Background(), 10, 2) // API Key ID=10, 调用者 userID=2
require.ErrorIs(t, err, ErrInsufficientPerms)
@@ -150,17 +150,17 @@ func TestAPIKeyService_Delete_OwnerMismatch(t *testing.T) {
require.Empty(t, cache.invalidated) // 验证缓存未被清除
}
// TestAPIKeyService_Delete_Success 测试所有者成功删除 API Key 的场景。
// TestApiKeyService_Delete_Success 测试所有者成功删除 API Key 的场景。
// 预期行为:
// - GetOwnerID 返回所有者 ID 为 7
// - 调用者 userID 为 7匹配
// - Delete 成功执行
// - 缓存被正确清除(使用 ownerID
// - 返回 nil 错误
func TestAPIKeyService_Delete_Success(t *testing.T) {
func TestApiKeyService_Delete_Success(t *testing.T) {
repo := &apiKeyRepoStub{ownerID: 7}
cache := &apiKeyCacheStub{}
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
svc := &ApiKeyService{apiKeyRepo: repo, cache: cache}
err := svc.Delete(context.Background(), 42, 7) // API Key ID=42, 调用者 userID=7
require.NoError(t, err)
@@ -168,37 +168,37 @@ func TestAPIKeyService_Delete_Success(t *testing.T) {
require.Equal(t, []int64{7}, cache.invalidated) // 验证所有者的缓存被清除
}
// TestAPIKeyService_Delete_NotFound 测试删除不存在的 API Key 时返回正确的错误。
// TestApiKeyService_Delete_NotFound 测试删除不存在的 API Key 时返回正确的错误。
// 预期行为:
// - GetOwnerID 返回 ErrAPIKeyNotFound 错误
// - 返回 ErrAPIKeyNotFound 错误(被 fmt.Errorf 包装)
// - GetOwnerID 返回 ErrApiKeyNotFound 错误
// - 返回 ErrApiKeyNotFound 错误(被 fmt.Errorf 包装)
// - Delete 方法不被调用
// - 缓存不被清除
func TestAPIKeyService_Delete_NotFound(t *testing.T) {
repo := &apiKeyRepoStub{ownerErr: ErrAPIKeyNotFound}
func TestApiKeyService_Delete_NotFound(t *testing.T) {
repo := &apiKeyRepoStub{ownerErr: ErrApiKeyNotFound}
cache := &apiKeyCacheStub{}
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
svc := &ApiKeyService{apiKeyRepo: repo, cache: cache}
err := svc.Delete(context.Background(), 99, 1)
require.ErrorIs(t, err, ErrAPIKeyNotFound)
require.ErrorIs(t, err, ErrApiKeyNotFound)
require.Empty(t, repo.deletedIDs)
require.Empty(t, cache.invalidated)
}
// TestAPIKeyService_Delete_DeleteFails 测试删除操作失败时的错误处理。
// TestApiKeyService_Delete_DeleteFails 测试删除操作失败时的错误处理。
// 预期行为:
// - GetOwnerID 返回正确的所有者 ID
// - 所有权验证通过
// - 缓存被清除(在删除之前)
// - Delete 被调用但返回错误
// - 返回包含 "delete api key" 的错误信息
func TestAPIKeyService_Delete_DeleteFails(t *testing.T) {
func TestApiKeyService_Delete_DeleteFails(t *testing.T) {
repo := &apiKeyRepoStub{
ownerID: 3,
deleteErr: errors.New("delete failed"),
}
cache := &apiKeyCacheStub{}
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
svc := &ApiKeyService{apiKeyRepo: repo, cache: cache}
err := svc.Delete(context.Background(), 3, 3) // API Key ID=3, 调用者 userID=3
require.Error(t, err)

View File

@@ -445,7 +445,7 @@ func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID
// CheckBillingEligibility 检查用户是否有资格发起请求
// 余额模式:检查缓存余额 > 0
// 订阅模式检查缓存用量未超过限额Group限额从参数传入
func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user *User, apiKey *APIKey, group *Group, subscription *UserSubscription) error {
func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user *User, apiKey *ApiKey, group *Group, subscription *UserSubscription) error {
// 简易模式:跳过所有计费检查
if s.cfg.RunMode == config.RunModeSimple {
return nil

View File

@@ -82,7 +82,7 @@ type crsExportResponse struct {
OpenAIOAuthAccounts []crsOpenAIOAuthAccount `json:"openaiOAuthAccounts"`
OpenAIResponsesAccounts []crsOpenAIResponsesAccount `json:"openaiResponsesAccounts"`
GeminiOAuthAccounts []crsGeminiOAuthAccount `json:"geminiOAuthAccounts"`
GeminiAPIKeyAccounts []crsGeminiAPIKeyAccount `json:"geminiAPIKeyAccounts"`
GeminiAPIKeyAccounts []crsGeminiAPIKeyAccount `json:"geminiApiKeyAccounts"`
} `json:"data"`
}
@@ -430,7 +430,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
account := &Account{
Name: defaultName(src.Name, src.ID),
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Type: AccountTypeApiKey,
Credentials: credentials,
Extra: extra,
ProxyID: proxyID,
@@ -455,7 +455,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
existing.Extra = mergeMap(existing.Extra, extra)
existing.Name = defaultName(src.Name, src.ID)
existing.Platform = PlatformAnthropic
existing.Type = AccountTypeAPIKey
existing.Type = AccountTypeApiKey
existing.Credentials = mergeMap(existing.Credentials, credentials)
if proxyID != nil {
existing.ProxyID = proxyID
@@ -674,7 +674,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
account := &Account{
Name: defaultName(src.Name, src.ID),
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Type: AccountTypeApiKey,
Credentials: credentials,
Extra: extra,
ProxyID: proxyID,
@@ -699,7 +699,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
existing.Extra = mergeMap(existing.Extra, extra)
existing.Name = defaultName(src.Name, src.ID)
existing.Platform = PlatformOpenAI
existing.Type = AccountTypeAPIKey
existing.Type = AccountTypeApiKey
existing.Credentials = mergeMap(existing.Credentials, credentials)
if proxyID != nil {
existing.ProxyID = proxyID
@@ -893,7 +893,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
account := &Account{
Name: defaultName(src.Name, src.ID),
Platform: PlatformGemini,
Type: AccountTypeAPIKey,
Type: AccountTypeApiKey,
Credentials: credentials,
Extra: extra,
ProxyID: proxyID,
@@ -918,7 +918,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
existing.Extra = mergeMap(existing.Extra, extra)
existing.Name = defaultName(src.Name, src.ID)
existing.Platform = PlatformGemini
existing.Type = AccountTypeAPIKey
existing.Type = AccountTypeApiKey
existing.Credentials = mergeMap(existing.Credentials, credentials)
if proxyID != nil {
existing.ProxyID = proxyID

View File

@@ -43,8 +43,8 @@ func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTi
return stats, nil
}
func (s *DashboardService) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) {
trend, err := s.usageRepo.GetAPIKeyUsageTrend(ctx, startTime, endTime, granularity, limit)
func (s *DashboardService) GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.ApiKeyUsageTrendPoint, error) {
trend, err := s.usageRepo.GetApiKeyUsageTrend(ctx, startTime, endTime, granularity, limit)
if err != nil {
return nil, fmt.Errorf("get api key usage trend: %w", err)
}
@@ -67,8 +67,8 @@ func (s *DashboardService) GetBatchUserUsageStats(ctx context.Context, userIDs [
return stats, nil
}
func (s *DashboardService) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
stats, err := s.usageRepo.GetBatchAPIKeyUsageStats(ctx, apiKeyIDs)
func (s *DashboardService) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchApiKeyUsageStats, error) {
stats, err := s.usageRepo.GetBatchApiKeyUsageStats(ctx, apiKeyIDs)
if err != nil {
return nil, fmt.Errorf("get batch api key usage stats: %w", err)
}

View File

@@ -28,7 +28,7 @@ const (
const (
AccountTypeOAuth = "oauth" // OAuth类型账号full scope: profile + inference
AccountTypeSetupToken = "setup-token" // Setup Token类型账号inference only scope
AccountTypeAPIKey = "apikey" // API Key类型账号
AccountTypeApiKey = "apikey" // API Key类型账号
)
// Redeem type constants
@@ -64,13 +64,13 @@ const (
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
// 邮件服务设置
SettingKeySMTPHost = "smtp_host" // SMTP服务器地址
SettingKeySMTPPort = "smtp_port" // SMTP端口
SettingKeySMTPUsername = "smtp_username" // SMTP用户名
SettingKeySMTPPassword = "smtp_password" // SMTP密码加密存储
SettingKeySMTPFrom = "smtp_from" // 发件人地址
SettingKeySMTPFromName = "smtp_from_name" // 发件人名称
SettingKeySMTPUseTLS = "smtp_use_tls" // 是否使用TLS
SettingKeySmtpHost = "smtp_host" // SMTP服务器地址
SettingKeySmtpPort = "smtp_port" // SMTP端口
SettingKeySmtpUsername = "smtp_username" // SMTP用户名
SettingKeySmtpPassword = "smtp_password" // SMTP密码加密存储
SettingKeySmtpFrom = "smtp_from" // 发件人地址
SettingKeySmtpFromName = "smtp_from_name" // 发件人名称
SettingKeySmtpUseTLS = "smtp_use_tls" // 是否使用TLS
// Cloudflare Turnstile 设置
SettingKeyTurnstileEnabled = "turnstile_enabled" // 是否启用 Turnstile 验证
@@ -81,20 +81,27 @@ const (
SettingKeySiteName = "site_name" // 网站名称
SettingKeySiteLogo = "site_logo" // 网站Logo (base64)
SettingKeySiteSubtitle = "site_subtitle" // 网站副标题
SettingKeyAPIBaseURL = "api_base_url" // API端点地址用于客户端配置和导入
SettingKeyApiBaseUrl = "api_base_url" // API端点地址用于客户端配置和导入
SettingKeyContactInfo = "contact_info" // 客服联系方式
SettingKeyDocURL = "doc_url" // 文档链接
SettingKeyDocUrl = "doc_url" // 文档链接
// 默认配置
SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量
SettingKeyDefaultBalance = "default_balance" // 新用户默认余额
// 管理员 API Key
SettingKeyAdminAPIKey = "admin_api_key" // 全局管理员 API Key用于外部系统集成
SettingKeyAdminApiKey = "admin_api_key" // 全局管理员 API Key用于外部系统集成
// Gemini 配额策略JSON
SettingKeyGeminiQuotaPolicy = "gemini_quota_policy"
// Model fallback settings
SettingKeyEnableModelFallback = "enable_model_fallback"
SettingKeyFallbackModelAnthropic = "fallback_model_anthropic"
SettingKeyFallbackModelOpenAI = "fallback_model_openai"
SettingKeyFallbackModelGemini = "fallback_model_gemini"
SettingKeyFallbackModelAntigravity = "fallback_model_antigravity"
)
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys)
const AdminAPIKeyPrefix = "admin-"
// Admin API Key prefix (distinct from user "sk-" keys)
const AdminApiKeyPrefix = "admin-"

View File

@@ -40,8 +40,8 @@ const (
maxVerifyCodeAttempts = 5
)
// SMTPConfig SMTP配置
type SMTPConfig struct {
// SmtpConfig SMTP配置
type SmtpConfig struct {
Host string
Port int
Username string
@@ -65,16 +65,16 @@ func NewEmailService(settingRepo SettingRepository, cache EmailCache) *EmailServ
}
}
// GetSMTPConfig 从数据库获取SMTP配置
func (s *EmailService) GetSMTPConfig(ctx context.Context) (*SMTPConfig, error) {
// GetSmtpConfig 从数据库获取SMTP配置
func (s *EmailService) GetSmtpConfig(ctx context.Context) (*SmtpConfig, error) {
keys := []string{
SettingKeySMTPHost,
SettingKeySMTPPort,
SettingKeySMTPUsername,
SettingKeySMTPPassword,
SettingKeySMTPFrom,
SettingKeySMTPFromName,
SettingKeySMTPUseTLS,
SettingKeySmtpHost,
SettingKeySmtpPort,
SettingKeySmtpUsername,
SettingKeySmtpPassword,
SettingKeySmtpFrom,
SettingKeySmtpFromName,
SettingKeySmtpUseTLS,
}
settings, err := s.settingRepo.GetMultiple(ctx, keys)
@@ -82,34 +82,34 @@ func (s *EmailService) GetSMTPConfig(ctx context.Context) (*SMTPConfig, error) {
return nil, fmt.Errorf("get smtp settings: %w", err)
}
host := settings[SettingKeySMTPHost]
host := settings[SettingKeySmtpHost]
if host == "" {
return nil, ErrEmailNotConfigured
}
port := 587 // 默认端口
if portStr := settings[SettingKeySMTPPort]; portStr != "" {
if portStr := settings[SettingKeySmtpPort]; portStr != "" {
if p, err := strconv.Atoi(portStr); err == nil {
port = p
}
}
useTLS := settings[SettingKeySMTPUseTLS] == "true"
useTLS := settings[SettingKeySmtpUseTLS] == "true"
return &SMTPConfig{
return &SmtpConfig{
Host: host,
Port: port,
Username: settings[SettingKeySMTPUsername],
Password: settings[SettingKeySMTPPassword],
From: settings[SettingKeySMTPFrom],
FromName: settings[SettingKeySMTPFromName],
Username: settings[SettingKeySmtpUsername],
Password: settings[SettingKeySmtpPassword],
From: settings[SettingKeySmtpFrom],
FromName: settings[SettingKeySmtpFromName],
UseTLS: useTLS,
}, nil
}
// SendEmail 发送邮件(使用数据库中保存的配置)
func (s *EmailService) SendEmail(ctx context.Context, to, subject, body string) error {
config, err := s.GetSMTPConfig(ctx)
config, err := s.GetSmtpConfig(ctx)
if err != nil {
return err
}
@@ -117,7 +117,7 @@ func (s *EmailService) SendEmail(ctx context.Context, to, subject, body string)
}
// SendEmailWithConfig 使用指定配置发送邮件
func (s *EmailService) SendEmailWithConfig(config *SMTPConfig, to, subject, body string) error {
func (s *EmailService) SendEmailWithConfig(config *SmtpConfig, to, subject, body string) error {
from := config.From
if config.FromName != "" {
from = fmt.Sprintf("%s <%s>", config.FromName, config.From)
@@ -306,8 +306,8 @@ func (s *EmailService) buildVerifyCodeEmailBody(code, siteName string) string {
`, siteName, code)
}
// TestSMTPConnectionWithConfig 使用指定配置测试SMTP连接
func (s *EmailService) TestSMTPConnectionWithConfig(config *SMTPConfig) error {
// TestSmtpConnectionWithConfig 使用指定配置测试SMTP连接
func (s *EmailService) TestSmtpConnectionWithConfig(config *SmtpConfig) error {
addr := fmt.Sprintf("%s:%d", config.Host, config.Port)
if config.UseTLS {

View File

@@ -487,8 +487,8 @@ func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *Acco
return "", "", errors.New("access_token not found in credentials")
}
return accessToken, "oauth", nil
case AccountTypeAPIKey:
apiKey := account.GetOpenAIAPIKey()
case AccountTypeApiKey:
apiKey := account.GetOpenAIApiKey()
if apiKey == "" {
return "", "", errors.New("api_key not found in credentials")
}
@@ -627,7 +627,7 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
case AccountTypeOAuth:
// OAuth accounts use ChatGPT internal API
targetURL = chatgptCodexURL
case AccountTypeAPIKey:
case AccountTypeApiKey:
// API Key accounts use Platform API or custom base URL
baseURL := account.GetOpenAIBaseURL()
if baseURL != "" {
@@ -703,7 +703,13 @@ func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *ht
}
// Handle upstream error (mark account status)
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
shouldDisable := false
if s.rateLimitService != nil {
shouldDisable = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
}
if shouldDisable {
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
// Return appropriate error response
var errType, errMsg string
@@ -940,7 +946,7 @@ func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel
// OpenAIRecordUsageInput input for recording usage
type OpenAIRecordUsageInput struct {
Result *OpenAIForwardResult
APIKey *APIKey
ApiKey *ApiKey
User *User
Account *Account
Subscription *UserSubscription
@@ -949,7 +955,7 @@ type OpenAIRecordUsageInput struct {
// RecordUsage records usage and deducts balance
func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRecordUsageInput) error {
result := input.Result
apiKey := input.APIKey
apiKey := input.ApiKey
user := input.User
account := input.Account
subscription := input.Subscription
@@ -991,7 +997,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
durationMs := int(result.Duration.Milliseconds())
usageLog := &UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
ApiKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: result.RequestID,
Model: result.Model,

View File

@@ -61,9 +61,9 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SettingKeySiteName,
SettingKeySiteLogo,
SettingKeySiteSubtitle,
SettingKeyAPIBaseURL,
SettingKeyApiBaseUrl,
SettingKeyContactInfo,
SettingKeyDocURL,
SettingKeyDocUrl,
}
settings, err := s.settingRepo.GetMultiple(ctx, keys)
@@ -79,9 +79,9 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"),
SiteLogo: settings[SettingKeySiteLogo],
SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"),
APIBaseURL: settings[SettingKeyAPIBaseURL],
ApiBaseUrl: settings[SettingKeyApiBaseUrl],
ContactInfo: settings[SettingKeyContactInfo],
DocURL: settings[SettingKeyDocURL],
DocUrl: settings[SettingKeyDocUrl],
}, nil
}
@@ -94,15 +94,15 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates[SettingKeyEmailVerifyEnabled] = strconv.FormatBool(settings.EmailVerifyEnabled)
// 邮件服务设置(只有非空才更新密码)
updates[SettingKeySMTPHost] = settings.SMTPHost
updates[SettingKeySMTPPort] = strconv.Itoa(settings.SMTPPort)
updates[SettingKeySMTPUsername] = settings.SMTPUsername
if settings.SMTPPassword != "" {
updates[SettingKeySMTPPassword] = settings.SMTPPassword
updates[SettingKeySmtpHost] = settings.SmtpHost
updates[SettingKeySmtpPort] = strconv.Itoa(settings.SmtpPort)
updates[SettingKeySmtpUsername] = settings.SmtpUsername
if settings.SmtpPassword != "" {
updates[SettingKeySmtpPassword] = settings.SmtpPassword
}
updates[SettingKeySMTPFrom] = settings.SMTPFrom
updates[SettingKeySMTPFromName] = settings.SMTPFromName
updates[SettingKeySMTPUseTLS] = strconv.FormatBool(settings.SMTPUseTLS)
updates[SettingKeySmtpFrom] = settings.SmtpFrom
updates[SettingKeySmtpFromName] = settings.SmtpFromName
updates[SettingKeySmtpUseTLS] = strconv.FormatBool(settings.SmtpUseTLS)
// Cloudflare Turnstile 设置(只有非空才更新密钥)
updates[SettingKeyTurnstileEnabled] = strconv.FormatBool(settings.TurnstileEnabled)
@@ -115,14 +115,21 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates[SettingKeySiteName] = settings.SiteName
updates[SettingKeySiteLogo] = settings.SiteLogo
updates[SettingKeySiteSubtitle] = settings.SiteSubtitle
updates[SettingKeyAPIBaseURL] = settings.APIBaseURL
updates[SettingKeyApiBaseUrl] = settings.ApiBaseUrl
updates[SettingKeyContactInfo] = settings.ContactInfo
updates[SettingKeyDocURL] = settings.DocURL
updates[SettingKeyDocUrl] = settings.DocUrl
// 默认配置
updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency)
updates[SettingKeyDefaultBalance] = strconv.FormatFloat(settings.DefaultBalance, 'f', 8, 64)
// Model fallback configuration
updates[SettingKeyEnableModelFallback] = strconv.FormatBool(settings.EnableModelFallback)
updates[SettingKeyFallbackModelAnthropic] = settings.FallbackModelAnthropic
updates[SettingKeyFallbackModelOpenAI] = settings.FallbackModelOpenAI
updates[SettingKeyFallbackModelGemini] = settings.FallbackModelGemini
updates[SettingKeyFallbackModelAntigravity] = settings.FallbackModelAntigravity
return s.settingRepo.SetMultiple(ctx, updates)
}
@@ -198,8 +205,14 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
SettingKeySiteLogo: "",
SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
SettingKeySMTPPort: "587",
SettingKeySMTPUseTLS: "false",
SettingKeySmtpPort: "587",
SettingKeySmtpUseTLS: "false",
// Model fallback defaults
SettingKeyEnableModelFallback: "false",
SettingKeyFallbackModelAnthropic: "claude-3-5-sonnet-20241022",
SettingKeyFallbackModelOpenAI: "gpt-4o",
SettingKeyFallbackModelGemini: "gemini-2.5-pro",
SettingKeyFallbackModelAntigravity: "gemini-2.5-pro",
}
return s.settingRepo.SetMultiple(ctx, defaults)
@@ -210,26 +223,26 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
result := &SystemSettings{
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
EmailVerifyEnabled: settings[SettingKeyEmailVerifyEnabled] == "true",
SMTPHost: settings[SettingKeySMTPHost],
SMTPUsername: settings[SettingKeySMTPUsername],
SMTPFrom: settings[SettingKeySMTPFrom],
SMTPFromName: settings[SettingKeySMTPFromName],
SMTPUseTLS: settings[SettingKeySMTPUseTLS] == "true",
SmtpHost: settings[SettingKeySmtpHost],
SmtpUsername: settings[SettingKeySmtpUsername],
SmtpFrom: settings[SettingKeySmtpFrom],
SmtpFromName: settings[SettingKeySmtpFromName],
SmtpUseTLS: settings[SettingKeySmtpUseTLS] == "true",
TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true",
TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey],
SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"),
SiteLogo: settings[SettingKeySiteLogo],
SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"),
APIBaseURL: settings[SettingKeyAPIBaseURL],
ApiBaseUrl: settings[SettingKeyApiBaseUrl],
ContactInfo: settings[SettingKeyContactInfo],
DocURL: settings[SettingKeyDocURL],
DocUrl: settings[SettingKeyDocUrl],
}
// 解析整数类型
if port, err := strconv.Atoi(settings[SettingKeySMTPPort]); err == nil {
result.SMTPPort = port
if port, err := strconv.Atoi(settings[SettingKeySmtpPort]); err == nil {
result.SmtpPort = port
} else {
result.SMTPPort = 587
result.SmtpPort = 587
}
if concurrency, err := strconv.Atoi(settings[SettingKeyDefaultConcurrency]); err == nil {
@@ -245,10 +258,17 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
result.DefaultBalance = s.cfg.Default.UserBalance
}
// 敏感信息直接返回,方便测试连接时使用
result.SMTPPassword = settings[SettingKeySMTPPassword]
// 敏感信息直接返回方便测试连接时使用
result.SmtpPassword = settings[SettingKeySmtpPassword]
result.TurnstileSecretKey = settings[SettingKeyTurnstileSecretKey]
// Model fallback settings
result.EnableModelFallback = settings[SettingKeyEnableModelFallback] == "true"
result.FallbackModelAnthropic = s.getStringOrDefault(settings, SettingKeyFallbackModelAnthropic, "claude-3-5-sonnet-20241022")
result.FallbackModelOpenAI = s.getStringOrDefault(settings, SettingKeyFallbackModelOpenAI, "gpt-4o")
result.FallbackModelGemini = s.getStringOrDefault(settings, SettingKeyFallbackModelGemini, "gemini-2.5-pro")
result.FallbackModelAntigravity = s.getStringOrDefault(settings, SettingKeyFallbackModelAntigravity, "gemini-2.5-pro")
return result
}
@@ -278,28 +298,28 @@ func (s *SettingService) GetTurnstileSecretKey(ctx context.Context) string {
return value
}
// GenerateAdminAPIKey 生成新的管理员 API Key
func (s *SettingService) GenerateAdminAPIKey(ctx context.Context) (string, error) {
// GenerateAdminApiKey 生成新的管理员 API Key
func (s *SettingService) GenerateAdminApiKey(ctx context.Context) (string, error) {
// 生成 32 字节随机数 = 64 位十六进制字符
bytes := make([]byte, 32)
if _, err := rand.Read(bytes); err != nil {
return "", fmt.Errorf("generate random bytes: %w", err)
}
key := AdminAPIKeyPrefix + hex.EncodeToString(bytes)
key := AdminApiKeyPrefix + hex.EncodeToString(bytes)
// 存储到 settings 表
if err := s.settingRepo.Set(ctx, SettingKeyAdminAPIKey, key); err != nil {
if err := s.settingRepo.Set(ctx, SettingKeyAdminApiKey, key); err != nil {
return "", fmt.Errorf("save admin api key: %w", err)
}
return key, nil
}
// GetAdminAPIKeyStatus 获取管理员 API Key 状态
// GetAdminApiKeyStatus 获取管理员 API Key 状态
// 返回脱敏的 key、是否存在、错误
func (s *SettingService) GetAdminAPIKeyStatus(ctx context.Context) (maskedKey string, exists bool, err error) {
key, err := s.settingRepo.GetValue(ctx, SettingKeyAdminAPIKey)
func (s *SettingService) GetAdminApiKeyStatus(ctx context.Context) (maskedKey string, exists bool, err error) {
key, err := s.settingRepo.GetValue(ctx, SettingKeyAdminApiKey)
if err != nil {
if errors.Is(err, ErrSettingNotFound) {
return "", false, nil
@@ -320,10 +340,10 @@ func (s *SettingService) GetAdminAPIKeyStatus(ctx context.Context) (maskedKey st
return maskedKey, true, nil
}
// GetAdminAPIKey 获取完整的管理员 API Key仅供内部验证使用
// GetAdminApiKey 获取完整的管理员 API Key仅供内部验证使用
// 如果未配置返回空字符串和 nil 错误,只有数据库错误时才返回 error
func (s *SettingService) GetAdminAPIKey(ctx context.Context) (string, error) {
key, err := s.settingRepo.GetValue(ctx, SettingKeyAdminAPIKey)
func (s *SettingService) GetAdminApiKey(ctx context.Context) (string, error) {
key, err := s.settingRepo.GetValue(ctx, SettingKeyAdminApiKey)
if err != nil {
if errors.Is(err, ErrSettingNotFound) {
return "", nil // 未配置,返回空字符串
@@ -333,7 +353,45 @@ func (s *SettingService) GetAdminAPIKey(ctx context.Context) (string, error) {
return key, nil
}
// DeleteAdminAPIKey 删除管理员 API Key
func (s *SettingService) DeleteAdminAPIKey(ctx context.Context) error {
return s.settingRepo.Delete(ctx, SettingKeyAdminAPIKey)
// DeleteAdminApiKey 删除管理员 API Key
func (s *SettingService) DeleteAdminApiKey(ctx context.Context) error {
return s.settingRepo.Delete(ctx, SettingKeyAdminApiKey)
}
// IsModelFallbackEnabled 检查是否启用模型兜底机制
func (s *SettingService) IsModelFallbackEnabled(ctx context.Context) bool {
value, err := s.settingRepo.GetValue(ctx, SettingKeyEnableModelFallback)
if err != nil {
return false // Default: disabled
}
return value == "true"
}
// GetFallbackModel 获取指定平台的兜底模型
func (s *SettingService) GetFallbackModel(ctx context.Context, platform string) string {
var key string
var defaultModel string
switch platform {
case PlatformAnthropic:
key = SettingKeyFallbackModelAnthropic
defaultModel = "claude-3-5-sonnet-20241022"
case PlatformOpenAI:
key = SettingKeyFallbackModelOpenAI
defaultModel = "gpt-4o"
case PlatformGemini:
key = SettingKeyFallbackModelGemini
defaultModel = "gemini-2.5-pro"
case PlatformAntigravity:
key = SettingKeyFallbackModelAntigravity
defaultModel = "gemini-2.5-pro"
default:
return ""
}
value, err := s.settingRepo.GetValue(ctx, key)
if err != nil || value == "" {
return defaultModel
}
return value
}

View File

@@ -4,13 +4,13 @@ type SystemSettings struct {
RegistrationEnabled bool
EmailVerifyEnabled bool
SMTPHost string
SMTPPort int
SMTPUsername string
SMTPPassword string
SMTPFrom string
SMTPFromName string
SMTPUseTLS bool
SmtpHost string
SmtpPort int
SmtpUsername string
SmtpPassword string
SmtpFrom string
SmtpFromName string
SmtpUseTLS bool
TurnstileEnabled bool
TurnstileSiteKey string
@@ -19,12 +19,19 @@ type SystemSettings struct {
SiteName string
SiteLogo string
SiteSubtitle string
APIBaseURL string
ApiBaseUrl string
ContactInfo string
DocURL string
DocUrl string
DefaultConcurrency int
DefaultBalance float64
// Model fallback configuration
EnableModelFallback bool `json:"enable_model_fallback"`
FallbackModelAnthropic string `json:"fallback_model_anthropic"`
FallbackModelOpenAI string `json:"fallback_model_openai"`
FallbackModelGemini string `json:"fallback_model_gemini"`
FallbackModelAntigravity string `json:"fallback_model_antigravity"`
}
type PublicSettings struct {
@@ -35,8 +42,8 @@ type PublicSettings struct {
SiteName string
SiteLogo string
SiteSubtitle string
APIBaseURL string
ApiBaseUrl string
ContactInfo string
DocURL string
DocUrl string
Version string
}

View File

@@ -79,7 +79,7 @@ type ReleaseInfo struct {
Name string `json:"name"`
Body string `json:"body"`
PublishedAt string `json:"published_at"`
HTMLURL string `json:"html_url"`
HtmlURL string `json:"html_url"`
Assets []Asset `json:"assets,omitempty"`
}
@@ -96,13 +96,13 @@ type GitHubRelease struct {
Name string `json:"name"`
Body string `json:"body"`
PublishedAt string `json:"published_at"`
HTMLURL string `json:"html_url"`
HtmlUrl string `json:"html_url"`
Assets []GitHubAsset `json:"assets"`
}
type GitHubAsset struct {
Name string `json:"name"`
BrowserDownloadURL string `json:"browser_download_url"`
BrowserDownloadUrl string `json:"browser_download_url"`
Size int64 `json:"size"`
}
@@ -285,7 +285,7 @@ func (s *UpdateService) fetchLatestRelease(ctx context.Context) (*UpdateInfo, er
for i, a := range release.Assets {
assets[i] = Asset{
Name: a.Name,
DownloadURL: a.BrowserDownloadURL,
DownloadURL: a.BrowserDownloadUrl,
Size: a.Size,
}
}
@@ -298,7 +298,7 @@ func (s *UpdateService) fetchLatestRelease(ctx context.Context) (*UpdateInfo, er
Name: release.Name,
Body: release.Body,
PublishedAt: release.PublishedAt,
HTMLURL: release.HTMLURL,
HtmlURL: release.HtmlUrl,
Assets: assets,
},
Cached: false,

View File

@@ -0,0 +1,35 @@
package service
import "time"
// clampInt 将整数限制在指定范围内
func clampInt(value, min, max int) int {
if value < min {
return min
}
if value > max {
return max
}
return value
}
// clampFloat64 将浮点数限制在指定范围内
func clampFloat64(value, min, max float64) float64 {
if value < min {
return min
}
if value > max {
return max
}
return value
}
// remainingSecondsUntil 计算到指定时间的剩余秒数,保证非负
func remainingSecondsUntil(t time.Time) int {
seconds := int(time.Until(t).Seconds())
if seconds < 0 {
return 0
}
return seconds
}

View File

@@ -10,7 +10,7 @@ const (
type UsageLog struct {
ID int64
UserID int64
APIKeyID int64
ApiKeyID int64
AccountID int64
RequestID string
Model string
@@ -42,7 +42,7 @@ type UsageLog struct {
CreatedAt time.Time
User *User
APIKey *APIKey
ApiKey *ApiKey
Account *Account
Group *Group
Subscription *UserSubscription

View File

@@ -17,7 +17,7 @@ var (
// CreateUsageLogRequest 创建使用日志请求
type CreateUsageLogRequest struct {
UserID int64 `json:"user_id"`
APIKeyID int64 `json:"api_key_id"`
ApiKeyID int64 `json:"api_key_id"`
AccountID int64 `json:"account_id"`
RequestID string `json:"request_id"`
Model string `json:"model"`
@@ -75,7 +75,7 @@ func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (*
// 创建使用日志
usageLog := &UsageLog{
UserID: req.UserID,
APIKeyID: req.APIKeyID,
ApiKeyID: req.ApiKeyID,
AccountID: req.AccountID,
RequestID: req.RequestID,
Model: req.Model,
@@ -128,9 +128,9 @@ func (s *UsageService) ListByUser(ctx context.Context, userID int64, params pagi
return logs, pagination, nil
}
// ListByAPIKey 获取API Key的使用日志列表
func (s *UsageService) ListByAPIKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error) {
logs, pagination, err := s.usageRepo.ListByAPIKey(ctx, apiKeyID, params)
// ListByApiKey 获取API Key的使用日志列表
func (s *UsageService) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error) {
logs, pagination, err := s.usageRepo.ListByApiKey(ctx, apiKeyID, params)
if err != nil {
return nil, nil, fmt.Errorf("list usage logs: %w", err)
}
@@ -165,9 +165,9 @@ func (s *UsageService) GetStatsByUser(ctx context.Context, userID int64, startTi
}, nil
}
// GetStatsByAPIKey 获取API Key的使用统计
func (s *UsageService) GetStatsByAPIKey(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*UsageStats, error) {
stats, err := s.usageRepo.GetAPIKeyStatsAggregated(ctx, apiKeyID, startTime, endTime)
// GetStatsByApiKey 获取API Key的使用统计
func (s *UsageService) GetStatsByApiKey(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*UsageStats, error) {
stats, err := s.usageRepo.GetApiKeyStatsAggregated(ctx, apiKeyID, startTime, endTime)
if err != nil {
return nil, fmt.Errorf("get api key stats: %w", err)
}
@@ -270,9 +270,9 @@ func (s *UsageService) GetUserModelStats(ctx context.Context, userID int64, star
return stats, nil
}
// GetBatchAPIKeyUsageStats returns today/total actual_cost for given api keys.
func (s *UsageService) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
stats, err := s.usageRepo.GetBatchAPIKeyUsageStats(ctx, apiKeyIDs)
// GetBatchApiKeyUsageStats returns today/total actual_cost for given api keys.
func (s *UsageService) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchApiKeyUsageStats, error) {
stats, err := s.usageRepo.GetBatchApiKeyUsageStats(ctx, apiKeyIDs)
if err != nil {
return nil, fmt.Errorf("get batch api key usage stats: %w", err)
}

View File

@@ -21,7 +21,7 @@ type User struct {
CreatedAt time.Time
UpdatedAt time.Time
APIKeys []APIKey
ApiKeys []ApiKey
Subscriptions []UserSubscription
}

View File

@@ -56,6 +56,10 @@ func (s *UserAttributeService) CreateDefinition(ctx context.Context, input Creat
Enabled: input.Enabled,
}
if err := validateDefinitionPattern(def); err != nil {
return nil, err
}
if err := s.defRepo.Create(ctx, def); err != nil {
return nil, fmt.Errorf("create definition: %w", err)
}
@@ -108,6 +112,10 @@ func (s *UserAttributeService) UpdateDefinition(ctx context.Context, id int64, i
def.Enabled = *input.Enabled
}
if err := validateDefinitionPattern(def); err != nil {
return nil, err
}
if err := s.defRepo.Update(ctx, def); err != nil {
return nil, fmt.Errorf("update definition: %w", err)
}
@@ -231,7 +239,10 @@ func (s *UserAttributeService) validateValue(def *UserAttributeDefinition, value
// Pattern validation
if v.Pattern != nil && *v.Pattern != "" && value != "" {
re, err := regexp.Compile(*v.Pattern)
if err == nil && !re.MatchString(value) {
if err != nil {
return validationError(def.Name + " has an invalid pattern")
}
if !re.MatchString(value) {
msg := def.Name + " format is invalid"
if v.Message != nil && *v.Message != "" {
msg = *v.Message
@@ -293,3 +304,20 @@ func isValidAttributeType(t UserAttributeType) bool {
}
return false
}
func validateDefinitionPattern(def *UserAttributeDefinition) error {
if def == nil {
return nil
}
if def.Validation.Pattern == nil {
return nil
}
pattern := strings.TrimSpace(*def.Validation.Pattern)
if pattern == "" {
return nil
}
if _, err := regexp.Compile(pattern); err != nil {
return infraerrors.BadRequest("INVALID_ATTRIBUTE_PATTERN", fmt.Sprintf("invalid pattern for %s: %v", def.Name, err))
}
return nil
}

View File

@@ -54,18 +54,6 @@ func ProvideTimingWheelService() *TimingWheelService {
return svc
}
// ProvideAntigravityQuotaRefresher creates and starts AntigravityQuotaRefresher
func ProvideAntigravityQuotaRefresher(
accountRepo AccountRepository,
proxyRepo ProxyRepository,
oauthSvc *AntigravityOAuthService,
cfg *config.Config,
) *AntigravityQuotaRefresher {
svc := NewAntigravityQuotaRefresher(accountRepo, proxyRepo, oauthSvc, cfg)
svc.Start()
return svc
}
// ProvideDeferredService creates and starts DeferredService
func ProvideDeferredService(accountRepo AccountRepository, timingWheel *TimingWheelService) *DeferredService {
svc := NewDeferredService(accountRepo, timingWheel, 10*time.Second)
@@ -73,20 +61,6 @@ func ProvideDeferredService(accountRepo AccountRepository, timingWheel *TimingWh
return svc
}
// ProvideOpsMetricsCollector creates and starts OpsMetricsCollector.
func ProvideOpsMetricsCollector(opsService *OpsService, concurrencyService *ConcurrencyService) *OpsMetricsCollector {
svc := NewOpsMetricsCollector(opsService, concurrencyService)
svc.Start()
return svc
}
// ProvideOpsAlertService creates and starts OpsAlertService.
func ProvideOpsAlertService(opsService *OpsService, userService *UserService, emailService *EmailService) *OpsAlertService {
svc := NewOpsAlertService(opsService, userService, emailService)
svc.Start()
return svc
}
// ProvideConcurrencyService creates ConcurrencyService and starts slot cleanup worker.
func ProvideConcurrencyService(cache ConcurrencyCache, accountRepo AccountRepository, cfg *config.Config) *ConcurrencyService {
svc := NewConcurrencyService(cache)
@@ -101,14 +75,13 @@ var ProviderSet = wire.NewSet(
// Core services
NewAuthService,
NewUserService,
NewAPIKeyService,
NewApiKeyService,
NewGroupService,
NewAccountService,
NewProxyService,
NewRedeemService,
NewUsageService,
NewDashboardService,
NewOpsService,
ProvidePricingService,
NewBillingService,
NewBillingCacheService,
@@ -139,8 +112,7 @@ var ProviderSet = wire.NewSet(
ProvideTokenRefreshService,
ProvideTimingWheelService,
ProvideDeferredService,
ProvideAntigravityQuotaRefresher,
ProvideOpsMetricsCollector,
ProvideOpsAlertService,
NewAntigravityQuotaFetcher,
NewUserAttributeService,
NewUsageCache,
)

View File

@@ -1,4 +1,3 @@
// Package setup provides CLI-based installation wizard for initial system configuration.
package setup
import (

View File

@@ -345,7 +345,7 @@ func writeConfigFile(cfg *SetupConfig) error {
Default struct {
UserConcurrency int `yaml:"user_concurrency"`
UserBalance float64 `yaml:"user_balance"`
APIKeyPrefix string `yaml:"api_key_prefix"`
ApiKeyPrefix string `yaml:"api_key_prefix"`
RateMultiplier float64 `yaml:"rate_multiplier"`
} `yaml:"default"`
RateLimit struct {
@@ -367,12 +367,12 @@ func writeConfigFile(cfg *SetupConfig) error {
Default: struct {
UserConcurrency int `yaml:"user_concurrency"`
UserBalance float64 `yaml:"user_balance"`
APIKeyPrefix string `yaml:"api_key_prefix"`
ApiKeyPrefix string `yaml:"api_key_prefix"`
RateMultiplier float64 `yaml:"rate_multiplier"`
}{
UserConcurrency: 5,
UserBalance: 0,
APIKeyPrefix: "sk-",
ApiKeyPrefix: "sk-",
RateMultiplier: 1.0,
},
RateLimit: struct {

View File

@@ -1,6 +1,5 @@
//go:build !embed
// Package web provides web server functionality including embedded frontend support.
package web
import (