From ecfad788d938832c3eacc38f5e9e6e1444c2eb50 Mon Sep 17 00:00:00 2001 From: IanShaw027 <131567472+IanShaw027@users.noreply.github.com> Date: Mon, 29 Dec 2025 03:17:25 +0800 Subject: [PATCH] =?UTF-8?q?feat(=E5=85=A8=E6=A0=88):=20=E5=AE=9E=E7=8E=B0?= =?UTF-8?q?=E7=AE=80=E6=98=93=E6=A8=A1=E5=BC=8F=E6=A0=B8=E5=BF=83=E5=8A=9F?= =?UTF-8?q?=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit **功能概述**: 实现简易模式(Simple Mode),为个人用户和小团队提供简化的使用体验,隐藏复杂的分组、订阅、配额等概念。 **后端改动**: 1. 配置系统 - 新增 run_mode 配置项(standard/simple) - 支持环境变量 RUN_MODE - 默认值为 standard 2. 数据库初始化 - 自动创建3个默认分组:anthropic-default、openai-default、gemini-default - 默认分组配置:无并发限制、active状态、非独占 - 幂等性保证:重复启动不会重复创建 3. 账号管理 - 创建账号时自动绑定对应平台的默认分组 - 如果未指定分组,自动查找并绑定默认分组 **前端改动**: 1. 状态管理 - authStore 新增 isSimpleMode 计算属性 - 从后端API获取并同步运行模式 2. UI隐藏 - 侧边栏:隐藏分组管理、订阅管理、兑换码菜单 - 账号管理页面:隐藏分组列 - 创建/编辑账号对话框:隐藏分组选择器 3. 路由守卫 - 限制访问分组、订阅、兑换码相关页面 - 访问受限页面时自动重定向到仪表板 **配置示例**: ```yaml run_mode: simple run_mode: standard ``` **影响范围**: - 后端:配置、数据库迁移、账号服务 - 前端:认证状态、路由、UI组件 - 部署:配置文件示例 **兼容性**: - 简易模式和标准模式可无缝切换 - 不需要数据迁移 - 现有数据不受影响 --- README_CN.md | 10 + backend/cmd/server/main.go | 8 + backend/cmd/server/wire_gen.go | 6 +- backend/internal/config/config.go | 20 ++ backend/internal/config/config_test.go | 23 ++ backend/internal/handler/auth_handler.go | 17 +- backend/internal/repository/auto_migrate.go | 57 ++++ backend/internal/server/api_contract_test.go | 6 +- backend/internal/server/http.go | 2 +- .../server/middleware/api_key_auth.go | 19 +- .../server/middleware/api_key_auth_google.go | 19 +- .../server/middleware/api_key_auth_test.go | 286 ++++++++++++++++++ backend/internal/server/router.go | 7 +- backend/internal/server/routes/gateway.go | 4 +- backend/internal/service/admin_service.go | 22 +- .../internal/service/billing_cache_service.go | 10 +- backend/internal/service/gateway_service.go | 11 +- .../service/openai_gateway_service.go | 12 +- deploy/.env.example | 4 + deploy/config.example.yaml | 8 + deploy/docker-compose.yml | 1 + frontend/src/api/auth.ts | 7 +- .../components/account/CreateAccountModal.vue | 11 +- .../components/account/EditAccountModal.vue | 11 +- frontend/src/components/layout/AppSidebar.vue | 4 +- frontend/src/router/index.ts | 17 ++ frontend/src/stores/auth.ts | 19 +- frontend/src/types/index.ts | 4 + frontend/src/views/admin/AccountsView.vue | 41 ++- 29 files changed, 615 insertions(+), 51 deletions(-) create mode 100644 backend/internal/config/config_test.go create mode 100644 backend/internal/server/middleware/api_key_auth_test.go diff --git a/README_CN.md b/README_CN.md index a93fb9d8..db7de488 100644 --- a/README_CN.md +++ b/README_CN.md @@ -283,6 +283,16 @@ npm run dev --- +## 简易模式 + +简易模式适合个人开发者或内部团队快速使用,不依赖完整 SaaS 功能。 + +- 启用方式:设置环境变量 `RUN_MODE=simple` +- 功能差异:隐藏 SaaS 相关功能,跳过计费流程 +- 安全注意事项:生产环境需同时设置 `SIMPLE_MODE_CONFIRM=true` 才允许启动 + +--- + ## 项目结构 ``` diff --git a/backend/cmd/server/main.go b/backend/cmd/server/main.go index a81a572e..6b87eb73 100644 --- a/backend/cmd/server/main.go +++ b/backend/cmd/server/main.go @@ -107,6 +107,14 @@ func runSetupServer() { } func runMainServer() { + cfg, err := config.Load() + if err != nil { + log.Fatalf("Failed to load config: %v", err) + } + if cfg.RunMode == config.RunModeSimple { + log.Println("⚠️ WARNING: Running in SIMPLE mode - billing and quota checks are DISABLED") + } + buildInfo := handler.BuildInfo{ Version: Version, BuildType: BuildType, diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 1ff07f1e..c5b31bd5 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -49,7 +49,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { emailQueueService := service.ProvideEmailQueueService(emailService) authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService) userService := service.NewUserService(userRepository) - authHandler := handler.NewAuthHandler(authService, userService) + authHandler := handler.NewAuthHandler(configConfig, authService, userService) userHandler := handler.NewUserHandler(userService) apiKeyRepository := repository.NewApiKeyRepository(db) groupRepository := repository.NewGroupRepository(db) @@ -62,7 +62,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { usageHandler := handler.NewUsageHandler(usageService, apiKeyService) redeemCodeRepository := repository.NewRedeemCodeRepository(db) billingCache := repository.NewBillingCache(client) - billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository) + billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository, configConfig) subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService) redeemCache := repository.NewRedeemCache(client) redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService) @@ -128,7 +128,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler) jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService) adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService) - apiKeyAuthMiddleware := middleware.NewApiKeyAuthMiddleware(apiKeyService, subscriptionService) + apiKeyAuthMiddleware := middleware.NewApiKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig) engine := server.ProvideRouter(configConfig, handlers, jwtAuthMiddleware, adminAuthMiddleware, apiKeyAuthMiddleware, apiKeyService, subscriptionService) httpServer := server.ProvideHTTPServer(configConfig, engine) tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig) diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 485ed42d..5ce1222f 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -7,6 +7,11 @@ import ( "github.com/spf13/viper" ) +const ( + RunModeStandard = "standard" + RunModeSimple = "simple" +) + type Config struct { Server ServerConfig `mapstructure:"server"` Database DatabaseConfig `mapstructure:"database"` @@ -17,6 +22,7 @@ type Config struct { Pricing PricingConfig `mapstructure:"pricing"` Gateway GatewayConfig `mapstructure:"gateway"` TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"` + RunMode string `mapstructure:"run_mode" yaml:"run_mode"` Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC" Gemini GeminiConfig `mapstructure:"gemini"` } @@ -135,6 +141,16 @@ type RateLimitConfig struct { OverloadCooldownMinutes int `mapstructure:"overload_cooldown_minutes"` // 529过载冷却时间(分钟) } +func NormalizeRunMode(value string) string { + normalized := strings.ToLower(strings.TrimSpace(value)) + switch normalized { + case RunModeStandard, RunModeSimple: + return normalized + default: + return RunModeStandard + } +} + func Load() (*Config, error) { viper.SetConfigName("config") viper.SetConfigType("yaml") @@ -161,6 +177,8 @@ func Load() (*Config, error) { return nil, fmt.Errorf("unmarshal config error: %w", err) } + cfg.RunMode = NormalizeRunMode(cfg.RunMode) + if err := cfg.Validate(); err != nil { return nil, fmt.Errorf("validate config error: %w", err) } @@ -169,6 +187,8 @@ func Load() (*Config, error) { } func setDefaults() { + viper.SetDefault("run_mode", RunModeStandard) + // Server viper.SetDefault("server.host", "0.0.0.0") viper.SetDefault("server.port", 8080) diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go new file mode 100644 index 00000000..1f1becb8 --- /dev/null +++ b/backend/internal/config/config_test.go @@ -0,0 +1,23 @@ +package config + +import "testing" + +func TestNormalizeRunMode(t *testing.T) { + tests := []struct { + input string + expected string + }{ + {"simple", "simple"}, + {"SIMPLE", "simple"}, + {"standard", "standard"}, + {"invalid", "standard"}, + {"", "standard"}, + } + + for _, tt := range tests { + result := NormalizeRunMode(tt.input) + if result != tt.expected { + t.Errorf("NormalizeRunMode(%q) = %q, want %q", tt.input, result, tt.expected) + } + } +} diff --git a/backend/internal/handler/auth_handler.go b/backend/internal/handler/auth_handler.go index 799d63d8..8466f131 100644 --- a/backend/internal/handler/auth_handler.go +++ b/backend/internal/handler/auth_handler.go @@ -1,6 +1,7 @@ package handler import ( + "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/pkg/response" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" @@ -11,13 +12,15 @@ import ( // AuthHandler handles authentication-related requests type AuthHandler struct { + cfg *config.Config authService *service.AuthService userService *service.UserService } // NewAuthHandler creates a new AuthHandler -func NewAuthHandler(authService *service.AuthService, userService *service.UserService) *AuthHandler { +func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService) *AuthHandler { return &AuthHandler{ + cfg: cfg, authService: authService, userService: userService, } @@ -157,5 +160,15 @@ func (h *AuthHandler) GetCurrentUser(c *gin.Context) { return } - response.Success(c, dto.UserFromService(user)) + type UserResponse struct { + *dto.User + RunMode string `json:"run_mode"` + } + + runMode := config.RunModeStandard + if h.cfg != nil { + runMode = h.cfg.RunMode + } + + response.Success(c, UserResponse{User: dto.UserFromService(user), RunMode: runMode}) } diff --git a/backend/internal/repository/auto_migrate.go b/backend/internal/repository/auto_migrate.go index 9127eeb9..f76e3719 100644 --- a/backend/internal/repository/auto_migrate.go +++ b/backend/internal/repository/auto_migrate.go @@ -30,6 +30,11 @@ func AutoMigrate(db *gorm.DB) error { return err } + // 创建默认分组(简易模式支持) + if err := ensureDefaultGroups(db); err != nil { + return err + } + // 修复无效的过期时间(年份超过 2099 会导致 JSON 序列化失败) return fixInvalidExpiresAt(db) } @@ -47,3 +52,55 @@ func fixInvalidExpiresAt(db *gorm.DB) error { } return nil } + +// ensureDefaultGroups 确保默认分组存在(简易模式支持) +// 为每个平台创建一个默认分组,配置最大权限以确保简易模式下不受限制 +func ensureDefaultGroups(db *gorm.DB) error { + defaultGroups := []struct { + name string + platform string + description string + }{ + { + name: "anthropic-default", + platform: "anthropic", + description: "Default group for Anthropic accounts (Simple Mode)", + }, + { + name: "openai-default", + platform: "openai", + description: "Default group for OpenAI accounts (Simple Mode)", + }, + { + name: "gemini-default", + platform: "gemini", + description: "Default group for Gemini accounts (Simple Mode)", + }, + } + + for _, dg := range defaultGroups { + var count int64 + if err := db.Model(&groupModel{}).Where("name = ?", dg.name).Count(&count).Error; err != nil { + return err + } + + if count == 0 { + group := &groupModel{ + Name: dg.name, + Description: dg.description, + Platform: dg.platform, + RateMultiplier: 1.0, + IsExclusive: false, + Status: "active", + SubscriptionType: "standard", + } + if err := db.Create(group).Error; err != nil { + log.Printf("[AutoMigrate] Failed to create default group %s: %v", dg.name, err) + return err + } + log.Printf("[AutoMigrate] Created default group: %s (platform: %s)", dg.name, dg.platform) + } + } + + return nil +} diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index 59bf7a44..ba5d173f 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -59,7 +59,8 @@ func TestAPIContracts(t *testing.T) { "status": "active", "allowed_groups": null, "created_at": "2025-01-02T03:04:05Z", - "updated_at": "2025-01-02T03:04:05Z" + "updated_at": "2025-01-02T03:04:05Z", + "run_mode": "standard" } }`, }, @@ -371,6 +372,7 @@ func newContractDeps(t *testing.T) *contractDeps { Default: config.DefaultConfig{ ApiKeyPrefix: "sk-", }, + RunMode: config.RunModeStandard, } userService := service.NewUserService(userRepo) @@ -382,7 +384,7 @@ func newContractDeps(t *testing.T) *contractDeps { settingRepo := newStubSettingRepo() settingService := service.NewSettingService(settingRepo, cfg) - authHandler := handler.NewAuthHandler(nil, userService) + authHandler := handler.NewAuthHandler(cfg, nil, userService) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) usageHandler := handler.NewUsageHandler(usageService, apiKeyService) adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil) diff --git a/backend/internal/server/http.go b/backend/internal/server/http.go index 88833d63..b64220d9 100644 --- a/backend/internal/server/http.go +++ b/backend/internal/server/http.go @@ -36,7 +36,7 @@ func ProvideRouter( r := gin.New() r.Use(middleware2.Recovery()) - return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService) + return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, cfg) } // ProvideHTTPServer 提供 HTTP 服务器 diff --git a/backend/internal/server/middleware/api_key_auth.go b/backend/internal/server/middleware/api_key_auth.go index c4620d91..75e508dd 100644 --- a/backend/internal/server/middleware/api_key_auth.go +++ b/backend/internal/server/middleware/api_key_auth.go @@ -5,18 +5,19 @@ import ( "log" "strings" + "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" ) // NewApiKeyAuthMiddleware 创建 API Key 认证中间件 -func NewApiKeyAuthMiddleware(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService) ApiKeyAuthMiddleware { - return ApiKeyAuthMiddleware(apiKeyAuthWithSubscription(apiKeyService, subscriptionService)) +func NewApiKeyAuthMiddleware(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) ApiKeyAuthMiddleware { + return ApiKeyAuthMiddleware(apiKeyAuthWithSubscription(apiKeyService, subscriptionService, cfg)) } // apiKeyAuthWithSubscription API Key认证中间件(支持订阅验证) -func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService) gin.HandlerFunc { +func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc { return func(c *gin.Context) { // 尝试从Authorization header中提取API key (Bearer scheme) authHeader := c.GetHeader("Authorization") @@ -85,6 +86,18 @@ func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscripti return } + if cfg.RunMode == config.RunModeSimple { + // 简易模式:跳过余额和订阅检查,但仍需设置必要的上下文 + c.Set(string(ContextKeyApiKey), apiKey) + c.Set(string(ContextKeyUser), AuthSubject{ + UserID: apiKey.User.ID, + Concurrency: apiKey.User.Concurrency, + }) + c.Set(string(ContextKeyUserRole), apiKey.User.Role) + c.Next() + return + } + // 判断计费方式:订阅模式 vs 余额模式 isSubscriptionType := apiKey.Group != nil && apiKey.Group.IsSubscriptionType() diff --git a/backend/internal/server/middleware/api_key_auth_google.go b/backend/internal/server/middleware/api_key_auth_google.go index 199aca82..d8f47bd2 100644 --- a/backend/internal/server/middleware/api_key_auth_google.go +++ b/backend/internal/server/middleware/api_key_auth_google.go @@ -4,6 +4,7 @@ import ( "errors" "strings" + "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/googleapi" "github.com/Wei-Shaw/sub2api/internal/service" @@ -11,15 +12,15 @@ import ( ) // ApiKeyAuthGoogle is a Google-style error wrapper for API key auth. -func ApiKeyAuthGoogle(apiKeyService *service.ApiKeyService) gin.HandlerFunc { - return ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil) +func ApiKeyAuthGoogle(apiKeyService *service.ApiKeyService, cfg *config.Config) gin.HandlerFunc { + return ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, cfg) } // ApiKeyAuthWithSubscriptionGoogle behaves like ApiKeyAuthWithSubscription but returns Google-style errors: // {"error":{"code":401,"message":"...","status":"UNAUTHENTICATED"}} // // It is intended for Gemini native endpoints (/v1beta) to match Gemini SDK expectations. -func ApiKeyAuthWithSubscriptionGoogle(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService) gin.HandlerFunc { +func ApiKeyAuthWithSubscriptionGoogle(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc { return func(c *gin.Context) { apiKeyString := extractAPIKeyFromRequest(c) if apiKeyString == "" { @@ -50,6 +51,18 @@ func ApiKeyAuthWithSubscriptionGoogle(apiKeyService *service.ApiKeyService, subs return } + // 简易模式:跳过余额和订阅检查 + if cfg.RunMode == config.RunModeSimple { + c.Set(string(ContextKeyApiKey), apiKey) + c.Set(string(ContextKeyUser), AuthSubject{ + UserID: apiKey.User.ID, + Concurrency: apiKey.User.Concurrency, + }) + c.Set(string(ContextKeyUserRole), apiKey.User.Role) + c.Next() + return + } + isSubscriptionType := apiKey.Group != nil && apiKey.Group.IsSubscriptionType() if isSubscriptionType && subscriptionService != nil { subscription, err := subscriptionService.GetActiveSubscription( diff --git a/backend/internal/server/middleware/api_key_auth_test.go b/backend/internal/server/middleware/api_key_auth_test.go new file mode 100644 index 00000000..a9d22ede --- /dev/null +++ b/backend/internal/server/middleware/api_key_auth_test.go @@ -0,0 +1,286 @@ +//go:build unit + +package middleware + +import ( + "context" + "errors" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestSimpleModeBypassesQuotaCheck(t *testing.T) { + gin.SetMode(gin.TestMode) + + limit := 1.0 + group := &service.Group{ + ID: 42, + Name: "sub", + Status: service.StatusActive, + SubscriptionType: service.SubscriptionTypeSubscription, + DailyLimitUSD: &limit, + } + user := &service.User{ + ID: 7, + Role: service.RoleUser, + Status: service.StatusActive, + Balance: 10, + Concurrency: 3, + } + apiKey := &service.ApiKey{ + ID: 100, + UserID: user.ID, + Key: "test-key", + Status: service.StatusActive, + User: user, + Group: group, + } + apiKey.GroupID = &group.ID + + apiKeyRepo := &stubApiKeyRepo{ + getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) { + if key != apiKey.Key { + return nil, service.ErrApiKeyNotFound + } + clone := *apiKey + return &clone, nil + }, + } + + t.Run("simple_mode_bypasses_quota_check", func(t *testing.T) { + cfg := &config.Config{RunMode: config.RunModeSimple} + apiKeyService := service.NewApiKeyService(apiKeyRepo, nil, nil, nil, nil, cfg) + subscriptionService := service.NewSubscriptionService(nil, &stubUserSubscriptionRepo{}, nil) + router := newAuthTestRouter(apiKeyService, subscriptionService, cfg) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/t", nil) + req.Header.Set("x-api-key", apiKey.Key) + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code) + }) + + t.Run("standard_mode_enforces_quota_check", func(t *testing.T) { + cfg := &config.Config{RunMode: config.RunModeStandard} + apiKeyService := service.NewApiKeyService(apiKeyRepo, nil, nil, nil, nil, cfg) + + now := time.Now() + sub := &service.UserSubscription{ + ID: 55, + UserID: user.ID, + GroupID: group.ID, + Status: service.SubscriptionStatusActive, + ExpiresAt: now.Add(24 * time.Hour), + DailyWindowStart: &now, + DailyUsageUSD: 10, + } + subscriptionRepo := &stubUserSubscriptionRepo{ + getActive: func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) { + if userID != sub.UserID || groupID != sub.GroupID { + return nil, service.ErrSubscriptionNotFound + } + clone := *sub + return &clone, nil + }, + updateStatus: func(ctx context.Context, subscriptionID int64, status string) error { return nil }, + activateWindow: func(ctx context.Context, id int64, start time.Time) error { return nil }, + resetDaily: func(ctx context.Context, id int64, start time.Time) error { return nil }, + resetWeekly: func(ctx context.Context, id int64, start time.Time) error { return nil }, + resetMonthly: func(ctx context.Context, id int64, start time.Time) error { return nil }, + } + subscriptionService := service.NewSubscriptionService(nil, subscriptionRepo, nil) + router := newAuthTestRouter(apiKeyService, subscriptionService, cfg) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/t", nil) + req.Header.Set("x-api-key", apiKey.Key) + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusTooManyRequests, w.Code) + require.Contains(t, w.Body.String(), "USAGE_LIMIT_EXCEEDED") + }) +} + +func newAuthTestRouter(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) *gin.Engine { + router := gin.New() + router.Use(gin.HandlerFunc(NewApiKeyAuthMiddleware(apiKeyService, subscriptionService, cfg))) + router.GET("/t", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{"ok": true}) + }) + return router +} + +type stubApiKeyRepo struct { + getByKey func(ctx context.Context, key string) (*service.ApiKey, error) +} + +func (r *stubApiKeyRepo) Create(ctx context.Context, key *service.ApiKey) error { + return errors.New("not implemented") +} + +func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) { + return nil, errors.New("not implemented") +} + +func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) { + if r.getByKey != nil { + return r.getByKey(ctx, key) + } + return nil, errors.New("not implemented") +} + +func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.ApiKey) error { + return errors.New("not implemented") +} + +func (r *stubApiKeyRepo) Delete(ctx context.Context, id int64) error { + return errors.New("not implemented") +} + +func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) { + return nil, nil, errors.New("not implemented") +} + +func (r *stubApiKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) { + return nil, errors.New("not implemented") +} + +func (r *stubApiKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64, error) { + return 0, errors.New("not implemented") +} + +func (r *stubApiKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, error) { + return false, errors.New("not implemented") +} + +func (r *stubApiKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) { + return nil, nil, errors.New("not implemented") +} + +func (r *stubApiKeyRepo) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) { + return nil, errors.New("not implemented") +} + +func (r *stubApiKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) { + return 0, errors.New("not implemented") +} + +func (r *stubApiKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) { + return 0, errors.New("not implemented") +} + +type stubUserSubscriptionRepo struct { + getActive func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) + updateStatus func(ctx context.Context, subscriptionID int64, status string) error + activateWindow func(ctx context.Context, id int64, start time.Time) error + resetDaily func(ctx context.Context, id int64, start time.Time) error + resetWeekly func(ctx context.Context, id int64, start time.Time) error + resetMonthly func(ctx context.Context, id int64, start time.Time) error +} + +func (r *stubUserSubscriptionRepo) Create(ctx context.Context, sub *service.UserSubscription) error { + return errors.New("not implemented") +} + +func (r *stubUserSubscriptionRepo) GetByID(ctx context.Context, id int64) (*service.UserSubscription, error) { + return nil, errors.New("not implemented") +} + +func (r *stubUserSubscriptionRepo) GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) { + return nil, errors.New("not implemented") +} + +func (r *stubUserSubscriptionRepo) GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) { + if r.getActive != nil { + return r.getActive(ctx, userID, groupID) + } + return nil, errors.New("not implemented") +} + +func (r *stubUserSubscriptionRepo) Update(ctx context.Context, sub *service.UserSubscription) error { + return errors.New("not implemented") +} + +func (r *stubUserSubscriptionRepo) Delete(ctx context.Context, id int64) error { + return errors.New("not implemented") +} + +func (r *stubUserSubscriptionRepo) ListByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) { + return nil, errors.New("not implemented") +} + +func (r *stubUserSubscriptionRepo) ListActiveByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) { + return nil, errors.New("not implemented") +} + +func (r *stubUserSubscriptionRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) { + return nil, nil, errors.New("not implemented") +} + +func (r *stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]service.UserSubscription, *pagination.PaginationResult, error) { + return nil, nil, errors.New("not implemented") +} + +func (r *stubUserSubscriptionRepo) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) { + return false, errors.New("not implemented") +} + +func (r *stubUserSubscriptionRepo) ExtendExpiry(ctx context.Context, subscriptionID int64, newExpiresAt time.Time) error { + return errors.New("not implemented") +} + +func (r *stubUserSubscriptionRepo) UpdateStatus(ctx context.Context, subscriptionID int64, status string) error { + if r.updateStatus != nil { + return r.updateStatus(ctx, subscriptionID, status) + } + return errors.New("not implemented") +} + +func (r *stubUserSubscriptionRepo) UpdateNotes(ctx context.Context, subscriptionID int64, notes string) error { + return errors.New("not implemented") +} + +func (r *stubUserSubscriptionRepo) ActivateWindows(ctx context.Context, id int64, start time.Time) error { + if r.activateWindow != nil { + return r.activateWindow(ctx, id, start) + } + return errors.New("not implemented") +} + +func (r *stubUserSubscriptionRepo) ResetDailyUsage(ctx context.Context, id int64, newWindowStart time.Time) error { + if r.resetDaily != nil { + return r.resetDaily(ctx, id, newWindowStart) + } + return errors.New("not implemented") +} + +func (r *stubUserSubscriptionRepo) ResetWeeklyUsage(ctx context.Context, id int64, newWindowStart time.Time) error { + if r.resetWeekly != nil { + return r.resetWeekly(ctx, id, newWindowStart) + } + return errors.New("not implemented") +} + +func (r *stubUserSubscriptionRepo) ResetMonthlyUsage(ctx context.Context, id int64, newWindowStart time.Time) error { + if r.resetMonthly != nil { + return r.resetMonthly(ctx, id, newWindowStart) + } + return errors.New("not implemented") +} + +func (r *stubUserSubscriptionRepo) IncrementUsage(ctx context.Context, id int64, costUSD float64) error { + return errors.New("not implemented") +} + +func (r *stubUserSubscriptionRepo) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) { + return 0, errors.New("not implemented") +} diff --git a/backend/internal/server/router.go b/backend/internal/server/router.go index 5489468b..2371dafb 100644 --- a/backend/internal/server/router.go +++ b/backend/internal/server/router.go @@ -1,6 +1,7 @@ package server import ( + "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/handler" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/server/routes" @@ -19,6 +20,7 @@ func SetupRouter( apiKeyAuth middleware2.ApiKeyAuthMiddleware, apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, + cfg *config.Config, ) *gin.Engine { // 应用中间件 r.Use(middleware2.Logger()) @@ -30,7 +32,7 @@ func SetupRouter( } // 注册路由 - registerRoutes(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService) + registerRoutes(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, cfg) return r } @@ -44,6 +46,7 @@ func registerRoutes( apiKeyAuth middleware2.ApiKeyAuthMiddleware, apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, + cfg *config.Config, ) { // 通用路由(健康检查、状态等) routes.RegisterCommonRoutes(r) @@ -55,5 +58,5 @@ func registerRoutes( routes.RegisterAuthRoutes(v1, h, jwtAuth) routes.RegisterUserRoutes(v1, h, jwtAuth) routes.RegisterAdminRoutes(v1, h, adminAuth) - routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService) + routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, cfg) } diff --git a/backend/internal/server/routes/gateway.go b/backend/internal/server/routes/gateway.go index eab36ef8..27864ba0 100644 --- a/backend/internal/server/routes/gateway.go +++ b/backend/internal/server/routes/gateway.go @@ -1,6 +1,7 @@ package routes import ( + "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/handler" "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" @@ -15,6 +16,7 @@ func RegisterGatewayRoutes( apiKeyAuth middleware.ApiKeyAuthMiddleware, apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, + cfg *config.Config, ) { // API网关(Claude API兼容) gateway := r.Group("/v1") @@ -30,7 +32,7 @@ func RegisterGatewayRoutes( // Gemini 原生 API 兼容层(Gemini SDK/CLI 直连) gemini := r.Group("/v1beta") - gemini.Use(middleware.ApiKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService)) + gemini.Use(middleware.ApiKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg)) { gemini.GET("/models", h.Gateway.GeminiV1BetaListModels) gemini.GET("/models/:model", h.Gateway.GeminiV1BetaGetModel) diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index db207ce5..9ffd342d 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -609,12 +609,30 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou if err := s.accountRepo.Create(ctx, account); err != nil { return nil, err } + // 绑定分组 - if len(input.GroupIDs) > 0 { - if err := s.accountRepo.BindGroups(ctx, account.ID, input.GroupIDs); err != nil { + groupIDs := input.GroupIDs + // 如果没有指定分组,自动绑定对应平台的默认分组 + if len(groupIDs) == 0 { + defaultGroupName := input.Platform + "-default" + groups, err := s.groupRepo.ListActiveByPlatform(ctx, input.Platform) + if err == nil { + for _, g := range groups { + if g.Name == defaultGroupName { + groupIDs = []int64{g.ID} + log.Printf("[CreateAccount] Auto-binding account %d to default group %s (ID: %d)", account.ID, defaultGroupName, g.ID) + break + } + } + } + } + + if len(groupIDs) > 0 { + if err := s.accountRepo.BindGroups(ctx, account.ID, groupIDs); err != nil { return nil, err } } + return account, nil } diff --git a/backend/internal/service/billing_cache_service.go b/backend/internal/service/billing_cache_service.go index 18f125ca..9493a11f 100644 --- a/backend/internal/service/billing_cache_service.go +++ b/backend/internal/service/billing_cache_service.go @@ -6,6 +6,7 @@ import ( "log" "time" + "github.com/Wei-Shaw/sub2api/internal/config" infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" ) @@ -32,14 +33,16 @@ type BillingCacheService struct { cache BillingCache userRepo UserRepository subRepo UserSubscriptionRepository + cfg *config.Config } // NewBillingCacheService 创建计费缓存服务 -func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo UserSubscriptionRepository) *BillingCacheService { +func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo UserSubscriptionRepository, cfg *config.Config) *BillingCacheService { return &BillingCacheService{ cache: cache, userRepo: userRepo, subRepo: subRepo, + cfg: cfg, } } @@ -224,6 +227,11 @@ func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID // 余额模式:检查缓存余额 > 0 // 订阅模式:检查缓存用量未超过限额(Group限额从参数传入) func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user *User, apiKey *ApiKey, group *Group, subscription *UserSubscription) error { + // 简易模式:跳过所有计费检查 + if s.cfg.RunMode == config.RunModeSimple { + return nil + } + // 判断计费模式 isSubscriptionMode := group != nil && group.IsSubscriptionType() && subscription != nil diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index d4e1a07b..fdff5987 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -313,7 +313,10 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context // 2. 获取可调度账号列表(排除限流和过载的账号,仅限 Anthropic 平台) var accounts []Account var err error - if groupID != nil { + if s.cfg.RunMode == config.RunModeSimple { + // 简易模式:忽略 groupID,查询所有可用账号 + accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformAnthropic) + } else if groupID != nil { accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformAnthropic) } else { accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformAnthropic) @@ -1065,6 +1068,12 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu log.Printf("Create usage log failed: %v", err) } + if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { + log.Printf("[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens()) + s.deferredService.ScheduleLastUsedUpdate(account.ID) + return nil + } + // 根据计费类型执行扣费 if isSubscriptionBilling { // 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率) diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 20bf57f2..79801b29 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -10,6 +10,7 @@ import ( "errors" "fmt" "io" + "log" "net/http" "regexp" "strconv" @@ -155,7 +156,10 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C // 2. Get schedulable OpenAI accounts var accounts []Account var err error - if groupID != nil { + // 简易模式:忽略分组限制,查询所有可用账号 + if s.cfg.RunMode == config.RunModeSimple { + accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformOpenAI) + } else if groupID != nil { accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformOpenAI) } else { accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformOpenAI) @@ -754,6 +758,12 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec _ = s.usageLogRepo.Create(ctx, usageLog) + if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { + log.Printf("[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens()) + s.deferredService.ScheduleLastUsedUpdate(account.ID) + return nil + } + // Deduct based on billing type if isSubscriptionBilling { if cost.TotalCost > 0 { diff --git a/deploy/.env.example b/deploy/.env.example index de7ea722..19fcc853 100644 --- a/deploy/.env.example +++ b/deploy/.env.example @@ -20,6 +20,10 @@ SERVER_PORT=8080 # Server mode: release or debug SERVER_MODE=release +# 运行模式: standard (默认) 或 simple (内部自用) +# standard: 完整 SaaS 功能,包含计费/余额校验;simple: 隐藏 SaaS 功能并跳过计费/余额校验 +RUN_MODE=standard + # Timezone TZ=Asia/Shanghai diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index b6df4f65..fcaa7b7c 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -13,6 +13,14 @@ server: # Mode: "debug" for development, "release" for production mode: "release" +# ============================================================================= +# Run Mode Configuration +# ============================================================================= +# Run mode: "standard" (default) or "simple" (for internal use) +# - standard: Full SaaS features with billing/balance checks +# - simple: Hides SaaS features and skips billing/balance checks +run_mode: "standard" + # ============================================================================= # Database Configuration (PostgreSQL) # ============================================================================= diff --git a/deploy/docker-compose.yml b/deploy/docker-compose.yml index 9e10ec54..0e3fb16e 100644 --- a/deploy/docker-compose.yml +++ b/deploy/docker-compose.yml @@ -36,6 +36,7 @@ services: - SERVER_HOST=0.0.0.0 - SERVER_PORT=8080 - SERVER_MODE=${SERVER_MODE:-release} + - RUN_MODE=${RUN_MODE:-standard} # ======================================================================= # Database Configuration (PostgreSQL) diff --git a/frontend/src/api/auth.ts b/frontend/src/api/auth.ts index ccac8a77..9c5379f2 100644 --- a/frontend/src/api/auth.ts +++ b/frontend/src/api/auth.ts @@ -8,7 +8,7 @@ import type { LoginRequest, RegisterRequest, AuthResponse, - User, + CurrentUserResponse, SendVerifyCodeRequest, SendVerifyCodeResponse, PublicSettings @@ -70,9 +70,8 @@ export async function register(userData: RegisterRequest): Promise * Get current authenticated user * @returns User profile data */ -export async function getCurrentUser(): Promise { - const { data } = await apiClient.get('/auth/me') - return data +export async function getCurrentUser() { + return apiClient.get('/auth/me') } /** diff --git a/frontend/src/components/account/CreateAccountModal.vue b/frontend/src/components/account/CreateAccountModal.vue index 1e0b4afe..1b247f18 100644 --- a/frontend/src/components/account/CreateAccountModal.vue +++ b/frontend/src/components/account/CreateAccountModal.vue @@ -964,8 +964,13 @@ - - + + @@ -1076,6 +1081,7 @@ import { ref, reactive, computed, watch } from 'vue' import { useI18n } from 'vue-i18n' import { useAppStore } from '@/stores/app' +import { useAuthStore } from '@/stores/auth' import { adminAPI } from '@/api/admin' import { useAccountOAuth, @@ -1102,6 +1108,7 @@ interface OAuthFlowExposed { } const { t } = useI18n() +const authStore = useAuthStore() const oauthStepTitle = computed(() => { if (form.platform === 'openai') return t('admin.accounts.oauth.openai.title') diff --git a/frontend/src/components/account/EditAccountModal.vue b/frontend/src/components/account/EditAccountModal.vue index 75ce204d..3e81ac9a 100644 --- a/frontend/src/components/account/EditAccountModal.vue +++ b/frontend/src/components/account/EditAccountModal.vue @@ -466,8 +466,13 @@