merge: 合并main分支最新改动

解决冲突:
- backend/internal/config/config.go: 合并Ops和Dashboard配置
- backend/internal/server/api_contract_test.go: 合并handler初始化
- backend/internal/service/openai_gateway_service.go: 保留Ops错误追踪逻辑
- backend/internal/service/wire.go: 合并Ops和APIKeyAuth provider

主要合并内容:
- Dashboard缓存和预聚合功能
- API Key认证缓存优化
- Codex转换支持
- 使用日志分区表
This commit is contained in:
IanShaw027
2026-01-11 23:15:01 +08:00
58 changed files with 5385 additions and 351 deletions

View File

@@ -186,9 +186,11 @@ type BulkUpdateAccountResult struct {
// BulkUpdateAccountsResult is the aggregated response for bulk updates.
type BulkUpdateAccountsResult struct {
Success int `json:"success"`
Failed int `json:"failed"`
Results []BulkUpdateAccountResult `json:"results"`
Success int `json:"success"`
Failed int `json:"failed"`
SuccessIDs []int64 `json:"success_ids"`
FailedIDs []int64 `json:"failed_ids"`
Results []BulkUpdateAccountResult `json:"results"`
}
type CreateProxyInput struct {
@@ -244,14 +246,15 @@ type ProxyExitInfoProber interface {
// adminServiceImpl implements AdminService
type adminServiceImpl struct {
userRepo UserRepository
groupRepo GroupRepository
accountRepo AccountRepository
proxyRepo ProxyRepository
apiKeyRepo APIKeyRepository
redeemCodeRepo RedeemCodeRepository
billingCacheService *BillingCacheService
proxyProber ProxyExitInfoProber
userRepo UserRepository
groupRepo GroupRepository
accountRepo AccountRepository
proxyRepo ProxyRepository
apiKeyRepo APIKeyRepository
redeemCodeRepo RedeemCodeRepository
billingCacheService *BillingCacheService
proxyProber ProxyExitInfoProber
authCacheInvalidator APIKeyAuthCacheInvalidator
}
// NewAdminService creates a new AdminService
@@ -264,16 +267,18 @@ func NewAdminService(
redeemCodeRepo RedeemCodeRepository,
billingCacheService *BillingCacheService,
proxyProber ProxyExitInfoProber,
authCacheInvalidator APIKeyAuthCacheInvalidator,
) AdminService {
return &adminServiceImpl{
userRepo: userRepo,
groupRepo: groupRepo,
accountRepo: accountRepo,
proxyRepo: proxyRepo,
apiKeyRepo: apiKeyRepo,
redeemCodeRepo: redeemCodeRepo,
billingCacheService: billingCacheService,
proxyProber: proxyProber,
userRepo: userRepo,
groupRepo: groupRepo,
accountRepo: accountRepo,
proxyRepo: proxyRepo,
apiKeyRepo: apiKeyRepo,
redeemCodeRepo: redeemCodeRepo,
billingCacheService: billingCacheService,
proxyProber: proxyProber,
authCacheInvalidator: authCacheInvalidator,
}
}
@@ -323,6 +328,8 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
}
oldConcurrency := user.Concurrency
oldStatus := user.Status
oldRole := user.Role
if input.Email != "" {
user.Email = input.Email
@@ -355,6 +362,11 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
if err := s.userRepo.Update(ctx, user); err != nil {
return nil, err
}
if s.authCacheInvalidator != nil {
if user.Concurrency != oldConcurrency || user.Status != oldStatus || user.Role != oldRole {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, user.ID)
}
}
concurrencyDiff := user.Concurrency - oldConcurrency
if concurrencyDiff != 0 {
@@ -393,6 +405,9 @@ func (s *adminServiceImpl) DeleteUser(ctx context.Context, id int64) error {
log.Printf("delete user failed: user_id=%d err=%v", id, err)
return err
}
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, id)
}
return nil
}
@@ -420,6 +435,10 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
if err := s.userRepo.Update(ctx, user); err != nil {
return nil, err
}
balanceDiff := user.Balance - oldBalance
if s.authCacheInvalidator != nil && balanceDiff != 0 {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
}
if s.billingCacheService != nil {
go func() {
@@ -431,7 +450,6 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
}()
}
balanceDiff := user.Balance - oldBalance
if balanceDiff != 0 {
code, err := GenerateRedeemCode()
if err != nil {
@@ -675,10 +693,21 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
if err := s.groupRepo.Update(ctx, group); err != nil {
return nil, err
}
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id)
}
return group, nil
}
func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
var groupKeys []string
if s.authCacheInvalidator != nil {
keys, err := s.apiKeyRepo.ListKeysByGroupID(ctx, id)
if err == nil {
groupKeys = keys
}
}
affectedUserIDs, err := s.groupRepo.DeleteCascade(ctx, id)
if err != nil {
return err
@@ -697,6 +726,11 @@ func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
}
}()
}
if s.authCacheInvalidator != nil {
for _, key := range groupKeys {
s.authCacheInvalidator.InvalidateAuthCacheByKey(ctx, key)
}
}
return nil
}
@@ -885,7 +919,9 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
// It merges credentials/extra keys instead of overwriting the whole object.
func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error) {
result := &BulkUpdateAccountsResult{
Results: make([]BulkUpdateAccountResult, 0, len(input.AccountIDs)),
SuccessIDs: make([]int64, 0, len(input.AccountIDs)),
FailedIDs: make([]int64, 0, len(input.AccountIDs)),
Results: make([]BulkUpdateAccountResult, 0, len(input.AccountIDs)),
}
if len(input.AccountIDs) == 0 {
@@ -949,6 +985,7 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
entry.Success = false
entry.Error = err.Error()
result.Failed++
result.FailedIDs = append(result.FailedIDs, accountID)
result.Results = append(result.Results, entry)
continue
}
@@ -958,6 +995,7 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
entry.Success = false
entry.Error = err.Error()
result.Failed++
result.FailedIDs = append(result.FailedIDs, accountID)
result.Results = append(result.Results, entry)
continue
}
@@ -967,6 +1005,7 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
entry.Success = false
entry.Error = err.Error()
result.Failed++
result.FailedIDs = append(result.FailedIDs, accountID)
result.Results = append(result.Results, entry)
continue
}
@@ -974,6 +1013,7 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
entry.Success = true
result.Success++
result.SuccessIDs = append(result.SuccessIDs, accountID)
result.Results = append(result.Results, entry)
}

View File

@@ -0,0 +1,80 @@
//go:build unit
package service
import (
"context"
"errors"
"testing"
"github.com/stretchr/testify/require"
)
type accountRepoStubForBulkUpdate struct {
accountRepoStub
bulkUpdateErr error
bulkUpdateIDs []int64
bindGroupErrByID map[int64]error
}
func (s *accountRepoStubForBulkUpdate) BulkUpdate(_ context.Context, ids []int64, _ AccountBulkUpdate) (int64, error) {
s.bulkUpdateIDs = append([]int64{}, ids...)
if s.bulkUpdateErr != nil {
return 0, s.bulkUpdateErr
}
return int64(len(ids)), nil
}
func (s *accountRepoStubForBulkUpdate) BindGroups(_ context.Context, accountID int64, _ []int64) error {
if err, ok := s.bindGroupErrByID[accountID]; ok {
return err
}
return nil
}
// TestAdminService_BulkUpdateAccounts_AllSuccessIDs 验证批量更新成功时返回 success_ids/failed_ids。
func TestAdminService_BulkUpdateAccounts_AllSuccessIDs(t *testing.T) {
repo := &accountRepoStubForBulkUpdate{}
svc := &adminServiceImpl{accountRepo: repo}
schedulable := true
input := &BulkUpdateAccountsInput{
AccountIDs: []int64{1, 2, 3},
Schedulable: &schedulable,
}
result, err := svc.BulkUpdateAccounts(context.Background(), input)
require.NoError(t, err)
require.Equal(t, 3, result.Success)
require.Equal(t, 0, result.Failed)
require.ElementsMatch(t, []int64{1, 2, 3}, result.SuccessIDs)
require.Empty(t, result.FailedIDs)
require.Len(t, result.Results, 3)
}
// TestAdminService_BulkUpdateAccounts_PartialFailureIDs 验证部分失败时 success_ids/failed_ids 正确。
func TestAdminService_BulkUpdateAccounts_PartialFailureIDs(t *testing.T) {
repo := &accountRepoStubForBulkUpdate{
bindGroupErrByID: map[int64]error{
2: errors.New("bind failed"),
},
}
svc := &adminServiceImpl{accountRepo: repo}
groupIDs := []int64{10}
schedulable := false
input := &BulkUpdateAccountsInput{
AccountIDs: []int64{1, 2, 3},
GroupIDs: &groupIDs,
Schedulable: &schedulable,
SkipMixedChannelCheck: true,
}
result, err := svc.BulkUpdateAccounts(context.Background(), input)
require.NoError(t, err)
require.Equal(t, 2, result.Success)
require.Equal(t, 1, result.Failed)
require.ElementsMatch(t, []int64{1, 3}, result.SuccessIDs)
require.ElementsMatch(t, []int64{2}, result.FailedIDs)
require.Len(t, result.Results, 3)
}

View File

@@ -0,0 +1,97 @@
//go:build unit
package service
import (
"context"
"testing"
"github.com/stretchr/testify/require"
)
type balanceUserRepoStub struct {
*userRepoStub
updateErr error
updated []*User
}
func (s *balanceUserRepoStub) Update(ctx context.Context, user *User) error {
if s.updateErr != nil {
return s.updateErr
}
if user == nil {
return nil
}
clone := *user
s.updated = append(s.updated, &clone)
if s.userRepoStub != nil {
s.userRepoStub.user = &clone
}
return nil
}
type balanceRedeemRepoStub struct {
*redeemRepoStub
created []*RedeemCode
}
func (s *balanceRedeemRepoStub) Create(ctx context.Context, code *RedeemCode) error {
if code == nil {
return nil
}
clone := *code
s.created = append(s.created, &clone)
return nil
}
type authCacheInvalidatorStub struct {
userIDs []int64
groupIDs []int64
keys []string
}
func (s *authCacheInvalidatorStub) InvalidateAuthCacheByKey(ctx context.Context, key string) {
s.keys = append(s.keys, key)
}
func (s *authCacheInvalidatorStub) InvalidateAuthCacheByUserID(ctx context.Context, userID int64) {
s.userIDs = append(s.userIDs, userID)
}
func (s *authCacheInvalidatorStub) InvalidateAuthCacheByGroupID(ctx context.Context, groupID int64) {
s.groupIDs = append(s.groupIDs, groupID)
}
func TestAdminService_UpdateUserBalance_InvalidatesAuthCache(t *testing.T) {
baseRepo := &userRepoStub{user: &User{ID: 7, Balance: 10}}
repo := &balanceUserRepoStub{userRepoStub: baseRepo}
redeemRepo := &balanceRedeemRepoStub{redeemRepoStub: &redeemRepoStub{}}
invalidator := &authCacheInvalidatorStub{}
svc := &adminServiceImpl{
userRepo: repo,
redeemCodeRepo: redeemRepo,
authCacheInvalidator: invalidator,
}
_, err := svc.UpdateUserBalance(context.Background(), 7, 5, "add", "")
require.NoError(t, err)
require.Equal(t, []int64{7}, invalidator.userIDs)
require.Len(t, redeemRepo.created, 1)
}
func TestAdminService_UpdateUserBalance_NoChangeNoInvalidate(t *testing.T) {
baseRepo := &userRepoStub{user: &User{ID: 7, Balance: 10}}
repo := &balanceUserRepoStub{userRepoStub: baseRepo}
redeemRepo := &balanceRedeemRepoStub{redeemRepoStub: &redeemRepoStub{}}
invalidator := &authCacheInvalidatorStub{}
svc := &adminServiceImpl{
userRepo: repo,
redeemCodeRepo: redeemRepo,
authCacheInvalidator: invalidator,
}
_, err := svc.UpdateUserBalance(context.Background(), 7, 10, "set", "")
require.NoError(t, err)
require.Empty(t, invalidator.userIDs)
require.Empty(t, redeemRepo.created)
}

View File

@@ -0,0 +1,46 @@
package service
// APIKeyAuthSnapshot API Key 认证缓存快照(仅包含认证所需字段)
type APIKeyAuthSnapshot struct {
APIKeyID int64 `json:"api_key_id"`
UserID int64 `json:"user_id"`
GroupID *int64 `json:"group_id,omitempty"`
Status string `json:"status"`
IPWhitelist []string `json:"ip_whitelist,omitempty"`
IPBlacklist []string `json:"ip_blacklist,omitempty"`
User APIKeyAuthUserSnapshot `json:"user"`
Group *APIKeyAuthGroupSnapshot `json:"group,omitempty"`
}
// APIKeyAuthUserSnapshot 用户快照
type APIKeyAuthUserSnapshot struct {
ID int64 `json:"id"`
Status string `json:"status"`
Role string `json:"role"`
Balance float64 `json:"balance"`
Concurrency int `json:"concurrency"`
}
// APIKeyAuthGroupSnapshot 分组快照
type APIKeyAuthGroupSnapshot struct {
ID int64 `json:"id"`
Name string `json:"name"`
Platform string `json:"platform"`
Status string `json:"status"`
SubscriptionType string `json:"subscription_type"`
RateMultiplier float64 `json:"rate_multiplier"`
DailyLimitUSD *float64 `json:"daily_limit_usd,omitempty"`
WeeklyLimitUSD *float64 `json:"weekly_limit_usd,omitempty"`
MonthlyLimitUSD *float64 `json:"monthly_limit_usd,omitempty"`
ImagePrice1K *float64 `json:"image_price_1k,omitempty"`
ImagePrice2K *float64 `json:"image_price_2k,omitempty"`
ImagePrice4K *float64 `json:"image_price_4k,omitempty"`
ClaudeCodeOnly bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`
}
// APIKeyAuthCacheEntry 缓存条目,支持负缓存
type APIKeyAuthCacheEntry struct {
NotFound bool `json:"not_found"`
Snapshot *APIKeyAuthSnapshot `json:"snapshot,omitempty"`
}

View File

@@ -0,0 +1,269 @@
package service
import (
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"math/rand"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/dgraph-io/ristretto"
)
type apiKeyAuthCacheConfig struct {
l1Size int
l1TTL time.Duration
l2TTL time.Duration
negativeTTL time.Duration
jitterPercent int
singleflight bool
}
var (
jitterRandMu sync.Mutex
// 认证缓存抖动使用独立随机源,避免全局 Seed
jitterRand = rand.New(rand.NewSource(time.Now().UnixNano()))
)
func newAPIKeyAuthCacheConfig(cfg *config.Config) apiKeyAuthCacheConfig {
if cfg == nil {
return apiKeyAuthCacheConfig{}
}
auth := cfg.APIKeyAuth
return apiKeyAuthCacheConfig{
l1Size: auth.L1Size,
l1TTL: time.Duration(auth.L1TTLSeconds) * time.Second,
l2TTL: time.Duration(auth.L2TTLSeconds) * time.Second,
negativeTTL: time.Duration(auth.NegativeTTLSeconds) * time.Second,
jitterPercent: auth.JitterPercent,
singleflight: auth.Singleflight,
}
}
func (c apiKeyAuthCacheConfig) l1Enabled() bool {
return c.l1Size > 0 && c.l1TTL > 0
}
func (c apiKeyAuthCacheConfig) l2Enabled() bool {
return c.l2TTL > 0
}
func (c apiKeyAuthCacheConfig) negativeEnabled() bool {
return c.negativeTTL > 0
}
func (c apiKeyAuthCacheConfig) jitterTTL(ttl time.Duration) time.Duration {
if ttl <= 0 {
return ttl
}
if c.jitterPercent <= 0 {
return ttl
}
percent := c.jitterPercent
if percent > 100 {
percent = 100
}
delta := float64(percent) / 100
jitterRandMu.Lock()
randVal := jitterRand.Float64()
jitterRandMu.Unlock()
factor := 1 - delta + randVal*(2*delta)
if factor <= 0 {
return ttl
}
return time.Duration(float64(ttl) * factor)
}
func (s *APIKeyService) initAuthCache(cfg *config.Config) {
s.authCfg = newAPIKeyAuthCacheConfig(cfg)
if !s.authCfg.l1Enabled() {
return
}
cache, err := ristretto.NewCache(&ristretto.Config{
NumCounters: int64(s.authCfg.l1Size) * 10,
MaxCost: int64(s.authCfg.l1Size),
BufferItems: 64,
})
if err != nil {
return
}
s.authCacheL1 = cache
}
func (s *APIKeyService) authCacheKey(key string) string {
sum := sha256.Sum256([]byte(key))
return hex.EncodeToString(sum[:])
}
func (s *APIKeyService) getAuthCacheEntry(ctx context.Context, cacheKey string) (*APIKeyAuthCacheEntry, bool) {
if s.authCacheL1 != nil {
if val, ok := s.authCacheL1.Get(cacheKey); ok {
if entry, ok := val.(*APIKeyAuthCacheEntry); ok {
return entry, true
}
}
}
if s.cache == nil || !s.authCfg.l2Enabled() {
return nil, false
}
entry, err := s.cache.GetAuthCache(ctx, cacheKey)
if err != nil {
return nil, false
}
s.setAuthCacheL1(cacheKey, entry)
return entry, true
}
func (s *APIKeyService) setAuthCacheL1(cacheKey string, entry *APIKeyAuthCacheEntry) {
if s.authCacheL1 == nil || entry == nil {
return
}
ttl := s.authCfg.l1TTL
if entry.NotFound && s.authCfg.negativeTTL > 0 && s.authCfg.negativeTTL < ttl {
ttl = s.authCfg.negativeTTL
}
ttl = s.authCfg.jitterTTL(ttl)
_ = s.authCacheL1.SetWithTTL(cacheKey, entry, 1, ttl)
}
func (s *APIKeyService) setAuthCacheEntry(ctx context.Context, cacheKey string, entry *APIKeyAuthCacheEntry, ttl time.Duration) {
if entry == nil {
return
}
s.setAuthCacheL1(cacheKey, entry)
if s.cache == nil || !s.authCfg.l2Enabled() {
return
}
_ = s.cache.SetAuthCache(ctx, cacheKey, entry, s.authCfg.jitterTTL(ttl))
}
func (s *APIKeyService) deleteAuthCache(ctx context.Context, cacheKey string) {
if s.authCacheL1 != nil {
s.authCacheL1.Del(cacheKey)
}
if s.cache == nil {
return
}
_ = s.cache.DeleteAuthCache(ctx, cacheKey)
}
func (s *APIKeyService) loadAuthCacheEntry(ctx context.Context, key, cacheKey string) (*APIKeyAuthCacheEntry, error) {
apiKey, err := s.apiKeyRepo.GetByKeyForAuth(ctx, key)
if err != nil {
if errors.Is(err, ErrAPIKeyNotFound) {
entry := &APIKeyAuthCacheEntry{NotFound: true}
if s.authCfg.negativeEnabled() {
s.setAuthCacheEntry(ctx, cacheKey, entry, s.authCfg.negativeTTL)
}
return entry, nil
}
return nil, fmt.Errorf("get api key: %w", err)
}
apiKey.Key = key
snapshot := s.snapshotFromAPIKey(apiKey)
if snapshot == nil {
return nil, fmt.Errorf("get api key: %w", ErrAPIKeyNotFound)
}
entry := &APIKeyAuthCacheEntry{Snapshot: snapshot}
s.setAuthCacheEntry(ctx, cacheKey, entry, s.authCfg.l2TTL)
return entry, nil
}
func (s *APIKeyService) applyAuthCacheEntry(key string, entry *APIKeyAuthCacheEntry) (*APIKey, bool, error) {
if entry == nil {
return nil, false, nil
}
if entry.NotFound {
return nil, true, ErrAPIKeyNotFound
}
if entry.Snapshot == nil {
return nil, false, nil
}
return s.snapshotToAPIKey(key, entry.Snapshot), true, nil
}
func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
if apiKey == nil || apiKey.User == nil {
return nil
}
snapshot := &APIKeyAuthSnapshot{
APIKeyID: apiKey.ID,
UserID: apiKey.UserID,
GroupID: apiKey.GroupID,
Status: apiKey.Status,
IPWhitelist: apiKey.IPWhitelist,
IPBlacklist: apiKey.IPBlacklist,
User: APIKeyAuthUserSnapshot{
ID: apiKey.User.ID,
Status: apiKey.User.Status,
Role: apiKey.User.Role,
Balance: apiKey.User.Balance,
Concurrency: apiKey.User.Concurrency,
},
}
if apiKey.Group != nil {
snapshot.Group = &APIKeyAuthGroupSnapshot{
ID: apiKey.Group.ID,
Name: apiKey.Group.Name,
Platform: apiKey.Group.Platform,
Status: apiKey.Group.Status,
SubscriptionType: apiKey.Group.SubscriptionType,
RateMultiplier: apiKey.Group.RateMultiplier,
DailyLimitUSD: apiKey.Group.DailyLimitUSD,
WeeklyLimitUSD: apiKey.Group.WeeklyLimitUSD,
MonthlyLimitUSD: apiKey.Group.MonthlyLimitUSD,
ImagePrice1K: apiKey.Group.ImagePrice1K,
ImagePrice2K: apiKey.Group.ImagePrice2K,
ImagePrice4K: apiKey.Group.ImagePrice4K,
ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly,
FallbackGroupID: apiKey.Group.FallbackGroupID,
}
}
return snapshot
}
func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapshot) *APIKey {
if snapshot == nil {
return nil
}
apiKey := &APIKey{
ID: snapshot.APIKeyID,
UserID: snapshot.UserID,
GroupID: snapshot.GroupID,
Key: key,
Status: snapshot.Status,
IPWhitelist: snapshot.IPWhitelist,
IPBlacklist: snapshot.IPBlacklist,
User: &User{
ID: snapshot.User.ID,
Status: snapshot.User.Status,
Role: snapshot.User.Role,
Balance: snapshot.User.Balance,
Concurrency: snapshot.User.Concurrency,
},
}
if snapshot.Group != nil {
apiKey.Group = &Group{
ID: snapshot.Group.ID,
Name: snapshot.Group.Name,
Platform: snapshot.Group.Platform,
Status: snapshot.Group.Status,
Hydrated: true,
SubscriptionType: snapshot.Group.SubscriptionType,
RateMultiplier: snapshot.Group.RateMultiplier,
DailyLimitUSD: snapshot.Group.DailyLimitUSD,
WeeklyLimitUSD: snapshot.Group.WeeklyLimitUSD,
MonthlyLimitUSD: snapshot.Group.MonthlyLimitUSD,
ImagePrice1K: snapshot.Group.ImagePrice1K,
ImagePrice2K: snapshot.Group.ImagePrice2K,
ImagePrice4K: snapshot.Group.ImagePrice4K,
ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly,
FallbackGroupID: snapshot.Group.FallbackGroupID,
}
}
return apiKey
}

View File

@@ -0,0 +1,48 @@
package service
import "context"
// InvalidateAuthCacheByKey 清除指定 API Key 的认证缓存
func (s *APIKeyService) InvalidateAuthCacheByKey(ctx context.Context, key string) {
if key == "" {
return
}
cacheKey := s.authCacheKey(key)
s.deleteAuthCache(ctx, cacheKey)
}
// InvalidateAuthCacheByUserID 清除用户相关的 API Key 认证缓存
func (s *APIKeyService) InvalidateAuthCacheByUserID(ctx context.Context, userID int64) {
if userID <= 0 {
return
}
keys, err := s.apiKeyRepo.ListKeysByUserID(ctx, userID)
if err != nil {
return
}
s.deleteAuthCacheByKeys(ctx, keys)
}
// InvalidateAuthCacheByGroupID 清除分组相关的 API Key 认证缓存
func (s *APIKeyService) InvalidateAuthCacheByGroupID(ctx context.Context, groupID int64) {
if groupID <= 0 {
return
}
keys, err := s.apiKeyRepo.ListKeysByGroupID(ctx, groupID)
if err != nil {
return
}
s.deleteAuthCacheByKeys(ctx, keys)
}
func (s *APIKeyService) deleteAuthCacheByKeys(ctx context.Context, keys []string) {
if len(keys) == 0 {
return
}
for _, key := range keys {
if key == "" {
continue
}
s.deleteAuthCache(ctx, s.authCacheKey(key))
}
}

View File

@@ -12,6 +12,8 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/dgraph-io/ristretto"
"golang.org/x/sync/singleflight"
)
var (
@@ -31,9 +33,11 @@ const (
type APIKeyRepository interface {
Create(ctx context.Context, key *APIKey) error
GetByID(ctx context.Context, id int64) (*APIKey, error)
// GetOwnerID 仅获取 API Key 的所有者 ID用于删除前的轻量级权限验证
GetOwnerID(ctx context.Context, id int64) (int64, error)
// GetKeyAndOwnerID 仅获取 API Key 的 key 与所有者 ID用于删除等轻量场景
GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error)
GetByKey(ctx context.Context, key string) (*APIKey, error)
// GetByKeyForAuth 认证专用查询,返回最小字段集
GetByKeyForAuth(ctx context.Context, key string) (*APIKey, error)
Update(ctx context.Context, key *APIKey) error
Delete(ctx context.Context, id int64) error
@@ -45,6 +49,8 @@ type APIKeyRepository interface {
SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error)
ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error)
CountByGroupID(ctx context.Context, groupID int64) (int64, error)
ListKeysByUserID(ctx context.Context, userID int64) ([]string, error)
ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error)
}
// APIKeyCache defines cache operations for API key service
@@ -55,6 +61,17 @@ type APIKeyCache interface {
IncrementDailyUsage(ctx context.Context, apiKey string) error
SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error
GetAuthCache(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error)
SetAuthCache(ctx context.Context, key string, entry *APIKeyAuthCacheEntry, ttl time.Duration) error
DeleteAuthCache(ctx context.Context, key string) error
}
// APIKeyAuthCacheInvalidator 提供认证缓存失效能力
type APIKeyAuthCacheInvalidator interface {
InvalidateAuthCacheByKey(ctx context.Context, key string)
InvalidateAuthCacheByUserID(ctx context.Context, userID int64)
InvalidateAuthCacheByGroupID(ctx context.Context, groupID int64)
}
// CreateAPIKeyRequest 创建API Key请求
@@ -83,6 +100,9 @@ type APIKeyService struct {
userSubRepo UserSubscriptionRepository
cache APIKeyCache
cfg *config.Config
authCacheL1 *ristretto.Cache
authCfg apiKeyAuthCacheConfig
authGroup singleflight.Group
}
// NewAPIKeyService 创建API Key服务实例
@@ -94,7 +114,7 @@ func NewAPIKeyService(
cache APIKeyCache,
cfg *config.Config,
) *APIKeyService {
return &APIKeyService{
svc := &APIKeyService{
apiKeyRepo: apiKeyRepo,
userRepo: userRepo,
groupRepo: groupRepo,
@@ -102,6 +122,8 @@ func NewAPIKeyService(
cache: cache,
cfg: cfg,
}
svc.initAuthCache(cfg)
return svc
}
// GenerateKey 生成随机API Key
@@ -269,6 +291,8 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK
return nil, fmt.Errorf("create api key: %w", err)
}
s.InvalidateAuthCacheByKey(ctx, apiKey.Key)
return apiKey, nil
}
@@ -304,21 +328,49 @@ func (s *APIKeyService) GetByID(ctx context.Context, id int64) (*APIKey, error)
// GetByKey 根据Key字符串获取API Key用于认证
func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, error) {
// 尝试从Redis缓存获取
cacheKey := fmt.Sprintf("apikey:%s", key)
cacheKey := s.authCacheKey(key)
// 这里可以添加Redis缓存逻辑暂时直接查询数据库
apiKey, err := s.apiKeyRepo.GetByKey(ctx, key)
if entry, ok := s.getAuthCacheEntry(ctx, cacheKey); ok {
if apiKey, used, err := s.applyAuthCacheEntry(key, entry); used {
if err != nil {
return nil, fmt.Errorf("get api key: %w", err)
}
return apiKey, nil
}
}
if s.authCfg.singleflight {
value, err, _ := s.authGroup.Do(cacheKey, func() (any, error) {
return s.loadAuthCacheEntry(ctx, key, cacheKey)
})
if err != nil {
return nil, err
}
entry, _ := value.(*APIKeyAuthCacheEntry)
if apiKey, used, err := s.applyAuthCacheEntry(key, entry); used {
if err != nil {
return nil, fmt.Errorf("get api key: %w", err)
}
return apiKey, nil
}
} else {
entry, err := s.loadAuthCacheEntry(ctx, key, cacheKey)
if err != nil {
return nil, err
}
if apiKey, used, err := s.applyAuthCacheEntry(key, entry); used {
if err != nil {
return nil, fmt.Errorf("get api key: %w", err)
}
return apiKey, nil
}
}
apiKey, err := s.apiKeyRepo.GetByKeyForAuth(ctx, key)
if err != nil {
return nil, fmt.Errorf("get api key: %w", err)
}
// 缓存到Redis可选TTL设置为5分钟
if s.cache != nil {
// 这里可以序列化并缓存API Key
_ = cacheKey // 使用变量避免未使用错误
}
apiKey.Key = key
return apiKey, nil
}
@@ -388,15 +440,14 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req
return nil, fmt.Errorf("update api key: %w", err)
}
s.InvalidateAuthCacheByKey(ctx, apiKey.Key)
return apiKey, nil
}
// Delete 删除API Key
// 优化:使用 GetOwnerID 替代 GetByID 进行权限验证,
// 避免加载完整 APIKey 对象及其关联数据User、Group提升删除操作的性能
func (s *APIKeyService) Delete(ctx context.Context, id int64, userID int64) error {
// 仅获取所有者 ID 用于权限验证,而非加载完整对象
ownerID, err := s.apiKeyRepo.GetOwnerID(ctx, id)
key, ownerID, err := s.apiKeyRepo.GetKeyAndOwnerID(ctx, id)
if err != nil {
return fmt.Errorf("get api key: %w", err)
}
@@ -406,10 +457,11 @@ func (s *APIKeyService) Delete(ctx context.Context, id int64, userID int64) erro
return ErrInsufficientPerms
}
// 清除Redis缓存使用 ownerID 而非 apiKey.UserID
// 清除Redis缓存使用 userID 而非 apiKey.UserID
if s.cache != nil {
_ = s.cache.DeleteCreateAttemptCount(ctx, ownerID)
_ = s.cache.DeleteCreateAttemptCount(ctx, userID)
}
s.InvalidateAuthCacheByKey(ctx, key)
if err := s.apiKeyRepo.Delete(ctx, id); err != nil {
return fmt.Errorf("delete api key: %w", err)

View File

@@ -0,0 +1,417 @@
//go:build unit
package service
import (
"context"
"errors"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
)
type authRepoStub struct {
getByKeyForAuth func(ctx context.Context, key string) (*APIKey, error)
listKeysByUserID func(ctx context.Context, userID int64) ([]string, error)
listKeysByGroupID func(ctx context.Context, groupID int64) ([]string, error)
}
func (s *authRepoStub) Create(ctx context.Context, key *APIKey) error {
panic("unexpected Create call")
}
func (s *authRepoStub) GetByID(ctx context.Context, id int64) (*APIKey, error) {
panic("unexpected GetByID call")
}
func (s *authRepoStub) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) {
panic("unexpected GetKeyAndOwnerID call")
}
func (s *authRepoStub) GetByKey(ctx context.Context, key string) (*APIKey, error) {
panic("unexpected GetByKey call")
}
func (s *authRepoStub) GetByKeyForAuth(ctx context.Context, key string) (*APIKey, error) {
if s.getByKeyForAuth == nil {
panic("unexpected GetByKeyForAuth call")
}
return s.getByKeyForAuth(ctx, key)
}
func (s *authRepoStub) Update(ctx context.Context, key *APIKey) error {
panic("unexpected Update call")
}
func (s *authRepoStub) Delete(ctx context.Context, id int64) error {
panic("unexpected Delete call")
}
func (s *authRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
panic("unexpected ListByUserID call")
}
func (s *authRepoStub) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
panic("unexpected VerifyOwnership call")
}
func (s *authRepoStub) CountByUserID(ctx context.Context, userID int64) (int64, error) {
panic("unexpected CountByUserID call")
}
func (s *authRepoStub) ExistsByKey(ctx context.Context, key string) (bool, error) {
panic("unexpected ExistsByKey call")
}
func (s *authRepoStub) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
panic("unexpected ListByGroupID call")
}
func (s *authRepoStub) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error) {
panic("unexpected SearchAPIKeys call")
}
func (s *authRepoStub) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
panic("unexpected ClearGroupIDByGroupID call")
}
func (s *authRepoStub) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
panic("unexpected CountByGroupID call")
}
func (s *authRepoStub) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) {
if s.listKeysByUserID == nil {
panic("unexpected ListKeysByUserID call")
}
return s.listKeysByUserID(ctx, userID)
}
func (s *authRepoStub) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) {
if s.listKeysByGroupID == nil {
panic("unexpected ListKeysByGroupID call")
}
return s.listKeysByGroupID(ctx, groupID)
}
type authCacheStub struct {
getAuthCache func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error)
setAuthKeys []string
deleteAuthKeys []string
}
func (s *authCacheStub) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) {
return 0, nil
}
func (s *authCacheStub) IncrementCreateAttemptCount(ctx context.Context, userID int64) error {
return nil
}
func (s *authCacheStub) DeleteCreateAttemptCount(ctx context.Context, userID int64) error {
return nil
}
func (s *authCacheStub) IncrementDailyUsage(ctx context.Context, apiKey string) error {
return nil
}
func (s *authCacheStub) SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error {
return nil
}
func (s *authCacheStub) GetAuthCache(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
if s.getAuthCache == nil {
return nil, redis.Nil
}
return s.getAuthCache(ctx, key)
}
func (s *authCacheStub) SetAuthCache(ctx context.Context, key string, entry *APIKeyAuthCacheEntry, ttl time.Duration) error {
s.setAuthKeys = append(s.setAuthKeys, key)
return nil
}
func (s *authCacheStub) DeleteAuthCache(ctx context.Context, key string) error {
s.deleteAuthKeys = append(s.deleteAuthKeys, key)
return nil
}
func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) {
cache := &authCacheStub{}
repo := &authRepoStub{
getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) {
return nil, errors.New("unexpected repo call")
},
}
cfg := &config.Config{
APIKeyAuth: config.APIKeyAuthCacheConfig{
L2TTLSeconds: 60,
NegativeTTLSeconds: 30,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
groupID := int64(9)
cacheEntry := &APIKeyAuthCacheEntry{
Snapshot: &APIKeyAuthSnapshot{
APIKeyID: 1,
UserID: 2,
GroupID: &groupID,
Status: StatusActive,
User: APIKeyAuthUserSnapshot{
ID: 2,
Status: StatusActive,
Role: RoleUser,
Balance: 10,
Concurrency: 3,
},
Group: &APIKeyAuthGroupSnapshot{
ID: groupID,
Name: "g",
Platform: PlatformAnthropic,
Status: StatusActive,
SubscriptionType: SubscriptionTypeStandard,
RateMultiplier: 1,
},
},
}
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
return cacheEntry, nil
}
apiKey, err := svc.GetByKey(context.Background(), "k1")
require.NoError(t, err)
require.Equal(t, int64(1), apiKey.ID)
require.Equal(t, int64(2), apiKey.User.ID)
require.Equal(t, groupID, apiKey.Group.ID)
}
func TestAPIKeyService_GetByKey_NegativeCache(t *testing.T) {
cache := &authCacheStub{}
repo := &authRepoStub{
getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) {
return nil, errors.New("unexpected repo call")
},
}
cfg := &config.Config{
APIKeyAuth: config.APIKeyAuthCacheConfig{
L2TTLSeconds: 60,
NegativeTTLSeconds: 30,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
return &APIKeyAuthCacheEntry{NotFound: true}, nil
}
_, err := svc.GetByKey(context.Background(), "missing")
require.ErrorIs(t, err, ErrAPIKeyNotFound)
}
func TestAPIKeyService_GetByKey_CacheMissStoresL2(t *testing.T) {
cache := &authCacheStub{}
repo := &authRepoStub{
getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) {
return &APIKey{
ID: 5,
UserID: 7,
Status: StatusActive,
User: &User{
ID: 7,
Status: StatusActive,
Role: RoleUser,
Balance: 12,
Concurrency: 2,
},
}, nil
},
}
cfg := &config.Config{
APIKeyAuth: config.APIKeyAuthCacheConfig{
L2TTLSeconds: 60,
NegativeTTLSeconds: 30,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
return nil, redis.Nil
}
apiKey, err := svc.GetByKey(context.Background(), "k2")
require.NoError(t, err)
require.Equal(t, int64(5), apiKey.ID)
require.Len(t, cache.setAuthKeys, 1)
}
func TestAPIKeyService_GetByKey_UsesL1Cache(t *testing.T) {
var calls int32
cache := &authCacheStub{}
repo := &authRepoStub{
getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) {
atomic.AddInt32(&calls, 1)
return &APIKey{
ID: 21,
UserID: 3,
Status: StatusActive,
User: &User{
ID: 3,
Status: StatusActive,
Role: RoleUser,
Balance: 5,
Concurrency: 2,
},
}, nil
},
}
cfg := &config.Config{
APIKeyAuth: config.APIKeyAuthCacheConfig{
L1Size: 1000,
L1TTLSeconds: 60,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
require.NotNil(t, svc.authCacheL1)
_, err := svc.GetByKey(context.Background(), "k-l1")
require.NoError(t, err)
svc.authCacheL1.Wait()
cacheKey := svc.authCacheKey("k-l1")
_, ok := svc.authCacheL1.Get(cacheKey)
require.True(t, ok)
_, err = svc.GetByKey(context.Background(), "k-l1")
require.NoError(t, err)
require.Equal(t, int32(1), atomic.LoadInt32(&calls))
}
func TestAPIKeyService_InvalidateAuthCacheByUserID(t *testing.T) {
cache := &authCacheStub{}
repo := &authRepoStub{
listKeysByUserID: func(ctx context.Context, userID int64) ([]string, error) {
return []string{"k1", "k2"}, nil
},
}
cfg := &config.Config{
APIKeyAuth: config.APIKeyAuthCacheConfig{
L2TTLSeconds: 60,
NegativeTTLSeconds: 30,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
svc.InvalidateAuthCacheByUserID(context.Background(), 7)
require.Len(t, cache.deleteAuthKeys, 2)
}
func TestAPIKeyService_InvalidateAuthCacheByGroupID(t *testing.T) {
cache := &authCacheStub{}
repo := &authRepoStub{
listKeysByGroupID: func(ctx context.Context, groupID int64) ([]string, error) {
return []string{"k1", "k2"}, nil
},
}
cfg := &config.Config{
APIKeyAuth: config.APIKeyAuthCacheConfig{
L2TTLSeconds: 60,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
svc.InvalidateAuthCacheByGroupID(context.Background(), 9)
require.Len(t, cache.deleteAuthKeys, 2)
}
func TestAPIKeyService_InvalidateAuthCacheByKey(t *testing.T) {
cache := &authCacheStub{}
repo := &authRepoStub{
listKeysByUserID: func(ctx context.Context, userID int64) ([]string, error) {
return nil, nil
},
}
cfg := &config.Config{
APIKeyAuth: config.APIKeyAuthCacheConfig{
L2TTLSeconds: 60,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
svc.InvalidateAuthCacheByKey(context.Background(), "k1")
require.Len(t, cache.deleteAuthKeys, 1)
}
func TestAPIKeyService_GetByKey_CachesNegativeOnRepoMiss(t *testing.T) {
cache := &authCacheStub{}
repo := &authRepoStub{
getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) {
return nil, ErrAPIKeyNotFound
},
}
cfg := &config.Config{
APIKeyAuth: config.APIKeyAuthCacheConfig{
L2TTLSeconds: 60,
NegativeTTLSeconds: 30,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
return nil, redis.Nil
}
_, err := svc.GetByKey(context.Background(), "missing")
require.ErrorIs(t, err, ErrAPIKeyNotFound)
require.Len(t, cache.setAuthKeys, 1)
}
func TestAPIKeyService_GetByKey_SingleflightCollapses(t *testing.T) {
var calls int32
cache := &authCacheStub{}
repo := &authRepoStub{
getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) {
atomic.AddInt32(&calls, 1)
time.Sleep(50 * time.Millisecond)
return &APIKey{
ID: 11,
UserID: 2,
Status: StatusActive,
User: &User{
ID: 2,
Status: StatusActive,
Role: RoleUser,
Balance: 1,
Concurrency: 1,
},
}, nil
},
}
cfg := &config.Config{
APIKeyAuth: config.APIKeyAuthCacheConfig{
Singleflight: true,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
start := make(chan struct{})
wg := sync.WaitGroup{}
errs := make([]error, 5)
for i := 0; i < 5; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
<-start
_, err := svc.GetByKey(context.Background(), "k1")
errs[idx] = err
}(i)
}
close(start)
wg.Wait()
for _, err := range errs {
require.NoError(t, err)
}
require.Equal(t, int32(1), atomic.LoadInt32(&calls))
}

View File

@@ -20,13 +20,12 @@ import (
// 用于隔离测试 APIKeyService.Delete 方法,避免依赖真实数据库。
//
// 设计说明:
// - ownerID: 模拟 GetOwnerID 返回的所有者 ID
// - ownerErr: 模拟 GetOwnerID 返回的错误(如 ErrAPIKeyNotFound
// - apiKey/getByIDErr: 模拟 GetKeyAndOwnerID 返回的记录与错误
// - deleteErr: 模拟 Delete 返回的错误
// - deletedIDs: 记录被调用删除的 API Key ID用于断言验证
type apiKeyRepoStub struct {
ownerID int64 // GetOwnerID 的返回值
ownerErr error // GetOwnerID 的错误返回值
apiKey *APIKey // GetKeyAndOwnerID 的返回值
getByIDErr error // GetKeyAndOwnerID 的错误返回值
deleteErr error // Delete 的错误返回值
deletedIDs []int64 // 记录已删除的 API Key ID 列表
}
@@ -38,19 +37,34 @@ func (s *apiKeyRepoStub) Create(ctx context.Context, key *APIKey) error {
}
func (s *apiKeyRepoStub) GetByID(ctx context.Context, id int64) (*APIKey, error) {
if s.getByIDErr != nil {
return nil, s.getByIDErr
}
if s.apiKey != nil {
clone := *s.apiKey
return &clone, nil
}
panic("unexpected GetByID call")
}
// GetOwnerID 返回预设的所有者 ID 或错误。
// 这是 Delete 方法调用的第一个仓储方法,用于验证调用者是否为 API Key 的所有者。
func (s *apiKeyRepoStub) GetOwnerID(ctx context.Context, id int64) (int64, error) {
return s.ownerID, s.ownerErr
func (s *apiKeyRepoStub) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) {
if s.getByIDErr != nil {
return "", 0, s.getByIDErr
}
if s.apiKey != nil {
return s.apiKey.Key, s.apiKey.UserID, nil
}
return "", 0, ErrAPIKeyNotFound
}
func (s *apiKeyRepoStub) GetByKey(ctx context.Context, key string) (*APIKey, error) {
panic("unexpected GetByKey call")
}
func (s *apiKeyRepoStub) GetByKeyForAuth(ctx context.Context, key string) (*APIKey, error) {
panic("unexpected GetByKeyForAuth call")
}
func (s *apiKeyRepoStub) Update(ctx context.Context, key *APIKey) error {
panic("unexpected Update call")
}
@@ -96,13 +110,22 @@ func (s *apiKeyRepoStub) CountByGroupID(ctx context.Context, groupID int64) (int
panic("unexpected CountByGroupID call")
}
func (s *apiKeyRepoStub) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) {
panic("unexpected ListKeysByUserID call")
}
func (s *apiKeyRepoStub) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) {
panic("unexpected ListKeysByGroupID call")
}
// apiKeyCacheStub 是 APIKeyCache 接口的测试桩实现。
// 用于验证删除操作时缓存清理逻辑是否被正确调用。
//
// 设计说明:
// - invalidated: 记录被清除缓存的用户 ID 列表
type apiKeyCacheStub struct {
invalidated []int64 // 记录调用 DeleteCreateAttemptCount 时传入的用户 ID
invalidated []int64 // 记录调用 DeleteCreateAttemptCount 时传入的用户 ID
deleteAuthKeys []string // 记录调用 DeleteAuthCache 时传入的缓存 key
}
// GetCreateAttemptCount 返回 0表示用户未超过创建次数限制
@@ -132,15 +155,30 @@ func (s *apiKeyCacheStub) SetDailyUsageExpiry(ctx context.Context, apiKey string
return nil
}
func (s *apiKeyCacheStub) GetAuthCache(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
return nil, nil
}
func (s *apiKeyCacheStub) SetAuthCache(ctx context.Context, key string, entry *APIKeyAuthCacheEntry, ttl time.Duration) error {
return nil
}
func (s *apiKeyCacheStub) DeleteAuthCache(ctx context.Context, key string) error {
s.deleteAuthKeys = append(s.deleteAuthKeys, key)
return nil
}
// TestApiKeyService_Delete_OwnerMismatch 测试非所有者尝试删除时返回权限错误。
// 预期行为:
// - GetOwnerID 返回所有者 ID 为 1
// - GetKeyAndOwnerID 返回所有者 ID 为 1
// - 调用者 userID 为 2不匹配
// - 返回 ErrInsufficientPerms 错误
// - Delete 方法不被调用
// - 缓存不被清除
func TestApiKeyService_Delete_OwnerMismatch(t *testing.T) {
repo := &apiKeyRepoStub{ownerID: 1}
repo := &apiKeyRepoStub{
apiKey: &APIKey{ID: 10, UserID: 1, Key: "k"},
}
cache := &apiKeyCacheStub{}
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
@@ -148,17 +186,20 @@ func TestApiKeyService_Delete_OwnerMismatch(t *testing.T) {
require.ErrorIs(t, err, ErrInsufficientPerms)
require.Empty(t, repo.deletedIDs) // 验证删除操作未被调用
require.Empty(t, cache.invalidated) // 验证缓存未被清除
require.Empty(t, cache.deleteAuthKeys)
}
// TestApiKeyService_Delete_Success 测试所有者成功删除 API Key 的场景。
// 预期行为:
// - GetOwnerID 返回所有者 ID 为 7
// - GetKeyAndOwnerID 返回所有者 ID 为 7
// - 调用者 userID 为 7匹配
// - Delete 成功执行
// - 缓存被正确清除(使用 ownerID
// - 返回 nil 错误
func TestApiKeyService_Delete_Success(t *testing.T) {
repo := &apiKeyRepoStub{ownerID: 7}
repo := &apiKeyRepoStub{
apiKey: &APIKey{ID: 42, UserID: 7, Key: "k"},
}
cache := &apiKeyCacheStub{}
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
@@ -166,16 +207,17 @@ func TestApiKeyService_Delete_Success(t *testing.T) {
require.NoError(t, err)
require.Equal(t, []int64{42}, repo.deletedIDs) // 验证正确的 API Key 被删除
require.Equal(t, []int64{7}, cache.invalidated) // 验证所有者的缓存被清除
require.Equal(t, []string{svc.authCacheKey("k")}, cache.deleteAuthKeys)
}
// TestApiKeyService_Delete_NotFound 测试删除不存在的 API Key 时返回正确的错误。
// 预期行为:
// - GetOwnerID 返回 ErrAPIKeyNotFound 错误
// - GetKeyAndOwnerID 返回 ErrAPIKeyNotFound 错误
// - 返回 ErrAPIKeyNotFound 错误(被 fmt.Errorf 包装)
// - Delete 方法不被调用
// - 缓存不被清除
func TestApiKeyService_Delete_NotFound(t *testing.T) {
repo := &apiKeyRepoStub{ownerErr: ErrAPIKeyNotFound}
repo := &apiKeyRepoStub{getByIDErr: ErrAPIKeyNotFound}
cache := &apiKeyCacheStub{}
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
@@ -183,18 +225,19 @@ func TestApiKeyService_Delete_NotFound(t *testing.T) {
require.ErrorIs(t, err, ErrAPIKeyNotFound)
require.Empty(t, repo.deletedIDs)
require.Empty(t, cache.invalidated)
require.Empty(t, cache.deleteAuthKeys)
}
// TestApiKeyService_Delete_DeleteFails 测试删除操作失败时的错误处理。
// 预期行为:
// - GetOwnerID 返回正确的所有者 ID
// - GetKeyAndOwnerID 返回正确的所有者 ID
// - 所有权验证通过
// - 缓存被清除(在删除之前)
// - Delete 被调用但返回错误
// - 返回包含 "delete api key" 的错误信息
func TestApiKeyService_Delete_DeleteFails(t *testing.T) {
repo := &apiKeyRepoStub{
ownerID: 3,
apiKey: &APIKey{ID: 42, UserID: 3, Key: "k"},
deleteErr: errors.New("delete failed"),
}
cache := &apiKeyCacheStub{}
@@ -205,4 +248,5 @@ func TestApiKeyService_Delete_DeleteFails(t *testing.T) {
require.ErrorContains(t, err, "delete api key")
require.Equal(t, []int64{3}, repo.deletedIDs) // 验证删除操作被调用
require.Equal(t, []int64{3}, cache.invalidated) // 验证缓存已被清除(即使删除失败)
require.Equal(t, []string{svc.authCacheKey("k")}, cache.deleteAuthKeys)
}

View File

@@ -0,0 +1,33 @@
//go:build unit
package service
import (
"context"
"testing"
"github.com/stretchr/testify/require"
)
func TestUsageService_InvalidateUsageCaches(t *testing.T) {
invalidator := &authCacheInvalidatorStub{}
svc := &UsageService{authCacheInvalidator: invalidator}
svc.invalidateUsageCaches(context.Background(), 7, false)
require.Empty(t, invalidator.userIDs)
svc.invalidateUsageCaches(context.Background(), 7, true)
require.Equal(t, []int64{7}, invalidator.userIDs)
}
func TestRedeemService_InvalidateRedeemCaches_AuthCache(t *testing.T) {
invalidator := &authCacheInvalidatorStub{}
svc := &RedeemService{authCacheInvalidator: invalidator}
svc.invalidateRedeemCaches(context.Background(), 11, &RedeemCode{Type: RedeemTypeBalance})
svc.invalidateRedeemCaches(context.Background(), 11, &RedeemCode{Type: RedeemTypeConcurrency})
groupID := int64(3)
svc.invalidateRedeemCaches(context.Background(), 11, &RedeemCode{Type: RedeemTypeSubscription, GroupID: &groupID})
require.Equal(t, []int64{11, 11, 11}, invalidator.userIDs)
}

View File

@@ -0,0 +1,242 @@
package service
import (
"context"
"errors"
"log"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
)
const (
defaultDashboardAggregationTimeout = 2 * time.Minute
defaultDashboardAggregationBackfillTimeout = 30 * time.Minute
dashboardAggregationRetentionInterval = 6 * time.Hour
)
var (
// ErrDashboardBackfillDisabled 当配置禁用回填时返回。
ErrDashboardBackfillDisabled = errors.New("仪表盘聚合回填已禁用")
// ErrDashboardBackfillTooLarge 当回填跨度超过限制时返回。
ErrDashboardBackfillTooLarge = errors.New("回填时间跨度过大")
)
// DashboardAggregationRepository 定义仪表盘预聚合仓储接口。
type DashboardAggregationRepository interface {
AggregateRange(ctx context.Context, start, end time.Time) error
GetAggregationWatermark(ctx context.Context) (time.Time, error)
UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error
CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error
CleanupUsageLogs(ctx context.Context, cutoff time.Time) error
EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error
}
// DashboardAggregationService 负责定时聚合与回填。
type DashboardAggregationService struct {
repo DashboardAggregationRepository
timingWheel *TimingWheelService
cfg config.DashboardAggregationConfig
running int32
lastRetentionCleanup atomic.Value // time.Time
}
// NewDashboardAggregationService 创建聚合服务。
func NewDashboardAggregationService(repo DashboardAggregationRepository, timingWheel *TimingWheelService, cfg *config.Config) *DashboardAggregationService {
var aggCfg config.DashboardAggregationConfig
if cfg != nil {
aggCfg = cfg.DashboardAgg
}
return &DashboardAggregationService{
repo: repo,
timingWheel: timingWheel,
cfg: aggCfg,
}
}
// Start 启动定时聚合作业(重启生效配置)。
func (s *DashboardAggregationService) Start() {
if s == nil || s.repo == nil || s.timingWheel == nil {
return
}
if !s.cfg.Enabled {
log.Printf("[DashboardAggregation] 聚合作业已禁用")
return
}
interval := time.Duration(s.cfg.IntervalSeconds) * time.Second
if interval <= 0 {
interval = time.Minute
}
if s.cfg.RecomputeDays > 0 {
go s.recomputeRecentDays()
}
s.timingWheel.ScheduleRecurring("dashboard:aggregation", interval, func() {
s.runScheduledAggregation()
})
log.Printf("[DashboardAggregation] 聚合作业启动 (interval=%v, lookback=%ds)", interval, s.cfg.LookbackSeconds)
if !s.cfg.BackfillEnabled {
log.Printf("[DashboardAggregation] 回填已禁用,如需补齐保留窗口以外历史数据请手动回填")
}
}
// TriggerBackfill 触发回填(异步)。
func (s *DashboardAggregationService) TriggerBackfill(start, end time.Time) error {
if s == nil || s.repo == nil {
return errors.New("聚合服务未初始化")
}
if !s.cfg.BackfillEnabled {
log.Printf("[DashboardAggregation] 回填被拒绝: backfill_enabled=false")
return ErrDashboardBackfillDisabled
}
if !end.After(start) {
return errors.New("回填时间范围无效")
}
if s.cfg.BackfillMaxDays > 0 {
maxRange := time.Duration(s.cfg.BackfillMaxDays) * 24 * time.Hour
if end.Sub(start) > maxRange {
return ErrDashboardBackfillTooLarge
}
}
go func() {
ctx, cancel := context.WithTimeout(context.Background(), defaultDashboardAggregationBackfillTimeout)
defer cancel()
if err := s.backfillRange(ctx, start, end); err != nil {
log.Printf("[DashboardAggregation] 回填失败: %v", err)
}
}()
return nil
}
func (s *DashboardAggregationService) recomputeRecentDays() {
days := s.cfg.RecomputeDays
if days <= 0 {
return
}
now := time.Now().UTC()
start := now.AddDate(0, 0, -days)
ctx, cancel := context.WithTimeout(context.Background(), defaultDashboardAggregationBackfillTimeout)
defer cancel()
if err := s.backfillRange(ctx, start, now); err != nil {
log.Printf("[DashboardAggregation] 启动重算失败: %v", err)
return
}
}
func (s *DashboardAggregationService) runScheduledAggregation() {
if !atomic.CompareAndSwapInt32(&s.running, 0, 1) {
return
}
defer atomic.StoreInt32(&s.running, 0)
ctx, cancel := context.WithTimeout(context.Background(), defaultDashboardAggregationTimeout)
defer cancel()
now := time.Now().UTC()
last, err := s.repo.GetAggregationWatermark(ctx)
if err != nil {
log.Printf("[DashboardAggregation] 读取水位失败: %v", err)
last = time.Unix(0, 0).UTC()
}
lookback := time.Duration(s.cfg.LookbackSeconds) * time.Second
epoch := time.Unix(0, 0).UTC()
start := last.Add(-lookback)
if !last.After(epoch) {
retentionDays := s.cfg.Retention.UsageLogsDays
if retentionDays <= 0 {
retentionDays = 1
}
start = truncateToDayUTC(now.AddDate(0, 0, -retentionDays))
} else if start.After(now) {
start = now.Add(-lookback)
}
if err := s.aggregateRange(ctx, start, now); err != nil {
log.Printf("[DashboardAggregation] 聚合失败: %v", err)
return
}
if err := s.repo.UpdateAggregationWatermark(ctx, now); err != nil {
log.Printf("[DashboardAggregation] 更新水位失败: %v", err)
}
s.maybeCleanupRetention(ctx, now)
}
func (s *DashboardAggregationService) backfillRange(ctx context.Context, start, end time.Time) error {
if !atomic.CompareAndSwapInt32(&s.running, 0, 1) {
return errors.New("聚合作业正在运行")
}
defer atomic.StoreInt32(&s.running, 0)
startUTC := start.UTC()
endUTC := end.UTC()
if !endUTC.After(startUTC) {
return errors.New("回填时间范围无效")
}
cursor := truncateToDayUTC(startUTC)
for cursor.Before(endUTC) {
windowEnd := cursor.Add(24 * time.Hour)
if windowEnd.After(endUTC) {
windowEnd = endUTC
}
if err := s.aggregateRange(ctx, cursor, windowEnd); err != nil {
return err
}
cursor = windowEnd
}
if err := s.repo.UpdateAggregationWatermark(ctx, endUTC); err != nil {
log.Printf("[DashboardAggregation] 更新水位失败: %v", err)
}
s.maybeCleanupRetention(ctx, endUTC)
return nil
}
func (s *DashboardAggregationService) aggregateRange(ctx context.Context, start, end time.Time) error {
if !end.After(start) {
return nil
}
if err := s.repo.EnsureUsageLogsPartitions(ctx, end); err != nil {
log.Printf("[DashboardAggregation] 分区检查失败: %v", err)
}
return s.repo.AggregateRange(ctx, start, end)
}
func (s *DashboardAggregationService) maybeCleanupRetention(ctx context.Context, now time.Time) {
lastAny := s.lastRetentionCleanup.Load()
if lastAny != nil {
if last, ok := lastAny.(time.Time); ok && now.Sub(last) < dashboardAggregationRetentionInterval {
return
}
}
hourlyCutoff := now.AddDate(0, 0, -s.cfg.Retention.HourlyDays)
dailyCutoff := now.AddDate(0, 0, -s.cfg.Retention.DailyDays)
usageCutoff := now.AddDate(0, 0, -s.cfg.Retention.UsageLogsDays)
aggErr := s.repo.CleanupAggregates(ctx, hourlyCutoff, dailyCutoff)
if aggErr != nil {
log.Printf("[DashboardAggregation] 聚合保留清理失败: %v", aggErr)
}
usageErr := s.repo.CleanupUsageLogs(ctx, usageCutoff)
if usageErr != nil {
log.Printf("[DashboardAggregation] usage_logs 保留清理失败: %v", usageErr)
}
if aggErr == nil && usageErr == nil {
s.lastRetentionCleanup.Store(now)
}
}
func truncateToDayUTC(t time.Time) time.Time {
t = t.UTC()
return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, time.UTC)
}

View File

@@ -0,0 +1,106 @@
package service
import (
"context"
"errors"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
type dashboardAggregationRepoTestStub struct {
aggregateCalls int
lastStart time.Time
lastEnd time.Time
watermark time.Time
aggregateErr error
cleanupAggregatesErr error
cleanupUsageErr error
}
func (s *dashboardAggregationRepoTestStub) AggregateRange(ctx context.Context, start, end time.Time) error {
s.aggregateCalls++
s.lastStart = start
s.lastEnd = end
return s.aggregateErr
}
func (s *dashboardAggregationRepoTestStub) GetAggregationWatermark(ctx context.Context) (time.Time, error) {
return s.watermark, nil
}
func (s *dashboardAggregationRepoTestStub) UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error {
return nil
}
func (s *dashboardAggregationRepoTestStub) CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error {
return s.cleanupAggregatesErr
}
func (s *dashboardAggregationRepoTestStub) CleanupUsageLogs(ctx context.Context, cutoff time.Time) error {
return s.cleanupUsageErr
}
func (s *dashboardAggregationRepoTestStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
return nil
}
func TestDashboardAggregationService_RunScheduledAggregation_EpochUsesRetentionStart(t *testing.T) {
repo := &dashboardAggregationRepoTestStub{watermark: time.Unix(0, 0).UTC()}
svc := &DashboardAggregationService{
repo: repo,
cfg: config.DashboardAggregationConfig{
Enabled: true,
IntervalSeconds: 60,
LookbackSeconds: 120,
Retention: config.DashboardAggregationRetentionConfig{
UsageLogsDays: 1,
HourlyDays: 1,
DailyDays: 1,
},
},
}
svc.runScheduledAggregation()
require.Equal(t, 1, repo.aggregateCalls)
require.False(t, repo.lastEnd.IsZero())
require.Equal(t, truncateToDayUTC(repo.lastEnd.AddDate(0, 0, -1)), repo.lastStart)
}
func TestDashboardAggregationService_CleanupRetentionFailure_DoesNotRecord(t *testing.T) {
repo := &dashboardAggregationRepoTestStub{cleanupAggregatesErr: errors.New("清理失败")}
svc := &DashboardAggregationService{
repo: repo,
cfg: config.DashboardAggregationConfig{
Retention: config.DashboardAggregationRetentionConfig{
UsageLogsDays: 1,
HourlyDays: 1,
DailyDays: 1,
},
},
}
svc.maybeCleanupRetention(context.Background(), time.Now().UTC())
require.Nil(t, svc.lastRetentionCleanup.Load())
}
func TestDashboardAggregationService_TriggerBackfill_TooLarge(t *testing.T) {
repo := &dashboardAggregationRepoTestStub{}
svc := &DashboardAggregationService{
repo: repo,
cfg: config.DashboardAggregationConfig{
BackfillEnabled: true,
BackfillMaxDays: 1,
},
}
start := time.Now().AddDate(0, 0, -3)
end := time.Now()
err := svc.TriggerBackfill(start, end)
require.ErrorIs(t, err, ErrDashboardBackfillTooLarge)
require.Equal(t, 0, repo.aggregateCalls)
}

View File

@@ -2,25 +2,119 @@ package service
import (
"context"
"encoding/json"
"errors"
"fmt"
"log"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
)
// DashboardService provides aggregated statistics for admin dashboard.
type DashboardService struct {
usageRepo UsageLogRepository
const (
defaultDashboardStatsFreshTTL = 15 * time.Second
defaultDashboardStatsCacheTTL = 30 * time.Second
defaultDashboardStatsRefreshTimeout = 30 * time.Second
)
// ErrDashboardStatsCacheMiss 标记仪表盘缓存未命中。
var ErrDashboardStatsCacheMiss = errors.New("仪表盘缓存未命中")
// DashboardStatsCache 定义仪表盘统计缓存接口。
type DashboardStatsCache interface {
GetDashboardStats(ctx context.Context) (string, error)
SetDashboardStats(ctx context.Context, data string, ttl time.Duration) error
DeleteDashboardStats(ctx context.Context) error
}
func NewDashboardService(usageRepo UsageLogRepository) *DashboardService {
type dashboardStatsRangeFetcher interface {
GetDashboardStatsWithRange(ctx context.Context, start, end time.Time) (*usagestats.DashboardStats, error)
}
type dashboardStatsCacheEntry struct {
Stats *usagestats.DashboardStats `json:"stats"`
UpdatedAt int64 `json:"updated_at"`
}
// DashboardService 提供管理员仪表盘统计服务。
type DashboardService struct {
usageRepo UsageLogRepository
aggRepo DashboardAggregationRepository
cache DashboardStatsCache
cacheFreshTTL time.Duration
cacheTTL time.Duration
refreshTimeout time.Duration
refreshing int32
aggEnabled bool
aggInterval time.Duration
aggLookback time.Duration
aggUsageDays int
}
func NewDashboardService(usageRepo UsageLogRepository, aggRepo DashboardAggregationRepository, cache DashboardStatsCache, cfg *config.Config) *DashboardService {
freshTTL := defaultDashboardStatsFreshTTL
cacheTTL := defaultDashboardStatsCacheTTL
refreshTimeout := defaultDashboardStatsRefreshTimeout
aggEnabled := true
aggInterval := time.Minute
aggLookback := 2 * time.Minute
aggUsageDays := 90
if cfg != nil {
if !cfg.Dashboard.Enabled {
cache = nil
}
if cfg.Dashboard.StatsFreshTTLSeconds > 0 {
freshTTL = time.Duration(cfg.Dashboard.StatsFreshTTLSeconds) * time.Second
}
if cfg.Dashboard.StatsTTLSeconds > 0 {
cacheTTL = time.Duration(cfg.Dashboard.StatsTTLSeconds) * time.Second
}
if cfg.Dashboard.StatsRefreshTimeoutSeconds > 0 {
refreshTimeout = time.Duration(cfg.Dashboard.StatsRefreshTimeoutSeconds) * time.Second
}
aggEnabled = cfg.DashboardAgg.Enabled
if cfg.DashboardAgg.IntervalSeconds > 0 {
aggInterval = time.Duration(cfg.DashboardAgg.IntervalSeconds) * time.Second
}
if cfg.DashboardAgg.LookbackSeconds > 0 {
aggLookback = time.Duration(cfg.DashboardAgg.LookbackSeconds) * time.Second
}
if cfg.DashboardAgg.Retention.UsageLogsDays > 0 {
aggUsageDays = cfg.DashboardAgg.Retention.UsageLogsDays
}
}
return &DashboardService{
usageRepo: usageRepo,
usageRepo: usageRepo,
aggRepo: aggRepo,
cache: cache,
cacheFreshTTL: freshTTL,
cacheTTL: cacheTTL,
refreshTimeout: refreshTimeout,
aggEnabled: aggEnabled,
aggInterval: aggInterval,
aggLookback: aggLookback,
aggUsageDays: aggUsageDays,
}
}
func (s *DashboardService) GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) {
stats, err := s.usageRepo.GetDashboardStats(ctx)
if s.cache != nil {
cached, fresh, err := s.getCachedDashboardStats(ctx)
if err == nil && cached != nil {
s.refreshAggregationStaleness(cached)
if !fresh {
s.refreshDashboardStatsAsync()
}
return cached, nil
}
if err != nil && !errors.Is(err, ErrDashboardStatsCacheMiss) {
log.Printf("[Dashboard] 仪表盘缓存读取失败: %v", err)
}
}
stats, err := s.refreshDashboardStats(ctx)
if err != nil {
return nil, fmt.Errorf("get dashboard stats: %w", err)
}
@@ -43,6 +137,169 @@ func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTi
return stats, nil
}
func (s *DashboardService) getCachedDashboardStats(ctx context.Context) (*usagestats.DashboardStats, bool, error) {
data, err := s.cache.GetDashboardStats(ctx)
if err != nil {
return nil, false, err
}
var entry dashboardStatsCacheEntry
if err := json.Unmarshal([]byte(data), &entry); err != nil {
s.evictDashboardStatsCache(err)
return nil, false, ErrDashboardStatsCacheMiss
}
if entry.Stats == nil {
s.evictDashboardStatsCache(errors.New("仪表盘缓存缺少统计数据"))
return nil, false, ErrDashboardStatsCacheMiss
}
age := time.Since(time.Unix(entry.UpdatedAt, 0))
return entry.Stats, age <= s.cacheFreshTTL, nil
}
func (s *DashboardService) refreshDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) {
stats, err := s.fetchDashboardStats(ctx)
if err != nil {
return nil, err
}
s.applyAggregationStatus(ctx, stats)
cacheCtx, cancel := s.cacheOperationContext()
defer cancel()
s.saveDashboardStatsCache(cacheCtx, stats)
return stats, nil
}
func (s *DashboardService) refreshDashboardStatsAsync() {
if s.cache == nil {
return
}
if !atomic.CompareAndSwapInt32(&s.refreshing, 0, 1) {
return
}
go func() {
defer atomic.StoreInt32(&s.refreshing, 0)
ctx, cancel := context.WithTimeout(context.Background(), s.refreshTimeout)
defer cancel()
stats, err := s.fetchDashboardStats(ctx)
if err != nil {
log.Printf("[Dashboard] 仪表盘缓存异步刷新失败: %v", err)
return
}
s.applyAggregationStatus(ctx, stats)
cacheCtx, cancel := s.cacheOperationContext()
defer cancel()
s.saveDashboardStatsCache(cacheCtx, stats)
}()
}
func (s *DashboardService) fetchDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) {
if !s.aggEnabled {
if fetcher, ok := s.usageRepo.(dashboardStatsRangeFetcher); ok {
now := time.Now().UTC()
start := truncateToDayUTC(now.AddDate(0, 0, -s.aggUsageDays))
return fetcher.GetDashboardStatsWithRange(ctx, start, now)
}
}
return s.usageRepo.GetDashboardStats(ctx)
}
func (s *DashboardService) saveDashboardStatsCache(ctx context.Context, stats *usagestats.DashboardStats) {
if s.cache == nil || stats == nil {
return
}
entry := dashboardStatsCacheEntry{
Stats: stats,
UpdatedAt: time.Now().Unix(),
}
data, err := json.Marshal(entry)
if err != nil {
log.Printf("[Dashboard] 仪表盘缓存序列化失败: %v", err)
return
}
if err := s.cache.SetDashboardStats(ctx, string(data), s.cacheTTL); err != nil {
log.Printf("[Dashboard] 仪表盘缓存写入失败: %v", err)
}
}
func (s *DashboardService) evictDashboardStatsCache(reason error) {
if s.cache == nil {
return
}
cacheCtx, cancel := s.cacheOperationContext()
defer cancel()
if err := s.cache.DeleteDashboardStats(cacheCtx); err != nil {
log.Printf("[Dashboard] 仪表盘缓存清理失败: %v", err)
}
if reason != nil {
log.Printf("[Dashboard] 仪表盘缓存异常,已清理: %v", reason)
}
}
func (s *DashboardService) cacheOperationContext() (context.Context, context.CancelFunc) {
return context.WithTimeout(context.Background(), s.refreshTimeout)
}
func (s *DashboardService) applyAggregationStatus(ctx context.Context, stats *usagestats.DashboardStats) {
if stats == nil {
return
}
updatedAt := s.fetchAggregationUpdatedAt(ctx)
stats.StatsUpdatedAt = updatedAt.UTC().Format(time.RFC3339)
stats.StatsStale = s.isAggregationStale(updatedAt, time.Now().UTC())
}
func (s *DashboardService) refreshAggregationStaleness(stats *usagestats.DashboardStats) {
if stats == nil {
return
}
updatedAt := parseStatsUpdatedAt(stats.StatsUpdatedAt)
stats.StatsStale = s.isAggregationStale(updatedAt, time.Now().UTC())
}
func (s *DashboardService) fetchAggregationUpdatedAt(ctx context.Context) time.Time {
if s.aggRepo == nil {
return time.Unix(0, 0).UTC()
}
updatedAt, err := s.aggRepo.GetAggregationWatermark(ctx)
if err != nil {
log.Printf("[Dashboard] 读取聚合水位失败: %v", err)
return time.Unix(0, 0).UTC()
}
if updatedAt.IsZero() {
return time.Unix(0, 0).UTC()
}
return updatedAt.UTC()
}
func (s *DashboardService) isAggregationStale(updatedAt, now time.Time) bool {
if !s.aggEnabled {
return true
}
epoch := time.Unix(0, 0).UTC()
if !updatedAt.After(epoch) {
return true
}
threshold := s.aggInterval + s.aggLookback
return now.Sub(updatedAt) > threshold
}
func parseStatsUpdatedAt(raw string) time.Time {
if raw == "" {
return time.Unix(0, 0).UTC()
}
parsed, err := time.Parse(time.RFC3339, raw)
if err != nil {
return time.Unix(0, 0).UTC()
}
return parsed.UTC()
}
func (s *DashboardService) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) {
trend, err := s.usageRepo.GetAPIKeyUsageTrend(ctx, startTime, endTime, granularity, limit)
if err != nil {

View File

@@ -0,0 +1,387 @@
package service
import (
"context"
"encoding/json"
"errors"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/stretchr/testify/require"
)
type usageRepoStub struct {
UsageLogRepository
stats *usagestats.DashboardStats
rangeStats *usagestats.DashboardStats
err error
rangeErr error
calls int32
rangeCalls int32
rangeStart time.Time
rangeEnd time.Time
onCall chan struct{}
}
func (s *usageRepoStub) GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) {
atomic.AddInt32(&s.calls, 1)
if s.onCall != nil {
select {
case s.onCall <- struct{}{}:
default:
}
}
if s.err != nil {
return nil, s.err
}
return s.stats, nil
}
func (s *usageRepoStub) GetDashboardStatsWithRange(ctx context.Context, start, end time.Time) (*usagestats.DashboardStats, error) {
atomic.AddInt32(&s.rangeCalls, 1)
s.rangeStart = start
s.rangeEnd = end
if s.rangeErr != nil {
return nil, s.rangeErr
}
if s.rangeStats != nil {
return s.rangeStats, nil
}
return s.stats, nil
}
type dashboardCacheStub struct {
get func(ctx context.Context) (string, error)
set func(ctx context.Context, data string, ttl time.Duration) error
del func(ctx context.Context) error
getCalls int32
setCalls int32
delCalls int32
lastSetMu sync.Mutex
lastSet string
}
func (c *dashboardCacheStub) GetDashboardStats(ctx context.Context) (string, error) {
atomic.AddInt32(&c.getCalls, 1)
if c.get != nil {
return c.get(ctx)
}
return "", ErrDashboardStatsCacheMiss
}
func (c *dashboardCacheStub) SetDashboardStats(ctx context.Context, data string, ttl time.Duration) error {
atomic.AddInt32(&c.setCalls, 1)
c.lastSetMu.Lock()
c.lastSet = data
c.lastSetMu.Unlock()
if c.set != nil {
return c.set(ctx, data, ttl)
}
return nil
}
func (c *dashboardCacheStub) DeleteDashboardStats(ctx context.Context) error {
atomic.AddInt32(&c.delCalls, 1)
if c.del != nil {
return c.del(ctx)
}
return nil
}
type dashboardAggregationRepoStub struct {
watermark time.Time
err error
}
func (s *dashboardAggregationRepoStub) AggregateRange(ctx context.Context, start, end time.Time) error {
return nil
}
func (s *dashboardAggregationRepoStub) GetAggregationWatermark(ctx context.Context) (time.Time, error) {
if s.err != nil {
return time.Time{}, s.err
}
return s.watermark, nil
}
func (s *dashboardAggregationRepoStub) UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error {
return nil
}
func (s *dashboardAggregationRepoStub) CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error {
return nil
}
func (s *dashboardAggregationRepoStub) CleanupUsageLogs(ctx context.Context, cutoff time.Time) error {
return nil
}
func (s *dashboardAggregationRepoStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
return nil
}
func (c *dashboardCacheStub) readLastEntry(t *testing.T) dashboardStatsCacheEntry {
t.Helper()
c.lastSetMu.Lock()
data := c.lastSet
c.lastSetMu.Unlock()
var entry dashboardStatsCacheEntry
err := json.Unmarshal([]byte(data), &entry)
require.NoError(t, err)
return entry
}
func TestDashboardService_CacheHitFresh(t *testing.T) {
stats := &usagestats.DashboardStats{
TotalUsers: 10,
StatsUpdatedAt: time.Unix(0, 0).UTC().Format(time.RFC3339),
StatsStale: true,
}
entry := dashboardStatsCacheEntry{
Stats: stats,
UpdatedAt: time.Now().Unix(),
}
payload, err := json.Marshal(entry)
require.NoError(t, err)
cache := &dashboardCacheStub{
get: func(ctx context.Context) (string, error) {
return string(payload), nil
},
}
repo := &usageRepoStub{
stats: &usagestats.DashboardStats{TotalUsers: 99},
}
aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()}
cfg := &config.Config{
Dashboard: config.DashboardCacheConfig{Enabled: true},
DashboardAgg: config.DashboardAggregationConfig{
Enabled: true,
},
}
svc := NewDashboardService(repo, aggRepo, cache, cfg)
got, err := svc.GetDashboardStats(context.Background())
require.NoError(t, err)
require.Equal(t, stats, got)
require.Equal(t, int32(0), atomic.LoadInt32(&repo.calls))
require.Equal(t, int32(1), atomic.LoadInt32(&cache.getCalls))
require.Equal(t, int32(0), atomic.LoadInt32(&cache.setCalls))
}
func TestDashboardService_CacheMiss_StoresCache(t *testing.T) {
stats := &usagestats.DashboardStats{
TotalUsers: 7,
StatsUpdatedAt: time.Unix(0, 0).UTC().Format(time.RFC3339),
StatsStale: true,
}
cache := &dashboardCacheStub{
get: func(ctx context.Context) (string, error) {
return "", ErrDashboardStatsCacheMiss
},
}
repo := &usageRepoStub{stats: stats}
aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()}
cfg := &config.Config{
Dashboard: config.DashboardCacheConfig{Enabled: true},
DashboardAgg: config.DashboardAggregationConfig{
Enabled: true,
},
}
svc := NewDashboardService(repo, aggRepo, cache, cfg)
got, err := svc.GetDashboardStats(context.Background())
require.NoError(t, err)
require.Equal(t, stats, got)
require.Equal(t, int32(1), atomic.LoadInt32(&repo.calls))
require.Equal(t, int32(1), atomic.LoadInt32(&cache.getCalls))
require.Equal(t, int32(1), atomic.LoadInt32(&cache.setCalls))
entry := cache.readLastEntry(t)
require.Equal(t, stats, entry.Stats)
require.WithinDuration(t, time.Now(), time.Unix(entry.UpdatedAt, 0), time.Second)
}
func TestDashboardService_CacheDisabled_SkipsCache(t *testing.T) {
stats := &usagestats.DashboardStats{
TotalUsers: 3,
StatsUpdatedAt: time.Unix(0, 0).UTC().Format(time.RFC3339),
StatsStale: true,
}
cache := &dashboardCacheStub{
get: func(ctx context.Context) (string, error) {
return "", nil
},
}
repo := &usageRepoStub{stats: stats}
aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()}
cfg := &config.Config{
Dashboard: config.DashboardCacheConfig{Enabled: false},
DashboardAgg: config.DashboardAggregationConfig{
Enabled: true,
},
}
svc := NewDashboardService(repo, aggRepo, cache, cfg)
got, err := svc.GetDashboardStats(context.Background())
require.NoError(t, err)
require.Equal(t, stats, got)
require.Equal(t, int32(1), atomic.LoadInt32(&repo.calls))
require.Equal(t, int32(0), atomic.LoadInt32(&cache.getCalls))
require.Equal(t, int32(0), atomic.LoadInt32(&cache.setCalls))
}
func TestDashboardService_CacheHitStale_TriggersAsyncRefresh(t *testing.T) {
staleStats := &usagestats.DashboardStats{
TotalUsers: 11,
StatsUpdatedAt: time.Unix(0, 0).UTC().Format(time.RFC3339),
StatsStale: true,
}
entry := dashboardStatsCacheEntry{
Stats: staleStats,
UpdatedAt: time.Now().Add(-defaultDashboardStatsFreshTTL * 2).Unix(),
}
payload, err := json.Marshal(entry)
require.NoError(t, err)
cache := &dashboardCacheStub{
get: func(ctx context.Context) (string, error) {
return string(payload), nil
},
}
refreshCh := make(chan struct{}, 1)
repo := &usageRepoStub{
stats: &usagestats.DashboardStats{TotalUsers: 22},
onCall: refreshCh,
}
aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()}
cfg := &config.Config{
Dashboard: config.DashboardCacheConfig{Enabled: true},
DashboardAgg: config.DashboardAggregationConfig{
Enabled: true,
},
}
svc := NewDashboardService(repo, aggRepo, cache, cfg)
got, err := svc.GetDashboardStats(context.Background())
require.NoError(t, err)
require.Equal(t, staleStats, got)
select {
case <-refreshCh:
case <-time.After(1 * time.Second):
t.Fatal("等待异步刷新超时")
}
require.Eventually(t, func() bool {
return atomic.LoadInt32(&cache.setCalls) >= 1
}, 1*time.Second, 10*time.Millisecond)
}
func TestDashboardService_CacheParseError_EvictsAndRefetches(t *testing.T) {
cache := &dashboardCacheStub{
get: func(ctx context.Context) (string, error) {
return "not-json", nil
},
}
stats := &usagestats.DashboardStats{TotalUsers: 9}
repo := &usageRepoStub{stats: stats}
aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()}
cfg := &config.Config{
Dashboard: config.DashboardCacheConfig{Enabled: true},
DashboardAgg: config.DashboardAggregationConfig{
Enabled: true,
},
}
svc := NewDashboardService(repo, aggRepo, cache, cfg)
got, err := svc.GetDashboardStats(context.Background())
require.NoError(t, err)
require.Equal(t, stats, got)
require.Equal(t, int32(1), atomic.LoadInt32(&cache.delCalls))
require.Equal(t, int32(1), atomic.LoadInt32(&repo.calls))
}
func TestDashboardService_CacheParseError_RepoFailure(t *testing.T) {
cache := &dashboardCacheStub{
get: func(ctx context.Context) (string, error) {
return "not-json", nil
},
}
repo := &usageRepoStub{err: errors.New("db down")}
aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()}
cfg := &config.Config{
Dashboard: config.DashboardCacheConfig{Enabled: true},
DashboardAgg: config.DashboardAggregationConfig{
Enabled: true,
},
}
svc := NewDashboardService(repo, aggRepo, cache, cfg)
_, err := svc.GetDashboardStats(context.Background())
require.Error(t, err)
require.Equal(t, int32(1), atomic.LoadInt32(&cache.delCalls))
}
func TestDashboardService_StatsUpdatedAtEpochWhenMissing(t *testing.T) {
stats := &usagestats.DashboardStats{}
repo := &usageRepoStub{stats: stats}
aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()}
cfg := &config.Config{Dashboard: config.DashboardCacheConfig{Enabled: false}}
svc := NewDashboardService(repo, aggRepo, nil, cfg)
got, err := svc.GetDashboardStats(context.Background())
require.NoError(t, err)
require.Equal(t, "1970-01-01T00:00:00Z", got.StatsUpdatedAt)
require.True(t, got.StatsStale)
}
func TestDashboardService_StatsStaleFalseWhenFresh(t *testing.T) {
aggNow := time.Now().UTC().Truncate(time.Second)
stats := &usagestats.DashboardStats{}
repo := &usageRepoStub{stats: stats}
aggRepo := &dashboardAggregationRepoStub{watermark: aggNow}
cfg := &config.Config{
Dashboard: config.DashboardCacheConfig{Enabled: false},
DashboardAgg: config.DashboardAggregationConfig{
Enabled: true,
IntervalSeconds: 60,
LookbackSeconds: 120,
},
}
svc := NewDashboardService(repo, aggRepo, nil, cfg)
got, err := svc.GetDashboardStats(context.Background())
require.NoError(t, err)
require.Equal(t, aggNow.Format(time.RFC3339), got.StatsUpdatedAt)
require.False(t, got.StatsStale)
}
func TestDashboardService_AggDisabled_UsesUsageLogsFallback(t *testing.T) {
expected := &usagestats.DashboardStats{TotalUsers: 42}
repo := &usageRepoStub{
rangeStats: expected,
err: errors.New("should not call aggregated stats"),
}
cfg := &config.Config{
Dashboard: config.DashboardCacheConfig{Enabled: false},
DashboardAgg: config.DashboardAggregationConfig{
Enabled: false,
Retention: config.DashboardAggregationRetentionConfig{
UsageLogsDays: 7,
},
},
}
svc := NewDashboardService(repo, nil, nil, cfg)
got, err := svc.GetDashboardStats(context.Background())
require.NoError(t, err)
require.Equal(t, int64(42), got.TotalUsers)
require.Equal(t, int32(0), atomic.LoadInt32(&repo.calls))
require.Equal(t, int32(1), atomic.LoadInt32(&repo.rangeCalls))
require.False(t, repo.rangeEnd.IsZero())
require.Equal(t, truncateToDayUTC(repo.rangeEnd.AddDate(0, 0, -7)), repo.rangeStart)
}

View File

@@ -50,13 +50,15 @@ type UpdateGroupRequest struct {
// GroupService 分组管理服务
type GroupService struct {
groupRepo GroupRepository
groupRepo GroupRepository
authCacheInvalidator APIKeyAuthCacheInvalidator
}
// NewGroupService 创建分组服务实例
func NewGroupService(groupRepo GroupRepository) *GroupService {
func NewGroupService(groupRepo GroupRepository, authCacheInvalidator APIKeyAuthCacheInvalidator) *GroupService {
return &GroupService{
groupRepo: groupRepo,
groupRepo: groupRepo,
authCacheInvalidator: authCacheInvalidator,
}
}
@@ -155,6 +157,9 @@ func (s *GroupService) Update(ctx context.Context, id int64, req UpdateGroupRequ
if err := s.groupRepo.Update(ctx, group); err != nil {
return nil, fmt.Errorf("update group: %w", err)
}
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id)
}
return group, nil
}
@@ -167,6 +172,9 @@ func (s *GroupService) Delete(ctx context.Context, id int64) error {
return fmt.Errorf("get group: %w", err)
}
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id)
}
if err := s.groupRepo.Delete(ctx, id); err != nil {
return fmt.Errorf("delete group: %w", err)
}

View File

@@ -0,0 +1,404 @@
package service
import (
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strings"
"time"
)
const (
opencodeCodexHeaderURL = "https://raw.githubusercontent.com/anomalyco/opencode/dev/packages/opencode/src/session/prompt/codex_header.txt"
codexCacheTTL = 15 * time.Minute
)
var codexModelMap = map[string]string{
"gpt-5.1-codex": "gpt-5.1-codex",
"gpt-5.1-codex-low": "gpt-5.1-codex",
"gpt-5.1-codex-medium": "gpt-5.1-codex",
"gpt-5.1-codex-high": "gpt-5.1-codex",
"gpt-5.1-codex-max": "gpt-5.1-codex-max",
"gpt-5.1-codex-max-low": "gpt-5.1-codex-max",
"gpt-5.1-codex-max-medium": "gpt-5.1-codex-max",
"gpt-5.1-codex-max-high": "gpt-5.1-codex-max",
"gpt-5.1-codex-max-xhigh": "gpt-5.1-codex-max",
"gpt-5.2": "gpt-5.2",
"gpt-5.2-none": "gpt-5.2",
"gpt-5.2-low": "gpt-5.2",
"gpt-5.2-medium": "gpt-5.2",
"gpt-5.2-high": "gpt-5.2",
"gpt-5.2-xhigh": "gpt-5.2",
"gpt-5.2-codex": "gpt-5.2-codex",
"gpt-5.2-codex-low": "gpt-5.2-codex",
"gpt-5.2-codex-medium": "gpt-5.2-codex",
"gpt-5.2-codex-high": "gpt-5.2-codex",
"gpt-5.2-codex-xhigh": "gpt-5.2-codex",
"gpt-5.1-codex-mini": "gpt-5.1-codex-mini",
"gpt-5.1-codex-mini-medium": "gpt-5.1-codex-mini",
"gpt-5.1-codex-mini-high": "gpt-5.1-codex-mini",
"gpt-5.1": "gpt-5.1",
"gpt-5.1-none": "gpt-5.1",
"gpt-5.1-low": "gpt-5.1",
"gpt-5.1-medium": "gpt-5.1",
"gpt-5.1-high": "gpt-5.1",
"gpt-5.1-chat-latest": "gpt-5.1",
"gpt-5-codex": "gpt-5.1-codex",
"codex-mini-latest": "gpt-5.1-codex-mini",
"gpt-5-codex-mini": "gpt-5.1-codex-mini",
"gpt-5-codex-mini-medium": "gpt-5.1-codex-mini",
"gpt-5-codex-mini-high": "gpt-5.1-codex-mini",
"gpt-5": "gpt-5.1",
"gpt-5-mini": "gpt-5.1",
"gpt-5-nano": "gpt-5.1",
}
type codexTransformResult struct {
Modified bool
NormalizedModel string
PromptCacheKey string
}
type opencodeCacheMetadata struct {
ETag string `json:"etag"`
LastFetch string `json:"lastFetch,omitempty"`
LastChecked int64 `json:"lastChecked"`
}
func applyCodexOAuthTransform(reqBody map[string]any) codexTransformResult {
result := codexTransformResult{}
model := ""
if v, ok := reqBody["model"].(string); ok {
model = v
}
normalizedModel := normalizeCodexModel(model)
if normalizedModel != "" {
if model != normalizedModel {
reqBody["model"] = normalizedModel
result.Modified = true
}
result.NormalizedModel = normalizedModel
}
if v, ok := reqBody["store"].(bool); !ok || v {
reqBody["store"] = false
result.Modified = true
}
if v, ok := reqBody["stream"].(bool); !ok || !v {
reqBody["stream"] = true
result.Modified = true
}
if _, ok := reqBody["max_output_tokens"]; ok {
delete(reqBody, "max_output_tokens")
result.Modified = true
}
if _, ok := reqBody["max_completion_tokens"]; ok {
delete(reqBody, "max_completion_tokens")
result.Modified = true
}
if normalizeCodexTools(reqBody) {
result.Modified = true
}
if v, ok := reqBody["prompt_cache_key"].(string); ok {
result.PromptCacheKey = strings.TrimSpace(v)
}
instructions := strings.TrimSpace(getOpenCodeCodexHeader())
existingInstructions, _ := reqBody["instructions"].(string)
existingInstructions = strings.TrimSpace(existingInstructions)
if instructions != "" {
if existingInstructions != instructions {
reqBody["instructions"] = instructions
result.Modified = true
}
}
if input, ok := reqBody["input"].([]any); ok {
input = filterCodexInput(input)
reqBody["input"] = input
result.Modified = true
}
return result
}
func normalizeCodexModel(model string) string {
if model == "" {
return "gpt-5.1"
}
modelID := model
if strings.Contains(modelID, "/") {
parts := strings.Split(modelID, "/")
modelID = parts[len(parts)-1]
}
if mapped := getNormalizedCodexModel(modelID); mapped != "" {
return mapped
}
normalized := strings.ToLower(modelID)
if strings.Contains(normalized, "gpt-5.2-codex") || strings.Contains(normalized, "gpt 5.2 codex") {
return "gpt-5.2-codex"
}
if strings.Contains(normalized, "gpt-5.2") || strings.Contains(normalized, "gpt 5.2") {
return "gpt-5.2"
}
if strings.Contains(normalized, "gpt-5.1-codex-max") || strings.Contains(normalized, "gpt 5.1 codex max") {
return "gpt-5.1-codex-max"
}
if strings.Contains(normalized, "gpt-5.1-codex-mini") || strings.Contains(normalized, "gpt 5.1 codex mini") {
return "gpt-5.1-codex-mini"
}
if strings.Contains(normalized, "codex-mini-latest") ||
strings.Contains(normalized, "gpt-5-codex-mini") ||
strings.Contains(normalized, "gpt 5 codex mini") {
return "codex-mini-latest"
}
if strings.Contains(normalized, "gpt-5.1-codex") || strings.Contains(normalized, "gpt 5.1 codex") {
return "gpt-5.1-codex"
}
if strings.Contains(normalized, "gpt-5.1") || strings.Contains(normalized, "gpt 5.1") {
return "gpt-5.1"
}
if strings.Contains(normalized, "codex") {
return "gpt-5.1-codex"
}
if strings.Contains(normalized, "gpt-5") || strings.Contains(normalized, "gpt 5") {
return "gpt-5.1"
}
return "gpt-5.1"
}
func getNormalizedCodexModel(modelID string) string {
if modelID == "" {
return ""
}
if mapped, ok := codexModelMap[modelID]; ok {
return mapped
}
lower := strings.ToLower(modelID)
for key, value := range codexModelMap {
if strings.ToLower(key) == lower {
return value
}
}
return ""
}
func getOpenCodeCachedPrompt(url, cacheFileName, metaFileName string) string {
cacheDir := codexCachePath("")
if cacheDir == "" {
return ""
}
cacheFile := filepath.Join(cacheDir, cacheFileName)
metaFile := filepath.Join(cacheDir, metaFileName)
var cachedContent string
if content, ok := readFile(cacheFile); ok {
cachedContent = content
}
var meta opencodeCacheMetadata
if loadJSON(metaFile, &meta) && meta.LastChecked > 0 && cachedContent != "" {
if time.Since(time.UnixMilli(meta.LastChecked)) < codexCacheTTL {
return cachedContent
}
}
content, etag, status, err := fetchWithETag(url, meta.ETag)
if err == nil && status == http.StatusNotModified && cachedContent != "" {
return cachedContent
}
if err == nil && status >= 200 && status < 300 && content != "" {
_ = writeFile(cacheFile, content)
meta = opencodeCacheMetadata{
ETag: etag,
LastFetch: time.Now().UTC().Format(time.RFC3339),
LastChecked: time.Now().UnixMilli(),
}
_ = writeJSON(metaFile, meta)
return content
}
return cachedContent
}
func getOpenCodeCodexHeader() string {
return getOpenCodeCachedPrompt(opencodeCodexHeaderURL, "opencode-codex-header.txt", "opencode-codex-header-meta.json")
}
func GetOpenCodeInstructions() string {
return getOpenCodeCodexHeader()
}
func filterCodexInput(input []any) []any {
filtered := make([]any, 0, len(input))
for _, item := range input {
m, ok := item.(map[string]any)
if !ok {
filtered = append(filtered, item)
continue
}
if typ, ok := m["type"].(string); ok && typ == "item_reference" {
continue
}
delete(m, "id")
filtered = append(filtered, m)
}
return filtered
}
func normalizeCodexTools(reqBody map[string]any) bool {
rawTools, ok := reqBody["tools"]
if !ok || rawTools == nil {
return false
}
tools, ok := rawTools.([]any)
if !ok {
return false
}
modified := false
for idx, tool := range tools {
toolMap, ok := tool.(map[string]any)
if !ok {
continue
}
toolType, _ := toolMap["type"].(string)
if strings.TrimSpace(toolType) != "function" {
continue
}
function, ok := toolMap["function"].(map[string]any)
if !ok {
continue
}
if _, ok := toolMap["name"]; !ok {
if name, ok := function["name"].(string); ok && strings.TrimSpace(name) != "" {
toolMap["name"] = name
modified = true
}
}
if _, ok := toolMap["description"]; !ok {
if desc, ok := function["description"].(string); ok && strings.TrimSpace(desc) != "" {
toolMap["description"] = desc
modified = true
}
}
if _, ok := toolMap["parameters"]; !ok {
if params, ok := function["parameters"]; ok {
toolMap["parameters"] = params
modified = true
}
}
if _, ok := toolMap["strict"]; !ok {
if strict, ok := function["strict"]; ok {
toolMap["strict"] = strict
modified = true
}
}
tools[idx] = toolMap
}
if modified {
reqBody["tools"] = tools
}
return modified
}
func codexCachePath(filename string) string {
home, err := os.UserHomeDir()
if err != nil {
return ""
}
cacheDir := filepath.Join(home, ".opencode", "cache")
if filename == "" {
return cacheDir
}
return filepath.Join(cacheDir, filename)
}
func readFile(path string) (string, bool) {
if path == "" {
return "", false
}
data, err := os.ReadFile(path)
if err != nil {
return "", false
}
return string(data), true
}
func writeFile(path, content string) error {
if path == "" {
return fmt.Errorf("empty cache path")
}
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
return err
}
return os.WriteFile(path, []byte(content), 0o644)
}
func loadJSON(path string, target any) bool {
data, err := os.ReadFile(path)
if err != nil {
return false
}
if err := json.Unmarshal(data, target); err != nil {
return false
}
return true
}
func writeJSON(path string, value any) error {
if path == "" {
return fmt.Errorf("empty json path")
}
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
return err
}
data, err := json.Marshal(value)
if err != nil {
return err
}
return os.WriteFile(path, data, 0o644)
}
func fetchWithETag(url, etag string) (string, string, int, error) {
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return "", "", 0, err
}
req.Header.Set("User-Agent", "sub2api-codex")
if etag != "" {
req.Header.Set("If-None-Match", etag)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return "", "", 0, err
}
defer func() {
_ = resp.Body.Close()
}()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", "", resp.StatusCode, err
}
return string(body), resp.Header.Get("etag"), resp.StatusCode, nil
}

View File

@@ -12,6 +12,7 @@ import (
"io"
"log"
"net/http"
"os"
"regexp"
"sort"
"strconv"
@@ -20,6 +21,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/gin-gonic/gin"
@@ -528,33 +530,38 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
// Extract model and stream from parsed body
reqModel, _ := reqBody["model"].(string)
reqStream, _ := reqBody["stream"].(bool)
promptCacheKey := ""
if v, ok := reqBody["prompt_cache_key"].(string); ok {
promptCacheKey = strings.TrimSpace(v)
}
// Track if body needs re-serialization
bodyModified := false
originalModel := reqModel
// Apply model mapping
mappedModel := account.GetMappedModel(reqModel)
if mappedModel != reqModel {
reqBody["model"] = mappedModel
bodyModified = true
isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent"))
// Apply model mapping (skip for Codex CLI for transparent forwarding)
mappedModel := reqModel
if !isCodexCLI {
mappedModel = account.GetMappedModel(reqModel)
if mappedModel != reqModel {
reqBody["model"] = mappedModel
bodyModified = true
}
}
// For OAuth accounts using ChatGPT internal API:
// 1. Add store: false
// 2. Normalize input format for Codex API compatibility
if account.Type == AccountTypeOAuth {
reqBody["store"] = false
// Codex 上游不接受 max_output_tokens 参数,需要在转发前移除。
delete(reqBody, "max_output_tokens")
bodyModified = true
// Normalize input format: convert AI SDK multi-part content format to simplified format
// AI SDK sends: {"content": [{"type": "input_text", "text": "..."}]}
// Codex API expects: {"content": "..."}
if normalizeInputForCodexAPI(reqBody) {
if account.Type == AccountTypeOAuth && !isCodexCLI {
codexResult := applyCodexOAuthTransform(reqBody)
if codexResult.Modified {
bodyModified = true
}
if codexResult.NormalizedModel != "" {
mappedModel = codexResult.NormalizedModel
}
if codexResult.PromptCacheKey != "" {
promptCacheKey = codexResult.PromptCacheKey
}
}
// Re-serialize body only if modified
@@ -573,7 +580,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
}
// Build upstream request
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, reqStream)
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI)
if err != nil {
return nil, err
}
@@ -674,7 +681,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
}, nil
}
func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token string, isStream bool) (*http.Request, error) {
func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token string, isStream bool, promptCacheKey string, isCodexCLI bool) (*http.Request, error) {
// Determine target URL based on account type
var targetURL string
switch account.Type {
@@ -714,12 +721,6 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
if chatgptAccountID != "" {
req.Header.Set("chatgpt-account-id", chatgptAccountID)
}
// Set accept header based on stream mode
if isStream {
req.Header.Set("accept", "text/event-stream")
} else {
req.Header.Set("accept", "application/json")
}
}
// Whitelist passthrough headers
@@ -731,6 +732,22 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
}
}
}
if account.Type == AccountTypeOAuth {
req.Header.Set("OpenAI-Beta", "responses=experimental")
if isCodexCLI {
req.Header.Set("originator", "codex_cli_rs")
} else {
req.Header.Set("originator", "opencode")
}
req.Header.Set("accept", "text/event-stream")
if promptCacheKey != "" {
req.Header.Set("conversation_id", promptCacheKey)
req.Header.Set("session_id", promptCacheKey)
} else {
req.Header.Del("conversation_id")
req.Header.Del("session_id")
}
}
// Apply custom User-Agent if configured
customUA := account.GetOpenAIUserAgent()
@@ -1109,6 +1126,13 @@ func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, r
return nil, err
}
if account.Type == AccountTypeOAuth {
bodyLooksLikeSSE := bytes.Contains(body, []byte("data:")) || bytes.Contains(body, []byte("event:"))
if isEventStreamResponse(resp.Header) || bodyLooksLikeSSE {
return s.handleOAuthSSEToJSON(resp, c, body, originalModel, mappedModel)
}
}
// Parse usage
var response struct {
Usage struct {
@@ -1148,6 +1172,110 @@ func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, r
return usage, nil
}
func isEventStreamResponse(header http.Header) bool {
contentType := strings.ToLower(header.Get("Content-Type"))
return strings.Contains(contentType, "text/event-stream")
}
func (s *OpenAIGatewayService) handleOAuthSSEToJSON(resp *http.Response, c *gin.Context, body []byte, originalModel, mappedModel string) (*OpenAIUsage, error) {
bodyText := string(body)
finalResponse, ok := extractCodexFinalResponse(bodyText)
usage := &OpenAIUsage{}
if ok {
var response struct {
Usage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
InputTokenDetails struct {
CachedTokens int `json:"cached_tokens"`
} `json:"input_tokens_details"`
} `json:"usage"`
}
if err := json.Unmarshal(finalResponse, &response); err == nil {
usage.InputTokens = response.Usage.InputTokens
usage.OutputTokens = response.Usage.OutputTokens
usage.CacheReadInputTokens = response.Usage.InputTokenDetails.CachedTokens
}
body = finalResponse
if originalModel != mappedModel {
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
}
} else {
usage = s.parseSSEUsageFromBody(bodyText)
if originalModel != mappedModel {
bodyText = s.replaceModelInSSEBody(bodyText, mappedModel, originalModel)
}
body = []byte(bodyText)
}
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
contentType := "application/json; charset=utf-8"
if !ok {
contentType = resp.Header.Get("Content-Type")
if contentType == "" {
contentType = "text/event-stream"
}
}
c.Data(resp.StatusCode, contentType, body)
return usage, nil
}
func extractCodexFinalResponse(body string) ([]byte, bool) {
lines := strings.Split(body, "\n")
for _, line := range lines {
if !openaiSSEDataRe.MatchString(line) {
continue
}
data := openaiSSEDataRe.ReplaceAllString(line, "")
if data == "" || data == "[DONE]" {
continue
}
var event struct {
Type string `json:"type"`
Response json.RawMessage `json:"response"`
}
if json.Unmarshal([]byte(data), &event) != nil {
continue
}
if event.Type == "response.done" || event.Type == "response.completed" {
if len(event.Response) > 0 {
return event.Response, true
}
}
}
return nil, false
}
func (s *OpenAIGatewayService) parseSSEUsageFromBody(body string) *OpenAIUsage {
usage := &OpenAIUsage{}
lines := strings.Split(body, "\n")
for _, line := range lines {
if !openaiSSEDataRe.MatchString(line) {
continue
}
data := openaiSSEDataRe.ReplaceAllString(line, "")
if data == "" || data == "[DONE]" {
continue
}
s.parseSSEUsage(data, usage)
}
return usage
}
func (s *OpenAIGatewayService) replaceModelInSSEBody(body, fromModel, toModel string) string {
lines := strings.Split(body, "\n")
for i, line := range lines {
if !openaiSSEDataRe.MatchString(line) {
continue
}
lines[i] = s.replaceModelInSSELine(line, fromModel, toModel)
}
return strings.Join(lines, "\n")
}
func (s *OpenAIGatewayService) validateUpstreamBaseURL(raw string) (string, error) {
if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled {
normalized, err := urlvalidator.ValidateURLFormat(raw, s.cfg.Security.URLAllowlist.AllowInsecureHTTP)
@@ -1187,101 +1315,6 @@ func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel
return newBody
}
// normalizeInputForCodexAPI converts AI SDK multi-part content format to simplified format
// that the ChatGPT internal Codex API expects.
//
// AI SDK sends content as an array of typed objects:
//
// {"content": [{"type": "input_text", "text": "hello"}]}
//
// ChatGPT Codex API expects content as a simple string:
//
// {"content": "hello"}
//
// This function modifies reqBody in-place and returns true if any modification was made.
func normalizeInputForCodexAPI(reqBody map[string]any) bool {
input, ok := reqBody["input"]
if !ok {
return false
}
// Handle case where input is a simple string (already compatible)
if _, isString := input.(string); isString {
return false
}
// Handle case where input is an array of messages
inputArray, ok := input.([]any)
if !ok {
return false
}
modified := false
for _, item := range inputArray {
message, ok := item.(map[string]any)
if !ok {
continue
}
content, ok := message["content"]
if !ok {
continue
}
// If content is already a string, no conversion needed
if _, isString := content.(string); isString {
continue
}
// If content is an array (AI SDK format), convert to string
contentArray, ok := content.([]any)
if !ok {
continue
}
// Extract text from content array
var textParts []string
for _, part := range contentArray {
partMap, ok := part.(map[string]any)
if !ok {
continue
}
// Handle different content types
partType, _ := partMap["type"].(string)
switch partType {
case "input_text", "text":
// Extract text from input_text or text type
if text, ok := partMap["text"].(string); ok {
textParts = append(textParts, text)
}
case "input_image", "image":
// For images, we need to preserve the original format
// as ChatGPT Codex API may support images in a different way
// For now, skip image parts (they will be lost in conversion)
// TODO: Consider preserving image data or handling it separately
continue
case "input_file", "file":
// Similar to images, file inputs may need special handling
continue
default:
// For unknown types, try to extract text if available
if text, ok := partMap["text"].(string); ok {
textParts = append(textParts, text)
}
}
}
// Convert content array to string
if len(textParts) > 0 {
message["content"] = strings.Join(textParts, "\n")
modified = true
}
}
return modified
}
// OpenAIRecordUsageInput input for recording usage
type OpenAIRecordUsageInput struct {
Result *OpenAIForwardResult

View File

@@ -220,7 +220,7 @@ func TestOpenAIInvalidBaseURLWhenAllowlistDisabled(t *testing.T) {
Credentials: map[string]any{"base_url": "://invalid-url"},
}
_, err := svc.buildUpstreamRequest(c.Request.Context(), c, account, []byte("{}"), "token", false)
_, err := svc.buildUpstreamRequest(c.Request.Context(), c, account, []byte("{}"), "token", false, "", false)
if err == nil {
t.Fatalf("expected error for invalid base_url when allowlist disabled")
}

View File

@@ -24,10 +24,11 @@ var (
// PromoService 优惠码服务
type PromoService struct {
promoRepo PromoCodeRepository
userRepo UserRepository
billingCacheService *BillingCacheService
entClient *dbent.Client
promoRepo PromoCodeRepository
userRepo UserRepository
billingCacheService *BillingCacheService
entClient *dbent.Client
authCacheInvalidator APIKeyAuthCacheInvalidator
}
// NewPromoService 创建优惠码服务实例
@@ -36,12 +37,14 @@ func NewPromoService(
userRepo UserRepository,
billingCacheService *BillingCacheService,
entClient *dbent.Client,
authCacheInvalidator APIKeyAuthCacheInvalidator,
) *PromoService {
return &PromoService{
promoRepo: promoRepo,
userRepo: userRepo,
billingCacheService: billingCacheService,
entClient: entClient,
promoRepo: promoRepo,
userRepo: userRepo,
billingCacheService: billingCacheService,
entClient: entClient,
authCacheInvalidator: authCacheInvalidator,
}
}
@@ -145,6 +148,8 @@ func (s *PromoService) ApplyPromoCode(ctx context.Context, userID int64, code st
return fmt.Errorf("commit transaction: %w", err)
}
s.invalidatePromoCaches(ctx, userID, promoCode.BonusAmount)
// 失效余额缓存
if s.billingCacheService != nil {
go func() {
@@ -157,6 +162,13 @@ func (s *PromoService) ApplyPromoCode(ctx context.Context, userID int64, code st
return nil
}
func (s *PromoService) invalidatePromoCaches(ctx context.Context, userID int64, bonusAmount float64) {
if bonusAmount == 0 || s.authCacheInvalidator == nil {
return
}
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
}
// GenerateRandomCode 生成随机优惠码
func (s *PromoService) GenerateRandomCode() (string, error) {
bytes := make([]byte, 8)

View File

@@ -0,0 +1,122 @@
# Codex Running in OpenCode
You are running Codex through OpenCode, an open-source terminal coding assistant. OpenCode provides different tools but follows Codex operating principles.
## CRITICAL: Tool Replacements
<critical_rule priority="0">
❌ APPLY_PATCH DOES NOT EXIST → ✅ USE "edit" INSTEAD
- NEVER use: apply_patch, applyPatch
- ALWAYS use: edit tool for ALL file modifications
- Before modifying files: Verify you're using "edit", NOT "apply_patch"
</critical_rule>
<critical_rule priority="0">
❌ UPDATE_PLAN DOES NOT EXIST → ✅ USE "todowrite" INSTEAD
- NEVER use: update_plan, updatePlan, read_plan, readPlan
- ALWAYS use: todowrite for task/plan updates, todoread to read plans
- Before plan operations: Verify you're using "todowrite", NOT "update_plan"
</critical_rule>
## Available OpenCode Tools
**File Operations:**
- `write` - Create new files
- Overwriting existing files requires a prior Read in this session; default to ASCII unless the file already uses Unicode.
- `edit` - Modify existing files (REPLACES apply_patch)
- Requires a prior Read in this session; preserve exact indentation; ensure `oldString` uniquely matches or use `replaceAll`; edit fails if ambiguous or missing.
- `read` - Read file contents
**Search/Discovery:**
- `grep` - Search file contents (tool, not bash grep); use `include` to filter patterns; set `path` only when not searching workspace root; for cross-file match counts use bash with `rg`.
- `glob` - Find files by pattern; defaults to workspace cwd unless `path` is set.
- `list` - List directories (requires absolute paths)
**Execution:**
- `bash` - Run shell commands
- No workdir parameter; do not include it in tool calls.
- Always include a short description for the command.
- Do not use cd; use absolute paths in commands.
- Quote paths containing spaces with double quotes.
- Chain multiple commands with ';' or '&&'; avoid newlines.
- Use Grep/Glob tools for searches; only use bash with `rg` when you need counts or advanced features.
- Do not use `ls`/`cat` in bash; use `list`/`read` tools instead.
- For deletions (rm), verify by listing parent dir with `list`.
**Network:**
- `webfetch` - Fetch web content
- Use fully-formed URLs (http/https; http auto-upgrades to https).
- Always set `format` to one of: text | markdown | html; prefer markdown unless otherwise required.
- Read-only; short cache window.
**Task Management:**
- `todowrite` - Manage tasks/plans (REPLACES update_plan)
- `todoread` - Read current plan
## Substitution Rules
Base instruction says: You MUST use instead:
apply_patch → edit
update_plan → todowrite
read_plan → todoread
**Path Usage:** Use per-tool conventions to avoid conflicts:
- Tool calls: `read`, `edit`, `write`, `list` require absolute paths.
- Searches: `grep`/`glob` default to the workspace cwd; prefer relative include patterns; set `path` only when a different root is needed.
- Presentation: In assistant messages, show workspace-relative paths; use absolute paths only inside tool calls.
- Tool schema overrides general path preferences—do not convert required absolute paths to relative.
## Verification Checklist
Before file/plan modifications:
1. Am I using "edit" NOT "apply_patch"?
2. Am I using "todowrite" NOT "update_plan"?
3. Is this tool in the approved list above?
4. Am I following each tool's path requirements?
If ANY answer is NO → STOP and correct before proceeding.
## OpenCode Working Style
**Communication:**
- Send brief preambles (8-12 words) before tool calls, building on prior context
- Provide progress updates during longer tasks
**Execution:**
- Keep working autonomously until query is fully resolved before yielding
- Don't return to user with partial solutions
**Code Approach:**
- New projects: Be ambitious and creative
- Existing codebases: Surgical precision - modify only what's requested unless explicitly instructed to do otherwise
**Testing:**
- If tests exist: Start specific to your changes, then broader validation
## Advanced Tools
**Task Tool (Sub-Agents):**
- Use the Task tool (functions.task) to launch sub-agents
- Check the Task tool description for current agent types and their capabilities
- Useful for complex analysis, specialized workflows, or tasks requiring isolated context
- The agent list is dynamically generated - refer to tool schema for available agents
**Parallelization:**
- When multiple independent tool calls are needed, use multi_tool_use.parallel to run them concurrently.
- Reserve sequential calls for ordered or data-dependent steps.
**MCP Tools:**
- Model Context Protocol servers provide additional capabilities
- MCP tools are prefixed: `mcp__<server-name>__<tool-name>`
- Check your available tools for MCP integrations
- Use when the tool's functionality matches your task needs
## What Remains from Codex
Sandbox policies, approval mechanisms, final answer formatting, git commit protocols, and file reference formats all follow Codex instructions. In approval policy "never", never request escalations.
## Approvals & Safety
- Assume workspace-write filesystem, network enabled, approval on-failure unless explicitly stated otherwise.
- When a command fails due to sandboxing or permissions, retry with escalated permissions if allowed by policy, including a one-line justification.
- Treat destructive commands (e.g., `rm`, `git reset --hard`) as requiring explicit user request or approval.
- When uncertain, prefer non-destructive verification first (e.g., confirm file existence with `list`, then delete with `bash`).

View File

@@ -0,0 +1,63 @@
<user_instructions priority="0">
<environment_override priority="0">
YOU ARE IN A DIFFERENT ENVIRONMENT. These instructions override ALL previous tool references.
</environment_override>
<tool_replacements priority="0">
<critical_rule priority="0">
❌ APPLY_PATCH DOES NOT EXIST → ✅ USE "edit" INSTEAD
- NEVER use: apply_patch, applyPatch
- ALWAYS use: edit tool for ALL file modifications
- Before modifying files: Verify you're using "edit", NOT "apply_patch"
</critical_rule>
<critical_rule priority="0">
❌ UPDATE_PLAN DOES NOT EXIST → ✅ USE "todowrite" INSTEAD
- NEVER use: update_plan, updatePlan
- ALWAYS use: todowrite for ALL task/plan operations
- Use todoread to read current plan
- Before plan operations: Verify you're using "todowrite", NOT "update_plan"
</critical_rule>
</tool_replacements>
<available_tools priority="0">
File Operations:
• write - Create new files
• edit - Modify existing files (REPLACES apply_patch)
• patch - Apply diff patches
• read - Read file contents
Search/Discovery:
• grep - Search file contents
• glob - Find files by pattern
• list - List directories (use relative paths)
Execution:
• bash - Run shell commands
Network:
• webfetch - Fetch web content
Task Management:
• todowrite - Manage tasks/plans (REPLACES update_plan)
• todoread - Read current plan
</available_tools>
<substitution_rules priority="0">
Base instruction says: You MUST use instead:
apply_patch → edit
update_plan → todowrite
read_plan → todoread
absolute paths → relative paths
</substitution_rules>
<verification_checklist priority="0">
Before file/plan modifications:
1. Am I using "edit" NOT "apply_patch"?
2. Am I using "todowrite" NOT "update_plan"?
3. Is this tool in the approved list above?
4. Am I using relative paths?
If ANY answer is NO → STOP and correct before proceeding.
</verification_checklist>
</user_instructions>

View File

@@ -68,12 +68,13 @@ type RedeemCodeResponse struct {
// RedeemService 兑换码服务
type RedeemService struct {
redeemRepo RedeemCodeRepository
userRepo UserRepository
subscriptionService *SubscriptionService
cache RedeemCache
billingCacheService *BillingCacheService
entClient *dbent.Client
redeemRepo RedeemCodeRepository
userRepo UserRepository
subscriptionService *SubscriptionService
cache RedeemCache
billingCacheService *BillingCacheService
entClient *dbent.Client
authCacheInvalidator APIKeyAuthCacheInvalidator
}
// NewRedeemService 创建兑换码服务实例
@@ -84,14 +85,16 @@ func NewRedeemService(
cache RedeemCache,
billingCacheService *BillingCacheService,
entClient *dbent.Client,
authCacheInvalidator APIKeyAuthCacheInvalidator,
) *RedeemService {
return &RedeemService{
redeemRepo: redeemRepo,
userRepo: userRepo,
subscriptionService: subscriptionService,
cache: cache,
billingCacheService: billingCacheService,
entClient: entClient,
redeemRepo: redeemRepo,
userRepo: userRepo,
subscriptionService: subscriptionService,
cache: cache,
billingCacheService: billingCacheService,
entClient: entClient,
authCacheInvalidator: authCacheInvalidator,
}
}
@@ -324,18 +327,33 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
// invalidateRedeemCaches 失效兑换相关的缓存
func (s *RedeemService) invalidateRedeemCaches(ctx context.Context, userID int64, redeemCode *RedeemCode) {
if s.billingCacheService == nil {
return
}
switch redeemCode.Type {
case RedeemTypeBalance:
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
}
if s.billingCacheService == nil {
return
}
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = s.billingCacheService.InvalidateUserBalance(cacheCtx, userID)
}()
case RedeemTypeConcurrency:
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
}
if s.billingCacheService == nil {
return
}
case RedeemTypeSubscription:
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
}
if s.billingCacheService == nil {
return
}
if redeemCode.GroupID != nil {
groupID := *redeemCode.GroupID
go func() {

View File

@@ -54,17 +54,19 @@ type UsageStats struct {
// UsageService 使用统计服务
type UsageService struct {
usageRepo UsageLogRepository
userRepo UserRepository
entClient *dbent.Client
usageRepo UsageLogRepository
userRepo UserRepository
entClient *dbent.Client
authCacheInvalidator APIKeyAuthCacheInvalidator
}
// NewUsageService 创建使用统计服务实例
func NewUsageService(usageRepo UsageLogRepository, userRepo UserRepository, entClient *dbent.Client) *UsageService {
func NewUsageService(usageRepo UsageLogRepository, userRepo UserRepository, entClient *dbent.Client, authCacheInvalidator APIKeyAuthCacheInvalidator) *UsageService {
return &UsageService{
usageRepo: usageRepo,
userRepo: userRepo,
entClient: entClient,
usageRepo: usageRepo,
userRepo: userRepo,
entClient: entClient,
authCacheInvalidator: authCacheInvalidator,
}
}
@@ -118,10 +120,12 @@ func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (*
}
// 扣除用户余额
balanceUpdated := false
if inserted && req.ActualCost > 0 {
if err := s.userRepo.UpdateBalance(txCtx, req.UserID, -req.ActualCost); err != nil {
return nil, fmt.Errorf("update user balance: %w", err)
}
balanceUpdated = true
}
if tx != nil {
@@ -130,9 +134,18 @@ func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (*
}
}
s.invalidateUsageCaches(ctx, req.UserID, balanceUpdated)
return usageLog, nil
}
func (s *UsageService) invalidateUsageCaches(ctx context.Context, userID int64, balanceUpdated bool) {
if !balanceUpdated || s.authCacheInvalidator == nil {
return
}
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
}
// GetByID 根据ID获取使用日志
func (s *UsageService) GetByID(ctx context.Context, id int64) (*UsageLog, error) {
log, err := s.usageRepo.GetByID(ctx, id)

View File

@@ -55,13 +55,15 @@ type ChangePasswordRequest struct {
// UserService 用户服务
type UserService struct {
userRepo UserRepository
userRepo UserRepository
authCacheInvalidator APIKeyAuthCacheInvalidator
}
// NewUserService 创建用户服务实例
func NewUserService(userRepo UserRepository) *UserService {
func NewUserService(userRepo UserRepository, authCacheInvalidator APIKeyAuthCacheInvalidator) *UserService {
return &UserService{
userRepo: userRepo,
userRepo: userRepo,
authCacheInvalidator: authCacheInvalidator,
}
}
@@ -89,6 +91,7 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat
if err != nil {
return nil, fmt.Errorf("get user: %w", err)
}
oldConcurrency := user.Concurrency
// 更新字段
if req.Email != nil {
@@ -114,6 +117,9 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat
if err := s.userRepo.Update(ctx, user); err != nil {
return nil, fmt.Errorf("update user: %w", err)
}
if s.authCacheInvalidator != nil && user.Concurrency != oldConcurrency {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
}
return user, nil
}
@@ -169,6 +175,9 @@ func (s *UserService) UpdateBalance(ctx context.Context, userID int64, amount fl
if err := s.userRepo.UpdateBalance(ctx, userID, amount); err != nil {
return fmt.Errorf("update balance: %w", err)
}
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
}
return nil
}
@@ -177,6 +186,9 @@ func (s *UserService) UpdateConcurrency(ctx context.Context, userID int64, concu
if err := s.userRepo.UpdateConcurrency(ctx, userID, concurrency); err != nil {
return fmt.Errorf("update concurrency: %w", err)
}
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
}
return nil
}
@@ -192,12 +204,18 @@ func (s *UserService) UpdateStatus(ctx context.Context, userID int64, status str
if err := s.userRepo.Update(ctx, user); err != nil {
return fmt.Errorf("update user: %w", err)
}
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
}
return nil
}
// Delete 删除用户(管理员功能)
func (s *UserService) Delete(ctx context.Context, userID int64) error {
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
}
if err := s.userRepo.Delete(ctx, userID); err != nil {
return fmt.Errorf("delete user: %w", err)
}

View File

@@ -49,6 +49,13 @@ func ProvideTokenRefreshService(
return svc
}
// ProvideDashboardAggregationService 创建并启动仪表盘聚合服务
func ProvideDashboardAggregationService(repo DashboardAggregationRepository, timingWheel *TimingWheelService, cfg *config.Config) *DashboardAggregationService {
svc := NewDashboardAggregationService(repo, timingWheel, cfg)
svc.Start()
return svc
}
// ProvideAccountExpiryService creates and starts AccountExpiryService.
func ProvideAccountExpiryService(accountRepo AccountRepository) *AccountExpiryService {
svc := NewAccountExpiryService(accountRepo, time.Minute)
@@ -145,12 +152,18 @@ func ProvideOpsScheduledReportService(
return svc
}
// ProvideAPIKeyAuthCacheInvalidator 提供 API Key 认证缓存失效能力
func ProvideAPIKeyAuthCacheInvalidator(apiKeyService *APIKeyService) APIKeyAuthCacheInvalidator {
return apiKeyService
}
// ProviderSet is the Wire provider set for all services
var ProviderSet = wire.NewSet(
// Core services
NewAuthService,
NewUserService,
NewAPIKeyService,
ProvideAPIKeyAuthCacheInvalidator,
NewGroupService,
NewAccountService,
NewProxyService,
@@ -194,6 +207,7 @@ var ProviderSet = wire.NewSet(
ProvideTokenRefreshService,
ProvideAccountExpiryService,
ProvideTimingWheelService,
ProvideDashboardAggregationService,
ProvideDeferredService,
NewAntigravityQuotaFetcher,
NewUserAttributeService,