Merge branch 'test' into dev

This commit is contained in:
yangjianbo
2026-01-10 09:39:02 +08:00
18 changed files with 805 additions and 61 deletions

View File

@@ -9,4 +9,6 @@ const (
ForcePlatform Key = "ctx_force_platform"
// IsClaudeCodeClient 是否为 Claude Code 客户端,由中间件设置
IsClaudeCodeClient Key = "ctx_is_claude_code_client"
// Group 认证后的分组信息,由 API Key 认证中间件设置
Group Key = "ctx_group"
)

View File

@@ -339,6 +339,7 @@ func groupEntityToService(g *dbent.Group) *service.Group {
RateMultiplier: g.RateMultiplier,
IsExclusive: g.IsExclusive,
Status: g.Status,
Hydrated: true,
SubscriptionType: g.SubscriptionType,
DailyLimitUSD: g.DailyLimitUsd,
WeeklyLimitUSD: g.WeeklyLimitUsd,

View File

@@ -60,6 +60,17 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
}
func (r *groupRepository) GetByID(ctx context.Context, id int64) (*service.Group, error) {
out, err := r.GetByIDLite(ctx, id)
if err != nil {
return nil, err
}
count, _ := r.GetAccountCount(ctx, out.ID)
out.AccountCount = count
return out, nil
}
func (r *groupRepository) GetByIDLite(ctx context.Context, id int64) (*service.Group, error) {
// AccountCount is intentionally not loaded here; use GetByID when needed.
m, err := r.client.Group.Query().
Where(group.IDEQ(id)).
Only(ctx)
@@ -67,10 +78,7 @@ func (r *groupRepository) GetByID(ctx context.Context, id int64) (*service.Group
return nil, translatePersistenceError(err, service.ErrGroupNotFound, nil)
}
out := groupEntityToService(m)
count, _ := r.GetAccountCount(ctx, out.ID)
out.AccountCount = count
return out, nil
return groupEntityToService(m), nil
}
func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) error {

View File

@@ -4,6 +4,8 @@ package repository
import (
"context"
"database/sql"
"errors"
"testing"
dbent "github.com/Wei-Shaw/sub2api/ent"
@@ -19,6 +21,20 @@ type GroupRepoSuite struct {
repo *groupRepository
}
type forbidSQLExecutor struct {
called bool
}
func (s *forbidSQLExecutor) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
s.called = true
return nil, errors.New("unexpected sql exec")
}
func (s *forbidSQLExecutor) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
s.called = true
return nil, errors.New("unexpected sql query")
}
func (s *GroupRepoSuite) SetupTest() {
s.ctx = context.Background()
tx := testEntTx(s.T())
@@ -57,6 +73,26 @@ func (s *GroupRepoSuite) TestGetByID_NotFound() {
s.Require().ErrorIs(err, service.ErrGroupNotFound)
}
func (s *GroupRepoSuite) TestGetByIDLite_DoesNotUseAccountCount() {
group := &service.Group{
Name: "lite-group",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}
s.Require().NoError(s.repo.Create(s.ctx, group))
spy := &forbidSQLExecutor{}
repo := newGroupRepositoryWithSQL(s.tx.Client(), spy)
got, err := repo.GetByIDLite(s.ctx, group.ID)
s.Require().NoError(err)
s.Require().Equal(group.ID, got.ID)
s.Require().False(spy.called, "expected no direct sql executor usage")
}
func (s *GroupRepoSuite) TestUpdate() {
group := &service.Group{
Name: "original",

View File

@@ -575,6 +575,10 @@ func (stubGroupRepo) GetByID(ctx context.Context, id int64) (*service.Group, err
return nil, service.ErrGroupNotFound
}
func (stubGroupRepo) GetByIDLite(ctx context.Context, id int64) (*service.Group, error) {
return nil, service.ErrGroupNotFound
}
func (stubGroupRepo) Update(ctx context.Context, group *service.Group) error {
return errors.New("not implemented")
}

View File

@@ -1,11 +1,13 @@
package middleware
import (
"context"
"errors"
"log"
"strings"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -103,6 +105,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
Concurrency: apiKey.User.Concurrency,
})
c.Set(string(ContextKeyUserRole), apiKey.User.Role)
setGroupContext(c, apiKey.Group)
c.Next()
return
}
@@ -161,6 +164,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
Concurrency: apiKey.User.Concurrency,
})
c.Set(string(ContextKeyUserRole), apiKey.User.Role)
setGroupContext(c, apiKey.Group)
c.Next()
}
@@ -185,3 +189,14 @@ func GetSubscriptionFromContext(c *gin.Context) (*service.UserSubscription, bool
subscription, ok := value.(*service.UserSubscription)
return subscription, ok
}
func setGroupContext(c *gin.Context, group *service.Group) {
if !service.IsGroupContextValid(group) {
return
}
if existing, ok := c.Request.Context().Value(ctxkey.Group).(*service.Group); ok && existing != nil && existing.ID == group.ID && service.IsGroupContextValid(existing) {
return
}
ctx := context.WithValue(c.Request.Context(), ctxkey.Group, group)
c.Request = c.Request.WithContext(ctx)
}

View File

@@ -63,6 +63,7 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs
Concurrency: apiKey.User.Concurrency,
})
c.Set(string(ContextKeyUserRole), apiKey.User.Role)
setGroupContext(c, apiKey.Group)
c.Next()
return
}
@@ -102,6 +103,7 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs
Concurrency: apiKey.User.Concurrency,
})
c.Set(string(ContextKeyUserRole), apiKey.User.Role)
setGroupContext(c, apiKey.Group)
c.Next()
}
}

View File

@@ -9,6 +9,7 @@ import (
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -133,6 +134,70 @@ func TestApiKeyAuthWithSubscriptionGoogle_QueryApiKeyRejected(t *testing.T) {
require.Equal(t, "INVALID_ARGUMENT", resp.Error.Status)
}
func TestApiKeyAuthWithSubscriptionGoogleSetsGroupContext(t *testing.T) {
gin.SetMode(gin.TestMode)
group := &service.Group{
ID: 99,
Name: "g1",
Status: service.StatusActive,
Platform: service.PlatformGemini,
Hydrated: true,
}
user := &service.User{
ID: 7,
Role: service.RoleUser,
Status: service.StatusActive,
Balance: 10,
Concurrency: 3,
}
apiKey := &service.APIKey{
ID: 100,
UserID: user.ID,
Key: "test-key",
Status: service.StatusActive,
User: user,
Group: group,
}
apiKey.GroupID = &group.ID
apiKeyService := service.NewAPIKeyService(
fakeAPIKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
if key != apiKey.Key {
return nil, service.ErrAPIKeyNotFound
}
clone := *apiKey
return &clone, nil
},
},
nil,
nil,
nil,
nil,
&config.Config{RunMode: config.RunModeSimple},
)
cfg := &config.Config{RunMode: config.RunModeSimple}
r := gin.New()
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, cfg))
r.GET("/v1beta/test", func(c *gin.Context) {
groupFromCtx, ok := c.Request.Context().Value(ctxkey.Group).(*service.Group)
if !ok || groupFromCtx == nil || groupFromCtx.ID != group.ID {
c.JSON(http.StatusInternalServerError, gin.H{"ok": false})
return
}
c.JSON(http.StatusOK, gin.H{"ok": true})
})
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
req.Header.Set("x-api-key", apiKey.Key)
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
}
func TestApiKeyAuthWithSubscriptionGoogle_QueryKeyAllowedOnV1Beta(t *testing.T) {
gin.SetMode(gin.TestMode)

View File

@@ -11,6 +11,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
@@ -25,6 +26,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
ID: 42,
Name: "sub",
Status: service.StatusActive,
Hydrated: true,
SubscriptionType: service.SubscriptionTypeSubscription,
DailyLimitUSD: &limit,
}
@@ -110,6 +112,129 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
})
}
func TestAPIKeyAuthSetsGroupContext(t *testing.T) {
gin.SetMode(gin.TestMode)
group := &service.Group{
ID: 101,
Name: "g1",
Status: service.StatusActive,
Platform: service.PlatformAnthropic,
Hydrated: true,
}
user := &service.User{
ID: 7,
Role: service.RoleUser,
Status: service.StatusActive,
Balance: 10,
Concurrency: 3,
}
apiKey := &service.APIKey{
ID: 100,
UserID: user.ID,
Key: "test-key",
Status: service.StatusActive,
User: user,
Group: group,
}
apiKey.GroupID = &group.ID
apiKeyRepo := &stubApiKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
if key != apiKey.Key {
return nil, service.ErrAPIKeyNotFound
}
clone := *apiKey
return &clone, nil
},
}
cfg := &config.Config{RunMode: config.RunModeSimple}
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
router := gin.New()
router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, nil, cfg)))
router.GET("/t", func(c *gin.Context) {
groupFromCtx, ok := c.Request.Context().Value(ctxkey.Group).(*service.Group)
if !ok || groupFromCtx == nil || groupFromCtx.ID != group.ID {
c.JSON(http.StatusInternalServerError, gin.H{"ok": false})
return
}
c.JSON(http.StatusOK, gin.H{"ok": true})
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/t", nil)
req.Header.Set("x-api-key", apiKey.Key)
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
}
func TestAPIKeyAuthOverwritesInvalidContextGroup(t *testing.T) {
gin.SetMode(gin.TestMode)
group := &service.Group{
ID: 101,
Name: "g1",
Status: service.StatusActive,
Platform: service.PlatformAnthropic,
Hydrated: true,
}
user := &service.User{
ID: 7,
Role: service.RoleUser,
Status: service.StatusActive,
Balance: 10,
Concurrency: 3,
}
apiKey := &service.APIKey{
ID: 100,
UserID: user.ID,
Key: "test-key",
Status: service.StatusActive,
User: user,
Group: group,
}
apiKey.GroupID = &group.ID
apiKeyRepo := &stubApiKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
if key != apiKey.Key {
return nil, service.ErrAPIKeyNotFound
}
clone := *apiKey
return &clone, nil
},
}
cfg := &config.Config{RunMode: config.RunModeSimple}
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
router := gin.New()
router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, nil, cfg)))
invalidGroup := &service.Group{
ID: group.ID,
Platform: group.Platform,
Status: group.Status,
}
router.GET("/t", func(c *gin.Context) {
groupFromCtx, ok := c.Request.Context().Value(ctxkey.Group).(*service.Group)
if !ok || groupFromCtx == nil || groupFromCtx.ID != group.ID || !groupFromCtx.Hydrated || groupFromCtx == invalidGroup {
c.JSON(http.StatusInternalServerError, gin.H{"ok": false})
return
}
c.JSON(http.StatusOK, gin.H{"ok": true})
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/t", nil)
req.Header.Set("x-api-key", apiKey.Key)
req = req.WithContext(context.WithValue(req.Context(), ctxkey.Group, invalidGroup))
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
}
func newAuthTestRouter(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) *gin.Engine {
router := gin.New()
router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, cfg)))

View File

@@ -576,18 +576,33 @@ func (s *adminServiceImpl) validateFallbackGroup(ctx context.Context, currentGro
return fmt.Errorf("cannot set self as fallback group")
}
// 检查降级分组是否存在
fallbackGroup, err := s.groupRepo.GetByID(ctx, fallbackGroupID)
if err != nil {
return fmt.Errorf("fallback group not found: %w", err)
}
visited := map[int64]struct{}{}
nextID := fallbackGroupID
for {
if _, seen := visited[nextID]; seen {
return fmt.Errorf("fallback group cycle detected")
}
visited[nextID] = struct{}{}
if currentGroupID > 0 && nextID == currentGroupID {
return fmt.Errorf("fallback group cycle detected")
}
// 降级分组不能启用 claude_code_only否则会造成死循环
if fallbackGroup.ClaudeCodeOnly {
return fmt.Errorf("fallback group cannot have claude_code_only enabled")
}
// 检查降级分组是否存在
fallbackGroup, err := s.groupRepo.GetByIDLite(ctx, nextID)
if err != nil {
return fmt.Errorf("fallback group not found: %w", err)
}
return nil
// 降级分组不能启用 claude_code_only否则会造成死循环
if nextID == fallbackGroupID && fallbackGroup.ClaudeCodeOnly {
return fmt.Errorf("fallback group cannot have claude_code_only enabled")
}
if fallbackGroup.FallbackGroupID == nil {
return nil
}
nextID = *fallbackGroup.FallbackGroupID
}
}
func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) {

View File

@@ -107,6 +107,10 @@ func (s *groupRepoStub) GetByID(ctx context.Context, id int64) (*Group, error) {
panic("unexpected GetByID call")
}
func (s *groupRepoStub) GetByIDLite(ctx context.Context, id int64) (*Group, error) {
panic("unexpected GetByIDLite call")
}
func (s *groupRepoStub) Update(ctx context.Context, group *Group) error {
panic("unexpected Update call")
}

View File

@@ -45,6 +45,13 @@ func (s *groupRepoStubForAdmin) GetByID(_ context.Context, _ int64) (*Group, err
return s.getByID, nil
}
func (s *groupRepoStubForAdmin) GetByIDLite(_ context.Context, _ int64) (*Group, error) {
if s.getErr != nil {
return nil, s.getErr
}
return s.getByID, nil
}
func (s *groupRepoStubForAdmin) Delete(_ context.Context, _ int64) error {
panic("unexpected Delete call")
}
@@ -290,3 +297,84 @@ func TestAdminService_ListGroups_WithSearch(t *testing.T) {
require.True(t, *repo.listWithFiltersIsExclusive)
})
}
func TestAdminService_ValidateFallbackGroup_DetectsCycle(t *testing.T) {
groupID := int64(1)
fallbackID := int64(2)
repo := &groupRepoStubForFallbackCycle{
groups: map[int64]*Group{
groupID: {
ID: groupID,
FallbackGroupID: &fallbackID,
},
fallbackID: {
ID: fallbackID,
FallbackGroupID: &groupID,
},
},
}
svc := &adminServiceImpl{groupRepo: repo}
err := svc.validateFallbackGroup(context.Background(), groupID, fallbackID)
require.Error(t, err)
require.Contains(t, err.Error(), "fallback group cycle")
}
type groupRepoStubForFallbackCycle struct {
groups map[int64]*Group
}
func (s *groupRepoStubForFallbackCycle) Create(_ context.Context, _ *Group) error {
panic("unexpected Create call")
}
func (s *groupRepoStubForFallbackCycle) Update(_ context.Context, _ *Group) error {
panic("unexpected Update call")
}
func (s *groupRepoStubForFallbackCycle) GetByID(ctx context.Context, id int64) (*Group, error) {
return s.GetByIDLite(ctx, id)
}
func (s *groupRepoStubForFallbackCycle) GetByIDLite(_ context.Context, id int64) (*Group, error) {
if g, ok := s.groups[id]; ok {
return g, nil
}
return nil, ErrGroupNotFound
}
func (s *groupRepoStubForFallbackCycle) Delete(_ context.Context, _ int64) error {
panic("unexpected Delete call")
}
func (s *groupRepoStubForFallbackCycle) DeleteCascade(_ context.Context, _ int64) ([]int64, error) {
panic("unexpected DeleteCascade call")
}
func (s *groupRepoStubForFallbackCycle) List(_ context.Context, _ pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
panic("unexpected List call")
}
func (s *groupRepoStubForFallbackCycle) ListWithFilters(_ context.Context, _ pagination.PaginationParams, _, _ string, _ *bool) ([]Group, *pagination.PaginationResult, error) {
panic("unexpected ListWithFilters call")
}
func (s *groupRepoStubForFallbackCycle) ListActive(_ context.Context) ([]Group, error) {
panic("unexpected ListActive call")
}
func (s *groupRepoStubForFallbackCycle) ListActiveByPlatform(_ context.Context, _ string) ([]Group, error) {
panic("unexpected ListActiveByPlatform call")
}
func (s *groupRepoStubForFallbackCycle) ExistsByName(_ context.Context, _ string) (bool, error) {
panic("unexpected ExistsByName call")
}
func (s *groupRepoStubForFallbackCycle) GetAccountCount(_ context.Context, _ int64) (int64, error) {
panic("unexpected GetAccountCount call")
}
func (s *groupRepoStubForFallbackCycle) DeleteAccountGroupsByGroupID(_ context.Context, _ int64) (int64, error) {
panic("unexpected DeleteAccountGroupsByGroupID call")
}

View File

@@ -9,6 +9,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
@@ -191,6 +192,56 @@ func (m *mockGatewayCacheForPlatform) RefreshSessionTTL(ctx context.Context, gro
return nil
}
type mockGroupRepoForGateway struct {
groups map[int64]*Group
getByIDCalls int
getByIDLiteCalls int
}
func (m *mockGroupRepoForGateway) GetByID(ctx context.Context, id int64) (*Group, error) {
m.getByIDCalls++
if g, ok := m.groups[id]; ok {
return g, nil
}
return nil, ErrGroupNotFound
}
func (m *mockGroupRepoForGateway) GetByIDLite(ctx context.Context, id int64) (*Group, error) {
m.getByIDLiteCalls++
if g, ok := m.groups[id]; ok {
return g, nil
}
return nil, ErrGroupNotFound
}
func (m *mockGroupRepoForGateway) Create(ctx context.Context, group *Group) error { return nil }
func (m *mockGroupRepoForGateway) Update(ctx context.Context, group *Group) error { return nil }
func (m *mockGroupRepoForGateway) Delete(ctx context.Context, id int64) error { return nil }
func (m *mockGroupRepoForGateway) DeleteCascade(ctx context.Context, id int64) ([]int64, error) {
return nil, nil
}
func (m *mockGroupRepoForGateway) List(ctx context.Context, params pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (m *mockGroupRepoForGateway) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (m *mockGroupRepoForGateway) ListActive(ctx context.Context) ([]Group, error) {
return nil, nil
}
func (m *mockGroupRepoForGateway) ListActiveByPlatform(ctx context.Context, platform string) ([]Group, error) {
return nil, nil
}
func (m *mockGroupRepoForGateway) ExistsByName(ctx context.Context, name string) (bool, error) {
return false, nil
}
func (m *mockGroupRepoForGateway) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
return 0, nil
}
func (m *mockGroupRepoForGateway) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, nil
}
func ptr[T any](v T) *T {
return &v
}
@@ -1019,3 +1070,190 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
require.Contains(t, err.Error(), "no available accounts")
})
}
func TestGatewayService_GroupResolution_ReusesContextGroup(t *testing.T) {
ctx := context.Background()
groupID := int64(42)
group := &Group{
ID: groupID,
Platform: PlatformAnthropic,
Status: StatusActive,
Hydrated: true,
}
ctx = context.WithValue(ctx, ctxkey.Group, group)
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
groupRepo := &mockGroupRepoForGateway{
groups: map[int64]*Group{groupID: group},
}
svc := &GatewayService{
accountRepo: repo,
groupRepo: groupRepo,
cfg: testConfig(),
}
account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "claude-3-5-sonnet-20241022", nil)
require.NoError(t, err)
require.NotNil(t, account)
require.Equal(t, 0, groupRepo.getByIDCalls)
require.Equal(t, 0, groupRepo.getByIDLiteCalls)
}
func TestGatewayService_GroupResolution_IgnoresInvalidContextGroup(t *testing.T) {
ctx := context.Background()
groupID := int64(42)
ctxGroup := &Group{
ID: groupID,
Platform: PlatformAnthropic,
Status: StatusActive,
}
ctx = context.WithValue(ctx, ctxkey.Group, ctxGroup)
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
group := &Group{
ID: groupID,
Platform: PlatformAnthropic,
Status: StatusActive,
}
groupRepo := &mockGroupRepoForGateway{
groups: map[int64]*Group{groupID: group},
}
svc := &GatewayService{
accountRepo: repo,
groupRepo: groupRepo,
cfg: testConfig(),
}
account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "claude-3-5-sonnet-20241022", nil)
require.NoError(t, err)
require.NotNil(t, account)
require.Equal(t, 0, groupRepo.getByIDCalls)
require.Equal(t, 1, groupRepo.getByIDLiteCalls)
}
func TestGatewayService_GroupContext_OverwritesInvalidContextGroup(t *testing.T) {
groupID := int64(42)
invalidGroup := &Group{
ID: groupID,
Platform: PlatformAnthropic,
Status: StatusActive,
}
hydratedGroup := &Group{
ID: groupID,
Platform: PlatformAnthropic,
Status: StatusActive,
Hydrated: true,
}
ctx := context.WithValue(context.Background(), ctxkey.Group, invalidGroup)
svc := &GatewayService{}
ctx = svc.withGroupContext(ctx, hydratedGroup)
got, ok := ctx.Value(ctxkey.Group).(*Group)
require.True(t, ok)
require.Same(t, hydratedGroup, got)
}
func TestGatewayService_GroupResolution_FallbackUsesLiteOnce(t *testing.T) {
ctx := context.Background()
groupID := int64(10)
fallbackID := int64(11)
group := &Group{
ID: groupID,
Platform: PlatformAnthropic,
Status: StatusActive,
ClaudeCodeOnly: true,
FallbackGroupID: &fallbackID,
Hydrated: true,
}
fallbackGroup := &Group{
ID: fallbackID,
Platform: PlatformAnthropic,
Status: StatusActive,
}
ctx = context.WithValue(ctx, ctxkey.Group, group)
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
groupRepo := &mockGroupRepoForGateway{
groups: map[int64]*Group{fallbackID: fallbackGroup},
}
svc := &GatewayService{
accountRepo: repo,
groupRepo: groupRepo,
cfg: testConfig(),
}
account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "claude-3-5-sonnet-20241022", nil)
require.NoError(t, err)
require.NotNil(t, account)
require.Equal(t, 0, groupRepo.getByIDCalls)
require.Equal(t, 1, groupRepo.getByIDLiteCalls)
}
func TestGatewayService_ResolveGatewayGroup_DetectsFallbackCycle(t *testing.T) {
ctx := context.Background()
groupID := int64(10)
fallbackID := int64(11)
group := &Group{
ID: groupID,
Platform: PlatformAnthropic,
Status: StatusActive,
ClaudeCodeOnly: true,
FallbackGroupID: &fallbackID,
}
fallbackGroup := &Group{
ID: fallbackID,
Platform: PlatformAnthropic,
Status: StatusActive,
ClaudeCodeOnly: true,
FallbackGroupID: &groupID,
}
groupRepo := &mockGroupRepoForGateway{
groups: map[int64]*Group{
groupID: group,
fallbackID: fallbackGroup,
},
}
svc := &GatewayService{
groupRepo: groupRepo,
}
gotGroup, gotID, err := svc.resolveGatewayGroup(ctx, &groupID)
require.Error(t, err)
require.Nil(t, gotGroup)
require.Nil(t, gotID)
require.Contains(t, err.Error(), "fallback group cycle")
}

View File

@@ -361,27 +361,13 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
if hasForcePlatform && forcePlatform != "" {
platform = forcePlatform
} else if groupID != nil {
// 根据分组 platform 决定查询哪种账号
group, err := s.groupRepo.GetByID(ctx, *groupID)
group, resolvedGroupID, err := s.resolveGatewayGroup(ctx, groupID)
if err != nil {
return nil, fmt.Errorf("get group failed: %w", err)
return nil, err
}
groupID = resolvedGroupID
ctx = s.withGroupContext(ctx, group)
platform = group.Platform
// 检查 Claude Code 客户端限制
if group.ClaudeCodeOnly {
isClaudeCode := IsClaudeCodeClient(ctx)
if !isClaudeCode {
// 非 Claude Code 客户端,检查是否有降级分组
if group.FallbackGroupID != nil {
// 使用降级分组重新调度
fallbackGroupID := *group.FallbackGroupID
return s.SelectAccountForModelWithExclusions(ctx, &fallbackGroupID, sessionHash, requestedModel, excludedIDs)
}
// 无降级分组,拒绝访问
return nil, ErrClaudeCodeOnly
}
}
} else {
// 无分组时只使用原生 anthropic 平台
platform = PlatformAnthropic
@@ -409,10 +395,11 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
}
// 检查 Claude Code 客户端限制(可能会替换 groupID 为降级分组)
groupID, err := s.checkClaudeCodeRestriction(ctx, groupID)
group, groupID, err := s.checkClaudeCodeRestriction(ctx, groupID)
if err != nil {
return nil, err
}
ctx = s.withGroupContext(ctx, group)
if s.concurrencyService == nil || !cfg.LoadBatchEnabled {
account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs)
@@ -452,7 +439,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
}, nil
}
platform, hasForcePlatform, err := s.resolvePlatform(ctx, groupID)
platform, hasForcePlatform, err := s.resolvePlatform(ctx, groupID, group)
if err != nil {
return nil, err
}
@@ -655,51 +642,97 @@ func (s *GatewayService) schedulingConfig() config.GatewaySchedulingConfig {
}
}
func (s *GatewayService) withGroupContext(ctx context.Context, group *Group) context.Context {
if !IsGroupContextValid(group) {
return ctx
}
if existing, ok := ctx.Value(ctxkey.Group).(*Group); ok && existing != nil && existing.ID == group.ID && IsGroupContextValid(existing) {
return ctx
}
return context.WithValue(ctx, ctxkey.Group, group)
}
func (s *GatewayService) groupFromContext(ctx context.Context, groupID int64) *Group {
if group, ok := ctx.Value(ctxkey.Group).(*Group); ok && IsGroupContextValid(group) && group.ID == groupID {
return group
}
return nil
}
func (s *GatewayService) resolveGroupByID(ctx context.Context, groupID int64) (*Group, error) {
if group := s.groupFromContext(ctx, groupID); group != nil {
return group, nil
}
group, err := s.groupRepo.GetByIDLite(ctx, groupID)
if err != nil {
return nil, fmt.Errorf("get group failed: %w", err)
}
return group, nil
}
func (s *GatewayService) resolveGatewayGroup(ctx context.Context, groupID *int64) (*Group, *int64, error) {
if groupID == nil {
return nil, nil, nil
}
currentID := *groupID
visited := map[int64]struct{}{}
for {
if _, seen := visited[currentID]; seen {
return nil, nil, fmt.Errorf("fallback group cycle detected")
}
visited[currentID] = struct{}{}
group, err := s.resolveGroupByID(ctx, currentID)
if err != nil {
return nil, nil, err
}
if !group.ClaudeCodeOnly || IsClaudeCodeClient(ctx) {
return group, &currentID, nil
}
if group.FallbackGroupID == nil {
return nil, nil, ErrClaudeCodeOnly
}
currentID = *group.FallbackGroupID
}
}
// checkClaudeCodeRestriction 检查分组的 Claude Code 客户端限制
// 如果分组启用了 claude_code_only 且请求不是来自 Claude Code 客户端:
// - 有降级分组:返回降级分组的 ID
// - 无降级分组:返回 ErrClaudeCodeOnly 错误
func (s *GatewayService) checkClaudeCodeRestriction(ctx context.Context, groupID *int64) (*int64, error) {
func (s *GatewayService) checkClaudeCodeRestriction(ctx context.Context, groupID *int64) (*Group, *int64, error) {
if groupID == nil {
return groupID, nil
return nil, groupID, nil
}
// 强制平台模式不检查 Claude Code 限制
if _, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string); hasForcePlatform {
return groupID, nil
return nil, groupID, nil
}
group, err := s.groupRepo.GetByID(ctx, *groupID)
group, resolvedID, err := s.resolveGatewayGroup(ctx, groupID)
if err != nil {
return nil, fmt.Errorf("get group failed: %w", err)
return nil, nil, err
}
if !group.ClaudeCodeOnly {
return groupID, nil
}
// 分组启用了 Claude Code 限制
if IsClaudeCodeClient(ctx) {
return groupID, nil
}
// 非 Claude Code 客户端,检查降级分组
if group.FallbackGroupID != nil {
return group.FallbackGroupID, nil
}
return nil, ErrClaudeCodeOnly
return group, resolvedID, nil
}
func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64) (string, bool, error) {
func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64, group *Group) (string, bool, error) {
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
if hasForcePlatform && forcePlatform != "" {
return forcePlatform, true, nil
}
if group != nil {
return group.Platform, false, nil
}
if groupID != nil {
group, err := s.groupRepo.GetByID(ctx, *groupID)
group, err := s.resolveGroupByID(ctx, *groupID)
if err != nil {
return "", false, fmt.Errorf("get group failed: %w", err)
return "", false, err
}
return group.Platform, false, nil
}

View File

@@ -86,9 +86,15 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
platform = forcePlatform
} else if groupID != nil {
// 根据分组 platform 决定查询哪种账号
group, err := s.groupRepo.GetByID(ctx, *groupID)
if err != nil {
return nil, fmt.Errorf("get group failed: %w", err)
var group *Group
if ctxGroup, ok := ctx.Value(ctxkey.Group).(*Group); ok && IsGroupContextValid(ctxGroup) && ctxGroup.ID == *groupID {
group = ctxGroup
} else {
var err error
group, err = s.groupRepo.GetByIDLite(ctx, *groupID)
if err != nil {
return nil, fmt.Errorf("get group failed: %w", err)
}
}
platform = group.Platform
} else {

View File

@@ -8,6 +8,7 @@ import (
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
@@ -152,10 +153,21 @@ var _ AccountRepository = (*mockAccountRepoForGemini)(nil)
// mockGroupRepoForGemini Gemini 测试用的 group repo mock
type mockGroupRepoForGemini struct {
groups map[int64]*Group
groups map[int64]*Group
getByIDCalls int
getByIDLiteCalls int
}
func (m *mockGroupRepoForGemini) GetByID(ctx context.Context, id int64) (*Group, error) {
m.getByIDCalls++
if g, ok := m.groups[id]; ok {
return g, nil
}
return nil, errors.New("group not found")
}
func (m *mockGroupRepoForGemini) GetByIDLite(ctx context.Context, id int64) (*Group, error) {
m.getByIDLiteCalls++
if g, ok := m.groups[id]; ok {
return g, nil
}
@@ -248,6 +260,77 @@ func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiP
require.Equal(t, PlatformGemini, acc.Platform, "无分组时应只返回 gemini 平台账户")
}
func TestGeminiMessagesCompatService_GroupResolution_ReusesContextGroup(t *testing.T) {
ctx := context.Background()
groupID := int64(7)
group := &Group{
ID: groupID,
Platform: PlatformGemini,
Status: StatusActive,
Hydrated: true,
}
ctx = context.WithValue(ctx, ctxkey.Group, group)
repo := &mockAccountRepoForGemini{
accounts: []Account{
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForGemini{}
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
svc := &GeminiMessagesCompatService{
accountRepo: repo,
groupRepo: groupRepo,
cache: cache,
}
acc, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "gemini-2.5-flash", nil)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, 0, groupRepo.getByIDCalls)
require.Equal(t, 0, groupRepo.getByIDLiteCalls)
}
func TestGeminiMessagesCompatService_GroupResolution_UsesLiteFetch(t *testing.T) {
ctx := context.Background()
groupID := int64(7)
repo := &mockAccountRepoForGemini{
accounts: []Account{
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForGemini{}
groupRepo := &mockGroupRepoForGemini{
groups: map[int64]*Group{
groupID: {ID: groupID, Platform: PlatformGemini},
},
}
svc := &GeminiMessagesCompatService{
accountRepo: repo,
groupRepo: groupRepo,
cache: cache,
}
acc, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "gemini-2.5-flash", nil)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, 0, groupRepo.getByIDCalls)
require.Equal(t, 1, groupRepo.getByIDLiteCalls)
}
// TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_AntigravityGroup 测试 antigravity 分组
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_AntigravityGroup(t *testing.T) {
ctx := context.Background()

View File

@@ -10,6 +10,7 @@ type Group struct {
RateMultiplier float64
IsExclusive bool
Status string
Hydrated bool // indicates the group was loaded from a trusted repository source
SubscriptionType string
DailyLimitUSD *float64
@@ -72,3 +73,20 @@ func (g *Group) GetImagePrice(imageSize string) *float64 {
return g.ImagePrice2K
}
}
// IsGroupContextValid reports whether a group from context has the fields required for routing decisions.
func IsGroupContextValid(group *Group) bool {
if group == nil {
return false
}
if group.ID <= 0 {
return false
}
if !group.Hydrated {
return false
}
if group.Platform == "" || group.Status == "" {
return false
}
return true
}

View File

@@ -16,6 +16,7 @@ var (
type GroupRepository interface {
Create(ctx context.Context, group *Group) error
GetByID(ctx context.Context, id int64) (*Group, error)
GetByIDLite(ctx context.Context, id int64) (*Group, error)
Update(ctx context.Context, group *Group) error
Delete(ctx context.Context, id int64) error
DeleteCascade(ctx context.Context, id int64) ([]int64, error)