Merge upstream/main: v0.1.85-v0.1.86 updates
Some checks failed
CI / test (push) Has been cancelled
CI / golangci-lint (push) Has been cancelled
Security Scan / backend-security (push) Has been cancelled
Security Scan / frontend-security (push) Has been cancelled

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
huangzhenpc
2026-02-25 18:39:59 +08:00
478 changed files with 70574 additions and 35853 deletions

View File

@@ -83,6 +83,7 @@ func TestAPIContracts(t *testing.T) {
"status": "active",
"ip_whitelist": null,
"ip_blacklist": null,
"last_used_at": null,
"quota": 0,
"quota_used": 0,
"expires_at": null,
@@ -122,6 +123,7 @@ func TestAPIContracts(t *testing.T) {
"status": "active",
"ip_whitelist": null,
"ip_blacklist": null,
"last_used_at": null,
"quota": 0,
"quota_used": 0,
"expires_at": null,
@@ -184,6 +186,10 @@ func TestAPIContracts(t *testing.T) {
"image_price_1k": null,
"image_price_2k": null,
"image_price_4k": null,
"sora_image_price_360": null,
"sora_image_price_540": null,
"sora_video_price_per_request": null,
"sora_video_price_per_request_hd": null,
"claude_code_only": false,
"fallback_group_id": null,
"fallback_group_id_on_invalid_request": null,
@@ -401,6 +407,7 @@ func TestAPIContracts(t *testing.T) {
"first_token_ms": 50,
"image_count": 0,
"image_size": null,
"media_type": null,
"cache_ttl_overridden": false,
"created_at": "2025-01-02T03:04:05Z",
"user_agent": null
@@ -593,13 +600,13 @@ func newContractDeps(t *testing.T) *contractDeps {
RunMode: config.RunModeStandard,
}
userService := service.NewUserService(userRepo, nil)
userService := service.NewUserService(userRepo, nil, nil)
apiKeyService := service.NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, nil, apiKeyCache, cfg)
usageRepo := newStubUsageLogRepo()
usageService := service.NewUsageService(usageRepo, userRepo, nil, nil)
subscriptionService := service.NewSubscriptionService(groupRepo, userSubRepo, nil)
subscriptionService := service.NewSubscriptionService(groupRepo, userSubRepo, nil, nil, cfg)
subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService)
redeemService := service.NewRedeemService(redeemRepo, userRepo, subscriptionService, nil, nil, nil, nil)
@@ -608,7 +615,7 @@ func newContractDeps(t *testing.T) *contractDeps {
settingRepo := newStubSettingRepo()
settingService := service.NewSettingService(settingRepo, cfg)
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil)
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil)
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
@@ -925,6 +932,10 @@ func (s *stubAccountRepo) GetByCRSAccountID(ctx context.Context, crsAccountID st
return nil, errors.New("not implemented")
}
func (s *stubAccountRepo) FindByExtraField(ctx context.Context, key string, value any) ([]service.Account, error) {
return nil, errors.New("not implemented")
}
func (s *stubAccountRepo) Update(ctx context.Context, account *service.Account) error {
return errors.New("not implemented")
}
@@ -1462,6 +1473,20 @@ func (r *stubApiKeyRepo) IncrementQuotaUsed(ctx context.Context, id int64, amoun
return 0, errors.New("not implemented")
}
func (r *stubApiKeyRepo) UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error {
key, ok := r.byID[id]
if !ok {
return service.ErrAPIKeyNotFound
}
ts := usedAt
key.LastUsedAt = &ts
key.UpdatedAt = usedAt
clone := *key
r.byID[id] = &clone
r.byKey[clone.Key] = &clone
return nil
}
type stubUsageLogRepo struct {
userLogs map[int64][]service.UsageLog
}
@@ -1607,11 +1632,11 @@ func (r *stubUsageLogRepo) GetDailyStatsAggregated(ctx context.Context, userID i
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error) {
func (r *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) {
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
func (r *stubUsageLogRepo) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
return nil, errors.New("not implemented")
}

View File

@@ -51,6 +51,9 @@ func ProvideRouter(
if err := r.SetTrustedProxies(nil); err != nil {
log.Printf("Failed to disable trusted proxies: %v", err)
}
if cfg.Server.Mode == "release" {
log.Printf("Warning: server.trusted_proxies is empty in release mode; client IP trust chain is disabled")
}
}
return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg, redisClient)

View File

@@ -58,8 +58,13 @@ func adminAuth(
authHeader := c.GetHeader("Authorization")
if authHeader != "" {
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) == 2 && parts[0] == "Bearer" {
if !validateJWTForAdmin(c, parts[1], authService, userService) {
if len(parts) == 2 && strings.EqualFold(parts[0], "Bearer") {
token := strings.TrimSpace(parts[1])
if token == "" {
AbortWithError(c, 401, "UNAUTHORIZED", "Authorization required")
return
}
if !validateJWTForAdmin(c, token, authService, userService) {
return
}
c.Next()
@@ -176,6 +181,12 @@ func validateJWTForAdmin(
return false
}
// 校验 TokenVersion确保管理员改密后旧 token 失效
if claims.TokenVersion != user.TokenVersion {
AbortWithError(c, 401, "TOKEN_REVOKED", "Token has been revoked (password changed)")
return false
}
// 检查管理员权限
if !user.IsAdmin() {
AbortWithError(c, 403, "FORBIDDEN", "Admin access required")

View File

@@ -0,0 +1,194 @@
//go:build unit
package middleware
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestAdminAuthJWTValidatesTokenVersion(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{JWT: config.JWTConfig{Secret: "test-secret", ExpireHour: 1}}
authService := service.NewAuthService(nil, nil, nil, cfg, nil, nil, nil, nil, nil)
admin := &service.User{
ID: 1,
Email: "admin@example.com",
Role: service.RoleAdmin,
Status: service.StatusActive,
TokenVersion: 2,
Concurrency: 1,
}
userRepo := &stubUserRepo{
getByID: func(ctx context.Context, id int64) (*service.User, error) {
if id != admin.ID {
return nil, service.ErrUserNotFound
}
clone := *admin
return &clone, nil
},
}
userService := service.NewUserService(userRepo, nil, nil)
router := gin.New()
router.Use(gin.HandlerFunc(NewAdminAuthMiddleware(authService, userService, nil)))
router.GET("/t", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"ok": true})
})
t.Run("token_version_mismatch_rejected", func(t *testing.T) {
token, err := authService.GenerateToken(&service.User{
ID: admin.ID,
Email: admin.Email,
Role: admin.Role,
TokenVersion: admin.TokenVersion - 1,
})
require.NoError(t, err)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/t", nil)
req.Header.Set("Authorization", "Bearer "+token)
router.ServeHTTP(w, req)
require.Equal(t, http.StatusUnauthorized, w.Code)
require.Contains(t, w.Body.String(), "TOKEN_REVOKED")
})
t.Run("token_version_match_allows", func(t *testing.T) {
token, err := authService.GenerateToken(&service.User{
ID: admin.ID,
Email: admin.Email,
Role: admin.Role,
TokenVersion: admin.TokenVersion,
})
require.NoError(t, err)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/t", nil)
req.Header.Set("Authorization", "Bearer "+token)
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
})
t.Run("websocket_token_version_mismatch_rejected", func(t *testing.T) {
token, err := authService.GenerateToken(&service.User{
ID: admin.ID,
Email: admin.Email,
Role: admin.Role,
TokenVersion: admin.TokenVersion - 1,
})
require.NoError(t, err)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/t", nil)
req.Header.Set("Upgrade", "websocket")
req.Header.Set("Connection", "Upgrade")
req.Header.Set("Sec-WebSocket-Protocol", "sub2api-admin, jwt."+token)
router.ServeHTTP(w, req)
require.Equal(t, http.StatusUnauthorized, w.Code)
require.Contains(t, w.Body.String(), "TOKEN_REVOKED")
})
t.Run("websocket_token_version_match_allows", func(t *testing.T) {
token, err := authService.GenerateToken(&service.User{
ID: admin.ID,
Email: admin.Email,
Role: admin.Role,
TokenVersion: admin.TokenVersion,
})
require.NoError(t, err)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/t", nil)
req.Header.Set("Upgrade", "websocket")
req.Header.Set("Connection", "Upgrade")
req.Header.Set("Sec-WebSocket-Protocol", "sub2api-admin, jwt."+token)
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
})
}
type stubUserRepo struct {
getByID func(ctx context.Context, id int64) (*service.User, error)
}
func (s *stubUserRepo) Create(ctx context.Context, user *service.User) error {
panic("unexpected Create call")
}
func (s *stubUserRepo) GetByID(ctx context.Context, id int64) (*service.User, error) {
if s.getByID == nil {
panic("GetByID not stubbed")
}
return s.getByID(ctx, id)
}
func (s *stubUserRepo) GetByEmail(ctx context.Context, email string) (*service.User, error) {
panic("unexpected GetByEmail call")
}
func (s *stubUserRepo) GetFirstAdmin(ctx context.Context) (*service.User, error) {
panic("unexpected GetFirstAdmin call")
}
func (s *stubUserRepo) Update(ctx context.Context, user *service.User) error {
panic("unexpected Update call")
}
func (s *stubUserRepo) Delete(ctx context.Context, id int64) error {
panic("unexpected Delete call")
}
func (s *stubUserRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
panic("unexpected List call")
}
func (s *stubUserRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) {
panic("unexpected ListWithFilters call")
}
func (s *stubUserRepo) UpdateBalance(ctx context.Context, id int64, amount float64) error {
panic("unexpected UpdateBalance call")
}
func (s *stubUserRepo) DeductBalance(ctx context.Context, id int64, amount float64) error {
panic("unexpected DeductBalance call")
}
func (s *stubUserRepo) UpdateConcurrency(ctx context.Context, id int64, amount int) error {
panic("unexpected UpdateConcurrency call")
}
func (s *stubUserRepo) ExistsByEmail(ctx context.Context, email string) (bool, error) {
panic("unexpected ExistsByEmail call")
}
func (s *stubUserRepo) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) {
panic("unexpected RemoveGroupFromAllowedGroups call")
}
func (s *stubUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
panic("unexpected UpdateTotpSecret call")
}
func (s *stubUserRepo) EnableTotp(ctx context.Context, userID int64) error {
panic("unexpected EnableTotp call")
}
func (s *stubUserRepo) DisableTotp(ctx context.Context, userID int64) error {
panic("unexpected DisableTotp call")
}

View File

@@ -3,7 +3,6 @@ package middleware
import (
"context"
"errors"
"log"
"strings"
"github.com/Wei-Shaw/sub2api/internal/config"
@@ -36,8 +35,8 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
if authHeader != "" {
// 验证Bearer scheme
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) == 2 && parts[0] == "Bearer" {
apiKeyString = parts[1]
if len(parts) == 2 && strings.EqualFold(parts[0], "Bearer") {
apiKeyString = strings.TrimSpace(parts[1])
}
}
@@ -97,7 +96,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
// 检查 IP 限制(白名单/黑名单)
// 注意:错误信息故意模糊,避免暴露具体的 IP 限制机制
if len(apiKey.IPWhitelist) > 0 || len(apiKey.IPBlacklist) > 0 {
clientIP := ip.GetClientIP(c)
clientIP := ip.GetTrustedClientIP(c)
allowed, _ := ip.CheckIPRestriction(clientIP, apiKey.IPWhitelist, apiKey.IPBlacklist)
if !allowed {
AbortWithError(c, 403, "ACCESS_DENIED", "Access denied")
@@ -126,6 +125,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
})
c.Set(string(ContextKeyUserRole), apiKey.User.Role)
setGroupContext(c, apiKey.Group)
_ = apiKeyService.TouchLastUsed(c.Request.Context(), apiKey.ID)
c.Next()
return
}
@@ -134,7 +134,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
isSubscriptionType := apiKey.Group != nil && apiKey.Group.IsSubscriptionType()
if isSubscriptionType && subscriptionService != nil {
// 订阅模式:验证订阅
// 订阅模式:获取订阅L1 缓存 + singleflight
subscription, err := subscriptionService.GetActiveSubscription(
c.Request.Context(),
apiKey.User.ID,
@@ -145,30 +145,30 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
return
}
// 验证订阅状态(是否过期、暂停等
if err := subscriptionService.ValidateSubscription(c.Request.Context(), subscription); err != nil {
AbortWithError(c, 403, "SUBSCRIPTION_INVALID", err.Error())
return
}
// 激活滑动窗口(首次使用时)
if err := subscriptionService.CheckAndActivateWindow(c.Request.Context(), subscription); err != nil {
log.Printf("Failed to activate subscription windows: %v", err)
}
// 检查并重置过期窗口
if err := subscriptionService.CheckAndResetWindows(c.Request.Context(), subscription); err != nil {
log.Printf("Failed to reset subscription windows: %v", err)
}
// 预检查用量限制使用0作为额外费用进行预检查
if err := subscriptionService.CheckUsageLimits(c.Request.Context(), subscription, apiKey.Group, 0); err != nil {
AbortWithError(c, 429, "USAGE_LIMIT_EXCEEDED", err.Error())
// 合并验证 + 限额检查(纯内存操作
needsMaintenance, err := subscriptionService.ValidateAndCheckLimits(subscription, apiKey.Group)
if err != nil {
code := "SUBSCRIPTION_INVALID"
status := 403
if errors.Is(err, service.ErrDailyLimitExceeded) ||
errors.Is(err, service.ErrWeeklyLimitExceeded) ||
errors.Is(err, service.ErrMonthlyLimitExceeded) {
code = "USAGE_LIMIT_EXCEEDED"
status = 429
}
AbortWithError(c, status, code, err.Error())
return
}
// 将订阅信息存入上下文
c.Set(string(ContextKeySubscription), subscription)
// 窗口维护异步化(不阻塞请求)
// 传递独立拷贝,避免与 handler 读取 context 中的 subscription 产生 data race
if needsMaintenance {
maintenanceCopy := *subscription
subscriptionService.DoWindowMaintenance(&maintenanceCopy)
}
} else {
// 余额模式:检查用户余额
if apiKey.User.Balance <= 0 {
@@ -185,6 +185,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
})
c.Set(string(ContextKeyUserRole), apiKey.User.Role)
setGroupContext(c, apiKey.Group)
_ = apiKeyService.TouchLastUsed(c.Request.Context(), apiKey.ID)
c.Next()
}

View File

@@ -64,6 +64,7 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs
})
c.Set(string(ContextKeyUserRole), apiKey.User.Role)
setGroupContext(c, apiKey.Group)
_ = apiKeyService.TouchLastUsed(c.Request.Context(), apiKey.ID)
c.Next()
return
}
@@ -104,6 +105,7 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs
})
c.Set(string(ContextKeyUserRole), apiKey.User.Role)
setGroupContext(c, apiKey.Group)
_ = apiKeyService.TouchLastUsed(c.Request.Context(), apiKey.ID)
c.Next()
}
}

View File

@@ -7,6 +7,7 @@ import (
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
@@ -18,7 +19,8 @@ import (
)
type fakeAPIKeyRepo struct {
getByKey func(ctx context.Context, key string) (*service.APIKey, error)
getByKey func(ctx context.Context, key string) (*service.APIKey, error)
updateLastUsed func(ctx context.Context, id int64, usedAt time.Time) error
}
func (f fakeAPIKeyRepo) Create(ctx context.Context, key *service.APIKey) error {
@@ -78,6 +80,12 @@ func (f fakeAPIKeyRepo) ListKeysByGroupID(ctx context.Context, groupID int64) ([
func (f fakeAPIKeyRepo) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error) {
return 0, errors.New("not implemented")
}
func (f fakeAPIKeyRepo) UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error {
if f.updateLastUsed != nil {
return f.updateLastUsed(ctx, id, usedAt)
}
return nil
}
type googleErrorResponse struct {
Error struct {
@@ -356,3 +364,144 @@ func TestApiKeyAuthWithSubscriptionGoogle_InsufficientBalance(t *testing.T) {
require.Equal(t, "Insufficient account balance", resp.Error.Message)
require.Equal(t, "PERMISSION_DENIED", resp.Error.Status)
}
func TestApiKeyAuthWithSubscriptionGoogle_TouchesLastUsedOnSuccess(t *testing.T) {
gin.SetMode(gin.TestMode)
user := &service.User{
ID: 11,
Role: service.RoleUser,
Status: service.StatusActive,
Balance: 10,
Concurrency: 3,
}
apiKey := &service.APIKey{
ID: 201,
UserID: user.ID,
Key: "google-touch-ok",
Status: service.StatusActive,
User: user,
}
var touchedID int64
var touchedAt time.Time
r := gin.New()
apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
if key != apiKey.Key {
return nil, service.ErrAPIKeyNotFound
}
clone := *apiKey
return &clone, nil
},
updateLastUsed: func(ctx context.Context, id int64, usedAt time.Time) error {
touchedID = id
touchedAt = usedAt
return nil
},
})
cfg := &config.Config{RunMode: config.RunModeSimple}
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, cfg))
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
req.Header.Set("x-goog-api-key", apiKey.Key)
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, apiKey.ID, touchedID)
require.False(t, touchedAt.IsZero())
}
func TestApiKeyAuthWithSubscriptionGoogle_TouchFailureDoesNotBlock(t *testing.T) {
gin.SetMode(gin.TestMode)
user := &service.User{
ID: 12,
Role: service.RoleUser,
Status: service.StatusActive,
Balance: 10,
Concurrency: 3,
}
apiKey := &service.APIKey{
ID: 202,
UserID: user.ID,
Key: "google-touch-fail",
Status: service.StatusActive,
User: user,
}
touchCalls := 0
r := gin.New()
apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
if key != apiKey.Key {
return nil, service.ErrAPIKeyNotFound
}
clone := *apiKey
return &clone, nil
},
updateLastUsed: func(ctx context.Context, id int64, usedAt time.Time) error {
touchCalls++
return errors.New("write failed")
},
})
cfg := &config.Config{RunMode: config.RunModeSimple}
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, cfg))
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
req.Header.Set("x-goog-api-key", apiKey.Key)
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, 1, touchCalls)
}
func TestApiKeyAuthWithSubscriptionGoogle_TouchesLastUsedInStandardMode(t *testing.T) {
gin.SetMode(gin.TestMode)
user := &service.User{
ID: 13,
Role: service.RoleUser,
Status: service.StatusActive,
Balance: 10,
Concurrency: 3,
}
apiKey := &service.APIKey{
ID: 203,
UserID: user.ID,
Key: "google-touch-standard",
Status: service.StatusActive,
User: user,
}
touchCalls := 0
r := gin.New()
apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
if key != apiKey.Key {
return nil, service.ErrAPIKeyNotFound
}
clone := *apiKey
return &clone, nil
},
updateLastUsed: func(ctx context.Context, id int64, usedAt time.Time) error {
touchCalls++
return nil
},
})
cfg := &config.Config{RunMode: config.RunModeStandard}
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, cfg))
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
req.Header.Set("Authorization", "Bearer "+apiKey.Key)
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, 1, touchCalls)
}

View File

@@ -57,10 +57,41 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
},
}
t.Run("simple_mode_bypasses_quota_check", func(t *testing.T) {
cfg := &config.Config{RunMode: config.RunModeSimple}
t.Run("standard_mode_needs_maintenance_does_not_block_request", func(t *testing.T) {
cfg := &config.Config{RunMode: config.RunModeStandard}
cfg.SubscriptionMaintenance.WorkerCount = 1
cfg.SubscriptionMaintenance.QueueSize = 1
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
subscriptionService := service.NewSubscriptionService(nil, &stubUserSubscriptionRepo{}, nil)
past := time.Now().Add(-48 * time.Hour)
sub := &service.UserSubscription{
ID: 55,
UserID: user.ID,
GroupID: group.ID,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
DailyWindowStart: &past,
DailyUsageUSD: 0,
}
maintenanceCalled := make(chan struct{}, 1)
subscriptionRepo := &stubUserSubscriptionRepo{
getActive: func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
clone := *sub
return &clone, nil
},
updateStatus: func(ctx context.Context, subscriptionID int64, status string) error { return nil },
activateWindow: func(ctx context.Context, id int64, start time.Time) error { return nil },
resetDaily: func(ctx context.Context, id int64, start time.Time) error {
maintenanceCalled <- struct{}{}
return nil
},
resetWeekly: func(ctx context.Context, id int64, start time.Time) error { return nil },
resetMonthly: func(ctx context.Context, id int64, start time.Time) error { return nil },
}
subscriptionService := service.NewSubscriptionService(nil, subscriptionRepo, nil, nil, cfg)
t.Cleanup(subscriptionService.Stop)
router := newAuthTestRouter(apiKeyService, subscriptionService, cfg)
w := httptest.NewRecorder()
@@ -68,6 +99,40 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
req.Header.Set("x-api-key", apiKey.Key)
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
select {
case <-maintenanceCalled:
// ok
case <-time.After(time.Second):
t.Fatalf("expected maintenance to be scheduled")
}
})
t.Run("simple_mode_bypasses_quota_check", func(t *testing.T) {
cfg := &config.Config{RunMode: config.RunModeSimple}
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
subscriptionService := service.NewSubscriptionService(nil, &stubUserSubscriptionRepo{}, nil, nil, cfg)
router := newAuthTestRouter(apiKeyService, subscriptionService, cfg)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/t", nil)
req.Header.Set("x-api-key", apiKey.Key)
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
})
t.Run("simple_mode_accepts_lowercase_bearer", func(t *testing.T) {
cfg := &config.Config{RunMode: config.RunModeSimple}
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
subscriptionService := service.NewSubscriptionService(nil, &stubUserSubscriptionRepo{}, nil, nil, cfg)
router := newAuthTestRouter(apiKeyService, subscriptionService, cfg)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/t", nil)
req.Header.Set("Authorization", "bearer "+apiKey.Key)
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
})
@@ -99,7 +164,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
resetWeekly: func(ctx context.Context, id int64, start time.Time) error { return nil },
resetMonthly: func(ctx context.Context, id int64, start time.Time) error { return nil },
}
subscriptionService := service.NewSubscriptionService(nil, subscriptionRepo, nil)
subscriptionService := service.NewSubscriptionService(nil, subscriptionRepo, nil, nil, cfg)
router := newAuthTestRouter(apiKeyService, subscriptionService, cfg)
w := httptest.NewRecorder()
@@ -235,6 +300,198 @@ func TestAPIKeyAuthOverwritesInvalidContextGroup(t *testing.T) {
require.Equal(t, http.StatusOK, w.Code)
}
func TestAPIKeyAuthIPRestrictionDoesNotTrustSpoofedForwardHeaders(t *testing.T) {
gin.SetMode(gin.TestMode)
user := &service.User{
ID: 7,
Role: service.RoleUser,
Status: service.StatusActive,
Balance: 10,
Concurrency: 3,
}
apiKey := &service.APIKey{
ID: 100,
UserID: user.ID,
Key: "test-key",
Status: service.StatusActive,
User: user,
IPWhitelist: []string{"1.2.3.4"},
}
apiKeyRepo := &stubApiKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
if key != apiKey.Key {
return nil, service.ErrAPIKeyNotFound
}
clone := *apiKey
return &clone, nil
},
}
cfg := &config.Config{RunMode: config.RunModeSimple}
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
router := gin.New()
require.NoError(t, router.SetTrustedProxies(nil))
router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, nil, cfg)))
router.GET("/t", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"ok": true})
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/t", nil)
req.RemoteAddr = "9.9.9.9:12345"
req.Header.Set("x-api-key", apiKey.Key)
req.Header.Set("X-Forwarded-For", "1.2.3.4")
req.Header.Set("X-Real-IP", "1.2.3.4")
req.Header.Set("CF-Connecting-IP", "1.2.3.4")
router.ServeHTTP(w, req)
require.Equal(t, http.StatusForbidden, w.Code)
require.Contains(t, w.Body.String(), "ACCESS_DENIED")
}
func TestAPIKeyAuthTouchesLastUsedOnSuccess(t *testing.T) {
gin.SetMode(gin.TestMode)
user := &service.User{
ID: 7,
Role: service.RoleUser,
Status: service.StatusActive,
Balance: 10,
Concurrency: 3,
}
apiKey := &service.APIKey{
ID: 100,
UserID: user.ID,
Key: "touch-ok",
Status: service.StatusActive,
User: user,
}
var touchedID int64
var touchedAt time.Time
apiKeyRepo := &stubApiKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
if key != apiKey.Key {
return nil, service.ErrAPIKeyNotFound
}
clone := *apiKey
return &clone, nil
},
updateLastUsed: func(ctx context.Context, id int64, usedAt time.Time) error {
touchedID = id
touchedAt = usedAt
return nil
},
}
cfg := &config.Config{RunMode: config.RunModeSimple}
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
router := newAuthTestRouter(apiKeyService, nil, cfg)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/t", nil)
req.Header.Set("x-api-key", apiKey.Key)
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
require.Equal(t, apiKey.ID, touchedID)
require.False(t, touchedAt.IsZero(), "expected touch timestamp")
}
func TestAPIKeyAuthTouchLastUsedFailureDoesNotBlock(t *testing.T) {
gin.SetMode(gin.TestMode)
user := &service.User{
ID: 8,
Role: service.RoleUser,
Status: service.StatusActive,
Balance: 10,
Concurrency: 3,
}
apiKey := &service.APIKey{
ID: 101,
UserID: user.ID,
Key: "touch-fail",
Status: service.StatusActive,
User: user,
}
touchCalls := 0
apiKeyRepo := &stubApiKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
if key != apiKey.Key {
return nil, service.ErrAPIKeyNotFound
}
clone := *apiKey
return &clone, nil
},
updateLastUsed: func(ctx context.Context, id int64, usedAt time.Time) error {
touchCalls++
return errors.New("db unavailable")
},
}
cfg := &config.Config{RunMode: config.RunModeSimple}
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
router := newAuthTestRouter(apiKeyService, nil, cfg)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/t", nil)
req.Header.Set("x-api-key", apiKey.Key)
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code, "touch failure should not block request")
require.Equal(t, 1, touchCalls)
}
func TestAPIKeyAuthTouchesLastUsedInStandardMode(t *testing.T) {
gin.SetMode(gin.TestMode)
user := &service.User{
ID: 9,
Role: service.RoleUser,
Status: service.StatusActive,
Balance: 10,
Concurrency: 3,
}
apiKey := &service.APIKey{
ID: 102,
UserID: user.ID,
Key: "touch-standard",
Status: service.StatusActive,
User: user,
}
touchCalls := 0
apiKeyRepo := &stubApiKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
if key != apiKey.Key {
return nil, service.ErrAPIKeyNotFound
}
clone := *apiKey
return &clone, nil
},
updateLastUsed: func(ctx context.Context, id int64, usedAt time.Time) error {
touchCalls++
return nil
},
}
cfg := &config.Config{RunMode: config.RunModeStandard}
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
router := newAuthTestRouter(apiKeyService, nil, cfg)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/t", nil)
req.Header.Set("x-api-key", apiKey.Key)
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
require.Equal(t, 1, touchCalls)
}
func newAuthTestRouter(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) *gin.Engine {
router := gin.New()
router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, cfg)))
@@ -245,7 +502,8 @@ func newAuthTestRouter(apiKeyService *service.APIKeyService, subscriptionService
}
type stubApiKeyRepo struct {
getByKey func(ctx context.Context, key string) (*service.APIKey, error)
getByKey func(ctx context.Context, key string) (*service.APIKey, error)
updateLastUsed func(ctx context.Context, id int64, usedAt time.Time) error
}
func (r *stubApiKeyRepo) Create(ctx context.Context, key *service.APIKey) error {
@@ -323,6 +581,13 @@ func (r *stubApiKeyRepo) IncrementQuotaUsed(ctx context.Context, id int64, amoun
return 0, errors.New("not implemented")
}
func (r *stubApiKeyRepo) UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error {
if r.updateLastUsed != nil {
return r.updateLastUsed(ctx, id, usedAt)
}
return nil
}
type stubUserSubscriptionRepo struct {
getActive func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error)
updateStatus func(ctx context.Context, subscriptionID int64, status string) error

View File

@@ -2,10 +2,13 @@ package middleware
import (
"context"
"strings"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"go.uber.org/zap"
)
// ClientRequestID ensures every request has a unique client_request_id in request.Context().
@@ -24,7 +27,10 @@ func ClientRequestID() gin.HandlerFunc {
}
id := uuid.New().String()
c.Request = c.Request.WithContext(context.WithValue(c.Request.Context(), ctxkey.ClientRequestID, id))
ctx := context.WithValue(c.Request.Context(), ctxkey.ClientRequestID, id)
requestLogger := logger.FromContext(ctx).With(zap.String("client_request_id", strings.TrimSpace(id)))
ctx = logger.IntoContext(ctx, requestLogger)
c.Request = c.Request.WithContext(ctx)
c.Next()
}
}

View File

@@ -50,6 +50,19 @@ func CORS(cfg config.CORSConfig) gin.HandlerFunc {
}
allowedSet[origin] = struct{}{}
}
allowHeaders := []string{
"Content-Type", "Content-Length", "Accept-Encoding", "X-CSRF-Token", "Authorization",
"accept", "origin", "Cache-Control", "X-Requested-With", "X-API-Key",
}
// OpenAI Node SDK 会发送 x-stainless-* 请求头,需在 CORS 中显式放行。
openAIProperties := []string{
"lang", "package-version", "os", "arch", "retry-count", "runtime",
"runtime-version", "async", "helper-method", "poll-helper", "custom-poll-interval", "timeout",
}
for _, prop := range openAIProperties {
allowHeaders = append(allowHeaders, "x-stainless-"+prop)
}
allowHeadersValue := strings.Join(allowHeaders, ", ")
return func(c *gin.Context) {
origin := strings.TrimSpace(c.GetHeader("Origin"))
@@ -68,19 +81,11 @@ func CORS(cfg config.CORSConfig) gin.HandlerFunc {
if allowCredentials {
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
}
c.Writer.Header().Set("Access-Control-Allow-Headers", allowHeadersValue)
c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE, PATCH")
c.Writer.Header().Set("Access-Control-Expose-Headers", "ETag")
c.Writer.Header().Set("Access-Control-Max-Age", "86400")
}
allowHeaders := []string{"Content-Type", "Content-Length", "Accept-Encoding", "X-CSRF-Token", "Authorization", "accept", "origin", "Cache-Control", "X-Requested-With", "X-API-Key"}
// openai node sdk
openAIProperties := []string{"lang", "package-version", "os", "arch", "retry-count", "runtime", "runtime-version", "async", "helper-method", "poll-helper", "custom-poll-interval", "timeout"}
for _, prop := range openAIProperties {
allowHeaders = append(allowHeaders, "x-stainless-"+prop)
}
c.Writer.Header().Set("Access-Control-Allow-Headers", strings.Join(allowHeaders, ", "))
c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE, PATCH")
// 处理预检请求
if c.Request.Method == http.MethodOptions {
if originAllowed {

View File

@@ -0,0 +1,308 @@
package middleware
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
)
func init() {
// cors_test 与 security_headers_test 在同一个包,但 init 是幂等的
gin.SetMode(gin.TestMode)
}
// --- Task 8.2: 验证 CORS 条件化头部 ---
func TestCORS_DisallowedOrigin_NoAllowHeaders(t *testing.T) {
cfg := config.CORSConfig{
AllowedOrigins: []string{"https://allowed.example.com"},
AllowCredentials: false,
}
middleware := CORS(cfg)
tests := []struct {
name string
method string
origin string
}{
{
name: "preflight_disallowed_origin",
method: http.MethodOptions,
origin: "https://evil.example.com",
},
{
name: "get_disallowed_origin",
method: http.MethodGet,
origin: "https://evil.example.com",
},
{
name: "post_disallowed_origin",
method: http.MethodPost,
origin: "https://attacker.example.com",
},
{
name: "preflight_no_origin",
method: http.MethodOptions,
origin: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(tt.method, "/", nil)
if tt.origin != "" {
c.Request.Header.Set("Origin", tt.origin)
}
middleware(c)
// 不应设置 Allow-Headers、Allow-Methods 和 Max-Age
assert.Empty(t, w.Header().Get("Access-Control-Allow-Headers"),
"不允许的 origin 不应收到 Allow-Headers")
assert.Empty(t, w.Header().Get("Access-Control-Allow-Methods"),
"不允许的 origin 不应收到 Allow-Methods")
assert.Empty(t, w.Header().Get("Access-Control-Max-Age"),
"不允许的 origin 不应收到 Max-Age")
assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin"),
"不允许的 origin 不应收到 Allow-Origin")
})
}
}
func TestCORS_AllowedOrigin_HasAllowHeaders(t *testing.T) {
cfg := config.CORSConfig{
AllowedOrigins: []string{"https://allowed.example.com"},
AllowCredentials: false,
}
middleware := CORS(cfg)
tests := []struct {
name string
method string
}{
{name: "preflight_OPTIONS", method: http.MethodOptions},
{name: "normal_GET", method: http.MethodGet},
{name: "normal_POST", method: http.MethodPost},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(tt.method, "/", nil)
c.Request.Header.Set("Origin", "https://allowed.example.com")
middleware(c)
// 应设置 Allow-Headers、Allow-Methods 和 Max-Age
assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Headers"),
"允许的 origin 应收到 Allow-Headers")
assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Methods"),
"允许的 origin 应收到 Allow-Methods")
assert.Equal(t, "86400", w.Header().Get("Access-Control-Max-Age"),
"允许的 origin 应收到 Max-Age=86400")
assert.Equal(t, "https://allowed.example.com", w.Header().Get("Access-Control-Allow-Origin"),
"允许的 origin 应收到 Allow-Origin")
})
}
}
func TestCORS_PreflightDisallowedOrigin_ReturnsForbidden(t *testing.T) {
cfg := config.CORSConfig{
AllowedOrigins: []string{"https://allowed.example.com"},
AllowCredentials: false,
}
middleware := CORS(cfg)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodOptions, "/", nil)
c.Request.Header.Set("Origin", "https://evil.example.com")
middleware(c)
assert.Equal(t, http.StatusForbidden, w.Code,
"不允许的 origin 的 preflight 请求应返回 403")
}
func TestCORS_PreflightAllowedOrigin_ReturnsNoContent(t *testing.T) {
cfg := config.CORSConfig{
AllowedOrigins: []string{"https://allowed.example.com"},
AllowCredentials: false,
}
middleware := CORS(cfg)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodOptions, "/", nil)
c.Request.Header.Set("Origin", "https://allowed.example.com")
middleware(c)
assert.Equal(t, http.StatusNoContent, w.Code,
"允许的 origin 的 preflight 请求应返回 204")
}
func TestCORS_WildcardOrigin_AllowsAny(t *testing.T) {
cfg := config.CORSConfig{
AllowedOrigins: []string{"*"},
AllowCredentials: false,
}
middleware := CORS(cfg)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
c.Request.Header.Set("Origin", "https://any-origin.example.com")
middleware(c)
assert.Equal(t, "*", w.Header().Get("Access-Control-Allow-Origin"),
"通配符配置应返回 *")
assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Headers"),
"通配符 origin 应设置 Allow-Headers")
assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Methods"),
"通配符 origin 应设置 Allow-Methods")
}
func TestCORS_AllowCredentials_SetCorrectly(t *testing.T) {
cfg := config.CORSConfig{
AllowedOrigins: []string{"https://allowed.example.com"},
AllowCredentials: true,
}
middleware := CORS(cfg)
t.Run("allowed_origin_gets_credentials", func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
c.Request.Header.Set("Origin", "https://allowed.example.com")
middleware(c)
assert.Equal(t, "true", w.Header().Get("Access-Control-Allow-Credentials"),
"允许的 origin 且开启 credentials 应设置 Allow-Credentials")
})
t.Run("disallowed_origin_no_credentials", func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
c.Request.Header.Set("Origin", "https://evil.example.com")
middleware(c)
assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials"),
"不允许的 origin 不应收到 Allow-Credentials")
})
}
func TestCORS_WildcardWithCredentials_DisablesCredentials(t *testing.T) {
cfg := config.CORSConfig{
AllowedOrigins: []string{"*"},
AllowCredentials: true,
}
middleware := CORS(cfg)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
c.Request.Header.Set("Origin", "https://any.example.com")
middleware(c)
// 通配符 + credentials 不兼容credentials 应被禁用
assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials"),
"通配符 origin 应禁用 Allow-Credentials")
}
func TestCORS_MultipleAllowedOrigins(t *testing.T) {
cfg := config.CORSConfig{
AllowedOrigins: []string{
"https://app1.example.com",
"https://app2.example.com",
},
AllowCredentials: false,
}
middleware := CORS(cfg)
t.Run("first_origin_allowed", func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
c.Request.Header.Set("Origin", "https://app1.example.com")
middleware(c)
assert.Equal(t, "https://app1.example.com", w.Header().Get("Access-Control-Allow-Origin"))
assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Headers"))
})
t.Run("second_origin_allowed", func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
c.Request.Header.Set("Origin", "https://app2.example.com")
middleware(c)
assert.Equal(t, "https://app2.example.com", w.Header().Get("Access-Control-Allow-Origin"))
assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Headers"))
})
t.Run("unlisted_origin_rejected", func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
c.Request.Header.Set("Origin", "https://app3.example.com")
middleware(c)
assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin"))
assert.Empty(t, w.Header().Get("Access-Control-Allow-Headers"))
})
}
func TestCORS_VaryHeader_SetForSpecificOrigin(t *testing.T) {
cfg := config.CORSConfig{
AllowedOrigins: []string{"https://allowed.example.com"},
AllowCredentials: false,
}
middleware := CORS(cfg)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
c.Request.Header.Set("Origin", "https://allowed.example.com")
middleware(c)
assert.Contains(t, w.Header().Values("Vary"), "Origin",
"非通配符允许的 origin 应设置 Vary: Origin")
}
func TestNormalizeOrigins(t *testing.T) {
tests := []struct {
name string
input []string
expect []string
}{
{name: "nil_input", input: nil, expect: nil},
{name: "empty_input", input: []string{}, expect: nil},
{name: "trims_whitespace", input: []string{" https://a.com ", " https://b.com"}, expect: []string{"https://a.com", "https://b.com"}},
{name: "removes_empty_strings", input: []string{"", " ", "https://a.com"}, expect: []string{"https://a.com"}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := normalizeOrigins(tt.input)
assert.Equal(t, tt.expect, result)
})
}
}

View File

@@ -26,12 +26,12 @@ func jwtAuth(authService *service.AuthService, userService *service.UserService)
// 验证Bearer scheme
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) != 2 || parts[0] != "Bearer" {
if len(parts) != 2 || !strings.EqualFold(parts[0], "Bearer") {
AbortWithError(c, 401, "INVALID_AUTH_HEADER", "Authorization header format must be 'Bearer {token}'")
return
}
tokenString := parts[1]
tokenString := strings.TrimSpace(parts[1])
if tokenString == "" {
AbortWithError(c, 401, "EMPTY_TOKEN", "Token cannot be empty")
return

View File

@@ -0,0 +1,256 @@
//go:build unit
package middleware
import (
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
// stubJWTUserRepo 实现 UserRepository 的最小子集,仅支持 GetByID。
type stubJWTUserRepo struct {
service.UserRepository
users map[int64]*service.User
}
func (r *stubJWTUserRepo) GetByID(_ context.Context, id int64) (*service.User, error) {
u, ok := r.users[id]
if !ok {
return nil, errors.New("user not found")
}
return u, nil
}
// newJWTTestEnv 创建 JWT 认证中间件测试环境。
// 返回 gin.Engine已注册 JWT 中间件)和 AuthService用于生成 Token
func newJWTTestEnv(users map[int64]*service.User) (*gin.Engine, *service.AuthService) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{}
cfg.JWT.Secret = "test-jwt-secret-32bytes-long!!!"
cfg.JWT.AccessTokenExpireMinutes = 60
userRepo := &stubJWTUserRepo{users: users}
authSvc := service.NewAuthService(userRepo, nil, nil, cfg, nil, nil, nil, nil, nil)
userSvc := service.NewUserService(userRepo, nil, nil)
mw := NewJWTAuthMiddleware(authSvc, userSvc)
r := gin.New()
r.Use(gin.HandlerFunc(mw))
r.GET("/protected", func(c *gin.Context) {
subject, _ := GetAuthSubjectFromContext(c)
role, _ := GetUserRoleFromContext(c)
c.JSON(http.StatusOK, gin.H{
"user_id": subject.UserID,
"role": role,
})
})
return r, authSvc
}
func TestJWTAuth_ValidToken(t *testing.T) {
user := &service.User{
ID: 1,
Email: "test@example.com",
Role: "user",
Status: service.StatusActive,
Concurrency: 5,
TokenVersion: 1,
}
router, authSvc := newJWTTestEnv(map[int64]*service.User{1: user})
token, err := authSvc.GenerateToken(user)
require.NoError(t, err)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
req.Header.Set("Authorization", "Bearer "+token)
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
var body map[string]any
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body))
require.Equal(t, float64(1), body["user_id"])
require.Equal(t, "user", body["role"])
}
func TestJWTAuth_ValidToken_LowercaseBearer(t *testing.T) {
user := &service.User{
ID: 1,
Email: "test@example.com",
Role: "user",
Status: service.StatusActive,
Concurrency: 5,
TokenVersion: 1,
}
router, authSvc := newJWTTestEnv(map[int64]*service.User{1: user})
token, err := authSvc.GenerateToken(user)
require.NoError(t, err)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
req.Header.Set("Authorization", "bearer "+token)
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
}
func TestJWTAuth_MissingAuthorizationHeader(t *testing.T) {
router, _ := newJWTTestEnv(nil)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
router.ServeHTTP(w, req)
require.Equal(t, http.StatusUnauthorized, w.Code)
var body ErrorResponse
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body))
require.Equal(t, "UNAUTHORIZED", body.Code)
}
func TestJWTAuth_InvalidHeaderFormat(t *testing.T) {
tests := []struct {
name string
header string
}{
{"无Bearer前缀", "Token abc123"},
{"缺少空格分隔", "Bearerabc123"},
{"仅有单词", "abc123"},
}
router, _ := newJWTTestEnv(nil)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
req.Header.Set("Authorization", tt.header)
router.ServeHTTP(w, req)
require.Equal(t, http.StatusUnauthorized, w.Code)
var body ErrorResponse
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body))
require.Equal(t, "INVALID_AUTH_HEADER", body.Code)
})
}
}
func TestJWTAuth_EmptyToken(t *testing.T) {
router, _ := newJWTTestEnv(nil)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
req.Header.Set("Authorization", "Bearer ")
router.ServeHTTP(w, req)
require.Equal(t, http.StatusUnauthorized, w.Code)
var body ErrorResponse
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body))
require.Equal(t, "EMPTY_TOKEN", body.Code)
}
func TestJWTAuth_TamperedToken(t *testing.T) {
router, _ := newJWTTestEnv(nil)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
req.Header.Set("Authorization", "Bearer eyJhbGciOiJIUzI1NiJ9.eyJ1c2VyX2lkIjoxfQ.invalid_signature")
router.ServeHTTP(w, req)
require.Equal(t, http.StatusUnauthorized, w.Code)
var body ErrorResponse
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body))
require.Equal(t, "INVALID_TOKEN", body.Code)
}
func TestJWTAuth_UserNotFound(t *testing.T) {
// 使用 user ID=1 的 token但 repo 中没有该用户
fakeUser := &service.User{
ID: 999,
Email: "ghost@example.com",
Role: "user",
Status: service.StatusActive,
TokenVersion: 1,
}
// 创建环境时不注入此用户,这样 GetByID 会失败
router, authSvc := newJWTTestEnv(map[int64]*service.User{})
token, err := authSvc.GenerateToken(fakeUser)
require.NoError(t, err)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
req.Header.Set("Authorization", "Bearer "+token)
router.ServeHTTP(w, req)
require.Equal(t, http.StatusUnauthorized, w.Code)
var body ErrorResponse
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body))
require.Equal(t, "USER_NOT_FOUND", body.Code)
}
func TestJWTAuth_UserInactive(t *testing.T) {
user := &service.User{
ID: 1,
Email: "disabled@example.com",
Role: "user",
Status: service.StatusDisabled,
TokenVersion: 1,
}
router, authSvc := newJWTTestEnv(map[int64]*service.User{1: user})
token, err := authSvc.GenerateToken(user)
require.NoError(t, err)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
req.Header.Set("Authorization", "Bearer "+token)
router.ServeHTTP(w, req)
require.Equal(t, http.StatusUnauthorized, w.Code)
var body ErrorResponse
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body))
require.Equal(t, "USER_INACTIVE", body.Code)
}
func TestJWTAuth_TokenVersionMismatch(t *testing.T) {
// Token 生成时 TokenVersion=1但数据库中用户已更新为 TokenVersion=2密码修改
userForToken := &service.User{
ID: 1,
Email: "test@example.com",
Role: "user",
Status: service.StatusActive,
TokenVersion: 1,
}
userInDB := &service.User{
ID: 1,
Email: "test@example.com",
Role: "user",
Status: service.StatusActive,
TokenVersion: 2, // 密码修改后版本递增
}
router, authSvc := newJWTTestEnv(map[int64]*service.User{1: userInDB})
token, err := authSvc.GenerateToken(userForToken)
require.NoError(t, err)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
req.Header.Set("Authorization", "Bearer "+token)
router.ServeHTTP(w, req)
require.Equal(t, http.StatusUnauthorized, w.Code)
var body ErrorResponse
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body))
require.Equal(t, "TOKEN_REVOKED", body.Code)
}

View File

@@ -1,10 +1,12 @@
package middleware
import (
"log"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// Logger 请求日志中间件
@@ -13,44 +15,52 @@ func Logger() gin.HandlerFunc {
// 开始时间
startTime := time.Now()
// 处理请求
c.Next()
// 结束时间
endTime := time.Now()
// 执行时间
latency := endTime.Sub(startTime)
// 请求方法
method := c.Request.Method
// 请求路径
path := c.Request.URL.Path
// 状态码
// 处理请求
c.Next()
// 跳过健康检查等高频探针路径的日志
if path == "/health" || path == "/setup/status" {
return
}
endTime := time.Now()
latency := endTime.Sub(startTime)
method := c.Request.Method
statusCode := c.Writer.Status()
// 客户端IP
clientIP := c.ClientIP()
// 协议版本
protocol := c.Request.Proto
accountID, hasAccountID := c.Request.Context().Value(ctxkey.AccountID).(int64)
platform, _ := c.Request.Context().Value(ctxkey.Platform).(string)
model, _ := c.Request.Context().Value(ctxkey.Model).(string)
// 日志格式: [时间] 状态码 | 延迟 | IP | 协议 | 方法 路径
log.Printf("[GIN] %v | %3d | %13v | %15s | %-6s | %-7s %s",
endTime.Format("2006/01/02 - 15:04:05"),
statusCode,
latency,
clientIP,
protocol,
method,
path,
)
fields := []zap.Field{
zap.String("component", "http.access"),
zap.Int("status_code", statusCode),
zap.Int64("latency_ms", latency.Milliseconds()),
zap.String("client_ip", clientIP),
zap.String("protocol", protocol),
zap.String("method", method),
zap.String("path", path),
}
if hasAccountID && accountID > 0 {
fields = append(fields, zap.Int64("account_id", accountID))
}
if platform != "" {
fields = append(fields, zap.String("platform", platform))
}
if model != "" {
fields = append(fields, zap.String("model", model))
}
l := logger.FromContext(c.Request.Context()).With(fields...)
l.Info("http request completed", zap.Time("completed_at", endTime))
// 如果有错误,额外记录错误信息
if len(c.Errors) > 0 {
log.Printf("[GIN] Errors: %v", c.Errors.String())
l.Warn("http request contains gin errors", zap.String("errors", c.Errors.String()))
}
}
}

View File

@@ -0,0 +1,126 @@
//go:build unit
package middleware
import (
"bytes"
"context"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestClientRequestID_GeneratesWhenMissing(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
r.Use(ClientRequestID())
r.GET("/t", func(c *gin.Context) {
v := c.Request.Context().Value(ctxkey.ClientRequestID)
require.NotNil(t, v)
id, ok := v.(string)
require.True(t, ok)
require.NotEmpty(t, id)
c.Status(http.StatusOK)
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/t", nil)
r.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
}
func TestClientRequestID_PreservesExisting(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
r.Use(ClientRequestID())
r.GET("/t", func(c *gin.Context) {
id, ok := c.Request.Context().Value(ctxkey.ClientRequestID).(string)
require.True(t, ok)
require.Equal(t, "keep", id)
c.Status(http.StatusOK)
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/t", nil)
req = req.WithContext(context.WithValue(req.Context(), ctxkey.ClientRequestID, "keep"))
r.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
}
func TestRequestBodyLimit_LimitsBody(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
r.Use(RequestBodyLimit(4))
r.POST("/t", func(c *gin.Context) {
_, err := io.ReadAll(c.Request.Body)
require.Error(t, err)
c.Status(http.StatusOK)
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/t", bytes.NewBufferString("12345"))
r.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
}
func TestForcePlatform_SetsContextAndGinValue(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
r.Use(ForcePlatform("anthropic"))
r.GET("/t", func(c *gin.Context) {
require.True(t, HasForcePlatform(c))
v, ok := GetForcePlatformFromContext(c)
require.True(t, ok)
require.Equal(t, "anthropic", v)
ctxV := c.Request.Context().Value(ctxkey.ForcePlatform)
require.Equal(t, "anthropic", ctxV)
c.Status(http.StatusOK)
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/t", nil)
r.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
}
func TestAuthSubjectHelpers_RoundTrip(t *testing.T) {
c := &gin.Context{}
c.Set(string(ContextKeyUser), AuthSubject{UserID: 1, Concurrency: 2})
c.Set(string(ContextKeyUserRole), "admin")
sub, ok := GetAuthSubjectFromContext(c)
require.True(t, ok)
require.Equal(t, int64(1), sub.UserID)
require.Equal(t, 2, sub.Concurrency)
role, ok := GetUserRoleFromContext(c)
require.True(t, ok)
require.Equal(t, "admin", role)
}
func TestAPIKeyAndSubscriptionFromContext(t *testing.T) {
c := &gin.Context{}
key := &service.APIKey{ID: 1}
c.Set(string(ContextKeyAPIKey), key)
gotKey, ok := GetAPIKeyFromContext(c)
require.True(t, ok)
require.Equal(t, int64(1), gotKey.ID)
sub := &service.UserSubscription{ID: 2}
c.Set(string(ContextKeySubscription), sub)
gotSub, ok := GetSubscriptionFromContext(c)
require.True(t, ok)
require.Equal(t, int64(2), gotSub.ID)
}

View File

@@ -3,6 +3,7 @@
package middleware
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
@@ -14,6 +15,34 @@ import (
"github.com/stretchr/testify/require"
)
func TestRecovery_PanicLogContainsInfo(t *testing.T) {
gin.SetMode(gin.TestMode)
// 临时替换 DefaultErrorWriter 以捕获日志输出
var buf bytes.Buffer
originalWriter := gin.DefaultErrorWriter
gin.DefaultErrorWriter = &buf
t.Cleanup(func() {
gin.DefaultErrorWriter = originalWriter
})
r := gin.New()
r.Use(Recovery())
r.GET("/panic", func(c *gin.Context) {
panic("custom panic message for test")
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/panic", nil)
r.ServeHTTP(w, req)
require.Equal(t, http.StatusInternalServerError, w.Code)
logOutput := buf.String()
require.Contains(t, logOutput, "custom panic message for test", "日志应包含 panic 信息")
require.Contains(t, logOutput, "recovery_test.go", "日志应包含堆栈跟踪文件名")
}
func TestRecovery(t *testing.T) {
gin.SetMode(gin.TestMode)

View File

@@ -0,0 +1,228 @@
package middleware
import (
"context"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/gin-gonic/gin"
)
type testLogSink struct {
mu sync.Mutex
events []*logger.LogEvent
}
func (s *testLogSink) WriteLogEvent(event *logger.LogEvent) {
s.mu.Lock()
defer s.mu.Unlock()
s.events = append(s.events, event)
}
func (s *testLogSink) list() []*logger.LogEvent {
s.mu.Lock()
defer s.mu.Unlock()
out := make([]*logger.LogEvent, len(s.events))
copy(out, s.events)
return out
}
func initMiddlewareTestLogger(t *testing.T) *testLogSink {
return initMiddlewareTestLoggerWithLevel(t, "debug")
}
func initMiddlewareTestLoggerWithLevel(t *testing.T, level string) *testLogSink {
t.Helper()
level = strings.TrimSpace(level)
if level == "" {
level = "debug"
}
if err := logger.Init(logger.InitOptions{
Level: level,
Format: "json",
ServiceName: "sub2api",
Environment: "test",
Output: logger.OutputOptions{
ToStdout: false,
ToFile: false,
},
}); err != nil {
t.Fatalf("init logger: %v", err)
}
sink := &testLogSink{}
logger.SetSink(sink)
t.Cleanup(func() {
logger.SetSink(nil)
})
return sink
}
func TestRequestLogger_GenerateAndPropagateRequestID(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
r.Use(RequestLogger())
r.GET("/t", func(c *gin.Context) {
reqID, ok := c.Request.Context().Value(ctxkey.RequestID).(string)
if !ok || reqID == "" {
t.Fatalf("request_id missing in context")
}
if got := c.Writer.Header().Get(requestIDHeader); got != reqID {
t.Fatalf("response header request_id mismatch, header=%q ctx=%q", got, reqID)
}
c.Status(http.StatusOK)
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/t", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("status=%d", w.Code)
}
if w.Header().Get(requestIDHeader) == "" {
t.Fatalf("X-Request-ID should be set")
}
}
func TestRequestLogger_KeepIncomingRequestID(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
r.Use(RequestLogger())
r.GET("/t", func(c *gin.Context) {
reqID, _ := c.Request.Context().Value(ctxkey.RequestID).(string)
if reqID != "rid-fixed" {
t.Fatalf("request_id=%q, want rid-fixed", reqID)
}
c.Status(http.StatusOK)
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/t", nil)
req.Header.Set(requestIDHeader, "rid-fixed")
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("status=%d", w.Code)
}
if got := w.Header().Get(requestIDHeader); got != "rid-fixed" {
t.Fatalf("header=%q, want rid-fixed", got)
}
}
func TestLogger_AccessLogIncludesCoreFields(t *testing.T) {
gin.SetMode(gin.TestMode)
sink := initMiddlewareTestLogger(t)
r := gin.New()
r.Use(Logger())
r.Use(func(c *gin.Context) {
ctx := c.Request.Context()
ctx = context.WithValue(ctx, ctxkey.AccountID, int64(101))
ctx = context.WithValue(ctx, ctxkey.Platform, "openai")
ctx = context.WithValue(ctx, ctxkey.Model, "gpt-5")
c.Request = c.Request.WithContext(ctx)
c.Next()
})
r.GET("/api/test", func(c *gin.Context) {
c.Status(http.StatusCreated)
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/test", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusCreated {
t.Fatalf("status=%d", w.Code)
}
events := sink.list()
if len(events) == 0 {
t.Fatalf("expected at least one log event")
}
found := false
for _, event := range events {
if event == nil || event.Message != "http request completed" {
continue
}
found = true
switch v := event.Fields["status_code"].(type) {
case int:
if v != http.StatusCreated {
t.Fatalf("status_code field mismatch: %v", v)
}
case int64:
if v != int64(http.StatusCreated) {
t.Fatalf("status_code field mismatch: %v", v)
}
default:
t.Fatalf("status_code type mismatch: %T", v)
}
switch v := event.Fields["account_id"].(type) {
case int64:
if v != 101 {
t.Fatalf("account_id field mismatch: %v", v)
}
case int:
if v != 101 {
t.Fatalf("account_id field mismatch: %v", v)
}
default:
t.Fatalf("account_id type mismatch: %T", v)
}
if event.Fields["platform"] != "openai" || event.Fields["model"] != "gpt-5" {
t.Fatalf("platform/model mismatch: %+v", event.Fields)
}
}
if !found {
t.Fatalf("access log event not found")
}
}
func TestLogger_HealthPathSkipped(t *testing.T) {
gin.SetMode(gin.TestMode)
sink := initMiddlewareTestLogger(t)
r := gin.New()
r.Use(Logger())
r.GET("/health", func(c *gin.Context) {
c.Status(http.StatusOK)
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/health", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("status=%d", w.Code)
}
if len(sink.list()) != 0 {
t.Fatalf("health endpoint should not write access log")
}
}
func TestLogger_AccessLogDroppedWhenLevelWarn(t *testing.T) {
gin.SetMode(gin.TestMode)
sink := initMiddlewareTestLoggerWithLevel(t, "warn")
r := gin.New()
r.Use(RequestLogger())
r.Use(Logger())
r.GET("/api/test", func(c *gin.Context) {
c.Status(http.StatusCreated)
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/test", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusCreated {
t.Fatalf("status=%d", w.Code)
}
events := sink.list()
for _, event := range events {
if event != nil && event.Message == "http request completed" {
t.Fatalf("access log should not be indexed when level=warn: %+v", event)
}
}
}

View File

@@ -0,0 +1,45 @@
package middleware
import (
"context"
"strings"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"go.uber.org/zap"
)
const requestIDHeader = "X-Request-ID"
// RequestLogger 在请求入口注入 request-scoped logger。
func RequestLogger() gin.HandlerFunc {
return func(c *gin.Context) {
if c.Request == nil {
c.Next()
return
}
requestID := strings.TrimSpace(c.GetHeader(requestIDHeader))
if requestID == "" {
requestID = uuid.NewString()
}
c.Header(requestIDHeader, requestID)
ctx := context.WithValue(c.Request.Context(), ctxkey.RequestID, requestID)
clientRequestID, _ := ctx.Value(ctxkey.ClientRequestID).(string)
requestLogger := logger.With(
zap.String("component", "http"),
zap.String("request_id", requestID),
zap.String("client_request_id", strings.TrimSpace(clientRequestID)),
zap.String("path", c.Request.URL.Path),
zap.String("method", c.Request.Method),
)
ctx = logger.IntoContext(ctx, requestLogger)
c.Request = c.Request.WithContext(ctx)
c.Next()
}
}

View File

@@ -3,6 +3,8 @@ package middleware
import (
"crypto/rand"
"encoding/base64"
"fmt"
"log"
"strings"
"github.com/Wei-Shaw/sub2api/internal/config"
@@ -18,11 +20,14 @@ const (
CloudflareInsightsDomain = "https://static.cloudflareinsights.com"
)
// GenerateNonce generates a cryptographically secure random nonce
func GenerateNonce() string {
// GenerateNonce generates a cryptographically secure random nonce.
// 返回 error 以确保调用方在 crypto/rand 失败时能正确降级。
func GenerateNonce() (string, error) {
b := make([]byte, 16)
_, _ = rand.Read(b)
return base64.StdEncoding.EncodeToString(b)
if _, err := rand.Read(b); err != nil {
return "", fmt.Errorf("generate CSP nonce: %w", err)
}
return base64.StdEncoding.EncodeToString(b), nil
}
// GetNonceFromContext retrieves the CSP nonce from gin context
@@ -52,12 +57,17 @@ func SecurityHeaders(cfg config.CSPConfig) gin.HandlerFunc {
if cfg.Enabled {
// Generate nonce for this request
nonce := GenerateNonce()
c.Set(CSPNonceKey, nonce)
// Replace nonce placeholder in policy
finalPolicy := strings.ReplaceAll(policy, NonceTemplate, "'nonce-"+nonce+"'")
c.Header("Content-Security-Policy", finalPolicy)
nonce, err := GenerateNonce()
if err != nil {
// crypto/rand 失败时降级为无 nonce 的 CSP 策略
log.Printf("[SecurityHeaders] %v — 降级为无 nonce 的 CSP", err)
finalPolicy := strings.ReplaceAll(policy, NonceTemplate, "'unsafe-inline'")
c.Header("Content-Security-Policy", finalPolicy)
} else {
c.Set(CSPNonceKey, nonce)
finalPolicy := strings.ReplaceAll(policy, NonceTemplate, "'nonce-"+nonce+"'")
c.Header("Content-Security-Policy", finalPolicy)
}
}
c.Next()
}

View File

@@ -19,7 +19,8 @@ func init() {
func TestGenerateNonce(t *testing.T) {
t.Run("generates_valid_base64_string", func(t *testing.T) {
nonce := GenerateNonce()
nonce, err := GenerateNonce()
require.NoError(t, err)
// Should be valid base64
decoded, err := base64.StdEncoding.DecodeString(nonce)
@@ -32,14 +33,16 @@ func TestGenerateNonce(t *testing.T) {
t.Run("generates_unique_nonces", func(t *testing.T) {
nonces := make(map[string]bool)
for i := 0; i < 100; i++ {
nonce := GenerateNonce()
nonce, err := GenerateNonce()
require.NoError(t, err)
assert.False(t, nonces[nonce], "nonce should be unique")
nonces[nonce] = true
}
})
t.Run("nonce_has_expected_length", func(t *testing.T) {
nonce := GenerateNonce()
nonce, err := GenerateNonce()
require.NoError(t, err)
// 16 bytes -> 24 chars in base64 (with padding)
assert.Len(t, nonce, 24)
})
@@ -344,7 +347,7 @@ func TestAddToDirective(t *testing.T) {
// Benchmark tests
func BenchmarkGenerateNonce(b *testing.B) {
for i := 0; i < b.N; i++ {
GenerateNonce()
_, _ = GenerateNonce()
}
}

View File

@@ -29,6 +29,7 @@ func SetupRouter(
redisClient *redis.Client,
) *gin.Engine {
// 应用中间件
r.Use(middleware2.RequestLogger())
r.Use(middleware2.Logger())
r.Use(middleware2.CORS(cfg.CORS))
r.Use(middleware2.SecurityHeaders(cfg.Security.CSP))

View File

@@ -34,6 +34,8 @@ func RegisterAdminRoutes(
// OpenAI OAuth
registerOpenAIOAuthRoutes(admin, h)
// Sora OAuth实现复用 OpenAI OAuth 服务,入口独立)
registerSoraOAuthRoutes(admin, h)
// Gemini OAuth
registerGeminiOAuthRoutes(admin, h)
@@ -101,6 +103,9 @@ func registerOpsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
{
runtime.GET("/alert", h.Admin.Ops.GetAlertRuntimeSettings)
runtime.PUT("/alert", h.Admin.Ops.UpdateAlertRuntimeSettings)
runtime.GET("/logging", h.Admin.Ops.GetRuntimeLogConfig)
runtime.PUT("/logging", h.Admin.Ops.UpdateRuntimeLogConfig)
runtime.POST("/logging/reset", h.Admin.Ops.ResetRuntimeLogConfig)
}
// Advanced settings (DB-backed)
@@ -144,12 +149,18 @@ func registerOpsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
// Request drilldown (success + error)
ops.GET("/requests", h.Admin.Ops.ListRequestDetails)
// Indexed system logs
ops.GET("/system-logs", h.Admin.Ops.ListSystemLogs)
ops.POST("/system-logs/cleanup", h.Admin.Ops.CleanupSystemLogs)
ops.GET("/system-logs/health", h.Admin.Ops.GetSystemLogIngestionHealth)
// Dashboard (vNext - raw path for MVP)
ops.GET("/dashboard/overview", h.Admin.Ops.GetDashboardOverview)
ops.GET("/dashboard/throughput-trend", h.Admin.Ops.GetDashboardThroughputTrend)
ops.GET("/dashboard/latency-histogram", h.Admin.Ops.GetDashboardLatencyHistogram)
ops.GET("/dashboard/error-trend", h.Admin.Ops.GetDashboardErrorTrend)
ops.GET("/dashboard/error-distribution", h.Admin.Ops.GetDashboardErrorDistribution)
ops.GET("/dashboard/openai-token-stats", h.Admin.Ops.GetDashboardOpenAITokenStats)
}
}
@@ -208,6 +219,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
accounts.GET("", h.Admin.Account.List)
accounts.GET("/:id", h.Admin.Account.GetByID)
accounts.POST("", h.Admin.Account.Create)
accounts.POST("/check-mixed-channel", h.Admin.Account.CheckMixedChannel)
accounts.POST("/sync/crs", h.Admin.Account.SyncFromCRS)
accounts.POST("/sync/crs/preview", h.Admin.Account.PreviewFromCRS)
accounts.PUT("/:id", h.Admin.Account.Update)
@@ -267,6 +279,19 @@ func registerOpenAIOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
}
}
func registerSoraOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
sora := admin.Group("/sora")
{
sora.POST("/generate-auth-url", h.Admin.OpenAIOAuth.GenerateAuthURL)
sora.POST("/exchange-code", h.Admin.OpenAIOAuth.ExchangeCode)
sora.POST("/refresh-token", h.Admin.OpenAIOAuth.RefreshToken)
sora.POST("/st2at", h.Admin.OpenAIOAuth.ExchangeSoraSessionToken)
sora.POST("/rt2at", h.Admin.OpenAIOAuth.RefreshToken)
sora.POST("/accounts/:id/refresh", h.Admin.OpenAIOAuth.RefreshAccountToken)
sora.POST("/create-from-oauth", h.Admin.OpenAIOAuth.CreateAccountFromOAuth)
}
}
func registerGeminiOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
gemini := admin.Group("/gemini")
{
@@ -297,6 +322,7 @@ func registerProxyRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
proxies.PUT("/:id", h.Admin.Proxy.Update)
proxies.DELETE("/:id", h.Admin.Proxy.Delete)
proxies.POST("/:id/test", h.Admin.Proxy.Test)
proxies.POST("/:id/quality-check", h.Admin.Proxy.CheckQuality)
proxies.GET("/:id/stats", h.Admin.Proxy.GetStats)
proxies.GET("/:id/accounts", h.Admin.Proxy.GetProxyAccounts)
proxies.POST("/batch-delete", h.Admin.Proxy.BatchDelete)

View File

@@ -24,10 +24,19 @@ func RegisterAuthRoutes(
// 公开接口
auth := v1.Group("/auth")
{
auth.POST("/register", h.Auth.Register)
auth.POST("/login", h.Auth.Login)
auth.POST("/login/2fa", h.Auth.Login2FA)
auth.POST("/send-verify-code", h.Auth.SendVerifyCode)
// 注册/登录/2FA/验证码发送均属于高风险入口增加服务端兜底限流Redis 故障时 fail-close
auth.POST("/register", rateLimiter.LimitWithOptions("auth-register", 5, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}), h.Auth.Register)
auth.POST("/login", rateLimiter.LimitWithOptions("auth-login", 20, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}), h.Auth.Login)
auth.POST("/login/2fa", rateLimiter.LimitWithOptions("auth-login-2fa", 20, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}), h.Auth.Login2FA)
auth.POST("/send-verify-code", rateLimiter.LimitWithOptions("auth-send-verify-code", 5, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}), h.Auth.SendVerifyCode)
// Token刷新接口添加速率限制每分钟最多 30 次Redis 故障时 fail-close
auth.POST("/refresh", rateLimiter.LimitWithOptions("refresh-token", 30, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,

View File

@@ -0,0 +1,111 @@
//go:build integration
package routes
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strconv"
"strings"
"testing"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
tcredis "github.com/testcontainers/testcontainers-go/modules/redis"
)
const authRouteRedisImageTag = "redis:8.4-alpine"
func TestAuthRegisterRateLimitThresholdHitReturns429(t *testing.T) {
ctx := context.Background()
rdb := startAuthRouteRedis(t, ctx)
router := newAuthRoutesTestRouter(rdb)
const path = "/api/v1/auth/register"
for i := 1; i <= 6; i++ {
req := httptest.NewRequest(http.MethodPost, path, strings.NewReader(`{}`))
req.Header.Set("Content-Type", "application/json")
req.RemoteAddr = "198.51.100.10:23456"
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if i <= 5 {
require.Equal(t, http.StatusBadRequest, w.Code, "第 %d 次请求应先进入业务校验", i)
continue
}
require.Equal(t, http.StatusTooManyRequests, w.Code, "第 6 次请求应命中限流")
require.Contains(t, w.Body.String(), "rate limit exceeded")
}
}
func startAuthRouteRedis(t *testing.T, ctx context.Context) *redis.Client {
t.Helper()
ensureAuthRouteDockerAvailable(t)
redisContainer, err := tcredis.Run(ctx, authRouteRedisImageTag)
require.NoError(t, err)
t.Cleanup(func() {
_ = redisContainer.Terminate(ctx)
})
redisHost, err := redisContainer.Host(ctx)
require.NoError(t, err)
redisPort, err := redisContainer.MappedPort(ctx, "6379/tcp")
require.NoError(t, err)
rdb := redis.NewClient(&redis.Options{
Addr: fmt.Sprintf("%s:%d", redisHost, redisPort.Int()),
DB: 0,
})
require.NoError(t, rdb.Ping(ctx).Err())
t.Cleanup(func() {
_ = rdb.Close()
})
return rdb
}
func ensureAuthRouteDockerAvailable(t *testing.T) {
t.Helper()
if authRouteDockerAvailable() {
return
}
t.Skip("Docker 未启用,跳过认证限流集成测试")
}
func authRouteDockerAvailable() bool {
if os.Getenv("DOCKER_HOST") != "" {
return true
}
socketCandidates := []string{
"/var/run/docker.sock",
filepath.Join(os.Getenv("XDG_RUNTIME_DIR"), "docker.sock"),
filepath.Join(authRouteUserHomeDir(), ".docker", "run", "docker.sock"),
filepath.Join(authRouteUserHomeDir(), ".docker", "desktop", "docker.sock"),
filepath.Join("/run/user", strconv.Itoa(os.Getuid()), "docker.sock"),
}
for _, socket := range socketCandidates {
if socket == "" {
continue
}
if _, err := os.Stat(socket); err == nil {
return true
}
}
return false
}
func authRouteUserHomeDir() string {
home, err := os.UserHomeDir()
if err != nil {
return ""
}
return home
}

View File

@@ -0,0 +1,67 @@
package routes
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/handler"
servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/gin-gonic/gin"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
)
func newAuthRoutesTestRouter(redisClient *redis.Client) *gin.Engine {
gin.SetMode(gin.TestMode)
router := gin.New()
v1 := router.Group("/api/v1")
RegisterAuthRoutes(
v1,
&handler.Handlers{
Auth: &handler.AuthHandler{},
Setting: &handler.SettingHandler{},
},
servermiddleware.JWTAuthMiddleware(func(c *gin.Context) {
c.Next()
}),
redisClient,
)
return router
}
func TestAuthRoutesRateLimitFailCloseWhenRedisUnavailable(t *testing.T) {
rdb := redis.NewClient(&redis.Options{
Addr: "127.0.0.1:1",
DialTimeout: 50 * time.Millisecond,
ReadTimeout: 50 * time.Millisecond,
WriteTimeout: 50 * time.Millisecond,
})
t.Cleanup(func() {
_ = rdb.Close()
})
router := newAuthRoutesTestRouter(rdb)
paths := []string{
"/api/v1/auth/register",
"/api/v1/auth/login",
"/api/v1/auth/login/2fa",
"/api/v1/auth/send-verify-code",
}
for _, path := range paths {
req := httptest.NewRequest(http.MethodPost, path, strings.NewReader(`{}`))
req.Header.Set("Content-Type", "application/json")
req.RemoteAddr = "203.0.113.10:12345"
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
require.Equal(t, http.StatusTooManyRequests, w.Code, "path=%s", path)
require.Contains(t, w.Body.String(), "rate limit exceeded", "path=%s", path)
}
}

View File

@@ -1,6 +1,8 @@
package routes
import (
"net/http"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
@@ -20,6 +22,11 @@ func RegisterGatewayRoutes(
cfg *config.Config,
) {
bodyLimit := middleware.RequestBodyLimit(cfg.Gateway.MaxBodySize)
soraMaxBodySize := cfg.Gateway.SoraMaxBodySize
if soraMaxBodySize <= 0 {
soraMaxBodySize = cfg.Gateway.MaxBodySize
}
soraBodyLimit := middleware.RequestBodyLimit(soraMaxBodySize)
clientRequestID := middleware.ClientRequestID()
opsErrorLogger := handler.OpsErrorLoggerMiddleware(opsService)
@@ -36,6 +43,15 @@ func RegisterGatewayRoutes(
gateway.GET("/usage", h.Gateway.Usage)
// OpenAI Responses API
gateway.POST("/responses", h.OpenAIGateway.Responses)
// 明确阻止旧协议入口OpenAI 仅支持 Responses API避免客户端误解为会自动路由到其它平台。
gateway.POST("/chat/completions", func(c *gin.Context) {
c.JSON(http.StatusBadRequest, gin.H{
"error": gin.H{
"type": "invalid_request_error",
"message": "Unsupported legacy protocol: /v1/chat/completions is not supported. Please use /v1/responses.",
},
})
})
}
// Gemini 原生 API 兼容层Gemini SDK/CLI 直连)
@@ -82,4 +98,25 @@ func RegisterGatewayRoutes(
antigravityV1Beta.GET("/models/:model", h.Gateway.GeminiV1BetaGetModel)
antigravityV1Beta.POST("/models/*modelAction", h.Gateway.GeminiV1BetaModels)
}
// Sora 专用路由(强制使用 sora 平台)
soraV1 := r.Group("/sora/v1")
soraV1.Use(soraBodyLimit)
soraV1.Use(clientRequestID)
soraV1.Use(opsErrorLogger)
soraV1.Use(middleware.ForcePlatform(service.PlatformSora))
soraV1.Use(gin.HandlerFunc(apiKeyAuth))
{
soraV1.POST("/chat/completions", h.SoraGateway.ChatCompletions)
soraV1.GET("/models", h.Gateway.Models)
}
// Sora 媒体代理(可选 API Key 验证)
if cfg.Gateway.SoraMediaRequireAPIKey {
r.GET("/sora/media/*filepath", gin.HandlerFunc(apiKeyAuth), h.SoraGateway.MediaProxy)
} else {
r.GET("/sora/media/*filepath", h.SoraGateway.MediaProxy)
}
// Sora 媒体代理(签名 URL无需 API Key
r.GET("/sora/media-signed/*filepath", h.SoraGateway.MediaProxySigned)
}