merge: 合并主分支改动并保留 ops 监控实现

合并 main 分支的最新改动到 ops 监控分支。
冲突解决策略:保留当前分支的 ops 相关改动,接受主分支的其他改动。

保留的 ops 改动:
- 运维监控配置和依赖注入
- 运维监控 API 处理器和中间件
- 运维监控服务层和数据访问层
- 运维监控前端界面和状态管理

接受的主分支改动:
- Linux DO OAuth 集成
- 账号过期功能
- IP 地址限制功能
- 用量统计优化
- 其他 bug 修复和功能改进
This commit is contained in:
IanShaw027
2026-01-10 13:24:40 +08:00
155 changed files with 9227 additions and 1355 deletions

View File

@@ -76,7 +76,8 @@ func (r *accountRepository) Create(ctx context.Context, account *service.Account
SetPriority(account.Priority).
SetStatus(account.Status).
SetErrorMessage(account.ErrorMessage).
SetSchedulable(account.Schedulable)
SetSchedulable(account.Schedulable).
SetAutoPauseOnExpired(account.AutoPauseOnExpired)
if account.ProxyID != nil {
builder.SetProxyID(*account.ProxyID)
@@ -84,6 +85,9 @@ func (r *accountRepository) Create(ctx context.Context, account *service.Account
if account.LastUsedAt != nil {
builder.SetLastUsedAt(*account.LastUsedAt)
}
if account.ExpiresAt != nil {
builder.SetExpiresAt(*account.ExpiresAt)
}
if account.RateLimitedAt != nil {
builder.SetRateLimitedAt(*account.RateLimitedAt)
}
@@ -280,7 +284,8 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
SetPriority(account.Priority).
SetStatus(account.Status).
SetErrorMessage(account.ErrorMessage).
SetSchedulable(account.Schedulable)
SetSchedulable(account.Schedulable).
SetAutoPauseOnExpired(account.AutoPauseOnExpired)
if account.ProxyID != nil {
builder.SetProxyID(*account.ProxyID)
@@ -292,6 +297,11 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
} else {
builder.ClearLastUsedAt()
}
if account.ExpiresAt != nil {
builder.SetExpiresAt(*account.ExpiresAt)
} else {
builder.ClearExpiresAt()
}
if account.RateLimitedAt != nil {
builder.SetRateLimitedAt(*account.RateLimitedAt)
} else {
@@ -570,6 +580,7 @@ func (r *accountRepository) ListSchedulable(ctx context.Context) ([]service.Acco
dbaccount.StatusEQ(service.StatusActive),
dbaccount.SchedulableEQ(true),
tempUnschedulablePredicate(),
notExpiredPredicate(now),
dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)),
dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)),
).
@@ -596,6 +607,7 @@ func (r *accountRepository) ListSchedulableByPlatform(ctx context.Context, platf
dbaccount.StatusEQ(service.StatusActive),
dbaccount.SchedulableEQ(true),
tempUnschedulablePredicate(),
notExpiredPredicate(now),
dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)),
dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)),
).
@@ -629,6 +641,7 @@ func (r *accountRepository) ListSchedulableByPlatforms(ctx context.Context, plat
dbaccount.StatusEQ(service.StatusActive),
dbaccount.SchedulableEQ(true),
tempUnschedulablePredicate(),
notExpiredPredicate(now),
dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)),
dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)),
).
@@ -662,6 +675,40 @@ func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetA
return err
}
func (r *accountRepository) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope service.AntigravityQuotaScope, resetAt time.Time) error {
now := time.Now().UTC()
payload := map[string]string{
"rate_limited_at": now.Format(time.RFC3339),
"rate_limit_reset_at": resetAt.UTC().Format(time.RFC3339),
}
raw, err := json.Marshal(payload)
if err != nil {
return err
}
path := "{antigravity_quota_scopes," + string(scope) + "}"
client := clientFromContext(ctx, r.client)
result, err := client.ExecContext(
ctx,
"UPDATE accounts SET extra = jsonb_set(COALESCE(extra, '{}'::jsonb), $1::text[], $2::jsonb, true), updated_at = NOW() WHERE id = $3 AND deleted_at IS NULL",
path,
raw,
id,
)
if err != nil {
return err
}
affected, err := result.RowsAffected()
if err != nil {
return err
}
if affected == 0 {
return service.ErrAccountNotFound
}
return nil
}
func (r *accountRepository) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
_, err := r.client.Account.Update().
Where(dbaccount.IDEQ(id)).
@@ -705,6 +752,27 @@ func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error
return err
}
func (r *accountRepository) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error {
client := clientFromContext(ctx, r.client)
result, err := client.ExecContext(
ctx,
"UPDATE accounts SET extra = COALESCE(extra, '{}'::jsonb) - 'antigravity_quota_scopes', updated_at = NOW() WHERE id = $1 AND deleted_at IS NULL",
id,
)
if err != nil {
return err
}
affected, err := result.RowsAffected()
if err != nil {
return err
}
if affected == 0 {
return service.ErrAccountNotFound
}
return nil
}
func (r *accountRepository) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
builder := r.client.Account.Update().
Where(dbaccount.IDEQ(id)).
@@ -727,6 +795,27 @@ func (r *accountRepository) SetSchedulable(ctx context.Context, id int64, schedu
return err
}
func (r *accountRepository) AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error) {
result, err := r.sql.ExecContext(ctx, `
UPDATE accounts
SET schedulable = FALSE,
updated_at = NOW()
WHERE deleted_at IS NULL
AND schedulable = TRUE
AND auto_pause_on_expired = TRUE
AND expires_at IS NOT NULL
AND expires_at <= $1
`, now)
if err != nil {
return 0, err
}
rows, err := result.RowsAffected()
if err != nil {
return 0, err
}
return rows, nil
}
func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error {
if len(updates) == 0 {
return nil
@@ -797,6 +886,11 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
args = append(args, *updates.Status)
idx++
}
if updates.Schedulable != nil {
setClauses = append(setClauses, "schedulable = $"+itoa(idx))
args = append(args, *updates.Schedulable)
idx++
}
// JSONB 需要合并而非覆盖,使用 raw SQL 保持旧行为。
if len(updates.Credentials) > 0 {
payload, err := json.Marshal(updates.Credentials)
@@ -861,6 +955,7 @@ func (r *accountRepository) queryAccountsByGroup(ctx context.Context, groupID in
preds = append(preds,
dbaccount.SchedulableEQ(true),
tempUnschedulablePredicate(),
notExpiredPredicate(now),
dbaccount.Or(dbaccount.OverloadUntilIsNil(), dbaccount.OverloadUntilLTE(now)),
dbaccount.Or(dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(now)),
)
@@ -971,6 +1066,14 @@ func tempUnschedulablePredicate() dbpredicate.Account {
})
}
func notExpiredPredicate(now time.Time) dbpredicate.Account {
return dbaccount.Or(
dbaccount.ExpiresAtIsNil(),
dbaccount.ExpiresAtGT(now),
dbaccount.AutoPauseOnExpiredEQ(false),
)
}
func (r *accountRepository) loadTempUnschedStates(ctx context.Context, accountIDs []int64) (map[int64]tempUnschedSnapshot, error) {
out := make(map[int64]tempUnschedSnapshot)
if len(accountIDs) == 0 {
@@ -1086,6 +1189,8 @@ func accountEntityToService(m *dbent.Account) *service.Account {
Status: m.Status,
ErrorMessage: derefString(m.ErrorMessage),
LastUsedAt: m.LastUsedAt,
ExpiresAt: m.ExpiresAt,
AutoPauseOnExpired: m.AutoPauseOnExpired,
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
Schedulable: m.Schedulable,

View File

@@ -26,13 +26,21 @@ func (r *apiKeyRepository) activeQuery() *dbent.APIKeyQuery {
}
func (r *apiKeyRepository) Create(ctx context.Context, key *service.APIKey) error {
created, err := r.client.APIKey.Create().
builder := r.client.APIKey.Create().
SetUserID(key.UserID).
SetKey(key.Key).
SetName(key.Name).
SetStatus(key.Status).
SetNillableGroupID(key.GroupID).
Save(ctx)
SetNillableGroupID(key.GroupID)
if len(key.IPWhitelist) > 0 {
builder.SetIPWhitelist(key.IPWhitelist)
}
if len(key.IPBlacklist) > 0 {
builder.SetIPBlacklist(key.IPBlacklist)
}
created, err := builder.Save(ctx)
if err == nil {
key.ID = created.ID
key.CreatedAt = created.CreatedAt
@@ -108,6 +116,18 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) erro
builder.ClearGroupID()
}
// IP 限制字段
if len(key.IPWhitelist) > 0 {
builder.SetIPWhitelist(key.IPWhitelist)
} else {
builder.ClearIPWhitelist()
}
if len(key.IPBlacklist) > 0 {
builder.SetIPBlacklist(key.IPBlacklist)
} else {
builder.ClearIPBlacklist()
}
affected, err := builder.Save(ctx)
if err != nil {
return err
@@ -268,14 +288,16 @@ func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey {
return nil
}
out := &service.APIKey{
ID: m.ID,
UserID: m.UserID,
Key: m.Key,
Name: m.Name,
Status: m.Status,
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
GroupID: m.GroupID,
ID: m.ID,
UserID: m.UserID,
Key: m.Key,
Name: m.Name,
Status: m.Status,
IPWhitelist: m.IPWhitelist,
IPBlacklist: m.IPBlacklist,
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
GroupID: m.GroupID,
}
if m.Edges.User != nil {
out.User = userEntityToService(m.Edges.User)
@@ -325,6 +347,8 @@ func groupEntityToService(g *dbent.Group) *service.Group {
ImagePrice2K: g.ImagePrice2k,
ImagePrice4K: g.ImagePrice4k,
DefaultValidityDays: g.DefaultValidityDays,
ClaudeCodeOnly: g.ClaudeCodeOnly,
FallbackGroupID: g.FallbackGroupID,
CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt,
}

View File

@@ -2,6 +2,7 @@ package repository
import (
"context"
"fmt"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -18,17 +19,23 @@ func NewGatewayCache(rdb *redis.Client) service.GatewayCache {
return &gatewayCache{rdb: rdb}
}
func (c *gatewayCache) GetSessionAccountID(ctx context.Context, sessionHash string) (int64, error) {
key := stickySessionPrefix + sessionHash
// buildSessionKey 构建 session key包含 groupID 实现分组隔离
// 格式: sticky_session:{groupID}:{sessionHash}
func buildSessionKey(groupID int64, sessionHash string) string {
return fmt.Sprintf("%s%d:%s", stickySessionPrefix, groupID, sessionHash)
}
func (c *gatewayCache) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) {
key := buildSessionKey(groupID, sessionHash)
return c.rdb.Get(ctx, key).Int64()
}
func (c *gatewayCache) SetSessionAccountID(ctx context.Context, sessionHash string, accountID int64, ttl time.Duration) error {
key := stickySessionPrefix + sessionHash
func (c *gatewayCache) SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error {
key := buildSessionKey(groupID, sessionHash)
return c.rdb.Set(ctx, key, accountID, ttl).Err()
}
func (c *gatewayCache) RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error {
key := stickySessionPrefix + sessionHash
func (c *gatewayCache) RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error {
key := buildSessionKey(groupID, sessionHash)
return c.rdb.Expire(ctx, key, ttl).Err()
}

View File

@@ -24,18 +24,19 @@ func (s *GatewayCacheSuite) SetupTest() {
}
func (s *GatewayCacheSuite) TestGetSessionAccountID_Missing() {
_, err := s.cache.GetSessionAccountID(s.ctx, "nonexistent")
_, err := s.cache.GetSessionAccountID(s.ctx, 1, "nonexistent")
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil for missing session")
}
func (s *GatewayCacheSuite) TestSetAndGetSessionAccountID() {
sessionID := "s1"
accountID := int64(99)
groupID := int64(1)
sessionTTL := 1 * time.Minute
require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, sessionID, accountID, sessionTTL), "SetSessionAccountID")
require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, groupID, sessionID, accountID, sessionTTL), "SetSessionAccountID")
sid, err := s.cache.GetSessionAccountID(s.ctx, sessionID)
sid, err := s.cache.GetSessionAccountID(s.ctx, groupID, sessionID)
require.NoError(s.T(), err, "GetSessionAccountID")
require.Equal(s.T(), accountID, sid, "session id mismatch")
}
@@ -43,11 +44,12 @@ func (s *GatewayCacheSuite) TestSetAndGetSessionAccountID() {
func (s *GatewayCacheSuite) TestSessionAccountID_TTL() {
sessionID := "s2"
accountID := int64(100)
groupID := int64(1)
sessionTTL := 1 * time.Minute
require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, sessionID, accountID, sessionTTL), "SetSessionAccountID")
require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, groupID, sessionID, accountID, sessionTTL), "SetSessionAccountID")
sessionKey := stickySessionPrefix + sessionID
sessionKey := buildSessionKey(groupID, sessionID)
ttl, err := s.rdb.TTL(s.ctx, sessionKey).Result()
require.NoError(s.T(), err, "TTL sessionKey after Set")
s.AssertTTLWithin(ttl, 1*time.Second, sessionTTL)
@@ -56,14 +58,15 @@ func (s *GatewayCacheSuite) TestSessionAccountID_TTL() {
func (s *GatewayCacheSuite) TestRefreshSessionTTL() {
sessionID := "s3"
accountID := int64(101)
groupID := int64(1)
initialTTL := 1 * time.Minute
refreshTTL := 3 * time.Minute
require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, sessionID, accountID, initialTTL), "SetSessionAccountID")
require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, groupID, sessionID, accountID, initialTTL), "SetSessionAccountID")
require.NoError(s.T(), s.cache.RefreshSessionTTL(s.ctx, sessionID, refreshTTL), "RefreshSessionTTL")
require.NoError(s.T(), s.cache.RefreshSessionTTL(s.ctx, groupID, sessionID, refreshTTL), "RefreshSessionTTL")
sessionKey := stickySessionPrefix + sessionID
sessionKey := buildSessionKey(groupID, sessionID)
ttl, err := s.rdb.TTL(s.ctx, sessionKey).Result()
require.NoError(s.T(), err, "TTL after Refresh")
s.AssertTTLWithin(ttl, 1*time.Second, refreshTTL)
@@ -71,18 +74,19 @@ func (s *GatewayCacheSuite) TestRefreshSessionTTL() {
func (s *GatewayCacheSuite) TestRefreshSessionTTL_MissingKey() {
// RefreshSessionTTL on a missing key should not error (no-op)
err := s.cache.RefreshSessionTTL(s.ctx, "missing-session", 1*time.Minute)
err := s.cache.RefreshSessionTTL(s.ctx, 1, "missing-session", 1*time.Minute)
require.NoError(s.T(), err, "RefreshSessionTTL on missing key should not error")
}
func (s *GatewayCacheSuite) TestGetSessionAccountID_CorruptedValue() {
sessionID := "corrupted"
sessionKey := stickySessionPrefix + sessionID
groupID := int64(1)
sessionKey := buildSessionKey(groupID, sessionID)
// Set a non-integer value
require.NoError(s.T(), s.rdb.Set(s.ctx, sessionKey, "not-a-number", 1*time.Minute).Err(), "Set invalid value")
_, err := s.cache.GetSessionAccountID(s.ctx, sessionID)
_, err := s.cache.GetSessionAccountID(s.ctx, groupID, sessionID)
require.Error(s.T(), err, "expected error for corrupted value")
require.False(s.T(), errors.Is(err, redis.Nil), "expected parsing error, not redis.Nil")
}

View File

@@ -30,14 +30,15 @@ func (c *geminiOAuthClient) ExchangeCode(ctx context.Context, oauthType, code, c
// Use different OAuth clients based on oauthType:
// - code_assist: always use built-in Gemini CLI OAuth client (public)
// - google_one: uses configured OAuth client when provided; otherwise falls back to built-in client
// - google_one: always use built-in Gemini CLI OAuth client (public)
// - ai_studio: requires a user-provided OAuth client
oauthCfgInput := geminicli.OAuthConfig{
ClientID: c.cfg.Gemini.OAuth.ClientID,
ClientSecret: c.cfg.Gemini.OAuth.ClientSecret,
Scopes: c.cfg.Gemini.OAuth.Scopes,
}
if oauthType == "code_assist" {
if oauthType == "code_assist" || oauthType == "google_one" {
// Force use of built-in Gemini CLI OAuth client
oauthCfgInput.ClientID = ""
oauthCfgInput.ClientSecret = ""
}
@@ -78,7 +79,8 @@ func (c *geminiOAuthClient) RefreshToken(ctx context.Context, oauthType, refresh
ClientSecret: c.cfg.Gemini.OAuth.ClientSecret,
Scopes: c.cfg.Gemini.OAuth.Scopes,
}
if oauthType == "code_assist" {
if oauthType == "code_assist" || oauthType == "google_one" {
// Force use of built-in Gemini CLI OAuth client
oauthCfgInput.ClientID = ""
oauthCfgInput.ClientSecret = ""
}

View File

@@ -14,23 +14,33 @@ import (
)
type githubReleaseClient struct {
httpClient *http.Client
allowPrivateHosts bool
httpClient *http.Client
downloadHTTPClient *http.Client
}
func NewGitHubReleaseClient() service.GitHubReleaseClient {
allowPrivate := false
// NewGitHubReleaseClient 创建 GitHub Release 客户端
// proxyURL 为空时直连 GitHub支持 http/https/socks5/socks5h 协议
func NewGitHubReleaseClient(proxyURL string) service.GitHubReleaseClient {
sharedClient, err := httpclient.GetClient(httpclient.Options{
Timeout: 30 * time.Second,
ValidateResolvedIP: true,
AllowPrivateHosts: allowPrivate,
Timeout: 30 * time.Second,
ProxyURL: proxyURL,
})
if err != nil {
sharedClient = &http.Client{Timeout: 30 * time.Second}
}
// 下载客户端需要更长的超时时间
downloadClient, err := httpclient.GetClient(httpclient.Options{
Timeout: 10 * time.Minute,
ProxyURL: proxyURL,
})
if err != nil {
downloadClient = &http.Client{Timeout: 10 * time.Minute}
}
return &githubReleaseClient{
httpClient: sharedClient,
allowPrivateHosts: allowPrivate,
httpClient: sharedClient,
downloadHTTPClient: downloadClient,
}
}
@@ -68,15 +78,8 @@ func (c *githubReleaseClient) DownloadFile(ctx context.Context, url, dest string
return err
}
downloadClient, err := httpclient.GetClient(httpclient.Options{
Timeout: 10 * time.Minute,
ValidateResolvedIP: true,
AllowPrivateHosts: c.allowPrivateHosts,
})
if err != nil {
downloadClient = &http.Client{Timeout: 10 * time.Minute}
}
resp, err := downloadClient.Do(req)
// 使用预配置的下载客户端(已包含代理配置)
resp, err := c.downloadHTTPClient.Do(req)
if err != nil {
return err
}

View File

@@ -39,8 +39,8 @@ func (t *testTransport) RoundTrip(req *http.Request) (*http.Response, error) {
func newTestGitHubReleaseClient() *githubReleaseClient {
return &githubReleaseClient{
httpClient: &http.Client{},
allowPrivateHosts: true,
httpClient: &http.Client{},
downloadHTTPClient: &http.Client{},
}
}
@@ -234,7 +234,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Success() {
httpClient: &http.Client{
Transport: &testTransport{testServerURL: s.srv.URL},
},
allowPrivateHosts: true,
downloadHTTPClient: &http.Client{},
}
release, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
@@ -254,7 +254,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Non200() {
httpClient: &http.Client{
Transport: &testTransport{testServerURL: s.srv.URL},
},
allowPrivateHosts: true,
downloadHTTPClient: &http.Client{},
}
_, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
@@ -272,7 +272,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_InvalidJSON() {
httpClient: &http.Client{
Transport: &testTransport{testServerURL: s.srv.URL},
},
allowPrivateHosts: true,
downloadHTTPClient: &http.Client{},
}
_, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
@@ -288,7 +288,7 @@ func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_ContextCancel() {
httpClient: &http.Client{
Transport: &testTransport{testServerURL: s.srv.URL},
},
allowPrivateHosts: true,
downloadHTTPClient: &http.Client{},
}
ctx, cancel := context.WithCancel(context.Background())

View File

@@ -46,7 +46,9 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
SetNillableImagePrice1k(groupIn.ImagePrice1K).
SetNillableImagePrice2k(groupIn.ImagePrice2K).
SetNillableImagePrice4k(groupIn.ImagePrice4K).
SetDefaultValidityDays(groupIn.DefaultValidityDays)
SetDefaultValidityDays(groupIn.DefaultValidityDays).
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
SetNillableFallbackGroupID(groupIn.FallbackGroupID)
created, err := builder.Save(ctx)
if err == nil {
@@ -72,7 +74,7 @@ func (r *groupRepository) GetByID(ctx context.Context, id int64) (*service.Group
}
func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) error {
updated, err := r.client.Group.UpdateOneID(groupIn.ID).
builder := r.client.Group.UpdateOneID(groupIn.ID).
SetName(groupIn.Name).
SetDescription(groupIn.Description).
SetPlatform(groupIn.Platform).
@@ -87,7 +89,16 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
SetNillableImagePrice2k(groupIn.ImagePrice2K).
SetNillableImagePrice4k(groupIn.ImagePrice4K).
SetDefaultValidityDays(groupIn.DefaultValidityDays).
Save(ctx)
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly)
// 处理 FallbackGroupIDnil 时清除,否则设置
if groupIn.FallbackGroupID != nil {
builder = builder.SetFallbackGroupID(*groupIn.FallbackGroupID)
} else {
builder = builder.ClearFallbackGroupID()
}
updated, err := builder.Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrGroupNotFound, service.ErrGroupExists)
}
@@ -101,10 +112,10 @@ func (r *groupRepository) Delete(ctx context.Context, id int64) error {
}
func (r *groupRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Group, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", nil)
return r.ListWithFilters(ctx, params, "", "", "", nil)
}
func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]service.Group, *pagination.PaginationResult, error) {
func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]service.Group, *pagination.PaginationResult, error) {
q := r.client.Group.Query()
if platform != "" {
@@ -113,6 +124,12 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination
if status != "" {
q = q.Where(group.StatusEQ(status))
}
if search != "" {
q = q.Where(group.Or(
group.NameContainsFold(search),
group.DescriptionContainsFold(search),
))
}
if isExclusive != nil {
q = q.Where(group.IsExclusiveEQ(*isExclusive))
}

View File

@@ -131,6 +131,7 @@ func (s *GroupRepoSuite) TestListWithFilters_Platform() {
pagination.PaginationParams{Page: 1, PageSize: 10},
service.PlatformOpenAI,
"",
"",
nil,
)
s.Require().NoError(err, "ListWithFilters base")
@@ -152,7 +153,7 @@ func (s *GroupRepoSuite) TestListWithFilters_Platform() {
SubscriptionType: service.SubscriptionTypeStandard,
}))
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformOpenAI, "", nil)
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformOpenAI, "", "", nil)
s.Require().NoError(err)
s.Require().Len(groups, len(baseGroups)+1)
// Verify all groups are OpenAI platform
@@ -179,7 +180,7 @@ func (s *GroupRepoSuite) TestListWithFilters_Status() {
SubscriptionType: service.SubscriptionTypeStandard,
}))
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.StatusDisabled, nil)
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.StatusDisabled, "", nil)
s.Require().NoError(err)
s.Require().Len(groups, 1)
s.Require().Equal(service.StatusDisabled, groups[0].Status)
@@ -204,12 +205,117 @@ func (s *GroupRepoSuite) TestListWithFilters_IsExclusive() {
}))
isExclusive := true
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", &isExclusive)
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "", &isExclusive)
s.Require().NoError(err)
s.Require().Len(groups, 1)
s.Require().True(groups[0].IsExclusive)
}
func (s *GroupRepoSuite) TestListWithFilters_Search() {
newRepo := func() (*groupRepository, context.Context) {
tx := testEntTx(s.T())
return newGroupRepositoryWithSQL(tx.Client(), tx), context.Background()
}
containsID := func(groups []service.Group, id int64) bool {
for i := range groups {
if groups[i].ID == id {
return true
}
}
return false
}
mustCreate := func(repo *groupRepository, ctx context.Context, g *service.Group) *service.Group {
s.Require().NoError(repo.Create(ctx, g))
s.Require().NotZero(g.ID)
return g
}
newGroup := func(name string) *service.Group {
return &service.Group{
Name: name,
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}
}
s.Run("search_name_should_match", func() {
repo, ctx := newRepo()
target := mustCreate(repo, ctx, newGroup("it-group-search-name-target"))
other := mustCreate(repo, ctx, newGroup("it-group-search-name-other"))
groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "name-target", nil)
s.Require().NoError(err)
s.Require().True(containsID(groups, target.ID), "expected target group to match by name")
s.Require().False(containsID(groups, other.ID), "expected other group to be filtered out")
})
s.Run("search_description_should_match", func() {
repo, ctx := newRepo()
target := newGroup("it-group-search-desc-target")
target.Description = "something about desc-needle in here"
target = mustCreate(repo, ctx, target)
other := newGroup("it-group-search-desc-other")
other.Description = "nothing to see here"
other = mustCreate(repo, ctx, other)
groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "desc-needle", nil)
s.Require().NoError(err)
s.Require().True(containsID(groups, target.ID), "expected target group to match by description")
s.Require().False(containsID(groups, other.ID), "expected other group to be filtered out")
})
s.Run("search_nonexistent_should_return_empty", func() {
repo, ctx := newRepo()
_ = mustCreate(repo, ctx, newGroup("it-group-search-nonexistent-baseline"))
search := s.T().Name() + "__no_such_group__"
groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", search, nil)
s.Require().NoError(err)
s.Require().Empty(groups)
})
s.Run("search_should_be_case_insensitive", func() {
repo, ctx := newRepo()
target := mustCreate(repo, ctx, newGroup("MiXeDCaSe-Needle"))
other := mustCreate(repo, ctx, newGroup("it-group-search-case-other"))
groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "mixedcase-needle", nil)
s.Require().NoError(err)
s.Require().True(containsID(groups, target.ID), "expected case-insensitive match")
s.Require().False(containsID(groups, other.ID), "expected other group to be filtered out")
})
s.Run("search_should_escape_like_wildcards", func() {
repo, ctx := newRepo()
percentTarget := mustCreate(repo, ctx, newGroup("it-group-search-100%-target"))
percentOther := mustCreate(repo, ctx, newGroup("it-group-search-100X-other"))
groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "100%", nil)
s.Require().NoError(err)
s.Require().True(containsID(groups, percentTarget.ID), "expected literal %% match")
s.Require().False(containsID(groups, percentOther.ID), "expected %% not to act as wildcard")
underscoreTarget := mustCreate(repo, ctx, newGroup("it-group-search-ab_cd-target"))
underscoreOther := mustCreate(repo, ctx, newGroup("it-group-search-abXcd-other"))
groups, _, err = repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "ab_cd", nil)
s.Require().NoError(err)
s.Require().True(containsID(groups, underscoreTarget.ID), "expected literal _ match")
s.Require().False(containsID(groups, underscoreOther.ID), "expected _ not to act as wildcard")
})
}
func (s *GroupRepoSuite) TestListWithFilters_AccountCount() {
g1 := &service.Group{
Name: "g1",
@@ -244,7 +350,7 @@ func (s *GroupRepoSuite) TestListWithFilters_AccountCount() {
s.Require().NoError(err)
isExclusive := true
groups, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformAnthropic, service.StatusActive, &isExclusive)
groups, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformAnthropic, service.StatusActive, "", &isExclusive)
s.Require().NoError(err, "ListWithFilters")
s.Require().Equal(int64(1), page.Total)
s.Require().Len(groups, 1)

View File

@@ -8,7 +8,6 @@ import (
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/service"
)
@@ -17,17 +16,12 @@ type pricingRemoteClient struct {
httpClient *http.Client
}
func NewPricingRemoteClient(cfg *config.Config) service.PricingRemoteClient {
allowPrivate := false
validateResolvedIP := true
if cfg != nil {
allowPrivate = cfg.Security.URLAllowlist.AllowPrivateHosts
validateResolvedIP = cfg.Security.URLAllowlist.Enabled
}
// NewPricingRemoteClient 创建定价数据远程客户端
// proxyURL 为空时直连,支持 http/https/socks5/socks5h 协议
func NewPricingRemoteClient(proxyURL string) service.PricingRemoteClient {
sharedClient, err := httpclient.GetClient(httpclient.Options{
Timeout: 30 * time.Second,
ValidateResolvedIP: validateResolvedIP,
AllowPrivateHosts: allowPrivate,
Timeout: 30 * time.Second,
ProxyURL: proxyURL,
})
if err != nil {
sharedClient = &http.Client{Timeout: 30 * time.Second}

View File

@@ -6,7 +6,6 @@ import (
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
@@ -20,13 +19,7 @@ type PricingServiceSuite struct {
func (s *PricingServiceSuite) SetupTest() {
s.ctx = context.Background()
client, ok := NewPricingRemoteClient(&config.Config{
Security: config.SecurityConfig{
URLAllowlist: config.URLAllowlistConfig{
AllowPrivateHosts: true,
},
},
}).(*pricingRemoteClient)
client, ok := NewPricingRemoteClient("").(*pricingRemoteClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
}

View File

@@ -133,6 +133,55 @@ func (r *proxyRepository) ListWithFilters(ctx context.Context, params pagination
return outProxies, paginationResultFromTotal(int64(total), params), nil
}
// ListWithFiltersAndAccountCount lists proxies with filters and includes account count per proxy
func (r *proxyRepository) ListWithFiltersAndAccountCount(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]service.ProxyWithAccountCount, *pagination.PaginationResult, error) {
q := r.client.Proxy.Query()
if protocol != "" {
q = q.Where(proxy.ProtocolEQ(protocol))
}
if status != "" {
q = q.Where(proxy.StatusEQ(status))
}
if search != "" {
q = q.Where(proxy.NameContainsFold(search))
}
total, err := q.Count(ctx)
if err != nil {
return nil, nil, err
}
proxies, err := q.
Offset(params.Offset()).
Limit(params.Limit()).
Order(dbent.Desc(proxy.FieldID)).
All(ctx)
if err != nil {
return nil, nil, err
}
// Get account counts
counts, err := r.GetAccountCountsForProxies(ctx)
if err != nil {
return nil, nil, err
}
// Build result with account counts
result := make([]service.ProxyWithAccountCount, 0, len(proxies))
for i := range proxies {
proxyOut := proxyEntityToService(proxies[i])
if proxyOut == nil {
continue
}
result = append(result, service.ProxyWithAccountCount{
Proxy: *proxyOut,
AccountCount: counts[proxyOut.ID],
})
}
return result, paginationResultFromTotal(int64(total), params), nil
}
func (r *proxyRepository) ListActive(ctx context.Context) ([]service.Proxy, error) {
proxies, err := r.client.Proxy.Query().
Where(proxy.StatusEQ(service.StatusActive)).

View File

@@ -22,7 +22,7 @@ import (
"github.com/lib/pq"
)
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, billing_type, stream, duration_ms, first_token_ms, image_count, image_size, created_at"
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, created_at"
type usageLogRepository struct {
client *dbent.Client
@@ -109,6 +109,8 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
stream,
duration_ms,
first_token_ms,
user_agent,
ip_address,
image_count,
image_size,
created_at
@@ -118,8 +120,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
$8, $9, $10, $11,
$12, $13,
$14, $15, $16, $17, $18, $19,
$20, $21, $22, $23, $24,
$25, $26, $27
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29
)
ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id, created_at
@@ -129,6 +130,8 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
subscriptionID := nullInt64(log.SubscriptionID)
duration := nullInt(log.DurationMs)
firstToken := nullInt(log.FirstTokenMs)
userAgent := nullString(log.UserAgent)
ipAddress := nullString(log.IPAddress)
imageSize := nullString(log.ImageSize)
var requestIDArg any
@@ -161,6 +164,8 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
log.Stream,
duration,
firstToken,
userAgent,
ipAddress,
log.ImageCount,
imageSize,
createdAt,
@@ -1388,6 +1393,81 @@ func (r *usageLogRepository) GetGlobalStats(ctx context.Context, startTime, endT
return stats, nil
}
// GetStatsWithFilters gets usage statistics with optional filters
func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters UsageLogFilters) (*UsageStats, error) {
conditions := make([]string, 0, 9)
args := make([]any, 0, 9)
if filters.UserID > 0 {
conditions = append(conditions, fmt.Sprintf("user_id = $%d", len(args)+1))
args = append(args, filters.UserID)
}
if filters.APIKeyID > 0 {
conditions = append(conditions, fmt.Sprintf("api_key_id = $%d", len(args)+1))
args = append(args, filters.APIKeyID)
}
if filters.AccountID > 0 {
conditions = append(conditions, fmt.Sprintf("account_id = $%d", len(args)+1))
args = append(args, filters.AccountID)
}
if filters.GroupID > 0 {
conditions = append(conditions, fmt.Sprintf("group_id = $%d", len(args)+1))
args = append(args, filters.GroupID)
}
if filters.Model != "" {
conditions = append(conditions, fmt.Sprintf("model = $%d", len(args)+1))
args = append(args, filters.Model)
}
if filters.Stream != nil {
conditions = append(conditions, fmt.Sprintf("stream = $%d", len(args)+1))
args = append(args, *filters.Stream)
}
if filters.BillingType != nil {
conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1))
args = append(args, int16(*filters.BillingType))
}
if filters.StartTime != nil {
conditions = append(conditions, fmt.Sprintf("created_at >= $%d", len(args)+1))
args = append(args, *filters.StartTime)
}
if filters.EndTime != nil {
conditions = append(conditions, fmt.Sprintf("created_at <= $%d", len(args)+1))
args = append(args, *filters.EndTime)
}
query := fmt.Sprintf(`
SELECT
COUNT(*) as total_requests,
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens,
COALESCE(SUM(total_cost), 0) as total_cost,
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
COALESCE(AVG(duration_ms), 0) as avg_duration_ms
FROM usage_logs
%s
`, buildWhere(conditions))
stats := &UsageStats{}
if err := scanSingleRow(
ctx,
r.sql,
query,
args,
&stats.TotalRequests,
&stats.TotalInputTokens,
&stats.TotalOutputTokens,
&stats.TotalCacheTokens,
&stats.TotalCost,
&stats.TotalActualCost,
&stats.AverageDurationMs,
); err != nil {
return nil, err
}
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
return stats, nil
}
// AccountUsageHistory represents daily usage history for an account
type AccountUsageHistory = usagestats.AccountUsageHistory
@@ -1795,6 +1875,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
stream bool
durationMs sql.NullInt64
firstTokenMs sql.NullInt64
userAgent sql.NullString
ipAddress sql.NullString
imageCount int
imageSize sql.NullString
createdAt time.Time
@@ -1826,6 +1908,8 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
&stream,
&durationMs,
&firstTokenMs,
&userAgent,
&ipAddress,
&imageCount,
&imageSize,
&createdAt,
@@ -1877,6 +1961,12 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
value := int(firstTokenMs.Int64)
log.FirstTokenMs = &value
}
if userAgent.Valid {
log.UserAgent = &userAgent.String
}
if ipAddress.Valid {
log.IPAddress = &ipAddress.String
}
if imageSize.Valid {
log.ImageSize = &imageSize.String
}

View File

@@ -25,6 +25,18 @@ func ProvideConcurrencyCache(rdb *redis.Client, cfg *config.Config) service.Conc
return NewConcurrencyCache(rdb, cfg.Gateway.ConcurrencySlotTTLMinutes, waitTTLSeconds)
}
// ProvideGitHubReleaseClient 创建 GitHub Release 客户端
// 从配置中读取代理设置,支持国内服务器通过代理访问 GitHub
func ProvideGitHubReleaseClient(cfg *config.Config) service.GitHubReleaseClient {
return NewGitHubReleaseClient(cfg.Update.ProxyURL)
}
// ProvidePricingRemoteClient 创建定价数据远程客户端
// 从配置中读取代理设置,支持国内服务器通过代理访问 GitHub 上的定价数据
func ProvidePricingRemoteClient(cfg *config.Config) service.PricingRemoteClient {
return NewPricingRemoteClient(cfg.Update.ProxyURL)
}
// ProviderSet is the Wire provider set for all repositories
var ProviderSet = wire.NewSet(
NewUserRepository,
@@ -54,8 +66,8 @@ var ProviderSet = wire.NewSet(
// HTTP service ports (DI Strategy A: return interface directly)
NewTurnstileVerifier,
NewPricingRemoteClient,
NewGitHubReleaseClient,
ProvidePricingRemoteClient,
ProvideGitHubReleaseClient,
NewProxyExitInfoProber,
NewClaudeUsageFetcher,
NewClaudeOAuthClient,