Merge upstream/main: v0.1.88-v0.1.101 updates
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -86,6 +86,15 @@ func TestAPIContracts(t *testing.T) {
|
||||
"last_used_at": null,
|
||||
"quota": 0,
|
||||
"quota_used": 0,
|
||||
"rate_limit_5h": 0,
|
||||
"rate_limit_1d": 0,
|
||||
"rate_limit_7d": 0,
|
||||
"usage_5h": 0,
|
||||
"usage_1d": 0,
|
||||
"usage_7d": 0,
|
||||
"window_5h_start": null,
|
||||
"window_1d_start": null,
|
||||
"window_7d_start": null,
|
||||
"expires_at": null,
|
||||
"created_at": "2025-01-02T03:04:05Z",
|
||||
"updated_at": "2025-01-02T03:04:05Z"
|
||||
@@ -126,6 +135,15 @@ func TestAPIContracts(t *testing.T) {
|
||||
"last_used_at": null,
|
||||
"quota": 0,
|
||||
"quota_used": 0,
|
||||
"rate_limit_5h": 0,
|
||||
"rate_limit_1d": 0,
|
||||
"rate_limit_7d": 0,
|
||||
"usage_5h": 0,
|
||||
"usage_1d": 0,
|
||||
"usage_7d": 0,
|
||||
"window_5h_start": null,
|
||||
"window_1d_start": null,
|
||||
"window_7d_start": null,
|
||||
"expires_at": null,
|
||||
"created_at": "2025-01-02T03:04:05Z",
|
||||
"updated_at": "2025-01-02T03:04:05Z"
|
||||
@@ -192,8 +210,10 @@ func TestAPIContracts(t *testing.T) {
|
||||
"sora_video_price_per_request": null,
|
||||
"sora_video_price_per_request_hd": null,
|
||||
"claude_code_only": false,
|
||||
"allow_messages_dispatch": false,
|
||||
"fallback_group_id": null,
|
||||
"fallback_group_id_on_invalid_request": null,
|
||||
"allow_messages_dispatch": false,
|
||||
"created_at": "2025-01-02T03:04:05Z",
|
||||
"updated_at": "2025-01-02T03:04:05Z"
|
||||
}
|
||||
@@ -428,9 +448,10 @@ func TestAPIContracts(t *testing.T) {
|
||||
setup: func(t *testing.T, deps *contractDeps) {
|
||||
t.Helper()
|
||||
deps.settingRepo.SetAll(map[string]string{
|
||||
service.SettingKeyRegistrationEnabled: "true",
|
||||
service.SettingKeyEmailVerifyEnabled: "false",
|
||||
service.SettingKeyPromoCodeEnabled: "true",
|
||||
service.SettingKeyRegistrationEnabled: "true",
|
||||
service.SettingKeyEmailVerifyEnabled: "false",
|
||||
service.SettingKeyRegistrationEmailSuffixWhitelist: "[]",
|
||||
service.SettingKeyPromoCodeEnabled: "true",
|
||||
|
||||
service.SettingKeySMTPHost: "smtp.example.com",
|
||||
service.SettingKeySMTPPort: "587",
|
||||
@@ -469,8 +490,10 @@ func TestAPIContracts(t *testing.T) {
|
||||
"data": {
|
||||
"registration_enabled": true,
|
||||
"email_verify_enabled": false,
|
||||
"registration_email_suffix_whitelist": [],
|
||||
"promo_code_enabled": true,
|
||||
"password_reset_enabled": false,
|
||||
"frontend_url": "",
|
||||
"totp_enabled": false,
|
||||
"totp_encryption_key_configured": false,
|
||||
"smtp_host": "smtp.example.com",
|
||||
@@ -513,7 +536,10 @@ func TestAPIContracts(t *testing.T) {
|
||||
"hide_ccs_import_button": false,
|
||||
"purchase_subscription_enabled": false,
|
||||
"purchase_subscription_url": "",
|
||||
"min_claude_code_version": ""
|
||||
"min_claude_code_version": "",
|
||||
"allow_ungrouped_key_scheduling": false,
|
||||
"backend_mode_enabled": false,
|
||||
"custom_menu_items": []
|
||||
}
|
||||
}`,
|
||||
},
|
||||
@@ -621,7 +647,7 @@ func newContractDeps(t *testing.T) *contractDeps {
|
||||
settingRepo := newStubSettingRepo()
|
||||
settingService := service.NewSettingService(settingRepo, cfg)
|
||||
|
||||
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil)
|
||||
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
||||
@@ -1026,6 +1052,14 @@ func (s *stubAccountRepo) ListSchedulableByGroupIDAndPlatforms(ctx context.Conte
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (s *stubAccountRepo) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]service.Account, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (s *stubAccountRepo) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]service.Account, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (s *stubAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
@@ -1066,6 +1100,14 @@ func (s *stubAccountRepo) UpdateExtra(ctx context.Context, id int64, updates map
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (s *stubAccountRepo) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (s *stubAccountRepo) ResetQuotaUsed(ctx context.Context, id int64) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (s *stubAccountRepo) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) {
|
||||
s.bulkUpdateIDs = append([]int64{}, ids...)
|
||||
return int64(len(ids)), nil
|
||||
@@ -1383,7 +1425,7 @@ func (r *stubApiKeyRepo) Delete(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
|
||||
func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams, _ service.APIKeyListFilters) ([]service.APIKey, *pagination.PaginationResult, error) {
|
||||
ids := make([]int64, 0, len(r.byID))
|
||||
for id := range r.byID {
|
||||
if r.byID[id].UserID == userID {
|
||||
@@ -1497,6 +1539,16 @@ func (r *stubApiKeyRepo) UpdateLastUsed(ctx context.Context, id int64, usedAt ti
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) IncrementRateLimitUsage(ctx context.Context, id int64, cost float64) error {
|
||||
return nil
|
||||
}
|
||||
func (r *stubApiKeyRepo) ResetRateLimitWindows(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
func (r *stubApiKeyRepo) GetRateLimitData(ctx context.Context, id int64) (*service.APIKeyRateLimitData, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
type stubUsageLogRepo struct {
|
||||
userLogs map[int64][]service.UsageLog
|
||||
}
|
||||
@@ -1573,6 +1625,14 @@ func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTi
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUsageLogRepo) GetEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUsageLogRepo) GetUpstreamEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUsageLogRepo) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
@@ -1585,6 +1645,10 @@ func (r *stubUsageLogRepo) GetUserUsageTrend(ctx context.Context, startTime, end
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUsageLogRepo) GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (*usagestats.UserSpendingRankingResponse, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUsageLogRepo) GetUserStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
|
||||
logs := r.userLogs[userID]
|
||||
if len(logs) == 0 {
|
||||
|
||||
@@ -19,7 +19,7 @@ func TestAdminAuthJWTValidatesTokenVersion(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
cfg := &config.Config{JWT: config.JWTConfig{Secret: "test-secret", ExpireHour: 1}}
|
||||
authService := service.NewAuthService(nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
|
||||
authService := service.NewAuthService(nil, nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
|
||||
|
||||
admin := &service.User{
|
||||
ID: 1,
|
||||
|
||||
@@ -19,8 +19,16 @@ func NewAPIKeyAuthMiddleware(apiKeyService *service.APIKeyService, subscriptionS
|
||||
}
|
||||
|
||||
// apiKeyAuthWithSubscription API Key认证中间件(支持订阅验证)
|
||||
//
|
||||
// 中间件职责分为两层:
|
||||
// - 鉴权(Authentication):验证 Key 有效性、用户状态、IP 限制 —— 始终执行
|
||||
// - 计费执行(Billing Enforcement):过期/配额/订阅/余额检查 —— skipBilling 时整块跳过
|
||||
//
|
||||
// /v1/usage 端点只需鉴权,不需要计费执行(允许过期/配额耗尽的 Key 查询自身用量)。
|
||||
func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// ── 1. 提取 API Key ──────────────────────────────────────────
|
||||
|
||||
queryKey := strings.TrimSpace(c.Query("key"))
|
||||
queryApiKey := strings.TrimSpace(c.Query("api_key"))
|
||||
if queryKey != "" || queryApiKey != "" {
|
||||
@@ -56,7 +64,8 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
|
||||
return
|
||||
}
|
||||
|
||||
// 从数据库验证API key
|
||||
// ── 2. 验证 Key 存在 ─────────────────────────────────────────
|
||||
|
||||
apiKey, err := apiKeyService.GetByKey(c.Request.Context(), apiKeyString)
|
||||
if err != nil {
|
||||
if errors.Is(err, service.ErrAPIKeyNotFound) {
|
||||
@@ -67,29 +76,13 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
|
||||
return
|
||||
}
|
||||
|
||||
// 检查API key是否激活
|
||||
if !apiKey.IsActive() {
|
||||
// Provide more specific error message based on status
|
||||
switch apiKey.Status {
|
||||
case service.StatusAPIKeyQuotaExhausted:
|
||||
AbortWithError(c, 429, "API_KEY_QUOTA_EXHAUSTED", "API key 额度已用完")
|
||||
case service.StatusAPIKeyExpired:
|
||||
AbortWithError(c, 403, "API_KEY_EXPIRED", "API key 已过期")
|
||||
default:
|
||||
AbortWithError(c, 401, "API_KEY_DISABLED", "API key is disabled")
|
||||
}
|
||||
return
|
||||
}
|
||||
// ── 3. 基础鉴权(始终执行) ─────────────────────────────────
|
||||
|
||||
// 检查API Key是否过期(即使状态是active,也要检查时间)
|
||||
if apiKey.IsExpired() {
|
||||
AbortWithError(c, 403, "API_KEY_EXPIRED", "API key 已过期")
|
||||
return
|
||||
}
|
||||
|
||||
// 检查API Key配额是否耗尽
|
||||
if apiKey.IsQuotaExhausted() {
|
||||
AbortWithError(c, 429, "API_KEY_QUOTA_EXHAUSTED", "API key 额度已用完")
|
||||
// disabled / 未知状态 → 无条件拦截(expired 和 quota_exhausted 留给计费阶段)
|
||||
if !apiKey.IsActive() &&
|
||||
apiKey.Status != service.StatusAPIKeyExpired &&
|
||||
apiKey.Status != service.StatusAPIKeyQuotaExhausted {
|
||||
AbortWithError(c, 401, "API_KEY_DISABLED", "API key is disabled")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -116,8 +109,9 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
|
||||
return
|
||||
}
|
||||
|
||||
// ── 4. SimpleMode → early return ─────────────────────────────
|
||||
|
||||
if cfg.RunMode == config.RunModeSimple {
|
||||
// 简易模式:跳过余额和订阅检查,但仍需设置必要的上下文
|
||||
c.Set(string(ContextKeyAPIKey), apiKey)
|
||||
c.Set(string(ContextKeyUser), AuthSubject{
|
||||
UserID: apiKey.User.ID,
|
||||
@@ -130,54 +124,89 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
|
||||
return
|
||||
}
|
||||
|
||||
// 判断计费方式:订阅模式 vs 余额模式
|
||||
// ── 5. 加载订阅(订阅模式时始终加载) ───────────────────────
|
||||
|
||||
// skipBilling: /v1/usage 只需鉴权,跳过所有计费执行
|
||||
skipBilling := c.Request.URL.Path == "/v1/usage"
|
||||
|
||||
var subscription *service.UserSubscription
|
||||
isSubscriptionType := apiKey.Group != nil && apiKey.Group.IsSubscriptionType()
|
||||
|
||||
if isSubscriptionType && subscriptionService != nil {
|
||||
// 订阅模式:获取订阅(L1 缓存 + singleflight)
|
||||
subscription, err := subscriptionService.GetActiveSubscription(
|
||||
sub, subErr := subscriptionService.GetActiveSubscription(
|
||||
c.Request.Context(),
|
||||
apiKey.User.ID,
|
||||
apiKey.Group.ID,
|
||||
)
|
||||
if err != nil {
|
||||
AbortWithError(c, 403, "SUBSCRIPTION_NOT_FOUND", "No active subscription found for this group")
|
||||
return
|
||||
}
|
||||
|
||||
// 合并验证 + 限额检查(纯内存操作)
|
||||
needsMaintenance, err := subscriptionService.ValidateAndCheckLimits(subscription, apiKey.Group)
|
||||
if err != nil {
|
||||
code := "SUBSCRIPTION_INVALID"
|
||||
status := 403
|
||||
if errors.Is(err, service.ErrDailyLimitExceeded) ||
|
||||
errors.Is(err, service.ErrWeeklyLimitExceeded) ||
|
||||
errors.Is(err, service.ErrMonthlyLimitExceeded) {
|
||||
code = "USAGE_LIMIT_EXCEEDED"
|
||||
status = 429
|
||||
if subErr != nil {
|
||||
if !skipBilling {
|
||||
AbortWithError(c, 403, "SUBSCRIPTION_NOT_FOUND", "No active subscription found for this group")
|
||||
return
|
||||
}
|
||||
AbortWithError(c, status, code, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 将订阅信息存入上下文
|
||||
c.Set(string(ContextKeySubscription), subscription)
|
||||
|
||||
// 窗口维护异步化(不阻塞请求)
|
||||
// 传递独立拷贝,避免与 handler 读取 context 中的 subscription 产生 data race
|
||||
if needsMaintenance {
|
||||
maintenanceCopy := *subscription
|
||||
subscriptionService.DoWindowMaintenance(&maintenanceCopy)
|
||||
}
|
||||
} else {
|
||||
// 余额模式:检查用户余额
|
||||
if apiKey.User.Balance <= 0 {
|
||||
AbortWithError(c, 403, "INSUFFICIENT_BALANCE", "Insufficient account balance")
|
||||
return
|
||||
// skipBilling: 订阅不存在也放行,handler 会返回可用的数据
|
||||
} else {
|
||||
subscription = sub
|
||||
}
|
||||
}
|
||||
|
||||
// 将API key和用户信息存入上下文
|
||||
// ── 6. 计费执行(skipBilling 时整块跳过) ────────────────────
|
||||
|
||||
if !skipBilling {
|
||||
// Key 状态检查
|
||||
switch apiKey.Status {
|
||||
case service.StatusAPIKeyQuotaExhausted:
|
||||
AbortWithError(c, 429, "API_KEY_QUOTA_EXHAUSTED", "API key 额度已用完")
|
||||
return
|
||||
case service.StatusAPIKeyExpired:
|
||||
AbortWithError(c, 403, "API_KEY_EXPIRED", "API key 已过期")
|
||||
return
|
||||
}
|
||||
|
||||
// 运行时过期/配额检查(即使状态是 active,也要检查时间和用量)
|
||||
if apiKey.IsExpired() {
|
||||
AbortWithError(c, 403, "API_KEY_EXPIRED", "API key 已过期")
|
||||
return
|
||||
}
|
||||
if apiKey.IsQuotaExhausted() {
|
||||
AbortWithError(c, 429, "API_KEY_QUOTA_EXHAUSTED", "API key 额度已用完")
|
||||
return
|
||||
}
|
||||
|
||||
// 订阅模式:验证订阅限额
|
||||
if subscription != nil {
|
||||
needsMaintenance, validateErr := subscriptionService.ValidateAndCheckLimits(subscription, apiKey.Group)
|
||||
if validateErr != nil {
|
||||
code := "SUBSCRIPTION_INVALID"
|
||||
status := 403
|
||||
if errors.Is(validateErr, service.ErrDailyLimitExceeded) ||
|
||||
errors.Is(validateErr, service.ErrWeeklyLimitExceeded) ||
|
||||
errors.Is(validateErr, service.ErrMonthlyLimitExceeded) {
|
||||
code = "USAGE_LIMIT_EXCEEDED"
|
||||
status = 429
|
||||
}
|
||||
AbortWithError(c, status, code, validateErr.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 窗口维护异步化(不阻塞请求)
|
||||
if needsMaintenance {
|
||||
maintenanceCopy := *subscription
|
||||
subscriptionService.DoWindowMaintenance(&maintenanceCopy)
|
||||
}
|
||||
} else {
|
||||
// 非订阅模式 或 订阅模式但 subscriptionService 未注入:回退到余额检查
|
||||
if apiKey.User.Balance <= 0 {
|
||||
AbortWithError(c, 403, "INSUFFICIENT_BALANCE", "Insufficient account balance")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── 7. 设置上下文 → Next ─────────────────────────────────────
|
||||
|
||||
if subscription != nil {
|
||||
c.Set(string(ContextKeySubscription), subscription)
|
||||
}
|
||||
c.Set(string(ContextKeyAPIKey), apiKey)
|
||||
c.Set(string(ContextKeyUser), AuthSubject{
|
||||
UserID: apiKey.User.ID,
|
||||
|
||||
@@ -56,7 +56,7 @@ func (f fakeAPIKeyRepo) Update(ctx context.Context, key *service.APIKey) error {
|
||||
func (f fakeAPIKeyRepo) Delete(ctx context.Context, id int64) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
func (f fakeAPIKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
|
||||
func (f fakeAPIKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams, _ service.APIKeyListFilters) ([]service.APIKey, *pagination.PaginationResult, error) {
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeAPIKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
|
||||
@@ -95,6 +95,15 @@ func (f fakeAPIKeyRepo) UpdateLastUsed(ctx context.Context, id int64, usedAt tim
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (f fakeAPIKeyRepo) IncrementRateLimitUsage(ctx context.Context, id int64, cost float64) error {
|
||||
return nil
|
||||
}
|
||||
func (f fakeAPIKeyRepo) ResetRateLimitWindows(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
func (f fakeAPIKeyRepo) GetRateLimitData(ctx context.Context, id int64) (*service.APIKeyRateLimitData, error) {
|
||||
return &service.APIKeyRateLimitData{}, nil
|
||||
}
|
||||
|
||||
func (f fakeGoogleSubscriptionRepo) Create(ctx context.Context, sub *service.UserSubscription) error {
|
||||
return errors.New("not implemented")
|
||||
|
||||
@@ -537,7 +537,7 @@ func (r *stubApiKeyRepo) Delete(ctx context.Context, id int64) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
|
||||
func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams, _ service.APIKeyListFilters) ([]service.APIKey, *pagination.PaginationResult, error) {
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
@@ -588,6 +588,16 @@ func (r *stubApiKeyRepo) UpdateLastUsed(ctx context.Context, id int64, usedAt ti
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) IncrementRateLimitUsage(ctx context.Context, id int64, cost float64) error {
|
||||
return nil
|
||||
}
|
||||
func (r *stubApiKeyRepo) ResetRateLimitWindows(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
func (r *stubApiKeyRepo) GetRateLimitData(ctx context.Context, id int64) (*service.APIKeyRateLimitData, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
type stubUserSubscriptionRepo struct {
|
||||
getActive func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error)
|
||||
updateStatus func(ctx context.Context, subscriptionID int64, status string) error
|
||||
|
||||
51
backend/internal/server/middleware/backend_mode_guard.go
Normal file
51
backend/internal/server/middleware/backend_mode_guard.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// BackendModeUserGuard blocks non-admin users from accessing user routes when backend mode is enabled.
|
||||
// Must be placed AFTER JWT auth middleware so that the user role is available in context.
|
||||
func BackendModeUserGuard(settingService *service.SettingService) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if settingService == nil || !settingService.IsBackendModeEnabled(c.Request.Context()) {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
role, _ := GetUserRoleFromContext(c)
|
||||
if role == "admin" {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
response.Forbidden(c, "Backend mode is active. User self-service is disabled.")
|
||||
c.Abort()
|
||||
}
|
||||
}
|
||||
|
||||
// BackendModeAuthGuard selectively blocks auth endpoints when backend mode is enabled.
|
||||
// Allows: login, login/2fa, logout, refresh (admin needs these).
|
||||
// Blocks: register, forgot-password, reset-password, OAuth, etc.
|
||||
func BackendModeAuthGuard(settingService *service.SettingService) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if settingService == nil || !settingService.IsBackendModeEnabled(c.Request.Context()) {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
path := c.Request.URL.Path
|
||||
// Allow login, 2FA, logout, refresh, public settings
|
||||
allowedSuffixes := []string{"/auth/login", "/auth/login/2fa", "/auth/logout", "/auth/refresh"}
|
||||
for _, suffix := range allowedSuffixes {
|
||||
if strings.HasSuffix(path, suffix) {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
}
|
||||
response.Forbidden(c, "Backend mode is active. Registration and self-service auth flows are disabled.")
|
||||
c.Abort()
|
||||
}
|
||||
}
|
||||
239
backend/internal/server/middleware/backend_mode_guard_test.go
Normal file
239
backend/internal/server/middleware/backend_mode_guard_test.go
Normal file
@@ -0,0 +1,239 @@
|
||||
//go:build unit
|
||||
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type bmSettingRepo struct {
|
||||
values map[string]string
|
||||
}
|
||||
|
||||
func (r *bmSettingRepo) Get(_ context.Context, _ string) (*service.Setting, error) {
|
||||
panic("unexpected Get call")
|
||||
}
|
||||
|
||||
func (r *bmSettingRepo) GetValue(_ context.Context, key string) (string, error) {
|
||||
v, ok := r.values[key]
|
||||
if !ok {
|
||||
return "", service.ErrSettingNotFound
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func (r *bmSettingRepo) Set(_ context.Context, _, _ string) error {
|
||||
panic("unexpected Set call")
|
||||
}
|
||||
|
||||
func (r *bmSettingRepo) GetMultiple(_ context.Context, _ []string) (map[string]string, error) {
|
||||
panic("unexpected GetMultiple call")
|
||||
}
|
||||
|
||||
func (r *bmSettingRepo) SetMultiple(_ context.Context, settings map[string]string) error {
|
||||
if r.values == nil {
|
||||
r.values = make(map[string]string, len(settings))
|
||||
}
|
||||
for key, value := range settings {
|
||||
r.values[key] = value
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *bmSettingRepo) GetAll(_ context.Context) (map[string]string, error) {
|
||||
panic("unexpected GetAll call")
|
||||
}
|
||||
|
||||
func (r *bmSettingRepo) Delete(_ context.Context, _ string) error {
|
||||
panic("unexpected Delete call")
|
||||
}
|
||||
|
||||
func newBackendModeSettingService(t *testing.T, enabled string) *service.SettingService {
|
||||
t.Helper()
|
||||
|
||||
repo := &bmSettingRepo{
|
||||
values: map[string]string{
|
||||
service.SettingKeyBackendModeEnabled: enabled,
|
||||
},
|
||||
}
|
||||
svc := service.NewSettingService(repo, &config.Config{})
|
||||
require.NoError(t, svc.UpdateSettings(context.Background(), &service.SystemSettings{
|
||||
BackendModeEnabled: enabled == "true",
|
||||
}))
|
||||
|
||||
return svc
|
||||
}
|
||||
|
||||
func stringPtr(v string) *string {
|
||||
return &v
|
||||
}
|
||||
|
||||
func TestBackendModeUserGuard(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
nilService bool
|
||||
enabled string
|
||||
role *string
|
||||
wantStatus int
|
||||
}{
|
||||
{
|
||||
name: "disabled_allows_all",
|
||||
enabled: "false",
|
||||
role: stringPtr("user"),
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "nil_service_allows_all",
|
||||
nilService: true,
|
||||
role: stringPtr("user"),
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "enabled_admin_allowed",
|
||||
enabled: "true",
|
||||
role: stringPtr("admin"),
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "enabled_user_blocked",
|
||||
enabled: "true",
|
||||
role: stringPtr("user"),
|
||||
wantStatus: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "enabled_no_role_blocked",
|
||||
enabled: "true",
|
||||
wantStatus: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "enabled_empty_role_blocked",
|
||||
enabled: "true",
|
||||
role: stringPtr(""),
|
||||
wantStatus: http.StatusForbidden,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
if tc.role != nil {
|
||||
role := *tc.role
|
||||
r.Use(func(c *gin.Context) {
|
||||
c.Set(string(ContextKeyUserRole), role)
|
||||
c.Next()
|
||||
})
|
||||
}
|
||||
|
||||
var svc *service.SettingService
|
||||
if !tc.nilService {
|
||||
svc = newBackendModeSettingService(t, tc.enabled)
|
||||
}
|
||||
|
||||
r.Use(BackendModeUserGuard(svc))
|
||||
r.GET("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, tc.wantStatus, w.Code)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackendModeAuthGuard(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
nilService bool
|
||||
enabled string
|
||||
path string
|
||||
wantStatus int
|
||||
}{
|
||||
{
|
||||
name: "disabled_allows_all",
|
||||
enabled: "false",
|
||||
path: "/api/v1/auth/register",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "nil_service_allows_all",
|
||||
nilService: true,
|
||||
path: "/api/v1/auth/register",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "enabled_allows_login",
|
||||
enabled: "true",
|
||||
path: "/api/v1/auth/login",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "enabled_allows_login_2fa",
|
||||
enabled: "true",
|
||||
path: "/api/v1/auth/login/2fa",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "enabled_allows_logout",
|
||||
enabled: "true",
|
||||
path: "/api/v1/auth/logout",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "enabled_allows_refresh",
|
||||
enabled: "true",
|
||||
path: "/api/v1/auth/refresh",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "enabled_blocks_register",
|
||||
enabled: "true",
|
||||
path: "/api/v1/auth/register",
|
||||
wantStatus: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "enabled_blocks_forgot_password",
|
||||
enabled: "true",
|
||||
path: "/api/v1/auth/forgot-password",
|
||||
wantStatus: http.StatusForbidden,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
|
||||
var svc *service.SettingService
|
||||
if !tc.nilService {
|
||||
svc = newBackendModeSettingService(t, tc.enabled)
|
||||
}
|
||||
|
||||
r.Use(BackendModeAuthGuard(svc))
|
||||
r.Any("/*path", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, tc.path, nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, tc.wantStatus, w.Code)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -40,7 +40,7 @@ func newJWTTestEnv(users map[int64]*service.User) (*gin.Engine, *service.AuthSer
|
||||
cfg.JWT.AccessTokenExpireMinutes = 60
|
||||
|
||||
userRepo := &stubJWTUserRepo{users: users}
|
||||
authSvc := service.NewAuthService(userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
|
||||
authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
|
||||
userSvc := service.NewUserService(userRepo, nil, nil)
|
||||
mw := NewJWTAuthMiddleware(authSvc, userSvc)
|
||||
|
||||
|
||||
@@ -2,8 +2,11 @@ package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
@@ -71,3 +74,48 @@ func AbortWithError(c *gin.Context, statusCode int, code, message string) {
|
||||
c.JSON(statusCode, NewErrorResponse(code, message))
|
||||
c.Abort()
|
||||
}
|
||||
|
||||
// ──────────────────────────────────────────────────────────
|
||||
// RequireGroupAssignment — 未分组 Key 拦截中间件
|
||||
// ──────────────────────────────────────────────────────────
|
||||
|
||||
// GatewayErrorWriter 定义网关错误响应格式(不同协议使用不同格式)
|
||||
type GatewayErrorWriter func(c *gin.Context, status int, message string)
|
||||
|
||||
// AnthropicErrorWriter 按 Anthropic API 规范输出错误
|
||||
func AnthropicErrorWriter(c *gin.Context, status int, message string) {
|
||||
c.JSON(status, gin.H{
|
||||
"type": "error",
|
||||
"error": gin.H{"type": "permission_error", "message": message},
|
||||
})
|
||||
}
|
||||
|
||||
// GoogleErrorWriter 按 Google API 规范输出错误
|
||||
func GoogleErrorWriter(c *gin.Context, status int, message string) {
|
||||
c.JSON(status, gin.H{
|
||||
"error": gin.H{
|
||||
"code": status,
|
||||
"message": message,
|
||||
"status": googleapi.HTTPStatusToGoogleStatus(status),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// RequireGroupAssignment 检查 API Key 是否已分配到分组,
|
||||
// 如果未分组且系统设置不允许未分组 Key 调度则返回 403。
|
||||
func RequireGroupAssignment(settingService *service.SettingService, writeError GatewayErrorWriter) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
apiKey, ok := GetAPIKeyFromContext(c)
|
||||
if !ok || apiKey.GroupID != nil {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
// 未分组 Key — 检查系统设置
|
||||
if settingService.IsUngroupedKeySchedulingAllowed(c.Request.Context()) {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
writeError(c, http.StatusForbidden, "API Key is not assigned to any group and cannot be used. Please contact the administrator to assign it to a group.")
|
||||
c.Abort()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -41,7 +41,9 @@ func GetNonceFromContext(c *gin.Context) string {
|
||||
}
|
||||
|
||||
// SecurityHeaders sets baseline security headers for all responses.
|
||||
func SecurityHeaders(cfg config.CSPConfig) gin.HandlerFunc {
|
||||
// getFrameSrcOrigins is an optional function that returns extra origins to inject into frame-src;
|
||||
// pass nil to disable dynamic frame-src injection.
|
||||
func SecurityHeaders(cfg config.CSPConfig, getFrameSrcOrigins func() []string) gin.HandlerFunc {
|
||||
policy := strings.TrimSpace(cfg.Policy)
|
||||
if policy == "" {
|
||||
policy = config.DefaultCSPPolicy
|
||||
@@ -51,6 +53,15 @@ func SecurityHeaders(cfg config.CSPConfig) gin.HandlerFunc {
|
||||
policy = enhanceCSPPolicy(policy)
|
||||
|
||||
return func(c *gin.Context) {
|
||||
finalPolicy := policy
|
||||
if getFrameSrcOrigins != nil {
|
||||
for _, origin := range getFrameSrcOrigins() {
|
||||
if origin != "" {
|
||||
finalPolicy = addToDirective(finalPolicy, "frame-src", origin)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.Header("X-Content-Type-Options", "nosniff")
|
||||
c.Header("X-Frame-Options", "DENY")
|
||||
c.Header("Referrer-Policy", "strict-origin-when-cross-origin")
|
||||
@@ -65,12 +76,10 @@ func SecurityHeaders(cfg config.CSPConfig) gin.HandlerFunc {
|
||||
if err != nil {
|
||||
// crypto/rand 失败时降级为无 nonce 的 CSP 策略
|
||||
log.Printf("[SecurityHeaders] %v — 降级为无 nonce 的 CSP", err)
|
||||
finalPolicy := strings.ReplaceAll(policy, NonceTemplate, "'unsafe-inline'")
|
||||
c.Header("Content-Security-Policy", finalPolicy)
|
||||
c.Header("Content-Security-Policy", strings.ReplaceAll(finalPolicy, NonceTemplate, "'unsafe-inline'"))
|
||||
} else {
|
||||
c.Set(CSPNonceKey, nonce)
|
||||
finalPolicy := strings.ReplaceAll(policy, NonceTemplate, "'nonce-"+nonce+"'")
|
||||
c.Header("Content-Security-Policy", finalPolicy)
|
||||
c.Header("Content-Security-Policy", strings.ReplaceAll(finalPolicy, NonceTemplate, "'nonce-"+nonce+"'"))
|
||||
}
|
||||
}
|
||||
c.Next()
|
||||
|
||||
@@ -84,7 +84,7 @@ func TestGetNonceFromContext(t *testing.T) {
|
||||
func TestSecurityHeaders(t *testing.T) {
|
||||
t.Run("sets_basic_security_headers", func(t *testing.T) {
|
||||
cfg := config.CSPConfig{Enabled: false}
|
||||
middleware := SecurityHeaders(cfg)
|
||||
middleware := SecurityHeaders(cfg, nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -99,7 +99,7 @@ func TestSecurityHeaders(t *testing.T) {
|
||||
|
||||
t.Run("csp_disabled_no_csp_header", func(t *testing.T) {
|
||||
cfg := config.CSPConfig{Enabled: false}
|
||||
middleware := SecurityHeaders(cfg)
|
||||
middleware := SecurityHeaders(cfg, nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -115,7 +115,7 @@ func TestSecurityHeaders(t *testing.T) {
|
||||
Enabled: true,
|
||||
Policy: "default-src 'self'",
|
||||
}
|
||||
middleware := SecurityHeaders(cfg)
|
||||
middleware := SecurityHeaders(cfg, nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -136,7 +136,7 @@ func TestSecurityHeaders(t *testing.T) {
|
||||
Enabled: true,
|
||||
Policy: "default-src 'self'; script-src 'self' __CSP_NONCE__",
|
||||
}
|
||||
middleware := SecurityHeaders(cfg)
|
||||
middleware := SecurityHeaders(cfg, nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -156,7 +156,7 @@ func TestSecurityHeaders(t *testing.T) {
|
||||
Enabled: true,
|
||||
Policy: "script-src 'self' __CSP_NONCE__",
|
||||
}
|
||||
middleware := SecurityHeaders(cfg)
|
||||
middleware := SecurityHeaders(cfg, nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -180,7 +180,7 @@ func TestSecurityHeaders(t *testing.T) {
|
||||
Enabled: true,
|
||||
Policy: "",
|
||||
}
|
||||
middleware := SecurityHeaders(cfg)
|
||||
middleware := SecurityHeaders(cfg, nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -199,7 +199,7 @@ func TestSecurityHeaders(t *testing.T) {
|
||||
Enabled: true,
|
||||
Policy: " \t\n ",
|
||||
}
|
||||
middleware := SecurityHeaders(cfg)
|
||||
middleware := SecurityHeaders(cfg, nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -217,7 +217,7 @@ func TestSecurityHeaders(t *testing.T) {
|
||||
Enabled: true,
|
||||
Policy: "script-src __CSP_NONCE__; style-src __CSP_NONCE__",
|
||||
}
|
||||
middleware := SecurityHeaders(cfg)
|
||||
middleware := SecurityHeaders(cfg, nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
@@ -235,7 +235,7 @@ func TestSecurityHeaders(t *testing.T) {
|
||||
|
||||
t.Run("calls_next_handler", func(t *testing.T) {
|
||||
cfg := config.CSPConfig{Enabled: true, Policy: "default-src 'self'"}
|
||||
middleware := SecurityHeaders(cfg)
|
||||
middleware := SecurityHeaders(cfg, nil)
|
||||
|
||||
nextCalled := false
|
||||
router := gin.New()
|
||||
@@ -258,7 +258,7 @@ func TestSecurityHeaders(t *testing.T) {
|
||||
Enabled: true,
|
||||
Policy: "script-src __CSP_NONCE__",
|
||||
}
|
||||
middleware := SecurityHeaders(cfg)
|
||||
middleware := SecurityHeaders(cfg, nil)
|
||||
|
||||
nonces := make(map[string]bool)
|
||||
for i := 0; i < 10; i++ {
|
||||
@@ -376,7 +376,7 @@ func BenchmarkSecurityHeadersMiddleware(b *testing.B) {
|
||||
Enabled: true,
|
||||
Policy: "script-src 'self' __CSP_NONCE__",
|
||||
}
|
||||
middleware := SecurityHeaders(cfg)
|
||||
middleware := SecurityHeaders(cfg, nil)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||
@@ -14,6 +17,8 @@ import (
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const frameSrcRefreshTimeout = 5 * time.Second
|
||||
|
||||
// SetupRouter 配置路由器中间件和路由
|
||||
func SetupRouter(
|
||||
r *gin.Engine,
|
||||
@@ -28,11 +33,33 @@ func SetupRouter(
|
||||
cfg *config.Config,
|
||||
redisClient *redis.Client,
|
||||
) *gin.Engine {
|
||||
// 缓存 iframe 页面的 origin 列表,用于动态注入 CSP frame-src
|
||||
var cachedFrameOrigins atomic.Pointer[[]string]
|
||||
emptyOrigins := []string{}
|
||||
cachedFrameOrigins.Store(&emptyOrigins)
|
||||
|
||||
refreshFrameOrigins := func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), frameSrcRefreshTimeout)
|
||||
defer cancel()
|
||||
origins, err := settingService.GetFrameSrcOrigins(ctx)
|
||||
if err != nil {
|
||||
// 获取失败时保留已有缓存,避免 frame-src 被意外清空
|
||||
return
|
||||
}
|
||||
cachedFrameOrigins.Store(&origins)
|
||||
}
|
||||
refreshFrameOrigins() // 启动时初始化
|
||||
|
||||
// 应用中间件
|
||||
r.Use(middleware2.RequestLogger())
|
||||
r.Use(middleware2.Logger())
|
||||
r.Use(middleware2.CORS(cfg.CORS))
|
||||
r.Use(middleware2.SecurityHeaders(cfg.Security.CSP))
|
||||
r.Use(middleware2.SecurityHeaders(cfg.Security.CSP, func() []string {
|
||||
if p := cachedFrameOrigins.Load(); p != nil {
|
||||
return *p
|
||||
}
|
||||
return nil
|
||||
}))
|
||||
|
||||
// Serve embedded frontend with settings injection if available
|
||||
if web.HasEmbeddedFrontend() {
|
||||
@@ -40,15 +67,21 @@ func SetupRouter(
|
||||
if err != nil {
|
||||
log.Printf("Warning: Failed to create frontend server with settings injection: %v, using legacy mode", err)
|
||||
r.Use(web.ServeEmbeddedFrontend())
|
||||
settingService.SetOnUpdateCallback(refreshFrameOrigins)
|
||||
} else {
|
||||
// Register cache invalidation callback
|
||||
settingService.SetOnUpdateCallback(frontendServer.InvalidateCache)
|
||||
// Register combined callback: invalidate HTML cache + refresh frame origins
|
||||
settingService.SetOnUpdateCallback(func() {
|
||||
frontendServer.InvalidateCache()
|
||||
refreshFrameOrigins()
|
||||
})
|
||||
r.Use(frontendServer.Middleware())
|
||||
}
|
||||
} else {
|
||||
settingService.SetOnUpdateCallback(refreshFrameOrigins)
|
||||
}
|
||||
|
||||
// 注册路由
|
||||
registerRoutes(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, cfg, redisClient)
|
||||
registerRoutes(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg, redisClient)
|
||||
|
||||
return r
|
||||
}
|
||||
@@ -63,6 +96,7 @@ func registerRoutes(
|
||||
apiKeyService *service.APIKeyService,
|
||||
subscriptionService *service.SubscriptionService,
|
||||
opsService *service.OpsService,
|
||||
settingService *service.SettingService,
|
||||
cfg *config.Config,
|
||||
redisClient *redis.Client,
|
||||
) {
|
||||
@@ -73,9 +107,9 @@ func registerRoutes(
|
||||
v1 := r.Group("/api/v1")
|
||||
|
||||
// 注册各模块路由
|
||||
routes.RegisterAuthRoutes(v1, h, jwtAuth, redisClient)
|
||||
routes.RegisterUserRoutes(v1, h, jwtAuth)
|
||||
routes.RegisterSoraClientRoutes(v1, h, jwtAuth)
|
||||
routes.RegisterAuthRoutes(v1, h, jwtAuth, redisClient, settingService)
|
||||
routes.RegisterUserRoutes(v1, h, jwtAuth, settingService)
|
||||
routes.RegisterSoraClientRoutes(v1, h, jwtAuth, settingService)
|
||||
routes.RegisterAdminRoutes(v1, h, adminAuth)
|
||||
routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, cfg)
|
||||
routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg)
|
||||
}
|
||||
|
||||
@@ -58,6 +58,9 @@ func RegisterAdminRoutes(
|
||||
// 数据管理
|
||||
registerDataManagementRoutes(admin, h)
|
||||
|
||||
// 数据库备份恢复
|
||||
registerBackupRoutes(admin, h)
|
||||
|
||||
// 运维监控(Ops)
|
||||
registerOpsRoutes(admin, h)
|
||||
|
||||
@@ -78,6 +81,9 @@ func RegisterAdminRoutes(
|
||||
|
||||
// API Key 管理
|
||||
registerAdminAPIKeyRoutes(admin, h)
|
||||
|
||||
// 定时测试计划
|
||||
registerScheduledTestRoutes(admin, h)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -168,6 +174,7 @@ func registerOpsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
ops.GET("/system-logs/health", h.Admin.Ops.GetSystemLogIngestionHealth)
|
||||
|
||||
// Dashboard (vNext - raw path for MVP)
|
||||
ops.GET("/dashboard/snapshot-v2", h.Admin.Ops.GetDashboardSnapshotV2)
|
||||
ops.GET("/dashboard/overview", h.Admin.Ops.GetDashboardOverview)
|
||||
ops.GET("/dashboard/throughput-trend", h.Admin.Ops.GetDashboardThroughputTrend)
|
||||
ops.GET("/dashboard/latency-histogram", h.Admin.Ops.GetDashboardLatencyHistogram)
|
||||
@@ -180,6 +187,7 @@ func registerOpsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
func registerDashboardRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
dashboard := admin.Group("/dashboard")
|
||||
{
|
||||
dashboard.GET("/snapshot-v2", h.Admin.Dashboard.GetSnapshotV2)
|
||||
dashboard.GET("/stats", h.Admin.Dashboard.GetStats)
|
||||
dashboard.GET("/realtime", h.Admin.Dashboard.GetRealtimeMetrics)
|
||||
dashboard.GET("/trend", h.Admin.Dashboard.GetUsageTrend)
|
||||
@@ -187,6 +195,7 @@ func registerDashboardRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
dashboard.GET("/groups", h.Admin.Dashboard.GetGroupStats)
|
||||
dashboard.GET("/api-keys-trend", h.Admin.Dashboard.GetAPIKeyUsageTrend)
|
||||
dashboard.GET("/users-trend", h.Admin.Dashboard.GetUserUsageTrend)
|
||||
dashboard.GET("/users-ranking", h.Admin.Dashboard.GetUserSpendingRanking)
|
||||
dashboard.POST("/users-usage", h.Admin.Dashboard.GetBatchUsersUsage)
|
||||
dashboard.POST("/api-keys-usage", h.Admin.Dashboard.GetBatchAPIKeysUsage)
|
||||
dashboard.POST("/aggregation/backfill", h.Admin.Dashboard.BackfillAggregation)
|
||||
@@ -223,6 +232,9 @@ func registerGroupRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
groups.PUT("/:id", h.Admin.Group.Update)
|
||||
groups.DELETE("/:id", h.Admin.Group.Delete)
|
||||
groups.GET("/:id/stats", h.Admin.Group.GetStats)
|
||||
groups.GET("/:id/rate-multipliers", h.Admin.Group.GetGroupRateMultipliers)
|
||||
groups.PUT("/:id/rate-multipliers", h.Admin.Group.BatchSetGroupRateMultipliers)
|
||||
groups.DELETE("/:id/rate-multipliers", h.Admin.Group.ClearGroupRateMultipliers)
|
||||
groups.GET("/:id/api-keys", h.Admin.Group.GetGroupAPIKeys)
|
||||
}
|
||||
}
|
||||
@@ -239,6 +251,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
accounts.PUT("/:id", h.Admin.Account.Update)
|
||||
accounts.DELETE("/:id", h.Admin.Account.Delete)
|
||||
accounts.POST("/:id/test", h.Admin.Account.Test)
|
||||
accounts.POST("/:id/recover-state", h.Admin.Account.RecoverState)
|
||||
accounts.POST("/:id/refresh", h.Admin.Account.Refresh)
|
||||
accounts.POST("/:id/refresh-tier", h.Admin.Account.RefreshTier)
|
||||
accounts.GET("/:id/stats", h.Admin.Account.GetStats)
|
||||
@@ -247,6 +260,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
accounts.GET("/:id/today-stats", h.Admin.Account.GetTodayStats)
|
||||
accounts.POST("/today-stats/batch", h.Admin.Account.GetBatchTodayStats)
|
||||
accounts.POST("/:id/clear-rate-limit", h.Admin.Account.ClearRateLimit)
|
||||
accounts.POST("/:id/reset-quota", h.Admin.Account.ResetQuota)
|
||||
accounts.GET("/:id/temp-unschedulable", h.Admin.Account.GetTempUnschedulable)
|
||||
accounts.DELETE("/:id/temp-unschedulable", h.Admin.Account.ClearTempUnschedulable)
|
||||
accounts.POST("/:id/schedulable", h.Admin.Account.SetSchedulable)
|
||||
@@ -257,6 +271,8 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
accounts.POST("/batch-update-credentials", h.Admin.Account.BatchUpdateCredentials)
|
||||
accounts.POST("/batch-refresh-tier", h.Admin.Account.BatchRefreshTier)
|
||||
accounts.POST("/bulk-update", h.Admin.Account.BulkUpdate)
|
||||
accounts.POST("/batch-clear-error", h.Admin.Account.BatchClearError)
|
||||
accounts.POST("/batch-refresh", h.Admin.Account.BatchRefresh)
|
||||
|
||||
// Antigravity 默认模型映射
|
||||
accounts.GET("/antigravity/default-model-mapping", h.Admin.Account.GetAntigravityDefaultModelMapping)
|
||||
@@ -386,6 +402,12 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
// 流超时处理配置
|
||||
adminSettings.GET("/stream-timeout", h.Admin.Setting.GetStreamTimeoutSettings)
|
||||
adminSettings.PUT("/stream-timeout", h.Admin.Setting.UpdateStreamTimeoutSettings)
|
||||
// 请求整流器配置
|
||||
adminSettings.GET("/rectifier", h.Admin.Setting.GetRectifierSettings)
|
||||
adminSettings.PUT("/rectifier", h.Admin.Setting.UpdateRectifierSettings)
|
||||
// Beta 策略配置
|
||||
adminSettings.GET("/beta-policy", h.Admin.Setting.GetBetaPolicySettings)
|
||||
adminSettings.PUT("/beta-policy", h.Admin.Setting.UpdateBetaPolicySettings)
|
||||
// Sora S3 存储配置
|
||||
adminSettings.GET("/sora-s3", h.Admin.Setting.GetSoraS3Settings)
|
||||
adminSettings.PUT("/sora-s3", h.Admin.Setting.UpdateSoraS3Settings)
|
||||
@@ -421,6 +443,30 @@ func registerDataManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
}
|
||||
}
|
||||
|
||||
func registerBackupRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
backup := admin.Group("/backups")
|
||||
{
|
||||
// S3 存储配置
|
||||
backup.GET("/s3-config", h.Admin.Backup.GetS3Config)
|
||||
backup.PUT("/s3-config", h.Admin.Backup.UpdateS3Config)
|
||||
backup.POST("/s3-config/test", h.Admin.Backup.TestS3Connection)
|
||||
|
||||
// 定时备份配置
|
||||
backup.GET("/schedule", h.Admin.Backup.GetSchedule)
|
||||
backup.PUT("/schedule", h.Admin.Backup.UpdateSchedule)
|
||||
|
||||
// 备份操作
|
||||
backup.POST("", h.Admin.Backup.CreateBackup)
|
||||
backup.GET("", h.Admin.Backup.ListBackups)
|
||||
backup.GET("/:id", h.Admin.Backup.GetBackup)
|
||||
backup.DELETE("/:id", h.Admin.Backup.DeleteBackup)
|
||||
backup.GET("/:id/download-url", h.Admin.Backup.GetDownloadURL)
|
||||
|
||||
// 恢复操作
|
||||
backup.POST("/:id/restore", h.Admin.Backup.RestoreBackup)
|
||||
}
|
||||
}
|
||||
|
||||
func registerSystemRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
system := admin.Group("/system")
|
||||
{
|
||||
@@ -441,6 +487,7 @@ func registerSubscriptionRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
subscriptions.POST("/assign", h.Admin.Subscription.Assign)
|
||||
subscriptions.POST("/bulk-assign", h.Admin.Subscription.BulkAssign)
|
||||
subscriptions.POST("/:id/extend", h.Admin.Subscription.Extend)
|
||||
subscriptions.POST("/:id/reset-quota", h.Admin.Subscription.ResetQuota)
|
||||
subscriptions.DELETE("/:id", h.Admin.Subscription.Revoke)
|
||||
}
|
||||
|
||||
@@ -476,6 +523,18 @@ func registerUserAttributeRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
}
|
||||
}
|
||||
|
||||
func registerScheduledTestRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
plans := admin.Group("/scheduled-test-plans")
|
||||
{
|
||||
plans.POST("", h.Admin.ScheduledTest.Create)
|
||||
plans.PUT("/:id", h.Admin.ScheduledTest.Update)
|
||||
plans.DELETE("/:id", h.Admin.ScheduledTest.Delete)
|
||||
plans.GET("/:id/results", h.Admin.ScheduledTest.ListResults)
|
||||
}
|
||||
// Nested under accounts
|
||||
admin.GET("/accounts/:id/scheduled-test-plans", h.Admin.ScheduledTest.ListByAccount)
|
||||
}
|
||||
|
||||
func registerErrorPassthroughRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
rules := admin.Group("/error-passthrough-rules")
|
||||
{
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||
"github.com/Wei-Shaw/sub2api/internal/middleware"
|
||||
servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/redis/go-redis/v9"
|
||||
@@ -17,12 +18,14 @@ func RegisterAuthRoutes(
|
||||
h *handler.Handlers,
|
||||
jwtAuth servermiddleware.JWTAuthMiddleware,
|
||||
redisClient *redis.Client,
|
||||
settingService *service.SettingService,
|
||||
) {
|
||||
// 创建速率限制器
|
||||
rateLimiter := middleware.NewRateLimiter(redisClient)
|
||||
|
||||
// 公开接口
|
||||
auth := v1.Group("/auth")
|
||||
auth.Use(servermiddleware.BackendModeAuthGuard(settingService))
|
||||
{
|
||||
// 注册/登录/2FA/验证码发送均属于高风险入口,增加服务端兜底限流(Redis 故障时 fail-close)
|
||||
auth.POST("/register", rateLimiter.LimitWithOptions("auth-register", 5, time.Minute, middleware.RateLimitOptions{
|
||||
@@ -61,6 +64,12 @@ func RegisterAuthRoutes(
|
||||
}), h.Auth.ResetPassword)
|
||||
auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart)
|
||||
auth.GET("/oauth/linuxdo/callback", h.Auth.LinuxDoOAuthCallback)
|
||||
auth.POST("/oauth/linuxdo/complete-registration",
|
||||
rateLimiter.LimitWithOptions("oauth-linuxdo-complete", 10, time.Minute, middleware.RateLimitOptions{
|
||||
FailureMode: middleware.RateLimitFailClose,
|
||||
}),
|
||||
h.Auth.CompleteLinuxDoOAuthRegistration,
|
||||
)
|
||||
}
|
||||
|
||||
// 公开设置(无需认证)
|
||||
@@ -72,6 +81,7 @@ func RegisterAuthRoutes(
|
||||
// 需要认证的当前用户信息
|
||||
authenticated := v1.Group("")
|
||||
authenticated.Use(gin.HandlerFunc(jwtAuth))
|
||||
authenticated.Use(servermiddleware.BackendModeUserGuard(settingService))
|
||||
{
|
||||
authenticated.GET("/auth/me", h.Auth.GetCurrentUser)
|
||||
// 撤销所有会话(需要认证)
|
||||
|
||||
@@ -29,6 +29,7 @@ func newAuthRoutesTestRouter(redisClient *redis.Client) *gin.Engine {
|
||||
c.Next()
|
||||
}),
|
||||
redisClient,
|
||||
nil,
|
||||
)
|
||||
|
||||
return router
|
||||
|
||||
@@ -19,6 +19,7 @@ func RegisterGatewayRoutes(
|
||||
apiKeyService *service.APIKeyService,
|
||||
subscriptionService *service.SubscriptionService,
|
||||
opsService *service.OpsService,
|
||||
settingService *service.SettingService,
|
||||
cfg *config.Config,
|
||||
) {
|
||||
bodyLimit := middleware.RequestBodyLimit(cfg.Gateway.MaxBodySize)
|
||||
@@ -29,30 +30,51 @@ func RegisterGatewayRoutes(
|
||||
soraBodyLimit := middleware.RequestBodyLimit(soraMaxBodySize)
|
||||
clientRequestID := middleware.ClientRequestID()
|
||||
opsErrorLogger := handler.OpsErrorLoggerMiddleware(opsService)
|
||||
endpointNorm := handler.InboundEndpointMiddleware()
|
||||
|
||||
// 未分组 Key 拦截中间件(按协议格式区分错误响应)
|
||||
requireGroupAnthropic := middleware.RequireGroupAssignment(settingService, middleware.AnthropicErrorWriter)
|
||||
requireGroupGoogle := middleware.RequireGroupAssignment(settingService, middleware.GoogleErrorWriter)
|
||||
|
||||
// API网关(Claude API兼容)
|
||||
gateway := r.Group("/v1")
|
||||
gateway.Use(bodyLimit)
|
||||
gateway.Use(clientRequestID)
|
||||
gateway.Use(opsErrorLogger)
|
||||
gateway.Use(endpointNorm)
|
||||
gateway.Use(gin.HandlerFunc(apiKeyAuth))
|
||||
gateway.Use(requireGroupAnthropic)
|
||||
{
|
||||
gateway.POST("/messages", h.Gateway.Messages)
|
||||
gateway.POST("/messages/count_tokens", h.Gateway.CountTokens)
|
||||
// /v1/messages: auto-route based on group platform
|
||||
gateway.POST("/messages", func(c *gin.Context) {
|
||||
if getGroupPlatform(c) == service.PlatformOpenAI {
|
||||
h.OpenAIGateway.Messages(c)
|
||||
return
|
||||
}
|
||||
h.Gateway.Messages(c)
|
||||
})
|
||||
// /v1/messages/count_tokens: OpenAI groups get 404
|
||||
gateway.POST("/messages/count_tokens", func(c *gin.Context) {
|
||||
if getGroupPlatform(c) == service.PlatformOpenAI {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"type": "error",
|
||||
"error": gin.H{
|
||||
"type": "not_found_error",
|
||||
"message": "Token counting is not supported for this platform",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
h.Gateway.CountTokens(c)
|
||||
})
|
||||
gateway.GET("/models", h.Gateway.Models)
|
||||
gateway.GET("/usage", h.Gateway.Usage)
|
||||
// OpenAI Responses API
|
||||
gateway.POST("/responses", h.OpenAIGateway.Responses)
|
||||
gateway.POST("/responses/*subpath", h.OpenAIGateway.Responses)
|
||||
gateway.GET("/responses", h.OpenAIGateway.ResponsesWebSocket)
|
||||
// 明确阻止旧协议入口:OpenAI 仅支持 Responses API,避免客户端误解为会自动路由到其它平台。
|
||||
gateway.POST("/chat/completions", func(c *gin.Context) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": gin.H{
|
||||
"type": "invalid_request_error",
|
||||
"message": "Unsupported legacy protocol: /v1/chat/completions is not supported. Please use /v1/responses.",
|
||||
},
|
||||
})
|
||||
})
|
||||
// OpenAI Chat Completions API
|
||||
gateway.POST("/chat/completions", h.OpenAIGateway.ChatCompletions)
|
||||
}
|
||||
|
||||
// Gemini 原生 API 兼容层(Gemini SDK/CLI 直连)
|
||||
@@ -60,7 +82,9 @@ func RegisterGatewayRoutes(
|
||||
gemini.Use(bodyLimit)
|
||||
gemini.Use(clientRequestID)
|
||||
gemini.Use(opsErrorLogger)
|
||||
gemini.Use(endpointNorm)
|
||||
gemini.Use(middleware.APIKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg))
|
||||
gemini.Use(requireGroupGoogle)
|
||||
{
|
||||
gemini.GET("/models", h.Gateway.GeminiV1BetaListModels)
|
||||
gemini.GET("/models/:model", h.Gateway.GeminiV1BetaGetModel)
|
||||
@@ -69,19 +93,24 @@ func RegisterGatewayRoutes(
|
||||
}
|
||||
|
||||
// OpenAI Responses API(不带v1前缀的别名)
|
||||
r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), h.OpenAIGateway.Responses)
|
||||
r.GET("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), h.OpenAIGateway.ResponsesWebSocket)
|
||||
r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses)
|
||||
r.POST("/responses/*subpath", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses)
|
||||
r.GET("/responses", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ResponsesWebSocket)
|
||||
// OpenAI Chat Completions API(不带v1前缀的别名)
|
||||
r.POST("/chat/completions", bodyLimit, clientRequestID, opsErrorLogger, endpointNorm, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ChatCompletions)
|
||||
|
||||
// Antigravity 模型列表
|
||||
r.GET("/antigravity/models", gin.HandlerFunc(apiKeyAuth), h.Gateway.AntigravityModels)
|
||||
r.GET("/antigravity/models", gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.Gateway.AntigravityModels)
|
||||
|
||||
// Antigravity 专用路由(仅使用 antigravity 账户,不混合调度)
|
||||
antigravityV1 := r.Group("/antigravity/v1")
|
||||
antigravityV1.Use(bodyLimit)
|
||||
antigravityV1.Use(clientRequestID)
|
||||
antigravityV1.Use(opsErrorLogger)
|
||||
antigravityV1.Use(endpointNorm)
|
||||
antigravityV1.Use(middleware.ForcePlatform(service.PlatformAntigravity))
|
||||
antigravityV1.Use(gin.HandlerFunc(apiKeyAuth))
|
||||
antigravityV1.Use(requireGroupAnthropic)
|
||||
{
|
||||
antigravityV1.POST("/messages", h.Gateway.Messages)
|
||||
antigravityV1.POST("/messages/count_tokens", h.Gateway.CountTokens)
|
||||
@@ -93,8 +122,10 @@ func RegisterGatewayRoutes(
|
||||
antigravityV1Beta.Use(bodyLimit)
|
||||
antigravityV1Beta.Use(clientRequestID)
|
||||
antigravityV1Beta.Use(opsErrorLogger)
|
||||
antigravityV1Beta.Use(endpointNorm)
|
||||
antigravityV1Beta.Use(middleware.ForcePlatform(service.PlatformAntigravity))
|
||||
antigravityV1Beta.Use(middleware.APIKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg))
|
||||
antigravityV1Beta.Use(requireGroupGoogle)
|
||||
{
|
||||
antigravityV1Beta.GET("/models", h.Gateway.GeminiV1BetaListModels)
|
||||
antigravityV1Beta.GET("/models/:model", h.Gateway.GeminiV1BetaGetModel)
|
||||
@@ -106,8 +137,10 @@ func RegisterGatewayRoutes(
|
||||
soraV1.Use(soraBodyLimit)
|
||||
soraV1.Use(clientRequestID)
|
||||
soraV1.Use(opsErrorLogger)
|
||||
soraV1.Use(endpointNorm)
|
||||
soraV1.Use(middleware.ForcePlatform(service.PlatformSora))
|
||||
soraV1.Use(gin.HandlerFunc(apiKeyAuth))
|
||||
soraV1.Use(requireGroupAnthropic)
|
||||
{
|
||||
soraV1.POST("/chat/completions", h.SoraGateway.ChatCompletions)
|
||||
soraV1.GET("/models", h.Gateway.Models)
|
||||
@@ -122,3 +155,12 @@ func RegisterGatewayRoutes(
|
||||
// Sora 媒体代理(签名 URL,无需 API Key)
|
||||
r.GET("/sora/media-signed/*filepath", h.SoraGateway.MediaProxySigned)
|
||||
}
|
||||
|
||||
// getGroupPlatform extracts the group platform from the API Key stored in context.
|
||||
func getGroupPlatform(c *gin.Context) string {
|
||||
apiKey, ok := middleware.GetAPIKeyFromContext(c)
|
||||
if !ok || apiKey.Group == nil {
|
||||
return ""
|
||||
}
|
||||
return apiKey.Group.Platform
|
||||
}
|
||||
|
||||
51
backend/internal/server/routes/gateway_test.go
Normal file
51
backend/internal/server/routes/gateway_test.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package routes
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||
servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func newGatewayRoutesTestRouter() *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
|
||||
RegisterGatewayRoutes(
|
||||
router,
|
||||
&handler.Handlers{
|
||||
Gateway: &handler.GatewayHandler{},
|
||||
OpenAIGateway: &handler.OpenAIGatewayHandler{},
|
||||
SoraGateway: &handler.SoraGatewayHandler{},
|
||||
},
|
||||
servermiddleware.APIKeyAuthMiddleware(func(c *gin.Context) {
|
||||
c.Next()
|
||||
}),
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
&config.Config{},
|
||||
)
|
||||
|
||||
return router
|
||||
}
|
||||
|
||||
func TestGatewayRoutesOpenAIResponsesCompactPathIsRegistered(t *testing.T) {
|
||||
router := newGatewayRoutesTestRouter()
|
||||
|
||||
for _, path := range []string{"/v1/responses/compact", "/responses/compact"} {
|
||||
req := httptest.NewRequest(http.MethodPost, path, strings.NewReader(`{"model":"gpt-5"}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(w, req)
|
||||
require.NotEqual(t, http.StatusNotFound, w.Code, "path=%s should hit OpenAI responses handler", path)
|
||||
}
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package routes
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -12,6 +13,7 @@ func RegisterSoraClientRoutes(
|
||||
v1 *gin.RouterGroup,
|
||||
h *handler.Handlers,
|
||||
jwtAuth middleware.JWTAuthMiddleware,
|
||||
settingService *service.SettingService,
|
||||
) {
|
||||
if h.SoraClient == nil {
|
||||
return
|
||||
@@ -19,6 +21,7 @@ func RegisterSoraClientRoutes(
|
||||
|
||||
authenticated := v1.Group("/sora")
|
||||
authenticated.Use(gin.HandlerFunc(jwtAuth))
|
||||
authenticated.Use(middleware.BackendModeUserGuard(settingService))
|
||||
{
|
||||
authenticated.POST("/generate", h.SoraClient.Generate)
|
||||
authenticated.GET("/generations", h.SoraClient.ListGenerations)
|
||||
|
||||
@@ -3,6 +3,7 @@ package routes
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -12,9 +13,11 @@ func RegisterUserRoutes(
|
||||
v1 *gin.RouterGroup,
|
||||
h *handler.Handlers,
|
||||
jwtAuth middleware.JWTAuthMiddleware,
|
||||
settingService *service.SettingService,
|
||||
) {
|
||||
authenticated := v1.Group("")
|
||||
authenticated.Use(gin.HandlerFunc(jwtAuth))
|
||||
authenticated.Use(middleware.BackendModeUserGuard(settingService))
|
||||
{
|
||||
// 用户接口
|
||||
user := authenticated.Group("/user")
|
||||
|
||||
Reference in New Issue
Block a user