diff --git a/README_CN.md b/README_CN.md index a93fb9d8..db7de488 100644 --- a/README_CN.md +++ b/README_CN.md @@ -283,6 +283,16 @@ npm run dev --- +## 简易模式 + +简易模式适合个人开发者或内部团队快速使用,不依赖完整 SaaS 功能。 + +- 启用方式:设置环境变量 `RUN_MODE=simple` +- 功能差异:隐藏 SaaS 相关功能,跳过计费流程 +- 安全注意事项:生产环境需同时设置 `SIMPLE_MODE_CONFIRM=true` 才允许启动 + +--- + ## 项目结构 ``` diff --git a/backend/cmd/server/main.go b/backend/cmd/server/main.go index a81a572e..6b87eb73 100644 --- a/backend/cmd/server/main.go +++ b/backend/cmd/server/main.go @@ -107,6 +107,14 @@ func runSetupServer() { } func runMainServer() { + cfg, err := config.Load() + if err != nil { + log.Fatalf("Failed to load config: %v", err) + } + if cfg.RunMode == config.RunModeSimple { + log.Println("⚠️ WARNING: Running in SIMPLE mode - billing and quota checks are DISABLED") + } + buildInfo := handler.BuildInfo{ Version: Version, BuildType: BuildType, diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 8b22c3b4..8dda96f9 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -49,7 +49,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { emailQueueService := service.ProvideEmailQueueService(emailService) authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService) userService := service.NewUserService(userRepository) - authHandler := handler.NewAuthHandler(authService, userService) + authHandler := handler.NewAuthHandler(configConfig, authService, userService) userHandler := handler.NewUserHandler(userService) apiKeyRepository := repository.NewApiKeyRepository(db) groupRepository := repository.NewGroupRepository(db) @@ -62,7 +62,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { usageHandler := handler.NewUsageHandler(usageService, apiKeyService) redeemCodeRepository := repository.NewRedeemCodeRepository(db) billingCache := repository.NewBillingCache(client) - billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository) + billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository, configConfig) subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService) redeemCache := repository.NewRedeemCache(client) redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService) @@ -132,7 +132,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler) jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService) adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService) - apiKeyAuthMiddleware := middleware.NewApiKeyAuthMiddleware(apiKeyService, subscriptionService) + apiKeyAuthMiddleware := middleware.NewApiKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig) engine := server.ProvideRouter(configConfig, handlers, jwtAuthMiddleware, adminAuthMiddleware, apiKeyAuthMiddleware, apiKeyService, subscriptionService) httpServer := server.ProvideHTTPServer(configConfig, engine) tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, configConfig) diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 485ed42d..5ce1222f 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -7,6 +7,11 @@ import ( "github.com/spf13/viper" ) +const ( + RunModeStandard = "standard" + RunModeSimple = "simple" +) + type Config struct { Server ServerConfig `mapstructure:"server"` Database DatabaseConfig `mapstructure:"database"` @@ -17,6 +22,7 @@ type Config struct { Pricing PricingConfig `mapstructure:"pricing"` Gateway GatewayConfig `mapstructure:"gateway"` TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"` + RunMode string `mapstructure:"run_mode" yaml:"run_mode"` Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC" Gemini GeminiConfig `mapstructure:"gemini"` } @@ -135,6 +141,16 @@ type RateLimitConfig struct { OverloadCooldownMinutes int `mapstructure:"overload_cooldown_minutes"` // 529过载冷却时间(分钟) } +func NormalizeRunMode(value string) string { + normalized := strings.ToLower(strings.TrimSpace(value)) + switch normalized { + case RunModeStandard, RunModeSimple: + return normalized + default: + return RunModeStandard + } +} + func Load() (*Config, error) { viper.SetConfigName("config") viper.SetConfigType("yaml") @@ -161,6 +177,8 @@ func Load() (*Config, error) { return nil, fmt.Errorf("unmarshal config error: %w", err) } + cfg.RunMode = NormalizeRunMode(cfg.RunMode) + if err := cfg.Validate(); err != nil { return nil, fmt.Errorf("validate config error: %w", err) } @@ -169,6 +187,8 @@ func Load() (*Config, error) { } func setDefaults() { + viper.SetDefault("run_mode", RunModeStandard) + // Server viper.SetDefault("server.host", "0.0.0.0") viper.SetDefault("server.port", 8080) diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go new file mode 100644 index 00000000..1f1becb8 --- /dev/null +++ b/backend/internal/config/config_test.go @@ -0,0 +1,23 @@ +package config + +import "testing" + +func TestNormalizeRunMode(t *testing.T) { + tests := []struct { + input string + expected string + }{ + {"simple", "simple"}, + {"SIMPLE", "simple"}, + {"standard", "standard"}, + {"invalid", "standard"}, + {"", "standard"}, + } + + for _, tt := range tests { + result := NormalizeRunMode(tt.input) + if result != tt.expected { + t.Errorf("NormalizeRunMode(%q) = %q, want %q", tt.input, result, tt.expected) + } + } +} diff --git a/backend/internal/handler/auth_handler.go b/backend/internal/handler/auth_handler.go index 799d63d8..8466f131 100644 --- a/backend/internal/handler/auth_handler.go +++ b/backend/internal/handler/auth_handler.go @@ -1,6 +1,7 @@ package handler import ( + "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/pkg/response" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" @@ -11,13 +12,15 @@ import ( // AuthHandler handles authentication-related requests type AuthHandler struct { + cfg *config.Config authService *service.AuthService userService *service.UserService } // NewAuthHandler creates a new AuthHandler -func NewAuthHandler(authService *service.AuthService, userService *service.UserService) *AuthHandler { +func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService) *AuthHandler { return &AuthHandler{ + cfg: cfg, authService: authService, userService: userService, } @@ -157,5 +160,15 @@ func (h *AuthHandler) GetCurrentUser(c *gin.Context) { return } - response.Success(c, dto.UserFromService(user)) + type UserResponse struct { + *dto.User + RunMode string `json:"run_mode"` + } + + runMode := config.RunModeStandard + if h.cfg != nil { + runMode = h.cfg.RunMode + } + + response.Success(c, UserResponse{User: dto.UserFromService(user), RunMode: runMode}) } diff --git a/backend/internal/repository/auto_migrate.go b/backend/internal/repository/auto_migrate.go index 9127eeb9..f76e3719 100644 --- a/backend/internal/repository/auto_migrate.go +++ b/backend/internal/repository/auto_migrate.go @@ -30,6 +30,11 @@ func AutoMigrate(db *gorm.DB) error { return err } + // 创建默认分组(简易模式支持) + if err := ensureDefaultGroups(db); err != nil { + return err + } + // 修复无效的过期时间(年份超过 2099 会导致 JSON 序列化失败) return fixInvalidExpiresAt(db) } @@ -47,3 +52,55 @@ func fixInvalidExpiresAt(db *gorm.DB) error { } return nil } + +// ensureDefaultGroups 确保默认分组存在(简易模式支持) +// 为每个平台创建一个默认分组,配置最大权限以确保简易模式下不受限制 +func ensureDefaultGroups(db *gorm.DB) error { + defaultGroups := []struct { + name string + platform string + description string + }{ + { + name: "anthropic-default", + platform: "anthropic", + description: "Default group for Anthropic accounts (Simple Mode)", + }, + { + name: "openai-default", + platform: "openai", + description: "Default group for OpenAI accounts (Simple Mode)", + }, + { + name: "gemini-default", + platform: "gemini", + description: "Default group for Gemini accounts (Simple Mode)", + }, + } + + for _, dg := range defaultGroups { + var count int64 + if err := db.Model(&groupModel{}).Where("name = ?", dg.name).Count(&count).Error; err != nil { + return err + } + + if count == 0 { + group := &groupModel{ + Name: dg.name, + Description: dg.description, + Platform: dg.platform, + RateMultiplier: 1.0, + IsExclusive: false, + Status: "active", + SubscriptionType: "standard", + } + if err := db.Create(group).Error; err != nil { + log.Printf("[AutoMigrate] Failed to create default group %s: %v", dg.name, err) + return err + } + log.Printf("[AutoMigrate] Created default group: %s (platform: %s)", dg.name, dg.platform) + } + } + + return nil +} diff --git a/backend/internal/repository/group_repo_integration_test.go b/backend/internal/repository/group_repo_integration_test.go index 33ff6326..42cf2194 100644 --- a/backend/internal/repository/group_repo_integration_test.go +++ b/backend/internal/repository/group_repo_integration_test.go @@ -82,8 +82,9 @@ func (s *GroupRepoSuite) TestList() { groups, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}) s.Require().NoError(err, "List") - s.Require().Len(groups, 2) - s.Require().Equal(int64(2), page.Total) + // 3 default groups + 2 test groups = 5 total + s.Require().Len(groups, 5) + s.Require().Equal(int64(5), page.Total) } func (s *GroupRepoSuite) TestListWithFilters_Platform() { @@ -92,8 +93,12 @@ func (s *GroupRepoSuite) TestListWithFilters_Platform() { groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformOpenAI, "", nil) s.Require().NoError(err) - s.Require().Len(groups, 1) - s.Require().Equal(service.PlatformOpenAI, groups[0].Platform) + // 1 default openai group + 1 test openai group = 2 total + s.Require().Len(groups, 2) + // Verify all groups are OpenAI platform + for _, g := range groups { + s.Require().Equal(service.PlatformOpenAI, g.Platform) + } } func (s *GroupRepoSuite) TestListWithFilters_Status() { @@ -151,8 +156,17 @@ func (s *GroupRepoSuite) TestListActive() { groups, err := s.repo.ListActive(s.ctx) s.Require().NoError(err, "ListActive") - s.Require().Len(groups, 1) - s.Require().Equal("active1", groups[0].Name) + // 3 default groups (all active) + 1 test active group = 4 total + s.Require().Len(groups, 4) + // Verify our test group is in the results + var found bool + for _, g := range groups { + if g.Name == "active1" { + found = true + break + } + } + s.Require().True(found, "active1 group should be in results") } func (s *GroupRepoSuite) TestListActiveByPlatform() { @@ -162,8 +176,17 @@ func (s *GroupRepoSuite) TestListActiveByPlatform() { groups, err := s.repo.ListActiveByPlatform(s.ctx, service.PlatformAnthropic) s.Require().NoError(err, "ListActiveByPlatform") - s.Require().Len(groups, 1) - s.Require().Equal("g1", groups[0].Name) + // 1 default anthropic group + 1 test active anthropic group = 2 total + s.Require().Len(groups, 2) + // Verify our test group is in the results + var found bool + for _, g := range groups { + if g.Name == "g1" { + found = true + break + } + } + s.Require().True(found, "g1 group should be in results") } // --- ExistsByName --- diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index 06eb2ebf..31e62861 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -59,7 +59,8 @@ func TestAPIContracts(t *testing.T) { "status": "active", "allowed_groups": null, "created_at": "2025-01-02T03:04:05Z", - "updated_at": "2025-01-02T03:04:05Z" + "updated_at": "2025-01-02T03:04:05Z", + "run_mode": "standard" } }`, }, @@ -369,6 +370,7 @@ func newContractDeps(t *testing.T) *contractDeps { Default: config.DefaultConfig{ ApiKeyPrefix: "sk-", }, + RunMode: config.RunModeStandard, } userService := service.NewUserService(userRepo) @@ -380,7 +382,7 @@ func newContractDeps(t *testing.T) *contractDeps { settingRepo := newStubSettingRepo() settingService := service.NewSettingService(settingRepo, cfg) - authHandler := handler.NewAuthHandler(nil, userService) + authHandler := handler.NewAuthHandler(cfg, nil, userService) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) usageHandler := handler.NewUsageHandler(usageService, apiKeyService) adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil) diff --git a/backend/internal/server/http.go b/backend/internal/server/http.go index 88833d63..b64220d9 100644 --- a/backend/internal/server/http.go +++ b/backend/internal/server/http.go @@ -36,7 +36,7 @@ func ProvideRouter( r := gin.New() r.Use(middleware2.Recovery()) - return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService) + return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, cfg) } // ProvideHTTPServer 提供 HTTP 服务器 diff --git a/backend/internal/server/middleware/api_key_auth.go b/backend/internal/server/middleware/api_key_auth.go index c4620d91..75e508dd 100644 --- a/backend/internal/server/middleware/api_key_auth.go +++ b/backend/internal/server/middleware/api_key_auth.go @@ -5,18 +5,19 @@ import ( "log" "strings" + "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" ) // NewApiKeyAuthMiddleware 创建 API Key 认证中间件 -func NewApiKeyAuthMiddleware(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService) ApiKeyAuthMiddleware { - return ApiKeyAuthMiddleware(apiKeyAuthWithSubscription(apiKeyService, subscriptionService)) +func NewApiKeyAuthMiddleware(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) ApiKeyAuthMiddleware { + return ApiKeyAuthMiddleware(apiKeyAuthWithSubscription(apiKeyService, subscriptionService, cfg)) } // apiKeyAuthWithSubscription API Key认证中间件(支持订阅验证) -func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService) gin.HandlerFunc { +func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc { return func(c *gin.Context) { // 尝试从Authorization header中提取API key (Bearer scheme) authHeader := c.GetHeader("Authorization") @@ -85,6 +86,18 @@ func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscripti return } + if cfg.RunMode == config.RunModeSimple { + // 简易模式:跳过余额和订阅检查,但仍需设置必要的上下文 + c.Set(string(ContextKeyApiKey), apiKey) + c.Set(string(ContextKeyUser), AuthSubject{ + UserID: apiKey.User.ID, + Concurrency: apiKey.User.Concurrency, + }) + c.Set(string(ContextKeyUserRole), apiKey.User.Role) + c.Next() + return + } + // 判断计费方式:订阅模式 vs 余额模式 isSubscriptionType := apiKey.Group != nil && apiKey.Group.IsSubscriptionType() diff --git a/backend/internal/server/middleware/api_key_auth_google.go b/backend/internal/server/middleware/api_key_auth_google.go index 199aca82..d8f47bd2 100644 --- a/backend/internal/server/middleware/api_key_auth_google.go +++ b/backend/internal/server/middleware/api_key_auth_google.go @@ -4,6 +4,7 @@ import ( "errors" "strings" + "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/googleapi" "github.com/Wei-Shaw/sub2api/internal/service" @@ -11,15 +12,15 @@ import ( ) // ApiKeyAuthGoogle is a Google-style error wrapper for API key auth. -func ApiKeyAuthGoogle(apiKeyService *service.ApiKeyService) gin.HandlerFunc { - return ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil) +func ApiKeyAuthGoogle(apiKeyService *service.ApiKeyService, cfg *config.Config) gin.HandlerFunc { + return ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, cfg) } // ApiKeyAuthWithSubscriptionGoogle behaves like ApiKeyAuthWithSubscription but returns Google-style errors: // {"error":{"code":401,"message":"...","status":"UNAUTHENTICATED"}} // // It is intended for Gemini native endpoints (/v1beta) to match Gemini SDK expectations. -func ApiKeyAuthWithSubscriptionGoogle(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService) gin.HandlerFunc { +func ApiKeyAuthWithSubscriptionGoogle(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc { return func(c *gin.Context) { apiKeyString := extractAPIKeyFromRequest(c) if apiKeyString == "" { @@ -50,6 +51,18 @@ func ApiKeyAuthWithSubscriptionGoogle(apiKeyService *service.ApiKeyService, subs return } + // 简易模式:跳过余额和订阅检查 + if cfg.RunMode == config.RunModeSimple { + c.Set(string(ContextKeyApiKey), apiKey) + c.Set(string(ContextKeyUser), AuthSubject{ + UserID: apiKey.User.ID, + Concurrency: apiKey.User.Concurrency, + }) + c.Set(string(ContextKeyUserRole), apiKey.User.Role) + c.Next() + return + } + isSubscriptionType := apiKey.Group != nil && apiKey.Group.IsSubscriptionType() if isSubscriptionType && subscriptionService != nil { subscription, err := subscriptionService.GetActiveSubscription( diff --git a/backend/internal/server/middleware/api_key_auth_test.go b/backend/internal/server/middleware/api_key_auth_test.go new file mode 100644 index 00000000..a9d22ede --- /dev/null +++ b/backend/internal/server/middleware/api_key_auth_test.go @@ -0,0 +1,286 @@ +//go:build unit + +package middleware + +import ( + "context" + "errors" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestSimpleModeBypassesQuotaCheck(t *testing.T) { + gin.SetMode(gin.TestMode) + + limit := 1.0 + group := &service.Group{ + ID: 42, + Name: "sub", + Status: service.StatusActive, + SubscriptionType: service.SubscriptionTypeSubscription, + DailyLimitUSD: &limit, + } + user := &service.User{ + ID: 7, + Role: service.RoleUser, + Status: service.StatusActive, + Balance: 10, + Concurrency: 3, + } + apiKey := &service.ApiKey{ + ID: 100, + UserID: user.ID, + Key: "test-key", + Status: service.StatusActive, + User: user, + Group: group, + } + apiKey.GroupID = &group.ID + + apiKeyRepo := &stubApiKeyRepo{ + getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) { + if key != apiKey.Key { + return nil, service.ErrApiKeyNotFound + } + clone := *apiKey + return &clone, nil + }, + } + + t.Run("simple_mode_bypasses_quota_check", func(t *testing.T) { + cfg := &config.Config{RunMode: config.RunModeSimple} + apiKeyService := service.NewApiKeyService(apiKeyRepo, nil, nil, nil, nil, cfg) + subscriptionService := service.NewSubscriptionService(nil, &stubUserSubscriptionRepo{}, nil) + router := newAuthTestRouter(apiKeyService, subscriptionService, cfg) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/t", nil) + req.Header.Set("x-api-key", apiKey.Key) + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code) + }) + + t.Run("standard_mode_enforces_quota_check", func(t *testing.T) { + cfg := &config.Config{RunMode: config.RunModeStandard} + apiKeyService := service.NewApiKeyService(apiKeyRepo, nil, nil, nil, nil, cfg) + + now := time.Now() + sub := &service.UserSubscription{ + ID: 55, + UserID: user.ID, + GroupID: group.ID, + Status: service.SubscriptionStatusActive, + ExpiresAt: now.Add(24 * time.Hour), + DailyWindowStart: &now, + DailyUsageUSD: 10, + } + subscriptionRepo := &stubUserSubscriptionRepo{ + getActive: func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) { + if userID != sub.UserID || groupID != sub.GroupID { + return nil, service.ErrSubscriptionNotFound + } + clone := *sub + return &clone, nil + }, + updateStatus: func(ctx context.Context, subscriptionID int64, status string) error { return nil }, + activateWindow: func(ctx context.Context, id int64, start time.Time) error { return nil }, + resetDaily: func(ctx context.Context, id int64, start time.Time) error { return nil }, + resetWeekly: func(ctx context.Context, id int64, start time.Time) error { return nil }, + resetMonthly: func(ctx context.Context, id int64, start time.Time) error { return nil }, + } + subscriptionService := service.NewSubscriptionService(nil, subscriptionRepo, nil) + router := newAuthTestRouter(apiKeyService, subscriptionService, cfg) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/t", nil) + req.Header.Set("x-api-key", apiKey.Key) + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusTooManyRequests, w.Code) + require.Contains(t, w.Body.String(), "USAGE_LIMIT_EXCEEDED") + }) +} + +func newAuthTestRouter(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) *gin.Engine { + router := gin.New() + router.Use(gin.HandlerFunc(NewApiKeyAuthMiddleware(apiKeyService, subscriptionService, cfg))) + router.GET("/t", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{"ok": true}) + }) + return router +} + +type stubApiKeyRepo struct { + getByKey func(ctx context.Context, key string) (*service.ApiKey, error) +} + +func (r *stubApiKeyRepo) Create(ctx context.Context, key *service.ApiKey) error { + return errors.New("not implemented") +} + +func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) { + return nil, errors.New("not implemented") +} + +func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) { + if r.getByKey != nil { + return r.getByKey(ctx, key) + } + return nil, errors.New("not implemented") +} + +func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.ApiKey) error { + return errors.New("not implemented") +} + +func (r *stubApiKeyRepo) Delete(ctx context.Context, id int64) error { + return errors.New("not implemented") +} + +func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) { + return nil, nil, errors.New("not implemented") +} + +func (r *stubApiKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) { + return nil, errors.New("not implemented") +} + +func (r *stubApiKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64, error) { + return 0, errors.New("not implemented") +} + +func (r *stubApiKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, error) { + return false, errors.New("not implemented") +} + +func (r *stubApiKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) { + return nil, nil, errors.New("not implemented") +} + +func (r *stubApiKeyRepo) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) { + return nil, errors.New("not implemented") +} + +func (r *stubApiKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) { + return 0, errors.New("not implemented") +} + +func (r *stubApiKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) { + return 0, errors.New("not implemented") +} + +type stubUserSubscriptionRepo struct { + getActive func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) + updateStatus func(ctx context.Context, subscriptionID int64, status string) error + activateWindow func(ctx context.Context, id int64, start time.Time) error + resetDaily func(ctx context.Context, id int64, start time.Time) error + resetWeekly func(ctx context.Context, id int64, start time.Time) error + resetMonthly func(ctx context.Context, id int64, start time.Time) error +} + +func (r *stubUserSubscriptionRepo) Create(ctx context.Context, sub *service.UserSubscription) error { + return errors.New("not implemented") +} + +func (r *stubUserSubscriptionRepo) GetByID(ctx context.Context, id int64) (*service.UserSubscription, error) { + return nil, errors.New("not implemented") +} + +func (r *stubUserSubscriptionRepo) GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) { + return nil, errors.New("not implemented") +} + +func (r *stubUserSubscriptionRepo) GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) { + if r.getActive != nil { + return r.getActive(ctx, userID, groupID) + } + return nil, errors.New("not implemented") +} + +func (r *stubUserSubscriptionRepo) Update(ctx context.Context, sub *service.UserSubscription) error { + return errors.New("not implemented") +} + +func (r *stubUserSubscriptionRepo) Delete(ctx context.Context, id int64) error { + return errors.New("not implemented") +} + +func (r *stubUserSubscriptionRepo) ListByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) { + return nil, errors.New("not implemented") +} + +func (r *stubUserSubscriptionRepo) ListActiveByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) { + return nil, errors.New("not implemented") +} + +func (r *stubUserSubscriptionRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) { + return nil, nil, errors.New("not implemented") +} + +func (r *stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]service.UserSubscription, *pagination.PaginationResult, error) { + return nil, nil, errors.New("not implemented") +} + +func (r *stubUserSubscriptionRepo) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) { + return false, errors.New("not implemented") +} + +func (r *stubUserSubscriptionRepo) ExtendExpiry(ctx context.Context, subscriptionID int64, newExpiresAt time.Time) error { + return errors.New("not implemented") +} + +func (r *stubUserSubscriptionRepo) UpdateStatus(ctx context.Context, subscriptionID int64, status string) error { + if r.updateStatus != nil { + return r.updateStatus(ctx, subscriptionID, status) + } + return errors.New("not implemented") +} + +func (r *stubUserSubscriptionRepo) UpdateNotes(ctx context.Context, subscriptionID int64, notes string) error { + return errors.New("not implemented") +} + +func (r *stubUserSubscriptionRepo) ActivateWindows(ctx context.Context, id int64, start time.Time) error { + if r.activateWindow != nil { + return r.activateWindow(ctx, id, start) + } + return errors.New("not implemented") +} + +func (r *stubUserSubscriptionRepo) ResetDailyUsage(ctx context.Context, id int64, newWindowStart time.Time) error { + if r.resetDaily != nil { + return r.resetDaily(ctx, id, newWindowStart) + } + return errors.New("not implemented") +} + +func (r *stubUserSubscriptionRepo) ResetWeeklyUsage(ctx context.Context, id int64, newWindowStart time.Time) error { + if r.resetWeekly != nil { + return r.resetWeekly(ctx, id, newWindowStart) + } + return errors.New("not implemented") +} + +func (r *stubUserSubscriptionRepo) ResetMonthlyUsage(ctx context.Context, id int64, newWindowStart time.Time) error { + if r.resetMonthly != nil { + return r.resetMonthly(ctx, id, newWindowStart) + } + return errors.New("not implemented") +} + +func (r *stubUserSubscriptionRepo) IncrementUsage(ctx context.Context, id int64, costUSD float64) error { + return errors.New("not implemented") +} + +func (r *stubUserSubscriptionRepo) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) { + return 0, errors.New("not implemented") +} diff --git a/backend/internal/server/router.go b/backend/internal/server/router.go index 5489468b..2371dafb 100644 --- a/backend/internal/server/router.go +++ b/backend/internal/server/router.go @@ -1,6 +1,7 @@ package server import ( + "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/handler" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/server/routes" @@ -19,6 +20,7 @@ func SetupRouter( apiKeyAuth middleware2.ApiKeyAuthMiddleware, apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, + cfg *config.Config, ) *gin.Engine { // 应用中间件 r.Use(middleware2.Logger()) @@ -30,7 +32,7 @@ func SetupRouter( } // 注册路由 - registerRoutes(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService) + registerRoutes(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, cfg) return r } @@ -44,6 +46,7 @@ func registerRoutes( apiKeyAuth middleware2.ApiKeyAuthMiddleware, apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, + cfg *config.Config, ) { // 通用路由(健康检查、状态等) routes.RegisterCommonRoutes(r) @@ -55,5 +58,5 @@ func registerRoutes( routes.RegisterAuthRoutes(v1, h, jwtAuth) routes.RegisterUserRoutes(v1, h, jwtAuth) routes.RegisterAdminRoutes(v1, h, adminAuth) - routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService) + routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, cfg) } diff --git a/backend/internal/server/routes/gateway.go b/backend/internal/server/routes/gateway.go index 2bf388f8..34792be8 100644 --- a/backend/internal/server/routes/gateway.go +++ b/backend/internal/server/routes/gateway.go @@ -1,6 +1,7 @@ package routes import ( + "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/handler" "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" @@ -15,6 +16,7 @@ func RegisterGatewayRoutes( apiKeyAuth middleware.ApiKeyAuthMiddleware, apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, + cfg *config.Config, ) { // API网关(Claude API兼容) gateway := r.Group("/v1") @@ -30,7 +32,7 @@ func RegisterGatewayRoutes( // Gemini 原生 API 兼容层(Gemini SDK/CLI 直连) gemini := r.Group("/v1beta") - gemini.Use(middleware.ApiKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService)) + gemini.Use(middleware.ApiKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg)) { gemini.GET("/models", h.Gateway.GeminiV1BetaListModels) gemini.GET("/models/:model", h.Gateway.GeminiV1BetaGetModel) @@ -54,7 +56,7 @@ func RegisterGatewayRoutes( antigravityV1Beta := r.Group("/antigravity/v1beta") antigravityV1Beta.Use(middleware.ForcePlatform(service.PlatformAntigravity)) - antigravityV1Beta.Use(middleware.ApiKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService)) + antigravityV1Beta.Use(middleware.ApiKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg)) { antigravityV1Beta.GET("/models", h.Gateway.GeminiV1BetaListModels) antigravityV1Beta.GET("/models/:model", h.Gateway.GeminiV1BetaGetModel) diff --git a/backend/internal/service/account_usage_service.go b/backend/internal/service/account_usage_service.go index 575e72b1..94d4c747 100644 --- a/backend/internal/service/account_usage_service.go +++ b/backend/internal/service/account_usage_service.go @@ -54,15 +54,23 @@ type UsageLogRepository interface { GetApiKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) } -// usageCache 用于缓存usage数据 -type usageCache struct { - data *UsageInfo +// apiUsageCache 缓存从 Anthropic API 获取的使用率数据(utilization, resets_at) +type apiUsageCache struct { + response *ClaudeUsageResponse + timestamp time.Time +} + +// windowStatsCache 缓存从本地数据库查询的窗口统计(requests, tokens, cost) +type windowStatsCache struct { + stats *WindowStats timestamp time.Time } var ( - usageCacheMap = sync.Map{} - cacheTTL = 10 * time.Minute + apiCacheMap = sync.Map{} // 缓存 API 响应 + windowStatsCacheMap = sync.Map{} // 缓存窗口统计 + apiCacheTTL = 10 * time.Minute + windowStatsCacheTTL = 1 * time.Minute ) // WindowStats 窗口期统计 @@ -126,7 +134,7 @@ func NewAccountUsageService(accountRepo AccountRepository, usageLogRepo UsageLog } // GetUsage 获取账号使用量 -// OAuth账号: 调用Anthropic API获取真实数据(需要profile scope),缓存10分钟 +// OAuth账号: 调用Anthropic API获取真实数据(需要profile scope),API响应缓存10分钟,窗口统计缓存1分钟 // Setup Token账号: 根据session_window推算5h窗口,7d数据不可用(没有profile scope) // API Key账号: 不支持usage查询 func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*UsageInfo, error) { @@ -137,30 +145,34 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U // 只有oauth类型账号可以通过API获取usage(有profile scope) if account.CanGetUsage() { - // 检查缓存 - if cached, ok := usageCacheMap.Load(accountID); ok { - cache, ok := cached.(*usageCache) - if !ok { - usageCacheMap.Delete(accountID) - } else if time.Since(cache.timestamp) < cacheTTL { - return cache.data, nil + var apiResp *ClaudeUsageResponse + + // 1. 检查 API 缓存(10 分钟) + if cached, ok := apiCacheMap.Load(accountID); ok { + if cache, ok := cached.(*apiUsageCache); ok && time.Since(cache.timestamp) < apiCacheTTL { + apiResp = cache.response } } - // 从API获取数据 - usage, err := s.fetchOAuthUsage(ctx, account) - if err != nil { - return nil, err + // 2. 如果没有缓存,从 API 获取 + if apiResp == nil { + apiResp, err = s.fetchOAuthUsageRaw(ctx, account) + if err != nil { + return nil, err + } + // 缓存 API 响应 + apiCacheMap.Store(accountID, &apiUsageCache{ + response: apiResp, + timestamp: time.Now(), + }) } - // 添加5h窗口统计数据 - s.addWindowStats(ctx, account, usage) + // 3. 构建 UsageInfo(每次都重新计算 RemainingSeconds) + now := time.Now() + usage := s.buildUsageInfo(apiResp, &now) - // 缓存结果 - usageCacheMap.Store(accountID, &usageCache{ - data: usage, - timestamp: time.Now(), - }) + // 4. 添加窗口统计(有独立缓存,1 分钟) + s.addWindowStats(ctx, account, usage) return usage, nil } @@ -177,31 +189,54 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U return nil, fmt.Errorf("account type %s does not support usage query", account.Type) } -// addWindowStats 为usage数据添加窗口期统计 +// addWindowStats 为 usage 数据添加窗口期统计 +// 使用独立缓存(1 分钟),与 API 缓存分离 func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Account, usage *UsageInfo) { - if usage.FiveHour == nil { + // 修复:即使 FiveHour 为 nil,也要尝试获取统计数据 + // 因为 SevenDay/SevenDaySonnet 可能需要 + if usage.FiveHour == nil && usage.SevenDay == nil && usage.SevenDaySonnet == nil { return } - // 使用session_window_start作为统计起始时间 - var startTime time.Time - if account.SessionWindowStart != nil { - startTime = *account.SessionWindowStart - } else { - // 如果没有窗口信息,使用5小时前作为默认 - startTime = time.Now().Add(-5 * time.Hour) + // 检查窗口统计缓存(1 分钟) + var windowStats *WindowStats + if cached, ok := windowStatsCacheMap.Load(account.ID); ok { + if cache, ok := cached.(*windowStatsCache); ok && time.Since(cache.timestamp) < windowStatsCacheTTL { + windowStats = cache.stats + } } - stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, startTime) - if err != nil { - log.Printf("Failed to get window stats for account %d: %v", account.ID, err) - return + // 如果没有缓存,从数据库查询 + if windowStats == nil { + var startTime time.Time + if account.SessionWindowStart != nil { + startTime = *account.SessionWindowStart + } else { + startTime = time.Now().Add(-5 * time.Hour) + } + + stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, startTime) + if err != nil { + log.Printf("Failed to get window stats for account %d: %v", account.ID, err) + return + } + + windowStats = &WindowStats{ + Requests: stats.Requests, + Tokens: stats.Tokens, + Cost: stats.Cost, + } + + // 缓存窗口统计(1 分钟) + windowStatsCacheMap.Store(account.ID, &windowStatsCache{ + stats: windowStats, + timestamp: time.Now(), + }) } - usage.FiveHour.WindowStats = &WindowStats{ - Requests: stats.Requests, - Tokens: stats.Tokens, - Cost: stats.Cost, + // 为 FiveHour 添加 WindowStats(5h 窗口统计) + if usage.FiveHour != nil { + usage.FiveHour.WindowStats = windowStats } } @@ -227,8 +262,8 @@ func (s *AccountUsageService) GetAccountUsageStats(ctx context.Context, accountI return stats, nil } -// fetchOAuthUsage 从Anthropic API获取OAuth账号的使用量 -func (s *AccountUsageService) fetchOAuthUsage(ctx context.Context, account *Account) (*UsageInfo, error) { +// fetchOAuthUsageRaw 从 Anthropic API 获取原始响应(不构建 UsageInfo) +func (s *AccountUsageService) fetchOAuthUsageRaw(ctx context.Context, account *Account) (*ClaudeUsageResponse, error) { accessToken := account.GetCredential("access_token") if accessToken == "" { return nil, fmt.Errorf("no access token available") @@ -239,13 +274,7 @@ func (s *AccountUsageService) fetchOAuthUsage(ctx context.Context, account *Acco proxyURL = account.Proxy.URL() } - usageResp, err := s.usageFetcher.FetchUsage(ctx, accessToken, proxyURL) - if err != nil { - return nil, err - } - - now := time.Now() - return s.buildUsageInfo(usageResp, &now), nil + return s.usageFetcher.FetchUsage(ctx, accessToken, proxyURL) } // parseTime 尝试多种格式解析时间 @@ -270,20 +299,16 @@ func (s *AccountUsageService) buildUsageInfo(resp *ClaudeUsageResponse, updatedA UpdatedAt: updatedAt, } - // 5小时窗口 + // 5小时窗口 - 始终创建对象(即使 ResetsAt 为空) + info.FiveHour = &UsageProgress{ + Utilization: resp.FiveHour.Utilization, + } if resp.FiveHour.ResetsAt != "" { if fiveHourReset, err := parseTime(resp.FiveHour.ResetsAt); err == nil { - info.FiveHour = &UsageProgress{ - Utilization: resp.FiveHour.Utilization, - ResetsAt: &fiveHourReset, - RemainingSeconds: int(time.Until(fiveHourReset).Seconds()), - } + info.FiveHour.ResetsAt = &fiveHourReset + info.FiveHour.RemainingSeconds = int(time.Until(fiveHourReset).Seconds()) } else { log.Printf("Failed to parse FiveHour.ResetsAt: %s, error: %v", resp.FiveHour.ResetsAt, err) - // 即使解析失败也返回utilization - info.FiveHour = &UsageProgress{ - Utilization: resp.FiveHour.Utilization, - } } } diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index db207ce5..9ffd342d 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -609,12 +609,30 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou if err := s.accountRepo.Create(ctx, account); err != nil { return nil, err } + // 绑定分组 - if len(input.GroupIDs) > 0 { - if err := s.accountRepo.BindGroups(ctx, account.ID, input.GroupIDs); err != nil { + groupIDs := input.GroupIDs + // 如果没有指定分组,自动绑定对应平台的默认分组 + if len(groupIDs) == 0 { + defaultGroupName := input.Platform + "-default" + groups, err := s.groupRepo.ListActiveByPlatform(ctx, input.Platform) + if err == nil { + for _, g := range groups { + if g.Name == defaultGroupName { + groupIDs = []int64{g.ID} + log.Printf("[CreateAccount] Auto-binding account %d to default group %s (ID: %d)", account.ID, defaultGroupName, g.ID) + break + } + } + } + } + + if len(groupIDs) > 0 { + if err := s.accountRepo.BindGroups(ctx, account.ID, groupIDs); err != nil { return nil, err } } + return account, nil } diff --git a/backend/internal/service/billing_cache_service.go b/backend/internal/service/billing_cache_service.go index 18f125ca..9493a11f 100644 --- a/backend/internal/service/billing_cache_service.go +++ b/backend/internal/service/billing_cache_service.go @@ -6,6 +6,7 @@ import ( "log" "time" + "github.com/Wei-Shaw/sub2api/internal/config" infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors" ) @@ -32,14 +33,16 @@ type BillingCacheService struct { cache BillingCache userRepo UserRepository subRepo UserSubscriptionRepository + cfg *config.Config } // NewBillingCacheService 创建计费缓存服务 -func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo UserSubscriptionRepository) *BillingCacheService { +func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo UserSubscriptionRepository, cfg *config.Config) *BillingCacheService { return &BillingCacheService{ cache: cache, userRepo: userRepo, subRepo: subRepo, + cfg: cfg, } } @@ -224,6 +227,11 @@ func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID // 余额模式:检查缓存余额 > 0 // 订阅模式:检查缓存用量未超过限额(Group限额从参数传入) func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user *User, apiKey *ApiKey, group *Group, subscription *UserSubscription) error { + // 简易模式:跳过所有计费检查 + if s.cfg.RunMode == config.RunModeSimple { + return nil + } + // 判断计费模式 isSubscriptionMode := group != nil && group.IsSubscriptionType() && subscription != nil diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 641962ea..08e3c1d1 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -357,7 +357,10 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, // 2. 获取可调度账号列表(单平台) var accounts []Account var err error - if groupID != nil { + if s.cfg.RunMode == config.RunModeSimple { + // 简易模式:忽略 groupID,查询所有可用账号 + accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform) + } else if groupID != nil { accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, platform) } else { accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform) @@ -1226,6 +1229,12 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu log.Printf("Create usage log failed: %v", err) } + if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { + log.Printf("[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens()) + s.deferredService.ScheduleLastUsedUpdate(account.ID) + return nil + } + // 根据计费类型执行扣费 if isSubscriptionBilling { // 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率) diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 20bf57f2..79801b29 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -10,6 +10,7 @@ import ( "errors" "fmt" "io" + "log" "net/http" "regexp" "strconv" @@ -155,7 +156,10 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C // 2. Get schedulable OpenAI accounts var accounts []Account var err error - if groupID != nil { + // 简易模式:忽略分组限制,查询所有可用账号 + if s.cfg.RunMode == config.RunModeSimple { + accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformOpenAI) + } else if groupID != nil { accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformOpenAI) } else { accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformOpenAI) @@ -754,6 +758,12 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec _ = s.usageLogRepo.Create(ctx, usageLog) + if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { + log.Printf("[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens()) + s.deferredService.ScheduleLastUsedUpdate(account.ID) + return nil + } + // Deduct based on billing type if isSubscriptionBilling { if cost.TotalCost > 0 { diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go index 3ff47e7d..6b190cf3 100644 --- a/backend/internal/service/user_service.go +++ b/backend/internal/service/user_service.go @@ -164,6 +164,14 @@ func (s *UserService) UpdateBalance(ctx context.Context, userID int64, amount fl return nil } +// UpdateConcurrency 更新用户并发数(管理员功能) +func (s *UserService) UpdateConcurrency(ctx context.Context, userID int64, concurrency int) error { + if err := s.userRepo.UpdateConcurrency(ctx, userID, concurrency); err != nil { + return fmt.Errorf("update concurrency: %w", err) + } + return nil +} + // UpdateStatus 更新用户状态(管理员功能) func (s *UserService) UpdateStatus(ctx context.Context, userID int64, status string) error { user, err := s.userRepo.GetByID(ctx, userID) diff --git a/deploy/.env.example b/deploy/.env.example index de7ea722..19fcc853 100644 --- a/deploy/.env.example +++ b/deploy/.env.example @@ -20,6 +20,10 @@ SERVER_PORT=8080 # Server mode: release or debug SERVER_MODE=release +# 运行模式: standard (默认) 或 simple (内部自用) +# standard: 完整 SaaS 功能,包含计费/余额校验;simple: 隐藏 SaaS 功能并跳过计费/余额校验 +RUN_MODE=standard + # Timezone TZ=Asia/Shanghai diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index b6df4f65..fcaa7b7c 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -13,6 +13,14 @@ server: # Mode: "debug" for development, "release" for production mode: "release" +# ============================================================================= +# Run Mode Configuration +# ============================================================================= +# Run mode: "standard" (default) or "simple" (for internal use) +# - standard: Full SaaS features with billing/balance checks +# - simple: Hides SaaS features and skips billing/balance checks +run_mode: "standard" + # ============================================================================= # Database Configuration (PostgreSQL) # ============================================================================= diff --git a/deploy/docker-compose.yml b/deploy/docker-compose.yml index 9e10ec54..0e3fb16e 100644 --- a/deploy/docker-compose.yml +++ b/deploy/docker-compose.yml @@ -36,6 +36,7 @@ services: - SERVER_HOST=0.0.0.0 - SERVER_PORT=8080 - SERVER_MODE=${SERVER_MODE:-release} + - RUN_MODE=${RUN_MODE:-standard} # ======================================================================= # Database Configuration (PostgreSQL) diff --git a/frontend/src/api/auth.ts b/frontend/src/api/auth.ts index ccac8a77..9c5379f2 100644 --- a/frontend/src/api/auth.ts +++ b/frontend/src/api/auth.ts @@ -8,7 +8,7 @@ import type { LoginRequest, RegisterRequest, AuthResponse, - User, + CurrentUserResponse, SendVerifyCodeRequest, SendVerifyCodeResponse, PublicSettings @@ -70,9 +70,8 @@ export async function register(userData: RegisterRequest): Promise * Get current authenticated user * @returns User profile data */ -export async function getCurrentUser(): Promise { - const { data } = await apiClient.get('/auth/me') - return data +export async function getCurrentUser() { + return apiClient.get('/auth/me') } /** diff --git a/frontend/src/components/account/CreateAccountModal.vue b/frontend/src/components/account/CreateAccountModal.vue index 1b12df01..84cd8ed7 100644 --- a/frontend/src/components/account/CreateAccountModal.vue +++ b/frontend/src/components/account/CreateAccountModal.vue @@ -585,7 +585,7 @@ : 'https://api.anthropic.com' " /> -

{{ t('admin.accounts.baseUrlHint') }}

+

{{ baseUrlHint }}

@@ -602,13 +602,7 @@ : 'sk-ant-...' " /> -

- {{ - form.platform === 'gemini' - ? t('admin.accounts.gemini.apiKeyHint') - : t('admin.accounts.apiKeyHint') - }} -

+

{{ apiKeyHint }}

@@ -1055,8 +1049,9 @@ - + { if (form.platform === 'openai') return t('admin.accounts.oauth.openai.title') @@ -1207,6 +1204,19 @@ const oauthStepTitle = computed(() => { return t('admin.accounts.oauth.title') }) +// Platform-specific hints for API Key type +const baseUrlHint = computed(() => { + if (form.platform === 'openai') return t('admin.accounts.openai.baseUrlHint') + if (form.platform === 'gemini') return t('admin.accounts.gemini.baseUrlHint') + return t('admin.accounts.baseUrlHint') +}) + +const apiKeyHint = computed(() => { + if (form.platform === 'openai') return t('admin.accounts.openai.apiKeyHint') + if (form.platform === 'gemini') return t('admin.accounts.gemini.apiKeyHint') + return t('admin.accounts.apiKeyHint') +}) + interface Props { show: boolean proxies: Proxy[] diff --git a/frontend/src/components/account/EditAccountModal.vue b/frontend/src/components/account/EditAccountModal.vue index 11b01b17..f678cc6f 100644 --- a/frontend/src/components/account/EditAccountModal.vue +++ b/frontend/src/components/account/EditAccountModal.vue @@ -32,7 +32,7 @@ : 'https://api.anthropic.com' " /> -

{{ t('admin.accounts.baseUrlHint') }}

+

{{ baseUrlHint }}

@@ -497,8 +497,9 @@
- + { + if (!props.account) return t('admin.accounts.baseUrlHint') + if (props.account.platform === 'openai') return t('admin.accounts.openai.baseUrlHint') + if (props.account.platform === 'gemini') return t('admin.accounts.gemini.baseUrlHint') + return t('admin.accounts.baseUrlHint') +}) // Model mapping type interface ModelMapping { diff --git a/frontend/src/components/common/Select.vue b/frontend/src/components/common/Select.vue index 71a41431..725aa1f3 100644 --- a/frontend/src/components/common/Select.vue +++ b/frontend/src/components/common/Select.vue @@ -297,7 +297,7 @@ onUnmounted(() => { } .select-dropdown { - @apply absolute z-[100] mt-2 w-full; + @apply absolute left-0 z-[100] mt-2 min-w-full w-max max-w-[300px]; @apply bg-white dark:bg-dark-800; @apply rounded-xl; @apply border border-gray-200 dark:border-dark-700; @@ -339,7 +339,7 @@ onUnmounted(() => { } .select-option-label { - @apply truncate; + @apply flex-1 min-w-0 truncate text-left; } .select-empty { diff --git a/frontend/src/components/layout/AppSidebar.vue b/frontend/src/components/layout/AppSidebar.vue index cfbd7c14..f4c0338a 100644 --- a/frontend/src/components/layout/AppSidebar.vue +++ b/frontend/src/components/layout/AppSidebar.vue @@ -45,8 +45,8 @@ - -
-

- {{ t('admin.groups.subscription.title') }} -

- -
+
{ ] }) +// Helper function to format date in local timezone +const formatLocalDate = (date: Date): string => { + return `${date.getFullYear()}-${String(date.getMonth() + 1).padStart(2, '0')}-${String(date.getDate()).padStart(2, '0')}` +} + +// Initialize date range immediately +const now = new Date() +const weekAgo = new Date(now) +weekAgo.setDate(weekAgo.getDate() - 6) + // Date range state -const startDate = ref('') -const endDate = ref('') +const startDate = ref(formatLocalDate(weekAgo)) +const endDate = ref(formatLocalDate(now)) const filters = ref({ user_id: undefined, @@ -752,18 +762,9 @@ const filters = ref({ end_date: undefined }) -// Initialize default date range (last 7 days) -const initializeDateRange = () => { - const now = new Date() - const today = now.toISOString().split('T')[0] - const weekAgo = new Date(now) - weekAgo.setDate(weekAgo.getDate() - 6) - - startDate.value = weekAgo.toISOString().split('T')[0] - endDate.value = today - filters.value.start_date = startDate.value - filters.value.end_date = endDate.value -} +// Initialize filters with date range +filters.value.start_date = startDate.value +filters.value.end_date = endDate.value // User search with debounce const debounceSearchUsers = () => { @@ -988,9 +989,12 @@ const loadModelOptions = async () => { const endDate = new Date() const startDateRange = new Date(endDate) startDateRange.setDate(startDateRange.getDate() - 29) + // Use local timezone instead of UTC + const endDateStr = `${endDate.getFullYear()}-${String(endDate.getMonth() + 1).padStart(2, '0')}-${String(endDate.getDate()).padStart(2, '0')}` + const startDateStr = `${startDateRange.getFullYear()}-${String(startDateRange.getMonth() + 1).padStart(2, '0')}-${String(startDateRange.getDate()).padStart(2, '0')}` const response = await adminAPI.dashboard.getModelStats({ - start_date: startDateRange.toISOString().split('T')[0], - end_date: endDate.toISOString().split('T')[0] + start_date: startDateStr, + end_date: endDateStr }) const uniqueModels = new Set() response.models?.forEach((stat) => { @@ -1022,7 +1026,13 @@ const resetFilters = () => { } granularity.value = 'day' // Reset date range to default (last 7 days) - initializeDateRange() + const now = new Date() + const weekAgo = new Date(now) + weekAgo.setDate(weekAgo.getDate() - 6) + startDate.value = formatLocalDate(weekAgo) + endDate.value = formatLocalDate(now) + filters.value.start_date = startDate.value + filters.value.end_date = endDate.value pagination.value.page = 1 loadApiKeys() loadUsageLogs() @@ -1114,7 +1124,6 @@ const hideTooltip = () => { } onMounted(() => { - initializeDateRange() loadFilterOptions() loadApiKeys() loadUsageLogs() diff --git a/frontend/src/views/user/DashboardView.vue b/frontend/src/views/user/DashboardView.vue index d660e1a0..1ef4f0d2 100644 --- a/frontend/src/views/user/DashboardView.vue +++ b/frontend/src/views/user/DashboardView.vue @@ -10,7 +10,7 @@
-
+
(null) // Recent usage const recentUsage = ref([]) +// Helper function to format date in local timezone +const formatLocalDate = (date: Date): string => { + return `${date.getFullYear()}-${String(date.getMonth() + 1).padStart(2, '0')}-${String(date.getDate()).padStart(2, '0')}` +} + +// Initialize date range immediately (not in onMounted) +const now = new Date() +const weekAgo = new Date(now) +weekAgo.setDate(weekAgo.getDate() - 6) + // Date range const granularity = ref<'day' | 'hour'>('day') -const startDate = ref('') -const endDate = ref('') +const startDate = ref(formatLocalDate(weekAgo)) +const endDate = ref(formatLocalDate(now)) // Granularity options for Select component const granularityOptions = computed(() => [ @@ -963,18 +973,6 @@ const onDateRangeChange = (range: { loadChartData() } -// Initialize default date range -const initializeDateRange = () => { - const now = new Date() - const today = now.toISOString().split('T')[0] - const weekAgo = new Date(now) - weekAgo.setDate(weekAgo.getDate() - 6) - - startDate.value = weekAgo.toISOString().split('T')[0] - endDate.value = today - granularity.value = 'day' -} - // Load data const loadDashboardStats = async () => { loading.value = true @@ -1015,8 +1013,11 @@ const loadChartData = async () => { const loadRecentUsage = async () => { loadingUsage.value = true try { - const endDate = new Date().toISOString().split('T')[0] - const startDate = new Date(Date.now() - 7 * 24 * 60 * 60 * 1000).toISOString().split('T')[0] + // Use local timezone instead of UTC + const now = new Date() + const endDate = `${now.getFullYear()}-${String(now.getMonth() + 1).padStart(2, '0')}-${String(now.getDate()).padStart(2, '0')}` + const weekAgo = new Date(Date.now() - 7 * 24 * 60 * 60 * 1000) + const startDate = `${weekAgo.getFullYear()}-${String(weekAgo.getMonth() + 1).padStart(2, '0')}-${String(weekAgo.getDate()).padStart(2, '0')}` const usageResponse = await usageAPI.getByDateRange(startDate, endDate) recentUsage.value = usageResponse.items.slice(0, 5) } catch (error) { @@ -1035,9 +1036,6 @@ onMounted(async () => { console.error('Failed to refresh subscription status:', error) }) - // Initialize date range (synchronous) - initializeDateRange() - // Load chart data and recent usage in parallel (non-critical) Promise.all([loadChartData(), loadRecentUsage()]).catch((error) => { console.error('Error loading secondary data:', error) diff --git a/frontend/src/views/user/UsageView.vue b/frontend/src/views/user/UsageView.vue index b326b4c5..53e6aa0b 100644 --- a/frontend/src/views/user/UsageView.vue +++ b/frontend/src/views/user/UsageView.vue @@ -488,9 +488,19 @@ const apiKeyOptions = computed(() => { ] }) +// Helper function to format date in local timezone +const formatLocalDate = (date: Date): string => { + return `${date.getFullYear()}-${String(date.getMonth() + 1).padStart(2, '0')}-${String(date.getDate()).padStart(2, '0')}` +} + +// Initialize date range immediately +const now = new Date() +const weekAgo = new Date(now) +weekAgo.setDate(weekAgo.getDate() - 6) + // Date range state -const startDate = ref('') -const endDate = ref('') +const startDate = ref(formatLocalDate(weekAgo)) +const endDate = ref(formatLocalDate(now)) const filters = ref({ api_key_id: undefined, @@ -498,18 +508,9 @@ const filters = ref({ end_date: undefined }) -// Initialize default date range (last 7 days) -const initializeDateRange = () => { - const now = new Date() - const today = now.toISOString().split('T')[0] - const weekAgo = new Date(now) - weekAgo.setDate(weekAgo.getDate() - 6) - - startDate.value = weekAgo.toISOString().split('T')[0] - endDate.value = today - filters.value.start_date = startDate.value - filters.value.end_date = endDate.value -} +// Initialize filters with date range +filters.value.start_date = startDate.value +filters.value.end_date = endDate.value // Handle date range change from DateRangePicker const onDateRangeChange = (range: { @@ -629,7 +630,13 @@ const resetFilters = () => { end_date: undefined } // Reset date range to default (last 7 days) - initializeDateRange() + const now = new Date() + const weekAgo = new Date(now) + weekAgo.setDate(weekAgo.getDate() - 6) + startDate.value = formatLocalDate(weekAgo) + endDate.value = formatLocalDate(now) + filters.value.start_date = startDate.value + filters.value.end_date = endDate.value pagination.page = 1 loadUsageLogs() loadUsageStats() @@ -772,7 +779,6 @@ const hideTooltip = () => { } onMounted(() => { - initializeDateRange() loadApiKeys() loadUsageLogs() loadUsageStats()