From c6b3de11995ebae41a5faeea42f6704f29545ddb Mon Sep 17 00:00:00 2001 From: NepetaLemon Date: Sat, 20 Dec 2025 15:29:52 +0800 Subject: [PATCH] =?UTF-8?q?ci(backend):=20=E6=B7=BB=E5=8A=A0=20github=20ac?= =?UTF-8?q?tions=20(#10)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## 变更内容 ### CI/CD - 添加 GitHub Actions 工作流(test + golangci-lint) - 添加 golangci-lint 配置,启用 errcheck/govet/staticcheck/unused/depguard - 通过 depguard 强制 service 层不能直接导入 repository ### 错误处理修复 - 修复 CSV 写入、SSE 流式输出、随机数生成等未处理的错误 - GenerateRedeemCode() 现在返回 error ### 资源泄露修复 - 统一使用 defer func() { _ = xxx.Close() }() 模式 ### 代码清理 - 移除未使用的常量 - 简化 nil map 检查 - 统一代码格式 --- backend/.github/workflows/ci.yml | 36 ++++++++++ backend/.golangci.yml | 27 ++++++++ backend/internal/config/config.go | 10 +-- .../internal/handler/admin/account_handler.go | 2 +- .../internal/handler/admin/proxy_handler.go | 1 - .../internal/handler/admin/redeem_handler.go | 24 +++++-- backend/internal/handler/gateway_handler.go | 18 +++-- backend/internal/model/account.go | 8 +-- backend/internal/model/group.go | 14 ++-- backend/internal/model/redeem_code.go | 14 ++-- backend/internal/model/setting.go | 16 ++--- backend/internal/model/usage_log.go | 2 +- backend/internal/model/user.go | 4 +- .../internal/pkg/timezone/timezone_test.go | 20 ++++-- .../repository/claude_usage_service.go | 2 +- .../repository/github_release_service.go | 10 +-- .../internal/repository/pricing_service.go | 4 +- .../repository/proxy_probe_service.go | 2 +- .../internal/repository/turnstile_service.go | 2 +- .../internal/service/account_test_service.go | 33 +++++++--- backend/internal/service/admin_service.go | 43 ++++++++---- backend/internal/service/api_key_service.go | 9 ++- .../internal/service/concurrency_service.go | 12 +--- backend/internal/service/email_service.go | 10 +-- backend/internal/service/gateway_service.go | 18 +++-- backend/internal/service/identity_service.go | 1 - backend/internal/service/redeem_service.go | 4 +- .../internal/service/subscription_service.go | 13 ++-- backend/internal/service/turnstile_service.go | 2 - backend/internal/service/update_service.go | 31 +++++---- backend/internal/setup/handler.go | 1 - backend/internal/setup/setup.go | 66 +++++++++++++++---- backend/internal/web/embed.go | 4 +- 33 files changed, 316 insertions(+), 147 deletions(-) create mode 100644 backend/.github/workflows/ci.yml create mode 100644 backend/.golangci.yml diff --git a/backend/.github/workflows/ci.yml b/backend/.github/workflows/ci.yml new file mode 100644 index 00000000..7efeea15 --- /dev/null +++ b/backend/.github/workflows/ci.yml @@ -0,0 +1,36 @@ +name: CI + +on: + push: + pull_request: + +permissions: + contents: read + +jobs: + test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-go@v5 + with: + go-version-file: go.mod + check-latest: true + cache: true + - name: Run tests + run: go test ./... + + golangci-lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-go@v5 + with: + go-version-file: go.mod + check-latest: true + cache: true + - name: golangci-lint + uses: golangci/golangci-lint-action@v6 + with: + version: latest + args: --timeout=5m diff --git a/backend/.golangci.yml b/backend/.golangci.yml new file mode 100644 index 00000000..45ecc0f3 --- /dev/null +++ b/backend/.golangci.yml @@ -0,0 +1,27 @@ +version: "2" + +linters: + default: none + enable: + - depguard + - errcheck + - govet + - ineffassign + - staticcheck + - unused + + settings: + depguard: + rules: + # Enforce: service must not depend on repository. + service-no-repository: + list-mode: original + files: + - internal/service/** + deny: + - pkg: sub2api/internal/repository + desc: "service must not import repository" + +formatters: + enable: + - gofmt diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 34ecbfb5..18fb162d 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -52,7 +52,7 @@ type PricingConfig struct { type ServerConfig struct { Host string `mapstructure:"host"` Port int `mapstructure:"port"` - Mode string `mapstructure:"mode"` // debug/release + Mode string `mapstructure:"mode"` // debug/release ReadHeaderTimeout int `mapstructure:"read_header_timeout"` // 读取请求头超时(秒) IdleTimeout int `mapstructure:"idle_timeout"` // 空闲连接超时(秒) } @@ -163,7 +163,7 @@ func setDefaults() { viper.SetDefault("server.port", 8080) viper.SetDefault("server.mode", "debug") viper.SetDefault("server.read_header_timeout", 30) // 30秒读取请求头 - viper.SetDefault("server.idle_timeout", 120) // 120秒空闲超时 + viper.SetDefault("server.idle_timeout", 120) // 120秒空闲超时 // Database viper.SetDefault("database.host", "localhost") @@ -210,10 +210,10 @@ func setDefaults() { // TokenRefresh viper.SetDefault("token_refresh.enabled", true) - viper.SetDefault("token_refresh.check_interval_minutes", 5) // 每5分钟检查一次 + viper.SetDefault("token_refresh.check_interval_minutes", 5) // 每5分钟检查一次 viper.SetDefault("token_refresh.refresh_before_expiry_hours", 1.5) // 提前1.5小时刷新 - viper.SetDefault("token_refresh.max_retries", 3) // 最多重试3次 - viper.SetDefault("token_refresh.retry_backoff_seconds", 2) // 重试退避基础2秒 + viper.SetDefault("token_refresh.max_retries", 3) // 最多重试3次 + viper.SetDefault("token_refresh.retry_backoff_seconds", 2) // 重试退避基础2秒 } func (c *Config) Validate() error { diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go index 5b3ac4f6..eef19b43 100644 --- a/backend/internal/handler/admin/account_handler.go +++ b/backend/internal/handler/admin/account_handler.go @@ -573,7 +573,7 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) { // For API Key accounts: return models based on model_mapping mapping := account.GetModelMapping() - if mapping == nil || len(mapping) == 0 { + if len(mapping) == 0 { // No mapping configured, return default models response.Success(c, claude.DefaultModels) return diff --git a/backend/internal/handler/admin/proxy_handler.go b/backend/internal/handler/admin/proxy_handler.go index ee5b706d..8ffa5470 100644 --- a/backend/internal/handler/admin/proxy_handler.go +++ b/backend/internal/handler/admin/proxy_handler.go @@ -236,7 +236,6 @@ func (h *ProxyHandler) GetProxyAccounts(c *gin.Context) { response.Paginated(c, accounts, total, page, pageSize) } - // BatchCreateProxyItem represents a single proxy in batch create request type BatchCreateProxyItem struct { Protocol string `json:"protocol" binding:"required,oneof=http https socks5"` diff --git a/backend/internal/handler/admin/redeem_handler.go b/backend/internal/handler/admin/redeem_handler.go index 63856dca..6db998bd 100644 --- a/backend/internal/handler/admin/redeem_handler.go +++ b/backend/internal/handler/admin/redeem_handler.go @@ -156,10 +156,10 @@ func (h *RedeemHandler) Expire(c *gin.Context) { func (h *RedeemHandler) GetStats(c *gin.Context) { // Return mock data for now response.Success(c, gin.H{ - "total_codes": 0, - "active_codes": 0, - "used_codes": 0, - "expired_codes": 0, + "total_codes": 0, + "active_codes": 0, + "used_codes": 0, + "expired_codes": 0, "total_value_distributed": 0.0, "by_type": gin.H{ "balance": 0, @@ -187,7 +187,10 @@ func (h *RedeemHandler) Export(c *gin.Context) { writer := csv.NewWriter(&buf) // Write header - writer.Write([]string{"id", "code", "type", "value", "status", "used_by", "used_at", "created_at"}) + if err := writer.Write([]string{"id", "code", "type", "value", "status", "used_by", "used_at", "created_at"}); err != nil { + response.InternalError(c, "Failed to export redeem codes: "+err.Error()) + return + } // Write data rows for _, code := range codes { @@ -199,7 +202,7 @@ func (h *RedeemHandler) Export(c *gin.Context) { if code.UsedAt != nil { usedAt = code.UsedAt.Format("2006-01-02 15:04:05") } - writer.Write([]string{ + if err := writer.Write([]string{ fmt.Sprintf("%d", code.ID), code.Code, code.Type, @@ -208,10 +211,17 @@ func (h *RedeemHandler) Export(c *gin.Context) { usedBy, usedAt, code.CreatedAt.Format("2006-01-02 15:04:05"), - }) + }); err != nil { + response.InternalError(c, "Failed to export redeem codes: "+err.Error()) + return + } } writer.Flush() + if err := writer.Error(); err != nil { + response.InternalError(c, "Failed to export redeem codes: "+err.Error()) + return + } c.Header("Content-Type", "text/csv") c.Header("Content-Disposition", "attachment; filename=redeem_codes.csv") diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 41f91872..9aa4f53d 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -268,7 +268,9 @@ func (h *GatewayHandler) waitForSlotWithPing(c *gin.Context, slotType string, id c.Header("X-Accel-Buffering", "no") *streamStarted = true } - fmt.Fprintf(c.Writer, "data: {\"type\": \"ping\"}\n\n") + if _, err := fmt.Fprintf(c.Writer, "data: {\"type\": \"ping\"}\n\n"); err != nil { + return nil, err + } flusher.Flush() } @@ -414,7 +416,9 @@ func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, e if ok { // Send error event in SSE format errorEvent := fmt.Sprintf(`data: {"type": "error", "error": {"type": "%s", "message": "%s"}}`+"\n\n", errType, message) - fmt.Fprint(c.Writer, errorEvent) + if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil { + _ = c.Error(err) + } flusher.Flush() } return @@ -574,11 +578,11 @@ func sendMockWarmupStream(c *gin.Context, model string) { // sendMockWarmupResponse 发送非流式 mock 响应(用于预热请求拦截) func sendMockWarmupResponse(c *gin.Context, model string) { c.JSON(http.StatusOK, gin.H{ - "id": "msg_mock_warmup", - "type": "message", - "role": "assistant", - "model": model, - "content": []gin.H{{"type": "text", "text": "New Conversation"}}, + "id": "msg_mock_warmup", + "type": "message", + "role": "assistant", + "model": model, + "content": []gin.H{{"type": "text", "text": "New Conversation"}}, "stop_reason": "end_turn", "usage": gin.H{ "input_tokens": 10, diff --git a/backend/internal/model/account.go b/backend/internal/model/account.go index 3040cf8f..22d2e7c8 100644 --- a/backend/internal/model/account.go +++ b/backend/internal/model/account.go @@ -40,8 +40,8 @@ type Account struct { Extra JSONB `gorm:"type:jsonb;default:'{}'" json:"extra"` // 扩展信息 ProxyID *int64 `gorm:"index" json:"proxy_id"` Concurrency int `gorm:"default:3;not null" json:"concurrency"` - Priority int `gorm:"default:50;not null" json:"priority"` // 1-100,越小越高 - Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled/error + Priority int `gorm:"default:50;not null" json:"priority"` // 1-100,越小越高 + Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled/error ErrorMessage string `gorm:"type:text" json:"error_message"` LastUsedAt *time.Time `gorm:"index" json:"last_used_at"` CreatedAt time.Time `gorm:"not null" json:"created_at"` @@ -163,7 +163,7 @@ func (a *Account) GetModelMapping() map[string]string { // 如果没有设置模型映射,则支持所有模型 func (a *Account) IsModelSupported(requestedModel string) bool { mapping := a.GetModelMapping() - if mapping == nil || len(mapping) == 0 { + if len(mapping) == 0 { return true // 没有映射配置,支持所有模型 } _, exists := mapping[requestedModel] @@ -174,7 +174,7 @@ func (a *Account) IsModelSupported(requestedModel string) bool { // 如果没有映射,返回原始模型名 func (a *Account) GetMappedModel(requestedModel string) string { mapping := a.GetModelMapping() - if mapping == nil || len(mapping) == 0 { + if len(mapping) == 0 { return requestedModel } if mappedModel, exists := mapping[requestedModel]; exists { diff --git a/backend/internal/model/group.go b/backend/internal/model/group.go index f02b2692..b1bbe527 100644 --- a/backend/internal/model/group.go +++ b/backend/internal/model/group.go @@ -13,13 +13,13 @@ const ( ) type Group struct { - ID int64 `gorm:"primaryKey" json:"id"` - Name string `gorm:"uniqueIndex;size:100;not null" json:"name"` - Description string `gorm:"type:text" json:"description"` - Platform string `gorm:"size:50;default:anthropic;not null" json:"platform"` // anthropic/openai/gemini - RateMultiplier float64 `gorm:"type:decimal(10,4);default:1.0;not null" json:"rate_multiplier"` - IsExclusive bool `gorm:"default:false;not null" json:"is_exclusive"` - Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled + ID int64 `gorm:"primaryKey" json:"id"` + Name string `gorm:"uniqueIndex;size:100;not null" json:"name"` + Description string `gorm:"type:text" json:"description"` + Platform string `gorm:"size:50;default:anthropic;not null" json:"platform"` // anthropic/openai/gemini + RateMultiplier float64 `gorm:"type:decimal(10,4);default:1.0;not null" json:"rate_multiplier"` + IsExclusive bool `gorm:"default:false;not null" json:"is_exclusive"` + Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled // 订阅功能字段 SubscriptionType string `gorm:"size:20;default:standard;not null" json:"subscription_type"` // standard/subscription diff --git a/backend/internal/model/redeem_code.go b/backend/internal/model/redeem_code.go index 725361c3..602fcbe8 100644 --- a/backend/internal/model/redeem_code.go +++ b/backend/internal/model/redeem_code.go @@ -9,15 +9,15 @@ import ( type RedeemCode struct { ID int64 `gorm:"primaryKey" json:"id"` Code string `gorm:"uniqueIndex;size:32;not null" json:"code"` - Type string `gorm:"size:20;default:balance;not null" json:"type"` // balance/concurrency/subscription - Value float64 `gorm:"type:decimal(20,8);not null" json:"value"` // 面值(USD)或并发数或有效天数 + Type string `gorm:"size:20;default:balance;not null" json:"type"` // balance/concurrency/subscription + Value float64 `gorm:"type:decimal(20,8);not null" json:"value"` // 面值(USD)或并发数或有效天数 Status string `gorm:"size:20;default:unused;not null" json:"status"` // unused/used UsedBy *int64 `gorm:"index" json:"used_by"` UsedAt *time.Time `json:"used_at"` CreatedAt time.Time `gorm:"not null" json:"created_at"` // 订阅类型专用字段 - GroupID *int64 `gorm:"index" json:"group_id"` // 订阅分组ID (仅subscription类型使用) + GroupID *int64 `gorm:"index" json:"group_id"` // 订阅分组ID (仅subscription类型使用) ValidityDays int `gorm:"default:30" json:"validity_days"` // 订阅有效天数 (仅subscription类型使用) // 关联 @@ -40,8 +40,10 @@ func (r *RedeemCode) CanUse() bool { } // GenerateRedeemCode 生成唯一的兑换码 -func GenerateRedeemCode() string { +func GenerateRedeemCode() (string, error) { b := make([]byte, 16) - rand.Read(b) - return hex.EncodeToString(b) + if _, err := rand.Read(b); err != nil { + return "", err + } + return hex.EncodeToString(b), nil } diff --git a/backend/internal/model/setting.go b/backend/internal/model/setting.go index 3bfa5068..ec95030a 100644 --- a/backend/internal/model/setting.go +++ b/backend/internal/model/setting.go @@ -19,17 +19,17 @@ func (Setting) TableName() string { // 设置Key常量 const ( // 注册设置 - SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册 - SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证 + SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册 + SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证 // 邮件服务设置 - SettingKeySmtpHost = "smtp_host" // SMTP服务器地址 - SettingKeySmtpPort = "smtp_port" // SMTP端口 - SettingKeySmtpUsername = "smtp_username" // SMTP用户名 - SettingKeySmtpPassword = "smtp_password" // SMTP密码(加密存储) - SettingKeySmtpFrom = "smtp_from" // 发件人地址 + SettingKeySmtpHost = "smtp_host" // SMTP服务器地址 + SettingKeySmtpPort = "smtp_port" // SMTP端口 + SettingKeySmtpUsername = "smtp_username" // SMTP用户名 + SettingKeySmtpPassword = "smtp_password" // SMTP密码(加密存储) + SettingKeySmtpFrom = "smtp_from" // 发件人地址 SettingKeySmtpFromName = "smtp_from_name" // 发件人名称 - SettingKeySmtpUseTLS = "smtp_use_tls" // 是否使用TLS + SettingKeySmtpUseTLS = "smtp_use_tls" // 是否使用TLS // Cloudflare Turnstile 设置 SettingKeyTurnstileEnabled = "turnstile_enabled" // 是否启用 Turnstile 验证 diff --git a/backend/internal/model/usage_log.go b/backend/internal/model/usage_log.go index fb23cf73..b9ca0a77 100644 --- a/backend/internal/model/usage_log.go +++ b/backend/internal/model/usage_log.go @@ -37,7 +37,7 @@ type UsageLog struct { OutputCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"output_cost"` CacheCreationCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"cache_creation_cost"` CacheReadCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"cache_read_cost"` - TotalCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"total_cost"` // 原始总费用 + TotalCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"total_cost"` // 原始总费用 ActualCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"actual_cost"` // 实际扣除费用 RateMultiplier float64 `gorm:"type:decimal(10,4);default:1;not null" json:"rate_multiplier"` // 计费倍率 diff --git a/backend/internal/model/user.go b/backend/internal/model/user.go index 96552474..61fc5215 100644 --- a/backend/internal/model/user.go +++ b/backend/internal/model/user.go @@ -9,8 +9,8 @@ import ( ) type User struct { - ID int64 `gorm:"primaryKey" json:"id"` - Email string `gorm:"uniqueIndex;size:255;not null" json:"email"` + ID int64 `gorm:"primaryKey" json:"id"` + Email string `gorm:"uniqueIndex;size:255;not null" json:"email"` PasswordHash string `gorm:"size:255;not null" json:"-"` Role string `gorm:"size:20;default:user;not null" json:"role"` // admin/user Balance float64 `gorm:"type:decimal(20,8);default:0;not null" json:"balance"` diff --git a/backend/internal/pkg/timezone/timezone_test.go b/backend/internal/pkg/timezone/timezone_test.go index 3d21246e..ac9cdde6 100644 --- a/backend/internal/pkg/timezone/timezone_test.go +++ b/backend/internal/pkg/timezone/timezone_test.go @@ -37,11 +37,15 @@ func TestInitInvalidTimezone(t *testing.T) { func TestTimeNowAffected(t *testing.T) { // Reset to UTC first - Init("UTC") + if err := Init("UTC"); err != nil { + t.Fatalf("Init failed with UTC: %v", err) + } utcNow := time.Now() // Switch to Shanghai (UTC+8) - Init("Asia/Shanghai") + if err := Init("Asia/Shanghai"); err != nil { + t.Fatalf("Init failed with Asia/Shanghai: %v", err) + } shanghaiNow := time.Now() // The times should be the same instant, but different timezone representation @@ -58,7 +62,9 @@ func TestTimeNowAffected(t *testing.T) { } func TestToday(t *testing.T) { - Init("Asia/Shanghai") + if err := Init("Asia/Shanghai"); err != nil { + t.Fatalf("Init failed with Asia/Shanghai: %v", err) + } today := Today() now := Now() @@ -75,7 +81,9 @@ func TestToday(t *testing.T) { } func TestStartOfDay(t *testing.T) { - Init("Asia/Shanghai") + if err := Init("Asia/Shanghai"); err != nil { + t.Fatalf("Init failed with Asia/Shanghai: %v", err) + } // Create a time at 15:30:45 testTime := time.Date(2024, 6, 15, 15, 30, 45, 123456789, Location()) @@ -91,7 +99,9 @@ func TestTruncateVsStartOfDay(t *testing.T) { // This test demonstrates why Truncate(24*time.Hour) can be problematic // and why StartOfDay is more reliable for timezone-aware code - Init("Asia/Shanghai") + if err := Init("Asia/Shanghai"); err != nil { + t.Fatalf("Init failed with Asia/Shanghai: %v", err) + } now := Now() diff --git a/backend/internal/repository/claude_usage_service.go b/backend/internal/repository/claude_usage_service.go index 45d0aace..88b8cc36 100644 --- a/backend/internal/repository/claude_usage_service.go +++ b/backend/internal/repository/claude_usage_service.go @@ -43,7 +43,7 @@ func (s *claudeUsageService) FetchUsage(ctx context.Context, accessToken, proxyU if err != nil { return nil, fmt.Errorf("request failed: %w", err) } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) diff --git a/backend/internal/repository/github_release_service.go b/backend/internal/repository/github_release_service.go index 927c0ee5..980cc345 100644 --- a/backend/internal/repository/github_release_service.go +++ b/backend/internal/repository/github_release_service.go @@ -38,7 +38,7 @@ func (c *githubReleaseClient) FetchLatestRelease(ctx context.Context, repo strin if err != nil { return nil, err } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("GitHub API returned %d", resp.StatusCode) @@ -63,7 +63,7 @@ func (c *githubReleaseClient) DownloadFile(ctx context.Context, url, dest string if err != nil { return err } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { return fmt.Errorf("download returned %d", resp.StatusCode) @@ -78,7 +78,7 @@ func (c *githubReleaseClient) DownloadFile(ctx context.Context, url, dest string if err != nil { return err } - defer out.Close() + defer func() { _ = out.Close() }() // SECURITY: Use LimitReader to enforce max download size even if Content-Length is missing/wrong limited := io.LimitReader(resp.Body, maxSize+1) @@ -89,7 +89,7 @@ func (c *githubReleaseClient) DownloadFile(ctx context.Context, url, dest string // Check if we hit the limit (downloaded more than maxSize) if written > maxSize { - os.Remove(dest) // Clean up partial file + _ = os.Remove(dest) // Clean up partial file (best-effort) return fmt.Errorf("download exceeded maximum size of %d bytes", maxSize) } @@ -106,7 +106,7 @@ func (c *githubReleaseClient) FetchChecksumFile(ctx context.Context, url string) if err != nil { return nil, err } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("HTTP %d", resp.StatusCode) diff --git a/backend/internal/repository/pricing_service.go b/backend/internal/repository/pricing_service.go index e634531c..b3c8e95d 100644 --- a/backend/internal/repository/pricing_service.go +++ b/backend/internal/repository/pricing_service.go @@ -33,7 +33,7 @@ func (c *pricingRemoteClient) FetchPricingJSON(ctx context.Context, url string) if err != nil { return nil, err } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("HTTP %d", resp.StatusCode) @@ -52,7 +52,7 @@ func (c *pricingRemoteClient) FetchHashText(ctx context.Context, url string) (st if err != nil { return "", err } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { return "", fmt.Errorf("HTTP %d", resp.StatusCode) diff --git a/backend/internal/repository/proxy_probe_service.go b/backend/internal/repository/proxy_probe_service.go index c670f3ef..0d281fba 100644 --- a/backend/internal/repository/proxy_probe_service.go +++ b/backend/internal/repository/proxy_probe_service.go @@ -43,7 +43,7 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s if err != nil { return nil, 0, fmt.Errorf("proxy connection failed: %w", err) } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() latencyMs := time.Since(startTime).Milliseconds() diff --git a/backend/internal/repository/turnstile_service.go b/backend/internal/repository/turnstile_service.go index 19152dbd..bfb62cbb 100644 --- a/backend/internal/repository/turnstile_service.go +++ b/backend/internal/repository/turnstile_service.go @@ -44,7 +44,7 @@ func (v *turnstileVerifier) VerifyToken(ctx context.Context, secretKey, token, r if err != nil { return nil, fmt.Errorf("send request: %w", err) } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() var result service.TurnstileVerifyResponse if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go index e7b44e83..26ab33ff 100644 --- a/backend/internal/service/account_test_service.go +++ b/backend/internal/service/account_test_service.go @@ -51,16 +51,23 @@ func NewAccountTestService(accountRepo ports.AccountRepository, oauthService *OA } // generateSessionString generates a Claude Code style session string -func generateSessionString() string { +func generateSessionString() (string, error) { bytes := make([]byte, 32) - rand.Read(bytes) + if _, err := rand.Read(bytes); err != nil { + return "", err + } hex64 := hex.EncodeToString(bytes) sessionUUID := uuid.New().String() - return fmt.Sprintf("user_%s_account__session_%s", hex64, sessionUUID) + return fmt.Sprintf("user_%s_account__session_%s", hex64, sessionUUID), nil } // createTestPayload creates a Claude Code style test request payload -func createTestPayload(modelID string) map[string]interface{} { +func createTestPayload(modelID string) (map[string]interface{}, error) { + sessionID, err := generateSessionString() + if err != nil { + return nil, err + } + return map[string]interface{}{ "model": modelID, "messages": []map[string]interface{}{ @@ -87,12 +94,12 @@ func createTestPayload(modelID string) map[string]interface{} { }, }, "metadata": map[string]string{ - "user_id": generateSessionString(), + "user_id": sessionID, }, "max_tokens": 1024, "temperature": 1, "stream": true, - } + }, nil } // TestAccountConnection tests an account's connection by sending a test request @@ -116,7 +123,7 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int // For API Key accounts with model mapping, map the model if account.Type == "apikey" { mapping := account.GetModelMapping() - if mapping != nil && len(mapping) > 0 { + if len(mapping) > 0 { if mappedModel, exists := mapping[testModelID]; exists { testModelID = mappedModel } @@ -178,7 +185,10 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int c.Writer.Flush() // Create Claude Code style payload (same for all account types) - payload := createTestPayload(testModelID) + payload, err := createTestPayload(testModelID) + if err != nil { + return s.sendErrorAndEnd(c, "Failed to create test payload") + } payloadBytes, _ := json.Marshal(payload) // Send test_start event @@ -216,7 +226,7 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int if err != nil { return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error())) } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) @@ -284,7 +294,10 @@ func (s *AccountTestService) processStream(c *gin.Context, body io.Reader) error // sendEvent sends a SSE event to the client func (s *AccountTestService) sendEvent(c *gin.Context, event TestEvent) { eventJSON, _ := json.Marshal(event) - fmt.Fprintf(c.Writer, "data: %s\n\n", eventJSON) + if _, err := fmt.Fprintf(c.Writer, "data: %s\n\n", eventJSON); err != nil { + log.Printf("failed to write SSE event: %v", err) + return + } c.Writer.Flush() } diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index 09e8cd20..ca501ee9 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "log" "time" "sub2api/internal/model" @@ -309,7 +310,9 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda go func() { cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - s.billingCacheService.InvalidateUserBalance(cacheCtx, id) + if err := s.billingCacheService.InvalidateUserBalance(cacheCtx, id); err != nil { + log.Printf("invalidate user balance cache failed: user_id=%d err=%v", id, err) + } }() } } @@ -317,8 +320,13 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda // Create adjustment records for balance/concurrency changes balanceDiff := user.Balance - oldBalance if balanceDiff != 0 { + code, err := model.GenerateRedeemCode() + if err != nil { + log.Printf("failed to generate adjustment redeem code: %v", err) + return user, nil + } adjustmentRecord := &model.RedeemCode{ - Code: model.GenerateRedeemCode(), + Code: code, Type: model.AdjustmentTypeAdminBalance, Value: balanceDiff, Status: model.StatusUsed, @@ -327,15 +335,19 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda now := time.Now() adjustmentRecord.UsedAt = &now if err := s.redeemCodeRepo.Create(ctx, adjustmentRecord); err != nil { - // Log error but don't fail the update - // The user update has already succeeded + log.Printf("failed to create balance adjustment redeem code: %v", err) } } concurrencyDiff := user.Concurrency - oldConcurrency if concurrencyDiff != 0 { + code, err := model.GenerateRedeemCode() + if err != nil { + log.Printf("failed to generate adjustment redeem code: %v", err) + return user, nil + } adjustmentRecord := &model.RedeemCode{ - Code: model.GenerateRedeemCode(), + Code: code, Type: model.AdjustmentTypeAdminConcurrency, Value: float64(concurrencyDiff), Status: model.StatusUsed, @@ -344,8 +356,7 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda now := time.Now() adjustmentRecord.UsedAt = &now if err := s.redeemCodeRepo.Create(ctx, adjustmentRecord); err != nil { - // Log error but don't fail the update - // The user update has already succeeded + log.Printf("failed to create concurrency adjustment redeem code: %v", err) } } @@ -388,7 +399,9 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, go func() { cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - s.billingCacheService.InvalidateUserBalance(cacheCtx, userID) + if err := s.billingCacheService.InvalidateUserBalance(cacheCtx, userID); err != nil { + log.Printf("invalidate user balance cache failed: user_id=%d err=%v", userID, err) + } }() } @@ -579,7 +592,9 @@ func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error { cacheCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() for _, userID := range affectedUserIDs { - s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID) + if err := s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID); err != nil { + log.Printf("invalidate subscription cache failed: user_id=%d group_id=%d err=%v", userID, groupID, err) + } } }() } @@ -646,10 +661,10 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U if input.Type != "" { account.Type = input.Type } - if input.Credentials != nil && len(input.Credentials) > 0 { + if len(input.Credentials) > 0 { account.Credentials = model.JSONB(input.Credentials) } - if input.Extra != nil && len(input.Extra) > 0 { + if len(input.Extra) > 0 { account.Extra = model.JSONB(input.Extra) } if input.ProxyID != nil { @@ -831,8 +846,12 @@ func (s *adminServiceImpl) GenerateRedeemCodes(ctx context.Context, input *Gener codes := make([]model.RedeemCode, 0, input.Count) for i := 0; i < input.Count; i++ { + codeValue, err := model.GenerateRedeemCode() + if err != nil { + return nil, err + } code := model.RedeemCode{ - Code: model.GenerateRedeemCode(), + Code: codeValue, Type: input.Type, Value: input.Value, Status: model.StatusUnused, diff --git a/backend/internal/service/api_key_service.go b/backend/internal/service/api_key_service.go index adbd0f46..1d047633 100644 --- a/backend/internal/service/api_key_service.go +++ b/backend/internal/service/api_key_service.go @@ -100,10 +100,13 @@ func (s *ApiKeyService) ValidateCustomKey(key string) error { // 检查字符:只允许字母、数字、下划线、连字符 for _, c := range key { - if !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || - (c >= '0' && c <= '9') || c == '_' || c == '-') { - return ErrApiKeyInvalidChars + if (c >= 'a' && c <= 'z') || + (c >= 'A' && c <= 'Z') || + (c >= '0' && c <= '9') || + c == '_' || c == '-' { + continue } + return ErrApiKeyInvalidChars } return nil diff --git a/backend/internal/service/concurrency_service.go b/backend/internal/service/concurrency_service.go index 5854ea67..c54167da 100644 --- a/backend/internal/service/concurrency_service.go +++ b/backend/internal/service/concurrency_service.go @@ -9,12 +9,6 @@ import ( ) const ( - // Wait polling interval - waitPollInterval = 100 * time.Millisecond - - // Default max wait time - defaultMaxWait = 60 * time.Second - // Default extra wait slots beyond concurrency limit defaultExtraWaitSlots = 20 ) @@ -31,7 +25,7 @@ func NewConcurrencyService(cache ports.ConcurrencyCache) *ConcurrencyService { // AcquireResult represents the result of acquiring a concurrency slot type AcquireResult struct { - Acquired bool + Acquired bool ReleaseFunc func() // Must be called when done (typically via defer) } @@ -54,7 +48,7 @@ func (s *ConcurrencyService) AcquireAccountSlot(ctx context.Context, accountID i if acquired { return &AcquireResult{ - Acquired: true, + Acquired: true, ReleaseFunc: func() { bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() @@ -90,7 +84,7 @@ func (s *ConcurrencyService) AcquireUserSlot(ctx context.Context, userID int64, if acquired { return &AcquireResult{ - Acquired: true, + Acquired: true, ReleaseFunc: func() { bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() diff --git a/backend/internal/service/email_service.go b/backend/internal/service/email_service.go index e9637846..4093287c 100644 --- a/backend/internal/service/email_service.go +++ b/backend/internal/service/email_service.go @@ -133,13 +133,13 @@ func (s *EmailService) sendMailTLS(addr string, auth smtp.Auth, from, to string, if err != nil { return fmt.Errorf("tls dial: %w", err) } - defer conn.Close() + defer func() { _ = conn.Close() }() client, err := smtp.NewClient(conn, host) if err != nil { return fmt.Errorf("new smtp client: %w", err) } - defer client.Close() + defer func() { _ = client.Close() }() if err = client.Auth(auth); err != nil { return fmt.Errorf("smtp auth: %w", err) @@ -303,13 +303,13 @@ func (s *EmailService) TestSmtpConnectionWithConfig(config *SmtpConfig) error { if err != nil { return fmt.Errorf("tls connection failed: %w", err) } - defer conn.Close() + defer func() { _ = conn.Close() }() client, err := smtp.NewClient(conn, config.Host) if err != nil { return fmt.Errorf("smtp client creation failed: %w", err) } - defer client.Close() + defer func() { _ = client.Close() }() auth := smtp.PlainAuth("", config.Username, config.Password, config.Host) if err = client.Auth(auth); err != nil { @@ -324,7 +324,7 @@ func (s *EmailService) TestSmtpConnectionWithConfig(config *SmtpConfig) error { if err != nil { return fmt.Errorf("smtp connection failed: %w", err) } - defer client.Close() + defer func() { _ = client.Close() }() auth := smtp.PlainAuth("", config.Username, config.Password, config.Host) if err = client.Auth(auth); err != nil { diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 0e208ac3..57b48ee2 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -281,7 +281,9 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int // 同时检查模型支持 if err == nil && account.IsSchedulable() && (requestedModel == "" || account.IsModelSupported(requestedModel)) { // 续期粘性会话 - s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL) + if err := s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL); err != nil { + log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err) + } return account, nil } } @@ -331,7 +333,9 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int // 4. 建立粘性绑定 if sessionHash != "" { - s.cache.SetSessionAccountID(ctx, sessionHash, selected.ID, stickySessionTTL) + if err := s.cache.SetSessionAccountID(ctx, sessionHash, selected.ID, stickySessionTTL); err != nil { + log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err) + } } return selected, nil @@ -411,7 +415,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *m if err != nil { return nil, fmt.Errorf("upstream request failed: %w", err) } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() // 处理错误响应(包括401,由后台TokenRefreshService维护token有效性) if resp.StatusCode >= 400 { @@ -678,7 +682,9 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http } // 转发行 - fmt.Fprintf(w, "%s\n", line) + if _, err := fmt.Fprintf(w, "%s\n", line); err != nil { + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err + } flusher.Flush() // 解析usage数据 @@ -985,7 +991,9 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Request failed") return fmt.Errorf("upstream request failed: %w", err) } - defer resp.Body.Close() + defer func() { + _ = resp.Body.Close() + }() // 读取响应体 respBody, err := io.ReadAll(resp.Body) diff --git a/backend/internal/service/identity_service.go b/backend/internal/service/identity_service.go index 4bc9da6f..5493b724 100644 --- a/backend/internal/service/identity_service.go +++ b/backend/internal/service/identity_service.go @@ -15,7 +15,6 @@ import ( "time" ) - // 预编译正则表达式(避免每次调用重新编译) var ( // 匹配 user_id 格式: user_{64位hex}_account__session_{uuid} diff --git a/backend/internal/service/redeem_service.go b/backend/internal/service/redeem_service.go index de4c7249..4cd97ab4 100644 --- a/backend/internal/service/redeem_service.go +++ b/backend/internal/service/redeem_service.go @@ -254,7 +254,7 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) ( go func() { cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - s.billingCacheService.InvalidateUserBalance(cacheCtx, userID) + _ = s.billingCacheService.InvalidateUserBalance(cacheCtx, userID) }() } @@ -285,7 +285,7 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) ( go func() { cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID) + _ = s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID) }() } diff --git a/backend/internal/service/subscription_service.go b/backend/internal/service/subscription_service.go index d348b5f3..5d3c4d15 100644 --- a/backend/internal/service/subscription_service.go +++ b/backend/internal/service/subscription_service.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "log" "time" "sub2api/internal/model" @@ -78,7 +79,7 @@ func (s *SubscriptionService) AssignSubscription(ctx context.Context, input *Ass go func() { cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID) + _ = s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID) }() } @@ -146,7 +147,7 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in } newNotes += input.Notes if err := s.userSubRepo.UpdateNotes(ctx, existingSub.ID, newNotes); err != nil { - // 备注更新失败不影响主流程 + log.Printf("update subscription notes failed: sub_id=%d err=%v", existingSub.ID, err) } } @@ -156,7 +157,7 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in go func() { cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID) + _ = s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID) }() } @@ -177,7 +178,7 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in go func() { cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID) + _ = s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID) }() } @@ -278,7 +279,7 @@ func (s *SubscriptionService) RevokeSubscription(ctx context.Context, subscripti go func() { cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID) + _ = s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID) }() } @@ -311,7 +312,7 @@ func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscripti go func() { cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID) + _ = s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID) }() } diff --git a/backend/internal/service/turnstile_service.go b/backend/internal/service/turnstile_service.go index 81b6e3a0..318c7d32 100644 --- a/backend/internal/service/turnstile_service.go +++ b/backend/internal/service/turnstile_service.go @@ -12,8 +12,6 @@ var ( ErrTurnstileNotConfigured = errors.New("turnstile not configured") ) -const turnstileVerifyURL = "https://challenges.cloudflare.com/turnstile/v0/siteverify" - // TurnstileVerifier 验证 Turnstile token 的接口 type TurnstileVerifier interface { VerifyToken(ctx context.Context, secretKey, token, remoteIP string) (*TurnstileVerifyResponse, error) diff --git a/backend/internal/service/update_service.go b/backend/internal/service/update_service.go index bc88575b..ca716bef 100644 --- a/backend/internal/service/update_service.go +++ b/backend/internal/service/update_service.go @@ -14,6 +14,7 @@ import ( "os" "path/filepath" "runtime" + "strconv" "strings" "time" @@ -190,7 +191,7 @@ func (s *UpdateService) PerformUpdate(ctx context.Context) error { if err != nil { return fmt.Errorf("failed to create temp dir: %w", err) } - defer os.RemoveAll(tempDir) + defer func() { _ = os.RemoveAll(tempDir) }() // Download archive archivePath := filepath.Join(tempDir, filepath.Base(downloadURL)) @@ -223,7 +224,7 @@ func (s *UpdateService) PerformUpdate(ctx context.Context) error { backupPath := exePath + ".backup" // Remove old backup if exists - os.Remove(backupPath) + _ = os.Remove(backupPath) // Step 1: Move current binary to backup if err := os.Rename(exePath, backupPath); err != nil { @@ -349,7 +350,7 @@ func (s *UpdateService) verifyChecksum(ctx context.Context, filePath, checksumUR if err != nil { return err } - defer f.Close() + defer func() { _ = f.Close() }() h := sha256.New() if _, err := io.Copy(h, f); err != nil { @@ -379,7 +380,7 @@ func (s *UpdateService) extractBinary(archivePath, destPath string) error { if err != nil { return err } - defer f.Close() + defer func() { _ = f.Close() }() var reader io.Reader = f @@ -389,7 +390,7 @@ func (s *UpdateService) extractBinary(archivePath, destPath string) error { if err != nil { return err } - defer gzr.Close() + defer func() { _ = gzr.Close() }() reader = gzr } @@ -435,10 +436,12 @@ func (s *UpdateService) extractBinary(archivePath, destPath string) error { // Use LimitReader to prevent decompression bombs limited := io.LimitReader(tr, maxBinarySize) if _, err := io.Copy(out, limited); err != nil { - out.Close() + _ = out.Close() + return err + } + if err := out.Close(); err != nil { return err } - out.Close() return nil } } @@ -451,11 +454,13 @@ func (s *UpdateService) extractBinary(archivePath, destPath string) error { if err != nil { return err } - defer out.Close() limited := io.LimitReader(reader, maxBinarySize) - _, err = io.Copy(out, limited) - return err + if _, err := io.Copy(out, limited); err != nil { + _ = out.Close() + return err + } + return out.Close() } func (s *UpdateService) getFromCache(ctx context.Context) (*UpdateInfo, error) { @@ -499,7 +504,7 @@ func (s *UpdateService) saveToCache(ctx context.Context, info *UpdateInfo) { } data, _ := json.Marshal(cacheData) - s.cache.SetUpdateInfo(ctx, string(data), time.Duration(updateCacheTTL)*time.Second) + _ = s.cache.SetUpdateInfo(ctx, string(data), time.Duration(updateCacheTTL)*time.Second) } // compareVersions compares two semantic versions @@ -523,7 +528,9 @@ func parseVersion(v string) [3]int { parts := strings.Split(v, ".") result := [3]int{0, 0, 0} for i := 0; i < len(parts) && i < 3; i++ { - fmt.Sscanf(parts[i], "%d", &result[i]) + if parsed, err := strconv.Atoi(parts[i]); err == nil { + result[i] = parsed + } } return result } diff --git a/backend/internal/setup/handler.go b/backend/internal/setup/handler.go index ebd02ed5..a2ddfa93 100644 --- a/backend/internal/setup/handler.go +++ b/backend/internal/setup/handler.go @@ -352,4 +352,3 @@ func install(c *gin.Context) { "restart": true, }) } - diff --git a/backend/internal/setup/setup.go b/backend/internal/setup/setup.go index e8cb702e..5c1d937c 100644 --- a/backend/internal/setup/setup.go +++ b/backend/internal/setup/setup.go @@ -14,9 +14,9 @@ import ( "github.com/redis/go-redis/v9" "golang.org/x/crypto/bcrypt" + "gopkg.in/yaml.v3" "gorm.io/driver/postgres" "gorm.io/gorm" - "gopkg.in/yaml.v3" ) // Config paths @@ -101,7 +101,14 @@ func TestDatabaseConnection(cfg *DatabaseConfig) error { if err != nil { return fmt.Errorf("failed to get db instance: %w", err) } - defer sqlDB.Close() + defer func() { + if sqlDB == nil { + return + } + if err := sqlDB.Close(); err != nil { + log.Printf("failed to close postgres connection: %v", err) + } + }() ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() @@ -129,7 +136,10 @@ func TestDatabaseConnection(cfg *DatabaseConfig) error { } // Now connect to the target database to verify - sqlDB.Close() + if err := sqlDB.Close(); err != nil { + log.Printf("failed to close postgres connection: %v", err) + } + sqlDB = nil targetDSN := fmt.Sprintf( "host=%s port=%d user=%s password=%s dbname=%s sslmode=%s", @@ -145,7 +155,11 @@ func TestDatabaseConnection(cfg *DatabaseConfig) error { if err != nil { return fmt.Errorf("failed to get target db instance: %w", err) } - defer targetSqlDB.Close() + defer func() { + if err := targetSqlDB.Close(); err != nil { + log.Printf("failed to close postgres connection: %v", err) + } + }() ctx2, cancel2 := context.WithTimeout(context.Background(), 5*time.Second) defer cancel2() @@ -164,7 +178,11 @@ func TestRedisConnection(cfg *RedisConfig) error { Password: cfg.Password, DB: cfg.DB, }) - defer rdb.Close() + defer func() { + if err := rdb.Close(); err != nil { + log.Printf("failed to close redis client: %v", err) + } + }() ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() @@ -185,7 +203,11 @@ func Install(cfg *SetupConfig) error { // Generate JWT secret if not provided if cfg.JWT.Secret == "" { - cfg.JWT.Secret = generateSecret(32) + secret, err := generateSecret(32) + if err != nil { + return fmt.Errorf("failed to generate jwt secret: %w", err) + } + cfg.JWT.Secret = secret } // Test connections @@ -243,7 +265,11 @@ func initializeDatabase(cfg *SetupConfig) error { if err != nil { return err } - defer sqlDB.Close() + defer func() { + if err := sqlDB.Close(); err != nil { + log.Printf("failed to close postgres connection: %v", err) + } + }() // 使用 model 包的 AutoMigrate,确保模型定义统一 return model.AutoMigrate(db) @@ -265,7 +291,11 @@ func createAdminUser(cfg *SetupConfig) error { if err != nil { return err } - defer sqlDB.Close() + defer func() { + if err := sqlDB.Close(); err != nil { + log.Printf("failed to close postgres connection: %v", err) + } + }() // Check if admin already exists var count int64 @@ -352,10 +382,12 @@ func writeConfigFile(cfg *SetupConfig) error { return os.WriteFile(ConfigFile, data, 0600) } -func generateSecret(length int) string { +func generateSecret(length int) (string, error) { bytes := make([]byte, length) - rand.Read(bytes) - return hex.EncodeToString(bytes) + if _, err := rand.Read(bytes); err != nil { + return "", err + } + return hex.EncodeToString(bytes), nil } // ============================================================================= @@ -431,13 +463,21 @@ func AutoSetupFromEnv() error { // Generate JWT secret if not provided if cfg.JWT.Secret == "" { - cfg.JWT.Secret = generateSecret(32) + secret, err := generateSecret(32) + if err != nil { + return fmt.Errorf("failed to generate jwt secret: %w", err) + } + cfg.JWT.Secret = secret log.Println("Generated JWT secret automatically") } // Generate admin password if not provided if cfg.Admin.Password == "" { - cfg.Admin.Password = generateSecret(16) + password, err := generateSecret(16) + if err != nil { + return fmt.Errorf("failed to generate admin password: %w", err) + } + cfg.Admin.Password = password log.Printf("Generated admin password: %s", cfg.Admin.Password) log.Println("IMPORTANT: Save this password! It will not be shown again.") } diff --git a/backend/internal/web/embed.go b/backend/internal/web/embed.go index da5ba167..f05ae8df 100644 --- a/backend/internal/web/embed.go +++ b/backend/internal/web/embed.go @@ -41,7 +41,7 @@ func ServeEmbeddedFrontend() gin.HandlerFunc { } if file, err := distFS.Open(cleanPath); err == nil { - file.Close() + _ = file.Close() fileServer.ServeHTTP(c.Writer, c.Request) c.Abort() return @@ -59,7 +59,7 @@ func serveIndexHTML(c *gin.Context, fsys fs.FS) { c.Abort() return } - defer file.Close() + defer func() { _ = file.Close() }() content, err := io.ReadAll(file) if err != nil {