ci(backend): 添加 github actions (#10)
## 变更内容
### CI/CD
- 添加 GitHub Actions 工作流(test + golangci-lint)
- 添加 golangci-lint 配置,启用 errcheck/govet/staticcheck/unused/depguard
- 通过 depguard 强制 service 层不能直接导入 repository
### 错误处理修复
- 修复 CSV 写入、SSE 流式输出、随机数生成等未处理的错误
- GenerateRedeemCode() 现在返回 error
### 资源泄露修复
- 统一使用 defer func() { _ = xxx.Close() }() 模式
### 代码清理
- 移除未使用的常量
- 简化 nil map 检查
- 统一代码格式
This commit is contained in:
36
backend/.github/workflows/ci.yml
vendored
Normal file
36
backend/.github/workflows/ci.yml
vendored
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
name: CI
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
pull_request:
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
test:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version-file: go.mod
|
||||||
|
check-latest: true
|
||||||
|
cache: true
|
||||||
|
- name: Run tests
|
||||||
|
run: go test ./...
|
||||||
|
|
||||||
|
golangci-lint:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version-file: go.mod
|
||||||
|
check-latest: true
|
||||||
|
cache: true
|
||||||
|
- name: golangci-lint
|
||||||
|
uses: golangci/golangci-lint-action@v6
|
||||||
|
with:
|
||||||
|
version: latest
|
||||||
|
args: --timeout=5m
|
||||||
27
backend/.golangci.yml
Normal file
27
backend/.golangci.yml
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
version: "2"
|
||||||
|
|
||||||
|
linters:
|
||||||
|
default: none
|
||||||
|
enable:
|
||||||
|
- depguard
|
||||||
|
- errcheck
|
||||||
|
- govet
|
||||||
|
- ineffassign
|
||||||
|
- staticcheck
|
||||||
|
- unused
|
||||||
|
|
||||||
|
settings:
|
||||||
|
depguard:
|
||||||
|
rules:
|
||||||
|
# Enforce: service must not depend on repository.
|
||||||
|
service-no-repository:
|
||||||
|
list-mode: original
|
||||||
|
files:
|
||||||
|
- internal/service/**
|
||||||
|
deny:
|
||||||
|
- pkg: sub2api/internal/repository
|
||||||
|
desc: "service must not import repository"
|
||||||
|
|
||||||
|
formatters:
|
||||||
|
enable:
|
||||||
|
- gofmt
|
||||||
@@ -52,7 +52,7 @@ type PricingConfig struct {
|
|||||||
type ServerConfig struct {
|
type ServerConfig struct {
|
||||||
Host string `mapstructure:"host"`
|
Host string `mapstructure:"host"`
|
||||||
Port int `mapstructure:"port"`
|
Port int `mapstructure:"port"`
|
||||||
Mode string `mapstructure:"mode"` // debug/release
|
Mode string `mapstructure:"mode"` // debug/release
|
||||||
ReadHeaderTimeout int `mapstructure:"read_header_timeout"` // 读取请求头超时(秒)
|
ReadHeaderTimeout int `mapstructure:"read_header_timeout"` // 读取请求头超时(秒)
|
||||||
IdleTimeout int `mapstructure:"idle_timeout"` // 空闲连接超时(秒)
|
IdleTimeout int `mapstructure:"idle_timeout"` // 空闲连接超时(秒)
|
||||||
}
|
}
|
||||||
@@ -163,7 +163,7 @@ func setDefaults() {
|
|||||||
viper.SetDefault("server.port", 8080)
|
viper.SetDefault("server.port", 8080)
|
||||||
viper.SetDefault("server.mode", "debug")
|
viper.SetDefault("server.mode", "debug")
|
||||||
viper.SetDefault("server.read_header_timeout", 30) // 30秒读取请求头
|
viper.SetDefault("server.read_header_timeout", 30) // 30秒读取请求头
|
||||||
viper.SetDefault("server.idle_timeout", 120) // 120秒空闲超时
|
viper.SetDefault("server.idle_timeout", 120) // 120秒空闲超时
|
||||||
|
|
||||||
// Database
|
// Database
|
||||||
viper.SetDefault("database.host", "localhost")
|
viper.SetDefault("database.host", "localhost")
|
||||||
@@ -210,10 +210,10 @@ func setDefaults() {
|
|||||||
|
|
||||||
// TokenRefresh
|
// TokenRefresh
|
||||||
viper.SetDefault("token_refresh.enabled", true)
|
viper.SetDefault("token_refresh.enabled", true)
|
||||||
viper.SetDefault("token_refresh.check_interval_minutes", 5) // 每5分钟检查一次
|
viper.SetDefault("token_refresh.check_interval_minutes", 5) // 每5分钟检查一次
|
||||||
viper.SetDefault("token_refresh.refresh_before_expiry_hours", 1.5) // 提前1.5小时刷新
|
viper.SetDefault("token_refresh.refresh_before_expiry_hours", 1.5) // 提前1.5小时刷新
|
||||||
viper.SetDefault("token_refresh.max_retries", 3) // 最多重试3次
|
viper.SetDefault("token_refresh.max_retries", 3) // 最多重试3次
|
||||||
viper.SetDefault("token_refresh.retry_backoff_seconds", 2) // 重试退避基础2秒
|
viper.SetDefault("token_refresh.retry_backoff_seconds", 2) // 重试退避基础2秒
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Config) Validate() error {
|
func (c *Config) Validate() error {
|
||||||
|
|||||||
@@ -573,7 +573,7 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
|
|||||||
|
|
||||||
// For API Key accounts: return models based on model_mapping
|
// For API Key accounts: return models based on model_mapping
|
||||||
mapping := account.GetModelMapping()
|
mapping := account.GetModelMapping()
|
||||||
if mapping == nil || len(mapping) == 0 {
|
if len(mapping) == 0 {
|
||||||
// No mapping configured, return default models
|
// No mapping configured, return default models
|
||||||
response.Success(c, claude.DefaultModels)
|
response.Success(c, claude.DefaultModels)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -236,7 +236,6 @@ func (h *ProxyHandler) GetProxyAccounts(c *gin.Context) {
|
|||||||
response.Paginated(c, accounts, total, page, pageSize)
|
response.Paginated(c, accounts, total, page, pageSize)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// BatchCreateProxyItem represents a single proxy in batch create request
|
// BatchCreateProxyItem represents a single proxy in batch create request
|
||||||
type BatchCreateProxyItem struct {
|
type BatchCreateProxyItem struct {
|
||||||
Protocol string `json:"protocol" binding:"required,oneof=http https socks5"`
|
Protocol string `json:"protocol" binding:"required,oneof=http https socks5"`
|
||||||
|
|||||||
@@ -156,10 +156,10 @@ func (h *RedeemHandler) Expire(c *gin.Context) {
|
|||||||
func (h *RedeemHandler) GetStats(c *gin.Context) {
|
func (h *RedeemHandler) GetStats(c *gin.Context) {
|
||||||
// Return mock data for now
|
// Return mock data for now
|
||||||
response.Success(c, gin.H{
|
response.Success(c, gin.H{
|
||||||
"total_codes": 0,
|
"total_codes": 0,
|
||||||
"active_codes": 0,
|
"active_codes": 0,
|
||||||
"used_codes": 0,
|
"used_codes": 0,
|
||||||
"expired_codes": 0,
|
"expired_codes": 0,
|
||||||
"total_value_distributed": 0.0,
|
"total_value_distributed": 0.0,
|
||||||
"by_type": gin.H{
|
"by_type": gin.H{
|
||||||
"balance": 0,
|
"balance": 0,
|
||||||
@@ -187,7 +187,10 @@ func (h *RedeemHandler) Export(c *gin.Context) {
|
|||||||
writer := csv.NewWriter(&buf)
|
writer := csv.NewWriter(&buf)
|
||||||
|
|
||||||
// Write header
|
// Write header
|
||||||
writer.Write([]string{"id", "code", "type", "value", "status", "used_by", "used_at", "created_at"})
|
if err := writer.Write([]string{"id", "code", "type", "value", "status", "used_by", "used_at", "created_at"}); err != nil {
|
||||||
|
response.InternalError(c, "Failed to export redeem codes: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Write data rows
|
// Write data rows
|
||||||
for _, code := range codes {
|
for _, code := range codes {
|
||||||
@@ -199,7 +202,7 @@ func (h *RedeemHandler) Export(c *gin.Context) {
|
|||||||
if code.UsedAt != nil {
|
if code.UsedAt != nil {
|
||||||
usedAt = code.UsedAt.Format("2006-01-02 15:04:05")
|
usedAt = code.UsedAt.Format("2006-01-02 15:04:05")
|
||||||
}
|
}
|
||||||
writer.Write([]string{
|
if err := writer.Write([]string{
|
||||||
fmt.Sprintf("%d", code.ID),
|
fmt.Sprintf("%d", code.ID),
|
||||||
code.Code,
|
code.Code,
|
||||||
code.Type,
|
code.Type,
|
||||||
@@ -208,10 +211,17 @@ func (h *RedeemHandler) Export(c *gin.Context) {
|
|||||||
usedBy,
|
usedBy,
|
||||||
usedAt,
|
usedAt,
|
||||||
code.CreatedAt.Format("2006-01-02 15:04:05"),
|
code.CreatedAt.Format("2006-01-02 15:04:05"),
|
||||||
})
|
}); err != nil {
|
||||||
|
response.InternalError(c, "Failed to export redeem codes: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
writer.Flush()
|
writer.Flush()
|
||||||
|
if err := writer.Error(); err != nil {
|
||||||
|
response.InternalError(c, "Failed to export redeem codes: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
c.Header("Content-Type", "text/csv")
|
c.Header("Content-Type", "text/csv")
|
||||||
c.Header("Content-Disposition", "attachment; filename=redeem_codes.csv")
|
c.Header("Content-Disposition", "attachment; filename=redeem_codes.csv")
|
||||||
|
|||||||
@@ -268,7 +268,9 @@ func (h *GatewayHandler) waitForSlotWithPing(c *gin.Context, slotType string, id
|
|||||||
c.Header("X-Accel-Buffering", "no")
|
c.Header("X-Accel-Buffering", "no")
|
||||||
*streamStarted = true
|
*streamStarted = true
|
||||||
}
|
}
|
||||||
fmt.Fprintf(c.Writer, "data: {\"type\": \"ping\"}\n\n")
|
if _, err := fmt.Fprintf(c.Writer, "data: {\"type\": \"ping\"}\n\n"); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -414,7 +416,9 @@ func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, e
|
|||||||
if ok {
|
if ok {
|
||||||
// Send error event in SSE format
|
// Send error event in SSE format
|
||||||
errorEvent := fmt.Sprintf(`data: {"type": "error", "error": {"type": "%s", "message": "%s"}}`+"\n\n", errType, message)
|
errorEvent := fmt.Sprintf(`data: {"type": "error", "error": {"type": "%s", "message": "%s"}}`+"\n\n", errType, message)
|
||||||
fmt.Fprint(c.Writer, errorEvent)
|
if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil {
|
||||||
|
_ = c.Error(err)
|
||||||
|
}
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
@@ -574,11 +578,11 @@ func sendMockWarmupStream(c *gin.Context, model string) {
|
|||||||
// sendMockWarmupResponse 发送非流式 mock 响应(用于预热请求拦截)
|
// sendMockWarmupResponse 发送非流式 mock 响应(用于预热请求拦截)
|
||||||
func sendMockWarmupResponse(c *gin.Context, model string) {
|
func sendMockWarmupResponse(c *gin.Context, model string) {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"id": "msg_mock_warmup",
|
"id": "msg_mock_warmup",
|
||||||
"type": "message",
|
"type": "message",
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"model": model,
|
"model": model,
|
||||||
"content": []gin.H{{"type": "text", "text": "New Conversation"}},
|
"content": []gin.H{{"type": "text", "text": "New Conversation"}},
|
||||||
"stop_reason": "end_turn",
|
"stop_reason": "end_turn",
|
||||||
"usage": gin.H{
|
"usage": gin.H{
|
||||||
"input_tokens": 10,
|
"input_tokens": 10,
|
||||||
|
|||||||
@@ -40,8 +40,8 @@ type Account struct {
|
|||||||
Extra JSONB `gorm:"type:jsonb;default:'{}'" json:"extra"` // 扩展信息
|
Extra JSONB `gorm:"type:jsonb;default:'{}'" json:"extra"` // 扩展信息
|
||||||
ProxyID *int64 `gorm:"index" json:"proxy_id"`
|
ProxyID *int64 `gorm:"index" json:"proxy_id"`
|
||||||
Concurrency int `gorm:"default:3;not null" json:"concurrency"`
|
Concurrency int `gorm:"default:3;not null" json:"concurrency"`
|
||||||
Priority int `gorm:"default:50;not null" json:"priority"` // 1-100,越小越高
|
Priority int `gorm:"default:50;not null" json:"priority"` // 1-100,越小越高
|
||||||
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled/error
|
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled/error
|
||||||
ErrorMessage string `gorm:"type:text" json:"error_message"`
|
ErrorMessage string `gorm:"type:text" json:"error_message"`
|
||||||
LastUsedAt *time.Time `gorm:"index" json:"last_used_at"`
|
LastUsedAt *time.Time `gorm:"index" json:"last_used_at"`
|
||||||
CreatedAt time.Time `gorm:"not null" json:"created_at"`
|
CreatedAt time.Time `gorm:"not null" json:"created_at"`
|
||||||
@@ -163,7 +163,7 @@ func (a *Account) GetModelMapping() map[string]string {
|
|||||||
// 如果没有设置模型映射,则支持所有模型
|
// 如果没有设置模型映射,则支持所有模型
|
||||||
func (a *Account) IsModelSupported(requestedModel string) bool {
|
func (a *Account) IsModelSupported(requestedModel string) bool {
|
||||||
mapping := a.GetModelMapping()
|
mapping := a.GetModelMapping()
|
||||||
if mapping == nil || len(mapping) == 0 {
|
if len(mapping) == 0 {
|
||||||
return true // 没有映射配置,支持所有模型
|
return true // 没有映射配置,支持所有模型
|
||||||
}
|
}
|
||||||
_, exists := mapping[requestedModel]
|
_, exists := mapping[requestedModel]
|
||||||
@@ -174,7 +174,7 @@ func (a *Account) IsModelSupported(requestedModel string) bool {
|
|||||||
// 如果没有映射,返回原始模型名
|
// 如果没有映射,返回原始模型名
|
||||||
func (a *Account) GetMappedModel(requestedModel string) string {
|
func (a *Account) GetMappedModel(requestedModel string) string {
|
||||||
mapping := a.GetModelMapping()
|
mapping := a.GetModelMapping()
|
||||||
if mapping == nil || len(mapping) == 0 {
|
if len(mapping) == 0 {
|
||||||
return requestedModel
|
return requestedModel
|
||||||
}
|
}
|
||||||
if mappedModel, exists := mapping[requestedModel]; exists {
|
if mappedModel, exists := mapping[requestedModel]; exists {
|
||||||
|
|||||||
@@ -13,13 +13,13 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Group struct {
|
type Group struct {
|
||||||
ID int64 `gorm:"primaryKey" json:"id"`
|
ID int64 `gorm:"primaryKey" json:"id"`
|
||||||
Name string `gorm:"uniqueIndex;size:100;not null" json:"name"`
|
Name string `gorm:"uniqueIndex;size:100;not null" json:"name"`
|
||||||
Description string `gorm:"type:text" json:"description"`
|
Description string `gorm:"type:text" json:"description"`
|
||||||
Platform string `gorm:"size:50;default:anthropic;not null" json:"platform"` // anthropic/openai/gemini
|
Platform string `gorm:"size:50;default:anthropic;not null" json:"platform"` // anthropic/openai/gemini
|
||||||
RateMultiplier float64 `gorm:"type:decimal(10,4);default:1.0;not null" json:"rate_multiplier"`
|
RateMultiplier float64 `gorm:"type:decimal(10,4);default:1.0;not null" json:"rate_multiplier"`
|
||||||
IsExclusive bool `gorm:"default:false;not null" json:"is_exclusive"`
|
IsExclusive bool `gorm:"default:false;not null" json:"is_exclusive"`
|
||||||
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled
|
Status string `gorm:"size:20;default:active;not null" json:"status"` // active/disabled
|
||||||
|
|
||||||
// 订阅功能字段
|
// 订阅功能字段
|
||||||
SubscriptionType string `gorm:"size:20;default:standard;not null" json:"subscription_type"` // standard/subscription
|
SubscriptionType string `gorm:"size:20;default:standard;not null" json:"subscription_type"` // standard/subscription
|
||||||
|
|||||||
@@ -9,15 +9,15 @@ import (
|
|||||||
type RedeemCode struct {
|
type RedeemCode struct {
|
||||||
ID int64 `gorm:"primaryKey" json:"id"`
|
ID int64 `gorm:"primaryKey" json:"id"`
|
||||||
Code string `gorm:"uniqueIndex;size:32;not null" json:"code"`
|
Code string `gorm:"uniqueIndex;size:32;not null" json:"code"`
|
||||||
Type string `gorm:"size:20;default:balance;not null" json:"type"` // balance/concurrency/subscription
|
Type string `gorm:"size:20;default:balance;not null" json:"type"` // balance/concurrency/subscription
|
||||||
Value float64 `gorm:"type:decimal(20,8);not null" json:"value"` // 面值(USD)或并发数或有效天数
|
Value float64 `gorm:"type:decimal(20,8);not null" json:"value"` // 面值(USD)或并发数或有效天数
|
||||||
Status string `gorm:"size:20;default:unused;not null" json:"status"` // unused/used
|
Status string `gorm:"size:20;default:unused;not null" json:"status"` // unused/used
|
||||||
UsedBy *int64 `gorm:"index" json:"used_by"`
|
UsedBy *int64 `gorm:"index" json:"used_by"`
|
||||||
UsedAt *time.Time `json:"used_at"`
|
UsedAt *time.Time `json:"used_at"`
|
||||||
CreatedAt time.Time `gorm:"not null" json:"created_at"`
|
CreatedAt time.Time `gorm:"not null" json:"created_at"`
|
||||||
|
|
||||||
// 订阅类型专用字段
|
// 订阅类型专用字段
|
||||||
GroupID *int64 `gorm:"index" json:"group_id"` // 订阅分组ID (仅subscription类型使用)
|
GroupID *int64 `gorm:"index" json:"group_id"` // 订阅分组ID (仅subscription类型使用)
|
||||||
ValidityDays int `gorm:"default:30" json:"validity_days"` // 订阅有效天数 (仅subscription类型使用)
|
ValidityDays int `gorm:"default:30" json:"validity_days"` // 订阅有效天数 (仅subscription类型使用)
|
||||||
|
|
||||||
// 关联
|
// 关联
|
||||||
@@ -40,8 +40,10 @@ func (r *RedeemCode) CanUse() bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GenerateRedeemCode 生成唯一的兑换码
|
// GenerateRedeemCode 生成唯一的兑换码
|
||||||
func GenerateRedeemCode() string {
|
func GenerateRedeemCode() (string, error) {
|
||||||
b := make([]byte, 16)
|
b := make([]byte, 16)
|
||||||
rand.Read(b)
|
if _, err := rand.Read(b); err != nil {
|
||||||
return hex.EncodeToString(b)
|
return "", err
|
||||||
|
}
|
||||||
|
return hex.EncodeToString(b), nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -19,17 +19,17 @@ func (Setting) TableName() string {
|
|||||||
// 设置Key常量
|
// 设置Key常量
|
||||||
const (
|
const (
|
||||||
// 注册设置
|
// 注册设置
|
||||||
SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册
|
SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册
|
||||||
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
|
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
|
||||||
|
|
||||||
// 邮件服务设置
|
// 邮件服务设置
|
||||||
SettingKeySmtpHost = "smtp_host" // SMTP服务器地址
|
SettingKeySmtpHost = "smtp_host" // SMTP服务器地址
|
||||||
SettingKeySmtpPort = "smtp_port" // SMTP端口
|
SettingKeySmtpPort = "smtp_port" // SMTP端口
|
||||||
SettingKeySmtpUsername = "smtp_username" // SMTP用户名
|
SettingKeySmtpUsername = "smtp_username" // SMTP用户名
|
||||||
SettingKeySmtpPassword = "smtp_password" // SMTP密码(加密存储)
|
SettingKeySmtpPassword = "smtp_password" // SMTP密码(加密存储)
|
||||||
SettingKeySmtpFrom = "smtp_from" // 发件人地址
|
SettingKeySmtpFrom = "smtp_from" // 发件人地址
|
||||||
SettingKeySmtpFromName = "smtp_from_name" // 发件人名称
|
SettingKeySmtpFromName = "smtp_from_name" // 发件人名称
|
||||||
SettingKeySmtpUseTLS = "smtp_use_tls" // 是否使用TLS
|
SettingKeySmtpUseTLS = "smtp_use_tls" // 是否使用TLS
|
||||||
|
|
||||||
// Cloudflare Turnstile 设置
|
// Cloudflare Turnstile 设置
|
||||||
SettingKeyTurnstileEnabled = "turnstile_enabled" // 是否启用 Turnstile 验证
|
SettingKeyTurnstileEnabled = "turnstile_enabled" // 是否启用 Turnstile 验证
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ type UsageLog struct {
|
|||||||
OutputCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"output_cost"`
|
OutputCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"output_cost"`
|
||||||
CacheCreationCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"cache_creation_cost"`
|
CacheCreationCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"cache_creation_cost"`
|
||||||
CacheReadCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"cache_read_cost"`
|
CacheReadCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"cache_read_cost"`
|
||||||
TotalCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"total_cost"` // 原始总费用
|
TotalCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"total_cost"` // 原始总费用
|
||||||
ActualCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"actual_cost"` // 实际扣除费用
|
ActualCost float64 `gorm:"type:decimal(20,10);default:0;not null" json:"actual_cost"` // 实际扣除费用
|
||||||
RateMultiplier float64 `gorm:"type:decimal(10,4);default:1;not null" json:"rate_multiplier"` // 计费倍率
|
RateMultiplier float64 `gorm:"type:decimal(10,4);default:1;not null" json:"rate_multiplier"` // 计费倍率
|
||||||
|
|
||||||
|
|||||||
@@ -9,8 +9,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type User struct {
|
type User struct {
|
||||||
ID int64 `gorm:"primaryKey" json:"id"`
|
ID int64 `gorm:"primaryKey" json:"id"`
|
||||||
Email string `gorm:"uniqueIndex;size:255;not null" json:"email"`
|
Email string `gorm:"uniqueIndex;size:255;not null" json:"email"`
|
||||||
PasswordHash string `gorm:"size:255;not null" json:"-"`
|
PasswordHash string `gorm:"size:255;not null" json:"-"`
|
||||||
Role string `gorm:"size:20;default:user;not null" json:"role"` // admin/user
|
Role string `gorm:"size:20;default:user;not null" json:"role"` // admin/user
|
||||||
Balance float64 `gorm:"type:decimal(20,8);default:0;not null" json:"balance"`
|
Balance float64 `gorm:"type:decimal(20,8);default:0;not null" json:"balance"`
|
||||||
|
|||||||
@@ -37,11 +37,15 @@ func TestInitInvalidTimezone(t *testing.T) {
|
|||||||
|
|
||||||
func TestTimeNowAffected(t *testing.T) {
|
func TestTimeNowAffected(t *testing.T) {
|
||||||
// Reset to UTC first
|
// Reset to UTC first
|
||||||
Init("UTC")
|
if err := Init("UTC"); err != nil {
|
||||||
|
t.Fatalf("Init failed with UTC: %v", err)
|
||||||
|
}
|
||||||
utcNow := time.Now()
|
utcNow := time.Now()
|
||||||
|
|
||||||
// Switch to Shanghai (UTC+8)
|
// Switch to Shanghai (UTC+8)
|
||||||
Init("Asia/Shanghai")
|
if err := Init("Asia/Shanghai"); err != nil {
|
||||||
|
t.Fatalf("Init failed with Asia/Shanghai: %v", err)
|
||||||
|
}
|
||||||
shanghaiNow := time.Now()
|
shanghaiNow := time.Now()
|
||||||
|
|
||||||
// The times should be the same instant, but different timezone representation
|
// The times should be the same instant, but different timezone representation
|
||||||
@@ -58,7 +62,9 @@ func TestTimeNowAffected(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestToday(t *testing.T) {
|
func TestToday(t *testing.T) {
|
||||||
Init("Asia/Shanghai")
|
if err := Init("Asia/Shanghai"); err != nil {
|
||||||
|
t.Fatalf("Init failed with Asia/Shanghai: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
today := Today()
|
today := Today()
|
||||||
now := Now()
|
now := Now()
|
||||||
@@ -75,7 +81,9 @@ func TestToday(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestStartOfDay(t *testing.T) {
|
func TestStartOfDay(t *testing.T) {
|
||||||
Init("Asia/Shanghai")
|
if err := Init("Asia/Shanghai"); err != nil {
|
||||||
|
t.Fatalf("Init failed with Asia/Shanghai: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
// Create a time at 15:30:45
|
// Create a time at 15:30:45
|
||||||
testTime := time.Date(2024, 6, 15, 15, 30, 45, 123456789, Location())
|
testTime := time.Date(2024, 6, 15, 15, 30, 45, 123456789, Location())
|
||||||
@@ -91,7 +99,9 @@ func TestTruncateVsStartOfDay(t *testing.T) {
|
|||||||
// This test demonstrates why Truncate(24*time.Hour) can be problematic
|
// This test demonstrates why Truncate(24*time.Hour) can be problematic
|
||||||
// and why StartOfDay is more reliable for timezone-aware code
|
// and why StartOfDay is more reliable for timezone-aware code
|
||||||
|
|
||||||
Init("Asia/Shanghai")
|
if err := Init("Asia/Shanghai"); err != nil {
|
||||||
|
t.Fatalf("Init failed with Asia/Shanghai: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
now := Now()
|
now := Now()
|
||||||
|
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ func (s *claudeUsageService) FetchUsage(ctx context.Context, accessToken, proxyU
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("request failed: %w", err)
|
return nil, fmt.Errorf("request failed: %w", err)
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
body, _ := io.ReadAll(resp.Body)
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ func (c *githubReleaseClient) FetchLatestRelease(ctx context.Context, repo strin
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
return nil, fmt.Errorf("GitHub API returned %d", resp.StatusCode)
|
return nil, fmt.Errorf("GitHub API returned %d", resp.StatusCode)
|
||||||
@@ -63,7 +63,7 @@ func (c *githubReleaseClient) DownloadFile(ctx context.Context, url, dest string
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
return fmt.Errorf("download returned %d", resp.StatusCode)
|
return fmt.Errorf("download returned %d", resp.StatusCode)
|
||||||
@@ -78,7 +78,7 @@ func (c *githubReleaseClient) DownloadFile(ctx context.Context, url, dest string
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer out.Close()
|
defer func() { _ = out.Close() }()
|
||||||
|
|
||||||
// SECURITY: Use LimitReader to enforce max download size even if Content-Length is missing/wrong
|
// SECURITY: Use LimitReader to enforce max download size even if Content-Length is missing/wrong
|
||||||
limited := io.LimitReader(resp.Body, maxSize+1)
|
limited := io.LimitReader(resp.Body, maxSize+1)
|
||||||
@@ -89,7 +89,7 @@ func (c *githubReleaseClient) DownloadFile(ctx context.Context, url, dest string
|
|||||||
|
|
||||||
// Check if we hit the limit (downloaded more than maxSize)
|
// Check if we hit the limit (downloaded more than maxSize)
|
||||||
if written > maxSize {
|
if written > maxSize {
|
||||||
os.Remove(dest) // Clean up partial file
|
_ = os.Remove(dest) // Clean up partial file (best-effort)
|
||||||
return fmt.Errorf("download exceeded maximum size of %d bytes", maxSize)
|
return fmt.Errorf("download exceeded maximum size of %d bytes", maxSize)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -106,7 +106,7 @@ func (c *githubReleaseClient) FetchChecksumFile(ctx context.Context, url string)
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
return nil, fmt.Errorf("HTTP %d", resp.StatusCode)
|
return nil, fmt.Errorf("HTTP %d", resp.StatusCode)
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ func (c *pricingRemoteClient) FetchPricingJSON(ctx context.Context, url string)
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
return nil, fmt.Errorf("HTTP %d", resp.StatusCode)
|
return nil, fmt.Errorf("HTTP %d", resp.StatusCode)
|
||||||
@@ -52,7 +52,7 @@ func (c *pricingRemoteClient) FetchHashText(ctx context.Context, url string) (st
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
return "", fmt.Errorf("HTTP %d", resp.StatusCode)
|
return "", fmt.Errorf("HTTP %d", resp.StatusCode)
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, fmt.Errorf("proxy connection failed: %w", err)
|
return nil, 0, fmt.Errorf("proxy connection failed: %w", err)
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
latencyMs := time.Since(startTime).Milliseconds()
|
latencyMs := time.Since(startTime).Milliseconds()
|
||||||
|
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ func (v *turnstileVerifier) VerifyToken(ctx context.Context, secretKey, token, r
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("send request: %w", err)
|
return nil, fmt.Errorf("send request: %w", err)
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
var result service.TurnstileVerifyResponse
|
var result service.TurnstileVerifyResponse
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||||
|
|||||||
@@ -51,16 +51,23 @@ func NewAccountTestService(accountRepo ports.AccountRepository, oauthService *OA
|
|||||||
}
|
}
|
||||||
|
|
||||||
// generateSessionString generates a Claude Code style session string
|
// generateSessionString generates a Claude Code style session string
|
||||||
func generateSessionString() string {
|
func generateSessionString() (string, error) {
|
||||||
bytes := make([]byte, 32)
|
bytes := make([]byte, 32)
|
||||||
rand.Read(bytes)
|
if _, err := rand.Read(bytes); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
hex64 := hex.EncodeToString(bytes)
|
hex64 := hex.EncodeToString(bytes)
|
||||||
sessionUUID := uuid.New().String()
|
sessionUUID := uuid.New().String()
|
||||||
return fmt.Sprintf("user_%s_account__session_%s", hex64, sessionUUID)
|
return fmt.Sprintf("user_%s_account__session_%s", hex64, sessionUUID), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// createTestPayload creates a Claude Code style test request payload
|
// createTestPayload creates a Claude Code style test request payload
|
||||||
func createTestPayload(modelID string) map[string]interface{} {
|
func createTestPayload(modelID string) (map[string]interface{}, error) {
|
||||||
|
sessionID, err := generateSessionString()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
return map[string]interface{}{
|
return map[string]interface{}{
|
||||||
"model": modelID,
|
"model": modelID,
|
||||||
"messages": []map[string]interface{}{
|
"messages": []map[string]interface{}{
|
||||||
@@ -87,12 +94,12 @@ func createTestPayload(modelID string) map[string]interface{} {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
"metadata": map[string]string{
|
"metadata": map[string]string{
|
||||||
"user_id": generateSessionString(),
|
"user_id": sessionID,
|
||||||
},
|
},
|
||||||
"max_tokens": 1024,
|
"max_tokens": 1024,
|
||||||
"temperature": 1,
|
"temperature": 1,
|
||||||
"stream": true,
|
"stream": true,
|
||||||
}
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestAccountConnection tests an account's connection by sending a test request
|
// TestAccountConnection tests an account's connection by sending a test request
|
||||||
@@ -116,7 +123,7 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
|
|||||||
// For API Key accounts with model mapping, map the model
|
// For API Key accounts with model mapping, map the model
|
||||||
if account.Type == "apikey" {
|
if account.Type == "apikey" {
|
||||||
mapping := account.GetModelMapping()
|
mapping := account.GetModelMapping()
|
||||||
if mapping != nil && len(mapping) > 0 {
|
if len(mapping) > 0 {
|
||||||
if mappedModel, exists := mapping[testModelID]; exists {
|
if mappedModel, exists := mapping[testModelID]; exists {
|
||||||
testModelID = mappedModel
|
testModelID = mappedModel
|
||||||
}
|
}
|
||||||
@@ -178,7 +185,10 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
|
|||||||
c.Writer.Flush()
|
c.Writer.Flush()
|
||||||
|
|
||||||
// Create Claude Code style payload (same for all account types)
|
// Create Claude Code style payload (same for all account types)
|
||||||
payload := createTestPayload(testModelID)
|
payload, err := createTestPayload(testModelID)
|
||||||
|
if err != nil {
|
||||||
|
return s.sendErrorAndEnd(c, "Failed to create test payload")
|
||||||
|
}
|
||||||
payloadBytes, _ := json.Marshal(payload)
|
payloadBytes, _ := json.Marshal(payload)
|
||||||
|
|
||||||
// Send test_start event
|
// Send test_start event
|
||||||
@@ -216,7 +226,7 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
|
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
body, _ := io.ReadAll(resp.Body)
|
body, _ := io.ReadAll(resp.Body)
|
||||||
@@ -284,7 +294,10 @@ func (s *AccountTestService) processStream(c *gin.Context, body io.Reader) error
|
|||||||
// sendEvent sends a SSE event to the client
|
// sendEvent sends a SSE event to the client
|
||||||
func (s *AccountTestService) sendEvent(c *gin.Context, event TestEvent) {
|
func (s *AccountTestService) sendEvent(c *gin.Context, event TestEvent) {
|
||||||
eventJSON, _ := json.Marshal(event)
|
eventJSON, _ := json.Marshal(event)
|
||||||
fmt.Fprintf(c.Writer, "data: %s\n\n", eventJSON)
|
if _, err := fmt.Fprintf(c.Writer, "data: %s\n\n", eventJSON); err != nil {
|
||||||
|
log.Printf("failed to write SSE event: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
c.Writer.Flush()
|
c.Writer.Flush()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"sub2api/internal/model"
|
"sub2api/internal/model"
|
||||||
@@ -309,7 +310,9 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
|
|||||||
go func() {
|
go func() {
|
||||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
s.billingCacheService.InvalidateUserBalance(cacheCtx, id)
|
if err := s.billingCacheService.InvalidateUserBalance(cacheCtx, id); err != nil {
|
||||||
|
log.Printf("invalidate user balance cache failed: user_id=%d err=%v", id, err)
|
||||||
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -317,8 +320,13 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
|
|||||||
// Create adjustment records for balance/concurrency changes
|
// Create adjustment records for balance/concurrency changes
|
||||||
balanceDiff := user.Balance - oldBalance
|
balanceDiff := user.Balance - oldBalance
|
||||||
if balanceDiff != 0 {
|
if balanceDiff != 0 {
|
||||||
|
code, err := model.GenerateRedeemCode()
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("failed to generate adjustment redeem code: %v", err)
|
||||||
|
return user, nil
|
||||||
|
}
|
||||||
adjustmentRecord := &model.RedeemCode{
|
adjustmentRecord := &model.RedeemCode{
|
||||||
Code: model.GenerateRedeemCode(),
|
Code: code,
|
||||||
Type: model.AdjustmentTypeAdminBalance,
|
Type: model.AdjustmentTypeAdminBalance,
|
||||||
Value: balanceDiff,
|
Value: balanceDiff,
|
||||||
Status: model.StatusUsed,
|
Status: model.StatusUsed,
|
||||||
@@ -327,15 +335,19 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
|
|||||||
now := time.Now()
|
now := time.Now()
|
||||||
adjustmentRecord.UsedAt = &now
|
adjustmentRecord.UsedAt = &now
|
||||||
if err := s.redeemCodeRepo.Create(ctx, adjustmentRecord); err != nil {
|
if err := s.redeemCodeRepo.Create(ctx, adjustmentRecord); err != nil {
|
||||||
// Log error but don't fail the update
|
log.Printf("failed to create balance adjustment redeem code: %v", err)
|
||||||
// The user update has already succeeded
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
concurrencyDiff := user.Concurrency - oldConcurrency
|
concurrencyDiff := user.Concurrency - oldConcurrency
|
||||||
if concurrencyDiff != 0 {
|
if concurrencyDiff != 0 {
|
||||||
|
code, err := model.GenerateRedeemCode()
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("failed to generate adjustment redeem code: %v", err)
|
||||||
|
return user, nil
|
||||||
|
}
|
||||||
adjustmentRecord := &model.RedeemCode{
|
adjustmentRecord := &model.RedeemCode{
|
||||||
Code: model.GenerateRedeemCode(),
|
Code: code,
|
||||||
Type: model.AdjustmentTypeAdminConcurrency,
|
Type: model.AdjustmentTypeAdminConcurrency,
|
||||||
Value: float64(concurrencyDiff),
|
Value: float64(concurrencyDiff),
|
||||||
Status: model.StatusUsed,
|
Status: model.StatusUsed,
|
||||||
@@ -344,8 +356,7 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
|
|||||||
now := time.Now()
|
now := time.Now()
|
||||||
adjustmentRecord.UsedAt = &now
|
adjustmentRecord.UsedAt = &now
|
||||||
if err := s.redeemCodeRepo.Create(ctx, adjustmentRecord); err != nil {
|
if err := s.redeemCodeRepo.Create(ctx, adjustmentRecord); err != nil {
|
||||||
// Log error but don't fail the update
|
log.Printf("failed to create concurrency adjustment redeem code: %v", err)
|
||||||
// The user update has already succeeded
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -388,7 +399,9 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
|
|||||||
go func() {
|
go func() {
|
||||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
s.billingCacheService.InvalidateUserBalance(cacheCtx, userID)
|
if err := s.billingCacheService.InvalidateUserBalance(cacheCtx, userID); err != nil {
|
||||||
|
log.Printf("invalidate user balance cache failed: user_id=%d err=%v", userID, err)
|
||||||
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -579,7 +592,9 @@ func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
|
|||||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
cacheCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
for _, userID := range affectedUserIDs {
|
for _, userID := range affectedUserIDs {
|
||||||
s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
if err := s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID); err != nil {
|
||||||
|
log.Printf("invalidate subscription cache failed: user_id=%d group_id=%d err=%v", userID, groupID, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
@@ -646,10 +661,10 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
|
|||||||
if input.Type != "" {
|
if input.Type != "" {
|
||||||
account.Type = input.Type
|
account.Type = input.Type
|
||||||
}
|
}
|
||||||
if input.Credentials != nil && len(input.Credentials) > 0 {
|
if len(input.Credentials) > 0 {
|
||||||
account.Credentials = model.JSONB(input.Credentials)
|
account.Credentials = model.JSONB(input.Credentials)
|
||||||
}
|
}
|
||||||
if input.Extra != nil && len(input.Extra) > 0 {
|
if len(input.Extra) > 0 {
|
||||||
account.Extra = model.JSONB(input.Extra)
|
account.Extra = model.JSONB(input.Extra)
|
||||||
}
|
}
|
||||||
if input.ProxyID != nil {
|
if input.ProxyID != nil {
|
||||||
@@ -831,8 +846,12 @@ func (s *adminServiceImpl) GenerateRedeemCodes(ctx context.Context, input *Gener
|
|||||||
|
|
||||||
codes := make([]model.RedeemCode, 0, input.Count)
|
codes := make([]model.RedeemCode, 0, input.Count)
|
||||||
for i := 0; i < input.Count; i++ {
|
for i := 0; i < input.Count; i++ {
|
||||||
|
codeValue, err := model.GenerateRedeemCode()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
code := model.RedeemCode{
|
code := model.RedeemCode{
|
||||||
Code: model.GenerateRedeemCode(),
|
Code: codeValue,
|
||||||
Type: input.Type,
|
Type: input.Type,
|
||||||
Value: input.Value,
|
Value: input.Value,
|
||||||
Status: model.StatusUnused,
|
Status: model.StatusUnused,
|
||||||
|
|||||||
@@ -100,10 +100,13 @@ func (s *ApiKeyService) ValidateCustomKey(key string) error {
|
|||||||
|
|
||||||
// 检查字符:只允许字母、数字、下划线、连字符
|
// 检查字符:只允许字母、数字、下划线、连字符
|
||||||
for _, c := range key {
|
for _, c := range key {
|
||||||
if !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') ||
|
if (c >= 'a' && c <= 'z') ||
|
||||||
(c >= '0' && c <= '9') || c == '_' || c == '-') {
|
(c >= 'A' && c <= 'Z') ||
|
||||||
return ErrApiKeyInvalidChars
|
(c >= '0' && c <= '9') ||
|
||||||
|
c == '_' || c == '-' {
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
return ErrApiKeyInvalidChars
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -9,12 +9,6 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// Wait polling interval
|
|
||||||
waitPollInterval = 100 * time.Millisecond
|
|
||||||
|
|
||||||
// Default max wait time
|
|
||||||
defaultMaxWait = 60 * time.Second
|
|
||||||
|
|
||||||
// Default extra wait slots beyond concurrency limit
|
// Default extra wait slots beyond concurrency limit
|
||||||
defaultExtraWaitSlots = 20
|
defaultExtraWaitSlots = 20
|
||||||
)
|
)
|
||||||
@@ -31,7 +25,7 @@ func NewConcurrencyService(cache ports.ConcurrencyCache) *ConcurrencyService {
|
|||||||
|
|
||||||
// AcquireResult represents the result of acquiring a concurrency slot
|
// AcquireResult represents the result of acquiring a concurrency slot
|
||||||
type AcquireResult struct {
|
type AcquireResult struct {
|
||||||
Acquired bool
|
Acquired bool
|
||||||
ReleaseFunc func() // Must be called when done (typically via defer)
|
ReleaseFunc func() // Must be called when done (typically via defer)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -54,7 +48,7 @@ func (s *ConcurrencyService) AcquireAccountSlot(ctx context.Context, accountID i
|
|||||||
|
|
||||||
if acquired {
|
if acquired {
|
||||||
return &AcquireResult{
|
return &AcquireResult{
|
||||||
Acquired: true,
|
Acquired: true,
|
||||||
ReleaseFunc: func() {
|
ReleaseFunc: func() {
|
||||||
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
@@ -90,7 +84,7 @@ func (s *ConcurrencyService) AcquireUserSlot(ctx context.Context, userID int64,
|
|||||||
|
|
||||||
if acquired {
|
if acquired {
|
||||||
return &AcquireResult{
|
return &AcquireResult{
|
||||||
Acquired: true,
|
Acquired: true,
|
||||||
ReleaseFunc: func() {
|
ReleaseFunc: func() {
|
||||||
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|||||||
@@ -133,13 +133,13 @@ func (s *EmailService) sendMailTLS(addr string, auth smtp.Auth, from, to string,
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("tls dial: %w", err)
|
return fmt.Errorf("tls dial: %w", err)
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer func() { _ = conn.Close() }()
|
||||||
|
|
||||||
client, err := smtp.NewClient(conn, host)
|
client, err := smtp.NewClient(conn, host)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("new smtp client: %w", err)
|
return fmt.Errorf("new smtp client: %w", err)
|
||||||
}
|
}
|
||||||
defer client.Close()
|
defer func() { _ = client.Close() }()
|
||||||
|
|
||||||
if err = client.Auth(auth); err != nil {
|
if err = client.Auth(auth); err != nil {
|
||||||
return fmt.Errorf("smtp auth: %w", err)
|
return fmt.Errorf("smtp auth: %w", err)
|
||||||
@@ -303,13 +303,13 @@ func (s *EmailService) TestSmtpConnectionWithConfig(config *SmtpConfig) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("tls connection failed: %w", err)
|
return fmt.Errorf("tls connection failed: %w", err)
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer func() { _ = conn.Close() }()
|
||||||
|
|
||||||
client, err := smtp.NewClient(conn, config.Host)
|
client, err := smtp.NewClient(conn, config.Host)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("smtp client creation failed: %w", err)
|
return fmt.Errorf("smtp client creation failed: %w", err)
|
||||||
}
|
}
|
||||||
defer client.Close()
|
defer func() { _ = client.Close() }()
|
||||||
|
|
||||||
auth := smtp.PlainAuth("", config.Username, config.Password, config.Host)
|
auth := smtp.PlainAuth("", config.Username, config.Password, config.Host)
|
||||||
if err = client.Auth(auth); err != nil {
|
if err = client.Auth(auth); err != nil {
|
||||||
@@ -324,7 +324,7 @@ func (s *EmailService) TestSmtpConnectionWithConfig(config *SmtpConfig) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("smtp connection failed: %w", err)
|
return fmt.Errorf("smtp connection failed: %w", err)
|
||||||
}
|
}
|
||||||
defer client.Close()
|
defer func() { _ = client.Close() }()
|
||||||
|
|
||||||
auth := smtp.PlainAuth("", config.Username, config.Password, config.Host)
|
auth := smtp.PlainAuth("", config.Username, config.Password, config.Host)
|
||||||
if err = client.Auth(auth); err != nil {
|
if err = client.Auth(auth); err != nil {
|
||||||
|
|||||||
@@ -281,7 +281,9 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int
|
|||||||
// 同时检查模型支持
|
// 同时检查模型支持
|
||||||
if err == nil && account.IsSchedulable() && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
|
if err == nil && account.IsSchedulable() && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
|
||||||
// 续期粘性会话
|
// 续期粘性会话
|
||||||
s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL)
|
if err := s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL); err != nil {
|
||||||
|
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
|
||||||
|
}
|
||||||
return account, nil
|
return account, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -331,7 +333,9 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int
|
|||||||
|
|
||||||
// 4. 建立粘性绑定
|
// 4. 建立粘性绑定
|
||||||
if sessionHash != "" {
|
if sessionHash != "" {
|
||||||
s.cache.SetSessionAccountID(ctx, sessionHash, selected.ID, stickySessionTTL)
|
if err := s.cache.SetSessionAccountID(ctx, sessionHash, selected.ID, stickySessionTTL); err != nil {
|
||||||
|
log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return selected, nil
|
return selected, nil
|
||||||
@@ -411,7 +415,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *m
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("upstream request failed: %w", err)
|
return nil, fmt.Errorf("upstream request failed: %w", err)
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
// 处理错误响应(包括401,由后台TokenRefreshService维护token有效性)
|
// 处理错误响应(包括401,由后台TokenRefreshService维护token有效性)
|
||||||
if resp.StatusCode >= 400 {
|
if resp.StatusCode >= 400 {
|
||||||
@@ -678,7 +682,9 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 转发行
|
// 转发行
|
||||||
fmt.Fprintf(w, "%s\n", line)
|
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
||||||
|
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
||||||
|
}
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
|
|
||||||
// 解析usage数据
|
// 解析usage数据
|
||||||
@@ -985,7 +991,9 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
|||||||
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Request failed")
|
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Request failed")
|
||||||
return fmt.Errorf("upstream request failed: %w", err)
|
return fmt.Errorf("upstream request failed: %w", err)
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer func() {
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
// 读取响应体
|
// 读取响应体
|
||||||
respBody, err := io.ReadAll(resp.Body)
|
respBody, err := io.ReadAll(resp.Body)
|
||||||
|
|||||||
@@ -15,7 +15,6 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
// 预编译正则表达式(避免每次调用重新编译)
|
// 预编译正则表达式(避免每次调用重新编译)
|
||||||
var (
|
var (
|
||||||
// 匹配 user_id 格式: user_{64位hex}_account__session_{uuid}
|
// 匹配 user_id 格式: user_{64位hex}_account__session_{uuid}
|
||||||
|
|||||||
@@ -254,7 +254,7 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
|
|||||||
go func() {
|
go func() {
|
||||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
s.billingCacheService.InvalidateUserBalance(cacheCtx, userID)
|
_ = s.billingCacheService.InvalidateUserBalance(cacheCtx, userID)
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -285,7 +285,7 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
|
|||||||
go func() {
|
go func() {
|
||||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
_ = s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"sub2api/internal/model"
|
"sub2api/internal/model"
|
||||||
@@ -78,7 +79,7 @@ func (s *SubscriptionService) AssignSubscription(ctx context.Context, input *Ass
|
|||||||
go func() {
|
go func() {
|
||||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
_ = s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -146,7 +147,7 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
|
|||||||
}
|
}
|
||||||
newNotes += input.Notes
|
newNotes += input.Notes
|
||||||
if err := s.userSubRepo.UpdateNotes(ctx, existingSub.ID, newNotes); err != nil {
|
if err := s.userSubRepo.UpdateNotes(ctx, existingSub.ID, newNotes); err != nil {
|
||||||
// 备注更新失败不影响主流程
|
log.Printf("update subscription notes failed: sub_id=%d err=%v", existingSub.ID, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -156,7 +157,7 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
|
|||||||
go func() {
|
go func() {
|
||||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
_ = s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -177,7 +178,7 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
|
|||||||
go func() {
|
go func() {
|
||||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
_ = s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -278,7 +279,7 @@ func (s *SubscriptionService) RevokeSubscription(ctx context.Context, subscripti
|
|||||||
go func() {
|
go func() {
|
||||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
_ = s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -311,7 +312,7 @@ func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscripti
|
|||||||
go func() {
|
go func() {
|
||||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
_ = s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -12,8 +12,6 @@ var (
|
|||||||
ErrTurnstileNotConfigured = errors.New("turnstile not configured")
|
ErrTurnstileNotConfigured = errors.New("turnstile not configured")
|
||||||
)
|
)
|
||||||
|
|
||||||
const turnstileVerifyURL = "https://challenges.cloudflare.com/turnstile/v0/siteverify"
|
|
||||||
|
|
||||||
// TurnstileVerifier 验证 Turnstile token 的接口
|
// TurnstileVerifier 验证 Turnstile token 的接口
|
||||||
type TurnstileVerifier interface {
|
type TurnstileVerifier interface {
|
||||||
VerifyToken(ctx context.Context, secretKey, token, remoteIP string) (*TurnstileVerifyResponse, error)
|
VerifyToken(ctx context.Context, secretKey, token, remoteIP string) (*TurnstileVerifyResponse, error)
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -190,7 +191,7 @@ func (s *UpdateService) PerformUpdate(ctx context.Context) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to create temp dir: %w", err)
|
return fmt.Errorf("failed to create temp dir: %w", err)
|
||||||
}
|
}
|
||||||
defer os.RemoveAll(tempDir)
|
defer func() { _ = os.RemoveAll(tempDir) }()
|
||||||
|
|
||||||
// Download archive
|
// Download archive
|
||||||
archivePath := filepath.Join(tempDir, filepath.Base(downloadURL))
|
archivePath := filepath.Join(tempDir, filepath.Base(downloadURL))
|
||||||
@@ -223,7 +224,7 @@ func (s *UpdateService) PerformUpdate(ctx context.Context) error {
|
|||||||
backupPath := exePath + ".backup"
|
backupPath := exePath + ".backup"
|
||||||
|
|
||||||
// Remove old backup if exists
|
// Remove old backup if exists
|
||||||
os.Remove(backupPath)
|
_ = os.Remove(backupPath)
|
||||||
|
|
||||||
// Step 1: Move current binary to backup
|
// Step 1: Move current binary to backup
|
||||||
if err := os.Rename(exePath, backupPath); err != nil {
|
if err := os.Rename(exePath, backupPath); err != nil {
|
||||||
@@ -349,7 +350,7 @@ func (s *UpdateService) verifyChecksum(ctx context.Context, filePath, checksumUR
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer f.Close()
|
defer func() { _ = f.Close() }()
|
||||||
|
|
||||||
h := sha256.New()
|
h := sha256.New()
|
||||||
if _, err := io.Copy(h, f); err != nil {
|
if _, err := io.Copy(h, f); err != nil {
|
||||||
@@ -379,7 +380,7 @@ func (s *UpdateService) extractBinary(archivePath, destPath string) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer f.Close()
|
defer func() { _ = f.Close() }()
|
||||||
|
|
||||||
var reader io.Reader = f
|
var reader io.Reader = f
|
||||||
|
|
||||||
@@ -389,7 +390,7 @@ func (s *UpdateService) extractBinary(archivePath, destPath string) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer gzr.Close()
|
defer func() { _ = gzr.Close() }()
|
||||||
reader = gzr
|
reader = gzr
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -435,10 +436,12 @@ func (s *UpdateService) extractBinary(archivePath, destPath string) error {
|
|||||||
// Use LimitReader to prevent decompression bombs
|
// Use LimitReader to prevent decompression bombs
|
||||||
limited := io.LimitReader(tr, maxBinarySize)
|
limited := io.LimitReader(tr, maxBinarySize)
|
||||||
if _, err := io.Copy(out, limited); err != nil {
|
if _, err := io.Copy(out, limited); err != nil {
|
||||||
out.Close()
|
_ = out.Close()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := out.Close(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
out.Close()
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -451,11 +454,13 @@ func (s *UpdateService) extractBinary(archivePath, destPath string) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer out.Close()
|
|
||||||
|
|
||||||
limited := io.LimitReader(reader, maxBinarySize)
|
limited := io.LimitReader(reader, maxBinarySize)
|
||||||
_, err = io.Copy(out, limited)
|
if _, err := io.Copy(out, limited); err != nil {
|
||||||
return err
|
_ = out.Close()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return out.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *UpdateService) getFromCache(ctx context.Context) (*UpdateInfo, error) {
|
func (s *UpdateService) getFromCache(ctx context.Context) (*UpdateInfo, error) {
|
||||||
@@ -499,7 +504,7 @@ func (s *UpdateService) saveToCache(ctx context.Context, info *UpdateInfo) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
data, _ := json.Marshal(cacheData)
|
data, _ := json.Marshal(cacheData)
|
||||||
s.cache.SetUpdateInfo(ctx, string(data), time.Duration(updateCacheTTL)*time.Second)
|
_ = s.cache.SetUpdateInfo(ctx, string(data), time.Duration(updateCacheTTL)*time.Second)
|
||||||
}
|
}
|
||||||
|
|
||||||
// compareVersions compares two semantic versions
|
// compareVersions compares two semantic versions
|
||||||
@@ -523,7 +528,9 @@ func parseVersion(v string) [3]int {
|
|||||||
parts := strings.Split(v, ".")
|
parts := strings.Split(v, ".")
|
||||||
result := [3]int{0, 0, 0}
|
result := [3]int{0, 0, 0}
|
||||||
for i := 0; i < len(parts) && i < 3; i++ {
|
for i := 0; i < len(parts) && i < 3; i++ {
|
||||||
fmt.Sscanf(parts[i], "%d", &result[i])
|
if parsed, err := strconv.Atoi(parts[i]); err == nil {
|
||||||
|
result[i] = parsed
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -352,4 +352,3 @@ func install(c *gin.Context) {
|
|||||||
"restart": true,
|
"restart": true,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -14,9 +14,9 @@ import (
|
|||||||
|
|
||||||
"github.com/redis/go-redis/v9"
|
"github.com/redis/go-redis/v9"
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
"gorm.io/driver/postgres"
|
"gorm.io/driver/postgres"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"gopkg.in/yaml.v3"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Config paths
|
// Config paths
|
||||||
@@ -101,7 +101,14 @@ func TestDatabaseConnection(cfg *DatabaseConfig) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to get db instance: %w", err)
|
return fmt.Errorf("failed to get db instance: %w", err)
|
||||||
}
|
}
|
||||||
defer sqlDB.Close()
|
defer func() {
|
||||||
|
if sqlDB == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := sqlDB.Close(); err != nil {
|
||||||
|
log.Printf("failed to close postgres connection: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
@@ -129,7 +136,10 @@ func TestDatabaseConnection(cfg *DatabaseConfig) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Now connect to the target database to verify
|
// Now connect to the target database to verify
|
||||||
sqlDB.Close()
|
if err := sqlDB.Close(); err != nil {
|
||||||
|
log.Printf("failed to close postgres connection: %v", err)
|
||||||
|
}
|
||||||
|
sqlDB = nil
|
||||||
|
|
||||||
targetDSN := fmt.Sprintf(
|
targetDSN := fmt.Sprintf(
|
||||||
"host=%s port=%d user=%s password=%s dbname=%s sslmode=%s",
|
"host=%s port=%d user=%s password=%s dbname=%s sslmode=%s",
|
||||||
@@ -145,7 +155,11 @@ func TestDatabaseConnection(cfg *DatabaseConfig) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to get target db instance: %w", err)
|
return fmt.Errorf("failed to get target db instance: %w", err)
|
||||||
}
|
}
|
||||||
defer targetSqlDB.Close()
|
defer func() {
|
||||||
|
if err := targetSqlDB.Close(); err != nil {
|
||||||
|
log.Printf("failed to close postgres connection: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
ctx2, cancel2 := context.WithTimeout(context.Background(), 5*time.Second)
|
ctx2, cancel2 := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel2()
|
defer cancel2()
|
||||||
@@ -164,7 +178,11 @@ func TestRedisConnection(cfg *RedisConfig) error {
|
|||||||
Password: cfg.Password,
|
Password: cfg.Password,
|
||||||
DB: cfg.DB,
|
DB: cfg.DB,
|
||||||
})
|
})
|
||||||
defer rdb.Close()
|
defer func() {
|
||||||
|
if err := rdb.Close(); err != nil {
|
||||||
|
log.Printf("failed to close redis client: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
@@ -185,7 +203,11 @@ func Install(cfg *SetupConfig) error {
|
|||||||
|
|
||||||
// Generate JWT secret if not provided
|
// Generate JWT secret if not provided
|
||||||
if cfg.JWT.Secret == "" {
|
if cfg.JWT.Secret == "" {
|
||||||
cfg.JWT.Secret = generateSecret(32)
|
secret, err := generateSecret(32)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to generate jwt secret: %w", err)
|
||||||
|
}
|
||||||
|
cfg.JWT.Secret = secret
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test connections
|
// Test connections
|
||||||
@@ -243,7 +265,11 @@ func initializeDatabase(cfg *SetupConfig) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer sqlDB.Close()
|
defer func() {
|
||||||
|
if err := sqlDB.Close(); err != nil {
|
||||||
|
log.Printf("failed to close postgres connection: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
// 使用 model 包的 AutoMigrate,确保模型定义统一
|
// 使用 model 包的 AutoMigrate,确保模型定义统一
|
||||||
return model.AutoMigrate(db)
|
return model.AutoMigrate(db)
|
||||||
@@ -265,7 +291,11 @@ func createAdminUser(cfg *SetupConfig) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer sqlDB.Close()
|
defer func() {
|
||||||
|
if err := sqlDB.Close(); err != nil {
|
||||||
|
log.Printf("failed to close postgres connection: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
// Check if admin already exists
|
// Check if admin already exists
|
||||||
var count int64
|
var count int64
|
||||||
@@ -352,10 +382,12 @@ func writeConfigFile(cfg *SetupConfig) error {
|
|||||||
return os.WriteFile(ConfigFile, data, 0600)
|
return os.WriteFile(ConfigFile, data, 0600)
|
||||||
}
|
}
|
||||||
|
|
||||||
func generateSecret(length int) string {
|
func generateSecret(length int) (string, error) {
|
||||||
bytes := make([]byte, length)
|
bytes := make([]byte, length)
|
||||||
rand.Read(bytes)
|
if _, err := rand.Read(bytes); err != nil {
|
||||||
return hex.EncodeToString(bytes)
|
return "", err
|
||||||
|
}
|
||||||
|
return hex.EncodeToString(bytes), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
@@ -431,13 +463,21 @@ func AutoSetupFromEnv() error {
|
|||||||
|
|
||||||
// Generate JWT secret if not provided
|
// Generate JWT secret if not provided
|
||||||
if cfg.JWT.Secret == "" {
|
if cfg.JWT.Secret == "" {
|
||||||
cfg.JWT.Secret = generateSecret(32)
|
secret, err := generateSecret(32)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to generate jwt secret: %w", err)
|
||||||
|
}
|
||||||
|
cfg.JWT.Secret = secret
|
||||||
log.Println("Generated JWT secret automatically")
|
log.Println("Generated JWT secret automatically")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate admin password if not provided
|
// Generate admin password if not provided
|
||||||
if cfg.Admin.Password == "" {
|
if cfg.Admin.Password == "" {
|
||||||
cfg.Admin.Password = generateSecret(16)
|
password, err := generateSecret(16)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to generate admin password: %w", err)
|
||||||
|
}
|
||||||
|
cfg.Admin.Password = password
|
||||||
log.Printf("Generated admin password: %s", cfg.Admin.Password)
|
log.Printf("Generated admin password: %s", cfg.Admin.Password)
|
||||||
log.Println("IMPORTANT: Save this password! It will not be shown again.")
|
log.Println("IMPORTANT: Save this password! It will not be shown again.")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ func ServeEmbeddedFrontend() gin.HandlerFunc {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if file, err := distFS.Open(cleanPath); err == nil {
|
if file, err := distFS.Open(cleanPath); err == nil {
|
||||||
file.Close()
|
_ = file.Close()
|
||||||
fileServer.ServeHTTP(c.Writer, c.Request)
|
fileServer.ServeHTTP(c.Writer, c.Request)
|
||||||
c.Abort()
|
c.Abort()
|
||||||
return
|
return
|
||||||
@@ -59,7 +59,7 @@ func serveIndexHTML(c *gin.Context, fsys fs.FS) {
|
|||||||
c.Abort()
|
c.Abort()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer file.Close()
|
defer func() { _ = file.Close() }()
|
||||||
|
|
||||||
content, err := io.ReadAll(file)
|
content, err := io.ReadAll(file)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
Reference in New Issue
Block a user