Merge upstream/main

This commit is contained in:
song
2026-01-17 18:00:07 +08:00
394 changed files with 76872 additions and 1877 deletions

View File

@@ -82,6 +82,8 @@ func TestAPIContracts(t *testing.T) {
"name": "Key One",
"group_id": null,
"status": "active",
"ip_whitelist": null,
"ip_blacklist": null,
"created_at": "2025-01-02T03:04:05Z",
"updated_at": "2025-01-02T03:04:05Z"
}
@@ -116,6 +118,8 @@ func TestAPIContracts(t *testing.T) {
"name": "Key One",
"group_id": null,
"status": "active",
"ip_whitelist": null,
"ip_blacklist": null,
"created_at": "2025-01-02T03:04:05Z",
"updated_at": "2025-01-02T03:04:05Z"
}
@@ -235,9 +239,10 @@ func TestAPIContracts(t *testing.T) {
"cache_creation_cost": 0,
"cache_read_cost": 0,
"total_cost": 0.5,
"actual_cost": 0.5,
"rate_multiplier": 1,
"billing_type": 0,
"actual_cost": 0.5,
"rate_multiplier": 1,
"account_rate_multiplier": null,
"billing_type": 0,
"stream": true,
"duration_ms": 100,
"first_token_ms": 50,
@@ -283,6 +288,11 @@ func TestAPIContracts(t *testing.T) {
service.SettingKeyDefaultConcurrency: "5",
service.SettingKeyDefaultBalance: "1.25",
service.SettingKeyOpsMonitoringEnabled: "false",
service.SettingKeyOpsRealtimeMonitoringEnabled: "true",
service.SettingKeyOpsQueryModeDefault: "auto",
service.SettingKeyOpsMetricsIntervalSeconds: "60",
})
},
method: http.MethodGet,
@@ -305,13 +315,17 @@ func TestAPIContracts(t *testing.T) {
"turnstile_site_key": "site-key",
"turnstile_secret_key_configured": true,
"linuxdo_connect_enabled": false,
"linuxdo_connect_client_id": "",
"linuxdo_connect_client_secret_configured": false,
"linuxdo_connect_redirect_url": "",
"site_name": "Sub2API",
"site_logo": "",
"site_subtitle": "Subtitle",
"api_base_url": "https://api.example.com",
"linuxdo_connect_client_id": "",
"linuxdo_connect_client_secret_configured": false,
"linuxdo_connect_redirect_url": "",
"ops_monitoring_enabled": false,
"ops_realtime_monitoring_enabled": true,
"ops_query_mode_default": "auto",
"ops_metrics_interval_seconds": 60,
"site_name": "Sub2API",
"site_logo": "",
"site_subtitle": "Subtitle",
"api_base_url": "https://api.example.com",
"contact_info": "support",
"doc_url": "https://docs.example.com",
"default_concurrency": 5,
@@ -322,7 +336,32 @@ func TestAPIContracts(t *testing.T) {
"fallback_model_gemini": "gemini-2.5-pro",
"fallback_model_openai": "gpt-4o",
"enable_identity_patch": true,
"identity_patch_prompt": ""
"identity_patch_prompt": "",
"home_content": ""
}
}`,
},
{
name: "POST /api/v1/admin/accounts/bulk-update",
method: http.MethodPost,
path: "/api/v1/admin/accounts/bulk-update",
body: `{"account_ids":[101,102],"schedulable":false}`,
headers: map[string]string{
"Content-Type": "application/json",
},
wantStatus: http.StatusOK,
wantJSON: `{
"code": 0,
"message": "success",
"data": {
"success": 2,
"failed": 0,
"success_ids": [101, 102],
"failed_ids": [],
"results": [
{"account_id": 101, "success": true},
{"account_id": 102, "success": true}
]
}
}`,
},
@@ -377,6 +416,9 @@ func newContractDeps(t *testing.T) *contractDeps {
apiKeyCache := stubApiKeyCache{}
groupRepo := stubGroupRepo{}
userSubRepo := stubUserSubscriptionRepo{}
accountRepo := stubAccountRepo{}
proxyRepo := stubProxyRepo{}
redeemRepo := stubRedeemCodeRepo{}
cfg := &config.Config{
Default: config.DefaultConfig{
@@ -385,19 +427,21 @@ func newContractDeps(t *testing.T) *contractDeps {
RunMode: config.RunModeStandard,
}
userService := service.NewUserService(userRepo)
userService := service.NewUserService(userRepo, nil)
apiKeyService := service.NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, apiKeyCache, cfg)
usageRepo := newStubUsageLogRepo()
usageService := service.NewUsageService(usageRepo, userRepo, nil)
usageService := service.NewUsageService(usageRepo, userRepo, nil, nil)
settingRepo := newStubSettingRepo()
settingService := service.NewSettingService(settingRepo, cfg)
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService)
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil)
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil)
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil)
adminAccountHandler := adminhandler.NewAccountHandler(adminService, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
jwtAuth := func(c *gin.Context) {
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{
@@ -437,6 +481,7 @@ func newContractDeps(t *testing.T) *contractDeps {
v1Admin := v1.Group("/admin")
v1Admin.Use(adminAuth)
v1Admin.GET("/settings", adminSettingHandler.GetSettings)
v1Admin.POST("/accounts/bulk-update", adminAccountHandler.BulkUpdate)
return &contractDeps{
now: now,
@@ -561,6 +606,18 @@ func (stubApiKeyCache) SetDailyUsageExpiry(ctx context.Context, apiKey string, t
return nil
}
func (stubApiKeyCache) GetAuthCache(ctx context.Context, key string) (*service.APIKeyAuthCacheEntry, error) {
return nil, nil
}
func (stubApiKeyCache) SetAuthCache(ctx context.Context, key string, entry *service.APIKeyAuthCacheEntry, ttl time.Duration) error {
return nil
}
func (stubApiKeyCache) DeleteAuthCache(ctx context.Context, key string) error {
return nil
}
type stubGroupRepo struct{}
func (stubGroupRepo) Create(ctx context.Context, group *service.Group) error {
@@ -571,6 +628,10 @@ func (stubGroupRepo) GetByID(ctx context.Context, id int64) (*service.Group, err
return nil, service.ErrGroupNotFound
}
func (stubGroupRepo) GetByIDLite(ctx context.Context, id int64) (*service.Group, error) {
return nil, service.ErrGroupNotFound
}
func (stubGroupRepo) Update(ctx context.Context, group *service.Group) error {
return errors.New("not implemented")
}
@@ -611,6 +672,251 @@ func (stubGroupRepo) DeleteAccountGroupsByGroupID(ctx context.Context, groupID i
return 0, errors.New("not implemented")
}
type stubAccountRepo struct {
bulkUpdateIDs []int64
}
func (s *stubAccountRepo) Create(ctx context.Context, account *service.Account) error {
return errors.New("not implemented")
}
func (s *stubAccountRepo) GetByID(ctx context.Context, id int64) (*service.Account, error) {
return nil, service.ErrAccountNotFound
}
func (s *stubAccountRepo) GetByIDs(ctx context.Context, ids []int64) ([]*service.Account, error) {
return nil, errors.New("not implemented")
}
func (s *stubAccountRepo) ExistsByID(ctx context.Context, id int64) (bool, error) {
return false, errors.New("not implemented")
}
func (s *stubAccountRepo) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*service.Account, error) {
return nil, errors.New("not implemented")
}
func (s *stubAccountRepo) Update(ctx context.Context, account *service.Account) error {
return errors.New("not implemented")
}
func (s *stubAccountRepo) Delete(ctx context.Context, id int64) error {
return errors.New("not implemented")
}
func (s *stubAccountRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (s *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]service.Account, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (s *stubAccountRepo) ListByGroup(ctx context.Context, groupID int64) ([]service.Account, error) {
return nil, errors.New("not implemented")
}
func (s *stubAccountRepo) ListActive(ctx context.Context) ([]service.Account, error) {
return nil, errors.New("not implemented")
}
func (s *stubAccountRepo) ListByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
return nil, errors.New("not implemented")
}
func (s *stubAccountRepo) UpdateLastUsed(ctx context.Context, id int64) error {
return errors.New("not implemented")
}
func (s *stubAccountRepo) BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error {
return errors.New("not implemented")
}
func (s *stubAccountRepo) SetError(ctx context.Context, id int64, errorMsg string) error {
return errors.New("not implemented")
}
func (s *stubAccountRepo) ClearError(ctx context.Context, id int64) error {
return errors.New("not implemented")
}
func (s *stubAccountRepo) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
return errors.New("not implemented")
}
func (s *stubAccountRepo) AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error) {
return 0, errors.New("not implemented")
}
func (s *stubAccountRepo) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
return errors.New("not implemented")
}
func (s *stubAccountRepo) ListSchedulable(ctx context.Context) ([]service.Account, error) {
return nil, errors.New("not implemented")
}
func (s *stubAccountRepo) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]service.Account, error) {
return nil, errors.New("not implemented")
}
func (s *stubAccountRepo) ListSchedulableByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
return nil, errors.New("not implemented")
}
func (s *stubAccountRepo) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]service.Account, error) {
return nil, errors.New("not implemented")
}
func (s *stubAccountRepo) ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]service.Account, error) {
return nil, errors.New("not implemented")
}
func (s *stubAccountRepo) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]service.Account, error) {
return nil, errors.New("not implemented")
}
func (s *stubAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
return errors.New("not implemented")
}
func (s *stubAccountRepo) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope service.AntigravityQuotaScope, resetAt time.Time) error {
return errors.New("not implemented")
}
func (s *stubAccountRepo) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
return errors.New("not implemented")
}
func (s *stubAccountRepo) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
return errors.New("not implemented")
}
func (s *stubAccountRepo) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
return errors.New("not implemented")
}
func (s *stubAccountRepo) ClearTempUnschedulable(ctx context.Context, id int64) error {
return errors.New("not implemented")
}
func (s *stubAccountRepo) ClearRateLimit(ctx context.Context, id int64) error {
return errors.New("not implemented")
}
func (s *stubAccountRepo) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error {
return errors.New("not implemented")
}
func (s *stubAccountRepo) ClearModelRateLimits(ctx context.Context, id int64) error {
return errors.New("not implemented")
}
func (s *stubAccountRepo) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
return errors.New("not implemented")
}
func (s *stubAccountRepo) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error {
return errors.New("not implemented")
}
func (s *stubAccountRepo) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) {
s.bulkUpdateIDs = append([]int64{}, ids...)
return int64(len(ids)), nil
}
type stubProxyRepo struct{}
func (stubProxyRepo) Create(ctx context.Context, proxy *service.Proxy) error {
return errors.New("not implemented")
}
func (stubProxyRepo) GetByID(ctx context.Context, id int64) (*service.Proxy, error) {
return nil, service.ErrProxyNotFound
}
func (stubProxyRepo) Update(ctx context.Context, proxy *service.Proxy) error {
return errors.New("not implemented")
}
func (stubProxyRepo) Delete(ctx context.Context, id int64) error {
return errors.New("not implemented")
}
func (stubProxyRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.Proxy, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (stubProxyRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]service.Proxy, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (stubProxyRepo) ListWithFiltersAndAccountCount(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]service.ProxyWithAccountCount, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (stubProxyRepo) ListActive(ctx context.Context) ([]service.Proxy, error) {
return nil, errors.New("not implemented")
}
func (stubProxyRepo) ListActiveWithAccountCount(ctx context.Context) ([]service.ProxyWithAccountCount, error) {
return nil, errors.New("not implemented")
}
func (stubProxyRepo) ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error) {
return false, errors.New("not implemented")
}
func (stubProxyRepo) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) {
return 0, errors.New("not implemented")
}
func (stubProxyRepo) ListAccountSummariesByProxyID(ctx context.Context, proxyID int64) ([]service.ProxyAccountSummary, error) {
return nil, errors.New("not implemented")
}
type stubRedeemCodeRepo struct{}
func (stubRedeemCodeRepo) Create(ctx context.Context, code *service.RedeemCode) error {
return errors.New("not implemented")
}
func (stubRedeemCodeRepo) CreateBatch(ctx context.Context, codes []service.RedeemCode) error {
return errors.New("not implemented")
}
func (stubRedeemCodeRepo) GetByID(ctx context.Context, id int64) (*service.RedeemCode, error) {
return nil, service.ErrRedeemCodeNotFound
}
func (stubRedeemCodeRepo) GetByCode(ctx context.Context, code string) (*service.RedeemCode, error) {
return nil, service.ErrRedeemCodeNotFound
}
func (stubRedeemCodeRepo) Update(ctx context.Context, code *service.RedeemCode) error {
return errors.New("not implemented")
}
func (stubRedeemCodeRepo) Delete(ctx context.Context, id int64) error {
return errors.New("not implemented")
}
func (stubRedeemCodeRepo) Use(ctx context.Context, id, userID int64) error {
return errors.New("not implemented")
}
func (stubRedeemCodeRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.RedeemCode, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (stubRedeemCodeRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, codeType, status, search string) ([]service.RedeemCode, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (stubRedeemCodeRepo) ListByUser(ctx context.Context, userID int64, limit int) ([]service.RedeemCode, error) {
return nil, errors.New("not implemented")
}
type stubUserSubscriptionRepo struct{}
func (stubUserSubscriptionRepo) Create(ctx context.Context, sub *service.UserSubscription) error {
@@ -729,12 +1035,12 @@ func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.APIKey
return &clone, nil
}
func (r *stubApiKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) {
func (r *stubApiKeyRepo) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) {
key, ok := r.byID[id]
if !ok {
return 0, service.ErrAPIKeyNotFound
return "", 0, service.ErrAPIKeyNotFound
}
return key.UserID, nil
return key.Key, key.UserID, nil
}
func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
@@ -746,6 +1052,10 @@ func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.API
return &clone, nil
}
func (r *stubApiKeyRepo) GetByKeyForAuth(ctx context.Context, key string) (*service.APIKey, error) {
return r.GetByKey(ctx, key)
}
func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.APIKey) error {
if key == nil {
return errors.New("nil key")
@@ -860,6 +1170,14 @@ func (r *stubApiKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int
return 0, errors.New("not implemented")
}
func (r *stubApiKeyRepo) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) {
return nil, errors.New("not implemented")
}
func (r *stubApiKeyRepo) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) {
return nil, errors.New("not implemented")
}
type stubUsageLogRepo struct {
userLogs map[int64][]service.UsageLog
}
@@ -928,11 +1246,11 @@ func (r *stubUsageLogRepo) GetDashboardStats(ctx context.Context) (*usagestats.D
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) ([]usagestats.TrendDataPoint, error) {
func (r *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool) ([]usagestats.TrendDataPoint, error) {
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) ([]usagestats.ModelStat, error) {
func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool) ([]usagestats.ModelStat, error) {
return nil, errors.New("not implemented")
}

View File

@@ -13,6 +13,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/google/wire"
"github.com/redis/go-redis/v9"
)
// ProviderSet 提供服务器层的依赖
@@ -30,6 +31,9 @@ func ProvideRouter(
apiKeyAuth middleware2.APIKeyAuthMiddleware,
apiKeyService *service.APIKeyService,
subscriptionService *service.SubscriptionService,
opsService *service.OpsService,
settingService *service.SettingService,
redisClient *redis.Client,
) *gin.Engine {
if cfg.Server.Mode == "release" {
gin.SetMode(gin.ReleaseMode)
@@ -47,7 +51,7 @@ func ProvideRouter(
}
}
return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, cfg)
return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg, redisClient)
}
// ProvideHTTPServer 提供 HTTP 服务器

View File

@@ -30,6 +30,20 @@ func adminAuth(
settingService *service.SettingService,
) gin.HandlerFunc {
return func(c *gin.Context) {
// WebSocket upgrade requests cannot set Authorization headers in browsers.
// For admin WebSocket endpoints (e.g. Ops realtime), allow passing the JWT via
// Sec-WebSocket-Protocol (subprotocol list) using a prefixed token item:
// Sec-WebSocket-Protocol: sub2api-admin, jwt.<token>
if isWebSocketUpgradeRequest(c) {
if token := extractJWTFromWebSocketSubprotocol(c); token != "" {
if !validateJWTForAdmin(c, token, authService, userService) {
return
}
c.Next()
return
}
}
// 检查 x-api-key headerAdmin API Key 认证)
apiKey := c.GetHeader("x-api-key")
if apiKey != "" {
@@ -58,6 +72,44 @@ func adminAuth(
}
}
func isWebSocketUpgradeRequest(c *gin.Context) bool {
if c == nil || c.Request == nil {
return false
}
// RFC6455 handshake uses:
// Connection: Upgrade
// Upgrade: websocket
upgrade := strings.ToLower(strings.TrimSpace(c.GetHeader("Upgrade")))
if upgrade != "websocket" {
return false
}
connection := strings.ToLower(c.GetHeader("Connection"))
return strings.Contains(connection, "upgrade")
}
func extractJWTFromWebSocketSubprotocol(c *gin.Context) string {
if c == nil {
return ""
}
raw := strings.TrimSpace(c.GetHeader("Sec-WebSocket-Protocol"))
if raw == "" {
return ""
}
// The header is a comma-separated list of tokens. We reserve the prefix "jwt."
// for carrying the admin JWT.
for _, part := range strings.Split(raw, ",") {
p := strings.TrimSpace(part)
if strings.HasPrefix(p, "jwt.") {
token := strings.TrimSpace(strings.TrimPrefix(p, "jwt."))
if token != "" {
return token
}
}
}
return ""
}
// validateAdminAPIKey 验证管理员 API Key
func validateAdminAPIKey(
c *gin.Context,

View File

@@ -1,11 +1,14 @@
package middleware
import (
"context"
"errors"
"log"
"strings"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
@@ -71,6 +74,17 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
return
}
// 检查 IP 限制(白名单/黑名单)
// 注意:错误信息故意模糊,避免暴露具体的 IP 限制机制
if len(apiKey.IPWhitelist) > 0 || len(apiKey.IPBlacklist) > 0 {
clientIP := ip.GetClientIP(c)
allowed, _ := ip.CheckIPRestriction(clientIP, apiKey.IPWhitelist, apiKey.IPBlacklist)
if !allowed {
AbortWithError(c, 403, "ACCESS_DENIED", "Access denied")
return
}
}
// 检查关联的用户
if apiKey.User == nil {
AbortWithError(c, 401, "USER_NOT_FOUND", "User associated with API key not found")
@@ -91,6 +105,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
Concurrency: apiKey.User.Concurrency,
})
c.Set(string(ContextKeyUserRole), apiKey.User.Role)
setGroupContext(c, apiKey.Group)
c.Next()
return
}
@@ -149,6 +164,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
Concurrency: apiKey.User.Concurrency,
})
c.Set(string(ContextKeyUserRole), apiKey.User.Role)
setGroupContext(c, apiKey.Group)
c.Next()
}
@@ -173,3 +189,14 @@ func GetSubscriptionFromContext(c *gin.Context) (*service.UserSubscription, bool
subscription, ok := value.(*service.UserSubscription)
return subscription, ok
}
func setGroupContext(c *gin.Context, group *service.Group) {
if !service.IsGroupContextValid(group) {
return
}
if existing, ok := c.Request.Context().Value(ctxkey.Group).(*service.Group); ok && existing != nil && existing.ID == group.ID && service.IsGroupContextValid(existing) {
return
}
ctx := context.WithValue(c.Request.Context(), ctxkey.Group, group)
c.Request = c.Request.WithContext(ctx)
}

View File

@@ -63,6 +63,7 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs
Concurrency: apiKey.User.Concurrency,
})
c.Set(string(ContextKeyUserRole), apiKey.User.Role)
setGroupContext(c, apiKey.Group)
c.Next()
return
}
@@ -102,6 +103,7 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs
Concurrency: apiKey.User.Concurrency,
})
c.Set(string(ContextKeyUserRole), apiKey.User.Role)
setGroupContext(c, apiKey.Group)
c.Next()
}
}

View File

@@ -9,6 +9,7 @@ import (
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -26,8 +27,8 @@ func (f fakeAPIKeyRepo) Create(ctx context.Context, key *service.APIKey) error {
func (f fakeAPIKeyRepo) GetByID(ctx context.Context, id int64) (*service.APIKey, error) {
return nil, errors.New("not implemented")
}
func (f fakeAPIKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) {
return 0, errors.New("not implemented")
func (f fakeAPIKeyRepo) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) {
return "", 0, errors.New("not implemented")
}
func (f fakeAPIKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
if f.getByKey == nil {
@@ -35,6 +36,9 @@ func (f fakeAPIKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIK
}
return f.getByKey(ctx, key)
}
func (f fakeAPIKeyRepo) GetByKeyForAuth(ctx context.Context, key string) (*service.APIKey, error) {
return f.GetByKey(ctx, key)
}
func (f fakeAPIKeyRepo) Update(ctx context.Context, key *service.APIKey) error {
return errors.New("not implemented")
}
@@ -65,6 +69,12 @@ func (f fakeAPIKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64
func (f fakeAPIKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, errors.New("not implemented")
}
func (f fakeAPIKeyRepo) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) {
return nil, errors.New("not implemented")
}
func (f fakeAPIKeyRepo) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) {
return nil, errors.New("not implemented")
}
type googleErrorResponse struct {
Error struct {
@@ -133,6 +143,70 @@ func TestApiKeyAuthWithSubscriptionGoogle_QueryApiKeyRejected(t *testing.T) {
require.Equal(t, "INVALID_ARGUMENT", resp.Error.Status)
}
func TestApiKeyAuthWithSubscriptionGoogleSetsGroupContext(t *testing.T) {
gin.SetMode(gin.TestMode)
group := &service.Group{
ID: 99,
Name: "g1",
Status: service.StatusActive,
Platform: service.PlatformGemini,
Hydrated: true,
}
user := &service.User{
ID: 7,
Role: service.RoleUser,
Status: service.StatusActive,
Balance: 10,
Concurrency: 3,
}
apiKey := &service.APIKey{
ID: 100,
UserID: user.ID,
Key: "test-key",
Status: service.StatusActive,
User: user,
Group: group,
}
apiKey.GroupID = &group.ID
apiKeyService := service.NewAPIKeyService(
fakeAPIKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
if key != apiKey.Key {
return nil, service.ErrAPIKeyNotFound
}
clone := *apiKey
return &clone, nil
},
},
nil,
nil,
nil,
nil,
&config.Config{RunMode: config.RunModeSimple},
)
cfg := &config.Config{RunMode: config.RunModeSimple}
r := gin.New()
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, cfg))
r.GET("/v1beta/test", func(c *gin.Context) {
groupFromCtx, ok := c.Request.Context().Value(ctxkey.Group).(*service.Group)
if !ok || groupFromCtx == nil || groupFromCtx.ID != group.ID {
c.JSON(http.StatusInternalServerError, gin.H{"ok": false})
return
}
c.JSON(http.StatusOK, gin.H{"ok": true})
})
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
req.Header.Set("x-api-key", apiKey.Key)
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
}
func TestApiKeyAuthWithSubscriptionGoogle_QueryKeyAllowedOnV1Beta(t *testing.T) {
gin.SetMode(gin.TestMode)

View File

@@ -11,6 +11,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
@@ -25,6 +26,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
ID: 42,
Name: "sub",
Status: service.StatusActive,
Hydrated: true,
SubscriptionType: service.SubscriptionTypeSubscription,
DailyLimitUSD: &limit,
}
@@ -110,6 +112,129 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
})
}
func TestAPIKeyAuthSetsGroupContext(t *testing.T) {
gin.SetMode(gin.TestMode)
group := &service.Group{
ID: 101,
Name: "g1",
Status: service.StatusActive,
Platform: service.PlatformAnthropic,
Hydrated: true,
}
user := &service.User{
ID: 7,
Role: service.RoleUser,
Status: service.StatusActive,
Balance: 10,
Concurrency: 3,
}
apiKey := &service.APIKey{
ID: 100,
UserID: user.ID,
Key: "test-key",
Status: service.StatusActive,
User: user,
Group: group,
}
apiKey.GroupID = &group.ID
apiKeyRepo := &stubApiKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
if key != apiKey.Key {
return nil, service.ErrAPIKeyNotFound
}
clone := *apiKey
return &clone, nil
},
}
cfg := &config.Config{RunMode: config.RunModeSimple}
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
router := gin.New()
router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, nil, cfg)))
router.GET("/t", func(c *gin.Context) {
groupFromCtx, ok := c.Request.Context().Value(ctxkey.Group).(*service.Group)
if !ok || groupFromCtx == nil || groupFromCtx.ID != group.ID {
c.JSON(http.StatusInternalServerError, gin.H{"ok": false})
return
}
c.JSON(http.StatusOK, gin.H{"ok": true})
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/t", nil)
req.Header.Set("x-api-key", apiKey.Key)
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
}
func TestAPIKeyAuthOverwritesInvalidContextGroup(t *testing.T) {
gin.SetMode(gin.TestMode)
group := &service.Group{
ID: 101,
Name: "g1",
Status: service.StatusActive,
Platform: service.PlatformAnthropic,
Hydrated: true,
}
user := &service.User{
ID: 7,
Role: service.RoleUser,
Status: service.StatusActive,
Balance: 10,
Concurrency: 3,
}
apiKey := &service.APIKey{
ID: 100,
UserID: user.ID,
Key: "test-key",
Status: service.StatusActive,
User: user,
Group: group,
}
apiKey.GroupID = &group.ID
apiKeyRepo := &stubApiKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
if key != apiKey.Key {
return nil, service.ErrAPIKeyNotFound
}
clone := *apiKey
return &clone, nil
},
}
cfg := &config.Config{RunMode: config.RunModeSimple}
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
router := gin.New()
router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, nil, cfg)))
invalidGroup := &service.Group{
ID: group.ID,
Platform: group.Platform,
Status: group.Status,
}
router.GET("/t", func(c *gin.Context) {
groupFromCtx, ok := c.Request.Context().Value(ctxkey.Group).(*service.Group)
if !ok || groupFromCtx == nil || groupFromCtx.ID != group.ID || !groupFromCtx.Hydrated || groupFromCtx == invalidGroup {
c.JSON(http.StatusInternalServerError, gin.H{"ok": false})
return
}
c.JSON(http.StatusOK, gin.H{"ok": true})
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/t", nil)
req.Header.Set("x-api-key", apiKey.Key)
req = req.WithContext(context.WithValue(req.Context(), ctxkey.Group, invalidGroup))
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
}
func newAuthTestRouter(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) *gin.Engine {
router := gin.New()
router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, cfg)))
@@ -131,8 +256,8 @@ func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.APIKey
return nil, errors.New("not implemented")
}
func (r *stubApiKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) {
return 0, errors.New("not implemented")
func (r *stubApiKeyRepo) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) {
return "", 0, errors.New("not implemented")
}
func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
@@ -142,6 +267,10 @@ func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.API
return nil, errors.New("not implemented")
}
func (r *stubApiKeyRepo) GetByKeyForAuth(ctx context.Context, key string) (*service.APIKey, error) {
return r.GetByKey(ctx, key)
}
func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.APIKey) error {
return errors.New("not implemented")
}
@@ -182,6 +311,14 @@ func (r *stubApiKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int
return 0, errors.New("not implemented")
}
func (r *stubApiKeyRepo) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) {
return nil, errors.New("not implemented")
}
func (r *stubApiKeyRepo) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) {
return nil, errors.New("not implemented")
}
type stubUserSubscriptionRepo struct {
getActive func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error)
updateStatus func(ctx context.Context, subscriptionID int64, status string) error

View File

@@ -0,0 +1,30 @@
package middleware
import (
"context"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
// ClientRequestID ensures every request has a unique client_request_id in request.Context().
//
// This is used by the Ops monitoring module for end-to-end request correlation.
func ClientRequestID() gin.HandlerFunc {
return func(c *gin.Context) {
if c.Request == nil {
c.Next()
return
}
if v := c.Request.Context().Value(ctxkey.ClientRequestID); v != nil {
c.Next()
return
}
id := uuid.New().String()
c.Request = c.Request.WithContext(context.WithValue(c.Request.Context(), ctxkey.ClientRequestID, id))
c.Next()
}
}

View File

@@ -1,12 +1,40 @@
package middleware
import (
"crypto/rand"
"encoding/base64"
"strings"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
)
const (
// CSPNonceKey is the context key for storing the CSP nonce
CSPNonceKey = "csp_nonce"
// NonceTemplate is the placeholder in CSP policy for nonce
NonceTemplate = "__CSP_NONCE__"
// CloudflareInsightsDomain is the domain for Cloudflare Web Analytics
CloudflareInsightsDomain = "https://static.cloudflareinsights.com"
)
// GenerateNonce generates a cryptographically secure random nonce
func GenerateNonce() string {
b := make([]byte, 16)
_, _ = rand.Read(b)
return base64.StdEncoding.EncodeToString(b)
}
// GetNonceFromContext retrieves the CSP nonce from gin context
func GetNonceFromContext(c *gin.Context) string {
if nonce, exists := c.Get(CSPNonceKey); exists {
if s, ok := nonce.(string); ok {
return s
}
}
return ""
}
// SecurityHeaders sets baseline security headers for all responses.
func SecurityHeaders(cfg config.CSPConfig) gin.HandlerFunc {
policy := strings.TrimSpace(cfg.Policy)
@@ -14,13 +42,75 @@ func SecurityHeaders(cfg config.CSPConfig) gin.HandlerFunc {
policy = config.DefaultCSPPolicy
}
// Enhance policy with required directives (nonce placeholder and Cloudflare Insights)
policy = enhanceCSPPolicy(policy)
return func(c *gin.Context) {
c.Header("X-Content-Type-Options", "nosniff")
c.Header("X-Frame-Options", "DENY")
c.Header("Referrer-Policy", "strict-origin-when-cross-origin")
if cfg.Enabled {
c.Header("Content-Security-Policy", policy)
// Generate nonce for this request
nonce := GenerateNonce()
c.Set(CSPNonceKey, nonce)
// Replace nonce placeholder in policy
finalPolicy := strings.ReplaceAll(policy, NonceTemplate, "'nonce-"+nonce+"'")
c.Header("Content-Security-Policy", finalPolicy)
}
c.Next()
}
}
// enhanceCSPPolicy ensures the CSP policy includes nonce support and Cloudflare Insights domain.
// This allows the application to work correctly even if the config file has an older CSP policy.
func enhanceCSPPolicy(policy string) string {
// Add nonce placeholder to script-src if not present
if !strings.Contains(policy, NonceTemplate) && !strings.Contains(policy, "'nonce-") {
policy = addToDirective(policy, "script-src", NonceTemplate)
}
// Add Cloudflare Insights domain to script-src if not present
if !strings.Contains(policy, CloudflareInsightsDomain) {
policy = addToDirective(policy, "script-src", CloudflareInsightsDomain)
}
return policy
}
// addToDirective adds a value to a specific CSP directive.
// If the directive doesn't exist, it will be added after default-src.
func addToDirective(policy, directive, value string) string {
// Find the directive in the policy
directivePrefix := directive + " "
idx := strings.Index(policy, directivePrefix)
if idx == -1 {
// Directive not found, add it after default-src or at the beginning
defaultSrcIdx := strings.Index(policy, "default-src ")
if defaultSrcIdx != -1 {
// Find the end of default-src directive (next semicolon)
endIdx := strings.Index(policy[defaultSrcIdx:], ";")
if endIdx != -1 {
insertPos := defaultSrcIdx + endIdx + 1
// Insert new directive after default-src
return policy[:insertPos] + " " + directive + " 'self' " + value + ";" + policy[insertPos:]
}
}
// Fallback: prepend the directive
return directive + " 'self' " + value + "; " + policy
}
// Find the end of this directive (next semicolon or end of string)
endIdx := strings.Index(policy[idx:], ";")
if endIdx == -1 {
// No semicolon found, directive goes to end of string
return policy + " " + value
}
// Insert value before the semicolon
insertPos := idx + endIdx
return policy[:insertPos] + " " + value + policy[insertPos:]
}

View File

@@ -0,0 +1,365 @@
package middleware
import (
"encoding/base64"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func init() {
gin.SetMode(gin.TestMode)
}
func TestGenerateNonce(t *testing.T) {
t.Run("generates_valid_base64_string", func(t *testing.T) {
nonce := GenerateNonce()
// Should be valid base64
decoded, err := base64.StdEncoding.DecodeString(nonce)
require.NoError(t, err)
// Should decode to 16 bytes
assert.Len(t, decoded, 16)
})
t.Run("generates_unique_nonces", func(t *testing.T) {
nonces := make(map[string]bool)
for i := 0; i < 100; i++ {
nonce := GenerateNonce()
assert.False(t, nonces[nonce], "nonce should be unique")
nonces[nonce] = true
}
})
t.Run("nonce_has_expected_length", func(t *testing.T) {
nonce := GenerateNonce()
// 16 bytes -> 24 chars in base64 (with padding)
assert.Len(t, nonce, 24)
})
}
func TestGetNonceFromContext(t *testing.T) {
t.Run("returns_nonce_when_present", func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
expectedNonce := "test-nonce-123"
c.Set(CSPNonceKey, expectedNonce)
nonce := GetNonceFromContext(c)
assert.Equal(t, expectedNonce, nonce)
})
t.Run("returns_empty_string_when_not_present", func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
nonce := GetNonceFromContext(c)
assert.Empty(t, nonce)
})
t.Run("returns_empty_for_wrong_type", func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
// Set a non-string value
c.Set(CSPNonceKey, 12345)
// Should return empty string for wrong type (safe type assertion)
nonce := GetNonceFromContext(c)
assert.Empty(t, nonce)
})
}
func TestSecurityHeaders(t *testing.T) {
t.Run("sets_basic_security_headers", func(t *testing.T) {
cfg := config.CSPConfig{Enabled: false}
middleware := SecurityHeaders(cfg)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
middleware(c)
assert.Equal(t, "nosniff", w.Header().Get("X-Content-Type-Options"))
assert.Equal(t, "DENY", w.Header().Get("X-Frame-Options"))
assert.Equal(t, "strict-origin-when-cross-origin", w.Header().Get("Referrer-Policy"))
})
t.Run("csp_disabled_no_csp_header", func(t *testing.T) {
cfg := config.CSPConfig{Enabled: false}
middleware := SecurityHeaders(cfg)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
middleware(c)
assert.Empty(t, w.Header().Get("Content-Security-Policy"))
})
t.Run("csp_enabled_sets_csp_header", func(t *testing.T) {
cfg := config.CSPConfig{
Enabled: true,
Policy: "default-src 'self'",
}
middleware := SecurityHeaders(cfg)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
middleware(c)
csp := w.Header().Get("Content-Security-Policy")
assert.NotEmpty(t, csp)
// Policy is auto-enhanced with nonce and Cloudflare Insights domain
assert.Contains(t, csp, "default-src 'self'")
assert.Contains(t, csp, "'nonce-")
assert.Contains(t, csp, CloudflareInsightsDomain)
})
t.Run("csp_enabled_with_nonce_placeholder", func(t *testing.T) {
cfg := config.CSPConfig{
Enabled: true,
Policy: "script-src 'self' __CSP_NONCE__",
}
middleware := SecurityHeaders(cfg)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
middleware(c)
csp := w.Header().Get("Content-Security-Policy")
assert.NotEmpty(t, csp)
assert.NotContains(t, csp, "__CSP_NONCE__", "placeholder should be replaced")
assert.Contains(t, csp, "'nonce-", "should contain nonce directive")
// Verify nonce is stored in context
nonce := GetNonceFromContext(c)
assert.NotEmpty(t, nonce)
assert.Contains(t, csp, "'nonce-"+nonce+"'")
})
t.Run("uses_default_policy_when_empty", func(t *testing.T) {
cfg := config.CSPConfig{
Enabled: true,
Policy: "",
}
middleware := SecurityHeaders(cfg)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
middleware(c)
csp := w.Header().Get("Content-Security-Policy")
assert.NotEmpty(t, csp)
// Default policy should contain these elements
assert.Contains(t, csp, "default-src 'self'")
})
t.Run("uses_default_policy_when_whitespace_only", func(t *testing.T) {
cfg := config.CSPConfig{
Enabled: true,
Policy: " \t\n ",
}
middleware := SecurityHeaders(cfg)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
middleware(c)
csp := w.Header().Get("Content-Security-Policy")
assert.NotEmpty(t, csp)
assert.Contains(t, csp, "default-src 'self'")
})
t.Run("multiple_nonce_placeholders_replaced", func(t *testing.T) {
cfg := config.CSPConfig{
Enabled: true,
Policy: "script-src __CSP_NONCE__; style-src __CSP_NONCE__",
}
middleware := SecurityHeaders(cfg)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
middleware(c)
csp := w.Header().Get("Content-Security-Policy")
nonce := GetNonceFromContext(c)
// Count occurrences of the nonce
count := strings.Count(csp, "'nonce-"+nonce+"'")
assert.Equal(t, 2, count, "both placeholders should be replaced with same nonce")
})
t.Run("calls_next_handler", func(t *testing.T) {
cfg := config.CSPConfig{Enabled: true, Policy: "default-src 'self'"}
middleware := SecurityHeaders(cfg)
nextCalled := false
router := gin.New()
router.Use(middleware)
router.GET("/test", func(c *gin.Context) {
nextCalled = true
c.Status(http.StatusOK)
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/test", nil)
router.ServeHTTP(w, req)
assert.True(t, nextCalled, "next handler should be called")
assert.Equal(t, http.StatusOK, w.Code)
})
t.Run("nonce_unique_per_request", func(t *testing.T) {
cfg := config.CSPConfig{
Enabled: true,
Policy: "script-src __CSP_NONCE__",
}
middleware := SecurityHeaders(cfg)
nonces := make(map[string]bool)
for i := 0; i < 10; i++ {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
middleware(c)
nonce := GetNonceFromContext(c)
assert.False(t, nonces[nonce], "nonce should be unique per request")
nonces[nonce] = true
}
})
}
func TestCSPNonceKey(t *testing.T) {
t.Run("constant_value", func(t *testing.T) {
assert.Equal(t, "csp_nonce", CSPNonceKey)
})
}
func TestNonceTemplate(t *testing.T) {
t.Run("constant_value", func(t *testing.T) {
assert.Equal(t, "__CSP_NONCE__", NonceTemplate)
})
}
func TestEnhanceCSPPolicy(t *testing.T) {
t.Run("adds_nonce_placeholder_if_missing", func(t *testing.T) {
policy := "default-src 'self'; script-src 'self'"
enhanced := enhanceCSPPolicy(policy)
assert.Contains(t, enhanced, NonceTemplate)
assert.Contains(t, enhanced, CloudflareInsightsDomain)
})
t.Run("does_not_duplicate_nonce_placeholder", func(t *testing.T) {
policy := "default-src 'self'; script-src 'self' __CSP_NONCE__"
enhanced := enhanceCSPPolicy(policy)
// Should not duplicate
count := strings.Count(enhanced, NonceTemplate)
assert.Equal(t, 1, count)
})
t.Run("does_not_duplicate_cloudflare_domain", func(t *testing.T) {
policy := "default-src 'self'; script-src 'self' https://static.cloudflareinsights.com"
enhanced := enhanceCSPPolicy(policy)
count := strings.Count(enhanced, CloudflareInsightsDomain)
assert.Equal(t, 1, count)
})
t.Run("handles_policy_without_script_src", func(t *testing.T) {
policy := "default-src 'self'"
enhanced := enhanceCSPPolicy(policy)
assert.Contains(t, enhanced, "script-src")
assert.Contains(t, enhanced, NonceTemplate)
assert.Contains(t, enhanced, CloudflareInsightsDomain)
})
t.Run("preserves_existing_nonce", func(t *testing.T) {
policy := "script-src 'self' 'nonce-existing'"
enhanced := enhanceCSPPolicy(policy)
// Should not add placeholder if nonce already exists
assert.NotContains(t, enhanced, NonceTemplate)
assert.Contains(t, enhanced, "'nonce-existing'")
})
}
func TestAddToDirective(t *testing.T) {
t.Run("adds_to_existing_directive", func(t *testing.T) {
policy := "script-src 'self'; style-src 'self'"
result := addToDirective(policy, "script-src", "https://example.com")
assert.Contains(t, result, "script-src 'self' https://example.com")
})
t.Run("creates_directive_if_not_exists", func(t *testing.T) {
policy := "default-src 'self'"
result := addToDirective(policy, "script-src", "https://example.com")
assert.Contains(t, result, "script-src")
assert.Contains(t, result, "https://example.com")
})
t.Run("handles_directive_at_end_without_semicolon", func(t *testing.T) {
policy := "default-src 'self'; script-src 'self'"
result := addToDirective(policy, "script-src", "https://example.com")
assert.Contains(t, result, "https://example.com")
})
t.Run("handles_empty_policy", func(t *testing.T) {
policy := ""
result := addToDirective(policy, "script-src", "https://example.com")
assert.Contains(t, result, "script-src")
assert.Contains(t, result, "https://example.com")
})
}
// Benchmark tests
func BenchmarkGenerateNonce(b *testing.B) {
for i := 0; i < b.N; i++ {
GenerateNonce()
}
}
func BenchmarkSecurityHeadersMiddleware(b *testing.B) {
cfg := config.CSPConfig{
Enabled: true,
Policy: "script-src 'self' __CSP_NONCE__",
}
middleware := SecurityHeaders(cfg)
b.ResetTimer()
for i := 0; i < b.N; i++ {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
middleware(c)
}
}

View File

@@ -1,6 +1,8 @@
package server
import (
"log"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
@@ -9,6 +11,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/web"
"github.com/gin-gonic/gin"
"github.com/redis/go-redis/v9"
)
// SetupRouter 配置路由器中间件和路由
@@ -20,20 +23,31 @@ func SetupRouter(
apiKeyAuth middleware2.APIKeyAuthMiddleware,
apiKeyService *service.APIKeyService,
subscriptionService *service.SubscriptionService,
opsService *service.OpsService,
settingService *service.SettingService,
cfg *config.Config,
redisClient *redis.Client,
) *gin.Engine {
// 应用中间件
r.Use(middleware2.Logger())
r.Use(middleware2.CORS(cfg.CORS))
r.Use(middleware2.SecurityHeaders(cfg.Security.CSP))
// Serve embedded frontend if available
// Serve embedded frontend with settings injection if available
if web.HasEmbeddedFrontend() {
r.Use(web.ServeEmbeddedFrontend())
frontendServer, err := web.NewFrontendServer(settingService)
if err != nil {
log.Printf("Warning: Failed to create frontend server with settings injection: %v, using legacy mode", err)
r.Use(web.ServeEmbeddedFrontend())
} else {
// Register cache invalidation callback
settingService.SetOnUpdateCallback(frontendServer.InvalidateCache)
r.Use(frontendServer.Middleware())
}
}
// 注册路由
registerRoutes(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, cfg)
registerRoutes(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, cfg, redisClient)
return r
}
@@ -47,7 +61,9 @@ func registerRoutes(
apiKeyAuth middleware2.APIKeyAuthMiddleware,
apiKeyService *service.APIKeyService,
subscriptionService *service.SubscriptionService,
opsService *service.OpsService,
cfg *config.Config,
redisClient *redis.Client,
) {
// 通用路由(健康检查、状态等)
routes.RegisterCommonRoutes(r)
@@ -56,8 +72,8 @@ func registerRoutes(
v1 := r.Group("/api/v1")
// 注册各模块路由
routes.RegisterAuthRoutes(v1, h, jwtAuth)
routes.RegisterAuthRoutes(v1, h, jwtAuth, redisClient)
routes.RegisterUserRoutes(v1, h, jwtAuth)
routes.RegisterAdminRoutes(v1, h, adminAuth)
routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, cfg)
routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, cfg)
}

View File

@@ -44,9 +44,15 @@ func RegisterAdminRoutes(
// 卡密管理
registerRedeemCodeRoutes(admin, h)
// 优惠码管理
registerPromoCodeRoutes(admin, h)
// 系统设置
registerSettingsRoutes(admin, h)
// 运维监控Ops
registerOpsRoutes(admin, h)
// 系统管理
registerSystemRoutes(admin, h)
@@ -61,6 +67,85 @@ func RegisterAdminRoutes(
}
}
func registerOpsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
ops := admin.Group("/ops")
{
// Realtime ops signals
ops.GET("/concurrency", h.Admin.Ops.GetConcurrencyStats)
ops.GET("/account-availability", h.Admin.Ops.GetAccountAvailability)
ops.GET("/realtime-traffic", h.Admin.Ops.GetRealtimeTrafficSummary)
// Alerts (rules + events)
ops.GET("/alert-rules", h.Admin.Ops.ListAlertRules)
ops.POST("/alert-rules", h.Admin.Ops.CreateAlertRule)
ops.PUT("/alert-rules/:id", h.Admin.Ops.UpdateAlertRule)
ops.DELETE("/alert-rules/:id", h.Admin.Ops.DeleteAlertRule)
ops.GET("/alert-events", h.Admin.Ops.ListAlertEvents)
ops.GET("/alert-events/:id", h.Admin.Ops.GetAlertEvent)
ops.PUT("/alert-events/:id/status", h.Admin.Ops.UpdateAlertEventStatus)
ops.POST("/alert-silences", h.Admin.Ops.CreateAlertSilence)
// Email notification config (DB-backed)
ops.GET("/email-notification/config", h.Admin.Ops.GetEmailNotificationConfig)
ops.PUT("/email-notification/config", h.Admin.Ops.UpdateEmailNotificationConfig)
// Runtime settings (DB-backed)
runtime := ops.Group("/runtime")
{
runtime.GET("/alert", h.Admin.Ops.GetAlertRuntimeSettings)
runtime.PUT("/alert", h.Admin.Ops.UpdateAlertRuntimeSettings)
}
// Advanced settings (DB-backed)
ops.GET("/advanced-settings", h.Admin.Ops.GetAdvancedSettings)
ops.PUT("/advanced-settings", h.Admin.Ops.UpdateAdvancedSettings)
// Settings group (DB-backed)
settings := ops.Group("/settings")
{
settings.GET("/metric-thresholds", h.Admin.Ops.GetMetricThresholds)
settings.PUT("/metric-thresholds", h.Admin.Ops.UpdateMetricThresholds)
}
// WebSocket realtime (QPS/TPS)
ws := ops.Group("/ws")
{
ws.GET("/qps", h.Admin.Ops.QPSWSHandler)
}
// Error logs (legacy)
ops.GET("/errors", h.Admin.Ops.GetErrorLogs)
ops.GET("/errors/:id", h.Admin.Ops.GetErrorLogByID)
ops.GET("/errors/:id/retries", h.Admin.Ops.ListRetryAttempts)
ops.POST("/errors/:id/retry", h.Admin.Ops.RetryErrorRequest)
ops.PUT("/errors/:id/resolve", h.Admin.Ops.UpdateErrorResolution)
// Request errors (client-visible failures)
ops.GET("/request-errors", h.Admin.Ops.ListRequestErrors)
ops.GET("/request-errors/:id", h.Admin.Ops.GetRequestError)
ops.GET("/request-errors/:id/upstream-errors", h.Admin.Ops.ListRequestErrorUpstreamErrors)
ops.POST("/request-errors/:id/retry-client", h.Admin.Ops.RetryRequestErrorClient)
ops.POST("/request-errors/:id/upstream-errors/:idx/retry", h.Admin.Ops.RetryRequestErrorUpstreamEvent)
ops.PUT("/request-errors/:id/resolve", h.Admin.Ops.ResolveRequestError)
// Upstream errors (independent upstream failures)
ops.GET("/upstream-errors", h.Admin.Ops.ListUpstreamErrors)
ops.GET("/upstream-errors/:id", h.Admin.Ops.GetUpstreamError)
ops.POST("/upstream-errors/:id/retry", h.Admin.Ops.RetryUpstreamError)
ops.PUT("/upstream-errors/:id/resolve", h.Admin.Ops.ResolveUpstreamError)
// Request drilldown (success + error)
ops.GET("/requests", h.Admin.Ops.ListRequestDetails)
// Dashboard (vNext - raw path for MVP)
ops.GET("/dashboard/overview", h.Admin.Ops.GetDashboardOverview)
ops.GET("/dashboard/throughput-trend", h.Admin.Ops.GetDashboardThroughputTrend)
ops.GET("/dashboard/latency-histogram", h.Admin.Ops.GetDashboardLatencyHistogram)
ops.GET("/dashboard/error-trend", h.Admin.Ops.GetDashboardErrorTrend)
ops.GET("/dashboard/error-distribution", h.Admin.Ops.GetDashboardErrorDistribution)
}
}
func registerDashboardRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
dashboard := admin.Group("/dashboard")
{
@@ -72,6 +157,7 @@ func registerDashboardRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
dashboard.GET("/users-trend", h.Admin.Dashboard.GetUserUsageTrend)
dashboard.POST("/users-usage", h.Admin.Dashboard.GetBatchUsersUsage)
dashboard.POST("/api-keys-usage", h.Admin.Dashboard.GetBatchAPIKeysUsage)
dashboard.POST("/aggregation/backfill", h.Admin.Dashboard.BackfillAggregation)
}
}
@@ -183,6 +269,7 @@ func registerProxyRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
proxies.POST("/:id/test", h.Admin.Proxy.Test)
proxies.GET("/:id/stats", h.Admin.Proxy.GetStats)
proxies.GET("/:id/accounts", h.Admin.Proxy.GetProxyAccounts)
proxies.POST("/batch-delete", h.Admin.Proxy.BatchDelete)
proxies.POST("/batch", h.Admin.Proxy.BatchCreate)
}
}
@@ -201,6 +288,18 @@ func registerRedeemCodeRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
}
}
func registerPromoCodeRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
promoCodes := admin.Group("/promo-codes")
{
promoCodes.GET("", h.Admin.Promo.List)
promoCodes.GET("/:id", h.Admin.Promo.GetByID)
promoCodes.POST("", h.Admin.Promo.Create)
promoCodes.PUT("/:id", h.Admin.Promo.Update)
promoCodes.DELETE("/:id", h.Admin.Promo.Delete)
promoCodes.GET("/:id/usages", h.Admin.Promo.GetUsages)
}
}
func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
adminSettings := admin.Group("/settings")
{
@@ -212,6 +311,9 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
adminSettings.GET("/admin-api-key", h.Admin.Setting.GetAdminAPIKey)
adminSettings.POST("/admin-api-key/regenerate", h.Admin.Setting.RegenerateAdminAPIKey)
adminSettings.DELETE("/admin-api-key", h.Admin.Setting.DeleteAdminAPIKey)
// 流超时处理配置
adminSettings.GET("/stream-timeout", h.Admin.Setting.GetStreamTimeoutSettings)
adminSettings.PUT("/stream-timeout", h.Admin.Setting.UpdateStreamTimeoutSettings)
}
}

View File

@@ -1,24 +1,36 @@
package routes
import (
"time"
"github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/middleware"
servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/gin-gonic/gin"
"github.com/redis/go-redis/v9"
)
// RegisterAuthRoutes 注册认证相关路由
func RegisterAuthRoutes(
v1 *gin.RouterGroup,
h *handler.Handlers,
jwtAuth middleware.JWTAuthMiddleware,
jwtAuth servermiddleware.JWTAuthMiddleware,
redisClient *redis.Client,
) {
// 创建速率限制器
rateLimiter := middleware.NewRateLimiter(redisClient)
// 公开接口
auth := v1.Group("/auth")
{
auth.POST("/register", h.Auth.Register)
auth.POST("/login", h.Auth.Login)
auth.POST("/send-verify-code", h.Auth.SendVerifyCode)
// 优惠码验证接口添加速率限制:每分钟最多 10 次Redis 故障时 fail-close
auth.POST("/validate-promo-code", rateLimiter.LimitWithOptions("validate-promo", 10, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}), h.Auth.ValidatePromoCode)
auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart)
auth.GET("/oauth/linuxdo/callback", h.Auth.LinuxDoOAuthCallback)
}

View File

@@ -16,13 +16,18 @@ func RegisterGatewayRoutes(
apiKeyAuth middleware.APIKeyAuthMiddleware,
apiKeyService *service.APIKeyService,
subscriptionService *service.SubscriptionService,
opsService *service.OpsService,
cfg *config.Config,
) {
bodyLimit := middleware.RequestBodyLimit(cfg.Gateway.MaxBodySize)
clientRequestID := middleware.ClientRequestID()
opsErrorLogger := handler.OpsErrorLoggerMiddleware(opsService)
// API网关Claude API兼容
gateway := r.Group("/v1")
gateway.Use(bodyLimit)
gateway.Use(clientRequestID)
gateway.Use(opsErrorLogger)
gateway.Use(gin.HandlerFunc(apiKeyAuth))
{
gateway.POST("/messages", h.Gateway.Messages)
@@ -36,6 +41,8 @@ func RegisterGatewayRoutes(
// Gemini 原生 API 兼容层Gemini SDK/CLI 直连)
gemini := r.Group("/v1beta")
gemini.Use(bodyLimit)
gemini.Use(clientRequestID)
gemini.Use(opsErrorLogger)
gemini.Use(middleware.APIKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg))
{
gemini.GET("/models", h.Gateway.GeminiV1BetaListModels)
@@ -45,7 +52,7 @@ func RegisterGatewayRoutes(
}
// OpenAI Responses API不带v1前缀的别名
r.POST("/responses", bodyLimit, gin.HandlerFunc(apiKeyAuth), h.OpenAIGateway.Responses)
r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), h.OpenAIGateway.Responses)
// Antigravity 模型列表
r.GET("/antigravity/models", gin.HandlerFunc(apiKeyAuth), h.Gateway.AntigravityModels)
@@ -53,6 +60,8 @@ func RegisterGatewayRoutes(
// Antigravity 专用路由(仅使用 antigravity 账户,不混合调度)
antigravityV1 := r.Group("/antigravity/v1")
antigravityV1.Use(bodyLimit)
antigravityV1.Use(clientRequestID)
antigravityV1.Use(opsErrorLogger)
antigravityV1.Use(middleware.ForcePlatform(service.PlatformAntigravity))
antigravityV1.Use(gin.HandlerFunc(apiKeyAuth))
{
@@ -64,6 +73,8 @@ func RegisterGatewayRoutes(
antigravityV1Beta := r.Group("/antigravity/v1beta")
antigravityV1Beta.Use(bodyLimit)
antigravityV1Beta.Use(clientRequestID)
antigravityV1Beta.Use(opsErrorLogger)
antigravityV1Beta.Use(middleware.ForcePlatform(service.PlatformAntigravity))
antigravityV1Beta.Use(middleware.APIKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg))
{