feat(全栈): 实现简易模式核心功能
**功能概述**: 实现简易模式(Simple Mode),为个人用户和小团队提供简化的使用体验,隐藏复杂的分组、订阅、配额等概念。 **后端改动**: 1. 配置系统 - 新增 run_mode 配置项(standard/simple) - 支持环境变量 RUN_MODE - 默认值为 standard 2. 数据库初始化 - 自动创建3个默认分组:anthropic-default、openai-default、gemini-default - 默认分组配置:无并发限制、active状态、非独占 - 幂等性保证:重复启动不会重复创建 3. 账号管理 - 创建账号时自动绑定对应平台的默认分组 - 如果未指定分组,自动查找并绑定默认分组 **前端改动**: 1. 状态管理 - authStore 新增 isSimpleMode 计算属性 - 从后端API获取并同步运行模式 2. UI隐藏 - 侧边栏:隐藏分组管理、订阅管理、兑换码菜单 - 账号管理页面:隐藏分组列 - 创建/编辑账号对话框:隐藏分组选择器 3. 路由守卫 - 限制访问分组、订阅、兑换码相关页面 - 访问受限页面时自动重定向到仪表板 **配置示例**: ```yaml run_mode: simple run_mode: standard ``` **影响范围**: - 后端:配置、数据库迁移、账号服务 - 前端:认证状态、路由、UI组件 - 部署:配置文件示例 **兼容性**: - 简易模式和标准模式可无缝切换 - 不需要数据迁移 - 现有数据不受影响
This commit is contained in:
@@ -107,6 +107,14 @@ func runSetupServer() {
|
||||
}
|
||||
|
||||
func runMainServer() {
|
||||
cfg, err := config.Load()
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
if cfg.RunMode == config.RunModeSimple {
|
||||
log.Println("⚠️ WARNING: Running in SIMPLE mode - billing and quota checks are DISABLED")
|
||||
}
|
||||
|
||||
buildInfo := handler.BuildInfo{
|
||||
Version: Version,
|
||||
BuildType: BuildType,
|
||||
|
||||
@@ -49,7 +49,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
emailQueueService := service.ProvideEmailQueueService(emailService)
|
||||
authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService)
|
||||
userService := service.NewUserService(userRepository)
|
||||
authHandler := handler.NewAuthHandler(authService, userService)
|
||||
authHandler := handler.NewAuthHandler(configConfig, authService, userService)
|
||||
userHandler := handler.NewUserHandler(userService)
|
||||
apiKeyRepository := repository.NewApiKeyRepository(db)
|
||||
groupRepository := repository.NewGroupRepository(db)
|
||||
@@ -62,7 +62,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
||||
redeemCodeRepository := repository.NewRedeemCodeRepository(db)
|
||||
billingCache := repository.NewBillingCache(client)
|
||||
billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository)
|
||||
billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository, configConfig)
|
||||
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService)
|
||||
redeemCache := repository.NewRedeemCache(client)
|
||||
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService)
|
||||
@@ -128,7 +128,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler)
|
||||
jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService)
|
||||
adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService)
|
||||
apiKeyAuthMiddleware := middleware.NewApiKeyAuthMiddleware(apiKeyService, subscriptionService)
|
||||
apiKeyAuthMiddleware := middleware.NewApiKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig)
|
||||
engine := server.ProvideRouter(configConfig, handlers, jwtAuthMiddleware, adminAuthMiddleware, apiKeyAuthMiddleware, apiKeyService, subscriptionService)
|
||||
httpServer := server.ProvideHTTPServer(configConfig, engine)
|
||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
|
||||
|
||||
@@ -7,6 +7,11 @@ import (
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
const (
|
||||
RunModeStandard = "standard"
|
||||
RunModeSimple = "simple"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Server ServerConfig `mapstructure:"server"`
|
||||
Database DatabaseConfig `mapstructure:"database"`
|
||||
@@ -17,6 +22,7 @@ type Config struct {
|
||||
Pricing PricingConfig `mapstructure:"pricing"`
|
||||
Gateway GatewayConfig `mapstructure:"gateway"`
|
||||
TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"`
|
||||
RunMode string `mapstructure:"run_mode" yaml:"run_mode"`
|
||||
Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC"
|
||||
Gemini GeminiConfig `mapstructure:"gemini"`
|
||||
}
|
||||
@@ -135,6 +141,16 @@ type RateLimitConfig struct {
|
||||
OverloadCooldownMinutes int `mapstructure:"overload_cooldown_minutes"` // 529过载冷却时间(分钟)
|
||||
}
|
||||
|
||||
func NormalizeRunMode(value string) string {
|
||||
normalized := strings.ToLower(strings.TrimSpace(value))
|
||||
switch normalized {
|
||||
case RunModeStandard, RunModeSimple:
|
||||
return normalized
|
||||
default:
|
||||
return RunModeStandard
|
||||
}
|
||||
}
|
||||
|
||||
func Load() (*Config, error) {
|
||||
viper.SetConfigName("config")
|
||||
viper.SetConfigType("yaml")
|
||||
@@ -161,6 +177,8 @@ func Load() (*Config, error) {
|
||||
return nil, fmt.Errorf("unmarshal config error: %w", err)
|
||||
}
|
||||
|
||||
cfg.RunMode = NormalizeRunMode(cfg.RunMode)
|
||||
|
||||
if err := cfg.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("validate config error: %w", err)
|
||||
}
|
||||
@@ -169,6 +187,8 @@ func Load() (*Config, error) {
|
||||
}
|
||||
|
||||
func setDefaults() {
|
||||
viper.SetDefault("run_mode", RunModeStandard)
|
||||
|
||||
// Server
|
||||
viper.SetDefault("server.host", "0.0.0.0")
|
||||
viper.SetDefault("server.port", 8080)
|
||||
|
||||
23
backend/internal/config/config_test.go
Normal file
23
backend/internal/config/config_test.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package config
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestNormalizeRunMode(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"simple", "simple"},
|
||||
{"SIMPLE", "simple"},
|
||||
{"standard", "standard"},
|
||||
{"invalid", "standard"},
|
||||
{"", "standard"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
result := NormalizeRunMode(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("NormalizeRunMode(%q) = %q, want %q", tt.input, result, tt.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
@@ -11,13 +12,15 @@ import (
|
||||
|
||||
// AuthHandler handles authentication-related requests
|
||||
type AuthHandler struct {
|
||||
cfg *config.Config
|
||||
authService *service.AuthService
|
||||
userService *service.UserService
|
||||
}
|
||||
|
||||
// NewAuthHandler creates a new AuthHandler
|
||||
func NewAuthHandler(authService *service.AuthService, userService *service.UserService) *AuthHandler {
|
||||
func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService) *AuthHandler {
|
||||
return &AuthHandler{
|
||||
cfg: cfg,
|
||||
authService: authService,
|
||||
userService: userService,
|
||||
}
|
||||
@@ -157,5 +160,15 @@ func (h *AuthHandler) GetCurrentUser(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.UserFromService(user))
|
||||
type UserResponse struct {
|
||||
*dto.User
|
||||
RunMode string `json:"run_mode"`
|
||||
}
|
||||
|
||||
runMode := config.RunModeStandard
|
||||
if h.cfg != nil {
|
||||
runMode = h.cfg.RunMode
|
||||
}
|
||||
|
||||
response.Success(c, UserResponse{User: dto.UserFromService(user), RunMode: runMode})
|
||||
}
|
||||
|
||||
@@ -30,6 +30,11 @@ func AutoMigrate(db *gorm.DB) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// 创建默认分组(简易模式支持)
|
||||
if err := ensureDefaultGroups(db); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 修复无效的过期时间(年份超过 2099 会导致 JSON 序列化失败)
|
||||
return fixInvalidExpiresAt(db)
|
||||
}
|
||||
@@ -47,3 +52,55 @@ func fixInvalidExpiresAt(db *gorm.DB) error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ensureDefaultGroups 确保默认分组存在(简易模式支持)
|
||||
// 为每个平台创建一个默认分组,配置最大权限以确保简易模式下不受限制
|
||||
func ensureDefaultGroups(db *gorm.DB) error {
|
||||
defaultGroups := []struct {
|
||||
name string
|
||||
platform string
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "anthropic-default",
|
||||
platform: "anthropic",
|
||||
description: "Default group for Anthropic accounts (Simple Mode)",
|
||||
},
|
||||
{
|
||||
name: "openai-default",
|
||||
platform: "openai",
|
||||
description: "Default group for OpenAI accounts (Simple Mode)",
|
||||
},
|
||||
{
|
||||
name: "gemini-default",
|
||||
platform: "gemini",
|
||||
description: "Default group for Gemini accounts (Simple Mode)",
|
||||
},
|
||||
}
|
||||
|
||||
for _, dg := range defaultGroups {
|
||||
var count int64
|
||||
if err := db.Model(&groupModel{}).Where("name = ?", dg.name).Count(&count).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if count == 0 {
|
||||
group := &groupModel{
|
||||
Name: dg.name,
|
||||
Description: dg.description,
|
||||
Platform: dg.platform,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: "active",
|
||||
SubscriptionType: "standard",
|
||||
}
|
||||
if err := db.Create(group).Error; err != nil {
|
||||
log.Printf("[AutoMigrate] Failed to create default group %s: %v", dg.name, err)
|
||||
return err
|
||||
}
|
||||
log.Printf("[AutoMigrate] Created default group: %s (platform: %s)", dg.name, dg.platform)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -59,7 +59,8 @@ func TestAPIContracts(t *testing.T) {
|
||||
"status": "active",
|
||||
"allowed_groups": null,
|
||||
"created_at": "2025-01-02T03:04:05Z",
|
||||
"updated_at": "2025-01-02T03:04:05Z"
|
||||
"updated_at": "2025-01-02T03:04:05Z",
|
||||
"run_mode": "standard"
|
||||
}
|
||||
}`,
|
||||
},
|
||||
@@ -371,6 +372,7 @@ func newContractDeps(t *testing.T) *contractDeps {
|
||||
Default: config.DefaultConfig{
|
||||
ApiKeyPrefix: "sk-",
|
||||
},
|
||||
RunMode: config.RunModeStandard,
|
||||
}
|
||||
|
||||
userService := service.NewUserService(userRepo)
|
||||
@@ -382,7 +384,7 @@ func newContractDeps(t *testing.T) *contractDeps {
|
||||
settingRepo := newStubSettingRepo()
|
||||
settingService := service.NewSettingService(settingRepo, cfg)
|
||||
|
||||
authHandler := handler.NewAuthHandler(nil, userService)
|
||||
authHandler := handler.NewAuthHandler(cfg, nil, userService)
|
||||
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
||||
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil)
|
||||
|
||||
@@ -36,7 +36,7 @@ func ProvideRouter(
|
||||
r := gin.New()
|
||||
r.Use(middleware2.Recovery())
|
||||
|
||||
return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService)
|
||||
return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, cfg)
|
||||
}
|
||||
|
||||
// ProvideHTTPServer 提供 HTTP 服务器
|
||||
|
||||
@@ -5,18 +5,19 @@ import (
|
||||
"log"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// NewApiKeyAuthMiddleware 创建 API Key 认证中间件
|
||||
func NewApiKeyAuthMiddleware(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService) ApiKeyAuthMiddleware {
|
||||
return ApiKeyAuthMiddleware(apiKeyAuthWithSubscription(apiKeyService, subscriptionService))
|
||||
func NewApiKeyAuthMiddleware(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) ApiKeyAuthMiddleware {
|
||||
return ApiKeyAuthMiddleware(apiKeyAuthWithSubscription(apiKeyService, subscriptionService, cfg))
|
||||
}
|
||||
|
||||
// apiKeyAuthWithSubscription API Key认证中间件(支持订阅验证)
|
||||
func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService) gin.HandlerFunc {
|
||||
func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 尝试从Authorization header中提取API key (Bearer scheme)
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
@@ -85,6 +86,18 @@ func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscripti
|
||||
return
|
||||
}
|
||||
|
||||
if cfg.RunMode == config.RunModeSimple {
|
||||
// 简易模式:跳过余额和订阅检查,但仍需设置必要的上下文
|
||||
c.Set(string(ContextKeyApiKey), apiKey)
|
||||
c.Set(string(ContextKeyUser), AuthSubject{
|
||||
UserID: apiKey.User.ID,
|
||||
Concurrency: apiKey.User.Concurrency,
|
||||
})
|
||||
c.Set(string(ContextKeyUserRole), apiKey.User.Role)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// 判断计费方式:订阅模式 vs 余额模式
|
||||
isSubscriptionType := apiKey.Group != nil && apiKey.Group.IsSubscriptionType()
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
@@ -11,15 +12,15 @@ import (
|
||||
)
|
||||
|
||||
// ApiKeyAuthGoogle is a Google-style error wrapper for API key auth.
|
||||
func ApiKeyAuthGoogle(apiKeyService *service.ApiKeyService) gin.HandlerFunc {
|
||||
return ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil)
|
||||
func ApiKeyAuthGoogle(apiKeyService *service.ApiKeyService, cfg *config.Config) gin.HandlerFunc {
|
||||
return ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, cfg)
|
||||
}
|
||||
|
||||
// ApiKeyAuthWithSubscriptionGoogle behaves like ApiKeyAuthWithSubscription but returns Google-style errors:
|
||||
// {"error":{"code":401,"message":"...","status":"UNAUTHENTICATED"}}
|
||||
//
|
||||
// It is intended for Gemini native endpoints (/v1beta) to match Gemini SDK expectations.
|
||||
func ApiKeyAuthWithSubscriptionGoogle(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService) gin.HandlerFunc {
|
||||
func ApiKeyAuthWithSubscriptionGoogle(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
apiKeyString := extractAPIKeyFromRequest(c)
|
||||
if apiKeyString == "" {
|
||||
@@ -50,6 +51,18 @@ func ApiKeyAuthWithSubscriptionGoogle(apiKeyService *service.ApiKeyService, subs
|
||||
return
|
||||
}
|
||||
|
||||
// 简易模式:跳过余额和订阅检查
|
||||
if cfg.RunMode == config.RunModeSimple {
|
||||
c.Set(string(ContextKeyApiKey), apiKey)
|
||||
c.Set(string(ContextKeyUser), AuthSubject{
|
||||
UserID: apiKey.User.ID,
|
||||
Concurrency: apiKey.User.Concurrency,
|
||||
})
|
||||
c.Set(string(ContextKeyUserRole), apiKey.User.Role)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
isSubscriptionType := apiKey.Group != nil && apiKey.Group.IsSubscriptionType()
|
||||
if isSubscriptionType && subscriptionService != nil {
|
||||
subscription, err := subscriptionService.GetActiveSubscription(
|
||||
|
||||
286
backend/internal/server/middleware/api_key_auth_test.go
Normal file
286
backend/internal/server/middleware/api_key_auth_test.go
Normal file
@@ -0,0 +1,286 @@
|
||||
//go:build unit
|
||||
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
limit := 1.0
|
||||
group := &service.Group{
|
||||
ID: 42,
|
||||
Name: "sub",
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeSubscription,
|
||||
DailyLimitUSD: &limit,
|
||||
}
|
||||
user := &service.User{
|
||||
ID: 7,
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Balance: 10,
|
||||
Concurrency: 3,
|
||||
}
|
||||
apiKey := &service.ApiKey{
|
||||
ID: 100,
|
||||
UserID: user.ID,
|
||||
Key: "test-key",
|
||||
Status: service.StatusActive,
|
||||
User: user,
|
||||
Group: group,
|
||||
}
|
||||
apiKey.GroupID = &group.ID
|
||||
|
||||
apiKeyRepo := &stubApiKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
|
||||
if key != apiKey.Key {
|
||||
return nil, service.ErrApiKeyNotFound
|
||||
}
|
||||
clone := *apiKey
|
||||
return &clone, nil
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("simple_mode_bypasses_quota_check", func(t *testing.T) {
|
||||
cfg := &config.Config{RunMode: config.RunModeSimple}
|
||||
apiKeyService := service.NewApiKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
|
||||
subscriptionService := service.NewSubscriptionService(nil, &stubUserSubscriptionRepo{}, nil)
|
||||
router := newAuthTestRouter(apiKeyService, subscriptionService, cfg)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
||||
req.Header.Set("x-api-key", apiKey.Key)
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
})
|
||||
|
||||
t.Run("standard_mode_enforces_quota_check", func(t *testing.T) {
|
||||
cfg := &config.Config{RunMode: config.RunModeStandard}
|
||||
apiKeyService := service.NewApiKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
|
||||
|
||||
now := time.Now()
|
||||
sub := &service.UserSubscription{
|
||||
ID: 55,
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: now.Add(24 * time.Hour),
|
||||
DailyWindowStart: &now,
|
||||
DailyUsageUSD: 10,
|
||||
}
|
||||
subscriptionRepo := &stubUserSubscriptionRepo{
|
||||
getActive: func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
|
||||
if userID != sub.UserID || groupID != sub.GroupID {
|
||||
return nil, service.ErrSubscriptionNotFound
|
||||
}
|
||||
clone := *sub
|
||||
return &clone, nil
|
||||
},
|
||||
updateStatus: func(ctx context.Context, subscriptionID int64, status string) error { return nil },
|
||||
activateWindow: func(ctx context.Context, id int64, start time.Time) error { return nil },
|
||||
resetDaily: func(ctx context.Context, id int64, start time.Time) error { return nil },
|
||||
resetWeekly: func(ctx context.Context, id int64, start time.Time) error { return nil },
|
||||
resetMonthly: func(ctx context.Context, id int64, start time.Time) error { return nil },
|
||||
}
|
||||
subscriptionService := service.NewSubscriptionService(nil, subscriptionRepo, nil)
|
||||
router := newAuthTestRouter(apiKeyService, subscriptionService, cfg)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
||||
req.Header.Set("x-api-key", apiKey.Key)
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusTooManyRequests, w.Code)
|
||||
require.Contains(t, w.Body.String(), "USAGE_LIMIT_EXCEEDED")
|
||||
})
|
||||
}
|
||||
|
||||
func newAuthTestRouter(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) *gin.Engine {
|
||||
router := gin.New()
|
||||
router.Use(gin.HandlerFunc(NewApiKeyAuthMiddleware(apiKeyService, subscriptionService, cfg)))
|
||||
router.GET("/t", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
return router
|
||||
}
|
||||
|
||||
type stubApiKeyRepo struct {
|
||||
getByKey func(ctx context.Context, key string) (*service.ApiKey, error)
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) Create(ctx context.Context, key *service.ApiKey) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) {
|
||||
if r.getByKey != nil {
|
||||
return r.getByKey(ctx, key)
|
||||
}
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.ApiKey) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) Delete(ctx context.Context, id int64) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64, error) {
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, error) {
|
||||
return false, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
|
||||
type stubUserSubscriptionRepo struct {
|
||||
getActive func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error)
|
||||
updateStatus func(ctx context.Context, subscriptionID int64, status string) error
|
||||
activateWindow func(ctx context.Context, id int64, start time.Time) error
|
||||
resetDaily func(ctx context.Context, id int64, start time.Time) error
|
||||
resetWeekly func(ctx context.Context, id int64, start time.Time) error
|
||||
resetMonthly func(ctx context.Context, id int64, start time.Time) error
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) Create(ctx context.Context, sub *service.UserSubscription) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) GetByID(ctx context.Context, id int64) (*service.UserSubscription, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
|
||||
if r.getActive != nil {
|
||||
return r.getActive(ctx, userID, groupID)
|
||||
}
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) Update(ctx context.Context, sub *service.UserSubscription) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) Delete(ctx context.Context, id int64) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) ListByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) ListActiveByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {
|
||||
return false, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) ExtendExpiry(ctx context.Context, subscriptionID int64, newExpiresAt time.Time) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) UpdateStatus(ctx context.Context, subscriptionID int64, status string) error {
|
||||
if r.updateStatus != nil {
|
||||
return r.updateStatus(ctx, subscriptionID, status)
|
||||
}
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) UpdateNotes(ctx context.Context, subscriptionID int64, notes string) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) ActivateWindows(ctx context.Context, id int64, start time.Time) error {
|
||||
if r.activateWindow != nil {
|
||||
return r.activateWindow(ctx, id, start)
|
||||
}
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) ResetDailyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
|
||||
if r.resetDaily != nil {
|
||||
return r.resetDaily(ctx, id, newWindowStart)
|
||||
}
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) ResetWeeklyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
|
||||
if r.resetWeekly != nil {
|
||||
return r.resetWeekly(ctx, id, newWindowStart)
|
||||
}
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) ResetMonthlyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
|
||||
if r.resetMonthly != nil {
|
||||
return r.resetMonthly(ctx, id, newWindowStart)
|
||||
}
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) IncrementUsage(ctx context.Context, id int64, costUSD float64) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) {
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/routes"
|
||||
@@ -19,6 +20,7 @@ func SetupRouter(
|
||||
apiKeyAuth middleware2.ApiKeyAuthMiddleware,
|
||||
apiKeyService *service.ApiKeyService,
|
||||
subscriptionService *service.SubscriptionService,
|
||||
cfg *config.Config,
|
||||
) *gin.Engine {
|
||||
// 应用中间件
|
||||
r.Use(middleware2.Logger())
|
||||
@@ -30,7 +32,7 @@ func SetupRouter(
|
||||
}
|
||||
|
||||
// 注册路由
|
||||
registerRoutes(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService)
|
||||
registerRoutes(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, cfg)
|
||||
|
||||
return r
|
||||
}
|
||||
@@ -44,6 +46,7 @@ func registerRoutes(
|
||||
apiKeyAuth middleware2.ApiKeyAuthMiddleware,
|
||||
apiKeyService *service.ApiKeyService,
|
||||
subscriptionService *service.SubscriptionService,
|
||||
cfg *config.Config,
|
||||
) {
|
||||
// 通用路由(健康检查、状态等)
|
||||
routes.RegisterCommonRoutes(r)
|
||||
@@ -55,5 +58,5 @@ func registerRoutes(
|
||||
routes.RegisterAuthRoutes(v1, h, jwtAuth)
|
||||
routes.RegisterUserRoutes(v1, h, jwtAuth)
|
||||
routes.RegisterAdminRoutes(v1, h, adminAuth)
|
||||
routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService)
|
||||
routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, cfg)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package routes
|
||||
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
@@ -15,6 +16,7 @@ func RegisterGatewayRoutes(
|
||||
apiKeyAuth middleware.ApiKeyAuthMiddleware,
|
||||
apiKeyService *service.ApiKeyService,
|
||||
subscriptionService *service.SubscriptionService,
|
||||
cfg *config.Config,
|
||||
) {
|
||||
// API网关(Claude API兼容)
|
||||
gateway := r.Group("/v1")
|
||||
@@ -30,7 +32,7 @@ func RegisterGatewayRoutes(
|
||||
|
||||
// Gemini 原生 API 兼容层(Gemini SDK/CLI 直连)
|
||||
gemini := r.Group("/v1beta")
|
||||
gemini.Use(middleware.ApiKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService))
|
||||
gemini.Use(middleware.ApiKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg))
|
||||
{
|
||||
gemini.GET("/models", h.Gateway.GeminiV1BetaListModels)
|
||||
gemini.GET("/models/:model", h.Gateway.GeminiV1BetaGetModel)
|
||||
|
||||
@@ -609,12 +609,30 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
|
||||
if err := s.accountRepo.Create(ctx, account); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 绑定分组
|
||||
if len(input.GroupIDs) > 0 {
|
||||
if err := s.accountRepo.BindGroups(ctx, account.ID, input.GroupIDs); err != nil {
|
||||
groupIDs := input.GroupIDs
|
||||
// 如果没有指定分组,自动绑定对应平台的默认分组
|
||||
if len(groupIDs) == 0 {
|
||||
defaultGroupName := input.Platform + "-default"
|
||||
groups, err := s.groupRepo.ListActiveByPlatform(ctx, input.Platform)
|
||||
if err == nil {
|
||||
for _, g := range groups {
|
||||
if g.Name == defaultGroupName {
|
||||
groupIDs = []int64{g.ID}
|
||||
log.Printf("[CreateAccount] Auto-binding account %d to default group %s (ID: %d)", account.ID, defaultGroupName, g.ID)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(groupIDs) > 0 {
|
||||
if err := s.accountRepo.BindGroups(ctx, account.ID, groupIDs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return account, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
|
||||
)
|
||||
|
||||
@@ -32,14 +33,16 @@ type BillingCacheService struct {
|
||||
cache BillingCache
|
||||
userRepo UserRepository
|
||||
subRepo UserSubscriptionRepository
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewBillingCacheService 创建计费缓存服务
|
||||
func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo UserSubscriptionRepository) *BillingCacheService {
|
||||
func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo UserSubscriptionRepository, cfg *config.Config) *BillingCacheService {
|
||||
return &BillingCacheService{
|
||||
cache: cache,
|
||||
userRepo: userRepo,
|
||||
subRepo: subRepo,
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -224,6 +227,11 @@ func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID
|
||||
// 余额模式:检查缓存余额 > 0
|
||||
// 订阅模式:检查缓存用量未超过限额(Group限额从参数传入)
|
||||
func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user *User, apiKey *ApiKey, group *Group, subscription *UserSubscription) error {
|
||||
// 简易模式:跳过所有计费检查
|
||||
if s.cfg.RunMode == config.RunModeSimple {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 判断计费模式
|
||||
isSubscriptionMode := group != nil && group.IsSubscriptionType() && subscription != nil
|
||||
|
||||
|
||||
@@ -313,7 +313,10 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
|
||||
// 2. 获取可调度账号列表(排除限流和过载的账号,仅限 Anthropic 平台)
|
||||
var accounts []Account
|
||||
var err error
|
||||
if groupID != nil {
|
||||
if s.cfg.RunMode == config.RunModeSimple {
|
||||
// 简易模式:忽略 groupID,查询所有可用账号
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformAnthropic)
|
||||
} else if groupID != nil {
|
||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformAnthropic)
|
||||
} else {
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformAnthropic)
|
||||
@@ -1065,6 +1068,12 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
log.Printf("Create usage log failed: %v", err)
|
||||
}
|
||||
|
||||
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
|
||||
log.Printf("[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
|
||||
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// 根据计费类型执行扣费
|
||||
if isSubscriptionBilling {
|
||||
// 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率)
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strconv"
|
||||
@@ -155,7 +156,10 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C
|
||||
// 2. Get schedulable OpenAI accounts
|
||||
var accounts []Account
|
||||
var err error
|
||||
if groupID != nil {
|
||||
// 简易模式:忽略分组限制,查询所有可用账号
|
||||
if s.cfg.RunMode == config.RunModeSimple {
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformOpenAI)
|
||||
} else if groupID != nil {
|
||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformOpenAI)
|
||||
} else {
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformOpenAI)
|
||||
@@ -754,6 +758,12 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
|
||||
_ = s.usageLogRepo.Create(ctx, usageLog)
|
||||
|
||||
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
|
||||
log.Printf("[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
|
||||
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Deduct based on billing type
|
||||
if isSubscriptionBilling {
|
||||
if cost.TotalCost > 0 {
|
||||
|
||||
Reference in New Issue
Block a user