Merge upstream/main: v0.1.85-v0.1.86 updates
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -83,6 +83,7 @@ func TestAPIContracts(t *testing.T) {
|
||||
"status": "active",
|
||||
"ip_whitelist": null,
|
||||
"ip_blacklist": null,
|
||||
"last_used_at": null,
|
||||
"quota": 0,
|
||||
"quota_used": 0,
|
||||
"expires_at": null,
|
||||
@@ -122,6 +123,7 @@ func TestAPIContracts(t *testing.T) {
|
||||
"status": "active",
|
||||
"ip_whitelist": null,
|
||||
"ip_blacklist": null,
|
||||
"last_used_at": null,
|
||||
"quota": 0,
|
||||
"quota_used": 0,
|
||||
"expires_at": null,
|
||||
@@ -184,6 +186,10 @@ func TestAPIContracts(t *testing.T) {
|
||||
"image_price_1k": null,
|
||||
"image_price_2k": null,
|
||||
"image_price_4k": null,
|
||||
"sora_image_price_360": null,
|
||||
"sora_image_price_540": null,
|
||||
"sora_video_price_per_request": null,
|
||||
"sora_video_price_per_request_hd": null,
|
||||
"claude_code_only": false,
|
||||
"fallback_group_id": null,
|
||||
"fallback_group_id_on_invalid_request": null,
|
||||
@@ -401,6 +407,7 @@ func TestAPIContracts(t *testing.T) {
|
||||
"first_token_ms": 50,
|
||||
"image_count": 0,
|
||||
"image_size": null,
|
||||
"media_type": null,
|
||||
"cache_ttl_overridden": false,
|
||||
"created_at": "2025-01-02T03:04:05Z",
|
||||
"user_agent": null
|
||||
@@ -593,13 +600,13 @@ func newContractDeps(t *testing.T) *contractDeps {
|
||||
RunMode: config.RunModeStandard,
|
||||
}
|
||||
|
||||
userService := service.NewUserService(userRepo, nil)
|
||||
userService := service.NewUserService(userRepo, nil, nil)
|
||||
apiKeyService := service.NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, nil, apiKeyCache, cfg)
|
||||
|
||||
usageRepo := newStubUsageLogRepo()
|
||||
usageService := service.NewUsageService(usageRepo, userRepo, nil, nil)
|
||||
|
||||
subscriptionService := service.NewSubscriptionService(groupRepo, userSubRepo, nil)
|
||||
subscriptionService := service.NewSubscriptionService(groupRepo, userSubRepo, nil, nil, cfg)
|
||||
subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService)
|
||||
|
||||
redeemService := service.NewRedeemService(redeemRepo, userRepo, subscriptionService, nil, nil, nil, nil)
|
||||
@@ -608,7 +615,7 @@ func newContractDeps(t *testing.T) *contractDeps {
|
||||
settingRepo := newStubSettingRepo()
|
||||
settingService := service.NewSettingService(settingRepo, cfg)
|
||||
|
||||
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil)
|
||||
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil)
|
||||
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil)
|
||||
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
||||
@@ -925,6 +932,10 @@ func (s *stubAccountRepo) GetByCRSAccountID(ctx context.Context, crsAccountID st
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (s *stubAccountRepo) FindByExtraField(ctx context.Context, key string, value any) ([]service.Account, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (s *stubAccountRepo) Update(ctx context.Context, account *service.Account) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
@@ -1462,6 +1473,20 @@ func (r *stubApiKeyRepo) IncrementQuotaUsed(ctx context.Context, id int64, amoun
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error {
|
||||
key, ok := r.byID[id]
|
||||
if !ok {
|
||||
return service.ErrAPIKeyNotFound
|
||||
}
|
||||
ts := usedAt
|
||||
key.LastUsedAt = &ts
|
||||
key.UpdatedAt = usedAt
|
||||
clone := *key
|
||||
r.byID[id] = &clone
|
||||
r.byKey[clone.Key] = &clone
|
||||
return nil
|
||||
}
|
||||
|
||||
type stubUsageLogRepo struct {
|
||||
userLogs map[int64][]service.UsageLog
|
||||
}
|
||||
@@ -1607,11 +1632,11 @@ func (r *stubUsageLogRepo) GetDailyStatsAggregated(ctx context.Context, userID i
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error) {
|
||||
func (r *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUsageLogRepo) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
|
||||
func (r *stubUsageLogRepo) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
|
||||
@@ -51,6 +51,9 @@ func ProvideRouter(
|
||||
if err := r.SetTrustedProxies(nil); err != nil {
|
||||
log.Printf("Failed to disable trusted proxies: %v", err)
|
||||
}
|
||||
if cfg.Server.Mode == "release" {
|
||||
log.Printf("Warning: server.trusted_proxies is empty in release mode; client IP trust chain is disabled")
|
||||
}
|
||||
}
|
||||
|
||||
return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg, redisClient)
|
||||
|
||||
@@ -58,8 +58,13 @@ func adminAuth(
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader != "" {
|
||||
parts := strings.SplitN(authHeader, " ", 2)
|
||||
if len(parts) == 2 && parts[0] == "Bearer" {
|
||||
if !validateJWTForAdmin(c, parts[1], authService, userService) {
|
||||
if len(parts) == 2 && strings.EqualFold(parts[0], "Bearer") {
|
||||
token := strings.TrimSpace(parts[1])
|
||||
if token == "" {
|
||||
AbortWithError(c, 401, "UNAUTHORIZED", "Authorization required")
|
||||
return
|
||||
}
|
||||
if !validateJWTForAdmin(c, token, authService, userService) {
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
@@ -176,6 +181,12 @@ func validateJWTForAdmin(
|
||||
return false
|
||||
}
|
||||
|
||||
// 校验 TokenVersion,确保管理员改密后旧 token 失效
|
||||
if claims.TokenVersion != user.TokenVersion {
|
||||
AbortWithError(c, 401, "TOKEN_REVOKED", "Token has been revoked (password changed)")
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查管理员权限
|
||||
if !user.IsAdmin() {
|
||||
AbortWithError(c, 403, "FORBIDDEN", "Admin access required")
|
||||
|
||||
194
backend/internal/server/middleware/admin_auth_test.go
Normal file
194
backend/internal/server/middleware/admin_auth_test.go
Normal file
@@ -0,0 +1,194 @@
|
||||
//go:build unit
|
||||
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAdminAuthJWTValidatesTokenVersion(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
cfg := &config.Config{JWT: config.JWTConfig{Secret: "test-secret", ExpireHour: 1}}
|
||||
authService := service.NewAuthService(nil, nil, nil, cfg, nil, nil, nil, nil, nil)
|
||||
|
||||
admin := &service.User{
|
||||
ID: 1,
|
||||
Email: "admin@example.com",
|
||||
Role: service.RoleAdmin,
|
||||
Status: service.StatusActive,
|
||||
TokenVersion: 2,
|
||||
Concurrency: 1,
|
||||
}
|
||||
|
||||
userRepo := &stubUserRepo{
|
||||
getByID: func(ctx context.Context, id int64) (*service.User, error) {
|
||||
if id != admin.ID {
|
||||
return nil, service.ErrUserNotFound
|
||||
}
|
||||
clone := *admin
|
||||
return &clone, nil
|
||||
},
|
||||
}
|
||||
userService := service.NewUserService(userRepo, nil, nil)
|
||||
|
||||
router := gin.New()
|
||||
router.Use(gin.HandlerFunc(NewAdminAuthMiddleware(authService, userService, nil)))
|
||||
router.GET("/t", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
|
||||
t.Run("token_version_mismatch_rejected", func(t *testing.T) {
|
||||
token, err := authService.GenerateToken(&service.User{
|
||||
ID: admin.ID,
|
||||
Email: admin.Email,
|
||||
Role: admin.Role,
|
||||
TokenVersion: admin.TokenVersion - 1,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusUnauthorized, w.Code)
|
||||
require.Contains(t, w.Body.String(), "TOKEN_REVOKED")
|
||||
})
|
||||
|
||||
t.Run("token_version_match_allows", func(t *testing.T) {
|
||||
token, err := authService.GenerateToken(&service.User{
|
||||
ID: admin.ID,
|
||||
Email: admin.Email,
|
||||
Role: admin.Role,
|
||||
TokenVersion: admin.TokenVersion,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
})
|
||||
|
||||
t.Run("websocket_token_version_mismatch_rejected", func(t *testing.T) {
|
||||
token, err := authService.GenerateToken(&service.User{
|
||||
ID: admin.ID,
|
||||
Email: admin.Email,
|
||||
Role: admin.Role,
|
||||
TokenVersion: admin.TokenVersion - 1,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
||||
req.Header.Set("Upgrade", "websocket")
|
||||
req.Header.Set("Connection", "Upgrade")
|
||||
req.Header.Set("Sec-WebSocket-Protocol", "sub2api-admin, jwt."+token)
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusUnauthorized, w.Code)
|
||||
require.Contains(t, w.Body.String(), "TOKEN_REVOKED")
|
||||
})
|
||||
|
||||
t.Run("websocket_token_version_match_allows", func(t *testing.T) {
|
||||
token, err := authService.GenerateToken(&service.User{
|
||||
ID: admin.ID,
|
||||
Email: admin.Email,
|
||||
Role: admin.Role,
|
||||
TokenVersion: admin.TokenVersion,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
||||
req.Header.Set("Upgrade", "websocket")
|
||||
req.Header.Set("Connection", "Upgrade")
|
||||
req.Header.Set("Sec-WebSocket-Protocol", "sub2api-admin, jwt."+token)
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
})
|
||||
}
|
||||
|
||||
type stubUserRepo struct {
|
||||
getByID func(ctx context.Context, id int64) (*service.User, error)
|
||||
}
|
||||
|
||||
func (s *stubUserRepo) Create(ctx context.Context, user *service.User) error {
|
||||
panic("unexpected Create call")
|
||||
}
|
||||
|
||||
func (s *stubUserRepo) GetByID(ctx context.Context, id int64) (*service.User, error) {
|
||||
if s.getByID == nil {
|
||||
panic("GetByID not stubbed")
|
||||
}
|
||||
return s.getByID(ctx, id)
|
||||
}
|
||||
|
||||
func (s *stubUserRepo) GetByEmail(ctx context.Context, email string) (*service.User, error) {
|
||||
panic("unexpected GetByEmail call")
|
||||
}
|
||||
|
||||
func (s *stubUserRepo) GetFirstAdmin(ctx context.Context) (*service.User, error) {
|
||||
panic("unexpected GetFirstAdmin call")
|
||||
}
|
||||
|
||||
func (s *stubUserRepo) Update(ctx context.Context, user *service.User) error {
|
||||
panic("unexpected Update call")
|
||||
}
|
||||
|
||||
func (s *stubUserRepo) Delete(ctx context.Context, id int64) error {
|
||||
panic("unexpected Delete call")
|
||||
}
|
||||
|
||||
func (s *stubUserRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
|
||||
panic("unexpected List call")
|
||||
}
|
||||
|
||||
func (s *stubUserRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListWithFilters call")
|
||||
}
|
||||
|
||||
func (s *stubUserRepo) UpdateBalance(ctx context.Context, id int64, amount float64) error {
|
||||
panic("unexpected UpdateBalance call")
|
||||
}
|
||||
|
||||
func (s *stubUserRepo) DeductBalance(ctx context.Context, id int64, amount float64) error {
|
||||
panic("unexpected DeductBalance call")
|
||||
}
|
||||
|
||||
func (s *stubUserRepo) UpdateConcurrency(ctx context.Context, id int64, amount int) error {
|
||||
panic("unexpected UpdateConcurrency call")
|
||||
}
|
||||
|
||||
func (s *stubUserRepo) ExistsByEmail(ctx context.Context, email string) (bool, error) {
|
||||
panic("unexpected ExistsByEmail call")
|
||||
}
|
||||
|
||||
func (s *stubUserRepo) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) {
|
||||
panic("unexpected RemoveGroupFromAllowedGroups call")
|
||||
}
|
||||
|
||||
func (s *stubUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
|
||||
panic("unexpected UpdateTotpSecret call")
|
||||
}
|
||||
|
||||
func (s *stubUserRepo) EnableTotp(ctx context.Context, userID int64) error {
|
||||
panic("unexpected EnableTotp call")
|
||||
}
|
||||
|
||||
func (s *stubUserRepo) DisableTotp(ctx context.Context, userID int64) error {
|
||||
panic("unexpected DisableTotp call")
|
||||
}
|
||||
@@ -3,7 +3,6 @@ package middleware
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
@@ -36,8 +35,8 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
|
||||
if authHeader != "" {
|
||||
// 验证Bearer scheme
|
||||
parts := strings.SplitN(authHeader, " ", 2)
|
||||
if len(parts) == 2 && parts[0] == "Bearer" {
|
||||
apiKeyString = parts[1]
|
||||
if len(parts) == 2 && strings.EqualFold(parts[0], "Bearer") {
|
||||
apiKeyString = strings.TrimSpace(parts[1])
|
||||
}
|
||||
}
|
||||
|
||||
@@ -97,7 +96,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
|
||||
// 检查 IP 限制(白名单/黑名单)
|
||||
// 注意:错误信息故意模糊,避免暴露具体的 IP 限制机制
|
||||
if len(apiKey.IPWhitelist) > 0 || len(apiKey.IPBlacklist) > 0 {
|
||||
clientIP := ip.GetClientIP(c)
|
||||
clientIP := ip.GetTrustedClientIP(c)
|
||||
allowed, _ := ip.CheckIPRestriction(clientIP, apiKey.IPWhitelist, apiKey.IPBlacklist)
|
||||
if !allowed {
|
||||
AbortWithError(c, 403, "ACCESS_DENIED", "Access denied")
|
||||
@@ -126,6 +125,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
|
||||
})
|
||||
c.Set(string(ContextKeyUserRole), apiKey.User.Role)
|
||||
setGroupContext(c, apiKey.Group)
|
||||
_ = apiKeyService.TouchLastUsed(c.Request.Context(), apiKey.ID)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
@@ -134,7 +134,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
|
||||
isSubscriptionType := apiKey.Group != nil && apiKey.Group.IsSubscriptionType()
|
||||
|
||||
if isSubscriptionType && subscriptionService != nil {
|
||||
// 订阅模式:验证订阅
|
||||
// 订阅模式:获取订阅(L1 缓存 + singleflight)
|
||||
subscription, err := subscriptionService.GetActiveSubscription(
|
||||
c.Request.Context(),
|
||||
apiKey.User.ID,
|
||||
@@ -145,30 +145,30 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
|
||||
return
|
||||
}
|
||||
|
||||
// 验证订阅状态(是否过期、暂停等)
|
||||
if err := subscriptionService.ValidateSubscription(c.Request.Context(), subscription); err != nil {
|
||||
AbortWithError(c, 403, "SUBSCRIPTION_INVALID", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 激活滑动窗口(首次使用时)
|
||||
if err := subscriptionService.CheckAndActivateWindow(c.Request.Context(), subscription); err != nil {
|
||||
log.Printf("Failed to activate subscription windows: %v", err)
|
||||
}
|
||||
|
||||
// 检查并重置过期窗口
|
||||
if err := subscriptionService.CheckAndResetWindows(c.Request.Context(), subscription); err != nil {
|
||||
log.Printf("Failed to reset subscription windows: %v", err)
|
||||
}
|
||||
|
||||
// 预检查用量限制(使用0作为额外费用进行预检查)
|
||||
if err := subscriptionService.CheckUsageLimits(c.Request.Context(), subscription, apiKey.Group, 0); err != nil {
|
||||
AbortWithError(c, 429, "USAGE_LIMIT_EXCEEDED", err.Error())
|
||||
// 合并验证 + 限额检查(纯内存操作)
|
||||
needsMaintenance, err := subscriptionService.ValidateAndCheckLimits(subscription, apiKey.Group)
|
||||
if err != nil {
|
||||
code := "SUBSCRIPTION_INVALID"
|
||||
status := 403
|
||||
if errors.Is(err, service.ErrDailyLimitExceeded) ||
|
||||
errors.Is(err, service.ErrWeeklyLimitExceeded) ||
|
||||
errors.Is(err, service.ErrMonthlyLimitExceeded) {
|
||||
code = "USAGE_LIMIT_EXCEEDED"
|
||||
status = 429
|
||||
}
|
||||
AbortWithError(c, status, code, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 将订阅信息存入上下文
|
||||
c.Set(string(ContextKeySubscription), subscription)
|
||||
|
||||
// 窗口维护异步化(不阻塞请求)
|
||||
// 传递独立拷贝,避免与 handler 读取 context 中的 subscription 产生 data race
|
||||
if needsMaintenance {
|
||||
maintenanceCopy := *subscription
|
||||
subscriptionService.DoWindowMaintenance(&maintenanceCopy)
|
||||
}
|
||||
} else {
|
||||
// 余额模式:检查用户余额
|
||||
if apiKey.User.Balance <= 0 {
|
||||
@@ -185,6 +185,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
|
||||
})
|
||||
c.Set(string(ContextKeyUserRole), apiKey.User.Role)
|
||||
setGroupContext(c, apiKey.Group)
|
||||
_ = apiKeyService.TouchLastUsed(c.Request.Context(), apiKey.ID)
|
||||
|
||||
c.Next()
|
||||
}
|
||||
|
||||
@@ -64,6 +64,7 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs
|
||||
})
|
||||
c.Set(string(ContextKeyUserRole), apiKey.User.Role)
|
||||
setGroupContext(c, apiKey.Group)
|
||||
_ = apiKeyService.TouchLastUsed(c.Request.Context(), apiKey.ID)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
@@ -104,6 +105,7 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs
|
||||
})
|
||||
c.Set(string(ContextKeyUserRole), apiKey.User.Role)
|
||||
setGroupContext(c, apiKey.Group)
|
||||
_ = apiKeyService.TouchLastUsed(c.Request.Context(), apiKey.ID)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
@@ -18,7 +19,8 @@ import (
|
||||
)
|
||||
|
||||
type fakeAPIKeyRepo struct {
|
||||
getByKey func(ctx context.Context, key string) (*service.APIKey, error)
|
||||
getByKey func(ctx context.Context, key string) (*service.APIKey, error)
|
||||
updateLastUsed func(ctx context.Context, id int64, usedAt time.Time) error
|
||||
}
|
||||
|
||||
func (f fakeAPIKeyRepo) Create(ctx context.Context, key *service.APIKey) error {
|
||||
@@ -78,6 +80,12 @@ func (f fakeAPIKeyRepo) ListKeysByGroupID(ctx context.Context, groupID int64) ([
|
||||
func (f fakeAPIKeyRepo) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error) {
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeAPIKeyRepo) UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error {
|
||||
if f.updateLastUsed != nil {
|
||||
return f.updateLastUsed(ctx, id, usedAt)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type googleErrorResponse struct {
|
||||
Error struct {
|
||||
@@ -356,3 +364,144 @@ func TestApiKeyAuthWithSubscriptionGoogle_InsufficientBalance(t *testing.T) {
|
||||
require.Equal(t, "Insufficient account balance", resp.Error.Message)
|
||||
require.Equal(t, "PERMISSION_DENIED", resp.Error.Status)
|
||||
}
|
||||
|
||||
func TestApiKeyAuthWithSubscriptionGoogle_TouchesLastUsedOnSuccess(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
user := &service.User{
|
||||
ID: 11,
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Balance: 10,
|
||||
Concurrency: 3,
|
||||
}
|
||||
apiKey := &service.APIKey{
|
||||
ID: 201,
|
||||
UserID: user.ID,
|
||||
Key: "google-touch-ok",
|
||||
Status: service.StatusActive,
|
||||
User: user,
|
||||
}
|
||||
|
||||
var touchedID int64
|
||||
var touchedAt time.Time
|
||||
r := gin.New()
|
||||
apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
if key != apiKey.Key {
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
}
|
||||
clone := *apiKey
|
||||
return &clone, nil
|
||||
},
|
||||
updateLastUsed: func(ctx context.Context, id int64, usedAt time.Time) error {
|
||||
touchedID = id
|
||||
touchedAt = usedAt
|
||||
return nil
|
||||
},
|
||||
})
|
||||
cfg := &config.Config{RunMode: config.RunModeSimple}
|
||||
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, cfg))
|
||||
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
|
||||
req.Header.Set("x-goog-api-key", apiKey.Key)
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
require.Equal(t, apiKey.ID, touchedID)
|
||||
require.False(t, touchedAt.IsZero())
|
||||
}
|
||||
|
||||
func TestApiKeyAuthWithSubscriptionGoogle_TouchFailureDoesNotBlock(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
user := &service.User{
|
||||
ID: 12,
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Balance: 10,
|
||||
Concurrency: 3,
|
||||
}
|
||||
apiKey := &service.APIKey{
|
||||
ID: 202,
|
||||
UserID: user.ID,
|
||||
Key: "google-touch-fail",
|
||||
Status: service.StatusActive,
|
||||
User: user,
|
||||
}
|
||||
|
||||
touchCalls := 0
|
||||
r := gin.New()
|
||||
apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
if key != apiKey.Key {
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
}
|
||||
clone := *apiKey
|
||||
return &clone, nil
|
||||
},
|
||||
updateLastUsed: func(ctx context.Context, id int64, usedAt time.Time) error {
|
||||
touchCalls++
|
||||
return errors.New("write failed")
|
||||
},
|
||||
})
|
||||
cfg := &config.Config{RunMode: config.RunModeSimple}
|
||||
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, cfg))
|
||||
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
|
||||
req.Header.Set("x-goog-api-key", apiKey.Key)
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
require.Equal(t, 1, touchCalls)
|
||||
}
|
||||
|
||||
func TestApiKeyAuthWithSubscriptionGoogle_TouchesLastUsedInStandardMode(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
user := &service.User{
|
||||
ID: 13,
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Balance: 10,
|
||||
Concurrency: 3,
|
||||
}
|
||||
apiKey := &service.APIKey{
|
||||
ID: 203,
|
||||
UserID: user.ID,
|
||||
Key: "google-touch-standard",
|
||||
Status: service.StatusActive,
|
||||
User: user,
|
||||
}
|
||||
|
||||
touchCalls := 0
|
||||
r := gin.New()
|
||||
apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
if key != apiKey.Key {
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
}
|
||||
clone := *apiKey
|
||||
return &clone, nil
|
||||
},
|
||||
updateLastUsed: func(ctx context.Context, id int64, usedAt time.Time) error {
|
||||
touchCalls++
|
||||
return nil
|
||||
},
|
||||
})
|
||||
cfg := &config.Config{RunMode: config.RunModeStandard}
|
||||
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, cfg))
|
||||
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey.Key)
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
require.Equal(t, 1, touchCalls)
|
||||
}
|
||||
|
||||
@@ -57,10 +57,41 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("simple_mode_bypasses_quota_check", func(t *testing.T) {
|
||||
cfg := &config.Config{RunMode: config.RunModeSimple}
|
||||
t.Run("standard_mode_needs_maintenance_does_not_block_request", func(t *testing.T) {
|
||||
cfg := &config.Config{RunMode: config.RunModeStandard}
|
||||
cfg.SubscriptionMaintenance.WorkerCount = 1
|
||||
cfg.SubscriptionMaintenance.QueueSize = 1
|
||||
|
||||
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
|
||||
subscriptionService := service.NewSubscriptionService(nil, &stubUserSubscriptionRepo{}, nil)
|
||||
|
||||
past := time.Now().Add(-48 * time.Hour)
|
||||
sub := &service.UserSubscription{
|
||||
ID: 55,
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
DailyWindowStart: &past,
|
||||
DailyUsageUSD: 0,
|
||||
}
|
||||
maintenanceCalled := make(chan struct{}, 1)
|
||||
subscriptionRepo := &stubUserSubscriptionRepo{
|
||||
getActive: func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
|
||||
clone := *sub
|
||||
return &clone, nil
|
||||
},
|
||||
updateStatus: func(ctx context.Context, subscriptionID int64, status string) error { return nil },
|
||||
activateWindow: func(ctx context.Context, id int64, start time.Time) error { return nil },
|
||||
resetDaily: func(ctx context.Context, id int64, start time.Time) error {
|
||||
maintenanceCalled <- struct{}{}
|
||||
return nil
|
||||
},
|
||||
resetWeekly: func(ctx context.Context, id int64, start time.Time) error { return nil },
|
||||
resetMonthly: func(ctx context.Context, id int64, start time.Time) error { return nil },
|
||||
}
|
||||
subscriptionService := service.NewSubscriptionService(nil, subscriptionRepo, nil, nil, cfg)
|
||||
t.Cleanup(subscriptionService.Stop)
|
||||
|
||||
router := newAuthTestRouter(apiKeyService, subscriptionService, cfg)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
@@ -68,6 +99,40 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
|
||||
req.Header.Set("x-api-key", apiKey.Key)
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
select {
|
||||
case <-maintenanceCalled:
|
||||
// ok
|
||||
case <-time.After(time.Second):
|
||||
t.Fatalf("expected maintenance to be scheduled")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("simple_mode_bypasses_quota_check", func(t *testing.T) {
|
||||
cfg := &config.Config{RunMode: config.RunModeSimple}
|
||||
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
|
||||
subscriptionService := service.NewSubscriptionService(nil, &stubUserSubscriptionRepo{}, nil, nil, cfg)
|
||||
router := newAuthTestRouter(apiKeyService, subscriptionService, cfg)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
||||
req.Header.Set("x-api-key", apiKey.Key)
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
})
|
||||
|
||||
t.Run("simple_mode_accepts_lowercase_bearer", func(t *testing.T) {
|
||||
cfg := &config.Config{RunMode: config.RunModeSimple}
|
||||
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
|
||||
subscriptionService := service.NewSubscriptionService(nil, &stubUserSubscriptionRepo{}, nil, nil, cfg)
|
||||
router := newAuthTestRouter(apiKeyService, subscriptionService, cfg)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
||||
req.Header.Set("Authorization", "bearer "+apiKey.Key)
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
})
|
||||
|
||||
@@ -99,7 +164,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
|
||||
resetWeekly: func(ctx context.Context, id int64, start time.Time) error { return nil },
|
||||
resetMonthly: func(ctx context.Context, id int64, start time.Time) error { return nil },
|
||||
}
|
||||
subscriptionService := service.NewSubscriptionService(nil, subscriptionRepo, nil)
|
||||
subscriptionService := service.NewSubscriptionService(nil, subscriptionRepo, nil, nil, cfg)
|
||||
router := newAuthTestRouter(apiKeyService, subscriptionService, cfg)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
@@ -235,6 +300,198 @@ func TestAPIKeyAuthOverwritesInvalidContextGroup(t *testing.T) {
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
func TestAPIKeyAuthIPRestrictionDoesNotTrustSpoofedForwardHeaders(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
user := &service.User{
|
||||
ID: 7,
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Balance: 10,
|
||||
Concurrency: 3,
|
||||
}
|
||||
apiKey := &service.APIKey{
|
||||
ID: 100,
|
||||
UserID: user.ID,
|
||||
Key: "test-key",
|
||||
Status: service.StatusActive,
|
||||
User: user,
|
||||
IPWhitelist: []string{"1.2.3.4"},
|
||||
}
|
||||
|
||||
apiKeyRepo := &stubApiKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
if key != apiKey.Key {
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
}
|
||||
clone := *apiKey
|
||||
return &clone, nil
|
||||
},
|
||||
}
|
||||
|
||||
cfg := &config.Config{RunMode: config.RunModeSimple}
|
||||
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
|
||||
router := gin.New()
|
||||
require.NoError(t, router.SetTrustedProxies(nil))
|
||||
router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, nil, cfg)))
|
||||
router.GET("/t", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
||||
req.RemoteAddr = "9.9.9.9:12345"
|
||||
req.Header.Set("x-api-key", apiKey.Key)
|
||||
req.Header.Set("X-Forwarded-For", "1.2.3.4")
|
||||
req.Header.Set("X-Real-IP", "1.2.3.4")
|
||||
req.Header.Set("CF-Connecting-IP", "1.2.3.4")
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusForbidden, w.Code)
|
||||
require.Contains(t, w.Body.String(), "ACCESS_DENIED")
|
||||
}
|
||||
|
||||
func TestAPIKeyAuthTouchesLastUsedOnSuccess(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
user := &service.User{
|
||||
ID: 7,
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Balance: 10,
|
||||
Concurrency: 3,
|
||||
}
|
||||
apiKey := &service.APIKey{
|
||||
ID: 100,
|
||||
UserID: user.ID,
|
||||
Key: "touch-ok",
|
||||
Status: service.StatusActive,
|
||||
User: user,
|
||||
}
|
||||
|
||||
var touchedID int64
|
||||
var touchedAt time.Time
|
||||
apiKeyRepo := &stubApiKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
if key != apiKey.Key {
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
}
|
||||
clone := *apiKey
|
||||
return &clone, nil
|
||||
},
|
||||
updateLastUsed: func(ctx context.Context, id int64, usedAt time.Time) error {
|
||||
touchedID = id
|
||||
touchedAt = usedAt
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
cfg := &config.Config{RunMode: config.RunModeSimple}
|
||||
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
|
||||
router := newAuthTestRouter(apiKeyService, nil, cfg)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
||||
req.Header.Set("x-api-key", apiKey.Key)
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
require.Equal(t, apiKey.ID, touchedID)
|
||||
require.False(t, touchedAt.IsZero(), "expected touch timestamp")
|
||||
}
|
||||
|
||||
func TestAPIKeyAuthTouchLastUsedFailureDoesNotBlock(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
user := &service.User{
|
||||
ID: 8,
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Balance: 10,
|
||||
Concurrency: 3,
|
||||
}
|
||||
apiKey := &service.APIKey{
|
||||
ID: 101,
|
||||
UserID: user.ID,
|
||||
Key: "touch-fail",
|
||||
Status: service.StatusActive,
|
||||
User: user,
|
||||
}
|
||||
|
||||
touchCalls := 0
|
||||
apiKeyRepo := &stubApiKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
if key != apiKey.Key {
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
}
|
||||
clone := *apiKey
|
||||
return &clone, nil
|
||||
},
|
||||
updateLastUsed: func(ctx context.Context, id int64, usedAt time.Time) error {
|
||||
touchCalls++
|
||||
return errors.New("db unavailable")
|
||||
},
|
||||
}
|
||||
|
||||
cfg := &config.Config{RunMode: config.RunModeSimple}
|
||||
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
|
||||
router := newAuthTestRouter(apiKeyService, nil, cfg)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
||||
req.Header.Set("x-api-key", apiKey.Key)
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code, "touch failure should not block request")
|
||||
require.Equal(t, 1, touchCalls)
|
||||
}
|
||||
|
||||
func TestAPIKeyAuthTouchesLastUsedInStandardMode(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
user := &service.User{
|
||||
ID: 9,
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Balance: 10,
|
||||
Concurrency: 3,
|
||||
}
|
||||
apiKey := &service.APIKey{
|
||||
ID: 102,
|
||||
UserID: user.ID,
|
||||
Key: "touch-standard",
|
||||
Status: service.StatusActive,
|
||||
User: user,
|
||||
}
|
||||
|
||||
touchCalls := 0
|
||||
apiKeyRepo := &stubApiKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
if key != apiKey.Key {
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
}
|
||||
clone := *apiKey
|
||||
return &clone, nil
|
||||
},
|
||||
updateLastUsed: func(ctx context.Context, id int64, usedAt time.Time) error {
|
||||
touchCalls++
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
cfg := &config.Config{RunMode: config.RunModeStandard}
|
||||
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
|
||||
router := newAuthTestRouter(apiKeyService, nil, cfg)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
||||
req.Header.Set("x-api-key", apiKey.Key)
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
require.Equal(t, 1, touchCalls)
|
||||
}
|
||||
|
||||
func newAuthTestRouter(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) *gin.Engine {
|
||||
router := gin.New()
|
||||
router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, cfg)))
|
||||
@@ -245,7 +502,8 @@ func newAuthTestRouter(apiKeyService *service.APIKeyService, subscriptionService
|
||||
}
|
||||
|
||||
type stubApiKeyRepo struct {
|
||||
getByKey func(ctx context.Context, key string) (*service.APIKey, error)
|
||||
getByKey func(ctx context.Context, key string) (*service.APIKey, error)
|
||||
updateLastUsed func(ctx context.Context, id int64, usedAt time.Time) error
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) Create(ctx context.Context, key *service.APIKey) error {
|
||||
@@ -323,6 +581,13 @@ func (r *stubApiKeyRepo) IncrementQuotaUsed(ctx context.Context, id int64, amoun
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error {
|
||||
if r.updateLastUsed != nil {
|
||||
return r.updateLastUsed(ctx, id, usedAt)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type stubUserSubscriptionRepo struct {
|
||||
getActive func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error)
|
||||
updateStatus func(ctx context.Context, subscriptionID int64, status string) error
|
||||
|
||||
@@ -2,10 +2,13 @@ package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// ClientRequestID ensures every request has a unique client_request_id in request.Context().
|
||||
@@ -24,7 +27,10 @@ func ClientRequestID() gin.HandlerFunc {
|
||||
}
|
||||
|
||||
id := uuid.New().String()
|
||||
c.Request = c.Request.WithContext(context.WithValue(c.Request.Context(), ctxkey.ClientRequestID, id))
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.ClientRequestID, id)
|
||||
requestLogger := logger.FromContext(ctx).With(zap.String("client_request_id", strings.TrimSpace(id)))
|
||||
ctx = logger.IntoContext(ctx, requestLogger)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -50,6 +50,19 @@ func CORS(cfg config.CORSConfig) gin.HandlerFunc {
|
||||
}
|
||||
allowedSet[origin] = struct{}{}
|
||||
}
|
||||
allowHeaders := []string{
|
||||
"Content-Type", "Content-Length", "Accept-Encoding", "X-CSRF-Token", "Authorization",
|
||||
"accept", "origin", "Cache-Control", "X-Requested-With", "X-API-Key",
|
||||
}
|
||||
// OpenAI Node SDK 会发送 x-stainless-* 请求头,需在 CORS 中显式放行。
|
||||
openAIProperties := []string{
|
||||
"lang", "package-version", "os", "arch", "retry-count", "runtime",
|
||||
"runtime-version", "async", "helper-method", "poll-helper", "custom-poll-interval", "timeout",
|
||||
}
|
||||
for _, prop := range openAIProperties {
|
||||
allowHeaders = append(allowHeaders, "x-stainless-"+prop)
|
||||
}
|
||||
allowHeadersValue := strings.Join(allowHeaders, ", ")
|
||||
|
||||
return func(c *gin.Context) {
|
||||
origin := strings.TrimSpace(c.GetHeader("Origin"))
|
||||
@@ -68,19 +81,11 @@ func CORS(cfg config.CORSConfig) gin.HandlerFunc {
|
||||
if allowCredentials {
|
||||
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
}
|
||||
c.Writer.Header().Set("Access-Control-Allow-Headers", allowHeadersValue)
|
||||
c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE, PATCH")
|
||||
c.Writer.Header().Set("Access-Control-Expose-Headers", "ETag")
|
||||
c.Writer.Header().Set("Access-Control-Max-Age", "86400")
|
||||
}
|
||||
|
||||
allowHeaders := []string{"Content-Type", "Content-Length", "Accept-Encoding", "X-CSRF-Token", "Authorization", "accept", "origin", "Cache-Control", "X-Requested-With", "X-API-Key"}
|
||||
|
||||
// openai node sdk
|
||||
openAIProperties := []string{"lang", "package-version", "os", "arch", "retry-count", "runtime", "runtime-version", "async", "helper-method", "poll-helper", "custom-poll-interval", "timeout"}
|
||||
for _, prop := range openAIProperties {
|
||||
allowHeaders = append(allowHeaders, "x-stainless-"+prop)
|
||||
}
|
||||
|
||||
c.Writer.Header().Set("Access-Control-Allow-Headers", strings.Join(allowHeaders, ", "))
|
||||
c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE, PATCH")
|
||||
|
||||
// 处理预检请求
|
||||
if c.Request.Method == http.MethodOptions {
|
||||
if originAllowed {
|
||||
|
||||
308
backend/internal/server/middleware/cors_test.go
Normal file
308
backend/internal/server/middleware/cors_test.go
Normal file
@@ -0,0 +1,308 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func init() {
|
||||
// cors_test 与 security_headers_test 在同一个包,但 init 是幂等的
|
||||
gin.SetMode(gin.TestMode)
|
||||
}
|
||||
|
||||
// --- Task 8.2: 验证 CORS 条件化头部 ---
|
||||
|
||||
func TestCORS_DisallowedOrigin_NoAllowHeaders(t *testing.T) {
|
||||
cfg := config.CORSConfig{
|
||||
AllowedOrigins: []string{"https://allowed.example.com"},
|
||||
AllowCredentials: false,
|
||||
}
|
||||
middleware := CORS(cfg)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
origin string
|
||||
}{
|
||||
{
|
||||
name: "preflight_disallowed_origin",
|
||||
method: http.MethodOptions,
|
||||
origin: "https://evil.example.com",
|
||||
},
|
||||
{
|
||||
name: "get_disallowed_origin",
|
||||
method: http.MethodGet,
|
||||
origin: "https://evil.example.com",
|
||||
},
|
||||
{
|
||||
name: "post_disallowed_origin",
|
||||
method: http.MethodPost,
|
||||
origin: "https://attacker.example.com",
|
||||
},
|
||||
{
|
||||
name: "preflight_no_origin",
|
||||
method: http.MethodOptions,
|
||||
origin: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(tt.method, "/", nil)
|
||||
if tt.origin != "" {
|
||||
c.Request.Header.Set("Origin", tt.origin)
|
||||
}
|
||||
|
||||
middleware(c)
|
||||
|
||||
// 不应设置 Allow-Headers、Allow-Methods 和 Max-Age
|
||||
assert.Empty(t, w.Header().Get("Access-Control-Allow-Headers"),
|
||||
"不允许的 origin 不应收到 Allow-Headers")
|
||||
assert.Empty(t, w.Header().Get("Access-Control-Allow-Methods"),
|
||||
"不允许的 origin 不应收到 Allow-Methods")
|
||||
assert.Empty(t, w.Header().Get("Access-Control-Max-Age"),
|
||||
"不允许的 origin 不应收到 Max-Age")
|
||||
assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin"),
|
||||
"不允许的 origin 不应收到 Allow-Origin")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCORS_AllowedOrigin_HasAllowHeaders(t *testing.T) {
|
||||
cfg := config.CORSConfig{
|
||||
AllowedOrigins: []string{"https://allowed.example.com"},
|
||||
AllowCredentials: false,
|
||||
}
|
||||
middleware := CORS(cfg)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
}{
|
||||
{name: "preflight_OPTIONS", method: http.MethodOptions},
|
||||
{name: "normal_GET", method: http.MethodGet},
|
||||
{name: "normal_POST", method: http.MethodPost},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(tt.method, "/", nil)
|
||||
c.Request.Header.Set("Origin", "https://allowed.example.com")
|
||||
|
||||
middleware(c)
|
||||
|
||||
// 应设置 Allow-Headers、Allow-Methods 和 Max-Age
|
||||
assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Headers"),
|
||||
"允许的 origin 应收到 Allow-Headers")
|
||||
assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Methods"),
|
||||
"允许的 origin 应收到 Allow-Methods")
|
||||
assert.Equal(t, "86400", w.Header().Get("Access-Control-Max-Age"),
|
||||
"允许的 origin 应收到 Max-Age=86400")
|
||||
assert.Equal(t, "https://allowed.example.com", w.Header().Get("Access-Control-Allow-Origin"),
|
||||
"允许的 origin 应收到 Allow-Origin")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCORS_PreflightDisallowedOrigin_ReturnsForbidden(t *testing.T) {
|
||||
cfg := config.CORSConfig{
|
||||
AllowedOrigins: []string{"https://allowed.example.com"},
|
||||
AllowCredentials: false,
|
||||
}
|
||||
middleware := CORS(cfg)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodOptions, "/", nil)
|
||||
c.Request.Header.Set("Origin", "https://evil.example.com")
|
||||
|
||||
middleware(c)
|
||||
|
||||
assert.Equal(t, http.StatusForbidden, w.Code,
|
||||
"不允许的 origin 的 preflight 请求应返回 403")
|
||||
}
|
||||
|
||||
func TestCORS_PreflightAllowedOrigin_ReturnsNoContent(t *testing.T) {
|
||||
cfg := config.CORSConfig{
|
||||
AllowedOrigins: []string{"https://allowed.example.com"},
|
||||
AllowCredentials: false,
|
||||
}
|
||||
middleware := CORS(cfg)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodOptions, "/", nil)
|
||||
c.Request.Header.Set("Origin", "https://allowed.example.com")
|
||||
|
||||
middleware(c)
|
||||
|
||||
assert.Equal(t, http.StatusNoContent, w.Code,
|
||||
"允许的 origin 的 preflight 请求应返回 204")
|
||||
}
|
||||
|
||||
func TestCORS_WildcardOrigin_AllowsAny(t *testing.T) {
|
||||
cfg := config.CORSConfig{
|
||||
AllowedOrigins: []string{"*"},
|
||||
AllowCredentials: false,
|
||||
}
|
||||
middleware := CORS(cfg)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
c.Request.Header.Set("Origin", "https://any-origin.example.com")
|
||||
|
||||
middleware(c)
|
||||
|
||||
assert.Equal(t, "*", w.Header().Get("Access-Control-Allow-Origin"),
|
||||
"通配符配置应返回 *")
|
||||
assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Headers"),
|
||||
"通配符 origin 应设置 Allow-Headers")
|
||||
assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Methods"),
|
||||
"通配符 origin 应设置 Allow-Methods")
|
||||
}
|
||||
|
||||
func TestCORS_AllowCredentials_SetCorrectly(t *testing.T) {
|
||||
cfg := config.CORSConfig{
|
||||
AllowedOrigins: []string{"https://allowed.example.com"},
|
||||
AllowCredentials: true,
|
||||
}
|
||||
middleware := CORS(cfg)
|
||||
|
||||
t.Run("allowed_origin_gets_credentials", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
c.Request.Header.Set("Origin", "https://allowed.example.com")
|
||||
|
||||
middleware(c)
|
||||
|
||||
assert.Equal(t, "true", w.Header().Get("Access-Control-Allow-Credentials"),
|
||||
"允许的 origin 且开启 credentials 应设置 Allow-Credentials")
|
||||
})
|
||||
|
||||
t.Run("disallowed_origin_no_credentials", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
c.Request.Header.Set("Origin", "https://evil.example.com")
|
||||
|
||||
middleware(c)
|
||||
|
||||
assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials"),
|
||||
"不允许的 origin 不应收到 Allow-Credentials")
|
||||
})
|
||||
}
|
||||
|
||||
func TestCORS_WildcardWithCredentials_DisablesCredentials(t *testing.T) {
|
||||
cfg := config.CORSConfig{
|
||||
AllowedOrigins: []string{"*"},
|
||||
AllowCredentials: true,
|
||||
}
|
||||
middleware := CORS(cfg)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
c.Request.Header.Set("Origin", "https://any.example.com")
|
||||
|
||||
middleware(c)
|
||||
|
||||
// 通配符 + credentials 不兼容,credentials 应被禁用
|
||||
assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials"),
|
||||
"通配符 origin 应禁用 Allow-Credentials")
|
||||
}
|
||||
|
||||
func TestCORS_MultipleAllowedOrigins(t *testing.T) {
|
||||
cfg := config.CORSConfig{
|
||||
AllowedOrigins: []string{
|
||||
"https://app1.example.com",
|
||||
"https://app2.example.com",
|
||||
},
|
||||
AllowCredentials: false,
|
||||
}
|
||||
middleware := CORS(cfg)
|
||||
|
||||
t.Run("first_origin_allowed", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
c.Request.Header.Set("Origin", "https://app1.example.com")
|
||||
|
||||
middleware(c)
|
||||
|
||||
assert.Equal(t, "https://app1.example.com", w.Header().Get("Access-Control-Allow-Origin"))
|
||||
assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Headers"))
|
||||
})
|
||||
|
||||
t.Run("second_origin_allowed", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
c.Request.Header.Set("Origin", "https://app2.example.com")
|
||||
|
||||
middleware(c)
|
||||
|
||||
assert.Equal(t, "https://app2.example.com", w.Header().Get("Access-Control-Allow-Origin"))
|
||||
assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Headers"))
|
||||
})
|
||||
|
||||
t.Run("unlisted_origin_rejected", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
c.Request.Header.Set("Origin", "https://app3.example.com")
|
||||
|
||||
middleware(c)
|
||||
|
||||
assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin"))
|
||||
assert.Empty(t, w.Header().Get("Access-Control-Allow-Headers"))
|
||||
})
|
||||
}
|
||||
|
||||
func TestCORS_VaryHeader_SetForSpecificOrigin(t *testing.T) {
|
||||
cfg := config.CORSConfig{
|
||||
AllowedOrigins: []string{"https://allowed.example.com"},
|
||||
AllowCredentials: false,
|
||||
}
|
||||
middleware := CORS(cfg)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
c.Request.Header.Set("Origin", "https://allowed.example.com")
|
||||
|
||||
middleware(c)
|
||||
|
||||
assert.Contains(t, w.Header().Values("Vary"), "Origin",
|
||||
"非通配符允许的 origin 应设置 Vary: Origin")
|
||||
}
|
||||
|
||||
func TestNormalizeOrigins(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input []string
|
||||
expect []string
|
||||
}{
|
||||
{name: "nil_input", input: nil, expect: nil},
|
||||
{name: "empty_input", input: []string{}, expect: nil},
|
||||
{name: "trims_whitespace", input: []string{" https://a.com ", " https://b.com"}, expect: []string{"https://a.com", "https://b.com"}},
|
||||
{name: "removes_empty_strings", input: []string{"", " ", "https://a.com"}, expect: []string{"https://a.com"}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := normalizeOrigins(tt.input)
|
||||
assert.Equal(t, tt.expect, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -26,12 +26,12 @@ func jwtAuth(authService *service.AuthService, userService *service.UserService)
|
||||
|
||||
// 验证Bearer scheme
|
||||
parts := strings.SplitN(authHeader, " ", 2)
|
||||
if len(parts) != 2 || parts[0] != "Bearer" {
|
||||
if len(parts) != 2 || !strings.EqualFold(parts[0], "Bearer") {
|
||||
AbortWithError(c, 401, "INVALID_AUTH_HEADER", "Authorization header format must be 'Bearer {token}'")
|
||||
return
|
||||
}
|
||||
|
||||
tokenString := parts[1]
|
||||
tokenString := strings.TrimSpace(parts[1])
|
||||
if tokenString == "" {
|
||||
AbortWithError(c, 401, "EMPTY_TOKEN", "Token cannot be empty")
|
||||
return
|
||||
|
||||
256
backend/internal/server/middleware/jwt_auth_test.go
Normal file
256
backend/internal/server/middleware/jwt_auth_test.go
Normal file
@@ -0,0 +1,256 @@
|
||||
//go:build unit
|
||||
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// stubJWTUserRepo 实现 UserRepository 的最小子集,仅支持 GetByID。
|
||||
type stubJWTUserRepo struct {
|
||||
service.UserRepository
|
||||
users map[int64]*service.User
|
||||
}
|
||||
|
||||
func (r *stubJWTUserRepo) GetByID(_ context.Context, id int64) (*service.User, error) {
|
||||
u, ok := r.users[id]
|
||||
if !ok {
|
||||
return nil, errors.New("user not found")
|
||||
}
|
||||
return u, nil
|
||||
}
|
||||
|
||||
// newJWTTestEnv 创建 JWT 认证中间件测试环境。
|
||||
// 返回 gin.Engine(已注册 JWT 中间件)和 AuthService(用于生成 Token)。
|
||||
func newJWTTestEnv(users map[int64]*service.User) (*gin.Engine, *service.AuthService) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
cfg := &config.Config{}
|
||||
cfg.JWT.Secret = "test-jwt-secret-32bytes-long!!!"
|
||||
cfg.JWT.AccessTokenExpireMinutes = 60
|
||||
|
||||
userRepo := &stubJWTUserRepo{users: users}
|
||||
authSvc := service.NewAuthService(userRepo, nil, nil, cfg, nil, nil, nil, nil, nil)
|
||||
userSvc := service.NewUserService(userRepo, nil, nil)
|
||||
mw := NewJWTAuthMiddleware(authSvc, userSvc)
|
||||
|
||||
r := gin.New()
|
||||
r.Use(gin.HandlerFunc(mw))
|
||||
r.GET("/protected", func(c *gin.Context) {
|
||||
subject, _ := GetAuthSubjectFromContext(c)
|
||||
role, _ := GetUserRoleFromContext(c)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"user_id": subject.UserID,
|
||||
"role": role,
|
||||
})
|
||||
})
|
||||
return r, authSvc
|
||||
}
|
||||
|
||||
func TestJWTAuth_ValidToken(t *testing.T) {
|
||||
user := &service.User{
|
||||
ID: 1,
|
||||
Email: "test@example.com",
|
||||
Role: "user",
|
||||
Status: service.StatusActive,
|
||||
Concurrency: 5,
|
||||
TokenVersion: 1,
|
||||
}
|
||||
router, authSvc := newJWTTestEnv(map[int64]*service.User{1: user})
|
||||
|
||||
token, err := authSvc.GenerateToken(user)
|
||||
require.NoError(t, err)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var body map[string]any
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body))
|
||||
require.Equal(t, float64(1), body["user_id"])
|
||||
require.Equal(t, "user", body["role"])
|
||||
}
|
||||
|
||||
func TestJWTAuth_ValidToken_LowercaseBearer(t *testing.T) {
|
||||
user := &service.User{
|
||||
ID: 1,
|
||||
Email: "test@example.com",
|
||||
Role: "user",
|
||||
Status: service.StatusActive,
|
||||
Concurrency: 5,
|
||||
TokenVersion: 1,
|
||||
}
|
||||
router, authSvc := newJWTTestEnv(map[int64]*service.User{1: user})
|
||||
|
||||
token, err := authSvc.GenerateToken(user)
|
||||
require.NoError(t, err)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
|
||||
req.Header.Set("Authorization", "bearer "+token)
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
func TestJWTAuth_MissingAuthorizationHeader(t *testing.T) {
|
||||
router, _ := newJWTTestEnv(nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusUnauthorized, w.Code)
|
||||
var body ErrorResponse
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body))
|
||||
require.Equal(t, "UNAUTHORIZED", body.Code)
|
||||
}
|
||||
|
||||
func TestJWTAuth_InvalidHeaderFormat(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
header string
|
||||
}{
|
||||
{"无Bearer前缀", "Token abc123"},
|
||||
{"缺少空格分隔", "Bearerabc123"},
|
||||
{"仅有单词", "abc123"},
|
||||
}
|
||||
router, _ := newJWTTestEnv(nil)
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
|
||||
req.Header.Set("Authorization", tt.header)
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusUnauthorized, w.Code)
|
||||
var body ErrorResponse
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body))
|
||||
require.Equal(t, "INVALID_AUTH_HEADER", body.Code)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestJWTAuth_EmptyToken(t *testing.T) {
|
||||
router, _ := newJWTTestEnv(nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
|
||||
req.Header.Set("Authorization", "Bearer ")
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusUnauthorized, w.Code)
|
||||
var body ErrorResponse
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body))
|
||||
require.Equal(t, "EMPTY_TOKEN", body.Code)
|
||||
}
|
||||
|
||||
func TestJWTAuth_TamperedToken(t *testing.T) {
|
||||
router, _ := newJWTTestEnv(nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
|
||||
req.Header.Set("Authorization", "Bearer eyJhbGciOiJIUzI1NiJ9.eyJ1c2VyX2lkIjoxfQ.invalid_signature")
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusUnauthorized, w.Code)
|
||||
var body ErrorResponse
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body))
|
||||
require.Equal(t, "INVALID_TOKEN", body.Code)
|
||||
}
|
||||
|
||||
func TestJWTAuth_UserNotFound(t *testing.T) {
|
||||
// 使用 user ID=1 的 token,但 repo 中没有该用户
|
||||
fakeUser := &service.User{
|
||||
ID: 999,
|
||||
Email: "ghost@example.com",
|
||||
Role: "user",
|
||||
Status: service.StatusActive,
|
||||
TokenVersion: 1,
|
||||
}
|
||||
// 创建环境时不注入此用户,这样 GetByID 会失败
|
||||
router, authSvc := newJWTTestEnv(map[int64]*service.User{})
|
||||
|
||||
token, err := authSvc.GenerateToken(fakeUser)
|
||||
require.NoError(t, err)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusUnauthorized, w.Code)
|
||||
var body ErrorResponse
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body))
|
||||
require.Equal(t, "USER_NOT_FOUND", body.Code)
|
||||
}
|
||||
|
||||
func TestJWTAuth_UserInactive(t *testing.T) {
|
||||
user := &service.User{
|
||||
ID: 1,
|
||||
Email: "disabled@example.com",
|
||||
Role: "user",
|
||||
Status: service.StatusDisabled,
|
||||
TokenVersion: 1,
|
||||
}
|
||||
router, authSvc := newJWTTestEnv(map[int64]*service.User{1: user})
|
||||
|
||||
token, err := authSvc.GenerateToken(user)
|
||||
require.NoError(t, err)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusUnauthorized, w.Code)
|
||||
var body ErrorResponse
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body))
|
||||
require.Equal(t, "USER_INACTIVE", body.Code)
|
||||
}
|
||||
|
||||
func TestJWTAuth_TokenVersionMismatch(t *testing.T) {
|
||||
// Token 生成时 TokenVersion=1,但数据库中用户已更新为 TokenVersion=2(密码修改)
|
||||
userForToken := &service.User{
|
||||
ID: 1,
|
||||
Email: "test@example.com",
|
||||
Role: "user",
|
||||
Status: service.StatusActive,
|
||||
TokenVersion: 1,
|
||||
}
|
||||
userInDB := &service.User{
|
||||
ID: 1,
|
||||
Email: "test@example.com",
|
||||
Role: "user",
|
||||
Status: service.StatusActive,
|
||||
TokenVersion: 2, // 密码修改后版本递增
|
||||
}
|
||||
router, authSvc := newJWTTestEnv(map[int64]*service.User{1: userInDB})
|
||||
|
||||
token, err := authSvc.GenerateToken(userForToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusUnauthorized, w.Code)
|
||||
var body ErrorResponse
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body))
|
||||
require.Equal(t, "TOKEN_REVOKED", body.Code)
|
||||
}
|
||||
@@ -1,10 +1,12 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Logger 请求日志中间件
|
||||
@@ -13,44 +15,52 @@ func Logger() gin.HandlerFunc {
|
||||
// 开始时间
|
||||
startTime := time.Now()
|
||||
|
||||
// 处理请求
|
||||
c.Next()
|
||||
|
||||
// 结束时间
|
||||
endTime := time.Now()
|
||||
|
||||
// 执行时间
|
||||
latency := endTime.Sub(startTime)
|
||||
|
||||
// 请求方法
|
||||
method := c.Request.Method
|
||||
|
||||
// 请求路径
|
||||
path := c.Request.URL.Path
|
||||
|
||||
// 状态码
|
||||
// 处理请求
|
||||
c.Next()
|
||||
|
||||
// 跳过健康检查等高频探针路径的日志
|
||||
if path == "/health" || path == "/setup/status" {
|
||||
return
|
||||
}
|
||||
|
||||
endTime := time.Now()
|
||||
latency := endTime.Sub(startTime)
|
||||
|
||||
method := c.Request.Method
|
||||
statusCode := c.Writer.Status()
|
||||
|
||||
// 客户端IP
|
||||
clientIP := c.ClientIP()
|
||||
|
||||
// 协议版本
|
||||
protocol := c.Request.Proto
|
||||
accountID, hasAccountID := c.Request.Context().Value(ctxkey.AccountID).(int64)
|
||||
platform, _ := c.Request.Context().Value(ctxkey.Platform).(string)
|
||||
model, _ := c.Request.Context().Value(ctxkey.Model).(string)
|
||||
|
||||
// 日志格式: [时间] 状态码 | 延迟 | IP | 协议 | 方法 路径
|
||||
log.Printf("[GIN] %v | %3d | %13v | %15s | %-6s | %-7s %s",
|
||||
endTime.Format("2006/01/02 - 15:04:05"),
|
||||
statusCode,
|
||||
latency,
|
||||
clientIP,
|
||||
protocol,
|
||||
method,
|
||||
path,
|
||||
)
|
||||
fields := []zap.Field{
|
||||
zap.String("component", "http.access"),
|
||||
zap.Int("status_code", statusCode),
|
||||
zap.Int64("latency_ms", latency.Milliseconds()),
|
||||
zap.String("client_ip", clientIP),
|
||||
zap.String("protocol", protocol),
|
||||
zap.String("method", method),
|
||||
zap.String("path", path),
|
||||
}
|
||||
if hasAccountID && accountID > 0 {
|
||||
fields = append(fields, zap.Int64("account_id", accountID))
|
||||
}
|
||||
if platform != "" {
|
||||
fields = append(fields, zap.String("platform", platform))
|
||||
}
|
||||
if model != "" {
|
||||
fields = append(fields, zap.String("model", model))
|
||||
}
|
||||
|
||||
l := logger.FromContext(c.Request.Context()).With(fields...)
|
||||
l.Info("http request completed", zap.Time("completed_at", endTime))
|
||||
|
||||
// 如果有错误,额外记录错误信息
|
||||
if len(c.Errors) > 0 {
|
||||
log.Printf("[GIN] Errors: %v", c.Errors.String())
|
||||
l.Warn("http request contains gin errors", zap.String("errors", c.Errors.String()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
126
backend/internal/server/middleware/misc_coverage_test.go
Normal file
126
backend/internal/server/middleware/misc_coverage_test.go
Normal file
@@ -0,0 +1,126 @@
|
||||
//go:build unit
|
||||
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestClientRequestID_GeneratesWhenMissing(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
r.Use(ClientRequestID())
|
||||
r.GET("/t", func(c *gin.Context) {
|
||||
v := c.Request.Context().Value(ctxkey.ClientRequestID)
|
||||
require.NotNil(t, v)
|
||||
id, ok := v.(string)
|
||||
require.True(t, ok)
|
||||
require.NotEmpty(t, id)
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
func TestClientRequestID_PreservesExisting(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
r.Use(ClientRequestID())
|
||||
r.GET("/t", func(c *gin.Context) {
|
||||
id, ok := c.Request.Context().Value(ctxkey.ClientRequestID).(string)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "keep", id)
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
||||
req = req.WithContext(context.WithValue(req.Context(), ctxkey.ClientRequestID, "keep"))
|
||||
r.ServeHTTP(w, req)
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
func TestRequestBodyLimit_LimitsBody(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
r.Use(RequestBodyLimit(4))
|
||||
r.POST("/t", func(c *gin.Context) {
|
||||
_, err := io.ReadAll(c.Request.Body)
|
||||
require.Error(t, err)
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/t", bytes.NewBufferString("12345"))
|
||||
r.ServeHTTP(w, req)
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
func TestForcePlatform_SetsContextAndGinValue(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
r.Use(ForcePlatform("anthropic"))
|
||||
r.GET("/t", func(c *gin.Context) {
|
||||
require.True(t, HasForcePlatform(c))
|
||||
v, ok := GetForcePlatformFromContext(c)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "anthropic", v)
|
||||
|
||||
ctxV := c.Request.Context().Value(ctxkey.ForcePlatform)
|
||||
require.Equal(t, "anthropic", ctxV)
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
func TestAuthSubjectHelpers_RoundTrip(t *testing.T) {
|
||||
c := &gin.Context{}
|
||||
c.Set(string(ContextKeyUser), AuthSubject{UserID: 1, Concurrency: 2})
|
||||
c.Set(string(ContextKeyUserRole), "admin")
|
||||
|
||||
sub, ok := GetAuthSubjectFromContext(c)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, int64(1), sub.UserID)
|
||||
require.Equal(t, 2, sub.Concurrency)
|
||||
|
||||
role, ok := GetUserRoleFromContext(c)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "admin", role)
|
||||
}
|
||||
|
||||
func TestAPIKeyAndSubscriptionFromContext(t *testing.T) {
|
||||
c := &gin.Context{}
|
||||
|
||||
key := &service.APIKey{ID: 1}
|
||||
c.Set(string(ContextKeyAPIKey), key)
|
||||
gotKey, ok := GetAPIKeyFromContext(c)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, int64(1), gotKey.ID)
|
||||
|
||||
sub := &service.UserSubscription{ID: 2}
|
||||
c.Set(string(ContextKeySubscription), sub)
|
||||
gotSub, ok := GetSubscriptionFromContext(c)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, int64(2), gotSub.ID)
|
||||
}
|
||||
@@ -3,6 +3,7 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@@ -14,6 +15,34 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRecovery_PanicLogContainsInfo(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
// 临时替换 DefaultErrorWriter 以捕获日志输出
|
||||
var buf bytes.Buffer
|
||||
originalWriter := gin.DefaultErrorWriter
|
||||
gin.DefaultErrorWriter = &buf
|
||||
t.Cleanup(func() {
|
||||
gin.DefaultErrorWriter = originalWriter
|
||||
})
|
||||
|
||||
r := gin.New()
|
||||
r.Use(Recovery())
|
||||
r.GET("/panic", func(c *gin.Context) {
|
||||
panic("custom panic message for test")
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/panic", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusInternalServerError, w.Code)
|
||||
|
||||
logOutput := buf.String()
|
||||
require.Contains(t, logOutput, "custom panic message for test", "日志应包含 panic 信息")
|
||||
require.Contains(t, logOutput, "recovery_test.go", "日志应包含堆栈跟踪文件名")
|
||||
}
|
||||
|
||||
func TestRecovery(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
|
||||
228
backend/internal/server/middleware/request_access_logger_test.go
Normal file
228
backend/internal/server/middleware/request_access_logger_test.go
Normal file
@@ -0,0 +1,228 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type testLogSink struct {
|
||||
mu sync.Mutex
|
||||
events []*logger.LogEvent
|
||||
}
|
||||
|
||||
func (s *testLogSink) WriteLogEvent(event *logger.LogEvent) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.events = append(s.events, event)
|
||||
}
|
||||
|
||||
func (s *testLogSink) list() []*logger.LogEvent {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
out := make([]*logger.LogEvent, len(s.events))
|
||||
copy(out, s.events)
|
||||
return out
|
||||
}
|
||||
|
||||
func initMiddlewareTestLogger(t *testing.T) *testLogSink {
|
||||
return initMiddlewareTestLoggerWithLevel(t, "debug")
|
||||
}
|
||||
|
||||
func initMiddlewareTestLoggerWithLevel(t *testing.T, level string) *testLogSink {
|
||||
t.Helper()
|
||||
level = strings.TrimSpace(level)
|
||||
if level == "" {
|
||||
level = "debug"
|
||||
}
|
||||
if err := logger.Init(logger.InitOptions{
|
||||
Level: level,
|
||||
Format: "json",
|
||||
ServiceName: "sub2api",
|
||||
Environment: "test",
|
||||
Output: logger.OutputOptions{
|
||||
ToStdout: false,
|
||||
ToFile: false,
|
||||
},
|
||||
}); err != nil {
|
||||
t.Fatalf("init logger: %v", err)
|
||||
}
|
||||
sink := &testLogSink{}
|
||||
logger.SetSink(sink)
|
||||
t.Cleanup(func() {
|
||||
logger.SetSink(nil)
|
||||
})
|
||||
return sink
|
||||
}
|
||||
|
||||
func TestRequestLogger_GenerateAndPropagateRequestID(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
r.Use(RequestLogger())
|
||||
r.GET("/t", func(c *gin.Context) {
|
||||
reqID, ok := c.Request.Context().Value(ctxkey.RequestID).(string)
|
||||
if !ok || reqID == "" {
|
||||
t.Fatalf("request_id missing in context")
|
||||
}
|
||||
if got := c.Writer.Header().Get(requestIDHeader); got != reqID {
|
||||
t.Fatalf("response header request_id mismatch, header=%q ctx=%q", got, reqID)
|
||||
}
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status=%d", w.Code)
|
||||
}
|
||||
if w.Header().Get(requestIDHeader) == "" {
|
||||
t.Fatalf("X-Request-ID should be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestLogger_KeepIncomingRequestID(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
r.Use(RequestLogger())
|
||||
r.GET("/t", func(c *gin.Context) {
|
||||
reqID, _ := c.Request.Context().Value(ctxkey.RequestID).(string)
|
||||
if reqID != "rid-fixed" {
|
||||
t.Fatalf("request_id=%q, want rid-fixed", reqID)
|
||||
}
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
||||
req.Header.Set(requestIDHeader, "rid-fixed")
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status=%d", w.Code)
|
||||
}
|
||||
if got := w.Header().Get(requestIDHeader); got != "rid-fixed" {
|
||||
t.Fatalf("header=%q, want rid-fixed", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogger_AccessLogIncludesCoreFields(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
sink := initMiddlewareTestLogger(t)
|
||||
|
||||
r := gin.New()
|
||||
r.Use(Logger())
|
||||
r.Use(func(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
ctx = context.WithValue(ctx, ctxkey.AccountID, int64(101))
|
||||
ctx = context.WithValue(ctx, ctxkey.Platform, "openai")
|
||||
ctx = context.WithValue(ctx, ctxkey.Model, "gpt-5")
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
c.Next()
|
||||
})
|
||||
r.GET("/api/test", func(c *gin.Context) {
|
||||
c.Status(http.StatusCreated)
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusCreated {
|
||||
t.Fatalf("status=%d", w.Code)
|
||||
}
|
||||
|
||||
events := sink.list()
|
||||
if len(events) == 0 {
|
||||
t.Fatalf("expected at least one log event")
|
||||
}
|
||||
found := false
|
||||
for _, event := range events {
|
||||
if event == nil || event.Message != "http request completed" {
|
||||
continue
|
||||
}
|
||||
found = true
|
||||
switch v := event.Fields["status_code"].(type) {
|
||||
case int:
|
||||
if v != http.StatusCreated {
|
||||
t.Fatalf("status_code field mismatch: %v", v)
|
||||
}
|
||||
case int64:
|
||||
if v != int64(http.StatusCreated) {
|
||||
t.Fatalf("status_code field mismatch: %v", v)
|
||||
}
|
||||
default:
|
||||
t.Fatalf("status_code type mismatch: %T", v)
|
||||
}
|
||||
switch v := event.Fields["account_id"].(type) {
|
||||
case int64:
|
||||
if v != 101 {
|
||||
t.Fatalf("account_id field mismatch: %v", v)
|
||||
}
|
||||
case int:
|
||||
if v != 101 {
|
||||
t.Fatalf("account_id field mismatch: %v", v)
|
||||
}
|
||||
default:
|
||||
t.Fatalf("account_id type mismatch: %T", v)
|
||||
}
|
||||
if event.Fields["platform"] != "openai" || event.Fields["model"] != "gpt-5" {
|
||||
t.Fatalf("platform/model mismatch: %+v", event.Fields)
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Fatalf("access log event not found")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogger_HealthPathSkipped(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
sink := initMiddlewareTestLogger(t)
|
||||
|
||||
r := gin.New()
|
||||
r.Use(Logger())
|
||||
r.GET("/health", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/health", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status=%d", w.Code)
|
||||
}
|
||||
if len(sink.list()) != 0 {
|
||||
t.Fatalf("health endpoint should not write access log")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogger_AccessLogDroppedWhenLevelWarn(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
sink := initMiddlewareTestLoggerWithLevel(t, "warn")
|
||||
|
||||
r := gin.New()
|
||||
r.Use(RequestLogger())
|
||||
r.Use(Logger())
|
||||
r.GET("/api/test", func(c *gin.Context) {
|
||||
c.Status(http.StatusCreated)
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusCreated {
|
||||
t.Fatalf("status=%d", w.Code)
|
||||
}
|
||||
|
||||
events := sink.list()
|
||||
for _, event := range events {
|
||||
if event != nil && event.Message == "http request completed" {
|
||||
t.Fatalf("access log should not be indexed when level=warn: %+v", event)
|
||||
}
|
||||
}
|
||||
}
|
||||
45
backend/internal/server/middleware/request_logger.go
Normal file
45
backend/internal/server/middleware/request_logger.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const requestIDHeader = "X-Request-ID"
|
||||
|
||||
// RequestLogger 在请求入口注入 request-scoped logger。
|
||||
func RequestLogger() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if c.Request == nil {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
requestID := strings.TrimSpace(c.GetHeader(requestIDHeader))
|
||||
if requestID == "" {
|
||||
requestID = uuid.NewString()
|
||||
}
|
||||
c.Header(requestIDHeader, requestID)
|
||||
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.RequestID, requestID)
|
||||
clientRequestID, _ := ctx.Value(ctxkey.ClientRequestID).(string)
|
||||
|
||||
requestLogger := logger.With(
|
||||
zap.String("component", "http"),
|
||||
zap.String("request_id", requestID),
|
||||
zap.String("client_request_id", strings.TrimSpace(clientRequestID)),
|
||||
zap.String("path", c.Request.URL.Path),
|
||||
zap.String("method", c.Request.Method),
|
||||
)
|
||||
|
||||
ctx = logger.IntoContext(ctx, requestLogger)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
@@ -3,6 +3,8 @@ package middleware
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
@@ -18,11 +20,14 @@ const (
|
||||
CloudflareInsightsDomain = "https://static.cloudflareinsights.com"
|
||||
)
|
||||
|
||||
// GenerateNonce generates a cryptographically secure random nonce
|
||||
func GenerateNonce() string {
|
||||
// GenerateNonce generates a cryptographically secure random nonce.
|
||||
// 返回 error 以确保调用方在 crypto/rand 失败时能正确降级。
|
||||
func GenerateNonce() (string, error) {
|
||||
b := make([]byte, 16)
|
||||
_, _ = rand.Read(b)
|
||||
return base64.StdEncoding.EncodeToString(b)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", fmt.Errorf("generate CSP nonce: %w", err)
|
||||
}
|
||||
return base64.StdEncoding.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
// GetNonceFromContext retrieves the CSP nonce from gin context
|
||||
@@ -52,12 +57,17 @@ func SecurityHeaders(cfg config.CSPConfig) gin.HandlerFunc {
|
||||
|
||||
if cfg.Enabled {
|
||||
// Generate nonce for this request
|
||||
nonce := GenerateNonce()
|
||||
c.Set(CSPNonceKey, nonce)
|
||||
|
||||
// Replace nonce placeholder in policy
|
||||
finalPolicy := strings.ReplaceAll(policy, NonceTemplate, "'nonce-"+nonce+"'")
|
||||
c.Header("Content-Security-Policy", finalPolicy)
|
||||
nonce, err := GenerateNonce()
|
||||
if err != nil {
|
||||
// crypto/rand 失败时降级为无 nonce 的 CSP 策略
|
||||
log.Printf("[SecurityHeaders] %v — 降级为无 nonce 的 CSP", err)
|
||||
finalPolicy := strings.ReplaceAll(policy, NonceTemplate, "'unsafe-inline'")
|
||||
c.Header("Content-Security-Policy", finalPolicy)
|
||||
} else {
|
||||
c.Set(CSPNonceKey, nonce)
|
||||
finalPolicy := strings.ReplaceAll(policy, NonceTemplate, "'nonce-"+nonce+"'")
|
||||
c.Header("Content-Security-Policy", finalPolicy)
|
||||
}
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
|
||||
@@ -19,7 +19,8 @@ func init() {
|
||||
|
||||
func TestGenerateNonce(t *testing.T) {
|
||||
t.Run("generates_valid_base64_string", func(t *testing.T) {
|
||||
nonce := GenerateNonce()
|
||||
nonce, err := GenerateNonce()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should be valid base64
|
||||
decoded, err := base64.StdEncoding.DecodeString(nonce)
|
||||
@@ -32,14 +33,16 @@ func TestGenerateNonce(t *testing.T) {
|
||||
t.Run("generates_unique_nonces", func(t *testing.T) {
|
||||
nonces := make(map[string]bool)
|
||||
for i := 0; i < 100; i++ {
|
||||
nonce := GenerateNonce()
|
||||
nonce, err := GenerateNonce()
|
||||
require.NoError(t, err)
|
||||
assert.False(t, nonces[nonce], "nonce should be unique")
|
||||
nonces[nonce] = true
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("nonce_has_expected_length", func(t *testing.T) {
|
||||
nonce := GenerateNonce()
|
||||
nonce, err := GenerateNonce()
|
||||
require.NoError(t, err)
|
||||
// 16 bytes -> 24 chars in base64 (with padding)
|
||||
assert.Len(t, nonce, 24)
|
||||
})
|
||||
@@ -344,7 +347,7 @@ func TestAddToDirective(t *testing.T) {
|
||||
// Benchmark tests
|
||||
func BenchmarkGenerateNonce(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
GenerateNonce()
|
||||
_, _ = GenerateNonce()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -29,6 +29,7 @@ func SetupRouter(
|
||||
redisClient *redis.Client,
|
||||
) *gin.Engine {
|
||||
// 应用中间件
|
||||
r.Use(middleware2.RequestLogger())
|
||||
r.Use(middleware2.Logger())
|
||||
r.Use(middleware2.CORS(cfg.CORS))
|
||||
r.Use(middleware2.SecurityHeaders(cfg.Security.CSP))
|
||||
|
||||
@@ -34,6 +34,8 @@ func RegisterAdminRoutes(
|
||||
|
||||
// OpenAI OAuth
|
||||
registerOpenAIOAuthRoutes(admin, h)
|
||||
// Sora OAuth(实现复用 OpenAI OAuth 服务,入口独立)
|
||||
registerSoraOAuthRoutes(admin, h)
|
||||
|
||||
// Gemini OAuth
|
||||
registerGeminiOAuthRoutes(admin, h)
|
||||
@@ -101,6 +103,9 @@ func registerOpsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
{
|
||||
runtime.GET("/alert", h.Admin.Ops.GetAlertRuntimeSettings)
|
||||
runtime.PUT("/alert", h.Admin.Ops.UpdateAlertRuntimeSettings)
|
||||
runtime.GET("/logging", h.Admin.Ops.GetRuntimeLogConfig)
|
||||
runtime.PUT("/logging", h.Admin.Ops.UpdateRuntimeLogConfig)
|
||||
runtime.POST("/logging/reset", h.Admin.Ops.ResetRuntimeLogConfig)
|
||||
}
|
||||
|
||||
// Advanced settings (DB-backed)
|
||||
@@ -144,12 +149,18 @@ func registerOpsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
// Request drilldown (success + error)
|
||||
ops.GET("/requests", h.Admin.Ops.ListRequestDetails)
|
||||
|
||||
// Indexed system logs
|
||||
ops.GET("/system-logs", h.Admin.Ops.ListSystemLogs)
|
||||
ops.POST("/system-logs/cleanup", h.Admin.Ops.CleanupSystemLogs)
|
||||
ops.GET("/system-logs/health", h.Admin.Ops.GetSystemLogIngestionHealth)
|
||||
|
||||
// Dashboard (vNext - raw path for MVP)
|
||||
ops.GET("/dashboard/overview", h.Admin.Ops.GetDashboardOverview)
|
||||
ops.GET("/dashboard/throughput-trend", h.Admin.Ops.GetDashboardThroughputTrend)
|
||||
ops.GET("/dashboard/latency-histogram", h.Admin.Ops.GetDashboardLatencyHistogram)
|
||||
ops.GET("/dashboard/error-trend", h.Admin.Ops.GetDashboardErrorTrend)
|
||||
ops.GET("/dashboard/error-distribution", h.Admin.Ops.GetDashboardErrorDistribution)
|
||||
ops.GET("/dashboard/openai-token-stats", h.Admin.Ops.GetDashboardOpenAITokenStats)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -208,6 +219,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
accounts.GET("", h.Admin.Account.List)
|
||||
accounts.GET("/:id", h.Admin.Account.GetByID)
|
||||
accounts.POST("", h.Admin.Account.Create)
|
||||
accounts.POST("/check-mixed-channel", h.Admin.Account.CheckMixedChannel)
|
||||
accounts.POST("/sync/crs", h.Admin.Account.SyncFromCRS)
|
||||
accounts.POST("/sync/crs/preview", h.Admin.Account.PreviewFromCRS)
|
||||
accounts.PUT("/:id", h.Admin.Account.Update)
|
||||
@@ -267,6 +279,19 @@ func registerOpenAIOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
}
|
||||
}
|
||||
|
||||
func registerSoraOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
sora := admin.Group("/sora")
|
||||
{
|
||||
sora.POST("/generate-auth-url", h.Admin.OpenAIOAuth.GenerateAuthURL)
|
||||
sora.POST("/exchange-code", h.Admin.OpenAIOAuth.ExchangeCode)
|
||||
sora.POST("/refresh-token", h.Admin.OpenAIOAuth.RefreshToken)
|
||||
sora.POST("/st2at", h.Admin.OpenAIOAuth.ExchangeSoraSessionToken)
|
||||
sora.POST("/rt2at", h.Admin.OpenAIOAuth.RefreshToken)
|
||||
sora.POST("/accounts/:id/refresh", h.Admin.OpenAIOAuth.RefreshAccountToken)
|
||||
sora.POST("/create-from-oauth", h.Admin.OpenAIOAuth.CreateAccountFromOAuth)
|
||||
}
|
||||
}
|
||||
|
||||
func registerGeminiOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
gemini := admin.Group("/gemini")
|
||||
{
|
||||
@@ -297,6 +322,7 @@ func registerProxyRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
proxies.PUT("/:id", h.Admin.Proxy.Update)
|
||||
proxies.DELETE("/:id", h.Admin.Proxy.Delete)
|
||||
proxies.POST("/:id/test", h.Admin.Proxy.Test)
|
||||
proxies.POST("/:id/quality-check", h.Admin.Proxy.CheckQuality)
|
||||
proxies.GET("/:id/stats", h.Admin.Proxy.GetStats)
|
||||
proxies.GET("/:id/accounts", h.Admin.Proxy.GetProxyAccounts)
|
||||
proxies.POST("/batch-delete", h.Admin.Proxy.BatchDelete)
|
||||
|
||||
@@ -24,10 +24,19 @@ func RegisterAuthRoutes(
|
||||
// 公开接口
|
||||
auth := v1.Group("/auth")
|
||||
{
|
||||
auth.POST("/register", h.Auth.Register)
|
||||
auth.POST("/login", h.Auth.Login)
|
||||
auth.POST("/login/2fa", h.Auth.Login2FA)
|
||||
auth.POST("/send-verify-code", h.Auth.SendVerifyCode)
|
||||
// 注册/登录/2FA/验证码发送均属于高风险入口,增加服务端兜底限流(Redis 故障时 fail-close)
|
||||
auth.POST("/register", rateLimiter.LimitWithOptions("auth-register", 5, time.Minute, middleware.RateLimitOptions{
|
||||
FailureMode: middleware.RateLimitFailClose,
|
||||
}), h.Auth.Register)
|
||||
auth.POST("/login", rateLimiter.LimitWithOptions("auth-login", 20, time.Minute, middleware.RateLimitOptions{
|
||||
FailureMode: middleware.RateLimitFailClose,
|
||||
}), h.Auth.Login)
|
||||
auth.POST("/login/2fa", rateLimiter.LimitWithOptions("auth-login-2fa", 20, time.Minute, middleware.RateLimitOptions{
|
||||
FailureMode: middleware.RateLimitFailClose,
|
||||
}), h.Auth.Login2FA)
|
||||
auth.POST("/send-verify-code", rateLimiter.LimitWithOptions("auth-send-verify-code", 5, time.Minute, middleware.RateLimitOptions{
|
||||
FailureMode: middleware.RateLimitFailClose,
|
||||
}), h.Auth.SendVerifyCode)
|
||||
// Token刷新接口添加速率限制:每分钟最多 30 次(Redis 故障时 fail-close)
|
||||
auth.POST("/refresh", rateLimiter.LimitWithOptions("refresh-token", 30, time.Minute, middleware.RateLimitOptions{
|
||||
FailureMode: middleware.RateLimitFailClose,
|
||||
|
||||
@@ -0,0 +1,111 @@
|
||||
//go:build integration
|
||||
|
||||
package routes
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/require"
|
||||
tcredis "github.com/testcontainers/testcontainers-go/modules/redis"
|
||||
)
|
||||
|
||||
const authRouteRedisImageTag = "redis:8.4-alpine"
|
||||
|
||||
func TestAuthRegisterRateLimitThresholdHitReturns429(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
rdb := startAuthRouteRedis(t, ctx)
|
||||
|
||||
router := newAuthRoutesTestRouter(rdb)
|
||||
const path = "/api/v1/auth/register"
|
||||
|
||||
for i := 1; i <= 6; i++ {
|
||||
req := httptest.NewRequest(http.MethodPost, path, strings.NewReader(`{}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.RemoteAddr = "198.51.100.10:23456"
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if i <= 5 {
|
||||
require.Equal(t, http.StatusBadRequest, w.Code, "第 %d 次请求应先进入业务校验", i)
|
||||
continue
|
||||
}
|
||||
require.Equal(t, http.StatusTooManyRequests, w.Code, "第 6 次请求应命中限流")
|
||||
require.Contains(t, w.Body.String(), "rate limit exceeded")
|
||||
}
|
||||
}
|
||||
|
||||
func startAuthRouteRedis(t *testing.T, ctx context.Context) *redis.Client {
|
||||
t.Helper()
|
||||
ensureAuthRouteDockerAvailable(t)
|
||||
|
||||
redisContainer, err := tcredis.Run(ctx, authRouteRedisImageTag)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
_ = redisContainer.Terminate(ctx)
|
||||
})
|
||||
|
||||
redisHost, err := redisContainer.Host(ctx)
|
||||
require.NoError(t, err)
|
||||
redisPort, err := redisContainer.MappedPort(ctx, "6379/tcp")
|
||||
require.NoError(t, err)
|
||||
|
||||
rdb := redis.NewClient(&redis.Options{
|
||||
Addr: fmt.Sprintf("%s:%d", redisHost, redisPort.Int()),
|
||||
DB: 0,
|
||||
})
|
||||
require.NoError(t, rdb.Ping(ctx).Err())
|
||||
t.Cleanup(func() {
|
||||
_ = rdb.Close()
|
||||
})
|
||||
return rdb
|
||||
}
|
||||
|
||||
func ensureAuthRouteDockerAvailable(t *testing.T) {
|
||||
t.Helper()
|
||||
if authRouteDockerAvailable() {
|
||||
return
|
||||
}
|
||||
t.Skip("Docker 未启用,跳过认证限流集成测试")
|
||||
}
|
||||
|
||||
func authRouteDockerAvailable() bool {
|
||||
if os.Getenv("DOCKER_HOST") != "" {
|
||||
return true
|
||||
}
|
||||
|
||||
socketCandidates := []string{
|
||||
"/var/run/docker.sock",
|
||||
filepath.Join(os.Getenv("XDG_RUNTIME_DIR"), "docker.sock"),
|
||||
filepath.Join(authRouteUserHomeDir(), ".docker", "run", "docker.sock"),
|
||||
filepath.Join(authRouteUserHomeDir(), ".docker", "desktop", "docker.sock"),
|
||||
filepath.Join("/run/user", strconv.Itoa(os.Getuid()), "docker.sock"),
|
||||
}
|
||||
|
||||
for _, socket := range socketCandidates {
|
||||
if socket == "" {
|
||||
continue
|
||||
}
|
||||
if _, err := os.Stat(socket); err == nil {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func authRouteUserHomeDir() string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return home
|
||||
}
|
||||
67
backend/internal/server/routes/auth_rate_limit_test.go
Normal file
67
backend/internal/server/routes/auth_rate_limit_test.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package routes
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||
servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func newAuthRoutesTestRouter(redisClient *redis.Client) *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
v1 := router.Group("/api/v1")
|
||||
|
||||
RegisterAuthRoutes(
|
||||
v1,
|
||||
&handler.Handlers{
|
||||
Auth: &handler.AuthHandler{},
|
||||
Setting: &handler.SettingHandler{},
|
||||
},
|
||||
servermiddleware.JWTAuthMiddleware(func(c *gin.Context) {
|
||||
c.Next()
|
||||
}),
|
||||
redisClient,
|
||||
)
|
||||
|
||||
return router
|
||||
}
|
||||
|
||||
func TestAuthRoutesRateLimitFailCloseWhenRedisUnavailable(t *testing.T) {
|
||||
rdb := redis.NewClient(&redis.Options{
|
||||
Addr: "127.0.0.1:1",
|
||||
DialTimeout: 50 * time.Millisecond,
|
||||
ReadTimeout: 50 * time.Millisecond,
|
||||
WriteTimeout: 50 * time.Millisecond,
|
||||
})
|
||||
t.Cleanup(func() {
|
||||
_ = rdb.Close()
|
||||
})
|
||||
|
||||
router := newAuthRoutesTestRouter(rdb)
|
||||
paths := []string{
|
||||
"/api/v1/auth/register",
|
||||
"/api/v1/auth/login",
|
||||
"/api/v1/auth/login/2fa",
|
||||
"/api/v1/auth/send-verify-code",
|
||||
}
|
||||
|
||||
for _, path := range paths {
|
||||
req := httptest.NewRequest(http.MethodPost, path, strings.NewReader(`{}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.RemoteAddr = "203.0.113.10:12345"
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusTooManyRequests, w.Code, "path=%s", path)
|
||||
require.Contains(t, w.Body.String(), "rate limit exceeded", "path=%s", path)
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,8 @@
|
||||
package routes
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
@@ -20,6 +22,11 @@ func RegisterGatewayRoutes(
|
||||
cfg *config.Config,
|
||||
) {
|
||||
bodyLimit := middleware.RequestBodyLimit(cfg.Gateway.MaxBodySize)
|
||||
soraMaxBodySize := cfg.Gateway.SoraMaxBodySize
|
||||
if soraMaxBodySize <= 0 {
|
||||
soraMaxBodySize = cfg.Gateway.MaxBodySize
|
||||
}
|
||||
soraBodyLimit := middleware.RequestBodyLimit(soraMaxBodySize)
|
||||
clientRequestID := middleware.ClientRequestID()
|
||||
opsErrorLogger := handler.OpsErrorLoggerMiddleware(opsService)
|
||||
|
||||
@@ -36,6 +43,15 @@ func RegisterGatewayRoutes(
|
||||
gateway.GET("/usage", h.Gateway.Usage)
|
||||
// OpenAI Responses API
|
||||
gateway.POST("/responses", h.OpenAIGateway.Responses)
|
||||
// 明确阻止旧协议入口:OpenAI 仅支持 Responses API,避免客户端误解为会自动路由到其它平台。
|
||||
gateway.POST("/chat/completions", func(c *gin.Context) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": gin.H{
|
||||
"type": "invalid_request_error",
|
||||
"message": "Unsupported legacy protocol: /v1/chat/completions is not supported. Please use /v1/responses.",
|
||||
},
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// Gemini 原生 API 兼容层(Gemini SDK/CLI 直连)
|
||||
@@ -82,4 +98,25 @@ func RegisterGatewayRoutes(
|
||||
antigravityV1Beta.GET("/models/:model", h.Gateway.GeminiV1BetaGetModel)
|
||||
antigravityV1Beta.POST("/models/*modelAction", h.Gateway.GeminiV1BetaModels)
|
||||
}
|
||||
|
||||
// Sora 专用路由(强制使用 sora 平台)
|
||||
soraV1 := r.Group("/sora/v1")
|
||||
soraV1.Use(soraBodyLimit)
|
||||
soraV1.Use(clientRequestID)
|
||||
soraV1.Use(opsErrorLogger)
|
||||
soraV1.Use(middleware.ForcePlatform(service.PlatformSora))
|
||||
soraV1.Use(gin.HandlerFunc(apiKeyAuth))
|
||||
{
|
||||
soraV1.POST("/chat/completions", h.SoraGateway.ChatCompletions)
|
||||
soraV1.GET("/models", h.Gateway.Models)
|
||||
}
|
||||
|
||||
// Sora 媒体代理(可选 API Key 验证)
|
||||
if cfg.Gateway.SoraMediaRequireAPIKey {
|
||||
r.GET("/sora/media/*filepath", gin.HandlerFunc(apiKeyAuth), h.SoraGateway.MediaProxy)
|
||||
} else {
|
||||
r.GET("/sora/media/*filepath", h.SoraGateway.MediaProxy)
|
||||
}
|
||||
// Sora 媒体代理(签名 URL,无需 API Key)
|
||||
r.GET("/sora/media-signed/*filepath", h.SoraGateway.MediaProxySigned)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user