Files
xinghuoapi/backend/internal/service/user_group_rate_resolver.go
yangjianbo a18bbb5f2f fix(openai): 统一专属倍率计费链路并补齐回归测试
抽取共享的用户分组专属倍率解析器,统一缓存、singleflight 与回退逻辑。\n\n让 OpenAI 独立计费链路复用专属倍率解析,修复 usage 记录与实际扣费未命中用户专属倍率的问题。\n\n补齐 OpenAI 计费与解析器单元测试,并修复全量回归中暴露的 lint 阻塞项。\n\nCo-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-06 16:47:51 +08:00

104 lines
2.5 KiB
Go

package service
import (
"context"
"fmt"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
gocache "github.com/patrickmn/go-cache"
"golang.org/x/sync/singleflight"
)
type userGroupRateResolver struct {
repo UserGroupRateRepository
cache *gocache.Cache
cacheTTL time.Duration
sf *singleflight.Group
logComponent string
}
func newUserGroupRateResolver(repo UserGroupRateRepository, cache *gocache.Cache, cacheTTL time.Duration, sf *singleflight.Group, logComponent string) *userGroupRateResolver {
if cacheTTL <= 0 {
cacheTTL = defaultUserGroupRateCacheTTL
}
if cache == nil {
cache = gocache.New(cacheTTL, time.Minute)
}
if logComponent == "" {
logComponent = "service.gateway"
}
if sf == nil {
sf = &singleflight.Group{}
}
return &userGroupRateResolver{
repo: repo,
cache: cache,
cacheTTL: cacheTTL,
sf: sf,
logComponent: logComponent,
}
}
func (r *userGroupRateResolver) Resolve(ctx context.Context, userID, groupID int64, groupDefaultMultiplier float64) float64 {
if r == nil || userID <= 0 || groupID <= 0 {
return groupDefaultMultiplier
}
key := fmt.Sprintf("%d:%d", userID, groupID)
if r.cache != nil {
if cached, ok := r.cache.Get(key); ok {
if multiplier, castOK := cached.(float64); castOK {
userGroupRateCacheHitTotal.Add(1)
return multiplier
}
}
}
if r.repo == nil {
return groupDefaultMultiplier
}
userGroupRateCacheMissTotal.Add(1)
value, err, shared := r.sf.Do(key, func() (any, error) {
if r.cache != nil {
if cached, ok := r.cache.Get(key); ok {
if multiplier, castOK := cached.(float64); castOK {
userGroupRateCacheHitTotal.Add(1)
return multiplier, nil
}
}
}
userGroupRateCacheLoadTotal.Add(1)
userRate, repoErr := r.repo.GetByUserAndGroup(ctx, userID, groupID)
if repoErr != nil {
return nil, repoErr
}
multiplier := groupDefaultMultiplier
if userRate != nil {
multiplier = *userRate
}
if r.cache != nil {
r.cache.Set(key, multiplier, r.cacheTTL)
}
return multiplier, nil
})
if shared {
userGroupRateCacheSFSharedTotal.Add(1)
}
if err != nil {
userGroupRateCacheFallbackTotal.Add(1)
logger.LegacyPrintf(r.logComponent, "get user group rate failed, fallback to group default: user=%d group=%d err=%v", userID, groupID, err)
return groupDefaultMultiplier
}
multiplier, ok := value.(float64)
if !ok {
userGroupRateCacheFallbackTotal.Add(1)
return groupDefaultMultiplier
}
return multiplier
}