merge: 合并主分支改动并保留 ops 监控实现

合并 main 分支的最新改动到 ops 监控分支。
冲突解决策略:保留当前分支的 ops 相关改动,接受主分支的其他改动。

保留的 ops 改动:
- 运维监控配置和依赖注入
- 运维监控 API 处理器和中间件
- 运维监控服务层和数据访问层
- 运维监控前端界面和状态管理

接受的主分支改动:
- Linux DO OAuth 集成
- 账号过期功能
- IP 地址限制功能
- 用量统计优化
- 其他 bug 修复和功能改进
This commit is contained in:
IanShaw027
2026-01-10 13:24:40 +08:00
155 changed files with 9227 additions and 1355 deletions

View File

@@ -9,21 +9,23 @@ import (
)
type Account struct {
ID int64
Name string
Notes *string
Platform string
Type string
Credentials map[string]any
Extra map[string]any
ProxyID *int64
Concurrency int
Priority int
Status string
ErrorMessage string
LastUsedAt *time.Time
CreatedAt time.Time
UpdatedAt time.Time
ID int64
Name string
Notes *string
Platform string
Type string
Credentials map[string]any
Extra map[string]any
ProxyID *int64
Concurrency int
Priority int
Status string
ErrorMessage string
LastUsedAt *time.Time
ExpiresAt *time.Time
AutoPauseOnExpired bool
CreatedAt time.Time
UpdatedAt time.Time
Schedulable bool
@@ -60,6 +62,9 @@ func (a *Account) IsSchedulable() bool {
return false
}
now := time.Now()
if a.AutoPauseOnExpired && a.ExpiresAt != nil && !now.Before(*a.ExpiresAt) {
return false
}
if a.OverloadUntil != nil && now.Before(*a.OverloadUntil) {
return false
}

View File

@@ -0,0 +1,71 @@
package service
import (
"context"
"log"
"sync"
"time"
)
// AccountExpiryService periodically pauses expired accounts when auto-pause is enabled.
type AccountExpiryService struct {
accountRepo AccountRepository
interval time.Duration
stopCh chan struct{}
stopOnce sync.Once
wg sync.WaitGroup
}
func NewAccountExpiryService(accountRepo AccountRepository, interval time.Duration) *AccountExpiryService {
return &AccountExpiryService{
accountRepo: accountRepo,
interval: interval,
stopCh: make(chan struct{}),
}
}
func (s *AccountExpiryService) Start() {
if s == nil || s.accountRepo == nil || s.interval <= 0 {
return
}
s.wg.Add(1)
go func() {
defer s.wg.Done()
ticker := time.NewTicker(s.interval)
defer ticker.Stop()
s.runOnce()
for {
select {
case <-ticker.C:
s.runOnce()
case <-s.stopCh:
return
}
}
}()
}
func (s *AccountExpiryService) Stop() {
if s == nil {
return
}
s.stopOnce.Do(func() {
close(s.stopCh)
})
s.wg.Wait()
}
func (s *AccountExpiryService) runOnce() {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
updated, err := s.accountRepo.AutoPauseExpiredAccounts(ctx, time.Now())
if err != nil {
log.Printf("[AccountExpiry] Auto pause expired accounts failed: %v", err)
return
}
if updated > 0 {
log.Printf("[AccountExpiry] Auto paused %d expired accounts", updated)
}
}

View File

@@ -38,6 +38,7 @@ type AccountRepository interface {
BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error
SetError(ctx context.Context, id int64, errorMsg string) error
SetSchedulable(ctx context.Context, id int64, schedulable bool) error
AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error)
BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error
ListSchedulable(ctx context.Context) ([]Account, error)
@@ -48,10 +49,12 @@ type AccountRepository interface {
ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error)
SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error
SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error
SetOverloaded(ctx context.Context, id int64, until time.Time) error
SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error
ClearTempUnschedulable(ctx context.Context, id int64) error
ClearRateLimit(ctx context.Context, id int64) error
ClearAntigravityQuotaScopes(ctx context.Context, id int64) error
UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error
UpdateExtra(ctx context.Context, id int64, updates map[string]any) error
BulkUpdate(ctx context.Context, ids []int64, updates AccountBulkUpdate) (int64, error)
@@ -65,35 +68,40 @@ type AccountBulkUpdate struct {
Concurrency *int
Priority *int
Status *string
Schedulable *bool
Credentials map[string]any
Extra map[string]any
}
// CreateAccountRequest 创建账号请求
type CreateAccountRequest struct {
Name string `json:"name"`
Notes *string `json:"notes"`
Platform string `json:"platform"`
Type string `json:"type"`
Credentials map[string]any `json:"credentials"`
Extra map[string]any `json:"extra"`
ProxyID *int64 `json:"proxy_id"`
Concurrency int `json:"concurrency"`
Priority int `json:"priority"`
GroupIDs []int64 `json:"group_ids"`
Name string `json:"name"`
Notes *string `json:"notes"`
Platform string `json:"platform"`
Type string `json:"type"`
Credentials map[string]any `json:"credentials"`
Extra map[string]any `json:"extra"`
ProxyID *int64 `json:"proxy_id"`
Concurrency int `json:"concurrency"`
Priority int `json:"priority"`
GroupIDs []int64 `json:"group_ids"`
ExpiresAt *time.Time `json:"expires_at"`
AutoPauseOnExpired *bool `json:"auto_pause_on_expired"`
}
// UpdateAccountRequest 更新账号请求
type UpdateAccountRequest struct {
Name *string `json:"name"`
Notes *string `json:"notes"`
Credentials *map[string]any `json:"credentials"`
Extra *map[string]any `json:"extra"`
ProxyID *int64 `json:"proxy_id"`
Concurrency *int `json:"concurrency"`
Priority *int `json:"priority"`
Status *string `json:"status"`
GroupIDs *[]int64 `json:"group_ids"`
Name *string `json:"name"`
Notes *string `json:"notes"`
Credentials *map[string]any `json:"credentials"`
Extra *map[string]any `json:"extra"`
ProxyID *int64 `json:"proxy_id"`
Concurrency *int `json:"concurrency"`
Priority *int `json:"priority"`
Status *string `json:"status"`
GroupIDs *[]int64 `json:"group_ids"`
ExpiresAt *time.Time `json:"expires_at"`
AutoPauseOnExpired *bool `json:"auto_pause_on_expired"`
}
// AccountService 账号管理服务
@@ -134,6 +142,12 @@ func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (
Concurrency: req.Concurrency,
Priority: req.Priority,
Status: StatusActive,
ExpiresAt: req.ExpiresAt,
}
if req.AutoPauseOnExpired != nil {
account.AutoPauseOnExpired = *req.AutoPauseOnExpired
} else {
account.AutoPauseOnExpired = true
}
if err := s.accountRepo.Create(ctx, account); err != nil {
@@ -224,6 +238,12 @@ func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccount
if req.Status != nil {
account.Status = *req.Status
}
if req.ExpiresAt != nil {
account.ExpiresAt = req.ExpiresAt
}
if req.AutoPauseOnExpired != nil {
account.AutoPauseOnExpired = *req.AutoPauseOnExpired
}
// 先验证分组是否存在(在任何写操作之前)
if req.GroupIDs != nil {

View File

@@ -103,6 +103,10 @@ func (s *accountRepoStub) SetSchedulable(ctx context.Context, id int64, schedula
panic("unexpected SetSchedulable call")
}
func (s *accountRepoStub) AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error) {
panic("unexpected AutoPauseExpiredAccounts call")
}
func (s *accountRepoStub) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
panic("unexpected BindGroups call")
}
@@ -135,6 +139,10 @@ func (s *accountRepoStub) SetRateLimited(ctx context.Context, id int64, resetAt
panic("unexpected SetRateLimited call")
}
func (s *accountRepoStub) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error {
panic("unexpected SetAntigravityQuotaScopeLimit call")
}
func (s *accountRepoStub) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
panic("unexpected SetOverloaded call")
}
@@ -151,6 +159,10 @@ func (s *accountRepoStub) ClearRateLimit(ctx context.Context, id int64) error {
panic("unexpected ClearRateLimit call")
}
func (s *accountRepoStub) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error {
panic("unexpected ClearAntigravityQuotaScopes call")
}
func (s *accountRepoStub) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
panic("unexpected UpdateSessionWindow call")
}

View File

@@ -661,13 +661,7 @@ func (s *AccountTestService) processGeminiStream(c *gin.Context, body io.Reader)
}
if candidates, ok := data["candidates"].([]any); ok && len(candidates) > 0 {
if candidate, ok := candidates[0].(map[string]any); ok {
// Check for completion
if finishReason, ok := candidate["finishReason"].(string); ok && finishReason != "" {
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
}
// Extract content
// Extract content first (before checking completion)
if content, ok := candidate["content"].(map[string]any); ok {
if parts, ok := content["parts"].([]any); ok {
for _, part := range parts {
@@ -679,6 +673,12 @@ func (s *AccountTestService) processGeminiStream(c *gin.Context, body io.Reader)
}
}
}
// Check for completion after extracting content
if finishReason, ok := candidate["finishReason"].(string); ok && finishReason != "" {
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
}
}
}

View File

@@ -47,6 +47,7 @@ type UsageLogRepository interface {
// Admin usage listing/stats
ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]UsageLog, *pagination.PaginationResult, error)
GetGlobalStats(ctx context.Context, startTime, endTime time.Time) (*usagestats.UsageStats, error)
GetStatsWithFilters(ctx context.Context, filters usagestats.UsageLogFilters) (*usagestats.UsageStats, error)
// Account stats
GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.AccountUsageStatsResponse, error)

View File

@@ -24,7 +24,7 @@ type AdminService interface {
GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error)
// Group management
ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]Group, int64, error)
ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool) ([]Group, int64, error)
GetAllGroups(ctx context.Context) ([]Group, error)
GetAllGroupsByPlatform(ctx context.Context, platform string) ([]Group, error)
GetGroup(ctx context.Context, id int64) (*Group, error)
@@ -47,6 +47,7 @@ type AdminService interface {
// Proxy management
ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]Proxy, int64, error)
ListProxiesWithAccountCount(ctx context.Context, page, pageSize int, protocol, status, search string) ([]ProxyWithAccountCount, int64, error)
GetAllProxies(ctx context.Context) ([]Proxy, error)
GetAllProxiesWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error)
GetProxy(ctx context.Context, id int64) (*Proxy, error)
@@ -99,9 +100,11 @@ type CreateGroupInput struct {
WeeklyLimitUSD *float64 // 周限额 (USD)
MonthlyLimitUSD *float64 // 月限额 (USD)
// 图片生成计费配置(仅 antigravity 平台使用)
ImagePrice1K *float64
ImagePrice2K *float64
ImagePrice4K *float64
ImagePrice1K *float64
ImagePrice2K *float64
ImagePrice4K *float64
ClaudeCodeOnly bool // 仅允许 Claude Code 客户端
FallbackGroupID *int64 // 降级分组 ID
}
type UpdateGroupInput struct {
@@ -116,22 +119,26 @@ type UpdateGroupInput struct {
WeeklyLimitUSD *float64 // 周限额 (USD)
MonthlyLimitUSD *float64 // 月限额 (USD)
// 图片生成计费配置(仅 antigravity 平台使用)
ImagePrice1K *float64
ImagePrice2K *float64
ImagePrice4K *float64
ImagePrice1K *float64
ImagePrice2K *float64
ImagePrice4K *float64
ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端
FallbackGroupID *int64 // 降级分组 ID
}
type CreateAccountInput struct {
Name string
Notes *string
Platform string
Type string
Credentials map[string]any
Extra map[string]any
ProxyID *int64
Concurrency int
Priority int
GroupIDs []int64
Name string
Notes *string
Platform string
Type string
Credentials map[string]any
Extra map[string]any
ProxyID *int64
Concurrency int
Priority int
GroupIDs []int64
ExpiresAt *int64
AutoPauseOnExpired *bool
// SkipMixedChannelCheck skips the mixed channel risk check when binding groups.
// This should only be set when the caller has explicitly confirmed the risk.
SkipMixedChannelCheck bool
@@ -148,6 +155,8 @@ type UpdateAccountInput struct {
Priority *int // 使用指针区分"未提供"和"设置为0"
Status string
GroupIDs *[]int64
ExpiresAt *int64
AutoPauseOnExpired *bool
SkipMixedChannelCheck bool // 跳过混合渠道检查(用户已确认风险)
}
@@ -159,6 +168,7 @@ type BulkUpdateAccountsInput struct {
Concurrency *int
Priority *int
Status string
Schedulable *bool
GroupIDs *[]int64
Credentials map[string]any
Extra map[string]any
@@ -469,9 +479,9 @@ func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64,
}
// Group management implementations
func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]Group, int64, error) {
func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool) ([]Group, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
groups, result, err := s.groupRepo.ListWithFilters(ctx, params, platform, status, isExclusive)
groups, result, err := s.groupRepo.ListWithFilters(ctx, params, platform, status, search, isExclusive)
if err != nil {
return nil, 0, err
}
@@ -511,6 +521,13 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
imagePrice2K := normalizePrice(input.ImagePrice2K)
imagePrice4K := normalizePrice(input.ImagePrice4K)
// 校验降级分组
if input.FallbackGroupID != nil {
if err := s.validateFallbackGroup(ctx, 0, *input.FallbackGroupID); err != nil {
return nil, err
}
}
group := &Group{
Name: input.Name,
Description: input.Description,
@@ -525,6 +542,8 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
ImagePrice1K: imagePrice1K,
ImagePrice2K: imagePrice2K,
ImagePrice4K: imagePrice4K,
ClaudeCodeOnly: input.ClaudeCodeOnly,
FallbackGroupID: input.FallbackGroupID,
}
if err := s.groupRepo.Create(ctx, group); err != nil {
return nil, err
@@ -548,6 +567,29 @@ func normalizePrice(price *float64) *float64 {
return price
}
// validateFallbackGroup 校验降级分组的有效性
// currentGroupID: 当前分组 ID新建时为 0
// fallbackGroupID: 降级分组 ID
func (s *adminServiceImpl) validateFallbackGroup(ctx context.Context, currentGroupID, fallbackGroupID int64) error {
// 不能将自己设置为降级分组
if currentGroupID > 0 && currentGroupID == fallbackGroupID {
return fmt.Errorf("cannot set self as fallback group")
}
// 检查降级分组是否存在
fallbackGroup, err := s.groupRepo.GetByID(ctx, fallbackGroupID)
if err != nil {
return fmt.Errorf("fallback group not found: %w", err)
}
// 降级分组不能启用 claude_code_only否则会造成死循环
if fallbackGroup.ClaudeCodeOnly {
return fmt.Errorf("fallback group cannot have claude_code_only enabled")
}
return nil
}
func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) {
group, err := s.groupRepo.GetByID(ctx, id)
if err != nil {
@@ -598,6 +640,23 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
group.ImagePrice4K = normalizePrice(input.ImagePrice4K)
}
// Claude Code 客户端限制
if input.ClaudeCodeOnly != nil {
group.ClaudeCodeOnly = *input.ClaudeCodeOnly
}
if input.FallbackGroupID != nil {
// 校验降级分组
if *input.FallbackGroupID > 0 {
if err := s.validateFallbackGroup(ctx, id, *input.FallbackGroupID); err != nil {
return nil, err
}
group.FallbackGroupID = input.FallbackGroupID
} else {
// 传入 0 或负数表示清除降级分组
group.FallbackGroupID = nil
}
}
if err := s.groupRepo.Update(ctx, group); err != nil {
return nil, err
}
@@ -700,6 +759,15 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
Status: StatusActive,
Schedulable: true,
}
if input.ExpiresAt != nil && *input.ExpiresAt > 0 {
expiresAt := time.Unix(*input.ExpiresAt, 0)
account.ExpiresAt = &expiresAt
}
if input.AutoPauseOnExpired != nil {
account.AutoPauseOnExpired = *input.AutoPauseOnExpired
} else {
account.AutoPauseOnExpired = true
}
if err := s.accountRepo.Create(ctx, account); err != nil {
return nil, err
}
@@ -755,6 +823,17 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
if input.Status != "" {
account.Status = input.Status
}
if input.ExpiresAt != nil {
if *input.ExpiresAt <= 0 {
account.ExpiresAt = nil
} else {
expiresAt := time.Unix(*input.ExpiresAt, 0)
account.ExpiresAt = &expiresAt
}
}
if input.AutoPauseOnExpired != nil {
account.AutoPauseOnExpired = *input.AutoPauseOnExpired
}
// 先验证分组是否存在(在任何写操作之前)
if input.GroupIDs != nil {
@@ -832,6 +911,9 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
if input.Status != "" {
repoUpdates.Status = &input.Status
}
if input.Schedulable != nil {
repoUpdates.Schedulable = input.Schedulable
}
// Run bulk update for column/jsonb fields first.
if _, err := s.accountRepo.BulkUpdate(ctx, input.AccountIDs, repoUpdates); err != nil {
@@ -926,6 +1008,15 @@ func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int,
return proxies, result.Total, nil
}
func (s *adminServiceImpl) ListProxiesWithAccountCount(ctx context.Context, page, pageSize int, protocol, status, search string) ([]ProxyWithAccountCount, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
proxies, result, err := s.proxyRepo.ListWithFiltersAndAccountCount(ctx, params, protocol, status, search)
if err != nil {
return nil, 0, err
}
return proxies, result.Total, nil
}
func (s *adminServiceImpl) GetAllProxies(ctx context.Context) ([]Proxy, error) {
return s.proxyRepo.ListActive(ctx)
}

View File

@@ -124,7 +124,7 @@ func (s *groupRepoStub) List(ctx context.Context, params pagination.PaginationPa
panic("unexpected List call")
}
func (s *groupRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) {
func (s *groupRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) {
panic("unexpected ListWithFilters call")
}
@@ -186,6 +186,10 @@ func (s *proxyRepoStub) ListActiveWithAccountCount(ctx context.Context) ([]Proxy
panic("unexpected ListActiveWithAccountCount call")
}
func (s *proxyRepoStub) ListWithFiltersAndAccountCount(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]ProxyWithAccountCount, *pagination.PaginationResult, error) {
panic("unexpected ListWithFiltersAndAccountCount call")
}
func (s *proxyRepoStub) ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error) {
panic("unexpected ExistsByHostPortAuth call")
}

View File

@@ -16,6 +16,16 @@ type groupRepoStubForAdmin struct {
updated *Group // 记录 Update 调用的参数
getByID *Group // GetByID 返回值
getErr error // GetByID 返回的错误
listWithFiltersCalls int
listWithFiltersParams pagination.PaginationParams
listWithFiltersPlatform string
listWithFiltersStatus string
listWithFiltersSearch string
listWithFiltersIsExclusive *bool
listWithFiltersGroups []Group
listWithFiltersResult *pagination.PaginationResult
listWithFiltersErr error
}
func (s *groupRepoStubForAdmin) Create(_ context.Context, g *Group) error {
@@ -47,8 +57,28 @@ func (s *groupRepoStubForAdmin) List(_ context.Context, _ pagination.PaginationP
panic("unexpected List call")
}
func (s *groupRepoStubForAdmin) ListWithFilters(_ context.Context, _ pagination.PaginationParams, _, _ string, _ *bool) ([]Group, *pagination.PaginationResult, error) {
panic("unexpected ListWithFilters call")
func (s *groupRepoStubForAdmin) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) {
s.listWithFiltersCalls++
s.listWithFiltersParams = params
s.listWithFiltersPlatform = platform
s.listWithFiltersStatus = status
s.listWithFiltersSearch = search
s.listWithFiltersIsExclusive = isExclusive
if s.listWithFiltersErr != nil {
return nil, nil, s.listWithFiltersErr
}
result := s.listWithFiltersResult
if result == nil {
result = &pagination.PaginationResult{
Total: int64(len(s.listWithFiltersGroups)),
Page: params.Page,
PageSize: params.PageSize,
}
}
return s.listWithFiltersGroups, result, nil
}
func (s *groupRepoStubForAdmin) ListActive(_ context.Context) ([]Group, error) {
@@ -195,3 +225,68 @@ func TestAdminService_UpdateGroup_PartialImagePricing(t *testing.T) {
require.InDelta(t, 0.15, *repo.updated.ImagePrice2K, 0.0001) // 原值保持
require.Nil(t, repo.updated.ImagePrice4K)
}
func TestAdminService_ListGroups_WithSearch(t *testing.T) {
// 测试:
// 1. search 参数正常传递到 repository 层
// 2. search 为空字符串时的行为
// 3. search 与其他过滤条件组合使用
t.Run("search 参数正常传递到 repository 层", func(t *testing.T) {
repo := &groupRepoStubForAdmin{
listWithFiltersGroups: []Group{{ID: 1, Name: "alpha"}},
listWithFiltersResult: &pagination.PaginationResult{Total: 1},
}
svc := &adminServiceImpl{groupRepo: repo}
groups, total, err := svc.ListGroups(context.Background(), 1, 20, "", "", "alpha", nil)
require.NoError(t, err)
require.Equal(t, int64(1), total)
require.Equal(t, []Group{{ID: 1, Name: "alpha"}}, groups)
require.Equal(t, 1, repo.listWithFiltersCalls)
require.Equal(t, pagination.PaginationParams{Page: 1, PageSize: 20}, repo.listWithFiltersParams)
require.Equal(t, "alpha", repo.listWithFiltersSearch)
require.Nil(t, repo.listWithFiltersIsExclusive)
})
t.Run("search 为空字符串时传递空字符串", func(t *testing.T) {
repo := &groupRepoStubForAdmin{
listWithFiltersGroups: []Group{},
listWithFiltersResult: &pagination.PaginationResult{Total: 0},
}
svc := &adminServiceImpl{groupRepo: repo}
groups, total, err := svc.ListGroups(context.Background(), 2, 10, "", "", "", nil)
require.NoError(t, err)
require.Empty(t, groups)
require.Equal(t, int64(0), total)
require.Equal(t, 1, repo.listWithFiltersCalls)
require.Equal(t, pagination.PaginationParams{Page: 2, PageSize: 10}, repo.listWithFiltersParams)
require.Equal(t, "", repo.listWithFiltersSearch)
require.Nil(t, repo.listWithFiltersIsExclusive)
})
t.Run("search 与其他过滤条件组合使用", func(t *testing.T) {
isExclusive := true
repo := &groupRepoStubForAdmin{
listWithFiltersGroups: []Group{{ID: 2, Name: "beta"}},
listWithFiltersResult: &pagination.PaginationResult{Total: 42},
}
svc := &adminServiceImpl{groupRepo: repo}
groups, total, err := svc.ListGroups(context.Background(), 3, 50, PlatformAntigravity, StatusActive, "beta", &isExclusive)
require.NoError(t, err)
require.Equal(t, int64(42), total)
require.Equal(t, []Group{{ID: 2, Name: "beta"}}, groups)
require.Equal(t, 1, repo.listWithFiltersCalls)
require.Equal(t, pagination.PaginationParams{Page: 3, PageSize: 50}, repo.listWithFiltersParams)
require.Equal(t, PlatformAntigravity, repo.listWithFiltersPlatform)
require.Equal(t, StatusActive, repo.listWithFiltersStatus)
require.Equal(t, "beta", repo.listWithFiltersSearch)
require.NotNil(t, repo.listWithFiltersIsExclusive)
require.True(t, *repo.listWithFiltersIsExclusive)
})
}

View File

@@ -0,0 +1,238 @@
//go:build unit
package service
import (
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
type accountRepoStubForAdminList struct {
accountRepoStub
listWithFiltersCalls int
listWithFiltersParams pagination.PaginationParams
listWithFiltersPlatform string
listWithFiltersType string
listWithFiltersStatus string
listWithFiltersSearch string
listWithFiltersAccounts []Account
listWithFiltersResult *pagination.PaginationResult
listWithFiltersErr error
}
func (s *accountRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) {
s.listWithFiltersCalls++
s.listWithFiltersParams = params
s.listWithFiltersPlatform = platform
s.listWithFiltersType = accountType
s.listWithFiltersStatus = status
s.listWithFiltersSearch = search
if s.listWithFiltersErr != nil {
return nil, nil, s.listWithFiltersErr
}
result := s.listWithFiltersResult
if result == nil {
result = &pagination.PaginationResult{
Total: int64(len(s.listWithFiltersAccounts)),
Page: params.Page,
PageSize: params.PageSize,
}
}
return s.listWithFiltersAccounts, result, nil
}
type proxyRepoStubForAdminList struct {
proxyRepoStub
listWithFiltersCalls int
listWithFiltersParams pagination.PaginationParams
listWithFiltersProtocol string
listWithFiltersStatus string
listWithFiltersSearch string
listWithFiltersProxies []Proxy
listWithFiltersResult *pagination.PaginationResult
listWithFiltersErr error
listWithFiltersAndAccountCountCalls int
listWithFiltersAndAccountCountParams pagination.PaginationParams
listWithFiltersAndAccountCountProtocol string
listWithFiltersAndAccountCountStatus string
listWithFiltersAndAccountCountSearch string
listWithFiltersAndAccountCountProxies []ProxyWithAccountCount
listWithFiltersAndAccountCountResult *pagination.PaginationResult
listWithFiltersAndAccountCountErr error
}
func (s *proxyRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, protocol, status, search string) ([]Proxy, *pagination.PaginationResult, error) {
s.listWithFiltersCalls++
s.listWithFiltersParams = params
s.listWithFiltersProtocol = protocol
s.listWithFiltersStatus = status
s.listWithFiltersSearch = search
if s.listWithFiltersErr != nil {
return nil, nil, s.listWithFiltersErr
}
result := s.listWithFiltersResult
if result == nil {
result = &pagination.PaginationResult{
Total: int64(len(s.listWithFiltersProxies)),
Page: params.Page,
PageSize: params.PageSize,
}
}
return s.listWithFiltersProxies, result, nil
}
func (s *proxyRepoStubForAdminList) ListWithFiltersAndAccountCount(_ context.Context, params pagination.PaginationParams, protocol, status, search string) ([]ProxyWithAccountCount, *pagination.PaginationResult, error) {
s.listWithFiltersAndAccountCountCalls++
s.listWithFiltersAndAccountCountParams = params
s.listWithFiltersAndAccountCountProtocol = protocol
s.listWithFiltersAndAccountCountStatus = status
s.listWithFiltersAndAccountCountSearch = search
if s.listWithFiltersAndAccountCountErr != nil {
return nil, nil, s.listWithFiltersAndAccountCountErr
}
result := s.listWithFiltersAndAccountCountResult
if result == nil {
result = &pagination.PaginationResult{
Total: int64(len(s.listWithFiltersAndAccountCountProxies)),
Page: params.Page,
PageSize: params.PageSize,
}
}
return s.listWithFiltersAndAccountCountProxies, result, nil
}
type redeemRepoStubForAdminList struct {
redeemRepoStub
listWithFiltersCalls int
listWithFiltersParams pagination.PaginationParams
listWithFiltersType string
listWithFiltersStatus string
listWithFiltersSearch string
listWithFiltersCodes []RedeemCode
listWithFiltersResult *pagination.PaginationResult
listWithFiltersErr error
}
func (s *redeemRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, codeType, status, search string) ([]RedeemCode, *pagination.PaginationResult, error) {
s.listWithFiltersCalls++
s.listWithFiltersParams = params
s.listWithFiltersType = codeType
s.listWithFiltersStatus = status
s.listWithFiltersSearch = search
if s.listWithFiltersErr != nil {
return nil, nil, s.listWithFiltersErr
}
result := s.listWithFiltersResult
if result == nil {
result = &pagination.PaginationResult{
Total: int64(len(s.listWithFiltersCodes)),
Page: params.Page,
PageSize: params.PageSize,
}
}
return s.listWithFiltersCodes, result, nil
}
func TestAdminService_ListAccounts_WithSearch(t *testing.T) {
t.Run("search 参数正常传递到 repository 层", func(t *testing.T) {
repo := &accountRepoStubForAdminList{
listWithFiltersAccounts: []Account{{ID: 1, Name: "acc"}},
listWithFiltersResult: &pagination.PaginationResult{Total: 10},
}
svc := &adminServiceImpl{accountRepo: repo}
accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformGemini, AccountTypeOAuth, StatusActive, "acc")
require.NoError(t, err)
require.Equal(t, int64(10), total)
require.Equal(t, []Account{{ID: 1, Name: "acc"}}, accounts)
require.Equal(t, 1, repo.listWithFiltersCalls)
require.Equal(t, pagination.PaginationParams{Page: 1, PageSize: 20}, repo.listWithFiltersParams)
require.Equal(t, PlatformGemini, repo.listWithFiltersPlatform)
require.Equal(t, AccountTypeOAuth, repo.listWithFiltersType)
require.Equal(t, StatusActive, repo.listWithFiltersStatus)
require.Equal(t, "acc", repo.listWithFiltersSearch)
})
}
func TestAdminService_ListProxies_WithSearch(t *testing.T) {
t.Run("search 参数正常传递到 repository 层", func(t *testing.T) {
repo := &proxyRepoStubForAdminList{
listWithFiltersProxies: []Proxy{{ID: 2, Name: "p1"}},
listWithFiltersResult: &pagination.PaginationResult{Total: 7},
}
svc := &adminServiceImpl{proxyRepo: repo}
proxies, total, err := svc.ListProxies(context.Background(), 3, 50, "http", StatusActive, "p1")
require.NoError(t, err)
require.Equal(t, int64(7), total)
require.Equal(t, []Proxy{{ID: 2, Name: "p1"}}, proxies)
require.Equal(t, 1, repo.listWithFiltersCalls)
require.Equal(t, pagination.PaginationParams{Page: 3, PageSize: 50}, repo.listWithFiltersParams)
require.Equal(t, "http", repo.listWithFiltersProtocol)
require.Equal(t, StatusActive, repo.listWithFiltersStatus)
require.Equal(t, "p1", repo.listWithFiltersSearch)
})
}
func TestAdminService_ListProxiesWithAccountCount_WithSearch(t *testing.T) {
t.Run("search 参数正常传递到 repository 层", func(t *testing.T) {
repo := &proxyRepoStubForAdminList{
listWithFiltersAndAccountCountProxies: []ProxyWithAccountCount{{Proxy: Proxy{ID: 3, Name: "p2"}, AccountCount: 5}},
listWithFiltersAndAccountCountResult: &pagination.PaginationResult{Total: 9},
}
svc := &adminServiceImpl{proxyRepo: repo}
proxies, total, err := svc.ListProxiesWithAccountCount(context.Background(), 2, 10, "socks5", StatusDisabled, "p2")
require.NoError(t, err)
require.Equal(t, int64(9), total)
require.Equal(t, []ProxyWithAccountCount{{Proxy: Proxy{ID: 3, Name: "p2"}, AccountCount: 5}}, proxies)
require.Equal(t, 1, repo.listWithFiltersAndAccountCountCalls)
require.Equal(t, pagination.PaginationParams{Page: 2, PageSize: 10}, repo.listWithFiltersAndAccountCountParams)
require.Equal(t, "socks5", repo.listWithFiltersAndAccountCountProtocol)
require.Equal(t, StatusDisabled, repo.listWithFiltersAndAccountCountStatus)
require.Equal(t, "p2", repo.listWithFiltersAndAccountCountSearch)
})
}
func TestAdminService_ListRedeemCodes_WithSearch(t *testing.T) {
t.Run("search 参数正常传递到 repository 层", func(t *testing.T) {
repo := &redeemRepoStubForAdminList{
listWithFiltersCodes: []RedeemCode{{ID: 4, Code: "ABC"}},
listWithFiltersResult: &pagination.PaginationResult{Total: 3},
}
svc := &adminServiceImpl{redeemCodeRepo: repo}
codes, total, err := svc.ListRedeemCodes(context.Background(), 1, 20, RedeemTypeBalance, StatusUnused, "ABC")
require.NoError(t, err)
require.Equal(t, int64(3), total)
require.Equal(t, []RedeemCode{{ID: 4, Code: "ABC"}}, codes)
require.Equal(t, 1, repo.listWithFiltersCalls)
require.Equal(t, pagination.PaginationParams{Page: 1, PageSize: 20}, repo.listWithFiltersParams)
require.Equal(t, RedeemTypeBalance, repo.listWithFiltersType)
require.Equal(t, StatusUnused, repo.listWithFiltersStatus)
require.Equal(t, "ABC", repo.listWithFiltersSearch)
})
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,88 @@
package service
import (
"strings"
"time"
)
const antigravityQuotaScopesKey = "antigravity_quota_scopes"
// AntigravityQuotaScope 表示 Antigravity 的配额域
type AntigravityQuotaScope string
const (
AntigravityQuotaScopeClaude AntigravityQuotaScope = "claude"
AntigravityQuotaScopeGeminiText AntigravityQuotaScope = "gemini_text"
AntigravityQuotaScopeGeminiImage AntigravityQuotaScope = "gemini_image"
)
// resolveAntigravityQuotaScope 根据模型名称解析配额域
func resolveAntigravityQuotaScope(requestedModel string) (AntigravityQuotaScope, bool) {
model := normalizeAntigravityModelName(requestedModel)
if model == "" {
return "", false
}
switch {
case strings.HasPrefix(model, "claude-"):
return AntigravityQuotaScopeClaude, true
case strings.HasPrefix(model, "gemini-"):
if isImageGenerationModel(model) {
return AntigravityQuotaScopeGeminiImage, true
}
return AntigravityQuotaScopeGeminiText, true
default:
return "", false
}
}
func normalizeAntigravityModelName(model string) string {
normalized := strings.ToLower(strings.TrimSpace(model))
normalized = strings.TrimPrefix(normalized, "models/")
return normalized
}
// IsSchedulableForModel 结合 Antigravity 配额域限流判断是否可调度
func (a *Account) IsSchedulableForModel(requestedModel string) bool {
if a == nil {
return false
}
if !a.IsSchedulable() {
return false
}
if a.Platform != PlatformAntigravity {
return true
}
scope, ok := resolveAntigravityQuotaScope(requestedModel)
if !ok {
return true
}
resetAt := a.antigravityQuotaScopeResetAt(scope)
if resetAt == nil {
return true
}
now := time.Now()
return !now.Before(*resetAt)
}
func (a *Account) antigravityQuotaScopeResetAt(scope AntigravityQuotaScope) *time.Time {
if a == nil || a.Extra == nil || scope == "" {
return nil
}
rawScopes, ok := a.Extra[antigravityQuotaScopesKey].(map[string]any)
if !ok {
return nil
}
rawScope, ok := rawScopes[string(scope)].(map[string]any)
if !ok {
return nil
}
resetAtRaw, ok := rawScope["rate_limit_reset_at"].(string)
if !ok || strings.TrimSpace(resetAtRaw) == "" {
return nil
}
resetAt, err := time.Parse(time.RFC3339, resetAtRaw)
if err != nil {
return nil
}
return &resetAt
}

View File

@@ -3,16 +3,18 @@ package service
import "time"
type APIKey struct {
ID int64
UserID int64
Key string
Name string
GroupID *int64
Status string
CreatedAt time.Time
UpdatedAt time.Time
User *User
Group *Group
ID int64
UserID int64
Key string
Name string
GroupID *int64
Status string
IPWhitelist []string
IPBlacklist []string
CreatedAt time.Time
UpdatedAt time.Time
User *User
Group *Group
}
func (k *APIKey) IsActive() bool {

View File

@@ -9,6 +9,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
)
@@ -20,6 +21,7 @@ var (
ErrAPIKeyTooShort = infraerrors.BadRequest("API_KEY_TOO_SHORT", "api key must be at least 16 characters")
ErrAPIKeyInvalidChars = infraerrors.BadRequest("API_KEY_INVALID_CHARS", "api key can only contain letters, numbers, underscores, and hyphens")
ErrAPIKeyRateLimited = infraerrors.TooManyRequests("API_KEY_RATE_LIMITED", "too many failed attempts, please try again later")
ErrInvalidIPPattern = infraerrors.BadRequest("INVALID_IP_PATTERN", "invalid IP or CIDR pattern")
)
const (
@@ -57,16 +59,20 @@ type APIKeyCache interface {
// CreateAPIKeyRequest 创建API Key请求
type CreateAPIKeyRequest struct {
Name string `json:"name"`
GroupID *int64 `json:"group_id"`
CustomKey *string `json:"custom_key"` // 可选的自定义key
Name string `json:"name"`
GroupID *int64 `json:"group_id"`
CustomKey *string `json:"custom_key"` // 可选的自定义key
IPWhitelist []string `json:"ip_whitelist"` // IP 白名单
IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单
}
// UpdateAPIKeyRequest 更新API Key请求
type UpdateAPIKeyRequest struct {
Name *string `json:"name"`
GroupID *int64 `json:"group_id"`
Status *string `json:"status"`
Name *string `json:"name"`
GroupID *int64 `json:"group_id"`
Status *string `json:"status"`
IPWhitelist []string `json:"ip_whitelist"` // IP 白名单(空数组清空)
IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单(空数组清空)
}
// APIKeyService API Key服务
@@ -186,6 +192,20 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK
return nil, fmt.Errorf("get user: %w", err)
}
// 验证 IP 白名单格式
if len(req.IPWhitelist) > 0 {
if invalid := ip.ValidateIPPatterns(req.IPWhitelist); len(invalid) > 0 {
return nil, fmt.Errorf("%w: %v", ErrInvalidIPPattern, invalid)
}
}
// 验证 IP 黑名单格式
if len(req.IPBlacklist) > 0 {
if invalid := ip.ValidateIPPatterns(req.IPBlacklist); len(invalid) > 0 {
return nil, fmt.Errorf("%w: %v", ErrInvalidIPPattern, invalid)
}
}
// 验证分组权限(如果指定了分组)
if req.GroupID != nil {
group, err := s.groupRepo.GetByID(ctx, *req.GroupID)
@@ -236,11 +256,13 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK
// 创建API Key记录
apiKey := &APIKey{
UserID: userID,
Key: key,
Name: req.Name,
GroupID: req.GroupID,
Status: StatusActive,
UserID: userID,
Key: key,
Name: req.Name,
GroupID: req.GroupID,
Status: StatusActive,
IPWhitelist: req.IPWhitelist,
IPBlacklist: req.IPBlacklist,
}
if err := s.apiKeyRepo.Create(ctx, apiKey); err != nil {
@@ -312,6 +334,20 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req
return nil, ErrInsufficientPerms
}
// 验证 IP 白名单格式
if len(req.IPWhitelist) > 0 {
if invalid := ip.ValidateIPPatterns(req.IPWhitelist); len(invalid) > 0 {
return nil, fmt.Errorf("%w: %v", ErrInvalidIPPattern, invalid)
}
}
// 验证 IP 黑名单格式
if len(req.IPBlacklist) > 0 {
if invalid := ip.ValidateIPPatterns(req.IPBlacklist); len(invalid) > 0 {
return nil, fmt.Errorf("%w: %v", ErrInvalidIPPattern, invalid)
}
}
// 更新字段
if req.Name != nil {
apiKey.Name = *req.Name
@@ -344,6 +380,10 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req
}
}
// 更新 IP 限制(空数组会清空设置)
apiKey.IPWhitelist = req.IPWhitelist
apiKey.IPBlacklist = req.IPBlacklist
if err := s.apiKeyRepo.Update(ctx, apiKey); err != nil {
return nil, fmt.Errorf("update api key: %w", err)
}

View File

@@ -2,9 +2,13 @@ package service
import (
"context"
"crypto/rand"
"encoding/hex"
"errors"
"fmt"
"log"
"net/mail"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
@@ -18,6 +22,7 @@ var (
ErrInvalidCredentials = infraerrors.Unauthorized("INVALID_CREDENTIALS", "invalid email or password")
ErrUserNotActive = infraerrors.Forbidden("USER_NOT_ACTIVE", "user is not active")
ErrEmailExists = infraerrors.Conflict("EMAIL_EXISTS", "email already exists")
ErrEmailReserved = infraerrors.BadRequest("EMAIL_RESERVED", "email is reserved")
ErrInvalidToken = infraerrors.Unauthorized("INVALID_TOKEN", "invalid token")
ErrTokenExpired = infraerrors.Unauthorized("TOKEN_EXPIRED", "token has expired")
ErrTokenTooLarge = infraerrors.BadRequest("TOKEN_TOO_LARGE", "token too large")
@@ -75,21 +80,30 @@ func (s *AuthService) Register(ctx context.Context, email, password string) (str
// RegisterWithVerification 用户注册支持邮件验证返回token和用户
func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode string) (string, *User, error) {
// 检查是否开放注册
if s.settingService != nil && !s.settingService.IsRegistrationEnabled(ctx) {
// 检查是否开放注册默认关闭settingService 未配置时不允许注册)
if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
return "", nil, ErrRegDisabled
}
// 防止用户注册 LinuxDo OAuth 合成邮箱,避免第三方登录与本地账号发生碰撞。
if isReservedEmail(email) {
return "", nil, ErrEmailReserved
}
// 检查是否需要邮件验证
if s.settingService != nil && s.settingService.IsEmailVerifyEnabled(ctx) {
// 如果邮件验证已开启但邮件服务未配置,拒绝注册
// 这是一个配置错误,不应该允许绕过验证
if s.emailService == nil {
log.Println("[Auth] Email verification enabled but email service not configured, rejecting registration")
return "", nil, ErrServiceUnavailable
}
if verifyCode == "" {
return "", nil, ErrEmailVerifyRequired
}
// 验证邮箱验证码
if s.emailService != nil {
if err := s.emailService.VerifyCode(ctx, email, verifyCode); err != nil {
return "", nil, fmt.Errorf("verify code: %w", err)
}
if err := s.emailService.VerifyCode(ctx, email, verifyCode); err != nil {
return "", nil, fmt.Errorf("verify code: %w", err)
}
}
@@ -128,6 +142,10 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
}
if err := s.userRepo.Create(ctx, user); err != nil {
// 优先检查邮箱冲突错误(竞态条件下可能发生)
if errors.Is(err, ErrEmailExists) {
return "", nil, ErrEmailExists
}
log.Printf("[Auth] Database error creating user: %v", err)
return "", nil, ErrServiceUnavailable
}
@@ -148,11 +166,15 @@ type SendVerifyCodeResult struct {
// SendVerifyCode 发送邮箱验证码(同步方式)
func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error {
// 检查是否开放注册
if s.settingService != nil && !s.settingService.IsRegistrationEnabled(ctx) {
// 检查是否开放注册(默认关闭)
if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
return ErrRegDisabled
}
if isReservedEmail(email) {
return ErrEmailReserved
}
// 检查邮箱是否已存在
existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
if err != nil {
@@ -181,12 +203,16 @@ func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error {
func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*SendVerifyCodeResult, error) {
log.Printf("[Auth] SendVerifyCodeAsync called for email: %s", email)
// 检查是否开放注册
if s.settingService != nil && !s.settingService.IsRegistrationEnabled(ctx) {
// 检查是否开放注册(默认关闭)
if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
log.Println("[Auth] Registration is disabled")
return nil, ErrRegDisabled
}
if isReservedEmail(email) {
return nil, ErrEmailReserved
}
// 检查邮箱是否已存在
existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
if err != nil {
@@ -266,7 +292,7 @@ func (s *AuthService) IsTurnstileEnabled(ctx context.Context) bool {
// IsRegistrationEnabled 检查是否开放注册
func (s *AuthService) IsRegistrationEnabled(ctx context.Context) bool {
if s.settingService == nil {
return true
return false // 安全默认settingService 未配置时关闭注册
}
return s.settingService.IsRegistrationEnabled(ctx)
}
@@ -311,6 +337,102 @@ func (s *AuthService) Login(ctx context.Context, email, password string) (string
return token, user, nil
}
// LoginOrRegisterOAuth 用于第三方 OAuth/SSO 登录:
// - 如果邮箱已存在:直接登录(不需要本地密码)
// - 如果邮箱不存在:创建新用户并登录
//
// 注意:该函数用于“终端用户登录 Sub2API 本身”的场景(不同于上游账号的 OAuth例如 OpenAI/Gemini
// 为了满足现有数据库约束(需要密码哈希),新用户会生成随机密码并进行哈希保存。
func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username string) (string, *User, error) {
email = strings.TrimSpace(email)
if email == "" || len(email) > 255 {
return "", nil, infraerrors.BadRequest("INVALID_EMAIL", "invalid email")
}
if _, err := mail.ParseAddress(email); err != nil {
return "", nil, infraerrors.BadRequest("INVALID_EMAIL", "invalid email")
}
username = strings.TrimSpace(username)
if len([]rune(username)) > 100 {
username = string([]rune(username)[:100])
}
user, err := s.userRepo.GetByEmail(ctx, email)
if err != nil {
if errors.Is(err, ErrUserNotFound) {
// OAuth 首次登录视为注册。
if s.settingService != nil && !s.settingService.IsRegistrationEnabled(ctx) {
return "", nil, ErrRegDisabled
}
randomPassword, err := randomHexString(32)
if err != nil {
log.Printf("[Auth] Failed to generate random password for oauth signup: %v", err)
return "", nil, ErrServiceUnavailable
}
hashedPassword, err := s.HashPassword(randomPassword)
if err != nil {
return "", nil, fmt.Errorf("hash password: %w", err)
}
// 新用户默认值。
defaultBalance := s.cfg.Default.UserBalance
defaultConcurrency := s.cfg.Default.UserConcurrency
if s.settingService != nil {
defaultBalance = s.settingService.GetDefaultBalance(ctx)
defaultConcurrency = s.settingService.GetDefaultConcurrency(ctx)
}
newUser := &User{
Email: email,
Username: username,
PasswordHash: hashedPassword,
Role: RoleUser,
Balance: defaultBalance,
Concurrency: defaultConcurrency,
Status: StatusActive,
}
if err := s.userRepo.Create(ctx, newUser); err != nil {
if errors.Is(err, ErrEmailExists) {
// 并发场景GetByEmail 与 Create 之间用户被创建。
user, err = s.userRepo.GetByEmail(ctx, email)
if err != nil {
log.Printf("[Auth] Database error getting user after conflict: %v", err)
return "", nil, ErrServiceUnavailable
}
} else {
log.Printf("[Auth] Database error creating oauth user: %v", err)
return "", nil, ErrServiceUnavailable
}
} else {
user = newUser
}
} else {
log.Printf("[Auth] Database error during oauth login: %v", err)
return "", nil, ErrServiceUnavailable
}
}
if !user.IsActive() {
return "", nil, ErrUserNotActive
}
// 尽力补全:当用户名为空时,使用第三方返回的用户名回填。
if user.Username == "" && username != "" {
user.Username = username
if err := s.userRepo.Update(ctx, user); err != nil {
log.Printf("[Auth] Failed to update username after oauth login: %v", err)
}
}
token, err := s.GenerateToken(user)
if err != nil {
return "", nil, fmt.Errorf("generate token: %w", err)
}
return token, user, nil
}
// ValidateToken 验证JWT token并返回用户声明
func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) {
// 先做长度校验,尽早拒绝异常超长 token降低 DoS 风险。
@@ -336,6 +458,11 @@ func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) {
if err != nil {
if errors.Is(err, jwt.ErrTokenExpired) {
// token 过期但仍返回 claims用于 RefreshToken 等场景)
// jwt-go 在解析时即使遇到过期错误token.Claims 仍会被填充
if claims, ok := token.Claims.(*JWTClaims); ok {
return claims, ErrTokenExpired
}
return nil, ErrTokenExpired
}
return nil, ErrInvalidToken
@@ -348,6 +475,22 @@ func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) {
return nil, ErrInvalidToken
}
func randomHexString(byteLength int) (string, error) {
if byteLength <= 0 {
byteLength = 16
}
buf := make([]byte, byteLength)
if _, err := rand.Read(buf); err != nil {
return "", err
}
return hex.EncodeToString(buf), nil
}
func isReservedEmail(email string) bool {
normalized := strings.ToLower(strings.TrimSpace(email))
return strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain)
}
// GenerateToken 生成JWT token
func (s *AuthService) GenerateToken(user *User) (string, error) {
now := time.Now()

View File

@@ -113,13 +113,36 @@ func TestAuthService_Register_Disabled(t *testing.T) {
require.ErrorIs(t, err, ErrRegDisabled)
}
func TestAuthService_Register_EmailVerifyRequired(t *testing.T) {
func TestAuthService_Register_DisabledByDefault(t *testing.T) {
// 当 settings 为 nil设置项不存在注册应该默认关闭
repo := &userRepoStub{}
service := newAuthService(repo, nil, nil)
_, _, err := service.Register(context.Background(), "user@test.com", "password")
require.ErrorIs(t, err, ErrRegDisabled)
}
func TestAuthService_Register_EmailVerifyEnabledButServiceNotConfigured(t *testing.T) {
repo := &userRepoStub{}
// 邮件验证开启但 emailCache 为 nilemailService 未配置)
service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
SettingKeyEmailVerifyEnabled: "true",
}, nil)
// 应返回服务不可用错误,而不是允许绕过验证
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "any-code")
require.ErrorIs(t, err, ErrServiceUnavailable)
}
func TestAuthService_Register_EmailVerifyRequired(t *testing.T) {
repo := &userRepoStub{}
cache := &emailCacheStub{} // 配置 emailService
service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
SettingKeyEmailVerifyEnabled: "true",
}, cache)
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "")
require.ErrorIs(t, err, ErrEmailVerifyRequired)
}
@@ -141,7 +164,9 @@ func TestAuthService_Register_EmailVerifyInvalid(t *testing.T) {
func TestAuthService_Register_EmailExists(t *testing.T) {
repo := &userRepoStub{exists: true}
service := newAuthService(repo, nil, nil)
service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
}, nil)
_, _, err := service.Register(context.Background(), "user@test.com", "password")
require.ErrorIs(t, err, ErrEmailExists)
@@ -149,23 +174,50 @@ func TestAuthService_Register_EmailExists(t *testing.T) {
func TestAuthService_Register_CheckEmailError(t *testing.T) {
repo := &userRepoStub{existsErr: errors.New("db down")}
service := newAuthService(repo, nil, nil)
service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
}, nil)
_, _, err := service.Register(context.Background(), "user@test.com", "password")
require.ErrorIs(t, err, ErrServiceUnavailable)
}
func TestAuthService_Register_ReservedEmail(t *testing.T) {
repo := &userRepoStub{}
service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
}, nil)
_, _, err := service.Register(context.Background(), "linuxdo-123@linuxdo-connect.invalid", "password")
require.ErrorIs(t, err, ErrEmailReserved)
}
func TestAuthService_Register_CreateError(t *testing.T) {
repo := &userRepoStub{createErr: errors.New("create failed")}
service := newAuthService(repo, nil, nil)
service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
}, nil)
_, _, err := service.Register(context.Background(), "user@test.com", "password")
require.ErrorIs(t, err, ErrServiceUnavailable)
}
func TestAuthService_Register_CreateEmailExistsRace(t *testing.T) {
// 模拟竞态条件ExistsByEmail 返回 false但 Create 时因唯一约束失败
repo := &userRepoStub{createErr: ErrEmailExists}
service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
}, nil)
_, _, err := service.Register(context.Background(), "user@test.com", "password")
require.ErrorIs(t, err, ErrEmailExists)
}
func TestAuthService_Register_Success(t *testing.T) {
repo := &userRepoStub{nextID: 5}
service := newAuthService(repo, nil, nil)
service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
}, nil)
token, user, err := service.Register(context.Background(), "user@test.com", "password")
require.NoError(t, err)
@@ -180,3 +232,63 @@ func TestAuthService_Register_Success(t *testing.T) {
require.Len(t, repo.created, 1)
require.True(t, user.CheckPassword("password"))
}
func TestAuthService_ValidateToken_ExpiredReturnsClaimsWithError(t *testing.T) {
repo := &userRepoStub{}
service := newAuthService(repo, nil, nil)
// 创建用户并生成 token
user := &User{
ID: 1,
Email: "test@test.com",
Role: RoleUser,
Status: StatusActive,
TokenVersion: 1,
}
token, err := service.GenerateToken(user)
require.NoError(t, err)
// 验证有效 token
claims, err := service.ValidateToken(token)
require.NoError(t, err)
require.NotNil(t, claims)
require.Equal(t, int64(1), claims.UserID)
// 模拟过期 token通过创建一个过期很久的 token
service.cfg.JWT.ExpireHour = -1 // 设置为负数使 token 立即过期
expiredToken, err := service.GenerateToken(user)
require.NoError(t, err)
service.cfg.JWT.ExpireHour = 1 // 恢复
// 验证过期 token 应返回 claims 和 ErrTokenExpired
claims, err = service.ValidateToken(expiredToken)
require.ErrorIs(t, err, ErrTokenExpired)
require.NotNil(t, claims, "claims should not be nil when token is expired")
require.Equal(t, int64(1), claims.UserID)
require.Equal(t, "test@test.com", claims.Email)
}
func TestAuthService_RefreshToken_ExpiredTokenNoPanic(t *testing.T) {
user := &User{
ID: 1,
Email: "test@test.com",
Role: RoleUser,
Status: StatusActive,
TokenVersion: 1,
}
repo := &userRepoStub{user: user}
service := newAuthService(repo, nil, nil)
// 创建过期 token
service.cfg.JWT.ExpireHour = -1
expiredToken, err := service.GenerateToken(user)
require.NoError(t, err)
service.cfg.JWT.ExpireHour = 1
// RefreshToken 使用过期 token 不应 panic
require.NotPanics(t, func() {
newToken, err := service.RefreshToken(context.Background(), expiredToken)
require.NoError(t, err)
require.NotEmpty(t, newToken)
})
}

View File

@@ -0,0 +1,265 @@
package service
import (
"context"
"net/http"
"regexp"
"strings"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
)
// ClaudeCodeValidator 验证请求是否来自 Claude Code 客户端
// 完全学习自 claude-relay-service 项目的验证逻辑
type ClaudeCodeValidator struct{}
var (
// User-Agent 匹配: claude-cli/x.x.x (仅支持官方 CLI大小写不敏感)
claudeCodeUAPattern = regexp.MustCompile(`(?i)^claude-cli/\d+\.\d+\.\d+`)
// metadata.user_id 格式: user_{64位hex}_account__session_{uuid}
userIDPattern = regexp.MustCompile(`^user_[a-fA-F0-9]{64}_account__session_[\w-]+$`)
// System prompt 相似度阈值(默认 0.5,和 claude-relay-service 一致)
systemPromptThreshold = 0.5
)
// Claude Code 官方 System Prompt 模板
// 从 claude-relay-service/src/utils/contents.js 提取
var claudeCodeSystemPrompts = []string{
// claudeOtherSystemPrompt1 - Primary
"You are Claude Code, Anthropic's official CLI for Claude.",
// claudeOtherSystemPrompt3 - Agent SDK
"You are a Claude agent, built on Anthropic's Claude Agent SDK.",
// claudeOtherSystemPrompt4 - Compact Agent SDK
"You are Claude Code, Anthropic's official CLI for Claude, running within the Claude Agent SDK.",
// exploreAgentSystemPrompt
"You are a file search specialist for Claude Code, Anthropic's official CLI for Claude.",
// claudeOtherSystemPromptCompact - Compact (用于对话摘要)
"You are a helpful AI assistant tasked with summarizing conversations.",
// claudeOtherSystemPrompt2 - Secondary (长提示词的关键部分)
"You are an interactive CLI tool that helps users",
}
// NewClaudeCodeValidator 创建验证器实例
func NewClaudeCodeValidator() *ClaudeCodeValidator {
return &ClaudeCodeValidator{}
}
// Validate 验证请求是否来自 Claude Code CLI
// 采用与 claude-relay-service 完全一致的验证策略:
//
// Step 1: User-Agent 检查 (必需) - 必须是 claude-cli/x.x.x
// Step 2: 对于非 messages 路径,只要 UA 匹配就通过
// Step 3: 对于 messages 路径,进行严格验证:
// - System prompt 相似度检查
// - X-App header 检查
// - anthropic-beta header 检查
// - anthropic-version header 检查
// - metadata.user_id 格式验证
func (v *ClaudeCodeValidator) Validate(r *http.Request, body map[string]any) bool {
// Step 1: User-Agent 检查
ua := r.Header.Get("User-Agent")
if !claudeCodeUAPattern.MatchString(ua) {
return false
}
// Step 2: 非 messages 路径,只要 UA 匹配就通过
path := r.URL.Path
if !strings.Contains(path, "messages") {
return true
}
// Step 3: messages 路径,进行严格验证
// 3.1 检查 system prompt 相似度
if !v.hasClaudeCodeSystemPrompt(body) {
return false
}
// 3.2 检查必需的 headers值不为空即可
xApp := r.Header.Get("X-App")
if xApp == "" {
return false
}
anthropicBeta := r.Header.Get("anthropic-beta")
if anthropicBeta == "" {
return false
}
anthropicVersion := r.Header.Get("anthropic-version")
if anthropicVersion == "" {
return false
}
// 3.3 验证 metadata.user_id
if body == nil {
return false
}
metadata, ok := body["metadata"].(map[string]any)
if !ok {
return false
}
userID, ok := metadata["user_id"].(string)
if !ok || userID == "" {
return false
}
if !userIDPattern.MatchString(userID) {
return false
}
return true
}
// hasClaudeCodeSystemPrompt 检查请求是否包含 Claude Code 系统提示词
// 使用字符串相似度匹配Dice coefficient
func (v *ClaudeCodeValidator) hasClaudeCodeSystemPrompt(body map[string]any) bool {
if body == nil {
return false
}
// 检查 model 字段
if _, ok := body["model"].(string); !ok {
return false
}
// 获取 system 字段
systemEntries, ok := body["system"].([]any)
if !ok {
return false
}
// 检查每个 system entry
for _, entry := range systemEntries {
entryMap, ok := entry.(map[string]any)
if !ok {
continue
}
text, ok := entryMap["text"].(string)
if !ok || text == "" {
continue
}
// 计算与所有模板的最佳相似度
bestScore := v.bestSimilarityScore(text)
if bestScore >= systemPromptThreshold {
return true
}
}
return false
}
// bestSimilarityScore 计算文本与所有 Claude Code 模板的最佳相似度
func (v *ClaudeCodeValidator) bestSimilarityScore(text string) float64 {
normalizedText := normalizePrompt(text)
bestScore := 0.0
for _, template := range claudeCodeSystemPrompts {
normalizedTemplate := normalizePrompt(template)
score := diceCoefficient(normalizedText, normalizedTemplate)
if score > bestScore {
bestScore = score
}
}
return bestScore
}
// normalizePrompt 标准化提示词文本(去除多余空白)
func normalizePrompt(text string) string {
// 将所有空白字符替换为单个空格,并去除首尾空白
return strings.Join(strings.Fields(text), " ")
}
// diceCoefficient 计算两个字符串的 Dice 系数SørensenDice coefficient
// 这是 string-similarity 库使用的算法
// 公式: 2 * |intersection| / (|bigrams(a)| + |bigrams(b)|)
func diceCoefficient(a, b string) float64 {
if a == b {
return 1.0
}
if len(a) < 2 || len(b) < 2 {
return 0.0
}
// 生成 bigrams
bigramsA := getBigrams(a)
bigramsB := getBigrams(b)
if len(bigramsA) == 0 || len(bigramsB) == 0 {
return 0.0
}
// 计算交集大小
intersection := 0
for bigram, countA := range bigramsA {
if countB, exists := bigramsB[bigram]; exists {
if countA < countB {
intersection += countA
} else {
intersection += countB
}
}
}
// 计算总 bigram 数量
totalA := 0
for _, count := range bigramsA {
totalA += count
}
totalB := 0
for _, count := range bigramsB {
totalB += count
}
return float64(2*intersection) / float64(totalA+totalB)
}
// getBigrams 获取字符串的所有 bigrams相邻字符对
func getBigrams(s string) map[string]int {
bigrams := make(map[string]int)
runes := []rune(strings.ToLower(s))
for i := 0; i < len(runes)-1; i++ {
bigram := string(runes[i : i+2])
bigrams[bigram]++
}
return bigrams
}
// ValidateUserAgent 仅验证 User-Agent用于不需要解析请求体的场景
func (v *ClaudeCodeValidator) ValidateUserAgent(ua string) bool {
return claudeCodeUAPattern.MatchString(ua)
}
// IncludesClaudeCodeSystemPrompt 检查请求体是否包含 Claude Code 系统提示词
// 只要存在匹配的系统提示词就返回 true用于宽松检测
func (v *ClaudeCodeValidator) IncludesClaudeCodeSystemPrompt(body map[string]any) bool {
return v.hasClaudeCodeSystemPrompt(body)
}
// IsClaudeCodeClient 从 context 中获取 Claude Code 客户端标识
func IsClaudeCodeClient(ctx context.Context) bool {
if v, ok := ctx.Value(ctxkey.IsClaudeCodeClient).(bool); ok {
return v
}
return false
}
// SetClaudeCodeClient 将 Claude Code 客户端标识设置到 context 中
func SetClaudeCodeClient(ctx context.Context, isClaudeCode bool) context.Context {
return context.WithValue(ctx, ctxkey.IsClaudeCodeClient, isClaudeCode)
}

View File

@@ -5,6 +5,7 @@ import (
"crypto/rand"
"crypto/tls"
"fmt"
"log"
"math/big"
"net/smtp"
"strconv"
@@ -256,7 +257,9 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error
// 验证码不匹配
if data.Code != code {
data.Attempts++
_ = s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL)
if err := s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL); err != nil {
log.Printf("[Email] Failed to update verification attempt count: %v", err)
}
if data.Attempts >= maxVerifyCodeAttempts {
return ErrVerifyCodeMaxAttempts
}
@@ -264,7 +267,9 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error
}
// 验证成功,删除验证码
_ = s.cache.DeleteVerificationCode(ctx, email)
if err := s.cache.DeleteVerificationCode(ctx, email); err != nil {
log.Printf("[Email] Failed to delete verification code after success: %v", err)
}
return nil
}

View File

@@ -105,6 +105,9 @@ func (m *mockAccountRepoForPlatform) SetError(ctx context.Context, id int64, err
func (m *mockAccountRepoForPlatform) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
return nil
}
func (m *mockAccountRepoForPlatform) AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error) {
return 0, nil
}
func (m *mockAccountRepoForPlatform) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
return nil
}
@@ -133,6 +136,9 @@ func (m *mockAccountRepoForPlatform) ListSchedulableByGroupIDAndPlatforms(ctx co
func (m *mockAccountRepoForPlatform) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
return nil
}
func (m *mockAccountRepoForPlatform) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error {
return nil
}
func (m *mockAccountRepoForPlatform) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
return nil
}
@@ -145,6 +151,9 @@ func (m *mockAccountRepoForPlatform) ClearTempUnschedulable(ctx context.Context,
func (m *mockAccountRepoForPlatform) ClearRateLimit(ctx context.Context, id int64) error {
return nil
}
func (m *mockAccountRepoForPlatform) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error {
return nil
}
func (m *mockAccountRepoForPlatform) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
return nil
}
@@ -163,14 +172,14 @@ type mockGatewayCacheForPlatform struct {
sessionBindings map[string]int64
}
func (m *mockGatewayCacheForPlatform) GetSessionAccountID(ctx context.Context, sessionHash string) (int64, error) {
func (m *mockGatewayCacheForPlatform) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) {
if id, ok := m.sessionBindings[sessionHash]; ok {
return id, nil
}
return 0, errors.New("not found")
}
func (m *mockGatewayCacheForPlatform) SetSessionAccountID(ctx context.Context, sessionHash string, accountID int64, ttl time.Duration) error {
func (m *mockGatewayCacheForPlatform) SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error {
if m.sessionBindings == nil {
m.sessionBindings = make(map[string]int64)
}
@@ -178,7 +187,7 @@ func (m *mockGatewayCacheForPlatform) SetSessionAccountID(ctx context.Context, s
return nil
}
func (m *mockGatewayCacheForPlatform) RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error {
func (m *mockGatewayCacheForPlatform) RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error {
return nil
}

View File

@@ -33,8 +33,9 @@ const (
claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true"
claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true"
stickySessionTTL = time.Hour // 粘性会话TTL
defaultMaxLineSize = 10 * 1024 * 1024
defaultMaxLineSize = 40 * 1024 * 1024
claudeCodeSystemPrompt = "You are Claude Code, Anthropic's official CLI for Claude."
maxCacheControlBlocks = 4 // Anthropic API 允许的最大 cache_control 块数量
)
// sseDataRe matches SSE data lines with optional whitespace after colon.
@@ -43,8 +44,21 @@ var (
sseDataRe = regexp.MustCompile(`^data:\s*`)
sessionIDRegex = regexp.MustCompile(`session_([a-f0-9-]{36})`)
claudeCliUserAgentRe = regexp.MustCompile(`^claude-cli/\d+\.\d+\.\d+`)
// claudeCodePromptPrefixes 用于检测 Claude Code 系统提示词的前缀列表
// 支持多种变体标准版、Agent SDK 版、Explore Agent 版、Compact 版等
// 注意:前缀之间不应存在包含关系,否则会导致冗余匹配
claudeCodePromptPrefixes = []string{
"You are Claude Code, Anthropic's official CLI for Claude", // 标准版 & Agent SDK 版(含 running within...
"You are a Claude agent, built on Anthropic's Claude Agent SDK", // Agent SDK 变体
"You are a file search specialist for Claude Code", // Explore Agent 版
"You are a helpful AI assistant tasked with summarizing conversations", // Compact 版
}
)
// ErrClaudeCodeOnly 表示分组仅允许 Claude Code 客户端访问
var ErrClaudeCodeOnly = errors.New("this group only allows Claude Code clients")
// allowedHeaders 白名单headers参考CRS项目
var allowedHeaders = map[string]bool{
"accept": true,
@@ -69,9 +83,17 @@ var allowedHeaders = map[string]bool{
// GatewayCache defines cache operations for gateway service
type GatewayCache interface {
GetSessionAccountID(ctx context.Context, sessionHash string) (int64, error)
SetSessionAccountID(ctx context.Context, sessionHash string, accountID int64, ttl time.Duration) error
RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error
GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error)
SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error
RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error
}
// derefGroupID safely dereferences *int64 to int64, returning 0 if nil
func derefGroupID(groupID *int64) int64 {
if groupID == nil {
return 0
}
return *groupID
}
type AccountWaitPlan struct {
@@ -98,12 +120,13 @@ type ClaudeUsage struct {
// ForwardResult 转发结果
type ForwardResult struct {
RequestID string
Usage ClaudeUsage
Model string
Stream bool
Duration time.Duration
FirstTokenMs *int // 首字时间(流式请求)
RequestID string
Usage ClaudeUsage
Model string
Stream bool
Duration time.Duration
FirstTokenMs *int // 首字时间(流式请求)
ClientDisconnect bool // 客户端是否在流式传输过程中断开
// 图片生成计费字段(仅 gemini-3-pro-image 使用)
ImageCount int // 生成的图片数量
@@ -213,11 +236,11 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string {
}
// BindStickySession sets session -> account binding with standard TTL.
func (s *GatewayService) BindStickySession(ctx context.Context, sessionHash string, accountID int64) error {
func (s *GatewayService) BindStickySession(ctx context.Context, groupID *int64, sessionHash string, accountID int64) error {
if sessionHash == "" || accountID <= 0 || s.cache == nil {
return nil
}
return s.cache.SetSessionAccountID(ctx, sessionHash, accountID, stickySessionTTL)
return s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, accountID, stickySessionTTL)
}
func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string {
@@ -344,6 +367,21 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
return nil, fmt.Errorf("get group failed: %w", err)
}
platform = group.Platform
// 检查 Claude Code 客户端限制
if group.ClaudeCodeOnly {
isClaudeCode := IsClaudeCodeClient(ctx)
if !isClaudeCode {
// 非 Claude Code 客户端,检查是否有降级分组
if group.FallbackGroupID != nil {
// 使用降级分组重新调度
fallbackGroupID := *group.FallbackGroupID
return s.SelectAccountForModelWithExclusions(ctx, &fallbackGroupID, sessionHash, requestedModel, excludedIDs)
}
// 无降级分组,拒绝访问
return nil, ErrClaudeCodeOnly
}
}
} else {
// 无分组时只使用原生 anthropic 平台
platform = PlatformAnthropic
@@ -355,17 +393,8 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
return s.selectAccountWithMixedScheduling(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
}
// 强制平台模式:优先按分组查找,找不到再查全部该平台账户
if hasForcePlatform && groupID != nil {
account, err := s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
if err == nil {
return account, nil
}
// 分组中找不到,回退查询全部该平台账户
groupID = nil
}
// antigravity 分组、强制平台模式或无分组使用单平台选择
// 注意:强制平台模式也必须遵守分组限制,不再回退到全平台查询
return s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
}
@@ -374,10 +403,17 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
cfg := s.schedulingConfig()
var stickyAccountID int64
if sessionHash != "" && s.cache != nil {
if accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash); err == nil {
if accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash); err == nil {
stickyAccountID = accountID
}
}
// 检查 Claude Code 客户端限制(可能会替换 groupID 为降级分组)
groupID, err := s.checkClaudeCodeRestriction(ctx, groupID)
if err != nil {
return nil, err
}
if s.concurrencyService == nil || !cfg.LoadBatchEnabled {
account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs)
if err != nil {
@@ -440,15 +476,16 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
// ============ Layer 1: 粘性会话优先 ============
if sessionHash != "" && s.cache != nil {
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
if err == nil && accountID > 0 && !isExcluded(accountID) {
account, err := s.accountRepo.GetByID(ctx, accountID)
if err == nil && s.isAccountAllowedForPlatform(account, platform, useMixed) &&
account.IsSchedulable() &&
if err == nil && s.isAccountInGroup(account, groupID) &&
s.isAccountAllowedForPlatform(account, platform, useMixed) &&
account.IsSchedulableForModel(requestedModel) &&
(requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if err == nil && result.Acquired {
_ = s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL)
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL)
return &AccountSelectionResult{
Account: account,
Acquired: true,
@@ -482,6 +519,9 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if !s.isAccountAllowedForPlatform(acc, platform, useMixed) {
continue
}
if !acc.IsSchedulableForModel(requestedModel) {
continue
}
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
continue
}
@@ -502,7 +542,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads)
if err != nil {
if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, sessionHash, preferOAuth); ok {
if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth); ok {
return result, nil
}
} else {
@@ -552,7 +592,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
if err == nil && result.Acquired {
if sessionHash != "" && s.cache != nil {
_ = s.cache.SetSessionAccountID(ctx, sessionHash, item.account.ID, stickySessionTTL)
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, item.account.ID, stickySessionTTL)
}
return &AccountSelectionResult{
Account: item.account,
@@ -580,7 +620,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
return nil, errors.New("no available accounts")
}
func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool) {
func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool) {
ordered := append([]*Account(nil), candidates...)
sortAccountsByPriorityAndLastUsed(ordered, preferOAuth)
@@ -588,7 +628,7 @@ func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates
result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency)
if err == nil && result.Acquired {
if sessionHash != "" && s.cache != nil {
_ = s.cache.SetSessionAccountID(ctx, sessionHash, acc.ID, stickySessionTTL)
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, acc.ID, stickySessionTTL)
}
return &AccountSelectionResult{
Account: acc,
@@ -615,6 +655,42 @@ func (s *GatewayService) schedulingConfig() config.GatewaySchedulingConfig {
}
}
// checkClaudeCodeRestriction 检查分组的 Claude Code 客户端限制
// 如果分组启用了 claude_code_only 且请求不是来自 Claude Code 客户端:
// - 有降级分组:返回降级分组的 ID
// - 无降级分组:返回 ErrClaudeCodeOnly 错误
func (s *GatewayService) checkClaudeCodeRestriction(ctx context.Context, groupID *int64) (*int64, error) {
if groupID == nil {
return groupID, nil
}
// 强制平台模式不检查 Claude Code 限制
if _, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string); hasForcePlatform {
return groupID, nil
}
group, err := s.groupRepo.GetByID(ctx, *groupID)
if err != nil {
return nil, fmt.Errorf("get group failed: %w", err)
}
if !group.ClaudeCodeOnly {
return groupID, nil
}
// 分组启用了 Claude Code 限制
if IsClaudeCodeClient(ctx) {
return groupID, nil
}
// 非 Claude Code 客户端,检查降级分组
if group.FallbackGroupID != nil {
return group.FallbackGroupID, nil
}
return nil, ErrClaudeCodeOnly
}
func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64) (string, bool, error) {
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
if hasForcePlatform && forcePlatform != "" {
@@ -660,9 +736,7 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform)
} else if groupID != nil {
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, platform)
if err == nil && len(accounts) == 0 && hasForcePlatform {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform)
}
// 分组内无账号则返回空列表,由上层处理错误,不再回退到全平台查询
} else {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform)
}
@@ -685,6 +759,23 @@ func (s *GatewayService) isAccountAllowedForPlatform(account *Account, platform
return account.Platform == platform
}
// isAccountInGroup checks if the account belongs to the specified group.
// Returns true if groupID is nil (no group restriction) or account belongs to the group.
func (s *GatewayService) isAccountInGroup(account *Account, groupID *int64) bool {
if groupID == nil {
return true // 无分组限制
}
if account == nil {
return false
}
for _, ag := range account.AccountGroups {
if ag.GroupID == *groupID {
return true
}
}
return false
}
func (s *GatewayService) tryAcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (*AcquireResult, error) {
if s.concurrencyService == nil {
return &AcquireResult{Acquired: true, ReleaseFunc: func() {}}, nil
@@ -719,13 +810,13 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
preferOAuth := platform == PlatformGemini
// 1. 查询粘性会话
if sessionHash != "" && s.cache != nil {
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
if err == nil && accountID > 0 {
if _, excluded := excludedIDs[accountID]; !excluded {
account, err := s.accountRepo.GetByID(ctx, accountID)
// 检查账号平台是否匹配(确保粘性会话不会跨平台)
if err == nil && account.Platform == platform && account.IsSchedulable() && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
if err := s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL); err != nil {
// 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台)
if err == nil && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
}
return account, nil
@@ -756,6 +847,9 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if _, excluded := excludedIDs[acc.ID]; excluded {
continue
}
if !acc.IsSchedulableForModel(requestedModel) {
continue
}
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
continue
}
@@ -792,7 +886,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
// 4. 建立粘性绑定
if sessionHash != "" && s.cache != nil {
if err := s.cache.SetSessionAccountID(ctx, sessionHash, selected.ID, stickySessionTTL); err != nil {
if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.ID, stickySessionTTL); err != nil {
log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err)
}
}
@@ -808,14 +902,14 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
// 1. 查询粘性会话
if sessionHash != "" && s.cache != nil {
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
if err == nil && accountID > 0 {
if _, excluded := excludedIDs[accountID]; !excluded {
account, err := s.accountRepo.GetByID(ctx, accountID)
// 检查账号是否有效原生平台直接匹配antigravity 需要启用混合调度
if err == nil && account.IsSchedulable() && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
// 检查账号分组归属和有效原生平台直接匹配antigravity 需要启用混合调度
if err == nil && s.isAccountInGroup(account, groupID) && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
if err := s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL); err != nil {
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
}
return account, nil
@@ -848,6 +942,9 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
continue
}
if !acc.IsSchedulableForModel(requestedModel) {
continue
}
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
continue
}
@@ -884,7 +981,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
// 4. 建立粘性绑定
if sessionHash != "" && s.cache != nil {
if err := s.cache.SetSessionAccountID(ctx, sessionHash, selected.ID, stickySessionTTL); err != nil {
if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.ID, stickySessionTTL); err != nil {
log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err)
}
}
@@ -1013,15 +1110,15 @@ func isClaudeCodeClient(userAgent string, metadataUserID string) bool {
}
// systemIncludesClaudeCodePrompt 检查 system 中是否已包含 Claude Code 提示词
// 支持 string 和 []any 两种格式
// 使用前缀匹配支持多种变体标准版、Agent SDK 版等)
func systemIncludesClaudeCodePrompt(system any) bool {
switch v := system.(type) {
case string:
return v == claudeCodeSystemPrompt
return hasClaudeCodePrefix(v)
case []any:
for _, item := range v {
if m, ok := item.(map[string]any); ok {
if text, ok := m["text"].(string); ok && text == claudeCodeSystemPrompt {
if text, ok := m["text"].(string); ok && hasClaudeCodePrefix(text) {
return true
}
}
@@ -1030,6 +1127,16 @@ func systemIncludesClaudeCodePrompt(system any) bool {
return false
}
// hasClaudeCodePrefix 检查文本是否以 Claude Code 提示词的特征前缀开头
func hasClaudeCodePrefix(text string) bool {
for _, prefix := range claudeCodePromptPrefixes {
if strings.HasPrefix(text, prefix) {
return true
}
}
return false
}
// injectClaudeCodePrompt 在 system 开头注入 Claude Code 提示词
// 处理 null、字符串、数组三种格式
func injectClaudeCodePrompt(body []byte, system any) []byte {
@@ -1073,6 +1180,124 @@ func injectClaudeCodePrompt(body []byte, system any) []byte {
return result
}
// enforceCacheControlLimit 强制执行 cache_control 块数量限制(最多 4 个)
// 超限时优先从 messages 中移除 cache_control保护 system 中的缓存控制
func enforceCacheControlLimit(body []byte) []byte {
var data map[string]any
if err := json.Unmarshal(body, &data); err != nil {
return body
}
// 计算当前 cache_control 块数量
count := countCacheControlBlocks(data)
if count <= maxCacheControlBlocks {
return body
}
// 超限:优先从 messages 中移除,再从 system 中移除
for count > maxCacheControlBlocks {
if removeCacheControlFromMessages(data) {
count--
continue
}
if removeCacheControlFromSystem(data) {
count--
continue
}
break
}
result, err := json.Marshal(data)
if err != nil {
return body
}
return result
}
// countCacheControlBlocks 统计 system 和 messages 中的 cache_control 块数量
func countCacheControlBlocks(data map[string]any) int {
count := 0
// 统计 system 中的块
if system, ok := data["system"].([]any); ok {
for _, item := range system {
if m, ok := item.(map[string]any); ok {
if _, has := m["cache_control"]; has {
count++
}
}
}
}
// 统计 messages 中的块
if messages, ok := data["messages"].([]any); ok {
for _, msg := range messages {
if msgMap, ok := msg.(map[string]any); ok {
if content, ok := msgMap["content"].([]any); ok {
for _, item := range content {
if m, ok := item.(map[string]any); ok {
if _, has := m["cache_control"]; has {
count++
}
}
}
}
}
}
}
return count
}
// removeCacheControlFromMessages 从 messages 中移除一个 cache_control从头开始
// 返回 true 表示成功移除false 表示没有可移除的
func removeCacheControlFromMessages(data map[string]any) bool {
messages, ok := data["messages"].([]any)
if !ok {
return false
}
for _, msg := range messages {
msgMap, ok := msg.(map[string]any)
if !ok {
continue
}
content, ok := msgMap["content"].([]any)
if !ok {
continue
}
for _, item := range content {
if m, ok := item.(map[string]any); ok {
if _, has := m["cache_control"]; has {
delete(m, "cache_control")
return true
}
}
}
}
return false
}
// removeCacheControlFromSystem 从 system 中移除一个 cache_control从尾部开始保护注入的 prompt
// 返回 true 表示成功移除false 表示没有可移除的
func removeCacheControlFromSystem(data map[string]any) bool {
system, ok := data["system"].([]any)
if !ok {
return false
}
// 从尾部开始移除,保护开头注入的 Claude Code prompt
for i := len(system) - 1; i >= 0; i-- {
if m, ok := system[i].(map[string]any); ok {
if _, has := m["cache_control"]; has {
delete(m, "cache_control")
return true
}
}
}
return false
}
// Forward 转发请求到Claude API
func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) (*ForwardResult, error) {
startTime := time.Now()
@@ -1093,6 +1318,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
body = injectClaudeCodePrompt(body, parsed.System)
}
// 强制执行 cache_control 块数量限制(最多 4 个)
body = enforceCacheControlLimit(body)
// 应用模型映射仅对apikey类型账号
originalModel := reqModel
if account.Type == AccountTypeAPIKey {
@@ -1316,6 +1544,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// 处理正常响应
var usage *ClaudeUsage
var firstTokenMs *int
var clientDisconnect bool
if reqStream {
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, reqModel)
if err != nil {
@@ -1328,6 +1557,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
}
usage = streamResult.usage
firstTokenMs = streamResult.firstTokenMs
clientDisconnect = streamResult.clientDisconnect
} else {
usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, reqModel)
if err != nil {
@@ -1336,12 +1566,13 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
}
return &ForwardResult{
RequestID: resp.Header.Get("x-request-id"),
Usage: *usage,
Model: originalModel, // 使用原始模型用于计费和日志
Stream: reqStream,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
RequestID: resp.Header.Get("x-request-id"),
Usage: *usage,
Model: originalModel, // 使用原始模型用于计费和日志
Stream: reqStream,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
ClientDisconnect: clientDisconnect,
}, nil
}
@@ -1696,8 +1927,9 @@ func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *ht
// streamingResult 流式响应结果
type streamingResult struct {
usage *ClaudeUsage
firstTokenMs *int
usage *ClaudeUsage
firstTokenMs *int
clientDisconnect bool // 客户端是否在流式传输过程中断开
}
func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string) (*streamingResult, error) {
@@ -1793,14 +2025,27 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
}
needModelReplace := originalModel != mappedModel
clientDisconnected := false // 客户端断开标志断开后继续读取上游以获取完整usage
for {
select {
case ev, ok := <-events:
if !ok {
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
// 上游完成,返回结果
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil
}
if ev.err != nil {
// 检测 context 取消(客户端断开会导致 context 取消,进而影响上游读取)
if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) {
log.Printf("Context canceled during streaming, returning collected usage")
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
}
// 客户端已通过写入失败检测到断开,上游也出错了,返回已收集的 usage
if clientDisconnected {
log.Printf("Upstream read error after client disconnect: %v, returning collected usage", ev.err)
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
}
// 客户端未断开,正常的错误处理
if errors.Is(ev.err, bufio.ErrTooLong) {
log.Printf("SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err)
sendErrorEvent("response_too_large")
@@ -1811,38 +2056,40 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
}
line := ev.line
if line == "event: error" {
// 上游返回错误事件,如果客户端已断开仍返回已收集的 usage
if clientDisconnected {
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
}
return nil, errors.New("have error in stream")
}
// Extract data from SSE line (supports both "data: " and "data:" formats)
var data string
if sseDataRe.MatchString(line) {
data := sseDataRe.ReplaceAllString(line, "")
data = sseDataRe.ReplaceAllString(line, "")
// 如果有模型映射替换响应中的model字段
if needModelReplace {
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
}
}
// 转发行
// 写入客户端(统一处理 data 行和非 data 行)
if !clientDisconnected {
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
sendErrorEvent("write_failed")
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
clientDisconnected = true
log.Printf("Client disconnected during streaming, continuing to drain upstream for billing")
} else {
flusher.Flush()
}
flusher.Flush()
}
// 记录首字时间:第一个有效的 content_block_delta 或 message_start
if firstTokenMs == nil && data != "" && data != "[DONE]" {
// 无论客户端是否断开,都解析 usage仅对 data 行)
if data != "" {
if firstTokenMs == nil && data != "[DONE]" {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
s.parseSSEUsage(data, usage)
} else {
// 非 data 行直接转发
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
sendErrorEvent("write_failed")
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
}
flusher.Flush()
}
case <-intervalCh:
@@ -1850,6 +2097,11 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
if time.Since(lastRead) < streamInterval {
continue
}
if clientDisconnected {
// 客户端已断开,上游也超时了,返回已收集的 usage
log.Printf("Upstream timeout after client disconnect, returning collected usage")
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
}
log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
sendErrorEvent("stream_timeout")
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
@@ -2003,6 +2255,8 @@ type RecordUsageInput struct {
User *User
Account *Account
Subscription *UserSubscription // 可选:订阅信息
UserAgent string // 请求的 User-Agent
IPAddress string // 请求的客户端 IP 地址
}
// RecordUsage 记录使用量并扣费(或更新订阅用量)
@@ -2088,6 +2342,16 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
CreatedAt: time.Now(),
}
// 添加 UserAgent
if input.UserAgent != "" {
usageLog.UserAgent = &input.UserAgent
}
// 添加 IPAddress
if input.IPAddress != "" {
usageLog.IPAddress = &input.IPAddress
}
// 添加分组和订阅关联
if apiKey.GroupID != nil {
usageLog.GroupID = apiKey.GroupID

View File

@@ -109,12 +109,12 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
cacheKey := "gemini:" + sessionHash
if sessionHash != "" {
accountID, err := s.cache.GetSessionAccountID(ctx, cacheKey)
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), cacheKey)
if err == nil && accountID > 0 {
if _, excluded := excludedIDs[accountID]; !excluded {
account, err := s.accountRepo.GetByID(ctx, accountID)
// 检查账号是否有效原生平台直接匹配antigravity 需要启用混合调度
if err == nil && account.IsSchedulable() && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
if err == nil && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
valid := false
if account.Platform == platform {
valid = true
@@ -133,7 +133,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
}
}
if usable {
_ = s.cache.RefreshSessionTTL(ctx, cacheKey, geminiStickySessionTTL)
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), cacheKey, geminiStickySessionTTL)
return account, nil
}
}
@@ -172,6 +172,9 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
if useMixedScheduling && acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
continue
}
if !acc.IsSchedulableForModel(requestedModel) {
continue
}
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
continue
}
@@ -217,7 +220,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
}
if sessionHash != "" {
_ = s.cache.SetSessionAccountID(ctx, cacheKey, selected.ID, geminiStickySessionTTL)
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), cacheKey, selected.ID, geminiStickySessionTTL)
}
return selected, nil

View File

@@ -90,6 +90,9 @@ func (m *mockAccountRepoForGemini) SetError(ctx context.Context, id int64, error
func (m *mockAccountRepoForGemini) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
return nil
}
func (m *mockAccountRepoForGemini) AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error) {
return 0, nil
}
func (m *mockAccountRepoForGemini) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
return nil
}
@@ -118,6 +121,9 @@ func (m *mockAccountRepoForGemini) ListSchedulableByGroupIDAndPlatforms(ctx cont
func (m *mockAccountRepoForGemini) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
return nil
}
func (m *mockAccountRepoForGemini) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error {
return nil
}
func (m *mockAccountRepoForGemini) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
return nil
}
@@ -128,6 +134,9 @@ func (m *mockAccountRepoForGemini) ClearTempUnschedulable(ctx context.Context, i
return nil
}
func (m *mockAccountRepoForGemini) ClearRateLimit(ctx context.Context, id int64) error { return nil }
func (m *mockAccountRepoForGemini) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error {
return nil
}
func (m *mockAccountRepoForGemini) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
return nil
}
@@ -163,7 +172,7 @@ func (m *mockGroupRepoForGemini) DeleteCascade(ctx context.Context, id int64) ([
func (m *mockGroupRepoForGemini) List(ctx context.Context, params pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (m *mockGroupRepoForGemini) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) {
func (m *mockGroupRepoForGemini) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (m *mockGroupRepoForGemini) ListActive(ctx context.Context) ([]Group, error) { return nil, nil }
@@ -187,14 +196,14 @@ type mockGatewayCacheForGemini struct {
sessionBindings map[string]int64
}
func (m *mockGatewayCacheForGemini) GetSessionAccountID(ctx context.Context, sessionHash string) (int64, error) {
func (m *mockGatewayCacheForGemini) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) {
if id, ok := m.sessionBindings[sessionHash]; ok {
return id, nil
}
return 0, errors.New("not found")
}
func (m *mockGatewayCacheForGemini) SetSessionAccountID(ctx context.Context, sessionHash string, accountID int64, ttl time.Duration) error {
func (m *mockGatewayCacheForGemini) SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error {
if m.sessionBindings == nil {
m.sessionBindings = make(map[string]int64)
}
@@ -202,7 +211,7 @@ func (m *mockGatewayCacheForGemini) SetSessionAccountID(ctx context.Context, ses
return nil
}
func (m *mockGatewayCacheForGemini) RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error {
func (m *mockGatewayCacheForGemini) RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error {
return nil
}

View File

@@ -120,15 +120,16 @@ func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64
}
// OAuth client selection:
// - code_assist: always use built-in Gemini CLI OAuth client (public), regardless of configured client_id/secret.
// - google_one: uses configured OAuth client when provided; otherwise falls back to built-in client.
// - ai_studio: requires a user-provided OAuth client.
// - code_assist: always use built-in Gemini CLI OAuth client (public)
// - google_one: always use built-in Gemini CLI OAuth client (public)
// - ai_studio: requires a user-provided OAuth client
oauthCfg := geminicli.OAuthConfig{
ClientID: s.cfg.Gemini.OAuth.ClientID,
ClientSecret: s.cfg.Gemini.OAuth.ClientSecret,
Scopes: s.cfg.Gemini.OAuth.Scopes,
}
if oauthType == "code_assist" {
if oauthType == "code_assist" || oauthType == "google_one" {
// Force use of built-in Gemini CLI OAuth client
oauthCfg.ClientID = ""
oauthCfg.ClientSecret = ""
}
@@ -576,6 +577,20 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
case "google_one":
log.Printf("[GeminiOAuth] Processing google_one OAuth type")
// Google One accounts use cloudaicompanion API, which requires a project_id.
// For personal accounts, Google auto-assigns a project_id via the LoadCodeAssist API.
if projectID == "" {
log.Printf("[GeminiOAuth] No project_id provided, attempting to fetch from LoadCodeAssist API...")
var err error
projectID, _, err = s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL)
if err != nil {
log.Printf("[GeminiOAuth] ERROR: Failed to fetch project_id: %v", err)
return nil, fmt.Errorf("google One accounts require a project_id, failed to auto-detect: %w", err)
}
log.Printf("[GeminiOAuth] Successfully fetched project_id: %s", projectID)
}
log.Printf("[GeminiOAuth] Attempting to fetch Google One tier from Drive API...")
// Attempt to fetch Drive storage tier
var storageInfo *geminicli.DriveStorageInfo

View File

@@ -40,7 +40,7 @@ func TestGeminiOAuthService_GenerateAuthURL_RedirectURIStrategy(t *testing.T) {
wantProjectID: "",
},
{
name: "google_one uses custom client when configured and redirects to localhost",
name: "google_one always forces built-in client even when custom client configured",
cfg: &config.Config{
Gemini: config.GeminiConfig{
OAuth: config.GeminiOAuthConfig{
@@ -50,9 +50,9 @@ func TestGeminiOAuthService_GenerateAuthURL_RedirectURIStrategy(t *testing.T) {
},
},
oauthType: "google_one",
wantClientID: "custom-client-id",
wantRedirect: geminicli.AIStudioOAuthRedirectURI,
wantScope: geminicli.DefaultGoogleOneScopes,
wantClientID: geminicli.GeminiCLIOAuthClientID,
wantRedirect: geminicli.GeminiCLIRedirectURI,
wantScope: geminicli.DefaultCodeAssistScopes,
wantProjectID: "",
},
{

View File

@@ -22,6 +22,10 @@ type Group struct {
ImagePrice2K *float64
ImagePrice4K *float64
// Claude Code 客户端限制
ClaudeCodeOnly bool
FallbackGroupID *int64
CreatedAt time.Time
UpdatedAt time.Time

View File

@@ -21,7 +21,7 @@ type GroupRepository interface {
DeleteCascade(ctx context.Context, id int64) ([]int64, error)
List(ctx context.Context, params pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error)
ListActive(ctx context.Context) ([]Group, error)
ListActiveByPlatform(ctx context.Context, platform string) ([]Group, error)

View File

@@ -134,11 +134,11 @@ func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context) string {
}
// BindStickySession sets session -> account binding with standard TTL.
func (s *OpenAIGatewayService) BindStickySession(ctx context.Context, sessionHash string, accountID int64) error {
func (s *OpenAIGatewayService) BindStickySession(ctx context.Context, groupID *int64, sessionHash string, accountID int64) error {
if sessionHash == "" || accountID <= 0 {
return nil
}
return s.cache.SetSessionAccountID(ctx, "openai:"+sessionHash, accountID, openaiStickySessionTTL)
return s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash, accountID, openaiStickySessionTTL)
}
// SelectAccount selects an OpenAI account with sticky session support
@@ -155,13 +155,13 @@ func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupI
func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
// 1. Check sticky session
if sessionHash != "" {
accountID, err := s.cache.GetSessionAccountID(ctx, "openai:"+sessionHash)
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash)
if err == nil && accountID > 0 {
if _, excluded := excludedIDs[accountID]; !excluded {
account, err := s.accountRepo.GetByID(ctx, accountID)
if err == nil && account.IsSchedulable() && account.IsOpenAI() && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
// Refresh sticky session TTL
_ = s.cache.RefreshSessionTTL(ctx, "openai:"+sessionHash, openaiStickySessionTTL)
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), "openai:"+sessionHash, openaiStickySessionTTL)
return account, nil
}
}
@@ -227,7 +227,7 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C
// 4. Set sticky session
if sessionHash != "" {
_ = s.cache.SetSessionAccountID(ctx, "openai:"+sessionHash, selected.ID, openaiStickySessionTTL)
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash, selected.ID, openaiStickySessionTTL)
}
return selected, nil
@@ -238,7 +238,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
cfg := s.schedulingConfig()
var stickyAccountID int64
if sessionHash != "" && s.cache != nil {
if accountID, err := s.cache.GetSessionAccountID(ctx, "openai:"+sessionHash); err == nil {
if accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash); err == nil {
stickyAccountID = accountID
}
}
@@ -298,14 +298,14 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
// ============ Layer 1: Sticky session ============
if sessionHash != "" {
accountID, err := s.cache.GetSessionAccountID(ctx, "openai:"+sessionHash)
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash)
if err == nil && accountID > 0 && !isExcluded(accountID) {
account, err := s.accountRepo.GetByID(ctx, accountID)
if err == nil && account.IsSchedulable() && account.IsOpenAI() &&
(requestedModel == "" || account.IsModelSupported(requestedModel)) {
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if err == nil && result.Acquired {
_ = s.cache.RefreshSessionTTL(ctx, "openai:"+sessionHash, openaiStickySessionTTL)
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), "openai:"+sessionHash, openaiStickySessionTTL)
return &AccountSelectionResult{
Account: account,
Acquired: true,
@@ -362,7 +362,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency)
if err == nil && result.Acquired {
if sessionHash != "" {
_ = s.cache.SetSessionAccountID(ctx, "openai:"+sessionHash, acc.ID, openaiStickySessionTTL)
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash, acc.ID, openaiStickySessionTTL)
}
return &AccountSelectionResult{
Account: acc,
@@ -415,7 +415,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
if err == nil && result.Acquired {
if sessionHash != "" {
_ = s.cache.SetSessionAccountID(ctx, "openai:"+sessionHash, item.account.ID, openaiStickySessionTTL)
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash, item.account.ID, openaiStickySessionTTL)
}
return &AccountSelectionResult{
Account: item.account,
@@ -540,10 +540,19 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
bodyModified = true
}
// For OAuth accounts using ChatGPT internal API, add store: false
// For OAuth accounts using ChatGPT internal API:
// 1. Add store: false
// 2. Normalize input format for Codex API compatibility
if account.Type == AccountTypeOAuth {
reqBody["store"] = false
bodyModified = true
// Normalize input format: convert AI SDK multi-part content format to simplified format
// AI SDK sends: {"content": [{"type": "input_text", "text": "..."}]}
// Codex API expects: {"content": "..."}
if normalizeInputForCodexAPI(reqBody) {
bodyModified = true
}
}
// Re-serialize body only if modified
@@ -1085,6 +1094,101 @@ func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel
return newBody
}
// normalizeInputForCodexAPI converts AI SDK multi-part content format to simplified format
// that the ChatGPT internal Codex API expects.
//
// AI SDK sends content as an array of typed objects:
//
// {"content": [{"type": "input_text", "text": "hello"}]}
//
// ChatGPT Codex API expects content as a simple string:
//
// {"content": "hello"}
//
// This function modifies reqBody in-place and returns true if any modification was made.
func normalizeInputForCodexAPI(reqBody map[string]any) bool {
input, ok := reqBody["input"]
if !ok {
return false
}
// Handle case where input is a simple string (already compatible)
if _, isString := input.(string); isString {
return false
}
// Handle case where input is an array of messages
inputArray, ok := input.([]any)
if !ok {
return false
}
modified := false
for _, item := range inputArray {
message, ok := item.(map[string]any)
if !ok {
continue
}
content, ok := message["content"]
if !ok {
continue
}
// If content is already a string, no conversion needed
if _, isString := content.(string); isString {
continue
}
// If content is an array (AI SDK format), convert to string
contentArray, ok := content.([]any)
if !ok {
continue
}
// Extract text from content array
var textParts []string
for _, part := range contentArray {
partMap, ok := part.(map[string]any)
if !ok {
continue
}
// Handle different content types
partType, _ := partMap["type"].(string)
switch partType {
case "input_text", "text":
// Extract text from input_text or text type
if text, ok := partMap["text"].(string); ok {
textParts = append(textParts, text)
}
case "input_image", "image":
// For images, we need to preserve the original format
// as ChatGPT Codex API may support images in a different way
// For now, skip image parts (they will be lost in conversion)
// TODO: Consider preserving image data or handling it separately
continue
case "input_file", "file":
// Similar to images, file inputs may need special handling
continue
default:
// For unknown types, try to extract text if available
if text, ok := partMap["text"].(string); ok {
textParts = append(textParts, text)
}
}
}
// Convert content array to string
if len(textParts) > 0 {
message["content"] = strings.Join(textParts, "\n")
modified = true
}
}
return modified
}
// OpenAIRecordUsageInput input for recording usage
type OpenAIRecordUsageInput struct {
Result *OpenAIForwardResult
@@ -1092,6 +1196,8 @@ type OpenAIRecordUsageInput struct {
User *User
Account *Account
Subscription *UserSubscription
UserAgent string // 请求的 User-Agent
IPAddress string // 请求的客户端 IP 地址
}
// RecordUsage records usage and deducts balance
@@ -1161,6 +1267,16 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
CreatedAt: time.Now(),
}
// 添加 UserAgent
if input.UserAgent != "" {
usageLog.UserAgent = &input.UserAgent
}
// 添加 IPAddress
if input.IPAddress != "" {
usageLog.IPAddress = &input.IPAddress
}
if apiKey.GroupID != nil {
usageLog.GroupID = apiKey.GroupID
}

View File

@@ -20,6 +20,7 @@ type ProxyRepository interface {
List(ctx context.Context, params pagination.PaginationParams) ([]Proxy, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]Proxy, *pagination.PaginationResult, error)
ListWithFiltersAndAccountCount(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]ProxyWithAccountCount, *pagination.PaginationResult, error)
ListActive(ctx context.Context) ([]Proxy, error)
ListActiveWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error)

View File

@@ -345,7 +345,7 @@ func (s *RateLimitService) UpdateSessionWindow(ctx context.Context, account *Acc
// 如果状态为allowed且之前有限流说明窗口已重置清除限流状态
if status == "allowed" && account.IsRateLimited() {
if err := s.accountRepo.ClearRateLimit(ctx, account.ID); err != nil {
if err := s.ClearRateLimit(ctx, account.ID); err != nil {
log.Printf("ClearRateLimit failed for account %d: %v", account.ID, err)
}
}
@@ -353,7 +353,10 @@ func (s *RateLimitService) UpdateSessionWindow(ctx context.Context, account *Acc
// ClearRateLimit 清除账号的限流状态
func (s *RateLimitService) ClearRateLimit(ctx context.Context, accountID int64) error {
return s.accountRepo.ClearRateLimit(ctx, accountID)
if err := s.accountRepo.ClearRateLimit(ctx, accountID); err != nil {
return err
}
return s.accountRepo.ClearAntigravityQuotaScopes(ctx, accountID)
}
func (s *RateLimitService) ClearTempUnschedulable(ctx context.Context, accountID int64) error {

View File

@@ -18,6 +18,13 @@ type SystemSettings struct {
TurnstileSecretKey string
TurnstileSecretKeyConfigured bool
// LinuxDo Connect OAuth 登录(终端用户 SSO
LinuxDoConnectEnabled bool
LinuxDoConnectClientID string
LinuxDoConnectClientSecret string
LinuxDoConnectClientSecretConfigured bool
LinuxDoConnectRedirectURL string
SiteName string
SiteLogo string
SiteSubtitle string
@@ -57,5 +64,6 @@ type PublicSettings struct {
APIBaseURL string
ContactInfo string
DocURL string
LinuxDoOAuthEnabled bool
Version string
}

View File

@@ -38,6 +38,8 @@ type UsageLog struct {
Stream bool
DurationMs *int
FirstTokenMs *int
UserAgent *string
IPAddress *string
// 图片生成字段
ImageCount int

View File

@@ -319,3 +319,12 @@ func (s *UsageService) GetGlobalStats(ctx context.Context, startTime, endTime ti
}
return stats, nil
}
// GetStatsWithFilters returns usage stats with optional filters.
func (s *UsageService) GetStatsWithFilters(ctx context.Context, filters usagestats.UsageLogFilters) (*usagestats.UsageStats, error) {
stats, err := s.usageRepo.GetStatsWithFilters(ctx, filters)
if err != nil {
return nil, fmt.Errorf("get usage stats with filters: %w", err)
}
return stats, nil
}

View File

@@ -49,6 +49,13 @@ func ProvideTokenRefreshService(
return svc
}
// ProvideAccountExpiryService creates and starts AccountExpiryService.
func ProvideAccountExpiryService(accountRepo AccountRepository) *AccountExpiryService {
svc := NewAccountExpiryService(accountRepo, time.Minute)
svc.Start()
return svc
}
// ProvideTimingWheelService creates and starts TimingWheelService
func ProvideTimingWheelService() *TimingWheelService {
svc := NewTimingWheelService()
@@ -168,6 +175,7 @@ var ProviderSet = wire.NewSet(
NewCRSSyncService,
ProvideUpdateService,
ProvideTokenRefreshService,
ProvideAccountExpiryService,
ProvideTimingWheelService,
ProvideDeferredService,
NewAntigravityQuotaFetcher,