From 675543240e2e61dfa8f3db64b0cdb27f5b04d1f9 Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Fri, 9 Jan 2026 23:01:42 +0800 Subject: [PATCH] =?UTF-8?q?perf(=E7=BD=91=E5=85=B3):=20=E5=A4=8D=E7=94=A8?= =?UTF-8?q?=E5=88=86=E7=BB=84=E4=B8=8A=E4=B8=8B=E6=96=87=E5=87=8F=E5=B0=91?= =?UTF-8?q?=E7=83=AD=E8=B7=AF=E5=BE=84=E6=9F=A5=E8=AF=A2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 新增 GetByIDLite 并在网关与 Gemini 选择流程复用上下文 group,避免 COUNT 触发 更新 API key 中间件注入 group 上下文,减少重复查库 补充 gateway/gemini 中间件与仓库层回归测试 测试: make test --- backend/internal/pkg/ctxkey/ctxkey.go | 2 + backend/internal/repository/group_repo.go | 15 +- .../repository/group_repo_integration_test.go | 36 +++++ backend/internal/server/api_contract_test.go | 4 + .../server/middleware/api_key_auth.go | 15 ++ .../server/middleware/api_key_auth_google.go | 2 + .../middleware/api_key_auth_google_test.go | 64 +++++++++ .../server/middleware/api_key_auth_test.go | 58 ++++++++ .../service/admin_service_delete_test.go | 4 + .../service/admin_service_group_test.go | 7 + .../service/gateway_multiplatform_test.go | 133 ++++++++++++++++++ backend/internal/service/gateway_service.go | 113 +++++++++------ .../service/gemini_messages_compat_service.go | 12 +- .../service/gemini_multiplatform_test.go | 84 ++++++++++- backend/internal/service/group_service.go | 1 + 15 files changed, 499 insertions(+), 51 deletions(-) diff --git a/backend/internal/pkg/ctxkey/ctxkey.go b/backend/internal/pkg/ctxkey/ctxkey.go index 3add78de..bd10eae0 100644 --- a/backend/internal/pkg/ctxkey/ctxkey.go +++ b/backend/internal/pkg/ctxkey/ctxkey.go @@ -9,4 +9,6 @@ const ( ForcePlatform Key = "ctx_force_platform" // IsClaudeCodeClient 是否为 Claude Code 客户端,由中间件设置 IsClaudeCodeClient Key = "ctx_is_claude_code_client" + // Group 认证后的分组信息,由 API Key 认证中间件设置 + Group Key = "ctx_group" ) diff --git a/backend/internal/repository/group_repo.go b/backend/internal/repository/group_repo.go index 1fb4ae90..daff8b89 100644 --- a/backend/internal/repository/group_repo.go +++ b/backend/internal/repository/group_repo.go @@ -60,6 +60,16 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er } func (r *groupRepository) GetByID(ctx context.Context, id int64) (*service.Group, error) { + out, err := r.GetByIDLite(ctx, id) + if err != nil { + return nil, err + } + count, _ := r.GetAccountCount(ctx, out.ID) + out.AccountCount = count + return out, nil +} + +func (r *groupRepository) GetByIDLite(ctx context.Context, id int64) (*service.Group, error) { m, err := r.client.Group.Query(). Where(group.IDEQ(id)). Only(ctx) @@ -67,10 +77,7 @@ func (r *groupRepository) GetByID(ctx context.Context, id int64) (*service.Group return nil, translatePersistenceError(err, service.ErrGroupNotFound, nil) } - out := groupEntityToService(m) - count, _ := r.GetAccountCount(ctx, out.ID) - out.AccountCount = count - return out, nil + return groupEntityToService(m), nil } func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) error { diff --git a/backend/internal/repository/group_repo_integration_test.go b/backend/internal/repository/group_repo_integration_test.go index b9079d7a..204dd0d3 100644 --- a/backend/internal/repository/group_repo_integration_test.go +++ b/backend/internal/repository/group_repo_integration_test.go @@ -4,6 +4,8 @@ package repository import ( "context" + "database/sql" + "errors" "testing" dbent "github.com/Wei-Shaw/sub2api/ent" @@ -19,6 +21,20 @@ type GroupRepoSuite struct { repo *groupRepository } +type forbidSQLExecutor struct { + called bool +} + +func (s *forbidSQLExecutor) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) { + s.called = true + return nil, errors.New("unexpected sql exec") +} + +func (s *forbidSQLExecutor) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) { + s.called = true + return nil, errors.New("unexpected sql query") +} + func (s *GroupRepoSuite) SetupTest() { s.ctx = context.Background() tx := testEntTx(s.T()) @@ -57,6 +73,26 @@ func (s *GroupRepoSuite) TestGetByID_NotFound() { s.Require().ErrorIs(err, service.ErrGroupNotFound) } +func (s *GroupRepoSuite) TestGetByIDLite_DoesNotUseAccountCount() { + group := &service.Group{ + Name: "lite-group", + Platform: service.PlatformAnthropic, + RateMultiplier: 1.0, + IsExclusive: false, + Status: service.StatusActive, + SubscriptionType: service.SubscriptionTypeStandard, + } + s.Require().NoError(s.repo.Create(s.ctx, group)) + + spy := &forbidSQLExecutor{} + repo := newGroupRepositoryWithSQL(s.tx.Client(), spy) + + got, err := repo.GetByIDLite(s.ctx, group.ID) + s.Require().NoError(err) + s.Require().Equal(group.ID, got.ID) + s.Require().False(spy.called, "expected no direct sql executor usage") +} + func (s *GroupRepoSuite) TestUpdate() { group := &service.Group{ Name: "original", diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index 502d74b3..f9b31193 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -567,6 +567,10 @@ func (stubGroupRepo) GetByID(ctx context.Context, id int64) (*service.Group, err return nil, service.ErrGroupNotFound } +func (stubGroupRepo) GetByIDLite(ctx context.Context, id int64) (*service.Group, error) { + return nil, service.ErrGroupNotFound +} + func (stubGroupRepo) Update(ctx context.Context, group *service.Group) error { return errors.New("not implemented") } diff --git a/backend/internal/server/middleware/api_key_auth.go b/backend/internal/server/middleware/api_key_auth.go index 74ff8af3..8d78e32d 100644 --- a/backend/internal/server/middleware/api_key_auth.go +++ b/backend/internal/server/middleware/api_key_auth.go @@ -1,11 +1,13 @@ package middleware import ( + "context" "errors" "log" "strings" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" @@ -91,6 +93,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti Concurrency: apiKey.User.Concurrency, }) c.Set(string(ContextKeyUserRole), apiKey.User.Role) + setGroupContext(c, apiKey.Group) c.Next() return } @@ -149,6 +152,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti Concurrency: apiKey.User.Concurrency, }) c.Set(string(ContextKeyUserRole), apiKey.User.Role) + setGroupContext(c, apiKey.Group) c.Next() } @@ -173,3 +177,14 @@ func GetSubscriptionFromContext(c *gin.Context) (*service.UserSubscription, bool subscription, ok := value.(*service.UserSubscription) return subscription, ok } + +func setGroupContext(c *gin.Context, group *service.Group) { + if group == nil { + return + } + if existing, ok := c.Request.Context().Value(ctxkey.Group).(*service.Group); ok && existing != nil && existing.ID == group.ID { + return + } + ctx := context.WithValue(c.Request.Context(), ctxkey.Group, group) + c.Request = c.Request.WithContext(ctx) +} diff --git a/backend/internal/server/middleware/api_key_auth_google.go b/backend/internal/server/middleware/api_key_auth_google.go index c5afd7ef..1a0b0dd5 100644 --- a/backend/internal/server/middleware/api_key_auth_google.go +++ b/backend/internal/server/middleware/api_key_auth_google.go @@ -63,6 +63,7 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs Concurrency: apiKey.User.Concurrency, }) c.Set(string(ContextKeyUserRole), apiKey.User.Role) + setGroupContext(c, apiKey.Group) c.Next() return } @@ -102,6 +103,7 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs Concurrency: apiKey.User.Concurrency, }) c.Set(string(ContextKeyUserRole), apiKey.User.Role) + setGroupContext(c, apiKey.Group) c.Next() } } diff --git a/backend/internal/server/middleware/api_key_auth_google_test.go b/backend/internal/server/middleware/api_key_auth_google_test.go index 0ed5a4a2..6d6ab0fc 100644 --- a/backend/internal/server/middleware/api_key_auth_google_test.go +++ b/backend/internal/server/middleware/api_key_auth_google_test.go @@ -9,6 +9,7 @@ import ( "testing" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/service" @@ -133,6 +134,69 @@ func TestApiKeyAuthWithSubscriptionGoogle_QueryApiKeyRejected(t *testing.T) { require.Equal(t, "INVALID_ARGUMENT", resp.Error.Status) } +func TestApiKeyAuthWithSubscriptionGoogleSetsGroupContext(t *testing.T) { + gin.SetMode(gin.TestMode) + + group := &service.Group{ + ID: 99, + Name: "g1", + Status: service.StatusActive, + Platform: service.PlatformGemini, + } + user := &service.User{ + ID: 7, + Role: service.RoleUser, + Status: service.StatusActive, + Balance: 10, + Concurrency: 3, + } + apiKey := &service.APIKey{ + ID: 100, + UserID: user.ID, + Key: "test-key", + Status: service.StatusActive, + User: user, + Group: group, + } + apiKey.GroupID = &group.ID + + apiKeyService := service.NewAPIKeyService( + fakeAPIKeyRepo{ + getByKey: func(ctx context.Context, key string) (*service.APIKey, error) { + if key != apiKey.Key { + return nil, service.ErrAPIKeyNotFound + } + clone := *apiKey + return &clone, nil + }, + }, + nil, + nil, + nil, + nil, + &config.Config{RunMode: config.RunModeSimple}, + ) + + cfg := &config.Config{RunMode: config.RunModeSimple} + r := gin.New() + r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, cfg)) + r.GET("/v1beta/test", func(c *gin.Context) { + groupFromCtx, ok := c.Request.Context().Value(ctxkey.Group).(*service.Group) + if !ok || groupFromCtx == nil || groupFromCtx.ID != group.ID { + c.JSON(http.StatusInternalServerError, gin.H{"ok": false}) + return + } + c.JSON(http.StatusOK, gin.H{"ok": true}) + }) + + req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil) + req.Header.Set("x-api-key", apiKey.Key) + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) +} + func TestApiKeyAuthWithSubscriptionGoogle_QueryKeyAllowedOnV1Beta(t *testing.T) { gin.SetMode(gin.TestMode) diff --git a/backend/internal/server/middleware/api_key_auth_test.go b/backend/internal/server/middleware/api_key_auth_test.go index d50fb7b2..719b5d99 100644 --- a/backend/internal/server/middleware/api_key_auth_test.go +++ b/backend/internal/server/middleware/api_key_auth_test.go @@ -11,6 +11,7 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" @@ -110,6 +111,63 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) { }) } +func TestAPIKeyAuthSetsGroupContext(t *testing.T) { + gin.SetMode(gin.TestMode) + + group := &service.Group{ + ID: 101, + Name: "g1", + Status: service.StatusActive, + Platform: service.PlatformAnthropic, + } + user := &service.User{ + ID: 7, + Role: service.RoleUser, + Status: service.StatusActive, + Balance: 10, + Concurrency: 3, + } + apiKey := &service.APIKey{ + ID: 100, + UserID: user.ID, + Key: "test-key", + Status: service.StatusActive, + User: user, + Group: group, + } + apiKey.GroupID = &group.ID + + apiKeyRepo := &stubApiKeyRepo{ + getByKey: func(ctx context.Context, key string) (*service.APIKey, error) { + if key != apiKey.Key { + return nil, service.ErrAPIKeyNotFound + } + clone := *apiKey + return &clone, nil + }, + } + + cfg := &config.Config{RunMode: config.RunModeSimple} + apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg) + router := gin.New() + router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, nil, cfg))) + router.GET("/t", func(c *gin.Context) { + groupFromCtx, ok := c.Request.Context().Value(ctxkey.Group).(*service.Group) + if !ok || groupFromCtx == nil || groupFromCtx.ID != group.ID { + c.JSON(http.StatusInternalServerError, gin.H{"ok": false}) + return + } + c.JSON(http.StatusOK, gin.H{"ok": true}) + }) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/t", nil) + req.Header.Set("x-api-key", apiKey.Key) + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code) +} + func newAuthTestRouter(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) *gin.Engine { router := gin.New() router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, cfg))) diff --git a/backend/internal/service/admin_service_delete_test.go b/backend/internal/service/admin_service_delete_test.go index c1d2e4c9..20a65ec2 100644 --- a/backend/internal/service/admin_service_delete_test.go +++ b/backend/internal/service/admin_service_delete_test.go @@ -107,6 +107,10 @@ func (s *groupRepoStub) GetByID(ctx context.Context, id int64) (*Group, error) { panic("unexpected GetByID call") } +func (s *groupRepoStub) GetByIDLite(ctx context.Context, id int64) (*Group, error) { + panic("unexpected GetByIDLite call") +} + func (s *groupRepoStub) Update(ctx context.Context, group *Group) error { panic("unexpected Update call") } diff --git a/backend/internal/service/admin_service_group_test.go b/backend/internal/service/admin_service_group_test.go index 3171de11..675f4c6f 100644 --- a/backend/internal/service/admin_service_group_test.go +++ b/backend/internal/service/admin_service_group_test.go @@ -35,6 +35,13 @@ func (s *groupRepoStubForAdmin) GetByID(_ context.Context, _ int64) (*Group, err return s.getByID, nil } +func (s *groupRepoStubForAdmin) GetByIDLite(_ context.Context, _ int64) (*Group, error) { + if s.getErr != nil { + return nil, s.getErr + } + return s.getByID, nil +} + func (s *groupRepoStubForAdmin) Delete(_ context.Context, _ int64) error { panic("unexpected Delete call") } diff --git a/backend/internal/service/gateway_multiplatform_test.go b/backend/internal/service/gateway_multiplatform_test.go index 66c40e25..4f6545e2 100644 --- a/backend/internal/service/gateway_multiplatform_test.go +++ b/backend/internal/service/gateway_multiplatform_test.go @@ -9,6 +9,7 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/stretchr/testify/require" ) @@ -185,6 +186,56 @@ func (m *mockGatewayCacheForPlatform) RefreshSessionTTL(ctx context.Context, gro return nil } +type mockGroupRepoForGateway struct { + groups map[int64]*Group + getByIDCalls int + getByIDLiteCalls int +} + +func (m *mockGroupRepoForGateway) GetByID(ctx context.Context, id int64) (*Group, error) { + m.getByIDCalls++ + if g, ok := m.groups[id]; ok { + return g, nil + } + return nil, ErrGroupNotFound +} + +func (m *mockGroupRepoForGateway) GetByIDLite(ctx context.Context, id int64) (*Group, error) { + m.getByIDLiteCalls++ + if g, ok := m.groups[id]; ok { + return g, nil + } + return nil, ErrGroupNotFound +} + +func (m *mockGroupRepoForGateway) Create(ctx context.Context, group *Group) error { return nil } +func (m *mockGroupRepoForGateway) Update(ctx context.Context, group *Group) error { return nil } +func (m *mockGroupRepoForGateway) Delete(ctx context.Context, id int64) error { return nil } +func (m *mockGroupRepoForGateway) DeleteCascade(ctx context.Context, id int64) ([]int64, error) { + return nil, nil +} +func (m *mockGroupRepoForGateway) List(ctx context.Context, params pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (m *mockGroupRepoForGateway) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (m *mockGroupRepoForGateway) ListActive(ctx context.Context) ([]Group, error) { + return nil, nil +} +func (m *mockGroupRepoForGateway) ListActiveByPlatform(ctx context.Context, platform string) ([]Group, error) { + return nil, nil +} +func (m *mockGroupRepoForGateway) ExistsByName(ctx context.Context, name string) (bool, error) { + return false, nil +} +func (m *mockGroupRepoForGateway) GetAccountCount(ctx context.Context, groupID int64) (int64, error) { + return 0, nil +} +func (m *mockGroupRepoForGateway) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) { + return 0, nil +} + func ptr[T any](v T) *T { return &v } @@ -1013,3 +1064,85 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) { require.Contains(t, err.Error(), "no available accounts") }) } + +func TestGatewayService_GroupResolution_ReusesContextGroup(t *testing.T) { + ctx := context.Background() + groupID := int64(42) + group := &Group{ + ID: groupID, + Platform: PlatformAnthropic, + Status: StatusActive, + } + ctx = context.WithValue(ctx, ctxkey.Group, group) + + repo := &mockAccountRepoForPlatform{ + accounts: []Account{ + {ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + groupRepo := &mockGroupRepoForGateway{ + groups: map[int64]*Group{groupID: group}, + } + + svc := &GatewayService{ + accountRepo: repo, + groupRepo: groupRepo, + cfg: testConfig(), + } + + account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "claude-3-5-sonnet-20241022", nil) + require.NoError(t, err) + require.NotNil(t, account) + require.Equal(t, 0, groupRepo.getByIDCalls) + require.Equal(t, 0, groupRepo.getByIDLiteCalls) +} + +func TestGatewayService_GroupResolution_FallbackUsesLiteOnce(t *testing.T) { + ctx := context.Background() + groupID := int64(10) + fallbackID := int64(11) + group := &Group{ + ID: groupID, + Platform: PlatformAnthropic, + Status: StatusActive, + ClaudeCodeOnly: true, + FallbackGroupID: &fallbackID, + } + fallbackGroup := &Group{ + ID: fallbackID, + Platform: PlatformAnthropic, + Status: StatusActive, + } + ctx = context.WithValue(ctx, ctxkey.Group, group) + + repo := &mockAccountRepoForPlatform{ + accounts: []Account{ + {ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + groupRepo := &mockGroupRepoForGateway{ + groups: map[int64]*Group{fallbackID: fallbackGroup}, + } + + svc := &GatewayService{ + accountRepo: repo, + groupRepo: groupRepo, + cfg: testConfig(), + } + + account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "claude-3-5-sonnet-20241022", nil) + require.NoError(t, err) + require.NotNil(t, account) + require.Equal(t, 0, groupRepo.getByIDCalls) + require.Equal(t, 1, groupRepo.getByIDLiteCalls) +} diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index e73e9406..27353022 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -361,27 +361,13 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context if hasForcePlatform && forcePlatform != "" { platform = forcePlatform } else if groupID != nil { - // 根据分组 platform 决定查询哪种账号 - group, err := s.groupRepo.GetByID(ctx, *groupID) + group, resolvedGroupID, err := s.resolveGatewayGroup(ctx, groupID) if err != nil { - return nil, fmt.Errorf("get group failed: %w", err) + return nil, err } + groupID = resolvedGroupID + ctx = s.withGroupContext(ctx, group) platform = group.Platform - - // 检查 Claude Code 客户端限制 - if group.ClaudeCodeOnly { - isClaudeCode := IsClaudeCodeClient(ctx) - if !isClaudeCode { - // 非 Claude Code 客户端,检查是否有降级分组 - if group.FallbackGroupID != nil { - // 使用降级分组重新调度 - fallbackGroupID := *group.FallbackGroupID - return s.SelectAccountForModelWithExclusions(ctx, &fallbackGroupID, sessionHash, requestedModel, excludedIDs) - } - // 无降级分组,拒绝访问 - return nil, ErrClaudeCodeOnly - } - } } else { // 无分组时只使用原生 anthropic 平台 platform = PlatformAnthropic @@ -409,10 +395,11 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro } // 检查 Claude Code 客户端限制(可能会替换 groupID 为降级分组) - groupID, err := s.checkClaudeCodeRestriction(ctx, groupID) + group, groupID, err := s.checkClaudeCodeRestriction(ctx, groupID) if err != nil { return nil, err } + ctx = s.withGroupContext(ctx, group) if s.concurrencyService == nil || !cfg.LoadBatchEnabled { account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs) @@ -452,7 +439,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro }, nil } - platform, hasForcePlatform, err := s.resolvePlatform(ctx, groupID) + platform, hasForcePlatform, err := s.resolvePlatform(ctx, groupID, group) if err != nil { return nil, err } @@ -652,51 +639,91 @@ func (s *GatewayService) schedulingConfig() config.GatewaySchedulingConfig { } } +func (s *GatewayService) withGroupContext(ctx context.Context, group *Group) context.Context { + if group == nil { + return ctx + } + if existing, ok := ctx.Value(ctxkey.Group).(*Group); ok && existing != nil && existing.ID == group.ID { + return ctx + } + return context.WithValue(ctx, ctxkey.Group, group) +} + +func (s *GatewayService) groupFromContext(ctx context.Context, groupID int64) *Group { + if group, ok := ctx.Value(ctxkey.Group).(*Group); ok && group != nil && group.ID == groupID { + return group + } + return nil +} + +func (s *GatewayService) resolveGroupByID(ctx context.Context, groupID int64) (*Group, error) { + if group := s.groupFromContext(ctx, groupID); group != nil { + return group, nil + } + group, err := s.groupRepo.GetByIDLite(ctx, groupID) + if err != nil { + return nil, fmt.Errorf("get group failed: %w", err) + } + return group, nil +} + +func (s *GatewayService) resolveGatewayGroup(ctx context.Context, groupID *int64) (*Group, *int64, error) { + if groupID == nil { + return nil, nil, nil + } + + currentID := *groupID + for { + group, err := s.resolveGroupByID(ctx, currentID) + if err != nil { + return nil, nil, err + } + + if !group.ClaudeCodeOnly || IsClaudeCodeClient(ctx) { + return group, ¤tID, nil + } + + if group.FallbackGroupID == nil { + return nil, nil, ErrClaudeCodeOnly + } + currentID = *group.FallbackGroupID + } +} + // checkClaudeCodeRestriction 检查分组的 Claude Code 客户端限制 // 如果分组启用了 claude_code_only 且请求不是来自 Claude Code 客户端: // - 有降级分组:返回降级分组的 ID // - 无降级分组:返回 ErrClaudeCodeOnly 错误 -func (s *GatewayService) checkClaudeCodeRestriction(ctx context.Context, groupID *int64) (*int64, error) { +func (s *GatewayService) checkClaudeCodeRestriction(ctx context.Context, groupID *int64) (*Group, *int64, error) { if groupID == nil { - return groupID, nil + return nil, groupID, nil } // 强制平台模式不检查 Claude Code 限制 if _, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string); hasForcePlatform { - return groupID, nil + return nil, groupID, nil } - group, err := s.groupRepo.GetByID(ctx, *groupID) + group, resolvedID, err := s.resolveGatewayGroup(ctx, groupID) if err != nil { - return nil, fmt.Errorf("get group failed: %w", err) + return nil, nil, err } - if !group.ClaudeCodeOnly { - return groupID, nil - } - - // 分组启用了 Claude Code 限制 - if IsClaudeCodeClient(ctx) { - return groupID, nil - } - - // 非 Claude Code 客户端,检查降级分组 - if group.FallbackGroupID != nil { - return group.FallbackGroupID, nil - } - - return nil, ErrClaudeCodeOnly + return group, resolvedID, nil } -func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64) (string, bool, error) { +func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64, group *Group) (string, bool, error) { forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string) if hasForcePlatform && forcePlatform != "" { return forcePlatform, true, nil } + if group != nil { + return group.Platform, false, nil + } if groupID != nil { - group, err := s.groupRepo.GetByID(ctx, *groupID) + group, err := s.resolveGroupByID(ctx, *groupID) if err != nil { - return "", false, fmt.Errorf("get group failed: %w", err) + return "", false, err } return group.Platform, false, nil } diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go index f2b5bafd..fba50b62 100644 --- a/backend/internal/service/gemini_messages_compat_service.go +++ b/backend/internal/service/gemini_messages_compat_service.go @@ -86,9 +86,15 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co platform = forcePlatform } else if groupID != nil { // 根据分组 platform 决定查询哪种账号 - group, err := s.groupRepo.GetByID(ctx, *groupID) - if err != nil { - return nil, fmt.Errorf("get group failed: %w", err) + var group *Group + if ctxGroup, ok := ctx.Value(ctxkey.Group).(*Group); ok && ctxGroup != nil && ctxGroup.ID == *groupID { + group = ctxGroup + } else { + var err error + group, err = s.groupRepo.GetByIDLite(ctx, *groupID) + if err != nil { + return nil, fmt.Errorf("get group failed: %w", err) + } } platform = group.Platform } else { diff --git a/backend/internal/service/gemini_multiplatform_test.go b/backend/internal/service/gemini_multiplatform_test.go index 6007bce8..897f3129 100644 --- a/backend/internal/service/gemini_multiplatform_test.go +++ b/backend/internal/service/gemini_multiplatform_test.go @@ -8,6 +8,7 @@ import ( "testing" "time" + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/stretchr/testify/require" ) @@ -146,10 +147,21 @@ var _ AccountRepository = (*mockAccountRepoForGemini)(nil) // mockGroupRepoForGemini Gemini 测试用的 group repo mock type mockGroupRepoForGemini struct { - groups map[int64]*Group + groups map[int64]*Group + getByIDCalls int + getByIDLiteCalls int } func (m *mockGroupRepoForGemini) GetByID(ctx context.Context, id int64) (*Group, error) { + m.getByIDCalls++ + if g, ok := m.groups[id]; ok { + return g, nil + } + return nil, errors.New("group not found") +} + +func (m *mockGroupRepoForGemini) GetByIDLite(ctx context.Context, id int64) (*Group, error) { + m.getByIDLiteCalls++ if g, ok := m.groups[id]; ok { return g, nil } @@ -242,6 +254,76 @@ func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiP require.Equal(t, PlatformGemini, acc.Platform, "无分组时应只返回 gemini 平台账户") } +func TestGeminiMessagesCompatService_GroupResolution_ReusesContextGroup(t *testing.T) { + ctx := context.Background() + groupID := int64(7) + group := &Group{ + ID: groupID, + Platform: PlatformGemini, + Status: StatusActive, + } + ctx = context.WithValue(ctx, ctxkey.Group, group) + + repo := &mockAccountRepoForGemini{ + accounts: []Account{ + {ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + cache := &mockGatewayCacheForGemini{} + groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}} + + svc := &GeminiMessagesCompatService{ + accountRepo: repo, + groupRepo: groupRepo, + cache: cache, + } + + acc, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "gemini-2.5-flash", nil) + require.NoError(t, err) + require.NotNil(t, acc) + require.Equal(t, 0, groupRepo.getByIDCalls) + require.Equal(t, 0, groupRepo.getByIDLiteCalls) +} + +func TestGeminiMessagesCompatService_GroupResolution_UsesLiteFetch(t *testing.T) { + ctx := context.Background() + groupID := int64(7) + + repo := &mockAccountRepoForGemini{ + accounts: []Account{ + {ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true}, + }, + accountsByID: map[int64]*Account{}, + } + for i := range repo.accounts { + repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i] + } + + cache := &mockGatewayCacheForGemini{} + groupRepo := &mockGroupRepoForGemini{ + groups: map[int64]*Group{ + groupID: {ID: groupID, Platform: PlatformGemini}, + }, + } + + svc := &GeminiMessagesCompatService{ + accountRepo: repo, + groupRepo: groupRepo, + cache: cache, + } + + acc, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "gemini-2.5-flash", nil) + require.NoError(t, err) + require.NotNil(t, acc) + require.Equal(t, 0, groupRepo.getByIDCalls) + require.Equal(t, 1, groupRepo.getByIDLiteCalls) +} + // TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_AntigravityGroup 测试 antigravity 分组 func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_AntigravityGroup(t *testing.T) { ctx := context.Background() diff --git a/backend/internal/service/group_service.go b/backend/internal/service/group_service.go index 403636e8..63c69f2a 100644 --- a/backend/internal/service/group_service.go +++ b/backend/internal/service/group_service.go @@ -16,6 +16,7 @@ var ( type GroupRepository interface { Create(ctx context.Context, group *Group) error GetByID(ctx context.Context, id int64) (*Group, error) + GetByIDLite(ctx context.Context, id int64) (*Group, error) Update(ctx context.Context, group *Group) error Delete(ctx context.Context, id int64) error DeleteCascade(ctx context.Context, id int64) ([]int64, error)