merge: 合并main分支最新改动
解决冲突: - backend/internal/config/config.go: 合并Ops和Dashboard配置 - backend/internal/server/api_contract_test.go: 合并handler初始化 - backend/internal/service/openai_gateway_service.go: 保留Ops错误追踪逻辑 - backend/internal/service/wire.go: 合并Ops和APIKeyAuth provider 主要合并内容: - Dashboard缓存和预聚合功能 - API Key认证缓存优化 - Codex转换支持 - 使用日志分区表
This commit is contained in:
@@ -186,9 +186,11 @@ type BulkUpdateAccountResult struct {
|
||||
|
||||
// BulkUpdateAccountsResult is the aggregated response for bulk updates.
|
||||
type BulkUpdateAccountsResult struct {
|
||||
Success int `json:"success"`
|
||||
Failed int `json:"failed"`
|
||||
Results []BulkUpdateAccountResult `json:"results"`
|
||||
Success int `json:"success"`
|
||||
Failed int `json:"failed"`
|
||||
SuccessIDs []int64 `json:"success_ids"`
|
||||
FailedIDs []int64 `json:"failed_ids"`
|
||||
Results []BulkUpdateAccountResult `json:"results"`
|
||||
}
|
||||
|
||||
type CreateProxyInput struct {
|
||||
@@ -244,14 +246,15 @@ type ProxyExitInfoProber interface {
|
||||
|
||||
// adminServiceImpl implements AdminService
|
||||
type adminServiceImpl struct {
|
||||
userRepo UserRepository
|
||||
groupRepo GroupRepository
|
||||
accountRepo AccountRepository
|
||||
proxyRepo ProxyRepository
|
||||
apiKeyRepo APIKeyRepository
|
||||
redeemCodeRepo RedeemCodeRepository
|
||||
billingCacheService *BillingCacheService
|
||||
proxyProber ProxyExitInfoProber
|
||||
userRepo UserRepository
|
||||
groupRepo GroupRepository
|
||||
accountRepo AccountRepository
|
||||
proxyRepo ProxyRepository
|
||||
apiKeyRepo APIKeyRepository
|
||||
redeemCodeRepo RedeemCodeRepository
|
||||
billingCacheService *BillingCacheService
|
||||
proxyProber ProxyExitInfoProber
|
||||
authCacheInvalidator APIKeyAuthCacheInvalidator
|
||||
}
|
||||
|
||||
// NewAdminService creates a new AdminService
|
||||
@@ -264,16 +267,18 @@ func NewAdminService(
|
||||
redeemCodeRepo RedeemCodeRepository,
|
||||
billingCacheService *BillingCacheService,
|
||||
proxyProber ProxyExitInfoProber,
|
||||
authCacheInvalidator APIKeyAuthCacheInvalidator,
|
||||
) AdminService {
|
||||
return &adminServiceImpl{
|
||||
userRepo: userRepo,
|
||||
groupRepo: groupRepo,
|
||||
accountRepo: accountRepo,
|
||||
proxyRepo: proxyRepo,
|
||||
apiKeyRepo: apiKeyRepo,
|
||||
redeemCodeRepo: redeemCodeRepo,
|
||||
billingCacheService: billingCacheService,
|
||||
proxyProber: proxyProber,
|
||||
userRepo: userRepo,
|
||||
groupRepo: groupRepo,
|
||||
accountRepo: accountRepo,
|
||||
proxyRepo: proxyRepo,
|
||||
apiKeyRepo: apiKeyRepo,
|
||||
redeemCodeRepo: redeemCodeRepo,
|
||||
billingCacheService: billingCacheService,
|
||||
proxyProber: proxyProber,
|
||||
authCacheInvalidator: authCacheInvalidator,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -323,6 +328,8 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
|
||||
}
|
||||
|
||||
oldConcurrency := user.Concurrency
|
||||
oldStatus := user.Status
|
||||
oldRole := user.Role
|
||||
|
||||
if input.Email != "" {
|
||||
user.Email = input.Email
|
||||
@@ -355,6 +362,11 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
|
||||
if err := s.userRepo.Update(ctx, user); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.authCacheInvalidator != nil {
|
||||
if user.Concurrency != oldConcurrency || user.Status != oldStatus || user.Role != oldRole {
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, user.ID)
|
||||
}
|
||||
}
|
||||
|
||||
concurrencyDiff := user.Concurrency - oldConcurrency
|
||||
if concurrencyDiff != 0 {
|
||||
@@ -393,6 +405,9 @@ func (s *adminServiceImpl) DeleteUser(ctx context.Context, id int64) error {
|
||||
log.Printf("delete user failed: user_id=%d err=%v", id, err)
|
||||
return err
|
||||
}
|
||||
if s.authCacheInvalidator != nil {
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, id)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -420,6 +435,10 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
|
||||
if err := s.userRepo.Update(ctx, user); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
balanceDiff := user.Balance - oldBalance
|
||||
if s.authCacheInvalidator != nil && balanceDiff != 0 {
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
|
||||
}
|
||||
|
||||
if s.billingCacheService != nil {
|
||||
go func() {
|
||||
@@ -431,7 +450,6 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
|
||||
}()
|
||||
}
|
||||
|
||||
balanceDiff := user.Balance - oldBalance
|
||||
if balanceDiff != 0 {
|
||||
code, err := GenerateRedeemCode()
|
||||
if err != nil {
|
||||
@@ -675,10 +693,21 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
|
||||
if err := s.groupRepo.Update(ctx, group); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.authCacheInvalidator != nil {
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id)
|
||||
}
|
||||
return group, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
|
||||
var groupKeys []string
|
||||
if s.authCacheInvalidator != nil {
|
||||
keys, err := s.apiKeyRepo.ListKeysByGroupID(ctx, id)
|
||||
if err == nil {
|
||||
groupKeys = keys
|
||||
}
|
||||
}
|
||||
|
||||
affectedUserIDs, err := s.groupRepo.DeleteCascade(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -697,6 +726,11 @@ func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
|
||||
}
|
||||
}()
|
||||
}
|
||||
if s.authCacheInvalidator != nil {
|
||||
for _, key := range groupKeys {
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByKey(ctx, key)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -885,7 +919,9 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
|
||||
// It merges credentials/extra keys instead of overwriting the whole object.
|
||||
func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error) {
|
||||
result := &BulkUpdateAccountsResult{
|
||||
Results: make([]BulkUpdateAccountResult, 0, len(input.AccountIDs)),
|
||||
SuccessIDs: make([]int64, 0, len(input.AccountIDs)),
|
||||
FailedIDs: make([]int64, 0, len(input.AccountIDs)),
|
||||
Results: make([]BulkUpdateAccountResult, 0, len(input.AccountIDs)),
|
||||
}
|
||||
|
||||
if len(input.AccountIDs) == 0 {
|
||||
@@ -949,6 +985,7 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
|
||||
entry.Success = false
|
||||
entry.Error = err.Error()
|
||||
result.Failed++
|
||||
result.FailedIDs = append(result.FailedIDs, accountID)
|
||||
result.Results = append(result.Results, entry)
|
||||
continue
|
||||
}
|
||||
@@ -958,6 +995,7 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
|
||||
entry.Success = false
|
||||
entry.Error = err.Error()
|
||||
result.Failed++
|
||||
result.FailedIDs = append(result.FailedIDs, accountID)
|
||||
result.Results = append(result.Results, entry)
|
||||
continue
|
||||
}
|
||||
@@ -967,6 +1005,7 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
|
||||
entry.Success = false
|
||||
entry.Error = err.Error()
|
||||
result.Failed++
|
||||
result.FailedIDs = append(result.FailedIDs, accountID)
|
||||
result.Results = append(result.Results, entry)
|
||||
continue
|
||||
}
|
||||
@@ -974,6 +1013,7 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
|
||||
|
||||
entry.Success = true
|
||||
result.Success++
|
||||
result.SuccessIDs = append(result.SuccessIDs, accountID)
|
||||
result.Results = append(result.Results, entry)
|
||||
}
|
||||
|
||||
|
||||
80
backend/internal/service/admin_service_bulk_update_test.go
Normal file
80
backend/internal/service/admin_service_bulk_update_test.go
Normal file
@@ -0,0 +1,80 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type accountRepoStubForBulkUpdate struct {
|
||||
accountRepoStub
|
||||
bulkUpdateErr error
|
||||
bulkUpdateIDs []int64
|
||||
bindGroupErrByID map[int64]error
|
||||
}
|
||||
|
||||
func (s *accountRepoStubForBulkUpdate) BulkUpdate(_ context.Context, ids []int64, _ AccountBulkUpdate) (int64, error) {
|
||||
s.bulkUpdateIDs = append([]int64{}, ids...)
|
||||
if s.bulkUpdateErr != nil {
|
||||
return 0, s.bulkUpdateErr
|
||||
}
|
||||
return int64(len(ids)), nil
|
||||
}
|
||||
|
||||
func (s *accountRepoStubForBulkUpdate) BindGroups(_ context.Context, accountID int64, _ []int64) error {
|
||||
if err, ok := s.bindGroupErrByID[accountID]; ok {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TestAdminService_BulkUpdateAccounts_AllSuccessIDs 验证批量更新成功时返回 success_ids/failed_ids。
|
||||
func TestAdminService_BulkUpdateAccounts_AllSuccessIDs(t *testing.T) {
|
||||
repo := &accountRepoStubForBulkUpdate{}
|
||||
svc := &adminServiceImpl{accountRepo: repo}
|
||||
|
||||
schedulable := true
|
||||
input := &BulkUpdateAccountsInput{
|
||||
AccountIDs: []int64{1, 2, 3},
|
||||
Schedulable: &schedulable,
|
||||
}
|
||||
|
||||
result, err := svc.BulkUpdateAccounts(context.Background(), input)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 3, result.Success)
|
||||
require.Equal(t, 0, result.Failed)
|
||||
require.ElementsMatch(t, []int64{1, 2, 3}, result.SuccessIDs)
|
||||
require.Empty(t, result.FailedIDs)
|
||||
require.Len(t, result.Results, 3)
|
||||
}
|
||||
|
||||
// TestAdminService_BulkUpdateAccounts_PartialFailureIDs 验证部分失败时 success_ids/failed_ids 正确。
|
||||
func TestAdminService_BulkUpdateAccounts_PartialFailureIDs(t *testing.T) {
|
||||
repo := &accountRepoStubForBulkUpdate{
|
||||
bindGroupErrByID: map[int64]error{
|
||||
2: errors.New("bind failed"),
|
||||
},
|
||||
}
|
||||
svc := &adminServiceImpl{accountRepo: repo}
|
||||
|
||||
groupIDs := []int64{10}
|
||||
schedulable := false
|
||||
input := &BulkUpdateAccountsInput{
|
||||
AccountIDs: []int64{1, 2, 3},
|
||||
GroupIDs: &groupIDs,
|
||||
Schedulable: &schedulable,
|
||||
SkipMixedChannelCheck: true,
|
||||
}
|
||||
|
||||
result, err := svc.BulkUpdateAccounts(context.Background(), input)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 2, result.Success)
|
||||
require.Equal(t, 1, result.Failed)
|
||||
require.ElementsMatch(t, []int64{1, 3}, result.SuccessIDs)
|
||||
require.ElementsMatch(t, []int64{2}, result.FailedIDs)
|
||||
require.Len(t, result.Results, 3)
|
||||
}
|
||||
@@ -0,0 +1,97 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type balanceUserRepoStub struct {
|
||||
*userRepoStub
|
||||
updateErr error
|
||||
updated []*User
|
||||
}
|
||||
|
||||
func (s *balanceUserRepoStub) Update(ctx context.Context, user *User) error {
|
||||
if s.updateErr != nil {
|
||||
return s.updateErr
|
||||
}
|
||||
if user == nil {
|
||||
return nil
|
||||
}
|
||||
clone := *user
|
||||
s.updated = append(s.updated, &clone)
|
||||
if s.userRepoStub != nil {
|
||||
s.userRepoStub.user = &clone
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type balanceRedeemRepoStub struct {
|
||||
*redeemRepoStub
|
||||
created []*RedeemCode
|
||||
}
|
||||
|
||||
func (s *balanceRedeemRepoStub) Create(ctx context.Context, code *RedeemCode) error {
|
||||
if code == nil {
|
||||
return nil
|
||||
}
|
||||
clone := *code
|
||||
s.created = append(s.created, &clone)
|
||||
return nil
|
||||
}
|
||||
|
||||
type authCacheInvalidatorStub struct {
|
||||
userIDs []int64
|
||||
groupIDs []int64
|
||||
keys []string
|
||||
}
|
||||
|
||||
func (s *authCacheInvalidatorStub) InvalidateAuthCacheByKey(ctx context.Context, key string) {
|
||||
s.keys = append(s.keys, key)
|
||||
}
|
||||
|
||||
func (s *authCacheInvalidatorStub) InvalidateAuthCacheByUserID(ctx context.Context, userID int64) {
|
||||
s.userIDs = append(s.userIDs, userID)
|
||||
}
|
||||
|
||||
func (s *authCacheInvalidatorStub) InvalidateAuthCacheByGroupID(ctx context.Context, groupID int64) {
|
||||
s.groupIDs = append(s.groupIDs, groupID)
|
||||
}
|
||||
|
||||
func TestAdminService_UpdateUserBalance_InvalidatesAuthCache(t *testing.T) {
|
||||
baseRepo := &userRepoStub{user: &User{ID: 7, Balance: 10}}
|
||||
repo := &balanceUserRepoStub{userRepoStub: baseRepo}
|
||||
redeemRepo := &balanceRedeemRepoStub{redeemRepoStub: &redeemRepoStub{}}
|
||||
invalidator := &authCacheInvalidatorStub{}
|
||||
svc := &adminServiceImpl{
|
||||
userRepo: repo,
|
||||
redeemCodeRepo: redeemRepo,
|
||||
authCacheInvalidator: invalidator,
|
||||
}
|
||||
|
||||
_, err := svc.UpdateUserBalance(context.Background(), 7, 5, "add", "")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []int64{7}, invalidator.userIDs)
|
||||
require.Len(t, redeemRepo.created, 1)
|
||||
}
|
||||
|
||||
func TestAdminService_UpdateUserBalance_NoChangeNoInvalidate(t *testing.T) {
|
||||
baseRepo := &userRepoStub{user: &User{ID: 7, Balance: 10}}
|
||||
repo := &balanceUserRepoStub{userRepoStub: baseRepo}
|
||||
redeemRepo := &balanceRedeemRepoStub{redeemRepoStub: &redeemRepoStub{}}
|
||||
invalidator := &authCacheInvalidatorStub{}
|
||||
svc := &adminServiceImpl{
|
||||
userRepo: repo,
|
||||
redeemCodeRepo: redeemRepo,
|
||||
authCacheInvalidator: invalidator,
|
||||
}
|
||||
|
||||
_, err := svc.UpdateUserBalance(context.Background(), 7, 10, "set", "")
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, invalidator.userIDs)
|
||||
require.Empty(t, redeemRepo.created)
|
||||
}
|
||||
46
backend/internal/service/api_key_auth_cache.go
Normal file
46
backend/internal/service/api_key_auth_cache.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package service
|
||||
|
||||
// APIKeyAuthSnapshot API Key 认证缓存快照(仅包含认证所需字段)
|
||||
type APIKeyAuthSnapshot struct {
|
||||
APIKeyID int64 `json:"api_key_id"`
|
||||
UserID int64 `json:"user_id"`
|
||||
GroupID *int64 `json:"group_id,omitempty"`
|
||||
Status string `json:"status"`
|
||||
IPWhitelist []string `json:"ip_whitelist,omitempty"`
|
||||
IPBlacklist []string `json:"ip_blacklist,omitempty"`
|
||||
User APIKeyAuthUserSnapshot `json:"user"`
|
||||
Group *APIKeyAuthGroupSnapshot `json:"group,omitempty"`
|
||||
}
|
||||
|
||||
// APIKeyAuthUserSnapshot 用户快照
|
||||
type APIKeyAuthUserSnapshot struct {
|
||||
ID int64 `json:"id"`
|
||||
Status string `json:"status"`
|
||||
Role string `json:"role"`
|
||||
Balance float64 `json:"balance"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
}
|
||||
|
||||
// APIKeyAuthGroupSnapshot 分组快照
|
||||
type APIKeyAuthGroupSnapshot struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Platform string `json:"platform"`
|
||||
Status string `json:"status"`
|
||||
SubscriptionType string `json:"subscription_type"`
|
||||
RateMultiplier float64 `json:"rate_multiplier"`
|
||||
DailyLimitUSD *float64 `json:"daily_limit_usd,omitempty"`
|
||||
WeeklyLimitUSD *float64 `json:"weekly_limit_usd,omitempty"`
|
||||
MonthlyLimitUSD *float64 `json:"monthly_limit_usd,omitempty"`
|
||||
ImagePrice1K *float64 `json:"image_price_1k,omitempty"`
|
||||
ImagePrice2K *float64 `json:"image_price_2k,omitempty"`
|
||||
ImagePrice4K *float64 `json:"image_price_4k,omitempty"`
|
||||
ClaudeCodeOnly bool `json:"claude_code_only"`
|
||||
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`
|
||||
}
|
||||
|
||||
// APIKeyAuthCacheEntry 缓存条目,支持负缓存
|
||||
type APIKeyAuthCacheEntry struct {
|
||||
NotFound bool `json:"not_found"`
|
||||
Snapshot *APIKeyAuthSnapshot `json:"snapshot,omitempty"`
|
||||
}
|
||||
269
backend/internal/service/api_key_auth_cache_impl.go
Normal file
269
backend/internal/service/api_key_auth_cache_impl.go
Normal file
@@ -0,0 +1,269 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/dgraph-io/ristretto"
|
||||
)
|
||||
|
||||
type apiKeyAuthCacheConfig struct {
|
||||
l1Size int
|
||||
l1TTL time.Duration
|
||||
l2TTL time.Duration
|
||||
negativeTTL time.Duration
|
||||
jitterPercent int
|
||||
singleflight bool
|
||||
}
|
||||
|
||||
var (
|
||||
jitterRandMu sync.Mutex
|
||||
// 认证缓存抖动使用独立随机源,避免全局 Seed
|
||||
jitterRand = rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
)
|
||||
|
||||
func newAPIKeyAuthCacheConfig(cfg *config.Config) apiKeyAuthCacheConfig {
|
||||
if cfg == nil {
|
||||
return apiKeyAuthCacheConfig{}
|
||||
}
|
||||
auth := cfg.APIKeyAuth
|
||||
return apiKeyAuthCacheConfig{
|
||||
l1Size: auth.L1Size,
|
||||
l1TTL: time.Duration(auth.L1TTLSeconds) * time.Second,
|
||||
l2TTL: time.Duration(auth.L2TTLSeconds) * time.Second,
|
||||
negativeTTL: time.Duration(auth.NegativeTTLSeconds) * time.Second,
|
||||
jitterPercent: auth.JitterPercent,
|
||||
singleflight: auth.Singleflight,
|
||||
}
|
||||
}
|
||||
|
||||
func (c apiKeyAuthCacheConfig) l1Enabled() bool {
|
||||
return c.l1Size > 0 && c.l1TTL > 0
|
||||
}
|
||||
|
||||
func (c apiKeyAuthCacheConfig) l2Enabled() bool {
|
||||
return c.l2TTL > 0
|
||||
}
|
||||
|
||||
func (c apiKeyAuthCacheConfig) negativeEnabled() bool {
|
||||
return c.negativeTTL > 0
|
||||
}
|
||||
|
||||
func (c apiKeyAuthCacheConfig) jitterTTL(ttl time.Duration) time.Duration {
|
||||
if ttl <= 0 {
|
||||
return ttl
|
||||
}
|
||||
if c.jitterPercent <= 0 {
|
||||
return ttl
|
||||
}
|
||||
percent := c.jitterPercent
|
||||
if percent > 100 {
|
||||
percent = 100
|
||||
}
|
||||
delta := float64(percent) / 100
|
||||
jitterRandMu.Lock()
|
||||
randVal := jitterRand.Float64()
|
||||
jitterRandMu.Unlock()
|
||||
factor := 1 - delta + randVal*(2*delta)
|
||||
if factor <= 0 {
|
||||
return ttl
|
||||
}
|
||||
return time.Duration(float64(ttl) * factor)
|
||||
}
|
||||
|
||||
func (s *APIKeyService) initAuthCache(cfg *config.Config) {
|
||||
s.authCfg = newAPIKeyAuthCacheConfig(cfg)
|
||||
if !s.authCfg.l1Enabled() {
|
||||
return
|
||||
}
|
||||
cache, err := ristretto.NewCache(&ristretto.Config{
|
||||
NumCounters: int64(s.authCfg.l1Size) * 10,
|
||||
MaxCost: int64(s.authCfg.l1Size),
|
||||
BufferItems: 64,
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
s.authCacheL1 = cache
|
||||
}
|
||||
|
||||
func (s *APIKeyService) authCacheKey(key string) string {
|
||||
sum := sha256.Sum256([]byte(key))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func (s *APIKeyService) getAuthCacheEntry(ctx context.Context, cacheKey string) (*APIKeyAuthCacheEntry, bool) {
|
||||
if s.authCacheL1 != nil {
|
||||
if val, ok := s.authCacheL1.Get(cacheKey); ok {
|
||||
if entry, ok := val.(*APIKeyAuthCacheEntry); ok {
|
||||
return entry, true
|
||||
}
|
||||
}
|
||||
}
|
||||
if s.cache == nil || !s.authCfg.l2Enabled() {
|
||||
return nil, false
|
||||
}
|
||||
entry, err := s.cache.GetAuthCache(ctx, cacheKey)
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
s.setAuthCacheL1(cacheKey, entry)
|
||||
return entry, true
|
||||
}
|
||||
|
||||
func (s *APIKeyService) setAuthCacheL1(cacheKey string, entry *APIKeyAuthCacheEntry) {
|
||||
if s.authCacheL1 == nil || entry == nil {
|
||||
return
|
||||
}
|
||||
ttl := s.authCfg.l1TTL
|
||||
if entry.NotFound && s.authCfg.negativeTTL > 0 && s.authCfg.negativeTTL < ttl {
|
||||
ttl = s.authCfg.negativeTTL
|
||||
}
|
||||
ttl = s.authCfg.jitterTTL(ttl)
|
||||
_ = s.authCacheL1.SetWithTTL(cacheKey, entry, 1, ttl)
|
||||
}
|
||||
|
||||
func (s *APIKeyService) setAuthCacheEntry(ctx context.Context, cacheKey string, entry *APIKeyAuthCacheEntry, ttl time.Duration) {
|
||||
if entry == nil {
|
||||
return
|
||||
}
|
||||
s.setAuthCacheL1(cacheKey, entry)
|
||||
if s.cache == nil || !s.authCfg.l2Enabled() {
|
||||
return
|
||||
}
|
||||
_ = s.cache.SetAuthCache(ctx, cacheKey, entry, s.authCfg.jitterTTL(ttl))
|
||||
}
|
||||
|
||||
func (s *APIKeyService) deleteAuthCache(ctx context.Context, cacheKey string) {
|
||||
if s.authCacheL1 != nil {
|
||||
s.authCacheL1.Del(cacheKey)
|
||||
}
|
||||
if s.cache == nil {
|
||||
return
|
||||
}
|
||||
_ = s.cache.DeleteAuthCache(ctx, cacheKey)
|
||||
}
|
||||
|
||||
func (s *APIKeyService) loadAuthCacheEntry(ctx context.Context, key, cacheKey string) (*APIKeyAuthCacheEntry, error) {
|
||||
apiKey, err := s.apiKeyRepo.GetByKeyForAuth(ctx, key)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrAPIKeyNotFound) {
|
||||
entry := &APIKeyAuthCacheEntry{NotFound: true}
|
||||
if s.authCfg.negativeEnabled() {
|
||||
s.setAuthCacheEntry(ctx, cacheKey, entry, s.authCfg.negativeTTL)
|
||||
}
|
||||
return entry, nil
|
||||
}
|
||||
return nil, fmt.Errorf("get api key: %w", err)
|
||||
}
|
||||
apiKey.Key = key
|
||||
snapshot := s.snapshotFromAPIKey(apiKey)
|
||||
if snapshot == nil {
|
||||
return nil, fmt.Errorf("get api key: %w", ErrAPIKeyNotFound)
|
||||
}
|
||||
entry := &APIKeyAuthCacheEntry{Snapshot: snapshot}
|
||||
s.setAuthCacheEntry(ctx, cacheKey, entry, s.authCfg.l2TTL)
|
||||
return entry, nil
|
||||
}
|
||||
|
||||
func (s *APIKeyService) applyAuthCacheEntry(key string, entry *APIKeyAuthCacheEntry) (*APIKey, bool, error) {
|
||||
if entry == nil {
|
||||
return nil, false, nil
|
||||
}
|
||||
if entry.NotFound {
|
||||
return nil, true, ErrAPIKeyNotFound
|
||||
}
|
||||
if entry.Snapshot == nil {
|
||||
return nil, false, nil
|
||||
}
|
||||
return s.snapshotToAPIKey(key, entry.Snapshot), true, nil
|
||||
}
|
||||
|
||||
func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
|
||||
if apiKey == nil || apiKey.User == nil {
|
||||
return nil
|
||||
}
|
||||
snapshot := &APIKeyAuthSnapshot{
|
||||
APIKeyID: apiKey.ID,
|
||||
UserID: apiKey.UserID,
|
||||
GroupID: apiKey.GroupID,
|
||||
Status: apiKey.Status,
|
||||
IPWhitelist: apiKey.IPWhitelist,
|
||||
IPBlacklist: apiKey.IPBlacklist,
|
||||
User: APIKeyAuthUserSnapshot{
|
||||
ID: apiKey.User.ID,
|
||||
Status: apiKey.User.Status,
|
||||
Role: apiKey.User.Role,
|
||||
Balance: apiKey.User.Balance,
|
||||
Concurrency: apiKey.User.Concurrency,
|
||||
},
|
||||
}
|
||||
if apiKey.Group != nil {
|
||||
snapshot.Group = &APIKeyAuthGroupSnapshot{
|
||||
ID: apiKey.Group.ID,
|
||||
Name: apiKey.Group.Name,
|
||||
Platform: apiKey.Group.Platform,
|
||||
Status: apiKey.Group.Status,
|
||||
SubscriptionType: apiKey.Group.SubscriptionType,
|
||||
RateMultiplier: apiKey.Group.RateMultiplier,
|
||||
DailyLimitUSD: apiKey.Group.DailyLimitUSD,
|
||||
WeeklyLimitUSD: apiKey.Group.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: apiKey.Group.MonthlyLimitUSD,
|
||||
ImagePrice1K: apiKey.Group.ImagePrice1K,
|
||||
ImagePrice2K: apiKey.Group.ImagePrice2K,
|
||||
ImagePrice4K: apiKey.Group.ImagePrice4K,
|
||||
ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly,
|
||||
FallbackGroupID: apiKey.Group.FallbackGroupID,
|
||||
}
|
||||
}
|
||||
return snapshot
|
||||
}
|
||||
|
||||
func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapshot) *APIKey {
|
||||
if snapshot == nil {
|
||||
return nil
|
||||
}
|
||||
apiKey := &APIKey{
|
||||
ID: snapshot.APIKeyID,
|
||||
UserID: snapshot.UserID,
|
||||
GroupID: snapshot.GroupID,
|
||||
Key: key,
|
||||
Status: snapshot.Status,
|
||||
IPWhitelist: snapshot.IPWhitelist,
|
||||
IPBlacklist: snapshot.IPBlacklist,
|
||||
User: &User{
|
||||
ID: snapshot.User.ID,
|
||||
Status: snapshot.User.Status,
|
||||
Role: snapshot.User.Role,
|
||||
Balance: snapshot.User.Balance,
|
||||
Concurrency: snapshot.User.Concurrency,
|
||||
},
|
||||
}
|
||||
if snapshot.Group != nil {
|
||||
apiKey.Group = &Group{
|
||||
ID: snapshot.Group.ID,
|
||||
Name: snapshot.Group.Name,
|
||||
Platform: snapshot.Group.Platform,
|
||||
Status: snapshot.Group.Status,
|
||||
Hydrated: true,
|
||||
SubscriptionType: snapshot.Group.SubscriptionType,
|
||||
RateMultiplier: snapshot.Group.RateMultiplier,
|
||||
DailyLimitUSD: snapshot.Group.DailyLimitUSD,
|
||||
WeeklyLimitUSD: snapshot.Group.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: snapshot.Group.MonthlyLimitUSD,
|
||||
ImagePrice1K: snapshot.Group.ImagePrice1K,
|
||||
ImagePrice2K: snapshot.Group.ImagePrice2K,
|
||||
ImagePrice4K: snapshot.Group.ImagePrice4K,
|
||||
ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly,
|
||||
FallbackGroupID: snapshot.Group.FallbackGroupID,
|
||||
}
|
||||
}
|
||||
return apiKey
|
||||
}
|
||||
48
backend/internal/service/api_key_auth_cache_invalidate.go
Normal file
48
backend/internal/service/api_key_auth_cache_invalidate.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package service
|
||||
|
||||
import "context"
|
||||
|
||||
// InvalidateAuthCacheByKey 清除指定 API Key 的认证缓存
|
||||
func (s *APIKeyService) InvalidateAuthCacheByKey(ctx context.Context, key string) {
|
||||
if key == "" {
|
||||
return
|
||||
}
|
||||
cacheKey := s.authCacheKey(key)
|
||||
s.deleteAuthCache(ctx, cacheKey)
|
||||
}
|
||||
|
||||
// InvalidateAuthCacheByUserID 清除用户相关的 API Key 认证缓存
|
||||
func (s *APIKeyService) InvalidateAuthCacheByUserID(ctx context.Context, userID int64) {
|
||||
if userID <= 0 {
|
||||
return
|
||||
}
|
||||
keys, err := s.apiKeyRepo.ListKeysByUserID(ctx, userID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
s.deleteAuthCacheByKeys(ctx, keys)
|
||||
}
|
||||
|
||||
// InvalidateAuthCacheByGroupID 清除分组相关的 API Key 认证缓存
|
||||
func (s *APIKeyService) InvalidateAuthCacheByGroupID(ctx context.Context, groupID int64) {
|
||||
if groupID <= 0 {
|
||||
return
|
||||
}
|
||||
keys, err := s.apiKeyRepo.ListKeysByGroupID(ctx, groupID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
s.deleteAuthCacheByKeys(ctx, keys)
|
||||
}
|
||||
|
||||
func (s *APIKeyService) deleteAuthCacheByKeys(ctx context.Context, keys []string) {
|
||||
if len(keys) == 0 {
|
||||
return
|
||||
}
|
||||
for _, key := range keys {
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
s.deleteAuthCache(ctx, s.authCacheKey(key))
|
||||
}
|
||||
}
|
||||
@@ -12,6 +12,8 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
"github.com/dgraph-io/ristretto"
|
||||
"golang.org/x/sync/singleflight"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -31,9 +33,11 @@ const (
|
||||
type APIKeyRepository interface {
|
||||
Create(ctx context.Context, key *APIKey) error
|
||||
GetByID(ctx context.Context, id int64) (*APIKey, error)
|
||||
// GetOwnerID 仅获取 API Key 的所有者 ID,用于删除前的轻量级权限验证
|
||||
GetOwnerID(ctx context.Context, id int64) (int64, error)
|
||||
// GetKeyAndOwnerID 仅获取 API Key 的 key 与所有者 ID,用于删除等轻量场景
|
||||
GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error)
|
||||
GetByKey(ctx context.Context, key string) (*APIKey, error)
|
||||
// GetByKeyForAuth 认证专用查询,返回最小字段集
|
||||
GetByKeyForAuth(ctx context.Context, key string) (*APIKey, error)
|
||||
Update(ctx context.Context, key *APIKey) error
|
||||
Delete(ctx context.Context, id int64) error
|
||||
|
||||
@@ -45,6 +49,8 @@ type APIKeyRepository interface {
|
||||
SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error)
|
||||
ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error)
|
||||
CountByGroupID(ctx context.Context, groupID int64) (int64, error)
|
||||
ListKeysByUserID(ctx context.Context, userID int64) ([]string, error)
|
||||
ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error)
|
||||
}
|
||||
|
||||
// APIKeyCache defines cache operations for API key service
|
||||
@@ -55,6 +61,17 @@ type APIKeyCache interface {
|
||||
|
||||
IncrementDailyUsage(ctx context.Context, apiKey string) error
|
||||
SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error
|
||||
|
||||
GetAuthCache(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error)
|
||||
SetAuthCache(ctx context.Context, key string, entry *APIKeyAuthCacheEntry, ttl time.Duration) error
|
||||
DeleteAuthCache(ctx context.Context, key string) error
|
||||
}
|
||||
|
||||
// APIKeyAuthCacheInvalidator 提供认证缓存失效能力
|
||||
type APIKeyAuthCacheInvalidator interface {
|
||||
InvalidateAuthCacheByKey(ctx context.Context, key string)
|
||||
InvalidateAuthCacheByUserID(ctx context.Context, userID int64)
|
||||
InvalidateAuthCacheByGroupID(ctx context.Context, groupID int64)
|
||||
}
|
||||
|
||||
// CreateAPIKeyRequest 创建API Key请求
|
||||
@@ -83,6 +100,9 @@ type APIKeyService struct {
|
||||
userSubRepo UserSubscriptionRepository
|
||||
cache APIKeyCache
|
||||
cfg *config.Config
|
||||
authCacheL1 *ristretto.Cache
|
||||
authCfg apiKeyAuthCacheConfig
|
||||
authGroup singleflight.Group
|
||||
}
|
||||
|
||||
// NewAPIKeyService 创建API Key服务实例
|
||||
@@ -94,7 +114,7 @@ func NewAPIKeyService(
|
||||
cache APIKeyCache,
|
||||
cfg *config.Config,
|
||||
) *APIKeyService {
|
||||
return &APIKeyService{
|
||||
svc := &APIKeyService{
|
||||
apiKeyRepo: apiKeyRepo,
|
||||
userRepo: userRepo,
|
||||
groupRepo: groupRepo,
|
||||
@@ -102,6 +122,8 @@ func NewAPIKeyService(
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
}
|
||||
svc.initAuthCache(cfg)
|
||||
return svc
|
||||
}
|
||||
|
||||
// GenerateKey 生成随机API Key
|
||||
@@ -269,6 +291,8 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK
|
||||
return nil, fmt.Errorf("create api key: %w", err)
|
||||
}
|
||||
|
||||
s.InvalidateAuthCacheByKey(ctx, apiKey.Key)
|
||||
|
||||
return apiKey, nil
|
||||
}
|
||||
|
||||
@@ -304,21 +328,49 @@ func (s *APIKeyService) GetByID(ctx context.Context, id int64) (*APIKey, error)
|
||||
|
||||
// GetByKey 根据Key字符串获取API Key(用于认证)
|
||||
func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, error) {
|
||||
// 尝试从Redis缓存获取
|
||||
cacheKey := fmt.Sprintf("apikey:%s", key)
|
||||
cacheKey := s.authCacheKey(key)
|
||||
|
||||
// 这里可以添加Redis缓存逻辑,暂时直接查询数据库
|
||||
apiKey, err := s.apiKeyRepo.GetByKey(ctx, key)
|
||||
if entry, ok := s.getAuthCacheEntry(ctx, cacheKey); ok {
|
||||
if apiKey, used, err := s.applyAuthCacheEntry(key, entry); used {
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get api key: %w", err)
|
||||
}
|
||||
return apiKey, nil
|
||||
}
|
||||
}
|
||||
|
||||
if s.authCfg.singleflight {
|
||||
value, err, _ := s.authGroup.Do(cacheKey, func() (any, error) {
|
||||
return s.loadAuthCacheEntry(ctx, key, cacheKey)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
entry, _ := value.(*APIKeyAuthCacheEntry)
|
||||
if apiKey, used, err := s.applyAuthCacheEntry(key, entry); used {
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get api key: %w", err)
|
||||
}
|
||||
return apiKey, nil
|
||||
}
|
||||
} else {
|
||||
entry, err := s.loadAuthCacheEntry(ctx, key, cacheKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if apiKey, used, err := s.applyAuthCacheEntry(key, entry); used {
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get api key: %w", err)
|
||||
}
|
||||
return apiKey, nil
|
||||
}
|
||||
}
|
||||
|
||||
apiKey, err := s.apiKeyRepo.GetByKeyForAuth(ctx, key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get api key: %w", err)
|
||||
}
|
||||
|
||||
// 缓存到Redis(可选,TTL设置为5分钟)
|
||||
if s.cache != nil {
|
||||
// 这里可以序列化并缓存API Key
|
||||
_ = cacheKey // 使用变量避免未使用错误
|
||||
}
|
||||
|
||||
apiKey.Key = key
|
||||
return apiKey, nil
|
||||
}
|
||||
|
||||
@@ -388,15 +440,14 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req
|
||||
return nil, fmt.Errorf("update api key: %w", err)
|
||||
}
|
||||
|
||||
s.InvalidateAuthCacheByKey(ctx, apiKey.Key)
|
||||
|
||||
return apiKey, nil
|
||||
}
|
||||
|
||||
// Delete 删除API Key
|
||||
// 优化:使用 GetOwnerID 替代 GetByID 进行权限验证,
|
||||
// 避免加载完整 APIKey 对象及其关联数据(User、Group),提升删除操作的性能
|
||||
func (s *APIKeyService) Delete(ctx context.Context, id int64, userID int64) error {
|
||||
// 仅获取所有者 ID 用于权限验证,而非加载完整对象
|
||||
ownerID, err := s.apiKeyRepo.GetOwnerID(ctx, id)
|
||||
key, ownerID, err := s.apiKeyRepo.GetKeyAndOwnerID(ctx, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get api key: %w", err)
|
||||
}
|
||||
@@ -406,10 +457,11 @@ func (s *APIKeyService) Delete(ctx context.Context, id int64, userID int64) erro
|
||||
return ErrInsufficientPerms
|
||||
}
|
||||
|
||||
// 清除Redis缓存(使用 ownerID 而非 apiKey.UserID)
|
||||
// 清除Redis缓存(使用 userID 而非 apiKey.UserID)
|
||||
if s.cache != nil {
|
||||
_ = s.cache.DeleteCreateAttemptCount(ctx, ownerID)
|
||||
_ = s.cache.DeleteCreateAttemptCount(ctx, userID)
|
||||
}
|
||||
s.InvalidateAuthCacheByKey(ctx, key)
|
||||
|
||||
if err := s.apiKeyRepo.Delete(ctx, id); err != nil {
|
||||
return fmt.Errorf("delete api key: %w", err)
|
||||
|
||||
417
backend/internal/service/api_key_service_cache_test.go
Normal file
417
backend/internal/service/api_key_service_cache_test.go
Normal file
@@ -0,0 +1,417 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type authRepoStub struct {
|
||||
getByKeyForAuth func(ctx context.Context, key string) (*APIKey, error)
|
||||
listKeysByUserID func(ctx context.Context, userID int64) ([]string, error)
|
||||
listKeysByGroupID func(ctx context.Context, groupID int64) ([]string, error)
|
||||
}
|
||||
|
||||
func (s *authRepoStub) Create(ctx context.Context, key *APIKey) error {
|
||||
panic("unexpected Create call")
|
||||
}
|
||||
|
||||
func (s *authRepoStub) GetByID(ctx context.Context, id int64) (*APIKey, error) {
|
||||
panic("unexpected GetByID call")
|
||||
}
|
||||
|
||||
func (s *authRepoStub) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) {
|
||||
panic("unexpected GetKeyAndOwnerID call")
|
||||
}
|
||||
|
||||
func (s *authRepoStub) GetByKey(ctx context.Context, key string) (*APIKey, error) {
|
||||
panic("unexpected GetByKey call")
|
||||
}
|
||||
|
||||
func (s *authRepoStub) GetByKeyForAuth(ctx context.Context, key string) (*APIKey, error) {
|
||||
if s.getByKeyForAuth == nil {
|
||||
panic("unexpected GetByKeyForAuth call")
|
||||
}
|
||||
return s.getByKeyForAuth(ctx, key)
|
||||
}
|
||||
|
||||
func (s *authRepoStub) Update(ctx context.Context, key *APIKey) error {
|
||||
panic("unexpected Update call")
|
||||
}
|
||||
|
||||
func (s *authRepoStub) Delete(ctx context.Context, id int64) error {
|
||||
panic("unexpected Delete call")
|
||||
}
|
||||
|
||||
func (s *authRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListByUserID call")
|
||||
}
|
||||
|
||||
func (s *authRepoStub) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
|
||||
panic("unexpected VerifyOwnership call")
|
||||
}
|
||||
|
||||
func (s *authRepoStub) CountByUserID(ctx context.Context, userID int64) (int64, error) {
|
||||
panic("unexpected CountByUserID call")
|
||||
}
|
||||
|
||||
func (s *authRepoStub) ExistsByKey(ctx context.Context, key string) (bool, error) {
|
||||
panic("unexpected ExistsByKey call")
|
||||
}
|
||||
|
||||
func (s *authRepoStub) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListByGroupID call")
|
||||
}
|
||||
|
||||
func (s *authRepoStub) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error) {
|
||||
panic("unexpected SearchAPIKeys call")
|
||||
}
|
||||
|
||||
func (s *authRepoStub) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
panic("unexpected ClearGroupIDByGroupID call")
|
||||
}
|
||||
|
||||
func (s *authRepoStub) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
panic("unexpected CountByGroupID call")
|
||||
}
|
||||
|
||||
func (s *authRepoStub) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) {
|
||||
if s.listKeysByUserID == nil {
|
||||
panic("unexpected ListKeysByUserID call")
|
||||
}
|
||||
return s.listKeysByUserID(ctx, userID)
|
||||
}
|
||||
|
||||
func (s *authRepoStub) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) {
|
||||
if s.listKeysByGroupID == nil {
|
||||
panic("unexpected ListKeysByGroupID call")
|
||||
}
|
||||
return s.listKeysByGroupID(ctx, groupID)
|
||||
}
|
||||
|
||||
type authCacheStub struct {
|
||||
getAuthCache func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error)
|
||||
setAuthKeys []string
|
||||
deleteAuthKeys []string
|
||||
}
|
||||
|
||||
func (s *authCacheStub) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *authCacheStub) IncrementCreateAttemptCount(ctx context.Context, userID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *authCacheStub) DeleteCreateAttemptCount(ctx context.Context, userID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *authCacheStub) IncrementDailyUsage(ctx context.Context, apiKey string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *authCacheStub) SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *authCacheStub) GetAuthCache(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
|
||||
if s.getAuthCache == nil {
|
||||
return nil, redis.Nil
|
||||
}
|
||||
return s.getAuthCache(ctx, key)
|
||||
}
|
||||
|
||||
func (s *authCacheStub) SetAuthCache(ctx context.Context, key string, entry *APIKeyAuthCacheEntry, ttl time.Duration) error {
|
||||
s.setAuthKeys = append(s.setAuthKeys, key)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *authCacheStub) DeleteAuthCache(ctx context.Context, key string) error {
|
||||
s.deleteAuthKeys = append(s.deleteAuthKeys, key)
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) {
|
||||
cache := &authCacheStub{}
|
||||
repo := &authRepoStub{
|
||||
getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) {
|
||||
return nil, errors.New("unexpected repo call")
|
||||
},
|
||||
}
|
||||
cfg := &config.Config{
|
||||
APIKeyAuth: config.APIKeyAuthCacheConfig{
|
||||
L2TTLSeconds: 60,
|
||||
NegativeTTLSeconds: 30,
|
||||
},
|
||||
}
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||
|
||||
groupID := int64(9)
|
||||
cacheEntry := &APIKeyAuthCacheEntry{
|
||||
Snapshot: &APIKeyAuthSnapshot{
|
||||
APIKeyID: 1,
|
||||
UserID: 2,
|
||||
GroupID: &groupID,
|
||||
Status: StatusActive,
|
||||
User: APIKeyAuthUserSnapshot{
|
||||
ID: 2,
|
||||
Status: StatusActive,
|
||||
Role: RoleUser,
|
||||
Balance: 10,
|
||||
Concurrency: 3,
|
||||
},
|
||||
Group: &APIKeyAuthGroupSnapshot{
|
||||
ID: groupID,
|
||||
Name: "g",
|
||||
Platform: PlatformAnthropic,
|
||||
Status: StatusActive,
|
||||
SubscriptionType: SubscriptionTypeStandard,
|
||||
RateMultiplier: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
|
||||
return cacheEntry, nil
|
||||
}
|
||||
|
||||
apiKey, err := svc.GetByKey(context.Background(), "k1")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1), apiKey.ID)
|
||||
require.Equal(t, int64(2), apiKey.User.ID)
|
||||
require.Equal(t, groupID, apiKey.Group.ID)
|
||||
}
|
||||
|
||||
func TestAPIKeyService_GetByKey_NegativeCache(t *testing.T) {
|
||||
cache := &authCacheStub{}
|
||||
repo := &authRepoStub{
|
||||
getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) {
|
||||
return nil, errors.New("unexpected repo call")
|
||||
},
|
||||
}
|
||||
cfg := &config.Config{
|
||||
APIKeyAuth: config.APIKeyAuthCacheConfig{
|
||||
L2TTLSeconds: 60,
|
||||
NegativeTTLSeconds: 30,
|
||||
},
|
||||
}
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
|
||||
return &APIKeyAuthCacheEntry{NotFound: true}, nil
|
||||
}
|
||||
|
||||
_, err := svc.GetByKey(context.Background(), "missing")
|
||||
require.ErrorIs(t, err, ErrAPIKeyNotFound)
|
||||
}
|
||||
|
||||
func TestAPIKeyService_GetByKey_CacheMissStoresL2(t *testing.T) {
|
||||
cache := &authCacheStub{}
|
||||
repo := &authRepoStub{
|
||||
getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) {
|
||||
return &APIKey{
|
||||
ID: 5,
|
||||
UserID: 7,
|
||||
Status: StatusActive,
|
||||
User: &User{
|
||||
ID: 7,
|
||||
Status: StatusActive,
|
||||
Role: RoleUser,
|
||||
Balance: 12,
|
||||
Concurrency: 2,
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
cfg := &config.Config{
|
||||
APIKeyAuth: config.APIKeyAuthCacheConfig{
|
||||
L2TTLSeconds: 60,
|
||||
NegativeTTLSeconds: 30,
|
||||
},
|
||||
}
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
|
||||
return nil, redis.Nil
|
||||
}
|
||||
|
||||
apiKey, err := svc.GetByKey(context.Background(), "k2")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(5), apiKey.ID)
|
||||
require.Len(t, cache.setAuthKeys, 1)
|
||||
}
|
||||
|
||||
func TestAPIKeyService_GetByKey_UsesL1Cache(t *testing.T) {
|
||||
var calls int32
|
||||
cache := &authCacheStub{}
|
||||
repo := &authRepoStub{
|
||||
getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) {
|
||||
atomic.AddInt32(&calls, 1)
|
||||
return &APIKey{
|
||||
ID: 21,
|
||||
UserID: 3,
|
||||
Status: StatusActive,
|
||||
User: &User{
|
||||
ID: 3,
|
||||
Status: StatusActive,
|
||||
Role: RoleUser,
|
||||
Balance: 5,
|
||||
Concurrency: 2,
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
cfg := &config.Config{
|
||||
APIKeyAuth: config.APIKeyAuthCacheConfig{
|
||||
L1Size: 1000,
|
||||
L1TTLSeconds: 60,
|
||||
},
|
||||
}
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||
require.NotNil(t, svc.authCacheL1)
|
||||
|
||||
_, err := svc.GetByKey(context.Background(), "k-l1")
|
||||
require.NoError(t, err)
|
||||
svc.authCacheL1.Wait()
|
||||
cacheKey := svc.authCacheKey("k-l1")
|
||||
_, ok := svc.authCacheL1.Get(cacheKey)
|
||||
require.True(t, ok)
|
||||
_, err = svc.GetByKey(context.Background(), "k-l1")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&calls))
|
||||
}
|
||||
|
||||
func TestAPIKeyService_InvalidateAuthCacheByUserID(t *testing.T) {
|
||||
cache := &authCacheStub{}
|
||||
repo := &authRepoStub{
|
||||
listKeysByUserID: func(ctx context.Context, userID int64) ([]string, error) {
|
||||
return []string{"k1", "k2"}, nil
|
||||
},
|
||||
}
|
||||
cfg := &config.Config{
|
||||
APIKeyAuth: config.APIKeyAuthCacheConfig{
|
||||
L2TTLSeconds: 60,
|
||||
NegativeTTLSeconds: 30,
|
||||
},
|
||||
}
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||
|
||||
svc.InvalidateAuthCacheByUserID(context.Background(), 7)
|
||||
require.Len(t, cache.deleteAuthKeys, 2)
|
||||
}
|
||||
|
||||
func TestAPIKeyService_InvalidateAuthCacheByGroupID(t *testing.T) {
|
||||
cache := &authCacheStub{}
|
||||
repo := &authRepoStub{
|
||||
listKeysByGroupID: func(ctx context.Context, groupID int64) ([]string, error) {
|
||||
return []string{"k1", "k2"}, nil
|
||||
},
|
||||
}
|
||||
cfg := &config.Config{
|
||||
APIKeyAuth: config.APIKeyAuthCacheConfig{
|
||||
L2TTLSeconds: 60,
|
||||
},
|
||||
}
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||
|
||||
svc.InvalidateAuthCacheByGroupID(context.Background(), 9)
|
||||
require.Len(t, cache.deleteAuthKeys, 2)
|
||||
}
|
||||
|
||||
func TestAPIKeyService_InvalidateAuthCacheByKey(t *testing.T) {
|
||||
cache := &authCacheStub{}
|
||||
repo := &authRepoStub{
|
||||
listKeysByUserID: func(ctx context.Context, userID int64) ([]string, error) {
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
cfg := &config.Config{
|
||||
APIKeyAuth: config.APIKeyAuthCacheConfig{
|
||||
L2TTLSeconds: 60,
|
||||
},
|
||||
}
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||
|
||||
svc.InvalidateAuthCacheByKey(context.Background(), "k1")
|
||||
require.Len(t, cache.deleteAuthKeys, 1)
|
||||
}
|
||||
|
||||
func TestAPIKeyService_GetByKey_CachesNegativeOnRepoMiss(t *testing.T) {
|
||||
cache := &authCacheStub{}
|
||||
repo := &authRepoStub{
|
||||
getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) {
|
||||
return nil, ErrAPIKeyNotFound
|
||||
},
|
||||
}
|
||||
cfg := &config.Config{
|
||||
APIKeyAuth: config.APIKeyAuthCacheConfig{
|
||||
L2TTLSeconds: 60,
|
||||
NegativeTTLSeconds: 30,
|
||||
},
|
||||
}
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
|
||||
return nil, redis.Nil
|
||||
}
|
||||
|
||||
_, err := svc.GetByKey(context.Background(), "missing")
|
||||
require.ErrorIs(t, err, ErrAPIKeyNotFound)
|
||||
require.Len(t, cache.setAuthKeys, 1)
|
||||
}
|
||||
|
||||
func TestAPIKeyService_GetByKey_SingleflightCollapses(t *testing.T) {
|
||||
var calls int32
|
||||
cache := &authCacheStub{}
|
||||
repo := &authRepoStub{
|
||||
getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) {
|
||||
atomic.AddInt32(&calls, 1)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
return &APIKey{
|
||||
ID: 11,
|
||||
UserID: 2,
|
||||
Status: StatusActive,
|
||||
User: &User{
|
||||
ID: 2,
|
||||
Status: StatusActive,
|
||||
Role: RoleUser,
|
||||
Balance: 1,
|
||||
Concurrency: 1,
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
cfg := &config.Config{
|
||||
APIKeyAuth: config.APIKeyAuthCacheConfig{
|
||||
Singleflight: true,
|
||||
},
|
||||
}
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||
|
||||
start := make(chan struct{})
|
||||
wg := sync.WaitGroup{}
|
||||
errs := make([]error, 5)
|
||||
for i := 0; i < 5; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
<-start
|
||||
_, err := svc.GetByKey(context.Background(), "k1")
|
||||
errs[idx] = err
|
||||
}(i)
|
||||
}
|
||||
close(start)
|
||||
wg.Wait()
|
||||
|
||||
for _, err := range errs {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&calls))
|
||||
}
|
||||
@@ -20,13 +20,12 @@ import (
|
||||
// 用于隔离测试 APIKeyService.Delete 方法,避免依赖真实数据库。
|
||||
//
|
||||
// 设计说明:
|
||||
// - ownerID: 模拟 GetOwnerID 返回的所有者 ID
|
||||
// - ownerErr: 模拟 GetOwnerID 返回的错误(如 ErrAPIKeyNotFound)
|
||||
// - apiKey/getByIDErr: 模拟 GetKeyAndOwnerID 返回的记录与错误
|
||||
// - deleteErr: 模拟 Delete 返回的错误
|
||||
// - deletedIDs: 记录被调用删除的 API Key ID,用于断言验证
|
||||
type apiKeyRepoStub struct {
|
||||
ownerID int64 // GetOwnerID 的返回值
|
||||
ownerErr error // GetOwnerID 的错误返回值
|
||||
apiKey *APIKey // GetKeyAndOwnerID 的返回值
|
||||
getByIDErr error // GetKeyAndOwnerID 的错误返回值
|
||||
deleteErr error // Delete 的错误返回值
|
||||
deletedIDs []int64 // 记录已删除的 API Key ID 列表
|
||||
}
|
||||
@@ -38,19 +37,34 @@ func (s *apiKeyRepoStub) Create(ctx context.Context, key *APIKey) error {
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) GetByID(ctx context.Context, id int64) (*APIKey, error) {
|
||||
if s.getByIDErr != nil {
|
||||
return nil, s.getByIDErr
|
||||
}
|
||||
if s.apiKey != nil {
|
||||
clone := *s.apiKey
|
||||
return &clone, nil
|
||||
}
|
||||
panic("unexpected GetByID call")
|
||||
}
|
||||
|
||||
// GetOwnerID 返回预设的所有者 ID 或错误。
|
||||
// 这是 Delete 方法调用的第一个仓储方法,用于验证调用者是否为 API Key 的所有者。
|
||||
func (s *apiKeyRepoStub) GetOwnerID(ctx context.Context, id int64) (int64, error) {
|
||||
return s.ownerID, s.ownerErr
|
||||
func (s *apiKeyRepoStub) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) {
|
||||
if s.getByIDErr != nil {
|
||||
return "", 0, s.getByIDErr
|
||||
}
|
||||
if s.apiKey != nil {
|
||||
return s.apiKey.Key, s.apiKey.UserID, nil
|
||||
}
|
||||
return "", 0, ErrAPIKeyNotFound
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) GetByKey(ctx context.Context, key string) (*APIKey, error) {
|
||||
panic("unexpected GetByKey call")
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) GetByKeyForAuth(ctx context.Context, key string) (*APIKey, error) {
|
||||
panic("unexpected GetByKeyForAuth call")
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) Update(ctx context.Context, key *APIKey) error {
|
||||
panic("unexpected Update call")
|
||||
}
|
||||
@@ -96,13 +110,22 @@ func (s *apiKeyRepoStub) CountByGroupID(ctx context.Context, groupID int64) (int
|
||||
panic("unexpected CountByGroupID call")
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) {
|
||||
panic("unexpected ListKeysByUserID call")
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) {
|
||||
panic("unexpected ListKeysByGroupID call")
|
||||
}
|
||||
|
||||
// apiKeyCacheStub 是 APIKeyCache 接口的测试桩实现。
|
||||
// 用于验证删除操作时缓存清理逻辑是否被正确调用。
|
||||
//
|
||||
// 设计说明:
|
||||
// - invalidated: 记录被清除缓存的用户 ID 列表
|
||||
type apiKeyCacheStub struct {
|
||||
invalidated []int64 // 记录调用 DeleteCreateAttemptCount 时传入的用户 ID
|
||||
invalidated []int64 // 记录调用 DeleteCreateAttemptCount 时传入的用户 ID
|
||||
deleteAuthKeys []string // 记录调用 DeleteAuthCache 时传入的缓存 key
|
||||
}
|
||||
|
||||
// GetCreateAttemptCount 返回 0,表示用户未超过创建次数限制
|
||||
@@ -132,15 +155,30 @@ func (s *apiKeyCacheStub) SetDailyUsageExpiry(ctx context.Context, apiKey string
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *apiKeyCacheStub) GetAuthCache(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *apiKeyCacheStub) SetAuthCache(ctx context.Context, key string, entry *APIKeyAuthCacheEntry, ttl time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *apiKeyCacheStub) DeleteAuthCache(ctx context.Context, key string) error {
|
||||
s.deleteAuthKeys = append(s.deleteAuthKeys, key)
|
||||
return nil
|
||||
}
|
||||
|
||||
// TestApiKeyService_Delete_OwnerMismatch 测试非所有者尝试删除时返回权限错误。
|
||||
// 预期行为:
|
||||
// - GetOwnerID 返回所有者 ID 为 1
|
||||
// - GetKeyAndOwnerID 返回所有者 ID 为 1
|
||||
// - 调用者 userID 为 2(不匹配)
|
||||
// - 返回 ErrInsufficientPerms 错误
|
||||
// - Delete 方法不被调用
|
||||
// - 缓存不被清除
|
||||
func TestApiKeyService_Delete_OwnerMismatch(t *testing.T) {
|
||||
repo := &apiKeyRepoStub{ownerID: 1}
|
||||
repo := &apiKeyRepoStub{
|
||||
apiKey: &APIKey{ID: 10, UserID: 1, Key: "k"},
|
||||
}
|
||||
cache := &apiKeyCacheStub{}
|
||||
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
|
||||
|
||||
@@ -148,17 +186,20 @@ func TestApiKeyService_Delete_OwnerMismatch(t *testing.T) {
|
||||
require.ErrorIs(t, err, ErrInsufficientPerms)
|
||||
require.Empty(t, repo.deletedIDs) // 验证删除操作未被调用
|
||||
require.Empty(t, cache.invalidated) // 验证缓存未被清除
|
||||
require.Empty(t, cache.deleteAuthKeys)
|
||||
}
|
||||
|
||||
// TestApiKeyService_Delete_Success 测试所有者成功删除 API Key 的场景。
|
||||
// 预期行为:
|
||||
// - GetOwnerID 返回所有者 ID 为 7
|
||||
// - GetKeyAndOwnerID 返回所有者 ID 为 7
|
||||
// - 调用者 userID 为 7(匹配)
|
||||
// - Delete 成功执行
|
||||
// - 缓存被正确清除(使用 ownerID)
|
||||
// - 返回 nil 错误
|
||||
func TestApiKeyService_Delete_Success(t *testing.T) {
|
||||
repo := &apiKeyRepoStub{ownerID: 7}
|
||||
repo := &apiKeyRepoStub{
|
||||
apiKey: &APIKey{ID: 42, UserID: 7, Key: "k"},
|
||||
}
|
||||
cache := &apiKeyCacheStub{}
|
||||
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
|
||||
|
||||
@@ -166,16 +207,17 @@ func TestApiKeyService_Delete_Success(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []int64{42}, repo.deletedIDs) // 验证正确的 API Key 被删除
|
||||
require.Equal(t, []int64{7}, cache.invalidated) // 验证所有者的缓存被清除
|
||||
require.Equal(t, []string{svc.authCacheKey("k")}, cache.deleteAuthKeys)
|
||||
}
|
||||
|
||||
// TestApiKeyService_Delete_NotFound 测试删除不存在的 API Key 时返回正确的错误。
|
||||
// 预期行为:
|
||||
// - GetOwnerID 返回 ErrAPIKeyNotFound 错误
|
||||
// - GetKeyAndOwnerID 返回 ErrAPIKeyNotFound 错误
|
||||
// - 返回 ErrAPIKeyNotFound 错误(被 fmt.Errorf 包装)
|
||||
// - Delete 方法不被调用
|
||||
// - 缓存不被清除
|
||||
func TestApiKeyService_Delete_NotFound(t *testing.T) {
|
||||
repo := &apiKeyRepoStub{ownerErr: ErrAPIKeyNotFound}
|
||||
repo := &apiKeyRepoStub{getByIDErr: ErrAPIKeyNotFound}
|
||||
cache := &apiKeyCacheStub{}
|
||||
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
|
||||
|
||||
@@ -183,18 +225,19 @@ func TestApiKeyService_Delete_NotFound(t *testing.T) {
|
||||
require.ErrorIs(t, err, ErrAPIKeyNotFound)
|
||||
require.Empty(t, repo.deletedIDs)
|
||||
require.Empty(t, cache.invalidated)
|
||||
require.Empty(t, cache.deleteAuthKeys)
|
||||
}
|
||||
|
||||
// TestApiKeyService_Delete_DeleteFails 测试删除操作失败时的错误处理。
|
||||
// 预期行为:
|
||||
// - GetOwnerID 返回正确的所有者 ID
|
||||
// - GetKeyAndOwnerID 返回正确的所有者 ID
|
||||
// - 所有权验证通过
|
||||
// - 缓存被清除(在删除之前)
|
||||
// - Delete 被调用但返回错误
|
||||
// - 返回包含 "delete api key" 的错误信息
|
||||
func TestApiKeyService_Delete_DeleteFails(t *testing.T) {
|
||||
repo := &apiKeyRepoStub{
|
||||
ownerID: 3,
|
||||
apiKey: &APIKey{ID: 42, UserID: 3, Key: "k"},
|
||||
deleteErr: errors.New("delete failed"),
|
||||
}
|
||||
cache := &apiKeyCacheStub{}
|
||||
@@ -205,4 +248,5 @@ func TestApiKeyService_Delete_DeleteFails(t *testing.T) {
|
||||
require.ErrorContains(t, err, "delete api key")
|
||||
require.Equal(t, []int64{3}, repo.deletedIDs) // 验证删除操作被调用
|
||||
require.Equal(t, []int64{3}, cache.invalidated) // 验证缓存已被清除(即使删除失败)
|
||||
require.Equal(t, []string{svc.authCacheKey("k")}, cache.deleteAuthKeys)
|
||||
}
|
||||
|
||||
33
backend/internal/service/auth_cache_invalidation_test.go
Normal file
33
backend/internal/service/auth_cache_invalidation_test.go
Normal file
@@ -0,0 +1,33 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestUsageService_InvalidateUsageCaches(t *testing.T) {
|
||||
invalidator := &authCacheInvalidatorStub{}
|
||||
svc := &UsageService{authCacheInvalidator: invalidator}
|
||||
|
||||
svc.invalidateUsageCaches(context.Background(), 7, false)
|
||||
require.Empty(t, invalidator.userIDs)
|
||||
|
||||
svc.invalidateUsageCaches(context.Background(), 7, true)
|
||||
require.Equal(t, []int64{7}, invalidator.userIDs)
|
||||
}
|
||||
|
||||
func TestRedeemService_InvalidateRedeemCaches_AuthCache(t *testing.T) {
|
||||
invalidator := &authCacheInvalidatorStub{}
|
||||
svc := &RedeemService{authCacheInvalidator: invalidator}
|
||||
|
||||
svc.invalidateRedeemCaches(context.Background(), 11, &RedeemCode{Type: RedeemTypeBalance})
|
||||
svc.invalidateRedeemCaches(context.Background(), 11, &RedeemCode{Type: RedeemTypeConcurrency})
|
||||
groupID := int64(3)
|
||||
svc.invalidateRedeemCaches(context.Background(), 11, &RedeemCode{Type: RedeemTypeSubscription, GroupID: &groupID})
|
||||
|
||||
require.Equal(t, []int64{11, 11, 11}, invalidator.userIDs)
|
||||
}
|
||||
242
backend/internal/service/dashboard_aggregation_service.go
Normal file
242
backend/internal/service/dashboard_aggregation_service.go
Normal file
@@ -0,0 +1,242 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultDashboardAggregationTimeout = 2 * time.Minute
|
||||
defaultDashboardAggregationBackfillTimeout = 30 * time.Minute
|
||||
dashboardAggregationRetentionInterval = 6 * time.Hour
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrDashboardBackfillDisabled 当配置禁用回填时返回。
|
||||
ErrDashboardBackfillDisabled = errors.New("仪表盘聚合回填已禁用")
|
||||
// ErrDashboardBackfillTooLarge 当回填跨度超过限制时返回。
|
||||
ErrDashboardBackfillTooLarge = errors.New("回填时间跨度过大")
|
||||
)
|
||||
|
||||
// DashboardAggregationRepository 定义仪表盘预聚合仓储接口。
|
||||
type DashboardAggregationRepository interface {
|
||||
AggregateRange(ctx context.Context, start, end time.Time) error
|
||||
GetAggregationWatermark(ctx context.Context) (time.Time, error)
|
||||
UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error
|
||||
CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error
|
||||
CleanupUsageLogs(ctx context.Context, cutoff time.Time) error
|
||||
EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error
|
||||
}
|
||||
|
||||
// DashboardAggregationService 负责定时聚合与回填。
|
||||
type DashboardAggregationService struct {
|
||||
repo DashboardAggregationRepository
|
||||
timingWheel *TimingWheelService
|
||||
cfg config.DashboardAggregationConfig
|
||||
running int32
|
||||
lastRetentionCleanup atomic.Value // time.Time
|
||||
}
|
||||
|
||||
// NewDashboardAggregationService 创建聚合服务。
|
||||
func NewDashboardAggregationService(repo DashboardAggregationRepository, timingWheel *TimingWheelService, cfg *config.Config) *DashboardAggregationService {
|
||||
var aggCfg config.DashboardAggregationConfig
|
||||
if cfg != nil {
|
||||
aggCfg = cfg.DashboardAgg
|
||||
}
|
||||
return &DashboardAggregationService{
|
||||
repo: repo,
|
||||
timingWheel: timingWheel,
|
||||
cfg: aggCfg,
|
||||
}
|
||||
}
|
||||
|
||||
// Start 启动定时聚合作业(重启生效配置)。
|
||||
func (s *DashboardAggregationService) Start() {
|
||||
if s == nil || s.repo == nil || s.timingWheel == nil {
|
||||
return
|
||||
}
|
||||
if !s.cfg.Enabled {
|
||||
log.Printf("[DashboardAggregation] 聚合作业已禁用")
|
||||
return
|
||||
}
|
||||
|
||||
interval := time.Duration(s.cfg.IntervalSeconds) * time.Second
|
||||
if interval <= 0 {
|
||||
interval = time.Minute
|
||||
}
|
||||
|
||||
if s.cfg.RecomputeDays > 0 {
|
||||
go s.recomputeRecentDays()
|
||||
}
|
||||
|
||||
s.timingWheel.ScheduleRecurring("dashboard:aggregation", interval, func() {
|
||||
s.runScheduledAggregation()
|
||||
})
|
||||
log.Printf("[DashboardAggregation] 聚合作业启动 (interval=%v, lookback=%ds)", interval, s.cfg.LookbackSeconds)
|
||||
if !s.cfg.BackfillEnabled {
|
||||
log.Printf("[DashboardAggregation] 回填已禁用,如需补齐保留窗口以外历史数据请手动回填")
|
||||
}
|
||||
}
|
||||
|
||||
// TriggerBackfill 触发回填(异步)。
|
||||
func (s *DashboardAggregationService) TriggerBackfill(start, end time.Time) error {
|
||||
if s == nil || s.repo == nil {
|
||||
return errors.New("聚合服务未初始化")
|
||||
}
|
||||
if !s.cfg.BackfillEnabled {
|
||||
log.Printf("[DashboardAggregation] 回填被拒绝: backfill_enabled=false")
|
||||
return ErrDashboardBackfillDisabled
|
||||
}
|
||||
if !end.After(start) {
|
||||
return errors.New("回填时间范围无效")
|
||||
}
|
||||
if s.cfg.BackfillMaxDays > 0 {
|
||||
maxRange := time.Duration(s.cfg.BackfillMaxDays) * 24 * time.Hour
|
||||
if end.Sub(start) > maxRange {
|
||||
return ErrDashboardBackfillTooLarge
|
||||
}
|
||||
}
|
||||
|
||||
go func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultDashboardAggregationBackfillTimeout)
|
||||
defer cancel()
|
||||
if err := s.backfillRange(ctx, start, end); err != nil {
|
||||
log.Printf("[DashboardAggregation] 回填失败: %v", err)
|
||||
}
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *DashboardAggregationService) recomputeRecentDays() {
|
||||
days := s.cfg.RecomputeDays
|
||||
if days <= 0 {
|
||||
return
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
start := now.AddDate(0, 0, -days)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultDashboardAggregationBackfillTimeout)
|
||||
defer cancel()
|
||||
if err := s.backfillRange(ctx, start, now); err != nil {
|
||||
log.Printf("[DashboardAggregation] 启动重算失败: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DashboardAggregationService) runScheduledAggregation() {
|
||||
if !atomic.CompareAndSwapInt32(&s.running, 0, 1) {
|
||||
return
|
||||
}
|
||||
defer atomic.StoreInt32(&s.running, 0)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultDashboardAggregationTimeout)
|
||||
defer cancel()
|
||||
|
||||
now := time.Now().UTC()
|
||||
last, err := s.repo.GetAggregationWatermark(ctx)
|
||||
if err != nil {
|
||||
log.Printf("[DashboardAggregation] 读取水位失败: %v", err)
|
||||
last = time.Unix(0, 0).UTC()
|
||||
}
|
||||
|
||||
lookback := time.Duration(s.cfg.LookbackSeconds) * time.Second
|
||||
epoch := time.Unix(0, 0).UTC()
|
||||
start := last.Add(-lookback)
|
||||
if !last.After(epoch) {
|
||||
retentionDays := s.cfg.Retention.UsageLogsDays
|
||||
if retentionDays <= 0 {
|
||||
retentionDays = 1
|
||||
}
|
||||
start = truncateToDayUTC(now.AddDate(0, 0, -retentionDays))
|
||||
} else if start.After(now) {
|
||||
start = now.Add(-lookback)
|
||||
}
|
||||
|
||||
if err := s.aggregateRange(ctx, start, now); err != nil {
|
||||
log.Printf("[DashboardAggregation] 聚合失败: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.repo.UpdateAggregationWatermark(ctx, now); err != nil {
|
||||
log.Printf("[DashboardAggregation] 更新水位失败: %v", err)
|
||||
}
|
||||
|
||||
s.maybeCleanupRetention(ctx, now)
|
||||
}
|
||||
|
||||
func (s *DashboardAggregationService) backfillRange(ctx context.Context, start, end time.Time) error {
|
||||
if !atomic.CompareAndSwapInt32(&s.running, 0, 1) {
|
||||
return errors.New("聚合作业正在运行")
|
||||
}
|
||||
defer atomic.StoreInt32(&s.running, 0)
|
||||
|
||||
startUTC := start.UTC()
|
||||
endUTC := end.UTC()
|
||||
if !endUTC.After(startUTC) {
|
||||
return errors.New("回填时间范围无效")
|
||||
}
|
||||
|
||||
cursor := truncateToDayUTC(startUTC)
|
||||
for cursor.Before(endUTC) {
|
||||
windowEnd := cursor.Add(24 * time.Hour)
|
||||
if windowEnd.After(endUTC) {
|
||||
windowEnd = endUTC
|
||||
}
|
||||
if err := s.aggregateRange(ctx, cursor, windowEnd); err != nil {
|
||||
return err
|
||||
}
|
||||
cursor = windowEnd
|
||||
}
|
||||
|
||||
if err := s.repo.UpdateAggregationWatermark(ctx, endUTC); err != nil {
|
||||
log.Printf("[DashboardAggregation] 更新水位失败: %v", err)
|
||||
}
|
||||
|
||||
s.maybeCleanupRetention(ctx, endUTC)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *DashboardAggregationService) aggregateRange(ctx context.Context, start, end time.Time) error {
|
||||
if !end.After(start) {
|
||||
return nil
|
||||
}
|
||||
if err := s.repo.EnsureUsageLogsPartitions(ctx, end); err != nil {
|
||||
log.Printf("[DashboardAggregation] 分区检查失败: %v", err)
|
||||
}
|
||||
return s.repo.AggregateRange(ctx, start, end)
|
||||
}
|
||||
|
||||
func (s *DashboardAggregationService) maybeCleanupRetention(ctx context.Context, now time.Time) {
|
||||
lastAny := s.lastRetentionCleanup.Load()
|
||||
if lastAny != nil {
|
||||
if last, ok := lastAny.(time.Time); ok && now.Sub(last) < dashboardAggregationRetentionInterval {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
hourlyCutoff := now.AddDate(0, 0, -s.cfg.Retention.HourlyDays)
|
||||
dailyCutoff := now.AddDate(0, 0, -s.cfg.Retention.DailyDays)
|
||||
usageCutoff := now.AddDate(0, 0, -s.cfg.Retention.UsageLogsDays)
|
||||
|
||||
aggErr := s.repo.CleanupAggregates(ctx, hourlyCutoff, dailyCutoff)
|
||||
if aggErr != nil {
|
||||
log.Printf("[DashboardAggregation] 聚合保留清理失败: %v", aggErr)
|
||||
}
|
||||
usageErr := s.repo.CleanupUsageLogs(ctx, usageCutoff)
|
||||
if usageErr != nil {
|
||||
log.Printf("[DashboardAggregation] usage_logs 保留清理失败: %v", usageErr)
|
||||
}
|
||||
if aggErr == nil && usageErr == nil {
|
||||
s.lastRetentionCleanup.Store(now)
|
||||
}
|
||||
}
|
||||
|
||||
func truncateToDayUTC(t time.Time) time.Time {
|
||||
t = t.UTC()
|
||||
return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, time.UTC)
|
||||
}
|
||||
106
backend/internal/service/dashboard_aggregation_service_test.go
Normal file
106
backend/internal/service/dashboard_aggregation_service_test.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type dashboardAggregationRepoTestStub struct {
|
||||
aggregateCalls int
|
||||
lastStart time.Time
|
||||
lastEnd time.Time
|
||||
watermark time.Time
|
||||
aggregateErr error
|
||||
cleanupAggregatesErr error
|
||||
cleanupUsageErr error
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoTestStub) AggregateRange(ctx context.Context, start, end time.Time) error {
|
||||
s.aggregateCalls++
|
||||
s.lastStart = start
|
||||
s.lastEnd = end
|
||||
return s.aggregateErr
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoTestStub) GetAggregationWatermark(ctx context.Context) (time.Time, error) {
|
||||
return s.watermark, nil
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoTestStub) UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoTestStub) CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error {
|
||||
return s.cleanupAggregatesErr
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoTestStub) CleanupUsageLogs(ctx context.Context, cutoff time.Time) error {
|
||||
return s.cleanupUsageErr
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoTestStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestDashboardAggregationService_RunScheduledAggregation_EpochUsesRetentionStart(t *testing.T) {
|
||||
repo := &dashboardAggregationRepoTestStub{watermark: time.Unix(0, 0).UTC()}
|
||||
svc := &DashboardAggregationService{
|
||||
repo: repo,
|
||||
cfg: config.DashboardAggregationConfig{
|
||||
Enabled: true,
|
||||
IntervalSeconds: 60,
|
||||
LookbackSeconds: 120,
|
||||
Retention: config.DashboardAggregationRetentionConfig{
|
||||
UsageLogsDays: 1,
|
||||
HourlyDays: 1,
|
||||
DailyDays: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
svc.runScheduledAggregation()
|
||||
|
||||
require.Equal(t, 1, repo.aggregateCalls)
|
||||
require.False(t, repo.lastEnd.IsZero())
|
||||
require.Equal(t, truncateToDayUTC(repo.lastEnd.AddDate(0, 0, -1)), repo.lastStart)
|
||||
}
|
||||
|
||||
func TestDashboardAggregationService_CleanupRetentionFailure_DoesNotRecord(t *testing.T) {
|
||||
repo := &dashboardAggregationRepoTestStub{cleanupAggregatesErr: errors.New("清理失败")}
|
||||
svc := &DashboardAggregationService{
|
||||
repo: repo,
|
||||
cfg: config.DashboardAggregationConfig{
|
||||
Retention: config.DashboardAggregationRetentionConfig{
|
||||
UsageLogsDays: 1,
|
||||
HourlyDays: 1,
|
||||
DailyDays: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
svc.maybeCleanupRetention(context.Background(), time.Now().UTC())
|
||||
|
||||
require.Nil(t, svc.lastRetentionCleanup.Load())
|
||||
}
|
||||
|
||||
func TestDashboardAggregationService_TriggerBackfill_TooLarge(t *testing.T) {
|
||||
repo := &dashboardAggregationRepoTestStub{}
|
||||
svc := &DashboardAggregationService{
|
||||
repo: repo,
|
||||
cfg: config.DashboardAggregationConfig{
|
||||
BackfillEnabled: true,
|
||||
BackfillMaxDays: 1,
|
||||
},
|
||||
}
|
||||
|
||||
start := time.Now().AddDate(0, 0, -3)
|
||||
end := time.Now()
|
||||
err := svc.TriggerBackfill(start, end)
|
||||
require.ErrorIs(t, err, ErrDashboardBackfillTooLarge)
|
||||
require.Equal(t, 0, repo.aggregateCalls)
|
||||
}
|
||||
@@ -2,25 +2,119 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
)
|
||||
|
||||
// DashboardService provides aggregated statistics for admin dashboard.
|
||||
type DashboardService struct {
|
||||
usageRepo UsageLogRepository
|
||||
const (
|
||||
defaultDashboardStatsFreshTTL = 15 * time.Second
|
||||
defaultDashboardStatsCacheTTL = 30 * time.Second
|
||||
defaultDashboardStatsRefreshTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
// ErrDashboardStatsCacheMiss 标记仪表盘缓存未命中。
|
||||
var ErrDashboardStatsCacheMiss = errors.New("仪表盘缓存未命中")
|
||||
|
||||
// DashboardStatsCache 定义仪表盘统计缓存接口。
|
||||
type DashboardStatsCache interface {
|
||||
GetDashboardStats(ctx context.Context) (string, error)
|
||||
SetDashboardStats(ctx context.Context, data string, ttl time.Duration) error
|
||||
DeleteDashboardStats(ctx context.Context) error
|
||||
}
|
||||
|
||||
func NewDashboardService(usageRepo UsageLogRepository) *DashboardService {
|
||||
type dashboardStatsRangeFetcher interface {
|
||||
GetDashboardStatsWithRange(ctx context.Context, start, end time.Time) (*usagestats.DashboardStats, error)
|
||||
}
|
||||
|
||||
type dashboardStatsCacheEntry struct {
|
||||
Stats *usagestats.DashboardStats `json:"stats"`
|
||||
UpdatedAt int64 `json:"updated_at"`
|
||||
}
|
||||
|
||||
// DashboardService 提供管理员仪表盘统计服务。
|
||||
type DashboardService struct {
|
||||
usageRepo UsageLogRepository
|
||||
aggRepo DashboardAggregationRepository
|
||||
cache DashboardStatsCache
|
||||
cacheFreshTTL time.Duration
|
||||
cacheTTL time.Duration
|
||||
refreshTimeout time.Duration
|
||||
refreshing int32
|
||||
aggEnabled bool
|
||||
aggInterval time.Duration
|
||||
aggLookback time.Duration
|
||||
aggUsageDays int
|
||||
}
|
||||
|
||||
func NewDashboardService(usageRepo UsageLogRepository, aggRepo DashboardAggregationRepository, cache DashboardStatsCache, cfg *config.Config) *DashboardService {
|
||||
freshTTL := defaultDashboardStatsFreshTTL
|
||||
cacheTTL := defaultDashboardStatsCacheTTL
|
||||
refreshTimeout := defaultDashboardStatsRefreshTimeout
|
||||
aggEnabled := true
|
||||
aggInterval := time.Minute
|
||||
aggLookback := 2 * time.Minute
|
||||
aggUsageDays := 90
|
||||
if cfg != nil {
|
||||
if !cfg.Dashboard.Enabled {
|
||||
cache = nil
|
||||
}
|
||||
if cfg.Dashboard.StatsFreshTTLSeconds > 0 {
|
||||
freshTTL = time.Duration(cfg.Dashboard.StatsFreshTTLSeconds) * time.Second
|
||||
}
|
||||
if cfg.Dashboard.StatsTTLSeconds > 0 {
|
||||
cacheTTL = time.Duration(cfg.Dashboard.StatsTTLSeconds) * time.Second
|
||||
}
|
||||
if cfg.Dashboard.StatsRefreshTimeoutSeconds > 0 {
|
||||
refreshTimeout = time.Duration(cfg.Dashboard.StatsRefreshTimeoutSeconds) * time.Second
|
||||
}
|
||||
aggEnabled = cfg.DashboardAgg.Enabled
|
||||
if cfg.DashboardAgg.IntervalSeconds > 0 {
|
||||
aggInterval = time.Duration(cfg.DashboardAgg.IntervalSeconds) * time.Second
|
||||
}
|
||||
if cfg.DashboardAgg.LookbackSeconds > 0 {
|
||||
aggLookback = time.Duration(cfg.DashboardAgg.LookbackSeconds) * time.Second
|
||||
}
|
||||
if cfg.DashboardAgg.Retention.UsageLogsDays > 0 {
|
||||
aggUsageDays = cfg.DashboardAgg.Retention.UsageLogsDays
|
||||
}
|
||||
}
|
||||
return &DashboardService{
|
||||
usageRepo: usageRepo,
|
||||
usageRepo: usageRepo,
|
||||
aggRepo: aggRepo,
|
||||
cache: cache,
|
||||
cacheFreshTTL: freshTTL,
|
||||
cacheTTL: cacheTTL,
|
||||
refreshTimeout: refreshTimeout,
|
||||
aggEnabled: aggEnabled,
|
||||
aggInterval: aggInterval,
|
||||
aggLookback: aggLookback,
|
||||
aggUsageDays: aggUsageDays,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DashboardService) GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) {
|
||||
stats, err := s.usageRepo.GetDashboardStats(ctx)
|
||||
if s.cache != nil {
|
||||
cached, fresh, err := s.getCachedDashboardStats(ctx)
|
||||
if err == nil && cached != nil {
|
||||
s.refreshAggregationStaleness(cached)
|
||||
if !fresh {
|
||||
s.refreshDashboardStatsAsync()
|
||||
}
|
||||
return cached, nil
|
||||
}
|
||||
if err != nil && !errors.Is(err, ErrDashboardStatsCacheMiss) {
|
||||
log.Printf("[Dashboard] 仪表盘缓存读取失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
stats, err := s.refreshDashboardStats(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get dashboard stats: %w", err)
|
||||
}
|
||||
@@ -43,6 +137,169 @@ func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTi
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (s *DashboardService) getCachedDashboardStats(ctx context.Context) (*usagestats.DashboardStats, bool, error) {
|
||||
data, err := s.cache.GetDashboardStats(ctx)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
var entry dashboardStatsCacheEntry
|
||||
if err := json.Unmarshal([]byte(data), &entry); err != nil {
|
||||
s.evictDashboardStatsCache(err)
|
||||
return nil, false, ErrDashboardStatsCacheMiss
|
||||
}
|
||||
if entry.Stats == nil {
|
||||
s.evictDashboardStatsCache(errors.New("仪表盘缓存缺少统计数据"))
|
||||
return nil, false, ErrDashboardStatsCacheMiss
|
||||
}
|
||||
|
||||
age := time.Since(time.Unix(entry.UpdatedAt, 0))
|
||||
return entry.Stats, age <= s.cacheFreshTTL, nil
|
||||
}
|
||||
|
||||
func (s *DashboardService) refreshDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) {
|
||||
stats, err := s.fetchDashboardStats(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.applyAggregationStatus(ctx, stats)
|
||||
cacheCtx, cancel := s.cacheOperationContext()
|
||||
defer cancel()
|
||||
s.saveDashboardStatsCache(cacheCtx, stats)
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (s *DashboardService) refreshDashboardStatsAsync() {
|
||||
if s.cache == nil {
|
||||
return
|
||||
}
|
||||
if !atomic.CompareAndSwapInt32(&s.refreshing, 0, 1) {
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer atomic.StoreInt32(&s.refreshing, 0)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), s.refreshTimeout)
|
||||
defer cancel()
|
||||
|
||||
stats, err := s.fetchDashboardStats(ctx)
|
||||
if err != nil {
|
||||
log.Printf("[Dashboard] 仪表盘缓存异步刷新失败: %v", err)
|
||||
return
|
||||
}
|
||||
s.applyAggregationStatus(ctx, stats)
|
||||
cacheCtx, cancel := s.cacheOperationContext()
|
||||
defer cancel()
|
||||
s.saveDashboardStatsCache(cacheCtx, stats)
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *DashboardService) fetchDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) {
|
||||
if !s.aggEnabled {
|
||||
if fetcher, ok := s.usageRepo.(dashboardStatsRangeFetcher); ok {
|
||||
now := time.Now().UTC()
|
||||
start := truncateToDayUTC(now.AddDate(0, 0, -s.aggUsageDays))
|
||||
return fetcher.GetDashboardStatsWithRange(ctx, start, now)
|
||||
}
|
||||
}
|
||||
return s.usageRepo.GetDashboardStats(ctx)
|
||||
}
|
||||
|
||||
func (s *DashboardService) saveDashboardStatsCache(ctx context.Context, stats *usagestats.DashboardStats) {
|
||||
if s.cache == nil || stats == nil {
|
||||
return
|
||||
}
|
||||
|
||||
entry := dashboardStatsCacheEntry{
|
||||
Stats: stats,
|
||||
UpdatedAt: time.Now().Unix(),
|
||||
}
|
||||
data, err := json.Marshal(entry)
|
||||
if err != nil {
|
||||
log.Printf("[Dashboard] 仪表盘缓存序列化失败: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.cache.SetDashboardStats(ctx, string(data), s.cacheTTL); err != nil {
|
||||
log.Printf("[Dashboard] 仪表盘缓存写入失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DashboardService) evictDashboardStatsCache(reason error) {
|
||||
if s.cache == nil {
|
||||
return
|
||||
}
|
||||
cacheCtx, cancel := s.cacheOperationContext()
|
||||
defer cancel()
|
||||
|
||||
if err := s.cache.DeleteDashboardStats(cacheCtx); err != nil {
|
||||
log.Printf("[Dashboard] 仪表盘缓存清理失败: %v", err)
|
||||
}
|
||||
if reason != nil {
|
||||
log.Printf("[Dashboard] 仪表盘缓存异常,已清理: %v", reason)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DashboardService) cacheOperationContext() (context.Context, context.CancelFunc) {
|
||||
return context.WithTimeout(context.Background(), s.refreshTimeout)
|
||||
}
|
||||
|
||||
func (s *DashboardService) applyAggregationStatus(ctx context.Context, stats *usagestats.DashboardStats) {
|
||||
if stats == nil {
|
||||
return
|
||||
}
|
||||
updatedAt := s.fetchAggregationUpdatedAt(ctx)
|
||||
stats.StatsUpdatedAt = updatedAt.UTC().Format(time.RFC3339)
|
||||
stats.StatsStale = s.isAggregationStale(updatedAt, time.Now().UTC())
|
||||
}
|
||||
|
||||
func (s *DashboardService) refreshAggregationStaleness(stats *usagestats.DashboardStats) {
|
||||
if stats == nil {
|
||||
return
|
||||
}
|
||||
updatedAt := parseStatsUpdatedAt(stats.StatsUpdatedAt)
|
||||
stats.StatsStale = s.isAggregationStale(updatedAt, time.Now().UTC())
|
||||
}
|
||||
|
||||
func (s *DashboardService) fetchAggregationUpdatedAt(ctx context.Context) time.Time {
|
||||
if s.aggRepo == nil {
|
||||
return time.Unix(0, 0).UTC()
|
||||
}
|
||||
updatedAt, err := s.aggRepo.GetAggregationWatermark(ctx)
|
||||
if err != nil {
|
||||
log.Printf("[Dashboard] 读取聚合水位失败: %v", err)
|
||||
return time.Unix(0, 0).UTC()
|
||||
}
|
||||
if updatedAt.IsZero() {
|
||||
return time.Unix(0, 0).UTC()
|
||||
}
|
||||
return updatedAt.UTC()
|
||||
}
|
||||
|
||||
func (s *DashboardService) isAggregationStale(updatedAt, now time.Time) bool {
|
||||
if !s.aggEnabled {
|
||||
return true
|
||||
}
|
||||
epoch := time.Unix(0, 0).UTC()
|
||||
if !updatedAt.After(epoch) {
|
||||
return true
|
||||
}
|
||||
threshold := s.aggInterval + s.aggLookback
|
||||
return now.Sub(updatedAt) > threshold
|
||||
}
|
||||
|
||||
func parseStatsUpdatedAt(raw string) time.Time {
|
||||
if raw == "" {
|
||||
return time.Unix(0, 0).UTC()
|
||||
}
|
||||
parsed, err := time.Parse(time.RFC3339, raw)
|
||||
if err != nil {
|
||||
return time.Unix(0, 0).UTC()
|
||||
}
|
||||
return parsed.UTC()
|
||||
}
|
||||
|
||||
func (s *DashboardService) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) {
|
||||
trend, err := s.usageRepo.GetAPIKeyUsageTrend(ctx, startTime, endTime, granularity, limit)
|
||||
if err != nil {
|
||||
|
||||
387
backend/internal/service/dashboard_service_test.go
Normal file
387
backend/internal/service/dashboard_service_test.go
Normal file
@@ -0,0 +1,387 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type usageRepoStub struct {
|
||||
UsageLogRepository
|
||||
stats *usagestats.DashboardStats
|
||||
rangeStats *usagestats.DashboardStats
|
||||
err error
|
||||
rangeErr error
|
||||
calls int32
|
||||
rangeCalls int32
|
||||
rangeStart time.Time
|
||||
rangeEnd time.Time
|
||||
onCall chan struct{}
|
||||
}
|
||||
|
||||
func (s *usageRepoStub) GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) {
|
||||
atomic.AddInt32(&s.calls, 1)
|
||||
if s.onCall != nil {
|
||||
select {
|
||||
case s.onCall <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
if s.err != nil {
|
||||
return nil, s.err
|
||||
}
|
||||
return s.stats, nil
|
||||
}
|
||||
|
||||
func (s *usageRepoStub) GetDashboardStatsWithRange(ctx context.Context, start, end time.Time) (*usagestats.DashboardStats, error) {
|
||||
atomic.AddInt32(&s.rangeCalls, 1)
|
||||
s.rangeStart = start
|
||||
s.rangeEnd = end
|
||||
if s.rangeErr != nil {
|
||||
return nil, s.rangeErr
|
||||
}
|
||||
if s.rangeStats != nil {
|
||||
return s.rangeStats, nil
|
||||
}
|
||||
return s.stats, nil
|
||||
}
|
||||
|
||||
type dashboardCacheStub struct {
|
||||
get func(ctx context.Context) (string, error)
|
||||
set func(ctx context.Context, data string, ttl time.Duration) error
|
||||
del func(ctx context.Context) error
|
||||
getCalls int32
|
||||
setCalls int32
|
||||
delCalls int32
|
||||
lastSetMu sync.Mutex
|
||||
lastSet string
|
||||
}
|
||||
|
||||
func (c *dashboardCacheStub) GetDashboardStats(ctx context.Context) (string, error) {
|
||||
atomic.AddInt32(&c.getCalls, 1)
|
||||
if c.get != nil {
|
||||
return c.get(ctx)
|
||||
}
|
||||
return "", ErrDashboardStatsCacheMiss
|
||||
}
|
||||
|
||||
func (c *dashboardCacheStub) SetDashboardStats(ctx context.Context, data string, ttl time.Duration) error {
|
||||
atomic.AddInt32(&c.setCalls, 1)
|
||||
c.lastSetMu.Lock()
|
||||
c.lastSet = data
|
||||
c.lastSetMu.Unlock()
|
||||
if c.set != nil {
|
||||
return c.set(ctx, data, ttl)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *dashboardCacheStub) DeleteDashboardStats(ctx context.Context) error {
|
||||
atomic.AddInt32(&c.delCalls, 1)
|
||||
if c.del != nil {
|
||||
return c.del(ctx)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type dashboardAggregationRepoStub struct {
|
||||
watermark time.Time
|
||||
err error
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoStub) AggregateRange(ctx context.Context, start, end time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoStub) GetAggregationWatermark(ctx context.Context) (time.Time, error) {
|
||||
if s.err != nil {
|
||||
return time.Time{}, s.err
|
||||
}
|
||||
return s.watermark, nil
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoStub) UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoStub) CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoStub) CleanupUsageLogs(ctx context.Context, cutoff time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *dashboardCacheStub) readLastEntry(t *testing.T) dashboardStatsCacheEntry {
|
||||
t.Helper()
|
||||
c.lastSetMu.Lock()
|
||||
data := c.lastSet
|
||||
c.lastSetMu.Unlock()
|
||||
|
||||
var entry dashboardStatsCacheEntry
|
||||
err := json.Unmarshal([]byte(data), &entry)
|
||||
require.NoError(t, err)
|
||||
return entry
|
||||
}
|
||||
|
||||
func TestDashboardService_CacheHitFresh(t *testing.T) {
|
||||
stats := &usagestats.DashboardStats{
|
||||
TotalUsers: 10,
|
||||
StatsUpdatedAt: time.Unix(0, 0).UTC().Format(time.RFC3339),
|
||||
StatsStale: true,
|
||||
}
|
||||
entry := dashboardStatsCacheEntry{
|
||||
Stats: stats,
|
||||
UpdatedAt: time.Now().Unix(),
|
||||
}
|
||||
payload, err := json.Marshal(entry)
|
||||
require.NoError(t, err)
|
||||
|
||||
cache := &dashboardCacheStub{
|
||||
get: func(ctx context.Context) (string, error) {
|
||||
return string(payload), nil
|
||||
},
|
||||
}
|
||||
repo := &usageRepoStub{
|
||||
stats: &usagestats.DashboardStats{TotalUsers: 99},
|
||||
}
|
||||
aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()}
|
||||
cfg := &config.Config{
|
||||
Dashboard: config.DashboardCacheConfig{Enabled: true},
|
||||
DashboardAgg: config.DashboardAggregationConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
}
|
||||
svc := NewDashboardService(repo, aggRepo, cache, cfg)
|
||||
|
||||
got, err := svc.GetDashboardStats(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, stats, got)
|
||||
require.Equal(t, int32(0), atomic.LoadInt32(&repo.calls))
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&cache.getCalls))
|
||||
require.Equal(t, int32(0), atomic.LoadInt32(&cache.setCalls))
|
||||
}
|
||||
|
||||
func TestDashboardService_CacheMiss_StoresCache(t *testing.T) {
|
||||
stats := &usagestats.DashboardStats{
|
||||
TotalUsers: 7,
|
||||
StatsUpdatedAt: time.Unix(0, 0).UTC().Format(time.RFC3339),
|
||||
StatsStale: true,
|
||||
}
|
||||
cache := &dashboardCacheStub{
|
||||
get: func(ctx context.Context) (string, error) {
|
||||
return "", ErrDashboardStatsCacheMiss
|
||||
},
|
||||
}
|
||||
repo := &usageRepoStub{stats: stats}
|
||||
aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()}
|
||||
cfg := &config.Config{
|
||||
Dashboard: config.DashboardCacheConfig{Enabled: true},
|
||||
DashboardAgg: config.DashboardAggregationConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
}
|
||||
svc := NewDashboardService(repo, aggRepo, cache, cfg)
|
||||
|
||||
got, err := svc.GetDashboardStats(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, stats, got)
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&repo.calls))
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&cache.getCalls))
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&cache.setCalls))
|
||||
entry := cache.readLastEntry(t)
|
||||
require.Equal(t, stats, entry.Stats)
|
||||
require.WithinDuration(t, time.Now(), time.Unix(entry.UpdatedAt, 0), time.Second)
|
||||
}
|
||||
|
||||
func TestDashboardService_CacheDisabled_SkipsCache(t *testing.T) {
|
||||
stats := &usagestats.DashboardStats{
|
||||
TotalUsers: 3,
|
||||
StatsUpdatedAt: time.Unix(0, 0).UTC().Format(time.RFC3339),
|
||||
StatsStale: true,
|
||||
}
|
||||
cache := &dashboardCacheStub{
|
||||
get: func(ctx context.Context) (string, error) {
|
||||
return "", nil
|
||||
},
|
||||
}
|
||||
repo := &usageRepoStub{stats: stats}
|
||||
aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()}
|
||||
cfg := &config.Config{
|
||||
Dashboard: config.DashboardCacheConfig{Enabled: false},
|
||||
DashboardAgg: config.DashboardAggregationConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
}
|
||||
svc := NewDashboardService(repo, aggRepo, cache, cfg)
|
||||
|
||||
got, err := svc.GetDashboardStats(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, stats, got)
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&repo.calls))
|
||||
require.Equal(t, int32(0), atomic.LoadInt32(&cache.getCalls))
|
||||
require.Equal(t, int32(0), atomic.LoadInt32(&cache.setCalls))
|
||||
}
|
||||
|
||||
func TestDashboardService_CacheHitStale_TriggersAsyncRefresh(t *testing.T) {
|
||||
staleStats := &usagestats.DashboardStats{
|
||||
TotalUsers: 11,
|
||||
StatsUpdatedAt: time.Unix(0, 0).UTC().Format(time.RFC3339),
|
||||
StatsStale: true,
|
||||
}
|
||||
entry := dashboardStatsCacheEntry{
|
||||
Stats: staleStats,
|
||||
UpdatedAt: time.Now().Add(-defaultDashboardStatsFreshTTL * 2).Unix(),
|
||||
}
|
||||
payload, err := json.Marshal(entry)
|
||||
require.NoError(t, err)
|
||||
|
||||
cache := &dashboardCacheStub{
|
||||
get: func(ctx context.Context) (string, error) {
|
||||
return string(payload), nil
|
||||
},
|
||||
}
|
||||
refreshCh := make(chan struct{}, 1)
|
||||
repo := &usageRepoStub{
|
||||
stats: &usagestats.DashboardStats{TotalUsers: 22},
|
||||
onCall: refreshCh,
|
||||
}
|
||||
aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()}
|
||||
cfg := &config.Config{
|
||||
Dashboard: config.DashboardCacheConfig{Enabled: true},
|
||||
DashboardAgg: config.DashboardAggregationConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
}
|
||||
svc := NewDashboardService(repo, aggRepo, cache, cfg)
|
||||
|
||||
got, err := svc.GetDashboardStats(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, staleStats, got)
|
||||
|
||||
select {
|
||||
case <-refreshCh:
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Fatal("等待异步刷新超时")
|
||||
}
|
||||
require.Eventually(t, func() bool {
|
||||
return atomic.LoadInt32(&cache.setCalls) >= 1
|
||||
}, 1*time.Second, 10*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestDashboardService_CacheParseError_EvictsAndRefetches(t *testing.T) {
|
||||
cache := &dashboardCacheStub{
|
||||
get: func(ctx context.Context) (string, error) {
|
||||
return "not-json", nil
|
||||
},
|
||||
}
|
||||
stats := &usagestats.DashboardStats{TotalUsers: 9}
|
||||
repo := &usageRepoStub{stats: stats}
|
||||
aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()}
|
||||
cfg := &config.Config{
|
||||
Dashboard: config.DashboardCacheConfig{Enabled: true},
|
||||
DashboardAgg: config.DashboardAggregationConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
}
|
||||
svc := NewDashboardService(repo, aggRepo, cache, cfg)
|
||||
|
||||
got, err := svc.GetDashboardStats(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, stats, got)
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&cache.delCalls))
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&repo.calls))
|
||||
}
|
||||
|
||||
func TestDashboardService_CacheParseError_RepoFailure(t *testing.T) {
|
||||
cache := &dashboardCacheStub{
|
||||
get: func(ctx context.Context) (string, error) {
|
||||
return "not-json", nil
|
||||
},
|
||||
}
|
||||
repo := &usageRepoStub{err: errors.New("db down")}
|
||||
aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()}
|
||||
cfg := &config.Config{
|
||||
Dashboard: config.DashboardCacheConfig{Enabled: true},
|
||||
DashboardAgg: config.DashboardAggregationConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
}
|
||||
svc := NewDashboardService(repo, aggRepo, cache, cfg)
|
||||
|
||||
_, err := svc.GetDashboardStats(context.Background())
|
||||
require.Error(t, err)
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&cache.delCalls))
|
||||
}
|
||||
|
||||
func TestDashboardService_StatsUpdatedAtEpochWhenMissing(t *testing.T) {
|
||||
stats := &usagestats.DashboardStats{}
|
||||
repo := &usageRepoStub{stats: stats}
|
||||
aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()}
|
||||
cfg := &config.Config{Dashboard: config.DashboardCacheConfig{Enabled: false}}
|
||||
svc := NewDashboardService(repo, aggRepo, nil, cfg)
|
||||
|
||||
got, err := svc.GetDashboardStats(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "1970-01-01T00:00:00Z", got.StatsUpdatedAt)
|
||||
require.True(t, got.StatsStale)
|
||||
}
|
||||
|
||||
func TestDashboardService_StatsStaleFalseWhenFresh(t *testing.T) {
|
||||
aggNow := time.Now().UTC().Truncate(time.Second)
|
||||
stats := &usagestats.DashboardStats{}
|
||||
repo := &usageRepoStub{stats: stats}
|
||||
aggRepo := &dashboardAggregationRepoStub{watermark: aggNow}
|
||||
cfg := &config.Config{
|
||||
Dashboard: config.DashboardCacheConfig{Enabled: false},
|
||||
DashboardAgg: config.DashboardAggregationConfig{
|
||||
Enabled: true,
|
||||
IntervalSeconds: 60,
|
||||
LookbackSeconds: 120,
|
||||
},
|
||||
}
|
||||
svc := NewDashboardService(repo, aggRepo, nil, cfg)
|
||||
|
||||
got, err := svc.GetDashboardStats(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, aggNow.Format(time.RFC3339), got.StatsUpdatedAt)
|
||||
require.False(t, got.StatsStale)
|
||||
}
|
||||
|
||||
func TestDashboardService_AggDisabled_UsesUsageLogsFallback(t *testing.T) {
|
||||
expected := &usagestats.DashboardStats{TotalUsers: 42}
|
||||
repo := &usageRepoStub{
|
||||
rangeStats: expected,
|
||||
err: errors.New("should not call aggregated stats"),
|
||||
}
|
||||
cfg := &config.Config{
|
||||
Dashboard: config.DashboardCacheConfig{Enabled: false},
|
||||
DashboardAgg: config.DashboardAggregationConfig{
|
||||
Enabled: false,
|
||||
Retention: config.DashboardAggregationRetentionConfig{
|
||||
UsageLogsDays: 7,
|
||||
},
|
||||
},
|
||||
}
|
||||
svc := NewDashboardService(repo, nil, nil, cfg)
|
||||
|
||||
got, err := svc.GetDashboardStats(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(42), got.TotalUsers)
|
||||
require.Equal(t, int32(0), atomic.LoadInt32(&repo.calls))
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&repo.rangeCalls))
|
||||
require.False(t, repo.rangeEnd.IsZero())
|
||||
require.Equal(t, truncateToDayUTC(repo.rangeEnd.AddDate(0, 0, -7)), repo.rangeStart)
|
||||
}
|
||||
@@ -50,13 +50,15 @@ type UpdateGroupRequest struct {
|
||||
|
||||
// GroupService 分组管理服务
|
||||
type GroupService struct {
|
||||
groupRepo GroupRepository
|
||||
groupRepo GroupRepository
|
||||
authCacheInvalidator APIKeyAuthCacheInvalidator
|
||||
}
|
||||
|
||||
// NewGroupService 创建分组服务实例
|
||||
func NewGroupService(groupRepo GroupRepository) *GroupService {
|
||||
func NewGroupService(groupRepo GroupRepository, authCacheInvalidator APIKeyAuthCacheInvalidator) *GroupService {
|
||||
return &GroupService{
|
||||
groupRepo: groupRepo,
|
||||
groupRepo: groupRepo,
|
||||
authCacheInvalidator: authCacheInvalidator,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -155,6 +157,9 @@ func (s *GroupService) Update(ctx context.Context, id int64, req UpdateGroupRequ
|
||||
if err := s.groupRepo.Update(ctx, group); err != nil {
|
||||
return nil, fmt.Errorf("update group: %w", err)
|
||||
}
|
||||
if s.authCacheInvalidator != nil {
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id)
|
||||
}
|
||||
|
||||
return group, nil
|
||||
}
|
||||
@@ -167,6 +172,9 @@ func (s *GroupService) Delete(ctx context.Context, id int64) error {
|
||||
return fmt.Errorf("get group: %w", err)
|
||||
}
|
||||
|
||||
if s.authCacheInvalidator != nil {
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id)
|
||||
}
|
||||
if err := s.groupRepo.Delete(ctx, id); err != nil {
|
||||
return fmt.Errorf("delete group: %w", err)
|
||||
}
|
||||
|
||||
404
backend/internal/service/openai_codex_transform.go
Normal file
404
backend/internal/service/openai_codex_transform.go
Normal file
@@ -0,0 +1,404 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
opencodeCodexHeaderURL = "https://raw.githubusercontent.com/anomalyco/opencode/dev/packages/opencode/src/session/prompt/codex_header.txt"
|
||||
codexCacheTTL = 15 * time.Minute
|
||||
)
|
||||
|
||||
var codexModelMap = map[string]string{
|
||||
"gpt-5.1-codex": "gpt-5.1-codex",
|
||||
"gpt-5.1-codex-low": "gpt-5.1-codex",
|
||||
"gpt-5.1-codex-medium": "gpt-5.1-codex",
|
||||
"gpt-5.1-codex-high": "gpt-5.1-codex",
|
||||
"gpt-5.1-codex-max": "gpt-5.1-codex-max",
|
||||
"gpt-5.1-codex-max-low": "gpt-5.1-codex-max",
|
||||
"gpt-5.1-codex-max-medium": "gpt-5.1-codex-max",
|
||||
"gpt-5.1-codex-max-high": "gpt-5.1-codex-max",
|
||||
"gpt-5.1-codex-max-xhigh": "gpt-5.1-codex-max",
|
||||
"gpt-5.2": "gpt-5.2",
|
||||
"gpt-5.2-none": "gpt-5.2",
|
||||
"gpt-5.2-low": "gpt-5.2",
|
||||
"gpt-5.2-medium": "gpt-5.2",
|
||||
"gpt-5.2-high": "gpt-5.2",
|
||||
"gpt-5.2-xhigh": "gpt-5.2",
|
||||
"gpt-5.2-codex": "gpt-5.2-codex",
|
||||
"gpt-5.2-codex-low": "gpt-5.2-codex",
|
||||
"gpt-5.2-codex-medium": "gpt-5.2-codex",
|
||||
"gpt-5.2-codex-high": "gpt-5.2-codex",
|
||||
"gpt-5.2-codex-xhigh": "gpt-5.2-codex",
|
||||
"gpt-5.1-codex-mini": "gpt-5.1-codex-mini",
|
||||
"gpt-5.1-codex-mini-medium": "gpt-5.1-codex-mini",
|
||||
"gpt-5.1-codex-mini-high": "gpt-5.1-codex-mini",
|
||||
"gpt-5.1": "gpt-5.1",
|
||||
"gpt-5.1-none": "gpt-5.1",
|
||||
"gpt-5.1-low": "gpt-5.1",
|
||||
"gpt-5.1-medium": "gpt-5.1",
|
||||
"gpt-5.1-high": "gpt-5.1",
|
||||
"gpt-5.1-chat-latest": "gpt-5.1",
|
||||
"gpt-5-codex": "gpt-5.1-codex",
|
||||
"codex-mini-latest": "gpt-5.1-codex-mini",
|
||||
"gpt-5-codex-mini": "gpt-5.1-codex-mini",
|
||||
"gpt-5-codex-mini-medium": "gpt-5.1-codex-mini",
|
||||
"gpt-5-codex-mini-high": "gpt-5.1-codex-mini",
|
||||
"gpt-5": "gpt-5.1",
|
||||
"gpt-5-mini": "gpt-5.1",
|
||||
"gpt-5-nano": "gpt-5.1",
|
||||
}
|
||||
|
||||
type codexTransformResult struct {
|
||||
Modified bool
|
||||
NormalizedModel string
|
||||
PromptCacheKey string
|
||||
}
|
||||
|
||||
type opencodeCacheMetadata struct {
|
||||
ETag string `json:"etag"`
|
||||
LastFetch string `json:"lastFetch,omitempty"`
|
||||
LastChecked int64 `json:"lastChecked"`
|
||||
}
|
||||
|
||||
func applyCodexOAuthTransform(reqBody map[string]any) codexTransformResult {
|
||||
result := codexTransformResult{}
|
||||
|
||||
model := ""
|
||||
if v, ok := reqBody["model"].(string); ok {
|
||||
model = v
|
||||
}
|
||||
normalizedModel := normalizeCodexModel(model)
|
||||
if normalizedModel != "" {
|
||||
if model != normalizedModel {
|
||||
reqBody["model"] = normalizedModel
|
||||
result.Modified = true
|
||||
}
|
||||
result.NormalizedModel = normalizedModel
|
||||
}
|
||||
|
||||
if v, ok := reqBody["store"].(bool); !ok || v {
|
||||
reqBody["store"] = false
|
||||
result.Modified = true
|
||||
}
|
||||
if v, ok := reqBody["stream"].(bool); !ok || !v {
|
||||
reqBody["stream"] = true
|
||||
result.Modified = true
|
||||
}
|
||||
|
||||
if _, ok := reqBody["max_output_tokens"]; ok {
|
||||
delete(reqBody, "max_output_tokens")
|
||||
result.Modified = true
|
||||
}
|
||||
if _, ok := reqBody["max_completion_tokens"]; ok {
|
||||
delete(reqBody, "max_completion_tokens")
|
||||
result.Modified = true
|
||||
}
|
||||
|
||||
if normalizeCodexTools(reqBody) {
|
||||
result.Modified = true
|
||||
}
|
||||
|
||||
if v, ok := reqBody["prompt_cache_key"].(string); ok {
|
||||
result.PromptCacheKey = strings.TrimSpace(v)
|
||||
}
|
||||
|
||||
instructions := strings.TrimSpace(getOpenCodeCodexHeader())
|
||||
existingInstructions, _ := reqBody["instructions"].(string)
|
||||
existingInstructions = strings.TrimSpace(existingInstructions)
|
||||
|
||||
if instructions != "" {
|
||||
if existingInstructions != instructions {
|
||||
reqBody["instructions"] = instructions
|
||||
result.Modified = true
|
||||
}
|
||||
}
|
||||
|
||||
if input, ok := reqBody["input"].([]any); ok {
|
||||
input = filterCodexInput(input)
|
||||
reqBody["input"] = input
|
||||
result.Modified = true
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func normalizeCodexModel(model string) string {
|
||||
if model == "" {
|
||||
return "gpt-5.1"
|
||||
}
|
||||
|
||||
modelID := model
|
||||
if strings.Contains(modelID, "/") {
|
||||
parts := strings.Split(modelID, "/")
|
||||
modelID = parts[len(parts)-1]
|
||||
}
|
||||
|
||||
if mapped := getNormalizedCodexModel(modelID); mapped != "" {
|
||||
return mapped
|
||||
}
|
||||
|
||||
normalized := strings.ToLower(modelID)
|
||||
|
||||
if strings.Contains(normalized, "gpt-5.2-codex") || strings.Contains(normalized, "gpt 5.2 codex") {
|
||||
return "gpt-5.2-codex"
|
||||
}
|
||||
if strings.Contains(normalized, "gpt-5.2") || strings.Contains(normalized, "gpt 5.2") {
|
||||
return "gpt-5.2"
|
||||
}
|
||||
if strings.Contains(normalized, "gpt-5.1-codex-max") || strings.Contains(normalized, "gpt 5.1 codex max") {
|
||||
return "gpt-5.1-codex-max"
|
||||
}
|
||||
if strings.Contains(normalized, "gpt-5.1-codex-mini") || strings.Contains(normalized, "gpt 5.1 codex mini") {
|
||||
return "gpt-5.1-codex-mini"
|
||||
}
|
||||
if strings.Contains(normalized, "codex-mini-latest") ||
|
||||
strings.Contains(normalized, "gpt-5-codex-mini") ||
|
||||
strings.Contains(normalized, "gpt 5 codex mini") {
|
||||
return "codex-mini-latest"
|
||||
}
|
||||
if strings.Contains(normalized, "gpt-5.1-codex") || strings.Contains(normalized, "gpt 5.1 codex") {
|
||||
return "gpt-5.1-codex"
|
||||
}
|
||||
if strings.Contains(normalized, "gpt-5.1") || strings.Contains(normalized, "gpt 5.1") {
|
||||
return "gpt-5.1"
|
||||
}
|
||||
if strings.Contains(normalized, "codex") {
|
||||
return "gpt-5.1-codex"
|
||||
}
|
||||
if strings.Contains(normalized, "gpt-5") || strings.Contains(normalized, "gpt 5") {
|
||||
return "gpt-5.1"
|
||||
}
|
||||
|
||||
return "gpt-5.1"
|
||||
}
|
||||
|
||||
func getNormalizedCodexModel(modelID string) string {
|
||||
if modelID == "" {
|
||||
return ""
|
||||
}
|
||||
if mapped, ok := codexModelMap[modelID]; ok {
|
||||
return mapped
|
||||
}
|
||||
lower := strings.ToLower(modelID)
|
||||
for key, value := range codexModelMap {
|
||||
if strings.ToLower(key) == lower {
|
||||
return value
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func getOpenCodeCachedPrompt(url, cacheFileName, metaFileName string) string {
|
||||
cacheDir := codexCachePath("")
|
||||
if cacheDir == "" {
|
||||
return ""
|
||||
}
|
||||
cacheFile := filepath.Join(cacheDir, cacheFileName)
|
||||
metaFile := filepath.Join(cacheDir, metaFileName)
|
||||
|
||||
var cachedContent string
|
||||
if content, ok := readFile(cacheFile); ok {
|
||||
cachedContent = content
|
||||
}
|
||||
|
||||
var meta opencodeCacheMetadata
|
||||
if loadJSON(metaFile, &meta) && meta.LastChecked > 0 && cachedContent != "" {
|
||||
if time.Since(time.UnixMilli(meta.LastChecked)) < codexCacheTTL {
|
||||
return cachedContent
|
||||
}
|
||||
}
|
||||
|
||||
content, etag, status, err := fetchWithETag(url, meta.ETag)
|
||||
if err == nil && status == http.StatusNotModified && cachedContent != "" {
|
||||
return cachedContent
|
||||
}
|
||||
if err == nil && status >= 200 && status < 300 && content != "" {
|
||||
_ = writeFile(cacheFile, content)
|
||||
meta = opencodeCacheMetadata{
|
||||
ETag: etag,
|
||||
LastFetch: time.Now().UTC().Format(time.RFC3339),
|
||||
LastChecked: time.Now().UnixMilli(),
|
||||
}
|
||||
_ = writeJSON(metaFile, meta)
|
||||
return content
|
||||
}
|
||||
|
||||
return cachedContent
|
||||
}
|
||||
|
||||
func getOpenCodeCodexHeader() string {
|
||||
return getOpenCodeCachedPrompt(opencodeCodexHeaderURL, "opencode-codex-header.txt", "opencode-codex-header-meta.json")
|
||||
}
|
||||
|
||||
func GetOpenCodeInstructions() string {
|
||||
return getOpenCodeCodexHeader()
|
||||
}
|
||||
|
||||
func filterCodexInput(input []any) []any {
|
||||
filtered := make([]any, 0, len(input))
|
||||
for _, item := range input {
|
||||
m, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
filtered = append(filtered, item)
|
||||
continue
|
||||
}
|
||||
if typ, ok := m["type"].(string); ok && typ == "item_reference" {
|
||||
continue
|
||||
}
|
||||
delete(m, "id")
|
||||
filtered = append(filtered, m)
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
func normalizeCodexTools(reqBody map[string]any) bool {
|
||||
rawTools, ok := reqBody["tools"]
|
||||
if !ok || rawTools == nil {
|
||||
return false
|
||||
}
|
||||
tools, ok := rawTools.([]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
modified := false
|
||||
for idx, tool := range tools {
|
||||
toolMap, ok := tool.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
toolType, _ := toolMap["type"].(string)
|
||||
if strings.TrimSpace(toolType) != "function" {
|
||||
continue
|
||||
}
|
||||
|
||||
function, ok := toolMap["function"].(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if _, ok := toolMap["name"]; !ok {
|
||||
if name, ok := function["name"].(string); ok && strings.TrimSpace(name) != "" {
|
||||
toolMap["name"] = name
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
if _, ok := toolMap["description"]; !ok {
|
||||
if desc, ok := function["description"].(string); ok && strings.TrimSpace(desc) != "" {
|
||||
toolMap["description"] = desc
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
if _, ok := toolMap["parameters"]; !ok {
|
||||
if params, ok := function["parameters"]; ok {
|
||||
toolMap["parameters"] = params
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
if _, ok := toolMap["strict"]; !ok {
|
||||
if strict, ok := function["strict"]; ok {
|
||||
toolMap["strict"] = strict
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
|
||||
tools[idx] = toolMap
|
||||
}
|
||||
|
||||
if modified {
|
||||
reqBody["tools"] = tools
|
||||
}
|
||||
|
||||
return modified
|
||||
}
|
||||
|
||||
func codexCachePath(filename string) string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
cacheDir := filepath.Join(home, ".opencode", "cache")
|
||||
if filename == "" {
|
||||
return cacheDir
|
||||
}
|
||||
return filepath.Join(cacheDir, filename)
|
||||
}
|
||||
|
||||
func readFile(path string) (string, bool) {
|
||||
if path == "" {
|
||||
return "", false
|
||||
}
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return "", false
|
||||
}
|
||||
return string(data), true
|
||||
}
|
||||
|
||||
func writeFile(path, content string) error {
|
||||
if path == "" {
|
||||
return fmt.Errorf("empty cache path")
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(path, []byte(content), 0o644)
|
||||
}
|
||||
|
||||
func loadJSON(path string, target any) bool {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if err := json.Unmarshal(data, target); err != nil {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func writeJSON(path string, value any) error {
|
||||
if path == "" {
|
||||
return fmt.Errorf("empty json path")
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
data, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(path, data, 0o644)
|
||||
}
|
||||
|
||||
func fetchWithETag(url, etag string) (string, string, int, error) {
|
||||
req, err := http.NewRequest(http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return "", "", 0, err
|
||||
}
|
||||
req.Header.Set("User-Agent", "sub2api-codex")
|
||||
if etag != "" {
|
||||
req.Header.Set("If-None-Match", etag)
|
||||
}
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return "", "", 0, err
|
||||
}
|
||||
defer func() {
|
||||
_ = resp.Body.Close()
|
||||
}()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", "", resp.StatusCode, err
|
||||
}
|
||||
return string(body), resp.Header.Get("etag"), resp.StatusCode, nil
|
||||
}
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
@@ -20,6 +21,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -528,33 +530,38 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
// Extract model and stream from parsed body
|
||||
reqModel, _ := reqBody["model"].(string)
|
||||
reqStream, _ := reqBody["stream"].(bool)
|
||||
promptCacheKey := ""
|
||||
if v, ok := reqBody["prompt_cache_key"].(string); ok {
|
||||
promptCacheKey = strings.TrimSpace(v)
|
||||
}
|
||||
|
||||
// Track if body needs re-serialization
|
||||
bodyModified := false
|
||||
originalModel := reqModel
|
||||
|
||||
// Apply model mapping
|
||||
mappedModel := account.GetMappedModel(reqModel)
|
||||
if mappedModel != reqModel {
|
||||
reqBody["model"] = mappedModel
|
||||
bodyModified = true
|
||||
isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent"))
|
||||
|
||||
// Apply model mapping (skip for Codex CLI for transparent forwarding)
|
||||
mappedModel := reqModel
|
||||
if !isCodexCLI {
|
||||
mappedModel = account.GetMappedModel(reqModel)
|
||||
if mappedModel != reqModel {
|
||||
reqBody["model"] = mappedModel
|
||||
bodyModified = true
|
||||
}
|
||||
}
|
||||
|
||||
// For OAuth accounts using ChatGPT internal API:
|
||||
// 1. Add store: false
|
||||
// 2. Normalize input format for Codex API compatibility
|
||||
if account.Type == AccountTypeOAuth {
|
||||
reqBody["store"] = false
|
||||
// Codex 上游不接受 max_output_tokens 参数,需要在转发前移除。
|
||||
delete(reqBody, "max_output_tokens")
|
||||
bodyModified = true
|
||||
|
||||
// Normalize input format: convert AI SDK multi-part content format to simplified format
|
||||
// AI SDK sends: {"content": [{"type": "input_text", "text": "..."}]}
|
||||
// Codex API expects: {"content": "..."}
|
||||
if normalizeInputForCodexAPI(reqBody) {
|
||||
if account.Type == AccountTypeOAuth && !isCodexCLI {
|
||||
codexResult := applyCodexOAuthTransform(reqBody)
|
||||
if codexResult.Modified {
|
||||
bodyModified = true
|
||||
}
|
||||
if codexResult.NormalizedModel != "" {
|
||||
mappedModel = codexResult.NormalizedModel
|
||||
}
|
||||
if codexResult.PromptCacheKey != "" {
|
||||
promptCacheKey = codexResult.PromptCacheKey
|
||||
}
|
||||
}
|
||||
|
||||
// Re-serialize body only if modified
|
||||
@@ -573,7 +580,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
}
|
||||
|
||||
// Build upstream request
|
||||
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, reqStream)
|
||||
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -674,7 +681,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token string, isStream bool) (*http.Request, error) {
|
||||
func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token string, isStream bool, promptCacheKey string, isCodexCLI bool) (*http.Request, error) {
|
||||
// Determine target URL based on account type
|
||||
var targetURL string
|
||||
switch account.Type {
|
||||
@@ -714,12 +721,6 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
|
||||
if chatgptAccountID != "" {
|
||||
req.Header.Set("chatgpt-account-id", chatgptAccountID)
|
||||
}
|
||||
// Set accept header based on stream mode
|
||||
if isStream {
|
||||
req.Header.Set("accept", "text/event-stream")
|
||||
} else {
|
||||
req.Header.Set("accept", "application/json")
|
||||
}
|
||||
}
|
||||
|
||||
// Whitelist passthrough headers
|
||||
@@ -731,6 +732,22 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
|
||||
}
|
||||
}
|
||||
}
|
||||
if account.Type == AccountTypeOAuth {
|
||||
req.Header.Set("OpenAI-Beta", "responses=experimental")
|
||||
if isCodexCLI {
|
||||
req.Header.Set("originator", "codex_cli_rs")
|
||||
} else {
|
||||
req.Header.Set("originator", "opencode")
|
||||
}
|
||||
req.Header.Set("accept", "text/event-stream")
|
||||
if promptCacheKey != "" {
|
||||
req.Header.Set("conversation_id", promptCacheKey)
|
||||
req.Header.Set("session_id", promptCacheKey)
|
||||
} else {
|
||||
req.Header.Del("conversation_id")
|
||||
req.Header.Del("session_id")
|
||||
}
|
||||
}
|
||||
|
||||
// Apply custom User-Agent if configured
|
||||
customUA := account.GetOpenAIUserAgent()
|
||||
@@ -1109,6 +1126,13 @@ func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, r
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if account.Type == AccountTypeOAuth {
|
||||
bodyLooksLikeSSE := bytes.Contains(body, []byte("data:")) || bytes.Contains(body, []byte("event:"))
|
||||
if isEventStreamResponse(resp.Header) || bodyLooksLikeSSE {
|
||||
return s.handleOAuthSSEToJSON(resp, c, body, originalModel, mappedModel)
|
||||
}
|
||||
}
|
||||
|
||||
// Parse usage
|
||||
var response struct {
|
||||
Usage struct {
|
||||
@@ -1148,6 +1172,110 @@ func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, r
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
func isEventStreamResponse(header http.Header) bool {
|
||||
contentType := strings.ToLower(header.Get("Content-Type"))
|
||||
return strings.Contains(contentType, "text/event-stream")
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) handleOAuthSSEToJSON(resp *http.Response, c *gin.Context, body []byte, originalModel, mappedModel string) (*OpenAIUsage, error) {
|
||||
bodyText := string(body)
|
||||
finalResponse, ok := extractCodexFinalResponse(bodyText)
|
||||
|
||||
usage := &OpenAIUsage{}
|
||||
if ok {
|
||||
var response struct {
|
||||
Usage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
InputTokenDetails struct {
|
||||
CachedTokens int `json:"cached_tokens"`
|
||||
} `json:"input_tokens_details"`
|
||||
} `json:"usage"`
|
||||
}
|
||||
if err := json.Unmarshal(finalResponse, &response); err == nil {
|
||||
usage.InputTokens = response.Usage.InputTokens
|
||||
usage.OutputTokens = response.Usage.OutputTokens
|
||||
usage.CacheReadInputTokens = response.Usage.InputTokenDetails.CachedTokens
|
||||
}
|
||||
body = finalResponse
|
||||
if originalModel != mappedModel {
|
||||
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
|
||||
}
|
||||
} else {
|
||||
usage = s.parseSSEUsageFromBody(bodyText)
|
||||
if originalModel != mappedModel {
|
||||
bodyText = s.replaceModelInSSEBody(bodyText, mappedModel, originalModel)
|
||||
}
|
||||
body = []byte(bodyText)
|
||||
}
|
||||
|
||||
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
|
||||
|
||||
contentType := "application/json; charset=utf-8"
|
||||
if !ok {
|
||||
contentType = resp.Header.Get("Content-Type")
|
||||
if contentType == "" {
|
||||
contentType = "text/event-stream"
|
||||
}
|
||||
}
|
||||
c.Data(resp.StatusCode, contentType, body)
|
||||
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
func extractCodexFinalResponse(body string) ([]byte, bool) {
|
||||
lines := strings.Split(body, "\n")
|
||||
for _, line := range lines {
|
||||
if !openaiSSEDataRe.MatchString(line) {
|
||||
continue
|
||||
}
|
||||
data := openaiSSEDataRe.ReplaceAllString(line, "")
|
||||
if data == "" || data == "[DONE]" {
|
||||
continue
|
||||
}
|
||||
var event struct {
|
||||
Type string `json:"type"`
|
||||
Response json.RawMessage `json:"response"`
|
||||
}
|
||||
if json.Unmarshal([]byte(data), &event) != nil {
|
||||
continue
|
||||
}
|
||||
if event.Type == "response.done" || event.Type == "response.completed" {
|
||||
if len(event.Response) > 0 {
|
||||
return event.Response, true
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) parseSSEUsageFromBody(body string) *OpenAIUsage {
|
||||
usage := &OpenAIUsage{}
|
||||
lines := strings.Split(body, "\n")
|
||||
for _, line := range lines {
|
||||
if !openaiSSEDataRe.MatchString(line) {
|
||||
continue
|
||||
}
|
||||
data := openaiSSEDataRe.ReplaceAllString(line, "")
|
||||
if data == "" || data == "[DONE]" {
|
||||
continue
|
||||
}
|
||||
s.parseSSEUsage(data, usage)
|
||||
}
|
||||
return usage
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) replaceModelInSSEBody(body, fromModel, toModel string) string {
|
||||
lines := strings.Split(body, "\n")
|
||||
for i, line := range lines {
|
||||
if !openaiSSEDataRe.MatchString(line) {
|
||||
continue
|
||||
}
|
||||
lines[i] = s.replaceModelInSSELine(line, fromModel, toModel)
|
||||
}
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) validateUpstreamBaseURL(raw string) (string, error) {
|
||||
if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled {
|
||||
normalized, err := urlvalidator.ValidateURLFormat(raw, s.cfg.Security.URLAllowlist.AllowInsecureHTTP)
|
||||
@@ -1187,101 +1315,6 @@ func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel
|
||||
return newBody
|
||||
}
|
||||
|
||||
// normalizeInputForCodexAPI converts AI SDK multi-part content format to simplified format
|
||||
// that the ChatGPT internal Codex API expects.
|
||||
//
|
||||
// AI SDK sends content as an array of typed objects:
|
||||
//
|
||||
// {"content": [{"type": "input_text", "text": "hello"}]}
|
||||
//
|
||||
// ChatGPT Codex API expects content as a simple string:
|
||||
//
|
||||
// {"content": "hello"}
|
||||
//
|
||||
// This function modifies reqBody in-place and returns true if any modification was made.
|
||||
func normalizeInputForCodexAPI(reqBody map[string]any) bool {
|
||||
input, ok := reqBody["input"]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
// Handle case where input is a simple string (already compatible)
|
||||
if _, isString := input.(string); isString {
|
||||
return false
|
||||
}
|
||||
|
||||
// Handle case where input is an array of messages
|
||||
inputArray, ok := input.([]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
modified := false
|
||||
for _, item := range inputArray {
|
||||
message, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
content, ok := message["content"]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
// If content is already a string, no conversion needed
|
||||
if _, isString := content.(string); isString {
|
||||
continue
|
||||
}
|
||||
|
||||
// If content is an array (AI SDK format), convert to string
|
||||
contentArray, ok := content.([]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
// Extract text from content array
|
||||
var textParts []string
|
||||
for _, part := range contentArray {
|
||||
partMap, ok := part.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle different content types
|
||||
partType, _ := partMap["type"].(string)
|
||||
switch partType {
|
||||
case "input_text", "text":
|
||||
// Extract text from input_text or text type
|
||||
if text, ok := partMap["text"].(string); ok {
|
||||
textParts = append(textParts, text)
|
||||
}
|
||||
case "input_image", "image":
|
||||
// For images, we need to preserve the original format
|
||||
// as ChatGPT Codex API may support images in a different way
|
||||
// For now, skip image parts (they will be lost in conversion)
|
||||
// TODO: Consider preserving image data or handling it separately
|
||||
continue
|
||||
case "input_file", "file":
|
||||
// Similar to images, file inputs may need special handling
|
||||
continue
|
||||
default:
|
||||
// For unknown types, try to extract text if available
|
||||
if text, ok := partMap["text"].(string); ok {
|
||||
textParts = append(textParts, text)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Convert content array to string
|
||||
if len(textParts) > 0 {
|
||||
message["content"] = strings.Join(textParts, "\n")
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
|
||||
return modified
|
||||
}
|
||||
|
||||
// OpenAIRecordUsageInput input for recording usage
|
||||
type OpenAIRecordUsageInput struct {
|
||||
Result *OpenAIForwardResult
|
||||
|
||||
@@ -220,7 +220,7 @@ func TestOpenAIInvalidBaseURLWhenAllowlistDisabled(t *testing.T) {
|
||||
Credentials: map[string]any{"base_url": "://invalid-url"},
|
||||
}
|
||||
|
||||
_, err := svc.buildUpstreamRequest(c.Request.Context(), c, account, []byte("{}"), "token", false)
|
||||
_, err := svc.buildUpstreamRequest(c.Request.Context(), c, account, []byte("{}"), "token", false, "", false)
|
||||
if err == nil {
|
||||
t.Fatalf("expected error for invalid base_url when allowlist disabled")
|
||||
}
|
||||
|
||||
@@ -24,10 +24,11 @@ var (
|
||||
|
||||
// PromoService 优惠码服务
|
||||
type PromoService struct {
|
||||
promoRepo PromoCodeRepository
|
||||
userRepo UserRepository
|
||||
billingCacheService *BillingCacheService
|
||||
entClient *dbent.Client
|
||||
promoRepo PromoCodeRepository
|
||||
userRepo UserRepository
|
||||
billingCacheService *BillingCacheService
|
||||
entClient *dbent.Client
|
||||
authCacheInvalidator APIKeyAuthCacheInvalidator
|
||||
}
|
||||
|
||||
// NewPromoService 创建优惠码服务实例
|
||||
@@ -36,12 +37,14 @@ func NewPromoService(
|
||||
userRepo UserRepository,
|
||||
billingCacheService *BillingCacheService,
|
||||
entClient *dbent.Client,
|
||||
authCacheInvalidator APIKeyAuthCacheInvalidator,
|
||||
) *PromoService {
|
||||
return &PromoService{
|
||||
promoRepo: promoRepo,
|
||||
userRepo: userRepo,
|
||||
billingCacheService: billingCacheService,
|
||||
entClient: entClient,
|
||||
promoRepo: promoRepo,
|
||||
userRepo: userRepo,
|
||||
billingCacheService: billingCacheService,
|
||||
entClient: entClient,
|
||||
authCacheInvalidator: authCacheInvalidator,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -145,6 +148,8 @@ func (s *PromoService) ApplyPromoCode(ctx context.Context, userID int64, code st
|
||||
return fmt.Errorf("commit transaction: %w", err)
|
||||
}
|
||||
|
||||
s.invalidatePromoCaches(ctx, userID, promoCode.BonusAmount)
|
||||
|
||||
// 失效余额缓存
|
||||
if s.billingCacheService != nil {
|
||||
go func() {
|
||||
@@ -157,6 +162,13 @@ func (s *PromoService) ApplyPromoCode(ctx context.Context, userID int64, code st
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *PromoService) invalidatePromoCaches(ctx context.Context, userID int64, bonusAmount float64) {
|
||||
if bonusAmount == 0 || s.authCacheInvalidator == nil {
|
||||
return
|
||||
}
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
|
||||
}
|
||||
|
||||
// GenerateRandomCode 生成随机优惠码
|
||||
func (s *PromoService) GenerateRandomCode() (string, error) {
|
||||
bytes := make([]byte, 8)
|
||||
|
||||
122
backend/internal/service/prompts/codex_opencode_bridge.txt
Normal file
122
backend/internal/service/prompts/codex_opencode_bridge.txt
Normal file
@@ -0,0 +1,122 @@
|
||||
# Codex Running in OpenCode
|
||||
|
||||
You are running Codex through OpenCode, an open-source terminal coding assistant. OpenCode provides different tools but follows Codex operating principles.
|
||||
|
||||
## CRITICAL: Tool Replacements
|
||||
|
||||
<critical_rule priority="0">
|
||||
❌ APPLY_PATCH DOES NOT EXIST → ✅ USE "edit" INSTEAD
|
||||
- NEVER use: apply_patch, applyPatch
|
||||
- ALWAYS use: edit tool for ALL file modifications
|
||||
- Before modifying files: Verify you're using "edit", NOT "apply_patch"
|
||||
</critical_rule>
|
||||
|
||||
<critical_rule priority="0">
|
||||
❌ UPDATE_PLAN DOES NOT EXIST → ✅ USE "todowrite" INSTEAD
|
||||
- NEVER use: update_plan, updatePlan, read_plan, readPlan
|
||||
- ALWAYS use: todowrite for task/plan updates, todoread to read plans
|
||||
- Before plan operations: Verify you're using "todowrite", NOT "update_plan"
|
||||
</critical_rule>
|
||||
|
||||
## Available OpenCode Tools
|
||||
|
||||
**File Operations:**
|
||||
- `write` - Create new files
|
||||
- Overwriting existing files requires a prior Read in this session; default to ASCII unless the file already uses Unicode.
|
||||
- `edit` - Modify existing files (REPLACES apply_patch)
|
||||
- Requires a prior Read in this session; preserve exact indentation; ensure `oldString` uniquely matches or use `replaceAll`; edit fails if ambiguous or missing.
|
||||
- `read` - Read file contents
|
||||
|
||||
**Search/Discovery:**
|
||||
- `grep` - Search file contents (tool, not bash grep); use `include` to filter patterns; set `path` only when not searching workspace root; for cross-file match counts use bash with `rg`.
|
||||
- `glob` - Find files by pattern; defaults to workspace cwd unless `path` is set.
|
||||
- `list` - List directories (requires absolute paths)
|
||||
|
||||
**Execution:**
|
||||
- `bash` - Run shell commands
|
||||
- No workdir parameter; do not include it in tool calls.
|
||||
- Always include a short description for the command.
|
||||
- Do not use cd; use absolute paths in commands.
|
||||
- Quote paths containing spaces with double quotes.
|
||||
- Chain multiple commands with ';' or '&&'; avoid newlines.
|
||||
- Use Grep/Glob tools for searches; only use bash with `rg` when you need counts or advanced features.
|
||||
- Do not use `ls`/`cat` in bash; use `list`/`read` tools instead.
|
||||
- For deletions (rm), verify by listing parent dir with `list`.
|
||||
|
||||
**Network:**
|
||||
- `webfetch` - Fetch web content
|
||||
- Use fully-formed URLs (http/https; http auto-upgrades to https).
|
||||
- Always set `format` to one of: text | markdown | html; prefer markdown unless otherwise required.
|
||||
- Read-only; short cache window.
|
||||
|
||||
**Task Management:**
|
||||
- `todowrite` - Manage tasks/plans (REPLACES update_plan)
|
||||
- `todoread` - Read current plan
|
||||
|
||||
## Substitution Rules
|
||||
|
||||
Base instruction says: You MUST use instead:
|
||||
apply_patch → edit
|
||||
update_plan → todowrite
|
||||
read_plan → todoread
|
||||
|
||||
**Path Usage:** Use per-tool conventions to avoid conflicts:
|
||||
- Tool calls: `read`, `edit`, `write`, `list` require absolute paths.
|
||||
- Searches: `grep`/`glob` default to the workspace cwd; prefer relative include patterns; set `path` only when a different root is needed.
|
||||
- Presentation: In assistant messages, show workspace-relative paths; use absolute paths only inside tool calls.
|
||||
- Tool schema overrides general path preferences—do not convert required absolute paths to relative.
|
||||
|
||||
## Verification Checklist
|
||||
|
||||
Before file/plan modifications:
|
||||
1. Am I using "edit" NOT "apply_patch"?
|
||||
2. Am I using "todowrite" NOT "update_plan"?
|
||||
3. Is this tool in the approved list above?
|
||||
4. Am I following each tool's path requirements?
|
||||
|
||||
If ANY answer is NO → STOP and correct before proceeding.
|
||||
|
||||
## OpenCode Working Style
|
||||
|
||||
**Communication:**
|
||||
- Send brief preambles (8-12 words) before tool calls, building on prior context
|
||||
- Provide progress updates during longer tasks
|
||||
|
||||
**Execution:**
|
||||
- Keep working autonomously until query is fully resolved before yielding
|
||||
- Don't return to user with partial solutions
|
||||
|
||||
**Code Approach:**
|
||||
- New projects: Be ambitious and creative
|
||||
- Existing codebases: Surgical precision - modify only what's requested unless explicitly instructed to do otherwise
|
||||
|
||||
**Testing:**
|
||||
- If tests exist: Start specific to your changes, then broader validation
|
||||
|
||||
## Advanced Tools
|
||||
|
||||
**Task Tool (Sub-Agents):**
|
||||
- Use the Task tool (functions.task) to launch sub-agents
|
||||
- Check the Task tool description for current agent types and their capabilities
|
||||
- Useful for complex analysis, specialized workflows, or tasks requiring isolated context
|
||||
- The agent list is dynamically generated - refer to tool schema for available agents
|
||||
|
||||
**Parallelization:**
|
||||
- When multiple independent tool calls are needed, use multi_tool_use.parallel to run them concurrently.
|
||||
- Reserve sequential calls for ordered or data-dependent steps.
|
||||
|
||||
**MCP Tools:**
|
||||
- Model Context Protocol servers provide additional capabilities
|
||||
- MCP tools are prefixed: `mcp__<server-name>__<tool-name>`
|
||||
- Check your available tools for MCP integrations
|
||||
- Use when the tool's functionality matches your task needs
|
||||
|
||||
## What Remains from Codex
|
||||
|
||||
Sandbox policies, approval mechanisms, final answer formatting, git commit protocols, and file reference formats all follow Codex instructions. In approval policy "never", never request escalations.
|
||||
|
||||
## Approvals & Safety
|
||||
- Assume workspace-write filesystem, network enabled, approval on-failure unless explicitly stated otherwise.
|
||||
- When a command fails due to sandboxing or permissions, retry with escalated permissions if allowed by policy, including a one-line justification.
|
||||
- Treat destructive commands (e.g., `rm`, `git reset --hard`) as requiring explicit user request or approval.
|
||||
- When uncertain, prefer non-destructive verification first (e.g., confirm file existence with `list`, then delete with `bash`).
|
||||
63
backend/internal/service/prompts/tool_remap_message.txt
Normal file
63
backend/internal/service/prompts/tool_remap_message.txt
Normal file
@@ -0,0 +1,63 @@
|
||||
<user_instructions priority="0">
|
||||
<environment_override priority="0">
|
||||
YOU ARE IN A DIFFERENT ENVIRONMENT. These instructions override ALL previous tool references.
|
||||
</environment_override>
|
||||
|
||||
<tool_replacements priority="0">
|
||||
<critical_rule priority="0">
|
||||
❌ APPLY_PATCH DOES NOT EXIST → ✅ USE "edit" INSTEAD
|
||||
- NEVER use: apply_patch, applyPatch
|
||||
- ALWAYS use: edit tool for ALL file modifications
|
||||
- Before modifying files: Verify you're using "edit", NOT "apply_patch"
|
||||
</critical_rule>
|
||||
|
||||
<critical_rule priority="0">
|
||||
❌ UPDATE_PLAN DOES NOT EXIST → ✅ USE "todowrite" INSTEAD
|
||||
- NEVER use: update_plan, updatePlan
|
||||
- ALWAYS use: todowrite for ALL task/plan operations
|
||||
- Use todoread to read current plan
|
||||
- Before plan operations: Verify you're using "todowrite", NOT "update_plan"
|
||||
</critical_rule>
|
||||
</tool_replacements>
|
||||
|
||||
<available_tools priority="0">
|
||||
File Operations:
|
||||
• write - Create new files
|
||||
• edit - Modify existing files (REPLACES apply_patch)
|
||||
• patch - Apply diff patches
|
||||
• read - Read file contents
|
||||
|
||||
Search/Discovery:
|
||||
• grep - Search file contents
|
||||
• glob - Find files by pattern
|
||||
• list - List directories (use relative paths)
|
||||
|
||||
Execution:
|
||||
• bash - Run shell commands
|
||||
|
||||
Network:
|
||||
• webfetch - Fetch web content
|
||||
|
||||
Task Management:
|
||||
• todowrite - Manage tasks/plans (REPLACES update_plan)
|
||||
• todoread - Read current plan
|
||||
</available_tools>
|
||||
|
||||
<substitution_rules priority="0">
|
||||
Base instruction says: You MUST use instead:
|
||||
apply_patch → edit
|
||||
update_plan → todowrite
|
||||
read_plan → todoread
|
||||
absolute paths → relative paths
|
||||
</substitution_rules>
|
||||
|
||||
<verification_checklist priority="0">
|
||||
Before file/plan modifications:
|
||||
1. Am I using "edit" NOT "apply_patch"?
|
||||
2. Am I using "todowrite" NOT "update_plan"?
|
||||
3. Is this tool in the approved list above?
|
||||
4. Am I using relative paths?
|
||||
|
||||
If ANY answer is NO → STOP and correct before proceeding.
|
||||
</verification_checklist>
|
||||
</user_instructions>
|
||||
@@ -68,12 +68,13 @@ type RedeemCodeResponse struct {
|
||||
|
||||
// RedeemService 兑换码服务
|
||||
type RedeemService struct {
|
||||
redeemRepo RedeemCodeRepository
|
||||
userRepo UserRepository
|
||||
subscriptionService *SubscriptionService
|
||||
cache RedeemCache
|
||||
billingCacheService *BillingCacheService
|
||||
entClient *dbent.Client
|
||||
redeemRepo RedeemCodeRepository
|
||||
userRepo UserRepository
|
||||
subscriptionService *SubscriptionService
|
||||
cache RedeemCache
|
||||
billingCacheService *BillingCacheService
|
||||
entClient *dbent.Client
|
||||
authCacheInvalidator APIKeyAuthCacheInvalidator
|
||||
}
|
||||
|
||||
// NewRedeemService 创建兑换码服务实例
|
||||
@@ -84,14 +85,16 @@ func NewRedeemService(
|
||||
cache RedeemCache,
|
||||
billingCacheService *BillingCacheService,
|
||||
entClient *dbent.Client,
|
||||
authCacheInvalidator APIKeyAuthCacheInvalidator,
|
||||
) *RedeemService {
|
||||
return &RedeemService{
|
||||
redeemRepo: redeemRepo,
|
||||
userRepo: userRepo,
|
||||
subscriptionService: subscriptionService,
|
||||
cache: cache,
|
||||
billingCacheService: billingCacheService,
|
||||
entClient: entClient,
|
||||
redeemRepo: redeemRepo,
|
||||
userRepo: userRepo,
|
||||
subscriptionService: subscriptionService,
|
||||
cache: cache,
|
||||
billingCacheService: billingCacheService,
|
||||
entClient: entClient,
|
||||
authCacheInvalidator: authCacheInvalidator,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -324,18 +327,33 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
|
||||
|
||||
// invalidateRedeemCaches 失效兑换相关的缓存
|
||||
func (s *RedeemService) invalidateRedeemCaches(ctx context.Context, userID int64, redeemCode *RedeemCode) {
|
||||
if s.billingCacheService == nil {
|
||||
return
|
||||
}
|
||||
|
||||
switch redeemCode.Type {
|
||||
case RedeemTypeBalance:
|
||||
if s.authCacheInvalidator != nil {
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
|
||||
}
|
||||
if s.billingCacheService == nil {
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_ = s.billingCacheService.InvalidateUserBalance(cacheCtx, userID)
|
||||
}()
|
||||
case RedeemTypeConcurrency:
|
||||
if s.authCacheInvalidator != nil {
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
|
||||
}
|
||||
if s.billingCacheService == nil {
|
||||
return
|
||||
}
|
||||
case RedeemTypeSubscription:
|
||||
if s.authCacheInvalidator != nil {
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
|
||||
}
|
||||
if s.billingCacheService == nil {
|
||||
return
|
||||
}
|
||||
if redeemCode.GroupID != nil {
|
||||
groupID := *redeemCode.GroupID
|
||||
go func() {
|
||||
|
||||
@@ -54,17 +54,19 @@ type UsageStats struct {
|
||||
|
||||
// UsageService 使用统计服务
|
||||
type UsageService struct {
|
||||
usageRepo UsageLogRepository
|
||||
userRepo UserRepository
|
||||
entClient *dbent.Client
|
||||
usageRepo UsageLogRepository
|
||||
userRepo UserRepository
|
||||
entClient *dbent.Client
|
||||
authCacheInvalidator APIKeyAuthCacheInvalidator
|
||||
}
|
||||
|
||||
// NewUsageService 创建使用统计服务实例
|
||||
func NewUsageService(usageRepo UsageLogRepository, userRepo UserRepository, entClient *dbent.Client) *UsageService {
|
||||
func NewUsageService(usageRepo UsageLogRepository, userRepo UserRepository, entClient *dbent.Client, authCacheInvalidator APIKeyAuthCacheInvalidator) *UsageService {
|
||||
return &UsageService{
|
||||
usageRepo: usageRepo,
|
||||
userRepo: userRepo,
|
||||
entClient: entClient,
|
||||
usageRepo: usageRepo,
|
||||
userRepo: userRepo,
|
||||
entClient: entClient,
|
||||
authCacheInvalidator: authCacheInvalidator,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -118,10 +120,12 @@ func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (*
|
||||
}
|
||||
|
||||
// 扣除用户余额
|
||||
balanceUpdated := false
|
||||
if inserted && req.ActualCost > 0 {
|
||||
if err := s.userRepo.UpdateBalance(txCtx, req.UserID, -req.ActualCost); err != nil {
|
||||
return nil, fmt.Errorf("update user balance: %w", err)
|
||||
}
|
||||
balanceUpdated = true
|
||||
}
|
||||
|
||||
if tx != nil {
|
||||
@@ -130,9 +134,18 @@ func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (*
|
||||
}
|
||||
}
|
||||
|
||||
s.invalidateUsageCaches(ctx, req.UserID, balanceUpdated)
|
||||
|
||||
return usageLog, nil
|
||||
}
|
||||
|
||||
func (s *UsageService) invalidateUsageCaches(ctx context.Context, userID int64, balanceUpdated bool) {
|
||||
if !balanceUpdated || s.authCacheInvalidator == nil {
|
||||
return
|
||||
}
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取使用日志
|
||||
func (s *UsageService) GetByID(ctx context.Context, id int64) (*UsageLog, error) {
|
||||
log, err := s.usageRepo.GetByID(ctx, id)
|
||||
|
||||
@@ -55,13 +55,15 @@ type ChangePasswordRequest struct {
|
||||
|
||||
// UserService 用户服务
|
||||
type UserService struct {
|
||||
userRepo UserRepository
|
||||
userRepo UserRepository
|
||||
authCacheInvalidator APIKeyAuthCacheInvalidator
|
||||
}
|
||||
|
||||
// NewUserService 创建用户服务实例
|
||||
func NewUserService(userRepo UserRepository) *UserService {
|
||||
func NewUserService(userRepo UserRepository, authCacheInvalidator APIKeyAuthCacheInvalidator) *UserService {
|
||||
return &UserService{
|
||||
userRepo: userRepo,
|
||||
userRepo: userRepo,
|
||||
authCacheInvalidator: authCacheInvalidator,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -89,6 +91,7 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get user: %w", err)
|
||||
}
|
||||
oldConcurrency := user.Concurrency
|
||||
|
||||
// 更新字段
|
||||
if req.Email != nil {
|
||||
@@ -114,6 +117,9 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat
|
||||
if err := s.userRepo.Update(ctx, user); err != nil {
|
||||
return nil, fmt.Errorf("update user: %w", err)
|
||||
}
|
||||
if s.authCacheInvalidator != nil && user.Concurrency != oldConcurrency {
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
@@ -169,6 +175,9 @@ func (s *UserService) UpdateBalance(ctx context.Context, userID int64, amount fl
|
||||
if err := s.userRepo.UpdateBalance(ctx, userID, amount); err != nil {
|
||||
return fmt.Errorf("update balance: %w", err)
|
||||
}
|
||||
if s.authCacheInvalidator != nil {
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -177,6 +186,9 @@ func (s *UserService) UpdateConcurrency(ctx context.Context, userID int64, concu
|
||||
if err := s.userRepo.UpdateConcurrency(ctx, userID, concurrency); err != nil {
|
||||
return fmt.Errorf("update concurrency: %w", err)
|
||||
}
|
||||
if s.authCacheInvalidator != nil {
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -192,12 +204,18 @@ func (s *UserService) UpdateStatus(ctx context.Context, userID int64, status str
|
||||
if err := s.userRepo.Update(ctx, user); err != nil {
|
||||
return fmt.Errorf("update user: %w", err)
|
||||
}
|
||||
if s.authCacheInvalidator != nil {
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete 删除用户(管理员功能)
|
||||
func (s *UserService) Delete(ctx context.Context, userID int64) error {
|
||||
if s.authCacheInvalidator != nil {
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID)
|
||||
}
|
||||
if err := s.userRepo.Delete(ctx, userID); err != nil {
|
||||
return fmt.Errorf("delete user: %w", err)
|
||||
}
|
||||
|
||||
@@ -49,6 +49,13 @@ func ProvideTokenRefreshService(
|
||||
return svc
|
||||
}
|
||||
|
||||
// ProvideDashboardAggregationService 创建并启动仪表盘聚合服务
|
||||
func ProvideDashboardAggregationService(repo DashboardAggregationRepository, timingWheel *TimingWheelService, cfg *config.Config) *DashboardAggregationService {
|
||||
svc := NewDashboardAggregationService(repo, timingWheel, cfg)
|
||||
svc.Start()
|
||||
return svc
|
||||
}
|
||||
|
||||
// ProvideAccountExpiryService creates and starts AccountExpiryService.
|
||||
func ProvideAccountExpiryService(accountRepo AccountRepository) *AccountExpiryService {
|
||||
svc := NewAccountExpiryService(accountRepo, time.Minute)
|
||||
@@ -145,12 +152,18 @@ func ProvideOpsScheduledReportService(
|
||||
return svc
|
||||
}
|
||||
|
||||
// ProvideAPIKeyAuthCacheInvalidator 提供 API Key 认证缓存失效能力
|
||||
func ProvideAPIKeyAuthCacheInvalidator(apiKeyService *APIKeyService) APIKeyAuthCacheInvalidator {
|
||||
return apiKeyService
|
||||
}
|
||||
|
||||
// ProviderSet is the Wire provider set for all services
|
||||
var ProviderSet = wire.NewSet(
|
||||
// Core services
|
||||
NewAuthService,
|
||||
NewUserService,
|
||||
NewAPIKeyService,
|
||||
ProvideAPIKeyAuthCacheInvalidator,
|
||||
NewGroupService,
|
||||
NewAccountService,
|
||||
NewProxyService,
|
||||
@@ -194,6 +207,7 @@ var ProviderSet = wire.NewSet(
|
||||
ProvideTokenRefreshService,
|
||||
ProvideAccountExpiryService,
|
||||
ProvideTimingWheelService,
|
||||
ProvideDashboardAggregationService,
|
||||
ProvideDeferredService,
|
||||
NewAntigravityQuotaFetcher,
|
||||
NewUserAttributeService,
|
||||
|
||||
Reference in New Issue
Block a user