fix(audit): 第二批审计修复 — P0 生产 Bug、安全加固、性能优化、缓存一致性、代码质量
基于 backend-code-audit 审计报告,修复剩余 P0/P1/P2 共 34 项问题: P0 生产 Bug: - 修复 time.Since(time.Now()) 计时逻辑错误 (P0-03) - generateRandomID 改用 crypto/rand 替代固定索引 (P0-04) - IncrementQuotaUsed 重写为 Ent 原子操作消除 TOCTOU 竞态 (P0-05) 安全加固: - gateway/openai handler 错误响应替换为泛化消息,防止内部信息泄露 (P1-14) - usage_log_repo dateFormat 参数改用白名单映射,防止 SQL 注入 (P1-16) - 默认配置安全加固:sslmode=prefer、response_headers=true、mode=release (P1-18/19, P2-15) 性能优化: - gateway handler 循环内 defer 替换为显式 releaseWait 闭包 (P1-02) - group_repo/promo_code_repo Count 前 Clone 查询避免状态污染 (P1-03) - usage_log_repo 四个查询添加 LIMIT 10000 防止 OOM (P1-07) - GetBatchUsageStats 添加时间范围参数,默认最近 30 天 (P1-10) - ip.go CIDR 预编译为包级变量 (P1-11) - BatchUpdateCredentials 重构为先验证后更新 (P1-13) 缓存一致性: - billing_cache 添加 jitteredTTL 防止缓存雪崩 (P2-10) - DeductUserBalance/UpdateSubscriptionUsage 错误传播修复 (P2-12) - UserService.UpdateBalance 成功后异步失效 billingCache (P2-13) 代码质量: - search 截断改为按 rune 处理,支持多字节字符 (P2-01) - TLS Handshake 改为 HandshakeContext 支持 context 取消 (P2-07) - CORS 预检添加 Access-Control-Max-Age: 86400 (P2-16) 测试覆盖: - 新增 user_service_test.go(UpdateBalance 缓存失效 6 个用例) - 新增 batch_update_credentials_test.go(fail-fast + 类型验证 7 个用例) - 新增 response_transformer_test.go、ip_test.go、usage_log_repo_unit_test.go、search_truncate_test.go - 集成测试:IncrementQuotaUsed 并发测试、billing_cache 错误传播测试 - config_test.go 补充 server.mode/sslmode 默认值断言 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -65,7 +65,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService)
|
apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService)
|
||||||
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator)
|
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator)
|
||||||
authService := service.NewAuthService(userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService)
|
authService := service.NewAuthService(userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService)
|
||||||
userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator)
|
userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator, billingCache)
|
||||||
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, configConfig)
|
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, configConfig)
|
||||||
redeemCache := repository.NewRedeemCache(redisClient)
|
redeemCache := repository.NewRedeemCache(redisClient)
|
||||||
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator)
|
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator)
|
||||||
|
|||||||
@@ -715,7 +715,7 @@ func setDefaults() {
|
|||||||
// Server
|
// Server
|
||||||
viper.SetDefault("server.host", "0.0.0.0")
|
viper.SetDefault("server.host", "0.0.0.0")
|
||||||
viper.SetDefault("server.port", 8080)
|
viper.SetDefault("server.port", 8080)
|
||||||
viper.SetDefault("server.mode", "debug")
|
viper.SetDefault("server.mode", "release")
|
||||||
viper.SetDefault("server.frontend_url", "")
|
viper.SetDefault("server.frontend_url", "")
|
||||||
viper.SetDefault("server.read_header_timeout", 30) // 30秒读取请求头
|
viper.SetDefault("server.read_header_timeout", 30) // 30秒读取请求头
|
||||||
viper.SetDefault("server.idle_timeout", 120) // 120秒空闲超时
|
viper.SetDefault("server.idle_timeout", 120) // 120秒空闲超时
|
||||||
@@ -751,7 +751,7 @@ func setDefaults() {
|
|||||||
viper.SetDefault("security.url_allowlist.crs_hosts", []string{})
|
viper.SetDefault("security.url_allowlist.crs_hosts", []string{})
|
||||||
viper.SetDefault("security.url_allowlist.allow_private_hosts", true)
|
viper.SetDefault("security.url_allowlist.allow_private_hosts", true)
|
||||||
viper.SetDefault("security.url_allowlist.allow_insecure_http", true)
|
viper.SetDefault("security.url_allowlist.allow_insecure_http", true)
|
||||||
viper.SetDefault("security.response_headers.enabled", false)
|
viper.SetDefault("security.response_headers.enabled", true)
|
||||||
viper.SetDefault("security.response_headers.additional_allowed", []string{})
|
viper.SetDefault("security.response_headers.additional_allowed", []string{})
|
||||||
viper.SetDefault("security.response_headers.force_remove", []string{})
|
viper.SetDefault("security.response_headers.force_remove", []string{})
|
||||||
viper.SetDefault("security.csp.enabled", true)
|
viper.SetDefault("security.csp.enabled", true)
|
||||||
@@ -789,9 +789,9 @@ func setDefaults() {
|
|||||||
viper.SetDefault("database.user", "postgres")
|
viper.SetDefault("database.user", "postgres")
|
||||||
viper.SetDefault("database.password", "postgres")
|
viper.SetDefault("database.password", "postgres")
|
||||||
viper.SetDefault("database.dbname", "sub2api")
|
viper.SetDefault("database.dbname", "sub2api")
|
||||||
viper.SetDefault("database.sslmode", "disable")
|
viper.SetDefault("database.sslmode", "prefer")
|
||||||
viper.SetDefault("database.max_open_conns", 50)
|
viper.SetDefault("database.max_open_conns", 256)
|
||||||
viper.SetDefault("database.max_idle_conns", 10)
|
viper.SetDefault("database.max_idle_conns", 128)
|
||||||
viper.SetDefault("database.conn_max_lifetime_minutes", 30)
|
viper.SetDefault("database.conn_max_lifetime_minutes", 30)
|
||||||
viper.SetDefault("database.conn_max_idle_time_minutes", 5)
|
viper.SetDefault("database.conn_max_idle_time_minutes", 5)
|
||||||
|
|
||||||
|
|||||||
@@ -87,8 +87,34 @@ func TestLoadDefaultSecurityToggles(t *testing.T) {
|
|||||||
if !cfg.Security.URLAllowlist.AllowPrivateHosts {
|
if !cfg.Security.URLAllowlist.AllowPrivateHosts {
|
||||||
t.Fatalf("URLAllowlist.AllowPrivateHosts = false, want true")
|
t.Fatalf("URLAllowlist.AllowPrivateHosts = false, want true")
|
||||||
}
|
}
|
||||||
if cfg.Security.ResponseHeaders.Enabled {
|
if !cfg.Security.ResponseHeaders.Enabled {
|
||||||
t.Fatalf("ResponseHeaders.Enabled = true, want false")
|
t.Fatalf("ResponseHeaders.Enabled = false, want true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadDefaultServerMode(t *testing.T) {
|
||||||
|
viper.Reset()
|
||||||
|
|
||||||
|
cfg, err := Load()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Load() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.Server.Mode != "release" {
|
||||||
|
t.Fatalf("Server.Mode = %q, want %q", cfg.Server.Mode, "release")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadDefaultDatabaseSSLMode(t *testing.T) {
|
||||||
|
viper.Reset()
|
||||||
|
|
||||||
|
cfg, err := Load()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Load() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.Database.SSLMode != "prefer" {
|
||||||
|
t.Fatalf("Database.SSLMode = %q, want %q", cfg.Database.SSLMode, "prefer")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package admin
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -738,57 +739,40 @@ func (h *AccountHandler) BatchUpdateCredentials(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ctx := c.Request.Context()
|
ctx := c.Request.Context()
|
||||||
success := 0
|
|
||||||
failed := 0
|
|
||||||
results := []gin.H{}
|
|
||||||
|
|
||||||
|
// 阶段一:预验证所有账号存在,收集 credentials
|
||||||
|
type accountUpdate struct {
|
||||||
|
ID int64
|
||||||
|
Credentials map[string]any
|
||||||
|
}
|
||||||
|
updates := make([]accountUpdate, 0, len(req.AccountIDs))
|
||||||
for _, accountID := range req.AccountIDs {
|
for _, accountID := range req.AccountIDs {
|
||||||
// Get account
|
|
||||||
account, err := h.adminService.GetAccount(ctx, accountID)
|
account, err := h.adminService.GetAccount(ctx, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
failed++
|
response.Error(c, 404, fmt.Sprintf("Account %d not found", accountID))
|
||||||
results = append(results, gin.H{
|
return
|
||||||
"account_id": accountID,
|
|
||||||
"success": false,
|
|
||||||
"error": "Account not found",
|
|
||||||
})
|
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update credentials field
|
|
||||||
if account.Credentials == nil {
|
if account.Credentials == nil {
|
||||||
account.Credentials = make(map[string]any)
|
account.Credentials = make(map[string]any)
|
||||||
}
|
}
|
||||||
|
|
||||||
account.Credentials[req.Field] = req.Value
|
account.Credentials[req.Field] = req.Value
|
||||||
|
updates = append(updates, accountUpdate{ID: accountID, Credentials: account.Credentials})
|
||||||
|
}
|
||||||
|
|
||||||
// Update account
|
// 阶段二:依次更新,任何失败立即返回(避免部分成功部分失败)
|
||||||
|
for _, u := range updates {
|
||||||
updateInput := &service.UpdateAccountInput{
|
updateInput := &service.UpdateAccountInput{
|
||||||
Credentials: account.Credentials,
|
Credentials: u.Credentials,
|
||||||
}
|
}
|
||||||
|
if _, err := h.adminService.UpdateAccount(ctx, u.ID, updateInput); err != nil {
|
||||||
_, err = h.adminService.UpdateAccount(ctx, accountID, updateInput)
|
response.Error(c, 500, fmt.Sprintf("Failed to update account %d: %v", u.ID, err))
|
||||||
if err != nil {
|
return
|
||||||
failed++
|
|
||||||
results = append(results, gin.H{
|
|
||||||
"account_id": accountID,
|
|
||||||
"success": false,
|
|
||||||
"error": err.Error(),
|
|
||||||
})
|
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
|
|
||||||
success++
|
|
||||||
results = append(results, gin.H{
|
|
||||||
"account_id": accountID,
|
|
||||||
"success": true,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
response.Success(c, gin.H{
|
response.Success(c, gin.H{
|
||||||
"success": success,
|
"success": len(updates),
|
||||||
"failed": failed,
|
"failed": 0,
|
||||||
"results": results,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
200
backend/internal/handler/admin/batch_update_credentials_test.go
Normal file
200
backend/internal/handler/admin/batch_update_credentials_test.go
Normal file
@@ -0,0 +1,200 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package admin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
// failingAdminService 嵌入 stubAdminService,可配置 UpdateAccount 在指定 ID 时失败。
|
||||||
|
type failingAdminService struct {
|
||||||
|
*stubAdminService
|
||||||
|
failOnAccountID int64
|
||||||
|
updateCallCount atomic.Int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *failingAdminService) UpdateAccount(ctx context.Context, id int64, input *service.UpdateAccountInput) (*service.Account, error) {
|
||||||
|
f.updateCallCount.Add(1)
|
||||||
|
if id == f.failOnAccountID {
|
||||||
|
return nil, errors.New("database error")
|
||||||
|
}
|
||||||
|
return f.stubAdminService.UpdateAccount(ctx, id, input)
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupAccountHandlerWithService(adminSvc service.AdminService) (*gin.Engine, *AccountHandler) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
router := gin.New()
|
||||||
|
handler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||||
|
router.POST("/api/v1/admin/accounts/batch-update-credentials", handler.BatchUpdateCredentials)
|
||||||
|
return router, handler
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBatchUpdateCredentials_AllSuccess(t *testing.T) {
|
||||||
|
svc := &failingAdminService{stubAdminService: newStubAdminService()}
|
||||||
|
router, _ := setupAccountHandlerWithService(svc)
|
||||||
|
|
||||||
|
body, _ := json.Marshal(BatchUpdateCredentialsRequest{
|
||||||
|
AccountIDs: []int64{1, 2, 3},
|
||||||
|
Field: "account_uuid",
|
||||||
|
Value: "test-uuid",
|
||||||
|
})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, w.Code, "全部成功时应返回 200")
|
||||||
|
require.Equal(t, int64(3), svc.updateCallCount.Load(), "应调用 3 次 UpdateAccount")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBatchUpdateCredentials_FailFast(t *testing.T) {
|
||||||
|
// 让第 2 个账号(ID=2)更新时失败
|
||||||
|
svc := &failingAdminService{
|
||||||
|
stubAdminService: newStubAdminService(),
|
||||||
|
failOnAccountID: 2,
|
||||||
|
}
|
||||||
|
router, _ := setupAccountHandlerWithService(svc)
|
||||||
|
|
||||||
|
body, _ := json.Marshal(BatchUpdateCredentialsRequest{
|
||||||
|
AccountIDs: []int64{1, 2, 3},
|
||||||
|
Field: "org_uuid",
|
||||||
|
Value: "test-org",
|
||||||
|
})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusInternalServerError, w.Code, "ID=2 失败时应返回 500")
|
||||||
|
// 验证 fail-fast:ID=1 更新成功,ID=2 失败,ID=3 不应被调用
|
||||||
|
require.Equal(t, int64(2), svc.updateCallCount.Load(),
|
||||||
|
"fail-fast: 应只调用 2 次 UpdateAccount(ID=1 成功、ID=2 失败后停止)")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBatchUpdateCredentials_FirstAccountNotFound(t *testing.T) {
|
||||||
|
// GetAccount 在 stubAdminService 中总是成功的,需要创建一个 GetAccount 会失败的 stub
|
||||||
|
svc := &getAccountFailingService{
|
||||||
|
stubAdminService: newStubAdminService(),
|
||||||
|
failOnAccountID: 1,
|
||||||
|
}
|
||||||
|
router, _ := setupAccountHandlerWithService(svc)
|
||||||
|
|
||||||
|
body, _ := json.Marshal(BatchUpdateCredentialsRequest{
|
||||||
|
AccountIDs: []int64{1, 2, 3},
|
||||||
|
Field: "account_uuid",
|
||||||
|
Value: "test",
|
||||||
|
})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusNotFound, w.Code, "第一阶段验证失败应返回 404")
|
||||||
|
}
|
||||||
|
|
||||||
|
// getAccountFailingService 模拟 GetAccount 在特定 ID 时返回 not found。
|
||||||
|
type getAccountFailingService struct {
|
||||||
|
*stubAdminService
|
||||||
|
failOnAccountID int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *getAccountFailingService) GetAccount(ctx context.Context, id int64) (*service.Account, error) {
|
||||||
|
if id == f.failOnAccountID {
|
||||||
|
return nil, errors.New("not found")
|
||||||
|
}
|
||||||
|
return f.stubAdminService.GetAccount(ctx, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBatchUpdateCredentials_InterceptWarmupRequests_NonBool(t *testing.T) {
|
||||||
|
svc := &failingAdminService{stubAdminService: newStubAdminService()}
|
||||||
|
router, _ := setupAccountHandlerWithService(svc)
|
||||||
|
|
||||||
|
// intercept_warmup_requests 传入非 bool 类型(string),应返回 400
|
||||||
|
body, _ := json.Marshal(map[string]any{
|
||||||
|
"account_ids": []int64{1},
|
||||||
|
"field": "intercept_warmup_requests",
|
||||||
|
"value": "not-a-bool",
|
||||||
|
})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusBadRequest, w.Code,
|
||||||
|
"intercept_warmup_requests 传入非 bool 值应返回 400")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBatchUpdateCredentials_InterceptWarmupRequests_ValidBool(t *testing.T) {
|
||||||
|
svc := &failingAdminService{stubAdminService: newStubAdminService()}
|
||||||
|
router, _ := setupAccountHandlerWithService(svc)
|
||||||
|
|
||||||
|
body, _ := json.Marshal(map[string]any{
|
||||||
|
"account_ids": []int64{1},
|
||||||
|
"field": "intercept_warmup_requests",
|
||||||
|
"value": true,
|
||||||
|
})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, w.Code,
|
||||||
|
"intercept_warmup_requests 传入合法 bool 值应返回 200")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBatchUpdateCredentials_AccountUUID_NonString(t *testing.T) {
|
||||||
|
svc := &failingAdminService{stubAdminService: newStubAdminService()}
|
||||||
|
router, _ := setupAccountHandlerWithService(svc)
|
||||||
|
|
||||||
|
// account_uuid 传入非 string 类型(number),应返回 400
|
||||||
|
body, _ := json.Marshal(map[string]any{
|
||||||
|
"account_ids": []int64{1},
|
||||||
|
"field": "account_uuid",
|
||||||
|
"value": 12345,
|
||||||
|
})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusBadRequest, w.Code,
|
||||||
|
"account_uuid 传入非 string 值应返回 400")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBatchUpdateCredentials_AccountUUID_NullValue(t *testing.T) {
|
||||||
|
svc := &failingAdminService{stubAdminService: newStubAdminService()}
|
||||||
|
router, _ := setupAccountHandlerWithService(svc)
|
||||||
|
|
||||||
|
// account_uuid 传入 null(设置为空),应正常通过
|
||||||
|
body, _ := json.Marshal(map[string]any{
|
||||||
|
"account_ids": []int64{1},
|
||||||
|
"field": "account_uuid",
|
||||||
|
"value": nil,
|
||||||
|
})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, w.Code,
|
||||||
|
"account_uuid 传入 null 应返回 200")
|
||||||
|
}
|
||||||
@@ -379,7 +379,7 @@ func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
stats, err := h.dashboardService.GetBatchUserUsageStats(c.Request.Context(), req.UserIDs)
|
stats, err := h.dashboardService.GetBatchUserUsageStats(c.Request.Context(), req.UserIDs, time.Time{}, time.Time{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.Error(c, 500, "Failed to get user usage stats")
|
response.Error(c, 500, "Failed to get user usage stats")
|
||||||
return
|
return
|
||||||
@@ -407,7 +407,7 @@ func (h *DashboardHandler) GetBatchAPIKeysUsage(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
stats, err := h.dashboardService.GetBatchAPIKeyUsageStats(c.Request.Context(), req.APIKeyIDs)
|
stats, err := h.dashboardService.GetBatchAPIKeyUsageStats(c.Request.Context(), req.APIKeyIDs, time.Time{}, time.Time{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.Error(c, 500, "Failed to get API key usage stats")
|
response.Error(c, 500, "Failed to get API key usage stats")
|
||||||
return
|
return
|
||||||
|
|||||||
97
backend/internal/handler/admin/search_truncate_test.go
Normal file
97
backend/internal/handler/admin/search_truncate_test.go
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package admin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// truncateSearchByRune 模拟 user_handler.go 中的 search 截断逻辑
|
||||||
|
func truncateSearchByRune(search string, maxRunes int) string {
|
||||||
|
if runes := []rune(search); len(runes) > maxRunes {
|
||||||
|
return string(runes[:maxRunes])
|
||||||
|
}
|
||||||
|
return search
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTruncateSearchByRune(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
maxRunes int
|
||||||
|
wantLen int // 期望的 rune 长度
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "纯中文超长",
|
||||||
|
input: string(make([]rune, 150)),
|
||||||
|
maxRunes: 100,
|
||||||
|
wantLen: 100,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "纯 ASCII 超长",
|
||||||
|
input: string(make([]byte, 150)),
|
||||||
|
maxRunes: 100,
|
||||||
|
wantLen: 100,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "空字符串",
|
||||||
|
input: "",
|
||||||
|
maxRunes: 100,
|
||||||
|
wantLen: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "恰好 100 个字符",
|
||||||
|
input: string(make([]rune, 100)),
|
||||||
|
maxRunes: 100,
|
||||||
|
wantLen: 100,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "不足 100 字符不截断",
|
||||||
|
input: "hello世界",
|
||||||
|
maxRunes: 100,
|
||||||
|
wantLen: 7,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
result := truncateSearchByRune(tc.input, tc.maxRunes)
|
||||||
|
require.Equal(t, tc.wantLen, len([]rune(result)))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTruncateSearchByRune_PreservesMultibyte(t *testing.T) {
|
||||||
|
// 101 个中文字符,截断到 100 个后应该仍然是有效 UTF-8
|
||||||
|
input := ""
|
||||||
|
for i := 0; i < 101; i++ {
|
||||||
|
input += "中"
|
||||||
|
}
|
||||||
|
result := truncateSearchByRune(input, 100)
|
||||||
|
|
||||||
|
require.Equal(t, 100, len([]rune(result)))
|
||||||
|
// 验证截断结果是有效的 UTF-8(每个中文字符 3 字节)
|
||||||
|
require.Equal(t, 300, len(result))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTruncateSearchByRune_MixedASCIIAndMultibyte(t *testing.T) {
|
||||||
|
// 50 个 ASCII + 51 个中文 = 101 个 rune
|
||||||
|
input := ""
|
||||||
|
for i := 0; i < 50; i++ {
|
||||||
|
input += "a"
|
||||||
|
}
|
||||||
|
for i := 0; i < 51; i++ {
|
||||||
|
input += "中"
|
||||||
|
}
|
||||||
|
result := truncateSearchByRune(input, 100)
|
||||||
|
|
||||||
|
runes := []rune(result)
|
||||||
|
require.Equal(t, 100, len(runes))
|
||||||
|
// 前 50 个应该是 'a',后 50 个应该是 '中'
|
||||||
|
require.Equal(t, 'a', runes[0])
|
||||||
|
require.Equal(t, 'a', runes[49])
|
||||||
|
require.Equal(t, '中', runes[50])
|
||||||
|
require.Equal(t, '中', runes[99])
|
||||||
|
}
|
||||||
@@ -70,8 +70,8 @@ func (h *UserHandler) List(c *gin.Context) {
|
|||||||
search := c.Query("search")
|
search := c.Query("search")
|
||||||
// 标准化和验证 search 参数
|
// 标准化和验证 search 参数
|
||||||
search = strings.TrimSpace(search)
|
search = strings.TrimSpace(search)
|
||||||
if len(search) > 100 {
|
if runes := []rune(search); len(runes) > 100 {
|
||||||
search = search[:100]
|
search = string(runes[:100])
|
||||||
}
|
}
|
||||||
|
|
||||||
filters := service.UserListFilters{
|
filters := service.UserListFilters{
|
||||||
|
|||||||
@@ -210,7 +210,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, "") // Gemini 不使用会话限制
|
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, "") // Gemini 不使用会话限制
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if len(failedAccountIDs) == 0 {
|
if len(failedAccountIDs) == 0 {
|
||||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
log.Printf("[Gateway] SelectAccount failed: %v", err)
|
||||||
|
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if lastFailoverErr != nil {
|
if lastFailoverErr != nil {
|
||||||
@@ -258,12 +259,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
if err == nil && canWait {
|
if err == nil && canWait {
|
||||||
accountWaitCounted = true
|
accountWaitCounted = true
|
||||||
}
|
}
|
||||||
// Ensure the wait counter is decremented if we exit before acquiring the slot.
|
releaseWait := func() {
|
||||||
defer func() {
|
|
||||||
if accountWaitCounted {
|
if accountWaitCounted {
|
||||||
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
||||||
|
accountWaitCounted = false
|
||||||
}
|
}
|
||||||
}()
|
}
|
||||||
|
|
||||||
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
|
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
|
||||||
c,
|
c,
|
||||||
@@ -275,14 +276,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Account concurrency acquire failed: %v", err)
|
log.Printf("Account concurrency acquire failed: %v", err)
|
||||||
|
releaseWait()
|
||||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Slot acquired: no longer waiting in queue.
|
// Slot acquired: no longer waiting in queue.
|
||||||
if accountWaitCounted {
|
releaseWait()
|
||||||
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
|
||||||
accountWaitCounted = false
|
|
||||||
}
|
|
||||||
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil {
|
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil {
|
||||||
log.Printf("Bind sticky session failed: %v", err)
|
log.Printf("Bind sticky session failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -367,7 +366,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, failedAccountIDs, parsedReq.MetadataUserID)
|
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, failedAccountIDs, parsedReq.MetadataUserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if len(failedAccountIDs) == 0 {
|
if len(failedAccountIDs) == 0 {
|
||||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
log.Printf("[Gateway] SelectAccount failed: %v", err)
|
||||||
|
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if lastFailoverErr != nil {
|
if lastFailoverErr != nil {
|
||||||
@@ -415,11 +415,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
if err == nil && canWait {
|
if err == nil && canWait {
|
||||||
accountWaitCounted = true
|
accountWaitCounted = true
|
||||||
}
|
}
|
||||||
defer func() {
|
releaseWait := func() {
|
||||||
if accountWaitCounted {
|
if accountWaitCounted {
|
||||||
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
||||||
|
accountWaitCounted = false
|
||||||
}
|
}
|
||||||
}()
|
}
|
||||||
|
|
||||||
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
|
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
|
||||||
c,
|
c,
|
||||||
@@ -431,13 +432,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Account concurrency acquire failed: %v", err)
|
log.Printf("Account concurrency acquire failed: %v", err)
|
||||||
|
releaseWait()
|
||||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if accountWaitCounted {
|
// Slot acquired: no longer waiting in queue.
|
||||||
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
releaseWait()
|
||||||
accountWaitCounted = false
|
|
||||||
}
|
|
||||||
if err := h.gatewayService.BindStickySession(c.Request.Context(), currentAPIKey.GroupID, sessionKey, account.ID); err != nil {
|
if err := h.gatewayService.BindStickySession(c.Request.Context(), currentAPIKey.GroupID, sessionKey, account.ID); err != nil {
|
||||||
log.Printf("Bind sticky session failed: %v", err)
|
log.Printf("Bind sticky session failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -930,7 +930,8 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
|
|||||||
// 选择支持该模型的账号
|
// 选择支持该模型的账号
|
||||||
account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, parsedReq.Model)
|
account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, parsedReq.Model)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error())
|
log.Printf("[Gateway] SelectAccountForModel failed: %v", err)
|
||||||
|
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
setOpsSelectedAccount(c, account.ID)
|
setOpsSelectedAccount(c, account.ID)
|
||||||
@@ -1143,7 +1144,8 @@ func billingErrorDetails(err error) (status int, code, message string) {
|
|||||||
}
|
}
|
||||||
msg := pkgerrors.Message(err)
|
msg := pkgerrors.Message(err)
|
||||||
if msg == "" {
|
if msg == "" {
|
||||||
msg = err.Error()
|
log.Printf("[Gateway] billing error details: %v", err)
|
||||||
|
msg = "Billing error"
|
||||||
}
|
}
|
||||||
return http.StatusForbidden, "billing_error", msg
|
return http.StatusForbidden, "billing_error", msg
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -216,7 +216,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("[OpenAI Handler] SelectAccount failed: %v", err)
|
log.Printf("[OpenAI Handler] SelectAccount failed: %v", err)
|
||||||
if len(failedAccountIDs) == 0 {
|
if len(failedAccountIDs) == 0 {
|
||||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
log.Printf("[OpenAI Gateway] SelectAccount failed: %v", err)
|
||||||
|
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if lastFailoverErr != nil {
|
if lastFailoverErr != nil {
|
||||||
@@ -249,11 +250,12 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
|||||||
if err == nil && canWait {
|
if err == nil && canWait {
|
||||||
accountWaitCounted = true
|
accountWaitCounted = true
|
||||||
}
|
}
|
||||||
defer func() {
|
releaseWait := func() {
|
||||||
if accountWaitCounted {
|
if accountWaitCounted {
|
||||||
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
||||||
|
accountWaitCounted = false
|
||||||
}
|
}
|
||||||
}()
|
}
|
||||||
|
|
||||||
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
|
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
|
||||||
c,
|
c,
|
||||||
@@ -265,13 +267,12 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
|||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Account concurrency acquire failed: %v", err)
|
log.Printf("Account concurrency acquire failed: %v", err)
|
||||||
|
releaseWait()
|
||||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if accountWaitCounted {
|
// Slot acquired: no longer waiting in queue.
|
||||||
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
releaseWait()
|
||||||
accountWaitCounted = false
|
|
||||||
}
|
|
||||||
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionHash, account.ID); err != nil {
|
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionHash, account.ID); err != nil {
|
||||||
log.Printf("Bind sticky session failed: %v", err)
|
log.Printf("Bind sticky session failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -392,7 +392,7 @@ func (h *UsageHandler) DashboardAPIKeysUsage(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
stats, err := h.usageService.GetBatchAPIKeyUsageStats(c.Request.Context(), validAPIKeyIDs)
|
stats, err := h.usageService.GetBatchAPIKeyUsageStats(c.Request.Context(), validAPIKeyIDs, time.Time{}, time.Time{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package antigravity
|
package antigravity
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/rand"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
@@ -341,12 +342,16 @@ func buildGroundingText(grounding *GeminiGroundingMetadata) string {
|
|||||||
return builder.String()
|
return builder.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
// generateRandomID 生成随机 ID
|
// generateRandomID 生成密码学安全的随机 ID
|
||||||
func generateRandomID() string {
|
func generateRandomID() string {
|
||||||
const chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
const chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||||
result := make([]byte, 12)
|
result := make([]byte, 12)
|
||||||
for i := range result {
|
randBytes := make([]byte, 12)
|
||||||
result[i] = chars[i%len(chars)]
|
if _, err := rand.Read(randBytes); err != nil {
|
||||||
|
panic("crypto/rand unavailable: " + err.Error())
|
||||||
|
}
|
||||||
|
for i, b := range randBytes {
|
||||||
|
result[i] = chars[int(b)%len(chars)]
|
||||||
}
|
}
|
||||||
return string(result)
|
return string(result)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,36 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package antigravity
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGenerateRandomID_Uniqueness(t *testing.T) {
|
||||||
|
seen := make(map[string]struct{}, 100)
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
id := generateRandomID()
|
||||||
|
require.Len(t, id, 12, "ID 长度应为 12")
|
||||||
|
_, dup := seen[id]
|
||||||
|
require.False(t, dup, "第 %d 次调用生成了重复 ID: %s", i, id)
|
||||||
|
seen[id] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateRandomID_Charset(t *testing.T) {
|
||||||
|
const validChars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||||
|
validSet := make(map[byte]struct{}, len(validChars))
|
||||||
|
for i := 0; i < len(validChars); i++ {
|
||||||
|
validSet[validChars[i]] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < 50; i++ {
|
||||||
|
id := generateRandomID()
|
||||||
|
for j := 0; j < len(id); j++ {
|
||||||
|
_, ok := validSet[id[j]]
|
||||||
|
require.True(t, ok, "ID 包含非法字符: %c (ID=%s)", id[j], id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -54,29 +54,34 @@ func normalizeIP(ip string) string {
|
|||||||
return ip
|
return ip
|
||||||
}
|
}
|
||||||
|
|
||||||
// isPrivateIP 检查 IP 是否为私有地址。
|
// privateNets 预编译私有 IP CIDR 块,避免每次调用 isPrivateIP 时重复解析
|
||||||
func isPrivateIP(ipStr string) bool {
|
var privateNets []*net.IPNet
|
||||||
ip := net.ParseIP(ipStr)
|
|
||||||
if ip == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// 私有 IP 范围
|
func init() {
|
||||||
privateBlocks := []string{
|
for _, cidr := range []string{
|
||||||
"10.0.0.0/8",
|
"10.0.0.0/8",
|
||||||
"172.16.0.0/12",
|
"172.16.0.0/12",
|
||||||
"192.168.0.0/16",
|
"192.168.0.0/16",
|
||||||
"127.0.0.0/8",
|
"127.0.0.0/8",
|
||||||
"::1/128",
|
"::1/128",
|
||||||
"fc00::/7",
|
"fc00::/7",
|
||||||
}
|
} {
|
||||||
|
_, block, err := net.ParseCIDR(cidr)
|
||||||
for _, block := range privateBlocks {
|
|
||||||
_, cidr, err := net.ParseCIDR(block)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
panic("invalid CIDR: " + cidr)
|
||||||
}
|
}
|
||||||
if cidr.Contains(ip) {
|
privateNets = append(privateNets, block)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// isPrivateIP 检查 IP 是否为私有地址。
|
||||||
|
func isPrivateIP(ipStr string) bool {
|
||||||
|
ip := net.ParseIP(ipStr)
|
||||||
|
if ip == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, block := range privateNets {
|
||||||
|
if block.Contains(ip) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
51
backend/internal/pkg/ip/ip_test.go
Normal file
51
backend/internal/pkg/ip/ip_test.go
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package ip
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestIsPrivateIP(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
ip string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
// 私有 IPv4
|
||||||
|
{"10.x 私有地址", "10.0.0.1", true},
|
||||||
|
{"10.x 私有地址段末", "10.255.255.255", true},
|
||||||
|
{"172.16.x 私有地址", "172.16.0.1", true},
|
||||||
|
{"172.31.x 私有地址", "172.31.255.255", true},
|
||||||
|
{"192.168.x 私有地址", "192.168.1.1", true},
|
||||||
|
{"127.0.0.1 本地回环", "127.0.0.1", true},
|
||||||
|
{"127.x 回环段", "127.255.255.255", true},
|
||||||
|
|
||||||
|
// 公网 IPv4
|
||||||
|
{"8.8.8.8 公网 DNS", "8.8.8.8", false},
|
||||||
|
{"1.1.1.1 公网", "1.1.1.1", false},
|
||||||
|
{"172.15.255.255 非私有", "172.15.255.255", false},
|
||||||
|
{"172.32.0.0 非私有", "172.32.0.0", false},
|
||||||
|
{"11.0.0.1 公网", "11.0.0.1", false},
|
||||||
|
|
||||||
|
// IPv6
|
||||||
|
{"::1 IPv6 回环", "::1", true},
|
||||||
|
{"fc00:: IPv6 私有", "fc00::1", true},
|
||||||
|
{"fd00:: IPv6 私有", "fd00::1", true},
|
||||||
|
{"2001:db8::1 IPv6 公网", "2001:db8::1", false},
|
||||||
|
|
||||||
|
// 无效输入
|
||||||
|
{"空字符串", "", false},
|
||||||
|
{"非法字符串", "not-an-ip", false},
|
||||||
|
{"不完整 IP", "192.168", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
got := isPrivateIP(tc.ip)
|
||||||
|
require.Equal(t, tc.expected, got, "isPrivateIP(%q)", tc.ip)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -286,7 +286,7 @@ func (d *SOCKS5ProxyDialer) DialTLSContext(ctx context.Context, network, addr st
|
|||||||
return nil, fmt.Errorf("apply TLS preset: %w", err)
|
return nil, fmt.Errorf("apply TLS preset: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := tlsConn.Handshake(); err != nil {
|
if err := tlsConn.HandshakeContext(ctx); err != nil {
|
||||||
slog.Debug("tls_fingerprint_socks5_handshake_failed", "error", err)
|
slog.Debug("tls_fingerprint_socks5_handshake_failed", "error", err)
|
||||||
_ = conn.Close()
|
_ = conn.Close()
|
||||||
return nil, fmt.Errorf("TLS handshake failed: %w", err)
|
return nil, fmt.Errorf("TLS handshake failed: %w", err)
|
||||||
|
|||||||
@@ -375,36 +375,19 @@ func (r *apiKeyRepository) ListKeysByGroupID(ctx context.Context, groupID int64)
|
|||||||
return keys, nil
|
return keys, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// IncrementQuotaUsed atomically increments the quota_used field and returns the new value
|
// IncrementQuotaUsed 使用 Ent 原子递增 quota_used 字段并返回新值
|
||||||
func (r *apiKeyRepository) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error) {
|
func (r *apiKeyRepository) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error) {
|
||||||
// Use raw SQL for atomic increment to avoid race conditions
|
updated, err := r.client.APIKey.UpdateOneID(id).
|
||||||
// First get current value
|
Where(apikey.DeletedAtIsNil()).
|
||||||
m, err := r.activeQuery().
|
AddQuotaUsed(amount).
|
||||||
Where(apikey.IDEQ(id)).
|
Save(ctx)
|
||||||
Select(apikey.FieldQuotaUsed).
|
|
||||||
Only(ctx)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if dbent.IsNotFound(err) {
|
if dbent.IsNotFound(err) {
|
||||||
return 0, service.ErrAPIKeyNotFound
|
return 0, service.ErrAPIKeyNotFound
|
||||||
}
|
}
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
return updated.QuotaUsed, nil
|
||||||
newValue := m.QuotaUsed + amount
|
|
||||||
|
|
||||||
// Update with new value
|
|
||||||
affected, err := r.client.APIKey.Update().
|
|
||||||
Where(apikey.IDEQ(id), apikey.DeletedAtIsNil()).
|
|
||||||
SetQuotaUsed(newValue).
|
|
||||||
Save(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
if affected == 0 {
|
|
||||||
return 0, service.ErrAPIKeyNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
return newValue, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey {
|
func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey {
|
||||||
|
|||||||
@@ -4,11 +4,14 @@ package repository
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/stretchr/testify/suite"
|
"github.com/stretchr/testify/suite"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -383,3 +386,87 @@ func (s *APIKeyRepoSuite) mustCreateApiKey(userID int64, key, name string, group
|
|||||||
s.Require().NoError(s.repo.Create(s.ctx, k), "create api key")
|
s.Require().NoError(s.repo.Create(s.ctx, k), "create api key")
|
||||||
return k
|
return k
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// --- IncrementQuotaUsed ---
|
||||||
|
|
||||||
|
func (s *APIKeyRepoSuite) TestIncrementQuotaUsed_Basic() {
|
||||||
|
user := s.mustCreateUser("incr-basic@test.com")
|
||||||
|
key := s.mustCreateApiKey(user.ID, "sk-incr-basic", "Incr", nil)
|
||||||
|
|
||||||
|
newQuota, err := s.repo.IncrementQuotaUsed(s.ctx, key.ID, 1.5)
|
||||||
|
s.Require().NoError(err, "IncrementQuotaUsed")
|
||||||
|
s.Require().Equal(1.5, newQuota, "第一次递增后应为 1.5")
|
||||||
|
|
||||||
|
newQuota, err = s.repo.IncrementQuotaUsed(s.ctx, key.ID, 2.5)
|
||||||
|
s.Require().NoError(err, "IncrementQuotaUsed second")
|
||||||
|
s.Require().Equal(4.0, newQuota, "第二次递增后应为 4.0")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *APIKeyRepoSuite) TestIncrementQuotaUsed_NotFound() {
|
||||||
|
_, err := s.repo.IncrementQuotaUsed(s.ctx, 999999, 1.0)
|
||||||
|
s.Require().ErrorIs(err, service.ErrAPIKeyNotFound, "不存在的 key 应返回 ErrAPIKeyNotFound")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *APIKeyRepoSuite) TestIncrementQuotaUsed_DeletedKey() {
|
||||||
|
user := s.mustCreateUser("incr-deleted@test.com")
|
||||||
|
key := s.mustCreateApiKey(user.ID, "sk-incr-del", "Deleted", nil)
|
||||||
|
|
||||||
|
s.Require().NoError(s.repo.Delete(s.ctx, key.ID), "Delete")
|
||||||
|
|
||||||
|
_, err := s.repo.IncrementQuotaUsed(s.ctx, key.ID, 1.0)
|
||||||
|
s.Require().ErrorIs(err, service.ErrAPIKeyNotFound, "已删除的 key 应返回 ErrAPIKeyNotFound")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestIncrementQuotaUsed_Concurrent 使用真实数据库验证并发原子性。
|
||||||
|
// 注意:此测试使用 testEntClient(非事务隔离),数据会真正写入数据库。
|
||||||
|
func TestIncrementQuotaUsed_Concurrent(t *testing.T) {
|
||||||
|
client := testEntClient(t)
|
||||||
|
repo := NewAPIKeyRepository(client).(*apiKeyRepository)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// 创建测试用户和 API Key
|
||||||
|
u, err := client.User.Create().
|
||||||
|
SetEmail("concurrent-incr-" + time.Now().Format(time.RFC3339Nano) + "@test.com").
|
||||||
|
SetPasswordHash("hash").
|
||||||
|
SetStatus(service.StatusActive).
|
||||||
|
SetRole(service.RoleUser).
|
||||||
|
Save(ctx)
|
||||||
|
require.NoError(t, err, "create user")
|
||||||
|
|
||||||
|
k := &service.APIKey{
|
||||||
|
UserID: u.ID,
|
||||||
|
Key: "sk-concurrent-" + time.Now().Format(time.RFC3339Nano),
|
||||||
|
Name: "Concurrent",
|
||||||
|
Status: service.StatusActive,
|
||||||
|
}
|
||||||
|
require.NoError(t, repo.Create(ctx, k), "create api key")
|
||||||
|
t.Cleanup(func() {
|
||||||
|
_ = client.APIKey.DeleteOneID(k.ID).Exec(ctx)
|
||||||
|
_ = client.User.DeleteOneID(u.ID).Exec(ctx)
|
||||||
|
})
|
||||||
|
|
||||||
|
// 10 个 goroutine 各递增 1.0,总计应为 10.0
|
||||||
|
const goroutines = 10
|
||||||
|
const increment = 1.0
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
errs := make([]error, goroutines)
|
||||||
|
|
||||||
|
for i := 0; i < goroutines; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(idx int) {
|
||||||
|
defer wg.Done()
|
||||||
|
_, errs[idx] = repo.IncrementQuotaUsed(ctx, k.ID, increment)
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
for i, e := range errs {
|
||||||
|
require.NoError(t, e, "goroutine %d failed", i)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证最终结果
|
||||||
|
got, err := repo.GetByID(ctx, k.ID)
|
||||||
|
require.NoError(t, err, "GetByID")
|
||||||
|
require.Equal(t, float64(goroutines)*increment, got.QuotaUsed,
|
||||||
|
"并发递增后总和应为 %v,实际为 %v", float64(goroutines)*increment, got.QuotaUsed)
|
||||||
|
}
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
|
"math/rand"
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -16,8 +17,15 @@ const (
|
|||||||
billingBalanceKeyPrefix = "billing:balance:"
|
billingBalanceKeyPrefix = "billing:balance:"
|
||||||
billingSubKeyPrefix = "billing:sub:"
|
billingSubKeyPrefix = "billing:sub:"
|
||||||
billingCacheTTL = 5 * time.Minute
|
billingCacheTTL = 5 * time.Minute
|
||||||
|
billingCacheJitter = 30 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// jitteredTTL 返回带随机抖动的 TTL,防止缓存雪崩
|
||||||
|
func jitteredTTL() time.Duration {
|
||||||
|
jitter := time.Duration(rand.Int63n(int64(2*billingCacheJitter))) - billingCacheJitter
|
||||||
|
return billingCacheTTL + jitter
|
||||||
|
}
|
||||||
|
|
||||||
// billingBalanceKey generates the Redis key for user balance cache.
|
// billingBalanceKey generates the Redis key for user balance cache.
|
||||||
func billingBalanceKey(userID int64) string {
|
func billingBalanceKey(userID int64) string {
|
||||||
return fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
return fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||||
@@ -82,14 +90,15 @@ func (c *billingCache) GetUserBalance(ctx context.Context, userID int64) (float6
|
|||||||
|
|
||||||
func (c *billingCache) SetUserBalance(ctx context.Context, userID int64, balance float64) error {
|
func (c *billingCache) SetUserBalance(ctx context.Context, userID int64, balance float64) error {
|
||||||
key := billingBalanceKey(userID)
|
key := billingBalanceKey(userID)
|
||||||
return c.rdb.Set(ctx, key, balance, billingCacheTTL).Err()
|
return c.rdb.Set(ctx, key, balance, jitteredTTL()).Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *billingCache) DeductUserBalance(ctx context.Context, userID int64, amount float64) error {
|
func (c *billingCache) DeductUserBalance(ctx context.Context, userID int64, amount float64) error {
|
||||||
key := billingBalanceKey(userID)
|
key := billingBalanceKey(userID)
|
||||||
_, err := deductBalanceScript.Run(ctx, c.rdb, []string{key}, amount, int(billingCacheTTL.Seconds())).Result()
|
_, err := deductBalanceScript.Run(ctx, c.rdb, []string{key}, amount, int(jitteredTTL().Seconds())).Result()
|
||||||
if err != nil && !errors.Is(err, redis.Nil) {
|
if err != nil && !errors.Is(err, redis.Nil) {
|
||||||
log.Printf("Warning: deduct balance cache failed for user %d: %v", userID, err)
|
log.Printf("Warning: deduct balance cache failed for user %d: %v", userID, err)
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -163,16 +172,17 @@ func (c *billingCache) SetSubscriptionCache(ctx context.Context, userID, groupID
|
|||||||
|
|
||||||
pipe := c.rdb.Pipeline()
|
pipe := c.rdb.Pipeline()
|
||||||
pipe.HSet(ctx, key, fields)
|
pipe.HSet(ctx, key, fields)
|
||||||
pipe.Expire(ctx, key, billingCacheTTL)
|
pipe.Expire(ctx, key, jitteredTTL())
|
||||||
_, err := pipe.Exec(ctx)
|
_, err := pipe.Exec(ctx)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *billingCache) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error {
|
func (c *billingCache) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error {
|
||||||
key := billingSubKey(userID, groupID)
|
key := billingSubKey(userID, groupID)
|
||||||
_, err := updateSubUsageScript.Run(ctx, c.rdb, []string{key}, cost, int(billingCacheTTL.Seconds())).Result()
|
_, err := updateSubUsageScript.Run(ctx, c.rdb, []string{key}, cost, int(jitteredTTL().Seconds())).Result()
|
||||||
if err != nil && !errors.Is(err, redis.Nil) {
|
if err != nil && !errors.Is(err, redis.Nil) {
|
||||||
log.Printf("Warning: update subscription usage cache failed for user %d group %d: %v", userID, groupID, err)
|
log.Printf("Warning: update subscription usage cache failed for user %d group %d: %v", userID, groupID, err)
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -278,6 +278,90 @@ func (s *BillingCacheSuite) TestSubscriptionCache() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestDeductUserBalance_ErrorPropagation 验证 P2-12 修复:
|
||||||
|
// Redis 真实错误应传播,key 不存在(redis.Nil)应返回 nil。
|
||||||
|
func (s *BillingCacheSuite) TestDeductUserBalance_ErrorPropagation() {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fn func(ctx context.Context, cache service.BillingCache)
|
||||||
|
expectErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "key_not_exists_returns_nil",
|
||||||
|
fn: func(ctx context.Context, cache service.BillingCache) {
|
||||||
|
// key 不存在时,Lua 脚本返回 0(redis.Nil),应返回 nil 而非错误
|
||||||
|
err := cache.DeductUserBalance(ctx, 99999, 1.0)
|
||||||
|
require.NoError(s.T(), err, "DeductUserBalance on non-existent key should return nil")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "existing_key_deducts_successfully",
|
||||||
|
fn: func(ctx context.Context, cache service.BillingCache) {
|
||||||
|
require.NoError(s.T(), cache.SetUserBalance(ctx, 200, 50.0))
|
||||||
|
err := cache.DeductUserBalance(ctx, 200, 10.0)
|
||||||
|
require.NoError(s.T(), err, "DeductUserBalance should succeed")
|
||||||
|
|
||||||
|
bal, err := cache.GetUserBalance(ctx, 200)
|
||||||
|
require.NoError(s.T(), err)
|
||||||
|
require.Equal(s.T(), 40.0, bal, "余额应为 40.0")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "cancelled_context_propagates_error",
|
||||||
|
fn: func(ctx context.Context, cache service.BillingCache) {
|
||||||
|
require.NoError(s.T(), cache.SetUserBalance(ctx, 201, 50.0))
|
||||||
|
|
||||||
|
cancelCtx, cancel := context.WithCancel(ctx)
|
||||||
|
cancel() // 立即取消
|
||||||
|
|
||||||
|
err := cache.DeductUserBalance(cancelCtx, 201, 10.0)
|
||||||
|
require.Error(s.T(), err, "cancelled context should propagate error")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
s.Run(tt.name, func() {
|
||||||
|
rdb := testRedis(s.T())
|
||||||
|
cache := NewBillingCache(rdb)
|
||||||
|
ctx := context.Background()
|
||||||
|
tt.fn(ctx, cache)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestUpdateSubscriptionUsage_ErrorPropagation 验证 P2-12 修复:
|
||||||
|
// Redis 真实错误应传播,key 不存在(redis.Nil)应返回 nil。
|
||||||
|
func (s *BillingCacheSuite) TestUpdateSubscriptionUsage_ErrorPropagation() {
|
||||||
|
s.Run("key_not_exists_returns_nil", func() {
|
||||||
|
rdb := testRedis(s.T())
|
||||||
|
cache := NewBillingCache(rdb)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
err := cache.UpdateSubscriptionUsage(ctx, 88888, 77777, 1.0)
|
||||||
|
require.NoError(s.T(), err, "UpdateSubscriptionUsage on non-existent key should return nil")
|
||||||
|
})
|
||||||
|
|
||||||
|
s.Run("cancelled_context_propagates_error", func() {
|
||||||
|
rdb := testRedis(s.T())
|
||||||
|
cache := NewBillingCache(rdb)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
data := &service.SubscriptionCacheData{
|
||||||
|
Status: "active",
|
||||||
|
ExpiresAt: time.Now().Add(1 * time.Hour),
|
||||||
|
Version: 1,
|
||||||
|
}
|
||||||
|
require.NoError(s.T(), cache.SetSubscriptionCache(ctx, 301, 401, data))
|
||||||
|
|
||||||
|
cancelCtx, cancel := context.WithCancel(ctx)
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
err := cache.UpdateSubscriptionUsage(cancelCtx, 301, 401, 1.0)
|
||||||
|
require.Error(s.T(), err, "cancelled context should propagate error")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestBillingCacheSuite(t *testing.T) {
|
func TestBillingCacheSuite(t *testing.T) {
|
||||||
suite.Run(t, new(BillingCacheSuite))
|
suite.Run(t, new(BillingCacheSuite))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ package repository
|
|||||||
import (
|
import (
|
||||||
"math"
|
"math"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
@@ -85,3 +86,26 @@ func TestBillingSubKey(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestJitteredTTL(t *testing.T) {
|
||||||
|
const (
|
||||||
|
minTTL = 4*time.Minute + 30*time.Second // 270s = 5min - 30s
|
||||||
|
maxTTL = 5*time.Minute + 30*time.Second // 330s = 5min + 30s
|
||||||
|
)
|
||||||
|
|
||||||
|
for i := 0; i < 200; i++ {
|
||||||
|
ttl := jitteredTTL()
|
||||||
|
require.GreaterOrEqual(t, ttl, minTTL, "jitteredTTL() 返回值低于下限: %v", ttl)
|
||||||
|
require.LessOrEqual(t, ttl, maxTTL, "jitteredTTL() 返回值超过上限: %v", ttl)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJitteredTTL_HasVariation(t *testing.T) {
|
||||||
|
// 多次调用应该产生不同的值(验证抖动存在)
|
||||||
|
seen := make(map[time.Duration]struct{}, 50)
|
||||||
|
for i := 0; i < 50; i++ {
|
||||||
|
seen[jitteredTTL()] = struct{}{}
|
||||||
|
}
|
||||||
|
// 50 次调用中应该至少有 2 个不同的值
|
||||||
|
require.Greater(t, len(seen), 1, "jitteredTTL() 应产生不同的 TTL 值")
|
||||||
|
}
|
||||||
|
|||||||
@@ -183,7 +183,7 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination
|
|||||||
q = q.Where(group.IsExclusiveEQ(*isExclusive))
|
q = q.Where(group.IsExclusiveEQ(*isExclusive))
|
||||||
}
|
}
|
||||||
|
|
||||||
total, err := q.Count(ctx)
|
total, err := q.Clone().Count(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -132,7 +132,7 @@ func (r *promoCodeRepository) ListWithFilters(ctx context.Context, params pagina
|
|||||||
q = q.Where(promocode.CodeContainsFold(search))
|
q = q.Where(promocode.CodeContainsFold(search))
|
||||||
}
|
}
|
||||||
|
|
||||||
total, err := q.Count(ctx)
|
total, err := q.Clone().Count(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
@@ -187,7 +187,7 @@ func (r *promoCodeRepository) ListUsagesByPromoCode(ctx context.Context, promoCo
|
|||||||
q := r.client.PromoCodeUsage.Query().
|
q := r.client.PromoCodeUsage.Query().
|
||||||
Where(promocodeusage.PromoCodeIDEQ(promoCodeID))
|
Where(promocodeusage.PromoCodeIDEQ(promoCodeID))
|
||||||
|
|
||||||
total, err := q.Count(ctx)
|
total, err := q.Clone().Count(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -24,6 +24,22 @@ import (
|
|||||||
|
|
||||||
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, reasoning_effort, created_at"
|
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, reasoning_effort, created_at"
|
||||||
|
|
||||||
|
// dateFormatWhitelist 将 granularity 参数映射为 PostgreSQL TO_CHAR 格式字符串,防止外部输入直接拼入 SQL
|
||||||
|
var dateFormatWhitelist = map[string]string{
|
||||||
|
"hour": "YYYY-MM-DD HH24:00",
|
||||||
|
"day": "YYYY-MM-DD",
|
||||||
|
"week": "IYYY-IW",
|
||||||
|
"month": "YYYY-MM",
|
||||||
|
}
|
||||||
|
|
||||||
|
// safeDateFormat 根据白名单获取 dateFormat,未匹配时返回默认值
|
||||||
|
func safeDateFormat(granularity string) string {
|
||||||
|
if f, ok := dateFormatWhitelist[granularity]; ok {
|
||||||
|
return f
|
||||||
|
}
|
||||||
|
return "YYYY-MM-DD"
|
||||||
|
}
|
||||||
|
|
||||||
type usageLogRepository struct {
|
type usageLogRepository struct {
|
||||||
client *dbent.Client
|
client *dbent.Client
|
||||||
sql sqlExecutor
|
sql sqlExecutor
|
||||||
@@ -564,7 +580,7 @@ func (r *usageLogRepository) ListByAccount(ctx context.Context, accountID int64,
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *usageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
func (r *usageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||||
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE user_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC"
|
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE user_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC LIMIT 10000"
|
||||||
logs, err := r.queryUsageLogs(ctx, query, userID, startTime, endTime)
|
logs, err := r.queryUsageLogs(ctx, query, userID, startTime, endTime)
|
||||||
return logs, nil, err
|
return logs, nil, err
|
||||||
}
|
}
|
||||||
@@ -810,19 +826,19 @@ func resolveUsageStatsTimezone() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *usageLogRepository) ListByAPIKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
func (r *usageLogRepository) ListByAPIKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||||
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE api_key_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC"
|
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE api_key_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC LIMIT 10000"
|
||||||
logs, err := r.queryUsageLogs(ctx, query, apiKeyID, startTime, endTime)
|
logs, err := r.queryUsageLogs(ctx, query, apiKeyID, startTime, endTime)
|
||||||
return logs, nil, err
|
return logs, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *usageLogRepository) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
func (r *usageLogRepository) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||||
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE account_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC"
|
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE account_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC LIMIT 10000"
|
||||||
logs, err := r.queryUsageLogs(ctx, query, accountID, startTime, endTime)
|
logs, err := r.queryUsageLogs(ctx, query, accountID, startTime, endTime)
|
||||||
return logs, nil, err
|
return logs, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *usageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
func (r *usageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||||
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE model = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC"
|
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE model = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC LIMIT 10000"
|
||||||
logs, err := r.queryUsageLogs(ctx, query, modelName, startTime, endTime)
|
logs, err := r.queryUsageLogs(ctx, query, modelName, startTime, endTime)
|
||||||
return logs, nil, err
|
return logs, nil, err
|
||||||
}
|
}
|
||||||
@@ -908,10 +924,7 @@ type APIKeyUsageTrendPoint = usagestats.APIKeyUsageTrendPoint
|
|||||||
|
|
||||||
// GetAPIKeyUsageTrend returns usage trend data grouped by API key and date
|
// GetAPIKeyUsageTrend returns usage trend data grouped by API key and date
|
||||||
func (r *usageLogRepository) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) (results []APIKeyUsageTrendPoint, err error) {
|
func (r *usageLogRepository) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) (results []APIKeyUsageTrendPoint, err error) {
|
||||||
dateFormat := "YYYY-MM-DD"
|
dateFormat := safeDateFormat(granularity)
|
||||||
if granularity == "hour" {
|
|
||||||
dateFormat = "YYYY-MM-DD HH24:00"
|
|
||||||
}
|
|
||||||
|
|
||||||
query := fmt.Sprintf(`
|
query := fmt.Sprintf(`
|
||||||
WITH top_keys AS (
|
WITH top_keys AS (
|
||||||
@@ -966,10 +979,7 @@ func (r *usageLogRepository) GetAPIKeyUsageTrend(ctx context.Context, startTime,
|
|||||||
|
|
||||||
// GetUserUsageTrend returns usage trend data grouped by user and date
|
// GetUserUsageTrend returns usage trend data grouped by user and date
|
||||||
func (r *usageLogRepository) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) (results []UserUsageTrendPoint, err error) {
|
func (r *usageLogRepository) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) (results []UserUsageTrendPoint, err error) {
|
||||||
dateFormat := "YYYY-MM-DD"
|
dateFormat := safeDateFormat(granularity)
|
||||||
if granularity == "hour" {
|
|
||||||
dateFormat = "YYYY-MM-DD HH24:00"
|
|
||||||
}
|
|
||||||
|
|
||||||
query := fmt.Sprintf(`
|
query := fmt.Sprintf(`
|
||||||
WITH top_users AS (
|
WITH top_users AS (
|
||||||
@@ -1228,10 +1238,7 @@ func (r *usageLogRepository) GetAPIKeyDashboardStats(ctx context.Context, apiKey
|
|||||||
|
|
||||||
// GetUserUsageTrendByUserID 获取指定用户的使用趋势
|
// GetUserUsageTrendByUserID 获取指定用户的使用趋势
|
||||||
func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) (results []TrendDataPoint, err error) {
|
func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) (results []TrendDataPoint, err error) {
|
||||||
dateFormat := "YYYY-MM-DD"
|
dateFormat := safeDateFormat(granularity)
|
||||||
if granularity == "hour" {
|
|
||||||
dateFormat = "YYYY-MM-DD HH24:00"
|
|
||||||
}
|
|
||||||
|
|
||||||
query := fmt.Sprintf(`
|
query := fmt.Sprintf(`
|
||||||
SELECT
|
SELECT
|
||||||
@@ -1369,13 +1376,22 @@ type UsageStats = usagestats.UsageStats
|
|||||||
// BatchUserUsageStats represents usage stats for a single user
|
// BatchUserUsageStats represents usage stats for a single user
|
||||||
type BatchUserUsageStats = usagestats.BatchUserUsageStats
|
type BatchUserUsageStats = usagestats.BatchUserUsageStats
|
||||||
|
|
||||||
// GetBatchUserUsageStats gets today and total actual_cost for multiple users
|
// GetBatchUserUsageStats gets today and total actual_cost for multiple users within a time range.
|
||||||
func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*BatchUserUsageStats, error) {
|
// If startTime is zero, defaults to 30 days ago.
|
||||||
|
func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*BatchUserUsageStats, error) {
|
||||||
result := make(map[int64]*BatchUserUsageStats)
|
result := make(map[int64]*BatchUserUsageStats)
|
||||||
if len(userIDs) == 0 {
|
if len(userIDs) == 0 {
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 默认最近 30 天
|
||||||
|
if startTime.IsZero() {
|
||||||
|
startTime = time.Now().AddDate(0, 0, -30)
|
||||||
|
}
|
||||||
|
if endTime.IsZero() {
|
||||||
|
endTime = time.Now()
|
||||||
|
}
|
||||||
|
|
||||||
for _, id := range userIDs {
|
for _, id := range userIDs {
|
||||||
result[id] = &BatchUserUsageStats{UserID: id}
|
result[id] = &BatchUserUsageStats{UserID: id}
|
||||||
}
|
}
|
||||||
@@ -1383,10 +1399,10 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs
|
|||||||
query := `
|
query := `
|
||||||
SELECT user_id, COALESCE(SUM(actual_cost), 0) as total_cost
|
SELECT user_id, COALESCE(SUM(actual_cost), 0) as total_cost
|
||||||
FROM usage_logs
|
FROM usage_logs
|
||||||
WHERE user_id = ANY($1)
|
WHERE user_id = ANY($1) AND created_at >= $2 AND created_at < $3
|
||||||
GROUP BY user_id
|
GROUP BY user_id
|
||||||
`
|
`
|
||||||
rows, err := r.sql.QueryContext(ctx, query, pq.Array(userIDs))
|
rows, err := r.sql.QueryContext(ctx, query, pq.Array(userIDs), startTime, endTime)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -1443,13 +1459,22 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs
|
|||||||
// BatchAPIKeyUsageStats represents usage stats for a single API key
|
// BatchAPIKeyUsageStats represents usage stats for a single API key
|
||||||
type BatchAPIKeyUsageStats = usagestats.BatchAPIKeyUsageStats
|
type BatchAPIKeyUsageStats = usagestats.BatchAPIKeyUsageStats
|
||||||
|
|
||||||
// GetBatchAPIKeyUsageStats gets today and total actual_cost for multiple API keys
|
// GetBatchAPIKeyUsageStats gets today and total actual_cost for multiple API keys within a time range.
|
||||||
func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*BatchAPIKeyUsageStats, error) {
|
// If startTime is zero, defaults to 30 days ago.
|
||||||
|
func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*BatchAPIKeyUsageStats, error) {
|
||||||
result := make(map[int64]*BatchAPIKeyUsageStats)
|
result := make(map[int64]*BatchAPIKeyUsageStats)
|
||||||
if len(apiKeyIDs) == 0 {
|
if len(apiKeyIDs) == 0 {
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 默认最近 30 天
|
||||||
|
if startTime.IsZero() {
|
||||||
|
startTime = time.Now().AddDate(0, 0, -30)
|
||||||
|
}
|
||||||
|
if endTime.IsZero() {
|
||||||
|
endTime = time.Now()
|
||||||
|
}
|
||||||
|
|
||||||
for _, id := range apiKeyIDs {
|
for _, id := range apiKeyIDs {
|
||||||
result[id] = &BatchAPIKeyUsageStats{APIKeyID: id}
|
result[id] = &BatchAPIKeyUsageStats{APIKeyID: id}
|
||||||
}
|
}
|
||||||
@@ -1457,10 +1482,10 @@ func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKe
|
|||||||
query := `
|
query := `
|
||||||
SELECT api_key_id, COALESCE(SUM(actual_cost), 0) as total_cost
|
SELECT api_key_id, COALESCE(SUM(actual_cost), 0) as total_cost
|
||||||
FROM usage_logs
|
FROM usage_logs
|
||||||
WHERE api_key_id = ANY($1)
|
WHERE api_key_id = ANY($1) AND created_at >= $2 AND created_at < $3
|
||||||
GROUP BY api_key_id
|
GROUP BY api_key_id
|
||||||
`
|
`
|
||||||
rows, err := r.sql.QueryContext(ctx, query, pq.Array(apiKeyIDs))
|
rows, err := r.sql.QueryContext(ctx, query, pq.Array(apiKeyIDs), startTime, endTime)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -1516,10 +1541,7 @@ func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKe
|
|||||||
|
|
||||||
// GetUsageTrendWithFilters returns usage trend data with optional filters
|
// GetUsageTrendWithFilters returns usage trend data with optional filters
|
||||||
func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) (results []TrendDataPoint, err error) {
|
func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) (results []TrendDataPoint, err error) {
|
||||||
dateFormat := "YYYY-MM-DD"
|
dateFormat := safeDateFormat(granularity)
|
||||||
if granularity == "hour" {
|
|
||||||
dateFormat = "YYYY-MM-DD HH24:00"
|
|
||||||
}
|
|
||||||
|
|
||||||
query := fmt.Sprintf(`
|
query := fmt.Sprintf(`
|
||||||
SELECT
|
SELECT
|
||||||
|
|||||||
@@ -648,7 +648,7 @@ func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats() {
|
|||||||
s.createUsageLog(user1, apiKey1, account, 10, 20, 0.5, time.Now())
|
s.createUsageLog(user1, apiKey1, account, 10, 20, 0.5, time.Now())
|
||||||
s.createUsageLog(user2, apiKey2, account, 15, 25, 0.6, time.Now())
|
s.createUsageLog(user2, apiKey2, account, 15, 25, 0.6, time.Now())
|
||||||
|
|
||||||
stats, err := s.repo.GetBatchUserUsageStats(s.ctx, []int64{user1.ID, user2.ID})
|
stats, err := s.repo.GetBatchUserUsageStats(s.ctx, []int64{user1.ID, user2.ID}, time.Time{}, time.Time{})
|
||||||
s.Require().NoError(err, "GetBatchUserUsageStats")
|
s.Require().NoError(err, "GetBatchUserUsageStats")
|
||||||
s.Require().Len(stats, 2)
|
s.Require().Len(stats, 2)
|
||||||
s.Require().NotNil(stats[user1.ID])
|
s.Require().NotNil(stats[user1.ID])
|
||||||
@@ -656,7 +656,7 @@ func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats_Empty() {
|
func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats_Empty() {
|
||||||
stats, err := s.repo.GetBatchUserUsageStats(s.ctx, []int64{})
|
stats, err := s.repo.GetBatchUserUsageStats(s.ctx, []int64{}, time.Time{}, time.Time{})
|
||||||
s.Require().NoError(err)
|
s.Require().NoError(err)
|
||||||
s.Require().Empty(stats)
|
s.Require().Empty(stats)
|
||||||
}
|
}
|
||||||
@@ -672,13 +672,13 @@ func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats() {
|
|||||||
s.createUsageLog(user, apiKey1, account, 10, 20, 0.5, time.Now())
|
s.createUsageLog(user, apiKey1, account, 10, 20, 0.5, time.Now())
|
||||||
s.createUsageLog(user, apiKey2, account, 15, 25, 0.6, time.Now())
|
s.createUsageLog(user, apiKey2, account, 15, 25, 0.6, time.Now())
|
||||||
|
|
||||||
stats, err := s.repo.GetBatchAPIKeyUsageStats(s.ctx, []int64{apiKey1.ID, apiKey2.ID})
|
stats, err := s.repo.GetBatchAPIKeyUsageStats(s.ctx, []int64{apiKey1.ID, apiKey2.ID}, time.Time{}, time.Time{})
|
||||||
s.Require().NoError(err, "GetBatchAPIKeyUsageStats")
|
s.Require().NoError(err, "GetBatchAPIKeyUsageStats")
|
||||||
s.Require().Len(stats, 2)
|
s.Require().Len(stats, 2)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats_Empty() {
|
func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats_Empty() {
|
||||||
stats, err := s.repo.GetBatchAPIKeyUsageStats(s.ctx, []int64{})
|
stats, err := s.repo.GetBatchAPIKeyUsageStats(s.ctx, []int64{}, time.Time{}, time.Time{})
|
||||||
s.Require().NoError(err)
|
s.Require().NoError(err)
|
||||||
s.Require().Empty(stats)
|
s.Require().Empty(stats)
|
||||||
}
|
}
|
||||||
|
|||||||
41
backend/internal/repository/usage_log_repo_unit_test.go
Normal file
41
backend/internal/repository/usage_log_repo_unit_test.go
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSafeDateFormat(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
granularity string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
// 合法值
|
||||||
|
{"hour", "hour", "YYYY-MM-DD HH24:00"},
|
||||||
|
{"day", "day", "YYYY-MM-DD"},
|
||||||
|
{"week", "week", "IYYY-IW"},
|
||||||
|
{"month", "month", "YYYY-MM"},
|
||||||
|
|
||||||
|
// 非法值回退到默认
|
||||||
|
{"空字符串", "", "YYYY-MM-DD"},
|
||||||
|
{"未知粒度 year", "year", "YYYY-MM-DD"},
|
||||||
|
{"未知粒度 minute", "minute", "YYYY-MM-DD"},
|
||||||
|
|
||||||
|
// 恶意字符串
|
||||||
|
{"SQL 注入尝试", "'; DROP TABLE users; --", "YYYY-MM-DD"},
|
||||||
|
{"带引号", "day'", "YYYY-MM-DD"},
|
||||||
|
{"带括号", "day)", "YYYY-MM-DD"},
|
||||||
|
{"Unicode", "日", "YYYY-MM-DD"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
got := safeDateFormat(tc.granularity)
|
||||||
|
require.Equal(t, tc.expected, got, "safeDateFormat(%q)", tc.granularity)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -592,7 +592,7 @@ func newContractDeps(t *testing.T) *contractDeps {
|
|||||||
RunMode: config.RunModeStandard,
|
RunMode: config.RunModeStandard,
|
||||||
}
|
}
|
||||||
|
|
||||||
userService := service.NewUserService(userRepo, nil)
|
userService := service.NewUserService(userRepo, nil, nil)
|
||||||
apiKeyService := service.NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, nil, apiKeyCache, cfg)
|
apiKeyService := service.NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, nil, apiKeyCache, cfg)
|
||||||
|
|
||||||
usageRepo := newStubUsageLogRepo()
|
usageRepo := newStubUsageLogRepo()
|
||||||
@@ -1598,11 +1598,11 @@ func (r *stubUsageLogRepo) GetDailyStatsAggregated(ctx context.Context, userID i
|
|||||||
return nil, errors.New("not implemented")
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error) {
|
func (r *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) {
|
||||||
return nil, errors.New("not implemented")
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *stubUsageLogRepo) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
|
func (r *stubUsageLogRepo) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
|
||||||
return nil, errors.New("not implemented")
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ func TestAdminAuthJWTValidatesTokenVersion(t *testing.T) {
|
|||||||
return &clone, nil
|
return &clone, nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
userService := service.NewUserService(userRepo, nil)
|
userService := service.NewUserService(userRepo, nil, nil)
|
||||||
|
|
||||||
router := gin.New()
|
router := gin.New()
|
||||||
router.Use(gin.HandlerFunc(NewAdminAuthMiddleware(authService, userService, nil)))
|
router.Use(gin.HandlerFunc(NewAdminAuthMiddleware(authService, userService, nil)))
|
||||||
|
|||||||
@@ -72,6 +72,7 @@ func CORS(cfg config.CORSConfig) gin.HandlerFunc {
|
|||||||
|
|
||||||
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With, X-API-Key")
|
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With, X-API-Key")
|
||||||
c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE, PATCH")
|
c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE, PATCH")
|
||||||
|
c.Writer.Header().Set("Access-Control-Max-Age", "86400")
|
||||||
|
|
||||||
// 处理预检请求
|
// 处理预检请求
|
||||||
if c.Request.Method == http.MethodOptions {
|
if c.Request.Method == http.MethodOptions {
|
||||||
|
|||||||
@@ -36,8 +36,8 @@ type UsageLogRepository interface {
|
|||||||
GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error)
|
GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error)
|
||||||
GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error)
|
GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error)
|
||||||
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
|
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
|
||||||
GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error)
|
GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error)
|
||||||
GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error)
|
GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error)
|
||||||
|
|
||||||
// User dashboard stats
|
// User dashboard stats
|
||||||
GetUserDashboardStats(ctx context.Context, userID int64) (*usagestats.UserDashboardStats, error)
|
GetUserDashboardStats(ctx context.Context, userID int64) (*usagestats.UserDashboardStats, error)
|
||||||
|
|||||||
@@ -1687,7 +1687,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
|||||||
Usage: ClaudeUsage{},
|
Usage: ClaudeUsage{},
|
||||||
Model: originalModel,
|
Model: originalModel,
|
||||||
Stream: false,
|
Stream: false,
|
||||||
Duration: time.Since(time.Now()),
|
Duration: time.Since(startTime),
|
||||||
FirstTokenMs: nil,
|
FirstTokenMs: nil,
|
||||||
}, nil
|
}, nil
|
||||||
default:
|
default:
|
||||||
|
|||||||
@@ -319,16 +319,16 @@ func (s *DashboardService) GetUserUsageTrend(ctx context.Context, startTime, end
|
|||||||
return trend, nil
|
return trend, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DashboardService) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error) {
|
func (s *DashboardService) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) {
|
||||||
stats, err := s.usageRepo.GetBatchUserUsageStats(ctx, userIDs)
|
stats, err := s.usageRepo.GetBatchUserUsageStats(ctx, userIDs, startTime, endTime)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("get batch user usage stats: %w", err)
|
return nil, fmt.Errorf("get batch user usage stats: %w", err)
|
||||||
}
|
}
|
||||||
return stats, nil
|
return stats, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DashboardService) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
|
func (s *DashboardService) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
|
||||||
stats, err := s.usageRepo.GetBatchAPIKeyUsageStats(ctx, apiKeyIDs)
|
stats, err := s.usageRepo.GetBatchAPIKeyUsageStats(ctx, apiKeyIDs, startTime, endTime)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("get batch api key usage stats: %w", err)
|
return nil, fmt.Errorf("get batch api key usage stats: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -316,8 +316,8 @@ func (s *UsageService) GetUserModelStats(ctx context.Context, userID int64, star
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetBatchAPIKeyUsageStats returns today/total actual_cost for given api keys.
|
// GetBatchAPIKeyUsageStats returns today/total actual_cost for given api keys.
|
||||||
func (s *UsageService) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
|
func (s *UsageService) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
|
||||||
stats, err := s.usageRepo.GetBatchAPIKeyUsageStats(ctx, apiKeyIDs)
|
stats, err := s.usageRepo.GetBatchAPIKeyUsageStats(ctx, apiKeyIDs, startTime, endTime)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("get batch api key usage stats: %w", err)
|
return nil, fmt.Errorf("get batch api key usage stats: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,8 @@ package service
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"time"
|
||||||
|
|
||||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||||
@@ -62,13 +64,15 @@ type ChangePasswordRequest struct {
|
|||||||
type UserService struct {
|
type UserService struct {
|
||||||
userRepo UserRepository
|
userRepo UserRepository
|
||||||
authCacheInvalidator APIKeyAuthCacheInvalidator
|
authCacheInvalidator APIKeyAuthCacheInvalidator
|
||||||
|
billingCache BillingCache
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewUserService 创建用户服务实例
|
// NewUserService 创建用户服务实例
|
||||||
func NewUserService(userRepo UserRepository, authCacheInvalidator APIKeyAuthCacheInvalidator) *UserService {
|
func NewUserService(userRepo UserRepository, authCacheInvalidator APIKeyAuthCacheInvalidator, billingCache BillingCache) *UserService {
|
||||||
return &UserService{
|
return &UserService{
|
||||||
userRepo: userRepo,
|
userRepo: userRepo,
|
||||||
authCacheInvalidator: authCacheInvalidator,
|
authCacheInvalidator: authCacheInvalidator,
|
||||||
|
billingCache: billingCache,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -183,6 +187,15 @@ func (s *UserService) UpdateBalance(ctx context.Context, userID int64, amount fl
|
|||||||
if s.authCacheInvalidator != nil {
|
if s.authCacheInvalidator != nil {
|
||||||
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
|
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
|
||||||
}
|
}
|
||||||
|
if s.billingCache != nil {
|
||||||
|
go func() {
|
||||||
|
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
if err := s.billingCache.InvalidateUserBalance(cacheCtx, userID); err != nil {
|
||||||
|
log.Printf("invalidate user balance cache failed: user_id=%d err=%v", userID, err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
186
backend/internal/service/user_service_test.go
Normal file
186
backend/internal/service/user_service_test.go
Normal file
@@ -0,0 +1,186 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// --- mock: UserRepository ---
|
||||||
|
|
||||||
|
type mockUserRepo struct {
|
||||||
|
updateBalanceErr error
|
||||||
|
updateBalanceFn func(ctx context.Context, id int64, amount float64) error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockUserRepo) Create(context.Context, *User) error { return nil }
|
||||||
|
func (m *mockUserRepo) GetByID(context.Context, int64) (*User, error) { return &User{}, nil }
|
||||||
|
func (m *mockUserRepo) GetByEmail(context.Context, string) (*User, error) { return &User{}, nil }
|
||||||
|
func (m *mockUserRepo) GetFirstAdmin(context.Context) (*User, error) { return &User{}, nil }
|
||||||
|
func (m *mockUserRepo) Update(context.Context, *User) error { return nil }
|
||||||
|
func (m *mockUserRepo) Delete(context.Context, int64) error { return nil }
|
||||||
|
func (m *mockUserRepo) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
|
||||||
|
return nil, nil, nil
|
||||||
|
}
|
||||||
|
func (m *mockUserRepo) ListWithFilters(context.Context, pagination.PaginationParams, UserListFilters) ([]User, *pagination.PaginationResult, error) {
|
||||||
|
return nil, nil, nil
|
||||||
|
}
|
||||||
|
func (m *mockUserRepo) UpdateBalance(ctx context.Context, id int64, amount float64) error {
|
||||||
|
if m.updateBalanceFn != nil {
|
||||||
|
return m.updateBalanceFn(ctx, id, amount)
|
||||||
|
}
|
||||||
|
return m.updateBalanceErr
|
||||||
|
}
|
||||||
|
func (m *mockUserRepo) DeductBalance(context.Context, int64, float64) error { return nil }
|
||||||
|
func (m *mockUserRepo) UpdateConcurrency(context.Context, int64, int) error { return nil }
|
||||||
|
func (m *mockUserRepo) ExistsByEmail(context.Context, string) (bool, error) { return false, nil }
|
||||||
|
func (m *mockUserRepo) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
func (m *mockUserRepo) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
|
||||||
|
func (m *mockUserRepo) EnableTotp(context.Context, int64) error { return nil }
|
||||||
|
func (m *mockUserRepo) DisableTotp(context.Context, int64) error { return nil }
|
||||||
|
|
||||||
|
// --- mock: APIKeyAuthCacheInvalidator ---
|
||||||
|
|
||||||
|
type mockAuthCacheInvalidator struct {
|
||||||
|
invalidatedUserIDs []int64
|
||||||
|
mu sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockAuthCacheInvalidator) InvalidateAuthCacheByKey(context.Context, string) {}
|
||||||
|
func (m *mockAuthCacheInvalidator) InvalidateAuthCacheByGroupID(context.Context, int64) {}
|
||||||
|
func (m *mockAuthCacheInvalidator) InvalidateAuthCacheByUserID(_ context.Context, userID int64) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
m.invalidatedUserIDs = append(m.invalidatedUserIDs, userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- mock: BillingCache ---
|
||||||
|
|
||||||
|
type mockBillingCache struct {
|
||||||
|
invalidateErr error
|
||||||
|
invalidateCallCount atomic.Int64
|
||||||
|
invalidatedUserIDs []int64
|
||||||
|
mu sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockBillingCache) GetUserBalance(context.Context, int64) (float64, error) { return 0, nil }
|
||||||
|
func (m *mockBillingCache) SetUserBalance(context.Context, int64, float64) error { return nil }
|
||||||
|
func (m *mockBillingCache) DeductUserBalance(context.Context, int64, float64) error { return nil }
|
||||||
|
func (m *mockBillingCache) InvalidateUserBalance(_ context.Context, userID int64) error {
|
||||||
|
m.invalidateCallCount.Add(1)
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
m.invalidatedUserIDs = append(m.invalidatedUserIDs, userID)
|
||||||
|
return m.invalidateErr
|
||||||
|
}
|
||||||
|
func (m *mockBillingCache) GetSubscriptionCache(context.Context, int64, int64) (*SubscriptionCacheData, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
func (m *mockBillingCache) SetSubscriptionCache(context.Context, int64, int64, *SubscriptionCacheData) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (m *mockBillingCache) UpdateSubscriptionUsage(context.Context, int64, int64, float64) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (m *mockBillingCache) InvalidateSubscriptionCache(context.Context, int64, int64) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- 测试 ---
|
||||||
|
|
||||||
|
func TestUpdateBalance_Success(t *testing.T) {
|
||||||
|
repo := &mockUserRepo{}
|
||||||
|
cache := &mockBillingCache{}
|
||||||
|
svc := NewUserService(repo, nil, cache)
|
||||||
|
|
||||||
|
err := svc.UpdateBalance(context.Background(), 42, 100.0)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 等待异步 goroutine 完成
|
||||||
|
require.Eventually(t, func() bool {
|
||||||
|
return cache.invalidateCallCount.Load() == 1
|
||||||
|
}, 2*time.Second, 10*time.Millisecond, "应异步调用 InvalidateUserBalance")
|
||||||
|
|
||||||
|
cache.mu.Lock()
|
||||||
|
defer cache.mu.Unlock()
|
||||||
|
require.Equal(t, []int64{42}, cache.invalidatedUserIDs, "应对 userID=42 失效缓存")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateBalance_NilBillingCache_NoPanic(t *testing.T) {
|
||||||
|
repo := &mockUserRepo{}
|
||||||
|
svc := NewUserService(repo, nil, nil) // billingCache = nil
|
||||||
|
|
||||||
|
err := svc.UpdateBalance(context.Background(), 1, 50.0)
|
||||||
|
require.NoError(t, err, "billingCache 为 nil 时不应 panic")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateBalance_CacheFailure_DoesNotAffectReturn(t *testing.T) {
|
||||||
|
repo := &mockUserRepo{}
|
||||||
|
cache := &mockBillingCache{invalidateErr: errors.New("redis connection refused")}
|
||||||
|
svc := NewUserService(repo, nil, cache)
|
||||||
|
|
||||||
|
err := svc.UpdateBalance(context.Background(), 99, 200.0)
|
||||||
|
require.NoError(t, err, "缓存失效失败不应影响主流程返回值")
|
||||||
|
|
||||||
|
// 等待异步 goroutine 完成(即使失败也应调用)
|
||||||
|
require.Eventually(t, func() bool {
|
||||||
|
return cache.invalidateCallCount.Load() == 1
|
||||||
|
}, 2*time.Second, 10*time.Millisecond, "即使失败也应调用 InvalidateUserBalance")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateBalance_RepoError_ReturnsError(t *testing.T) {
|
||||||
|
repo := &mockUserRepo{updateBalanceErr: errors.New("database error")}
|
||||||
|
cache := &mockBillingCache{}
|
||||||
|
svc := NewUserService(repo, nil, cache)
|
||||||
|
|
||||||
|
err := svc.UpdateBalance(context.Background(), 1, 100.0)
|
||||||
|
require.Error(t, err, "repo 失败时应返回错误")
|
||||||
|
require.Contains(t, err.Error(), "update balance")
|
||||||
|
|
||||||
|
// repo 失败时不应触发缓存失效
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
require.Equal(t, int64(0), cache.invalidateCallCount.Load(),
|
||||||
|
"repo 失败时不应调用 InvalidateUserBalance")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateBalance_WithAuthCacheInvalidator(t *testing.T) {
|
||||||
|
repo := &mockUserRepo{}
|
||||||
|
auth := &mockAuthCacheInvalidator{}
|
||||||
|
cache := &mockBillingCache{}
|
||||||
|
svc := NewUserService(repo, auth, cache)
|
||||||
|
|
||||||
|
err := svc.UpdateBalance(context.Background(), 77, 300.0)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 验证 auth cache 同步失效
|
||||||
|
auth.mu.Lock()
|
||||||
|
require.Equal(t, []int64{77}, auth.invalidatedUserIDs)
|
||||||
|
auth.mu.Unlock()
|
||||||
|
|
||||||
|
// 验证 billing cache 异步失效
|
||||||
|
require.Eventually(t, func() bool {
|
||||||
|
return cache.invalidateCallCount.Load() == 1
|
||||||
|
}, 2*time.Second, 10*time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewUserService_FieldsAssignment(t *testing.T) {
|
||||||
|
repo := &mockUserRepo{}
|
||||||
|
auth := &mockAuthCacheInvalidator{}
|
||||||
|
cache := &mockBillingCache{}
|
||||||
|
|
||||||
|
svc := NewUserService(repo, auth, cache)
|
||||||
|
require.NotNil(t, svc)
|
||||||
|
require.Equal(t, repo, svc.userRepo)
|
||||||
|
require.Equal(t, auth, svc.authCacheInvalidator)
|
||||||
|
require.Equal(t, cache, svc.billingCache)
|
||||||
|
}
|
||||||
@@ -112,9 +112,9 @@ security:
|
|||||||
# 白名单禁用时是否允许 http:// URL(默认: false,要求 https)
|
# 白名单禁用时是否允许 http:// URL(默认: false,要求 https)
|
||||||
allow_insecure_http: true
|
allow_insecure_http: true
|
||||||
response_headers:
|
response_headers:
|
||||||
# Enable configurable response header filtering (disable to use default allowlist)
|
# Enable configurable response header filtering (default: true)
|
||||||
# 启用可配置的响应头过滤(禁用则使用默认白名单)
|
# 启用可配置的响应头过滤(默认启用,过滤上游敏感响应头)
|
||||||
enabled: false
|
enabled: true
|
||||||
# Extra allowed response headers from upstream
|
# Extra allowed response headers from upstream
|
||||||
# 额外允许的上游响应头
|
# 额外允许的上游响应头
|
||||||
additional_allowed: []
|
additional_allowed: []
|
||||||
@@ -390,15 +390,16 @@ database:
|
|||||||
# Database name
|
# Database name
|
||||||
# 数据库名称
|
# 数据库名称
|
||||||
dbname: "sub2api"
|
dbname: "sub2api"
|
||||||
# SSL mode: disable, require, verify-ca, verify-full
|
# SSL mode: disable, prefer, require, verify-ca, verify-full
|
||||||
# SSL 模式:disable(禁用), require(要求), verify-ca(验证CA), verify-full(完全验证)
|
# SSL 模式:disable(禁用), prefer(优先加密,默认), require(要求), verify-ca(验证CA), verify-full(完全验证)
|
||||||
sslmode: "disable"
|
# 默认值为 "prefer",数据库支持 SSL 时自动使用加密连接,不支持时回退明文
|
||||||
# Max open connections
|
sslmode: "prefer"
|
||||||
|
# Max open connections (高并发场景建议 256+,需配合 PostgreSQL max_connections 调整)
|
||||||
# 最大打开连接数
|
# 最大打开连接数
|
||||||
max_open_conns: 50
|
max_open_conns: 256
|
||||||
# Max idle connections
|
# Max idle connections (建议为 max_open_conns 的 50%,减少频繁建连开销)
|
||||||
# 最大空闲连接数
|
# 最大空闲连接数
|
||||||
max_idle_conns: 10
|
max_idle_conns: 128
|
||||||
# Connection max lifetime (minutes)
|
# Connection max lifetime (minutes)
|
||||||
# 连接最大存活时间(分钟)
|
# 连接最大存活时间(分钟)
|
||||||
conn_max_lifetime_minutes: 30
|
conn_max_lifetime_minutes: 30
|
||||||
@@ -426,9 +427,9 @@ redis:
|
|||||||
# Connection pool size (max concurrent connections)
|
# Connection pool size (max concurrent connections)
|
||||||
# 连接池大小(最大并发连接数)
|
# 连接池大小(最大并发连接数)
|
||||||
pool_size: 1024
|
pool_size: 1024
|
||||||
# Minimum number of idle connections
|
# Minimum number of idle connections (高并发场景建议 128+,保持足够热连接)
|
||||||
# 最小空闲连接数
|
# 最小空闲连接数
|
||||||
min_idle_conns: 10
|
min_idle_conns: 128
|
||||||
# Enable TLS/SSL connection
|
# Enable TLS/SSL connection
|
||||||
# 是否启用 TLS/SSL 连接
|
# 是否启用 TLS/SSL 连接
|
||||||
enable_tls: false
|
enable_tls: false
|
||||||
|
|||||||
Reference in New Issue
Block a user