ci(backend): 添加 github actions (#10)
## 变更内容
### CI/CD
- 添加 GitHub Actions 工作流(test + golangci-lint)
- 添加 golangci-lint 配置,启用 errcheck/govet/staticcheck/unused/depguard
- 通过 depguard 强制 service 层不能直接导入 repository
### 错误处理修复
- 修复 CSV 写入、SSE 流式输出、随机数生成等未处理的错误
- GenerateRedeemCode() 现在返回 error
### 资源泄露修复
- 统一使用 defer func() { _ = xxx.Close() }() 模式
### 代码清理
- 移除未使用的常量
- 简化 nil map 检查
- 统一代码格式
This commit is contained in:
@@ -51,16 +51,23 @@ func NewAccountTestService(accountRepo ports.AccountRepository, oauthService *OA
|
||||
}
|
||||
|
||||
// generateSessionString generates a Claude Code style session string
|
||||
func generateSessionString() string {
|
||||
func generateSessionString() (string, error) {
|
||||
bytes := make([]byte, 32)
|
||||
rand.Read(bytes)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return "", err
|
||||
}
|
||||
hex64 := hex.EncodeToString(bytes)
|
||||
sessionUUID := uuid.New().String()
|
||||
return fmt.Sprintf("user_%s_account__session_%s", hex64, sessionUUID)
|
||||
return fmt.Sprintf("user_%s_account__session_%s", hex64, sessionUUID), nil
|
||||
}
|
||||
|
||||
// createTestPayload creates a Claude Code style test request payload
|
||||
func createTestPayload(modelID string) map[string]interface{} {
|
||||
func createTestPayload(modelID string) (map[string]interface{}, error) {
|
||||
sessionID, err := generateSessionString()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"model": modelID,
|
||||
"messages": []map[string]interface{}{
|
||||
@@ -87,12 +94,12 @@ func createTestPayload(modelID string) map[string]interface{} {
|
||||
},
|
||||
},
|
||||
"metadata": map[string]string{
|
||||
"user_id": generateSessionString(),
|
||||
"user_id": sessionID,
|
||||
},
|
||||
"max_tokens": 1024,
|
||||
"temperature": 1,
|
||||
"stream": true,
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
// TestAccountConnection tests an account's connection by sending a test request
|
||||
@@ -116,7 +123,7 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
|
||||
// For API Key accounts with model mapping, map the model
|
||||
if account.Type == "apikey" {
|
||||
mapping := account.GetModelMapping()
|
||||
if mapping != nil && len(mapping) > 0 {
|
||||
if len(mapping) > 0 {
|
||||
if mappedModel, exists := mapping[testModelID]; exists {
|
||||
testModelID = mappedModel
|
||||
}
|
||||
@@ -178,7 +185,10 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
|
||||
c.Writer.Flush()
|
||||
|
||||
// Create Claude Code style payload (same for all account types)
|
||||
payload := createTestPayload(testModelID)
|
||||
payload, err := createTestPayload(testModelID)
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, "Failed to create test payload")
|
||||
}
|
||||
payloadBytes, _ := json.Marshal(payload)
|
||||
|
||||
// Send test_start event
|
||||
@@ -216,7 +226,7 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
@@ -284,7 +294,10 @@ func (s *AccountTestService) processStream(c *gin.Context, body io.Reader) error
|
||||
// sendEvent sends a SSE event to the client
|
||||
func (s *AccountTestService) sendEvent(c *gin.Context, event TestEvent) {
|
||||
eventJSON, _ := json.Marshal(event)
|
||||
fmt.Fprintf(c.Writer, "data: %s\n\n", eventJSON)
|
||||
if _, err := fmt.Fprintf(c.Writer, "data: %s\n\n", eventJSON); err != nil {
|
||||
log.Printf("failed to write SSE event: %v", err)
|
||||
return
|
||||
}
|
||||
c.Writer.Flush()
|
||||
}
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/model"
|
||||
@@ -309,7 +310,9 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
s.billingCacheService.InvalidateUserBalance(cacheCtx, id)
|
||||
if err := s.billingCacheService.InvalidateUserBalance(cacheCtx, id); err != nil {
|
||||
log.Printf("invalidate user balance cache failed: user_id=%d err=%v", id, err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
@@ -317,8 +320,13 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
|
||||
// Create adjustment records for balance/concurrency changes
|
||||
balanceDiff := user.Balance - oldBalance
|
||||
if balanceDiff != 0 {
|
||||
code, err := model.GenerateRedeemCode()
|
||||
if err != nil {
|
||||
log.Printf("failed to generate adjustment redeem code: %v", err)
|
||||
return user, nil
|
||||
}
|
||||
adjustmentRecord := &model.RedeemCode{
|
||||
Code: model.GenerateRedeemCode(),
|
||||
Code: code,
|
||||
Type: model.AdjustmentTypeAdminBalance,
|
||||
Value: balanceDiff,
|
||||
Status: model.StatusUsed,
|
||||
@@ -327,15 +335,19 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
|
||||
now := time.Now()
|
||||
adjustmentRecord.UsedAt = &now
|
||||
if err := s.redeemCodeRepo.Create(ctx, adjustmentRecord); err != nil {
|
||||
// Log error but don't fail the update
|
||||
// The user update has already succeeded
|
||||
log.Printf("failed to create balance adjustment redeem code: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
concurrencyDiff := user.Concurrency - oldConcurrency
|
||||
if concurrencyDiff != 0 {
|
||||
code, err := model.GenerateRedeemCode()
|
||||
if err != nil {
|
||||
log.Printf("failed to generate adjustment redeem code: %v", err)
|
||||
return user, nil
|
||||
}
|
||||
adjustmentRecord := &model.RedeemCode{
|
||||
Code: model.GenerateRedeemCode(),
|
||||
Code: code,
|
||||
Type: model.AdjustmentTypeAdminConcurrency,
|
||||
Value: float64(concurrencyDiff),
|
||||
Status: model.StatusUsed,
|
||||
@@ -344,8 +356,7 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
|
||||
now := time.Now()
|
||||
adjustmentRecord.UsedAt = &now
|
||||
if err := s.redeemCodeRepo.Create(ctx, adjustmentRecord); err != nil {
|
||||
// Log error but don't fail the update
|
||||
// The user update has already succeeded
|
||||
log.Printf("failed to create concurrency adjustment redeem code: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -388,7 +399,9 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
s.billingCacheService.InvalidateUserBalance(cacheCtx, userID)
|
||||
if err := s.billingCacheService.InvalidateUserBalance(cacheCtx, userID); err != nil {
|
||||
log.Printf("invalidate user balance cache failed: user_id=%d err=%v", userID, err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -579,7 +592,9 @@ func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
for _, userID := range affectedUserIDs {
|
||||
s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
||||
if err := s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID); err != nil {
|
||||
log.Printf("invalidate subscription cache failed: user_id=%d group_id=%d err=%v", userID, groupID, err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
@@ -646,10 +661,10 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
|
||||
if input.Type != "" {
|
||||
account.Type = input.Type
|
||||
}
|
||||
if input.Credentials != nil && len(input.Credentials) > 0 {
|
||||
if len(input.Credentials) > 0 {
|
||||
account.Credentials = model.JSONB(input.Credentials)
|
||||
}
|
||||
if input.Extra != nil && len(input.Extra) > 0 {
|
||||
if len(input.Extra) > 0 {
|
||||
account.Extra = model.JSONB(input.Extra)
|
||||
}
|
||||
if input.ProxyID != nil {
|
||||
@@ -831,8 +846,12 @@ func (s *adminServiceImpl) GenerateRedeemCodes(ctx context.Context, input *Gener
|
||||
|
||||
codes := make([]model.RedeemCode, 0, input.Count)
|
||||
for i := 0; i < input.Count; i++ {
|
||||
codeValue, err := model.GenerateRedeemCode()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
code := model.RedeemCode{
|
||||
Code: model.GenerateRedeemCode(),
|
||||
Code: codeValue,
|
||||
Type: input.Type,
|
||||
Value: input.Value,
|
||||
Status: model.StatusUnused,
|
||||
|
||||
@@ -100,10 +100,13 @@ func (s *ApiKeyService) ValidateCustomKey(key string) error {
|
||||
|
||||
// 检查字符:只允许字母、数字、下划线、连字符
|
||||
for _, c := range key {
|
||||
if !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') ||
|
||||
(c >= '0' && c <= '9') || c == '_' || c == '-') {
|
||||
return ErrApiKeyInvalidChars
|
||||
if (c >= 'a' && c <= 'z') ||
|
||||
(c >= 'A' && c <= 'Z') ||
|
||||
(c >= '0' && c <= '9') ||
|
||||
c == '_' || c == '-' {
|
||||
continue
|
||||
}
|
||||
return ErrApiKeyInvalidChars
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -9,12 +9,6 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
// Wait polling interval
|
||||
waitPollInterval = 100 * time.Millisecond
|
||||
|
||||
// Default max wait time
|
||||
defaultMaxWait = 60 * time.Second
|
||||
|
||||
// Default extra wait slots beyond concurrency limit
|
||||
defaultExtraWaitSlots = 20
|
||||
)
|
||||
@@ -31,7 +25,7 @@ func NewConcurrencyService(cache ports.ConcurrencyCache) *ConcurrencyService {
|
||||
|
||||
// AcquireResult represents the result of acquiring a concurrency slot
|
||||
type AcquireResult struct {
|
||||
Acquired bool
|
||||
Acquired bool
|
||||
ReleaseFunc func() // Must be called when done (typically via defer)
|
||||
}
|
||||
|
||||
@@ -54,7 +48,7 @@ func (s *ConcurrencyService) AcquireAccountSlot(ctx context.Context, accountID i
|
||||
|
||||
if acquired {
|
||||
return &AcquireResult{
|
||||
Acquired: true,
|
||||
Acquired: true,
|
||||
ReleaseFunc: func() {
|
||||
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
@@ -90,7 +84,7 @@ func (s *ConcurrencyService) AcquireUserSlot(ctx context.Context, userID int64,
|
||||
|
||||
if acquired {
|
||||
return &AcquireResult{
|
||||
Acquired: true,
|
||||
Acquired: true,
|
||||
ReleaseFunc: func() {
|
||||
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
@@ -133,13 +133,13 @@ func (s *EmailService) sendMailTLS(addr string, auth smtp.Auth, from, to string,
|
||||
if err != nil {
|
||||
return fmt.Errorf("tls dial: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
defer func() { _ = conn.Close() }()
|
||||
|
||||
client, err := smtp.NewClient(conn, host)
|
||||
if err != nil {
|
||||
return fmt.Errorf("new smtp client: %w", err)
|
||||
}
|
||||
defer client.Close()
|
||||
defer func() { _ = client.Close() }()
|
||||
|
||||
if err = client.Auth(auth); err != nil {
|
||||
return fmt.Errorf("smtp auth: %w", err)
|
||||
@@ -303,13 +303,13 @@ func (s *EmailService) TestSmtpConnectionWithConfig(config *SmtpConfig) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("tls connection failed: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
defer func() { _ = conn.Close() }()
|
||||
|
||||
client, err := smtp.NewClient(conn, config.Host)
|
||||
if err != nil {
|
||||
return fmt.Errorf("smtp client creation failed: %w", err)
|
||||
}
|
||||
defer client.Close()
|
||||
defer func() { _ = client.Close() }()
|
||||
|
||||
auth := smtp.PlainAuth("", config.Username, config.Password, config.Host)
|
||||
if err = client.Auth(auth); err != nil {
|
||||
@@ -324,7 +324,7 @@ func (s *EmailService) TestSmtpConnectionWithConfig(config *SmtpConfig) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("smtp connection failed: %w", err)
|
||||
}
|
||||
defer client.Close()
|
||||
defer func() { _ = client.Close() }()
|
||||
|
||||
auth := smtp.PlainAuth("", config.Username, config.Password, config.Host)
|
||||
if err = client.Auth(auth); err != nil {
|
||||
|
||||
@@ -281,7 +281,9 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int
|
||||
// 同时检查模型支持
|
||||
if err == nil && account.IsSchedulable() && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
|
||||
// 续期粘性会话
|
||||
s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL)
|
||||
if err := s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL); err != nil {
|
||||
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
|
||||
}
|
||||
return account, nil
|
||||
}
|
||||
}
|
||||
@@ -331,7 +333,9 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int
|
||||
|
||||
// 4. 建立粘性绑定
|
||||
if sessionHash != "" {
|
||||
s.cache.SetSessionAccountID(ctx, sessionHash, selected.ID, stickySessionTTL)
|
||||
if err := s.cache.SetSessionAccountID(ctx, sessionHash, selected.ID, stickySessionTTL); err != nil {
|
||||
log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
return selected, nil
|
||||
@@ -411,7 +415,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *m
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("upstream request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
// 处理错误响应(包括401,由后台TokenRefreshService维护token有效性)
|
||||
if resp.StatusCode >= 400 {
|
||||
@@ -678,7 +682,9 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
||||
}
|
||||
|
||||
// 转发行
|
||||
fmt.Fprintf(w, "%s\n", line)
|
||||
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
||||
}
|
||||
flusher.Flush()
|
||||
|
||||
// 解析usage数据
|
||||
@@ -985,7 +991,9 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Request failed")
|
||||
return fmt.Errorf("upstream request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
defer func() {
|
||||
_ = resp.Body.Close()
|
||||
}()
|
||||
|
||||
// 读取响应体
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
|
||||
@@ -15,7 +15,6 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
|
||||
// 预编译正则表达式(避免每次调用重新编译)
|
||||
var (
|
||||
// 匹配 user_id 格式: user_{64位hex}_account__session_{uuid}
|
||||
|
||||
@@ -254,7 +254,7 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
s.billingCacheService.InvalidateUserBalance(cacheCtx, userID)
|
||||
_ = s.billingCacheService.InvalidateUserBalance(cacheCtx, userID)
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -285,7 +285,7 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
||||
_ = s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
||||
}()
|
||||
}
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/model"
|
||||
@@ -78,7 +79,7 @@ func (s *SubscriptionService) AssignSubscription(ctx context.Context, input *Ass
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
||||
_ = s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -146,7 +147,7 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
|
||||
}
|
||||
newNotes += input.Notes
|
||||
if err := s.userSubRepo.UpdateNotes(ctx, existingSub.ID, newNotes); err != nil {
|
||||
// 备注更新失败不影响主流程
|
||||
log.Printf("update subscription notes failed: sub_id=%d err=%v", existingSub.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -156,7 +157,7 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
||||
_ = s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -177,7 +178,7 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
||||
_ = s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -278,7 +279,7 @@ func (s *SubscriptionService) RevokeSubscription(ctx context.Context, subscripti
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
||||
_ = s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -311,7 +312,7 @@ func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscripti
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
||||
_ = s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
||||
}()
|
||||
}
|
||||
|
||||
|
||||
@@ -12,8 +12,6 @@ var (
|
||||
ErrTurnstileNotConfigured = errors.New("turnstile not configured")
|
||||
)
|
||||
|
||||
const turnstileVerifyURL = "https://challenges.cloudflare.com/turnstile/v0/siteverify"
|
||||
|
||||
// TurnstileVerifier 验证 Turnstile token 的接口
|
||||
type TurnstileVerifier interface {
|
||||
VerifyToken(ctx context.Context, secretKey, token, remoteIP string) (*TurnstileVerifyResponse, error)
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -190,7 +191,7 @@ func (s *UpdateService) PerformUpdate(ctx context.Context) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create temp dir: %w", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
defer func() { _ = os.RemoveAll(tempDir) }()
|
||||
|
||||
// Download archive
|
||||
archivePath := filepath.Join(tempDir, filepath.Base(downloadURL))
|
||||
@@ -223,7 +224,7 @@ func (s *UpdateService) PerformUpdate(ctx context.Context) error {
|
||||
backupPath := exePath + ".backup"
|
||||
|
||||
// Remove old backup if exists
|
||||
os.Remove(backupPath)
|
||||
_ = os.Remove(backupPath)
|
||||
|
||||
// Step 1: Move current binary to backup
|
||||
if err := os.Rename(exePath, backupPath); err != nil {
|
||||
@@ -349,7 +350,7 @@ func (s *UpdateService) verifyChecksum(ctx context.Context, filePath, checksumUR
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
defer func() { _ = f.Close() }()
|
||||
|
||||
h := sha256.New()
|
||||
if _, err := io.Copy(h, f); err != nil {
|
||||
@@ -379,7 +380,7 @@ func (s *UpdateService) extractBinary(archivePath, destPath string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
defer func() { _ = f.Close() }()
|
||||
|
||||
var reader io.Reader = f
|
||||
|
||||
@@ -389,7 +390,7 @@ func (s *UpdateService) extractBinary(archivePath, destPath string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer gzr.Close()
|
||||
defer func() { _ = gzr.Close() }()
|
||||
reader = gzr
|
||||
}
|
||||
|
||||
@@ -435,10 +436,12 @@ func (s *UpdateService) extractBinary(archivePath, destPath string) error {
|
||||
// Use LimitReader to prevent decompression bombs
|
||||
limited := io.LimitReader(tr, maxBinarySize)
|
||||
if _, err := io.Copy(out, limited); err != nil {
|
||||
out.Close()
|
||||
_ = out.Close()
|
||||
return err
|
||||
}
|
||||
if err := out.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
out.Close()
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@@ -451,11 +454,13 @@ func (s *UpdateService) extractBinary(archivePath, destPath string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer out.Close()
|
||||
|
||||
limited := io.LimitReader(reader, maxBinarySize)
|
||||
_, err = io.Copy(out, limited)
|
||||
return err
|
||||
if _, err := io.Copy(out, limited); err != nil {
|
||||
_ = out.Close()
|
||||
return err
|
||||
}
|
||||
return out.Close()
|
||||
}
|
||||
|
||||
func (s *UpdateService) getFromCache(ctx context.Context) (*UpdateInfo, error) {
|
||||
@@ -499,7 +504,7 @@ func (s *UpdateService) saveToCache(ctx context.Context, info *UpdateInfo) {
|
||||
}
|
||||
|
||||
data, _ := json.Marshal(cacheData)
|
||||
s.cache.SetUpdateInfo(ctx, string(data), time.Duration(updateCacheTTL)*time.Second)
|
||||
_ = s.cache.SetUpdateInfo(ctx, string(data), time.Duration(updateCacheTTL)*time.Second)
|
||||
}
|
||||
|
||||
// compareVersions compares two semantic versions
|
||||
@@ -523,7 +528,9 @@ func parseVersion(v string) [3]int {
|
||||
parts := strings.Split(v, ".")
|
||||
result := [3]int{0, 0, 0}
|
||||
for i := 0; i < len(parts) && i < 3; i++ {
|
||||
fmt.Sscanf(parts[i], "%d", &result[i])
|
||||
if parsed, err := strconv.Atoi(parts[i]); err == nil {
|
||||
result[i] = parsed
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user