chore: 更新依赖、配置和代码生成
主要更新: - 更新 go.mod/go.sum 依赖 - 重新生成 Ent ORM 代码 - 更新 Wire 依赖注入配置 - 添加 docker-compose.override.yml 到 .gitignore - 更新 README 文档(Simple Mode 说明和已知问题) - 清理调试日志 - 其他代码优化和格式修复
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
// Package config provides application configuration management.
|
||||
package config
|
||||
|
||||
import (
|
||||
@@ -140,7 +139,7 @@ type GatewayConfig struct {
|
||||
LogUpstreamErrorBodyMaxBytes int `mapstructure:"log_upstream_error_body_max_bytes"`
|
||||
|
||||
// API-key 账号在客户端未提供 anthropic-beta 时,是否按需自动补齐(默认关闭以保持兼容)
|
||||
InjectBetaForAPIKey bool `mapstructure:"inject_beta_for_apikey"`
|
||||
InjectBetaForApiKey bool `mapstructure:"inject_beta_for_apikey"`
|
||||
|
||||
// 是否允许对部分 400 错误触发 failover(默认关闭以避免改变语义)
|
||||
FailoverOn400 bool `mapstructure:"failover_on_400"`
|
||||
@@ -242,7 +241,7 @@ type DefaultConfig struct {
|
||||
AdminPassword string `mapstructure:"admin_password"`
|
||||
UserConcurrency int `mapstructure:"user_concurrency"`
|
||||
UserBalance float64 `mapstructure:"user_balance"`
|
||||
APIKeyPrefix string `mapstructure:"api_key_prefix"`
|
||||
ApiKeyPrefix string `mapstructure:"api_key_prefix"`
|
||||
RateMultiplier float64 `mapstructure:"rate_multiplier"`
|
||||
}
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
// Package config provides application configuration management.
|
||||
package config
|
||||
|
||||
import "github.com/google/wire"
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
// Package admin provides HTTP handlers for administrative operations including
|
||||
// dashboard statistics, user management, API key management, and account management.
|
||||
package admin
|
||||
|
||||
import (
|
||||
@@ -77,8 +75,8 @@ func (h *DashboardHandler) GetStats(c *gin.Context) {
|
||||
"active_users": stats.ActiveUsers,
|
||||
|
||||
// API Key 统计
|
||||
"total_api_keys": stats.TotalAPIKeys,
|
||||
"active_api_keys": stats.ActiveAPIKeys,
|
||||
"total_api_keys": stats.TotalApiKeys,
|
||||
"active_api_keys": stats.ActiveApiKeys,
|
||||
|
||||
// 账户统计
|
||||
"total_accounts": stats.TotalAccounts,
|
||||
@@ -195,10 +193,10 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
// GetAPIKeyUsageTrend handles getting API key usage trend data
|
||||
// GetApiKeyUsageTrend handles getting API key usage trend data
|
||||
// GET /api/v1/admin/dashboard/api-keys-trend
|
||||
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), limit (default 5)
|
||||
func (h *DashboardHandler) GetAPIKeyUsageTrend(c *gin.Context) {
|
||||
func (h *DashboardHandler) GetApiKeyUsageTrend(c *gin.Context) {
|
||||
startTime, endTime := parseTimeRange(c)
|
||||
granularity := c.DefaultQuery("granularity", "day")
|
||||
limitStr := c.DefaultQuery("limit", "5")
|
||||
@@ -207,7 +205,7 @@ func (h *DashboardHandler) GetAPIKeyUsageTrend(c *gin.Context) {
|
||||
limit = 5
|
||||
}
|
||||
|
||||
trend, err := h.dashboardService.GetAPIKeyUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit)
|
||||
trend, err := h.dashboardService.GetApiKeyUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get API key usage trend")
|
||||
return
|
||||
@@ -275,26 +273,26 @@ func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) {
|
||||
response.Success(c, gin.H{"stats": stats})
|
||||
}
|
||||
|
||||
// BatchAPIKeysUsageRequest represents the request body for batch api key usage stats
|
||||
type BatchAPIKeysUsageRequest struct {
|
||||
APIKeyIDs []int64 `json:"api_key_ids" binding:"required"`
|
||||
// BatchApiKeysUsageRequest represents the request body for batch api key usage stats
|
||||
type BatchApiKeysUsageRequest struct {
|
||||
ApiKeyIDs []int64 `json:"api_key_ids" binding:"required"`
|
||||
}
|
||||
|
||||
// GetBatchAPIKeysUsage handles getting usage stats for multiple API keys
|
||||
// GetBatchApiKeysUsage handles getting usage stats for multiple API keys
|
||||
// POST /api/v1/admin/dashboard/api-keys-usage
|
||||
func (h *DashboardHandler) GetBatchAPIKeysUsage(c *gin.Context) {
|
||||
var req BatchAPIKeysUsageRequest
|
||||
func (h *DashboardHandler) GetBatchApiKeysUsage(c *gin.Context) {
|
||||
var req BatchApiKeysUsageRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.APIKeyIDs) == 0 {
|
||||
if len(req.ApiKeyIDs) == 0 {
|
||||
response.Success(c, gin.H{"stats": map[string]any{}})
|
||||
return
|
||||
}
|
||||
|
||||
stats, err := h.dashboardService.GetBatchAPIKeyUsageStats(c.Request.Context(), req.APIKeyIDs)
|
||||
stats, err := h.dashboardService.GetBatchApiKeyUsageStats(c.Request.Context(), req.ApiKeyIDs)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get API key usage stats")
|
||||
return
|
||||
|
||||
@@ -237,9 +237,9 @@ func (h *GroupHandler) GetGroupAPIKeys(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
outKeys := make([]dto.APIKey, 0, len(keys))
|
||||
outKeys := make([]dto.ApiKey, 0, len(keys))
|
||||
for i := range keys {
|
||||
outKeys = append(outKeys, *dto.APIKeyFromService(&keys[i]))
|
||||
outKeys = append(outKeys, *dto.ApiKeyFromService(&keys[i]))
|
||||
}
|
||||
response.Paginated(c, outKeys, total, page, pageSize)
|
||||
}
|
||||
|
||||
@@ -36,24 +36,29 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
response.Success(c, dto.SystemSettings{
|
||||
RegistrationEnabled: settings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: settings.EmailVerifyEnabled,
|
||||
SMTPHost: settings.SMTPHost,
|
||||
SMTPPort: settings.SMTPPort,
|
||||
SMTPUsername: settings.SMTPUsername,
|
||||
SMTPPassword: settings.SMTPPassword,
|
||||
SMTPFrom: settings.SMTPFrom,
|
||||
SMTPFromName: settings.SMTPFromName,
|
||||
SMTPUseTLS: settings.SMTPUseTLS,
|
||||
SmtpHost: settings.SmtpHost,
|
||||
SmtpPort: settings.SmtpPort,
|
||||
SmtpUsername: settings.SmtpUsername,
|
||||
SmtpPassword: settings.SmtpPassword,
|
||||
SmtpFrom: settings.SmtpFrom,
|
||||
SmtpFromName: settings.SmtpFromName,
|
||||
SmtpUseTLS: settings.SmtpUseTLS,
|
||||
TurnstileEnabled: settings.TurnstileEnabled,
|
||||
TurnstileSiteKey: settings.TurnstileSiteKey,
|
||||
TurnstileSecretKey: settings.TurnstileSecretKey,
|
||||
SiteName: settings.SiteName,
|
||||
SiteLogo: settings.SiteLogo,
|
||||
SiteSubtitle: settings.SiteSubtitle,
|
||||
APIBaseURL: settings.APIBaseURL,
|
||||
ApiBaseUrl: settings.ApiBaseUrl,
|
||||
ContactInfo: settings.ContactInfo,
|
||||
DocURL: settings.DocURL,
|
||||
DocUrl: settings.DocUrl,
|
||||
DefaultConcurrency: settings.DefaultConcurrency,
|
||||
DefaultBalance: settings.DefaultBalance,
|
||||
EnableModelFallback: settings.EnableModelFallback,
|
||||
FallbackModelAnthropic: settings.FallbackModelAnthropic,
|
||||
FallbackModelOpenAI: settings.FallbackModelOpenAI,
|
||||
FallbackModelGemini: settings.FallbackModelGemini,
|
||||
FallbackModelAntigravity: settings.FallbackModelAntigravity,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -64,13 +69,13 @@ type UpdateSettingsRequest struct {
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
|
||||
// 邮件服务设置
|
||||
SMTPHost string `json:"smtp_host"`
|
||||
SMTPPort int `json:"smtp_port"`
|
||||
SMTPUsername string `json:"smtp_username"`
|
||||
SMTPPassword string `json:"smtp_password"`
|
||||
SMTPFrom string `json:"smtp_from_email"`
|
||||
SMTPFromName string `json:"smtp_from_name"`
|
||||
SMTPUseTLS bool `json:"smtp_use_tls"`
|
||||
SmtpHost string `json:"smtp_host"`
|
||||
SmtpPort int `json:"smtp_port"`
|
||||
SmtpUsername string `json:"smtp_username"`
|
||||
SmtpPassword string `json:"smtp_password"`
|
||||
SmtpFrom string `json:"smtp_from_email"`
|
||||
SmtpFromName string `json:"smtp_from_name"`
|
||||
SmtpUseTLS bool `json:"smtp_use_tls"`
|
||||
|
||||
// Cloudflare Turnstile 设置
|
||||
TurnstileEnabled bool `json:"turnstile_enabled"`
|
||||
@@ -81,13 +86,20 @@ type UpdateSettingsRequest struct {
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
APIBaseURL string `json:"api_base_url"`
|
||||
ApiBaseUrl string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocURL string `json:"doc_url"`
|
||||
DocUrl string `json:"doc_url"`
|
||||
|
||||
// 默认配置
|
||||
DefaultConcurrency int `json:"default_concurrency"`
|
||||
DefaultBalance float64 `json:"default_balance"`
|
||||
|
||||
// Model fallback configuration
|
||||
EnableModelFallback bool `json:"enable_model_fallback"`
|
||||
FallbackModelAnthropic string `json:"fallback_model_anthropic"`
|
||||
FallbackModelOpenAI string `json:"fallback_model_openai"`
|
||||
FallbackModelGemini string `json:"fallback_model_gemini"`
|
||||
FallbackModelAntigravity string `json:"fallback_model_antigravity"`
|
||||
}
|
||||
|
||||
// UpdateSettings 更新系统设置
|
||||
@@ -106,8 +118,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
if req.DefaultBalance < 0 {
|
||||
req.DefaultBalance = 0
|
||||
}
|
||||
if req.SMTPPort <= 0 {
|
||||
req.SMTPPort = 587
|
||||
if req.SmtpPort <= 0 {
|
||||
req.SmtpPort = 587
|
||||
}
|
||||
|
||||
// Turnstile 参数验证
|
||||
@@ -143,24 +155,29 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
settings := &service.SystemSettings{
|
||||
RegistrationEnabled: req.RegistrationEnabled,
|
||||
EmailVerifyEnabled: req.EmailVerifyEnabled,
|
||||
SMTPHost: req.SMTPHost,
|
||||
SMTPPort: req.SMTPPort,
|
||||
SMTPUsername: req.SMTPUsername,
|
||||
SMTPPassword: req.SMTPPassword,
|
||||
SMTPFrom: req.SMTPFrom,
|
||||
SMTPFromName: req.SMTPFromName,
|
||||
SMTPUseTLS: req.SMTPUseTLS,
|
||||
SmtpHost: req.SmtpHost,
|
||||
SmtpPort: req.SmtpPort,
|
||||
SmtpUsername: req.SmtpUsername,
|
||||
SmtpPassword: req.SmtpPassword,
|
||||
SmtpFrom: req.SmtpFrom,
|
||||
SmtpFromName: req.SmtpFromName,
|
||||
SmtpUseTLS: req.SmtpUseTLS,
|
||||
TurnstileEnabled: req.TurnstileEnabled,
|
||||
TurnstileSiteKey: req.TurnstileSiteKey,
|
||||
TurnstileSecretKey: req.TurnstileSecretKey,
|
||||
SiteName: req.SiteName,
|
||||
SiteLogo: req.SiteLogo,
|
||||
SiteSubtitle: req.SiteSubtitle,
|
||||
APIBaseURL: req.APIBaseURL,
|
||||
ApiBaseUrl: req.ApiBaseUrl,
|
||||
ContactInfo: req.ContactInfo,
|
||||
DocURL: req.DocURL,
|
||||
DocUrl: req.DocUrl,
|
||||
DefaultConcurrency: req.DefaultConcurrency,
|
||||
DefaultBalance: req.DefaultBalance,
|
||||
EnableModelFallback: req.EnableModelFallback,
|
||||
FallbackModelAnthropic: req.FallbackModelAnthropic,
|
||||
FallbackModelOpenAI: req.FallbackModelOpenAI,
|
||||
FallbackModelGemini: req.FallbackModelGemini,
|
||||
FallbackModelAntigravity: req.FallbackModelAntigravity,
|
||||
}
|
||||
|
||||
if err := h.settingService.UpdateSettings(c.Request.Context(), settings); err != nil {
|
||||
@@ -178,67 +195,72 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
response.Success(c, dto.SystemSettings{
|
||||
RegistrationEnabled: updatedSettings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
|
||||
SMTPHost: updatedSettings.SMTPHost,
|
||||
SMTPPort: updatedSettings.SMTPPort,
|
||||
SMTPUsername: updatedSettings.SMTPUsername,
|
||||
SMTPPassword: updatedSettings.SMTPPassword,
|
||||
SMTPFrom: updatedSettings.SMTPFrom,
|
||||
SMTPFromName: updatedSettings.SMTPFromName,
|
||||
SMTPUseTLS: updatedSettings.SMTPUseTLS,
|
||||
SmtpHost: updatedSettings.SmtpHost,
|
||||
SmtpPort: updatedSettings.SmtpPort,
|
||||
SmtpUsername: updatedSettings.SmtpUsername,
|
||||
SmtpPassword: updatedSettings.SmtpPassword,
|
||||
SmtpFrom: updatedSettings.SmtpFrom,
|
||||
SmtpFromName: updatedSettings.SmtpFromName,
|
||||
SmtpUseTLS: updatedSettings.SmtpUseTLS,
|
||||
TurnstileEnabled: updatedSettings.TurnstileEnabled,
|
||||
TurnstileSiteKey: updatedSettings.TurnstileSiteKey,
|
||||
TurnstileSecretKey: updatedSettings.TurnstileSecretKey,
|
||||
SiteName: updatedSettings.SiteName,
|
||||
SiteLogo: updatedSettings.SiteLogo,
|
||||
SiteSubtitle: updatedSettings.SiteSubtitle,
|
||||
APIBaseURL: updatedSettings.APIBaseURL,
|
||||
ApiBaseUrl: updatedSettings.ApiBaseUrl,
|
||||
ContactInfo: updatedSettings.ContactInfo,
|
||||
DocURL: updatedSettings.DocURL,
|
||||
DocUrl: updatedSettings.DocUrl,
|
||||
DefaultConcurrency: updatedSettings.DefaultConcurrency,
|
||||
DefaultBalance: updatedSettings.DefaultBalance,
|
||||
EnableModelFallback: updatedSettings.EnableModelFallback,
|
||||
FallbackModelAnthropic: updatedSettings.FallbackModelAnthropic,
|
||||
FallbackModelOpenAI: updatedSettings.FallbackModelOpenAI,
|
||||
FallbackModelGemini: updatedSettings.FallbackModelGemini,
|
||||
FallbackModelAntigravity: updatedSettings.FallbackModelAntigravity,
|
||||
})
|
||||
}
|
||||
|
||||
// TestSMTPRequest 测试SMTP连接请求
|
||||
type TestSMTPRequest struct {
|
||||
SMTPHost string `json:"smtp_host" binding:"required"`
|
||||
SMTPPort int `json:"smtp_port"`
|
||||
SMTPUsername string `json:"smtp_username"`
|
||||
SMTPPassword string `json:"smtp_password"`
|
||||
SMTPUseTLS bool `json:"smtp_use_tls"`
|
||||
// TestSmtpRequest 测试SMTP连接请求
|
||||
type TestSmtpRequest struct {
|
||||
SmtpHost string `json:"smtp_host" binding:"required"`
|
||||
SmtpPort int `json:"smtp_port"`
|
||||
SmtpUsername string `json:"smtp_username"`
|
||||
SmtpPassword string `json:"smtp_password"`
|
||||
SmtpUseTLS bool `json:"smtp_use_tls"`
|
||||
}
|
||||
|
||||
// TestSMTPConnection 测试SMTP连接
|
||||
// TestSmtpConnection 测试SMTP连接
|
||||
// POST /api/v1/admin/settings/test-smtp
|
||||
func (h *SettingHandler) TestSMTPConnection(c *gin.Context) {
|
||||
var req TestSMTPRequest
|
||||
func (h *SettingHandler) TestSmtpConnection(c *gin.Context) {
|
||||
var req TestSmtpRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if req.SMTPPort <= 0 {
|
||||
req.SMTPPort = 587
|
||||
if req.SmtpPort <= 0 {
|
||||
req.SmtpPort = 587
|
||||
}
|
||||
|
||||
// 如果未提供密码,从数据库获取已保存的密码
|
||||
password := req.SMTPPassword
|
||||
password := req.SmtpPassword
|
||||
if password == "" {
|
||||
savedConfig, err := h.emailService.GetSMTPConfig(c.Request.Context())
|
||||
savedConfig, err := h.emailService.GetSmtpConfig(c.Request.Context())
|
||||
if err == nil && savedConfig != nil {
|
||||
password = savedConfig.Password
|
||||
}
|
||||
}
|
||||
|
||||
config := &service.SMTPConfig{
|
||||
Host: req.SMTPHost,
|
||||
Port: req.SMTPPort,
|
||||
Username: req.SMTPUsername,
|
||||
config := &service.SmtpConfig{
|
||||
Host: req.SmtpHost,
|
||||
Port: req.SmtpPort,
|
||||
Username: req.SmtpUsername,
|
||||
Password: password,
|
||||
UseTLS: req.SMTPUseTLS,
|
||||
UseTLS: req.SmtpUseTLS,
|
||||
}
|
||||
|
||||
err := h.emailService.TestSMTPConnectionWithConfig(config)
|
||||
err := h.emailService.TestSmtpConnectionWithConfig(config)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
@@ -250,13 +272,13 @@ func (h *SettingHandler) TestSMTPConnection(c *gin.Context) {
|
||||
// SendTestEmailRequest 发送测试邮件请求
|
||||
type SendTestEmailRequest struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
SMTPHost string `json:"smtp_host" binding:"required"`
|
||||
SMTPPort int `json:"smtp_port"`
|
||||
SMTPUsername string `json:"smtp_username"`
|
||||
SMTPPassword string `json:"smtp_password"`
|
||||
SMTPFrom string `json:"smtp_from_email"`
|
||||
SMTPFromName string `json:"smtp_from_name"`
|
||||
SMTPUseTLS bool `json:"smtp_use_tls"`
|
||||
SmtpHost string `json:"smtp_host" binding:"required"`
|
||||
SmtpPort int `json:"smtp_port"`
|
||||
SmtpUsername string `json:"smtp_username"`
|
||||
SmtpPassword string `json:"smtp_password"`
|
||||
SmtpFrom string `json:"smtp_from_email"`
|
||||
SmtpFromName string `json:"smtp_from_name"`
|
||||
SmtpUseTLS bool `json:"smtp_use_tls"`
|
||||
}
|
||||
|
||||
// SendTestEmail 发送测试邮件
|
||||
@@ -268,27 +290,27 @@ func (h *SettingHandler) SendTestEmail(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if req.SMTPPort <= 0 {
|
||||
req.SMTPPort = 587
|
||||
if req.SmtpPort <= 0 {
|
||||
req.SmtpPort = 587
|
||||
}
|
||||
|
||||
// 如果未提供密码,从数据库获取已保存的密码
|
||||
password := req.SMTPPassword
|
||||
password := req.SmtpPassword
|
||||
if password == "" {
|
||||
savedConfig, err := h.emailService.GetSMTPConfig(c.Request.Context())
|
||||
savedConfig, err := h.emailService.GetSmtpConfig(c.Request.Context())
|
||||
if err == nil && savedConfig != nil {
|
||||
password = savedConfig.Password
|
||||
}
|
||||
}
|
||||
|
||||
config := &service.SMTPConfig{
|
||||
Host: req.SMTPHost,
|
||||
Port: req.SMTPPort,
|
||||
Username: req.SMTPUsername,
|
||||
config := &service.SmtpConfig{
|
||||
Host: req.SmtpHost,
|
||||
Port: req.SmtpPort,
|
||||
Username: req.SmtpUsername,
|
||||
Password: password,
|
||||
From: req.SMTPFrom,
|
||||
FromName: req.SMTPFromName,
|
||||
UseTLS: req.SMTPUseTLS,
|
||||
From: req.SmtpFrom,
|
||||
FromName: req.SmtpFromName,
|
||||
UseTLS: req.SmtpUseTLS,
|
||||
}
|
||||
|
||||
siteName := h.settingService.GetSiteName(c.Request.Context())
|
||||
@@ -333,10 +355,10 @@ func (h *SettingHandler) SendTestEmail(c *gin.Context) {
|
||||
response.Success(c, gin.H{"message": "Test email sent successfully"})
|
||||
}
|
||||
|
||||
// GetAdminAPIKey 获取管理员 API Key 状态
|
||||
// GetAdminApiKey 获取管理员 API Key 状态
|
||||
// GET /api/v1/admin/settings/admin-api-key
|
||||
func (h *SettingHandler) GetAdminAPIKey(c *gin.Context) {
|
||||
maskedKey, exists, err := h.settingService.GetAdminAPIKeyStatus(c.Request.Context())
|
||||
func (h *SettingHandler) GetAdminApiKey(c *gin.Context) {
|
||||
maskedKey, exists, err := h.settingService.GetAdminApiKeyStatus(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
@@ -348,10 +370,10 @@ func (h *SettingHandler) GetAdminAPIKey(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
// RegenerateAdminAPIKey 生成/重新生成管理员 API Key
|
||||
// RegenerateAdminApiKey 生成/重新生成管理员 API Key
|
||||
// POST /api/v1/admin/settings/admin-api-key/regenerate
|
||||
func (h *SettingHandler) RegenerateAdminAPIKey(c *gin.Context) {
|
||||
key, err := h.settingService.GenerateAdminAPIKey(c.Request.Context())
|
||||
func (h *SettingHandler) RegenerateAdminApiKey(c *gin.Context) {
|
||||
key, err := h.settingService.GenerateAdminApiKey(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
@@ -362,10 +384,10 @@ func (h *SettingHandler) RegenerateAdminAPIKey(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
// DeleteAdminAPIKey 删除管理员 API Key
|
||||
// DeleteAdminApiKey 删除管理员 API Key
|
||||
// DELETE /api/v1/admin/settings/admin-api-key
|
||||
func (h *SettingHandler) DeleteAdminAPIKey(c *gin.Context) {
|
||||
if err := h.settingService.DeleteAdminAPIKey(c.Request.Context()); err != nil {
|
||||
func (h *SettingHandler) DeleteAdminApiKey(c *gin.Context) {
|
||||
if err := h.settingService.DeleteAdminApiKey(c.Request.Context()); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -17,14 +17,14 @@ import (
|
||||
// UsageHandler handles admin usage-related requests
|
||||
type UsageHandler struct {
|
||||
usageService *service.UsageService
|
||||
apiKeyService *service.APIKeyService
|
||||
apiKeyService *service.ApiKeyService
|
||||
adminService service.AdminService
|
||||
}
|
||||
|
||||
// NewUsageHandler creates a new admin usage handler
|
||||
func NewUsageHandler(
|
||||
usageService *service.UsageService,
|
||||
apiKeyService *service.APIKeyService,
|
||||
apiKeyService *service.ApiKeyService,
|
||||
adminService service.AdminService,
|
||||
) *UsageHandler {
|
||||
return &UsageHandler{
|
||||
@@ -125,7 +125,7 @@ func (h *UsageHandler) List(c *gin.Context) {
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
filters := usagestats.UsageLogFilters{
|
||||
UserID: userID,
|
||||
APIKeyID: apiKeyID,
|
||||
ApiKeyID: apiKeyID,
|
||||
AccountID: accountID,
|
||||
GroupID: groupID,
|
||||
Model: model,
|
||||
@@ -207,7 +207,7 @@ func (h *UsageHandler) Stats(c *gin.Context) {
|
||||
}
|
||||
|
||||
if apiKeyID > 0 {
|
||||
stats, err := h.usageService.GetStatsByAPIKey(c.Request.Context(), apiKeyID, startTime, endTime)
|
||||
stats, err := h.usageService.GetStatsByApiKey(c.Request.Context(), apiKeyID, startTime, endTime)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
@@ -269,9 +269,9 @@ func (h *UsageHandler) SearchUsers(c *gin.Context) {
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
// SearchAPIKeys handles searching API keys by user
|
||||
// SearchApiKeys handles searching API keys by user
|
||||
// GET /api/v1/admin/usage/search-api-keys
|
||||
func (h *UsageHandler) SearchAPIKeys(c *gin.Context) {
|
||||
func (h *UsageHandler) SearchApiKeys(c *gin.Context) {
|
||||
userIDStr := c.Query("user_id")
|
||||
keyword := c.Query("q")
|
||||
|
||||
@@ -285,22 +285,22 @@ func (h *UsageHandler) SearchAPIKeys(c *gin.Context) {
|
||||
userID = id
|
||||
}
|
||||
|
||||
keys, err := h.apiKeyService.SearchAPIKeys(c.Request.Context(), userID, keyword, 30)
|
||||
keys, err := h.apiKeyService.SearchApiKeys(c.Request.Context(), userID, keyword, 30)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Return simplified API key list (only id and name)
|
||||
type SimpleAPIKey struct {
|
||||
type SimpleApiKey struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
UserID int64 `json:"user_id"`
|
||||
}
|
||||
|
||||
result := make([]SimpleAPIKey, len(keys))
|
||||
result := make([]SimpleApiKey, len(keys))
|
||||
for i, k := range keys {
|
||||
result[i] = SimpleAPIKey{
|
||||
result[i] = SimpleApiKey{
|
||||
ID: k.ID,
|
||||
Name: k.Name,
|
||||
UserID: k.UserID,
|
||||
|
||||
@@ -243,9 +243,9 @@ func (h *UserHandler) GetUserAPIKeys(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.APIKey, 0, len(keys))
|
||||
out := make([]dto.ApiKey, 0, len(keys))
|
||||
for i := range keys {
|
||||
out = append(out, *dto.APIKeyFromService(&keys[i]))
|
||||
out = append(out, *dto.ApiKeyFromService(&keys[i]))
|
||||
}
|
||||
response.Paginated(c, out, total, page, pageSize)
|
||||
}
|
||||
|
||||
@@ -14,11 +14,11 @@ import (
|
||||
|
||||
// APIKeyHandler handles API key-related requests
|
||||
type APIKeyHandler struct {
|
||||
apiKeyService *service.APIKeyService
|
||||
apiKeyService *service.ApiKeyService
|
||||
}
|
||||
|
||||
// NewAPIKeyHandler creates a new APIKeyHandler
|
||||
func NewAPIKeyHandler(apiKeyService *service.APIKeyService) *APIKeyHandler {
|
||||
func NewAPIKeyHandler(apiKeyService *service.ApiKeyService) *APIKeyHandler {
|
||||
return &APIKeyHandler{
|
||||
apiKeyService: apiKeyService,
|
||||
}
|
||||
@@ -56,9 +56,9 @@ func (h *APIKeyHandler) List(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.APIKey, 0, len(keys))
|
||||
out := make([]dto.ApiKey, 0, len(keys))
|
||||
for i := range keys {
|
||||
out = append(out, *dto.APIKeyFromService(&keys[i]))
|
||||
out = append(out, *dto.ApiKeyFromService(&keys[i]))
|
||||
}
|
||||
response.Paginated(c, out, result.Total, page, pageSize)
|
||||
}
|
||||
@@ -90,7 +90,7 @@ func (h *APIKeyHandler) GetByID(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.APIKeyFromService(key))
|
||||
response.Success(c, dto.ApiKeyFromService(key))
|
||||
}
|
||||
|
||||
// Create handles creating a new API key
|
||||
@@ -108,7 +108,7 @@ func (h *APIKeyHandler) Create(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
svcReq := service.CreateAPIKeyRequest{
|
||||
svcReq := service.CreateApiKeyRequest{
|
||||
Name: req.Name,
|
||||
GroupID: req.GroupID,
|
||||
CustomKey: req.CustomKey,
|
||||
@@ -119,7 +119,7 @@ func (h *APIKeyHandler) Create(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.APIKeyFromService(key))
|
||||
response.Success(c, dto.ApiKeyFromService(key))
|
||||
}
|
||||
|
||||
// Update handles updating an API key
|
||||
@@ -143,7 +143,7 @@ func (h *APIKeyHandler) Update(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
svcReq := service.UpdateAPIKeyRequest{}
|
||||
svcReq := service.UpdateApiKeyRequest{}
|
||||
if req.Name != "" {
|
||||
svcReq.Name = &req.Name
|
||||
}
|
||||
@@ -158,7 +158,7 @@ func (h *APIKeyHandler) Update(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.APIKeyFromService(key))
|
||||
response.Success(c, dto.ApiKeyFromService(key))
|
||||
}
|
||||
|
||||
// Delete handles deleting an API key
|
||||
|
||||
@@ -5,13 +5,13 @@ type SystemSettings struct {
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
|
||||
SMTPHost string `json:"smtp_host"`
|
||||
SMTPPort int `json:"smtp_port"`
|
||||
SMTPUsername string `json:"smtp_username"`
|
||||
SMTPPassword string `json:"smtp_password,omitempty"`
|
||||
SMTPFrom string `json:"smtp_from_email"`
|
||||
SMTPFromName string `json:"smtp_from_name"`
|
||||
SMTPUseTLS bool `json:"smtp_use_tls"`
|
||||
SmtpHost string `json:"smtp_host"`
|
||||
SmtpPort int `json:"smtp_port"`
|
||||
SmtpUsername string `json:"smtp_username"`
|
||||
SmtpPassword string `json:"smtp_password,omitempty"`
|
||||
SmtpFrom string `json:"smtp_from_email"`
|
||||
SmtpFromName string `json:"smtp_from_name"`
|
||||
SmtpUseTLS bool `json:"smtp_use_tls"`
|
||||
|
||||
TurnstileEnabled bool `json:"turnstile_enabled"`
|
||||
TurnstileSiteKey string `json:"turnstile_site_key"`
|
||||
@@ -20,12 +20,19 @@ type SystemSettings struct {
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
APIBaseURL string `json:"api_base_url"`
|
||||
ApiBaseUrl string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocURL string `json:"doc_url"`
|
||||
DocUrl string `json:"doc_url"`
|
||||
|
||||
DefaultConcurrency int `json:"default_concurrency"`
|
||||
DefaultBalance float64 `json:"default_balance"`
|
||||
|
||||
// Model fallback configuration
|
||||
EnableModelFallback bool `json:"enable_model_fallback"`
|
||||
FallbackModelAnthropic string `json:"fallback_model_anthropic"`
|
||||
FallbackModelOpenAI string `json:"fallback_model_openai"`
|
||||
FallbackModelGemini string `json:"fallback_model_gemini"`
|
||||
FallbackModelAntigravity string `json:"fallback_model_antigravity"`
|
||||
}
|
||||
|
||||
type PublicSettings struct {
|
||||
@@ -36,8 +43,8 @@ type PublicSettings struct {
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
APIBaseURL string `json:"api_base_url"`
|
||||
ApiBaseUrl string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocURL string `json:"doc_url"`
|
||||
DocUrl string `json:"doc_url"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
// AdminHandlers contains all admin-related HTTP handlers
|
||||
type AdminHandlers struct {
|
||||
Dashboard *admin.DashboardHandler
|
||||
Ops *admin.OpsHandler
|
||||
User *admin.UserHandler
|
||||
Group *admin.GroupHandler
|
||||
Account *admin.AccountHandler
|
||||
|
||||
@@ -22,7 +22,6 @@ type OpenAIGatewayHandler struct {
|
||||
gatewayService *service.OpenAIGatewayService
|
||||
billingCacheService *service.BillingCacheService
|
||||
concurrencyHelper *ConcurrencyHelper
|
||||
opsService *service.OpsService
|
||||
}
|
||||
|
||||
// NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler
|
||||
@@ -30,21 +29,19 @@ func NewOpenAIGatewayHandler(
|
||||
gatewayService *service.OpenAIGatewayService,
|
||||
concurrencyService *service.ConcurrencyService,
|
||||
billingCacheService *service.BillingCacheService,
|
||||
opsService *service.OpsService,
|
||||
) *OpenAIGatewayHandler {
|
||||
return &OpenAIGatewayHandler{
|
||||
gatewayService: gatewayService,
|
||||
billingCacheService: billingCacheService,
|
||||
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatNone),
|
||||
opsService: opsService,
|
||||
}
|
||||
}
|
||||
|
||||
// Responses handles OpenAI Responses API endpoint
|
||||
// POST /openai/v1/responses
|
||||
func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
// Get apiKey and user from context (set by APIKeyAuth middleware)
|
||||
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
|
||||
// Get apiKey and user from context (set by ApiKeyAuth middleware)
|
||||
apiKey, ok := middleware2.GetApiKeyFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
|
||||
return
|
||||
@@ -82,7 +79,6 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
// Extract model and stream
|
||||
reqModel, _ := reqBody["model"].(string)
|
||||
reqStream, _ := reqBody["stream"].(bool)
|
||||
setOpsRequestContext(c, reqModel, reqStream)
|
||||
|
||||
// 验证 model 必填
|
||||
if reqModel == "" {
|
||||
@@ -239,7 +235,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
defer cancel()
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
ApiKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: usedAccount,
|
||||
Subscription: subscription,
|
||||
@@ -282,7 +278,6 @@ func (h *OpenAIGatewayHandler) mapUpstreamError(statusCode int) (int, string, st
|
||||
// handleStreamingAwareError handles errors that may occur after streaming has started
|
||||
func (h *OpenAIGatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
|
||||
if streamStarted {
|
||||
recordOpsError(c, h.opsService, status, errType, message, service.PlatformOpenAI)
|
||||
// Stream already started, send error as SSE event then close
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if ok {
|
||||
@@ -302,7 +297,6 @@ func (h *OpenAIGatewayHandler) handleStreamingAwareError(c *gin.Context, status
|
||||
|
||||
// errorResponse returns OpenAI API format error response
|
||||
func (h *OpenAIGatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) {
|
||||
recordOpsError(c, h.opsService, status, errType, message, service.PlatformOpenAI)
|
||||
c.JSON(status, gin.H{
|
||||
"error": gin.H{
|
||||
"type": errType,
|
||||
|
||||
@@ -39,9 +39,9 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
|
||||
SiteName: settings.SiteName,
|
||||
SiteLogo: settings.SiteLogo,
|
||||
SiteSubtitle: settings.SiteSubtitle,
|
||||
APIBaseURL: settings.APIBaseURL,
|
||||
ApiBaseUrl: settings.ApiBaseUrl,
|
||||
ContactInfo: settings.ContactInfo,
|
||||
DocURL: settings.DocURL,
|
||||
DocUrl: settings.DocUrl,
|
||||
Version: h.version,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -18,11 +18,11 @@ import (
|
||||
// UsageHandler handles usage-related requests
|
||||
type UsageHandler struct {
|
||||
usageService *service.UsageService
|
||||
apiKeyService *service.APIKeyService
|
||||
apiKeyService *service.ApiKeyService
|
||||
}
|
||||
|
||||
// NewUsageHandler creates a new UsageHandler
|
||||
func NewUsageHandler(usageService *service.UsageService, apiKeyService *service.APIKeyService) *UsageHandler {
|
||||
func NewUsageHandler(usageService *service.UsageService, apiKeyService *service.ApiKeyService) *UsageHandler {
|
||||
return &UsageHandler{
|
||||
usageService: usageService,
|
||||
apiKeyService: apiKeyService,
|
||||
@@ -111,7 +111,7 @@ func (h *UsageHandler) List(c *gin.Context) {
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
filters := usagestats.UsageLogFilters{
|
||||
UserID: subject.UserID, // Always filter by current user for security
|
||||
APIKeyID: apiKeyID,
|
||||
ApiKeyID: apiKeyID,
|
||||
Model: model,
|
||||
Stream: stream,
|
||||
BillingType: billingType,
|
||||
@@ -235,7 +235,7 @@ func (h *UsageHandler) Stats(c *gin.Context) {
|
||||
var stats *service.UsageStats
|
||||
var err error
|
||||
if apiKeyID > 0 {
|
||||
stats, err = h.usageService.GetStatsByAPIKey(c.Request.Context(), apiKeyID, startTime, endTime)
|
||||
stats, err = h.usageService.GetStatsByApiKey(c.Request.Context(), apiKeyID, startTime, endTime)
|
||||
} else {
|
||||
stats, err = h.usageService.GetStatsByUser(c.Request.Context(), subject.UserID, startTime, endTime)
|
||||
}
|
||||
@@ -346,49 +346,49 @@ func (h *UsageHandler) DashboardModels(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
// BatchAPIKeysUsageRequest represents the request for batch API keys usage
|
||||
type BatchAPIKeysUsageRequest struct {
|
||||
APIKeyIDs []int64 `json:"api_key_ids" binding:"required"`
|
||||
// BatchApiKeysUsageRequest represents the request for batch API keys usage
|
||||
type BatchApiKeysUsageRequest struct {
|
||||
ApiKeyIDs []int64 `json:"api_key_ids" binding:"required"`
|
||||
}
|
||||
|
||||
// DashboardAPIKeysUsage handles getting usage stats for user's own API keys
|
||||
// DashboardApiKeysUsage handles getting usage stats for user's own API keys
|
||||
// POST /api/v1/usage/dashboard/api-keys-usage
|
||||
func (h *UsageHandler) DashboardAPIKeysUsage(c *gin.Context) {
|
||||
func (h *UsageHandler) DashboardApiKeysUsage(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
var req BatchAPIKeysUsageRequest
|
||||
var req BatchApiKeysUsageRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.APIKeyIDs) == 0 {
|
||||
if len(req.ApiKeyIDs) == 0 {
|
||||
response.Success(c, gin.H{"stats": map[string]any{}})
|
||||
return
|
||||
}
|
||||
|
||||
// Limit the number of API key IDs to prevent SQL parameter overflow
|
||||
if len(req.APIKeyIDs) > 100 {
|
||||
if len(req.ApiKeyIDs) > 100 {
|
||||
response.BadRequest(c, "Too many API key IDs (maximum 100 allowed)")
|
||||
return
|
||||
}
|
||||
|
||||
validAPIKeyIDs, err := h.apiKeyService.VerifyOwnership(c.Request.Context(), subject.UserID, req.APIKeyIDs)
|
||||
validApiKeyIDs, err := h.apiKeyService.VerifyOwnership(c.Request.Context(), subject.UserID, req.ApiKeyIDs)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(validAPIKeyIDs) == 0 {
|
||||
if len(validApiKeyIDs) == 0 {
|
||||
response.Success(c, gin.H{"stats": map[string]any{}})
|
||||
return
|
||||
}
|
||||
|
||||
stats, err := h.usageService.GetBatchAPIKeyUsageStats(c.Request.Context(), validAPIKeyIDs)
|
||||
stats, err := h.usageService.GetBatchApiKeyUsageStats(c.Request.Context(), validApiKeyIDs)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
// ProvideAdminHandlers creates the AdminHandlers struct
|
||||
func ProvideAdminHandlers(
|
||||
dashboardHandler *admin.DashboardHandler,
|
||||
opsHandler *admin.OpsHandler,
|
||||
userHandler *admin.UserHandler,
|
||||
groupHandler *admin.GroupHandler,
|
||||
accountHandler *admin.AccountHandler,
|
||||
@@ -28,7 +27,6 @@ func ProvideAdminHandlers(
|
||||
) *AdminHandlers {
|
||||
return &AdminHandlers{
|
||||
Dashboard: dashboardHandler,
|
||||
Ops: opsHandler,
|
||||
User: userHandler,
|
||||
Group: groupHandler,
|
||||
Account: accountHandler,
|
||||
@@ -98,7 +96,6 @@ var ProviderSet = wire.NewSet(
|
||||
|
||||
// Admin handlers
|
||||
admin.NewDashboardHandler,
|
||||
admin.NewOpsHandler,
|
||||
admin.NewUserHandler,
|
||||
admin.NewGroupHandler,
|
||||
admin.NewAccountHandler,
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
// Package claude provides Claude API client constants and utilities.
|
||||
package claude
|
||||
|
||||
// Claude Code 客户端相关常量
|
||||
@@ -17,13 +16,13 @@ const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleav
|
||||
// HaikuBetaHeader Haiku 模型使用的 anthropic-beta header(不需要 claude-code beta)
|
||||
const HaikuBetaHeader = BetaOAuth + "," + BetaInterleavedThinking
|
||||
|
||||
// APIKeyBetaHeader API-key 账号建议使用的 anthropic-beta header(不包含 oauth)
|
||||
const APIKeyBetaHeader = BetaClaudeCode + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
|
||||
// ApiKeyBetaHeader API-key 账号建议使用的 anthropic-beta header(不包含 oauth)
|
||||
const ApiKeyBetaHeader = BetaClaudeCode + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
|
||||
|
||||
// APIKeyHaikuBetaHeader Haiku 模型在 API-key 账号下使用的 anthropic-beta header(不包含 oauth / claude-code)
|
||||
const APIKeyHaikuBetaHeader = BetaInterleavedThinking
|
||||
// ApiKeyHaikuBetaHeader Haiku 模型在 API-key 账号下使用的 anthropic-beta header(不包含 oauth / claude-code)
|
||||
const ApiKeyHaikuBetaHeader = BetaInterleavedThinking
|
||||
|
||||
// DefaultHeaders are the default request headers for Claude Code client.
|
||||
// Claude Code 客户端默认请求头
|
||||
var DefaultHeaders = map[string]string{
|
||||
"User-Agent": "claude-cli/2.0.62 (external, cli)",
|
||||
"X-Stainless-Lang": "js",
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
// Package errors provides custom error types and error handling utilities.
|
||||
// nolint:mnd
|
||||
package errors
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
// Package gemini provides minimal fallback model metadata for Gemini native endpoints.
|
||||
package gemini
|
||||
|
||||
// This package is used when upstream model listing is unavailable (e.g. OAuth token missing AI Studio scopes).
|
||||
// This package provides minimal fallback model metadata for Gemini native endpoints.
|
||||
// It is used when upstream model listing is unavailable (e.g. OAuth token missing AI Studio scopes).
|
||||
|
||||
type Model struct {
|
||||
Name string `json:"name"`
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
// Package googleapi provides utilities for Google API interactions.
|
||||
package googleapi
|
||||
|
||||
import "net/http"
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
// Package oauth provides OAuth 2.0 utilities including PKCE flow, session management, and token exchange.
|
||||
package oauth
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
// Package openai provides OpenAI API models and configuration.
|
||||
package openai
|
||||
|
||||
import _ "embed"
|
||||
|
||||
@@ -327,7 +327,7 @@ func ParseIDToken(idToken string) (*IDTokenClaims, error) {
|
||||
return &claims, nil
|
||||
}
|
||||
|
||||
// UserInfo extracts user information from ID Token claims
|
||||
// ExtractUserInfo extracts user information from ID Token claims
|
||||
type UserInfo struct {
|
||||
Email string
|
||||
ChatGPTAccountID string
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
// Package pagination provides utilities for handling paginated queries and results.
|
||||
package pagination
|
||||
|
||||
// PaginationParams 分页参数
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
// Package response provides HTTP response utilities for standardized API responses and error handling.
|
||||
package response
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
// Package sysutil provides system-level utilities for service management.
|
||||
package sysutil
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
// Package usagestats defines types for tracking and reporting API usage statistics.
|
||||
package usagestats
|
||||
|
||||
import "time"
|
||||
@@ -11,8 +10,8 @@ type DashboardStats struct {
|
||||
ActiveUsers int64 `json:"active_users"` // 今日有请求的用户数
|
||||
|
||||
// API Key 统计
|
||||
TotalAPIKeys int64 `json:"total_api_keys"`
|
||||
ActiveAPIKeys int64 `json:"active_api_keys"` // 状态为 active 的 API Key 数
|
||||
TotalApiKeys int64 `json:"total_api_keys"`
|
||||
ActiveApiKeys int64 `json:"active_api_keys"` // 状态为 active 的 API Key 数
|
||||
|
||||
// 账户统计
|
||||
TotalAccounts int64 `json:"total_accounts"`
|
||||
@@ -83,10 +82,10 @@ type UserUsageTrendPoint struct {
|
||||
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||
}
|
||||
|
||||
// APIKeyUsageTrendPoint represents API key usage trend data point
|
||||
type APIKeyUsageTrendPoint struct {
|
||||
// ApiKeyUsageTrendPoint represents API key usage trend data point
|
||||
type ApiKeyUsageTrendPoint struct {
|
||||
Date string `json:"date"`
|
||||
APIKeyID int64 `json:"api_key_id"`
|
||||
ApiKeyID int64 `json:"api_key_id"`
|
||||
KeyName string `json:"key_name"`
|
||||
Requests int64 `json:"requests"`
|
||||
Tokens int64 `json:"tokens"`
|
||||
@@ -95,8 +94,8 @@ type APIKeyUsageTrendPoint struct {
|
||||
// UserDashboardStats 用户仪表盘统计
|
||||
type UserDashboardStats struct {
|
||||
// API Key 统计
|
||||
TotalAPIKeys int64 `json:"total_api_keys"`
|
||||
ActiveAPIKeys int64 `json:"active_api_keys"`
|
||||
TotalApiKeys int64 `json:"total_api_keys"`
|
||||
ActiveApiKeys int64 `json:"active_api_keys"`
|
||||
|
||||
// 累计 Token 使用统计
|
||||
TotalRequests int64 `json:"total_requests"`
|
||||
@@ -129,7 +128,7 @@ type UserDashboardStats struct {
|
||||
// UsageLogFilters represents filters for usage log queries
|
||||
type UsageLogFilters struct {
|
||||
UserID int64
|
||||
APIKeyID int64
|
||||
ApiKeyID int64
|
||||
AccountID int64
|
||||
GroupID int64
|
||||
Model string
|
||||
@@ -158,9 +157,9 @@ type BatchUserUsageStats struct {
|
||||
TotalActualCost float64 `json:"total_actual_cost"`
|
||||
}
|
||||
|
||||
// BatchAPIKeyUsageStats represents usage stats for a single API key
|
||||
type BatchAPIKeyUsageStats struct {
|
||||
APIKeyID int64 `json:"api_key_id"`
|
||||
// BatchApiKeyUsageStats represents usage stats for a single API key
|
||||
type BatchApiKeyUsageStats struct {
|
||||
ApiKeyID int64 `json:"api_key_id"`
|
||||
TodayActualCost float64 `json:"today_actual_cost"`
|
||||
TotalActualCost float64 `json:"total_actual_cost"`
|
||||
}
|
||||
|
||||
@@ -135,12 +135,12 @@ func (s *AccountRepoSuite) TestListWithFilters() {
|
||||
name: "filter_by_type",
|
||||
setup: func(client *dbent.Client) {
|
||||
mustCreateAccount(s.T(), client, &service.Account{Name: "t1", Type: service.AccountTypeOAuth})
|
||||
mustCreateAccount(s.T(), client, &service.Account{Name: "t2", Type: service.AccountTypeAPIKey})
|
||||
mustCreateAccount(s.T(), client, &service.Account{Name: "t2", Type: service.AccountTypeApiKey})
|
||||
},
|
||||
accType: service.AccountTypeAPIKey,
|
||||
accType: service.AccountTypeApiKey,
|
||||
wantCount: 1,
|
||||
validate: func(accounts []service.Account) {
|
||||
s.Require().Equal(service.AccountTypeAPIKey, accounts[0].Type)
|
||||
s.Require().Equal(service.AccountTypeApiKey, accounts[0].Type)
|
||||
},
|
||||
},
|
||||
{
|
||||
|
||||
@@ -80,7 +80,7 @@ func TestUserRepository_RemoveGroupFromAllowedGroups_RemovesAllOccurrences(t *te
|
||||
require.NotContains(t, u2After.AllowedGroups, targetGroup.ID)
|
||||
}
|
||||
|
||||
func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsAPIKeys(t *testing.T) {
|
||||
func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsApiKeys(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
tx := testEntTx(t)
|
||||
entClient := tx.Client()
|
||||
@@ -98,7 +98,7 @@ func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsAPIKeys(t *t
|
||||
|
||||
userRepo := newUserRepositoryWithSQL(entClient, tx)
|
||||
groupRepo := newGroupRepositoryWithSQL(entClient, tx)
|
||||
apiKeyRepo := NewAPIKeyRepository(entClient)
|
||||
apiKeyRepo := NewApiKeyRepository(entClient)
|
||||
|
||||
u := &service.User{
|
||||
Email: uniqueTestValue(t, "cascade-user") + "@example.com",
|
||||
@@ -110,7 +110,7 @@ func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsAPIKeys(t *t
|
||||
}
|
||||
require.NoError(t, userRepo.Create(ctx, u))
|
||||
|
||||
key := &service.APIKey{
|
||||
key := &service.ApiKey{
|
||||
UserID: u.ID,
|
||||
Key: uniqueTestValue(t, "sk-test-delete-cascade"),
|
||||
Name: "test key",
|
||||
|
||||
@@ -24,7 +24,7 @@ type apiKeyCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewAPIKeyCache(rdb *redis.Client) service.APIKeyCache {
|
||||
func NewApiKeyCache(rdb *redis.Client) service.ApiKeyCache {
|
||||
return &apiKeyCache{rdb: rdb}
|
||||
}
|
||||
|
||||
|
||||
@@ -13,11 +13,11 @@ import (
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type APIKeyCacheSuite struct {
|
||||
type ApiKeyCacheSuite struct {
|
||||
IntegrationRedisSuite
|
||||
}
|
||||
|
||||
func (s *APIKeyCacheSuite) TestCreateAttemptCount() {
|
||||
func (s *ApiKeyCacheSuite) TestCreateAttemptCount() {
|
||||
tests := []struct {
|
||||
name string
|
||||
fn func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache)
|
||||
@@ -78,7 +78,7 @@ func (s *APIKeyCacheSuite) TestCreateAttemptCount() {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *APIKeyCacheSuite) TestDailyUsage() {
|
||||
func (s *ApiKeyCacheSuite) TestDailyUsage() {
|
||||
tests := []struct {
|
||||
name string
|
||||
fn func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache)
|
||||
@@ -122,6 +122,6 @@ func (s *APIKeyCacheSuite) TestDailyUsage() {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIKeyCacheSuite(t *testing.T) {
|
||||
suite.Run(t, new(APIKeyCacheSuite))
|
||||
func TestApiKeyCacheSuite(t *testing.T) {
|
||||
suite.Run(t, new(ApiKeyCacheSuite))
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAPIKeyRateLimitKey(t *testing.T) {
|
||||
func TestApiKeyRateLimitKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
userID int64
|
||||
|
||||
@@ -16,17 +16,17 @@ type apiKeyRepository struct {
|
||||
client *dbent.Client
|
||||
}
|
||||
|
||||
func NewAPIKeyRepository(client *dbent.Client) service.APIKeyRepository {
|
||||
func NewApiKeyRepository(client *dbent.Client) service.ApiKeyRepository {
|
||||
return &apiKeyRepository{client: client}
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) activeQuery() *dbent.APIKeyQuery {
|
||||
func (r *apiKeyRepository) activeQuery() *dbent.ApiKeyQuery {
|
||||
// 默认过滤已软删除记录,避免删除后仍被查询到。
|
||||
return r.client.APIKey.Query().Where(apikey.DeletedAtIsNil())
|
||||
return r.client.ApiKey.Query().Where(apikey.DeletedAtIsNil())
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) Create(ctx context.Context, key *service.APIKey) error {
|
||||
created, err := r.client.APIKey.Create().
|
||||
func (r *apiKeyRepository) Create(ctx context.Context, key *service.ApiKey) error {
|
||||
created, err := r.client.ApiKey.Create().
|
||||
SetUserID(key.UserID).
|
||||
SetKey(key.Key).
|
||||
SetName(key.Name).
|
||||
@@ -38,10 +38,10 @@ func (r *apiKeyRepository) Create(ctx context.Context, key *service.APIKey) erro
|
||||
key.CreatedAt = created.CreatedAt
|
||||
key.UpdatedAt = created.UpdatedAt
|
||||
}
|
||||
return translatePersistenceError(err, nil, service.ErrAPIKeyExists)
|
||||
return translatePersistenceError(err, nil, service.ErrApiKeyExists)
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.APIKey, error) {
|
||||
func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) {
|
||||
m, err := r.activeQuery().
|
||||
Where(apikey.IDEQ(id)).
|
||||
WithUser().
|
||||
@@ -49,7 +49,7 @@ func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.APIK
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
return nil, service.ErrApiKeyNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
@@ -59,7 +59,7 @@ func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.APIK
|
||||
// GetOwnerID 根据 API Key ID 获取其所有者(用户)的 ID。
|
||||
// 相比 GetByID,此方法性能更优,因为:
|
||||
// - 使用 Select() 只查询 user_id 字段,减少数据传输量
|
||||
// - 不加载完整的 APIKey 实体及其关联数据(User、Group 等)
|
||||
// - 不加载完整的 ApiKey 实体及其关联数据(User、Group 等)
|
||||
// - 适用于权限验证等只需用户 ID 的场景(如删除前的所有权检查)
|
||||
func (r *apiKeyRepository) GetOwnerID(ctx context.Context, id int64) (int64, error) {
|
||||
m, err := r.activeQuery().
|
||||
@@ -68,14 +68,14 @@ func (r *apiKeyRepository) GetOwnerID(ctx context.Context, id int64) (int64, err
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return 0, service.ErrAPIKeyNotFound
|
||||
return 0, service.ErrApiKeyNotFound
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
return m.UserID, nil
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) {
|
||||
m, err := r.activeQuery().
|
||||
Where(apikey.KeyEQ(key)).
|
||||
WithUser().
|
||||
@@ -83,21 +83,21 @@ func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.A
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
return nil, service.ErrApiKeyNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return apiKeyEntityToService(m), nil
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) error {
|
||||
func (r *apiKeyRepository) Update(ctx context.Context, key *service.ApiKey) error {
|
||||
// 使用原子操作:将软删除检查与更新合并到同一语句,避免竞态条件。
|
||||
// 之前的实现先检查 Exist 再 UpdateOneID,若在两步之间发生软删除,
|
||||
// 则会更新已删除的记录。
|
||||
// 这里选择 Update().Where(),确保只有未软删除记录能被更新。
|
||||
// 同时显式设置 updated_at,避免二次查询带来的并发可见性问题。
|
||||
now := time.Now()
|
||||
builder := r.client.APIKey.Update().
|
||||
builder := r.client.ApiKey.Update().
|
||||
Where(apikey.IDEQ(key.ID), apikey.DeletedAtIsNil()).
|
||||
SetName(key.Name).
|
||||
SetStatus(key.Status).
|
||||
@@ -114,7 +114,7 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) erro
|
||||
}
|
||||
if affected == 0 {
|
||||
// 更新影响行数为 0,说明记录不存在或已被软删除。
|
||||
return service.ErrAPIKeyNotFound
|
||||
return service.ErrApiKeyNotFound
|
||||
}
|
||||
|
||||
// 使用同一时间戳回填,避免并发删除导致二次查询失败。
|
||||
@@ -124,18 +124,18 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) erro
|
||||
|
||||
func (r *apiKeyRepository) Delete(ctx context.Context, id int64) error {
|
||||
// 显式软删除:避免依赖 Hook 行为,确保 deleted_at 一定被设置。
|
||||
affected, err := r.client.APIKey.Update().
|
||||
affected, err := r.client.ApiKey.Update().
|
||||
Where(apikey.IDEQ(id), apikey.DeletedAtIsNil()).
|
||||
SetDeletedAt(time.Now()).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return service.ErrAPIKeyNotFound
|
||||
return service.ErrApiKeyNotFound
|
||||
}
|
||||
return err
|
||||
}
|
||||
if affected == 0 {
|
||||
exists, err := r.client.APIKey.Query().
|
||||
exists, err := r.client.ApiKey.Query().
|
||||
Where(apikey.IDEQ(id)).
|
||||
Exist(mixins.SkipSoftDelete(ctx))
|
||||
if err != nil {
|
||||
@@ -144,12 +144,12 @@ func (r *apiKeyRepository) Delete(ctx context.Context, id int64) error {
|
||||
if exists {
|
||||
return nil
|
||||
}
|
||||
return service.ErrAPIKeyNotFound
|
||||
return service.ErrApiKeyNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
|
||||
func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
|
||||
q := r.activeQuery().Where(apikey.UserIDEQ(userID))
|
||||
|
||||
total, err := q.Count(ctx)
|
||||
@@ -167,7 +167,7 @@ func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, param
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
outKeys := make([]service.APIKey, 0, len(keys))
|
||||
outKeys := make([]service.ApiKey, 0, len(keys))
|
||||
for i := range keys {
|
||||
outKeys = append(outKeys, *apiKeyEntityToService(keys[i]))
|
||||
}
|
||||
@@ -180,7 +180,7 @@ func (r *apiKeyRepository) VerifyOwnership(ctx context.Context, userID int64, ap
|
||||
return []int64{}, nil
|
||||
}
|
||||
|
||||
ids, err := r.client.APIKey.Query().
|
||||
ids, err := r.client.ApiKey.Query().
|
||||
Where(apikey.UserIDEQ(userID), apikey.IDIn(apiKeyIDs...), apikey.DeletedAtIsNil()).
|
||||
IDs(ctx)
|
||||
if err != nil {
|
||||
@@ -199,7 +199,7 @@ func (r *apiKeyRepository) ExistsByKey(ctx context.Context, key string) (bool, e
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
|
||||
func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
|
||||
q := r.activeQuery().Where(apikey.GroupIDEQ(groupID))
|
||||
|
||||
total, err := q.Count(ctx)
|
||||
@@ -217,7 +217,7 @@ func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, par
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
outKeys := make([]service.APIKey, 0, len(keys))
|
||||
outKeys := make([]service.ApiKey, 0, len(keys))
|
||||
for i := range keys {
|
||||
outKeys = append(outKeys, *apiKeyEntityToService(keys[i]))
|
||||
}
|
||||
@@ -225,8 +225,8 @@ func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, par
|
||||
return outKeys, paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
// SearchAPIKeys searches API keys by user ID and/or keyword (name)
|
||||
func (r *apiKeyRepository) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.APIKey, error) {
|
||||
// SearchApiKeys searches API keys by user ID and/or keyword (name)
|
||||
func (r *apiKeyRepository) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) {
|
||||
q := r.activeQuery()
|
||||
if userID > 0 {
|
||||
q = q.Where(apikey.UserIDEQ(userID))
|
||||
@@ -241,7 +241,7 @@ func (r *apiKeyRepository) SearchAPIKeys(ctx context.Context, userID int64, keyw
|
||||
return nil, err
|
||||
}
|
||||
|
||||
outKeys := make([]service.APIKey, 0, len(keys))
|
||||
outKeys := make([]service.ApiKey, 0, len(keys))
|
||||
for i := range keys {
|
||||
outKeys = append(outKeys, *apiKeyEntityToService(keys[i]))
|
||||
}
|
||||
@@ -250,7 +250,7 @@ func (r *apiKeyRepository) SearchAPIKeys(ctx context.Context, userID int64, keyw
|
||||
|
||||
// ClearGroupIDByGroupID 将指定分组的所有 API Key 的 group_id 设为 nil
|
||||
func (r *apiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
n, err := r.client.APIKey.Update().
|
||||
n, err := r.client.ApiKey.Update().
|
||||
Where(apikey.GroupIDEQ(groupID), apikey.DeletedAtIsNil()).
|
||||
ClearGroupID().
|
||||
Save(ctx)
|
||||
@@ -263,11 +263,11 @@ func (r *apiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (i
|
||||
return int64(count), err
|
||||
}
|
||||
|
||||
func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey {
|
||||
func apiKeyEntityToService(m *dbent.ApiKey) *service.ApiKey {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
out := &service.APIKey{
|
||||
out := &service.ApiKey{
|
||||
ID: m.ID,
|
||||
UserID: m.UserID,
|
||||
Key: m.Key,
|
||||
|
||||
@@ -12,30 +12,30 @@ import (
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type APIKeyRepoSuite struct {
|
||||
type ApiKeyRepoSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
client *dbent.Client
|
||||
repo *apiKeyRepository
|
||||
}
|
||||
|
||||
func (s *APIKeyRepoSuite) SetupTest() {
|
||||
func (s *ApiKeyRepoSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
tx := testEntTx(s.T())
|
||||
s.client = tx.Client()
|
||||
s.repo = NewAPIKeyRepository(s.client).(*apiKeyRepository)
|
||||
s.repo = NewApiKeyRepository(s.client).(*apiKeyRepository)
|
||||
}
|
||||
|
||||
func TestAPIKeyRepoSuite(t *testing.T) {
|
||||
suite.Run(t, new(APIKeyRepoSuite))
|
||||
func TestApiKeyRepoSuite(t *testing.T) {
|
||||
suite.Run(t, new(ApiKeyRepoSuite))
|
||||
}
|
||||
|
||||
// --- Create / GetByID / GetByKey ---
|
||||
|
||||
func (s *APIKeyRepoSuite) TestCreate() {
|
||||
func (s *ApiKeyRepoSuite) TestCreate() {
|
||||
user := s.mustCreateUser("create@test.com")
|
||||
|
||||
key := &service.APIKey{
|
||||
key := &service.ApiKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-create-test",
|
||||
Name: "Test Key",
|
||||
@@ -51,16 +51,16 @@ func (s *APIKeyRepoSuite) TestCreate() {
|
||||
s.Require().Equal("sk-create-test", got.Key)
|
||||
}
|
||||
|
||||
func (s *APIKeyRepoSuite) TestGetByID_NotFound() {
|
||||
func (s *ApiKeyRepoSuite) TestGetByID_NotFound() {
|
||||
_, err := s.repo.GetByID(s.ctx, 999999)
|
||||
s.Require().Error(err, "expected error for non-existent ID")
|
||||
}
|
||||
|
||||
func (s *APIKeyRepoSuite) TestGetByKey() {
|
||||
func (s *ApiKeyRepoSuite) TestGetByKey() {
|
||||
user := s.mustCreateUser("getbykey@test.com")
|
||||
group := s.mustCreateGroup("g-key")
|
||||
|
||||
key := &service.APIKey{
|
||||
key := &service.ApiKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-getbykey",
|
||||
Name: "My Key",
|
||||
@@ -78,16 +78,16 @@ func (s *APIKeyRepoSuite) TestGetByKey() {
|
||||
s.Require().Equal(group.ID, got.Group.ID)
|
||||
}
|
||||
|
||||
func (s *APIKeyRepoSuite) TestGetByKey_NotFound() {
|
||||
func (s *ApiKeyRepoSuite) TestGetByKey_NotFound() {
|
||||
_, err := s.repo.GetByKey(s.ctx, "non-existent-key")
|
||||
s.Require().Error(err, "expected error for non-existent key")
|
||||
}
|
||||
|
||||
// --- Update ---
|
||||
|
||||
func (s *APIKeyRepoSuite) TestUpdate() {
|
||||
func (s *ApiKeyRepoSuite) TestUpdate() {
|
||||
user := s.mustCreateUser("update@test.com")
|
||||
key := &service.APIKey{
|
||||
key := &service.ApiKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-update",
|
||||
Name: "Original",
|
||||
@@ -108,10 +108,10 @@ func (s *APIKeyRepoSuite) TestUpdate() {
|
||||
s.Require().Equal(service.StatusDisabled, got.Status)
|
||||
}
|
||||
|
||||
func (s *APIKeyRepoSuite) TestUpdate_ClearGroupID() {
|
||||
func (s *ApiKeyRepoSuite) TestUpdate_ClearGroupID() {
|
||||
user := s.mustCreateUser("cleargroup@test.com")
|
||||
group := s.mustCreateGroup("g-clear")
|
||||
key := &service.APIKey{
|
||||
key := &service.ApiKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-clear-group",
|
||||
Name: "Group Key",
|
||||
@@ -131,9 +131,9 @@ func (s *APIKeyRepoSuite) TestUpdate_ClearGroupID() {
|
||||
|
||||
// --- Delete ---
|
||||
|
||||
func (s *APIKeyRepoSuite) TestDelete() {
|
||||
func (s *ApiKeyRepoSuite) TestDelete() {
|
||||
user := s.mustCreateUser("delete@test.com")
|
||||
key := &service.APIKey{
|
||||
key := &service.ApiKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-delete",
|
||||
Name: "Delete Me",
|
||||
@@ -150,10 +150,10 @@ func (s *APIKeyRepoSuite) TestDelete() {
|
||||
|
||||
// --- ListByUserID / CountByUserID ---
|
||||
|
||||
func (s *APIKeyRepoSuite) TestListByUserID() {
|
||||
func (s *ApiKeyRepoSuite) TestListByUserID() {
|
||||
user := s.mustCreateUser("listbyuser@test.com")
|
||||
s.mustCreateAPIKey(user.ID, "sk-list-1", "Key 1", nil)
|
||||
s.mustCreateAPIKey(user.ID, "sk-list-2", "Key 2", nil)
|
||||
s.mustCreateApiKey(user.ID, "sk-list-1", "Key 1", nil)
|
||||
s.mustCreateApiKey(user.ID, "sk-list-2", "Key 2", nil)
|
||||
|
||||
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
s.Require().NoError(err, "ListByUserID")
|
||||
@@ -161,10 +161,10 @@ func (s *APIKeyRepoSuite) TestListByUserID() {
|
||||
s.Require().Equal(int64(2), page.Total)
|
||||
}
|
||||
|
||||
func (s *APIKeyRepoSuite) TestListByUserID_Pagination() {
|
||||
func (s *ApiKeyRepoSuite) TestListByUserID_Pagination() {
|
||||
user := s.mustCreateUser("paging@test.com")
|
||||
for i := 0; i < 5; i++ {
|
||||
s.mustCreateAPIKey(user.ID, "sk-page-"+string(rune('a'+i)), "Key", nil)
|
||||
s.mustCreateApiKey(user.ID, "sk-page-"+string(rune('a'+i)), "Key", nil)
|
||||
}
|
||||
|
||||
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 2})
|
||||
@@ -174,10 +174,10 @@ func (s *APIKeyRepoSuite) TestListByUserID_Pagination() {
|
||||
s.Require().Equal(3, page.Pages)
|
||||
}
|
||||
|
||||
func (s *APIKeyRepoSuite) TestCountByUserID() {
|
||||
func (s *ApiKeyRepoSuite) TestCountByUserID() {
|
||||
user := s.mustCreateUser("count@test.com")
|
||||
s.mustCreateAPIKey(user.ID, "sk-count-1", "K1", nil)
|
||||
s.mustCreateAPIKey(user.ID, "sk-count-2", "K2", nil)
|
||||
s.mustCreateApiKey(user.ID, "sk-count-1", "K1", nil)
|
||||
s.mustCreateApiKey(user.ID, "sk-count-2", "K2", nil)
|
||||
|
||||
count, err := s.repo.CountByUserID(s.ctx, user.ID)
|
||||
s.Require().NoError(err, "CountByUserID")
|
||||
@@ -186,13 +186,13 @@ func (s *APIKeyRepoSuite) TestCountByUserID() {
|
||||
|
||||
// --- ListByGroupID / CountByGroupID ---
|
||||
|
||||
func (s *APIKeyRepoSuite) TestListByGroupID() {
|
||||
func (s *ApiKeyRepoSuite) TestListByGroupID() {
|
||||
user := s.mustCreateUser("listbygroup@test.com")
|
||||
group := s.mustCreateGroup("g-list")
|
||||
|
||||
s.mustCreateAPIKey(user.ID, "sk-grp-1", "K1", &group.ID)
|
||||
s.mustCreateAPIKey(user.ID, "sk-grp-2", "K2", &group.ID)
|
||||
s.mustCreateAPIKey(user.ID, "sk-grp-3", "K3", nil) // no group
|
||||
s.mustCreateApiKey(user.ID, "sk-grp-1", "K1", &group.ID)
|
||||
s.mustCreateApiKey(user.ID, "sk-grp-2", "K2", &group.ID)
|
||||
s.mustCreateApiKey(user.ID, "sk-grp-3", "K3", nil) // no group
|
||||
|
||||
keys, page, err := s.repo.ListByGroupID(s.ctx, group.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
s.Require().NoError(err, "ListByGroupID")
|
||||
@@ -202,10 +202,10 @@ func (s *APIKeyRepoSuite) TestListByGroupID() {
|
||||
s.Require().NotNil(keys[0].User)
|
||||
}
|
||||
|
||||
func (s *APIKeyRepoSuite) TestCountByGroupID() {
|
||||
func (s *ApiKeyRepoSuite) TestCountByGroupID() {
|
||||
user := s.mustCreateUser("countgroup@test.com")
|
||||
group := s.mustCreateGroup("g-count")
|
||||
s.mustCreateAPIKey(user.ID, "sk-gc-1", "K1", &group.ID)
|
||||
s.mustCreateApiKey(user.ID, "sk-gc-1", "K1", &group.ID)
|
||||
|
||||
count, err := s.repo.CountByGroupID(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "CountByGroupID")
|
||||
@@ -214,9 +214,9 @@ func (s *APIKeyRepoSuite) TestCountByGroupID() {
|
||||
|
||||
// --- ExistsByKey ---
|
||||
|
||||
func (s *APIKeyRepoSuite) TestExistsByKey() {
|
||||
func (s *ApiKeyRepoSuite) TestExistsByKey() {
|
||||
user := s.mustCreateUser("exists@test.com")
|
||||
s.mustCreateAPIKey(user.ID, "sk-exists", "K", nil)
|
||||
s.mustCreateApiKey(user.ID, "sk-exists", "K", nil)
|
||||
|
||||
exists, err := s.repo.ExistsByKey(s.ctx, "sk-exists")
|
||||
s.Require().NoError(err, "ExistsByKey")
|
||||
@@ -227,47 +227,47 @@ func (s *APIKeyRepoSuite) TestExistsByKey() {
|
||||
s.Require().False(notExists)
|
||||
}
|
||||
|
||||
// --- SearchAPIKeys ---
|
||||
// --- SearchApiKeys ---
|
||||
|
||||
func (s *APIKeyRepoSuite) TestSearchAPIKeys() {
|
||||
func (s *ApiKeyRepoSuite) TestSearchApiKeys() {
|
||||
user := s.mustCreateUser("search@test.com")
|
||||
s.mustCreateAPIKey(user.ID, "sk-search-1", "Production Key", nil)
|
||||
s.mustCreateAPIKey(user.ID, "sk-search-2", "Development Key", nil)
|
||||
s.mustCreateApiKey(user.ID, "sk-search-1", "Production Key", nil)
|
||||
s.mustCreateApiKey(user.ID, "sk-search-2", "Development Key", nil)
|
||||
|
||||
found, err := s.repo.SearchAPIKeys(s.ctx, user.ID, "prod", 10)
|
||||
s.Require().NoError(err, "SearchAPIKeys")
|
||||
found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "prod", 10)
|
||||
s.Require().NoError(err, "SearchApiKeys")
|
||||
s.Require().Len(found, 1)
|
||||
s.Require().Contains(found[0].Name, "Production")
|
||||
}
|
||||
|
||||
func (s *APIKeyRepoSuite) TestSearchAPIKeys_NoKeyword() {
|
||||
func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoKeyword() {
|
||||
user := s.mustCreateUser("searchnokw@test.com")
|
||||
s.mustCreateAPIKey(user.ID, "sk-nk-1", "K1", nil)
|
||||
s.mustCreateAPIKey(user.ID, "sk-nk-2", "K2", nil)
|
||||
s.mustCreateApiKey(user.ID, "sk-nk-1", "K1", nil)
|
||||
s.mustCreateApiKey(user.ID, "sk-nk-2", "K2", nil)
|
||||
|
||||
found, err := s.repo.SearchAPIKeys(s.ctx, user.ID, "", 10)
|
||||
found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "", 10)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(found, 2)
|
||||
}
|
||||
|
||||
func (s *APIKeyRepoSuite) TestSearchAPIKeys_NoUserID() {
|
||||
func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoUserID() {
|
||||
user := s.mustCreateUser("searchnouid@test.com")
|
||||
s.mustCreateAPIKey(user.ID, "sk-nu-1", "TestKey", nil)
|
||||
s.mustCreateApiKey(user.ID, "sk-nu-1", "TestKey", nil)
|
||||
|
||||
found, err := s.repo.SearchAPIKeys(s.ctx, 0, "testkey", 10)
|
||||
found, err := s.repo.SearchApiKeys(s.ctx, 0, "testkey", 10)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(found, 1)
|
||||
}
|
||||
|
||||
// --- ClearGroupIDByGroupID ---
|
||||
|
||||
func (s *APIKeyRepoSuite) TestClearGroupIDByGroupID() {
|
||||
func (s *ApiKeyRepoSuite) TestClearGroupIDByGroupID() {
|
||||
user := s.mustCreateUser("cleargrp@test.com")
|
||||
group := s.mustCreateGroup("g-clear-bulk")
|
||||
|
||||
k1 := s.mustCreateAPIKey(user.ID, "sk-clr-1", "K1", &group.ID)
|
||||
k2 := s.mustCreateAPIKey(user.ID, "sk-clr-2", "K2", &group.ID)
|
||||
s.mustCreateAPIKey(user.ID, "sk-clr-3", "K3", nil) // no group
|
||||
k1 := s.mustCreateApiKey(user.ID, "sk-clr-1", "K1", &group.ID)
|
||||
k2 := s.mustCreateApiKey(user.ID, "sk-clr-2", "K2", &group.ID)
|
||||
s.mustCreateApiKey(user.ID, "sk-clr-3", "K3", nil) // no group
|
||||
|
||||
affected, err := s.repo.ClearGroupIDByGroupID(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "ClearGroupIDByGroupID")
|
||||
@@ -284,10 +284,10 @@ func (s *APIKeyRepoSuite) TestClearGroupIDByGroupID() {
|
||||
|
||||
// --- Combined CRUD/Search/ClearGroupID (original test preserved as integration) ---
|
||||
|
||||
func (s *APIKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
|
||||
func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
|
||||
user := s.mustCreateUser("k@example.com")
|
||||
group := s.mustCreateGroup("g-k")
|
||||
key := s.mustCreateAPIKey(user.ID, "sk-test-1", "My Key", &group.ID)
|
||||
key := s.mustCreateApiKey(user.ID, "sk-test-1", "My Key", &group.ID)
|
||||
key.GroupID = &group.ID
|
||||
|
||||
got, err := s.repo.GetByKey(s.ctx, key.Key)
|
||||
@@ -320,13 +320,13 @@ func (s *APIKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
|
||||
s.Require().NoError(err, "ExistsByKey")
|
||||
s.Require().True(exists, "expected key to exist")
|
||||
|
||||
found, err := s.repo.SearchAPIKeys(s.ctx, user.ID, "renam", 10)
|
||||
s.Require().NoError(err, "SearchAPIKeys")
|
||||
found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "renam", 10)
|
||||
s.Require().NoError(err, "SearchApiKeys")
|
||||
s.Require().Len(found, 1)
|
||||
s.Require().Equal(key.ID, found[0].ID)
|
||||
|
||||
// ClearGroupIDByGroupID
|
||||
k2 := s.mustCreateAPIKey(user.ID, "sk-test-2", "Group Key", &group.ID)
|
||||
k2 := s.mustCreateApiKey(user.ID, "sk-test-2", "Group Key", &group.ID)
|
||||
k2.GroupID = &group.ID
|
||||
|
||||
countBefore, err := s.repo.CountByGroupID(s.ctx, group.ID)
|
||||
@@ -346,7 +346,7 @@ func (s *APIKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
|
||||
s.Require().Equal(int64(0), countAfter, "expected 0 keys in group after clear")
|
||||
}
|
||||
|
||||
func (s *APIKeyRepoSuite) mustCreateUser(email string) *service.User {
|
||||
func (s *ApiKeyRepoSuite) mustCreateUser(email string) *service.User {
|
||||
s.T().Helper()
|
||||
|
||||
u, err := s.client.User.Create().
|
||||
@@ -359,7 +359,7 @@ func (s *APIKeyRepoSuite) mustCreateUser(email string) *service.User {
|
||||
return userEntityToService(u)
|
||||
}
|
||||
|
||||
func (s *APIKeyRepoSuite) mustCreateGroup(name string) *service.Group {
|
||||
func (s *ApiKeyRepoSuite) mustCreateGroup(name string) *service.Group {
|
||||
s.T().Helper()
|
||||
|
||||
g, err := s.client.Group.Create().
|
||||
@@ -370,10 +370,10 @@ func (s *APIKeyRepoSuite) mustCreateGroup(name string) *service.Group {
|
||||
return groupEntityToService(g)
|
||||
}
|
||||
|
||||
func (s *APIKeyRepoSuite) mustCreateAPIKey(userID int64, key, name string, groupID *int64) *service.APIKey {
|
||||
func (s *ApiKeyRepoSuite) mustCreateApiKey(userID int64, key, name string, groupID *int64) *service.ApiKey {
|
||||
s.T().Helper()
|
||||
|
||||
k := &service.APIKey{
|
||||
k := &service.ApiKey{
|
||||
UserID: userID,
|
||||
Key: key,
|
||||
Name: name,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// Package repository 提供应用程序的基础设施层组件。
|
||||
// Package infrastructure 提供应用程序的基础设施层组件。
|
||||
// 包括数据库连接初始化、ORM 客户端管理、Redis 连接、数据库迁移等核心功能。
|
||||
package repository
|
||||
|
||||
|
||||
@@ -243,7 +243,7 @@ func mustCreateAccount(t *testing.T, client *dbent.Client, a *service.Account) *
|
||||
return a
|
||||
}
|
||||
|
||||
func mustCreateAPIKey(t *testing.T, client *dbent.Client, k *service.APIKey) *service.APIKey {
|
||||
func mustCreateApiKey(t *testing.T, client *dbent.Client, k *service.ApiKey) *service.ApiKey {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
@@ -257,7 +257,7 @@ func mustCreateAPIKey(t *testing.T, client *dbent.Client, k *service.APIKey) *se
|
||||
k.Name = "default"
|
||||
}
|
||||
|
||||
create := client.APIKey.Create().
|
||||
create := client.ApiKey.Create().
|
||||
SetUserID(k.UserID).
|
||||
SetKey(k.Key).
|
||||
SetName(k.Name).
|
||||
|
||||
@@ -293,8 +293,8 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
|
||||
|
||||
// 2. Clear group_id for api keys bound to this group.
|
||||
// 仅更新未软删除的记录,避免修改已删除数据,保证审计与历史回溯一致性。
|
||||
// 与 APIKeyRepository 的软删除语义保持一致,减少跨模块行为差异。
|
||||
if _, err := txClient.APIKey.Update().
|
||||
// 与 ApiKeyRepository 的软删除语义保持一致,减少跨模块行为差异。
|
||||
if _, err := txClient.ApiKey.Update().
|
||||
Where(apikey.GroupIDEQ(id), apikey.DeletedAtIsNil()).
|
||||
ClearGroupID().
|
||||
Save(ctx); err != nil {
|
||||
|
||||
@@ -34,15 +34,15 @@ func createEntUser(t *testing.T, ctx context.Context, client *dbent.Client, emai
|
||||
return u
|
||||
}
|
||||
|
||||
func TestEntSoftDelete_APIKey_DefaultFilterAndSkip(t *testing.T) {
|
||||
func TestEntSoftDelete_ApiKey_DefaultFilterAndSkip(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
// 使用全局 ent client,确保软删除验证在实际持久化数据上进行。
|
||||
client := testEntClient(t)
|
||||
|
||||
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user")+"@example.com")
|
||||
|
||||
repo := NewAPIKeyRepository(client)
|
||||
key := &service.APIKey{
|
||||
repo := NewApiKeyRepository(client)
|
||||
key := &service.ApiKey{
|
||||
UserID: u.ID,
|
||||
Key: uniqueSoftDeleteValue(t, "sk-soft-delete"),
|
||||
Name: "soft-delete",
|
||||
@@ -53,28 +53,28 @@ func TestEntSoftDelete_APIKey_DefaultFilterAndSkip(t *testing.T) {
|
||||
require.NoError(t, repo.Delete(ctx, key.ID), "soft delete api key")
|
||||
|
||||
_, err := repo.GetByID(ctx, key.ID)
|
||||
require.ErrorIs(t, err, service.ErrAPIKeyNotFound, "deleted rows should be hidden by default")
|
||||
require.ErrorIs(t, err, service.ErrApiKeyNotFound, "deleted rows should be hidden by default")
|
||||
|
||||
_, err = client.APIKey.Query().Where(apikey.IDEQ(key.ID)).Only(ctx)
|
||||
_, err = client.ApiKey.Query().Where(apikey.IDEQ(key.ID)).Only(ctx)
|
||||
require.Error(t, err, "default ent query should not see soft-deleted rows")
|
||||
require.True(t, dbent.IsNotFound(err), "expected ent not-found after default soft delete filter")
|
||||
|
||||
got, err := client.APIKey.Query().
|
||||
got, err := client.ApiKey.Query().
|
||||
Where(apikey.IDEQ(key.ID)).
|
||||
Only(mixins.SkipSoftDelete(ctx))
|
||||
require.NoError(t, err, "SkipSoftDelete should include soft-deleted rows")
|
||||
require.NotNil(t, got.DeletedAt, "deleted_at should be set after soft delete")
|
||||
}
|
||||
|
||||
func TestEntSoftDelete_APIKey_DeleteIdempotent(t *testing.T) {
|
||||
func TestEntSoftDelete_ApiKey_DeleteIdempotent(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
// 使用全局 ent client,避免事务回滚影响幂等性验证。
|
||||
client := testEntClient(t)
|
||||
|
||||
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user2")+"@example.com")
|
||||
|
||||
repo := NewAPIKeyRepository(client)
|
||||
key := &service.APIKey{
|
||||
repo := NewApiKeyRepository(client)
|
||||
key := &service.ApiKey{
|
||||
UserID: u.ID,
|
||||
Key: uniqueSoftDeleteValue(t, "sk-soft-delete2"),
|
||||
Name: "soft-delete2",
|
||||
@@ -86,15 +86,15 @@ func TestEntSoftDelete_APIKey_DeleteIdempotent(t *testing.T) {
|
||||
require.NoError(t, repo.Delete(ctx, key.ID), "second delete should be idempotent")
|
||||
}
|
||||
|
||||
func TestEntSoftDelete_APIKey_HardDeleteViaSkipSoftDelete(t *testing.T) {
|
||||
func TestEntSoftDelete_ApiKey_HardDeleteViaSkipSoftDelete(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
// 使用全局 ent client,确保 SkipSoftDelete 的硬删除语义可验证。
|
||||
client := testEntClient(t)
|
||||
|
||||
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user3")+"@example.com")
|
||||
|
||||
repo := NewAPIKeyRepository(client)
|
||||
key := &service.APIKey{
|
||||
repo := NewApiKeyRepository(client)
|
||||
key := &service.ApiKey{
|
||||
UserID: u.ID,
|
||||
Key: uniqueSoftDeleteValue(t, "sk-soft-delete3"),
|
||||
Name: "soft-delete3",
|
||||
@@ -105,10 +105,10 @@ func TestEntSoftDelete_APIKey_HardDeleteViaSkipSoftDelete(t *testing.T) {
|
||||
require.NoError(t, repo.Delete(ctx, key.ID), "soft delete api key")
|
||||
|
||||
// Hard delete using SkipSoftDelete so the hook doesn't convert it to update-deleted_at.
|
||||
_, err := client.APIKey.Delete().Where(apikey.IDEQ(key.ID)).Exec(mixins.SkipSoftDelete(ctx))
|
||||
_, err := client.ApiKey.Delete().Where(apikey.IDEQ(key.ID)).Exec(mixins.SkipSoftDelete(ctx))
|
||||
require.NoError(t, err, "hard delete")
|
||||
|
||||
_, err = client.APIKey.Query().
|
||||
_, err = client.ApiKey.Query().
|
||||
Where(apikey.IDEQ(key.ID)).
|
||||
Only(mixins.SkipSoftDelete(ctx))
|
||||
require.True(t, dbent.IsNotFound(err), "expected row to be hard deleted")
|
||||
|
||||
@@ -4,12 +4,13 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
|
||||
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
|
||||
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
|
||||
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
@@ -17,14 +18,15 @@ import (
|
||||
|
||||
type userRepository struct {
|
||||
client *dbent.Client
|
||||
sql sqlExecutor
|
||||
}
|
||||
|
||||
func NewUserRepository(client *dbent.Client, sqlDB *sql.DB) service.UserRepository {
|
||||
return newUserRepositoryWithSQL(client, sqlDB)
|
||||
}
|
||||
|
||||
func newUserRepositoryWithSQL(client *dbent.Client, _ sqlExecutor) *userRepository {
|
||||
return &userRepository{client: client}
|
||||
func newUserRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *userRepository {
|
||||
return &userRepository{client: client, sql: sqlq}
|
||||
}
|
||||
|
||||
func (r *userRepository) Create(ctx context.Context, userIn *service.User) error {
|
||||
@@ -194,7 +196,11 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
|
||||
// If attribute filters are specified, we need to filter by user IDs first
|
||||
var allowedUserIDs []int64
|
||||
if len(filters.Attributes) > 0 {
|
||||
allowedUserIDs = r.filterUsersByAttributes(ctx, filters.Attributes)
|
||||
var attrErr error
|
||||
allowedUserIDs, attrErr = r.filterUsersByAttributes(ctx, filters.Attributes)
|
||||
if attrErr != nil {
|
||||
return nil, nil, attrErr
|
||||
}
|
||||
if len(allowedUserIDs) == 0 {
|
||||
// No users match the attribute filters
|
||||
return []service.User{}, paginationResultFromTotal(0, params), nil
|
||||
@@ -262,56 +268,53 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
|
||||
}
|
||||
|
||||
// filterUsersByAttributes returns user IDs that match ALL the given attribute filters
|
||||
func (r *userRepository) filterUsersByAttributes(ctx context.Context, attrs map[int64]string) []int64 {
|
||||
func (r *userRepository) filterUsersByAttributes(ctx context.Context, attrs map[int64]string) ([]int64, error) {
|
||||
if len(attrs) == 0 {
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// For each attribute filter, get the set of matching user IDs
|
||||
// Then intersect all sets to get users matching ALL filters
|
||||
var resultSet map[int64]struct{}
|
||||
first := true
|
||||
if r.sql == nil {
|
||||
return nil, fmt.Errorf("sql executor is not configured")
|
||||
}
|
||||
|
||||
clauses := make([]string, 0, len(attrs))
|
||||
args := make([]any, 0, len(attrs)*2+1)
|
||||
argIndex := 1
|
||||
for attrID, value := range attrs {
|
||||
// Query user_attribute_values for this attribute
|
||||
values, err := r.client.UserAttributeValue.Query().
|
||||
Where(
|
||||
userattributevalue.AttributeIDEQ(attrID),
|
||||
userattributevalue.ValueContainsFold(value),
|
||||
).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
currentSet := make(map[int64]struct{}, len(values))
|
||||
for _, v := range values {
|
||||
currentSet[v.UserID] = struct{}{}
|
||||
}
|
||||
|
||||
if first {
|
||||
resultSet = currentSet
|
||||
first = false
|
||||
} else {
|
||||
// Intersect with previous results
|
||||
for userID := range resultSet {
|
||||
if _, ok := currentSet[userID]; !ok {
|
||||
delete(resultSet, userID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Early exit if no users match
|
||||
if len(resultSet) == 0 {
|
||||
return nil
|
||||
}
|
||||
clauses = append(clauses, fmt.Sprintf("(attribute_id = $%d AND value ILIKE $%d)", argIndex, argIndex+1))
|
||||
args = append(args, attrID, "%"+value+"%")
|
||||
argIndex += 2
|
||||
}
|
||||
|
||||
result := make([]int64, 0, len(resultSet))
|
||||
for userID := range resultSet {
|
||||
query := fmt.Sprintf(
|
||||
`SELECT user_id
|
||||
FROM user_attribute_values
|
||||
WHERE %s
|
||||
GROUP BY user_id
|
||||
HAVING COUNT(DISTINCT attribute_id) = $%d`,
|
||||
strings.Join(clauses, " OR "),
|
||||
argIndex,
|
||||
)
|
||||
args = append(args, len(attrs))
|
||||
|
||||
rows, err := r.sql.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
result := make([]int64, 0)
|
||||
for rows.Next() {
|
||||
var userID int64
|
||||
if scanErr := rows.Scan(&userID); scanErr != nil {
|
||||
return nil, scanErr
|
||||
}
|
||||
result = append(result, userID)
|
||||
}
|
||||
return result
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *userRepository) UpdateBalance(ctx context.Context, id int64, amount float64) error {
|
||||
|
||||
@@ -28,13 +28,12 @@ func ProvideConcurrencyCache(rdb *redis.Client, cfg *config.Config) service.Conc
|
||||
// ProviderSet is the Wire provider set for all repositories
|
||||
var ProviderSet = wire.NewSet(
|
||||
NewUserRepository,
|
||||
NewAPIKeyRepository,
|
||||
NewApiKeyRepository,
|
||||
NewGroupRepository,
|
||||
NewAccountRepository,
|
||||
NewProxyRepository,
|
||||
NewRedeemCodeRepository,
|
||||
NewUsageLogRepository,
|
||||
NewOpsRepository,
|
||||
NewSettingRepository,
|
||||
NewUserSubscriptionRepository,
|
||||
NewUserAttributeDefinitionRepository,
|
||||
@@ -43,7 +42,8 @@ var ProviderSet = wire.NewSet(
|
||||
// Cache implementations
|
||||
NewGatewayCache,
|
||||
NewBillingCache,
|
||||
NewAPIKeyCache,
|
||||
NewApiKeyCache,
|
||||
NewTempUnschedCache,
|
||||
ProvideConcurrencyCache,
|
||||
NewEmailCache,
|
||||
NewIdentityCache,
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
// Package server provides HTTP server setup and routing configuration.
|
||||
package server
|
||||
|
||||
import (
|
||||
@@ -26,8 +25,8 @@ func ProvideRouter(
|
||||
handlers *handler.Handlers,
|
||||
jwtAuth middleware2.JWTAuthMiddleware,
|
||||
adminAuth middleware2.AdminAuthMiddleware,
|
||||
apiKeyAuth middleware2.APIKeyAuthMiddleware,
|
||||
apiKeyService *service.APIKeyService,
|
||||
apiKeyAuth middleware2.ApiKeyAuthMiddleware,
|
||||
apiKeyService *service.ApiKeyService,
|
||||
subscriptionService *service.SubscriptionService,
|
||||
) *gin.Engine {
|
||||
if cfg.Server.Mode == "release" {
|
||||
|
||||
@@ -32,7 +32,7 @@ func adminAuth(
|
||||
// 检查 x-api-key header(Admin API Key 认证)
|
||||
apiKey := c.GetHeader("x-api-key")
|
||||
if apiKey != "" {
|
||||
if !validateAdminAPIKey(c, apiKey, settingService, userService) {
|
||||
if !validateAdminApiKey(c, apiKey, settingService, userService) {
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
@@ -52,48 +52,19 @@ func adminAuth(
|
||||
}
|
||||
}
|
||||
|
||||
// WebSocket 请求无法设置自定义 header,允许在 query 中携带凭证
|
||||
if isWebSocketRequest(c) {
|
||||
if token := strings.TrimSpace(c.Query("token")); token != "" {
|
||||
if !validateJWTForAdmin(c, token, authService, userService) {
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
if apiKey := strings.TrimSpace(c.Query("api_key")); apiKey != "" {
|
||||
if !validateAdminAPIKey(c, apiKey, settingService, userService) {
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 无有效认证信息
|
||||
AbortWithError(c, 401, "UNAUTHORIZED", "Authorization required")
|
||||
}
|
||||
}
|
||||
|
||||
func isWebSocketRequest(c *gin.Context) bool {
|
||||
if c == nil || c.Request == nil {
|
||||
return false
|
||||
}
|
||||
if strings.EqualFold(c.GetHeader("Upgrade"), "websocket") {
|
||||
return true
|
||||
}
|
||||
conn := strings.ToLower(c.GetHeader("Connection"))
|
||||
return strings.Contains(conn, "upgrade") && strings.EqualFold(c.GetHeader("Upgrade"), "websocket")
|
||||
}
|
||||
|
||||
// validateAdminAPIKey 验证管理员 API Key
|
||||
func validateAdminAPIKey(
|
||||
// validateAdminApiKey 验证管理员 API Key
|
||||
func validateAdminApiKey(
|
||||
c *gin.Context,
|
||||
key string,
|
||||
settingService *service.SettingService,
|
||||
userService *service.UserService,
|
||||
) bool {
|
||||
storedKey, err := settingService.GetAdminAPIKey(c.Request.Context())
|
||||
storedKey, err := settingService.GetAdminApiKey(c.Request.Context())
|
||||
if err != nil {
|
||||
AbortWithError(c, 500, "INTERNAL_ERROR", "Internal server error")
|
||||
return false
|
||||
|
||||
@@ -11,13 +11,13 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// NewAPIKeyAuthMiddleware 创建 API Key 认证中间件
|
||||
func NewAPIKeyAuthMiddleware(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config, opsService *service.OpsService) APIKeyAuthMiddleware {
|
||||
return APIKeyAuthMiddleware(apiKeyAuthWithSubscription(apiKeyService, subscriptionService, cfg, opsService))
|
||||
// NewApiKeyAuthMiddleware 创建 API Key 认证中间件
|
||||
func NewApiKeyAuthMiddleware(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) ApiKeyAuthMiddleware {
|
||||
return ApiKeyAuthMiddleware(apiKeyAuthWithSubscription(apiKeyService, subscriptionService, cfg))
|
||||
}
|
||||
|
||||
// apiKeyAuthWithSubscription API Key认证中间件(支持订阅验证)
|
||||
func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config, opsService *service.OpsService) gin.HandlerFunc {
|
||||
func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 尝试从Authorization header中提取API key (Bearer scheme)
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
@@ -53,7 +53,6 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
|
||||
|
||||
// 如果所有header都没有API key
|
||||
if apiKeyString == "" {
|
||||
recordOpsAuthError(c, opsService, nil, 401, "API key is required in Authorization header (Bearer scheme), x-api-key header, x-goog-api-key header, or key/api_key query parameter")
|
||||
AbortWithError(c, 401, "API_KEY_REQUIRED", "API key is required in Authorization header (Bearer scheme), x-api-key header, x-goog-api-key header, or key/api_key query parameter")
|
||||
return
|
||||
}
|
||||
@@ -61,40 +60,35 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
|
||||
// 从数据库验证API key
|
||||
apiKey, err := apiKeyService.GetByKey(c.Request.Context(), apiKeyString)
|
||||
if err != nil {
|
||||
if errors.Is(err, service.ErrAPIKeyNotFound) {
|
||||
recordOpsAuthError(c, opsService, nil, 401, "Invalid API key")
|
||||
if errors.Is(err, service.ErrApiKeyNotFound) {
|
||||
AbortWithError(c, 401, "INVALID_API_KEY", "Invalid API key")
|
||||
return
|
||||
}
|
||||
recordOpsAuthError(c, opsService, nil, 500, "Failed to validate API key")
|
||||
AbortWithError(c, 500, "INTERNAL_ERROR", "Failed to validate API key")
|
||||
return
|
||||
}
|
||||
|
||||
// 检查API key是否激活
|
||||
if !apiKey.IsActive() {
|
||||
recordOpsAuthError(c, opsService, apiKey, 401, "API key is disabled")
|
||||
AbortWithError(c, 401, "API_KEY_DISABLED", "API key is disabled")
|
||||
return
|
||||
}
|
||||
|
||||
// 检查关联的用户
|
||||
if apiKey.User == nil {
|
||||
recordOpsAuthError(c, opsService, apiKey, 401, "User associated with API key not found")
|
||||
AbortWithError(c, 401, "USER_NOT_FOUND", "User associated with API key not found")
|
||||
return
|
||||
}
|
||||
|
||||
// 检查用户状态
|
||||
if !apiKey.User.IsActive() {
|
||||
recordOpsAuthError(c, opsService, apiKey, 401, "User account is not active")
|
||||
AbortWithError(c, 401, "USER_INACTIVE", "User account is not active")
|
||||
return
|
||||
}
|
||||
|
||||
if cfg.RunMode == config.RunModeSimple {
|
||||
// 简易模式:跳过余额和订阅检查,但仍需设置必要的上下文
|
||||
c.Set(string(ContextKeyAPIKey), apiKey)
|
||||
c.Set(string(ContextKeyApiKey), apiKey)
|
||||
c.Set(string(ContextKeyUser), AuthSubject{
|
||||
UserID: apiKey.User.ID,
|
||||
Concurrency: apiKey.User.Concurrency,
|
||||
@@ -115,14 +109,12 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
|
||||
apiKey.Group.ID,
|
||||
)
|
||||
if err != nil {
|
||||
recordOpsAuthError(c, opsService, apiKey, 403, "No active subscription found for this group")
|
||||
AbortWithError(c, 403, "SUBSCRIPTION_NOT_FOUND", "No active subscription found for this group")
|
||||
return
|
||||
}
|
||||
|
||||
// 验证订阅状态(是否过期、暂停等)
|
||||
if err := subscriptionService.ValidateSubscription(c.Request.Context(), subscription); err != nil {
|
||||
recordOpsAuthError(c, opsService, apiKey, 403, err.Error())
|
||||
AbortWithError(c, 403, "SUBSCRIPTION_INVALID", err.Error())
|
||||
return
|
||||
}
|
||||
@@ -139,7 +131,6 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
|
||||
|
||||
// 预检查用量限制(使用0作为额外费用进行预检查)
|
||||
if err := subscriptionService.CheckUsageLimits(c.Request.Context(), subscription, apiKey.Group, 0); err != nil {
|
||||
recordOpsAuthError(c, opsService, apiKey, 429, err.Error())
|
||||
AbortWithError(c, 429, "USAGE_LIMIT_EXCEEDED", err.Error())
|
||||
return
|
||||
}
|
||||
@@ -149,14 +140,13 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
|
||||
} else {
|
||||
// 余额模式:检查用户余额
|
||||
if apiKey.User.Balance <= 0 {
|
||||
recordOpsAuthError(c, opsService, apiKey, 403, "Insufficient account balance")
|
||||
AbortWithError(c, 403, "INSUFFICIENT_BALANCE", "Insufficient account balance")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 将API key和用户信息存入上下文
|
||||
c.Set(string(ContextKeyAPIKey), apiKey)
|
||||
c.Set(string(ContextKeyApiKey), apiKey)
|
||||
c.Set(string(ContextKeyUser), AuthSubject{
|
||||
UserID: apiKey.User.ID,
|
||||
Concurrency: apiKey.User.Concurrency,
|
||||
@@ -167,66 +157,13 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
|
||||
}
|
||||
}
|
||||
|
||||
func recordOpsAuthError(c *gin.Context, opsService *service.OpsService, apiKey *service.APIKey, status int, message string) {
|
||||
if opsService == nil || c == nil {
|
||||
return
|
||||
}
|
||||
|
||||
errType := "authentication_error"
|
||||
phase := "auth"
|
||||
severity := "P3"
|
||||
switch status {
|
||||
case 403:
|
||||
errType = "billing_error"
|
||||
phase = "billing"
|
||||
case 429:
|
||||
errType = "rate_limit_error"
|
||||
phase = "billing"
|
||||
severity = "P2"
|
||||
case 500:
|
||||
errType = "api_error"
|
||||
phase = "internal"
|
||||
severity = "P1"
|
||||
}
|
||||
|
||||
logEntry := &service.OpsErrorLog{
|
||||
Phase: phase,
|
||||
Type: errType,
|
||||
Severity: severity,
|
||||
StatusCode: status,
|
||||
Message: message,
|
||||
ClientIP: c.ClientIP(),
|
||||
RequestPath: func() string {
|
||||
if c.Request != nil && c.Request.URL != nil {
|
||||
return c.Request.URL.Path
|
||||
}
|
||||
return ""
|
||||
}(),
|
||||
}
|
||||
|
||||
if apiKey != nil {
|
||||
logEntry.APIKeyID = &apiKey.ID
|
||||
if apiKey.User != nil {
|
||||
logEntry.UserID = &apiKey.User.ID
|
||||
}
|
||||
if apiKey.GroupID != nil {
|
||||
logEntry.GroupID = apiKey.GroupID
|
||||
}
|
||||
if apiKey.Group != nil {
|
||||
logEntry.Platform = apiKey.Group.Platform
|
||||
}
|
||||
}
|
||||
|
||||
enqueueOpsAuthErrorLog(opsService, logEntry)
|
||||
}
|
||||
|
||||
// GetAPIKeyFromContext 从上下文中获取API key
|
||||
func GetAPIKeyFromContext(c *gin.Context) (*service.APIKey, bool) {
|
||||
value, exists := c.Get(string(ContextKeyAPIKey))
|
||||
// GetApiKeyFromContext 从上下文中获取API key
|
||||
func GetApiKeyFromContext(c *gin.Context) (*service.ApiKey, bool) {
|
||||
value, exists := c.Get(string(ContextKeyApiKey))
|
||||
if !exists {
|
||||
return nil, false
|
||||
}
|
||||
apiKey, ok := value.(*service.APIKey)
|
||||
apiKey, ok := value.(*service.ApiKey)
|
||||
return apiKey, ok
|
||||
}
|
||||
|
||||
|
||||
@@ -11,16 +11,16 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// APIKeyAuthGoogle is a Google-style error wrapper for API key auth.
|
||||
func APIKeyAuthGoogle(apiKeyService *service.APIKeyService, cfg *config.Config) gin.HandlerFunc {
|
||||
return APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, cfg)
|
||||
// ApiKeyAuthGoogle is a Google-style error wrapper for API key auth.
|
||||
func ApiKeyAuthGoogle(apiKeyService *service.ApiKeyService, cfg *config.Config) gin.HandlerFunc {
|
||||
return ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, cfg)
|
||||
}
|
||||
|
||||
// APIKeyAuthWithSubscriptionGoogle behaves like APIKeyAuthWithSubscription but returns Google-style errors:
|
||||
// ApiKeyAuthWithSubscriptionGoogle behaves like ApiKeyAuthWithSubscription but returns Google-style errors:
|
||||
// {"error":{"code":401,"message":"...","status":"UNAUTHENTICATED"}}
|
||||
//
|
||||
// It is intended for Gemini native endpoints (/v1beta) to match Gemini SDK expectations.
|
||||
func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc {
|
||||
func ApiKeyAuthWithSubscriptionGoogle(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
apiKeyString := extractAPIKeyFromRequest(c)
|
||||
if apiKeyString == "" {
|
||||
@@ -30,7 +30,7 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs
|
||||
|
||||
apiKey, err := apiKeyService.GetByKey(c.Request.Context(), apiKeyString)
|
||||
if err != nil {
|
||||
if errors.Is(err, service.ErrAPIKeyNotFound) {
|
||||
if errors.Is(err, service.ErrApiKeyNotFound) {
|
||||
abortWithGoogleError(c, 401, "Invalid API key")
|
||||
return
|
||||
}
|
||||
@@ -53,7 +53,7 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs
|
||||
|
||||
// 简易模式:跳过余额和订阅检查
|
||||
if cfg.RunMode == config.RunModeSimple {
|
||||
c.Set(string(ContextKeyAPIKey), apiKey)
|
||||
c.Set(string(ContextKeyApiKey), apiKey)
|
||||
c.Set(string(ContextKeyUser), AuthSubject{
|
||||
UserID: apiKey.User.ID,
|
||||
Concurrency: apiKey.User.Concurrency,
|
||||
@@ -92,7 +92,7 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs
|
||||
}
|
||||
}
|
||||
|
||||
c.Set(string(ContextKeyAPIKey), apiKey)
|
||||
c.Set(string(ContextKeyApiKey), apiKey)
|
||||
c.Set(string(ContextKeyUser), AuthSubject{
|
||||
UserID: apiKey.User.ID,
|
||||
Concurrency: apiKey.User.Concurrency,
|
||||
|
||||
@@ -16,53 +16,53 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type fakeAPIKeyRepo struct {
|
||||
getByKey func(ctx context.Context, key string) (*service.APIKey, error)
|
||||
type fakeApiKeyRepo struct {
|
||||
getByKey func(ctx context.Context, key string) (*service.ApiKey, error)
|
||||
}
|
||||
|
||||
func (f fakeAPIKeyRepo) Create(ctx context.Context, key *service.APIKey) error {
|
||||
func (f fakeApiKeyRepo) Create(ctx context.Context, key *service.ApiKey) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
func (f fakeAPIKeyRepo) GetByID(ctx context.Context, id int64) (*service.APIKey, error) {
|
||||
func (f fakeApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeAPIKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) {
|
||||
func (f fakeApiKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) {
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeAPIKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
func (f fakeApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) {
|
||||
if f.getByKey == nil {
|
||||
return nil, errors.New("unexpected call")
|
||||
}
|
||||
return f.getByKey(ctx, key)
|
||||
}
|
||||
func (f fakeAPIKeyRepo) Update(ctx context.Context, key *service.APIKey) error {
|
||||
func (f fakeApiKeyRepo) Update(ctx context.Context, key *service.ApiKey) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
func (f fakeAPIKeyRepo) Delete(ctx context.Context, id int64) error {
|
||||
func (f fakeApiKeyRepo) Delete(ctx context.Context, id int64) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
func (f fakeAPIKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
|
||||
func (f fakeApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeAPIKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
|
||||
func (f fakeApiKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeAPIKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64, error) {
|
||||
func (f fakeApiKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64, error) {
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeAPIKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, error) {
|
||||
func (f fakeApiKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, error) {
|
||||
return false, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeAPIKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
|
||||
func (f fakeApiKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeAPIKeyRepo) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.APIKey, error) {
|
||||
func (f fakeApiKeyRepo) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeAPIKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
func (f fakeApiKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeAPIKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
func (f fakeApiKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
|
||||
@@ -74,8 +74,8 @@ type googleErrorResponse struct {
|
||||
} `json:"error"`
|
||||
}
|
||||
|
||||
func newTestAPIKeyService(repo service.APIKeyRepository) *service.APIKeyService {
|
||||
return service.NewAPIKeyService(
|
||||
func newTestApiKeyService(repo service.ApiKeyRepository) *service.ApiKeyService {
|
||||
return service.NewApiKeyService(
|
||||
repo,
|
||||
nil, // userRepo (unused in GetByKey)
|
||||
nil, // groupRepo
|
||||
@@ -85,16 +85,16 @@ func newTestAPIKeyService(repo service.APIKeyRepository) *service.APIKeyService
|
||||
)
|
||||
}
|
||||
|
||||
func TestAPIKeyAuthWithSubscriptionGoogle_MissingKey(t *testing.T) {
|
||||
func TestApiKeyAuthWithSubscriptionGoogle_MissingKey(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
|
||||
return nil, errors.New("should not be called")
|
||||
},
|
||||
})
|
||||
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
|
||||
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
|
||||
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
|
||||
@@ -109,16 +109,16 @@ func TestAPIKeyAuthWithSubscriptionGoogle_MissingKey(t *testing.T) {
|
||||
require.Equal(t, "UNAUTHENTICATED", resp.Error.Status)
|
||||
}
|
||||
|
||||
func TestAPIKeyAuthWithSubscriptionGoogle_InvalidKey(t *testing.T) {
|
||||
func TestApiKeyAuthWithSubscriptionGoogle_InvalidKey(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
|
||||
return nil, service.ErrApiKeyNotFound
|
||||
},
|
||||
})
|
||||
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
|
||||
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
|
||||
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
|
||||
@@ -134,16 +134,16 @@ func TestAPIKeyAuthWithSubscriptionGoogle_InvalidKey(t *testing.T) {
|
||||
require.Equal(t, "UNAUTHENTICATED", resp.Error.Status)
|
||||
}
|
||||
|
||||
func TestAPIKeyAuthWithSubscriptionGoogle_RepoError(t *testing.T) {
|
||||
func TestApiKeyAuthWithSubscriptionGoogle_RepoError(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
|
||||
return nil, errors.New("db down")
|
||||
},
|
||||
})
|
||||
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
|
||||
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
|
||||
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
|
||||
@@ -159,13 +159,13 @@ func TestAPIKeyAuthWithSubscriptionGoogle_RepoError(t *testing.T) {
|
||||
require.Equal(t, "INTERNAL", resp.Error.Status)
|
||||
}
|
||||
|
||||
func TestAPIKeyAuthWithSubscriptionGoogle_DisabledKey(t *testing.T) {
|
||||
func TestApiKeyAuthWithSubscriptionGoogle_DisabledKey(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
return &service.APIKey{
|
||||
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
|
||||
return &service.ApiKey{
|
||||
ID: 1,
|
||||
Key: key,
|
||||
Status: service.StatusDisabled,
|
||||
@@ -176,7 +176,7 @@ func TestAPIKeyAuthWithSubscriptionGoogle_DisabledKey(t *testing.T) {
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
|
||||
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
|
||||
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
|
||||
@@ -192,13 +192,13 @@ func TestAPIKeyAuthWithSubscriptionGoogle_DisabledKey(t *testing.T) {
|
||||
require.Equal(t, "UNAUTHENTICATED", resp.Error.Status)
|
||||
}
|
||||
|
||||
func TestAPIKeyAuthWithSubscriptionGoogle_InsufficientBalance(t *testing.T) {
|
||||
func TestApiKeyAuthWithSubscriptionGoogle_InsufficientBalance(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
return &service.APIKey{
|
||||
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
|
||||
return &service.ApiKey{
|
||||
ID: 1,
|
||||
Key: key,
|
||||
Status: service.StatusActive,
|
||||
@@ -210,7 +210,7 @@ func TestAPIKeyAuthWithSubscriptionGoogle_InsufficientBalance(t *testing.T) {
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
|
||||
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
|
||||
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
|
||||
|
||||
@@ -35,7 +35,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
|
||||
Balance: 10,
|
||||
Concurrency: 3,
|
||||
}
|
||||
apiKey := &service.APIKey{
|
||||
apiKey := &service.ApiKey{
|
||||
ID: 100,
|
||||
UserID: user.ID,
|
||||
Key: "test-key",
|
||||
@@ -45,10 +45,10 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
|
||||
}
|
||||
apiKey.GroupID = &group.ID
|
||||
|
||||
apiKeyRepo := &stubAPIKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
apiKeyRepo := &stubApiKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
|
||||
if key != apiKey.Key {
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
return nil, service.ErrApiKeyNotFound
|
||||
}
|
||||
clone := *apiKey
|
||||
return &clone, nil
|
||||
@@ -57,7 +57,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
|
||||
|
||||
t.Run("simple_mode_bypasses_quota_check", func(t *testing.T) {
|
||||
cfg := &config.Config{RunMode: config.RunModeSimple}
|
||||
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
|
||||
apiKeyService := service.NewApiKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
|
||||
subscriptionService := service.NewSubscriptionService(nil, &stubUserSubscriptionRepo{}, nil)
|
||||
router := newAuthTestRouter(apiKeyService, subscriptionService, cfg)
|
||||
|
||||
@@ -71,7 +71,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
|
||||
|
||||
t.Run("standard_mode_enforces_quota_check", func(t *testing.T) {
|
||||
cfg := &config.Config{RunMode: config.RunModeStandard}
|
||||
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
|
||||
apiKeyService := service.NewApiKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
|
||||
|
||||
now := time.Now()
|
||||
sub := &service.UserSubscription{
|
||||
@@ -110,75 +110,75 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func newAuthTestRouter(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) *gin.Engine {
|
||||
func newAuthTestRouter(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) *gin.Engine {
|
||||
router := gin.New()
|
||||
router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, cfg, nil)))
|
||||
router.Use(gin.HandlerFunc(NewApiKeyAuthMiddleware(apiKeyService, subscriptionService, cfg)))
|
||||
router.GET("/t", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
return router
|
||||
}
|
||||
|
||||
type stubAPIKeyRepo struct {
|
||||
getByKey func(ctx context.Context, key string) (*service.APIKey, error)
|
||||
type stubApiKeyRepo struct {
|
||||
getByKey func(ctx context.Context, key string) (*service.ApiKey, error)
|
||||
}
|
||||
|
||||
func (r *stubAPIKeyRepo) Create(ctx context.Context, key *service.APIKey) error {
|
||||
func (r *stubApiKeyRepo) Create(ctx context.Context, key *service.ApiKey) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubAPIKeyRepo) GetByID(ctx context.Context, id int64) (*service.APIKey, error) {
|
||||
func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubAPIKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) {
|
||||
func (r *stubApiKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) {
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubAPIKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) {
|
||||
if r.getByKey != nil {
|
||||
return r.getByKey(ctx, key)
|
||||
}
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubAPIKeyRepo) Update(ctx context.Context, key *service.APIKey) error {
|
||||
func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.ApiKey) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubAPIKeyRepo) Delete(ctx context.Context, id int64) error {
|
||||
func (r *stubApiKeyRepo) Delete(ctx context.Context, id int64) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubAPIKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
|
||||
func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubAPIKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
|
||||
func (r *stubApiKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubAPIKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64, error) {
|
||||
func (r *stubApiKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64, error) {
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubAPIKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, error) {
|
||||
func (r *stubApiKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, error) {
|
||||
return false, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubAPIKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
|
||||
func (r *stubApiKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubAPIKeyRepo) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.APIKey, error) {
|
||||
func (r *stubApiKeyRepo) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubAPIKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
func (r *stubApiKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubAPIKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
func (r *stubApiKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
|
||||
|
||||
@@ -2,14 +2,11 @@ package middleware
|
||||
|
||||
import (
|
||||
"log"
|
||||
"regexp"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
var sensitiveQueryParamRE = regexp.MustCompile(`(?i)([?&](?:token|api_key)=)[^&#]*`)
|
||||
|
||||
// Logger 请求日志中间件
|
||||
func Logger() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
@@ -29,7 +26,7 @@ func Logger() gin.HandlerFunc {
|
||||
method := c.Request.Method
|
||||
|
||||
// 请求路径
|
||||
path := sensitiveQueryParamRE.ReplaceAllString(c.Request.URL.RequestURI(), "${1}***")
|
||||
path := c.Request.URL.Path
|
||||
|
||||
// 状态码
|
||||
statusCode := c.Writer.Status()
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
// Package middleware provides HTTP middleware components for authentication,
|
||||
// authorization, logging, error recovery, and request processing.
|
||||
package middleware
|
||||
|
||||
import (
|
||||
@@ -17,8 +15,8 @@ const (
|
||||
ContextKeyUser ContextKey = "user"
|
||||
// ContextKeyUserRole 当前用户角色(string)
|
||||
ContextKeyUserRole ContextKey = "user_role"
|
||||
// ContextKeyAPIKey API密钥上下文键
|
||||
ContextKeyAPIKey ContextKey = "api_key"
|
||||
// ContextKeyApiKey API密钥上下文键
|
||||
ContextKeyApiKey ContextKey = "api_key"
|
||||
// ContextKeySubscription 订阅上下文键
|
||||
ContextKeySubscription ContextKey = "subscription"
|
||||
// ContextKeyForcePlatform 强制平台(用于 /antigravity 路由)
|
||||
|
||||
@@ -11,12 +11,12 @@ type JWTAuthMiddleware gin.HandlerFunc
|
||||
// AdminAuthMiddleware 管理员认证中间件类型
|
||||
type AdminAuthMiddleware gin.HandlerFunc
|
||||
|
||||
// APIKeyAuthMiddleware API Key 认证中间件类型
|
||||
type APIKeyAuthMiddleware gin.HandlerFunc
|
||||
// ApiKeyAuthMiddleware API Key 认证中间件类型
|
||||
type ApiKeyAuthMiddleware gin.HandlerFunc
|
||||
|
||||
// ProviderSet 中间件层的依赖注入
|
||||
var ProviderSet = wire.NewSet(
|
||||
NewJWTAuthMiddleware,
|
||||
NewAdminAuthMiddleware,
|
||||
NewAPIKeyAuthMiddleware,
|
||||
NewApiKeyAuthMiddleware,
|
||||
)
|
||||
|
||||
@@ -17,8 +17,8 @@ func SetupRouter(
|
||||
handlers *handler.Handlers,
|
||||
jwtAuth middleware2.JWTAuthMiddleware,
|
||||
adminAuth middleware2.AdminAuthMiddleware,
|
||||
apiKeyAuth middleware2.APIKeyAuthMiddleware,
|
||||
apiKeyService *service.APIKeyService,
|
||||
apiKeyAuth middleware2.ApiKeyAuthMiddleware,
|
||||
apiKeyService *service.ApiKeyService,
|
||||
subscriptionService *service.SubscriptionService,
|
||||
cfg *config.Config,
|
||||
) *gin.Engine {
|
||||
@@ -43,8 +43,8 @@ func registerRoutes(
|
||||
h *handler.Handlers,
|
||||
jwtAuth middleware2.JWTAuthMiddleware,
|
||||
adminAuth middleware2.AdminAuthMiddleware,
|
||||
apiKeyAuth middleware2.APIKeyAuthMiddleware,
|
||||
apiKeyService *service.APIKeyService,
|
||||
apiKeyAuth middleware2.ApiKeyAuthMiddleware,
|
||||
apiKeyService *service.ApiKeyService,
|
||||
subscriptionService *service.SubscriptionService,
|
||||
cfg *config.Config,
|
||||
) {
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
// Package routes 提供 HTTP 路由注册和处理函数
|
||||
package routes
|
||||
|
||||
import (
|
||||
|
||||
@@ -50,7 +50,7 @@ func RegisterUserRoutes(
|
||||
usage.GET("/dashboard/stats", h.Usage.DashboardStats)
|
||||
usage.GET("/dashboard/trend", h.Usage.DashboardTrend)
|
||||
usage.GET("/dashboard/models", h.Usage.DashboardModels)
|
||||
usage.POST("/dashboard/api-keys-usage", h.Usage.DashboardAPIKeysUsage)
|
||||
usage.POST("/dashboard/api-keys-usage", h.Usage.DashboardApiKeysUsage)
|
||||
}
|
||||
|
||||
// 卡密兑换
|
||||
|
||||
@@ -29,6 +29,9 @@ type Account struct {
|
||||
RateLimitResetAt *time.Time
|
||||
OverloadUntil *time.Time
|
||||
|
||||
TempUnschedulableUntil *time.Time
|
||||
TempUnschedulableReason string
|
||||
|
||||
SessionWindowStart *time.Time
|
||||
SessionWindowEnd *time.Time
|
||||
SessionWindowStatus string
|
||||
@@ -39,6 +42,13 @@ type Account struct {
|
||||
Groups []*Group
|
||||
}
|
||||
|
||||
type TempUnschedulableRule struct {
|
||||
ErrorCode int `json:"error_code"`
|
||||
Keywords []string `json:"keywords"`
|
||||
DurationMinutes int `json:"duration_minutes"`
|
||||
Description string `json:"description"`
|
||||
}
|
||||
|
||||
func (a *Account) IsActive() bool {
|
||||
return a.Status == StatusActive
|
||||
}
|
||||
@@ -54,6 +64,9 @@ func (a *Account) IsSchedulable() bool {
|
||||
if a.RateLimitResetAt != nil && now.Before(*a.RateLimitResetAt) {
|
||||
return false
|
||||
}
|
||||
if a.TempUnschedulableUntil != nil && now.Before(*a.TempUnschedulableUntil) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -163,6 +176,114 @@ func (a *Account) GetCredentialAsTime(key string) *time.Time {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Account) IsTempUnschedulableEnabled() bool {
|
||||
if a.Credentials == nil {
|
||||
return false
|
||||
}
|
||||
raw, ok := a.Credentials["temp_unschedulable_enabled"]
|
||||
if !ok || raw == nil {
|
||||
return false
|
||||
}
|
||||
enabled, ok := raw.(bool)
|
||||
return ok && enabled
|
||||
}
|
||||
|
||||
func (a *Account) GetTempUnschedulableRules() []TempUnschedulableRule {
|
||||
if a.Credentials == nil {
|
||||
return nil
|
||||
}
|
||||
raw, ok := a.Credentials["temp_unschedulable_rules"]
|
||||
if !ok || raw == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
arr, ok := raw.([]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
rules := make([]TempUnschedulableRule, 0, len(arr))
|
||||
for _, item := range arr {
|
||||
entry, ok := item.(map[string]any)
|
||||
if !ok || entry == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
rule := TempUnschedulableRule{
|
||||
ErrorCode: parseTempUnschedInt(entry["error_code"]),
|
||||
Keywords: parseTempUnschedStrings(entry["keywords"]),
|
||||
DurationMinutes: parseTempUnschedInt(entry["duration_minutes"]),
|
||||
Description: parseTempUnschedString(entry["description"]),
|
||||
}
|
||||
|
||||
if rule.ErrorCode <= 0 || rule.DurationMinutes <= 0 || len(rule.Keywords) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
rules = append(rules, rule)
|
||||
}
|
||||
|
||||
return rules
|
||||
}
|
||||
|
||||
func parseTempUnschedString(value any) string {
|
||||
s, ok := value.(string)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(s)
|
||||
}
|
||||
|
||||
func parseTempUnschedStrings(value any) []string {
|
||||
if value == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var raw []string
|
||||
switch v := value.(type) {
|
||||
case []string:
|
||||
raw = v
|
||||
case []any:
|
||||
raw = make([]string, 0, len(v))
|
||||
for _, item := range v {
|
||||
if s, ok := item.(string); ok {
|
||||
raw = append(raw, s)
|
||||
}
|
||||
}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
out := make([]string, 0, len(raw))
|
||||
for _, item := range raw {
|
||||
s := strings.TrimSpace(item)
|
||||
if s != "" {
|
||||
out = append(out, s)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func parseTempUnschedInt(value any) int {
|
||||
switch v := value.(type) {
|
||||
case int:
|
||||
return v
|
||||
case int64:
|
||||
return int(v)
|
||||
case float64:
|
||||
return int(v)
|
||||
case json.Number:
|
||||
if i, err := v.Int64(); err == nil {
|
||||
return int(i)
|
||||
}
|
||||
case string:
|
||||
if i, err := strconv.Atoi(strings.TrimSpace(v)); err == nil {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (a *Account) GetModelMapping() map[string]string {
|
||||
if a.Credentials == nil {
|
||||
return nil
|
||||
@@ -206,7 +327,7 @@ func (a *Account) GetMappedModel(requestedModel string) string {
|
||||
}
|
||||
|
||||
func (a *Account) GetBaseURL() string {
|
||||
if a.Type != AccountTypeAPIKey {
|
||||
if a.Type != AccountTypeApiKey {
|
||||
return ""
|
||||
}
|
||||
baseURL := a.GetCredential("base_url")
|
||||
@@ -229,7 +350,7 @@ func (a *Account) GetExtraString(key string) string {
|
||||
}
|
||||
|
||||
func (a *Account) IsCustomErrorCodesEnabled() bool {
|
||||
if a.Type != AccountTypeAPIKey || a.Credentials == nil {
|
||||
if a.Type != AccountTypeApiKey || a.Credentials == nil {
|
||||
return false
|
||||
}
|
||||
if v, ok := a.Credentials["custom_error_codes_enabled"]; ok {
|
||||
@@ -300,15 +421,15 @@ func (a *Account) IsOpenAIOAuth() bool {
|
||||
return a.IsOpenAI() && a.Type == AccountTypeOAuth
|
||||
}
|
||||
|
||||
func (a *Account) IsOpenAIAPIKey() bool {
|
||||
return a.IsOpenAI() && a.Type == AccountTypeAPIKey
|
||||
func (a *Account) IsOpenAIApiKey() bool {
|
||||
return a.IsOpenAI() && a.Type == AccountTypeApiKey
|
||||
}
|
||||
|
||||
func (a *Account) GetOpenAIBaseURL() string {
|
||||
if !a.IsOpenAI() {
|
||||
return ""
|
||||
}
|
||||
if a.Type == AccountTypeAPIKey {
|
||||
if a.Type == AccountTypeApiKey {
|
||||
baseURL := a.GetCredential("base_url")
|
||||
if baseURL != "" {
|
||||
return baseURL
|
||||
@@ -338,8 +459,8 @@ func (a *Account) GetOpenAIIDToken() string {
|
||||
return a.GetCredential("id_token")
|
||||
}
|
||||
|
||||
func (a *Account) GetOpenAIAPIKey() string {
|
||||
if !a.IsOpenAIAPIKey() {
|
||||
func (a *Account) GetOpenAIApiKey() string {
|
||||
if !a.IsOpenAIApiKey() {
|
||||
return ""
|
||||
}
|
||||
return a.GetCredential("api_key")
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
// Package service 提供业务逻辑层服务,封装领域模型的业务规则和操作流程。
|
||||
// 服务层协调 repository 层的数据访问,实现跨实体的业务逻辑,并为上层 API 提供统一的业务接口。
|
||||
package service
|
||||
|
||||
import (
|
||||
@@ -51,6 +49,8 @@ type AccountRepository interface {
|
||||
|
||||
SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error
|
||||
SetOverloaded(ctx context.Context, id int64, until time.Time) error
|
||||
SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error
|
||||
ClearTempUnschedulable(ctx context.Context, id int64) error
|
||||
ClearRateLimit(ctx context.Context, id int64) error
|
||||
UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error
|
||||
UpdateExtra(ctx context.Context, id int64, updates map[string]any) error
|
||||
|
||||
@@ -139,6 +139,14 @@ func (s *accountRepoStub) SetOverloaded(ctx context.Context, id int64, until tim
|
||||
panic("unexpected SetOverloaded call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
|
||||
panic("unexpected SetTempUnschedulable call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) ClearTempUnschedulable(ctx context.Context, id int64) error {
|
||||
panic("unexpected ClearTempUnschedulable call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) ClearRateLimit(ctx context.Context, id int64) error {
|
||||
panic("unexpected ClearRateLimit call")
|
||||
}
|
||||
|
||||
@@ -324,7 +324,7 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
|
||||
chatgptAccountID = account.GetChatGPTAccountID()
|
||||
} else if account.Type == "apikey" {
|
||||
// API Key - use Platform API
|
||||
authToken = account.GetOpenAIAPIKey()
|
||||
authToken = account.GetOpenAIApiKey()
|
||||
if authToken == "" {
|
||||
return s.sendErrorAndEnd(c, "No API key available")
|
||||
}
|
||||
@@ -402,7 +402,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
|
||||
}
|
||||
|
||||
// For API Key accounts with model mapping, map the model
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
if account.Type == AccountTypeApiKey {
|
||||
mapping := account.GetModelMapping()
|
||||
if len(mapping) > 0 {
|
||||
if mappedModel, exists := mapping[testModelID]; exists {
|
||||
@@ -426,7 +426,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
|
||||
var err error
|
||||
|
||||
switch account.Type {
|
||||
case AccountTypeAPIKey:
|
||||
case AccountTypeApiKey:
|
||||
req, err = s.buildGeminiAPIKeyRequest(ctx, account, testModelID, payload)
|
||||
case AccountTypeOAuth:
|
||||
req, err = s.buildGeminiOAuthRequest(ctx, account, testModelID, payload)
|
||||
|
||||
@@ -2,7 +2,7 @@ package service
|
||||
|
||||
import "time"
|
||||
|
||||
type APIKey struct {
|
||||
type ApiKey struct {
|
||||
ID int64
|
||||
UserID int64
|
||||
Key string
|
||||
@@ -15,6 +15,6 @@ type APIKey struct {
|
||||
Group *Group
|
||||
}
|
||||
|
||||
func (k *APIKey) IsActive() bool {
|
||||
func (k *ApiKey) IsActive() bool {
|
||||
return k.Status == StatusActive
|
||||
}
|
||||
|
||||
@@ -14,39 +14,39 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
ErrAPIKeyNotFound = infraerrors.NotFound("API_KEY_NOT_FOUND", "api key not found")
|
||||
ErrApiKeyNotFound = infraerrors.NotFound("API_KEY_NOT_FOUND", "api key not found")
|
||||
ErrGroupNotAllowed = infraerrors.Forbidden("GROUP_NOT_ALLOWED", "user is not allowed to bind this group")
|
||||
ErrAPIKeyExists = infraerrors.Conflict("API_KEY_EXISTS", "api key already exists")
|
||||
ErrAPIKeyTooShort = infraerrors.BadRequest("API_KEY_TOO_SHORT", "api key must be at least 16 characters")
|
||||
ErrAPIKeyInvalidChars = infraerrors.BadRequest("API_KEY_INVALID_CHARS", "api key can only contain letters, numbers, underscores, and hyphens")
|
||||
ErrAPIKeyRateLimited = infraerrors.TooManyRequests("API_KEY_RATE_LIMITED", "too many failed attempts, please try again later")
|
||||
ErrApiKeyExists = infraerrors.Conflict("API_KEY_EXISTS", "api key already exists")
|
||||
ErrApiKeyTooShort = infraerrors.BadRequest("API_KEY_TOO_SHORT", "api key must be at least 16 characters")
|
||||
ErrApiKeyInvalidChars = infraerrors.BadRequest("API_KEY_INVALID_CHARS", "api key can only contain letters, numbers, underscores, and hyphens")
|
||||
ErrApiKeyRateLimited = infraerrors.TooManyRequests("API_KEY_RATE_LIMITED", "too many failed attempts, please try again later")
|
||||
)
|
||||
|
||||
const (
|
||||
apiKeyMaxErrorsPerHour = 20
|
||||
)
|
||||
|
||||
type APIKeyRepository interface {
|
||||
Create(ctx context.Context, key *APIKey) error
|
||||
GetByID(ctx context.Context, id int64) (*APIKey, error)
|
||||
type ApiKeyRepository interface {
|
||||
Create(ctx context.Context, key *ApiKey) error
|
||||
GetByID(ctx context.Context, id int64) (*ApiKey, error)
|
||||
// GetOwnerID 仅获取 API Key 的所有者 ID,用于删除前的轻量级权限验证
|
||||
GetOwnerID(ctx context.Context, id int64) (int64, error)
|
||||
GetByKey(ctx context.Context, key string) (*APIKey, error)
|
||||
Update(ctx context.Context, key *APIKey) error
|
||||
GetByKey(ctx context.Context, key string) (*ApiKey, error)
|
||||
Update(ctx context.Context, key *ApiKey) error
|
||||
Delete(ctx context.Context, id int64) error
|
||||
|
||||
ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error)
|
||||
ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error)
|
||||
VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error)
|
||||
CountByUserID(ctx context.Context, userID int64) (int64, error)
|
||||
ExistsByKey(ctx context.Context, key string) (bool, error)
|
||||
ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error)
|
||||
SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error)
|
||||
ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error)
|
||||
SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]ApiKey, error)
|
||||
ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error)
|
||||
CountByGroupID(ctx context.Context, groupID int64) (int64, error)
|
||||
}
|
||||
|
||||
// APIKeyCache defines cache operations for API key service
|
||||
type APIKeyCache interface {
|
||||
// ApiKeyCache defines cache operations for API key service
|
||||
type ApiKeyCache interface {
|
||||
GetCreateAttemptCount(ctx context.Context, userID int64) (int, error)
|
||||
IncrementCreateAttemptCount(ctx context.Context, userID int64) error
|
||||
DeleteCreateAttemptCount(ctx context.Context, userID int64) error
|
||||
@@ -55,40 +55,40 @@ type APIKeyCache interface {
|
||||
SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error
|
||||
}
|
||||
|
||||
// CreateAPIKeyRequest 创建API Key请求
|
||||
type CreateAPIKeyRequest struct {
|
||||
// CreateApiKeyRequest 创建API Key请求
|
||||
type CreateApiKeyRequest struct {
|
||||
Name string `json:"name"`
|
||||
GroupID *int64 `json:"group_id"`
|
||||
CustomKey *string `json:"custom_key"` // 可选的自定义key
|
||||
}
|
||||
|
||||
// UpdateAPIKeyRequest 更新API Key请求
|
||||
type UpdateAPIKeyRequest struct {
|
||||
// UpdateApiKeyRequest 更新API Key请求
|
||||
type UpdateApiKeyRequest struct {
|
||||
Name *string `json:"name"`
|
||||
GroupID *int64 `json:"group_id"`
|
||||
Status *string `json:"status"`
|
||||
}
|
||||
|
||||
// APIKeyService API Key服务
|
||||
type APIKeyService struct {
|
||||
apiKeyRepo APIKeyRepository
|
||||
// ApiKeyService API Key服务
|
||||
type ApiKeyService struct {
|
||||
apiKeyRepo ApiKeyRepository
|
||||
userRepo UserRepository
|
||||
groupRepo GroupRepository
|
||||
userSubRepo UserSubscriptionRepository
|
||||
cache APIKeyCache
|
||||
cache ApiKeyCache
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewAPIKeyService 创建API Key服务实例
|
||||
func NewAPIKeyService(
|
||||
apiKeyRepo APIKeyRepository,
|
||||
// NewApiKeyService 创建API Key服务实例
|
||||
func NewApiKeyService(
|
||||
apiKeyRepo ApiKeyRepository,
|
||||
userRepo UserRepository,
|
||||
groupRepo GroupRepository,
|
||||
userSubRepo UserSubscriptionRepository,
|
||||
cache APIKeyCache,
|
||||
cache ApiKeyCache,
|
||||
cfg *config.Config,
|
||||
) *APIKeyService {
|
||||
return &APIKeyService{
|
||||
) *ApiKeyService {
|
||||
return &ApiKeyService{
|
||||
apiKeyRepo: apiKeyRepo,
|
||||
userRepo: userRepo,
|
||||
groupRepo: groupRepo,
|
||||
@@ -99,7 +99,7 @@ func NewAPIKeyService(
|
||||
}
|
||||
|
||||
// GenerateKey 生成随机API Key
|
||||
func (s *APIKeyService) GenerateKey() (string, error) {
|
||||
func (s *ApiKeyService) GenerateKey() (string, error) {
|
||||
// 生成32字节随机数据
|
||||
bytes := make([]byte, 32)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
@@ -107,7 +107,7 @@ func (s *APIKeyService) GenerateKey() (string, error) {
|
||||
}
|
||||
|
||||
// 转换为十六进制字符串并添加前缀
|
||||
prefix := s.cfg.Default.APIKeyPrefix
|
||||
prefix := s.cfg.Default.ApiKeyPrefix
|
||||
if prefix == "" {
|
||||
prefix = "sk-"
|
||||
}
|
||||
@@ -117,10 +117,10 @@ func (s *APIKeyService) GenerateKey() (string, error) {
|
||||
}
|
||||
|
||||
// ValidateCustomKey 验证自定义API Key格式
|
||||
func (s *APIKeyService) ValidateCustomKey(key string) error {
|
||||
func (s *ApiKeyService) ValidateCustomKey(key string) error {
|
||||
// 检查长度
|
||||
if len(key) < 16 {
|
||||
return ErrAPIKeyTooShort
|
||||
return ErrApiKeyTooShort
|
||||
}
|
||||
|
||||
// 检查字符:只允许字母、数字、下划线、连字符
|
||||
@@ -131,14 +131,14 @@ func (s *APIKeyService) ValidateCustomKey(key string) error {
|
||||
c == '_' || c == '-' {
|
||||
continue
|
||||
}
|
||||
return ErrAPIKeyInvalidChars
|
||||
return ErrApiKeyInvalidChars
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkAPIKeyRateLimit 检查用户创建自定义Key的错误次数是否超限
|
||||
func (s *APIKeyService) checkAPIKeyRateLimit(ctx context.Context, userID int64) error {
|
||||
// checkApiKeyRateLimit 检查用户创建自定义Key的错误次数是否超限
|
||||
func (s *ApiKeyService) checkApiKeyRateLimit(ctx context.Context, userID int64) error {
|
||||
if s.cache == nil {
|
||||
return nil
|
||||
}
|
||||
@@ -150,14 +150,14 @@ func (s *APIKeyService) checkAPIKeyRateLimit(ctx context.Context, userID int64)
|
||||
}
|
||||
|
||||
if count >= apiKeyMaxErrorsPerHour {
|
||||
return ErrAPIKeyRateLimited
|
||||
return ErrApiKeyRateLimited
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// incrementAPIKeyErrorCount 增加用户创建自定义Key的错误计数
|
||||
func (s *APIKeyService) incrementAPIKeyErrorCount(ctx context.Context, userID int64) {
|
||||
// incrementApiKeyErrorCount 增加用户创建自定义Key的错误计数
|
||||
func (s *ApiKeyService) incrementApiKeyErrorCount(ctx context.Context, userID int64) {
|
||||
if s.cache == nil {
|
||||
return
|
||||
}
|
||||
@@ -168,7 +168,7 @@ func (s *APIKeyService) incrementAPIKeyErrorCount(ctx context.Context, userID in
|
||||
// canUserBindGroup 检查用户是否可以绑定指定分组
|
||||
// 对于订阅类型分组:检查用户是否有有效订阅
|
||||
// 对于标准类型分组:使用原有的 AllowedGroups 和 IsExclusive 逻辑
|
||||
func (s *APIKeyService) canUserBindGroup(ctx context.Context, user *User, group *Group) bool {
|
||||
func (s *ApiKeyService) canUserBindGroup(ctx context.Context, user *User, group *Group) bool {
|
||||
// 订阅类型分组:需要有效订阅
|
||||
if group.IsSubscriptionType() {
|
||||
_, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, user.ID, group.ID)
|
||||
@@ -179,7 +179,7 @@ func (s *APIKeyService) canUserBindGroup(ctx context.Context, user *User, group
|
||||
}
|
||||
|
||||
// Create 创建API Key
|
||||
func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIKeyRequest) (*APIKey, error) {
|
||||
func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiKeyRequest) (*ApiKey, error) {
|
||||
// 验证用户存在
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
@@ -204,7 +204,7 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK
|
||||
// 判断是否使用自定义Key
|
||||
if req.CustomKey != nil && *req.CustomKey != "" {
|
||||
// 检查限流(仅对自定义key进行限流)
|
||||
if err := s.checkAPIKeyRateLimit(ctx, userID); err != nil {
|
||||
if err := s.checkApiKeyRateLimit(ctx, userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -219,9 +219,9 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK
|
||||
return nil, fmt.Errorf("check key exists: %w", err)
|
||||
}
|
||||
if exists {
|
||||
// Key已存在,增加错误计数
|
||||
s.incrementAPIKeyErrorCount(ctx, userID)
|
||||
return nil, ErrAPIKeyExists
|
||||
// Key已存在,增加错误计数
|
||||
s.incrementApiKeyErrorCount(ctx, userID)
|
||||
return nil, ErrApiKeyExists
|
||||
}
|
||||
|
||||
key = *req.CustomKey
|
||||
@@ -235,7 +235,7 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK
|
||||
}
|
||||
|
||||
// 创建API Key记录
|
||||
apiKey := &APIKey{
|
||||
apiKey := &ApiKey{
|
||||
UserID: userID,
|
||||
Key: key,
|
||||
Name: req.Name,
|
||||
@@ -251,7 +251,7 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK
|
||||
}
|
||||
|
||||
// List 获取用户的API Key列表
|
||||
func (s *APIKeyService) List(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
|
||||
func (s *ApiKeyService) List(ctx context.Context, userID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error) {
|
||||
keys, pagination, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("list api keys: %w", err)
|
||||
@@ -259,7 +259,7 @@ func (s *APIKeyService) List(ctx context.Context, userID int64, params paginatio
|
||||
return keys, pagination, nil
|
||||
}
|
||||
|
||||
func (s *APIKeyService) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
|
||||
func (s *ApiKeyService) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
|
||||
if len(apiKeyIDs) == 0 {
|
||||
return []int64{}, nil
|
||||
}
|
||||
@@ -272,7 +272,7 @@ func (s *APIKeyService) VerifyOwnership(ctx context.Context, userID int64, apiKe
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取API Key
|
||||
func (s *APIKeyService) GetByID(ctx context.Context, id int64) (*APIKey, error) {
|
||||
func (s *ApiKeyService) GetByID(ctx context.Context, id int64) (*ApiKey, error) {
|
||||
apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get api key: %w", err)
|
||||
@@ -281,7 +281,7 @@ func (s *APIKeyService) GetByID(ctx context.Context, id int64) (*APIKey, error)
|
||||
}
|
||||
|
||||
// GetByKey 根据Key字符串获取API Key(用于认证)
|
||||
func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, error) {
|
||||
func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*ApiKey, error) {
|
||||
// 尝试从Redis缓存获取
|
||||
cacheKey := fmt.Sprintf("apikey:%s", key)
|
||||
|
||||
@@ -301,7 +301,7 @@ func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, erro
|
||||
}
|
||||
|
||||
// Update 更新API Key
|
||||
func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req UpdateAPIKeyRequest) (*APIKey, error) {
|
||||
func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req UpdateApiKeyRequest) (*ApiKey, error) {
|
||||
apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get api key: %w", err)
|
||||
@@ -353,8 +353,8 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req
|
||||
|
||||
// Delete 删除API Key
|
||||
// 优化:使用 GetOwnerID 替代 GetByID 进行权限验证,
|
||||
// 避免加载完整 APIKey 对象及其关联数据(User、Group),提升删除操作的性能
|
||||
func (s *APIKeyService) Delete(ctx context.Context, id int64, userID int64) error {
|
||||
// 避免加载完整 ApiKey 对象及其关联数据(User、Group),提升删除操作的性能
|
||||
func (s *ApiKeyService) Delete(ctx context.Context, id int64, userID int64) error {
|
||||
// 仅获取所有者 ID 用于权限验证,而非加载完整对象
|
||||
ownerID, err := s.apiKeyRepo.GetOwnerID(ctx, id)
|
||||
if err != nil {
|
||||
@@ -379,7 +379,7 @@ func (s *APIKeyService) Delete(ctx context.Context, id int64, userID int64) erro
|
||||
}
|
||||
|
||||
// ValidateKey 验证API Key是否有效(用于认证中间件)
|
||||
func (s *APIKeyService) ValidateKey(ctx context.Context, key string) (*APIKey, *User, error) {
|
||||
func (s *ApiKeyService) ValidateKey(ctx context.Context, key string) (*ApiKey, *User, error) {
|
||||
// 获取API Key
|
||||
apiKey, err := s.GetByKey(ctx, key)
|
||||
if err != nil {
|
||||
@@ -406,7 +406,7 @@ func (s *APIKeyService) ValidateKey(ctx context.Context, key string) (*APIKey, *
|
||||
}
|
||||
|
||||
// IncrementUsage 增加API Key使用次数(可选:用于统计)
|
||||
func (s *APIKeyService) IncrementUsage(ctx context.Context, keyID int64) error {
|
||||
func (s *ApiKeyService) IncrementUsage(ctx context.Context, keyID int64) error {
|
||||
// 使用Redis计数器
|
||||
if s.cache != nil {
|
||||
cacheKey := fmt.Sprintf("apikey:usage:%d:%s", keyID, timezone.Now().Format("2006-01-02"))
|
||||
@@ -423,7 +423,7 @@ func (s *APIKeyService) IncrementUsage(ctx context.Context, keyID int64) error {
|
||||
// 返回用户可以选择的分组:
|
||||
// - 标准类型分组:公开的(非专属)或用户被明确允许的
|
||||
// - 订阅类型分组:用户有有效订阅的
|
||||
func (s *APIKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([]Group, error) {
|
||||
func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([]Group, error) {
|
||||
// 获取用户信息
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
@@ -460,7 +460,7 @@ func (s *APIKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([
|
||||
}
|
||||
|
||||
// canUserBindGroupInternal 内部方法,检查用户是否可以绑定分组(使用预加载的订阅数据)
|
||||
func (s *APIKeyService) canUserBindGroupInternal(user *User, group *Group, subscribedGroupIDs map[int64]bool) bool {
|
||||
func (s *ApiKeyService) canUserBindGroupInternal(user *User, group *Group, subscribedGroupIDs map[int64]bool) bool {
|
||||
// 订阅类型分组:需要有效订阅
|
||||
if group.IsSubscriptionType() {
|
||||
return subscribedGroupIDs[group.ID]
|
||||
@@ -469,8 +469,8 @@ func (s *APIKeyService) canUserBindGroupInternal(user *User, group *Group, subsc
|
||||
return user.CanBindGroup(group.ID, group.IsExclusive)
|
||||
}
|
||||
|
||||
func (s *APIKeyService) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error) {
|
||||
keys, err := s.apiKeyRepo.SearchAPIKeys(ctx, userID, keyword, limit)
|
||||
func (s *ApiKeyService) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]ApiKey, error) {
|
||||
keys, err := s.apiKeyRepo.SearchApiKeys(ctx, userID, keyword, limit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("search api keys: %w", err)
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
//go:build unit
|
||||
|
||||
// API Key 服务删除方法的单元测试
|
||||
// 测试 APIKeyService.Delete 方法在各种场景下的行为,
|
||||
// 测试 ApiKeyService.Delete 方法在各种场景下的行为,
|
||||
// 包括权限验证、缓存清理和错误处理
|
||||
|
||||
package service
|
||||
@@ -16,12 +16,12 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// apiKeyRepoStub 是 APIKeyRepository 接口的测试桩实现。
|
||||
// 用于隔离测试 APIKeyService.Delete 方法,避免依赖真实数据库。
|
||||
// apiKeyRepoStub 是 ApiKeyRepository 接口的测试桩实现。
|
||||
// 用于隔离测试 ApiKeyService.Delete 方法,避免依赖真实数据库。
|
||||
//
|
||||
// 设计说明:
|
||||
// - ownerID: 模拟 GetOwnerID 返回的所有者 ID
|
||||
// - ownerErr: 模拟 GetOwnerID 返回的错误(如 ErrAPIKeyNotFound)
|
||||
// - ownerErr: 模拟 GetOwnerID 返回的错误(如 ErrApiKeyNotFound)
|
||||
// - deleteErr: 模拟 Delete 返回的错误
|
||||
// - deletedIDs: 记录被调用删除的 API Key ID,用于断言验证
|
||||
type apiKeyRepoStub struct {
|
||||
@@ -33,11 +33,11 @@ type apiKeyRepoStub struct {
|
||||
|
||||
// 以下方法在本测试中不应被调用,使用 panic 确保测试失败时能快速定位问题
|
||||
|
||||
func (s *apiKeyRepoStub) Create(ctx context.Context, key *APIKey) error {
|
||||
func (s *apiKeyRepoStub) Create(ctx context.Context, key *ApiKey) error {
|
||||
panic("unexpected Create call")
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) GetByID(ctx context.Context, id int64) (*APIKey, error) {
|
||||
func (s *apiKeyRepoStub) GetByID(ctx context.Context, id int64) (*ApiKey, error) {
|
||||
panic("unexpected GetByID call")
|
||||
}
|
||||
|
||||
@@ -47,11 +47,11 @@ func (s *apiKeyRepoStub) GetOwnerID(ctx context.Context, id int64) (int64, error
|
||||
return s.ownerID, s.ownerErr
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) GetByKey(ctx context.Context, key string) (*APIKey, error) {
|
||||
func (s *apiKeyRepoStub) GetByKey(ctx context.Context, key string) (*ApiKey, error) {
|
||||
panic("unexpected GetByKey call")
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) Update(ctx context.Context, key *APIKey) error {
|
||||
func (s *apiKeyRepoStub) Update(ctx context.Context, key *ApiKey) error {
|
||||
panic("unexpected Update call")
|
||||
}
|
||||
|
||||
@@ -64,7 +64,7 @@ func (s *apiKeyRepoStub) Delete(ctx context.Context, id int64) error {
|
||||
|
||||
// 以下是接口要求实现但本测试不关心的方法
|
||||
|
||||
func (s *apiKeyRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
|
||||
func (s *apiKeyRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListByUserID call")
|
||||
}
|
||||
|
||||
@@ -80,12 +80,12 @@ func (s *apiKeyRepoStub) ExistsByKey(ctx context.Context, key string) (bool, err
|
||||
panic("unexpected ExistsByKey call")
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
|
||||
func (s *apiKeyRepoStub) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListByGroupID call")
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error) {
|
||||
panic("unexpected SearchAPIKeys call")
|
||||
func (s *apiKeyRepoStub) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]ApiKey, error) {
|
||||
panic("unexpected SearchApiKeys call")
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
@@ -96,7 +96,7 @@ func (s *apiKeyRepoStub) CountByGroupID(ctx context.Context, groupID int64) (int
|
||||
panic("unexpected CountByGroupID call")
|
||||
}
|
||||
|
||||
// apiKeyCacheStub 是 APIKeyCache 接口的测试桩实现。
|
||||
// apiKeyCacheStub 是 ApiKeyCache 接口的测试桩实现。
|
||||
// 用于验证删除操作时缓存清理逻辑是否被正确调用。
|
||||
//
|
||||
// 设计说明:
|
||||
@@ -132,17 +132,17 @@ func (s *apiKeyCacheStub) SetDailyUsageExpiry(ctx context.Context, apiKey string
|
||||
return nil
|
||||
}
|
||||
|
||||
// TestAPIKeyService_Delete_OwnerMismatch 测试非所有者尝试删除时返回权限错误。
|
||||
// TestApiKeyService_Delete_OwnerMismatch 测试非所有者尝试删除时返回权限错误。
|
||||
// 预期行为:
|
||||
// - GetOwnerID 返回所有者 ID 为 1
|
||||
// - 调用者 userID 为 2(不匹配)
|
||||
// - 返回 ErrInsufficientPerms 错误
|
||||
// - Delete 方法不被调用
|
||||
// - 缓存不被清除
|
||||
func TestAPIKeyService_Delete_OwnerMismatch(t *testing.T) {
|
||||
func TestApiKeyService_Delete_OwnerMismatch(t *testing.T) {
|
||||
repo := &apiKeyRepoStub{ownerID: 1}
|
||||
cache := &apiKeyCacheStub{}
|
||||
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
|
||||
svc := &ApiKeyService{apiKeyRepo: repo, cache: cache}
|
||||
|
||||
err := svc.Delete(context.Background(), 10, 2) // API Key ID=10, 调用者 userID=2
|
||||
require.ErrorIs(t, err, ErrInsufficientPerms)
|
||||
@@ -150,17 +150,17 @@ func TestAPIKeyService_Delete_OwnerMismatch(t *testing.T) {
|
||||
require.Empty(t, cache.invalidated) // 验证缓存未被清除
|
||||
}
|
||||
|
||||
// TestAPIKeyService_Delete_Success 测试所有者成功删除 API Key 的场景。
|
||||
// TestApiKeyService_Delete_Success 测试所有者成功删除 API Key 的场景。
|
||||
// 预期行为:
|
||||
// - GetOwnerID 返回所有者 ID 为 7
|
||||
// - 调用者 userID 为 7(匹配)
|
||||
// - Delete 成功执行
|
||||
// - 缓存被正确清除(使用 ownerID)
|
||||
// - 返回 nil 错误
|
||||
func TestAPIKeyService_Delete_Success(t *testing.T) {
|
||||
func TestApiKeyService_Delete_Success(t *testing.T) {
|
||||
repo := &apiKeyRepoStub{ownerID: 7}
|
||||
cache := &apiKeyCacheStub{}
|
||||
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
|
||||
svc := &ApiKeyService{apiKeyRepo: repo, cache: cache}
|
||||
|
||||
err := svc.Delete(context.Background(), 42, 7) // API Key ID=42, 调用者 userID=7
|
||||
require.NoError(t, err)
|
||||
@@ -168,37 +168,37 @@ func TestAPIKeyService_Delete_Success(t *testing.T) {
|
||||
require.Equal(t, []int64{7}, cache.invalidated) // 验证所有者的缓存被清除
|
||||
}
|
||||
|
||||
// TestAPIKeyService_Delete_NotFound 测试删除不存在的 API Key 时返回正确的错误。
|
||||
// TestApiKeyService_Delete_NotFound 测试删除不存在的 API Key 时返回正确的错误。
|
||||
// 预期行为:
|
||||
// - GetOwnerID 返回 ErrAPIKeyNotFound 错误
|
||||
// - 返回 ErrAPIKeyNotFound 错误(被 fmt.Errorf 包装)
|
||||
// - GetOwnerID 返回 ErrApiKeyNotFound 错误
|
||||
// - 返回 ErrApiKeyNotFound 错误(被 fmt.Errorf 包装)
|
||||
// - Delete 方法不被调用
|
||||
// - 缓存不被清除
|
||||
func TestAPIKeyService_Delete_NotFound(t *testing.T) {
|
||||
repo := &apiKeyRepoStub{ownerErr: ErrAPIKeyNotFound}
|
||||
func TestApiKeyService_Delete_NotFound(t *testing.T) {
|
||||
repo := &apiKeyRepoStub{ownerErr: ErrApiKeyNotFound}
|
||||
cache := &apiKeyCacheStub{}
|
||||
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
|
||||
svc := &ApiKeyService{apiKeyRepo: repo, cache: cache}
|
||||
|
||||
err := svc.Delete(context.Background(), 99, 1)
|
||||
require.ErrorIs(t, err, ErrAPIKeyNotFound)
|
||||
require.ErrorIs(t, err, ErrApiKeyNotFound)
|
||||
require.Empty(t, repo.deletedIDs)
|
||||
require.Empty(t, cache.invalidated)
|
||||
}
|
||||
|
||||
// TestAPIKeyService_Delete_DeleteFails 测试删除操作失败时的错误处理。
|
||||
// TestApiKeyService_Delete_DeleteFails 测试删除操作失败时的错误处理。
|
||||
// 预期行为:
|
||||
// - GetOwnerID 返回正确的所有者 ID
|
||||
// - 所有权验证通过
|
||||
// - 缓存被清除(在删除之前)
|
||||
// - Delete 被调用但返回错误
|
||||
// - 返回包含 "delete api key" 的错误信息
|
||||
func TestAPIKeyService_Delete_DeleteFails(t *testing.T) {
|
||||
func TestApiKeyService_Delete_DeleteFails(t *testing.T) {
|
||||
repo := &apiKeyRepoStub{
|
||||
ownerID: 3,
|
||||
deleteErr: errors.New("delete failed"),
|
||||
}
|
||||
cache := &apiKeyCacheStub{}
|
||||
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
|
||||
svc := &ApiKeyService{apiKeyRepo: repo, cache: cache}
|
||||
|
||||
err := svc.Delete(context.Background(), 3, 3) // API Key ID=3, 调用者 userID=3
|
||||
require.Error(t, err)
|
||||
|
||||
@@ -445,7 +445,7 @@ func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID
|
||||
// CheckBillingEligibility 检查用户是否有资格发起请求
|
||||
// 余额模式:检查缓存余额 > 0
|
||||
// 订阅模式:检查缓存用量未超过限额(Group限额从参数传入)
|
||||
func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user *User, apiKey *APIKey, group *Group, subscription *UserSubscription) error {
|
||||
func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user *User, apiKey *ApiKey, group *Group, subscription *UserSubscription) error {
|
||||
// 简易模式:跳过所有计费检查
|
||||
if s.cfg.RunMode == config.RunModeSimple {
|
||||
return nil
|
||||
|
||||
@@ -82,7 +82,7 @@ type crsExportResponse struct {
|
||||
OpenAIOAuthAccounts []crsOpenAIOAuthAccount `json:"openaiOAuthAccounts"`
|
||||
OpenAIResponsesAccounts []crsOpenAIResponsesAccount `json:"openaiResponsesAccounts"`
|
||||
GeminiOAuthAccounts []crsGeminiOAuthAccount `json:"geminiOAuthAccounts"`
|
||||
GeminiAPIKeyAccounts []crsGeminiAPIKeyAccount `json:"geminiAPIKeyAccounts"`
|
||||
GeminiAPIKeyAccounts []crsGeminiAPIKeyAccount `json:"geminiApiKeyAccounts"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
@@ -430,7 +430,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
||||
account := &Account{
|
||||
Name: defaultName(src.Name, src.ID),
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeAPIKey,
|
||||
Type: AccountTypeApiKey,
|
||||
Credentials: credentials,
|
||||
Extra: extra,
|
||||
ProxyID: proxyID,
|
||||
@@ -455,7 +455,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
||||
existing.Extra = mergeMap(existing.Extra, extra)
|
||||
existing.Name = defaultName(src.Name, src.ID)
|
||||
existing.Platform = PlatformAnthropic
|
||||
existing.Type = AccountTypeAPIKey
|
||||
existing.Type = AccountTypeApiKey
|
||||
existing.Credentials = mergeMap(existing.Credentials, credentials)
|
||||
if proxyID != nil {
|
||||
existing.ProxyID = proxyID
|
||||
@@ -674,7 +674,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
||||
account := &Account{
|
||||
Name: defaultName(src.Name, src.ID),
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Type: AccountTypeApiKey,
|
||||
Credentials: credentials,
|
||||
Extra: extra,
|
||||
ProxyID: proxyID,
|
||||
@@ -699,7 +699,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
||||
existing.Extra = mergeMap(existing.Extra, extra)
|
||||
existing.Name = defaultName(src.Name, src.ID)
|
||||
existing.Platform = PlatformOpenAI
|
||||
existing.Type = AccountTypeAPIKey
|
||||
existing.Type = AccountTypeApiKey
|
||||
existing.Credentials = mergeMap(existing.Credentials, credentials)
|
||||
if proxyID != nil {
|
||||
existing.ProxyID = proxyID
|
||||
@@ -893,7 +893,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
||||
account := &Account{
|
||||
Name: defaultName(src.Name, src.ID),
|
||||
Platform: PlatformGemini,
|
||||
Type: AccountTypeAPIKey,
|
||||
Type: AccountTypeApiKey,
|
||||
Credentials: credentials,
|
||||
Extra: extra,
|
||||
ProxyID: proxyID,
|
||||
@@ -918,7 +918,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
||||
existing.Extra = mergeMap(existing.Extra, extra)
|
||||
existing.Name = defaultName(src.Name, src.ID)
|
||||
existing.Platform = PlatformGemini
|
||||
existing.Type = AccountTypeAPIKey
|
||||
existing.Type = AccountTypeApiKey
|
||||
existing.Credentials = mergeMap(existing.Credentials, credentials)
|
||||
if proxyID != nil {
|
||||
existing.ProxyID = proxyID
|
||||
|
||||
@@ -43,8 +43,8 @@ func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTi
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (s *DashboardService) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) {
|
||||
trend, err := s.usageRepo.GetAPIKeyUsageTrend(ctx, startTime, endTime, granularity, limit)
|
||||
func (s *DashboardService) GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.ApiKeyUsageTrendPoint, error) {
|
||||
trend, err := s.usageRepo.GetApiKeyUsageTrend(ctx, startTime, endTime, granularity, limit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get api key usage trend: %w", err)
|
||||
}
|
||||
@@ -67,8 +67,8 @@ func (s *DashboardService) GetBatchUserUsageStats(ctx context.Context, userIDs [
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (s *DashboardService) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
|
||||
stats, err := s.usageRepo.GetBatchAPIKeyUsageStats(ctx, apiKeyIDs)
|
||||
func (s *DashboardService) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchApiKeyUsageStats, error) {
|
||||
stats, err := s.usageRepo.GetBatchApiKeyUsageStats(ctx, apiKeyIDs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get batch api key usage stats: %w", err)
|
||||
}
|
||||
|
||||
@@ -28,7 +28,7 @@ const (
|
||||
const (
|
||||
AccountTypeOAuth = "oauth" // OAuth类型账号(full scope: profile + inference)
|
||||
AccountTypeSetupToken = "setup-token" // Setup Token类型账号(inference only scope)
|
||||
AccountTypeAPIKey = "apikey" // API Key类型账号
|
||||
AccountTypeApiKey = "apikey" // API Key类型账号
|
||||
)
|
||||
|
||||
// Redeem type constants
|
||||
@@ -64,13 +64,13 @@ const (
|
||||
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
|
||||
|
||||
// 邮件服务设置
|
||||
SettingKeySMTPHost = "smtp_host" // SMTP服务器地址
|
||||
SettingKeySMTPPort = "smtp_port" // SMTP端口
|
||||
SettingKeySMTPUsername = "smtp_username" // SMTP用户名
|
||||
SettingKeySMTPPassword = "smtp_password" // SMTP密码(加密存储)
|
||||
SettingKeySMTPFrom = "smtp_from" // 发件人地址
|
||||
SettingKeySMTPFromName = "smtp_from_name" // 发件人名称
|
||||
SettingKeySMTPUseTLS = "smtp_use_tls" // 是否使用TLS
|
||||
SettingKeySmtpHost = "smtp_host" // SMTP服务器地址
|
||||
SettingKeySmtpPort = "smtp_port" // SMTP端口
|
||||
SettingKeySmtpUsername = "smtp_username" // SMTP用户名
|
||||
SettingKeySmtpPassword = "smtp_password" // SMTP密码(加密存储)
|
||||
SettingKeySmtpFrom = "smtp_from" // 发件人地址
|
||||
SettingKeySmtpFromName = "smtp_from_name" // 发件人名称
|
||||
SettingKeySmtpUseTLS = "smtp_use_tls" // 是否使用TLS
|
||||
|
||||
// Cloudflare Turnstile 设置
|
||||
SettingKeyTurnstileEnabled = "turnstile_enabled" // 是否启用 Turnstile 验证
|
||||
@@ -81,20 +81,27 @@ const (
|
||||
SettingKeySiteName = "site_name" // 网站名称
|
||||
SettingKeySiteLogo = "site_logo" // 网站Logo (base64)
|
||||
SettingKeySiteSubtitle = "site_subtitle" // 网站副标题
|
||||
SettingKeyAPIBaseURL = "api_base_url" // API端点地址(用于客户端配置和导入)
|
||||
SettingKeyApiBaseUrl = "api_base_url" // API端点地址(用于客户端配置和导入)
|
||||
SettingKeyContactInfo = "contact_info" // 客服联系方式
|
||||
SettingKeyDocURL = "doc_url" // 文档链接
|
||||
SettingKeyDocUrl = "doc_url" // 文档链接
|
||||
|
||||
// 默认配置
|
||||
SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量
|
||||
SettingKeyDefaultBalance = "default_balance" // 新用户默认余额
|
||||
|
||||
// 管理员 API Key
|
||||
SettingKeyAdminAPIKey = "admin_api_key" // 全局管理员 API Key(用于外部系统集成)
|
||||
SettingKeyAdminApiKey = "admin_api_key" // 全局管理员 API Key(用于外部系统集成)
|
||||
|
||||
// Gemini 配额策略(JSON)
|
||||
SettingKeyGeminiQuotaPolicy = "gemini_quota_policy"
|
||||
|
||||
// Model fallback settings
|
||||
SettingKeyEnableModelFallback = "enable_model_fallback"
|
||||
SettingKeyFallbackModelAnthropic = "fallback_model_anthropic"
|
||||
SettingKeyFallbackModelOpenAI = "fallback_model_openai"
|
||||
SettingKeyFallbackModelGemini = "fallback_model_gemini"
|
||||
SettingKeyFallbackModelAntigravity = "fallback_model_antigravity"
|
||||
)
|
||||
|
||||
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys)
|
||||
const AdminAPIKeyPrefix = "admin-"
|
||||
// Admin API Key prefix (distinct from user "sk-" keys)
|
||||
const AdminApiKeyPrefix = "admin-"
|
||||
|
||||
@@ -40,8 +40,8 @@ const (
|
||||
maxVerifyCodeAttempts = 5
|
||||
)
|
||||
|
||||
// SMTPConfig SMTP配置
|
||||
type SMTPConfig struct {
|
||||
// SmtpConfig SMTP配置
|
||||
type SmtpConfig struct {
|
||||
Host string
|
||||
Port int
|
||||
Username string
|
||||
@@ -65,16 +65,16 @@ func NewEmailService(settingRepo SettingRepository, cache EmailCache) *EmailServ
|
||||
}
|
||||
}
|
||||
|
||||
// GetSMTPConfig 从数据库获取SMTP配置
|
||||
func (s *EmailService) GetSMTPConfig(ctx context.Context) (*SMTPConfig, error) {
|
||||
// GetSmtpConfig 从数据库获取SMTP配置
|
||||
func (s *EmailService) GetSmtpConfig(ctx context.Context) (*SmtpConfig, error) {
|
||||
keys := []string{
|
||||
SettingKeySMTPHost,
|
||||
SettingKeySMTPPort,
|
||||
SettingKeySMTPUsername,
|
||||
SettingKeySMTPPassword,
|
||||
SettingKeySMTPFrom,
|
||||
SettingKeySMTPFromName,
|
||||
SettingKeySMTPUseTLS,
|
||||
SettingKeySmtpHost,
|
||||
SettingKeySmtpPort,
|
||||
SettingKeySmtpUsername,
|
||||
SettingKeySmtpPassword,
|
||||
SettingKeySmtpFrom,
|
||||
SettingKeySmtpFromName,
|
||||
SettingKeySmtpUseTLS,
|
||||
}
|
||||
|
||||
settings, err := s.settingRepo.GetMultiple(ctx, keys)
|
||||
@@ -82,34 +82,34 @@ func (s *EmailService) GetSMTPConfig(ctx context.Context) (*SMTPConfig, error) {
|
||||
return nil, fmt.Errorf("get smtp settings: %w", err)
|
||||
}
|
||||
|
||||
host := settings[SettingKeySMTPHost]
|
||||
host := settings[SettingKeySmtpHost]
|
||||
if host == "" {
|
||||
return nil, ErrEmailNotConfigured
|
||||
}
|
||||
|
||||
port := 587 // 默认端口
|
||||
if portStr := settings[SettingKeySMTPPort]; portStr != "" {
|
||||
if portStr := settings[SettingKeySmtpPort]; portStr != "" {
|
||||
if p, err := strconv.Atoi(portStr); err == nil {
|
||||
port = p
|
||||
}
|
||||
}
|
||||
|
||||
useTLS := settings[SettingKeySMTPUseTLS] == "true"
|
||||
useTLS := settings[SettingKeySmtpUseTLS] == "true"
|
||||
|
||||
return &SMTPConfig{
|
||||
return &SmtpConfig{
|
||||
Host: host,
|
||||
Port: port,
|
||||
Username: settings[SettingKeySMTPUsername],
|
||||
Password: settings[SettingKeySMTPPassword],
|
||||
From: settings[SettingKeySMTPFrom],
|
||||
FromName: settings[SettingKeySMTPFromName],
|
||||
Username: settings[SettingKeySmtpUsername],
|
||||
Password: settings[SettingKeySmtpPassword],
|
||||
From: settings[SettingKeySmtpFrom],
|
||||
FromName: settings[SettingKeySmtpFromName],
|
||||
UseTLS: useTLS,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SendEmail 发送邮件(使用数据库中保存的配置)
|
||||
func (s *EmailService) SendEmail(ctx context.Context, to, subject, body string) error {
|
||||
config, err := s.GetSMTPConfig(ctx)
|
||||
config, err := s.GetSmtpConfig(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -117,7 +117,7 @@ func (s *EmailService) SendEmail(ctx context.Context, to, subject, body string)
|
||||
}
|
||||
|
||||
// SendEmailWithConfig 使用指定配置发送邮件
|
||||
func (s *EmailService) SendEmailWithConfig(config *SMTPConfig, to, subject, body string) error {
|
||||
func (s *EmailService) SendEmailWithConfig(config *SmtpConfig, to, subject, body string) error {
|
||||
from := config.From
|
||||
if config.FromName != "" {
|
||||
from = fmt.Sprintf("%s <%s>", config.FromName, config.From)
|
||||
@@ -306,8 +306,8 @@ func (s *EmailService) buildVerifyCodeEmailBody(code, siteName string) string {
|
||||
`, siteName, code)
|
||||
}
|
||||
|
||||
// TestSMTPConnectionWithConfig 使用指定配置测试SMTP连接
|
||||
func (s *EmailService) TestSMTPConnectionWithConfig(config *SMTPConfig) error {
|
||||
// TestSmtpConnectionWithConfig 使用指定配置测试SMTP连接
|
||||
func (s *EmailService) TestSmtpConnectionWithConfig(config *SmtpConfig) error {
|
||||
addr := fmt.Sprintf("%s:%d", config.Host, config.Port)
|
||||
|
||||
if config.UseTLS {
|
||||
|
||||
@@ -487,8 +487,8 @@ func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *Acco
|
||||
return "", "", errors.New("access_token not found in credentials")
|
||||
}
|
||||
return accessToken, "oauth", nil
|
||||
case AccountTypeAPIKey:
|
||||
apiKey := account.GetOpenAIAPIKey()
|
||||
case AccountTypeApiKey:
|
||||
apiKey := account.GetOpenAIApiKey()
|
||||
if apiKey == "" {
|
||||
return "", "", errors.New("api_key not found in credentials")
|
||||
}
|
||||
@@ -627,7 +627,7 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
|
||||
case AccountTypeOAuth:
|
||||
// OAuth accounts use ChatGPT internal API
|
||||
targetURL = chatgptCodexURL
|
||||
case AccountTypeAPIKey:
|
||||
case AccountTypeApiKey:
|
||||
// API Key accounts use Platform API or custom base URL
|
||||
baseURL := account.GetOpenAIBaseURL()
|
||||
if baseURL != "" {
|
||||
@@ -703,7 +703,13 @@ func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *ht
|
||||
}
|
||||
|
||||
// Handle upstream error (mark account status)
|
||||
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
||||
shouldDisable := false
|
||||
if s.rateLimitService != nil {
|
||||
shouldDisable = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
||||
}
|
||||
if shouldDisable {
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
}
|
||||
|
||||
// Return appropriate error response
|
||||
var errType, errMsg string
|
||||
@@ -940,7 +946,7 @@ func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel
|
||||
// OpenAIRecordUsageInput input for recording usage
|
||||
type OpenAIRecordUsageInput struct {
|
||||
Result *OpenAIForwardResult
|
||||
APIKey *APIKey
|
||||
ApiKey *ApiKey
|
||||
User *User
|
||||
Account *Account
|
||||
Subscription *UserSubscription
|
||||
@@ -949,7 +955,7 @@ type OpenAIRecordUsageInput struct {
|
||||
// RecordUsage records usage and deducts balance
|
||||
func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRecordUsageInput) error {
|
||||
result := input.Result
|
||||
apiKey := input.APIKey
|
||||
apiKey := input.ApiKey
|
||||
user := input.User
|
||||
account := input.Account
|
||||
subscription := input.Subscription
|
||||
@@ -991,7 +997,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
durationMs := int(result.Duration.Milliseconds())
|
||||
usageLog := &UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
ApiKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: result.RequestID,
|
||||
Model: result.Model,
|
||||
|
||||
@@ -61,9 +61,9 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
||||
SettingKeySiteName,
|
||||
SettingKeySiteLogo,
|
||||
SettingKeySiteSubtitle,
|
||||
SettingKeyAPIBaseURL,
|
||||
SettingKeyApiBaseUrl,
|
||||
SettingKeyContactInfo,
|
||||
SettingKeyDocURL,
|
||||
SettingKeyDocUrl,
|
||||
}
|
||||
|
||||
settings, err := s.settingRepo.GetMultiple(ctx, keys)
|
||||
@@ -79,9 +79,9 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
||||
SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"),
|
||||
SiteLogo: settings[SettingKeySiteLogo],
|
||||
SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"),
|
||||
APIBaseURL: settings[SettingKeyAPIBaseURL],
|
||||
ApiBaseUrl: settings[SettingKeyApiBaseUrl],
|
||||
ContactInfo: settings[SettingKeyContactInfo],
|
||||
DocURL: settings[SettingKeyDocURL],
|
||||
DocUrl: settings[SettingKeyDocUrl],
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -94,15 +94,15 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
|
||||
updates[SettingKeyEmailVerifyEnabled] = strconv.FormatBool(settings.EmailVerifyEnabled)
|
||||
|
||||
// 邮件服务设置(只有非空才更新密码)
|
||||
updates[SettingKeySMTPHost] = settings.SMTPHost
|
||||
updates[SettingKeySMTPPort] = strconv.Itoa(settings.SMTPPort)
|
||||
updates[SettingKeySMTPUsername] = settings.SMTPUsername
|
||||
if settings.SMTPPassword != "" {
|
||||
updates[SettingKeySMTPPassword] = settings.SMTPPassword
|
||||
updates[SettingKeySmtpHost] = settings.SmtpHost
|
||||
updates[SettingKeySmtpPort] = strconv.Itoa(settings.SmtpPort)
|
||||
updates[SettingKeySmtpUsername] = settings.SmtpUsername
|
||||
if settings.SmtpPassword != "" {
|
||||
updates[SettingKeySmtpPassword] = settings.SmtpPassword
|
||||
}
|
||||
updates[SettingKeySMTPFrom] = settings.SMTPFrom
|
||||
updates[SettingKeySMTPFromName] = settings.SMTPFromName
|
||||
updates[SettingKeySMTPUseTLS] = strconv.FormatBool(settings.SMTPUseTLS)
|
||||
updates[SettingKeySmtpFrom] = settings.SmtpFrom
|
||||
updates[SettingKeySmtpFromName] = settings.SmtpFromName
|
||||
updates[SettingKeySmtpUseTLS] = strconv.FormatBool(settings.SmtpUseTLS)
|
||||
|
||||
// Cloudflare Turnstile 设置(只有非空才更新密钥)
|
||||
updates[SettingKeyTurnstileEnabled] = strconv.FormatBool(settings.TurnstileEnabled)
|
||||
@@ -115,14 +115,21 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
|
||||
updates[SettingKeySiteName] = settings.SiteName
|
||||
updates[SettingKeySiteLogo] = settings.SiteLogo
|
||||
updates[SettingKeySiteSubtitle] = settings.SiteSubtitle
|
||||
updates[SettingKeyAPIBaseURL] = settings.APIBaseURL
|
||||
updates[SettingKeyApiBaseUrl] = settings.ApiBaseUrl
|
||||
updates[SettingKeyContactInfo] = settings.ContactInfo
|
||||
updates[SettingKeyDocURL] = settings.DocURL
|
||||
updates[SettingKeyDocUrl] = settings.DocUrl
|
||||
|
||||
// 默认配置
|
||||
updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency)
|
||||
updates[SettingKeyDefaultBalance] = strconv.FormatFloat(settings.DefaultBalance, 'f', 8, 64)
|
||||
|
||||
// Model fallback configuration
|
||||
updates[SettingKeyEnableModelFallback] = strconv.FormatBool(settings.EnableModelFallback)
|
||||
updates[SettingKeyFallbackModelAnthropic] = settings.FallbackModelAnthropic
|
||||
updates[SettingKeyFallbackModelOpenAI] = settings.FallbackModelOpenAI
|
||||
updates[SettingKeyFallbackModelGemini] = settings.FallbackModelGemini
|
||||
updates[SettingKeyFallbackModelAntigravity] = settings.FallbackModelAntigravity
|
||||
|
||||
return s.settingRepo.SetMultiple(ctx, updates)
|
||||
}
|
||||
|
||||
@@ -198,8 +205,14 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
|
||||
SettingKeySiteLogo: "",
|
||||
SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
|
||||
SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
|
||||
SettingKeySMTPPort: "587",
|
||||
SettingKeySMTPUseTLS: "false",
|
||||
SettingKeySmtpPort: "587",
|
||||
SettingKeySmtpUseTLS: "false",
|
||||
// Model fallback defaults
|
||||
SettingKeyEnableModelFallback: "false",
|
||||
SettingKeyFallbackModelAnthropic: "claude-3-5-sonnet-20241022",
|
||||
SettingKeyFallbackModelOpenAI: "gpt-4o",
|
||||
SettingKeyFallbackModelGemini: "gemini-2.5-pro",
|
||||
SettingKeyFallbackModelAntigravity: "gemini-2.5-pro",
|
||||
}
|
||||
|
||||
return s.settingRepo.SetMultiple(ctx, defaults)
|
||||
@@ -210,26 +223,26 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
|
||||
result := &SystemSettings{
|
||||
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
|
||||
EmailVerifyEnabled: settings[SettingKeyEmailVerifyEnabled] == "true",
|
||||
SMTPHost: settings[SettingKeySMTPHost],
|
||||
SMTPUsername: settings[SettingKeySMTPUsername],
|
||||
SMTPFrom: settings[SettingKeySMTPFrom],
|
||||
SMTPFromName: settings[SettingKeySMTPFromName],
|
||||
SMTPUseTLS: settings[SettingKeySMTPUseTLS] == "true",
|
||||
SmtpHost: settings[SettingKeySmtpHost],
|
||||
SmtpUsername: settings[SettingKeySmtpUsername],
|
||||
SmtpFrom: settings[SettingKeySmtpFrom],
|
||||
SmtpFromName: settings[SettingKeySmtpFromName],
|
||||
SmtpUseTLS: settings[SettingKeySmtpUseTLS] == "true",
|
||||
TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true",
|
||||
TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey],
|
||||
SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"),
|
||||
SiteLogo: settings[SettingKeySiteLogo],
|
||||
SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"),
|
||||
APIBaseURL: settings[SettingKeyAPIBaseURL],
|
||||
ApiBaseUrl: settings[SettingKeyApiBaseUrl],
|
||||
ContactInfo: settings[SettingKeyContactInfo],
|
||||
DocURL: settings[SettingKeyDocURL],
|
||||
DocUrl: settings[SettingKeyDocUrl],
|
||||
}
|
||||
|
||||
// 解析整数类型
|
||||
if port, err := strconv.Atoi(settings[SettingKeySMTPPort]); err == nil {
|
||||
result.SMTPPort = port
|
||||
if port, err := strconv.Atoi(settings[SettingKeySmtpPort]); err == nil {
|
||||
result.SmtpPort = port
|
||||
} else {
|
||||
result.SMTPPort = 587
|
||||
result.SmtpPort = 587
|
||||
}
|
||||
|
||||
if concurrency, err := strconv.Atoi(settings[SettingKeyDefaultConcurrency]); err == nil {
|
||||
@@ -245,10 +258,17 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
|
||||
result.DefaultBalance = s.cfg.Default.UserBalance
|
||||
}
|
||||
|
||||
// 敏感信息直接返回,方便测试连接时使用
|
||||
result.SMTPPassword = settings[SettingKeySMTPPassword]
|
||||
// 敏感信息直接返回,方便测试连接时使用
|
||||
result.SmtpPassword = settings[SettingKeySmtpPassword]
|
||||
result.TurnstileSecretKey = settings[SettingKeyTurnstileSecretKey]
|
||||
|
||||
// Model fallback settings
|
||||
result.EnableModelFallback = settings[SettingKeyEnableModelFallback] == "true"
|
||||
result.FallbackModelAnthropic = s.getStringOrDefault(settings, SettingKeyFallbackModelAnthropic, "claude-3-5-sonnet-20241022")
|
||||
result.FallbackModelOpenAI = s.getStringOrDefault(settings, SettingKeyFallbackModelOpenAI, "gpt-4o")
|
||||
result.FallbackModelGemini = s.getStringOrDefault(settings, SettingKeyFallbackModelGemini, "gemini-2.5-pro")
|
||||
result.FallbackModelAntigravity = s.getStringOrDefault(settings, SettingKeyFallbackModelAntigravity, "gemini-2.5-pro")
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -278,28 +298,28 @@ func (s *SettingService) GetTurnstileSecretKey(ctx context.Context) string {
|
||||
return value
|
||||
}
|
||||
|
||||
// GenerateAdminAPIKey 生成新的管理员 API Key
|
||||
func (s *SettingService) GenerateAdminAPIKey(ctx context.Context) (string, error) {
|
||||
// GenerateAdminApiKey 生成新的管理员 API Key
|
||||
func (s *SettingService) GenerateAdminApiKey(ctx context.Context) (string, error) {
|
||||
// 生成 32 字节随机数 = 64 位十六进制字符
|
||||
bytes := make([]byte, 32)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return "", fmt.Errorf("generate random bytes: %w", err)
|
||||
}
|
||||
|
||||
key := AdminAPIKeyPrefix + hex.EncodeToString(bytes)
|
||||
key := AdminApiKeyPrefix + hex.EncodeToString(bytes)
|
||||
|
||||
// 存储到 settings 表
|
||||
if err := s.settingRepo.Set(ctx, SettingKeyAdminAPIKey, key); err != nil {
|
||||
if err := s.settingRepo.Set(ctx, SettingKeyAdminApiKey, key); err != nil {
|
||||
return "", fmt.Errorf("save admin api key: %w", err)
|
||||
}
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// GetAdminAPIKeyStatus 获取管理员 API Key 状态
|
||||
// GetAdminApiKeyStatus 获取管理员 API Key 状态
|
||||
// 返回脱敏的 key、是否存在、错误
|
||||
func (s *SettingService) GetAdminAPIKeyStatus(ctx context.Context) (maskedKey string, exists bool, err error) {
|
||||
key, err := s.settingRepo.GetValue(ctx, SettingKeyAdminAPIKey)
|
||||
func (s *SettingService) GetAdminApiKeyStatus(ctx context.Context) (maskedKey string, exists bool, err error) {
|
||||
key, err := s.settingRepo.GetValue(ctx, SettingKeyAdminApiKey)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrSettingNotFound) {
|
||||
return "", false, nil
|
||||
@@ -320,10 +340,10 @@ func (s *SettingService) GetAdminAPIKeyStatus(ctx context.Context) (maskedKey st
|
||||
return maskedKey, true, nil
|
||||
}
|
||||
|
||||
// GetAdminAPIKey 获取完整的管理员 API Key(仅供内部验证使用)
|
||||
// GetAdminApiKey 获取完整的管理员 API Key(仅供内部验证使用)
|
||||
// 如果未配置返回空字符串和 nil 错误,只有数据库错误时才返回 error
|
||||
func (s *SettingService) GetAdminAPIKey(ctx context.Context) (string, error) {
|
||||
key, err := s.settingRepo.GetValue(ctx, SettingKeyAdminAPIKey)
|
||||
func (s *SettingService) GetAdminApiKey(ctx context.Context) (string, error) {
|
||||
key, err := s.settingRepo.GetValue(ctx, SettingKeyAdminApiKey)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrSettingNotFound) {
|
||||
return "", nil // 未配置,返回空字符串
|
||||
@@ -333,7 +353,45 @@ func (s *SettingService) GetAdminAPIKey(ctx context.Context) (string, error) {
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// DeleteAdminAPIKey 删除管理员 API Key
|
||||
func (s *SettingService) DeleteAdminAPIKey(ctx context.Context) error {
|
||||
return s.settingRepo.Delete(ctx, SettingKeyAdminAPIKey)
|
||||
// DeleteAdminApiKey 删除管理员 API Key
|
||||
func (s *SettingService) DeleteAdminApiKey(ctx context.Context) error {
|
||||
return s.settingRepo.Delete(ctx, SettingKeyAdminApiKey)
|
||||
}
|
||||
|
||||
// IsModelFallbackEnabled 检查是否启用模型兜底机制
|
||||
func (s *SettingService) IsModelFallbackEnabled(ctx context.Context) bool {
|
||||
value, err := s.settingRepo.GetValue(ctx, SettingKeyEnableModelFallback)
|
||||
if err != nil {
|
||||
return false // Default: disabled
|
||||
}
|
||||
return value == "true"
|
||||
}
|
||||
|
||||
// GetFallbackModel 获取指定平台的兜底模型
|
||||
func (s *SettingService) GetFallbackModel(ctx context.Context, platform string) string {
|
||||
var key string
|
||||
var defaultModel string
|
||||
|
||||
switch platform {
|
||||
case PlatformAnthropic:
|
||||
key = SettingKeyFallbackModelAnthropic
|
||||
defaultModel = "claude-3-5-sonnet-20241022"
|
||||
case PlatformOpenAI:
|
||||
key = SettingKeyFallbackModelOpenAI
|
||||
defaultModel = "gpt-4o"
|
||||
case PlatformGemini:
|
||||
key = SettingKeyFallbackModelGemini
|
||||
defaultModel = "gemini-2.5-pro"
|
||||
case PlatformAntigravity:
|
||||
key = SettingKeyFallbackModelAntigravity
|
||||
defaultModel = "gemini-2.5-pro"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
|
||||
value, err := s.settingRepo.GetValue(ctx, key)
|
||||
if err != nil || value == "" {
|
||||
return defaultModel
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
@@ -4,13 +4,13 @@ type SystemSettings struct {
|
||||
RegistrationEnabled bool
|
||||
EmailVerifyEnabled bool
|
||||
|
||||
SMTPHost string
|
||||
SMTPPort int
|
||||
SMTPUsername string
|
||||
SMTPPassword string
|
||||
SMTPFrom string
|
||||
SMTPFromName string
|
||||
SMTPUseTLS bool
|
||||
SmtpHost string
|
||||
SmtpPort int
|
||||
SmtpUsername string
|
||||
SmtpPassword string
|
||||
SmtpFrom string
|
||||
SmtpFromName string
|
||||
SmtpUseTLS bool
|
||||
|
||||
TurnstileEnabled bool
|
||||
TurnstileSiteKey string
|
||||
@@ -19,12 +19,19 @@ type SystemSettings struct {
|
||||
SiteName string
|
||||
SiteLogo string
|
||||
SiteSubtitle string
|
||||
APIBaseURL string
|
||||
ApiBaseUrl string
|
||||
ContactInfo string
|
||||
DocURL string
|
||||
DocUrl string
|
||||
|
||||
DefaultConcurrency int
|
||||
DefaultBalance float64
|
||||
|
||||
// Model fallback configuration
|
||||
EnableModelFallback bool `json:"enable_model_fallback"`
|
||||
FallbackModelAnthropic string `json:"fallback_model_anthropic"`
|
||||
FallbackModelOpenAI string `json:"fallback_model_openai"`
|
||||
FallbackModelGemini string `json:"fallback_model_gemini"`
|
||||
FallbackModelAntigravity string `json:"fallback_model_antigravity"`
|
||||
}
|
||||
|
||||
type PublicSettings struct {
|
||||
@@ -35,8 +42,8 @@ type PublicSettings struct {
|
||||
SiteName string
|
||||
SiteLogo string
|
||||
SiteSubtitle string
|
||||
APIBaseURL string
|
||||
ApiBaseUrl string
|
||||
ContactInfo string
|
||||
DocURL string
|
||||
DocUrl string
|
||||
Version string
|
||||
}
|
||||
|
||||
@@ -79,7 +79,7 @@ type ReleaseInfo struct {
|
||||
Name string `json:"name"`
|
||||
Body string `json:"body"`
|
||||
PublishedAt string `json:"published_at"`
|
||||
HTMLURL string `json:"html_url"`
|
||||
HtmlURL string `json:"html_url"`
|
||||
Assets []Asset `json:"assets,omitempty"`
|
||||
}
|
||||
|
||||
@@ -96,13 +96,13 @@ type GitHubRelease struct {
|
||||
Name string `json:"name"`
|
||||
Body string `json:"body"`
|
||||
PublishedAt string `json:"published_at"`
|
||||
HTMLURL string `json:"html_url"`
|
||||
HtmlUrl string `json:"html_url"`
|
||||
Assets []GitHubAsset `json:"assets"`
|
||||
}
|
||||
|
||||
type GitHubAsset struct {
|
||||
Name string `json:"name"`
|
||||
BrowserDownloadURL string `json:"browser_download_url"`
|
||||
BrowserDownloadUrl string `json:"browser_download_url"`
|
||||
Size int64 `json:"size"`
|
||||
}
|
||||
|
||||
@@ -285,7 +285,7 @@ func (s *UpdateService) fetchLatestRelease(ctx context.Context) (*UpdateInfo, er
|
||||
for i, a := range release.Assets {
|
||||
assets[i] = Asset{
|
||||
Name: a.Name,
|
||||
DownloadURL: a.BrowserDownloadURL,
|
||||
DownloadURL: a.BrowserDownloadUrl,
|
||||
Size: a.Size,
|
||||
}
|
||||
}
|
||||
@@ -298,7 +298,7 @@ func (s *UpdateService) fetchLatestRelease(ctx context.Context) (*UpdateInfo, er
|
||||
Name: release.Name,
|
||||
Body: release.Body,
|
||||
PublishedAt: release.PublishedAt,
|
||||
HTMLURL: release.HTMLURL,
|
||||
HtmlURL: release.HtmlUrl,
|
||||
Assets: assets,
|
||||
},
|
||||
Cached: false,
|
||||
|
||||
35
backend/internal/service/usage_clamp.go
Normal file
35
backend/internal/service/usage_clamp.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package service
|
||||
|
||||
import "time"
|
||||
|
||||
// clampInt 将整数限制在指定范围内
|
||||
func clampInt(value, min, max int) int {
|
||||
if value < min {
|
||||
return min
|
||||
}
|
||||
if value > max {
|
||||
return max
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
// clampFloat64 将浮点数限制在指定范围内
|
||||
func clampFloat64(value, min, max float64) float64 {
|
||||
if value < min {
|
||||
return min
|
||||
}
|
||||
if value > max {
|
||||
return max
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
// remainingSecondsUntil 计算到指定时间的剩余秒数,保证非负
|
||||
func remainingSecondsUntil(t time.Time) int {
|
||||
seconds := int(time.Until(t).Seconds())
|
||||
if seconds < 0 {
|
||||
return 0
|
||||
}
|
||||
return seconds
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ const (
|
||||
type UsageLog struct {
|
||||
ID int64
|
||||
UserID int64
|
||||
APIKeyID int64
|
||||
ApiKeyID int64
|
||||
AccountID int64
|
||||
RequestID string
|
||||
Model string
|
||||
@@ -42,7 +42,7 @@ type UsageLog struct {
|
||||
CreatedAt time.Time
|
||||
|
||||
User *User
|
||||
APIKey *APIKey
|
||||
ApiKey *ApiKey
|
||||
Account *Account
|
||||
Group *Group
|
||||
Subscription *UserSubscription
|
||||
|
||||
@@ -17,7 +17,7 @@ var (
|
||||
// CreateUsageLogRequest 创建使用日志请求
|
||||
type CreateUsageLogRequest struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
APIKeyID int64 `json:"api_key_id"`
|
||||
ApiKeyID int64 `json:"api_key_id"`
|
||||
AccountID int64 `json:"account_id"`
|
||||
RequestID string `json:"request_id"`
|
||||
Model string `json:"model"`
|
||||
@@ -75,7 +75,7 @@ func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (*
|
||||
// 创建使用日志
|
||||
usageLog := &UsageLog{
|
||||
UserID: req.UserID,
|
||||
APIKeyID: req.APIKeyID,
|
||||
ApiKeyID: req.ApiKeyID,
|
||||
AccountID: req.AccountID,
|
||||
RequestID: req.RequestID,
|
||||
Model: req.Model,
|
||||
@@ -128,9 +128,9 @@ func (s *UsageService) ListByUser(ctx context.Context, userID int64, params pagi
|
||||
return logs, pagination, nil
|
||||
}
|
||||
|
||||
// ListByAPIKey 获取API Key的使用日志列表
|
||||
func (s *UsageService) ListByAPIKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error) {
|
||||
logs, pagination, err := s.usageRepo.ListByAPIKey(ctx, apiKeyID, params)
|
||||
// ListByApiKey 获取API Key的使用日志列表
|
||||
func (s *UsageService) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error) {
|
||||
logs, pagination, err := s.usageRepo.ListByApiKey(ctx, apiKeyID, params)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("list usage logs: %w", err)
|
||||
}
|
||||
@@ -165,9 +165,9 @@ func (s *UsageService) GetStatsByUser(ctx context.Context, userID int64, startTi
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetStatsByAPIKey 获取API Key的使用统计
|
||||
func (s *UsageService) GetStatsByAPIKey(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*UsageStats, error) {
|
||||
stats, err := s.usageRepo.GetAPIKeyStatsAggregated(ctx, apiKeyID, startTime, endTime)
|
||||
// GetStatsByApiKey 获取API Key的使用统计
|
||||
func (s *UsageService) GetStatsByApiKey(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*UsageStats, error) {
|
||||
stats, err := s.usageRepo.GetApiKeyStatsAggregated(ctx, apiKeyID, startTime, endTime)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get api key stats: %w", err)
|
||||
}
|
||||
@@ -270,9 +270,9 @@ func (s *UsageService) GetUserModelStats(ctx context.Context, userID int64, star
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
// GetBatchAPIKeyUsageStats returns today/total actual_cost for given api keys.
|
||||
func (s *UsageService) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
|
||||
stats, err := s.usageRepo.GetBatchAPIKeyUsageStats(ctx, apiKeyIDs)
|
||||
// GetBatchApiKeyUsageStats returns today/total actual_cost for given api keys.
|
||||
func (s *UsageService) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchApiKeyUsageStats, error) {
|
||||
stats, err := s.usageRepo.GetBatchApiKeyUsageStats(ctx, apiKeyIDs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get batch api key usage stats: %w", err)
|
||||
}
|
||||
|
||||
@@ -21,7 +21,7 @@ type User struct {
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
|
||||
APIKeys []APIKey
|
||||
ApiKeys []ApiKey
|
||||
Subscriptions []UserSubscription
|
||||
}
|
||||
|
||||
|
||||
@@ -56,6 +56,10 @@ func (s *UserAttributeService) CreateDefinition(ctx context.Context, input Creat
|
||||
Enabled: input.Enabled,
|
||||
}
|
||||
|
||||
if err := validateDefinitionPattern(def); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := s.defRepo.Create(ctx, def); err != nil {
|
||||
return nil, fmt.Errorf("create definition: %w", err)
|
||||
}
|
||||
@@ -108,6 +112,10 @@ func (s *UserAttributeService) UpdateDefinition(ctx context.Context, id int64, i
|
||||
def.Enabled = *input.Enabled
|
||||
}
|
||||
|
||||
if err := validateDefinitionPattern(def); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := s.defRepo.Update(ctx, def); err != nil {
|
||||
return nil, fmt.Errorf("update definition: %w", err)
|
||||
}
|
||||
@@ -231,7 +239,10 @@ func (s *UserAttributeService) validateValue(def *UserAttributeDefinition, value
|
||||
// Pattern validation
|
||||
if v.Pattern != nil && *v.Pattern != "" && value != "" {
|
||||
re, err := regexp.Compile(*v.Pattern)
|
||||
if err == nil && !re.MatchString(value) {
|
||||
if err != nil {
|
||||
return validationError(def.Name + " has an invalid pattern")
|
||||
}
|
||||
if !re.MatchString(value) {
|
||||
msg := def.Name + " format is invalid"
|
||||
if v.Message != nil && *v.Message != "" {
|
||||
msg = *v.Message
|
||||
@@ -293,3 +304,20 @@ func isValidAttributeType(t UserAttributeType) bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func validateDefinitionPattern(def *UserAttributeDefinition) error {
|
||||
if def == nil {
|
||||
return nil
|
||||
}
|
||||
if def.Validation.Pattern == nil {
|
||||
return nil
|
||||
}
|
||||
pattern := strings.TrimSpace(*def.Validation.Pattern)
|
||||
if pattern == "" {
|
||||
return nil
|
||||
}
|
||||
if _, err := regexp.Compile(pattern); err != nil {
|
||||
return infraerrors.BadRequest("INVALID_ATTRIBUTE_PATTERN", fmt.Sprintf("invalid pattern for %s: %v", def.Name, err))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -54,18 +54,6 @@ func ProvideTimingWheelService() *TimingWheelService {
|
||||
return svc
|
||||
}
|
||||
|
||||
// ProvideAntigravityQuotaRefresher creates and starts AntigravityQuotaRefresher
|
||||
func ProvideAntigravityQuotaRefresher(
|
||||
accountRepo AccountRepository,
|
||||
proxyRepo ProxyRepository,
|
||||
oauthSvc *AntigravityOAuthService,
|
||||
cfg *config.Config,
|
||||
) *AntigravityQuotaRefresher {
|
||||
svc := NewAntigravityQuotaRefresher(accountRepo, proxyRepo, oauthSvc, cfg)
|
||||
svc.Start()
|
||||
return svc
|
||||
}
|
||||
|
||||
// ProvideDeferredService creates and starts DeferredService
|
||||
func ProvideDeferredService(accountRepo AccountRepository, timingWheel *TimingWheelService) *DeferredService {
|
||||
svc := NewDeferredService(accountRepo, timingWheel, 10*time.Second)
|
||||
@@ -73,20 +61,6 @@ func ProvideDeferredService(accountRepo AccountRepository, timingWheel *TimingWh
|
||||
return svc
|
||||
}
|
||||
|
||||
// ProvideOpsMetricsCollector creates and starts OpsMetricsCollector.
|
||||
func ProvideOpsMetricsCollector(opsService *OpsService, concurrencyService *ConcurrencyService) *OpsMetricsCollector {
|
||||
svc := NewOpsMetricsCollector(opsService, concurrencyService)
|
||||
svc.Start()
|
||||
return svc
|
||||
}
|
||||
|
||||
// ProvideOpsAlertService creates and starts OpsAlertService.
|
||||
func ProvideOpsAlertService(opsService *OpsService, userService *UserService, emailService *EmailService) *OpsAlertService {
|
||||
svc := NewOpsAlertService(opsService, userService, emailService)
|
||||
svc.Start()
|
||||
return svc
|
||||
}
|
||||
|
||||
// ProvideConcurrencyService creates ConcurrencyService and starts slot cleanup worker.
|
||||
func ProvideConcurrencyService(cache ConcurrencyCache, accountRepo AccountRepository, cfg *config.Config) *ConcurrencyService {
|
||||
svc := NewConcurrencyService(cache)
|
||||
@@ -101,14 +75,13 @@ var ProviderSet = wire.NewSet(
|
||||
// Core services
|
||||
NewAuthService,
|
||||
NewUserService,
|
||||
NewAPIKeyService,
|
||||
NewApiKeyService,
|
||||
NewGroupService,
|
||||
NewAccountService,
|
||||
NewProxyService,
|
||||
NewRedeemService,
|
||||
NewUsageService,
|
||||
NewDashboardService,
|
||||
NewOpsService,
|
||||
ProvidePricingService,
|
||||
NewBillingService,
|
||||
NewBillingCacheService,
|
||||
@@ -139,8 +112,7 @@ var ProviderSet = wire.NewSet(
|
||||
ProvideTokenRefreshService,
|
||||
ProvideTimingWheelService,
|
||||
ProvideDeferredService,
|
||||
ProvideAntigravityQuotaRefresher,
|
||||
ProvideOpsMetricsCollector,
|
||||
ProvideOpsAlertService,
|
||||
NewAntigravityQuotaFetcher,
|
||||
NewUserAttributeService,
|
||||
NewUsageCache,
|
||||
)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
// Package setup provides CLI-based installation wizard for initial system configuration.
|
||||
package setup
|
||||
|
||||
import (
|
||||
|
||||
@@ -345,7 +345,7 @@ func writeConfigFile(cfg *SetupConfig) error {
|
||||
Default struct {
|
||||
UserConcurrency int `yaml:"user_concurrency"`
|
||||
UserBalance float64 `yaml:"user_balance"`
|
||||
APIKeyPrefix string `yaml:"api_key_prefix"`
|
||||
ApiKeyPrefix string `yaml:"api_key_prefix"`
|
||||
RateMultiplier float64 `yaml:"rate_multiplier"`
|
||||
} `yaml:"default"`
|
||||
RateLimit struct {
|
||||
@@ -367,12 +367,12 @@ func writeConfigFile(cfg *SetupConfig) error {
|
||||
Default: struct {
|
||||
UserConcurrency int `yaml:"user_concurrency"`
|
||||
UserBalance float64 `yaml:"user_balance"`
|
||||
APIKeyPrefix string `yaml:"api_key_prefix"`
|
||||
ApiKeyPrefix string `yaml:"api_key_prefix"`
|
||||
RateMultiplier float64 `yaml:"rate_multiplier"`
|
||||
}{
|
||||
UserConcurrency: 5,
|
||||
UserBalance: 0,
|
||||
APIKeyPrefix: "sk-",
|
||||
ApiKeyPrefix: "sk-",
|
||||
RateMultiplier: 1.0,
|
||||
},
|
||||
RateLimit: struct {
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
//go:build !embed
|
||||
|
||||
// Package web provides web server functionality including embedded frontend support.
|
||||
package web
|
||||
|
||||
import (
|
||||
|
||||
Reference in New Issue
Block a user