first commit
This commit is contained in:
559
backend/internal/service/account.go
Normal file
559
backend/internal/service/account.go
Normal file
@@ -0,0 +1,559 @@
|
||||
// Package service provides business logic and domain services for the application.
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Account struct {
|
||||
ID int64
|
||||
Name string
|
||||
Notes *string
|
||||
Platform string
|
||||
Type string
|
||||
Credentials map[string]any
|
||||
Extra map[string]any
|
||||
ProxyID *int64
|
||||
Concurrency int
|
||||
Priority int
|
||||
// RateMultiplier 账号计费倍率(>=0,允许 0 表示该账号计费为 0)。
|
||||
// 使用指针用于兼容旧版本调度缓存(Redis)中缺字段的情况:nil 表示按 1.0 处理。
|
||||
RateMultiplier *float64
|
||||
Status string
|
||||
ErrorMessage string
|
||||
LastUsedAt *time.Time
|
||||
ExpiresAt *time.Time
|
||||
AutoPauseOnExpired bool
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
|
||||
Schedulable bool
|
||||
|
||||
RateLimitedAt *time.Time
|
||||
RateLimitResetAt *time.Time
|
||||
OverloadUntil *time.Time
|
||||
|
||||
TempUnschedulableUntil *time.Time
|
||||
TempUnschedulableReason string
|
||||
|
||||
SessionWindowStart *time.Time
|
||||
SessionWindowEnd *time.Time
|
||||
SessionWindowStatus string
|
||||
|
||||
Proxy *Proxy
|
||||
AccountGroups []AccountGroup
|
||||
GroupIDs []int64
|
||||
Groups []*Group
|
||||
}
|
||||
|
||||
type TempUnschedulableRule struct {
|
||||
ErrorCode int `json:"error_code"`
|
||||
Keywords []string `json:"keywords"`
|
||||
DurationMinutes int `json:"duration_minutes"`
|
||||
Description string `json:"description"`
|
||||
}
|
||||
|
||||
func (a *Account) IsActive() bool {
|
||||
return a.Status == StatusActive
|
||||
}
|
||||
|
||||
// BillingRateMultiplier 返回账号计费倍率。
|
||||
// - nil 表示未配置/旧缓存缺字段,按 1.0 处理
|
||||
// - 允许 0,表示该账号计费为 0
|
||||
// - 负数属于非法数据,出于安全考虑按 1.0 处理
|
||||
func (a *Account) BillingRateMultiplier() float64 {
|
||||
if a == nil || a.RateMultiplier == nil {
|
||||
return 1.0
|
||||
}
|
||||
if *a.RateMultiplier < 0 {
|
||||
return 1.0
|
||||
}
|
||||
return *a.RateMultiplier
|
||||
}
|
||||
|
||||
func (a *Account) IsSchedulable() bool {
|
||||
if !a.IsActive() || !a.Schedulable {
|
||||
return false
|
||||
}
|
||||
now := time.Now()
|
||||
if a.AutoPauseOnExpired && a.ExpiresAt != nil && !now.Before(*a.ExpiresAt) {
|
||||
return false
|
||||
}
|
||||
if a.OverloadUntil != nil && now.Before(*a.OverloadUntil) {
|
||||
return false
|
||||
}
|
||||
if a.RateLimitResetAt != nil && now.Before(*a.RateLimitResetAt) {
|
||||
return false
|
||||
}
|
||||
if a.TempUnschedulableUntil != nil && now.Before(*a.TempUnschedulableUntil) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (a *Account) IsRateLimited() bool {
|
||||
if a.RateLimitResetAt == nil {
|
||||
return false
|
||||
}
|
||||
return time.Now().Before(*a.RateLimitResetAt)
|
||||
}
|
||||
|
||||
func (a *Account) IsOverloaded() bool {
|
||||
if a.OverloadUntil == nil {
|
||||
return false
|
||||
}
|
||||
return time.Now().Before(*a.OverloadUntil)
|
||||
}
|
||||
|
||||
func (a *Account) IsOAuth() bool {
|
||||
return a.Type == AccountTypeOAuth || a.Type == AccountTypeSetupToken
|
||||
}
|
||||
|
||||
func (a *Account) IsGemini() bool {
|
||||
return a.Platform == PlatformGemini
|
||||
}
|
||||
|
||||
func (a *Account) GeminiOAuthType() string {
|
||||
if a.Platform != PlatformGemini || a.Type != AccountTypeOAuth {
|
||||
return ""
|
||||
}
|
||||
oauthType := strings.TrimSpace(a.GetCredential("oauth_type"))
|
||||
if oauthType == "" && strings.TrimSpace(a.GetCredential("project_id")) != "" {
|
||||
return "code_assist"
|
||||
}
|
||||
return oauthType
|
||||
}
|
||||
|
||||
func (a *Account) GeminiTierID() string {
|
||||
tierID := strings.TrimSpace(a.GetCredential("tier_id"))
|
||||
return tierID
|
||||
}
|
||||
|
||||
func (a *Account) IsGeminiCodeAssist() bool {
|
||||
if a.Platform != PlatformGemini || a.Type != AccountTypeOAuth {
|
||||
return false
|
||||
}
|
||||
oauthType := a.GeminiOAuthType()
|
||||
if oauthType == "" {
|
||||
return strings.TrimSpace(a.GetCredential("project_id")) != ""
|
||||
}
|
||||
return oauthType == "code_assist"
|
||||
}
|
||||
|
||||
func (a *Account) CanGetUsage() bool {
|
||||
return a.Type == AccountTypeOAuth
|
||||
}
|
||||
|
||||
func (a *Account) GetCredential(key string) string {
|
||||
if a.Credentials == nil {
|
||||
return ""
|
||||
}
|
||||
v, ok := a.Credentials[key]
|
||||
if !ok || v == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// 支持多种类型(兼容历史数据中 expires_at 等字段可能是数字或字符串)
|
||||
switch val := v.(type) {
|
||||
case string:
|
||||
return val
|
||||
case json.Number:
|
||||
// GORM datatypes.JSONMap 使用 UseNumber() 解析,数字类型为 json.Number
|
||||
return val.String()
|
||||
case float64:
|
||||
// JSON 解析后数字默认为 float64
|
||||
return strconv.FormatInt(int64(val), 10)
|
||||
case int64:
|
||||
return strconv.FormatInt(val, 10)
|
||||
case int:
|
||||
return strconv.Itoa(val)
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// GetCredentialAsTime 解析凭证中的时间戳字段,支持多种格式
|
||||
// 兼容以下格式:
|
||||
// - RFC3339 字符串: "2025-01-01T00:00:00Z"
|
||||
// - Unix 时间戳字符串: "1735689600"
|
||||
// - Unix 时间戳数字: 1735689600 (float64/int64/json.Number)
|
||||
func (a *Account) GetCredentialAsTime(key string) *time.Time {
|
||||
s := a.GetCredential(key)
|
||||
if s == "" {
|
||||
return nil
|
||||
}
|
||||
// 尝试 RFC3339 格式
|
||||
if t, err := time.Parse(time.RFC3339, s); err == nil {
|
||||
return &t
|
||||
}
|
||||
// 尝试 Unix 时间戳(纯数字字符串)
|
||||
if ts, err := strconv.ParseInt(s, 10, 64); err == nil {
|
||||
t := time.Unix(ts, 0)
|
||||
return &t
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Account) IsTempUnschedulableEnabled() bool {
|
||||
if a.Credentials == nil {
|
||||
return false
|
||||
}
|
||||
raw, ok := a.Credentials["temp_unschedulable_enabled"]
|
||||
if !ok || raw == nil {
|
||||
return false
|
||||
}
|
||||
enabled, ok := raw.(bool)
|
||||
return ok && enabled
|
||||
}
|
||||
|
||||
func (a *Account) GetTempUnschedulableRules() []TempUnschedulableRule {
|
||||
if a.Credentials == nil {
|
||||
return nil
|
||||
}
|
||||
raw, ok := a.Credentials["temp_unschedulable_rules"]
|
||||
if !ok || raw == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
arr, ok := raw.([]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
rules := make([]TempUnschedulableRule, 0, len(arr))
|
||||
for _, item := range arr {
|
||||
entry, ok := item.(map[string]any)
|
||||
if !ok || entry == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
rule := TempUnschedulableRule{
|
||||
ErrorCode: parseTempUnschedInt(entry["error_code"]),
|
||||
Keywords: parseTempUnschedStrings(entry["keywords"]),
|
||||
DurationMinutes: parseTempUnschedInt(entry["duration_minutes"]),
|
||||
Description: parseTempUnschedString(entry["description"]),
|
||||
}
|
||||
|
||||
if rule.ErrorCode <= 0 || rule.DurationMinutes <= 0 || len(rule.Keywords) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
rules = append(rules, rule)
|
||||
}
|
||||
|
||||
return rules
|
||||
}
|
||||
|
||||
func parseTempUnschedString(value any) string {
|
||||
s, ok := value.(string)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(s)
|
||||
}
|
||||
|
||||
func parseTempUnschedStrings(value any) []string {
|
||||
if value == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var raw []string
|
||||
switch v := value.(type) {
|
||||
case []string:
|
||||
raw = v
|
||||
case []any:
|
||||
raw = make([]string, 0, len(v))
|
||||
for _, item := range v {
|
||||
if s, ok := item.(string); ok {
|
||||
raw = append(raw, s)
|
||||
}
|
||||
}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
out := make([]string, 0, len(raw))
|
||||
for _, item := range raw {
|
||||
s := strings.TrimSpace(item)
|
||||
if s != "" {
|
||||
out = append(out, s)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func normalizeAccountNotes(value *string) *string {
|
||||
if value == nil {
|
||||
return nil
|
||||
}
|
||||
trimmed := strings.TrimSpace(*value)
|
||||
if trimmed == "" {
|
||||
return nil
|
||||
}
|
||||
return &trimmed
|
||||
}
|
||||
|
||||
func parseTempUnschedInt(value any) int {
|
||||
switch v := value.(type) {
|
||||
case int:
|
||||
return v
|
||||
case int64:
|
||||
return int(v)
|
||||
case float64:
|
||||
return int(v)
|
||||
case json.Number:
|
||||
if i, err := v.Int64(); err == nil {
|
||||
return int(i)
|
||||
}
|
||||
case string:
|
||||
if i, err := strconv.Atoi(strings.TrimSpace(v)); err == nil {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (a *Account) GetModelMapping() map[string]string {
|
||||
if a.Credentials == nil {
|
||||
return nil
|
||||
}
|
||||
raw, ok := a.Credentials["model_mapping"]
|
||||
if !ok || raw == nil {
|
||||
return nil
|
||||
}
|
||||
if m, ok := raw.(map[string]any); ok {
|
||||
result := make(map[string]string)
|
||||
for k, v := range m {
|
||||
if s, ok := v.(string); ok {
|
||||
result[k] = s
|
||||
}
|
||||
}
|
||||
if len(result) > 0 {
|
||||
return result
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Account) IsModelSupported(requestedModel string) bool {
|
||||
mapping := a.GetModelMapping()
|
||||
if len(mapping) == 0 {
|
||||
return true
|
||||
}
|
||||
_, exists := mapping[requestedModel]
|
||||
return exists
|
||||
}
|
||||
|
||||
func (a *Account) GetMappedModel(requestedModel string) string {
|
||||
mapping := a.GetModelMapping()
|
||||
if len(mapping) == 0 {
|
||||
return requestedModel
|
||||
}
|
||||
if mappedModel, exists := mapping[requestedModel]; exists {
|
||||
return mappedModel
|
||||
}
|
||||
return requestedModel
|
||||
}
|
||||
|
||||
func (a *Account) GetBaseURL() string {
|
||||
if a.Type != AccountTypeAPIKey {
|
||||
return ""
|
||||
}
|
||||
baseURL := a.GetCredential("base_url")
|
||||
if baseURL == "" {
|
||||
return "https://api.anthropic.com"
|
||||
}
|
||||
return baseURL
|
||||
}
|
||||
|
||||
func (a *Account) GetExtraString(key string) string {
|
||||
if a.Extra == nil {
|
||||
return ""
|
||||
}
|
||||
if v, ok := a.Extra[key]; ok {
|
||||
if s, ok := v.(string); ok {
|
||||
return s
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (a *Account) IsCustomErrorCodesEnabled() bool {
|
||||
if a.Type != AccountTypeAPIKey || a.Credentials == nil {
|
||||
return false
|
||||
}
|
||||
if v, ok := a.Credentials["custom_error_codes_enabled"]; ok {
|
||||
if enabled, ok := v.(bool); ok {
|
||||
return enabled
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (a *Account) GetCustomErrorCodes() []int {
|
||||
if a.Credentials == nil {
|
||||
return nil
|
||||
}
|
||||
raw, ok := a.Credentials["custom_error_codes"]
|
||||
if !ok || raw == nil {
|
||||
return nil
|
||||
}
|
||||
if arr, ok := raw.([]any); ok {
|
||||
result := make([]int, 0, len(arr))
|
||||
for _, v := range arr {
|
||||
if f, ok := v.(float64); ok {
|
||||
result = append(result, int(f))
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Account) ShouldHandleErrorCode(statusCode int) bool {
|
||||
if !a.IsCustomErrorCodesEnabled() {
|
||||
return true
|
||||
}
|
||||
codes := a.GetCustomErrorCodes()
|
||||
if len(codes) == 0 {
|
||||
return true
|
||||
}
|
||||
for _, code := range codes {
|
||||
if code == statusCode {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (a *Account) IsInterceptWarmupEnabled() bool {
|
||||
if a.Credentials == nil {
|
||||
return false
|
||||
}
|
||||
if v, ok := a.Credentials["intercept_warmup_requests"]; ok {
|
||||
if enabled, ok := v.(bool); ok {
|
||||
return enabled
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (a *Account) IsOpenAI() bool {
|
||||
return a.Platform == PlatformOpenAI
|
||||
}
|
||||
|
||||
func (a *Account) IsAnthropic() bool {
|
||||
return a.Platform == PlatformAnthropic
|
||||
}
|
||||
|
||||
func (a *Account) IsOpenAIOAuth() bool {
|
||||
return a.IsOpenAI() && a.Type == AccountTypeOAuth
|
||||
}
|
||||
|
||||
func (a *Account) IsOpenAIApiKey() bool {
|
||||
return a.IsOpenAI() && a.Type == AccountTypeAPIKey
|
||||
}
|
||||
|
||||
func (a *Account) GetOpenAIBaseURL() string {
|
||||
if !a.IsOpenAI() {
|
||||
return ""
|
||||
}
|
||||
if a.Type == AccountTypeAPIKey {
|
||||
baseURL := a.GetCredential("base_url")
|
||||
if baseURL != "" {
|
||||
return baseURL
|
||||
}
|
||||
}
|
||||
return "https://api.openai.com"
|
||||
}
|
||||
|
||||
func (a *Account) GetOpenAIAccessToken() string {
|
||||
if !a.IsOpenAI() {
|
||||
return ""
|
||||
}
|
||||
return a.GetCredential("access_token")
|
||||
}
|
||||
|
||||
func (a *Account) GetOpenAIRefreshToken() string {
|
||||
if !a.IsOpenAIOAuth() {
|
||||
return ""
|
||||
}
|
||||
return a.GetCredential("refresh_token")
|
||||
}
|
||||
|
||||
func (a *Account) GetOpenAIIDToken() string {
|
||||
if !a.IsOpenAIOAuth() {
|
||||
return ""
|
||||
}
|
||||
return a.GetCredential("id_token")
|
||||
}
|
||||
|
||||
func (a *Account) GetOpenAIApiKey() string {
|
||||
if !a.IsOpenAIApiKey() {
|
||||
return ""
|
||||
}
|
||||
return a.GetCredential("api_key")
|
||||
}
|
||||
|
||||
func (a *Account) GetOpenAIUserAgent() string {
|
||||
if !a.IsOpenAI() {
|
||||
return ""
|
||||
}
|
||||
return a.GetCredential("user_agent")
|
||||
}
|
||||
|
||||
func (a *Account) GetChatGPTAccountID() string {
|
||||
if !a.IsOpenAIOAuth() {
|
||||
return ""
|
||||
}
|
||||
return a.GetCredential("chatgpt_account_id")
|
||||
}
|
||||
|
||||
func (a *Account) GetChatGPTUserID() string {
|
||||
if !a.IsOpenAIOAuth() {
|
||||
return ""
|
||||
}
|
||||
return a.GetCredential("chatgpt_user_id")
|
||||
}
|
||||
|
||||
func (a *Account) GetOpenAIOrganizationID() string {
|
||||
if !a.IsOpenAIOAuth() {
|
||||
return ""
|
||||
}
|
||||
return a.GetCredential("organization_id")
|
||||
}
|
||||
|
||||
func (a *Account) GetOpenAITokenExpiresAt() *time.Time {
|
||||
if !a.IsOpenAIOAuth() {
|
||||
return nil
|
||||
}
|
||||
return a.GetCredentialAsTime("expires_at")
|
||||
}
|
||||
|
||||
func (a *Account) IsOpenAITokenExpired() bool {
|
||||
expiresAt := a.GetOpenAITokenExpiresAt()
|
||||
if expiresAt == nil {
|
||||
return false
|
||||
}
|
||||
return time.Now().Add(60 * time.Second).After(*expiresAt)
|
||||
}
|
||||
|
||||
// IsMixedSchedulingEnabled 检查 antigravity 账户是否启用混合调度
|
||||
// 启用后可参与 anthropic/gemini 分组的账户调度
|
||||
func (a *Account) IsMixedSchedulingEnabled() bool {
|
||||
if a.Platform != PlatformAntigravity {
|
||||
return false
|
||||
}
|
||||
if a.Extra == nil {
|
||||
return false
|
||||
}
|
||||
if v, ok := a.Extra["mixed_scheduling"]; ok {
|
||||
if enabled, ok := v.(bool); ok {
|
||||
return enabled
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAccount_BillingRateMultiplier_DefaultsToOneWhenNil(t *testing.T) {
|
||||
var a Account
|
||||
require.NoError(t, json.Unmarshal([]byte(`{"id":1,"name":"acc","status":"active"}`), &a))
|
||||
require.Nil(t, a.RateMultiplier)
|
||||
require.Equal(t, 1.0, a.BillingRateMultiplier())
|
||||
}
|
||||
|
||||
func TestAccount_BillingRateMultiplier_AllowsZero(t *testing.T) {
|
||||
v := 0.0
|
||||
a := Account{RateMultiplier: &v}
|
||||
require.Equal(t, 0.0, a.BillingRateMultiplier())
|
||||
}
|
||||
|
||||
func TestAccount_BillingRateMultiplier_NegativeFallsBackToOne(t *testing.T) {
|
||||
v := -1.0
|
||||
a := Account{RateMultiplier: &v}
|
||||
require.Equal(t, 1.0, a.BillingRateMultiplier())
|
||||
}
|
||||
71
backend/internal/service/account_expiry_service.go
Normal file
71
backend/internal/service/account_expiry_service.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// AccountExpiryService periodically pauses expired accounts when auto-pause is enabled.
|
||||
type AccountExpiryService struct {
|
||||
accountRepo AccountRepository
|
||||
interval time.Duration
|
||||
stopCh chan struct{}
|
||||
stopOnce sync.Once
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
func NewAccountExpiryService(accountRepo AccountRepository, interval time.Duration) *AccountExpiryService {
|
||||
return &AccountExpiryService{
|
||||
accountRepo: accountRepo,
|
||||
interval: interval,
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AccountExpiryService) Start() {
|
||||
if s == nil || s.accountRepo == nil || s.interval <= 0 {
|
||||
return
|
||||
}
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
ticker := time.NewTicker(s.interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
s.runOnce()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
s.runOnce()
|
||||
case <-s.stopCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *AccountExpiryService) Stop() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.stopOnce.Do(func() {
|
||||
close(s.stopCh)
|
||||
})
|
||||
s.wg.Wait()
|
||||
}
|
||||
|
||||
func (s *AccountExpiryService) runOnce() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
updated, err := s.accountRepo.AutoPauseExpiredAccounts(ctx, time.Now())
|
||||
if err != nil {
|
||||
log.Printf("[AccountExpiry] Auto pause expired accounts failed: %v", err)
|
||||
return
|
||||
}
|
||||
if updated > 0 {
|
||||
log.Printf("[AccountExpiry] Auto paused %d expired accounts", updated)
|
||||
}
|
||||
}
|
||||
13
backend/internal/service/account_group.go
Normal file
13
backend/internal/service/account_group.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package service
|
||||
|
||||
import "time"
|
||||
|
||||
type AccountGroup struct {
|
||||
AccountID int64
|
||||
GroupID int64
|
||||
Priority int
|
||||
CreatedAt time.Time
|
||||
|
||||
Account *Account
|
||||
Group *Group
|
||||
}
|
||||
351
backend/internal/service/account_service.go
Normal file
351
backend/internal/service/account_service.go
Normal file
@@ -0,0 +1,351 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrAccountNotFound = infraerrors.NotFound("ACCOUNT_NOT_FOUND", "account not found")
|
||||
ErrAccountNilInput = infraerrors.BadRequest("ACCOUNT_NIL_INPUT", "account input cannot be nil")
|
||||
)
|
||||
|
||||
type AccountRepository interface {
|
||||
Create(ctx context.Context, account *Account) error
|
||||
GetByID(ctx context.Context, id int64) (*Account, error)
|
||||
// GetByIDs fetches accounts by IDs in a single query.
|
||||
// It should return all accounts found (missing IDs are ignored).
|
||||
GetByIDs(ctx context.Context, ids []int64) ([]*Account, error)
|
||||
// ExistsByID 检查账号是否存在,仅返回布尔值,用于删除前的轻量级存在性检查
|
||||
ExistsByID(ctx context.Context, id int64) (bool, error)
|
||||
// GetByCRSAccountID finds an account previously synced from CRS.
|
||||
// Returns (nil, nil) if not found.
|
||||
GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error)
|
||||
Update(ctx context.Context, account *Account) error
|
||||
Delete(ctx context.Context, id int64) error
|
||||
|
||||
List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error)
|
||||
ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error)
|
||||
ListByGroup(ctx context.Context, groupID int64) ([]Account, error)
|
||||
ListActive(ctx context.Context) ([]Account, error)
|
||||
ListByPlatform(ctx context.Context, platform string) ([]Account, error)
|
||||
|
||||
UpdateLastUsed(ctx context.Context, id int64) error
|
||||
BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error
|
||||
SetError(ctx context.Context, id int64, errorMsg string) error
|
||||
SetSchedulable(ctx context.Context, id int64, schedulable bool) error
|
||||
AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error)
|
||||
BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error
|
||||
|
||||
ListSchedulable(ctx context.Context) ([]Account, error)
|
||||
ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]Account, error)
|
||||
ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error)
|
||||
ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error)
|
||||
ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]Account, error)
|
||||
ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error)
|
||||
|
||||
SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error
|
||||
SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error
|
||||
SetOverloaded(ctx context.Context, id int64, until time.Time) error
|
||||
SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error
|
||||
ClearTempUnschedulable(ctx context.Context, id int64) error
|
||||
ClearRateLimit(ctx context.Context, id int64) error
|
||||
ClearAntigravityQuotaScopes(ctx context.Context, id int64) error
|
||||
UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error
|
||||
UpdateExtra(ctx context.Context, id int64, updates map[string]any) error
|
||||
BulkUpdate(ctx context.Context, ids []int64, updates AccountBulkUpdate) (int64, error)
|
||||
}
|
||||
|
||||
// AccountBulkUpdate describes the fields that can be updated in a bulk operation.
|
||||
// Nil pointers mean "do not change".
|
||||
type AccountBulkUpdate struct {
|
||||
Name *string
|
||||
ProxyID *int64
|
||||
Concurrency *int
|
||||
Priority *int
|
||||
RateMultiplier *float64
|
||||
Status *string
|
||||
Schedulable *bool
|
||||
Credentials map[string]any
|
||||
Extra map[string]any
|
||||
}
|
||||
|
||||
// CreateAccountRequest 创建账号请求
|
||||
type CreateAccountRequest struct {
|
||||
Name string `json:"name"`
|
||||
Notes *string `json:"notes"`
|
||||
Platform string `json:"platform"`
|
||||
Type string `json:"type"`
|
||||
Credentials map[string]any `json:"credentials"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
Priority int `json:"priority"`
|
||||
GroupIDs []int64 `json:"group_ids"`
|
||||
ExpiresAt *time.Time `json:"expires_at"`
|
||||
AutoPauseOnExpired *bool `json:"auto_pause_on_expired"`
|
||||
}
|
||||
|
||||
// UpdateAccountRequest 更新账号请求
|
||||
type UpdateAccountRequest struct {
|
||||
Name *string `json:"name"`
|
||||
Notes *string `json:"notes"`
|
||||
Credentials *map[string]any `json:"credentials"`
|
||||
Extra *map[string]any `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Concurrency *int `json:"concurrency"`
|
||||
Priority *int `json:"priority"`
|
||||
Status *string `json:"status"`
|
||||
GroupIDs *[]int64 `json:"group_ids"`
|
||||
ExpiresAt *time.Time `json:"expires_at"`
|
||||
AutoPauseOnExpired *bool `json:"auto_pause_on_expired"`
|
||||
}
|
||||
|
||||
// AccountService 账号管理服务
|
||||
type AccountService struct {
|
||||
accountRepo AccountRepository
|
||||
groupRepo GroupRepository
|
||||
}
|
||||
|
||||
// NewAccountService 创建账号服务实例
|
||||
func NewAccountService(accountRepo AccountRepository, groupRepo GroupRepository) *AccountService {
|
||||
return &AccountService{
|
||||
accountRepo: accountRepo,
|
||||
groupRepo: groupRepo,
|
||||
}
|
||||
}
|
||||
|
||||
// Create 创建账号
|
||||
func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (*Account, error) {
|
||||
// 验证分组是否存在(如果指定了分组)
|
||||
if len(req.GroupIDs) > 0 {
|
||||
for _, groupID := range req.GroupIDs {
|
||||
_, err := s.groupRepo.GetByID(ctx, groupID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get group: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 创建账号
|
||||
account := &Account{
|
||||
Name: req.Name,
|
||||
Notes: normalizeAccountNotes(req.Notes),
|
||||
Platform: req.Platform,
|
||||
Type: req.Type,
|
||||
Credentials: req.Credentials,
|
||||
Extra: req.Extra,
|
||||
ProxyID: req.ProxyID,
|
||||
Concurrency: req.Concurrency,
|
||||
Priority: req.Priority,
|
||||
Status: StatusActive,
|
||||
ExpiresAt: req.ExpiresAt,
|
||||
}
|
||||
if req.AutoPauseOnExpired != nil {
|
||||
account.AutoPauseOnExpired = *req.AutoPauseOnExpired
|
||||
} else {
|
||||
account.AutoPauseOnExpired = true
|
||||
}
|
||||
|
||||
if err := s.accountRepo.Create(ctx, account); err != nil {
|
||||
return nil, fmt.Errorf("create account: %w", err)
|
||||
}
|
||||
|
||||
// 绑定分组
|
||||
if len(req.GroupIDs) > 0 {
|
||||
if err := s.accountRepo.BindGroups(ctx, account.ID, req.GroupIDs); err != nil {
|
||||
return nil, fmt.Errorf("bind groups: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return account, nil
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取账号
|
||||
func (s *AccountService) GetByID(ctx context.Context, id int64) (*Account, error) {
|
||||
account, err := s.accountRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get account: %w", err)
|
||||
}
|
||||
return account, nil
|
||||
}
|
||||
|
||||
// List 获取账号列表
|
||||
func (s *AccountService) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) {
|
||||
accounts, pagination, err := s.accountRepo.List(ctx, params)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("list accounts: %w", err)
|
||||
}
|
||||
return accounts, pagination, nil
|
||||
}
|
||||
|
||||
// ListByPlatform 根据平台获取账号列表
|
||||
func (s *AccountService) ListByPlatform(ctx context.Context, platform string) ([]Account, error) {
|
||||
accounts, err := s.accountRepo.ListByPlatform(ctx, platform)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list accounts by platform: %w", err)
|
||||
}
|
||||
return accounts, nil
|
||||
}
|
||||
|
||||
// ListByGroup 根据分组获取账号列表
|
||||
func (s *AccountService) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) {
|
||||
accounts, err := s.accountRepo.ListByGroup(ctx, groupID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list accounts by group: %w", err)
|
||||
}
|
||||
return accounts, nil
|
||||
}
|
||||
|
||||
// Update 更新账号
|
||||
func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccountRequest) (*Account, error) {
|
||||
account, err := s.accountRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get account: %w", err)
|
||||
}
|
||||
|
||||
// 更新字段
|
||||
if req.Name != nil {
|
||||
account.Name = *req.Name
|
||||
}
|
||||
if req.Notes != nil {
|
||||
account.Notes = normalizeAccountNotes(req.Notes)
|
||||
}
|
||||
|
||||
if req.Credentials != nil {
|
||||
account.Credentials = *req.Credentials
|
||||
}
|
||||
|
||||
if req.Extra != nil {
|
||||
account.Extra = *req.Extra
|
||||
}
|
||||
|
||||
if req.ProxyID != nil {
|
||||
account.ProxyID = req.ProxyID
|
||||
}
|
||||
|
||||
if req.Concurrency != nil {
|
||||
account.Concurrency = *req.Concurrency
|
||||
}
|
||||
|
||||
if req.Priority != nil {
|
||||
account.Priority = *req.Priority
|
||||
}
|
||||
|
||||
if req.Status != nil {
|
||||
account.Status = *req.Status
|
||||
}
|
||||
if req.ExpiresAt != nil {
|
||||
account.ExpiresAt = req.ExpiresAt
|
||||
}
|
||||
if req.AutoPauseOnExpired != nil {
|
||||
account.AutoPauseOnExpired = *req.AutoPauseOnExpired
|
||||
}
|
||||
|
||||
// 先验证分组是否存在(在任何写操作之前)
|
||||
if req.GroupIDs != nil {
|
||||
for _, groupID := range *req.GroupIDs {
|
||||
_, err := s.groupRepo.GetByID(ctx, groupID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get group: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 执行更新
|
||||
if err := s.accountRepo.Update(ctx, account); err != nil {
|
||||
return nil, fmt.Errorf("update account: %w", err)
|
||||
}
|
||||
|
||||
// 绑定分组
|
||||
if req.GroupIDs != nil {
|
||||
if err := s.accountRepo.BindGroups(ctx, account.ID, *req.GroupIDs); err != nil {
|
||||
return nil, fmt.Errorf("bind groups: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return account, nil
|
||||
}
|
||||
|
||||
// Delete 删除账号
|
||||
// 优化:使用 ExistsByID 替代 GetByID 进行存在性检查,
|
||||
// 避免加载完整账号对象及其关联数据,提升删除操作的性能
|
||||
func (s *AccountService) Delete(ctx context.Context, id int64) error {
|
||||
// 使用轻量级的存在性检查,而非加载完整账号对象
|
||||
exists, err := s.accountRepo.ExistsByID(ctx, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("check account: %w", err)
|
||||
}
|
||||
// 明确返回账号不存在错误,便于调用方区分错误类型
|
||||
if !exists {
|
||||
return ErrAccountNotFound
|
||||
}
|
||||
|
||||
if err := s.accountRepo.Delete(ctx, id); err != nil {
|
||||
return fmt.Errorf("delete account: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateStatus 更新账号状态
|
||||
func (s *AccountService) UpdateStatus(ctx context.Context, id int64, status string, errorMessage string) error {
|
||||
account, err := s.accountRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get account: %w", err)
|
||||
}
|
||||
|
||||
account.Status = status
|
||||
account.ErrorMessage = errorMessage
|
||||
|
||||
if err := s.accountRepo.Update(ctx, account); err != nil {
|
||||
return fmt.Errorf("update account: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateLastUsed 更新最后使用时间
|
||||
func (s *AccountService) UpdateLastUsed(ctx context.Context, id int64) error {
|
||||
if err := s.accountRepo.UpdateLastUsed(ctx, id); err != nil {
|
||||
return fmt.Errorf("update last used: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetCredential 获取账号凭证(安全访问)
|
||||
func (s *AccountService) GetCredential(ctx context.Context, id int64, key string) (string, error) {
|
||||
account, err := s.accountRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("get account: %w", err)
|
||||
}
|
||||
|
||||
return account.GetCredential(key), nil
|
||||
}
|
||||
|
||||
// TestCredentials 测试账号凭证是否有效(需要实现具体平台的测试逻辑)
|
||||
func (s *AccountService) TestCredentials(ctx context.Context, id int64) error {
|
||||
account, err := s.accountRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get account: %w", err)
|
||||
}
|
||||
|
||||
// 根据平台执行不同的测试逻辑
|
||||
switch account.Platform {
|
||||
case PlatformAnthropic:
|
||||
// TODO: 测试Anthropic API凭证
|
||||
return nil
|
||||
case PlatformOpenAI:
|
||||
// TODO: 测试OpenAI API凭证
|
||||
return nil
|
||||
case PlatformGemini:
|
||||
// TODO: 测试Gemini API凭证
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("unsupported platform: %s", account.Platform)
|
||||
}
|
||||
}
|
||||
239
backend/internal/service/account_service_delete_test.go
Normal file
239
backend/internal/service/account_service_delete_test.go
Normal file
@@ -0,0 +1,239 @@
|
||||
//go:build unit
|
||||
|
||||
// 账号服务删除方法的单元测试
|
||||
// 测试 AccountService.Delete 方法在各种场景下的行为
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// accountRepoStub 是 AccountRepository 接口的测试桩实现。
|
||||
// 用于隔离测试 AccountService.Delete 方法,避免依赖真实数据库。
|
||||
//
|
||||
// 设计说明:
|
||||
// - exists: 模拟 ExistsByID 返回的存在性结果
|
||||
// - existsErr: 模拟 ExistsByID 返回的错误
|
||||
// - deleteErr: 模拟 Delete 返回的错误
|
||||
// - deletedIDs: 记录被调用删除的账号 ID,用于断言验证
|
||||
type accountRepoStub struct {
|
||||
exists bool // ExistsByID 的返回值
|
||||
existsErr error // ExistsByID 的错误返回值
|
||||
deleteErr error // Delete 的错误返回值
|
||||
deletedIDs []int64 // 记录已删除的账号 ID 列表
|
||||
}
|
||||
|
||||
// 以下方法在本测试中不应被调用,使用 panic 确保测试失败时能快速定位问题
|
||||
|
||||
func (s *accountRepoStub) Create(ctx context.Context, account *Account) error {
|
||||
panic("unexpected Create call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) GetByID(ctx context.Context, id int64) (*Account, error) {
|
||||
panic("unexpected GetByID call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) GetByIDs(ctx context.Context, ids []int64) ([]*Account, error) {
|
||||
panic("unexpected GetByIDs call")
|
||||
}
|
||||
|
||||
// ExistsByID 返回预设的存在性检查结果。
|
||||
// 这是 Delete 方法调用的第一个仓储方法,用于验证账号是否存在。
|
||||
func (s *accountRepoStub) ExistsByID(ctx context.Context, id int64) (bool, error) {
|
||||
return s.exists, s.existsErr
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error) {
|
||||
panic("unexpected GetByCRSAccountID call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) Update(ctx context.Context, account *Account) error {
|
||||
panic("unexpected Update call")
|
||||
}
|
||||
|
||||
// Delete 记录被删除的账号 ID 并返回预设的错误。
|
||||
// 通过 deletedIDs 可以验证删除操作是否被正确调用。
|
||||
func (s *accountRepoStub) Delete(ctx context.Context, id int64) error {
|
||||
s.deletedIDs = append(s.deletedIDs, id)
|
||||
return s.deleteErr
|
||||
}
|
||||
|
||||
// 以下是接口要求实现但本测试不关心的方法
|
||||
|
||||
func (s *accountRepoStub) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) {
|
||||
panic("unexpected List call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListWithFilters call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) {
|
||||
panic("unexpected ListByGroup call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) ListActive(ctx context.Context) ([]Account, error) {
|
||||
panic("unexpected ListActive call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) ListByPlatform(ctx context.Context, platform string) ([]Account, error) {
|
||||
panic("unexpected ListByPlatform call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) UpdateLastUsed(ctx context.Context, id int64) error {
|
||||
panic("unexpected UpdateLastUsed call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error {
|
||||
panic("unexpected BatchUpdateLastUsed call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) SetError(ctx context.Context, id int64, errorMsg string) error {
|
||||
panic("unexpected SetError call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
|
||||
panic("unexpected SetSchedulable call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error) {
|
||||
panic("unexpected AutoPauseExpiredAccounts call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
|
||||
panic("unexpected BindGroups call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) ListSchedulable(ctx context.Context) ([]Account, error) {
|
||||
panic("unexpected ListSchedulable call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]Account, error) {
|
||||
panic("unexpected ListSchedulableByGroupID call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error) {
|
||||
panic("unexpected ListSchedulableByPlatform call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error) {
|
||||
panic("unexpected ListSchedulableByGroupIDAndPlatform call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]Account, error) {
|
||||
panic("unexpected ListSchedulableByPlatforms call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error) {
|
||||
panic("unexpected ListSchedulableByGroupIDAndPlatforms call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
||||
panic("unexpected SetRateLimited call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error {
|
||||
panic("unexpected SetAntigravityQuotaScopeLimit call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
|
||||
panic("unexpected SetOverloaded call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
|
||||
panic("unexpected SetTempUnschedulable call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) ClearTempUnschedulable(ctx context.Context, id int64) error {
|
||||
panic("unexpected ClearTempUnschedulable call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) ClearRateLimit(ctx context.Context, id int64) error {
|
||||
panic("unexpected ClearRateLimit call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error {
|
||||
panic("unexpected ClearAntigravityQuotaScopes call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
|
||||
panic("unexpected UpdateSessionWindow call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error {
|
||||
panic("unexpected UpdateExtra call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) BulkUpdate(ctx context.Context, ids []int64, updates AccountBulkUpdate) (int64, error) {
|
||||
panic("unexpected BulkUpdate call")
|
||||
}
|
||||
|
||||
// TestAccountService_Delete_NotFound 测试删除不存在的账号时返回正确的错误。
|
||||
// 预期行为:
|
||||
// - ExistsByID 返回 false(账号不存在)
|
||||
// - 返回 ErrAccountNotFound 错误
|
||||
// - Delete 方法不被调用(deletedIDs 为空)
|
||||
func TestAccountService_Delete_NotFound(t *testing.T) {
|
||||
repo := &accountRepoStub{exists: false}
|
||||
svc := &AccountService{accountRepo: repo}
|
||||
|
||||
err := svc.Delete(context.Background(), 55)
|
||||
require.ErrorIs(t, err, ErrAccountNotFound)
|
||||
require.Empty(t, repo.deletedIDs) // 验证删除操作未被调用
|
||||
}
|
||||
|
||||
// TestAccountService_Delete_CheckError 测试存在性检查失败时的错误处理。
|
||||
// 预期行为:
|
||||
// - ExistsByID 返回数据库错误
|
||||
// - 返回包含 "check account" 的错误信息
|
||||
// - Delete 方法不被调用
|
||||
func TestAccountService_Delete_CheckError(t *testing.T) {
|
||||
repo := &accountRepoStub{existsErr: errors.New("db down")}
|
||||
svc := &AccountService{accountRepo: repo}
|
||||
|
||||
err := svc.Delete(context.Background(), 55)
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, "check account") // 验证错误信息包含上下文
|
||||
require.Empty(t, repo.deletedIDs)
|
||||
}
|
||||
|
||||
// TestAccountService_Delete_DeleteError 测试删除操作失败时的错误处理。
|
||||
// 预期行为:
|
||||
// - ExistsByID 返回 true(账号存在)
|
||||
// - Delete 被调用但返回错误
|
||||
// - 返回包含 "delete account" 的错误信息
|
||||
// - deletedIDs 记录了尝试删除的 ID
|
||||
func TestAccountService_Delete_DeleteError(t *testing.T) {
|
||||
repo := &accountRepoStub{
|
||||
exists: true,
|
||||
deleteErr: errors.New("delete failed"),
|
||||
}
|
||||
svc := &AccountService{accountRepo: repo}
|
||||
|
||||
err := svc.Delete(context.Background(), 55)
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, "delete account")
|
||||
require.Equal(t, []int64{55}, repo.deletedIDs) // 验证删除操作被调用
|
||||
}
|
||||
|
||||
// TestAccountService_Delete_Success 测试删除操作成功的场景。
|
||||
// 预期行为:
|
||||
// - ExistsByID 返回 true(账号存在)
|
||||
// - Delete 成功执行
|
||||
// - 返回 nil 错误
|
||||
// - deletedIDs 记录了被删除的 ID
|
||||
func TestAccountService_Delete_Success(t *testing.T) {
|
||||
repo := &accountRepoStub{exists: true}
|
||||
svc := &AccountService{accountRepo: repo}
|
||||
|
||||
err := svc.Delete(context.Background(), 55)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []int64{55}, repo.deletedIDs) // 验证正确的 ID 被删除
|
||||
}
|
||||
847
backend/internal/service/account_test_service.go
Normal file
847
backend/internal/service/account_test_service.go
Normal file
@@ -0,0 +1,847 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// sseDataPrefix matches SSE data lines with optional whitespace after colon.
|
||||
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
|
||||
var sseDataPrefix = regexp.MustCompile(`^data:\s*`)
|
||||
|
||||
const (
|
||||
testClaudeAPIURL = "https://api.anthropic.com/v1/messages"
|
||||
chatgptCodexAPIURL = "https://chatgpt.com/backend-api/codex/responses"
|
||||
)
|
||||
|
||||
// TestEvent represents a SSE event for account testing
|
||||
type TestEvent struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Success bool `json:"success,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// AccountTestService handles account testing operations
|
||||
type AccountTestService struct {
|
||||
accountRepo AccountRepository
|
||||
geminiTokenProvider *GeminiTokenProvider
|
||||
antigravityGatewayService *AntigravityGatewayService
|
||||
httpUpstream HTTPUpstream
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewAccountTestService creates a new AccountTestService
|
||||
func NewAccountTestService(
|
||||
accountRepo AccountRepository,
|
||||
geminiTokenProvider *GeminiTokenProvider,
|
||||
antigravityGatewayService *AntigravityGatewayService,
|
||||
httpUpstream HTTPUpstream,
|
||||
cfg *config.Config,
|
||||
) *AccountTestService {
|
||||
return &AccountTestService{
|
||||
accountRepo: accountRepo,
|
||||
geminiTokenProvider: geminiTokenProvider,
|
||||
antigravityGatewayService: antigravityGatewayService,
|
||||
httpUpstream: httpUpstream,
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AccountTestService) validateUpstreamBaseURL(raw string) (string, error) {
|
||||
if s.cfg == nil {
|
||||
return "", errors.New("config is not available")
|
||||
}
|
||||
if !s.cfg.Security.URLAllowlist.Enabled {
|
||||
return urlvalidator.ValidateURLFormat(raw, s.cfg.Security.URLAllowlist.AllowInsecureHTTP)
|
||||
}
|
||||
normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
|
||||
AllowedHosts: s.cfg.Security.URLAllowlist.UpstreamHosts,
|
||||
RequireAllowlist: true,
|
||||
AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return normalized, nil
|
||||
}
|
||||
|
||||
// generateSessionString generates a Claude Code style session string
|
||||
func generateSessionString() (string, error) {
|
||||
bytes := make([]byte, 32)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return "", err
|
||||
}
|
||||
hex64 := hex.EncodeToString(bytes)
|
||||
sessionUUID := uuid.New().String()
|
||||
return fmt.Sprintf("user_%s_account__session_%s", hex64, sessionUUID), nil
|
||||
}
|
||||
|
||||
// createTestPayload creates a Claude Code style test request payload
|
||||
func createTestPayload(modelID string) (map[string]any, error) {
|
||||
sessionID, err := generateSessionString()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return map[string]any{
|
||||
"model": modelID,
|
||||
"messages": []map[string]any{
|
||||
{
|
||||
"role": "user",
|
||||
"content": []map[string]any{
|
||||
{
|
||||
"type": "text",
|
||||
"text": "hi",
|
||||
"cache_control": map[string]string{
|
||||
"type": "ephemeral",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"system": []map[string]any{
|
||||
{
|
||||
"type": "text",
|
||||
"text": "You are Claude Code, Anthropic's official CLI for Claude.",
|
||||
"cache_control": map[string]string{
|
||||
"type": "ephemeral",
|
||||
},
|
||||
},
|
||||
},
|
||||
"metadata": map[string]string{
|
||||
"user_id": sessionID,
|
||||
},
|
||||
"max_tokens": 1024,
|
||||
"temperature": 1,
|
||||
"stream": true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// TestAccountConnection tests an account's connection by sending a test request
|
||||
// All account types use full Claude Code client characteristics, only auth header differs
|
||||
// modelID is optional - if empty, defaults to claude.DefaultTestModel
|
||||
func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int64, modelID string) error {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Get account
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, "Account not found")
|
||||
}
|
||||
|
||||
// Route to platform-specific test method
|
||||
if account.IsOpenAI() {
|
||||
return s.testOpenAIAccountConnection(c, account, modelID)
|
||||
}
|
||||
|
||||
if account.IsGemini() {
|
||||
return s.testGeminiAccountConnection(c, account, modelID)
|
||||
}
|
||||
|
||||
if account.Platform == PlatformAntigravity {
|
||||
return s.testAntigravityAccountConnection(c, account, modelID)
|
||||
}
|
||||
|
||||
return s.testClaudeAccountConnection(c, account, modelID)
|
||||
}
|
||||
|
||||
// testClaudeAccountConnection tests an Anthropic Claude account's connection
|
||||
func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account *Account, modelID string) error {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Determine the model to use
|
||||
testModelID := modelID
|
||||
if testModelID == "" {
|
||||
testModelID = claude.DefaultTestModel
|
||||
}
|
||||
|
||||
// For API Key accounts with model mapping, map the model
|
||||
if account.Type == "apikey" {
|
||||
mapping := account.GetModelMapping()
|
||||
if len(mapping) > 0 {
|
||||
if mappedModel, exists := mapping[testModelID]; exists {
|
||||
testModelID = mappedModel
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Determine authentication method and API URL
|
||||
var authToken string
|
||||
var useBearer bool
|
||||
var apiURL string
|
||||
|
||||
if account.IsOAuth() {
|
||||
// OAuth or Setup Token - use Bearer token
|
||||
useBearer = true
|
||||
apiURL = testClaudeAPIURL
|
||||
authToken = account.GetCredential("access_token")
|
||||
if authToken == "" {
|
||||
return s.sendErrorAndEnd(c, "No access token available")
|
||||
}
|
||||
} else if account.Type == "apikey" {
|
||||
// API Key - use x-api-key header
|
||||
useBearer = false
|
||||
authToken = account.GetCredential("api_key")
|
||||
if authToken == "" {
|
||||
return s.sendErrorAndEnd(c, "No API key available")
|
||||
}
|
||||
|
||||
baseURL := account.GetBaseURL()
|
||||
if baseURL == "" {
|
||||
baseURL = "https://api.anthropic.com"
|
||||
}
|
||||
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid base URL: %s", err.Error()))
|
||||
}
|
||||
apiURL = strings.TrimSuffix(normalizedBaseURL, "/") + "/v1/messages"
|
||||
} else {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
|
||||
}
|
||||
|
||||
// Set SSE headers
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||
c.Writer.Flush()
|
||||
|
||||
// Create Claude Code style payload (same for all account types)
|
||||
payload, err := createTestPayload(testModelID)
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, "Failed to create test payload")
|
||||
}
|
||||
payloadBytes, _ := json.Marshal(payload)
|
||||
|
||||
// Send test_start event
|
||||
s.sendEvent(c, TestEvent{Type: "test_start", Model: testModelID})
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewReader(payloadBytes))
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, "Failed to create request")
|
||||
}
|
||||
|
||||
// Set common headers
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("anthropic-version", "2023-06-01")
|
||||
req.Header.Set("anthropic-beta", claude.DefaultBetaHeader)
|
||||
|
||||
// Apply Claude Code client headers
|
||||
for key, value := range claude.DefaultHeaders {
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
|
||||
// Set authentication header
|
||||
if useBearer {
|
||||
req.Header.Set("Authorization", "Bearer "+authToken)
|
||||
} else {
|
||||
req.Header.Set("x-api-key", authToken)
|
||||
}
|
||||
|
||||
// Get proxy URL
|
||||
proxyURL := ""
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body)))
|
||||
}
|
||||
|
||||
// Process SSE stream
|
||||
return s.processClaudeStream(c, resp.Body)
|
||||
}
|
||||
|
||||
// testOpenAIAccountConnection tests an OpenAI account's connection
|
||||
func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account *Account, modelID string) error {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Default to openai.DefaultTestModel for OpenAI testing
|
||||
testModelID := modelID
|
||||
if testModelID == "" {
|
||||
testModelID = openai.DefaultTestModel
|
||||
}
|
||||
|
||||
// For API Key accounts with model mapping, map the model
|
||||
if account.Type == "apikey" {
|
||||
mapping := account.GetModelMapping()
|
||||
if len(mapping) > 0 {
|
||||
if mappedModel, exists := mapping[testModelID]; exists {
|
||||
testModelID = mappedModel
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Determine authentication method and API URL
|
||||
var authToken string
|
||||
var apiURL string
|
||||
var isOAuth bool
|
||||
var chatgptAccountID string
|
||||
|
||||
if account.IsOAuth() {
|
||||
isOAuth = true
|
||||
// OAuth - use Bearer token with ChatGPT internal API
|
||||
authToken = account.GetOpenAIAccessToken()
|
||||
if authToken == "" {
|
||||
return s.sendErrorAndEnd(c, "No access token available")
|
||||
}
|
||||
|
||||
// OAuth uses ChatGPT internal API
|
||||
apiURL = chatgptCodexAPIURL
|
||||
chatgptAccountID = account.GetChatGPTAccountID()
|
||||
} else if account.Type == "apikey" {
|
||||
// API Key - use Platform API
|
||||
authToken = account.GetOpenAIApiKey()
|
||||
if authToken == "" {
|
||||
return s.sendErrorAndEnd(c, "No API key available")
|
||||
}
|
||||
|
||||
baseURL := account.GetOpenAIBaseURL()
|
||||
if baseURL == "" {
|
||||
baseURL = "https://api.openai.com"
|
||||
}
|
||||
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid base URL: %s", err.Error()))
|
||||
}
|
||||
apiURL = strings.TrimSuffix(normalizedBaseURL, "/") + "/responses"
|
||||
} else {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
|
||||
}
|
||||
|
||||
// Set SSE headers
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||
c.Writer.Flush()
|
||||
|
||||
// Create OpenAI Responses API payload
|
||||
payload := createOpenAITestPayload(testModelID, isOAuth)
|
||||
payloadBytes, _ := json.Marshal(payload)
|
||||
|
||||
// Send test_start event
|
||||
s.sendEvent(c, TestEvent{Type: "test_start", Model: testModelID})
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewReader(payloadBytes))
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, "Failed to create request")
|
||||
}
|
||||
|
||||
// Set common headers
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+authToken)
|
||||
|
||||
// Set OAuth-specific headers for ChatGPT internal API
|
||||
if isOAuth {
|
||||
req.Host = "chatgpt.com"
|
||||
req.Header.Set("accept", "text/event-stream")
|
||||
if chatgptAccountID != "" {
|
||||
req.Header.Set("chatgpt-account-id", chatgptAccountID)
|
||||
}
|
||||
}
|
||||
|
||||
// Get proxy URL
|
||||
proxyURL := ""
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body)))
|
||||
}
|
||||
|
||||
// Process SSE stream
|
||||
return s.processOpenAIStream(c, resp.Body)
|
||||
}
|
||||
|
||||
// testGeminiAccountConnection tests a Gemini account's connection
|
||||
func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account *Account, modelID string) error {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Determine the model to use
|
||||
testModelID := modelID
|
||||
if testModelID == "" {
|
||||
testModelID = geminicli.DefaultTestModel
|
||||
}
|
||||
|
||||
// For API Key accounts with model mapping, map the model
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
mapping := account.GetModelMapping()
|
||||
if len(mapping) > 0 {
|
||||
if mappedModel, exists := mapping[testModelID]; exists {
|
||||
testModelID = mappedModel
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Set SSE headers
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||
c.Writer.Flush()
|
||||
|
||||
// Create test payload (Gemini format)
|
||||
payload := createGeminiTestPayload()
|
||||
|
||||
// Build request based on account type
|
||||
var req *http.Request
|
||||
var err error
|
||||
|
||||
switch account.Type {
|
||||
case AccountTypeAPIKey:
|
||||
req, err = s.buildGeminiAPIKeyRequest(ctx, account, testModelID, payload)
|
||||
case AccountTypeOAuth:
|
||||
req, err = s.buildGeminiOAuthRequest(ctx, account, testModelID, payload)
|
||||
default:
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to build request: %s", err.Error()))
|
||||
}
|
||||
|
||||
// Send test_start event
|
||||
s.sendEvent(c, TestEvent{Type: "test_start", Model: testModelID})
|
||||
|
||||
// Get proxy and execute request
|
||||
proxyURL := ""
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body)))
|
||||
}
|
||||
|
||||
// Process SSE stream
|
||||
return s.processGeminiStream(c, resp.Body)
|
||||
}
|
||||
|
||||
// testAntigravityAccountConnection tests an Antigravity account's connection
|
||||
// 支持 Claude 和 Gemini 两种协议,使用非流式请求
|
||||
func (s *AccountTestService) testAntigravityAccountConnection(c *gin.Context, account *Account, modelID string) error {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// 默认模型:Claude 使用 claude-sonnet-4-5,Gemini 使用 gemini-3-pro-preview
|
||||
testModelID := modelID
|
||||
if testModelID == "" {
|
||||
testModelID = "claude-sonnet-4-5"
|
||||
}
|
||||
|
||||
if s.antigravityGatewayService == nil {
|
||||
return s.sendErrorAndEnd(c, "Antigravity gateway service not configured")
|
||||
}
|
||||
|
||||
// Set SSE headers
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||
c.Writer.Flush()
|
||||
|
||||
// Send test_start event
|
||||
s.sendEvent(c, TestEvent{Type: "test_start", Model: testModelID})
|
||||
|
||||
// 调用 AntigravityGatewayService.TestConnection(复用协议转换逻辑)
|
||||
result, err := s.antigravityGatewayService.TestConnection(ctx, account, testModelID)
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, err.Error())
|
||||
}
|
||||
|
||||
// 发送响应内容
|
||||
if result.Text != "" {
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: result.Text})
|
||||
}
|
||||
|
||||
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
|
||||
return nil
|
||||
}
|
||||
|
||||
// buildGeminiAPIKeyRequest builds request for Gemini API Key accounts
|
||||
func (s *AccountTestService) buildGeminiAPIKeyRequest(ctx context.Context, account *Account, modelID string, payload []byte) (*http.Request, error) {
|
||||
apiKey := account.GetCredential("api_key")
|
||||
if strings.TrimSpace(apiKey) == "" {
|
||||
return nil, fmt.Errorf("no API key available")
|
||||
}
|
||||
|
||||
baseURL := account.GetCredential("base_url")
|
||||
if baseURL == "" {
|
||||
baseURL = geminicli.AIStudioBaseURL
|
||||
}
|
||||
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Use streamGenerateContent for real-time feedback
|
||||
fullURL := fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse",
|
||||
strings.TrimRight(normalizedBaseURL, "/"), modelID)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", fullURL, bytes.NewReader(payload))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("x-goog-api-key", apiKey)
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// buildGeminiOAuthRequest builds request for Gemini OAuth accounts
|
||||
func (s *AccountTestService) buildGeminiOAuthRequest(ctx context.Context, account *Account, modelID string, payload []byte) (*http.Request, error) {
|
||||
if s.geminiTokenProvider == nil {
|
||||
return nil, fmt.Errorf("gemini token provider not configured")
|
||||
}
|
||||
|
||||
// Get access token (auto-refreshes if needed)
|
||||
accessToken, err := s.geminiTokenProvider.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get access token: %w", err)
|
||||
}
|
||||
|
||||
projectID := strings.TrimSpace(account.GetCredential("project_id"))
|
||||
if projectID == "" {
|
||||
// AI Studio OAuth mode (no project_id): call generativelanguage API directly with Bearer token.
|
||||
baseURL := account.GetCredential("base_url")
|
||||
if strings.TrimSpace(baseURL) == "" {
|
||||
baseURL = geminicli.AIStudioBaseURL
|
||||
}
|
||||
normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fullURL := fmt.Sprintf("%s/v1beta/models/%s:streamGenerateContent?alt=sse", strings.TrimRight(normalizedBaseURL, "/"), modelID)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(payload))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// Code Assist mode (with project_id)
|
||||
return s.buildCodeAssistRequest(ctx, accessToken, projectID, modelID, payload)
|
||||
}
|
||||
|
||||
// buildCodeAssistRequest builds request for Google Code Assist API (used by Gemini CLI and Antigravity)
|
||||
func (s *AccountTestService) buildCodeAssistRequest(ctx context.Context, accessToken, projectID, modelID string, payload []byte) (*http.Request, error) {
|
||||
var inner map[string]any
|
||||
if err := json.Unmarshal(payload, &inner); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
wrapped := map[string]any{
|
||||
"model": modelID,
|
||||
"project": projectID,
|
||||
"request": inner,
|
||||
}
|
||||
wrappedBytes, _ := json.Marshal(wrapped)
|
||||
|
||||
normalizedBaseURL, err := s.validateUpstreamBaseURL(geminicli.GeminiCliBaseURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fullURL := fmt.Sprintf("%s/v1internal:streamGenerateContent?alt=sse", normalizedBaseURL)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", fullURL, bytes.NewReader(wrappedBytes))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("User-Agent", geminicli.GeminiCLIUserAgent)
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// createGeminiTestPayload creates a minimal test payload for Gemini API
|
||||
func createGeminiTestPayload() []byte {
|
||||
payload := map[string]any{
|
||||
"contents": []map[string]any{
|
||||
{
|
||||
"role": "user",
|
||||
"parts": []map[string]any{
|
||||
{"text": "hi"},
|
||||
},
|
||||
},
|
||||
},
|
||||
"systemInstruction": map[string]any{
|
||||
"parts": []map[string]any{
|
||||
{"text": "You are a helpful AI assistant."},
|
||||
},
|
||||
},
|
||||
}
|
||||
bytes, _ := json.Marshal(payload)
|
||||
return bytes
|
||||
}
|
||||
|
||||
// processGeminiStream processes SSE stream from Gemini API
|
||||
func (s *AccountTestService) processGeminiStream(c *gin.Context, body io.Reader) error {
|
||||
reader := bufio.NewReader(body)
|
||||
|
||||
for {
|
||||
line, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
|
||||
return nil
|
||||
}
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Stream read error: %s", err.Error()))
|
||||
}
|
||||
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" || !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
|
||||
jsonStr := strings.TrimPrefix(line, "data: ")
|
||||
if jsonStr == "[DONE]" {
|
||||
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
|
||||
return nil
|
||||
}
|
||||
|
||||
var data map[string]any
|
||||
if err := json.Unmarshal([]byte(jsonStr), &data); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Support two Gemini response formats:
|
||||
// - AI Studio: {"candidates": [...]}
|
||||
// - Gemini CLI: {"response": {"candidates": [...]}}
|
||||
if resp, ok := data["response"].(map[string]any); ok && resp != nil {
|
||||
data = resp
|
||||
}
|
||||
if candidates, ok := data["candidates"].([]any); ok && len(candidates) > 0 {
|
||||
if candidate, ok := candidates[0].(map[string]any); ok {
|
||||
// Extract content first (before checking completion)
|
||||
if content, ok := candidate["content"].(map[string]any); ok {
|
||||
if parts, ok := content["parts"].([]any); ok {
|
||||
for _, part := range parts {
|
||||
if partMap, ok := part.(map[string]any); ok {
|
||||
if text, ok := partMap["text"].(string); ok && text != "" {
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: text})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check for completion after extracting content
|
||||
if finishReason, ok := candidate["finishReason"].(string); ok && finishReason != "" {
|
||||
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle errors
|
||||
if errData, ok := data["error"].(map[string]any); ok {
|
||||
errorMsg := "Unknown error"
|
||||
if msg, ok := errData["message"].(string); ok {
|
||||
errorMsg = msg
|
||||
}
|
||||
return s.sendErrorAndEnd(c, errorMsg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// createOpenAITestPayload creates a test payload for OpenAI Responses API
|
||||
func createOpenAITestPayload(modelID string, isOAuth bool) map[string]any {
|
||||
payload := map[string]any{
|
||||
"model": modelID,
|
||||
"input": []map[string]any{
|
||||
{
|
||||
"role": "user",
|
||||
"content": []map[string]any{
|
||||
{
|
||||
"type": "input_text",
|
||||
"text": "hi",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"stream": true,
|
||||
}
|
||||
|
||||
// OAuth accounts using ChatGPT internal API require store: false
|
||||
if isOAuth {
|
||||
payload["store"] = false
|
||||
}
|
||||
|
||||
// All accounts require instructions for Responses API
|
||||
payload["instructions"] = openai.DefaultInstructions
|
||||
|
||||
return payload
|
||||
}
|
||||
|
||||
// processClaudeStream processes the SSE stream from Claude API
|
||||
func (s *AccountTestService) processClaudeStream(c *gin.Context, body io.Reader) error {
|
||||
reader := bufio.NewReader(body)
|
||||
|
||||
for {
|
||||
line, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
|
||||
return nil
|
||||
}
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Stream read error: %s", err.Error()))
|
||||
}
|
||||
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" || !sseDataPrefix.MatchString(line) {
|
||||
continue
|
||||
}
|
||||
|
||||
jsonStr := sseDataPrefix.ReplaceAllString(line, "")
|
||||
if jsonStr == "[DONE]" {
|
||||
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
|
||||
return nil
|
||||
}
|
||||
|
||||
var data map[string]any
|
||||
if err := json.Unmarshal([]byte(jsonStr), &data); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
eventType, _ := data["type"].(string)
|
||||
|
||||
switch eventType {
|
||||
case "content_block_delta":
|
||||
if delta, ok := data["delta"].(map[string]any); ok {
|
||||
if text, ok := delta["text"].(string); ok {
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: text})
|
||||
}
|
||||
}
|
||||
case "message_stop":
|
||||
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
|
||||
return nil
|
||||
case "error":
|
||||
errorMsg := "Unknown error"
|
||||
if errData, ok := data["error"].(map[string]any); ok {
|
||||
if msg, ok := errData["message"].(string); ok {
|
||||
errorMsg = msg
|
||||
}
|
||||
}
|
||||
return s.sendErrorAndEnd(c, errorMsg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// processOpenAIStream processes the SSE stream from OpenAI Responses API
|
||||
func (s *AccountTestService) processOpenAIStream(c *gin.Context, body io.Reader) error {
|
||||
reader := bufio.NewReader(body)
|
||||
|
||||
for {
|
||||
line, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
|
||||
return nil
|
||||
}
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Stream read error: %s", err.Error()))
|
||||
}
|
||||
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" || !sseDataPrefix.MatchString(line) {
|
||||
continue
|
||||
}
|
||||
|
||||
jsonStr := sseDataPrefix.ReplaceAllString(line, "")
|
||||
if jsonStr == "[DONE]" {
|
||||
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
|
||||
return nil
|
||||
}
|
||||
|
||||
var data map[string]any
|
||||
if err := json.Unmarshal([]byte(jsonStr), &data); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
eventType, _ := data["type"].(string)
|
||||
|
||||
switch eventType {
|
||||
case "response.output_text.delta":
|
||||
// OpenAI Responses API uses "delta" field for text content
|
||||
if delta, ok := data["delta"].(string); ok && delta != "" {
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: delta})
|
||||
}
|
||||
case "response.completed":
|
||||
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
|
||||
return nil
|
||||
case "error":
|
||||
errorMsg := "Unknown error"
|
||||
if errData, ok := data["error"].(map[string]any); ok {
|
||||
if msg, ok := errData["message"].(string); ok {
|
||||
errorMsg = msg
|
||||
}
|
||||
}
|
||||
return s.sendErrorAndEnd(c, errorMsg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// sendEvent sends a SSE event to the client
|
||||
func (s *AccountTestService) sendEvent(c *gin.Context, event TestEvent) {
|
||||
eventJSON, _ := json.Marshal(event)
|
||||
if _, err := fmt.Fprintf(c.Writer, "data: %s\n\n", eventJSON); err != nil {
|
||||
log.Printf("failed to write SSE event: %v", err)
|
||||
return
|
||||
}
|
||||
c.Writer.Flush()
|
||||
}
|
||||
|
||||
// sendErrorAndEnd sends an error event and ends the stream
|
||||
func (s *AccountTestService) sendErrorAndEnd(c *gin.Context, errorMsg string) error {
|
||||
log.Printf("Account test error: %s", errorMsg)
|
||||
s.sendEvent(c, TestEvent{Type: "error", Error: errorMsg})
|
||||
return fmt.Errorf("%s", errorMsg)
|
||||
}
|
||||
577
backend/internal/service/account_usage_service.go
Normal file
577
backend/internal/service/account_usage_service.go
Normal file
@@ -0,0 +1,577 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
)
|
||||
|
||||
type UsageLogRepository interface {
|
||||
// Create creates a usage log and returns whether it was actually inserted.
|
||||
// inserted is false when the insert was skipped due to conflict (idempotent retries).
|
||||
Create(ctx context.Context, log *UsageLog) (inserted bool, err error)
|
||||
GetByID(ctx context.Context, id int64) (*UsageLog, error)
|
||||
Delete(ctx context.Context, id int64) error
|
||||
|
||||
ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error)
|
||||
ListByAPIKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error)
|
||||
ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error)
|
||||
|
||||
ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
|
||||
ListByAPIKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
|
||||
ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
|
||||
ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
|
||||
|
||||
GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error)
|
||||
GetAccountTodayStats(ctx context.Context, accountID int64) (*usagestats.AccountStats, error)
|
||||
|
||||
// Admin dashboard stats
|
||||
GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error)
|
||||
GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool) ([]usagestats.TrendDataPoint, error)
|
||||
GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool) ([]usagestats.ModelStat, error)
|
||||
GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error)
|
||||
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
|
||||
GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error)
|
||||
GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error)
|
||||
|
||||
// User dashboard stats
|
||||
GetUserDashboardStats(ctx context.Context, userID int64) (*usagestats.UserDashboardStats, error)
|
||||
GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) ([]usagestats.TrendDataPoint, error)
|
||||
GetUserModelStats(ctx context.Context, userID int64, startTime, endTime time.Time) ([]usagestats.ModelStat, error)
|
||||
|
||||
// Admin usage listing/stats
|
||||
ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]UsageLog, *pagination.PaginationResult, error)
|
||||
GetGlobalStats(ctx context.Context, startTime, endTime time.Time) (*usagestats.UsageStats, error)
|
||||
GetStatsWithFilters(ctx context.Context, filters usagestats.UsageLogFilters) (*usagestats.UsageStats, error)
|
||||
|
||||
// Account stats
|
||||
GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.AccountUsageStatsResponse, error)
|
||||
|
||||
// Aggregated stats (optimized)
|
||||
GetUserStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
|
||||
GetAPIKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
|
||||
GetAccountStatsAggregated(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
|
||||
GetModelStatsAggregated(ctx context.Context, modelName string, startTime, endTime time.Time) (*usagestats.UsageStats, error)
|
||||
GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) ([]map[string]any, error)
|
||||
}
|
||||
|
||||
// apiUsageCache 缓存从 Anthropic API 获取的使用率数据(utilization, resets_at)
|
||||
type apiUsageCache struct {
|
||||
response *ClaudeUsageResponse
|
||||
timestamp time.Time
|
||||
}
|
||||
|
||||
// windowStatsCache 缓存从本地数据库查询的窗口统计(requests, tokens, cost)
|
||||
type windowStatsCache struct {
|
||||
stats *WindowStats
|
||||
timestamp time.Time
|
||||
}
|
||||
|
||||
// antigravityUsageCache 缓存 Antigravity 额度数据
|
||||
type antigravityUsageCache struct {
|
||||
usageInfo *UsageInfo
|
||||
timestamp time.Time
|
||||
}
|
||||
|
||||
const (
|
||||
apiCacheTTL = 3 * time.Minute
|
||||
windowStatsCacheTTL = 1 * time.Minute
|
||||
)
|
||||
|
||||
// UsageCache 封装账户使用量相关的缓存
|
||||
type UsageCache struct {
|
||||
apiCache sync.Map // accountID -> *apiUsageCache
|
||||
windowStatsCache sync.Map // accountID -> *windowStatsCache
|
||||
antigravityCache sync.Map // accountID -> *antigravityUsageCache
|
||||
}
|
||||
|
||||
// NewUsageCache 创建 UsageCache 实例
|
||||
func NewUsageCache() *UsageCache {
|
||||
return &UsageCache{}
|
||||
}
|
||||
|
||||
// WindowStats 窗口期统计
|
||||
//
|
||||
// cost: 账号口径费用(total_cost * account_rate_multiplier)
|
||||
// standard_cost: 标准费用(total_cost,不含倍率)
|
||||
// user_cost: 用户/API Key 口径费用(actual_cost,受分组倍率影响)
|
||||
type WindowStats struct {
|
||||
Requests int64 `json:"requests"`
|
||||
Tokens int64 `json:"tokens"`
|
||||
Cost float64 `json:"cost"`
|
||||
StandardCost float64 `json:"standard_cost"`
|
||||
UserCost float64 `json:"user_cost"`
|
||||
}
|
||||
|
||||
// UsageProgress 使用量进度
|
||||
type UsageProgress struct {
|
||||
Utilization float64 `json:"utilization"` // 使用率百分比 (0-100+,100表示100%)
|
||||
ResetsAt *time.Time `json:"resets_at"` // 重置时间
|
||||
RemainingSeconds int `json:"remaining_seconds"` // 距重置剩余秒数
|
||||
WindowStats *WindowStats `json:"window_stats,omitempty"` // 窗口期统计(从窗口开始到当前的使用量)
|
||||
UsedRequests int64 `json:"used_requests,omitempty"`
|
||||
LimitRequests int64 `json:"limit_requests,omitempty"`
|
||||
}
|
||||
|
||||
// AntigravityModelQuota Antigravity 单个模型的配额信息
|
||||
type AntigravityModelQuota struct {
|
||||
Utilization int `json:"utilization"` // 使用率 0-100
|
||||
ResetTime string `json:"reset_time"` // 重置时间 ISO8601
|
||||
}
|
||||
|
||||
// UsageInfo 账号使用量信息
|
||||
type UsageInfo struct {
|
||||
UpdatedAt *time.Time `json:"updated_at,omitempty"` // 更新时间
|
||||
FiveHour *UsageProgress `json:"five_hour"` // 5小时窗口
|
||||
SevenDay *UsageProgress `json:"seven_day,omitempty"` // 7天窗口
|
||||
SevenDaySonnet *UsageProgress `json:"seven_day_sonnet,omitempty"` // 7天Sonnet窗口
|
||||
GeminiSharedDaily *UsageProgress `json:"gemini_shared_daily,omitempty"` // Gemini shared pool RPD (Google One / Code Assist)
|
||||
GeminiProDaily *UsageProgress `json:"gemini_pro_daily,omitempty"` // Gemini Pro 日配额
|
||||
GeminiFlashDaily *UsageProgress `json:"gemini_flash_daily,omitempty"` // Gemini Flash 日配额
|
||||
GeminiSharedMinute *UsageProgress `json:"gemini_shared_minute,omitempty"` // Gemini shared pool RPM (Google One / Code Assist)
|
||||
GeminiProMinute *UsageProgress `json:"gemini_pro_minute,omitempty"` // Gemini Pro RPM
|
||||
GeminiFlashMinute *UsageProgress `json:"gemini_flash_minute,omitempty"` // Gemini Flash RPM
|
||||
|
||||
// Antigravity 多模型配额
|
||||
AntigravityQuota map[string]*AntigravityModelQuota `json:"antigravity_quota,omitempty"`
|
||||
}
|
||||
|
||||
// ClaudeUsageResponse Anthropic API返回的usage结构
|
||||
type ClaudeUsageResponse struct {
|
||||
FiveHour struct {
|
||||
Utilization float64 `json:"utilization"`
|
||||
ResetsAt string `json:"resets_at"`
|
||||
} `json:"five_hour"`
|
||||
SevenDay struct {
|
||||
Utilization float64 `json:"utilization"`
|
||||
ResetsAt string `json:"resets_at"`
|
||||
} `json:"seven_day"`
|
||||
SevenDaySonnet struct {
|
||||
Utilization float64 `json:"utilization"`
|
||||
ResetsAt string `json:"resets_at"`
|
||||
} `json:"seven_day_sonnet"`
|
||||
}
|
||||
|
||||
// ClaudeUsageFetcher fetches usage data from Anthropic OAuth API
|
||||
type ClaudeUsageFetcher interface {
|
||||
FetchUsage(ctx context.Context, accessToken, proxyURL string) (*ClaudeUsageResponse, error)
|
||||
}
|
||||
|
||||
// AccountUsageService 账号使用量查询服务
|
||||
type AccountUsageService struct {
|
||||
accountRepo AccountRepository
|
||||
usageLogRepo UsageLogRepository
|
||||
usageFetcher ClaudeUsageFetcher
|
||||
geminiQuotaService *GeminiQuotaService
|
||||
antigravityQuotaFetcher *AntigravityQuotaFetcher
|
||||
cache *UsageCache
|
||||
}
|
||||
|
||||
// NewAccountUsageService 创建AccountUsageService实例
|
||||
func NewAccountUsageService(
|
||||
accountRepo AccountRepository,
|
||||
usageLogRepo UsageLogRepository,
|
||||
usageFetcher ClaudeUsageFetcher,
|
||||
geminiQuotaService *GeminiQuotaService,
|
||||
antigravityQuotaFetcher *AntigravityQuotaFetcher,
|
||||
cache *UsageCache,
|
||||
) *AccountUsageService {
|
||||
return &AccountUsageService{
|
||||
accountRepo: accountRepo,
|
||||
usageLogRepo: usageLogRepo,
|
||||
usageFetcher: usageFetcher,
|
||||
geminiQuotaService: geminiQuotaService,
|
||||
antigravityQuotaFetcher: antigravityQuotaFetcher,
|
||||
cache: cache,
|
||||
}
|
||||
}
|
||||
|
||||
// GetUsage 获取账号使用量
|
||||
// OAuth账号: 调用Anthropic API获取真实数据(需要profile scope),API响应缓存10分钟,窗口统计缓存1分钟
|
||||
// Setup Token账号: 根据session_window推算5h窗口,7d数据不可用(没有profile scope)
|
||||
// API Key账号: 不支持usage查询
|
||||
func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*UsageInfo, error) {
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get account failed: %w", err)
|
||||
}
|
||||
|
||||
if account.Platform == PlatformGemini {
|
||||
return s.getGeminiUsage(ctx, account)
|
||||
}
|
||||
|
||||
// Antigravity 平台:使用 AntigravityQuotaFetcher 获取额度
|
||||
if account.Platform == PlatformAntigravity {
|
||||
return s.getAntigravityUsage(ctx, account)
|
||||
}
|
||||
|
||||
// 只有oauth类型账号可以通过API获取usage(有profile scope)
|
||||
if account.CanGetUsage() {
|
||||
var apiResp *ClaudeUsageResponse
|
||||
|
||||
// 1. 检查 API 缓存(10 分钟)
|
||||
if cached, ok := s.cache.apiCache.Load(accountID); ok {
|
||||
if cache, ok := cached.(*apiUsageCache); ok && time.Since(cache.timestamp) < apiCacheTTL {
|
||||
apiResp = cache.response
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 如果没有缓存,从 API 获取
|
||||
if apiResp == nil {
|
||||
apiResp, err = s.fetchOAuthUsageRaw(ctx, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 缓存 API 响应
|
||||
s.cache.apiCache.Store(accountID, &apiUsageCache{
|
||||
response: apiResp,
|
||||
timestamp: time.Now(),
|
||||
})
|
||||
}
|
||||
|
||||
// 3. 构建 UsageInfo(每次都重新计算 RemainingSeconds)
|
||||
now := time.Now()
|
||||
usage := s.buildUsageInfo(apiResp, &now)
|
||||
|
||||
// 4. 添加窗口统计(有独立缓存,1 分钟)
|
||||
s.addWindowStats(ctx, account, usage)
|
||||
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
// Setup Token账号:根据session_window推算(没有profile scope,无法调用usage API)
|
||||
if account.Type == AccountTypeSetupToken {
|
||||
usage := s.estimateSetupTokenUsage(account)
|
||||
// 添加窗口统计
|
||||
s.addWindowStats(ctx, account, usage)
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
// API Key账号不支持usage查询
|
||||
return nil, fmt.Errorf("account type %s does not support usage query", account.Type)
|
||||
}
|
||||
|
||||
func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Account) (*UsageInfo, error) {
|
||||
now := time.Now()
|
||||
usage := &UsageInfo{
|
||||
UpdatedAt: &now,
|
||||
}
|
||||
|
||||
if s.geminiQuotaService == nil || s.usageLogRepo == nil {
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
quota, ok := s.geminiQuotaService.QuotaForAccount(ctx, account)
|
||||
if !ok {
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
dayStart := geminiDailyWindowStart(now)
|
||||
stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, dayStart, now, 0, 0, account.ID, 0, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get gemini usage stats failed: %w", err)
|
||||
}
|
||||
|
||||
dayTotals := geminiAggregateUsage(stats)
|
||||
dailyResetAt := geminiDailyResetTime(now)
|
||||
|
||||
// Daily window (RPD)
|
||||
if quota.SharedRPD > 0 {
|
||||
totalReq := dayTotals.ProRequests + dayTotals.FlashRequests
|
||||
totalTokens := dayTotals.ProTokens + dayTotals.FlashTokens
|
||||
totalCost := dayTotals.ProCost + dayTotals.FlashCost
|
||||
usage.GeminiSharedDaily = buildGeminiUsageProgress(totalReq, quota.SharedRPD, dailyResetAt, totalTokens, totalCost, now)
|
||||
} else {
|
||||
usage.GeminiProDaily = buildGeminiUsageProgress(dayTotals.ProRequests, quota.ProRPD, dailyResetAt, dayTotals.ProTokens, dayTotals.ProCost, now)
|
||||
usage.GeminiFlashDaily = buildGeminiUsageProgress(dayTotals.FlashRequests, quota.FlashRPD, dailyResetAt, dayTotals.FlashTokens, dayTotals.FlashCost, now)
|
||||
}
|
||||
|
||||
// Minute window (RPM) - fixed-window approximation: current minute [truncate(now), truncate(now)+1m)
|
||||
minuteStart := now.Truncate(time.Minute)
|
||||
minuteResetAt := minuteStart.Add(time.Minute)
|
||||
minuteStats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, minuteStart, now, 0, 0, account.ID, 0, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get gemini minute usage stats failed: %w", err)
|
||||
}
|
||||
minuteTotals := geminiAggregateUsage(minuteStats)
|
||||
|
||||
if quota.SharedRPM > 0 {
|
||||
totalReq := minuteTotals.ProRequests + minuteTotals.FlashRequests
|
||||
totalTokens := minuteTotals.ProTokens + minuteTotals.FlashTokens
|
||||
totalCost := minuteTotals.ProCost + minuteTotals.FlashCost
|
||||
usage.GeminiSharedMinute = buildGeminiUsageProgress(totalReq, quota.SharedRPM, minuteResetAt, totalTokens, totalCost, now)
|
||||
} else {
|
||||
usage.GeminiProMinute = buildGeminiUsageProgress(minuteTotals.ProRequests, quota.ProRPM, minuteResetAt, minuteTotals.ProTokens, minuteTotals.ProCost, now)
|
||||
usage.GeminiFlashMinute = buildGeminiUsageProgress(minuteTotals.FlashRequests, quota.FlashRPM, minuteResetAt, minuteTotals.FlashTokens, minuteTotals.FlashCost, now)
|
||||
}
|
||||
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
// getAntigravityUsage 获取 Antigravity 账户额度
|
||||
func (s *AccountUsageService) getAntigravityUsage(ctx context.Context, account *Account) (*UsageInfo, error) {
|
||||
if s.antigravityQuotaFetcher == nil || !s.antigravityQuotaFetcher.CanFetch(account) {
|
||||
now := time.Now()
|
||||
return &UsageInfo{UpdatedAt: &now}, nil
|
||||
}
|
||||
|
||||
// 1. 检查缓存(10 分钟)
|
||||
if cached, ok := s.cache.antigravityCache.Load(account.ID); ok {
|
||||
if cache, ok := cached.(*antigravityUsageCache); ok && time.Since(cache.timestamp) < apiCacheTTL {
|
||||
// 重新计算 RemainingSeconds
|
||||
usage := cache.usageInfo
|
||||
if usage.FiveHour != nil && usage.FiveHour.ResetsAt != nil {
|
||||
usage.FiveHour.RemainingSeconds = int(time.Until(*usage.FiveHour.ResetsAt).Seconds())
|
||||
}
|
||||
return usage, nil
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 获取代理 URL
|
||||
proxyURL := s.antigravityQuotaFetcher.GetProxyURL(ctx, account)
|
||||
|
||||
// 3. 调用 API 获取额度
|
||||
result, err := s.antigravityQuotaFetcher.FetchQuota(ctx, account, proxyURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fetch antigravity quota failed: %w", err)
|
||||
}
|
||||
|
||||
// 4. 缓存结果
|
||||
s.cache.antigravityCache.Store(account.ID, &antigravityUsageCache{
|
||||
usageInfo: result.UsageInfo,
|
||||
timestamp: time.Now(),
|
||||
})
|
||||
|
||||
return result.UsageInfo, nil
|
||||
}
|
||||
|
||||
// addWindowStats 为 usage 数据添加窗口期统计
|
||||
// 使用独立缓存(1 分钟),与 API 缓存分离
|
||||
func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Account, usage *UsageInfo) {
|
||||
// 修复:即使 FiveHour 为 nil,也要尝试获取统计数据
|
||||
// 因为 SevenDay/SevenDaySonnet 可能需要
|
||||
if usage.FiveHour == nil && usage.SevenDay == nil && usage.SevenDaySonnet == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 检查窗口统计缓存(1 分钟)
|
||||
var windowStats *WindowStats
|
||||
if cached, ok := s.cache.windowStatsCache.Load(account.ID); ok {
|
||||
if cache, ok := cached.(*windowStatsCache); ok && time.Since(cache.timestamp) < windowStatsCacheTTL {
|
||||
windowStats = cache.stats
|
||||
}
|
||||
}
|
||||
|
||||
// 如果没有缓存,从数据库查询
|
||||
if windowStats == nil {
|
||||
var startTime time.Time
|
||||
if account.SessionWindowStart != nil {
|
||||
startTime = *account.SessionWindowStart
|
||||
} else {
|
||||
startTime = time.Now().Add(-5 * time.Hour)
|
||||
}
|
||||
|
||||
stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, startTime)
|
||||
if err != nil {
|
||||
log.Printf("Failed to get window stats for account %d: %v", account.ID, err)
|
||||
return
|
||||
}
|
||||
|
||||
windowStats = &WindowStats{
|
||||
Requests: stats.Requests,
|
||||
Tokens: stats.Tokens,
|
||||
Cost: stats.Cost,
|
||||
StandardCost: stats.StandardCost,
|
||||
UserCost: stats.UserCost,
|
||||
}
|
||||
|
||||
// 缓存窗口统计(1 分钟)
|
||||
s.cache.windowStatsCache.Store(account.ID, &windowStatsCache{
|
||||
stats: windowStats,
|
||||
timestamp: time.Now(),
|
||||
})
|
||||
}
|
||||
|
||||
// 为 FiveHour 添加 WindowStats(5h 窗口统计)
|
||||
if usage.FiveHour != nil {
|
||||
usage.FiveHour.WindowStats = windowStats
|
||||
}
|
||||
}
|
||||
|
||||
// GetTodayStats 获取账号今日统计
|
||||
func (s *AccountUsageService) GetTodayStats(ctx context.Context, accountID int64) (*WindowStats, error) {
|
||||
stats, err := s.usageLogRepo.GetAccountTodayStats(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get today stats failed: %w", err)
|
||||
}
|
||||
|
||||
return &WindowStats{
|
||||
Requests: stats.Requests,
|
||||
Tokens: stats.Tokens,
|
||||
Cost: stats.Cost,
|
||||
StandardCost: stats.StandardCost,
|
||||
UserCost: stats.UserCost,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *AccountUsageService) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.AccountUsageStatsResponse, error) {
|
||||
stats, err := s.usageLogRepo.GetAccountUsageStats(ctx, accountID, startTime, endTime)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get account usage stats failed: %w", err)
|
||||
}
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
// fetchOAuthUsageRaw 从 Anthropic API 获取原始响应(不构建 UsageInfo)
|
||||
func (s *AccountUsageService) fetchOAuthUsageRaw(ctx context.Context, account *Account) (*ClaudeUsageResponse, error) {
|
||||
accessToken := account.GetCredential("access_token")
|
||||
if accessToken == "" {
|
||||
return nil, fmt.Errorf("no access token available")
|
||||
}
|
||||
|
||||
var proxyURL string
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
return s.usageFetcher.FetchUsage(ctx, accessToken, proxyURL)
|
||||
}
|
||||
|
||||
// parseTime 尝试多种格式解析时间
|
||||
func parseTime(s string) (time.Time, error) {
|
||||
formats := []string{
|
||||
time.RFC3339,
|
||||
time.RFC3339Nano,
|
||||
"2006-01-02T15:04:05Z",
|
||||
"2006-01-02T15:04:05.000Z",
|
||||
}
|
||||
for _, format := range formats {
|
||||
if t, err := time.Parse(format, s); err == nil {
|
||||
return t, nil
|
||||
}
|
||||
}
|
||||
return time.Time{}, fmt.Errorf("unable to parse time: %s", s)
|
||||
}
|
||||
|
||||
// buildUsageInfo 构建UsageInfo
|
||||
func (s *AccountUsageService) buildUsageInfo(resp *ClaudeUsageResponse, updatedAt *time.Time) *UsageInfo {
|
||||
info := &UsageInfo{
|
||||
UpdatedAt: updatedAt,
|
||||
}
|
||||
|
||||
// 5小时窗口 - 始终创建对象(即使 ResetsAt 为空)
|
||||
info.FiveHour = &UsageProgress{
|
||||
Utilization: resp.FiveHour.Utilization,
|
||||
}
|
||||
if resp.FiveHour.ResetsAt != "" {
|
||||
if fiveHourReset, err := parseTime(resp.FiveHour.ResetsAt); err == nil {
|
||||
info.FiveHour.ResetsAt = &fiveHourReset
|
||||
info.FiveHour.RemainingSeconds = int(time.Until(fiveHourReset).Seconds())
|
||||
} else {
|
||||
log.Printf("Failed to parse FiveHour.ResetsAt: %s, error: %v", resp.FiveHour.ResetsAt, err)
|
||||
}
|
||||
}
|
||||
|
||||
// 7天窗口
|
||||
if resp.SevenDay.ResetsAt != "" {
|
||||
if sevenDayReset, err := parseTime(resp.SevenDay.ResetsAt); err == nil {
|
||||
info.SevenDay = &UsageProgress{
|
||||
Utilization: resp.SevenDay.Utilization,
|
||||
ResetsAt: &sevenDayReset,
|
||||
RemainingSeconds: int(time.Until(sevenDayReset).Seconds()),
|
||||
}
|
||||
} else {
|
||||
log.Printf("Failed to parse SevenDay.ResetsAt: %s, error: %v", resp.SevenDay.ResetsAt, err)
|
||||
info.SevenDay = &UsageProgress{
|
||||
Utilization: resp.SevenDay.Utilization,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 7天Sonnet窗口
|
||||
if resp.SevenDaySonnet.ResetsAt != "" {
|
||||
if sonnetReset, err := parseTime(resp.SevenDaySonnet.ResetsAt); err == nil {
|
||||
info.SevenDaySonnet = &UsageProgress{
|
||||
Utilization: resp.SevenDaySonnet.Utilization,
|
||||
ResetsAt: &sonnetReset,
|
||||
RemainingSeconds: int(time.Until(sonnetReset).Seconds()),
|
||||
}
|
||||
} else {
|
||||
log.Printf("Failed to parse SevenDaySonnet.ResetsAt: %s, error: %v", resp.SevenDaySonnet.ResetsAt, err)
|
||||
info.SevenDaySonnet = &UsageProgress{
|
||||
Utilization: resp.SevenDaySonnet.Utilization,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return info
|
||||
}
|
||||
|
||||
// estimateSetupTokenUsage 根据session_window推算Setup Token账号的使用量
|
||||
func (s *AccountUsageService) estimateSetupTokenUsage(account *Account) *UsageInfo {
|
||||
info := &UsageInfo{}
|
||||
|
||||
// 如果有session_window信息
|
||||
if account.SessionWindowEnd != nil {
|
||||
remaining := int(time.Until(*account.SessionWindowEnd).Seconds())
|
||||
if remaining < 0 {
|
||||
remaining = 0
|
||||
}
|
||||
|
||||
// 根据状态估算使用率 (百分比形式,100 = 100%)
|
||||
var utilization float64
|
||||
switch account.SessionWindowStatus {
|
||||
case "rejected":
|
||||
utilization = 100.0
|
||||
case "allowed_warning":
|
||||
utilization = 80.0
|
||||
default:
|
||||
utilization = 0.0
|
||||
}
|
||||
|
||||
info.FiveHour = &UsageProgress{
|
||||
Utilization: utilization,
|
||||
ResetsAt: account.SessionWindowEnd,
|
||||
RemainingSeconds: remaining,
|
||||
}
|
||||
} else {
|
||||
// 没有窗口信息,返回空数据
|
||||
info.FiveHour = &UsageProgress{
|
||||
Utilization: 0,
|
||||
RemainingSeconds: 0,
|
||||
}
|
||||
}
|
||||
|
||||
// Setup Token无法获取7d数据
|
||||
return info
|
||||
}
|
||||
|
||||
func buildGeminiUsageProgress(used, limit int64, resetAt time.Time, tokens int64, cost float64, now time.Time) *UsageProgress {
|
||||
// limit <= 0 means "no local quota window" (unknown or unlimited).
|
||||
if limit <= 0 {
|
||||
return nil
|
||||
}
|
||||
utilization := (float64(used) / float64(limit)) * 100
|
||||
remainingSeconds := int(resetAt.Sub(now).Seconds())
|
||||
if remainingSeconds < 0 {
|
||||
remainingSeconds = 0
|
||||
}
|
||||
resetCopy := resetAt
|
||||
return &UsageProgress{
|
||||
Utilization: utilization,
|
||||
ResetsAt: &resetCopy,
|
||||
RemainingSeconds: remainingSeconds,
|
||||
UsedRequests: used,
|
||||
LimitRequests: limit,
|
||||
WindowStats: &WindowStats{
|
||||
Requests: used,
|
||||
Tokens: tokens,
|
||||
Cost: cost,
|
||||
},
|
||||
}
|
||||
}
|
||||
1512
backend/internal/service/admin_service.go
Normal file
1512
backend/internal/service/admin_service.go
Normal file
File diff suppressed because it is too large
Load Diff
80
backend/internal/service/admin_service_bulk_update_test.go
Normal file
80
backend/internal/service/admin_service_bulk_update_test.go
Normal file
@@ -0,0 +1,80 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type accountRepoStubForBulkUpdate struct {
|
||||
accountRepoStub
|
||||
bulkUpdateErr error
|
||||
bulkUpdateIDs []int64
|
||||
bindGroupErrByID map[int64]error
|
||||
}
|
||||
|
||||
func (s *accountRepoStubForBulkUpdate) BulkUpdate(_ context.Context, ids []int64, _ AccountBulkUpdate) (int64, error) {
|
||||
s.bulkUpdateIDs = append([]int64{}, ids...)
|
||||
if s.bulkUpdateErr != nil {
|
||||
return 0, s.bulkUpdateErr
|
||||
}
|
||||
return int64(len(ids)), nil
|
||||
}
|
||||
|
||||
func (s *accountRepoStubForBulkUpdate) BindGroups(_ context.Context, accountID int64, _ []int64) error {
|
||||
if err, ok := s.bindGroupErrByID[accountID]; ok {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TestAdminService_BulkUpdateAccounts_AllSuccessIDs 验证批量更新成功时返回 success_ids/failed_ids。
|
||||
func TestAdminService_BulkUpdateAccounts_AllSuccessIDs(t *testing.T) {
|
||||
repo := &accountRepoStubForBulkUpdate{}
|
||||
svc := &adminServiceImpl{accountRepo: repo}
|
||||
|
||||
schedulable := true
|
||||
input := &BulkUpdateAccountsInput{
|
||||
AccountIDs: []int64{1, 2, 3},
|
||||
Schedulable: &schedulable,
|
||||
}
|
||||
|
||||
result, err := svc.BulkUpdateAccounts(context.Background(), input)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 3, result.Success)
|
||||
require.Equal(t, 0, result.Failed)
|
||||
require.ElementsMatch(t, []int64{1, 2, 3}, result.SuccessIDs)
|
||||
require.Empty(t, result.FailedIDs)
|
||||
require.Len(t, result.Results, 3)
|
||||
}
|
||||
|
||||
// TestAdminService_BulkUpdateAccounts_PartialFailureIDs 验证部分失败时 success_ids/failed_ids 正确。
|
||||
func TestAdminService_BulkUpdateAccounts_PartialFailureIDs(t *testing.T) {
|
||||
repo := &accountRepoStubForBulkUpdate{
|
||||
bindGroupErrByID: map[int64]error{
|
||||
2: errors.New("bind failed"),
|
||||
},
|
||||
}
|
||||
svc := &adminServiceImpl{accountRepo: repo}
|
||||
|
||||
groupIDs := []int64{10}
|
||||
schedulable := false
|
||||
input := &BulkUpdateAccountsInput{
|
||||
AccountIDs: []int64{1, 2, 3},
|
||||
GroupIDs: &groupIDs,
|
||||
Schedulable: &schedulable,
|
||||
SkipMixedChannelCheck: true,
|
||||
}
|
||||
|
||||
result, err := svc.BulkUpdateAccounts(context.Background(), input)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 2, result.Success)
|
||||
require.Equal(t, 1, result.Failed)
|
||||
require.ElementsMatch(t, []int64{1, 3}, result.SuccessIDs)
|
||||
require.ElementsMatch(t, []int64{2}, result.FailedIDs)
|
||||
require.Len(t, result.Results, 3)
|
||||
}
|
||||
67
backend/internal/service/admin_service_create_user_test.go
Normal file
67
backend/internal/service/admin_service_create_user_test.go
Normal file
@@ -0,0 +1,67 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAdminService_CreateUser_Success(t *testing.T) {
|
||||
repo := &userRepoStub{nextID: 10}
|
||||
svc := &adminServiceImpl{userRepo: repo}
|
||||
|
||||
input := &CreateUserInput{
|
||||
Email: "user@test.com",
|
||||
Password: "strong-pass",
|
||||
Username: "tester",
|
||||
Notes: "note",
|
||||
Balance: 12.5,
|
||||
Concurrency: 7,
|
||||
AllowedGroups: []int64{3, 5},
|
||||
}
|
||||
|
||||
user, err := svc.CreateUser(context.Background(), input)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, user)
|
||||
require.Equal(t, int64(10), user.ID)
|
||||
require.Equal(t, input.Email, user.Email)
|
||||
require.Equal(t, input.Username, user.Username)
|
||||
require.Equal(t, input.Notes, user.Notes)
|
||||
require.Equal(t, input.Balance, user.Balance)
|
||||
require.Equal(t, input.Concurrency, user.Concurrency)
|
||||
require.Equal(t, input.AllowedGroups, user.AllowedGroups)
|
||||
require.Equal(t, RoleUser, user.Role)
|
||||
require.Equal(t, StatusActive, user.Status)
|
||||
require.True(t, user.CheckPassword(input.Password))
|
||||
require.Len(t, repo.created, 1)
|
||||
require.Equal(t, user, repo.created[0])
|
||||
}
|
||||
|
||||
func TestAdminService_CreateUser_EmailExists(t *testing.T) {
|
||||
repo := &userRepoStub{createErr: ErrEmailExists}
|
||||
svc := &adminServiceImpl{userRepo: repo}
|
||||
|
||||
_, err := svc.CreateUser(context.Background(), &CreateUserInput{
|
||||
Email: "dup@test.com",
|
||||
Password: "password",
|
||||
})
|
||||
require.ErrorIs(t, err, ErrEmailExists)
|
||||
require.Empty(t, repo.created)
|
||||
}
|
||||
|
||||
func TestAdminService_CreateUser_CreateError(t *testing.T) {
|
||||
createErr := errors.New("db down")
|
||||
repo := &userRepoStub{createErr: createErr}
|
||||
svc := &adminServiceImpl{userRepo: repo}
|
||||
|
||||
_, err := svc.CreateUser(context.Background(), &CreateUserInput{
|
||||
Email: "user@test.com",
|
||||
Password: "password",
|
||||
})
|
||||
require.ErrorIs(t, err, createErr)
|
||||
require.Empty(t, repo.created)
|
||||
}
|
||||
489
backend/internal/service/admin_service_delete_test.go
Normal file
489
backend/internal/service/admin_service_delete_test.go
Normal file
@@ -0,0 +1,489 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type userRepoStub struct {
|
||||
user *User
|
||||
getErr error
|
||||
createErr error
|
||||
deleteErr error
|
||||
exists bool
|
||||
existsErr error
|
||||
nextID int64
|
||||
created []*User
|
||||
deletedIDs []int64
|
||||
}
|
||||
|
||||
func (s *userRepoStub) Create(ctx context.Context, user *User) error {
|
||||
if s.createErr != nil {
|
||||
return s.createErr
|
||||
}
|
||||
if s.nextID != 0 && user.ID == 0 {
|
||||
user.ID = s.nextID
|
||||
}
|
||||
s.created = append(s.created, user)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *userRepoStub) GetByID(ctx context.Context, id int64) (*User, error) {
|
||||
if s.getErr != nil {
|
||||
return nil, s.getErr
|
||||
}
|
||||
if s.user == nil {
|
||||
return nil, ErrUserNotFound
|
||||
}
|
||||
return s.user, nil
|
||||
}
|
||||
|
||||
func (s *userRepoStub) GetByEmail(ctx context.Context, email string) (*User, error) {
|
||||
panic("unexpected GetByEmail call")
|
||||
}
|
||||
|
||||
func (s *userRepoStub) GetFirstAdmin(ctx context.Context) (*User, error) {
|
||||
panic("unexpected GetFirstAdmin call")
|
||||
}
|
||||
|
||||
func (s *userRepoStub) Update(ctx context.Context, user *User) error {
|
||||
panic("unexpected Update call")
|
||||
}
|
||||
|
||||
func (s *userRepoStub) Delete(ctx context.Context, id int64) error {
|
||||
s.deletedIDs = append(s.deletedIDs, id)
|
||||
return s.deleteErr
|
||||
}
|
||||
|
||||
func (s *userRepoStub) List(ctx context.Context, params pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
|
||||
panic("unexpected List call")
|
||||
}
|
||||
|
||||
func (s *userRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UserListFilters) ([]User, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListWithFilters call")
|
||||
}
|
||||
|
||||
func (s *userRepoStub) UpdateBalance(ctx context.Context, id int64, amount float64) error {
|
||||
panic("unexpected UpdateBalance call")
|
||||
}
|
||||
|
||||
func (s *userRepoStub) DeductBalance(ctx context.Context, id int64, amount float64) error {
|
||||
panic("unexpected DeductBalance call")
|
||||
}
|
||||
|
||||
func (s *userRepoStub) UpdateConcurrency(ctx context.Context, id int64, amount int) error {
|
||||
panic("unexpected UpdateConcurrency call")
|
||||
}
|
||||
|
||||
func (s *userRepoStub) ExistsByEmail(ctx context.Context, email string) (bool, error) {
|
||||
if s.existsErr != nil {
|
||||
return false, s.existsErr
|
||||
}
|
||||
return s.exists, nil
|
||||
}
|
||||
|
||||
func (s *userRepoStub) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) {
|
||||
panic("unexpected RemoveGroupFromAllowedGroups call")
|
||||
}
|
||||
|
||||
type groupRepoStub struct {
|
||||
affectedUserIDs []int64
|
||||
deleteErr error
|
||||
deleteCalls []int64
|
||||
}
|
||||
|
||||
func (s *groupRepoStub) Create(ctx context.Context, group *Group) error {
|
||||
panic("unexpected Create call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStub) GetByID(ctx context.Context, id int64) (*Group, error) {
|
||||
panic("unexpected GetByID call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStub) GetByIDLite(ctx context.Context, id int64) (*Group, error) {
|
||||
panic("unexpected GetByIDLite call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStub) Update(ctx context.Context, group *Group) error {
|
||||
panic("unexpected Update call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStub) Delete(ctx context.Context, id int64) error {
|
||||
panic("unexpected Delete call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStub) DeleteCascade(ctx context.Context, id int64) ([]int64, error) {
|
||||
s.deleteCalls = append(s.deleteCalls, id)
|
||||
return s.affectedUserIDs, s.deleteErr
|
||||
}
|
||||
|
||||
func (s *groupRepoStub) List(ctx context.Context, params pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
|
||||
panic("unexpected List call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListWithFilters call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStub) ListActive(ctx context.Context) ([]Group, error) {
|
||||
panic("unexpected ListActive call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStub) ListActiveByPlatform(ctx context.Context, platform string) ([]Group, error) {
|
||||
panic("unexpected ListActiveByPlatform call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStub) ExistsByName(ctx context.Context, name string) (bool, error) {
|
||||
panic("unexpected ExistsByName call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStub) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
|
||||
panic("unexpected GetAccountCount call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStub) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
panic("unexpected DeleteAccountGroupsByGroupID call")
|
||||
}
|
||||
|
||||
type proxyRepoStub struct {
|
||||
deleteErr error
|
||||
countErr error
|
||||
accountCount int64
|
||||
deletedIDs []int64
|
||||
}
|
||||
|
||||
func (s *proxyRepoStub) Create(ctx context.Context, proxy *Proxy) error {
|
||||
panic("unexpected Create call")
|
||||
}
|
||||
|
||||
func (s *proxyRepoStub) GetByID(ctx context.Context, id int64) (*Proxy, error) {
|
||||
panic("unexpected GetByID call")
|
||||
}
|
||||
|
||||
func (s *proxyRepoStub) Update(ctx context.Context, proxy *Proxy) error {
|
||||
panic("unexpected Update call")
|
||||
}
|
||||
|
||||
func (s *proxyRepoStub) Delete(ctx context.Context, id int64) error {
|
||||
s.deletedIDs = append(s.deletedIDs, id)
|
||||
return s.deleteErr
|
||||
}
|
||||
|
||||
func (s *proxyRepoStub) List(ctx context.Context, params pagination.PaginationParams) ([]Proxy, *pagination.PaginationResult, error) {
|
||||
panic("unexpected List call")
|
||||
}
|
||||
|
||||
func (s *proxyRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]Proxy, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListWithFilters call")
|
||||
}
|
||||
|
||||
func (s *proxyRepoStub) ListActive(ctx context.Context) ([]Proxy, error) {
|
||||
panic("unexpected ListActive call")
|
||||
}
|
||||
|
||||
func (s *proxyRepoStub) ListActiveWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error) {
|
||||
panic("unexpected ListActiveWithAccountCount call")
|
||||
}
|
||||
|
||||
func (s *proxyRepoStub) ListWithFiltersAndAccountCount(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]ProxyWithAccountCount, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListWithFiltersAndAccountCount call")
|
||||
}
|
||||
|
||||
func (s *proxyRepoStub) ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error) {
|
||||
panic("unexpected ExistsByHostPortAuth call")
|
||||
}
|
||||
|
||||
func (s *proxyRepoStub) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) {
|
||||
if s.countErr != nil {
|
||||
return 0, s.countErr
|
||||
}
|
||||
return s.accountCount, nil
|
||||
}
|
||||
|
||||
func (s *proxyRepoStub) ListAccountSummariesByProxyID(ctx context.Context, proxyID int64) ([]ProxyAccountSummary, error) {
|
||||
panic("unexpected ListAccountSummariesByProxyID call")
|
||||
}
|
||||
|
||||
type redeemRepoStub struct {
|
||||
deleteErrByID map[int64]error
|
||||
deletedIDs []int64
|
||||
}
|
||||
|
||||
func (s *redeemRepoStub) Create(ctx context.Context, code *RedeemCode) error {
|
||||
panic("unexpected Create call")
|
||||
}
|
||||
|
||||
func (s *redeemRepoStub) CreateBatch(ctx context.Context, codes []RedeemCode) error {
|
||||
panic("unexpected CreateBatch call")
|
||||
}
|
||||
|
||||
func (s *redeemRepoStub) GetByID(ctx context.Context, id int64) (*RedeemCode, error) {
|
||||
panic("unexpected GetByID call")
|
||||
}
|
||||
|
||||
func (s *redeemRepoStub) GetByCode(ctx context.Context, code string) (*RedeemCode, error) {
|
||||
panic("unexpected GetByCode call")
|
||||
}
|
||||
|
||||
func (s *redeemRepoStub) Update(ctx context.Context, code *RedeemCode) error {
|
||||
panic("unexpected Update call")
|
||||
}
|
||||
|
||||
func (s *redeemRepoStub) Delete(ctx context.Context, id int64) error {
|
||||
s.deletedIDs = append(s.deletedIDs, id)
|
||||
if s.deleteErrByID != nil {
|
||||
if err, ok := s.deleteErrByID[id]; ok {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *redeemRepoStub) Use(ctx context.Context, id, userID int64) error {
|
||||
panic("unexpected Use call")
|
||||
}
|
||||
|
||||
func (s *redeemRepoStub) List(ctx context.Context, params pagination.PaginationParams) ([]RedeemCode, *pagination.PaginationResult, error) {
|
||||
panic("unexpected List call")
|
||||
}
|
||||
|
||||
func (s *redeemRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, codeType, status, search string) ([]RedeemCode, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListWithFilters call")
|
||||
}
|
||||
|
||||
func (s *redeemRepoStub) ListByUser(ctx context.Context, userID int64, limit int) ([]RedeemCode, error) {
|
||||
panic("unexpected ListByUser call")
|
||||
}
|
||||
|
||||
type subscriptionInvalidateCall struct {
|
||||
userID int64
|
||||
groupID int64
|
||||
}
|
||||
|
||||
type billingCacheStub struct {
|
||||
invalidations chan subscriptionInvalidateCall
|
||||
}
|
||||
|
||||
func newBillingCacheStub(buffer int) *billingCacheStub {
|
||||
return &billingCacheStub{invalidations: make(chan subscriptionInvalidateCall, buffer)}
|
||||
}
|
||||
|
||||
func (s *billingCacheStub) GetUserBalance(ctx context.Context, userID int64) (float64, error) {
|
||||
panic("unexpected GetUserBalance call")
|
||||
}
|
||||
|
||||
func (s *billingCacheStub) SetUserBalance(ctx context.Context, userID int64, balance float64) error {
|
||||
panic("unexpected SetUserBalance call")
|
||||
}
|
||||
|
||||
func (s *billingCacheStub) DeductUserBalance(ctx context.Context, userID int64, amount float64) error {
|
||||
panic("unexpected DeductUserBalance call")
|
||||
}
|
||||
|
||||
func (s *billingCacheStub) InvalidateUserBalance(ctx context.Context, userID int64) error {
|
||||
panic("unexpected InvalidateUserBalance call")
|
||||
}
|
||||
|
||||
func (s *billingCacheStub) GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*SubscriptionCacheData, error) {
|
||||
panic("unexpected GetSubscriptionCache call")
|
||||
}
|
||||
|
||||
func (s *billingCacheStub) SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *SubscriptionCacheData) error {
|
||||
panic("unexpected SetSubscriptionCache call")
|
||||
}
|
||||
|
||||
func (s *billingCacheStub) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error {
|
||||
panic("unexpected UpdateSubscriptionUsage call")
|
||||
}
|
||||
|
||||
func (s *billingCacheStub) InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error {
|
||||
s.invalidations <- subscriptionInvalidateCall{userID: userID, groupID: groupID}
|
||||
return nil
|
||||
}
|
||||
|
||||
func waitForInvalidations(t *testing.T, ch <-chan subscriptionInvalidateCall, expected int) []subscriptionInvalidateCall {
|
||||
t.Helper()
|
||||
calls := make([]subscriptionInvalidateCall, 0, expected)
|
||||
timeout := time.After(2 * time.Second)
|
||||
for len(calls) < expected {
|
||||
select {
|
||||
case call := <-ch:
|
||||
calls = append(calls, call)
|
||||
case <-timeout:
|
||||
t.Fatalf("timeout waiting for %d invalidations, got %d", expected, len(calls))
|
||||
}
|
||||
}
|
||||
return calls
|
||||
}
|
||||
|
||||
func TestAdminService_DeleteUser_Success(t *testing.T) {
|
||||
repo := &userRepoStub{user: &User{ID: 7, Role: RoleUser}}
|
||||
svc := &adminServiceImpl{userRepo: repo}
|
||||
|
||||
err := svc.DeleteUser(context.Background(), 7)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []int64{7}, repo.deletedIDs)
|
||||
}
|
||||
|
||||
func TestAdminService_DeleteUser_NotFound(t *testing.T) {
|
||||
repo := &userRepoStub{getErr: ErrUserNotFound}
|
||||
svc := &adminServiceImpl{userRepo: repo}
|
||||
|
||||
err := svc.DeleteUser(context.Background(), 404)
|
||||
require.ErrorIs(t, err, ErrUserNotFound)
|
||||
require.Empty(t, repo.deletedIDs)
|
||||
}
|
||||
|
||||
func TestAdminService_DeleteUser_AdminGuard(t *testing.T) {
|
||||
repo := &userRepoStub{user: &User{ID: 1, Role: RoleAdmin}}
|
||||
svc := &adminServiceImpl{userRepo: repo}
|
||||
|
||||
err := svc.DeleteUser(context.Background(), 1)
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, "cannot delete admin user")
|
||||
require.Empty(t, repo.deletedIDs)
|
||||
}
|
||||
|
||||
func TestAdminService_DeleteUser_DeleteError(t *testing.T) {
|
||||
deleteErr := errors.New("delete failed")
|
||||
repo := &userRepoStub{
|
||||
user: &User{ID: 9, Role: RoleUser},
|
||||
deleteErr: deleteErr,
|
||||
}
|
||||
svc := &adminServiceImpl{userRepo: repo}
|
||||
|
||||
err := svc.DeleteUser(context.Background(), 9)
|
||||
require.ErrorIs(t, err, deleteErr)
|
||||
require.Equal(t, []int64{9}, repo.deletedIDs)
|
||||
}
|
||||
|
||||
func TestAdminService_DeleteGroup_Success_WithCacheInvalidation(t *testing.T) {
|
||||
cache := newBillingCacheStub(2)
|
||||
repo := &groupRepoStub{affectedUserIDs: []int64{11, 12}}
|
||||
svc := &adminServiceImpl{
|
||||
groupRepo: repo,
|
||||
billingCacheService: &BillingCacheService{cache: cache},
|
||||
}
|
||||
|
||||
err := svc.DeleteGroup(context.Background(), 5)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []int64{5}, repo.deleteCalls)
|
||||
|
||||
calls := waitForInvalidations(t, cache.invalidations, 2)
|
||||
require.ElementsMatch(t, []subscriptionInvalidateCall{
|
||||
{userID: 11, groupID: 5},
|
||||
{userID: 12, groupID: 5},
|
||||
}, calls)
|
||||
}
|
||||
|
||||
func TestAdminService_DeleteGroup_NotFound(t *testing.T) {
|
||||
repo := &groupRepoStub{deleteErr: ErrGroupNotFound}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
err := svc.DeleteGroup(context.Background(), 99)
|
||||
require.ErrorIs(t, err, ErrGroupNotFound)
|
||||
}
|
||||
|
||||
func TestAdminService_DeleteGroup_Error(t *testing.T) {
|
||||
deleteErr := errors.New("delete failed")
|
||||
repo := &groupRepoStub{deleteErr: deleteErr}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
err := svc.DeleteGroup(context.Background(), 42)
|
||||
require.ErrorIs(t, err, deleteErr)
|
||||
}
|
||||
|
||||
func TestAdminService_DeleteProxy_Success(t *testing.T) {
|
||||
repo := &proxyRepoStub{}
|
||||
svc := &adminServiceImpl{proxyRepo: repo}
|
||||
|
||||
err := svc.DeleteProxy(context.Background(), 7)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []int64{7}, repo.deletedIDs)
|
||||
}
|
||||
|
||||
func TestAdminService_DeleteProxy_Idempotent(t *testing.T) {
|
||||
repo := &proxyRepoStub{}
|
||||
svc := &adminServiceImpl{proxyRepo: repo}
|
||||
|
||||
err := svc.DeleteProxy(context.Background(), 404)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []int64{404}, repo.deletedIDs)
|
||||
}
|
||||
|
||||
func TestAdminService_DeleteProxy_InUse(t *testing.T) {
|
||||
repo := &proxyRepoStub{accountCount: 2}
|
||||
svc := &adminServiceImpl{proxyRepo: repo}
|
||||
|
||||
err := svc.DeleteProxy(context.Background(), 77)
|
||||
require.ErrorIs(t, err, ErrProxyInUse)
|
||||
require.Empty(t, repo.deletedIDs)
|
||||
}
|
||||
|
||||
func TestAdminService_DeleteProxy_Error(t *testing.T) {
|
||||
deleteErr := errors.New("delete failed")
|
||||
repo := &proxyRepoStub{deleteErr: deleteErr}
|
||||
svc := &adminServiceImpl{proxyRepo: repo}
|
||||
|
||||
err := svc.DeleteProxy(context.Background(), 33)
|
||||
require.ErrorIs(t, err, deleteErr)
|
||||
}
|
||||
|
||||
func TestAdminService_DeleteRedeemCode_Success(t *testing.T) {
|
||||
repo := &redeemRepoStub{}
|
||||
svc := &adminServiceImpl{redeemCodeRepo: repo}
|
||||
|
||||
err := svc.DeleteRedeemCode(context.Background(), 10)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []int64{10}, repo.deletedIDs)
|
||||
}
|
||||
|
||||
func TestAdminService_DeleteRedeemCode_Idempotent(t *testing.T) {
|
||||
repo := &redeemRepoStub{}
|
||||
svc := &adminServiceImpl{redeemCodeRepo: repo}
|
||||
|
||||
err := svc.DeleteRedeemCode(context.Background(), 999)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []int64{999}, repo.deletedIDs)
|
||||
}
|
||||
|
||||
func TestAdminService_DeleteRedeemCode_Error(t *testing.T) {
|
||||
deleteErr := errors.New("delete failed")
|
||||
repo := &redeemRepoStub{deleteErrByID: map[int64]error{1: deleteErr}}
|
||||
svc := &adminServiceImpl{redeemCodeRepo: repo}
|
||||
|
||||
err := svc.DeleteRedeemCode(context.Background(), 1)
|
||||
require.ErrorIs(t, err, deleteErr)
|
||||
require.Equal(t, []int64{1}, repo.deletedIDs)
|
||||
}
|
||||
|
||||
func TestAdminService_BatchDeleteRedeemCodes_Success(t *testing.T) {
|
||||
repo := &redeemRepoStub{}
|
||||
svc := &adminServiceImpl{redeemCodeRepo: repo}
|
||||
|
||||
deleted, err := svc.BatchDeleteRedeemCodes(context.Background(), []int64{1, 2, 3})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(3), deleted)
|
||||
require.Equal(t, []int64{1, 2, 3}, repo.deletedIDs)
|
||||
}
|
||||
|
||||
func TestAdminService_BatchDeleteRedeemCodes_PartialFailures(t *testing.T) {
|
||||
repo := &redeemRepoStub{
|
||||
deleteErrByID: map[int64]error{
|
||||
2: errors.New("db error"),
|
||||
},
|
||||
}
|
||||
svc := &adminServiceImpl{redeemCodeRepo: repo}
|
||||
|
||||
deleted, err := svc.BatchDeleteRedeemCodes(context.Background(), []int64{1, 2, 3})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(2), deleted)
|
||||
require.Equal(t, []int64{1, 2, 3}, repo.deletedIDs)
|
||||
}
|
||||
380
backend/internal/service/admin_service_group_test.go
Normal file
380
backend/internal/service/admin_service_group_test.go
Normal file
@@ -0,0 +1,380 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// groupRepoStubForAdmin 用于测试 AdminService 的 GroupRepository Stub
|
||||
type groupRepoStubForAdmin struct {
|
||||
created *Group // 记录 Create 调用的参数
|
||||
updated *Group // 记录 Update 调用的参数
|
||||
getByID *Group // GetByID 返回值
|
||||
getErr error // GetByID 返回的错误
|
||||
|
||||
listWithFiltersCalls int
|
||||
listWithFiltersParams pagination.PaginationParams
|
||||
listWithFiltersPlatform string
|
||||
listWithFiltersStatus string
|
||||
listWithFiltersSearch string
|
||||
listWithFiltersIsExclusive *bool
|
||||
listWithFiltersGroups []Group
|
||||
listWithFiltersResult *pagination.PaginationResult
|
||||
listWithFiltersErr error
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) Create(_ context.Context, g *Group) error {
|
||||
s.created = g
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) Update(_ context.Context, g *Group) error {
|
||||
s.updated = g
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) GetByID(_ context.Context, _ int64) (*Group, error) {
|
||||
if s.getErr != nil {
|
||||
return nil, s.getErr
|
||||
}
|
||||
return s.getByID, nil
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) GetByIDLite(_ context.Context, _ int64) (*Group, error) {
|
||||
if s.getErr != nil {
|
||||
return nil, s.getErr
|
||||
}
|
||||
return s.getByID, nil
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) Delete(_ context.Context, _ int64) error {
|
||||
panic("unexpected Delete call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) DeleteCascade(_ context.Context, _ int64) ([]int64, error) {
|
||||
panic("unexpected DeleteCascade call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) List(_ context.Context, _ pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
|
||||
panic("unexpected List call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) {
|
||||
s.listWithFiltersCalls++
|
||||
s.listWithFiltersParams = params
|
||||
s.listWithFiltersPlatform = platform
|
||||
s.listWithFiltersStatus = status
|
||||
s.listWithFiltersSearch = search
|
||||
s.listWithFiltersIsExclusive = isExclusive
|
||||
|
||||
if s.listWithFiltersErr != nil {
|
||||
return nil, nil, s.listWithFiltersErr
|
||||
}
|
||||
|
||||
result := s.listWithFiltersResult
|
||||
if result == nil {
|
||||
result = &pagination.PaginationResult{
|
||||
Total: int64(len(s.listWithFiltersGroups)),
|
||||
Page: params.Page,
|
||||
PageSize: params.PageSize,
|
||||
}
|
||||
}
|
||||
|
||||
return s.listWithFiltersGroups, result, nil
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) ListActive(_ context.Context) ([]Group, error) {
|
||||
panic("unexpected ListActive call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) ListActiveByPlatform(_ context.Context, _ string) ([]Group, error) {
|
||||
panic("unexpected ListActiveByPlatform call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) ExistsByName(_ context.Context, _ string) (bool, error) {
|
||||
panic("unexpected ExistsByName call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) GetAccountCount(_ context.Context, _ int64) (int64, error) {
|
||||
panic("unexpected GetAccountCount call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) DeleteAccountGroupsByGroupID(_ context.Context, _ int64) (int64, error) {
|
||||
panic("unexpected DeleteAccountGroupsByGroupID call")
|
||||
}
|
||||
|
||||
// TestAdminService_CreateGroup_WithImagePricing 测试创建分组时 ImagePrice 字段正确传递
|
||||
func TestAdminService_CreateGroup_WithImagePricing(t *testing.T) {
|
||||
repo := &groupRepoStubForAdmin{}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
price1K := 0.10
|
||||
price2K := 0.15
|
||||
price4K := 0.30
|
||||
|
||||
input := &CreateGroupInput{
|
||||
Name: "test-group",
|
||||
Description: "Test group",
|
||||
Platform: PlatformAntigravity,
|
||||
RateMultiplier: 1.0,
|
||||
ImagePrice1K: &price1K,
|
||||
ImagePrice2K: &price2K,
|
||||
ImagePrice4K: &price4K,
|
||||
}
|
||||
|
||||
group, err := svc.CreateGroup(context.Background(), input)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, group)
|
||||
|
||||
// 验证 repo 收到了正确的字段
|
||||
require.NotNil(t, repo.created)
|
||||
require.NotNil(t, repo.created.ImagePrice1K)
|
||||
require.NotNil(t, repo.created.ImagePrice2K)
|
||||
require.NotNil(t, repo.created.ImagePrice4K)
|
||||
require.InDelta(t, 0.10, *repo.created.ImagePrice1K, 0.0001)
|
||||
require.InDelta(t, 0.15, *repo.created.ImagePrice2K, 0.0001)
|
||||
require.InDelta(t, 0.30, *repo.created.ImagePrice4K, 0.0001)
|
||||
}
|
||||
|
||||
// TestAdminService_CreateGroup_NilImagePricing 测试 ImagePrice 为 nil 时正常创建
|
||||
func TestAdminService_CreateGroup_NilImagePricing(t *testing.T) {
|
||||
repo := &groupRepoStubForAdmin{}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
input := &CreateGroupInput{
|
||||
Name: "test-group",
|
||||
Description: "Test group",
|
||||
Platform: PlatformAntigravity,
|
||||
RateMultiplier: 1.0,
|
||||
// ImagePrice 字段全部为 nil
|
||||
}
|
||||
|
||||
group, err := svc.CreateGroup(context.Background(), input)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, group)
|
||||
|
||||
// 验证 ImagePrice 字段为 nil
|
||||
require.NotNil(t, repo.created)
|
||||
require.Nil(t, repo.created.ImagePrice1K)
|
||||
require.Nil(t, repo.created.ImagePrice2K)
|
||||
require.Nil(t, repo.created.ImagePrice4K)
|
||||
}
|
||||
|
||||
// TestAdminService_UpdateGroup_WithImagePricing 测试更新分组时 ImagePrice 字段正确更新
|
||||
func TestAdminService_UpdateGroup_WithImagePricing(t *testing.T) {
|
||||
existingGroup := &Group{
|
||||
ID: 1,
|
||||
Name: "existing-group",
|
||||
Platform: PlatformAntigravity,
|
||||
Status: StatusActive,
|
||||
}
|
||||
repo := &groupRepoStubForAdmin{getByID: existingGroup}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
price1K := 0.12
|
||||
price2K := 0.18
|
||||
price4K := 0.36
|
||||
|
||||
input := &UpdateGroupInput{
|
||||
ImagePrice1K: &price1K,
|
||||
ImagePrice2K: &price2K,
|
||||
ImagePrice4K: &price4K,
|
||||
}
|
||||
|
||||
group, err := svc.UpdateGroup(context.Background(), 1, input)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, group)
|
||||
|
||||
// 验证 repo 收到了更新后的字段
|
||||
require.NotNil(t, repo.updated)
|
||||
require.NotNil(t, repo.updated.ImagePrice1K)
|
||||
require.NotNil(t, repo.updated.ImagePrice2K)
|
||||
require.NotNil(t, repo.updated.ImagePrice4K)
|
||||
require.InDelta(t, 0.12, *repo.updated.ImagePrice1K, 0.0001)
|
||||
require.InDelta(t, 0.18, *repo.updated.ImagePrice2K, 0.0001)
|
||||
require.InDelta(t, 0.36, *repo.updated.ImagePrice4K, 0.0001)
|
||||
}
|
||||
|
||||
// TestAdminService_UpdateGroup_PartialImagePricing 测试仅更新部分 ImagePrice 字段
|
||||
func TestAdminService_UpdateGroup_PartialImagePricing(t *testing.T) {
|
||||
oldPrice2K := 0.15
|
||||
existingGroup := &Group{
|
||||
ID: 1,
|
||||
Name: "existing-group",
|
||||
Platform: PlatformAntigravity,
|
||||
Status: StatusActive,
|
||||
ImagePrice2K: &oldPrice2K, // 已有 2K 价格
|
||||
}
|
||||
repo := &groupRepoStubForAdmin{getByID: existingGroup}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
// 只更新 1K 价格
|
||||
price1K := 0.10
|
||||
input := &UpdateGroupInput{
|
||||
ImagePrice1K: &price1K,
|
||||
// ImagePrice2K 和 ImagePrice4K 为 nil,不更新
|
||||
}
|
||||
|
||||
group, err := svc.UpdateGroup(context.Background(), 1, input)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, group)
|
||||
|
||||
// 验证:1K 被更新,2K 保持原值,4K 仍为 nil
|
||||
require.NotNil(t, repo.updated)
|
||||
require.NotNil(t, repo.updated.ImagePrice1K)
|
||||
require.InDelta(t, 0.10, *repo.updated.ImagePrice1K, 0.0001)
|
||||
require.NotNil(t, repo.updated.ImagePrice2K)
|
||||
require.InDelta(t, 0.15, *repo.updated.ImagePrice2K, 0.0001) // 原值保持
|
||||
require.Nil(t, repo.updated.ImagePrice4K)
|
||||
}
|
||||
|
||||
func TestAdminService_ListGroups_WithSearch(t *testing.T) {
|
||||
// 测试:
|
||||
// 1. search 参数正常传递到 repository 层
|
||||
// 2. search 为空字符串时的行为
|
||||
// 3. search 与其他过滤条件组合使用
|
||||
|
||||
t.Run("search 参数正常传递到 repository 层", func(t *testing.T) {
|
||||
repo := &groupRepoStubForAdmin{
|
||||
listWithFiltersGroups: []Group{{ID: 1, Name: "alpha"}},
|
||||
listWithFiltersResult: &pagination.PaginationResult{Total: 1},
|
||||
}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
groups, total, err := svc.ListGroups(context.Background(), 1, 20, "", "", "alpha", nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1), total)
|
||||
require.Equal(t, []Group{{ID: 1, Name: "alpha"}}, groups)
|
||||
|
||||
require.Equal(t, 1, repo.listWithFiltersCalls)
|
||||
require.Equal(t, pagination.PaginationParams{Page: 1, PageSize: 20}, repo.listWithFiltersParams)
|
||||
require.Equal(t, "alpha", repo.listWithFiltersSearch)
|
||||
require.Nil(t, repo.listWithFiltersIsExclusive)
|
||||
})
|
||||
|
||||
t.Run("search 为空字符串时传递空字符串", func(t *testing.T) {
|
||||
repo := &groupRepoStubForAdmin{
|
||||
listWithFiltersGroups: []Group{},
|
||||
listWithFiltersResult: &pagination.PaginationResult{Total: 0},
|
||||
}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
groups, total, err := svc.ListGroups(context.Background(), 2, 10, "", "", "", nil)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, groups)
|
||||
require.Equal(t, int64(0), total)
|
||||
|
||||
require.Equal(t, 1, repo.listWithFiltersCalls)
|
||||
require.Equal(t, pagination.PaginationParams{Page: 2, PageSize: 10}, repo.listWithFiltersParams)
|
||||
require.Equal(t, "", repo.listWithFiltersSearch)
|
||||
require.Nil(t, repo.listWithFiltersIsExclusive)
|
||||
})
|
||||
|
||||
t.Run("search 与其他过滤条件组合使用", func(t *testing.T) {
|
||||
isExclusive := true
|
||||
repo := &groupRepoStubForAdmin{
|
||||
listWithFiltersGroups: []Group{{ID: 2, Name: "beta"}},
|
||||
listWithFiltersResult: &pagination.PaginationResult{Total: 42},
|
||||
}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
groups, total, err := svc.ListGroups(context.Background(), 3, 50, PlatformAntigravity, StatusActive, "beta", &isExclusive)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(42), total)
|
||||
require.Equal(t, []Group{{ID: 2, Name: "beta"}}, groups)
|
||||
|
||||
require.Equal(t, 1, repo.listWithFiltersCalls)
|
||||
require.Equal(t, pagination.PaginationParams{Page: 3, PageSize: 50}, repo.listWithFiltersParams)
|
||||
require.Equal(t, PlatformAntigravity, repo.listWithFiltersPlatform)
|
||||
require.Equal(t, StatusActive, repo.listWithFiltersStatus)
|
||||
require.Equal(t, "beta", repo.listWithFiltersSearch)
|
||||
require.NotNil(t, repo.listWithFiltersIsExclusive)
|
||||
require.True(t, *repo.listWithFiltersIsExclusive)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAdminService_ValidateFallbackGroup_DetectsCycle(t *testing.T) {
|
||||
groupID := int64(1)
|
||||
fallbackID := int64(2)
|
||||
repo := &groupRepoStubForFallbackCycle{
|
||||
groups: map[int64]*Group{
|
||||
groupID: {
|
||||
ID: groupID,
|
||||
FallbackGroupID: &fallbackID,
|
||||
},
|
||||
fallbackID: {
|
||||
ID: fallbackID,
|
||||
FallbackGroupID: &groupID,
|
||||
},
|
||||
},
|
||||
}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
err := svc.validateFallbackGroup(context.Background(), groupID, fallbackID)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "fallback group cycle")
|
||||
}
|
||||
|
||||
type groupRepoStubForFallbackCycle struct {
|
||||
groups map[int64]*Group
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) Create(_ context.Context, _ *Group) error {
|
||||
panic("unexpected Create call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) Update(_ context.Context, _ *Group) error {
|
||||
panic("unexpected Update call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) GetByID(ctx context.Context, id int64) (*Group, error) {
|
||||
return s.GetByIDLite(ctx, id)
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) GetByIDLite(_ context.Context, id int64) (*Group, error) {
|
||||
if g, ok := s.groups[id]; ok {
|
||||
return g, nil
|
||||
}
|
||||
return nil, ErrGroupNotFound
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) Delete(_ context.Context, _ int64) error {
|
||||
panic("unexpected Delete call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) DeleteCascade(_ context.Context, _ int64) ([]int64, error) {
|
||||
panic("unexpected DeleteCascade call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) List(_ context.Context, _ pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
|
||||
panic("unexpected List call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) ListWithFilters(_ context.Context, _ pagination.PaginationParams, _, _, _ string, _ *bool) ([]Group, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListWithFilters call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) ListActive(_ context.Context) ([]Group, error) {
|
||||
panic("unexpected ListActive call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) ListActiveByPlatform(_ context.Context, _ string) ([]Group, error) {
|
||||
panic("unexpected ListActiveByPlatform call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) ExistsByName(_ context.Context, _ string) (bool, error) {
|
||||
panic("unexpected ExistsByName call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) GetAccountCount(_ context.Context, _ int64) (int64, error) {
|
||||
panic("unexpected GetAccountCount call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) DeleteAccountGroupsByGroupID(_ context.Context, _ int64) (int64, error) {
|
||||
panic("unexpected DeleteAccountGroupsByGroupID call")
|
||||
}
|
||||
238
backend/internal/service/admin_service_search_test.go
Normal file
238
backend/internal/service/admin_service_search_test.go
Normal file
@@ -0,0 +1,238 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type accountRepoStubForAdminList struct {
|
||||
accountRepoStub
|
||||
|
||||
listWithFiltersCalls int
|
||||
listWithFiltersParams pagination.PaginationParams
|
||||
listWithFiltersPlatform string
|
||||
listWithFiltersType string
|
||||
listWithFiltersStatus string
|
||||
listWithFiltersSearch string
|
||||
listWithFiltersAccounts []Account
|
||||
listWithFiltersResult *pagination.PaginationResult
|
||||
listWithFiltersErr error
|
||||
}
|
||||
|
||||
func (s *accountRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) {
|
||||
s.listWithFiltersCalls++
|
||||
s.listWithFiltersParams = params
|
||||
s.listWithFiltersPlatform = platform
|
||||
s.listWithFiltersType = accountType
|
||||
s.listWithFiltersStatus = status
|
||||
s.listWithFiltersSearch = search
|
||||
|
||||
if s.listWithFiltersErr != nil {
|
||||
return nil, nil, s.listWithFiltersErr
|
||||
}
|
||||
|
||||
result := s.listWithFiltersResult
|
||||
if result == nil {
|
||||
result = &pagination.PaginationResult{
|
||||
Total: int64(len(s.listWithFiltersAccounts)),
|
||||
Page: params.Page,
|
||||
PageSize: params.PageSize,
|
||||
}
|
||||
}
|
||||
|
||||
return s.listWithFiltersAccounts, result, nil
|
||||
}
|
||||
|
||||
type proxyRepoStubForAdminList struct {
|
||||
proxyRepoStub
|
||||
|
||||
listWithFiltersCalls int
|
||||
listWithFiltersParams pagination.PaginationParams
|
||||
listWithFiltersProtocol string
|
||||
listWithFiltersStatus string
|
||||
listWithFiltersSearch string
|
||||
listWithFiltersProxies []Proxy
|
||||
listWithFiltersResult *pagination.PaginationResult
|
||||
listWithFiltersErr error
|
||||
|
||||
listWithFiltersAndAccountCountCalls int
|
||||
listWithFiltersAndAccountCountParams pagination.PaginationParams
|
||||
listWithFiltersAndAccountCountProtocol string
|
||||
listWithFiltersAndAccountCountStatus string
|
||||
listWithFiltersAndAccountCountSearch string
|
||||
listWithFiltersAndAccountCountProxies []ProxyWithAccountCount
|
||||
listWithFiltersAndAccountCountResult *pagination.PaginationResult
|
||||
listWithFiltersAndAccountCountErr error
|
||||
}
|
||||
|
||||
func (s *proxyRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, protocol, status, search string) ([]Proxy, *pagination.PaginationResult, error) {
|
||||
s.listWithFiltersCalls++
|
||||
s.listWithFiltersParams = params
|
||||
s.listWithFiltersProtocol = protocol
|
||||
s.listWithFiltersStatus = status
|
||||
s.listWithFiltersSearch = search
|
||||
|
||||
if s.listWithFiltersErr != nil {
|
||||
return nil, nil, s.listWithFiltersErr
|
||||
}
|
||||
|
||||
result := s.listWithFiltersResult
|
||||
if result == nil {
|
||||
result = &pagination.PaginationResult{
|
||||
Total: int64(len(s.listWithFiltersProxies)),
|
||||
Page: params.Page,
|
||||
PageSize: params.PageSize,
|
||||
}
|
||||
}
|
||||
|
||||
return s.listWithFiltersProxies, result, nil
|
||||
}
|
||||
|
||||
func (s *proxyRepoStubForAdminList) ListWithFiltersAndAccountCount(_ context.Context, params pagination.PaginationParams, protocol, status, search string) ([]ProxyWithAccountCount, *pagination.PaginationResult, error) {
|
||||
s.listWithFiltersAndAccountCountCalls++
|
||||
s.listWithFiltersAndAccountCountParams = params
|
||||
s.listWithFiltersAndAccountCountProtocol = protocol
|
||||
s.listWithFiltersAndAccountCountStatus = status
|
||||
s.listWithFiltersAndAccountCountSearch = search
|
||||
|
||||
if s.listWithFiltersAndAccountCountErr != nil {
|
||||
return nil, nil, s.listWithFiltersAndAccountCountErr
|
||||
}
|
||||
|
||||
result := s.listWithFiltersAndAccountCountResult
|
||||
if result == nil {
|
||||
result = &pagination.PaginationResult{
|
||||
Total: int64(len(s.listWithFiltersAndAccountCountProxies)),
|
||||
Page: params.Page,
|
||||
PageSize: params.PageSize,
|
||||
}
|
||||
}
|
||||
|
||||
return s.listWithFiltersAndAccountCountProxies, result, nil
|
||||
}
|
||||
|
||||
type redeemRepoStubForAdminList struct {
|
||||
redeemRepoStub
|
||||
|
||||
listWithFiltersCalls int
|
||||
listWithFiltersParams pagination.PaginationParams
|
||||
listWithFiltersType string
|
||||
listWithFiltersStatus string
|
||||
listWithFiltersSearch string
|
||||
listWithFiltersCodes []RedeemCode
|
||||
listWithFiltersResult *pagination.PaginationResult
|
||||
listWithFiltersErr error
|
||||
}
|
||||
|
||||
func (s *redeemRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, codeType, status, search string) ([]RedeemCode, *pagination.PaginationResult, error) {
|
||||
s.listWithFiltersCalls++
|
||||
s.listWithFiltersParams = params
|
||||
s.listWithFiltersType = codeType
|
||||
s.listWithFiltersStatus = status
|
||||
s.listWithFiltersSearch = search
|
||||
|
||||
if s.listWithFiltersErr != nil {
|
||||
return nil, nil, s.listWithFiltersErr
|
||||
}
|
||||
|
||||
result := s.listWithFiltersResult
|
||||
if result == nil {
|
||||
result = &pagination.PaginationResult{
|
||||
Total: int64(len(s.listWithFiltersCodes)),
|
||||
Page: params.Page,
|
||||
PageSize: params.PageSize,
|
||||
}
|
||||
}
|
||||
|
||||
return s.listWithFiltersCodes, result, nil
|
||||
}
|
||||
|
||||
func TestAdminService_ListAccounts_WithSearch(t *testing.T) {
|
||||
t.Run("search 参数正常传递到 repository 层", func(t *testing.T) {
|
||||
repo := &accountRepoStubForAdminList{
|
||||
listWithFiltersAccounts: []Account{{ID: 1, Name: "acc"}},
|
||||
listWithFiltersResult: &pagination.PaginationResult{Total: 10},
|
||||
}
|
||||
svc := &adminServiceImpl{accountRepo: repo}
|
||||
|
||||
accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformGemini, AccountTypeOAuth, StatusActive, "acc")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(10), total)
|
||||
require.Equal(t, []Account{{ID: 1, Name: "acc"}}, accounts)
|
||||
|
||||
require.Equal(t, 1, repo.listWithFiltersCalls)
|
||||
require.Equal(t, pagination.PaginationParams{Page: 1, PageSize: 20}, repo.listWithFiltersParams)
|
||||
require.Equal(t, PlatformGemini, repo.listWithFiltersPlatform)
|
||||
require.Equal(t, AccountTypeOAuth, repo.listWithFiltersType)
|
||||
require.Equal(t, StatusActive, repo.listWithFiltersStatus)
|
||||
require.Equal(t, "acc", repo.listWithFiltersSearch)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAdminService_ListProxies_WithSearch(t *testing.T) {
|
||||
t.Run("search 参数正常传递到 repository 层", func(t *testing.T) {
|
||||
repo := &proxyRepoStubForAdminList{
|
||||
listWithFiltersProxies: []Proxy{{ID: 2, Name: "p1"}},
|
||||
listWithFiltersResult: &pagination.PaginationResult{Total: 7},
|
||||
}
|
||||
svc := &adminServiceImpl{proxyRepo: repo}
|
||||
|
||||
proxies, total, err := svc.ListProxies(context.Background(), 3, 50, "http", StatusActive, "p1")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(7), total)
|
||||
require.Equal(t, []Proxy{{ID: 2, Name: "p1"}}, proxies)
|
||||
|
||||
require.Equal(t, 1, repo.listWithFiltersCalls)
|
||||
require.Equal(t, pagination.PaginationParams{Page: 3, PageSize: 50}, repo.listWithFiltersParams)
|
||||
require.Equal(t, "http", repo.listWithFiltersProtocol)
|
||||
require.Equal(t, StatusActive, repo.listWithFiltersStatus)
|
||||
require.Equal(t, "p1", repo.listWithFiltersSearch)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAdminService_ListProxiesWithAccountCount_WithSearch(t *testing.T) {
|
||||
t.Run("search 参数正常传递到 repository 层", func(t *testing.T) {
|
||||
repo := &proxyRepoStubForAdminList{
|
||||
listWithFiltersAndAccountCountProxies: []ProxyWithAccountCount{{Proxy: Proxy{ID: 3, Name: "p2"}, AccountCount: 5}},
|
||||
listWithFiltersAndAccountCountResult: &pagination.PaginationResult{Total: 9},
|
||||
}
|
||||
svc := &adminServiceImpl{proxyRepo: repo}
|
||||
|
||||
proxies, total, err := svc.ListProxiesWithAccountCount(context.Background(), 2, 10, "socks5", StatusDisabled, "p2")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(9), total)
|
||||
require.Equal(t, []ProxyWithAccountCount{{Proxy: Proxy{ID: 3, Name: "p2"}, AccountCount: 5}}, proxies)
|
||||
|
||||
require.Equal(t, 1, repo.listWithFiltersAndAccountCountCalls)
|
||||
require.Equal(t, pagination.PaginationParams{Page: 2, PageSize: 10}, repo.listWithFiltersAndAccountCountParams)
|
||||
require.Equal(t, "socks5", repo.listWithFiltersAndAccountCountProtocol)
|
||||
require.Equal(t, StatusDisabled, repo.listWithFiltersAndAccountCountStatus)
|
||||
require.Equal(t, "p2", repo.listWithFiltersAndAccountCountSearch)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAdminService_ListRedeemCodes_WithSearch(t *testing.T) {
|
||||
t.Run("search 参数正常传递到 repository 层", func(t *testing.T) {
|
||||
repo := &redeemRepoStubForAdminList{
|
||||
listWithFiltersCodes: []RedeemCode{{ID: 4, Code: "ABC"}},
|
||||
listWithFiltersResult: &pagination.PaginationResult{Total: 3},
|
||||
}
|
||||
svc := &adminServiceImpl{redeemCodeRepo: repo}
|
||||
|
||||
codes, total, err := svc.ListRedeemCodes(context.Background(), 1, 20, RedeemTypeBalance, StatusUnused, "ABC")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(3), total)
|
||||
require.Equal(t, []RedeemCode{{ID: 4, Code: "ABC"}}, codes)
|
||||
|
||||
require.Equal(t, 1, repo.listWithFiltersCalls)
|
||||
require.Equal(t, pagination.PaginationParams{Page: 1, PageSize: 20}, repo.listWithFiltersParams)
|
||||
require.Equal(t, RedeemTypeBalance, repo.listWithFiltersType)
|
||||
require.Equal(t, StatusUnused, repo.listWithFiltersStatus)
|
||||
require.Equal(t, "ABC", repo.listWithFiltersSearch)
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,97 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type balanceUserRepoStub struct {
|
||||
*userRepoStub
|
||||
updateErr error
|
||||
updated []*User
|
||||
}
|
||||
|
||||
func (s *balanceUserRepoStub) Update(ctx context.Context, user *User) error {
|
||||
if s.updateErr != nil {
|
||||
return s.updateErr
|
||||
}
|
||||
if user == nil {
|
||||
return nil
|
||||
}
|
||||
clone := *user
|
||||
s.updated = append(s.updated, &clone)
|
||||
if s.userRepoStub != nil {
|
||||
s.userRepoStub.user = &clone
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type balanceRedeemRepoStub struct {
|
||||
*redeemRepoStub
|
||||
created []*RedeemCode
|
||||
}
|
||||
|
||||
func (s *balanceRedeemRepoStub) Create(ctx context.Context, code *RedeemCode) error {
|
||||
if code == nil {
|
||||
return nil
|
||||
}
|
||||
clone := *code
|
||||
s.created = append(s.created, &clone)
|
||||
return nil
|
||||
}
|
||||
|
||||
type authCacheInvalidatorStub struct {
|
||||
userIDs []int64
|
||||
groupIDs []int64
|
||||
keys []string
|
||||
}
|
||||
|
||||
func (s *authCacheInvalidatorStub) InvalidateAuthCacheByKey(ctx context.Context, key string) {
|
||||
s.keys = append(s.keys, key)
|
||||
}
|
||||
|
||||
func (s *authCacheInvalidatorStub) InvalidateAuthCacheByUserID(ctx context.Context, userID int64) {
|
||||
s.userIDs = append(s.userIDs, userID)
|
||||
}
|
||||
|
||||
func (s *authCacheInvalidatorStub) InvalidateAuthCacheByGroupID(ctx context.Context, groupID int64) {
|
||||
s.groupIDs = append(s.groupIDs, groupID)
|
||||
}
|
||||
|
||||
func TestAdminService_UpdateUserBalance_InvalidatesAuthCache(t *testing.T) {
|
||||
baseRepo := &userRepoStub{user: &User{ID: 7, Balance: 10}}
|
||||
repo := &balanceUserRepoStub{userRepoStub: baseRepo}
|
||||
redeemRepo := &balanceRedeemRepoStub{redeemRepoStub: &redeemRepoStub{}}
|
||||
invalidator := &authCacheInvalidatorStub{}
|
||||
svc := &adminServiceImpl{
|
||||
userRepo: repo,
|
||||
redeemCodeRepo: redeemRepo,
|
||||
authCacheInvalidator: invalidator,
|
||||
}
|
||||
|
||||
_, err := svc.UpdateUserBalance(context.Background(), 7, 5, "add", "")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []int64{7}, invalidator.userIDs)
|
||||
require.Len(t, redeemRepo.created, 1)
|
||||
}
|
||||
|
||||
func TestAdminService_UpdateUserBalance_NoChangeNoInvalidate(t *testing.T) {
|
||||
baseRepo := &userRepoStub{user: &User{ID: 7, Balance: 10}}
|
||||
repo := &balanceUserRepoStub{userRepoStub: baseRepo}
|
||||
redeemRepo := &balanceRedeemRepoStub{redeemRepoStub: &redeemRepoStub{}}
|
||||
invalidator := &authCacheInvalidatorStub{}
|
||||
svc := &adminServiceImpl{
|
||||
userRepo: repo,
|
||||
redeemCodeRepo: redeemRepo,
|
||||
authCacheInvalidator: invalidator,
|
||||
}
|
||||
|
||||
_, err := svc.UpdateUserBalance(context.Background(), 7, 10, "set", "")
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, invalidator.userIDs)
|
||||
require.Empty(t, redeemRepo.created)
|
||||
}
|
||||
2474
backend/internal/service/antigravity_gateway_service.go
Normal file
2474
backend/internal/service/antigravity_gateway_service.go
Normal file
File diff suppressed because it is too large
Load Diff
83
backend/internal/service/antigravity_gateway_service_test.go
Normal file
83
backend/internal/service/antigravity_gateway_service_test.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestStripSignatureSensitiveBlocksFromClaudeRequest(t *testing.T) {
|
||||
req := &antigravity.ClaudeRequest{
|
||||
Model: "claude-sonnet-4-5",
|
||||
Thinking: &antigravity.ThinkingConfig{
|
||||
Type: "enabled",
|
||||
BudgetTokens: 1024,
|
||||
},
|
||||
Messages: []antigravity.ClaudeMessage{
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: json.RawMessage(`[
|
||||
{"type":"thinking","thinking":"secret plan","signature":""},
|
||||
{"type":"tool_use","id":"t1","name":"Bash","input":{"command":"ls"}}
|
||||
]`),
|
||||
},
|
||||
{
|
||||
Role: "user",
|
||||
Content: json.RawMessage(`[
|
||||
{"type":"tool_result","tool_use_id":"t1","content":"ok","is_error":false},
|
||||
{"type":"redacted_thinking","data":"..."}
|
||||
]`),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
changed, err := stripSignatureSensitiveBlocksFromClaudeRequest(req)
|
||||
require.NoError(t, err)
|
||||
require.True(t, changed)
|
||||
require.Nil(t, req.Thinking)
|
||||
|
||||
require.Len(t, req.Messages, 2)
|
||||
|
||||
var blocks0 []map[string]any
|
||||
require.NoError(t, json.Unmarshal(req.Messages[0].Content, &blocks0))
|
||||
require.Len(t, blocks0, 2)
|
||||
require.Equal(t, "text", blocks0[0]["type"])
|
||||
require.Equal(t, "secret plan", blocks0[0]["text"])
|
||||
require.Equal(t, "text", blocks0[1]["type"])
|
||||
|
||||
var blocks1 []map[string]any
|
||||
require.NoError(t, json.Unmarshal(req.Messages[1].Content, &blocks1))
|
||||
require.Len(t, blocks1, 1)
|
||||
require.Equal(t, "text", blocks1[0]["type"])
|
||||
require.NotEmpty(t, blocks1[0]["text"])
|
||||
}
|
||||
|
||||
func TestStripThinkingFromClaudeRequest_DoesNotDowngradeTools(t *testing.T) {
|
||||
req := &antigravity.ClaudeRequest{
|
||||
Model: "claude-sonnet-4-5",
|
||||
Thinking: &antigravity.ThinkingConfig{
|
||||
Type: "enabled",
|
||||
BudgetTokens: 1024,
|
||||
},
|
||||
Messages: []antigravity.ClaudeMessage{
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: json.RawMessage(`[{"type":"thinking","thinking":"secret plan"},{"type":"tool_use","id":"t1","name":"Bash","input":{"command":"ls"}}]`),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
changed, err := stripThinkingFromClaudeRequest(req)
|
||||
require.NoError(t, err)
|
||||
require.True(t, changed)
|
||||
require.Nil(t, req.Thinking)
|
||||
|
||||
var blocks []map[string]any
|
||||
require.NoError(t, json.Unmarshal(req.Messages[0].Content, &blocks))
|
||||
require.Len(t, blocks, 2)
|
||||
require.Equal(t, "text", blocks[0]["type"])
|
||||
require.Equal(t, "secret plan", blocks[0]["text"])
|
||||
require.Equal(t, "tool_use", blocks[1]["type"])
|
||||
}
|
||||
123
backend/internal/service/antigravity_image_test.go
Normal file
123
backend/internal/service/antigravity_image_test.go
Normal file
@@ -0,0 +1,123 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestIsImageGenerationModel_GeminiProImage 测试 gemini-3-pro-image 识别
|
||||
func TestIsImageGenerationModel_GeminiProImage(t *testing.T) {
|
||||
require.True(t, isImageGenerationModel("gemini-3-pro-image"))
|
||||
require.True(t, isImageGenerationModel("gemini-3-pro-image-preview"))
|
||||
require.True(t, isImageGenerationModel("models/gemini-3-pro-image"))
|
||||
}
|
||||
|
||||
// TestIsImageGenerationModel_GeminiFlashImage 测试 gemini-2.5-flash-image 识别
|
||||
func TestIsImageGenerationModel_GeminiFlashImage(t *testing.T) {
|
||||
require.True(t, isImageGenerationModel("gemini-2.5-flash-image"))
|
||||
require.True(t, isImageGenerationModel("gemini-2.5-flash-image-preview"))
|
||||
}
|
||||
|
||||
// TestIsImageGenerationModel_RegularModel 测试普通模型不被识别为图片模型
|
||||
func TestIsImageGenerationModel_RegularModel(t *testing.T) {
|
||||
require.False(t, isImageGenerationModel("claude-3-opus"))
|
||||
require.False(t, isImageGenerationModel("claude-sonnet-4-20250514"))
|
||||
require.False(t, isImageGenerationModel("gpt-4o"))
|
||||
require.False(t, isImageGenerationModel("gemini-2.5-pro")) // 非图片模型
|
||||
require.False(t, isImageGenerationModel("gemini-2.5-flash"))
|
||||
// 验证不会误匹配包含关键词的自定义模型名
|
||||
require.False(t, isImageGenerationModel("my-gemini-3-pro-image-test"))
|
||||
require.False(t, isImageGenerationModel("custom-gemini-2.5-flash-image-wrapper"))
|
||||
}
|
||||
|
||||
// TestIsImageGenerationModel_CaseInsensitive 测试大小写不敏感
|
||||
func TestIsImageGenerationModel_CaseInsensitive(t *testing.T) {
|
||||
require.True(t, isImageGenerationModel("GEMINI-3-PRO-IMAGE"))
|
||||
require.True(t, isImageGenerationModel("Gemini-3-Pro-Image"))
|
||||
require.True(t, isImageGenerationModel("GEMINI-2.5-FLASH-IMAGE"))
|
||||
}
|
||||
|
||||
// TestExtractImageSize_ValidSizes 测试有效尺寸解析
|
||||
func TestExtractImageSize_ValidSizes(t *testing.T) {
|
||||
svc := &AntigravityGatewayService{}
|
||||
|
||||
// 1K
|
||||
body := []byte(`{"generationConfig":{"imageConfig":{"imageSize":"1K"}}}`)
|
||||
require.Equal(t, "1K", svc.extractImageSize(body))
|
||||
|
||||
// 2K
|
||||
body = []byte(`{"generationConfig":{"imageConfig":{"imageSize":"2K"}}}`)
|
||||
require.Equal(t, "2K", svc.extractImageSize(body))
|
||||
|
||||
// 4K
|
||||
body = []byte(`{"generationConfig":{"imageConfig":{"imageSize":"4K"}}}`)
|
||||
require.Equal(t, "4K", svc.extractImageSize(body))
|
||||
}
|
||||
|
||||
// TestExtractImageSize_CaseInsensitive 测试大小写不敏感
|
||||
func TestExtractImageSize_CaseInsensitive(t *testing.T) {
|
||||
svc := &AntigravityGatewayService{}
|
||||
|
||||
body := []byte(`{"generationConfig":{"imageConfig":{"imageSize":"1k"}}}`)
|
||||
require.Equal(t, "1K", svc.extractImageSize(body))
|
||||
|
||||
body = []byte(`{"generationConfig":{"imageConfig":{"imageSize":"4k"}}}`)
|
||||
require.Equal(t, "4K", svc.extractImageSize(body))
|
||||
}
|
||||
|
||||
// TestExtractImageSize_Default 测试无 imageConfig 返回默认 2K
|
||||
func TestExtractImageSize_Default(t *testing.T) {
|
||||
svc := &AntigravityGatewayService{}
|
||||
|
||||
// 无 generationConfig
|
||||
body := []byte(`{"contents":[]}`)
|
||||
require.Equal(t, "2K", svc.extractImageSize(body))
|
||||
|
||||
// 有 generationConfig 但无 imageConfig
|
||||
body = []byte(`{"generationConfig":{"temperature":0.7}}`)
|
||||
require.Equal(t, "2K", svc.extractImageSize(body))
|
||||
|
||||
// 有 imageConfig 但无 imageSize
|
||||
body = []byte(`{"generationConfig":{"imageConfig":{}}}`)
|
||||
require.Equal(t, "2K", svc.extractImageSize(body))
|
||||
}
|
||||
|
||||
// TestExtractImageSize_InvalidJSON 测试非法 JSON 返回默认 2K
|
||||
func TestExtractImageSize_InvalidJSON(t *testing.T) {
|
||||
svc := &AntigravityGatewayService{}
|
||||
|
||||
body := []byte(`not valid json`)
|
||||
require.Equal(t, "2K", svc.extractImageSize(body))
|
||||
|
||||
body = []byte(`{"broken":`)
|
||||
require.Equal(t, "2K", svc.extractImageSize(body))
|
||||
}
|
||||
|
||||
// TestExtractImageSize_EmptySize 测试空 imageSize 返回默认 2K
|
||||
func TestExtractImageSize_EmptySize(t *testing.T) {
|
||||
svc := &AntigravityGatewayService{}
|
||||
|
||||
body := []byte(`{"generationConfig":{"imageConfig":{"imageSize":""}}}`)
|
||||
require.Equal(t, "2K", svc.extractImageSize(body))
|
||||
|
||||
// 空格
|
||||
body = []byte(`{"generationConfig":{"imageConfig":{"imageSize":" "}}}`)
|
||||
require.Equal(t, "2K", svc.extractImageSize(body))
|
||||
}
|
||||
|
||||
// TestExtractImageSize_InvalidSize 测试无效尺寸返回默认 2K
|
||||
func TestExtractImageSize_InvalidSize(t *testing.T) {
|
||||
svc := &AntigravityGatewayService{}
|
||||
|
||||
body := []byte(`{"generationConfig":{"imageConfig":{"imageSize":"3K"}}}`)
|
||||
require.Equal(t, "2K", svc.extractImageSize(body))
|
||||
|
||||
body = []byte(`{"generationConfig":{"imageConfig":{"imageSize":"8K"}}}`)
|
||||
require.Equal(t, "2K", svc.extractImageSize(body))
|
||||
|
||||
body = []byte(`{"generationConfig":{"imageConfig":{"imageSize":"invalid"}}}`)
|
||||
require.Equal(t, "2K", svc.extractImageSize(body))
|
||||
}
|
||||
269
backend/internal/service/antigravity_model_mapping_test.go
Normal file
269
backend/internal/service/antigravity_model_mapping_test.go
Normal file
@@ -0,0 +1,269 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestIsAntigravityModelSupported(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
expected bool
|
||||
}{
|
||||
// 直接支持的模型
|
||||
{"直接支持 - claude-sonnet-4-5", "claude-sonnet-4-5", true},
|
||||
{"直接支持 - claude-opus-4-5-thinking", "claude-opus-4-5-thinking", true},
|
||||
{"直接支持 - claude-sonnet-4-5-thinking", "claude-sonnet-4-5-thinking", true},
|
||||
{"直接支持 - gemini-2.5-flash", "gemini-2.5-flash", true},
|
||||
{"直接支持 - gemini-2.5-flash-lite", "gemini-2.5-flash-lite", true},
|
||||
{"直接支持 - gemini-3-pro-high", "gemini-3-pro-high", true},
|
||||
|
||||
// 可映射的模型
|
||||
{"可映射 - claude-3-5-sonnet-20241022", "claude-3-5-sonnet-20241022", true},
|
||||
{"可映射 - claude-3-5-sonnet-20240620", "claude-3-5-sonnet-20240620", true},
|
||||
{"可映射 - claude-opus-4", "claude-opus-4", true},
|
||||
{"可映射 - claude-haiku-4", "claude-haiku-4", true},
|
||||
{"可映射 - claude-3-haiku-20240307", "claude-3-haiku-20240307", true},
|
||||
|
||||
// Gemini 前缀透传
|
||||
{"Gemini前缀 - gemini-1.5-pro", "gemini-1.5-pro", true},
|
||||
{"Gemini前缀 - gemini-unknown-model", "gemini-unknown-model", true},
|
||||
{"Gemini前缀 - gemini-future-version", "gemini-future-version", true},
|
||||
|
||||
// Claude 前缀兜底
|
||||
{"Claude前缀 - claude-unknown-model", "claude-unknown-model", true},
|
||||
{"Claude前缀 - claude-3-opus-20240229", "claude-3-opus-20240229", true},
|
||||
{"Claude前缀 - claude-future-version", "claude-future-version", true},
|
||||
|
||||
// 不支持的模型
|
||||
{"不支持 - gpt-4", "gpt-4", false},
|
||||
{"不支持 - gpt-4o", "gpt-4o", false},
|
||||
{"不支持 - llama-3", "llama-3", false},
|
||||
{"不支持 - mistral-7b", "mistral-7b", false},
|
||||
{"不支持 - 空字符串", "", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := IsAntigravityModelSupported(tt.model)
|
||||
require.Equal(t, tt.expected, got, "model: %s", tt.model)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAntigravityGatewayService_GetMappedModel(t *testing.T) {
|
||||
svc := &AntigravityGatewayService{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
requestedModel string
|
||||
accountMapping map[string]string
|
||||
expected string
|
||||
}{
|
||||
// 1. 账户级映射优先(注意:model_mapping 在 credentials 中存储为 map[string]any)
|
||||
{
|
||||
name: "账户映射优先",
|
||||
requestedModel: "claude-3-5-sonnet-20241022",
|
||||
accountMapping: map[string]string{"claude-3-5-sonnet-20241022": "custom-model"},
|
||||
expected: "custom-model",
|
||||
},
|
||||
{
|
||||
name: "账户映射覆盖系统映射",
|
||||
requestedModel: "claude-opus-4",
|
||||
accountMapping: map[string]string{"claude-opus-4": "my-opus"},
|
||||
expected: "my-opus",
|
||||
},
|
||||
|
||||
// 2. 系统默认映射
|
||||
{
|
||||
name: "系统映射 - claude-3-5-sonnet-20241022",
|
||||
requestedModel: "claude-3-5-sonnet-20241022",
|
||||
accountMapping: nil,
|
||||
expected: "claude-sonnet-4-5",
|
||||
},
|
||||
{
|
||||
name: "系统映射 - claude-3-5-sonnet-20240620",
|
||||
requestedModel: "claude-3-5-sonnet-20240620",
|
||||
accountMapping: nil,
|
||||
expected: "claude-sonnet-4-5",
|
||||
},
|
||||
{
|
||||
name: "系统映射 - claude-opus-4",
|
||||
requestedModel: "claude-opus-4",
|
||||
accountMapping: nil,
|
||||
expected: "claude-opus-4-5-thinking",
|
||||
},
|
||||
{
|
||||
name: "系统映射 - claude-opus-4-5-20251101",
|
||||
requestedModel: "claude-opus-4-5-20251101",
|
||||
accountMapping: nil,
|
||||
expected: "claude-opus-4-5-thinking",
|
||||
},
|
||||
{
|
||||
name: "系统映射 - claude-haiku-4 → claude-sonnet-4-5",
|
||||
requestedModel: "claude-haiku-4",
|
||||
accountMapping: nil,
|
||||
expected: "claude-sonnet-4-5",
|
||||
},
|
||||
{
|
||||
name: "系统映射 - claude-haiku-4-5 → claude-sonnet-4-5",
|
||||
requestedModel: "claude-haiku-4-5",
|
||||
accountMapping: nil,
|
||||
expected: "claude-sonnet-4-5",
|
||||
},
|
||||
{
|
||||
name: "系统映射 - claude-3-haiku-20240307 → claude-sonnet-4-5",
|
||||
requestedModel: "claude-3-haiku-20240307",
|
||||
accountMapping: nil,
|
||||
expected: "claude-sonnet-4-5",
|
||||
},
|
||||
{
|
||||
name: "系统映射 - claude-haiku-4-5-20251001 → claude-sonnet-4-5",
|
||||
requestedModel: "claude-haiku-4-5-20251001",
|
||||
accountMapping: nil,
|
||||
expected: "claude-sonnet-4-5",
|
||||
},
|
||||
{
|
||||
name: "系统映射 - claude-sonnet-4-5-20250929",
|
||||
requestedModel: "claude-sonnet-4-5-20250929",
|
||||
accountMapping: nil,
|
||||
expected: "claude-sonnet-4-5",
|
||||
},
|
||||
|
||||
// 3. Gemini 透传
|
||||
{
|
||||
name: "Gemini透传 - gemini-2.5-flash",
|
||||
requestedModel: "gemini-2.5-flash",
|
||||
accountMapping: nil,
|
||||
expected: "gemini-2.5-flash",
|
||||
},
|
||||
{
|
||||
name: "Gemini透传 - gemini-1.5-pro",
|
||||
requestedModel: "gemini-1.5-pro",
|
||||
accountMapping: nil,
|
||||
expected: "gemini-1.5-pro",
|
||||
},
|
||||
{
|
||||
name: "Gemini透传 - gemini-future-model",
|
||||
requestedModel: "gemini-future-model",
|
||||
accountMapping: nil,
|
||||
expected: "gemini-future-model",
|
||||
},
|
||||
|
||||
// 4. 直接支持的模型
|
||||
{
|
||||
name: "直接支持 - claude-sonnet-4-5",
|
||||
requestedModel: "claude-sonnet-4-5",
|
||||
accountMapping: nil,
|
||||
expected: "claude-sonnet-4-5",
|
||||
},
|
||||
{
|
||||
name: "直接支持 - claude-opus-4-5-thinking",
|
||||
requestedModel: "claude-opus-4-5-thinking",
|
||||
accountMapping: nil,
|
||||
expected: "claude-opus-4-5-thinking",
|
||||
},
|
||||
{
|
||||
name: "直接支持 - claude-sonnet-4-5-thinking",
|
||||
requestedModel: "claude-sonnet-4-5-thinking",
|
||||
accountMapping: nil,
|
||||
expected: "claude-sonnet-4-5-thinking",
|
||||
},
|
||||
|
||||
// 5. 默认值 fallback(未知 claude 模型)
|
||||
{
|
||||
name: "默认值 - claude-unknown",
|
||||
requestedModel: "claude-unknown",
|
||||
accountMapping: nil,
|
||||
expected: "claude-sonnet-4-5",
|
||||
},
|
||||
{
|
||||
name: "默认值 - claude-3-opus-20240229",
|
||||
requestedModel: "claude-3-opus-20240229",
|
||||
accountMapping: nil,
|
||||
expected: "claude-sonnet-4-5",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformAntigravity,
|
||||
}
|
||||
if tt.accountMapping != nil {
|
||||
// GetModelMapping 期望 model_mapping 是 map[string]any 格式
|
||||
mappingAny := make(map[string]any)
|
||||
for k, v := range tt.accountMapping {
|
||||
mappingAny[k] = v
|
||||
}
|
||||
account.Credentials = map[string]any{
|
||||
"model_mapping": mappingAny,
|
||||
}
|
||||
}
|
||||
|
||||
got := svc.getMappedModel(account, tt.requestedModel)
|
||||
require.Equal(t, tt.expected, got, "model: %s", tt.requestedModel)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAntigravityGatewayService_GetMappedModel_EdgeCases(t *testing.T) {
|
||||
svc := &AntigravityGatewayService{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
requestedModel string
|
||||
expected string
|
||||
}{
|
||||
// 空字符串回退到默认值
|
||||
{"空字符串", "", "claude-sonnet-4-5"},
|
||||
|
||||
// 非 claude/gemini 前缀回退到默认值
|
||||
{"非claude/gemini前缀 - gpt", "gpt-4", "claude-sonnet-4-5"},
|
||||
{"非claude/gemini前缀 - llama", "llama-3", "claude-sonnet-4-5"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
account := &Account{Platform: PlatformAntigravity}
|
||||
got := svc.getMappedModel(account, tt.requestedModel)
|
||||
require.Equal(t, tt.expected, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAntigravityGatewayService_IsModelSupported(t *testing.T) {
|
||||
svc := &AntigravityGatewayService{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
expected bool
|
||||
}{
|
||||
// 直接支持
|
||||
{"直接支持 - claude-sonnet-4-5", "claude-sonnet-4-5", true},
|
||||
{"直接支持 - gemini-3-flash", "gemini-3-flash", true},
|
||||
|
||||
// 可映射
|
||||
{"可映射 - claude-opus-4", "claude-opus-4", true},
|
||||
|
||||
// 前缀透传
|
||||
{"Gemini前缀", "gemini-unknown", true},
|
||||
{"Claude前缀", "claude-unknown", true},
|
||||
|
||||
// 不支持
|
||||
{"不支持 - gpt-4", "gpt-4", false},
|
||||
{"不支持 - 空字符串", "", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := svc.IsModelSupported(tt.model)
|
||||
require.Equal(t, tt.expected, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
276
backend/internal/service/antigravity_oauth_service.go
Normal file
276
backend/internal/service/antigravity_oauth_service.go
Normal file
@@ -0,0 +1,276 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
)
|
||||
|
||||
type AntigravityOAuthService struct {
|
||||
sessionStore *antigravity.SessionStore
|
||||
proxyRepo ProxyRepository
|
||||
}
|
||||
|
||||
func NewAntigravityOAuthService(proxyRepo ProxyRepository) *AntigravityOAuthService {
|
||||
return &AntigravityOAuthService{
|
||||
sessionStore: antigravity.NewSessionStore(),
|
||||
proxyRepo: proxyRepo,
|
||||
}
|
||||
}
|
||||
|
||||
// AntigravityAuthURLResult is the result of generating an authorization URL
|
||||
type AntigravityAuthURLResult struct {
|
||||
AuthURL string `json:"auth_url"`
|
||||
SessionID string `json:"session_id"`
|
||||
State string `json:"state"`
|
||||
}
|
||||
|
||||
// GenerateAuthURL 生成 Google OAuth 授权链接
|
||||
func (s *AntigravityOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64) (*AntigravityAuthURLResult, error) {
|
||||
state, err := antigravity.GenerateState()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("生成 state 失败: %w", err)
|
||||
}
|
||||
|
||||
codeVerifier, err := antigravity.GenerateCodeVerifier()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("生成 code_verifier 失败: %w", err)
|
||||
}
|
||||
|
||||
sessionID, err := antigravity.GenerateSessionID()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("生成 session_id 失败: %w", err)
|
||||
}
|
||||
|
||||
var proxyURL string
|
||||
if proxyID != nil {
|
||||
proxy, err := s.proxyRepo.GetByID(ctx, *proxyID)
|
||||
if err == nil && proxy != nil {
|
||||
proxyURL = proxy.URL()
|
||||
}
|
||||
}
|
||||
|
||||
session := &antigravity.OAuthSession{
|
||||
State: state,
|
||||
CodeVerifier: codeVerifier,
|
||||
ProxyURL: proxyURL,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
s.sessionStore.Set(sessionID, session)
|
||||
|
||||
codeChallenge := antigravity.GenerateCodeChallenge(codeVerifier)
|
||||
authURL := antigravity.BuildAuthorizationURL(state, codeChallenge)
|
||||
|
||||
return &AntigravityAuthURLResult{
|
||||
AuthURL: authURL,
|
||||
SessionID: sessionID,
|
||||
State: state,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// AntigravityExchangeCodeInput 交换 code 的输入
|
||||
type AntigravityExchangeCodeInput struct {
|
||||
SessionID string
|
||||
State string
|
||||
Code string
|
||||
ProxyID *int64
|
||||
}
|
||||
|
||||
// AntigravityTokenInfo token 信息
|
||||
type AntigravityTokenInfo struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
ExpiresAt int64 `json:"expires_at"`
|
||||
TokenType string `json:"token_type"`
|
||||
Email string `json:"email,omitempty"`
|
||||
ProjectID string `json:"project_id,omitempty"`
|
||||
}
|
||||
|
||||
// ExchangeCode 用 authorization code 交换 token
|
||||
func (s *AntigravityOAuthService) ExchangeCode(ctx context.Context, input *AntigravityExchangeCodeInput) (*AntigravityTokenInfo, error) {
|
||||
session, ok := s.sessionStore.Get(input.SessionID)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("session 不存在或已过期")
|
||||
}
|
||||
|
||||
if strings.TrimSpace(input.State) == "" || input.State != session.State {
|
||||
return nil, fmt.Errorf("state 无效")
|
||||
}
|
||||
|
||||
// 确定代理 URL
|
||||
proxyURL := session.ProxyURL
|
||||
if input.ProxyID != nil {
|
||||
proxy, err := s.proxyRepo.GetByID(ctx, *input.ProxyID)
|
||||
if err == nil && proxy != nil {
|
||||
proxyURL = proxy.URL()
|
||||
}
|
||||
}
|
||||
|
||||
client := antigravity.NewClient(proxyURL)
|
||||
|
||||
// 交换 token
|
||||
tokenResp, err := client.ExchangeCode(ctx, input.Code, session.CodeVerifier)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("token 交换失败: %w", err)
|
||||
}
|
||||
|
||||
// 删除 session
|
||||
s.sessionStore.Delete(input.SessionID)
|
||||
|
||||
// 计算过期时间(减去 5 分钟安全窗口)
|
||||
expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - 300
|
||||
|
||||
result := &AntigravityTokenInfo{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
ExpiresIn: tokenResp.ExpiresIn,
|
||||
ExpiresAt: expiresAt,
|
||||
TokenType: tokenResp.TokenType,
|
||||
}
|
||||
|
||||
// 获取用户信息
|
||||
userInfo, err := client.GetUserInfo(ctx, tokenResp.AccessToken)
|
||||
if err != nil {
|
||||
fmt.Printf("[AntigravityOAuth] 警告: 获取用户信息失败: %v\n", err)
|
||||
} else {
|
||||
result.Email = userInfo.Email
|
||||
}
|
||||
|
||||
// 获取 project_id(部分账户类型可能没有)
|
||||
loadResp, _, err := client.LoadCodeAssist(ctx, tokenResp.AccessToken)
|
||||
if err != nil {
|
||||
fmt.Printf("[AntigravityOAuth] 警告: 获取 project_id 失败: %v\n", err)
|
||||
} else if loadResp != nil && loadResp.CloudAICompanionProject != "" {
|
||||
result.ProjectID = loadResp.CloudAICompanionProject
|
||||
}
|
||||
|
||||
// 兜底:随机生成 project_id
|
||||
if result.ProjectID == "" {
|
||||
result.ProjectID = antigravity.GenerateMockProjectID()
|
||||
fmt.Printf("[AntigravityOAuth] 使用随机生成的 project_id: %s\n", result.ProjectID)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// RefreshToken 刷新 token
|
||||
func (s *AntigravityOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*AntigravityTokenInfo, error) {
|
||||
var lastErr error
|
||||
|
||||
for attempt := 0; attempt <= 3; attempt++ {
|
||||
if attempt > 0 {
|
||||
backoff := time.Duration(1<<uint(attempt-1)) * time.Second
|
||||
if backoff > 30*time.Second {
|
||||
backoff = 30 * time.Second
|
||||
}
|
||||
time.Sleep(backoff)
|
||||
}
|
||||
|
||||
client := antigravity.NewClient(proxyURL)
|
||||
tokenResp, err := client.RefreshToken(ctx, refreshToken)
|
||||
if err == nil {
|
||||
now := time.Now()
|
||||
expiresAt := now.Unix() + tokenResp.ExpiresIn - 300
|
||||
fmt.Printf("[AntigravityOAuth] Token refreshed: expires_in=%d, expires_at=%d (%s)\n",
|
||||
tokenResp.ExpiresIn, expiresAt, time.Unix(expiresAt, 0).Format("2006-01-02 15:04:05"))
|
||||
return &AntigravityTokenInfo{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
ExpiresIn: tokenResp.ExpiresIn,
|
||||
ExpiresAt: expiresAt,
|
||||
TokenType: tokenResp.TokenType,
|
||||
}, nil
|
||||
}
|
||||
|
||||
if isNonRetryableAntigravityOAuthError(err) {
|
||||
return nil, err
|
||||
}
|
||||
lastErr = err
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("token 刷新失败 (重试后): %w", lastErr)
|
||||
}
|
||||
|
||||
func isNonRetryableAntigravityOAuthError(err error) bool {
|
||||
msg := err.Error()
|
||||
nonRetryable := []string{
|
||||
"invalid_grant",
|
||||
"invalid_client",
|
||||
"unauthorized_client",
|
||||
"access_denied",
|
||||
}
|
||||
for _, needle := range nonRetryable {
|
||||
if strings.Contains(msg, needle) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// RefreshAccountToken 刷新账户的 token
|
||||
func (s *AntigravityOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*AntigravityTokenInfo, error) {
|
||||
if account.Platform != PlatformAntigravity || account.Type != AccountTypeOAuth {
|
||||
return nil, fmt.Errorf("非 Antigravity OAuth 账户")
|
||||
}
|
||||
|
||||
refreshToken := account.GetCredential("refresh_token")
|
||||
if strings.TrimSpace(refreshToken) == "" {
|
||||
return nil, fmt.Errorf("无可用的 refresh_token")
|
||||
}
|
||||
|
||||
var proxyURL string
|
||||
if account.ProxyID != nil {
|
||||
proxy, err := s.proxyRepo.GetByID(ctx, *account.ProxyID)
|
||||
if err == nil && proxy != nil {
|
||||
proxyURL = proxy.URL()
|
||||
}
|
||||
}
|
||||
|
||||
tokenInfo, err := s.RefreshToken(ctx, refreshToken, proxyURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 保留原有的 project_id 和 email
|
||||
existingProjectID := strings.TrimSpace(account.GetCredential("project_id"))
|
||||
if existingProjectID != "" {
|
||||
tokenInfo.ProjectID = existingProjectID
|
||||
}
|
||||
existingEmail := strings.TrimSpace(account.GetCredential("email"))
|
||||
if existingEmail != "" {
|
||||
tokenInfo.Email = existingEmail
|
||||
}
|
||||
|
||||
return tokenInfo, nil
|
||||
}
|
||||
|
||||
// BuildAccountCredentials 构建账户凭证
|
||||
func (s *AntigravityOAuthService) BuildAccountCredentials(tokenInfo *AntigravityTokenInfo) map[string]any {
|
||||
creds := map[string]any{
|
||||
"access_token": tokenInfo.AccessToken,
|
||||
"expires_at": strconv.FormatInt(tokenInfo.ExpiresAt, 10),
|
||||
}
|
||||
if tokenInfo.RefreshToken != "" {
|
||||
creds["refresh_token"] = tokenInfo.RefreshToken
|
||||
}
|
||||
if tokenInfo.TokenType != "" {
|
||||
creds["token_type"] = tokenInfo.TokenType
|
||||
}
|
||||
if tokenInfo.Email != "" {
|
||||
creds["email"] = tokenInfo.Email
|
||||
}
|
||||
if tokenInfo.ProjectID != "" {
|
||||
creds["project_id"] = tokenInfo.ProjectID
|
||||
}
|
||||
return creds
|
||||
}
|
||||
|
||||
// Stop 停止服务
|
||||
func (s *AntigravityOAuthService) Stop() {
|
||||
s.sessionStore.Stop()
|
||||
}
|
||||
111
backend/internal/service/antigravity_quota_fetcher.go
Normal file
111
backend/internal/service/antigravity_quota_fetcher.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
)
|
||||
|
||||
// AntigravityQuotaFetcher 从 Antigravity API 获取额度
|
||||
type AntigravityQuotaFetcher struct {
|
||||
proxyRepo ProxyRepository
|
||||
}
|
||||
|
||||
// NewAntigravityQuotaFetcher 创建 AntigravityQuotaFetcher
|
||||
func NewAntigravityQuotaFetcher(proxyRepo ProxyRepository) *AntigravityQuotaFetcher {
|
||||
return &AntigravityQuotaFetcher{proxyRepo: proxyRepo}
|
||||
}
|
||||
|
||||
// CanFetch 检查是否可以获取此账户的额度
|
||||
func (f *AntigravityQuotaFetcher) CanFetch(account *Account) bool {
|
||||
if account.Platform != PlatformAntigravity {
|
||||
return false
|
||||
}
|
||||
accessToken := account.GetCredential("access_token")
|
||||
return accessToken != ""
|
||||
}
|
||||
|
||||
// FetchQuota 获取 Antigravity 账户额度信息
|
||||
func (f *AntigravityQuotaFetcher) FetchQuota(ctx context.Context, account *Account, proxyURL string) (*QuotaResult, error) {
|
||||
accessToken := account.GetCredential("access_token")
|
||||
projectID := account.GetCredential("project_id")
|
||||
|
||||
// 如果没有 project_id,生成一个随机的
|
||||
if projectID == "" {
|
||||
projectID = antigravity.GenerateMockProjectID()
|
||||
}
|
||||
|
||||
client := antigravity.NewClient(proxyURL)
|
||||
|
||||
// 调用 API 获取配额
|
||||
modelsResp, modelsRaw, err := client.FetchAvailableModels(ctx, accessToken, projectID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 转换为 UsageInfo
|
||||
usageInfo := f.buildUsageInfo(modelsResp)
|
||||
|
||||
return &QuotaResult{
|
||||
UsageInfo: usageInfo,
|
||||
Raw: modelsRaw,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// buildUsageInfo 将 API 响应转换为 UsageInfo
|
||||
func (f *AntigravityQuotaFetcher) buildUsageInfo(modelsResp *antigravity.FetchAvailableModelsResponse) *UsageInfo {
|
||||
now := time.Now()
|
||||
info := &UsageInfo{
|
||||
UpdatedAt: &now,
|
||||
AntigravityQuota: make(map[string]*AntigravityModelQuota),
|
||||
}
|
||||
|
||||
// 遍历所有模型,填充 AntigravityQuota
|
||||
for modelName, modelInfo := range modelsResp.Models {
|
||||
if modelInfo.QuotaInfo == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// remainingFraction 是剩余比例 (0.0-1.0),转换为使用率百分比
|
||||
utilization := int((1.0 - modelInfo.QuotaInfo.RemainingFraction) * 100)
|
||||
|
||||
info.AntigravityQuota[modelName] = &AntigravityModelQuota{
|
||||
Utilization: utilization,
|
||||
ResetTime: modelInfo.QuotaInfo.ResetTime,
|
||||
}
|
||||
}
|
||||
|
||||
// 同时设置 FiveHour 用于兼容展示(取主要模型)
|
||||
priorityModels := []string{"claude-sonnet-4-20250514", "claude-sonnet-4", "gemini-2.5-pro"}
|
||||
for _, modelName := range priorityModels {
|
||||
if modelInfo, ok := modelsResp.Models[modelName]; ok && modelInfo.QuotaInfo != nil {
|
||||
utilization := (1.0 - modelInfo.QuotaInfo.RemainingFraction) * 100
|
||||
progress := &UsageProgress{
|
||||
Utilization: utilization,
|
||||
}
|
||||
if modelInfo.QuotaInfo.ResetTime != "" {
|
||||
if resetTime, err := time.Parse(time.RFC3339, modelInfo.QuotaInfo.ResetTime); err == nil {
|
||||
progress.ResetsAt = &resetTime
|
||||
progress.RemainingSeconds = int(time.Until(resetTime).Seconds())
|
||||
}
|
||||
}
|
||||
info.FiveHour = progress
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return info
|
||||
}
|
||||
|
||||
// GetProxyURL 获取账户的代理 URL
|
||||
func (f *AntigravityQuotaFetcher) GetProxyURL(ctx context.Context, account *Account) string {
|
||||
if account.ProxyID == nil || f.proxyRepo == nil {
|
||||
return ""
|
||||
}
|
||||
proxy, err := f.proxyRepo.GetByID(ctx, *account.ProxyID)
|
||||
if err != nil || proxy == nil {
|
||||
return ""
|
||||
}
|
||||
return proxy.URL()
|
||||
}
|
||||
88
backend/internal/service/antigravity_quota_scope.go
Normal file
88
backend/internal/service/antigravity_quota_scope.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const antigravityQuotaScopesKey = "antigravity_quota_scopes"
|
||||
|
||||
// AntigravityQuotaScope 表示 Antigravity 的配额域
|
||||
type AntigravityQuotaScope string
|
||||
|
||||
const (
|
||||
AntigravityQuotaScopeClaude AntigravityQuotaScope = "claude"
|
||||
AntigravityQuotaScopeGeminiText AntigravityQuotaScope = "gemini_text"
|
||||
AntigravityQuotaScopeGeminiImage AntigravityQuotaScope = "gemini_image"
|
||||
)
|
||||
|
||||
// resolveAntigravityQuotaScope 根据模型名称解析配额域
|
||||
func resolveAntigravityQuotaScope(requestedModel string) (AntigravityQuotaScope, bool) {
|
||||
model := normalizeAntigravityModelName(requestedModel)
|
||||
if model == "" {
|
||||
return "", false
|
||||
}
|
||||
switch {
|
||||
case strings.HasPrefix(model, "claude-"):
|
||||
return AntigravityQuotaScopeClaude, true
|
||||
case strings.HasPrefix(model, "gemini-"):
|
||||
if isImageGenerationModel(model) {
|
||||
return AntigravityQuotaScopeGeminiImage, true
|
||||
}
|
||||
return AntigravityQuotaScopeGeminiText, true
|
||||
default:
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeAntigravityModelName(model string) string {
|
||||
normalized := strings.ToLower(strings.TrimSpace(model))
|
||||
normalized = strings.TrimPrefix(normalized, "models/")
|
||||
return normalized
|
||||
}
|
||||
|
||||
// IsSchedulableForModel 结合 Antigravity 配额域限流判断是否可调度
|
||||
func (a *Account) IsSchedulableForModel(requestedModel string) bool {
|
||||
if a == nil {
|
||||
return false
|
||||
}
|
||||
if !a.IsSchedulable() {
|
||||
return false
|
||||
}
|
||||
if a.Platform != PlatformAntigravity {
|
||||
return true
|
||||
}
|
||||
scope, ok := resolveAntigravityQuotaScope(requestedModel)
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
resetAt := a.antigravityQuotaScopeResetAt(scope)
|
||||
if resetAt == nil {
|
||||
return true
|
||||
}
|
||||
now := time.Now()
|
||||
return !now.Before(*resetAt)
|
||||
}
|
||||
|
||||
func (a *Account) antigravityQuotaScopeResetAt(scope AntigravityQuotaScope) *time.Time {
|
||||
if a == nil || a.Extra == nil || scope == "" {
|
||||
return nil
|
||||
}
|
||||
rawScopes, ok := a.Extra[antigravityQuotaScopesKey].(map[string]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
rawScope, ok := rawScopes[string(scope)].(map[string]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
resetAtRaw, ok := rawScope["rate_limit_reset_at"].(string)
|
||||
if !ok || strings.TrimSpace(resetAtRaw) == "" {
|
||||
return nil
|
||||
}
|
||||
resetAt, err := time.Parse(time.RFC3339, resetAtRaw)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return &resetAt
|
||||
}
|
||||
130
backend/internal/service/antigravity_token_provider.go
Normal file
130
backend/internal/service/antigravity_token_provider.go
Normal file
@@ -0,0 +1,130 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
antigravityTokenRefreshSkew = 3 * time.Minute
|
||||
antigravityTokenCacheSkew = 5 * time.Minute
|
||||
)
|
||||
|
||||
// AntigravityTokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义)
|
||||
type AntigravityTokenCache = GeminiTokenCache
|
||||
|
||||
// AntigravityTokenProvider 管理 Antigravity 账户的 access_token
|
||||
type AntigravityTokenProvider struct {
|
||||
accountRepo AccountRepository
|
||||
tokenCache AntigravityTokenCache
|
||||
antigravityOAuthService *AntigravityOAuthService
|
||||
}
|
||||
|
||||
func NewAntigravityTokenProvider(
|
||||
accountRepo AccountRepository,
|
||||
tokenCache AntigravityTokenCache,
|
||||
antigravityOAuthService *AntigravityOAuthService,
|
||||
) *AntigravityTokenProvider {
|
||||
return &AntigravityTokenProvider{
|
||||
accountRepo: accountRepo,
|
||||
tokenCache: tokenCache,
|
||||
antigravityOAuthService: antigravityOAuthService,
|
||||
}
|
||||
}
|
||||
|
||||
// GetAccessToken 获取有效的 access_token
|
||||
func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
|
||||
if account == nil {
|
||||
return "", errors.New("account is nil")
|
||||
}
|
||||
if account.Platform != PlatformAntigravity || account.Type != AccountTypeOAuth {
|
||||
return "", errors.New("not an antigravity oauth account")
|
||||
}
|
||||
|
||||
cacheKey := AntigravityTokenCacheKey(account)
|
||||
|
||||
// 1. 先尝试缓存
|
||||
if p.tokenCache != nil {
|
||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
||||
return token, nil
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 如果即将过期则刷新
|
||||
expiresAt := account.GetCredentialAsTime("expires_at")
|
||||
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= antigravityTokenRefreshSkew
|
||||
if needsRefresh && p.tokenCache != nil {
|
||||
locked, err := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
|
||||
if err == nil && locked {
|
||||
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
|
||||
|
||||
// 拿到锁后再次检查缓存(另一个 worker 可能已刷新)
|
||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// 从数据库获取最新账户信息
|
||||
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
|
||||
if err == nil && fresh != nil {
|
||||
account = fresh
|
||||
}
|
||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||
if expiresAt == nil || time.Until(*expiresAt) <= antigravityTokenRefreshSkew {
|
||||
if p.antigravityOAuthService == nil {
|
||||
return "", errors.New("antigravity oauth service not configured")
|
||||
}
|
||||
tokenInfo, err := p.antigravityOAuthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
newCredentials := p.antigravityOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
for k, v := range account.Credentials {
|
||||
if _, exists := newCredentials[k]; !exists {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
}
|
||||
account.Credentials = newCredentials
|
||||
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
|
||||
log.Printf("[AntigravityTokenProvider] Failed to update account credentials: %v", updateErr)
|
||||
}
|
||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
accessToken := account.GetCredential("access_token")
|
||||
if strings.TrimSpace(accessToken) == "" {
|
||||
return "", errors.New("access_token not found in credentials")
|
||||
}
|
||||
|
||||
// 3. 存入缓存
|
||||
if p.tokenCache != nil {
|
||||
ttl := 30 * time.Minute
|
||||
if expiresAt != nil {
|
||||
until := time.Until(*expiresAt)
|
||||
switch {
|
||||
case until > antigravityTokenCacheSkew:
|
||||
ttl = until - antigravityTokenCacheSkew
|
||||
case until > 0:
|
||||
ttl = until
|
||||
default:
|
||||
ttl = time.Minute
|
||||
}
|
||||
}
|
||||
_ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
|
||||
}
|
||||
|
||||
return accessToken, nil
|
||||
}
|
||||
|
||||
func AntigravityTokenCacheKey(account *Account) string {
|
||||
projectID := strings.TrimSpace(account.GetCredential("project_id"))
|
||||
if projectID != "" {
|
||||
return "ag:" + projectID
|
||||
}
|
||||
return "ag:account:" + strconv.FormatInt(account.ID, 10)
|
||||
}
|
||||
65
backend/internal/service/antigravity_token_refresher.go
Normal file
65
backend/internal/service/antigravity_token_refresher.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
// antigravityRefreshWindow Antigravity token 提前刷新窗口:15分钟
|
||||
// Google OAuth token 有效期55分钟,提前15分钟刷新
|
||||
antigravityRefreshWindow = 15 * time.Minute
|
||||
)
|
||||
|
||||
// AntigravityTokenRefresher 实现 TokenRefresher 接口
|
||||
type AntigravityTokenRefresher struct {
|
||||
antigravityOAuthService *AntigravityOAuthService
|
||||
}
|
||||
|
||||
func NewAntigravityTokenRefresher(antigravityOAuthService *AntigravityOAuthService) *AntigravityTokenRefresher {
|
||||
return &AntigravityTokenRefresher{
|
||||
antigravityOAuthService: antigravityOAuthService,
|
||||
}
|
||||
}
|
||||
|
||||
// CanRefresh 检查是否可以刷新此账户
|
||||
func (r *AntigravityTokenRefresher) CanRefresh(account *Account) bool {
|
||||
return account.Platform == PlatformAntigravity && account.Type == AccountTypeOAuth
|
||||
}
|
||||
|
||||
// NeedsRefresh 检查账户是否需要刷新
|
||||
// Antigravity 使用固定的15分钟刷新窗口,忽略全局配置
|
||||
func (r *AntigravityTokenRefresher) NeedsRefresh(account *Account, _ time.Duration) bool {
|
||||
if !r.CanRefresh(account) {
|
||||
return false
|
||||
}
|
||||
expiresAt := account.GetCredentialAsTime("expires_at")
|
||||
if expiresAt == nil {
|
||||
return false
|
||||
}
|
||||
timeUntilExpiry := time.Until(*expiresAt)
|
||||
needsRefresh := timeUntilExpiry < antigravityRefreshWindow
|
||||
if needsRefresh {
|
||||
fmt.Printf("[AntigravityTokenRefresher] Account %d needs refresh: expires_at=%s, time_until_expiry=%v, window=%v\n",
|
||||
account.ID, expiresAt.Format("2006-01-02 15:04:05"), timeUntilExpiry, antigravityRefreshWindow)
|
||||
}
|
||||
return needsRefresh
|
||||
}
|
||||
|
||||
// Refresh 执行 token 刷新
|
||||
func (r *AntigravityTokenRefresher) Refresh(ctx context.Context, account *Account) (map[string]any, error) {
|
||||
tokenInfo, err := r.antigravityOAuthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
newCredentials := r.antigravityOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
for k, v := range account.Credentials {
|
||||
if _, exists := newCredentials[k]; !exists {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
return newCredentials, nil
|
||||
}
|
||||
22
backend/internal/service/api_key.go
Normal file
22
backend/internal/service/api_key.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package service
|
||||
|
||||
import "time"
|
||||
|
||||
type APIKey struct {
|
||||
ID int64
|
||||
UserID int64
|
||||
Key string
|
||||
Name string
|
||||
GroupID *int64
|
||||
Status string
|
||||
IPWhitelist []string
|
||||
IPBlacklist []string
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
User *User
|
||||
Group *Group
|
||||
}
|
||||
|
||||
func (k *APIKey) IsActive() bool {
|
||||
return k.Status == StatusActive
|
||||
}
|
||||
46
backend/internal/service/api_key_auth_cache.go
Normal file
46
backend/internal/service/api_key_auth_cache.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package service
|
||||
|
||||
// APIKeyAuthSnapshot API Key 认证缓存快照(仅包含认证所需字段)
|
||||
type APIKeyAuthSnapshot struct {
|
||||
APIKeyID int64 `json:"api_key_id"`
|
||||
UserID int64 `json:"user_id"`
|
||||
GroupID *int64 `json:"group_id,omitempty"`
|
||||
Status string `json:"status"`
|
||||
IPWhitelist []string `json:"ip_whitelist,omitempty"`
|
||||
IPBlacklist []string `json:"ip_blacklist,omitempty"`
|
||||
User APIKeyAuthUserSnapshot `json:"user"`
|
||||
Group *APIKeyAuthGroupSnapshot `json:"group,omitempty"`
|
||||
}
|
||||
|
||||
// APIKeyAuthUserSnapshot 用户快照
|
||||
type APIKeyAuthUserSnapshot struct {
|
||||
ID int64 `json:"id"`
|
||||
Status string `json:"status"`
|
||||
Role string `json:"role"`
|
||||
Balance float64 `json:"balance"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
}
|
||||
|
||||
// APIKeyAuthGroupSnapshot 分组快照
|
||||
type APIKeyAuthGroupSnapshot struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Platform string `json:"platform"`
|
||||
Status string `json:"status"`
|
||||
SubscriptionType string `json:"subscription_type"`
|
||||
RateMultiplier float64 `json:"rate_multiplier"`
|
||||
DailyLimitUSD *float64 `json:"daily_limit_usd,omitempty"`
|
||||
WeeklyLimitUSD *float64 `json:"weekly_limit_usd,omitempty"`
|
||||
MonthlyLimitUSD *float64 `json:"monthly_limit_usd,omitempty"`
|
||||
ImagePrice1K *float64 `json:"image_price_1k,omitempty"`
|
||||
ImagePrice2K *float64 `json:"image_price_2k,omitempty"`
|
||||
ImagePrice4K *float64 `json:"image_price_4k,omitempty"`
|
||||
ClaudeCodeOnly bool `json:"claude_code_only"`
|
||||
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`
|
||||
}
|
||||
|
||||
// APIKeyAuthCacheEntry 缓存条目,支持负缓存
|
||||
type APIKeyAuthCacheEntry struct {
|
||||
NotFound bool `json:"not_found"`
|
||||
Snapshot *APIKeyAuthSnapshot `json:"snapshot,omitempty"`
|
||||
}
|
||||
269
backend/internal/service/api_key_auth_cache_impl.go
Normal file
269
backend/internal/service/api_key_auth_cache_impl.go
Normal file
@@ -0,0 +1,269 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/dgraph-io/ristretto"
|
||||
)
|
||||
|
||||
type apiKeyAuthCacheConfig struct {
|
||||
l1Size int
|
||||
l1TTL time.Duration
|
||||
l2TTL time.Duration
|
||||
negativeTTL time.Duration
|
||||
jitterPercent int
|
||||
singleflight bool
|
||||
}
|
||||
|
||||
var (
|
||||
jitterRandMu sync.Mutex
|
||||
// 认证缓存抖动使用独立随机源,避免全局 Seed
|
||||
jitterRand = rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
)
|
||||
|
||||
func newAPIKeyAuthCacheConfig(cfg *config.Config) apiKeyAuthCacheConfig {
|
||||
if cfg == nil {
|
||||
return apiKeyAuthCacheConfig{}
|
||||
}
|
||||
auth := cfg.APIKeyAuth
|
||||
return apiKeyAuthCacheConfig{
|
||||
l1Size: auth.L1Size,
|
||||
l1TTL: time.Duration(auth.L1TTLSeconds) * time.Second,
|
||||
l2TTL: time.Duration(auth.L2TTLSeconds) * time.Second,
|
||||
negativeTTL: time.Duration(auth.NegativeTTLSeconds) * time.Second,
|
||||
jitterPercent: auth.JitterPercent,
|
||||
singleflight: auth.Singleflight,
|
||||
}
|
||||
}
|
||||
|
||||
func (c apiKeyAuthCacheConfig) l1Enabled() bool {
|
||||
return c.l1Size > 0 && c.l1TTL > 0
|
||||
}
|
||||
|
||||
func (c apiKeyAuthCacheConfig) l2Enabled() bool {
|
||||
return c.l2TTL > 0
|
||||
}
|
||||
|
||||
func (c apiKeyAuthCacheConfig) negativeEnabled() bool {
|
||||
return c.negativeTTL > 0
|
||||
}
|
||||
|
||||
func (c apiKeyAuthCacheConfig) jitterTTL(ttl time.Duration) time.Duration {
|
||||
if ttl <= 0 {
|
||||
return ttl
|
||||
}
|
||||
if c.jitterPercent <= 0 {
|
||||
return ttl
|
||||
}
|
||||
percent := c.jitterPercent
|
||||
if percent > 100 {
|
||||
percent = 100
|
||||
}
|
||||
delta := float64(percent) / 100
|
||||
jitterRandMu.Lock()
|
||||
randVal := jitterRand.Float64()
|
||||
jitterRandMu.Unlock()
|
||||
factor := 1 - delta + randVal*(2*delta)
|
||||
if factor <= 0 {
|
||||
return ttl
|
||||
}
|
||||
return time.Duration(float64(ttl) * factor)
|
||||
}
|
||||
|
||||
func (s *APIKeyService) initAuthCache(cfg *config.Config) {
|
||||
s.authCfg = newAPIKeyAuthCacheConfig(cfg)
|
||||
if !s.authCfg.l1Enabled() {
|
||||
return
|
||||
}
|
||||
cache, err := ristretto.NewCache(&ristretto.Config{
|
||||
NumCounters: int64(s.authCfg.l1Size) * 10,
|
||||
MaxCost: int64(s.authCfg.l1Size),
|
||||
BufferItems: 64,
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
s.authCacheL1 = cache
|
||||
}
|
||||
|
||||
func (s *APIKeyService) authCacheKey(key string) string {
|
||||
sum := sha256.Sum256([]byte(key))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func (s *APIKeyService) getAuthCacheEntry(ctx context.Context, cacheKey string) (*APIKeyAuthCacheEntry, bool) {
|
||||
if s.authCacheL1 != nil {
|
||||
if val, ok := s.authCacheL1.Get(cacheKey); ok {
|
||||
if entry, ok := val.(*APIKeyAuthCacheEntry); ok {
|
||||
return entry, true
|
||||
}
|
||||
}
|
||||
}
|
||||
if s.cache == nil || !s.authCfg.l2Enabled() {
|
||||
return nil, false
|
||||
}
|
||||
entry, err := s.cache.GetAuthCache(ctx, cacheKey)
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
s.setAuthCacheL1(cacheKey, entry)
|
||||
return entry, true
|
||||
}
|
||||
|
||||
func (s *APIKeyService) setAuthCacheL1(cacheKey string, entry *APIKeyAuthCacheEntry) {
|
||||
if s.authCacheL1 == nil || entry == nil {
|
||||
return
|
||||
}
|
||||
ttl := s.authCfg.l1TTL
|
||||
if entry.NotFound && s.authCfg.negativeTTL > 0 && s.authCfg.negativeTTL < ttl {
|
||||
ttl = s.authCfg.negativeTTL
|
||||
}
|
||||
ttl = s.authCfg.jitterTTL(ttl)
|
||||
_ = s.authCacheL1.SetWithTTL(cacheKey, entry, 1, ttl)
|
||||
}
|
||||
|
||||
func (s *APIKeyService) setAuthCacheEntry(ctx context.Context, cacheKey string, entry *APIKeyAuthCacheEntry, ttl time.Duration) {
|
||||
if entry == nil {
|
||||
return
|
||||
}
|
||||
s.setAuthCacheL1(cacheKey, entry)
|
||||
if s.cache == nil || !s.authCfg.l2Enabled() {
|
||||
return
|
||||
}
|
||||
_ = s.cache.SetAuthCache(ctx, cacheKey, entry, s.authCfg.jitterTTL(ttl))
|
||||
}
|
||||
|
||||
func (s *APIKeyService) deleteAuthCache(ctx context.Context, cacheKey string) {
|
||||
if s.authCacheL1 != nil {
|
||||
s.authCacheL1.Del(cacheKey)
|
||||
}
|
||||
if s.cache == nil {
|
||||
return
|
||||
}
|
||||
_ = s.cache.DeleteAuthCache(ctx, cacheKey)
|
||||
}
|
||||
|
||||
func (s *APIKeyService) loadAuthCacheEntry(ctx context.Context, key, cacheKey string) (*APIKeyAuthCacheEntry, error) {
|
||||
apiKey, err := s.apiKeyRepo.GetByKeyForAuth(ctx, key)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrAPIKeyNotFound) {
|
||||
entry := &APIKeyAuthCacheEntry{NotFound: true}
|
||||
if s.authCfg.negativeEnabled() {
|
||||
s.setAuthCacheEntry(ctx, cacheKey, entry, s.authCfg.negativeTTL)
|
||||
}
|
||||
return entry, nil
|
||||
}
|
||||
return nil, fmt.Errorf("get api key: %w", err)
|
||||
}
|
||||
apiKey.Key = key
|
||||
snapshot := s.snapshotFromAPIKey(apiKey)
|
||||
if snapshot == nil {
|
||||
return nil, fmt.Errorf("get api key: %w", ErrAPIKeyNotFound)
|
||||
}
|
||||
entry := &APIKeyAuthCacheEntry{Snapshot: snapshot}
|
||||
s.setAuthCacheEntry(ctx, cacheKey, entry, s.authCfg.l2TTL)
|
||||
return entry, nil
|
||||
}
|
||||
|
||||
func (s *APIKeyService) applyAuthCacheEntry(key string, entry *APIKeyAuthCacheEntry) (*APIKey, bool, error) {
|
||||
if entry == nil {
|
||||
return nil, false, nil
|
||||
}
|
||||
if entry.NotFound {
|
||||
return nil, true, ErrAPIKeyNotFound
|
||||
}
|
||||
if entry.Snapshot == nil {
|
||||
return nil, false, nil
|
||||
}
|
||||
return s.snapshotToAPIKey(key, entry.Snapshot), true, nil
|
||||
}
|
||||
|
||||
func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
|
||||
if apiKey == nil || apiKey.User == nil {
|
||||
return nil
|
||||
}
|
||||
snapshot := &APIKeyAuthSnapshot{
|
||||
APIKeyID: apiKey.ID,
|
||||
UserID: apiKey.UserID,
|
||||
GroupID: apiKey.GroupID,
|
||||
Status: apiKey.Status,
|
||||
IPWhitelist: apiKey.IPWhitelist,
|
||||
IPBlacklist: apiKey.IPBlacklist,
|
||||
User: APIKeyAuthUserSnapshot{
|
||||
ID: apiKey.User.ID,
|
||||
Status: apiKey.User.Status,
|
||||
Role: apiKey.User.Role,
|
||||
Balance: apiKey.User.Balance,
|
||||
Concurrency: apiKey.User.Concurrency,
|
||||
},
|
||||
}
|
||||
if apiKey.Group != nil {
|
||||
snapshot.Group = &APIKeyAuthGroupSnapshot{
|
||||
ID: apiKey.Group.ID,
|
||||
Name: apiKey.Group.Name,
|
||||
Platform: apiKey.Group.Platform,
|
||||
Status: apiKey.Group.Status,
|
||||
SubscriptionType: apiKey.Group.SubscriptionType,
|
||||
RateMultiplier: apiKey.Group.RateMultiplier,
|
||||
DailyLimitUSD: apiKey.Group.DailyLimitUSD,
|
||||
WeeklyLimitUSD: apiKey.Group.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: apiKey.Group.MonthlyLimitUSD,
|
||||
ImagePrice1K: apiKey.Group.ImagePrice1K,
|
||||
ImagePrice2K: apiKey.Group.ImagePrice2K,
|
||||
ImagePrice4K: apiKey.Group.ImagePrice4K,
|
||||
ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly,
|
||||
FallbackGroupID: apiKey.Group.FallbackGroupID,
|
||||
}
|
||||
}
|
||||
return snapshot
|
||||
}
|
||||
|
||||
func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapshot) *APIKey {
|
||||
if snapshot == nil {
|
||||
return nil
|
||||
}
|
||||
apiKey := &APIKey{
|
||||
ID: snapshot.APIKeyID,
|
||||
UserID: snapshot.UserID,
|
||||
GroupID: snapshot.GroupID,
|
||||
Key: key,
|
||||
Status: snapshot.Status,
|
||||
IPWhitelist: snapshot.IPWhitelist,
|
||||
IPBlacklist: snapshot.IPBlacklist,
|
||||
User: &User{
|
||||
ID: snapshot.User.ID,
|
||||
Status: snapshot.User.Status,
|
||||
Role: snapshot.User.Role,
|
||||
Balance: snapshot.User.Balance,
|
||||
Concurrency: snapshot.User.Concurrency,
|
||||
},
|
||||
}
|
||||
if snapshot.Group != nil {
|
||||
apiKey.Group = &Group{
|
||||
ID: snapshot.Group.ID,
|
||||
Name: snapshot.Group.Name,
|
||||
Platform: snapshot.Group.Platform,
|
||||
Status: snapshot.Group.Status,
|
||||
Hydrated: true,
|
||||
SubscriptionType: snapshot.Group.SubscriptionType,
|
||||
RateMultiplier: snapshot.Group.RateMultiplier,
|
||||
DailyLimitUSD: snapshot.Group.DailyLimitUSD,
|
||||
WeeklyLimitUSD: snapshot.Group.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: snapshot.Group.MonthlyLimitUSD,
|
||||
ImagePrice1K: snapshot.Group.ImagePrice1K,
|
||||
ImagePrice2K: snapshot.Group.ImagePrice2K,
|
||||
ImagePrice4K: snapshot.Group.ImagePrice4K,
|
||||
ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly,
|
||||
FallbackGroupID: snapshot.Group.FallbackGroupID,
|
||||
}
|
||||
}
|
||||
return apiKey
|
||||
}
|
||||
48
backend/internal/service/api_key_auth_cache_invalidate.go
Normal file
48
backend/internal/service/api_key_auth_cache_invalidate.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package service
|
||||
|
||||
import "context"
|
||||
|
||||
// InvalidateAuthCacheByKey 清除指定 API Key 的认证缓存
|
||||
func (s *APIKeyService) InvalidateAuthCacheByKey(ctx context.Context, key string) {
|
||||
if key == "" {
|
||||
return
|
||||
}
|
||||
cacheKey := s.authCacheKey(key)
|
||||
s.deleteAuthCache(ctx, cacheKey)
|
||||
}
|
||||
|
||||
// InvalidateAuthCacheByUserID 清除用户相关的 API Key 认证缓存
|
||||
func (s *APIKeyService) InvalidateAuthCacheByUserID(ctx context.Context, userID int64) {
|
||||
if userID <= 0 {
|
||||
return
|
||||
}
|
||||
keys, err := s.apiKeyRepo.ListKeysByUserID(ctx, userID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
s.deleteAuthCacheByKeys(ctx, keys)
|
||||
}
|
||||
|
||||
// InvalidateAuthCacheByGroupID 清除分组相关的 API Key 认证缓存
|
||||
func (s *APIKeyService) InvalidateAuthCacheByGroupID(ctx context.Context, groupID int64) {
|
||||
if groupID <= 0 {
|
||||
return
|
||||
}
|
||||
keys, err := s.apiKeyRepo.ListKeysByGroupID(ctx, groupID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
s.deleteAuthCacheByKeys(ctx, keys)
|
||||
}
|
||||
|
||||
func (s *APIKeyService) deleteAuthCacheByKeys(ctx context.Context, keys []string) {
|
||||
if len(keys) == 0 {
|
||||
return
|
||||
}
|
||||
for _, key := range keys {
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
s.deleteAuthCache(ctx, s.authCacheKey(key))
|
||||
}
|
||||
}
|
||||
570
backend/internal/service/api_key_service.go
Normal file
570
backend/internal/service/api_key_service.go
Normal file
@@ -0,0 +1,570 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
"github.com/dgraph-io/ristretto"
|
||||
"golang.org/x/sync/singleflight"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrAPIKeyNotFound = infraerrors.NotFound("API_KEY_NOT_FOUND", "api key not found")
|
||||
ErrGroupNotAllowed = infraerrors.Forbidden("GROUP_NOT_ALLOWED", "user is not allowed to bind this group")
|
||||
ErrAPIKeyExists = infraerrors.Conflict("API_KEY_EXISTS", "api key already exists")
|
||||
ErrAPIKeyTooShort = infraerrors.BadRequest("API_KEY_TOO_SHORT", "api key must be at least 16 characters")
|
||||
ErrAPIKeyInvalidChars = infraerrors.BadRequest("API_KEY_INVALID_CHARS", "api key can only contain letters, numbers, underscores, and hyphens")
|
||||
ErrAPIKeyRateLimited = infraerrors.TooManyRequests("API_KEY_RATE_LIMITED", "too many failed attempts, please try again later")
|
||||
ErrInvalidIPPattern = infraerrors.BadRequest("INVALID_IP_PATTERN", "invalid IP or CIDR pattern")
|
||||
)
|
||||
|
||||
const (
|
||||
apiKeyMaxErrorsPerHour = 20
|
||||
)
|
||||
|
||||
type APIKeyRepository interface {
|
||||
Create(ctx context.Context, key *APIKey) error
|
||||
GetByID(ctx context.Context, id int64) (*APIKey, error)
|
||||
// GetKeyAndOwnerID 仅获取 API Key 的 key 与所有者 ID,用于删除等轻量场景
|
||||
GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error)
|
||||
GetByKey(ctx context.Context, key string) (*APIKey, error)
|
||||
// GetByKeyForAuth 认证专用查询,返回最小字段集
|
||||
GetByKeyForAuth(ctx context.Context, key string) (*APIKey, error)
|
||||
Update(ctx context.Context, key *APIKey) error
|
||||
Delete(ctx context.Context, id int64) error
|
||||
|
||||
ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error)
|
||||
VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error)
|
||||
CountByUserID(ctx context.Context, userID int64) (int64, error)
|
||||
ExistsByKey(ctx context.Context, key string) (bool, error)
|
||||
ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error)
|
||||
SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error)
|
||||
ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error)
|
||||
CountByGroupID(ctx context.Context, groupID int64) (int64, error)
|
||||
ListKeysByUserID(ctx context.Context, userID int64) ([]string, error)
|
||||
ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error)
|
||||
}
|
||||
|
||||
// APIKeyCache defines cache operations for API key service
|
||||
type APIKeyCache interface {
|
||||
GetCreateAttemptCount(ctx context.Context, userID int64) (int, error)
|
||||
IncrementCreateAttemptCount(ctx context.Context, userID int64) error
|
||||
DeleteCreateAttemptCount(ctx context.Context, userID int64) error
|
||||
|
||||
IncrementDailyUsage(ctx context.Context, apiKey string) error
|
||||
SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error
|
||||
|
||||
GetAuthCache(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error)
|
||||
SetAuthCache(ctx context.Context, key string, entry *APIKeyAuthCacheEntry, ttl time.Duration) error
|
||||
DeleteAuthCache(ctx context.Context, key string) error
|
||||
}
|
||||
|
||||
// APIKeyAuthCacheInvalidator 提供认证缓存失效能力
|
||||
type APIKeyAuthCacheInvalidator interface {
|
||||
InvalidateAuthCacheByKey(ctx context.Context, key string)
|
||||
InvalidateAuthCacheByUserID(ctx context.Context, userID int64)
|
||||
InvalidateAuthCacheByGroupID(ctx context.Context, groupID int64)
|
||||
}
|
||||
|
||||
// CreateAPIKeyRequest 创建API Key请求
|
||||
type CreateAPIKeyRequest struct {
|
||||
Name string `json:"name"`
|
||||
GroupID *int64 `json:"group_id"`
|
||||
CustomKey *string `json:"custom_key"` // 可选的自定义key
|
||||
IPWhitelist []string `json:"ip_whitelist"` // IP 白名单
|
||||
IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单
|
||||
}
|
||||
|
||||
// UpdateAPIKeyRequest 更新API Key请求
|
||||
type UpdateAPIKeyRequest struct {
|
||||
Name *string `json:"name"`
|
||||
GroupID *int64 `json:"group_id"`
|
||||
Status *string `json:"status"`
|
||||
IPWhitelist []string `json:"ip_whitelist"` // IP 白名单(空数组清空)
|
||||
IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单(空数组清空)
|
||||
}
|
||||
|
||||
// APIKeyService API Key服务
|
||||
type APIKeyService struct {
|
||||
apiKeyRepo APIKeyRepository
|
||||
userRepo UserRepository
|
||||
groupRepo GroupRepository
|
||||
userSubRepo UserSubscriptionRepository
|
||||
cache APIKeyCache
|
||||
cfg *config.Config
|
||||
authCacheL1 *ristretto.Cache
|
||||
authCfg apiKeyAuthCacheConfig
|
||||
authGroup singleflight.Group
|
||||
}
|
||||
|
||||
// NewAPIKeyService 创建API Key服务实例
|
||||
func NewAPIKeyService(
|
||||
apiKeyRepo APIKeyRepository,
|
||||
userRepo UserRepository,
|
||||
groupRepo GroupRepository,
|
||||
userSubRepo UserSubscriptionRepository,
|
||||
cache APIKeyCache,
|
||||
cfg *config.Config,
|
||||
) *APIKeyService {
|
||||
svc := &APIKeyService{
|
||||
apiKeyRepo: apiKeyRepo,
|
||||
userRepo: userRepo,
|
||||
groupRepo: groupRepo,
|
||||
userSubRepo: userSubRepo,
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
}
|
||||
svc.initAuthCache(cfg)
|
||||
return svc
|
||||
}
|
||||
|
||||
// GenerateKey 生成随机API Key
|
||||
func (s *APIKeyService) GenerateKey() (string, error) {
|
||||
// 生成32字节随机数据
|
||||
bytes := make([]byte, 32)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return "", fmt.Errorf("generate random bytes: %w", err)
|
||||
}
|
||||
|
||||
// 转换为十六进制字符串并添加前缀
|
||||
prefix := s.cfg.Default.APIKeyPrefix
|
||||
if prefix == "" {
|
||||
prefix = "sk-"
|
||||
}
|
||||
|
||||
key := prefix + hex.EncodeToString(bytes)
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// ValidateCustomKey 验证自定义API Key格式
|
||||
func (s *APIKeyService) ValidateCustomKey(key string) error {
|
||||
// 检查长度
|
||||
if len(key) < 16 {
|
||||
return ErrAPIKeyTooShort
|
||||
}
|
||||
|
||||
// 检查字符:只允许字母、数字、下划线、连字符
|
||||
for _, c := range key {
|
||||
if (c >= 'a' && c <= 'z') ||
|
||||
(c >= 'A' && c <= 'Z') ||
|
||||
(c >= '0' && c <= '9') ||
|
||||
c == '_' || c == '-' {
|
||||
continue
|
||||
}
|
||||
return ErrAPIKeyInvalidChars
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkAPIKeyRateLimit 检查用户创建自定义Key的错误次数是否超限
|
||||
func (s *APIKeyService) checkAPIKeyRateLimit(ctx context.Context, userID int64) error {
|
||||
if s.cache == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
count, err := s.cache.GetCreateAttemptCount(ctx, userID)
|
||||
if err != nil {
|
||||
// Redis 出错时不阻止用户操作
|
||||
return nil
|
||||
}
|
||||
|
||||
if count >= apiKeyMaxErrorsPerHour {
|
||||
return ErrAPIKeyRateLimited
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// incrementAPIKeyErrorCount 增加用户创建自定义Key的错误计数
|
||||
func (s *APIKeyService) incrementAPIKeyErrorCount(ctx context.Context, userID int64) {
|
||||
if s.cache == nil {
|
||||
return
|
||||
}
|
||||
|
||||
_ = s.cache.IncrementCreateAttemptCount(ctx, userID)
|
||||
}
|
||||
|
||||
// canUserBindGroup 检查用户是否可以绑定指定分组
|
||||
// 对于订阅类型分组:检查用户是否有有效订阅
|
||||
// 对于标准类型分组:使用原有的 AllowedGroups 和 IsExclusive 逻辑
|
||||
func (s *APIKeyService) canUserBindGroup(ctx context.Context, user *User, group *Group) bool {
|
||||
// 订阅类型分组:需要有效订阅
|
||||
if group.IsSubscriptionType() {
|
||||
_, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, user.ID, group.ID)
|
||||
return err == nil // 有有效订阅则允许
|
||||
}
|
||||
// 标准类型分组:使用原有逻辑
|
||||
return user.CanBindGroup(group.ID, group.IsExclusive)
|
||||
}
|
||||
|
||||
// Create 创建API Key
|
||||
func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIKeyRequest) (*APIKey, error) {
|
||||
// 验证用户存在
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get user: %w", err)
|
||||
}
|
||||
|
||||
// 验证 IP 白名单格式
|
||||
if len(req.IPWhitelist) > 0 {
|
||||
if invalid := ip.ValidateIPPatterns(req.IPWhitelist); len(invalid) > 0 {
|
||||
return nil, fmt.Errorf("%w: %v", ErrInvalidIPPattern, invalid)
|
||||
}
|
||||
}
|
||||
|
||||
// 验证 IP 黑名单格式
|
||||
if len(req.IPBlacklist) > 0 {
|
||||
if invalid := ip.ValidateIPPatterns(req.IPBlacklist); len(invalid) > 0 {
|
||||
return nil, fmt.Errorf("%w: %v", ErrInvalidIPPattern, invalid)
|
||||
}
|
||||
}
|
||||
|
||||
// 验证分组权限(如果指定了分组)
|
||||
if req.GroupID != nil {
|
||||
group, err := s.groupRepo.GetByID(ctx, *req.GroupID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get group: %w", err)
|
||||
}
|
||||
|
||||
// 检查用户是否可以绑定该分组
|
||||
if !s.canUserBindGroup(ctx, user, group) {
|
||||
return nil, ErrGroupNotAllowed
|
||||
}
|
||||
}
|
||||
|
||||
var key string
|
||||
|
||||
// 判断是否使用自定义Key
|
||||
if req.CustomKey != nil && *req.CustomKey != "" {
|
||||
// 检查限流(仅对自定义key进行限流)
|
||||
if err := s.checkAPIKeyRateLimit(ctx, userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 验证自定义Key格式
|
||||
if err := s.ValidateCustomKey(*req.CustomKey); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 检查Key是否已存在
|
||||
exists, err := s.apiKeyRepo.ExistsByKey(ctx, *req.CustomKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("check key exists: %w", err)
|
||||
}
|
||||
if exists {
|
||||
// Key已存在,增加错误计数
|
||||
s.incrementAPIKeyErrorCount(ctx, userID)
|
||||
return nil, ErrAPIKeyExists
|
||||
}
|
||||
|
||||
key = *req.CustomKey
|
||||
} else {
|
||||
// 生成随机API Key
|
||||
var err error
|
||||
key, err = s.GenerateKey()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate key: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 创建API Key记录
|
||||
apiKey := &APIKey{
|
||||
UserID: userID,
|
||||
Key: key,
|
||||
Name: req.Name,
|
||||
GroupID: req.GroupID,
|
||||
Status: StatusActive,
|
||||
IPWhitelist: req.IPWhitelist,
|
||||
IPBlacklist: req.IPBlacklist,
|
||||
}
|
||||
|
||||
if err := s.apiKeyRepo.Create(ctx, apiKey); err != nil {
|
||||
return nil, fmt.Errorf("create api key: %w", err)
|
||||
}
|
||||
|
||||
s.InvalidateAuthCacheByKey(ctx, apiKey.Key)
|
||||
|
||||
return apiKey, nil
|
||||
}
|
||||
|
||||
// List 获取用户的API Key列表
|
||||
func (s *APIKeyService) List(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
|
||||
keys, pagination, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("list api keys: %w", err)
|
||||
}
|
||||
return keys, pagination, nil
|
||||
}
|
||||
|
||||
func (s *APIKeyService) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
|
||||
if len(apiKeyIDs) == 0 {
|
||||
return []int64{}, nil
|
||||
}
|
||||
|
||||
validIDs, err := s.apiKeyRepo.VerifyOwnership(ctx, userID, apiKeyIDs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("verify api key ownership: %w", err)
|
||||
}
|
||||
return validIDs, nil
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取API Key
|
||||
func (s *APIKeyService) GetByID(ctx context.Context, id int64) (*APIKey, error) {
|
||||
apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get api key: %w", err)
|
||||
}
|
||||
return apiKey, nil
|
||||
}
|
||||
|
||||
// GetByKey 根据Key字符串获取API Key(用于认证)
|
||||
func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, error) {
|
||||
cacheKey := s.authCacheKey(key)
|
||||
|
||||
if entry, ok := s.getAuthCacheEntry(ctx, cacheKey); ok {
|
||||
if apiKey, used, err := s.applyAuthCacheEntry(key, entry); used {
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get api key: %w", err)
|
||||
}
|
||||
return apiKey, nil
|
||||
}
|
||||
}
|
||||
|
||||
if s.authCfg.singleflight {
|
||||
value, err, _ := s.authGroup.Do(cacheKey, func() (any, error) {
|
||||
return s.loadAuthCacheEntry(ctx, key, cacheKey)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
entry, _ := value.(*APIKeyAuthCacheEntry)
|
||||
if apiKey, used, err := s.applyAuthCacheEntry(key, entry); used {
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get api key: %w", err)
|
||||
}
|
||||
return apiKey, nil
|
||||
}
|
||||
} else {
|
||||
entry, err := s.loadAuthCacheEntry(ctx, key, cacheKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if apiKey, used, err := s.applyAuthCacheEntry(key, entry); used {
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get api key: %w", err)
|
||||
}
|
||||
return apiKey, nil
|
||||
}
|
||||
}
|
||||
|
||||
apiKey, err := s.apiKeyRepo.GetByKeyForAuth(ctx, key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get api key: %w", err)
|
||||
}
|
||||
apiKey.Key = key
|
||||
return apiKey, nil
|
||||
}
|
||||
|
||||
// Update 更新API Key
|
||||
func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req UpdateAPIKeyRequest) (*APIKey, error) {
|
||||
apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get api key: %w", err)
|
||||
}
|
||||
|
||||
// 验证所有权
|
||||
if apiKey.UserID != userID {
|
||||
return nil, ErrInsufficientPerms
|
||||
}
|
||||
|
||||
// 验证 IP 白名单格式
|
||||
if len(req.IPWhitelist) > 0 {
|
||||
if invalid := ip.ValidateIPPatterns(req.IPWhitelist); len(invalid) > 0 {
|
||||
return nil, fmt.Errorf("%w: %v", ErrInvalidIPPattern, invalid)
|
||||
}
|
||||
}
|
||||
|
||||
// 验证 IP 黑名单格式
|
||||
if len(req.IPBlacklist) > 0 {
|
||||
if invalid := ip.ValidateIPPatterns(req.IPBlacklist); len(invalid) > 0 {
|
||||
return nil, fmt.Errorf("%w: %v", ErrInvalidIPPattern, invalid)
|
||||
}
|
||||
}
|
||||
|
||||
// 更新字段
|
||||
if req.Name != nil {
|
||||
apiKey.Name = *req.Name
|
||||
}
|
||||
|
||||
if req.GroupID != nil {
|
||||
// 验证分组权限
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get user: %w", err)
|
||||
}
|
||||
|
||||
group, err := s.groupRepo.GetByID(ctx, *req.GroupID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get group: %w", err)
|
||||
}
|
||||
|
||||
if !s.canUserBindGroup(ctx, user, group) {
|
||||
return nil, ErrGroupNotAllowed
|
||||
}
|
||||
|
||||
apiKey.GroupID = req.GroupID
|
||||
}
|
||||
|
||||
if req.Status != nil {
|
||||
apiKey.Status = *req.Status
|
||||
// 如果状态改变,清除Redis缓存
|
||||
if s.cache != nil {
|
||||
_ = s.cache.DeleteCreateAttemptCount(ctx, apiKey.UserID)
|
||||
}
|
||||
}
|
||||
|
||||
// 更新 IP 限制(空数组会清空设置)
|
||||
apiKey.IPWhitelist = req.IPWhitelist
|
||||
apiKey.IPBlacklist = req.IPBlacklist
|
||||
|
||||
if err := s.apiKeyRepo.Update(ctx, apiKey); err != nil {
|
||||
return nil, fmt.Errorf("update api key: %w", err)
|
||||
}
|
||||
|
||||
s.InvalidateAuthCacheByKey(ctx, apiKey.Key)
|
||||
|
||||
return apiKey, nil
|
||||
}
|
||||
|
||||
// Delete 删除API Key
|
||||
func (s *APIKeyService) Delete(ctx context.Context, id int64, userID int64) error {
|
||||
key, ownerID, err := s.apiKeyRepo.GetKeyAndOwnerID(ctx, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get api key: %w", err)
|
||||
}
|
||||
|
||||
// 验证当前用户是否为该 API Key 的所有者
|
||||
if ownerID != userID {
|
||||
return ErrInsufficientPerms
|
||||
}
|
||||
|
||||
// 清除Redis缓存(使用 userID 而非 apiKey.UserID)
|
||||
if s.cache != nil {
|
||||
_ = s.cache.DeleteCreateAttemptCount(ctx, userID)
|
||||
}
|
||||
s.InvalidateAuthCacheByKey(ctx, key)
|
||||
|
||||
if err := s.apiKeyRepo.Delete(ctx, id); err != nil {
|
||||
return fmt.Errorf("delete api key: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateKey 验证API Key是否有效(用于认证中间件)
|
||||
func (s *APIKeyService) ValidateKey(ctx context.Context, key string) (*APIKey, *User, error) {
|
||||
// 获取API Key
|
||||
apiKey, err := s.GetByKey(ctx, key)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// 检查API Key状态
|
||||
if !apiKey.IsActive() {
|
||||
return nil, nil, infraerrors.Unauthorized("API_KEY_INACTIVE", "api key is not active")
|
||||
}
|
||||
|
||||
// 获取用户信息
|
||||
user, err := s.userRepo.GetByID(ctx, apiKey.UserID)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("get user: %w", err)
|
||||
}
|
||||
|
||||
// 检查用户状态
|
||||
if !user.IsActive() {
|
||||
return nil, nil, ErrUserNotActive
|
||||
}
|
||||
|
||||
return apiKey, user, nil
|
||||
}
|
||||
|
||||
// IncrementUsage 增加API Key使用次数(可选:用于统计)
|
||||
func (s *APIKeyService) IncrementUsage(ctx context.Context, keyID int64) error {
|
||||
// 使用Redis计数器
|
||||
if s.cache != nil {
|
||||
cacheKey := fmt.Sprintf("apikey:usage:%d:%s", keyID, timezone.Now().Format("2006-01-02"))
|
||||
if err := s.cache.IncrementDailyUsage(ctx, cacheKey); err != nil {
|
||||
return fmt.Errorf("increment usage: %w", err)
|
||||
}
|
||||
// 设置24小时过期
|
||||
_ = s.cache.SetDailyUsageExpiry(ctx, cacheKey, 24*time.Hour)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAvailableGroups 获取用户有权限绑定的分组列表
|
||||
// 返回用户可以选择的分组:
|
||||
// - 标准类型分组:公开的(非专属)或用户被明确允许的
|
||||
// - 订阅类型分组:用户有有效订阅的
|
||||
func (s *APIKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([]Group, error) {
|
||||
// 获取用户信息
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get user: %w", err)
|
||||
}
|
||||
|
||||
// 获取所有活跃分组
|
||||
allGroups, err := s.groupRepo.ListActive(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list active groups: %w", err)
|
||||
}
|
||||
|
||||
// 获取用户的所有有效订阅
|
||||
activeSubscriptions, err := s.userSubRepo.ListActiveByUserID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list active subscriptions: %w", err)
|
||||
}
|
||||
|
||||
// 构建订阅分组 ID 集合
|
||||
subscribedGroupIDs := make(map[int64]bool)
|
||||
for _, sub := range activeSubscriptions {
|
||||
subscribedGroupIDs[sub.GroupID] = true
|
||||
}
|
||||
|
||||
// 过滤出用户有权限的分组
|
||||
availableGroups := make([]Group, 0)
|
||||
for _, group := range allGroups {
|
||||
if s.canUserBindGroupInternal(user, &group, subscribedGroupIDs) {
|
||||
availableGroups = append(availableGroups, group)
|
||||
}
|
||||
}
|
||||
|
||||
return availableGroups, nil
|
||||
}
|
||||
|
||||
// canUserBindGroupInternal 内部方法,检查用户是否可以绑定分组(使用预加载的订阅数据)
|
||||
func (s *APIKeyService) canUserBindGroupInternal(user *User, group *Group, subscribedGroupIDs map[int64]bool) bool {
|
||||
// 订阅类型分组:需要有效订阅
|
||||
if group.IsSubscriptionType() {
|
||||
return subscribedGroupIDs[group.ID]
|
||||
}
|
||||
// 标准类型分组:使用原有逻辑
|
||||
return user.CanBindGroup(group.ID, group.IsExclusive)
|
||||
}
|
||||
|
||||
func (s *APIKeyService) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error) {
|
||||
keys, err := s.apiKeyRepo.SearchAPIKeys(ctx, userID, keyword, limit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("search api keys: %w", err)
|
||||
}
|
||||
return keys, nil
|
||||
}
|
||||
417
backend/internal/service/api_key_service_cache_test.go
Normal file
417
backend/internal/service/api_key_service_cache_test.go
Normal file
@@ -0,0 +1,417 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type authRepoStub struct {
|
||||
getByKeyForAuth func(ctx context.Context, key string) (*APIKey, error)
|
||||
listKeysByUserID func(ctx context.Context, userID int64) ([]string, error)
|
||||
listKeysByGroupID func(ctx context.Context, groupID int64) ([]string, error)
|
||||
}
|
||||
|
||||
func (s *authRepoStub) Create(ctx context.Context, key *APIKey) error {
|
||||
panic("unexpected Create call")
|
||||
}
|
||||
|
||||
func (s *authRepoStub) GetByID(ctx context.Context, id int64) (*APIKey, error) {
|
||||
panic("unexpected GetByID call")
|
||||
}
|
||||
|
||||
func (s *authRepoStub) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) {
|
||||
panic("unexpected GetKeyAndOwnerID call")
|
||||
}
|
||||
|
||||
func (s *authRepoStub) GetByKey(ctx context.Context, key string) (*APIKey, error) {
|
||||
panic("unexpected GetByKey call")
|
||||
}
|
||||
|
||||
func (s *authRepoStub) GetByKeyForAuth(ctx context.Context, key string) (*APIKey, error) {
|
||||
if s.getByKeyForAuth == nil {
|
||||
panic("unexpected GetByKeyForAuth call")
|
||||
}
|
||||
return s.getByKeyForAuth(ctx, key)
|
||||
}
|
||||
|
||||
func (s *authRepoStub) Update(ctx context.Context, key *APIKey) error {
|
||||
panic("unexpected Update call")
|
||||
}
|
||||
|
||||
func (s *authRepoStub) Delete(ctx context.Context, id int64) error {
|
||||
panic("unexpected Delete call")
|
||||
}
|
||||
|
||||
func (s *authRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListByUserID call")
|
||||
}
|
||||
|
||||
func (s *authRepoStub) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
|
||||
panic("unexpected VerifyOwnership call")
|
||||
}
|
||||
|
||||
func (s *authRepoStub) CountByUserID(ctx context.Context, userID int64) (int64, error) {
|
||||
panic("unexpected CountByUserID call")
|
||||
}
|
||||
|
||||
func (s *authRepoStub) ExistsByKey(ctx context.Context, key string) (bool, error) {
|
||||
panic("unexpected ExistsByKey call")
|
||||
}
|
||||
|
||||
func (s *authRepoStub) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListByGroupID call")
|
||||
}
|
||||
|
||||
func (s *authRepoStub) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error) {
|
||||
panic("unexpected SearchAPIKeys call")
|
||||
}
|
||||
|
||||
func (s *authRepoStub) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
panic("unexpected ClearGroupIDByGroupID call")
|
||||
}
|
||||
|
||||
func (s *authRepoStub) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
panic("unexpected CountByGroupID call")
|
||||
}
|
||||
|
||||
func (s *authRepoStub) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) {
|
||||
if s.listKeysByUserID == nil {
|
||||
panic("unexpected ListKeysByUserID call")
|
||||
}
|
||||
return s.listKeysByUserID(ctx, userID)
|
||||
}
|
||||
|
||||
func (s *authRepoStub) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) {
|
||||
if s.listKeysByGroupID == nil {
|
||||
panic("unexpected ListKeysByGroupID call")
|
||||
}
|
||||
return s.listKeysByGroupID(ctx, groupID)
|
||||
}
|
||||
|
||||
type authCacheStub struct {
|
||||
getAuthCache func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error)
|
||||
setAuthKeys []string
|
||||
deleteAuthKeys []string
|
||||
}
|
||||
|
||||
func (s *authCacheStub) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *authCacheStub) IncrementCreateAttemptCount(ctx context.Context, userID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *authCacheStub) DeleteCreateAttemptCount(ctx context.Context, userID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *authCacheStub) IncrementDailyUsage(ctx context.Context, apiKey string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *authCacheStub) SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *authCacheStub) GetAuthCache(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
|
||||
if s.getAuthCache == nil {
|
||||
return nil, redis.Nil
|
||||
}
|
||||
return s.getAuthCache(ctx, key)
|
||||
}
|
||||
|
||||
func (s *authCacheStub) SetAuthCache(ctx context.Context, key string, entry *APIKeyAuthCacheEntry, ttl time.Duration) error {
|
||||
s.setAuthKeys = append(s.setAuthKeys, key)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *authCacheStub) DeleteAuthCache(ctx context.Context, key string) error {
|
||||
s.deleteAuthKeys = append(s.deleteAuthKeys, key)
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) {
|
||||
cache := &authCacheStub{}
|
||||
repo := &authRepoStub{
|
||||
getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) {
|
||||
return nil, errors.New("unexpected repo call")
|
||||
},
|
||||
}
|
||||
cfg := &config.Config{
|
||||
APIKeyAuth: config.APIKeyAuthCacheConfig{
|
||||
L2TTLSeconds: 60,
|
||||
NegativeTTLSeconds: 30,
|
||||
},
|
||||
}
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||
|
||||
groupID := int64(9)
|
||||
cacheEntry := &APIKeyAuthCacheEntry{
|
||||
Snapshot: &APIKeyAuthSnapshot{
|
||||
APIKeyID: 1,
|
||||
UserID: 2,
|
||||
GroupID: &groupID,
|
||||
Status: StatusActive,
|
||||
User: APIKeyAuthUserSnapshot{
|
||||
ID: 2,
|
||||
Status: StatusActive,
|
||||
Role: RoleUser,
|
||||
Balance: 10,
|
||||
Concurrency: 3,
|
||||
},
|
||||
Group: &APIKeyAuthGroupSnapshot{
|
||||
ID: groupID,
|
||||
Name: "g",
|
||||
Platform: PlatformAnthropic,
|
||||
Status: StatusActive,
|
||||
SubscriptionType: SubscriptionTypeStandard,
|
||||
RateMultiplier: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
|
||||
return cacheEntry, nil
|
||||
}
|
||||
|
||||
apiKey, err := svc.GetByKey(context.Background(), "k1")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1), apiKey.ID)
|
||||
require.Equal(t, int64(2), apiKey.User.ID)
|
||||
require.Equal(t, groupID, apiKey.Group.ID)
|
||||
}
|
||||
|
||||
func TestAPIKeyService_GetByKey_NegativeCache(t *testing.T) {
|
||||
cache := &authCacheStub{}
|
||||
repo := &authRepoStub{
|
||||
getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) {
|
||||
return nil, errors.New("unexpected repo call")
|
||||
},
|
||||
}
|
||||
cfg := &config.Config{
|
||||
APIKeyAuth: config.APIKeyAuthCacheConfig{
|
||||
L2TTLSeconds: 60,
|
||||
NegativeTTLSeconds: 30,
|
||||
},
|
||||
}
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
|
||||
return &APIKeyAuthCacheEntry{NotFound: true}, nil
|
||||
}
|
||||
|
||||
_, err := svc.GetByKey(context.Background(), "missing")
|
||||
require.ErrorIs(t, err, ErrAPIKeyNotFound)
|
||||
}
|
||||
|
||||
func TestAPIKeyService_GetByKey_CacheMissStoresL2(t *testing.T) {
|
||||
cache := &authCacheStub{}
|
||||
repo := &authRepoStub{
|
||||
getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) {
|
||||
return &APIKey{
|
||||
ID: 5,
|
||||
UserID: 7,
|
||||
Status: StatusActive,
|
||||
User: &User{
|
||||
ID: 7,
|
||||
Status: StatusActive,
|
||||
Role: RoleUser,
|
||||
Balance: 12,
|
||||
Concurrency: 2,
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
cfg := &config.Config{
|
||||
APIKeyAuth: config.APIKeyAuthCacheConfig{
|
||||
L2TTLSeconds: 60,
|
||||
NegativeTTLSeconds: 30,
|
||||
},
|
||||
}
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
|
||||
return nil, redis.Nil
|
||||
}
|
||||
|
||||
apiKey, err := svc.GetByKey(context.Background(), "k2")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(5), apiKey.ID)
|
||||
require.Len(t, cache.setAuthKeys, 1)
|
||||
}
|
||||
|
||||
func TestAPIKeyService_GetByKey_UsesL1Cache(t *testing.T) {
|
||||
var calls int32
|
||||
cache := &authCacheStub{}
|
||||
repo := &authRepoStub{
|
||||
getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) {
|
||||
atomic.AddInt32(&calls, 1)
|
||||
return &APIKey{
|
||||
ID: 21,
|
||||
UserID: 3,
|
||||
Status: StatusActive,
|
||||
User: &User{
|
||||
ID: 3,
|
||||
Status: StatusActive,
|
||||
Role: RoleUser,
|
||||
Balance: 5,
|
||||
Concurrency: 2,
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
cfg := &config.Config{
|
||||
APIKeyAuth: config.APIKeyAuthCacheConfig{
|
||||
L1Size: 1000,
|
||||
L1TTLSeconds: 60,
|
||||
},
|
||||
}
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||
require.NotNil(t, svc.authCacheL1)
|
||||
|
||||
_, err := svc.GetByKey(context.Background(), "k-l1")
|
||||
require.NoError(t, err)
|
||||
svc.authCacheL1.Wait()
|
||||
cacheKey := svc.authCacheKey("k-l1")
|
||||
_, ok := svc.authCacheL1.Get(cacheKey)
|
||||
require.True(t, ok)
|
||||
_, err = svc.GetByKey(context.Background(), "k-l1")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&calls))
|
||||
}
|
||||
|
||||
func TestAPIKeyService_InvalidateAuthCacheByUserID(t *testing.T) {
|
||||
cache := &authCacheStub{}
|
||||
repo := &authRepoStub{
|
||||
listKeysByUserID: func(ctx context.Context, userID int64) ([]string, error) {
|
||||
return []string{"k1", "k2"}, nil
|
||||
},
|
||||
}
|
||||
cfg := &config.Config{
|
||||
APIKeyAuth: config.APIKeyAuthCacheConfig{
|
||||
L2TTLSeconds: 60,
|
||||
NegativeTTLSeconds: 30,
|
||||
},
|
||||
}
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||
|
||||
svc.InvalidateAuthCacheByUserID(context.Background(), 7)
|
||||
require.Len(t, cache.deleteAuthKeys, 2)
|
||||
}
|
||||
|
||||
func TestAPIKeyService_InvalidateAuthCacheByGroupID(t *testing.T) {
|
||||
cache := &authCacheStub{}
|
||||
repo := &authRepoStub{
|
||||
listKeysByGroupID: func(ctx context.Context, groupID int64) ([]string, error) {
|
||||
return []string{"k1", "k2"}, nil
|
||||
},
|
||||
}
|
||||
cfg := &config.Config{
|
||||
APIKeyAuth: config.APIKeyAuthCacheConfig{
|
||||
L2TTLSeconds: 60,
|
||||
},
|
||||
}
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||
|
||||
svc.InvalidateAuthCacheByGroupID(context.Background(), 9)
|
||||
require.Len(t, cache.deleteAuthKeys, 2)
|
||||
}
|
||||
|
||||
func TestAPIKeyService_InvalidateAuthCacheByKey(t *testing.T) {
|
||||
cache := &authCacheStub{}
|
||||
repo := &authRepoStub{
|
||||
listKeysByUserID: func(ctx context.Context, userID int64) ([]string, error) {
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
cfg := &config.Config{
|
||||
APIKeyAuth: config.APIKeyAuthCacheConfig{
|
||||
L2TTLSeconds: 60,
|
||||
},
|
||||
}
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||
|
||||
svc.InvalidateAuthCacheByKey(context.Background(), "k1")
|
||||
require.Len(t, cache.deleteAuthKeys, 1)
|
||||
}
|
||||
|
||||
func TestAPIKeyService_GetByKey_CachesNegativeOnRepoMiss(t *testing.T) {
|
||||
cache := &authCacheStub{}
|
||||
repo := &authRepoStub{
|
||||
getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) {
|
||||
return nil, ErrAPIKeyNotFound
|
||||
},
|
||||
}
|
||||
cfg := &config.Config{
|
||||
APIKeyAuth: config.APIKeyAuthCacheConfig{
|
||||
L2TTLSeconds: 60,
|
||||
NegativeTTLSeconds: 30,
|
||||
},
|
||||
}
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
|
||||
return nil, redis.Nil
|
||||
}
|
||||
|
||||
_, err := svc.GetByKey(context.Background(), "missing")
|
||||
require.ErrorIs(t, err, ErrAPIKeyNotFound)
|
||||
require.Len(t, cache.setAuthKeys, 1)
|
||||
}
|
||||
|
||||
func TestAPIKeyService_GetByKey_SingleflightCollapses(t *testing.T) {
|
||||
var calls int32
|
||||
cache := &authCacheStub{}
|
||||
repo := &authRepoStub{
|
||||
getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) {
|
||||
atomic.AddInt32(&calls, 1)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
return &APIKey{
|
||||
ID: 11,
|
||||
UserID: 2,
|
||||
Status: StatusActive,
|
||||
User: &User{
|
||||
ID: 2,
|
||||
Status: StatusActive,
|
||||
Role: RoleUser,
|
||||
Balance: 1,
|
||||
Concurrency: 1,
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
cfg := &config.Config{
|
||||
APIKeyAuth: config.APIKeyAuthCacheConfig{
|
||||
Singleflight: true,
|
||||
},
|
||||
}
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||
|
||||
start := make(chan struct{})
|
||||
wg := sync.WaitGroup{}
|
||||
errs := make([]error, 5)
|
||||
for i := 0; i < 5; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
<-start
|
||||
_, err := svc.GetByKey(context.Background(), "k1")
|
||||
errs[idx] = err
|
||||
}(i)
|
||||
}
|
||||
close(start)
|
||||
wg.Wait()
|
||||
|
||||
for _, err := range errs {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&calls))
|
||||
}
|
||||
252
backend/internal/service/api_key_service_delete_test.go
Normal file
252
backend/internal/service/api_key_service_delete_test.go
Normal file
@@ -0,0 +1,252 @@
|
||||
//go:build unit
|
||||
|
||||
// API Key 服务删除方法的单元测试
|
||||
// 测试 APIKeyService.Delete 方法在各种场景下的行为,
|
||||
// 包括权限验证、缓存清理和错误处理
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// apiKeyRepoStub 是 APIKeyRepository 接口的测试桩实现。
|
||||
// 用于隔离测试 APIKeyService.Delete 方法,避免依赖真实数据库。
|
||||
//
|
||||
// 设计说明:
|
||||
// - apiKey/getByIDErr: 模拟 GetKeyAndOwnerID 返回的记录与错误
|
||||
// - deleteErr: 模拟 Delete 返回的错误
|
||||
// - deletedIDs: 记录被调用删除的 API Key ID,用于断言验证
|
||||
type apiKeyRepoStub struct {
|
||||
apiKey *APIKey // GetKeyAndOwnerID 的返回值
|
||||
getByIDErr error // GetKeyAndOwnerID 的错误返回值
|
||||
deleteErr error // Delete 的错误返回值
|
||||
deletedIDs []int64 // 记录已删除的 API Key ID 列表
|
||||
}
|
||||
|
||||
// 以下方法在本测试中不应被调用,使用 panic 确保测试失败时能快速定位问题
|
||||
|
||||
func (s *apiKeyRepoStub) Create(ctx context.Context, key *APIKey) error {
|
||||
panic("unexpected Create call")
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) GetByID(ctx context.Context, id int64) (*APIKey, error) {
|
||||
if s.getByIDErr != nil {
|
||||
return nil, s.getByIDErr
|
||||
}
|
||||
if s.apiKey != nil {
|
||||
clone := *s.apiKey
|
||||
return &clone, nil
|
||||
}
|
||||
panic("unexpected GetByID call")
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) {
|
||||
if s.getByIDErr != nil {
|
||||
return "", 0, s.getByIDErr
|
||||
}
|
||||
if s.apiKey != nil {
|
||||
return s.apiKey.Key, s.apiKey.UserID, nil
|
||||
}
|
||||
return "", 0, ErrAPIKeyNotFound
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) GetByKey(ctx context.Context, key string) (*APIKey, error) {
|
||||
panic("unexpected GetByKey call")
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) GetByKeyForAuth(ctx context.Context, key string) (*APIKey, error) {
|
||||
panic("unexpected GetByKeyForAuth call")
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) Update(ctx context.Context, key *APIKey) error {
|
||||
panic("unexpected Update call")
|
||||
}
|
||||
|
||||
// Delete 记录被删除的 API Key ID 并返回预设的错误。
|
||||
// 通过 deletedIDs 可以验证删除操作是否被正确调用。
|
||||
func (s *apiKeyRepoStub) Delete(ctx context.Context, id int64) error {
|
||||
s.deletedIDs = append(s.deletedIDs, id)
|
||||
return s.deleteErr
|
||||
}
|
||||
|
||||
// 以下是接口要求实现但本测试不关心的方法
|
||||
|
||||
func (s *apiKeyRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListByUserID call")
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
|
||||
panic("unexpected VerifyOwnership call")
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) CountByUserID(ctx context.Context, userID int64) (int64, error) {
|
||||
panic("unexpected CountByUserID call")
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) ExistsByKey(ctx context.Context, key string) (bool, error) {
|
||||
panic("unexpected ExistsByKey call")
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListByGroupID call")
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error) {
|
||||
panic("unexpected SearchAPIKeys call")
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
panic("unexpected ClearGroupIDByGroupID call")
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
panic("unexpected CountByGroupID call")
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) {
|
||||
panic("unexpected ListKeysByUserID call")
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) {
|
||||
panic("unexpected ListKeysByGroupID call")
|
||||
}
|
||||
|
||||
// apiKeyCacheStub 是 APIKeyCache 接口的测试桩实现。
|
||||
// 用于验证删除操作时缓存清理逻辑是否被正确调用。
|
||||
//
|
||||
// 设计说明:
|
||||
// - invalidated: 记录被清除缓存的用户 ID 列表
|
||||
type apiKeyCacheStub struct {
|
||||
invalidated []int64 // 记录调用 DeleteCreateAttemptCount 时传入的用户 ID
|
||||
deleteAuthKeys []string // 记录调用 DeleteAuthCache 时传入的缓存 key
|
||||
}
|
||||
|
||||
// GetCreateAttemptCount 返回 0,表示用户未超过创建次数限制
|
||||
func (s *apiKeyCacheStub) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// IncrementCreateAttemptCount 空实现,本测试不验证此行为
|
||||
func (s *apiKeyCacheStub) IncrementCreateAttemptCount(ctx context.Context, userID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteCreateAttemptCount 记录被清除缓存的用户 ID。
|
||||
// 删除 API Key 时会调用此方法清除用户的创建尝试计数缓存。
|
||||
func (s *apiKeyCacheStub) DeleteCreateAttemptCount(ctx context.Context, userID int64) error {
|
||||
s.invalidated = append(s.invalidated, userID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// IncrementDailyUsage 空实现,本测试不验证此行为
|
||||
func (s *apiKeyCacheStub) IncrementDailyUsage(ctx context.Context, apiKey string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetDailyUsageExpiry 空实现,本测试不验证此行为
|
||||
func (s *apiKeyCacheStub) SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *apiKeyCacheStub) GetAuthCache(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *apiKeyCacheStub) SetAuthCache(ctx context.Context, key string, entry *APIKeyAuthCacheEntry, ttl time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *apiKeyCacheStub) DeleteAuthCache(ctx context.Context, key string) error {
|
||||
s.deleteAuthKeys = append(s.deleteAuthKeys, key)
|
||||
return nil
|
||||
}
|
||||
|
||||
// TestApiKeyService_Delete_OwnerMismatch 测试非所有者尝试删除时返回权限错误。
|
||||
// 预期行为:
|
||||
// - GetKeyAndOwnerID 返回所有者 ID 为 1
|
||||
// - 调用者 userID 为 2(不匹配)
|
||||
// - 返回 ErrInsufficientPerms 错误
|
||||
// - Delete 方法不被调用
|
||||
// - 缓存不被清除
|
||||
func TestApiKeyService_Delete_OwnerMismatch(t *testing.T) {
|
||||
repo := &apiKeyRepoStub{
|
||||
apiKey: &APIKey{ID: 10, UserID: 1, Key: "k"},
|
||||
}
|
||||
cache := &apiKeyCacheStub{}
|
||||
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
|
||||
|
||||
err := svc.Delete(context.Background(), 10, 2) // API Key ID=10, 调用者 userID=2
|
||||
require.ErrorIs(t, err, ErrInsufficientPerms)
|
||||
require.Empty(t, repo.deletedIDs) // 验证删除操作未被调用
|
||||
require.Empty(t, cache.invalidated) // 验证缓存未被清除
|
||||
require.Empty(t, cache.deleteAuthKeys)
|
||||
}
|
||||
|
||||
// TestApiKeyService_Delete_Success 测试所有者成功删除 API Key 的场景。
|
||||
// 预期行为:
|
||||
// - GetKeyAndOwnerID 返回所有者 ID 为 7
|
||||
// - 调用者 userID 为 7(匹配)
|
||||
// - Delete 成功执行
|
||||
// - 缓存被正确清除(使用 ownerID)
|
||||
// - 返回 nil 错误
|
||||
func TestApiKeyService_Delete_Success(t *testing.T) {
|
||||
repo := &apiKeyRepoStub{
|
||||
apiKey: &APIKey{ID: 42, UserID: 7, Key: "k"},
|
||||
}
|
||||
cache := &apiKeyCacheStub{}
|
||||
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
|
||||
|
||||
err := svc.Delete(context.Background(), 42, 7) // API Key ID=42, 调用者 userID=7
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []int64{42}, repo.deletedIDs) // 验证正确的 API Key 被删除
|
||||
require.Equal(t, []int64{7}, cache.invalidated) // 验证所有者的缓存被清除
|
||||
require.Equal(t, []string{svc.authCacheKey("k")}, cache.deleteAuthKeys)
|
||||
}
|
||||
|
||||
// TestApiKeyService_Delete_NotFound 测试删除不存在的 API Key 时返回正确的错误。
|
||||
// 预期行为:
|
||||
// - GetKeyAndOwnerID 返回 ErrAPIKeyNotFound 错误
|
||||
// - 返回 ErrAPIKeyNotFound 错误(被 fmt.Errorf 包装)
|
||||
// - Delete 方法不被调用
|
||||
// - 缓存不被清除
|
||||
func TestApiKeyService_Delete_NotFound(t *testing.T) {
|
||||
repo := &apiKeyRepoStub{getByIDErr: ErrAPIKeyNotFound}
|
||||
cache := &apiKeyCacheStub{}
|
||||
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
|
||||
|
||||
err := svc.Delete(context.Background(), 99, 1)
|
||||
require.ErrorIs(t, err, ErrAPIKeyNotFound)
|
||||
require.Empty(t, repo.deletedIDs)
|
||||
require.Empty(t, cache.invalidated)
|
||||
require.Empty(t, cache.deleteAuthKeys)
|
||||
}
|
||||
|
||||
// TestApiKeyService_Delete_DeleteFails 测试删除操作失败时的错误处理。
|
||||
// 预期行为:
|
||||
// - GetKeyAndOwnerID 返回正确的所有者 ID
|
||||
// - 所有权验证通过
|
||||
// - 缓存被清除(在删除之前)
|
||||
// - Delete 被调用但返回错误
|
||||
// - 返回包含 "delete api key" 的错误信息
|
||||
func TestApiKeyService_Delete_DeleteFails(t *testing.T) {
|
||||
repo := &apiKeyRepoStub{
|
||||
apiKey: &APIKey{ID: 42, UserID: 3, Key: "k"},
|
||||
deleteErr: errors.New("delete failed"),
|
||||
}
|
||||
cache := &apiKeyCacheStub{}
|
||||
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
|
||||
|
||||
err := svc.Delete(context.Background(), 3, 3) // API Key ID=3, 调用者 userID=3
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, "delete api key")
|
||||
require.Equal(t, []int64{3}, repo.deletedIDs) // 验证删除操作被调用
|
||||
require.Equal(t, []int64{3}, cache.invalidated) // 验证缓存已被清除(即使删除失败)
|
||||
require.Equal(t, []string{svc.authCacheKey("k")}, cache.deleteAuthKeys)
|
||||
}
|
||||
33
backend/internal/service/auth_cache_invalidation_test.go
Normal file
33
backend/internal/service/auth_cache_invalidation_test.go
Normal file
@@ -0,0 +1,33 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestUsageService_InvalidateUsageCaches(t *testing.T) {
|
||||
invalidator := &authCacheInvalidatorStub{}
|
||||
svc := &UsageService{authCacheInvalidator: invalidator}
|
||||
|
||||
svc.invalidateUsageCaches(context.Background(), 7, false)
|
||||
require.Empty(t, invalidator.userIDs)
|
||||
|
||||
svc.invalidateUsageCaches(context.Background(), 7, true)
|
||||
require.Equal(t, []int64{7}, invalidator.userIDs)
|
||||
}
|
||||
|
||||
func TestRedeemService_InvalidateRedeemCaches_AuthCache(t *testing.T) {
|
||||
invalidator := &authCacheInvalidatorStub{}
|
||||
svc := &RedeemService{authCacheInvalidator: invalidator}
|
||||
|
||||
svc.invalidateRedeemCaches(context.Background(), 11, &RedeemCode{Type: RedeemTypeBalance})
|
||||
svc.invalidateRedeemCaches(context.Background(), 11, &RedeemCode{Type: RedeemTypeConcurrency})
|
||||
groupID := int64(3)
|
||||
svc.invalidateRedeemCaches(context.Background(), 11, &RedeemCode{Type: RedeemTypeSubscription, GroupID: &groupID})
|
||||
|
||||
require.Equal(t, []int64{11, 11, 11}, invalidator.userIDs)
|
||||
}
|
||||
582
backend/internal/service/auth_service.go
Normal file
582
backend/internal/service/auth_service.go
Normal file
@@ -0,0 +1,582 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/mail"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidCredentials = infraerrors.Unauthorized("INVALID_CREDENTIALS", "invalid email or password")
|
||||
ErrUserNotActive = infraerrors.Forbidden("USER_NOT_ACTIVE", "user is not active")
|
||||
ErrEmailExists = infraerrors.Conflict("EMAIL_EXISTS", "email already exists")
|
||||
ErrEmailReserved = infraerrors.BadRequest("EMAIL_RESERVED", "email is reserved")
|
||||
ErrInvalidToken = infraerrors.Unauthorized("INVALID_TOKEN", "invalid token")
|
||||
ErrTokenExpired = infraerrors.Unauthorized("TOKEN_EXPIRED", "token has expired")
|
||||
ErrTokenTooLarge = infraerrors.BadRequest("TOKEN_TOO_LARGE", "token too large")
|
||||
ErrTokenRevoked = infraerrors.Unauthorized("TOKEN_REVOKED", "token has been revoked")
|
||||
ErrEmailVerifyRequired = infraerrors.BadRequest("EMAIL_VERIFY_REQUIRED", "email verification is required")
|
||||
ErrRegDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
|
||||
ErrServiceUnavailable = infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable")
|
||||
)
|
||||
|
||||
// maxTokenLength 限制 token 大小,避免超长 header 触发解析时的异常内存分配。
|
||||
const maxTokenLength = 8192
|
||||
|
||||
// JWTClaims JWT载荷数据
|
||||
type JWTClaims struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
Email string `json:"email"`
|
||||
Role string `json:"role"`
|
||||
TokenVersion int64 `json:"token_version"` // Used to invalidate tokens on password change
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
// AuthService 认证服务
|
||||
type AuthService struct {
|
||||
userRepo UserRepository
|
||||
cfg *config.Config
|
||||
settingService *SettingService
|
||||
emailService *EmailService
|
||||
turnstileService *TurnstileService
|
||||
emailQueueService *EmailQueueService
|
||||
promoService *PromoService
|
||||
}
|
||||
|
||||
// NewAuthService 创建认证服务实例
|
||||
func NewAuthService(
|
||||
userRepo UserRepository,
|
||||
cfg *config.Config,
|
||||
settingService *SettingService,
|
||||
emailService *EmailService,
|
||||
turnstileService *TurnstileService,
|
||||
emailQueueService *EmailQueueService,
|
||||
promoService *PromoService,
|
||||
) *AuthService {
|
||||
return &AuthService{
|
||||
userRepo: userRepo,
|
||||
cfg: cfg,
|
||||
settingService: settingService,
|
||||
emailService: emailService,
|
||||
turnstileService: turnstileService,
|
||||
emailQueueService: emailQueueService,
|
||||
promoService: promoService,
|
||||
}
|
||||
}
|
||||
|
||||
// Register 用户注册,返回token和用户
|
||||
func (s *AuthService) Register(ctx context.Context, email, password string) (string, *User, error) {
|
||||
return s.RegisterWithVerification(ctx, email, password, "", "")
|
||||
}
|
||||
|
||||
// RegisterWithVerification 用户注册(支持邮件验证和优惠码),返回token和用户
|
||||
func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode, promoCode string) (string, *User, error) {
|
||||
// 检查是否开放注册(默认关闭:settingService 未配置时不允许注册)
|
||||
if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
|
||||
return "", nil, ErrRegDisabled
|
||||
}
|
||||
|
||||
// 防止用户注册 LinuxDo OAuth 合成邮箱,避免第三方登录与本地账号发生碰撞。
|
||||
if isReservedEmail(email) {
|
||||
return "", nil, ErrEmailReserved
|
||||
}
|
||||
|
||||
// 检查是否需要邮件验证
|
||||
if s.settingService != nil && s.settingService.IsEmailVerifyEnabled(ctx) {
|
||||
// 如果邮件验证已开启但邮件服务未配置,拒绝注册
|
||||
// 这是一个配置错误,不应该允许绕过验证
|
||||
if s.emailService == nil {
|
||||
log.Println("[Auth] Email verification enabled but email service not configured, rejecting registration")
|
||||
return "", nil, ErrServiceUnavailable
|
||||
}
|
||||
if verifyCode == "" {
|
||||
return "", nil, ErrEmailVerifyRequired
|
||||
}
|
||||
// 验证邮箱验证码
|
||||
if err := s.emailService.VerifyCode(ctx, email, verifyCode); err != nil {
|
||||
return "", nil, fmt.Errorf("verify code: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 检查邮箱是否已存在
|
||||
existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
|
||||
if err != nil {
|
||||
log.Printf("[Auth] Database error checking email exists: %v", err)
|
||||
return "", nil, ErrServiceUnavailable
|
||||
}
|
||||
if existsEmail {
|
||||
return "", nil, ErrEmailExists
|
||||
}
|
||||
|
||||
// 密码哈希
|
||||
hashedPassword, err := s.HashPassword(password)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("hash password: %w", err)
|
||||
}
|
||||
|
||||
// 获取默认配置
|
||||
defaultBalance := s.cfg.Default.UserBalance
|
||||
defaultConcurrency := s.cfg.Default.UserConcurrency
|
||||
if s.settingService != nil {
|
||||
defaultBalance = s.settingService.GetDefaultBalance(ctx)
|
||||
defaultConcurrency = s.settingService.GetDefaultConcurrency(ctx)
|
||||
}
|
||||
|
||||
// 创建用户
|
||||
user := &User{
|
||||
Email: email,
|
||||
PasswordHash: hashedPassword,
|
||||
Role: RoleUser,
|
||||
Balance: defaultBalance,
|
||||
Concurrency: defaultConcurrency,
|
||||
Status: StatusActive,
|
||||
}
|
||||
|
||||
if err := s.userRepo.Create(ctx, user); err != nil {
|
||||
// 优先检查邮箱冲突错误(竞态条件下可能发生)
|
||||
if errors.Is(err, ErrEmailExists) {
|
||||
return "", nil, ErrEmailExists
|
||||
}
|
||||
log.Printf("[Auth] Database error creating user: %v", err)
|
||||
return "", nil, ErrServiceUnavailable
|
||||
}
|
||||
|
||||
// 应用优惠码(如果提供)
|
||||
if promoCode != "" && s.promoService != nil {
|
||||
if err := s.promoService.ApplyPromoCode(ctx, user.ID, promoCode); err != nil {
|
||||
// 优惠码应用失败不影响注册,只记录日志
|
||||
log.Printf("[Auth] Failed to apply promo code for user %d: %v", user.ID, err)
|
||||
} else {
|
||||
// 重新获取用户信息以获取更新后的余额
|
||||
if updatedUser, err := s.userRepo.GetByID(ctx, user.ID); err == nil {
|
||||
user = updatedUser
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 生成token
|
||||
token, err := s.GenerateToken(user)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("generate token: %w", err)
|
||||
}
|
||||
|
||||
return token, user, nil
|
||||
}
|
||||
|
||||
// SendVerifyCodeResult 发送验证码返回结果
|
||||
type SendVerifyCodeResult struct {
|
||||
Countdown int `json:"countdown"` // 倒计时秒数
|
||||
}
|
||||
|
||||
// SendVerifyCode 发送邮箱验证码(同步方式)
|
||||
func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error {
|
||||
// 检查是否开放注册(默认关闭)
|
||||
if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
|
||||
return ErrRegDisabled
|
||||
}
|
||||
|
||||
if isReservedEmail(email) {
|
||||
return ErrEmailReserved
|
||||
}
|
||||
|
||||
// 检查邮箱是否已存在
|
||||
existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
|
||||
if err != nil {
|
||||
log.Printf("[Auth] Database error checking email exists: %v", err)
|
||||
return ErrServiceUnavailable
|
||||
}
|
||||
if existsEmail {
|
||||
return ErrEmailExists
|
||||
}
|
||||
|
||||
// 发送验证码
|
||||
if s.emailService == nil {
|
||||
return errors.New("email service not configured")
|
||||
}
|
||||
|
||||
// 获取网站名称
|
||||
siteName := "Sub2API"
|
||||
if s.settingService != nil {
|
||||
siteName = s.settingService.GetSiteName(ctx)
|
||||
}
|
||||
|
||||
return s.emailService.SendVerifyCode(ctx, email, siteName)
|
||||
}
|
||||
|
||||
// SendVerifyCodeAsync 异步发送邮箱验证码并返回倒计时
|
||||
func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*SendVerifyCodeResult, error) {
|
||||
log.Printf("[Auth] SendVerifyCodeAsync called for email: %s", email)
|
||||
|
||||
// 检查是否开放注册(默认关闭)
|
||||
if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
|
||||
log.Println("[Auth] Registration is disabled")
|
||||
return nil, ErrRegDisabled
|
||||
}
|
||||
|
||||
if isReservedEmail(email) {
|
||||
return nil, ErrEmailReserved
|
||||
}
|
||||
|
||||
// 检查邮箱是否已存在
|
||||
existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
|
||||
if err != nil {
|
||||
log.Printf("[Auth] Database error checking email exists: %v", err)
|
||||
return nil, ErrServiceUnavailable
|
||||
}
|
||||
if existsEmail {
|
||||
log.Printf("[Auth] Email already exists: %s", email)
|
||||
return nil, ErrEmailExists
|
||||
}
|
||||
|
||||
// 检查邮件队列服务是否配置
|
||||
if s.emailQueueService == nil {
|
||||
log.Println("[Auth] Email queue service not configured")
|
||||
return nil, errors.New("email queue service not configured")
|
||||
}
|
||||
|
||||
// 获取网站名称
|
||||
siteName := "Sub2API"
|
||||
if s.settingService != nil {
|
||||
siteName = s.settingService.GetSiteName(ctx)
|
||||
}
|
||||
|
||||
// 异步发送
|
||||
log.Printf("[Auth] Enqueueing verify code for: %s", email)
|
||||
if err := s.emailQueueService.EnqueueVerifyCode(email, siteName); err != nil {
|
||||
log.Printf("[Auth] Failed to enqueue: %v", err)
|
||||
return nil, fmt.Errorf("enqueue verify code: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("[Auth] Verify code enqueued successfully for: %s", email)
|
||||
return &SendVerifyCodeResult{
|
||||
Countdown: 60, // 60秒倒计时
|
||||
}, nil
|
||||
}
|
||||
|
||||
// VerifyTurnstile 验证Turnstile token
|
||||
func (s *AuthService) VerifyTurnstile(ctx context.Context, token string, remoteIP string) error {
|
||||
required := s.cfg != nil && s.cfg.Server.Mode == "release" && s.cfg.Turnstile.Required
|
||||
|
||||
if required {
|
||||
if s.settingService == nil {
|
||||
log.Println("[Auth] Turnstile required but settings service is not configured")
|
||||
return ErrTurnstileNotConfigured
|
||||
}
|
||||
enabled := s.settingService.IsTurnstileEnabled(ctx)
|
||||
secretConfigured := s.settingService.GetTurnstileSecretKey(ctx) != ""
|
||||
if !enabled || !secretConfigured {
|
||||
log.Printf("[Auth] Turnstile required but not configured (enabled=%v, secret_configured=%v)", enabled, secretConfigured)
|
||||
return ErrTurnstileNotConfigured
|
||||
}
|
||||
}
|
||||
|
||||
if s.turnstileService == nil {
|
||||
if required {
|
||||
log.Println("[Auth] Turnstile required but service not configured")
|
||||
return ErrTurnstileNotConfigured
|
||||
}
|
||||
return nil // 服务未配置则跳过验证
|
||||
}
|
||||
|
||||
if !required && s.settingService != nil && s.settingService.IsTurnstileEnabled(ctx) && s.settingService.GetTurnstileSecretKey(ctx) == "" {
|
||||
log.Println("[Auth] Turnstile enabled but secret key not configured")
|
||||
}
|
||||
|
||||
return s.turnstileService.VerifyToken(ctx, token, remoteIP)
|
||||
}
|
||||
|
||||
// IsTurnstileEnabled 检查是否启用Turnstile验证
|
||||
func (s *AuthService) IsTurnstileEnabled(ctx context.Context) bool {
|
||||
if s.turnstileService == nil {
|
||||
return false
|
||||
}
|
||||
return s.turnstileService.IsEnabled(ctx)
|
||||
}
|
||||
|
||||
// IsRegistrationEnabled 检查是否开放注册
|
||||
func (s *AuthService) IsRegistrationEnabled(ctx context.Context) bool {
|
||||
if s.settingService == nil {
|
||||
return false // 安全默认:settingService 未配置时关闭注册
|
||||
}
|
||||
return s.settingService.IsRegistrationEnabled(ctx)
|
||||
}
|
||||
|
||||
// IsEmailVerifyEnabled 检查是否开启邮件验证
|
||||
func (s *AuthService) IsEmailVerifyEnabled(ctx context.Context) bool {
|
||||
if s.settingService == nil {
|
||||
return false
|
||||
}
|
||||
return s.settingService.IsEmailVerifyEnabled(ctx)
|
||||
}
|
||||
|
||||
// Login 用户登录,返回JWT token
|
||||
func (s *AuthService) Login(ctx context.Context, email, password string) (string, *User, error) {
|
||||
// 查找用户
|
||||
user, err := s.userRepo.GetByEmail(ctx, email)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrUserNotFound) {
|
||||
return "", nil, ErrInvalidCredentials
|
||||
}
|
||||
// 记录数据库错误但不暴露给用户
|
||||
log.Printf("[Auth] Database error during login: %v", err)
|
||||
return "", nil, ErrServiceUnavailable
|
||||
}
|
||||
|
||||
// 验证密码
|
||||
if !s.CheckPassword(password, user.PasswordHash) {
|
||||
return "", nil, ErrInvalidCredentials
|
||||
}
|
||||
|
||||
// 检查用户状态
|
||||
if !user.IsActive() {
|
||||
return "", nil, ErrUserNotActive
|
||||
}
|
||||
|
||||
// 生成JWT token
|
||||
token, err := s.GenerateToken(user)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("generate token: %w", err)
|
||||
}
|
||||
|
||||
return token, user, nil
|
||||
}
|
||||
|
||||
// LoginOrRegisterOAuth 用于第三方 OAuth/SSO 登录:
|
||||
// - 如果邮箱已存在:直接登录(不需要本地密码)
|
||||
// - 如果邮箱不存在:创建新用户并登录
|
||||
//
|
||||
// 注意:该函数用于 LinuxDo OAuth 登录场景(不同于上游账号的 OAuth,例如 Claude/OpenAI/Gemini)。
|
||||
// 为了满足现有数据库约束(需要密码哈希),新用户会生成随机密码并进行哈希保存。
|
||||
func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username string) (string, *User, error) {
|
||||
email = strings.TrimSpace(email)
|
||||
if email == "" || len(email) > 255 {
|
||||
return "", nil, infraerrors.BadRequest("INVALID_EMAIL", "invalid email")
|
||||
}
|
||||
if _, err := mail.ParseAddress(email); err != nil {
|
||||
return "", nil, infraerrors.BadRequest("INVALID_EMAIL", "invalid email")
|
||||
}
|
||||
|
||||
username = strings.TrimSpace(username)
|
||||
if len([]rune(username)) > 100 {
|
||||
username = string([]rune(username)[:100])
|
||||
}
|
||||
|
||||
user, err := s.userRepo.GetByEmail(ctx, email)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrUserNotFound) {
|
||||
// OAuth 首次登录视为注册(fail-close:settingService 未配置时不允许注册)
|
||||
if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
|
||||
return "", nil, ErrRegDisabled
|
||||
}
|
||||
|
||||
randomPassword, err := randomHexString(32)
|
||||
if err != nil {
|
||||
log.Printf("[Auth] Failed to generate random password for oauth signup: %v", err)
|
||||
return "", nil, ErrServiceUnavailable
|
||||
}
|
||||
hashedPassword, err := s.HashPassword(randomPassword)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("hash password: %w", err)
|
||||
}
|
||||
|
||||
// 新用户默认值。
|
||||
defaultBalance := s.cfg.Default.UserBalance
|
||||
defaultConcurrency := s.cfg.Default.UserConcurrency
|
||||
if s.settingService != nil {
|
||||
defaultBalance = s.settingService.GetDefaultBalance(ctx)
|
||||
defaultConcurrency = s.settingService.GetDefaultConcurrency(ctx)
|
||||
}
|
||||
|
||||
newUser := &User{
|
||||
Email: email,
|
||||
Username: username,
|
||||
PasswordHash: hashedPassword,
|
||||
Role: RoleUser,
|
||||
Balance: defaultBalance,
|
||||
Concurrency: defaultConcurrency,
|
||||
Status: StatusActive,
|
||||
}
|
||||
|
||||
if err := s.userRepo.Create(ctx, newUser); err != nil {
|
||||
if errors.Is(err, ErrEmailExists) {
|
||||
// 并发场景:GetByEmail 与 Create 之间用户被创建。
|
||||
user, err = s.userRepo.GetByEmail(ctx, email)
|
||||
if err != nil {
|
||||
log.Printf("[Auth] Database error getting user after conflict: %v", err)
|
||||
return "", nil, ErrServiceUnavailable
|
||||
}
|
||||
} else {
|
||||
log.Printf("[Auth] Database error creating oauth user: %v", err)
|
||||
return "", nil, ErrServiceUnavailable
|
||||
}
|
||||
} else {
|
||||
user = newUser
|
||||
}
|
||||
} else {
|
||||
log.Printf("[Auth] Database error during oauth login: %v", err)
|
||||
return "", nil, ErrServiceUnavailable
|
||||
}
|
||||
}
|
||||
|
||||
if !user.IsActive() {
|
||||
return "", nil, ErrUserNotActive
|
||||
}
|
||||
|
||||
// 尽力补全:当用户名为空时,使用第三方返回的用户名回填。
|
||||
if user.Username == "" && username != "" {
|
||||
user.Username = username
|
||||
if err := s.userRepo.Update(ctx, user); err != nil {
|
||||
log.Printf("[Auth] Failed to update username after oauth login: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
token, err := s.GenerateToken(user)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("generate token: %w", err)
|
||||
}
|
||||
return token, user, nil
|
||||
}
|
||||
|
||||
// ValidateToken 验证JWT token并返回用户声明
|
||||
func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) {
|
||||
// 先做长度校验,尽早拒绝异常超长 token,降低 DoS 风险。
|
||||
if len(tokenString) > maxTokenLength {
|
||||
return nil, ErrTokenTooLarge
|
||||
}
|
||||
|
||||
// 使用解析器并限制可接受的签名算法,防止算法混淆。
|
||||
parser := jwt.NewParser(jwt.WithValidMethods([]string{
|
||||
jwt.SigningMethodHS256.Name,
|
||||
jwt.SigningMethodHS384.Name,
|
||||
jwt.SigningMethodHS512.Name,
|
||||
}))
|
||||
|
||||
// 保留默认 claims 校验(exp/nbf),避免放行过期或未生效的 token。
|
||||
token, err := parser.ParseWithClaims(tokenString, &JWTClaims{}, func(token *jwt.Token) (any, error) {
|
||||
// 验证签名方法
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
return []byte(s.cfg.JWT.Secret), nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, jwt.ErrTokenExpired) {
|
||||
// token 过期但仍返回 claims(用于 RefreshToken 等场景)
|
||||
// jwt-go 在解析时即使遇到过期错误,token.Claims 仍会被填充
|
||||
if claims, ok := token.Claims.(*JWTClaims); ok {
|
||||
return claims, ErrTokenExpired
|
||||
}
|
||||
return nil, ErrTokenExpired
|
||||
}
|
||||
return nil, ErrInvalidToken
|
||||
}
|
||||
|
||||
if claims, ok := token.Claims.(*JWTClaims); ok && token.Valid {
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
return nil, ErrInvalidToken
|
||||
}
|
||||
|
||||
func randomHexString(byteLength int) (string, error) {
|
||||
if byteLength <= 0 {
|
||||
byteLength = 16
|
||||
}
|
||||
buf := make([]byte, byteLength)
|
||||
if _, err := rand.Read(buf); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(buf), nil
|
||||
}
|
||||
|
||||
func isReservedEmail(email string) bool {
|
||||
normalized := strings.ToLower(strings.TrimSpace(email))
|
||||
return strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain)
|
||||
}
|
||||
|
||||
// GenerateToken 生成JWT token
|
||||
func (s *AuthService) GenerateToken(user *User) (string, error) {
|
||||
now := time.Now()
|
||||
expiresAt := now.Add(time.Duration(s.cfg.JWT.ExpireHour) * time.Hour)
|
||||
|
||||
claims := &JWTClaims{
|
||||
UserID: user.ID,
|
||||
Email: user.Email,
|
||||
Role: user.Role,
|
||||
TokenVersion: user.TokenVersion,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(expiresAt),
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
NotBefore: jwt.NewNumericDate(now),
|
||||
},
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
tokenString, err := token.SignedString([]byte(s.cfg.JWT.Secret))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("sign token: %w", err)
|
||||
}
|
||||
|
||||
return tokenString, nil
|
||||
}
|
||||
|
||||
// HashPassword 使用bcrypt加密密码
|
||||
func (s *AuthService) HashPassword(password string) (string, error) {
|
||||
hashedBytes, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(hashedBytes), nil
|
||||
}
|
||||
|
||||
// CheckPassword 验证密码是否匹配
|
||||
func (s *AuthService) CheckPassword(password, hashedPassword string) bool {
|
||||
err := bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(password))
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// RefreshToken 刷新token
|
||||
func (s *AuthService) RefreshToken(ctx context.Context, oldTokenString string) (string, error) {
|
||||
// 验证旧token(即使过期也允许,用于刷新)
|
||||
claims, err := s.ValidateToken(oldTokenString)
|
||||
if err != nil && !errors.Is(err, ErrTokenExpired) {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 获取最新的用户信息
|
||||
user, err := s.userRepo.GetByID(ctx, claims.UserID)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrUserNotFound) {
|
||||
return "", ErrInvalidToken
|
||||
}
|
||||
log.Printf("[Auth] Database error refreshing token: %v", err)
|
||||
return "", ErrServiceUnavailable
|
||||
}
|
||||
|
||||
// 检查用户状态
|
||||
if !user.IsActive() {
|
||||
return "", ErrUserNotActive
|
||||
}
|
||||
|
||||
// Security: Check TokenVersion to prevent refreshing revoked tokens
|
||||
// This ensures tokens issued before a password change cannot be refreshed
|
||||
if claims.TokenVersion != user.TokenVersion {
|
||||
return "", ErrTokenRevoked
|
||||
}
|
||||
|
||||
// 生成新token
|
||||
return s.GenerateToken(user)
|
||||
}
|
||||
295
backend/internal/service/auth_service_register_test.go
Normal file
295
backend/internal/service/auth_service_register_test.go
Normal file
@@ -0,0 +1,295 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type settingRepoStub struct {
|
||||
values map[string]string
|
||||
err error
|
||||
}
|
||||
|
||||
func (s *settingRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
|
||||
panic("unexpected Get call")
|
||||
}
|
||||
|
||||
func (s *settingRepoStub) GetValue(ctx context.Context, key string) (string, error) {
|
||||
if s.err != nil {
|
||||
return "", s.err
|
||||
}
|
||||
if v, ok := s.values[key]; ok {
|
||||
return v, nil
|
||||
}
|
||||
return "", ErrSettingNotFound
|
||||
}
|
||||
|
||||
func (s *settingRepoStub) Set(ctx context.Context, key, value string) error {
|
||||
panic("unexpected Set call")
|
||||
}
|
||||
|
||||
func (s *settingRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
|
||||
panic("unexpected GetMultiple call")
|
||||
}
|
||||
|
||||
func (s *settingRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
|
||||
panic("unexpected SetMultiple call")
|
||||
}
|
||||
|
||||
func (s *settingRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
|
||||
panic("unexpected GetAll call")
|
||||
}
|
||||
|
||||
func (s *settingRepoStub) Delete(ctx context.Context, key string) error {
|
||||
panic("unexpected Delete call")
|
||||
}
|
||||
|
||||
type emailCacheStub struct {
|
||||
data *VerificationCodeData
|
||||
err error
|
||||
}
|
||||
|
||||
func (s *emailCacheStub) GetVerificationCode(ctx context.Context, email string) (*VerificationCodeData, error) {
|
||||
if s.err != nil {
|
||||
return nil, s.err
|
||||
}
|
||||
return s.data, nil
|
||||
}
|
||||
|
||||
func (s *emailCacheStub) SetVerificationCode(ctx context.Context, email string, data *VerificationCodeData, ttl time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *emailCacheStub) DeleteVerificationCode(ctx context.Context, email string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func newAuthService(repo *userRepoStub, settings map[string]string, emailCache EmailCache) *AuthService {
|
||||
cfg := &config.Config{
|
||||
JWT: config.JWTConfig{
|
||||
Secret: "test-secret",
|
||||
ExpireHour: 1,
|
||||
},
|
||||
Default: config.DefaultConfig{
|
||||
UserBalance: 3.5,
|
||||
UserConcurrency: 2,
|
||||
},
|
||||
}
|
||||
|
||||
var settingService *SettingService
|
||||
if settings != nil {
|
||||
settingService = NewSettingService(&settingRepoStub{values: settings}, cfg)
|
||||
}
|
||||
|
||||
var emailService *EmailService
|
||||
if emailCache != nil {
|
||||
emailService = NewEmailService(&settingRepoStub{values: settings}, emailCache)
|
||||
}
|
||||
|
||||
return NewAuthService(
|
||||
repo,
|
||||
cfg,
|
||||
settingService,
|
||||
emailService,
|
||||
nil,
|
||||
nil,
|
||||
nil, // promoService
|
||||
)
|
||||
}
|
||||
|
||||
func TestAuthService_Register_Disabled(t *testing.T) {
|
||||
repo := &userRepoStub{}
|
||||
service := newAuthService(repo, map[string]string{
|
||||
SettingKeyRegistrationEnabled: "false",
|
||||
}, nil)
|
||||
|
||||
_, _, err := service.Register(context.Background(), "user@test.com", "password")
|
||||
require.ErrorIs(t, err, ErrRegDisabled)
|
||||
}
|
||||
|
||||
func TestAuthService_Register_DisabledByDefault(t *testing.T) {
|
||||
// 当 settings 为 nil(设置项不存在)时,注册应该默认关闭
|
||||
repo := &userRepoStub{}
|
||||
service := newAuthService(repo, nil, nil)
|
||||
|
||||
_, _, err := service.Register(context.Background(), "user@test.com", "password")
|
||||
require.ErrorIs(t, err, ErrRegDisabled)
|
||||
}
|
||||
|
||||
func TestAuthService_Register_EmailVerifyEnabledButServiceNotConfigured(t *testing.T) {
|
||||
repo := &userRepoStub{}
|
||||
// 邮件验证开启但 emailCache 为 nil(emailService 未配置)
|
||||
service := newAuthService(repo, map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
SettingKeyEmailVerifyEnabled: "true",
|
||||
}, nil)
|
||||
|
||||
// 应返回服务不可用错误,而不是允许绕过验证
|
||||
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "any-code", "")
|
||||
require.ErrorIs(t, err, ErrServiceUnavailable)
|
||||
}
|
||||
|
||||
func TestAuthService_Register_EmailVerifyRequired(t *testing.T) {
|
||||
repo := &userRepoStub{}
|
||||
cache := &emailCacheStub{} // 配置 emailService
|
||||
service := newAuthService(repo, map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
SettingKeyEmailVerifyEnabled: "true",
|
||||
}, cache)
|
||||
|
||||
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "", "")
|
||||
require.ErrorIs(t, err, ErrEmailVerifyRequired)
|
||||
}
|
||||
|
||||
func TestAuthService_Register_EmailVerifyInvalid(t *testing.T) {
|
||||
repo := &userRepoStub{}
|
||||
cache := &emailCacheStub{
|
||||
data: &VerificationCodeData{Code: "expected", Attempts: 0},
|
||||
}
|
||||
service := newAuthService(repo, map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
SettingKeyEmailVerifyEnabled: "true",
|
||||
}, cache)
|
||||
|
||||
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "wrong", "")
|
||||
require.ErrorIs(t, err, ErrInvalidVerifyCode)
|
||||
require.ErrorContains(t, err, "verify code")
|
||||
}
|
||||
|
||||
func TestAuthService_Register_EmailExists(t *testing.T) {
|
||||
repo := &userRepoStub{exists: true}
|
||||
service := newAuthService(repo, map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
}, nil)
|
||||
|
||||
_, _, err := service.Register(context.Background(), "user@test.com", "password")
|
||||
require.ErrorIs(t, err, ErrEmailExists)
|
||||
}
|
||||
|
||||
func TestAuthService_Register_CheckEmailError(t *testing.T) {
|
||||
repo := &userRepoStub{existsErr: errors.New("db down")}
|
||||
service := newAuthService(repo, map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
}, nil)
|
||||
|
||||
_, _, err := service.Register(context.Background(), "user@test.com", "password")
|
||||
require.ErrorIs(t, err, ErrServiceUnavailable)
|
||||
}
|
||||
|
||||
func TestAuthService_Register_ReservedEmail(t *testing.T) {
|
||||
repo := &userRepoStub{}
|
||||
service := newAuthService(repo, map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
}, nil)
|
||||
|
||||
_, _, err := service.Register(context.Background(), "linuxdo-123@linuxdo-connect.invalid", "password")
|
||||
require.ErrorIs(t, err, ErrEmailReserved)
|
||||
}
|
||||
|
||||
func TestAuthService_Register_CreateError(t *testing.T) {
|
||||
repo := &userRepoStub{createErr: errors.New("create failed")}
|
||||
service := newAuthService(repo, map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
}, nil)
|
||||
|
||||
_, _, err := service.Register(context.Background(), "user@test.com", "password")
|
||||
require.ErrorIs(t, err, ErrServiceUnavailable)
|
||||
}
|
||||
|
||||
func TestAuthService_Register_CreateEmailExistsRace(t *testing.T) {
|
||||
// 模拟竞态条件:ExistsByEmail 返回 false,但 Create 时因唯一约束失败
|
||||
repo := &userRepoStub{createErr: ErrEmailExists}
|
||||
service := newAuthService(repo, map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
}, nil)
|
||||
|
||||
_, _, err := service.Register(context.Background(), "user@test.com", "password")
|
||||
require.ErrorIs(t, err, ErrEmailExists)
|
||||
}
|
||||
|
||||
func TestAuthService_Register_Success(t *testing.T) {
|
||||
repo := &userRepoStub{nextID: 5}
|
||||
service := newAuthService(repo, map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
}, nil)
|
||||
|
||||
token, user, err := service.Register(context.Background(), "user@test.com", "password")
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, token)
|
||||
require.NotNil(t, user)
|
||||
require.Equal(t, int64(5), user.ID)
|
||||
require.Equal(t, "user@test.com", user.Email)
|
||||
require.Equal(t, RoleUser, user.Role)
|
||||
require.Equal(t, StatusActive, user.Status)
|
||||
require.Equal(t, 3.5, user.Balance)
|
||||
require.Equal(t, 2, user.Concurrency)
|
||||
require.Len(t, repo.created, 1)
|
||||
require.True(t, user.CheckPassword("password"))
|
||||
}
|
||||
|
||||
func TestAuthService_ValidateToken_ExpiredReturnsClaimsWithError(t *testing.T) {
|
||||
repo := &userRepoStub{}
|
||||
service := newAuthService(repo, nil, nil)
|
||||
|
||||
// 创建用户并生成 token
|
||||
user := &User{
|
||||
ID: 1,
|
||||
Email: "test@test.com",
|
||||
Role: RoleUser,
|
||||
Status: StatusActive,
|
||||
TokenVersion: 1,
|
||||
}
|
||||
token, err := service.GenerateToken(user)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 验证有效 token
|
||||
claims, err := service.ValidateToken(token)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, claims)
|
||||
require.Equal(t, int64(1), claims.UserID)
|
||||
|
||||
// 模拟过期 token(通过创建一个过期很久的 token)
|
||||
service.cfg.JWT.ExpireHour = -1 // 设置为负数使 token 立即过期
|
||||
expiredToken, err := service.GenerateToken(user)
|
||||
require.NoError(t, err)
|
||||
service.cfg.JWT.ExpireHour = 1 // 恢复
|
||||
|
||||
// 验证过期 token 应返回 claims 和 ErrTokenExpired
|
||||
claims, err = service.ValidateToken(expiredToken)
|
||||
require.ErrorIs(t, err, ErrTokenExpired)
|
||||
require.NotNil(t, claims, "claims should not be nil when token is expired")
|
||||
require.Equal(t, int64(1), claims.UserID)
|
||||
require.Equal(t, "test@test.com", claims.Email)
|
||||
}
|
||||
|
||||
func TestAuthService_RefreshToken_ExpiredTokenNoPanic(t *testing.T) {
|
||||
user := &User{
|
||||
ID: 1,
|
||||
Email: "test@test.com",
|
||||
Role: RoleUser,
|
||||
Status: StatusActive,
|
||||
TokenVersion: 1,
|
||||
}
|
||||
repo := &userRepoStub{user: user}
|
||||
service := newAuthService(repo, nil, nil)
|
||||
|
||||
// 创建过期 token
|
||||
service.cfg.JWT.ExpireHour = -1
|
||||
expiredToken, err := service.GenerateToken(user)
|
||||
require.NoError(t, err)
|
||||
service.cfg.JWT.ExpireHour = 1
|
||||
|
||||
// RefreshToken 使用过期 token 不应 panic
|
||||
require.NotPanics(t, func() {
|
||||
newToken, err := service.RefreshToken(context.Background(), expiredToken)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, newToken)
|
||||
})
|
||||
}
|
||||
15
backend/internal/service/billing_cache_port.go
Normal file
15
backend/internal/service/billing_cache_port.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// SubscriptionCacheData represents cached subscription data
|
||||
type SubscriptionCacheData struct {
|
||||
Status string
|
||||
ExpiresAt time.Time
|
||||
DailyUsage float64
|
||||
WeeklyUsage float64
|
||||
MonthlyUsage float64
|
||||
Version int64
|
||||
}
|
||||
661
backend/internal/service/billing_cache_service.go
Normal file
661
backend/internal/service/billing_cache_service.go
Normal file
@@ -0,0 +1,661 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
)
|
||||
|
||||
// 错误定义
|
||||
// 注:ErrInsufficientBalance在redeem_service.go中定义
|
||||
// 注:ErrDailyLimitExceeded/ErrWeeklyLimitExceeded/ErrMonthlyLimitExceeded在subscription_service.go中定义
|
||||
var (
|
||||
ErrSubscriptionInvalid = infraerrors.Forbidden("SUBSCRIPTION_INVALID", "subscription is invalid or expired")
|
||||
ErrBillingServiceUnavailable = infraerrors.ServiceUnavailable("BILLING_SERVICE_ERROR", "Billing service temporarily unavailable. Please retry later.")
|
||||
)
|
||||
|
||||
// subscriptionCacheData 订阅缓存数据结构(内部使用)
|
||||
type subscriptionCacheData struct {
|
||||
Status string
|
||||
ExpiresAt time.Time
|
||||
DailyUsage float64
|
||||
WeeklyUsage float64
|
||||
MonthlyUsage float64
|
||||
Version int64
|
||||
}
|
||||
|
||||
// 缓存写入任务类型
|
||||
type cacheWriteKind int
|
||||
|
||||
const (
|
||||
cacheWriteSetBalance cacheWriteKind = iota
|
||||
cacheWriteSetSubscription
|
||||
cacheWriteUpdateSubscriptionUsage
|
||||
cacheWriteDeductBalance
|
||||
)
|
||||
|
||||
// 异步缓存写入工作池配置
|
||||
//
|
||||
// 性能优化说明:
|
||||
// 原实现在请求热路径中使用 goroutine 异步更新缓存,存在以下问题:
|
||||
// 1. 每次请求创建新 goroutine,高并发下产生大量短生命周期 goroutine
|
||||
// 2. 无法控制并发数量,可能导致 Redis 连接耗尽
|
||||
// 3. goroutine 创建/销毁带来额外开销
|
||||
//
|
||||
// 新实现使用固定大小的工作池:
|
||||
// 1. 预创建 10 个 worker goroutine,避免频繁创建销毁
|
||||
// 2. 使用带缓冲的 channel(1000)作为任务队列,平滑写入峰值
|
||||
// 3. 非阻塞写入,队列满时关键任务同步回退,非关键任务丢弃并告警
|
||||
// 4. 统一超时控制,避免慢操作阻塞工作池
|
||||
const (
|
||||
cacheWriteWorkerCount = 10 // 工作协程数量
|
||||
cacheWriteBufferSize = 1000 // 任务队列缓冲大小
|
||||
cacheWriteTimeout = 2 * time.Second // 单个写入操作超时
|
||||
cacheWriteDropLogInterval = 5 * time.Second // 丢弃日志节流间隔
|
||||
)
|
||||
|
||||
// cacheWriteTask 缓存写入任务
|
||||
type cacheWriteTask struct {
|
||||
kind cacheWriteKind
|
||||
userID int64
|
||||
groupID int64
|
||||
balance float64
|
||||
amount float64
|
||||
subscriptionData *subscriptionCacheData
|
||||
}
|
||||
|
||||
// BillingCacheService 计费缓存服务
|
||||
// 负责余额和订阅数据的缓存管理,提供高性能的计费资格检查
|
||||
type BillingCacheService struct {
|
||||
cache BillingCache
|
||||
userRepo UserRepository
|
||||
subRepo UserSubscriptionRepository
|
||||
cfg *config.Config
|
||||
circuitBreaker *billingCircuitBreaker
|
||||
|
||||
cacheWriteChan chan cacheWriteTask
|
||||
cacheWriteWg sync.WaitGroup
|
||||
cacheWriteStopOnce sync.Once
|
||||
// 丢弃日志节流计数器(减少高负载下日志噪音)
|
||||
cacheWriteDropFullCount uint64
|
||||
cacheWriteDropFullLastLog int64
|
||||
cacheWriteDropClosedCount uint64
|
||||
cacheWriteDropClosedLastLog int64
|
||||
}
|
||||
|
||||
// NewBillingCacheService 创建计费缓存服务
|
||||
func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo UserSubscriptionRepository, cfg *config.Config) *BillingCacheService {
|
||||
svc := &BillingCacheService{
|
||||
cache: cache,
|
||||
userRepo: userRepo,
|
||||
subRepo: subRepo,
|
||||
cfg: cfg,
|
||||
}
|
||||
svc.circuitBreaker = newBillingCircuitBreaker(cfg.Billing.CircuitBreaker)
|
||||
svc.startCacheWriteWorkers()
|
||||
return svc
|
||||
}
|
||||
|
||||
// Stop 关闭缓存写入工作池
|
||||
func (s *BillingCacheService) Stop() {
|
||||
s.cacheWriteStopOnce.Do(func() {
|
||||
if s.cacheWriteChan == nil {
|
||||
return
|
||||
}
|
||||
close(s.cacheWriteChan)
|
||||
s.cacheWriteWg.Wait()
|
||||
s.cacheWriteChan = nil
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BillingCacheService) startCacheWriteWorkers() {
|
||||
s.cacheWriteChan = make(chan cacheWriteTask, cacheWriteBufferSize)
|
||||
for i := 0; i < cacheWriteWorkerCount; i++ {
|
||||
s.cacheWriteWg.Add(1)
|
||||
go s.cacheWriteWorker()
|
||||
}
|
||||
}
|
||||
|
||||
// enqueueCacheWrite 尝试将任务入队,队列满时返回 false(并记录告警)。
|
||||
func (s *BillingCacheService) enqueueCacheWrite(task cacheWriteTask) (enqueued bool) {
|
||||
if s.cacheWriteChan == nil {
|
||||
return false
|
||||
}
|
||||
defer func() {
|
||||
if recovered := recover(); recovered != nil {
|
||||
// 队列已关闭时可能触发 panic,记录后静默失败。
|
||||
s.logCacheWriteDrop(task, "closed")
|
||||
enqueued = false
|
||||
}
|
||||
}()
|
||||
select {
|
||||
case s.cacheWriteChan <- task:
|
||||
return true
|
||||
default:
|
||||
// 队列满时不阻塞主流程,交由调用方决定是否同步回退。
|
||||
s.logCacheWriteDrop(task, "full")
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BillingCacheService) cacheWriteWorker() {
|
||||
defer s.cacheWriteWg.Done()
|
||||
for task := range s.cacheWriteChan {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout)
|
||||
switch task.kind {
|
||||
case cacheWriteSetBalance:
|
||||
s.setBalanceCache(ctx, task.userID, task.balance)
|
||||
case cacheWriteSetSubscription:
|
||||
s.setSubscriptionCache(ctx, task.userID, task.groupID, task.subscriptionData)
|
||||
case cacheWriteUpdateSubscriptionUsage:
|
||||
if s.cache != nil {
|
||||
if err := s.cache.UpdateSubscriptionUsage(ctx, task.userID, task.groupID, task.amount); err != nil {
|
||||
log.Printf("Warning: update subscription cache failed for user %d group %d: %v", task.userID, task.groupID, err)
|
||||
}
|
||||
}
|
||||
case cacheWriteDeductBalance:
|
||||
if s.cache != nil {
|
||||
if err := s.cache.DeductUserBalance(ctx, task.userID, task.amount); err != nil {
|
||||
log.Printf("Warning: deduct balance cache failed for user %d: %v", task.userID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
|
||||
// cacheWriteKindName 用于日志中的任务类型标识,便于排查丢弃原因。
|
||||
func cacheWriteKindName(kind cacheWriteKind) string {
|
||||
switch kind {
|
||||
case cacheWriteSetBalance:
|
||||
return "set_balance"
|
||||
case cacheWriteSetSubscription:
|
||||
return "set_subscription"
|
||||
case cacheWriteUpdateSubscriptionUsage:
|
||||
return "update_subscription_usage"
|
||||
case cacheWriteDeductBalance:
|
||||
return "deduct_balance"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// logCacheWriteDrop 使用节流方式记录丢弃情况,并汇总丢弃数量。
|
||||
func (s *BillingCacheService) logCacheWriteDrop(task cacheWriteTask, reason string) {
|
||||
var (
|
||||
countPtr *uint64
|
||||
lastPtr *int64
|
||||
)
|
||||
switch reason {
|
||||
case "full":
|
||||
countPtr = &s.cacheWriteDropFullCount
|
||||
lastPtr = &s.cacheWriteDropFullLastLog
|
||||
case "closed":
|
||||
countPtr = &s.cacheWriteDropClosedCount
|
||||
lastPtr = &s.cacheWriteDropClosedLastLog
|
||||
default:
|
||||
return
|
||||
}
|
||||
|
||||
atomic.AddUint64(countPtr, 1)
|
||||
now := time.Now().UnixNano()
|
||||
last := atomic.LoadInt64(lastPtr)
|
||||
if now-last < int64(cacheWriteDropLogInterval) {
|
||||
return
|
||||
}
|
||||
if !atomic.CompareAndSwapInt64(lastPtr, last, now) {
|
||||
return
|
||||
}
|
||||
dropped := atomic.SwapUint64(countPtr, 0)
|
||||
if dropped == 0 {
|
||||
return
|
||||
}
|
||||
log.Printf("Warning: cache write queue %s, dropped %d tasks in last %s (latest kind=%s user %d group %d)",
|
||||
reason,
|
||||
dropped,
|
||||
cacheWriteDropLogInterval,
|
||||
cacheWriteKindName(task.kind),
|
||||
task.userID,
|
||||
task.groupID,
|
||||
)
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// 余额缓存方法
|
||||
// ============================================
|
||||
|
||||
// GetUserBalance 获取用户余额(优先从缓存读取)
|
||||
func (s *BillingCacheService) GetUserBalance(ctx context.Context, userID int64) (float64, error) {
|
||||
if s.cache == nil {
|
||||
// Redis不可用,直接查询数据库
|
||||
return s.getUserBalanceFromDB(ctx, userID)
|
||||
}
|
||||
|
||||
// 尝试从缓存读取
|
||||
balance, err := s.cache.GetUserBalance(ctx, userID)
|
||||
if err == nil {
|
||||
return balance, nil
|
||||
}
|
||||
|
||||
// 缓存未命中,从数据库读取
|
||||
balance, err = s.getUserBalanceFromDB(ctx, userID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// 异步建立缓存
|
||||
_ = s.enqueueCacheWrite(cacheWriteTask{
|
||||
kind: cacheWriteSetBalance,
|
||||
userID: userID,
|
||||
balance: balance,
|
||||
})
|
||||
|
||||
return balance, nil
|
||||
}
|
||||
|
||||
// getUserBalanceFromDB 从数据库获取用户余额
|
||||
func (s *BillingCacheService) getUserBalanceFromDB(ctx context.Context, userID int64) (float64, error) {
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("get user balance: %w", err)
|
||||
}
|
||||
return user.Balance, nil
|
||||
}
|
||||
|
||||
// setBalanceCache 设置余额缓存
|
||||
func (s *BillingCacheService) setBalanceCache(ctx context.Context, userID int64, balance float64) {
|
||||
if s.cache == nil {
|
||||
return
|
||||
}
|
||||
if err := s.cache.SetUserBalance(ctx, userID, balance); err != nil {
|
||||
log.Printf("Warning: set balance cache failed for user %d: %v", userID, err)
|
||||
}
|
||||
}
|
||||
|
||||
// DeductBalanceCache 扣减余额缓存(同步调用)
|
||||
func (s *BillingCacheService) DeductBalanceCache(ctx context.Context, userID int64, amount float64) error {
|
||||
if s.cache == nil {
|
||||
return nil
|
||||
}
|
||||
return s.cache.DeductUserBalance(ctx, userID, amount)
|
||||
}
|
||||
|
||||
// QueueDeductBalance 异步扣减余额缓存
|
||||
func (s *BillingCacheService) QueueDeductBalance(userID int64, amount float64) {
|
||||
if s.cache == nil {
|
||||
return
|
||||
}
|
||||
// 队列满时同步回退,避免关键扣减被静默丢弃。
|
||||
if s.enqueueCacheWrite(cacheWriteTask{
|
||||
kind: cacheWriteDeductBalance,
|
||||
userID: userID,
|
||||
amount: amount,
|
||||
}) {
|
||||
return
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout)
|
||||
defer cancel()
|
||||
if err := s.DeductBalanceCache(ctx, userID, amount); err != nil {
|
||||
log.Printf("Warning: deduct balance cache fallback failed for user %d: %v", userID, err)
|
||||
}
|
||||
}
|
||||
|
||||
// InvalidateUserBalance 失效用户余额缓存
|
||||
func (s *BillingCacheService) InvalidateUserBalance(ctx context.Context, userID int64) error {
|
||||
if s.cache == nil {
|
||||
return nil
|
||||
}
|
||||
if err := s.cache.InvalidateUserBalance(ctx, userID); err != nil {
|
||||
log.Printf("Warning: invalidate balance cache failed for user %d: %v", userID, err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// 订阅缓存方法
|
||||
// ============================================
|
||||
|
||||
// GetSubscriptionStatus 获取订阅状态(优先从缓存读取)
|
||||
func (s *BillingCacheService) GetSubscriptionStatus(ctx context.Context, userID, groupID int64) (*subscriptionCacheData, error) {
|
||||
if s.cache == nil {
|
||||
return s.getSubscriptionFromDB(ctx, userID, groupID)
|
||||
}
|
||||
|
||||
// 尝试从缓存读取
|
||||
cacheData, err := s.cache.GetSubscriptionCache(ctx, userID, groupID)
|
||||
if err == nil && cacheData != nil {
|
||||
return s.convertFromPortsData(cacheData), nil
|
||||
}
|
||||
|
||||
// 缓存未命中,从数据库读取
|
||||
data, err := s.getSubscriptionFromDB(ctx, userID, groupID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 异步建立缓存
|
||||
_ = s.enqueueCacheWrite(cacheWriteTask{
|
||||
kind: cacheWriteSetSubscription,
|
||||
userID: userID,
|
||||
groupID: groupID,
|
||||
subscriptionData: data,
|
||||
})
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (s *BillingCacheService) convertFromPortsData(data *SubscriptionCacheData) *subscriptionCacheData {
|
||||
return &subscriptionCacheData{
|
||||
Status: data.Status,
|
||||
ExpiresAt: data.ExpiresAt,
|
||||
DailyUsage: data.DailyUsage,
|
||||
WeeklyUsage: data.WeeklyUsage,
|
||||
MonthlyUsage: data.MonthlyUsage,
|
||||
Version: data.Version,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BillingCacheService) convertToPortsData(data *subscriptionCacheData) *SubscriptionCacheData {
|
||||
return &SubscriptionCacheData{
|
||||
Status: data.Status,
|
||||
ExpiresAt: data.ExpiresAt,
|
||||
DailyUsage: data.DailyUsage,
|
||||
WeeklyUsage: data.WeeklyUsage,
|
||||
MonthlyUsage: data.MonthlyUsage,
|
||||
Version: data.Version,
|
||||
}
|
||||
}
|
||||
|
||||
// getSubscriptionFromDB 从数据库获取订阅数据
|
||||
func (s *BillingCacheService) getSubscriptionFromDB(ctx context.Context, userID, groupID int64) (*subscriptionCacheData, error) {
|
||||
sub, err := s.subRepo.GetActiveByUserIDAndGroupID(ctx, userID, groupID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get subscription: %w", err)
|
||||
}
|
||||
|
||||
return &subscriptionCacheData{
|
||||
Status: sub.Status,
|
||||
ExpiresAt: sub.ExpiresAt,
|
||||
DailyUsage: sub.DailyUsageUSD,
|
||||
WeeklyUsage: sub.WeeklyUsageUSD,
|
||||
MonthlyUsage: sub.MonthlyUsageUSD,
|
||||
Version: sub.UpdatedAt.Unix(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// setSubscriptionCache 设置订阅缓存
|
||||
func (s *BillingCacheService) setSubscriptionCache(ctx context.Context, userID, groupID int64, data *subscriptionCacheData) {
|
||||
if s.cache == nil || data == nil {
|
||||
return
|
||||
}
|
||||
if err := s.cache.SetSubscriptionCache(ctx, userID, groupID, s.convertToPortsData(data)); err != nil {
|
||||
log.Printf("Warning: set subscription cache failed for user %d group %d: %v", userID, groupID, err)
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateSubscriptionUsage 更新订阅用量缓存(同步调用)
|
||||
func (s *BillingCacheService) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, costUSD float64) error {
|
||||
if s.cache == nil {
|
||||
return nil
|
||||
}
|
||||
return s.cache.UpdateSubscriptionUsage(ctx, userID, groupID, costUSD)
|
||||
}
|
||||
|
||||
// QueueUpdateSubscriptionUsage 异步更新订阅用量缓存
|
||||
func (s *BillingCacheService) QueueUpdateSubscriptionUsage(userID, groupID int64, costUSD float64) {
|
||||
if s.cache == nil {
|
||||
return
|
||||
}
|
||||
// 队列满时同步回退,确保订阅用量及时更新。
|
||||
if s.enqueueCacheWrite(cacheWriteTask{
|
||||
kind: cacheWriteUpdateSubscriptionUsage,
|
||||
userID: userID,
|
||||
groupID: groupID,
|
||||
amount: costUSD,
|
||||
}) {
|
||||
return
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout)
|
||||
defer cancel()
|
||||
if err := s.UpdateSubscriptionUsage(ctx, userID, groupID, costUSD); err != nil {
|
||||
log.Printf("Warning: update subscription cache fallback failed for user %d group %d: %v", userID, groupID, err)
|
||||
}
|
||||
}
|
||||
|
||||
// InvalidateSubscription 失效指定订阅缓存
|
||||
func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID, groupID int64) error {
|
||||
if s.cache == nil {
|
||||
return nil
|
||||
}
|
||||
if err := s.cache.InvalidateSubscriptionCache(ctx, userID, groupID); err != nil {
|
||||
log.Printf("Warning: invalidate subscription cache failed for user %d group %d: %v", userID, groupID, err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// 统一检查方法
|
||||
// ============================================
|
||||
|
||||
// CheckBillingEligibility 检查用户是否有资格发起请求
|
||||
// 余额模式:检查缓存余额 > 0
|
||||
// 订阅模式:检查缓存用量未超过限额(Group限额从参数传入)
|
||||
func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user *User, apiKey *APIKey, group *Group, subscription *UserSubscription) error {
|
||||
// 简易模式:跳过所有计费检查
|
||||
if s.cfg.RunMode == config.RunModeSimple {
|
||||
return nil
|
||||
}
|
||||
if s.circuitBreaker != nil && !s.circuitBreaker.Allow() {
|
||||
return ErrBillingServiceUnavailable
|
||||
}
|
||||
|
||||
// 判断计费模式
|
||||
isSubscriptionMode := group != nil && group.IsSubscriptionType() && subscription != nil
|
||||
|
||||
if isSubscriptionMode {
|
||||
return s.checkSubscriptionEligibility(ctx, user.ID, group, subscription)
|
||||
}
|
||||
|
||||
return s.checkBalanceEligibility(ctx, user.ID)
|
||||
}
|
||||
|
||||
// checkBalanceEligibility 检查余额模式资格
|
||||
func (s *BillingCacheService) checkBalanceEligibility(ctx context.Context, userID int64) error {
|
||||
balance, err := s.GetUserBalance(ctx, userID)
|
||||
if err != nil {
|
||||
if s.circuitBreaker != nil {
|
||||
s.circuitBreaker.OnFailure(err)
|
||||
}
|
||||
log.Printf("ALERT: billing balance check failed for user %d: %v", userID, err)
|
||||
return ErrBillingServiceUnavailable.WithCause(err)
|
||||
}
|
||||
if s.circuitBreaker != nil {
|
||||
s.circuitBreaker.OnSuccess()
|
||||
}
|
||||
|
||||
if balance <= 0 {
|
||||
return ErrInsufficientBalance
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkSubscriptionEligibility 检查订阅模式资格
|
||||
func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context, userID int64, group *Group, subscription *UserSubscription) error {
|
||||
// 获取订阅缓存数据
|
||||
subData, err := s.GetSubscriptionStatus(ctx, userID, group.ID)
|
||||
if err != nil {
|
||||
if s.circuitBreaker != nil {
|
||||
s.circuitBreaker.OnFailure(err)
|
||||
}
|
||||
log.Printf("ALERT: billing subscription check failed for user %d group %d: %v", userID, group.ID, err)
|
||||
return ErrBillingServiceUnavailable.WithCause(err)
|
||||
}
|
||||
if s.circuitBreaker != nil {
|
||||
s.circuitBreaker.OnSuccess()
|
||||
}
|
||||
|
||||
// 检查订阅状态
|
||||
if subData.Status != SubscriptionStatusActive {
|
||||
return ErrSubscriptionInvalid
|
||||
}
|
||||
|
||||
// 检查是否过期
|
||||
if time.Now().After(subData.ExpiresAt) {
|
||||
return ErrSubscriptionInvalid
|
||||
}
|
||||
|
||||
// 检查限额(使用传入的Group限额配置)
|
||||
if group.HasDailyLimit() && subData.DailyUsage >= *group.DailyLimitUSD {
|
||||
return ErrDailyLimitExceeded
|
||||
}
|
||||
|
||||
if group.HasWeeklyLimit() && subData.WeeklyUsage >= *group.WeeklyLimitUSD {
|
||||
return ErrWeeklyLimitExceeded
|
||||
}
|
||||
|
||||
if group.HasMonthlyLimit() && subData.MonthlyUsage >= *group.MonthlyLimitUSD {
|
||||
return ErrMonthlyLimitExceeded
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type billingCircuitBreakerState int
|
||||
|
||||
const (
|
||||
billingCircuitClosed billingCircuitBreakerState = iota
|
||||
billingCircuitOpen
|
||||
billingCircuitHalfOpen
|
||||
)
|
||||
|
||||
type billingCircuitBreaker struct {
|
||||
mu sync.Mutex
|
||||
state billingCircuitBreakerState
|
||||
failures int
|
||||
openedAt time.Time
|
||||
failureThreshold int
|
||||
resetTimeout time.Duration
|
||||
halfOpenRequests int
|
||||
halfOpenRemaining int
|
||||
}
|
||||
|
||||
func newBillingCircuitBreaker(cfg config.CircuitBreakerConfig) *billingCircuitBreaker {
|
||||
if !cfg.Enabled {
|
||||
return nil
|
||||
}
|
||||
resetTimeout := time.Duration(cfg.ResetTimeoutSeconds) * time.Second
|
||||
if resetTimeout <= 0 {
|
||||
resetTimeout = 30 * time.Second
|
||||
}
|
||||
halfOpen := cfg.HalfOpenRequests
|
||||
if halfOpen <= 0 {
|
||||
halfOpen = 1
|
||||
}
|
||||
threshold := cfg.FailureThreshold
|
||||
if threshold <= 0 {
|
||||
threshold = 5
|
||||
}
|
||||
return &billingCircuitBreaker{
|
||||
state: billingCircuitClosed,
|
||||
failureThreshold: threshold,
|
||||
resetTimeout: resetTimeout,
|
||||
halfOpenRequests: halfOpen,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *billingCircuitBreaker) Allow() bool {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
switch b.state {
|
||||
case billingCircuitClosed:
|
||||
return true
|
||||
case billingCircuitOpen:
|
||||
if time.Since(b.openedAt) < b.resetTimeout {
|
||||
return false
|
||||
}
|
||||
b.state = billingCircuitHalfOpen
|
||||
b.halfOpenRemaining = b.halfOpenRequests
|
||||
log.Printf("ALERT: billing circuit breaker entering half-open state")
|
||||
fallthrough
|
||||
case billingCircuitHalfOpen:
|
||||
if b.halfOpenRemaining <= 0 {
|
||||
return false
|
||||
}
|
||||
b.halfOpenRemaining--
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (b *billingCircuitBreaker) OnFailure(err error) {
|
||||
if b == nil {
|
||||
return
|
||||
}
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
switch b.state {
|
||||
case billingCircuitOpen:
|
||||
return
|
||||
case billingCircuitHalfOpen:
|
||||
b.state = billingCircuitOpen
|
||||
b.openedAt = time.Now()
|
||||
b.halfOpenRemaining = 0
|
||||
log.Printf("ALERT: billing circuit breaker opened after half-open failure: %v", err)
|
||||
return
|
||||
default:
|
||||
b.failures++
|
||||
if b.failures >= b.failureThreshold {
|
||||
b.state = billingCircuitOpen
|
||||
b.openedAt = time.Now()
|
||||
b.halfOpenRemaining = 0
|
||||
log.Printf("ALERT: billing circuit breaker opened after %d failures: %v", b.failures, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (b *billingCircuitBreaker) OnSuccess() {
|
||||
if b == nil {
|
||||
return
|
||||
}
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
previousState := b.state
|
||||
previousFailures := b.failures
|
||||
|
||||
b.state = billingCircuitClosed
|
||||
b.failures = 0
|
||||
b.halfOpenRemaining = 0
|
||||
|
||||
// 只有状态真正发生变化时才记录日志
|
||||
if previousState != billingCircuitClosed {
|
||||
log.Printf("ALERT: billing circuit breaker closed (was %s)", circuitStateString(previousState))
|
||||
} else if previousFailures > 0 {
|
||||
log.Printf("INFO: billing circuit breaker failures reset from %d", previousFailures)
|
||||
}
|
||||
}
|
||||
|
||||
func circuitStateString(state billingCircuitBreakerState) string {
|
||||
switch state {
|
||||
case billingCircuitClosed:
|
||||
return "closed"
|
||||
case billingCircuitOpen:
|
||||
return "open"
|
||||
case billingCircuitHalfOpen:
|
||||
return "half-open"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
75
backend/internal/service/billing_cache_service_test.go
Normal file
75
backend/internal/service/billing_cache_service_test.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type billingCacheWorkerStub struct {
|
||||
balanceUpdates int64
|
||||
subscriptionUpdates int64
|
||||
}
|
||||
|
||||
func (b *billingCacheWorkerStub) GetUserBalance(ctx context.Context, userID int64) (float64, error) {
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (b *billingCacheWorkerStub) SetUserBalance(ctx context.Context, userID int64, balance float64) error {
|
||||
atomic.AddInt64(&b.balanceUpdates, 1)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *billingCacheWorkerStub) DeductUserBalance(ctx context.Context, userID int64, amount float64) error {
|
||||
atomic.AddInt64(&b.balanceUpdates, 1)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *billingCacheWorkerStub) InvalidateUserBalance(ctx context.Context, userID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *billingCacheWorkerStub) GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*SubscriptionCacheData, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (b *billingCacheWorkerStub) SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *SubscriptionCacheData) error {
|
||||
atomic.AddInt64(&b.subscriptionUpdates, 1)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *billingCacheWorkerStub) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error {
|
||||
atomic.AddInt64(&b.subscriptionUpdates, 1)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *billingCacheWorkerStub) InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestBillingCacheServiceQueueHighLoad(t *testing.T) {
|
||||
cache := &billingCacheWorkerStub{}
|
||||
svc := NewBillingCacheService(cache, nil, nil, &config.Config{})
|
||||
t.Cleanup(svc.Stop)
|
||||
|
||||
start := time.Now()
|
||||
for i := 0; i < cacheWriteBufferSize*2; i++ {
|
||||
svc.QueueDeductBalance(1, 1)
|
||||
}
|
||||
require.Less(t, time.Since(start), 2*time.Second)
|
||||
|
||||
svc.QueueUpdateSubscriptionUsage(1, 2, 1.5)
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
return atomic.LoadInt64(&cache.balanceUpdates) > 0
|
||||
}, 2*time.Second, 10*time.Millisecond)
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
return atomic.LoadInt64(&cache.subscriptionUpdates) > 0
|
||||
}, 2*time.Second, 10*time.Millisecond)
|
||||
}
|
||||
382
backend/internal/service/billing_service.go
Normal file
382
backend/internal/service/billing_service.go
Normal file
@@ -0,0 +1,382 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"log"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
)
|
||||
|
||||
// BillingCache defines cache operations for billing service
|
||||
type BillingCache interface {
|
||||
// Balance operations
|
||||
GetUserBalance(ctx context.Context, userID int64) (float64, error)
|
||||
SetUserBalance(ctx context.Context, userID int64, balance float64) error
|
||||
DeductUserBalance(ctx context.Context, userID int64, amount float64) error
|
||||
InvalidateUserBalance(ctx context.Context, userID int64) error
|
||||
|
||||
// Subscription operations
|
||||
GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*SubscriptionCacheData, error)
|
||||
SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *SubscriptionCacheData) error
|
||||
UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error
|
||||
InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error
|
||||
}
|
||||
|
||||
// ModelPricing 模型价格配置(per-token价格,与LiteLLM格式一致)
|
||||
type ModelPricing struct {
|
||||
InputPricePerToken float64 // 每token输入价格 (USD)
|
||||
OutputPricePerToken float64 // 每token输出价格 (USD)
|
||||
CacheCreationPricePerToken float64 // 缓存创建每token价格 (USD)
|
||||
CacheReadPricePerToken float64 // 缓存读取每token价格 (USD)
|
||||
CacheCreation5mPrice float64 // 5分钟缓存创建价格(每百万token)- 仅用于硬编码回退
|
||||
CacheCreation1hPrice float64 // 1小时缓存创建价格(每百万token)- 仅用于硬编码回退
|
||||
SupportsCacheBreakdown bool // 是否支持详细的缓存分类
|
||||
}
|
||||
|
||||
// UsageTokens 使用的token数量
|
||||
type UsageTokens struct {
|
||||
InputTokens int
|
||||
OutputTokens int
|
||||
CacheCreationTokens int
|
||||
CacheReadTokens int
|
||||
CacheCreation5mTokens int
|
||||
CacheCreation1hTokens int
|
||||
}
|
||||
|
||||
// CostBreakdown 费用明细
|
||||
type CostBreakdown struct {
|
||||
InputCost float64
|
||||
OutputCost float64
|
||||
CacheCreationCost float64
|
||||
CacheReadCost float64
|
||||
TotalCost float64
|
||||
ActualCost float64 // 应用倍率后的实际费用
|
||||
}
|
||||
|
||||
// BillingService 计费服务
|
||||
type BillingService struct {
|
||||
cfg *config.Config
|
||||
pricingService *PricingService
|
||||
fallbackPrices map[string]*ModelPricing // 硬编码回退价格
|
||||
}
|
||||
|
||||
// NewBillingService 创建计费服务实例
|
||||
func NewBillingService(cfg *config.Config, pricingService *PricingService) *BillingService {
|
||||
s := &BillingService{
|
||||
cfg: cfg,
|
||||
pricingService: pricingService,
|
||||
fallbackPrices: make(map[string]*ModelPricing),
|
||||
}
|
||||
|
||||
// 初始化硬编码回退价格(当动态价格不可用时使用)
|
||||
s.initFallbackPricing()
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// initFallbackPricing 初始化硬编码回退价格(当动态价格不可用时使用)
|
||||
// 价格单位:USD per token(与LiteLLM格式一致)
|
||||
func (s *BillingService) initFallbackPricing() {
|
||||
// Claude 4.5 Opus
|
||||
s.fallbackPrices["claude-opus-4.5"] = &ModelPricing{
|
||||
InputPricePerToken: 5e-6, // $5 per MTok
|
||||
OutputPricePerToken: 25e-6, // $25 per MTok
|
||||
CacheCreationPricePerToken: 6.25e-6, // $6.25 per MTok
|
||||
CacheReadPricePerToken: 0.5e-6, // $0.50 per MTok
|
||||
SupportsCacheBreakdown: false,
|
||||
}
|
||||
|
||||
// Claude 4 Sonnet
|
||||
s.fallbackPrices["claude-sonnet-4"] = &ModelPricing{
|
||||
InputPricePerToken: 3e-6, // $3 per MTok
|
||||
OutputPricePerToken: 15e-6, // $15 per MTok
|
||||
CacheCreationPricePerToken: 3.75e-6, // $3.75 per MTok
|
||||
CacheReadPricePerToken: 0.3e-6, // $0.30 per MTok
|
||||
SupportsCacheBreakdown: false,
|
||||
}
|
||||
|
||||
// Claude 3.5 Sonnet
|
||||
s.fallbackPrices["claude-3-5-sonnet"] = &ModelPricing{
|
||||
InputPricePerToken: 3e-6, // $3 per MTok
|
||||
OutputPricePerToken: 15e-6, // $15 per MTok
|
||||
CacheCreationPricePerToken: 3.75e-6, // $3.75 per MTok
|
||||
CacheReadPricePerToken: 0.3e-6, // $0.30 per MTok
|
||||
SupportsCacheBreakdown: false,
|
||||
}
|
||||
|
||||
// Claude 3.5 Haiku
|
||||
s.fallbackPrices["claude-3-5-haiku"] = &ModelPricing{
|
||||
InputPricePerToken: 1e-6, // $1 per MTok
|
||||
OutputPricePerToken: 5e-6, // $5 per MTok
|
||||
CacheCreationPricePerToken: 1.25e-6, // $1.25 per MTok
|
||||
CacheReadPricePerToken: 0.1e-6, // $0.10 per MTok
|
||||
SupportsCacheBreakdown: false,
|
||||
}
|
||||
|
||||
// Claude 3 Opus
|
||||
s.fallbackPrices["claude-3-opus"] = &ModelPricing{
|
||||
InputPricePerToken: 15e-6, // $15 per MTok
|
||||
OutputPricePerToken: 75e-6, // $75 per MTok
|
||||
CacheCreationPricePerToken: 18.75e-6, // $18.75 per MTok
|
||||
CacheReadPricePerToken: 1.5e-6, // $1.50 per MTok
|
||||
SupportsCacheBreakdown: false,
|
||||
}
|
||||
|
||||
// Claude 3 Haiku
|
||||
s.fallbackPrices["claude-3-haiku"] = &ModelPricing{
|
||||
InputPricePerToken: 0.25e-6, // $0.25 per MTok
|
||||
OutputPricePerToken: 1.25e-6, // $1.25 per MTok
|
||||
CacheCreationPricePerToken: 0.3e-6, // $0.30 per MTok
|
||||
CacheReadPricePerToken: 0.03e-6, // $0.03 per MTok
|
||||
SupportsCacheBreakdown: false,
|
||||
}
|
||||
}
|
||||
|
||||
// getFallbackPricing 根据模型系列获取回退价格
|
||||
func (s *BillingService) getFallbackPricing(model string) *ModelPricing {
|
||||
modelLower := strings.ToLower(model)
|
||||
|
||||
// 按模型系列匹配
|
||||
if strings.Contains(modelLower, "opus") {
|
||||
if strings.Contains(modelLower, "4.5") || strings.Contains(modelLower, "4-5") {
|
||||
return s.fallbackPrices["claude-opus-4.5"]
|
||||
}
|
||||
return s.fallbackPrices["claude-3-opus"]
|
||||
}
|
||||
if strings.Contains(modelLower, "sonnet") {
|
||||
if strings.Contains(modelLower, "4") && !strings.Contains(modelLower, "3") {
|
||||
return s.fallbackPrices["claude-sonnet-4"]
|
||||
}
|
||||
return s.fallbackPrices["claude-3-5-sonnet"]
|
||||
}
|
||||
if strings.Contains(modelLower, "haiku") {
|
||||
if strings.Contains(modelLower, "3-5") || strings.Contains(modelLower, "3.5") {
|
||||
return s.fallbackPrices["claude-3-5-haiku"]
|
||||
}
|
||||
return s.fallbackPrices["claude-3-haiku"]
|
||||
}
|
||||
|
||||
// 默认使用Sonnet价格
|
||||
return s.fallbackPrices["claude-sonnet-4"]
|
||||
}
|
||||
|
||||
// GetModelPricing 获取模型价格配置
|
||||
func (s *BillingService) GetModelPricing(model string) (*ModelPricing, error) {
|
||||
// 标准化模型名称(转小写)
|
||||
model = strings.ToLower(model)
|
||||
|
||||
// 1. 优先从动态价格服务获取
|
||||
if s.pricingService != nil {
|
||||
litellmPricing := s.pricingService.GetModelPricing(model)
|
||||
if litellmPricing != nil {
|
||||
return &ModelPricing{
|
||||
InputPricePerToken: litellmPricing.InputCostPerToken,
|
||||
OutputPricePerToken: litellmPricing.OutputCostPerToken,
|
||||
CacheCreationPricePerToken: litellmPricing.CacheCreationInputTokenCost,
|
||||
CacheReadPricePerToken: litellmPricing.CacheReadInputTokenCost,
|
||||
SupportsCacheBreakdown: false,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 使用硬编码回退价格
|
||||
fallback := s.getFallbackPricing(model)
|
||||
if fallback != nil {
|
||||
log.Printf("[Billing] Using fallback pricing for model: %s", model)
|
||||
return fallback, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("pricing not found for model: %s", model)
|
||||
}
|
||||
|
||||
// CalculateCost 计算使用费用
|
||||
func (s *BillingService) CalculateCost(model string, tokens UsageTokens, rateMultiplier float64) (*CostBreakdown, error) {
|
||||
pricing, err := s.GetModelPricing(model)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
breakdown := &CostBreakdown{}
|
||||
|
||||
// 计算输入token费用(使用per-token价格)
|
||||
breakdown.InputCost = float64(tokens.InputTokens) * pricing.InputPricePerToken
|
||||
|
||||
// 计算输出token费用
|
||||
breakdown.OutputCost = float64(tokens.OutputTokens) * pricing.OutputPricePerToken
|
||||
|
||||
// 计算缓存费用
|
||||
if pricing.SupportsCacheBreakdown && (pricing.CacheCreation5mPrice > 0 || pricing.CacheCreation1hPrice > 0) {
|
||||
// 支持详细缓存分类的模型(5分钟/1小时缓存)
|
||||
breakdown.CacheCreationCost = float64(tokens.CacheCreation5mTokens)/1_000_000*pricing.CacheCreation5mPrice +
|
||||
float64(tokens.CacheCreation1hTokens)/1_000_000*pricing.CacheCreation1hPrice
|
||||
} else {
|
||||
// 标准缓存创建价格(per-token)
|
||||
breakdown.CacheCreationCost = float64(tokens.CacheCreationTokens) * pricing.CacheCreationPricePerToken
|
||||
}
|
||||
|
||||
breakdown.CacheReadCost = float64(tokens.CacheReadTokens) * pricing.CacheReadPricePerToken
|
||||
|
||||
// 计算总费用
|
||||
breakdown.TotalCost = breakdown.InputCost + breakdown.OutputCost +
|
||||
breakdown.CacheCreationCost + breakdown.CacheReadCost
|
||||
|
||||
// 应用倍率计算实际费用
|
||||
if rateMultiplier <= 0 {
|
||||
rateMultiplier = 1.0
|
||||
}
|
||||
breakdown.ActualCost = breakdown.TotalCost * rateMultiplier
|
||||
|
||||
return breakdown, nil
|
||||
}
|
||||
|
||||
// CalculateCostWithConfig 使用配置中的默认倍率计算费用
|
||||
func (s *BillingService) CalculateCostWithConfig(model string, tokens UsageTokens) (*CostBreakdown, error) {
|
||||
multiplier := s.cfg.Default.RateMultiplier
|
||||
if multiplier <= 0 {
|
||||
multiplier = 1.0
|
||||
}
|
||||
return s.CalculateCost(model, tokens, multiplier)
|
||||
}
|
||||
|
||||
// ListSupportedModels 列出所有支持的模型(现在总是返回true,因为有模糊匹配)
|
||||
func (s *BillingService) ListSupportedModels() []string {
|
||||
models := make([]string, 0)
|
||||
// 返回回退价格支持的模型系列
|
||||
for model := range s.fallbackPrices {
|
||||
models = append(models, model)
|
||||
}
|
||||
return models
|
||||
}
|
||||
|
||||
// IsModelSupported 检查模型是否支持(现在总是返回true,因为有模糊匹配回退)
|
||||
func (s *BillingService) IsModelSupported(model string) bool {
|
||||
// 所有Claude模型都有回退价格支持
|
||||
modelLower := strings.ToLower(model)
|
||||
return strings.Contains(modelLower, "claude") ||
|
||||
strings.Contains(modelLower, "opus") ||
|
||||
strings.Contains(modelLower, "sonnet") ||
|
||||
strings.Contains(modelLower, "haiku")
|
||||
}
|
||||
|
||||
// GetEstimatedCost 估算费用(用于前端展示)
|
||||
func (s *BillingService) GetEstimatedCost(model string, estimatedInputTokens, estimatedOutputTokens int) (float64, error) {
|
||||
tokens := UsageTokens{
|
||||
InputTokens: estimatedInputTokens,
|
||||
OutputTokens: estimatedOutputTokens,
|
||||
}
|
||||
|
||||
breakdown, err := s.CalculateCostWithConfig(model, tokens)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return breakdown.ActualCost, nil
|
||||
}
|
||||
|
||||
// GetPricingServiceStatus 获取价格服务状态
|
||||
func (s *BillingService) GetPricingServiceStatus() map[string]any {
|
||||
if s.pricingService != nil {
|
||||
return s.pricingService.GetStatus()
|
||||
}
|
||||
return map[string]any{
|
||||
"model_count": len(s.fallbackPrices),
|
||||
"last_updated": "using fallback",
|
||||
"local_hash": "N/A",
|
||||
}
|
||||
}
|
||||
|
||||
// ForceUpdatePricing 强制更新价格数据
|
||||
func (s *BillingService) ForceUpdatePricing() error {
|
||||
if s.pricingService != nil {
|
||||
return s.pricingService.ForceUpdate()
|
||||
}
|
||||
return fmt.Errorf("pricing service not initialized")
|
||||
}
|
||||
|
||||
// ImagePriceConfig 图片计费配置
|
||||
type ImagePriceConfig struct {
|
||||
Price1K *float64 // 1K 尺寸价格(nil 表示使用默认值)
|
||||
Price2K *float64 // 2K 尺寸价格(nil 表示使用默认值)
|
||||
Price4K *float64 // 4K 尺寸价格(nil 表示使用默认值)
|
||||
}
|
||||
|
||||
// CalculateImageCost 计算图片生成费用
|
||||
// model: 请求的模型名称(用于获取 LiteLLM 默认价格)
|
||||
// imageSize: 图片尺寸 "1K", "2K", "4K"
|
||||
// imageCount: 生成的图片数量
|
||||
// groupConfig: 分组配置的价格(可能为 nil,表示使用默认值)
|
||||
// rateMultiplier: 费率倍数
|
||||
func (s *BillingService) CalculateImageCost(model string, imageSize string, imageCount int, groupConfig *ImagePriceConfig, rateMultiplier float64) *CostBreakdown {
|
||||
if imageCount <= 0 {
|
||||
return &CostBreakdown{}
|
||||
}
|
||||
|
||||
// 获取单价
|
||||
unitPrice := s.getImageUnitPrice(model, imageSize, groupConfig)
|
||||
|
||||
// 计算总费用
|
||||
totalCost := unitPrice * float64(imageCount)
|
||||
|
||||
// 应用倍率
|
||||
if rateMultiplier <= 0 {
|
||||
rateMultiplier = 1.0
|
||||
}
|
||||
actualCost := totalCost * rateMultiplier
|
||||
|
||||
return &CostBreakdown{
|
||||
TotalCost: totalCost,
|
||||
ActualCost: actualCost,
|
||||
}
|
||||
}
|
||||
|
||||
// getImageUnitPrice 获取图片单价
|
||||
func (s *BillingService) getImageUnitPrice(model string, imageSize string, groupConfig *ImagePriceConfig) float64 {
|
||||
// 优先使用分组配置的价格
|
||||
if groupConfig != nil {
|
||||
switch imageSize {
|
||||
case "1K":
|
||||
if groupConfig.Price1K != nil {
|
||||
return *groupConfig.Price1K
|
||||
}
|
||||
case "2K":
|
||||
if groupConfig.Price2K != nil {
|
||||
return *groupConfig.Price2K
|
||||
}
|
||||
case "4K":
|
||||
if groupConfig.Price4K != nil {
|
||||
return *groupConfig.Price4K
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 回退到 LiteLLM 默认价格
|
||||
return s.getDefaultImagePrice(model, imageSize)
|
||||
}
|
||||
|
||||
// getDefaultImagePrice 获取 LiteLLM 默认图片价格
|
||||
func (s *BillingService) getDefaultImagePrice(model string, imageSize string) float64 {
|
||||
basePrice := 0.0
|
||||
|
||||
// 从 PricingService 获取 output_cost_per_image
|
||||
if s.pricingService != nil {
|
||||
pricing := s.pricingService.GetModelPricing(model)
|
||||
if pricing != nil && pricing.OutputCostPerImage > 0 {
|
||||
basePrice = pricing.OutputCostPerImage
|
||||
}
|
||||
}
|
||||
|
||||
// 如果没有找到价格,使用硬编码默认值($0.134,来自 gemini-3-pro-image-preview)
|
||||
if basePrice <= 0 {
|
||||
basePrice = 0.134
|
||||
}
|
||||
|
||||
// 4K 尺寸翻倍
|
||||
if imageSize == "4K" {
|
||||
return basePrice * 2
|
||||
}
|
||||
|
||||
return basePrice
|
||||
}
|
||||
149
backend/internal/service/billing_service_image_test.go
Normal file
149
backend/internal/service/billing_service_image_test.go
Normal file
@@ -0,0 +1,149 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestCalculateImageCost_DefaultPricing 测试无分组配置时使用默认价格
|
||||
func TestCalculateImageCost_DefaultPricing(t *testing.T) {
|
||||
svc := &BillingService{} // pricingService 为 nil,使用硬编码默认值
|
||||
|
||||
// 2K 尺寸,默认价格 $0.134
|
||||
cost := svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, nil, 1.0)
|
||||
require.InDelta(t, 0.134, cost.TotalCost, 0.0001)
|
||||
require.InDelta(t, 0.134, cost.ActualCost, 0.0001)
|
||||
|
||||
// 多张图片
|
||||
cost = svc.CalculateImageCost("gemini-3-pro-image", "2K", 3, nil, 1.0)
|
||||
require.InDelta(t, 0.402, cost.TotalCost, 0.0001)
|
||||
}
|
||||
|
||||
// TestCalculateImageCost_GroupCustomPricing 测试分组自定义价格
|
||||
func TestCalculateImageCost_GroupCustomPricing(t *testing.T) {
|
||||
svc := &BillingService{}
|
||||
|
||||
price1K := 0.10
|
||||
price2K := 0.15
|
||||
price4K := 0.30
|
||||
groupConfig := &ImagePriceConfig{
|
||||
Price1K: &price1K,
|
||||
Price2K: &price2K,
|
||||
Price4K: &price4K,
|
||||
}
|
||||
|
||||
// 1K 使用分组价格
|
||||
cost := svc.CalculateImageCost("gemini-3-pro-image", "1K", 2, groupConfig, 1.0)
|
||||
require.InDelta(t, 0.20, cost.TotalCost, 0.0001)
|
||||
|
||||
// 2K 使用分组价格
|
||||
cost = svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, groupConfig, 1.0)
|
||||
require.InDelta(t, 0.15, cost.TotalCost, 0.0001)
|
||||
|
||||
// 4K 使用分组价格
|
||||
cost = svc.CalculateImageCost("gemini-3-pro-image", "4K", 1, groupConfig, 1.0)
|
||||
require.InDelta(t, 0.30, cost.TotalCost, 0.0001)
|
||||
}
|
||||
|
||||
// TestCalculateImageCost_4KDoublePrice 测试 4K 默认价格翻倍
|
||||
func TestCalculateImageCost_4KDoublePrice(t *testing.T) {
|
||||
svc := &BillingService{}
|
||||
|
||||
// 4K 尺寸,默认价格翻倍 $0.134 * 2 = $0.268
|
||||
cost := svc.CalculateImageCost("gemini-3-pro-image", "4K", 1, nil, 1.0)
|
||||
require.InDelta(t, 0.268, cost.TotalCost, 0.0001)
|
||||
}
|
||||
|
||||
// TestCalculateImageCost_RateMultiplier 测试费率倍数
|
||||
func TestCalculateImageCost_RateMultiplier(t *testing.T) {
|
||||
svc := &BillingService{}
|
||||
|
||||
// 费率倍数 1.5x
|
||||
cost := svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, nil, 1.5)
|
||||
require.InDelta(t, 0.134, cost.TotalCost, 0.0001) // TotalCost 不变
|
||||
require.InDelta(t, 0.201, cost.ActualCost, 0.0001) // ActualCost = 0.134 * 1.5
|
||||
|
||||
// 费率倍数 2.0x
|
||||
cost = svc.CalculateImageCost("gemini-3-pro-image", "2K", 2, nil, 2.0)
|
||||
require.InDelta(t, 0.268, cost.TotalCost, 0.0001)
|
||||
require.InDelta(t, 0.536, cost.ActualCost, 0.0001)
|
||||
}
|
||||
|
||||
// TestCalculateImageCost_ZeroCount 测试 imageCount=0
|
||||
func TestCalculateImageCost_ZeroCount(t *testing.T) {
|
||||
svc := &BillingService{}
|
||||
|
||||
cost := svc.CalculateImageCost("gemini-3-pro-image", "2K", 0, nil, 1.0)
|
||||
require.Equal(t, 0.0, cost.TotalCost)
|
||||
require.Equal(t, 0.0, cost.ActualCost)
|
||||
}
|
||||
|
||||
// TestCalculateImageCost_NegativeCount 测试 imageCount=-1
|
||||
func TestCalculateImageCost_NegativeCount(t *testing.T) {
|
||||
svc := &BillingService{}
|
||||
|
||||
cost := svc.CalculateImageCost("gemini-3-pro-image", "2K", -1, nil, 1.0)
|
||||
require.Equal(t, 0.0, cost.TotalCost)
|
||||
require.Equal(t, 0.0, cost.ActualCost)
|
||||
}
|
||||
|
||||
// TestCalculateImageCost_ZeroRateMultiplier 测试费率倍数为 0 时默认使用 1.0
|
||||
func TestCalculateImageCost_ZeroRateMultiplier(t *testing.T) {
|
||||
svc := &BillingService{}
|
||||
|
||||
cost := svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, nil, 0)
|
||||
require.InDelta(t, 0.134, cost.TotalCost, 0.0001)
|
||||
require.InDelta(t, 0.134, cost.ActualCost, 0.0001) // 0 倍率当作 1.0 处理
|
||||
}
|
||||
|
||||
// TestGetImageUnitPrice_GroupPriorityOverDefault 测试分组价格优先于默认价格
|
||||
func TestGetImageUnitPrice_GroupPriorityOverDefault(t *testing.T) {
|
||||
svc := &BillingService{}
|
||||
|
||||
price2K := 0.20
|
||||
groupConfig := &ImagePriceConfig{
|
||||
Price2K: &price2K,
|
||||
}
|
||||
|
||||
// 分组配置了 2K 价格,应该使用分组价格而不是默认的 $0.134
|
||||
cost := svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, groupConfig, 1.0)
|
||||
require.InDelta(t, 0.20, cost.TotalCost, 0.0001)
|
||||
}
|
||||
|
||||
// TestGetImageUnitPrice_PartialGroupConfig 测试分组部分配置时回退默认
|
||||
func TestGetImageUnitPrice_PartialGroupConfig(t *testing.T) {
|
||||
svc := &BillingService{}
|
||||
|
||||
// 只配置 1K 价格
|
||||
price1K := 0.10
|
||||
groupConfig := &ImagePriceConfig{
|
||||
Price1K: &price1K,
|
||||
}
|
||||
|
||||
// 1K 使用分组价格
|
||||
cost := svc.CalculateImageCost("gemini-3-pro-image", "1K", 1, groupConfig, 1.0)
|
||||
require.InDelta(t, 0.10, cost.TotalCost, 0.0001)
|
||||
|
||||
// 2K 回退默认价格 $0.134
|
||||
cost = svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, groupConfig, 1.0)
|
||||
require.InDelta(t, 0.134, cost.TotalCost, 0.0001)
|
||||
|
||||
// 4K 回退默认价格 $0.268 (翻倍)
|
||||
cost = svc.CalculateImageCost("gemini-3-pro-image", "4K", 1, groupConfig, 1.0)
|
||||
require.InDelta(t, 0.268, cost.TotalCost, 0.0001)
|
||||
}
|
||||
|
||||
// TestGetDefaultImagePrice_FallbackHardcoded 测试 PricingService 无数据时使用硬编码默认值
|
||||
func TestGetDefaultImagePrice_FallbackHardcoded(t *testing.T) {
|
||||
svc := &BillingService{} // pricingService 为 nil
|
||||
|
||||
// 1K 和 2K 使用相同的默认价格 $0.134
|
||||
cost := svc.CalculateImageCost("gemini-3-pro-image", "1K", 1, nil, 1.0)
|
||||
require.InDelta(t, 0.134, cost.TotalCost, 0.0001)
|
||||
|
||||
cost = svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, nil, 1.0)
|
||||
require.InDelta(t, 0.134, cost.TotalCost, 0.0001)
|
||||
}
|
||||
265
backend/internal/service/claude_code_validator.go
Normal file
265
backend/internal/service/claude_code_validator.go
Normal file
@@ -0,0 +1,265 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
)
|
||||
|
||||
// ClaudeCodeValidator 验证请求是否来自 Claude Code 客户端
|
||||
// 完全学习自 claude-relay-service 项目的验证逻辑
|
||||
type ClaudeCodeValidator struct{}
|
||||
|
||||
var (
|
||||
// User-Agent 匹配: claude-cli/x.x.x (仅支持官方 CLI,大小写不敏感)
|
||||
claudeCodeUAPattern = regexp.MustCompile(`(?i)^claude-cli/\d+\.\d+\.\d+`)
|
||||
|
||||
// metadata.user_id 格式: user_{64位hex}_account__session_{uuid}
|
||||
userIDPattern = regexp.MustCompile(`^user_[a-fA-F0-9]{64}_account__session_[\w-]+$`)
|
||||
|
||||
// System prompt 相似度阈值(默认 0.5,和 claude-relay-service 一致)
|
||||
systemPromptThreshold = 0.5
|
||||
)
|
||||
|
||||
// Claude Code 官方 System Prompt 模板
|
||||
// 从 claude-relay-service/src/utils/contents.js 提取
|
||||
var claudeCodeSystemPrompts = []string{
|
||||
// claudeOtherSystemPrompt1 - Primary
|
||||
"You are Claude Code, Anthropic's official CLI for Claude.",
|
||||
|
||||
// claudeOtherSystemPrompt3 - Agent SDK
|
||||
"You are a Claude agent, built on Anthropic's Claude Agent SDK.",
|
||||
|
||||
// claudeOtherSystemPrompt4 - Compact Agent SDK
|
||||
"You are Claude Code, Anthropic's official CLI for Claude, running within the Claude Agent SDK.",
|
||||
|
||||
// exploreAgentSystemPrompt
|
||||
"You are a file search specialist for Claude Code, Anthropic's official CLI for Claude.",
|
||||
|
||||
// claudeOtherSystemPromptCompact - Compact (用于对话摘要)
|
||||
"You are a helpful AI assistant tasked with summarizing conversations.",
|
||||
|
||||
// claudeOtherSystemPrompt2 - Secondary (长提示词的关键部分)
|
||||
"You are an interactive CLI tool that helps users",
|
||||
}
|
||||
|
||||
// NewClaudeCodeValidator 创建验证器实例
|
||||
func NewClaudeCodeValidator() *ClaudeCodeValidator {
|
||||
return &ClaudeCodeValidator{}
|
||||
}
|
||||
|
||||
// Validate 验证请求是否来自 Claude Code CLI
|
||||
// 采用与 claude-relay-service 完全一致的验证策略:
|
||||
//
|
||||
// Step 1: User-Agent 检查 (必需) - 必须是 claude-cli/x.x.x
|
||||
// Step 2: 对于非 messages 路径,只要 UA 匹配就通过
|
||||
// Step 3: 对于 messages 路径,进行严格验证:
|
||||
// - System prompt 相似度检查
|
||||
// - X-App header 检查
|
||||
// - anthropic-beta header 检查
|
||||
// - anthropic-version header 检查
|
||||
// - metadata.user_id 格式验证
|
||||
func (v *ClaudeCodeValidator) Validate(r *http.Request, body map[string]any) bool {
|
||||
// Step 1: User-Agent 检查
|
||||
ua := r.Header.Get("User-Agent")
|
||||
if !claudeCodeUAPattern.MatchString(ua) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Step 2: 非 messages 路径,只要 UA 匹配就通过
|
||||
path := r.URL.Path
|
||||
if !strings.Contains(path, "messages") {
|
||||
return true
|
||||
}
|
||||
|
||||
// Step 3: messages 路径,进行严格验证
|
||||
|
||||
// 3.1 检查 system prompt 相似度
|
||||
if !v.hasClaudeCodeSystemPrompt(body) {
|
||||
return false
|
||||
}
|
||||
|
||||
// 3.2 检查必需的 headers(值不为空即可)
|
||||
xApp := r.Header.Get("X-App")
|
||||
if xApp == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
anthropicBeta := r.Header.Get("anthropic-beta")
|
||||
if anthropicBeta == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
anthropicVersion := r.Header.Get("anthropic-version")
|
||||
if anthropicVersion == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
// 3.3 验证 metadata.user_id
|
||||
if body == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
metadata, ok := body["metadata"].(map[string]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
userID, ok := metadata["user_id"].(string)
|
||||
if !ok || userID == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
if !userIDPattern.MatchString(userID) {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// hasClaudeCodeSystemPrompt 检查请求是否包含 Claude Code 系统提示词
|
||||
// 使用字符串相似度匹配(Dice coefficient)
|
||||
func (v *ClaudeCodeValidator) hasClaudeCodeSystemPrompt(body map[string]any) bool {
|
||||
if body == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查 model 字段
|
||||
if _, ok := body["model"].(string); !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
// 获取 system 字段
|
||||
systemEntries, ok := body["system"].([]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查每个 system entry
|
||||
for _, entry := range systemEntries {
|
||||
entryMap, ok := entry.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
text, ok := entryMap["text"].(string)
|
||||
if !ok || text == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// 计算与所有模板的最佳相似度
|
||||
bestScore := v.bestSimilarityScore(text)
|
||||
if bestScore >= systemPromptThreshold {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// bestSimilarityScore 计算文本与所有 Claude Code 模板的最佳相似度
|
||||
func (v *ClaudeCodeValidator) bestSimilarityScore(text string) float64 {
|
||||
normalizedText := normalizePrompt(text)
|
||||
bestScore := 0.0
|
||||
|
||||
for _, template := range claudeCodeSystemPrompts {
|
||||
normalizedTemplate := normalizePrompt(template)
|
||||
score := diceCoefficient(normalizedText, normalizedTemplate)
|
||||
if score > bestScore {
|
||||
bestScore = score
|
||||
}
|
||||
}
|
||||
|
||||
return bestScore
|
||||
}
|
||||
|
||||
// normalizePrompt 标准化提示词文本(去除多余空白)
|
||||
func normalizePrompt(text string) string {
|
||||
// 将所有空白字符替换为单个空格,并去除首尾空白
|
||||
return strings.Join(strings.Fields(text), " ")
|
||||
}
|
||||
|
||||
// diceCoefficient 计算两个字符串的 Dice 系数(Sørensen–Dice coefficient)
|
||||
// 这是 string-similarity 库使用的算法
|
||||
// 公式: 2 * |intersection| / (|bigrams(a)| + |bigrams(b)|)
|
||||
func diceCoefficient(a, b string) float64 {
|
||||
if a == b {
|
||||
return 1.0
|
||||
}
|
||||
|
||||
if len(a) < 2 || len(b) < 2 {
|
||||
return 0.0
|
||||
}
|
||||
|
||||
// 生成 bigrams
|
||||
bigramsA := getBigrams(a)
|
||||
bigramsB := getBigrams(b)
|
||||
|
||||
if len(bigramsA) == 0 || len(bigramsB) == 0 {
|
||||
return 0.0
|
||||
}
|
||||
|
||||
// 计算交集大小
|
||||
intersection := 0
|
||||
for bigram, countA := range bigramsA {
|
||||
if countB, exists := bigramsB[bigram]; exists {
|
||||
if countA < countB {
|
||||
intersection += countA
|
||||
} else {
|
||||
intersection += countB
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 计算总 bigram 数量
|
||||
totalA := 0
|
||||
for _, count := range bigramsA {
|
||||
totalA += count
|
||||
}
|
||||
totalB := 0
|
||||
for _, count := range bigramsB {
|
||||
totalB += count
|
||||
}
|
||||
|
||||
return float64(2*intersection) / float64(totalA+totalB)
|
||||
}
|
||||
|
||||
// getBigrams 获取字符串的所有 bigrams(相邻字符对)
|
||||
func getBigrams(s string) map[string]int {
|
||||
bigrams := make(map[string]int)
|
||||
runes := []rune(strings.ToLower(s))
|
||||
|
||||
for i := 0; i < len(runes)-1; i++ {
|
||||
bigram := string(runes[i : i+2])
|
||||
bigrams[bigram]++
|
||||
}
|
||||
|
||||
return bigrams
|
||||
}
|
||||
|
||||
// ValidateUserAgent 仅验证 User-Agent(用于不需要解析请求体的场景)
|
||||
func (v *ClaudeCodeValidator) ValidateUserAgent(ua string) bool {
|
||||
return claudeCodeUAPattern.MatchString(ua)
|
||||
}
|
||||
|
||||
// IncludesClaudeCodeSystemPrompt 检查请求体是否包含 Claude Code 系统提示词
|
||||
// 只要存在匹配的系统提示词就返回 true(用于宽松检测)
|
||||
func (v *ClaudeCodeValidator) IncludesClaudeCodeSystemPrompt(body map[string]any) bool {
|
||||
return v.hasClaudeCodeSystemPrompt(body)
|
||||
}
|
||||
|
||||
// IsClaudeCodeClient 从 context 中获取 Claude Code 客户端标识
|
||||
func IsClaudeCodeClient(ctx context.Context) bool {
|
||||
if v, ok := ctx.Value(ctxkey.IsClaudeCodeClient).(bool); ok {
|
||||
return v
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// SetClaudeCodeClient 将 Claude Code 客户端标识设置到 context 中
|
||||
func SetClaudeCodeClient(ctx context.Context, isClaudeCode bool) context.Context {
|
||||
return context.WithValue(ctx, ctxkey.IsClaudeCodeClient, isClaudeCode)
|
||||
}
|
||||
314
backend/internal/service/concurrency_service.go
Normal file
314
backend/internal/service/concurrency_service.go
Normal file
@@ -0,0 +1,314 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ConcurrencyCache 定义并发控制的缓存接口
|
||||
// 使用有序集合存储槽位,按时间戳清理过期条目
|
||||
type ConcurrencyCache interface {
|
||||
// 账号槽位管理
|
||||
// 键格式: concurrency:account:{accountID}(有序集合,成员为 requestID)
|
||||
AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error)
|
||||
ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error
|
||||
GetAccountConcurrency(ctx context.Context, accountID int64) (int, error)
|
||||
|
||||
// 账号等待队列(账号级)
|
||||
IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error)
|
||||
DecrementAccountWaitCount(ctx context.Context, accountID int64) error
|
||||
GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error)
|
||||
|
||||
// 用户槽位管理
|
||||
// 键格式: concurrency:user:{userID}(有序集合,成员为 requestID)
|
||||
AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error)
|
||||
ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error
|
||||
GetUserConcurrency(ctx context.Context, userID int64) (int, error)
|
||||
|
||||
// 等待队列计数(只在首次创建时设置 TTL)
|
||||
IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error)
|
||||
DecrementWaitCount(ctx context.Context, userID int64) error
|
||||
|
||||
// 批量负载查询(只读)
|
||||
GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error)
|
||||
|
||||
// 清理过期槽位(后台任务)
|
||||
CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error
|
||||
}
|
||||
|
||||
// generateRequestID generates a unique request ID for concurrency slot tracking
|
||||
// Uses 8 random bytes (16 hex chars) for uniqueness
|
||||
func generateRequestID() string {
|
||||
b := make([]byte, 8)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
// Fallback to nanosecond timestamp (extremely rare case)
|
||||
return fmt.Sprintf("%x", time.Now().UnixNano())
|
||||
}
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
|
||||
const (
|
||||
// Default extra wait slots beyond concurrency limit
|
||||
defaultExtraWaitSlots = 20
|
||||
)
|
||||
|
||||
// ConcurrencyService manages concurrent request limiting for accounts and users
|
||||
type ConcurrencyService struct {
|
||||
cache ConcurrencyCache
|
||||
}
|
||||
|
||||
// NewConcurrencyService creates a new ConcurrencyService
|
||||
func NewConcurrencyService(cache ConcurrencyCache) *ConcurrencyService {
|
||||
return &ConcurrencyService{cache: cache}
|
||||
}
|
||||
|
||||
// AcquireResult represents the result of acquiring a concurrency slot
|
||||
type AcquireResult struct {
|
||||
Acquired bool
|
||||
ReleaseFunc func() // Must be called when done (typically via defer)
|
||||
}
|
||||
|
||||
type AccountWithConcurrency struct {
|
||||
ID int64
|
||||
MaxConcurrency int
|
||||
}
|
||||
|
||||
type AccountLoadInfo struct {
|
||||
AccountID int64
|
||||
CurrentConcurrency int
|
||||
WaitingCount int
|
||||
LoadRate int // 0-100+ (percent)
|
||||
}
|
||||
|
||||
// AcquireAccountSlot attempts to acquire a concurrency slot for an account.
|
||||
// If the account is at max concurrency, it waits until a slot is available or timeout.
|
||||
// Returns a release function that MUST be called when the request completes.
|
||||
func (s *ConcurrencyService) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (*AcquireResult, error) {
|
||||
// If maxConcurrency is 0 or negative, no limit
|
||||
if maxConcurrency <= 0 {
|
||||
return &AcquireResult{
|
||||
Acquired: true,
|
||||
ReleaseFunc: func() {}, // no-op
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Generate unique request ID for this slot
|
||||
requestID := generateRequestID()
|
||||
|
||||
acquired, err := s.cache.AcquireAccountSlot(ctx, accountID, maxConcurrency, requestID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if acquired {
|
||||
return &AcquireResult{
|
||||
Acquired: true,
|
||||
ReleaseFunc: func() {
|
||||
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := s.cache.ReleaseAccountSlot(bgCtx, accountID, requestID); err != nil {
|
||||
log.Printf("Warning: failed to release account slot for %d (req=%s): %v", accountID, requestID, err)
|
||||
}
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &AcquireResult{
|
||||
Acquired: false,
|
||||
ReleaseFunc: nil,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// AcquireUserSlot attempts to acquire a concurrency slot for a user.
|
||||
// If the user is at max concurrency, it waits until a slot is available or timeout.
|
||||
// Returns a release function that MUST be called when the request completes.
|
||||
func (s *ConcurrencyService) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int) (*AcquireResult, error) {
|
||||
// If maxConcurrency is 0 or negative, no limit
|
||||
if maxConcurrency <= 0 {
|
||||
return &AcquireResult{
|
||||
Acquired: true,
|
||||
ReleaseFunc: func() {}, // no-op
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Generate unique request ID for this slot
|
||||
requestID := generateRequestID()
|
||||
|
||||
acquired, err := s.cache.AcquireUserSlot(ctx, userID, maxConcurrency, requestID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if acquired {
|
||||
return &AcquireResult{
|
||||
Acquired: true,
|
||||
ReleaseFunc: func() {
|
||||
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := s.cache.ReleaseUserSlot(bgCtx, userID, requestID); err != nil {
|
||||
log.Printf("Warning: failed to release user slot for %d (req=%s): %v", userID, requestID, err)
|
||||
}
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &AcquireResult{
|
||||
Acquired: false,
|
||||
ReleaseFunc: nil,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// Wait Queue Count Methods
|
||||
// ============================================
|
||||
|
||||
// IncrementWaitCount attempts to increment the wait queue counter for a user.
|
||||
// Returns true if successful, false if the wait queue is full.
|
||||
// maxWait should be user.Concurrency + defaultExtraWaitSlots
|
||||
func (s *ConcurrencyService) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
|
||||
if s.cache == nil {
|
||||
// Redis not available, allow request
|
||||
return true, nil
|
||||
}
|
||||
|
||||
result, err := s.cache.IncrementWaitCount(ctx, userID, maxWait)
|
||||
if err != nil {
|
||||
// On error, allow the request to proceed (fail open)
|
||||
log.Printf("Warning: increment wait count failed for user %d: %v", userID, err)
|
||||
return true, nil
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// DecrementWaitCount decrements the wait queue counter for a user.
|
||||
// Should be called when a request completes or exits the wait queue.
|
||||
func (s *ConcurrencyService) DecrementWaitCount(ctx context.Context, userID int64) {
|
||||
if s.cache == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Use background context to ensure decrement even if original context is cancelled
|
||||
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := s.cache.DecrementWaitCount(bgCtx, userID); err != nil {
|
||||
log.Printf("Warning: decrement wait count failed for user %d: %v", userID, err)
|
||||
}
|
||||
}
|
||||
|
||||
// IncrementAccountWaitCount increments the wait queue counter for an account.
|
||||
func (s *ConcurrencyService) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
|
||||
if s.cache == nil {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
result, err := s.cache.IncrementAccountWaitCount(ctx, accountID, maxWait)
|
||||
if err != nil {
|
||||
log.Printf("Warning: increment wait count failed for account %d: %v", accountID, err)
|
||||
return true, nil
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// DecrementAccountWaitCount decrements the wait queue counter for an account.
|
||||
func (s *ConcurrencyService) DecrementAccountWaitCount(ctx context.Context, accountID int64) {
|
||||
if s.cache == nil {
|
||||
return
|
||||
}
|
||||
|
||||
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := s.cache.DecrementAccountWaitCount(bgCtx, accountID); err != nil {
|
||||
log.Printf("Warning: decrement wait count failed for account %d: %v", accountID, err)
|
||||
}
|
||||
}
|
||||
|
||||
// GetAccountWaitingCount gets current wait queue count for an account.
|
||||
func (s *ConcurrencyService) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
|
||||
if s.cache == nil {
|
||||
return 0, nil
|
||||
}
|
||||
return s.cache.GetAccountWaitingCount(ctx, accountID)
|
||||
}
|
||||
|
||||
// CalculateMaxWait calculates the maximum wait queue size for a user
|
||||
// maxWait = userConcurrency + defaultExtraWaitSlots
|
||||
func CalculateMaxWait(userConcurrency int) int {
|
||||
if userConcurrency <= 0 {
|
||||
userConcurrency = 1
|
||||
}
|
||||
return userConcurrency + defaultExtraWaitSlots
|
||||
}
|
||||
|
||||
// GetAccountsLoadBatch returns load info for multiple accounts.
|
||||
func (s *ConcurrencyService) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
|
||||
if s.cache == nil {
|
||||
return map[int64]*AccountLoadInfo{}, nil
|
||||
}
|
||||
return s.cache.GetAccountsLoadBatch(ctx, accounts)
|
||||
}
|
||||
|
||||
// CleanupExpiredAccountSlots removes expired slots for one account (background task).
|
||||
func (s *ConcurrencyService) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error {
|
||||
if s.cache == nil {
|
||||
return nil
|
||||
}
|
||||
return s.cache.CleanupExpiredAccountSlots(ctx, accountID)
|
||||
}
|
||||
|
||||
// StartSlotCleanupWorker starts a background cleanup worker for expired account slots.
|
||||
func (s *ConcurrencyService) StartSlotCleanupWorker(accountRepo AccountRepository, interval time.Duration) {
|
||||
if s == nil || s.cache == nil || accountRepo == nil || interval <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
runCleanup := func() {
|
||||
listCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
accounts, err := accountRepo.ListSchedulable(listCtx)
|
||||
cancel()
|
||||
if err != nil {
|
||||
log.Printf("Warning: list schedulable accounts failed: %v", err)
|
||||
return
|
||||
}
|
||||
for _, account := range accounts {
|
||||
accountCtx, accountCancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
err := s.cache.CleanupExpiredAccountSlots(accountCtx, account.ID)
|
||||
accountCancel()
|
||||
if err != nil {
|
||||
log.Printf("Warning: cleanup expired slots failed for account %d: %v", account.ID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
go func() {
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
runCleanup()
|
||||
for range ticker.C {
|
||||
runCleanup()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// GetAccountConcurrencyBatch gets current concurrency counts for multiple accounts
|
||||
// Returns a map of accountID -> current concurrency count
|
||||
func (s *ConcurrencyService) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
|
||||
result := make(map[int64]int)
|
||||
|
||||
for _, accountID := range accountIDs {
|
||||
count, err := s.cache.GetAccountConcurrency(ctx, accountID)
|
||||
if err != nil {
|
||||
// If key doesn't exist in Redis, count is 0
|
||||
count = 0
|
||||
}
|
||||
result[accountID] = count
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
1255
backend/internal/service/crs_sync_service.go
Normal file
1255
backend/internal/service/crs_sync_service.go
Normal file
File diff suppressed because it is too large
Load Diff
258
backend/internal/service/dashboard_aggregation_service.go
Normal file
258
backend/internal/service/dashboard_aggregation_service.go
Normal file
@@ -0,0 +1,258 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultDashboardAggregationTimeout = 2 * time.Minute
|
||||
defaultDashboardAggregationBackfillTimeout = 30 * time.Minute
|
||||
dashboardAggregationRetentionInterval = 6 * time.Hour
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrDashboardBackfillDisabled 当配置禁用回填时返回。
|
||||
ErrDashboardBackfillDisabled = errors.New("仪表盘聚合回填已禁用")
|
||||
// ErrDashboardBackfillTooLarge 当回填跨度超过限制时返回。
|
||||
ErrDashboardBackfillTooLarge = errors.New("回填时间跨度过大")
|
||||
)
|
||||
|
||||
// DashboardAggregationRepository 定义仪表盘预聚合仓储接口。
|
||||
type DashboardAggregationRepository interface {
|
||||
AggregateRange(ctx context.Context, start, end time.Time) error
|
||||
GetAggregationWatermark(ctx context.Context) (time.Time, error)
|
||||
UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error
|
||||
CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error
|
||||
CleanupUsageLogs(ctx context.Context, cutoff time.Time) error
|
||||
EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error
|
||||
}
|
||||
|
||||
// DashboardAggregationService 负责定时聚合与回填。
|
||||
type DashboardAggregationService struct {
|
||||
repo DashboardAggregationRepository
|
||||
timingWheel *TimingWheelService
|
||||
cfg config.DashboardAggregationConfig
|
||||
running int32
|
||||
lastRetentionCleanup atomic.Value // time.Time
|
||||
}
|
||||
|
||||
// NewDashboardAggregationService 创建聚合服务。
|
||||
func NewDashboardAggregationService(repo DashboardAggregationRepository, timingWheel *TimingWheelService, cfg *config.Config) *DashboardAggregationService {
|
||||
var aggCfg config.DashboardAggregationConfig
|
||||
if cfg != nil {
|
||||
aggCfg = cfg.DashboardAgg
|
||||
}
|
||||
return &DashboardAggregationService{
|
||||
repo: repo,
|
||||
timingWheel: timingWheel,
|
||||
cfg: aggCfg,
|
||||
}
|
||||
}
|
||||
|
||||
// Start 启动定时聚合作业(重启生效配置)。
|
||||
func (s *DashboardAggregationService) Start() {
|
||||
if s == nil || s.repo == nil || s.timingWheel == nil {
|
||||
return
|
||||
}
|
||||
if !s.cfg.Enabled {
|
||||
log.Printf("[DashboardAggregation] 聚合作业已禁用")
|
||||
return
|
||||
}
|
||||
|
||||
interval := time.Duration(s.cfg.IntervalSeconds) * time.Second
|
||||
if interval <= 0 {
|
||||
interval = time.Minute
|
||||
}
|
||||
|
||||
if s.cfg.RecomputeDays > 0 {
|
||||
go s.recomputeRecentDays()
|
||||
}
|
||||
|
||||
s.timingWheel.ScheduleRecurring("dashboard:aggregation", interval, func() {
|
||||
s.runScheduledAggregation()
|
||||
})
|
||||
log.Printf("[DashboardAggregation] 聚合作业启动 (interval=%v, lookback=%ds)", interval, s.cfg.LookbackSeconds)
|
||||
if !s.cfg.BackfillEnabled {
|
||||
log.Printf("[DashboardAggregation] 回填已禁用,如需补齐保留窗口以外历史数据请手动回填")
|
||||
}
|
||||
}
|
||||
|
||||
// TriggerBackfill 触发回填(异步)。
|
||||
func (s *DashboardAggregationService) TriggerBackfill(start, end time.Time) error {
|
||||
if s == nil || s.repo == nil {
|
||||
return errors.New("聚合服务未初始化")
|
||||
}
|
||||
if !s.cfg.BackfillEnabled {
|
||||
log.Printf("[DashboardAggregation] 回填被拒绝: backfill_enabled=false")
|
||||
return ErrDashboardBackfillDisabled
|
||||
}
|
||||
if !end.After(start) {
|
||||
return errors.New("回填时间范围无效")
|
||||
}
|
||||
if s.cfg.BackfillMaxDays > 0 {
|
||||
maxRange := time.Duration(s.cfg.BackfillMaxDays) * 24 * time.Hour
|
||||
if end.Sub(start) > maxRange {
|
||||
return ErrDashboardBackfillTooLarge
|
||||
}
|
||||
}
|
||||
|
||||
go func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultDashboardAggregationBackfillTimeout)
|
||||
defer cancel()
|
||||
if err := s.backfillRange(ctx, start, end); err != nil {
|
||||
log.Printf("[DashboardAggregation] 回填失败: %v", err)
|
||||
}
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *DashboardAggregationService) recomputeRecentDays() {
|
||||
days := s.cfg.RecomputeDays
|
||||
if days <= 0 {
|
||||
return
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
start := now.AddDate(0, 0, -days)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultDashboardAggregationBackfillTimeout)
|
||||
defer cancel()
|
||||
if err := s.backfillRange(ctx, start, now); err != nil {
|
||||
log.Printf("[DashboardAggregation] 启动重算失败: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DashboardAggregationService) runScheduledAggregation() {
|
||||
if !atomic.CompareAndSwapInt32(&s.running, 0, 1) {
|
||||
return
|
||||
}
|
||||
defer atomic.StoreInt32(&s.running, 0)
|
||||
|
||||
jobStart := time.Now().UTC()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultDashboardAggregationTimeout)
|
||||
defer cancel()
|
||||
|
||||
now := time.Now().UTC()
|
||||
last, err := s.repo.GetAggregationWatermark(ctx)
|
||||
if err != nil {
|
||||
log.Printf("[DashboardAggregation] 读取水位失败: %v", err)
|
||||
last = time.Unix(0, 0).UTC()
|
||||
}
|
||||
|
||||
lookback := time.Duration(s.cfg.LookbackSeconds) * time.Second
|
||||
epoch := time.Unix(0, 0).UTC()
|
||||
start := last.Add(-lookback)
|
||||
if !last.After(epoch) {
|
||||
retentionDays := s.cfg.Retention.UsageLogsDays
|
||||
if retentionDays <= 0 {
|
||||
retentionDays = 1
|
||||
}
|
||||
start = truncateToDayUTC(now.AddDate(0, 0, -retentionDays))
|
||||
} else if start.After(now) {
|
||||
start = now.Add(-lookback)
|
||||
}
|
||||
|
||||
if err := s.aggregateRange(ctx, start, now); err != nil {
|
||||
log.Printf("[DashboardAggregation] 聚合失败: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
updateErr := s.repo.UpdateAggregationWatermark(ctx, now)
|
||||
if updateErr != nil {
|
||||
log.Printf("[DashboardAggregation] 更新水位失败: %v", updateErr)
|
||||
}
|
||||
log.Printf("[DashboardAggregation] 聚合完成 (start=%s end=%s duration=%s watermark_updated=%t)",
|
||||
start.Format(time.RFC3339),
|
||||
now.Format(time.RFC3339),
|
||||
time.Since(jobStart).String(),
|
||||
updateErr == nil,
|
||||
)
|
||||
|
||||
s.maybeCleanupRetention(ctx, now)
|
||||
}
|
||||
|
||||
func (s *DashboardAggregationService) backfillRange(ctx context.Context, start, end time.Time) error {
|
||||
if !atomic.CompareAndSwapInt32(&s.running, 0, 1) {
|
||||
return errors.New("聚合作业正在运行")
|
||||
}
|
||||
defer atomic.StoreInt32(&s.running, 0)
|
||||
|
||||
jobStart := time.Now().UTC()
|
||||
startUTC := start.UTC()
|
||||
endUTC := end.UTC()
|
||||
if !endUTC.After(startUTC) {
|
||||
return errors.New("回填时间范围无效")
|
||||
}
|
||||
|
||||
cursor := truncateToDayUTC(startUTC)
|
||||
for cursor.Before(endUTC) {
|
||||
windowEnd := cursor.Add(24 * time.Hour)
|
||||
if windowEnd.After(endUTC) {
|
||||
windowEnd = endUTC
|
||||
}
|
||||
if err := s.aggregateRange(ctx, cursor, windowEnd); err != nil {
|
||||
return err
|
||||
}
|
||||
cursor = windowEnd
|
||||
}
|
||||
|
||||
updateErr := s.repo.UpdateAggregationWatermark(ctx, endUTC)
|
||||
if updateErr != nil {
|
||||
log.Printf("[DashboardAggregation] 更新水位失败: %v", updateErr)
|
||||
}
|
||||
log.Printf("[DashboardAggregation] 回填聚合完成 (start=%s end=%s duration=%s watermark_updated=%t)",
|
||||
startUTC.Format(time.RFC3339),
|
||||
endUTC.Format(time.RFC3339),
|
||||
time.Since(jobStart).String(),
|
||||
updateErr == nil,
|
||||
)
|
||||
|
||||
s.maybeCleanupRetention(ctx, endUTC)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *DashboardAggregationService) aggregateRange(ctx context.Context, start, end time.Time) error {
|
||||
if !end.After(start) {
|
||||
return nil
|
||||
}
|
||||
if err := s.repo.EnsureUsageLogsPartitions(ctx, end); err != nil {
|
||||
log.Printf("[DashboardAggregation] 分区检查失败: %v", err)
|
||||
}
|
||||
return s.repo.AggregateRange(ctx, start, end)
|
||||
}
|
||||
|
||||
func (s *DashboardAggregationService) maybeCleanupRetention(ctx context.Context, now time.Time) {
|
||||
lastAny := s.lastRetentionCleanup.Load()
|
||||
if lastAny != nil {
|
||||
if last, ok := lastAny.(time.Time); ok && now.Sub(last) < dashboardAggregationRetentionInterval {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
hourlyCutoff := now.AddDate(0, 0, -s.cfg.Retention.HourlyDays)
|
||||
dailyCutoff := now.AddDate(0, 0, -s.cfg.Retention.DailyDays)
|
||||
usageCutoff := now.AddDate(0, 0, -s.cfg.Retention.UsageLogsDays)
|
||||
|
||||
aggErr := s.repo.CleanupAggregates(ctx, hourlyCutoff, dailyCutoff)
|
||||
if aggErr != nil {
|
||||
log.Printf("[DashboardAggregation] 聚合保留清理失败: %v", aggErr)
|
||||
}
|
||||
usageErr := s.repo.CleanupUsageLogs(ctx, usageCutoff)
|
||||
if usageErr != nil {
|
||||
log.Printf("[DashboardAggregation] usage_logs 保留清理失败: %v", usageErr)
|
||||
}
|
||||
if aggErr == nil && usageErr == nil {
|
||||
s.lastRetentionCleanup.Store(now)
|
||||
}
|
||||
}
|
||||
|
||||
func truncateToDayUTC(t time.Time) time.Time {
|
||||
t = t.UTC()
|
||||
return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, time.UTC)
|
||||
}
|
||||
106
backend/internal/service/dashboard_aggregation_service_test.go
Normal file
106
backend/internal/service/dashboard_aggregation_service_test.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type dashboardAggregationRepoTestStub struct {
|
||||
aggregateCalls int
|
||||
lastStart time.Time
|
||||
lastEnd time.Time
|
||||
watermark time.Time
|
||||
aggregateErr error
|
||||
cleanupAggregatesErr error
|
||||
cleanupUsageErr error
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoTestStub) AggregateRange(ctx context.Context, start, end time.Time) error {
|
||||
s.aggregateCalls++
|
||||
s.lastStart = start
|
||||
s.lastEnd = end
|
||||
return s.aggregateErr
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoTestStub) GetAggregationWatermark(ctx context.Context) (time.Time, error) {
|
||||
return s.watermark, nil
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoTestStub) UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoTestStub) CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error {
|
||||
return s.cleanupAggregatesErr
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoTestStub) CleanupUsageLogs(ctx context.Context, cutoff time.Time) error {
|
||||
return s.cleanupUsageErr
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoTestStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestDashboardAggregationService_RunScheduledAggregation_EpochUsesRetentionStart(t *testing.T) {
|
||||
repo := &dashboardAggregationRepoTestStub{watermark: time.Unix(0, 0).UTC()}
|
||||
svc := &DashboardAggregationService{
|
||||
repo: repo,
|
||||
cfg: config.DashboardAggregationConfig{
|
||||
Enabled: true,
|
||||
IntervalSeconds: 60,
|
||||
LookbackSeconds: 120,
|
||||
Retention: config.DashboardAggregationRetentionConfig{
|
||||
UsageLogsDays: 1,
|
||||
HourlyDays: 1,
|
||||
DailyDays: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
svc.runScheduledAggregation()
|
||||
|
||||
require.Equal(t, 1, repo.aggregateCalls)
|
||||
require.False(t, repo.lastEnd.IsZero())
|
||||
require.Equal(t, truncateToDayUTC(repo.lastEnd.AddDate(0, 0, -1)), repo.lastStart)
|
||||
}
|
||||
|
||||
func TestDashboardAggregationService_CleanupRetentionFailure_DoesNotRecord(t *testing.T) {
|
||||
repo := &dashboardAggregationRepoTestStub{cleanupAggregatesErr: errors.New("清理失败")}
|
||||
svc := &DashboardAggregationService{
|
||||
repo: repo,
|
||||
cfg: config.DashboardAggregationConfig{
|
||||
Retention: config.DashboardAggregationRetentionConfig{
|
||||
UsageLogsDays: 1,
|
||||
HourlyDays: 1,
|
||||
DailyDays: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
svc.maybeCleanupRetention(context.Background(), time.Now().UTC())
|
||||
|
||||
require.Nil(t, svc.lastRetentionCleanup.Load())
|
||||
}
|
||||
|
||||
func TestDashboardAggregationService_TriggerBackfill_TooLarge(t *testing.T) {
|
||||
repo := &dashboardAggregationRepoTestStub{}
|
||||
svc := &DashboardAggregationService{
|
||||
repo: repo,
|
||||
cfg: config.DashboardAggregationConfig{
|
||||
BackfillEnabled: true,
|
||||
BackfillMaxDays: 1,
|
||||
},
|
||||
}
|
||||
|
||||
start := time.Now().AddDate(0, 0, -3)
|
||||
end := time.Now()
|
||||
err := svc.TriggerBackfill(start, end)
|
||||
require.ErrorIs(t, err, ErrDashboardBackfillTooLarge)
|
||||
require.Equal(t, 0, repo.aggregateCalls)
|
||||
}
|
||||
336
backend/internal/service/dashboard_service.go
Normal file
336
backend/internal/service/dashboard_service.go
Normal file
@@ -0,0 +1,336 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultDashboardStatsFreshTTL = 15 * time.Second
|
||||
defaultDashboardStatsCacheTTL = 30 * time.Second
|
||||
defaultDashboardStatsRefreshTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
// ErrDashboardStatsCacheMiss 标记仪表盘缓存未命中。
|
||||
var ErrDashboardStatsCacheMiss = errors.New("仪表盘缓存未命中")
|
||||
|
||||
// DashboardStatsCache 定义仪表盘统计缓存接口。
|
||||
type DashboardStatsCache interface {
|
||||
GetDashboardStats(ctx context.Context) (string, error)
|
||||
SetDashboardStats(ctx context.Context, data string, ttl time.Duration) error
|
||||
DeleteDashboardStats(ctx context.Context) error
|
||||
}
|
||||
|
||||
type dashboardStatsRangeFetcher interface {
|
||||
GetDashboardStatsWithRange(ctx context.Context, start, end time.Time) (*usagestats.DashboardStats, error)
|
||||
}
|
||||
|
||||
type dashboardStatsCacheEntry struct {
|
||||
Stats *usagestats.DashboardStats `json:"stats"`
|
||||
UpdatedAt int64 `json:"updated_at"`
|
||||
}
|
||||
|
||||
// DashboardService 提供管理员仪表盘统计服务。
|
||||
type DashboardService struct {
|
||||
usageRepo UsageLogRepository
|
||||
aggRepo DashboardAggregationRepository
|
||||
cache DashboardStatsCache
|
||||
cacheFreshTTL time.Duration
|
||||
cacheTTL time.Duration
|
||||
refreshTimeout time.Duration
|
||||
refreshing int32
|
||||
aggEnabled bool
|
||||
aggInterval time.Duration
|
||||
aggLookback time.Duration
|
||||
aggUsageDays int
|
||||
}
|
||||
|
||||
func NewDashboardService(usageRepo UsageLogRepository, aggRepo DashboardAggregationRepository, cache DashboardStatsCache, cfg *config.Config) *DashboardService {
|
||||
freshTTL := defaultDashboardStatsFreshTTL
|
||||
cacheTTL := defaultDashboardStatsCacheTTL
|
||||
refreshTimeout := defaultDashboardStatsRefreshTimeout
|
||||
aggEnabled := true
|
||||
aggInterval := time.Minute
|
||||
aggLookback := 2 * time.Minute
|
||||
aggUsageDays := 90
|
||||
if cfg != nil {
|
||||
if !cfg.Dashboard.Enabled {
|
||||
cache = nil
|
||||
}
|
||||
if cfg.Dashboard.StatsFreshTTLSeconds > 0 {
|
||||
freshTTL = time.Duration(cfg.Dashboard.StatsFreshTTLSeconds) * time.Second
|
||||
}
|
||||
if cfg.Dashboard.StatsTTLSeconds > 0 {
|
||||
cacheTTL = time.Duration(cfg.Dashboard.StatsTTLSeconds) * time.Second
|
||||
}
|
||||
if cfg.Dashboard.StatsRefreshTimeoutSeconds > 0 {
|
||||
refreshTimeout = time.Duration(cfg.Dashboard.StatsRefreshTimeoutSeconds) * time.Second
|
||||
}
|
||||
aggEnabled = cfg.DashboardAgg.Enabled
|
||||
if cfg.DashboardAgg.IntervalSeconds > 0 {
|
||||
aggInterval = time.Duration(cfg.DashboardAgg.IntervalSeconds) * time.Second
|
||||
}
|
||||
if cfg.DashboardAgg.LookbackSeconds > 0 {
|
||||
aggLookback = time.Duration(cfg.DashboardAgg.LookbackSeconds) * time.Second
|
||||
}
|
||||
if cfg.DashboardAgg.Retention.UsageLogsDays > 0 {
|
||||
aggUsageDays = cfg.DashboardAgg.Retention.UsageLogsDays
|
||||
}
|
||||
}
|
||||
if aggRepo == nil {
|
||||
aggEnabled = false
|
||||
}
|
||||
return &DashboardService{
|
||||
usageRepo: usageRepo,
|
||||
aggRepo: aggRepo,
|
||||
cache: cache,
|
||||
cacheFreshTTL: freshTTL,
|
||||
cacheTTL: cacheTTL,
|
||||
refreshTimeout: refreshTimeout,
|
||||
aggEnabled: aggEnabled,
|
||||
aggInterval: aggInterval,
|
||||
aggLookback: aggLookback,
|
||||
aggUsageDays: aggUsageDays,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DashboardService) GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) {
|
||||
if s.cache != nil {
|
||||
cached, fresh, err := s.getCachedDashboardStats(ctx)
|
||||
if err == nil && cached != nil {
|
||||
s.refreshAggregationStaleness(cached)
|
||||
if !fresh {
|
||||
s.refreshDashboardStatsAsync()
|
||||
}
|
||||
return cached, nil
|
||||
}
|
||||
if err != nil && !errors.Is(err, ErrDashboardStatsCacheMiss) {
|
||||
log.Printf("[Dashboard] 仪表盘缓存读取失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
stats, err := s.refreshDashboardStats(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get dashboard stats: %w", err)
|
||||
}
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (s *DashboardService) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool) ([]usagestats.TrendDataPoint, error) {
|
||||
trend, err := s.usageRepo.GetUsageTrendWithFilters(ctx, startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, stream)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get usage trend with filters: %w", err)
|
||||
}
|
||||
return trend, nil
|
||||
}
|
||||
|
||||
func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool) ([]usagestats.ModelStat, error) {
|
||||
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, stream)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get model stats with filters: %w", err)
|
||||
}
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (s *DashboardService) getCachedDashboardStats(ctx context.Context) (*usagestats.DashboardStats, bool, error) {
|
||||
data, err := s.cache.GetDashboardStats(ctx)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
var entry dashboardStatsCacheEntry
|
||||
if err := json.Unmarshal([]byte(data), &entry); err != nil {
|
||||
s.evictDashboardStatsCache(err)
|
||||
return nil, false, ErrDashboardStatsCacheMiss
|
||||
}
|
||||
if entry.Stats == nil {
|
||||
s.evictDashboardStatsCache(errors.New("仪表盘缓存缺少统计数据"))
|
||||
return nil, false, ErrDashboardStatsCacheMiss
|
||||
}
|
||||
|
||||
age := time.Since(time.Unix(entry.UpdatedAt, 0))
|
||||
return entry.Stats, age <= s.cacheFreshTTL, nil
|
||||
}
|
||||
|
||||
func (s *DashboardService) refreshDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) {
|
||||
stats, err := s.fetchDashboardStats(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.applyAggregationStatus(ctx, stats)
|
||||
cacheCtx, cancel := s.cacheOperationContext()
|
||||
defer cancel()
|
||||
s.saveDashboardStatsCache(cacheCtx, stats)
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (s *DashboardService) refreshDashboardStatsAsync() {
|
||||
if s.cache == nil {
|
||||
return
|
||||
}
|
||||
if !atomic.CompareAndSwapInt32(&s.refreshing, 0, 1) {
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer atomic.StoreInt32(&s.refreshing, 0)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), s.refreshTimeout)
|
||||
defer cancel()
|
||||
|
||||
stats, err := s.fetchDashboardStats(ctx)
|
||||
if err != nil {
|
||||
log.Printf("[Dashboard] 仪表盘缓存异步刷新失败: %v", err)
|
||||
return
|
||||
}
|
||||
s.applyAggregationStatus(ctx, stats)
|
||||
cacheCtx, cancel := s.cacheOperationContext()
|
||||
defer cancel()
|
||||
s.saveDashboardStatsCache(cacheCtx, stats)
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *DashboardService) fetchDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) {
|
||||
if !s.aggEnabled {
|
||||
if fetcher, ok := s.usageRepo.(dashboardStatsRangeFetcher); ok {
|
||||
now := time.Now().UTC()
|
||||
start := truncateToDayUTC(now.AddDate(0, 0, -s.aggUsageDays))
|
||||
return fetcher.GetDashboardStatsWithRange(ctx, start, now)
|
||||
}
|
||||
}
|
||||
return s.usageRepo.GetDashboardStats(ctx)
|
||||
}
|
||||
|
||||
func (s *DashboardService) saveDashboardStatsCache(ctx context.Context, stats *usagestats.DashboardStats) {
|
||||
if s.cache == nil || stats == nil {
|
||||
return
|
||||
}
|
||||
|
||||
entry := dashboardStatsCacheEntry{
|
||||
Stats: stats,
|
||||
UpdatedAt: time.Now().Unix(),
|
||||
}
|
||||
data, err := json.Marshal(entry)
|
||||
if err != nil {
|
||||
log.Printf("[Dashboard] 仪表盘缓存序列化失败: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.cache.SetDashboardStats(ctx, string(data), s.cacheTTL); err != nil {
|
||||
log.Printf("[Dashboard] 仪表盘缓存写入失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DashboardService) evictDashboardStatsCache(reason error) {
|
||||
if s.cache == nil {
|
||||
return
|
||||
}
|
||||
cacheCtx, cancel := s.cacheOperationContext()
|
||||
defer cancel()
|
||||
|
||||
if err := s.cache.DeleteDashboardStats(cacheCtx); err != nil {
|
||||
log.Printf("[Dashboard] 仪表盘缓存清理失败: %v", err)
|
||||
}
|
||||
if reason != nil {
|
||||
log.Printf("[Dashboard] 仪表盘缓存异常,已清理: %v", reason)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DashboardService) cacheOperationContext() (context.Context, context.CancelFunc) {
|
||||
return context.WithTimeout(context.Background(), s.refreshTimeout)
|
||||
}
|
||||
|
||||
func (s *DashboardService) applyAggregationStatus(ctx context.Context, stats *usagestats.DashboardStats) {
|
||||
if stats == nil {
|
||||
return
|
||||
}
|
||||
updatedAt := s.fetchAggregationUpdatedAt(ctx)
|
||||
stats.StatsUpdatedAt = updatedAt.UTC().Format(time.RFC3339)
|
||||
stats.StatsStale = s.isAggregationStale(updatedAt, time.Now().UTC())
|
||||
}
|
||||
|
||||
func (s *DashboardService) refreshAggregationStaleness(stats *usagestats.DashboardStats) {
|
||||
if stats == nil {
|
||||
return
|
||||
}
|
||||
updatedAt := parseStatsUpdatedAt(stats.StatsUpdatedAt)
|
||||
stats.StatsStale = s.isAggregationStale(updatedAt, time.Now().UTC())
|
||||
}
|
||||
|
||||
func (s *DashboardService) fetchAggregationUpdatedAt(ctx context.Context) time.Time {
|
||||
if s.aggRepo == nil {
|
||||
return time.Unix(0, 0).UTC()
|
||||
}
|
||||
updatedAt, err := s.aggRepo.GetAggregationWatermark(ctx)
|
||||
if err != nil {
|
||||
log.Printf("[Dashboard] 读取聚合水位失败: %v", err)
|
||||
return time.Unix(0, 0).UTC()
|
||||
}
|
||||
if updatedAt.IsZero() {
|
||||
return time.Unix(0, 0).UTC()
|
||||
}
|
||||
return updatedAt.UTC()
|
||||
}
|
||||
|
||||
func (s *DashboardService) isAggregationStale(updatedAt, now time.Time) bool {
|
||||
if !s.aggEnabled {
|
||||
return true
|
||||
}
|
||||
epoch := time.Unix(0, 0).UTC()
|
||||
if !updatedAt.After(epoch) {
|
||||
return true
|
||||
}
|
||||
threshold := s.aggInterval + s.aggLookback
|
||||
return now.Sub(updatedAt) > threshold
|
||||
}
|
||||
|
||||
func parseStatsUpdatedAt(raw string) time.Time {
|
||||
if raw == "" {
|
||||
return time.Unix(0, 0).UTC()
|
||||
}
|
||||
parsed, err := time.Parse(time.RFC3339, raw)
|
||||
if err != nil {
|
||||
return time.Unix(0, 0).UTC()
|
||||
}
|
||||
return parsed.UTC()
|
||||
}
|
||||
|
||||
func (s *DashboardService) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) {
|
||||
trend, err := s.usageRepo.GetAPIKeyUsageTrend(ctx, startTime, endTime, granularity, limit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get api key usage trend: %w", err)
|
||||
}
|
||||
return trend, nil
|
||||
}
|
||||
|
||||
func (s *DashboardService) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error) {
|
||||
trend, err := s.usageRepo.GetUserUsageTrend(ctx, startTime, endTime, granularity, limit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get user usage trend: %w", err)
|
||||
}
|
||||
return trend, nil
|
||||
}
|
||||
|
||||
func (s *DashboardService) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error) {
|
||||
stats, err := s.usageRepo.GetBatchUserUsageStats(ctx, userIDs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get batch user usage stats: %w", err)
|
||||
}
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (s *DashboardService) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
|
||||
stats, err := s.usageRepo.GetBatchAPIKeyUsageStats(ctx, apiKeyIDs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get batch api key usage stats: %w", err)
|
||||
}
|
||||
return stats, nil
|
||||
}
|
||||
387
backend/internal/service/dashboard_service_test.go
Normal file
387
backend/internal/service/dashboard_service_test.go
Normal file
@@ -0,0 +1,387 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type usageRepoStub struct {
|
||||
UsageLogRepository
|
||||
stats *usagestats.DashboardStats
|
||||
rangeStats *usagestats.DashboardStats
|
||||
err error
|
||||
rangeErr error
|
||||
calls int32
|
||||
rangeCalls int32
|
||||
rangeStart time.Time
|
||||
rangeEnd time.Time
|
||||
onCall chan struct{}
|
||||
}
|
||||
|
||||
func (s *usageRepoStub) GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) {
|
||||
atomic.AddInt32(&s.calls, 1)
|
||||
if s.onCall != nil {
|
||||
select {
|
||||
case s.onCall <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
if s.err != nil {
|
||||
return nil, s.err
|
||||
}
|
||||
return s.stats, nil
|
||||
}
|
||||
|
||||
func (s *usageRepoStub) GetDashboardStatsWithRange(ctx context.Context, start, end time.Time) (*usagestats.DashboardStats, error) {
|
||||
atomic.AddInt32(&s.rangeCalls, 1)
|
||||
s.rangeStart = start
|
||||
s.rangeEnd = end
|
||||
if s.rangeErr != nil {
|
||||
return nil, s.rangeErr
|
||||
}
|
||||
if s.rangeStats != nil {
|
||||
return s.rangeStats, nil
|
||||
}
|
||||
return s.stats, nil
|
||||
}
|
||||
|
||||
type dashboardCacheStub struct {
|
||||
get func(ctx context.Context) (string, error)
|
||||
set func(ctx context.Context, data string, ttl time.Duration) error
|
||||
del func(ctx context.Context) error
|
||||
getCalls int32
|
||||
setCalls int32
|
||||
delCalls int32
|
||||
lastSetMu sync.Mutex
|
||||
lastSet string
|
||||
}
|
||||
|
||||
func (c *dashboardCacheStub) GetDashboardStats(ctx context.Context) (string, error) {
|
||||
atomic.AddInt32(&c.getCalls, 1)
|
||||
if c.get != nil {
|
||||
return c.get(ctx)
|
||||
}
|
||||
return "", ErrDashboardStatsCacheMiss
|
||||
}
|
||||
|
||||
func (c *dashboardCacheStub) SetDashboardStats(ctx context.Context, data string, ttl time.Duration) error {
|
||||
atomic.AddInt32(&c.setCalls, 1)
|
||||
c.lastSetMu.Lock()
|
||||
c.lastSet = data
|
||||
c.lastSetMu.Unlock()
|
||||
if c.set != nil {
|
||||
return c.set(ctx, data, ttl)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *dashboardCacheStub) DeleteDashboardStats(ctx context.Context) error {
|
||||
atomic.AddInt32(&c.delCalls, 1)
|
||||
if c.del != nil {
|
||||
return c.del(ctx)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type dashboardAggregationRepoStub struct {
|
||||
watermark time.Time
|
||||
err error
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoStub) AggregateRange(ctx context.Context, start, end time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoStub) GetAggregationWatermark(ctx context.Context) (time.Time, error) {
|
||||
if s.err != nil {
|
||||
return time.Time{}, s.err
|
||||
}
|
||||
return s.watermark, nil
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoStub) UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoStub) CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoStub) CleanupUsageLogs(ctx context.Context, cutoff time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *dashboardCacheStub) readLastEntry(t *testing.T) dashboardStatsCacheEntry {
|
||||
t.Helper()
|
||||
c.lastSetMu.Lock()
|
||||
data := c.lastSet
|
||||
c.lastSetMu.Unlock()
|
||||
|
||||
var entry dashboardStatsCacheEntry
|
||||
err := json.Unmarshal([]byte(data), &entry)
|
||||
require.NoError(t, err)
|
||||
return entry
|
||||
}
|
||||
|
||||
func TestDashboardService_CacheHitFresh(t *testing.T) {
|
||||
stats := &usagestats.DashboardStats{
|
||||
TotalUsers: 10,
|
||||
StatsUpdatedAt: time.Unix(0, 0).UTC().Format(time.RFC3339),
|
||||
StatsStale: true,
|
||||
}
|
||||
entry := dashboardStatsCacheEntry{
|
||||
Stats: stats,
|
||||
UpdatedAt: time.Now().Unix(),
|
||||
}
|
||||
payload, err := json.Marshal(entry)
|
||||
require.NoError(t, err)
|
||||
|
||||
cache := &dashboardCacheStub{
|
||||
get: func(ctx context.Context) (string, error) {
|
||||
return string(payload), nil
|
||||
},
|
||||
}
|
||||
repo := &usageRepoStub{
|
||||
stats: &usagestats.DashboardStats{TotalUsers: 99},
|
||||
}
|
||||
aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()}
|
||||
cfg := &config.Config{
|
||||
Dashboard: config.DashboardCacheConfig{Enabled: true},
|
||||
DashboardAgg: config.DashboardAggregationConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
}
|
||||
svc := NewDashboardService(repo, aggRepo, cache, cfg)
|
||||
|
||||
got, err := svc.GetDashboardStats(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, stats, got)
|
||||
require.Equal(t, int32(0), atomic.LoadInt32(&repo.calls))
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&cache.getCalls))
|
||||
require.Equal(t, int32(0), atomic.LoadInt32(&cache.setCalls))
|
||||
}
|
||||
|
||||
func TestDashboardService_CacheMiss_StoresCache(t *testing.T) {
|
||||
stats := &usagestats.DashboardStats{
|
||||
TotalUsers: 7,
|
||||
StatsUpdatedAt: time.Unix(0, 0).UTC().Format(time.RFC3339),
|
||||
StatsStale: true,
|
||||
}
|
||||
cache := &dashboardCacheStub{
|
||||
get: func(ctx context.Context) (string, error) {
|
||||
return "", ErrDashboardStatsCacheMiss
|
||||
},
|
||||
}
|
||||
repo := &usageRepoStub{stats: stats}
|
||||
aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()}
|
||||
cfg := &config.Config{
|
||||
Dashboard: config.DashboardCacheConfig{Enabled: true},
|
||||
DashboardAgg: config.DashboardAggregationConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
}
|
||||
svc := NewDashboardService(repo, aggRepo, cache, cfg)
|
||||
|
||||
got, err := svc.GetDashboardStats(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, stats, got)
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&repo.calls))
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&cache.getCalls))
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&cache.setCalls))
|
||||
entry := cache.readLastEntry(t)
|
||||
require.Equal(t, stats, entry.Stats)
|
||||
require.WithinDuration(t, time.Now(), time.Unix(entry.UpdatedAt, 0), time.Second)
|
||||
}
|
||||
|
||||
func TestDashboardService_CacheDisabled_SkipsCache(t *testing.T) {
|
||||
stats := &usagestats.DashboardStats{
|
||||
TotalUsers: 3,
|
||||
StatsUpdatedAt: time.Unix(0, 0).UTC().Format(time.RFC3339),
|
||||
StatsStale: true,
|
||||
}
|
||||
cache := &dashboardCacheStub{
|
||||
get: func(ctx context.Context) (string, error) {
|
||||
return "", nil
|
||||
},
|
||||
}
|
||||
repo := &usageRepoStub{stats: stats}
|
||||
aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()}
|
||||
cfg := &config.Config{
|
||||
Dashboard: config.DashboardCacheConfig{Enabled: false},
|
||||
DashboardAgg: config.DashboardAggregationConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
}
|
||||
svc := NewDashboardService(repo, aggRepo, cache, cfg)
|
||||
|
||||
got, err := svc.GetDashboardStats(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, stats, got)
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&repo.calls))
|
||||
require.Equal(t, int32(0), atomic.LoadInt32(&cache.getCalls))
|
||||
require.Equal(t, int32(0), atomic.LoadInt32(&cache.setCalls))
|
||||
}
|
||||
|
||||
func TestDashboardService_CacheHitStale_TriggersAsyncRefresh(t *testing.T) {
|
||||
staleStats := &usagestats.DashboardStats{
|
||||
TotalUsers: 11,
|
||||
StatsUpdatedAt: time.Unix(0, 0).UTC().Format(time.RFC3339),
|
||||
StatsStale: true,
|
||||
}
|
||||
entry := dashboardStatsCacheEntry{
|
||||
Stats: staleStats,
|
||||
UpdatedAt: time.Now().Add(-defaultDashboardStatsFreshTTL * 2).Unix(),
|
||||
}
|
||||
payload, err := json.Marshal(entry)
|
||||
require.NoError(t, err)
|
||||
|
||||
cache := &dashboardCacheStub{
|
||||
get: func(ctx context.Context) (string, error) {
|
||||
return string(payload), nil
|
||||
},
|
||||
}
|
||||
refreshCh := make(chan struct{}, 1)
|
||||
repo := &usageRepoStub{
|
||||
stats: &usagestats.DashboardStats{TotalUsers: 22},
|
||||
onCall: refreshCh,
|
||||
}
|
||||
aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()}
|
||||
cfg := &config.Config{
|
||||
Dashboard: config.DashboardCacheConfig{Enabled: true},
|
||||
DashboardAgg: config.DashboardAggregationConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
}
|
||||
svc := NewDashboardService(repo, aggRepo, cache, cfg)
|
||||
|
||||
got, err := svc.GetDashboardStats(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, staleStats, got)
|
||||
|
||||
select {
|
||||
case <-refreshCh:
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Fatal("等待异步刷新超时")
|
||||
}
|
||||
require.Eventually(t, func() bool {
|
||||
return atomic.LoadInt32(&cache.setCalls) >= 1
|
||||
}, 1*time.Second, 10*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestDashboardService_CacheParseError_EvictsAndRefetches(t *testing.T) {
|
||||
cache := &dashboardCacheStub{
|
||||
get: func(ctx context.Context) (string, error) {
|
||||
return "not-json", nil
|
||||
},
|
||||
}
|
||||
stats := &usagestats.DashboardStats{TotalUsers: 9}
|
||||
repo := &usageRepoStub{stats: stats}
|
||||
aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()}
|
||||
cfg := &config.Config{
|
||||
Dashboard: config.DashboardCacheConfig{Enabled: true},
|
||||
DashboardAgg: config.DashboardAggregationConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
}
|
||||
svc := NewDashboardService(repo, aggRepo, cache, cfg)
|
||||
|
||||
got, err := svc.GetDashboardStats(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, stats, got)
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&cache.delCalls))
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&repo.calls))
|
||||
}
|
||||
|
||||
func TestDashboardService_CacheParseError_RepoFailure(t *testing.T) {
|
||||
cache := &dashboardCacheStub{
|
||||
get: func(ctx context.Context) (string, error) {
|
||||
return "not-json", nil
|
||||
},
|
||||
}
|
||||
repo := &usageRepoStub{err: errors.New("db down")}
|
||||
aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()}
|
||||
cfg := &config.Config{
|
||||
Dashboard: config.DashboardCacheConfig{Enabled: true},
|
||||
DashboardAgg: config.DashboardAggregationConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
}
|
||||
svc := NewDashboardService(repo, aggRepo, cache, cfg)
|
||||
|
||||
_, err := svc.GetDashboardStats(context.Background())
|
||||
require.Error(t, err)
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&cache.delCalls))
|
||||
}
|
||||
|
||||
func TestDashboardService_StatsUpdatedAtEpochWhenMissing(t *testing.T) {
|
||||
stats := &usagestats.DashboardStats{}
|
||||
repo := &usageRepoStub{stats: stats}
|
||||
aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()}
|
||||
cfg := &config.Config{Dashboard: config.DashboardCacheConfig{Enabled: false}}
|
||||
svc := NewDashboardService(repo, aggRepo, nil, cfg)
|
||||
|
||||
got, err := svc.GetDashboardStats(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "1970-01-01T00:00:00Z", got.StatsUpdatedAt)
|
||||
require.True(t, got.StatsStale)
|
||||
}
|
||||
|
||||
func TestDashboardService_StatsStaleFalseWhenFresh(t *testing.T) {
|
||||
aggNow := time.Now().UTC().Truncate(time.Second)
|
||||
stats := &usagestats.DashboardStats{}
|
||||
repo := &usageRepoStub{stats: stats}
|
||||
aggRepo := &dashboardAggregationRepoStub{watermark: aggNow}
|
||||
cfg := &config.Config{
|
||||
Dashboard: config.DashboardCacheConfig{Enabled: false},
|
||||
DashboardAgg: config.DashboardAggregationConfig{
|
||||
Enabled: true,
|
||||
IntervalSeconds: 60,
|
||||
LookbackSeconds: 120,
|
||||
},
|
||||
}
|
||||
svc := NewDashboardService(repo, aggRepo, nil, cfg)
|
||||
|
||||
got, err := svc.GetDashboardStats(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, aggNow.Format(time.RFC3339), got.StatsUpdatedAt)
|
||||
require.False(t, got.StatsStale)
|
||||
}
|
||||
|
||||
func TestDashboardService_AggDisabled_UsesUsageLogsFallback(t *testing.T) {
|
||||
expected := &usagestats.DashboardStats{TotalUsers: 42}
|
||||
repo := &usageRepoStub{
|
||||
rangeStats: expected,
|
||||
err: errors.New("should not call aggregated stats"),
|
||||
}
|
||||
cfg := &config.Config{
|
||||
Dashboard: config.DashboardCacheConfig{Enabled: false},
|
||||
DashboardAgg: config.DashboardAggregationConfig{
|
||||
Enabled: false,
|
||||
Retention: config.DashboardAggregationRetentionConfig{
|
||||
UsageLogsDays: 7,
|
||||
},
|
||||
},
|
||||
}
|
||||
svc := NewDashboardService(repo, nil, nil, cfg)
|
||||
|
||||
got, err := svc.GetDashboardStats(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(42), got.TotalUsers)
|
||||
require.Equal(t, int32(0), atomic.LoadInt32(&repo.calls))
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&repo.rangeCalls))
|
||||
require.False(t, repo.rangeEnd.IsZero())
|
||||
require.Equal(t, truncateToDayUTC(repo.rangeEnd.AddDate(0, 0, -7)), repo.rangeStart)
|
||||
}
|
||||
76
backend/internal/service/deferred_service.go
Normal file
76
backend/internal/service/deferred_service.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// DeferredService provides deferred batch update functionality
|
||||
type DeferredService struct {
|
||||
accountRepo AccountRepository
|
||||
timingWheel *TimingWheelService
|
||||
interval time.Duration
|
||||
|
||||
lastUsedUpdates sync.Map
|
||||
}
|
||||
|
||||
// NewDeferredService creates a new DeferredService instance
|
||||
func NewDeferredService(accountRepo AccountRepository, timingWheel *TimingWheelService, interval time.Duration) *DeferredService {
|
||||
return &DeferredService{
|
||||
accountRepo: accountRepo,
|
||||
timingWheel: timingWheel,
|
||||
interval: interval,
|
||||
}
|
||||
}
|
||||
|
||||
// Start starts the deferred service
|
||||
func (s *DeferredService) Start() {
|
||||
s.timingWheel.ScheduleRecurring("deferred:last_used", s.interval, s.flushLastUsed)
|
||||
log.Printf("[DeferredService] Started (interval: %v)", s.interval)
|
||||
}
|
||||
|
||||
// Stop stops the deferred service
|
||||
func (s *DeferredService) Stop() {
|
||||
s.timingWheel.Cancel("deferred:last_used")
|
||||
s.flushLastUsed()
|
||||
log.Printf("[DeferredService] Service stopped")
|
||||
}
|
||||
|
||||
func (s *DeferredService) ScheduleLastUsedUpdate(accountID int64) {
|
||||
s.lastUsedUpdates.Store(accountID, time.Now())
|
||||
}
|
||||
|
||||
func (s *DeferredService) flushLastUsed() {
|
||||
updates := make(map[int64]time.Time)
|
||||
s.lastUsedUpdates.Range(func(key, value any) bool {
|
||||
id, ok := key.(int64)
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
ts, ok := value.(time.Time)
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
updates[id] = ts
|
||||
s.lastUsedUpdates.Delete(key)
|
||||
return true
|
||||
})
|
||||
|
||||
if len(updates) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := s.accountRepo.BatchUpdateLastUsed(ctx, updates); err != nil {
|
||||
log.Printf("[DeferredService] BatchUpdateLastUsed failed (%d accounts): %v", len(updates), err)
|
||||
for id, ts := range updates {
|
||||
s.lastUsedUpdates.Store(id, ts)
|
||||
}
|
||||
} else {
|
||||
log.Printf("[DeferredService] BatchUpdateLastUsed flushed %d accounts", len(updates))
|
||||
}
|
||||
}
|
||||
159
backend/internal/service/domain_constants.go
Normal file
159
backend/internal/service/domain_constants.go
Normal file
@@ -0,0 +1,159 @@
|
||||
package service
|
||||
|
||||
// Status constants
|
||||
const (
|
||||
StatusActive = "active"
|
||||
StatusDisabled = "disabled"
|
||||
StatusError = "error"
|
||||
StatusUnused = "unused"
|
||||
StatusUsed = "used"
|
||||
StatusExpired = "expired"
|
||||
)
|
||||
|
||||
// Role constants
|
||||
const (
|
||||
RoleAdmin = "admin"
|
||||
RoleUser = "user"
|
||||
)
|
||||
|
||||
// Platform constants
|
||||
const (
|
||||
PlatformAnthropic = "anthropic"
|
||||
PlatformOpenAI = "openai"
|
||||
PlatformGemini = "gemini"
|
||||
PlatformAntigravity = "antigravity"
|
||||
)
|
||||
|
||||
// Account type constants
|
||||
const (
|
||||
AccountTypeOAuth = "oauth" // OAuth类型账号(full scope: profile + inference)
|
||||
AccountTypeSetupToken = "setup-token" // Setup Token类型账号(inference only scope)
|
||||
AccountTypeAPIKey = "apikey" // API Key类型账号
|
||||
)
|
||||
|
||||
// Redeem type constants
|
||||
const (
|
||||
RedeemTypeBalance = "balance"
|
||||
RedeemTypeConcurrency = "concurrency"
|
||||
RedeemTypeSubscription = "subscription"
|
||||
)
|
||||
|
||||
// PromoCode status constants
|
||||
const (
|
||||
PromoCodeStatusActive = "active"
|
||||
PromoCodeStatusDisabled = "disabled"
|
||||
)
|
||||
|
||||
// Admin adjustment type constants
|
||||
const (
|
||||
AdjustmentTypeAdminBalance = "admin_balance" // 管理员调整余额
|
||||
AdjustmentTypeAdminConcurrency = "admin_concurrency" // 管理员调整并发数
|
||||
)
|
||||
|
||||
// Group subscription type constants
|
||||
const (
|
||||
SubscriptionTypeStandard = "standard" // 标准计费模式(按余额扣费)
|
||||
SubscriptionTypeSubscription = "subscription" // 订阅模式(按限额控制)
|
||||
)
|
||||
|
||||
// Subscription status constants
|
||||
const (
|
||||
SubscriptionStatusActive = "active"
|
||||
SubscriptionStatusExpired = "expired"
|
||||
SubscriptionStatusSuspended = "suspended"
|
||||
)
|
||||
|
||||
// LinuxDoConnectSyntheticEmailDomain 是 LinuxDo Connect 用户的合成邮箱后缀(RFC 保留域名)。
|
||||
const LinuxDoConnectSyntheticEmailDomain = "@linuxdo-connect.invalid"
|
||||
|
||||
// Setting keys
|
||||
const (
|
||||
// 注册设置
|
||||
SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册
|
||||
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
|
||||
|
||||
// 邮件服务设置
|
||||
SettingKeySMTPHost = "smtp_host" // SMTP服务器地址
|
||||
SettingKeySMTPPort = "smtp_port" // SMTP端口
|
||||
SettingKeySMTPUsername = "smtp_username" // SMTP用户名
|
||||
SettingKeySMTPPassword = "smtp_password" // SMTP密码(加密存储)
|
||||
SettingKeySMTPFrom = "smtp_from" // 发件人地址
|
||||
SettingKeySMTPFromName = "smtp_from_name" // 发件人名称
|
||||
SettingKeySMTPUseTLS = "smtp_use_tls" // 是否使用TLS
|
||||
|
||||
// Cloudflare Turnstile 设置
|
||||
SettingKeyTurnstileEnabled = "turnstile_enabled" // 是否启用 Turnstile 验证
|
||||
SettingKeyTurnstileSiteKey = "turnstile_site_key" // Turnstile Site Key
|
||||
SettingKeyTurnstileSecretKey = "turnstile_secret_key" // Turnstile Secret Key
|
||||
|
||||
// LinuxDo Connect OAuth 登录设置
|
||||
SettingKeyLinuxDoConnectEnabled = "linuxdo_connect_enabled"
|
||||
SettingKeyLinuxDoConnectClientID = "linuxdo_connect_client_id"
|
||||
SettingKeyLinuxDoConnectClientSecret = "linuxdo_connect_client_secret"
|
||||
SettingKeyLinuxDoConnectRedirectURL = "linuxdo_connect_redirect_url"
|
||||
|
||||
// OEM设置
|
||||
SettingKeySiteName = "site_name" // 网站名称
|
||||
SettingKeySiteLogo = "site_logo" // 网站Logo (base64)
|
||||
SettingKeySiteSubtitle = "site_subtitle" // 网站副标题
|
||||
SettingKeyAPIBaseURL = "api_base_url" // API端点地址(用于客户端配置和导入)
|
||||
SettingKeyContactInfo = "contact_info" // 客服联系方式
|
||||
SettingKeyDocURL = "doc_url" // 文档链接
|
||||
SettingKeyHomeContent = "home_content" // 首页内容(支持 Markdown/HTML,或 URL 作为 iframe src)
|
||||
|
||||
// 默认配置
|
||||
SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量
|
||||
SettingKeyDefaultBalance = "default_balance" // 新用户默认余额
|
||||
|
||||
// 管理员 API Key
|
||||
SettingKeyAdminAPIKey = "admin_api_key" // 全局管理员 API Key(用于外部系统集成)
|
||||
|
||||
// Gemini 配额策略(JSON)
|
||||
SettingKeyGeminiQuotaPolicy = "gemini_quota_policy"
|
||||
|
||||
// Model fallback settings
|
||||
SettingKeyEnableModelFallback = "enable_model_fallback"
|
||||
SettingKeyFallbackModelAnthropic = "fallback_model_anthropic"
|
||||
SettingKeyFallbackModelOpenAI = "fallback_model_openai"
|
||||
SettingKeyFallbackModelGemini = "fallback_model_gemini"
|
||||
SettingKeyFallbackModelAntigravity = "fallback_model_antigravity"
|
||||
|
||||
// Request identity patch (Claude -> Gemini systemInstruction injection)
|
||||
SettingKeyEnableIdentityPatch = "enable_identity_patch"
|
||||
SettingKeyIdentityPatchPrompt = "identity_patch_prompt"
|
||||
|
||||
// =========================
|
||||
// Ops Monitoring (vNext)
|
||||
// =========================
|
||||
|
||||
// SettingKeyOpsMonitoringEnabled is a DB-backed soft switch to enable/disable ops module at runtime.
|
||||
SettingKeyOpsMonitoringEnabled = "ops_monitoring_enabled"
|
||||
|
||||
// SettingKeyOpsRealtimeMonitoringEnabled controls realtime features (e.g. WS/QPS push).
|
||||
SettingKeyOpsRealtimeMonitoringEnabled = "ops_realtime_monitoring_enabled"
|
||||
|
||||
// SettingKeyOpsQueryModeDefault controls the default query mode for ops dashboard (auto/raw/preagg).
|
||||
SettingKeyOpsQueryModeDefault = "ops_query_mode_default"
|
||||
|
||||
// SettingKeyOpsEmailNotificationConfig stores JSON config for ops email notifications.
|
||||
SettingKeyOpsEmailNotificationConfig = "ops_email_notification_config"
|
||||
|
||||
// SettingKeyOpsAlertRuntimeSettings stores JSON config for ops alert evaluator runtime settings.
|
||||
SettingKeyOpsAlertRuntimeSettings = "ops_alert_runtime_settings"
|
||||
|
||||
// SettingKeyOpsMetricsIntervalSeconds controls the ops metrics collector interval (>=60).
|
||||
SettingKeyOpsMetricsIntervalSeconds = "ops_metrics_interval_seconds"
|
||||
|
||||
// SettingKeyOpsAdvancedSettings stores JSON config for ops advanced settings (data retention, aggregation).
|
||||
SettingKeyOpsAdvancedSettings = "ops_advanced_settings"
|
||||
|
||||
// =========================
|
||||
// Stream Timeout Handling
|
||||
// =========================
|
||||
|
||||
// SettingKeyStreamTimeoutSettings stores JSON config for stream timeout handling.
|
||||
SettingKeyStreamTimeoutSettings = "stream_timeout_settings"
|
||||
)
|
||||
|
||||
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys).
|
||||
const AdminAPIKeyPrefix = "admin-"
|
||||
109
backend/internal/service/email_queue_service.go
Normal file
109
backend/internal/service/email_queue_service.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// EmailTask 邮件发送任务
|
||||
type EmailTask struct {
|
||||
Email string
|
||||
SiteName string
|
||||
TaskType string // "verify_code"
|
||||
}
|
||||
|
||||
// EmailQueueService 异步邮件队列服务
|
||||
type EmailQueueService struct {
|
||||
emailService *EmailService
|
||||
taskChan chan EmailTask
|
||||
wg sync.WaitGroup
|
||||
stopChan chan struct{}
|
||||
workers int
|
||||
}
|
||||
|
||||
// NewEmailQueueService 创建邮件队列服务
|
||||
func NewEmailQueueService(emailService *EmailService, workers int) *EmailQueueService {
|
||||
if workers <= 0 {
|
||||
workers = 3 // 默认3个工作协程
|
||||
}
|
||||
|
||||
service := &EmailQueueService{
|
||||
emailService: emailService,
|
||||
taskChan: make(chan EmailTask, 100), // 缓冲100个任务
|
||||
stopChan: make(chan struct{}),
|
||||
workers: workers,
|
||||
}
|
||||
|
||||
// 启动工作协程
|
||||
service.start()
|
||||
|
||||
return service
|
||||
}
|
||||
|
||||
// start 启动工作协程
|
||||
func (s *EmailQueueService) start() {
|
||||
for i := 0; i < s.workers; i++ {
|
||||
s.wg.Add(1)
|
||||
go s.worker(i)
|
||||
}
|
||||
log.Printf("[EmailQueue] Started %d workers", s.workers)
|
||||
}
|
||||
|
||||
// worker 工作协程
|
||||
func (s *EmailQueueService) worker(id int) {
|
||||
defer s.wg.Done()
|
||||
|
||||
for {
|
||||
select {
|
||||
case task := <-s.taskChan:
|
||||
s.processTask(id, task)
|
||||
case <-s.stopChan:
|
||||
log.Printf("[EmailQueue] Worker %d stopping", id)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// processTask 处理任务
|
||||
func (s *EmailQueueService) processTask(workerID int, task EmailTask) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
switch task.TaskType {
|
||||
case "verify_code":
|
||||
if err := s.emailService.SendVerifyCode(ctx, task.Email, task.SiteName); err != nil {
|
||||
log.Printf("[EmailQueue] Worker %d failed to send verify code to %s: %v", workerID, task.Email, err)
|
||||
} else {
|
||||
log.Printf("[EmailQueue] Worker %d sent verify code to %s", workerID, task.Email)
|
||||
}
|
||||
default:
|
||||
log.Printf("[EmailQueue] Worker %d unknown task type: %s", workerID, task.TaskType)
|
||||
}
|
||||
}
|
||||
|
||||
// EnqueueVerifyCode 将验证码发送任务加入队列
|
||||
func (s *EmailQueueService) EnqueueVerifyCode(email, siteName string) error {
|
||||
task := EmailTask{
|
||||
Email: email,
|
||||
SiteName: siteName,
|
||||
TaskType: "verify_code",
|
||||
}
|
||||
|
||||
select {
|
||||
case s.taskChan <- task:
|
||||
log.Printf("[EmailQueue] Enqueued verify code task for %s", email)
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("email queue is full")
|
||||
}
|
||||
}
|
||||
|
||||
// Stop 停止队列服务
|
||||
func (s *EmailQueueService) Stop() {
|
||||
close(s.stopChan)
|
||||
s.wg.Wait()
|
||||
log.Println("[EmailQueue] All workers stopped")
|
||||
}
|
||||
359
backend/internal/service/email_service.go
Normal file
359
backend/internal/service/email_service.go
Normal file
@@ -0,0 +1,359 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"log"
|
||||
"math/big"
|
||||
"net/smtp"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrEmailNotConfigured = infraerrors.ServiceUnavailable("EMAIL_NOT_CONFIGURED", "email service not configured")
|
||||
ErrInvalidVerifyCode = infraerrors.BadRequest("INVALID_VERIFY_CODE", "invalid or expired verification code")
|
||||
ErrVerifyCodeTooFrequent = infraerrors.TooManyRequests("VERIFY_CODE_TOO_FREQUENT", "please wait before requesting a new code")
|
||||
ErrVerifyCodeMaxAttempts = infraerrors.TooManyRequests("VERIFY_CODE_MAX_ATTEMPTS", "too many failed attempts, please request a new code")
|
||||
)
|
||||
|
||||
// EmailCache defines cache operations for email service
|
||||
type EmailCache interface {
|
||||
GetVerificationCode(ctx context.Context, email string) (*VerificationCodeData, error)
|
||||
SetVerificationCode(ctx context.Context, email string, data *VerificationCodeData, ttl time.Duration) error
|
||||
DeleteVerificationCode(ctx context.Context, email string) error
|
||||
}
|
||||
|
||||
// VerificationCodeData represents verification code data
|
||||
type VerificationCodeData struct {
|
||||
Code string
|
||||
Attempts int
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
const (
|
||||
verifyCodeTTL = 15 * time.Minute
|
||||
verifyCodeCooldown = 1 * time.Minute
|
||||
maxVerifyCodeAttempts = 5
|
||||
)
|
||||
|
||||
// SMTPConfig SMTP配置
|
||||
type SMTPConfig struct {
|
||||
Host string
|
||||
Port int
|
||||
Username string
|
||||
Password string
|
||||
From string
|
||||
FromName string
|
||||
UseTLS bool
|
||||
}
|
||||
|
||||
// EmailService 邮件服务
|
||||
type EmailService struct {
|
||||
settingRepo SettingRepository
|
||||
cache EmailCache
|
||||
}
|
||||
|
||||
// NewEmailService 创建邮件服务实例
|
||||
func NewEmailService(settingRepo SettingRepository, cache EmailCache) *EmailService {
|
||||
return &EmailService{
|
||||
settingRepo: settingRepo,
|
||||
cache: cache,
|
||||
}
|
||||
}
|
||||
|
||||
// GetSMTPConfig 从数据库获取SMTP配置
|
||||
func (s *EmailService) GetSMTPConfig(ctx context.Context) (*SMTPConfig, error) {
|
||||
keys := []string{
|
||||
SettingKeySMTPHost,
|
||||
SettingKeySMTPPort,
|
||||
SettingKeySMTPUsername,
|
||||
SettingKeySMTPPassword,
|
||||
SettingKeySMTPFrom,
|
||||
SettingKeySMTPFromName,
|
||||
SettingKeySMTPUseTLS,
|
||||
}
|
||||
|
||||
settings, err := s.settingRepo.GetMultiple(ctx, keys)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get smtp settings: %w", err)
|
||||
}
|
||||
|
||||
host := settings[SettingKeySMTPHost]
|
||||
if host == "" {
|
||||
return nil, ErrEmailNotConfigured
|
||||
}
|
||||
|
||||
port := 587 // 默认端口
|
||||
if portStr := settings[SettingKeySMTPPort]; portStr != "" {
|
||||
if p, err := strconv.Atoi(portStr); err == nil {
|
||||
port = p
|
||||
}
|
||||
}
|
||||
|
||||
useTLS := settings[SettingKeySMTPUseTLS] == "true"
|
||||
|
||||
return &SMTPConfig{
|
||||
Host: host,
|
||||
Port: port,
|
||||
Username: settings[SettingKeySMTPUsername],
|
||||
Password: settings[SettingKeySMTPPassword],
|
||||
From: settings[SettingKeySMTPFrom],
|
||||
FromName: settings[SettingKeySMTPFromName],
|
||||
UseTLS: useTLS,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SendEmail 发送邮件(使用数据库中保存的配置)
|
||||
func (s *EmailService) SendEmail(ctx context.Context, to, subject, body string) error {
|
||||
config, err := s.GetSMTPConfig(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return s.SendEmailWithConfig(config, to, subject, body)
|
||||
}
|
||||
|
||||
// SendEmailWithConfig 使用指定配置发送邮件
|
||||
func (s *EmailService) SendEmailWithConfig(config *SMTPConfig, to, subject, body string) error {
|
||||
from := config.From
|
||||
if config.FromName != "" {
|
||||
from = fmt.Sprintf("%s <%s>", config.FromName, config.From)
|
||||
}
|
||||
|
||||
msg := fmt.Sprintf("From: %s\r\nTo: %s\r\nSubject: %s\r\nMIME-Version: 1.0\r\nContent-Type: text/html; charset=UTF-8\r\n\r\n%s",
|
||||
from, to, subject, body)
|
||||
|
||||
addr := fmt.Sprintf("%s:%d", config.Host, config.Port)
|
||||
auth := smtp.PlainAuth("", config.Username, config.Password, config.Host)
|
||||
|
||||
if config.UseTLS {
|
||||
return s.sendMailTLS(addr, auth, config.From, to, []byte(msg), config.Host)
|
||||
}
|
||||
|
||||
return smtp.SendMail(addr, auth, config.From, []string{to}, []byte(msg))
|
||||
}
|
||||
|
||||
// sendMailTLS 使用TLS发送邮件
|
||||
func (s *EmailService) sendMailTLS(addr string, auth smtp.Auth, from, to string, msg []byte, host string) error {
|
||||
tlsConfig := &tls.Config{
|
||||
ServerName: host,
|
||||
// 强制 TLS 1.2+,避免协议降级导致的弱加密风险。
|
||||
MinVersion: tls.VersionTLS12,
|
||||
}
|
||||
|
||||
conn, err := tls.Dial("tcp", addr, tlsConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("tls dial: %w", err)
|
||||
}
|
||||
defer func() { _ = conn.Close() }()
|
||||
|
||||
client, err := smtp.NewClient(conn, host)
|
||||
if err != nil {
|
||||
return fmt.Errorf("new smtp client: %w", err)
|
||||
}
|
||||
defer func() { _ = client.Close() }()
|
||||
|
||||
if err = client.Auth(auth); err != nil {
|
||||
return fmt.Errorf("smtp auth: %w", err)
|
||||
}
|
||||
|
||||
if err = client.Mail(from); err != nil {
|
||||
return fmt.Errorf("smtp mail: %w", err)
|
||||
}
|
||||
|
||||
if err = client.Rcpt(to); err != nil {
|
||||
return fmt.Errorf("smtp rcpt: %w", err)
|
||||
}
|
||||
|
||||
w, err := client.Data()
|
||||
if err != nil {
|
||||
return fmt.Errorf("smtp data: %w", err)
|
||||
}
|
||||
|
||||
_, err = w.Write(msg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("write msg: %w", err)
|
||||
}
|
||||
|
||||
err = w.Close()
|
||||
if err != nil {
|
||||
return fmt.Errorf("close writer: %w", err)
|
||||
}
|
||||
|
||||
// Email is sent successfully after w.Close(), ignore Quit errors
|
||||
// Some SMTP servers return non-standard responses on QUIT
|
||||
_ = client.Quit()
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateVerifyCode 生成6位数字验证码
|
||||
func (s *EmailService) GenerateVerifyCode() (string, error) {
|
||||
const digits = "0123456789"
|
||||
code := make([]byte, 6)
|
||||
for i := range code {
|
||||
num, err := rand.Int(rand.Reader, big.NewInt(int64(len(digits))))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
code[i] = digits[num.Int64()]
|
||||
}
|
||||
return string(code), nil
|
||||
}
|
||||
|
||||
// SendVerifyCode 发送验证码邮件
|
||||
func (s *EmailService) SendVerifyCode(ctx context.Context, email, siteName string) error {
|
||||
// 检查是否在冷却期内
|
||||
existing, err := s.cache.GetVerificationCode(ctx, email)
|
||||
if err == nil && existing != nil {
|
||||
if time.Since(existing.CreatedAt) < verifyCodeCooldown {
|
||||
return ErrVerifyCodeTooFrequent
|
||||
}
|
||||
}
|
||||
|
||||
// 生成验证码
|
||||
code, err := s.GenerateVerifyCode()
|
||||
if err != nil {
|
||||
return fmt.Errorf("generate code: %w", err)
|
||||
}
|
||||
|
||||
// 保存验证码到 Redis
|
||||
data := &VerificationCodeData{
|
||||
Code: code,
|
||||
Attempts: 0,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
if err := s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL); err != nil {
|
||||
return fmt.Errorf("save verify code: %w", err)
|
||||
}
|
||||
|
||||
// 构建邮件内容
|
||||
subject := fmt.Sprintf("[%s] Email Verification Code", siteName)
|
||||
body := s.buildVerifyCodeEmailBody(code, siteName)
|
||||
|
||||
// 发送邮件
|
||||
if err := s.SendEmail(ctx, email, subject, body); err != nil {
|
||||
return fmt.Errorf("send email: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// VerifyCode 验证验证码
|
||||
func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error {
|
||||
data, err := s.cache.GetVerificationCode(ctx, email)
|
||||
if err != nil || data == nil {
|
||||
return ErrInvalidVerifyCode
|
||||
}
|
||||
|
||||
// 检查是否已达到最大尝试次数
|
||||
if data.Attempts >= maxVerifyCodeAttempts {
|
||||
return ErrVerifyCodeMaxAttempts
|
||||
}
|
||||
|
||||
// 验证码不匹配
|
||||
if data.Code != code {
|
||||
data.Attempts++
|
||||
if err := s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL); err != nil {
|
||||
log.Printf("[Email] Failed to update verification attempt count: %v", err)
|
||||
}
|
||||
if data.Attempts >= maxVerifyCodeAttempts {
|
||||
return ErrVerifyCodeMaxAttempts
|
||||
}
|
||||
return ErrInvalidVerifyCode
|
||||
}
|
||||
|
||||
// 验证成功,删除验证码
|
||||
if err := s.cache.DeleteVerificationCode(ctx, email); err != nil {
|
||||
log.Printf("[Email] Failed to delete verification code after success: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// buildVerifyCodeEmailBody 构建验证码邮件HTML内容
|
||||
func (s *EmailService) buildVerifyCodeEmailBody(code, siteName string) string {
|
||||
return fmt.Sprintf(`
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<style>
|
||||
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, sans-serif; background-color: #f5f5f5; margin: 0; padding: 20px; }
|
||||
.container { max-width: 600px; margin: 0 auto; background-color: #ffffff; border-radius: 8px; overflow: hidden; box-shadow: 0 2px 8px rgba(0,0,0,0.1); }
|
||||
.header { background: linear-gradient(135deg, #667eea 0%%, #764ba2 100%%); color: white; padding: 30px; text-align: center; }
|
||||
.header h1 { margin: 0; font-size: 24px; }
|
||||
.content { padding: 40px 30px; text-align: center; }
|
||||
.code { font-size: 36px; font-weight: bold; letter-spacing: 8px; color: #333; background-color: #f8f9fa; padding: 20px 30px; border-radius: 8px; display: inline-block; margin: 20px 0; font-family: monospace; }
|
||||
.info { color: #666; font-size: 14px; line-height: 1.6; margin-top: 20px; }
|
||||
.footer { background-color: #f8f9fa; padding: 20px; text-align: center; color: #999; font-size: 12px; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<div class="header">
|
||||
<h1>%s</h1>
|
||||
</div>
|
||||
<div class="content">
|
||||
<p style="font-size: 18px; color: #333;">Your verification code is:</p>
|
||||
<div class="code">%s</div>
|
||||
<div class="info">
|
||||
<p>This code will expire in <strong>15 minutes</strong>.</p>
|
||||
<p>If you did not request this code, please ignore this email.</p>
|
||||
</div>
|
||||
</div>
|
||||
<div class="footer">
|
||||
<p>This is an automated message, please do not reply.</p>
|
||||
</div>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
`, siteName, code)
|
||||
}
|
||||
|
||||
// TestSMTPConnectionWithConfig 使用指定配置测试SMTP连接
|
||||
func (s *EmailService) TestSMTPConnectionWithConfig(config *SMTPConfig) error {
|
||||
addr := fmt.Sprintf("%s:%d", config.Host, config.Port)
|
||||
|
||||
if config.UseTLS {
|
||||
tlsConfig := &tls.Config{
|
||||
ServerName: config.Host,
|
||||
// 与发送逻辑一致,显式要求 TLS 1.2+。
|
||||
MinVersion: tls.VersionTLS12,
|
||||
}
|
||||
conn, err := tls.Dial("tcp", addr, tlsConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("tls connection failed: %w", err)
|
||||
}
|
||||
defer func() { _ = conn.Close() }()
|
||||
|
||||
client, err := smtp.NewClient(conn, config.Host)
|
||||
if err != nil {
|
||||
return fmt.Errorf("smtp client creation failed: %w", err)
|
||||
}
|
||||
defer func() { _ = client.Close() }()
|
||||
|
||||
auth := smtp.PlainAuth("", config.Username, config.Password, config.Host)
|
||||
if err = client.Auth(auth); err != nil {
|
||||
return fmt.Errorf("smtp authentication failed: %w", err)
|
||||
}
|
||||
|
||||
return client.Quit()
|
||||
}
|
||||
|
||||
// 非TLS连接测试
|
||||
client, err := smtp.Dial(addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("smtp connection failed: %w", err)
|
||||
}
|
||||
defer func() { _ = client.Close() }()
|
||||
|
||||
auth := smtp.PlainAuth("", config.Username, config.Password, config.Host)
|
||||
if err = client.Auth(auth); err != nil {
|
||||
return fmt.Errorf("smtp authentication failed: %w", err)
|
||||
}
|
||||
|
||||
return client.Quit()
|
||||
}
|
||||
1467
backend/internal/service/gateway_multiplatform_test.go
Normal file
1467
backend/internal/service/gateway_multiplatform_test.go
Normal file
File diff suppressed because it is too large
Load Diff
233
backend/internal/service/gateway_prompt_test.go
Normal file
233
backend/internal/service/gateway_prompt_test.go
Normal file
@@ -0,0 +1,233 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestIsClaudeCodeClient(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
userAgent string
|
||||
metadataUserID string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "Claude Code client",
|
||||
userAgent: "claude-cli/1.0.62 (darwin; arm64)",
|
||||
metadataUserID: "session_123e4567-e89b-12d3-a456-426614174000",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "Claude Code without version suffix",
|
||||
userAgent: "claude-cli/2.0.0",
|
||||
metadataUserID: "session_abc",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "Missing metadata user_id",
|
||||
userAgent: "claude-cli/1.0.0",
|
||||
metadataUserID: "",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "Different user agent",
|
||||
userAgent: "curl/7.68.0",
|
||||
metadataUserID: "user123",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "Empty user agent",
|
||||
userAgent: "",
|
||||
metadataUserID: "user123",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "Similar but not Claude CLI",
|
||||
userAgent: "claude-api/1.0.0",
|
||||
metadataUserID: "user123",
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := isClaudeCodeClient(tt.userAgent, tt.metadataUserID)
|
||||
require.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSystemIncludesClaudeCodePrompt(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
system any
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "nil system",
|
||||
system: nil,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
system: "",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "string with Claude Code prompt",
|
||||
system: claudeCodeSystemPrompt,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "string with different content",
|
||||
system: "You are a helpful assistant.",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "empty array",
|
||||
system: []any{},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "array with Claude Code prompt",
|
||||
system: []any{
|
||||
map[string]any{
|
||||
"type": "text",
|
||||
"text": claudeCodeSystemPrompt,
|
||||
},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "array with Claude Code prompt in second position",
|
||||
system: []any{
|
||||
map[string]any{"type": "text", "text": "First prompt"},
|
||||
map[string]any{"type": "text", "text": claudeCodeSystemPrompt},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "array without Claude Code prompt",
|
||||
system: []any{
|
||||
map[string]any{"type": "text", "text": "Custom prompt"},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "array with partial match (should not match)",
|
||||
system: []any{
|
||||
map[string]any{"type": "text", "text": "You are Claude"},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := systemIncludesClaudeCodePrompt(tt.system)
|
||||
require.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestInjectClaudeCodePrompt(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
system any
|
||||
wantSystemLen int
|
||||
wantFirstText string
|
||||
wantSecondText string
|
||||
}{
|
||||
{
|
||||
name: "nil system",
|
||||
body: `{"model":"claude-3"}`,
|
||||
system: nil,
|
||||
wantSystemLen: 1,
|
||||
wantFirstText: claudeCodeSystemPrompt,
|
||||
},
|
||||
{
|
||||
name: "empty string system",
|
||||
body: `{"model":"claude-3"}`,
|
||||
system: "",
|
||||
wantSystemLen: 1,
|
||||
wantFirstText: claudeCodeSystemPrompt,
|
||||
},
|
||||
{
|
||||
name: "string system",
|
||||
body: `{"model":"claude-3"}`,
|
||||
system: "Custom prompt",
|
||||
wantSystemLen: 2,
|
||||
wantFirstText: claudeCodeSystemPrompt,
|
||||
wantSecondText: "Custom prompt",
|
||||
},
|
||||
{
|
||||
name: "string system equals Claude Code prompt",
|
||||
body: `{"model":"claude-3"}`,
|
||||
system: claudeCodeSystemPrompt,
|
||||
wantSystemLen: 1,
|
||||
wantFirstText: claudeCodeSystemPrompt,
|
||||
},
|
||||
{
|
||||
name: "array system",
|
||||
body: `{"model":"claude-3"}`,
|
||||
system: []any{map[string]any{"type": "text", "text": "Custom"}},
|
||||
// Claude Code + Custom = 2
|
||||
wantSystemLen: 2,
|
||||
wantFirstText: claudeCodeSystemPrompt,
|
||||
wantSecondText: "Custom",
|
||||
},
|
||||
{
|
||||
name: "array system with existing Claude Code prompt (should dedupe)",
|
||||
body: `{"model":"claude-3"}`,
|
||||
system: []any{
|
||||
map[string]any{"type": "text", "text": claudeCodeSystemPrompt},
|
||||
map[string]any{"type": "text", "text": "Other"},
|
||||
},
|
||||
// Claude Code at start + Other = 2 (deduped)
|
||||
wantSystemLen: 2,
|
||||
wantFirstText: claudeCodeSystemPrompt,
|
||||
wantSecondText: "Other",
|
||||
},
|
||||
{
|
||||
name: "empty array",
|
||||
body: `{"model":"claude-3"}`,
|
||||
system: []any{},
|
||||
wantSystemLen: 1,
|
||||
wantFirstText: claudeCodeSystemPrompt,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := injectClaudeCodePrompt([]byte(tt.body), tt.system)
|
||||
|
||||
var parsed map[string]any
|
||||
err := json.Unmarshal(result, &parsed)
|
||||
require.NoError(t, err)
|
||||
|
||||
system, ok := parsed["system"].([]any)
|
||||
require.True(t, ok, "system should be an array")
|
||||
require.Len(t, system, tt.wantSystemLen)
|
||||
|
||||
first, ok := system[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, tt.wantFirstText, first["text"])
|
||||
require.Equal(t, "text", first["type"])
|
||||
|
||||
// Check cache_control
|
||||
cc, ok := first["cache_control"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "ephemeral", cc["type"])
|
||||
|
||||
if tt.wantSecondText != "" && len(system) > 1 {
|
||||
second, ok := system[1].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, tt.wantSecondText, second["text"])
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
505
backend/internal/service/gateway_request.go
Normal file
505
backend/internal/service/gateway_request.go
Normal file
@@ -0,0 +1,505 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// ParsedRequest 保存网关请求的预解析结果
|
||||
//
|
||||
// 性能优化说明:
|
||||
// 原实现在多个位置重复解析请求体(Handler、Service 各解析一次):
|
||||
// 1. gateway_handler.go 解析获取 model 和 stream
|
||||
// 2. gateway_service.go 再次解析获取 system、messages、metadata
|
||||
// 3. GenerateSessionHash 又一次解析获取会话哈希所需字段
|
||||
//
|
||||
// 新实现一次解析,多处复用:
|
||||
// 1. 在 Handler 层统一调用 ParseGatewayRequest 一次性解析
|
||||
// 2. 将解析结果 ParsedRequest 传递给 Service 层
|
||||
// 3. 避免重复 json.Unmarshal,减少 CPU 和内存开销
|
||||
type ParsedRequest struct {
|
||||
Body []byte // 原始请求体(保留用于转发)
|
||||
Model string // 请求的模型名称
|
||||
Stream bool // 是否为流式请求
|
||||
MetadataUserID string // metadata.user_id(用于会话亲和)
|
||||
System any // system 字段内容
|
||||
Messages []any // messages 数组
|
||||
HasSystem bool // 是否包含 system 字段(包含 null 也视为显式传入)
|
||||
}
|
||||
|
||||
// ParseGatewayRequest 解析网关请求体并返回结构化结果
|
||||
// 性能优化:一次解析提取所有需要的字段,避免重复 Unmarshal
|
||||
func ParseGatewayRequest(body []byte) (*ParsedRequest, error) {
|
||||
var req map[string]any
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
parsed := &ParsedRequest{
|
||||
Body: body,
|
||||
}
|
||||
|
||||
if rawModel, exists := req["model"]; exists {
|
||||
model, ok := rawModel.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid model field type")
|
||||
}
|
||||
parsed.Model = model
|
||||
}
|
||||
if rawStream, exists := req["stream"]; exists {
|
||||
stream, ok := rawStream.(bool)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid stream field type")
|
||||
}
|
||||
parsed.Stream = stream
|
||||
}
|
||||
if metadata, ok := req["metadata"].(map[string]any); ok {
|
||||
if userID, ok := metadata["user_id"].(string); ok {
|
||||
parsed.MetadataUserID = userID
|
||||
}
|
||||
}
|
||||
// system 字段只要存在就视为显式提供(即使为 null),
|
||||
// 以避免客户端传 null 时被默认 system 误注入。
|
||||
if system, ok := req["system"]; ok {
|
||||
parsed.HasSystem = true
|
||||
parsed.System = system
|
||||
}
|
||||
if messages, ok := req["messages"].([]any); ok {
|
||||
parsed.Messages = messages
|
||||
}
|
||||
|
||||
return parsed, nil
|
||||
}
|
||||
|
||||
// FilterThinkingBlocks removes thinking blocks from request body
|
||||
// Returns filtered body or original body if filtering fails (fail-safe)
|
||||
// This prevents 400 errors from invalid thinking block signatures
|
||||
//
|
||||
// Strategy:
|
||||
// - When thinking.type != "enabled": Remove all thinking blocks
|
||||
// - When thinking.type == "enabled": Only remove thinking blocks without valid signatures
|
||||
// (blocks with missing/empty/dummy signatures that would cause 400 errors)
|
||||
func FilterThinkingBlocks(body []byte) []byte {
|
||||
return filterThinkingBlocksInternal(body, false)
|
||||
}
|
||||
|
||||
// FilterThinkingBlocksForRetry strips thinking-related constructs for retry scenarios.
|
||||
//
|
||||
// Why:
|
||||
// - Upstreams may reject historical `thinking`/`redacted_thinking` blocks due to invalid/missing signatures.
|
||||
// - Anthropic extended thinking has a structural constraint: when top-level `thinking` is enabled and the
|
||||
// final message is an assistant prefill, the assistant content must start with a thinking block.
|
||||
// - If we remove thinking blocks but keep top-level `thinking` enabled, we can trigger:
|
||||
// "Expected `thinking` or `redacted_thinking`, but found `text`"
|
||||
//
|
||||
// Strategy (B: preserve content as text):
|
||||
// - Disable top-level `thinking` (remove `thinking` field).
|
||||
// - Convert `thinking` blocks to `text` blocks (preserve the thinking content).
|
||||
// - Remove `redacted_thinking` blocks (cannot be converted to text).
|
||||
// - Ensure no message ends up with empty content.
|
||||
func FilterThinkingBlocksForRetry(body []byte) []byte {
|
||||
hasThinkingContent := bytes.Contains(body, []byte(`"type":"thinking"`)) ||
|
||||
bytes.Contains(body, []byte(`"type": "thinking"`)) ||
|
||||
bytes.Contains(body, []byte(`"type":"redacted_thinking"`)) ||
|
||||
bytes.Contains(body, []byte(`"type": "redacted_thinking"`)) ||
|
||||
bytes.Contains(body, []byte(`"thinking":`)) ||
|
||||
bytes.Contains(body, []byte(`"thinking" :`))
|
||||
|
||||
// Also check for empty content arrays that need fixing.
|
||||
// Note: This is a heuristic check; the actual empty content handling is done below.
|
||||
hasEmptyContent := bytes.Contains(body, []byte(`"content":[]`)) ||
|
||||
bytes.Contains(body, []byte(`"content": []`)) ||
|
||||
bytes.Contains(body, []byte(`"content" : []`)) ||
|
||||
bytes.Contains(body, []byte(`"content" :[]`))
|
||||
|
||||
// Fast path: nothing to process
|
||||
if !hasThinkingContent && !hasEmptyContent {
|
||||
return body
|
||||
}
|
||||
|
||||
var req map[string]any
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
return body
|
||||
}
|
||||
|
||||
modified := false
|
||||
|
||||
messages, ok := req["messages"].([]any)
|
||||
if !ok {
|
||||
return body
|
||||
}
|
||||
|
||||
// Disable top-level thinking mode for retry to avoid structural/signature constraints upstream.
|
||||
if _, exists := req["thinking"]; exists {
|
||||
delete(req, "thinking")
|
||||
modified = true
|
||||
}
|
||||
|
||||
newMessages := make([]any, 0, len(messages))
|
||||
|
||||
for _, msg := range messages {
|
||||
msgMap, ok := msg.(map[string]any)
|
||||
if !ok {
|
||||
newMessages = append(newMessages, msg)
|
||||
continue
|
||||
}
|
||||
|
||||
role, _ := msgMap["role"].(string)
|
||||
content, ok := msgMap["content"].([]any)
|
||||
if !ok {
|
||||
// String content or other format - keep as is
|
||||
newMessages = append(newMessages, msg)
|
||||
continue
|
||||
}
|
||||
|
||||
newContent := make([]any, 0, len(content))
|
||||
modifiedThisMsg := false
|
||||
|
||||
for _, block := range content {
|
||||
blockMap, ok := block.(map[string]any)
|
||||
if !ok {
|
||||
newContent = append(newContent, block)
|
||||
continue
|
||||
}
|
||||
|
||||
blockType, _ := blockMap["type"].(string)
|
||||
|
||||
// Convert thinking blocks to text (preserve content) and drop redacted_thinking.
|
||||
switch blockType {
|
||||
case "thinking":
|
||||
modifiedThisMsg = true
|
||||
thinkingText, _ := blockMap["thinking"].(string)
|
||||
if thinkingText == "" {
|
||||
continue
|
||||
}
|
||||
newContent = append(newContent, map[string]any{
|
||||
"type": "text",
|
||||
"text": thinkingText,
|
||||
})
|
||||
continue
|
||||
case "redacted_thinking":
|
||||
modifiedThisMsg = true
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle blocks without type discriminator but with a "thinking" field.
|
||||
if blockType == "" {
|
||||
if rawThinking, hasThinking := blockMap["thinking"]; hasThinking {
|
||||
modifiedThisMsg = true
|
||||
switch v := rawThinking.(type) {
|
||||
case string:
|
||||
if v != "" {
|
||||
newContent = append(newContent, map[string]any{"type": "text", "text": v})
|
||||
}
|
||||
default:
|
||||
if b, err := json.Marshal(v); err == nil && len(b) > 0 {
|
||||
newContent = append(newContent, map[string]any{"type": "text", "text": string(b)})
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
newContent = append(newContent, block)
|
||||
}
|
||||
|
||||
// Handle empty content: either from filtering or originally empty
|
||||
if len(newContent) == 0 {
|
||||
modified = true
|
||||
placeholder := "(content removed)"
|
||||
if role == "assistant" {
|
||||
placeholder = "(assistant content removed)"
|
||||
}
|
||||
newContent = append(newContent, map[string]any{
|
||||
"type": "text",
|
||||
"text": placeholder,
|
||||
})
|
||||
msgMap["content"] = newContent
|
||||
} else if modifiedThisMsg {
|
||||
modified = true
|
||||
msgMap["content"] = newContent
|
||||
}
|
||||
newMessages = append(newMessages, msgMap)
|
||||
}
|
||||
|
||||
if modified {
|
||||
req["messages"] = newMessages
|
||||
} else {
|
||||
// Avoid rewriting JSON when no changes are needed.
|
||||
return body
|
||||
}
|
||||
|
||||
newBody, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
return newBody
|
||||
}
|
||||
|
||||
// FilterSignatureSensitiveBlocksForRetry is a stronger retry filter for cases where upstream errors indicate
|
||||
// signature/thought_signature validation issues involving tool blocks.
|
||||
//
|
||||
// This performs everything in FilterThinkingBlocksForRetry, plus:
|
||||
// - Convert `tool_use` blocks to text (name/id/input) so we stop sending structured tool calls.
|
||||
// - Convert `tool_result` blocks to text so we keep tool results visible without tool semantics.
|
||||
//
|
||||
// Use this only when needed: converting tool blocks to text changes model behaviour and can increase the
|
||||
// risk of prompt injection (tool output becomes plain conversation text).
|
||||
func FilterSignatureSensitiveBlocksForRetry(body []byte) []byte {
|
||||
// Fast path: only run when we see likely relevant constructs.
|
||||
if !bytes.Contains(body, []byte(`"type":"thinking"`)) &&
|
||||
!bytes.Contains(body, []byte(`"type": "thinking"`)) &&
|
||||
!bytes.Contains(body, []byte(`"type":"redacted_thinking"`)) &&
|
||||
!bytes.Contains(body, []byte(`"type": "redacted_thinking"`)) &&
|
||||
!bytes.Contains(body, []byte(`"type":"tool_use"`)) &&
|
||||
!bytes.Contains(body, []byte(`"type": "tool_use"`)) &&
|
||||
!bytes.Contains(body, []byte(`"type":"tool_result"`)) &&
|
||||
!bytes.Contains(body, []byte(`"type": "tool_result"`)) &&
|
||||
!bytes.Contains(body, []byte(`"thinking":`)) &&
|
||||
!bytes.Contains(body, []byte(`"thinking" :`)) {
|
||||
return body
|
||||
}
|
||||
|
||||
var req map[string]any
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
return body
|
||||
}
|
||||
|
||||
modified := false
|
||||
|
||||
// Disable top-level thinking for retry to avoid structural/signature constraints upstream.
|
||||
if _, exists := req["thinking"]; exists {
|
||||
delete(req, "thinking")
|
||||
modified = true
|
||||
}
|
||||
|
||||
messages, ok := req["messages"].([]any)
|
||||
if !ok {
|
||||
return body
|
||||
}
|
||||
|
||||
newMessages := make([]any, 0, len(messages))
|
||||
|
||||
for _, msg := range messages {
|
||||
msgMap, ok := msg.(map[string]any)
|
||||
if !ok {
|
||||
newMessages = append(newMessages, msg)
|
||||
continue
|
||||
}
|
||||
|
||||
role, _ := msgMap["role"].(string)
|
||||
content, ok := msgMap["content"].([]any)
|
||||
if !ok {
|
||||
newMessages = append(newMessages, msg)
|
||||
continue
|
||||
}
|
||||
|
||||
newContent := make([]any, 0, len(content))
|
||||
modifiedThisMsg := false
|
||||
|
||||
for _, block := range content {
|
||||
blockMap, ok := block.(map[string]any)
|
||||
if !ok {
|
||||
newContent = append(newContent, block)
|
||||
continue
|
||||
}
|
||||
|
||||
blockType, _ := blockMap["type"].(string)
|
||||
switch blockType {
|
||||
case "thinking":
|
||||
modifiedThisMsg = true
|
||||
thinkingText, _ := blockMap["thinking"].(string)
|
||||
if thinkingText == "" {
|
||||
continue
|
||||
}
|
||||
newContent = append(newContent, map[string]any{"type": "text", "text": thinkingText})
|
||||
continue
|
||||
case "redacted_thinking":
|
||||
modifiedThisMsg = true
|
||||
continue
|
||||
case "tool_use":
|
||||
modifiedThisMsg = true
|
||||
name, _ := blockMap["name"].(string)
|
||||
id, _ := blockMap["id"].(string)
|
||||
input := blockMap["input"]
|
||||
inputJSON, _ := json.Marshal(input)
|
||||
text := "(tool_use)"
|
||||
if name != "" {
|
||||
text += " name=" + name
|
||||
}
|
||||
if id != "" {
|
||||
text += " id=" + id
|
||||
}
|
||||
if len(inputJSON) > 0 && string(inputJSON) != "null" {
|
||||
text += " input=" + string(inputJSON)
|
||||
}
|
||||
newContent = append(newContent, map[string]any{"type": "text", "text": text})
|
||||
continue
|
||||
case "tool_result":
|
||||
modifiedThisMsg = true
|
||||
toolUseID, _ := blockMap["tool_use_id"].(string)
|
||||
isError, _ := blockMap["is_error"].(bool)
|
||||
content := blockMap["content"]
|
||||
contentJSON, _ := json.Marshal(content)
|
||||
text := "(tool_result)"
|
||||
if toolUseID != "" {
|
||||
text += " tool_use_id=" + toolUseID
|
||||
}
|
||||
if isError {
|
||||
text += " is_error=true"
|
||||
}
|
||||
if len(contentJSON) > 0 && string(contentJSON) != "null" {
|
||||
text += "\n" + string(contentJSON)
|
||||
}
|
||||
newContent = append(newContent, map[string]any{"type": "text", "text": text})
|
||||
continue
|
||||
}
|
||||
|
||||
if blockType == "" {
|
||||
if rawThinking, hasThinking := blockMap["thinking"]; hasThinking {
|
||||
modifiedThisMsg = true
|
||||
switch v := rawThinking.(type) {
|
||||
case string:
|
||||
if v != "" {
|
||||
newContent = append(newContent, map[string]any{"type": "text", "text": v})
|
||||
}
|
||||
default:
|
||||
if b, err := json.Marshal(v); err == nil && len(b) > 0 {
|
||||
newContent = append(newContent, map[string]any{"type": "text", "text": string(b)})
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
newContent = append(newContent, block)
|
||||
}
|
||||
|
||||
if modifiedThisMsg {
|
||||
modified = true
|
||||
if len(newContent) == 0 {
|
||||
placeholder := "(content removed)"
|
||||
if role == "assistant" {
|
||||
placeholder = "(assistant content removed)"
|
||||
}
|
||||
newContent = append(newContent, map[string]any{"type": "text", "text": placeholder})
|
||||
}
|
||||
msgMap["content"] = newContent
|
||||
}
|
||||
|
||||
newMessages = append(newMessages, msgMap)
|
||||
}
|
||||
|
||||
if !modified {
|
||||
return body
|
||||
}
|
||||
|
||||
req["messages"] = newMessages
|
||||
newBody, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
return newBody
|
||||
}
|
||||
|
||||
// filterThinkingBlocksInternal removes invalid thinking blocks from request
|
||||
// Strategy:
|
||||
// - When thinking.type != "enabled": Remove all thinking blocks
|
||||
// - When thinking.type == "enabled": Only remove thinking blocks without valid signatures
|
||||
func filterThinkingBlocksInternal(body []byte, _ bool) []byte {
|
||||
// Fast path: if body doesn't contain "thinking", skip parsing
|
||||
if !bytes.Contains(body, []byte(`"type":"thinking"`)) &&
|
||||
!bytes.Contains(body, []byte(`"type": "thinking"`)) &&
|
||||
!bytes.Contains(body, []byte(`"type":"redacted_thinking"`)) &&
|
||||
!bytes.Contains(body, []byte(`"type": "redacted_thinking"`)) &&
|
||||
!bytes.Contains(body, []byte(`"thinking":`)) &&
|
||||
!bytes.Contains(body, []byte(`"thinking" :`)) {
|
||||
return body
|
||||
}
|
||||
|
||||
var req map[string]any
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
return body
|
||||
}
|
||||
|
||||
// Check if thinking is enabled
|
||||
thinkingEnabled := false
|
||||
if thinking, ok := req["thinking"].(map[string]any); ok {
|
||||
if thinkType, ok := thinking["type"].(string); ok && thinkType == "enabled" {
|
||||
thinkingEnabled = true
|
||||
}
|
||||
}
|
||||
|
||||
messages, ok := req["messages"].([]any)
|
||||
if !ok {
|
||||
return body
|
||||
}
|
||||
|
||||
filtered := false
|
||||
for _, msg := range messages {
|
||||
msgMap, ok := msg.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
role, _ := msgMap["role"].(string)
|
||||
content, ok := msgMap["content"].([]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
newContent := make([]any, 0, len(content))
|
||||
filteredThisMessage := false
|
||||
|
||||
for _, block := range content {
|
||||
blockMap, ok := block.(map[string]any)
|
||||
if !ok {
|
||||
newContent = append(newContent, block)
|
||||
continue
|
||||
}
|
||||
|
||||
blockType, _ := blockMap["type"].(string)
|
||||
|
||||
if blockType == "thinking" || blockType == "redacted_thinking" {
|
||||
// When thinking is enabled and this is an assistant message,
|
||||
// only keep thinking blocks with valid signatures
|
||||
if thinkingEnabled && role == "assistant" {
|
||||
signature, _ := blockMap["signature"].(string)
|
||||
if signature != "" && signature != "skip_thought_signature_validator" {
|
||||
newContent = append(newContent, block)
|
||||
continue
|
||||
}
|
||||
}
|
||||
filtered = true
|
||||
filteredThisMessage = true
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle blocks without type discriminator but with "thinking" key
|
||||
if blockType == "" {
|
||||
if _, hasThinking := blockMap["thinking"]; hasThinking {
|
||||
filtered = true
|
||||
filteredThisMessage = true
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
newContent = append(newContent, block)
|
||||
}
|
||||
|
||||
if filteredThisMessage {
|
||||
msgMap["content"] = newContent
|
||||
}
|
||||
}
|
||||
|
||||
if !filtered {
|
||||
return body
|
||||
}
|
||||
|
||||
newBody, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
return newBody
|
||||
}
|
||||
298
backend/internal/service/gateway_request_test.go
Normal file
298
backend/internal/service/gateway_request_test.go
Normal file
@@ -0,0 +1,298 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestParseGatewayRequest(t *testing.T) {
|
||||
body := []byte(`{"model":"claude-3-7-sonnet","stream":true,"metadata":{"user_id":"session_123e4567-e89b-12d3-a456-426614174000"},"system":[{"type":"text","text":"hello","cache_control":{"type":"ephemeral"}}],"messages":[{"content":"hi"}]}`)
|
||||
parsed, err := ParseGatewayRequest(body)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "claude-3-7-sonnet", parsed.Model)
|
||||
require.True(t, parsed.Stream)
|
||||
require.Equal(t, "session_123e4567-e89b-12d3-a456-426614174000", parsed.MetadataUserID)
|
||||
require.True(t, parsed.HasSystem)
|
||||
require.NotNil(t, parsed.System)
|
||||
require.Len(t, parsed.Messages, 1)
|
||||
}
|
||||
|
||||
func TestParseGatewayRequest_SystemNull(t *testing.T) {
|
||||
body := []byte(`{"model":"claude-3","system":null}`)
|
||||
parsed, err := ParseGatewayRequest(body)
|
||||
require.NoError(t, err)
|
||||
// 显式传入 system:null 也应视为“字段已存在”,避免默认 system 被注入。
|
||||
require.True(t, parsed.HasSystem)
|
||||
require.Nil(t, parsed.System)
|
||||
}
|
||||
|
||||
func TestParseGatewayRequest_InvalidModelType(t *testing.T) {
|
||||
body := []byte(`{"model":123}`)
|
||||
_, err := ParseGatewayRequest(body)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestParseGatewayRequest_InvalidStreamType(t *testing.T) {
|
||||
body := []byte(`{"stream":"true"}`)
|
||||
_, err := ParseGatewayRequest(body)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestFilterThinkingBlocks(t *testing.T) {
|
||||
containsThinkingBlock := func(body []byte) bool {
|
||||
var req map[string]any
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
return false
|
||||
}
|
||||
messages, ok := req["messages"].([]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
for _, msg := range messages {
|
||||
msgMap, ok := msg.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
content, ok := msgMap["content"].([]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
for _, block := range content {
|
||||
blockMap, ok := block.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
blockType, _ := blockMap["type"].(string)
|
||||
if blockType == "thinking" {
|
||||
return true
|
||||
}
|
||||
if blockType == "" {
|
||||
if _, hasThinking := blockMap["thinking"]; hasThinking {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
shouldFilter bool
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "filters thinking blocks",
|
||||
input: `{"model":"claude-3-5-sonnet-20241022","messages":[{"role":"user","content":[{"type":"text","text":"Hello"},{"type":"thinking","thinking":"internal","signature":"invalid"},{"type":"text","text":"World"}]}]}`,
|
||||
shouldFilter: true,
|
||||
},
|
||||
{
|
||||
name: "handles no thinking blocks",
|
||||
input: `{"model":"claude-3-5-sonnet-20241022","messages":[{"role":"user","content":[{"type":"text","text":"Hello"}]}]}`,
|
||||
shouldFilter: false,
|
||||
},
|
||||
{
|
||||
name: "handles invalid JSON gracefully",
|
||||
input: `{invalid json`,
|
||||
shouldFilter: false,
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "handles multiple messages with thinking blocks",
|
||||
input: `{"messages":[{"role":"user","content":[{"type":"text","text":"A"}]},{"role":"assistant","content":[{"type":"thinking","thinking":"think"},{"type":"text","text":"B"}]}]}`,
|
||||
shouldFilter: true,
|
||||
},
|
||||
{
|
||||
name: "filters thinking blocks without type discriminator",
|
||||
input: `{"messages":[{"role":"assistant","content":[{"thinking":{"text":"internal"}},{"type":"text","text":"B"}]}]}`,
|
||||
shouldFilter: true,
|
||||
},
|
||||
{
|
||||
name: "does not filter tool_use input fields named thinking",
|
||||
input: `{"messages":[{"role":"user","content":[{"type":"tool_use","id":"t1","name":"foo","input":{"thinking":"keepme","x":1}},{"type":"text","text":"Hello"}]}]}`,
|
||||
shouldFilter: false,
|
||||
},
|
||||
{
|
||||
name: "handles empty messages array",
|
||||
input: `{"messages":[]}`,
|
||||
shouldFilter: false,
|
||||
},
|
||||
{
|
||||
name: "handles missing messages field",
|
||||
input: `{"model":"claude-3"}`,
|
||||
shouldFilter: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := FilterThinkingBlocks([]byte(tt.input))
|
||||
|
||||
if tt.expectError {
|
||||
// For invalid JSON, should return original
|
||||
require.Equal(t, tt.input, string(result))
|
||||
return
|
||||
}
|
||||
|
||||
if tt.shouldFilter {
|
||||
require.False(t, containsThinkingBlock(result))
|
||||
} else {
|
||||
// Ensure we don't rewrite JSON when no filtering is needed.
|
||||
require.Equal(t, tt.input, string(result))
|
||||
}
|
||||
|
||||
// Verify valid JSON returned (unless input was invalid)
|
||||
var parsed map[string]any
|
||||
err := json.Unmarshal(result, &parsed)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterThinkingBlocksForRetry_DisablesThinkingAndPreservesAsText(t *testing.T) {
|
||||
input := []byte(`{
|
||||
"model":"claude-3-5-sonnet-20241022",
|
||||
"thinking":{"type":"enabled","budget_tokens":1024},
|
||||
"messages":[
|
||||
{"role":"user","content":[{"type":"text","text":"Hi"}]},
|
||||
{"role":"assistant","content":[
|
||||
{"type":"thinking","thinking":"Let me think...","signature":"bad_sig"},
|
||||
{"type":"text","text":"Answer"}
|
||||
]}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := FilterThinkingBlocksForRetry(input)
|
||||
|
||||
var req map[string]any
|
||||
require.NoError(t, json.Unmarshal(out, &req))
|
||||
_, hasThinking := req["thinking"]
|
||||
require.False(t, hasThinking)
|
||||
|
||||
msgs, ok := req["messages"].([]any)
|
||||
require.True(t, ok)
|
||||
require.Len(t, msgs, 2)
|
||||
|
||||
assistant, ok := msgs[1].(map[string]any)
|
||||
require.True(t, ok)
|
||||
content, ok := assistant["content"].([]any)
|
||||
require.True(t, ok)
|
||||
require.Len(t, content, 2)
|
||||
|
||||
first, ok := content[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "text", first["type"])
|
||||
require.Equal(t, "Let me think...", first["text"])
|
||||
}
|
||||
|
||||
func TestFilterThinkingBlocksForRetry_DisablesThinkingEvenWithoutThinkingBlocks(t *testing.T) {
|
||||
input := []byte(`{
|
||||
"model":"claude-3-5-sonnet-20241022",
|
||||
"thinking":{"type":"enabled","budget_tokens":1024},
|
||||
"messages":[
|
||||
{"role":"user","content":[{"type":"text","text":"Hi"}]},
|
||||
{"role":"assistant","content":[{"type":"text","text":"Prefill"}]}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := FilterThinkingBlocksForRetry(input)
|
||||
|
||||
var req map[string]any
|
||||
require.NoError(t, json.Unmarshal(out, &req))
|
||||
_, hasThinking := req["thinking"]
|
||||
require.False(t, hasThinking)
|
||||
}
|
||||
|
||||
func TestFilterThinkingBlocksForRetry_RemovesRedactedThinkingAndKeepsValidContent(t *testing.T) {
|
||||
input := []byte(`{
|
||||
"thinking":{"type":"enabled","budget_tokens":1024},
|
||||
"messages":[
|
||||
{"role":"assistant","content":[
|
||||
{"type":"redacted_thinking","data":"..."},
|
||||
{"type":"text","text":"Visible"}
|
||||
]}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := FilterThinkingBlocksForRetry(input)
|
||||
|
||||
var req map[string]any
|
||||
require.NoError(t, json.Unmarshal(out, &req))
|
||||
_, hasThinking := req["thinking"]
|
||||
require.False(t, hasThinking)
|
||||
|
||||
msgs, ok := req["messages"].([]any)
|
||||
require.True(t, ok)
|
||||
msg0, ok := msgs[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
content, ok := msg0["content"].([]any)
|
||||
require.True(t, ok)
|
||||
require.Len(t, content, 1)
|
||||
content0, ok := content[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "text", content0["type"])
|
||||
require.Equal(t, "Visible", content0["text"])
|
||||
}
|
||||
|
||||
func TestFilterThinkingBlocksForRetry_EmptyContentGetsPlaceholder(t *testing.T) {
|
||||
input := []byte(`{
|
||||
"thinking":{"type":"enabled"},
|
||||
"messages":[
|
||||
{"role":"assistant","content":[{"type":"redacted_thinking","data":"..."}]}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := FilterThinkingBlocksForRetry(input)
|
||||
|
||||
var req map[string]any
|
||||
require.NoError(t, json.Unmarshal(out, &req))
|
||||
msgs, ok := req["messages"].([]any)
|
||||
require.True(t, ok)
|
||||
msg0, ok := msgs[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
content, ok := msg0["content"].([]any)
|
||||
require.True(t, ok)
|
||||
require.Len(t, content, 1)
|
||||
content0, ok := content[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "text", content0["type"])
|
||||
require.NotEmpty(t, content0["text"])
|
||||
}
|
||||
|
||||
func TestFilterSignatureSensitiveBlocksForRetry_DowngradesTools(t *testing.T) {
|
||||
input := []byte(`{
|
||||
"thinking":{"type":"enabled","budget_tokens":1024},
|
||||
"messages":[
|
||||
{"role":"assistant","content":[
|
||||
{"type":"tool_use","id":"t1","name":"Bash","input":{"command":"ls"}},
|
||||
{"type":"tool_result","tool_use_id":"t1","content":"ok","is_error":false}
|
||||
]}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := FilterSignatureSensitiveBlocksForRetry(input)
|
||||
|
||||
var req map[string]any
|
||||
require.NoError(t, json.Unmarshal(out, &req))
|
||||
_, hasThinking := req["thinking"]
|
||||
require.False(t, hasThinking)
|
||||
|
||||
msgs, ok := req["messages"].([]any)
|
||||
require.True(t, ok)
|
||||
msg0, ok := msgs[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
content, ok := msg0["content"].([]any)
|
||||
require.True(t, ok)
|
||||
require.Len(t, content, 2)
|
||||
content0, ok := content[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
content1, ok := content[1].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "text", content0["type"])
|
||||
require.Equal(t, "text", content1["type"])
|
||||
require.Contains(t, content0["text"], "tool_use")
|
||||
require.Contains(t, content1["text"], "tool_result")
|
||||
}
|
||||
3031
backend/internal/service/gateway_service.go
Normal file
3031
backend/internal/service/gateway_service.go
Normal file
File diff suppressed because it is too large
Load Diff
50
backend/internal/service/gateway_service_benchmark_test.go
Normal file
50
backend/internal/service/gateway_service_benchmark_test.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var benchmarkStringSink string
|
||||
|
||||
// BenchmarkGenerateSessionHash_Metadata 关注 JSON 解析与正则匹配开销。
|
||||
func BenchmarkGenerateSessionHash_Metadata(b *testing.B) {
|
||||
svc := &GatewayService{}
|
||||
body := []byte(`{"metadata":{"user_id":"session_123e4567-e89b-12d3-a456-426614174000"},"messages":[{"content":"hello"}]}`)
|
||||
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
parsed, err := ParseGatewayRequest(body)
|
||||
if err != nil {
|
||||
b.Fatalf("解析请求失败: %v", err)
|
||||
}
|
||||
benchmarkStringSink = svc.GenerateSessionHash(parsed)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkExtractCacheableContent_System 关注字符串拼接路径的性能。
|
||||
func BenchmarkExtractCacheableContent_System(b *testing.B) {
|
||||
svc := &GatewayService{}
|
||||
req := buildSystemCacheableRequest(12)
|
||||
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
benchmarkStringSink = svc.extractCacheableContent(req)
|
||||
}
|
||||
}
|
||||
|
||||
func buildSystemCacheableRequest(parts int) *ParsedRequest {
|
||||
systemParts := make([]any, 0, parts)
|
||||
for i := 0; i < parts; i++ {
|
||||
systemParts = append(systemParts, map[string]any{
|
||||
"text": "system_part_" + strconv.Itoa(i),
|
||||
"cache_control": map[string]any{
|
||||
"type": "ephemeral",
|
||||
},
|
||||
})
|
||||
}
|
||||
return &ParsedRequest{
|
||||
System: systemParts,
|
||||
HasSystem: true,
|
||||
}
|
||||
}
|
||||
2818
backend/internal/service/gemini_messages_compat_service.go
Normal file
2818
backend/internal/service/gemini_messages_compat_service.go
Normal file
File diff suppressed because it is too large
Load Diff
128
backend/internal/service/gemini_messages_compat_service_test.go
Normal file
128
backend/internal/service/gemini_messages_compat_service_test.go
Normal file
@@ -0,0 +1,128 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestConvertClaudeToolsToGeminiTools_CustomType 测试custom类型工具转换
|
||||
func TestConvertClaudeToolsToGeminiTools_CustomType(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tools any
|
||||
expectedLen int
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "Standard tools",
|
||||
tools: []any{
|
||||
map[string]any{
|
||||
"name": "get_weather",
|
||||
"description": "Get weather info",
|
||||
"input_schema": map[string]any{"type": "object"},
|
||||
},
|
||||
},
|
||||
expectedLen: 1,
|
||||
description: "标准工具格式应该正常转换",
|
||||
},
|
||||
{
|
||||
name: "Custom type tool (MCP format)",
|
||||
tools: []any{
|
||||
map[string]any{
|
||||
"type": "custom",
|
||||
"name": "mcp_tool",
|
||||
"custom": map[string]any{
|
||||
"description": "MCP tool description",
|
||||
"input_schema": map[string]any{"type": "object"},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedLen: 1,
|
||||
description: "Custom类型工具应该从custom字段读取",
|
||||
},
|
||||
{
|
||||
name: "Mixed standard and custom tools",
|
||||
tools: []any{
|
||||
map[string]any{
|
||||
"name": "standard_tool",
|
||||
"description": "Standard",
|
||||
"input_schema": map[string]any{"type": "object"},
|
||||
},
|
||||
map[string]any{
|
||||
"type": "custom",
|
||||
"name": "custom_tool",
|
||||
"custom": map[string]any{
|
||||
"description": "Custom",
|
||||
"input_schema": map[string]any{"type": "object"},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedLen: 1,
|
||||
description: "混合工具应该都能正确转换",
|
||||
},
|
||||
{
|
||||
name: "Custom tool without custom field",
|
||||
tools: []any{
|
||||
map[string]any{
|
||||
"type": "custom",
|
||||
"name": "invalid_custom",
|
||||
// 缺少 custom 字段
|
||||
},
|
||||
},
|
||||
expectedLen: 0, // 应该被跳过
|
||||
description: "缺少custom字段的custom工具应该被跳过",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := convertClaudeToolsToGeminiTools(tt.tools)
|
||||
|
||||
if tt.expectedLen == 0 {
|
||||
if result != nil {
|
||||
t.Errorf("%s: expected nil result, got %v", tt.description, result)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
t.Fatalf("%s: expected non-nil result", tt.description)
|
||||
}
|
||||
|
||||
if len(result) != 1 {
|
||||
t.Errorf("%s: expected 1 tool declaration, got %d", tt.description, len(result))
|
||||
return
|
||||
}
|
||||
|
||||
toolDecl, ok := result[0].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("%s: result[0] is not map[string]any", tt.description)
|
||||
}
|
||||
|
||||
funcDecls, ok := toolDecl["functionDeclarations"].([]any)
|
||||
if !ok {
|
||||
t.Fatalf("%s: functionDeclarations is not []any", tt.description)
|
||||
}
|
||||
|
||||
toolsArr, _ := tt.tools.([]any)
|
||||
expectedFuncCount := 0
|
||||
for _, tool := range toolsArr {
|
||||
toolMap, _ := tool.(map[string]any)
|
||||
if toolMap["name"] != "" {
|
||||
// 检查是否为有效的custom工具
|
||||
if toolMap["type"] == "custom" {
|
||||
if toolMap["custom"] != nil {
|
||||
expectedFuncCount++
|
||||
}
|
||||
} else {
|
||||
expectedFuncCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(funcDecls) != expectedFuncCount {
|
||||
t.Errorf("%s: expected %d function declarations, got %d",
|
||||
tt.description, expectedFuncCount, len(funcDecls))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
609
backend/internal/service/gemini_multiplatform_test.go
Normal file
609
backend/internal/service/gemini_multiplatform_test.go
Normal file
@@ -0,0 +1,609 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// mockAccountRepoForGemini Gemini 测试用的 mock
|
||||
type mockAccountRepoForGemini struct {
|
||||
accounts []Account
|
||||
accountsByID map[int64]*Account
|
||||
}
|
||||
|
||||
func (m *mockAccountRepoForGemini) GetByID(ctx context.Context, id int64) (*Account, error) {
|
||||
if acc, ok := m.accountsByID[id]; ok {
|
||||
return acc, nil
|
||||
}
|
||||
return nil, errors.New("account not found")
|
||||
}
|
||||
|
||||
func (m *mockAccountRepoForGemini) GetByIDs(ctx context.Context, ids []int64) ([]*Account, error) {
|
||||
var result []*Account
|
||||
for _, id := range ids {
|
||||
if acc, ok := m.accountsByID[id]; ok {
|
||||
result = append(result, acc)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *mockAccountRepoForGemini) ExistsByID(ctx context.Context, id int64) (bool, error) {
|
||||
if m.accountsByID == nil {
|
||||
return false, nil
|
||||
}
|
||||
_, ok := m.accountsByID[id]
|
||||
return ok, nil
|
||||
}
|
||||
|
||||
func (m *mockAccountRepoForGemini) ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error) {
|
||||
var result []Account
|
||||
for _, acc := range m.accounts {
|
||||
if acc.Platform == platform && acc.IsSchedulable() {
|
||||
result = append(result, acc)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *mockAccountRepoForGemini) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error) {
|
||||
// 测试时不区分 groupID,直接按 platform 过滤
|
||||
return m.ListSchedulableByPlatform(ctx, platform)
|
||||
}
|
||||
|
||||
// Stub methods to implement AccountRepository interface
|
||||
func (m *mockAccountRepoForGemini) Create(ctx context.Context, account *Account) error { return nil }
|
||||
func (m *mockAccountRepoForGemini) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) Update(ctx context.Context, account *Account) error { return nil }
|
||||
func (m *mockAccountRepoForGemini) Delete(ctx context.Context, id int64) error { return nil }
|
||||
func (m *mockAccountRepoForGemini) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) ListActive(ctx context.Context) ([]Account, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) ListByPlatform(ctx context.Context, platform string) ([]Account, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) UpdateLastUsed(ctx context.Context, id int64) error { return nil }
|
||||
func (m *mockAccountRepoForGemini) BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) SetError(ctx context.Context, id int64, errorMsg string) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) ListSchedulable(ctx context.Context) ([]Account, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]Account, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]Account, error) {
|
||||
var result []Account
|
||||
platformSet := make(map[string]bool)
|
||||
for _, p := range platforms {
|
||||
platformSet[p] = true
|
||||
}
|
||||
for _, acc := range m.accounts {
|
||||
if platformSet[acc.Platform] && acc.IsSchedulable() {
|
||||
result = append(result, acc)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error) {
|
||||
return m.ListSchedulableByPlatforms(ctx, platforms)
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) ClearTempUnschedulable(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) ClearRateLimit(ctx context.Context, id int64) error { return nil }
|
||||
func (m *mockAccountRepoForGemini) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) BulkUpdate(ctx context.Context, ids []int64, updates AccountBulkUpdate) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// Verify interface implementation
|
||||
var _ AccountRepository = (*mockAccountRepoForGemini)(nil)
|
||||
|
||||
// mockGroupRepoForGemini Gemini 测试用的 group repo mock
|
||||
type mockGroupRepoForGemini struct {
|
||||
groups map[int64]*Group
|
||||
getByIDCalls int
|
||||
getByIDLiteCalls int
|
||||
}
|
||||
|
||||
func (m *mockGroupRepoForGemini) GetByID(ctx context.Context, id int64) (*Group, error) {
|
||||
m.getByIDCalls++
|
||||
if g, ok := m.groups[id]; ok {
|
||||
return g, nil
|
||||
}
|
||||
return nil, errors.New("group not found")
|
||||
}
|
||||
|
||||
func (m *mockGroupRepoForGemini) GetByIDLite(ctx context.Context, id int64) (*Group, error) {
|
||||
m.getByIDLiteCalls++
|
||||
if g, ok := m.groups[id]; ok {
|
||||
return g, nil
|
||||
}
|
||||
return nil, errors.New("group not found")
|
||||
}
|
||||
|
||||
// Stub methods to implement GroupRepository interface
|
||||
func (m *mockGroupRepoForGemini) Create(ctx context.Context, group *Group) error { return nil }
|
||||
func (m *mockGroupRepoForGemini) Update(ctx context.Context, group *Group) error { return nil }
|
||||
func (m *mockGroupRepoForGemini) Delete(ctx context.Context, id int64) error { return nil }
|
||||
func (m *mockGroupRepoForGemini) DeleteCascade(ctx context.Context, id int64) ([]int64, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockGroupRepoForGemini) List(ctx context.Context, params pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (m *mockGroupRepoForGemini) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (m *mockGroupRepoForGemini) ListActive(ctx context.Context) ([]Group, error) { return nil, nil }
|
||||
func (m *mockGroupRepoForGemini) ListActiveByPlatform(ctx context.Context, platform string) ([]Group, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockGroupRepoForGemini) ExistsByName(ctx context.Context, name string) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
func (m *mockGroupRepoForGemini) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
func (m *mockGroupRepoForGemini) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
var _ GroupRepository = (*mockGroupRepoForGemini)(nil)
|
||||
|
||||
// mockGatewayCacheForGemini Gemini 测试用的 cache mock
|
||||
type mockGatewayCacheForGemini struct {
|
||||
sessionBindings map[string]int64
|
||||
}
|
||||
|
||||
func (m *mockGatewayCacheForGemini) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) {
|
||||
if id, ok := m.sessionBindings[sessionHash]; ok {
|
||||
return id, nil
|
||||
}
|
||||
return 0, errors.New("not found")
|
||||
}
|
||||
|
||||
func (m *mockGatewayCacheForGemini) SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error {
|
||||
if m.sessionBindings == nil {
|
||||
m.sessionBindings = make(map[string]int64)
|
||||
}
|
||||
m.sessionBindings[sessionHash] = accountID
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockGatewayCacheForGemini) RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform 测试 Gemini 单平台选择
|
||||
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
repo := &mockAccountRepoForGemini{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||
{ID: 2, Platform: PlatformGemini, Priority: 2, Status: StatusActive, Schedulable: true},
|
||||
{ID: 3, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true}, // 应被隔离
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForGemini{}
|
||||
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
|
||||
|
||||
svc := &GeminiMessagesCompatService{
|
||||
accountRepo: repo,
|
||||
groupRepo: groupRepo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
// 无分组时使用 gemini 平台
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-flash", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, int64(1), acc.ID, "应选择优先级最高的 gemini 账户")
|
||||
require.Equal(t, PlatformGemini, acc.Platform, "无分组时应只返回 gemini 平台账户")
|
||||
}
|
||||
|
||||
func TestGeminiMessagesCompatService_GroupResolution_ReusesContextGroup(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(7)
|
||||
group := &Group{
|
||||
ID: groupID,
|
||||
Platform: PlatformGemini,
|
||||
Status: StatusActive,
|
||||
Hydrated: true,
|
||||
}
|
||||
ctx = context.WithValue(ctx, ctxkey.Group, group)
|
||||
|
||||
repo := &mockAccountRepoForGemini{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForGemini{}
|
||||
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
|
||||
|
||||
svc := &GeminiMessagesCompatService{
|
||||
accountRepo: repo,
|
||||
groupRepo: groupRepo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "gemini-2.5-flash", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, 0, groupRepo.getByIDCalls)
|
||||
require.Equal(t, 0, groupRepo.getByIDLiteCalls)
|
||||
}
|
||||
|
||||
func TestGeminiMessagesCompatService_GroupResolution_UsesLiteFetch(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(7)
|
||||
|
||||
repo := &mockAccountRepoForGemini{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForGemini{}
|
||||
groupRepo := &mockGroupRepoForGemini{
|
||||
groups: map[int64]*Group{
|
||||
groupID: {ID: groupID, Platform: PlatformGemini},
|
||||
},
|
||||
}
|
||||
|
||||
svc := &GeminiMessagesCompatService{
|
||||
accountRepo: repo,
|
||||
groupRepo: groupRepo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "gemini-2.5-flash", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, 0, groupRepo.getByIDCalls)
|
||||
require.Equal(t, 1, groupRepo.getByIDLiteCalls)
|
||||
}
|
||||
|
||||
// TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_AntigravityGroup 测试 antigravity 分组
|
||||
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_AntigravityGroup(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
repo := &mockAccountRepoForGemini{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true}, // 应被隔离
|
||||
{ID: 2, Platform: PlatformAntigravity, Priority: 1, Status: StatusActive, Schedulable: true}, // 应被选择
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForGemini{}
|
||||
groupRepo := &mockGroupRepoForGemini{
|
||||
groups: map[int64]*Group{
|
||||
1: {ID: 1, Platform: PlatformAntigravity},
|
||||
},
|
||||
}
|
||||
|
||||
svc := &GeminiMessagesCompatService{
|
||||
accountRepo: repo,
|
||||
groupRepo: groupRepo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
groupID := int64(1)
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "gemini-2.5-flash", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, int64(2), acc.ID)
|
||||
require.Equal(t, PlatformAntigravity, acc.Platform, "antigravity 分组应只返回 antigravity 账户")
|
||||
}
|
||||
|
||||
// TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_OAuthPreferred 测试 OAuth 优先
|
||||
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_OAuthPreferred(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
repo := &mockAccountRepoForGemini{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformGemini, Type: AccountTypeAPIKey, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: nil},
|
||||
{ID: 2, Platform: PlatformGemini, Type: AccountTypeOAuth, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: nil},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForGemini{}
|
||||
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
|
||||
|
||||
svc := &GeminiMessagesCompatService{
|
||||
accountRepo: repo,
|
||||
groupRepo: groupRepo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-flash", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, int64(2), acc.ID, "同优先级且都未使用时,应优先选择 OAuth 账户")
|
||||
require.Equal(t, AccountTypeOAuth, acc.Type)
|
||||
}
|
||||
|
||||
// TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_NoAvailableAccounts 测试无可用账户
|
||||
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_NoAvailableAccounts(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
repo := &mockAccountRepoForGemini{
|
||||
accounts: []Account{},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForGemini{}
|
||||
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
|
||||
|
||||
svc := &GeminiMessagesCompatService{
|
||||
accountRepo: repo,
|
||||
groupRepo: groupRepo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "", "gemini-2.5-flash", nil)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, acc)
|
||||
require.Contains(t, err.Error(), "no available")
|
||||
}
|
||||
|
||||
// TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_StickySession 测试粘性会话
|
||||
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_StickySession(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("粘性会话命中-同平台", func(t *testing.T) {
|
||||
repo := &mockAccountRepoForGemini{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformGemini, Priority: 2, Status: StatusActive, Schedulable: true},
|
||||
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
// 注意:缓存键使用 "gemini:" 前缀
|
||||
cache := &mockGatewayCacheForGemini{
|
||||
sessionBindings: map[string]int64{"gemini:session-123": 1},
|
||||
}
|
||||
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
|
||||
|
||||
svc := &GeminiMessagesCompatService{
|
||||
accountRepo: repo,
|
||||
groupRepo: groupRepo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "session-123", "gemini-2.5-flash", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, int64(1), acc.ID, "应返回粘性会话绑定的账户")
|
||||
})
|
||||
|
||||
t.Run("粘性会话平台不匹配-降级选择", func(t *testing.T) {
|
||||
repo := &mockAccountRepoForGemini{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformAntigravity, Priority: 2, Status: StatusActive, Schedulable: true}, // 粘性会话绑定
|
||||
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForGemini{
|
||||
sessionBindings: map[string]int64{"gemini:session-123": 1}, // 绑定 antigravity 账户
|
||||
}
|
||||
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
|
||||
|
||||
svc := &GeminiMessagesCompatService{
|
||||
accountRepo: repo,
|
||||
groupRepo: groupRepo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
// 无分组时使用 gemini 平台,粘性会话绑定的 antigravity 账户平台不匹配
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "session-123", "gemini-2.5-flash", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, int64(2), acc.ID, "粘性会话账户平台不匹配,应降级选择 gemini 账户")
|
||||
require.Equal(t, PlatformGemini, acc.Platform)
|
||||
})
|
||||
|
||||
t.Run("粘性会话不命中无前缀缓存键", func(t *testing.T) {
|
||||
repo := &mockAccountRepoForGemini{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformGemini, Priority: 2, Status: StatusActive, Schedulable: true},
|
||||
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
// 缓存键没有 "gemini:" 前缀,不应命中
|
||||
cache := &mockGatewayCacheForGemini{
|
||||
sessionBindings: map[string]int64{"session-123": 1},
|
||||
}
|
||||
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
|
||||
|
||||
svc := &GeminiMessagesCompatService{
|
||||
accountRepo: repo,
|
||||
groupRepo: groupRepo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(ctx, nil, "session-123", "gemini-2.5-flash", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
// 粘性会话未命中,按优先级选择
|
||||
require.Equal(t, int64(2), acc.ID, "粘性会话未命中,应按优先级选择")
|
||||
})
|
||||
}
|
||||
|
||||
// TestGeminiPlatformRouting_DocumentRouteDecision 测试平台路由决策逻辑
|
||||
func TestGeminiPlatformRouting_DocumentRouteDecision(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
platform string
|
||||
expectedService string // "gemini" 表示 ForwardNative, "antigravity" 表示 ForwardGemini
|
||||
}{
|
||||
{
|
||||
name: "Gemini平台走ForwardNative",
|
||||
platform: PlatformGemini,
|
||||
expectedService: "gemini",
|
||||
},
|
||||
{
|
||||
name: "Antigravity平台走ForwardGemini",
|
||||
platform: PlatformAntigravity,
|
||||
expectedService: "antigravity",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
account := &Account{Platform: tt.platform}
|
||||
|
||||
// 模拟 Handler 层的路由逻辑
|
||||
var serviceName string
|
||||
if account.Platform == PlatformAntigravity {
|
||||
serviceName = "antigravity"
|
||||
} else {
|
||||
serviceName = "gemini"
|
||||
}
|
||||
|
||||
require.Equal(t, tt.expectedService, serviceName,
|
||||
"平台 %s 应该路由到 %s 服务", tt.platform, tt.expectedService)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeminiMessagesCompatService_isModelSupportedByAccount(t *testing.T) {
|
||||
svc := &GeminiMessagesCompatService{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
account *Account
|
||||
model string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "Antigravity平台-支持gemini模型",
|
||||
account: &Account{Platform: PlatformAntigravity},
|
||||
model: "gemini-2.5-flash",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Antigravity平台-支持claude模型",
|
||||
account: &Account{Platform: PlatformAntigravity},
|
||||
model: "claude-3-5-sonnet-20241022",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Antigravity平台-不支持gpt模型",
|
||||
account: &Account{Platform: PlatformAntigravity},
|
||||
model: "gpt-4",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Gemini平台-无映射配置-支持所有模型",
|
||||
account: &Account{Platform: PlatformGemini},
|
||||
model: "gemini-2.5-flash",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Gemini平台-有映射配置-只支持配置的模型",
|
||||
account: &Account{
|
||||
Platform: PlatformGemini,
|
||||
Credentials: map[string]any{"model_mapping": map[string]any{"gemini-1.5-pro": "x"}},
|
||||
},
|
||||
model: "gemini-2.5-flash",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := svc.isModelSupportedByAccount(tt.account, tt.model)
|
||||
require.Equal(t, tt.expected, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
13
backend/internal/service/gemini_oauth.go
Normal file
13
backend/internal/service/gemini_oauth.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
|
||||
)
|
||||
|
||||
// GeminiOAuthClient performs Google OAuth token exchange/refresh for Gemini integration.
|
||||
type GeminiOAuthClient interface {
|
||||
ExchangeCode(ctx context.Context, oauthType, code, codeVerifier, redirectURI, proxyURL string) (*geminicli.TokenResponse, error)
|
||||
RefreshToken(ctx context.Context, oauthType, refreshToken, proxyURL string) (*geminicli.TokenResponse, error)
|
||||
}
|
||||
1074
backend/internal/service/gemini_oauth_service.go
Normal file
1074
backend/internal/service/gemini_oauth_service.go
Normal file
File diff suppressed because it is too large
Load Diff
130
backend/internal/service/gemini_oauth_service_test.go
Normal file
130
backend/internal/service/gemini_oauth_service_test.go
Normal file
@@ -0,0 +1,130 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
|
||||
)
|
||||
|
||||
func TestGeminiOAuthService_GenerateAuthURL_RedirectURIStrategy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
type testCase struct {
|
||||
name string
|
||||
cfg *config.Config
|
||||
oauthType string
|
||||
projectID string
|
||||
wantClientID string
|
||||
wantRedirect string
|
||||
wantScope string
|
||||
wantProjectID string
|
||||
wantErrSubstr string
|
||||
}
|
||||
|
||||
tests := []testCase{
|
||||
{
|
||||
name: "google_one uses built-in client when not configured and redirects to upstream",
|
||||
cfg: &config.Config{
|
||||
Gemini: config.GeminiConfig{
|
||||
OAuth: config.GeminiOAuthConfig{},
|
||||
},
|
||||
},
|
||||
oauthType: "google_one",
|
||||
wantClientID: geminicli.GeminiCLIOAuthClientID,
|
||||
wantRedirect: geminicli.GeminiCLIRedirectURI,
|
||||
wantScope: geminicli.DefaultCodeAssistScopes,
|
||||
wantProjectID: "",
|
||||
},
|
||||
{
|
||||
name: "google_one always forces built-in client even when custom client configured",
|
||||
cfg: &config.Config{
|
||||
Gemini: config.GeminiConfig{
|
||||
OAuth: config.GeminiOAuthConfig{
|
||||
ClientID: "custom-client-id",
|
||||
ClientSecret: "custom-client-secret",
|
||||
},
|
||||
},
|
||||
},
|
||||
oauthType: "google_one",
|
||||
wantClientID: geminicli.GeminiCLIOAuthClientID,
|
||||
wantRedirect: geminicli.GeminiCLIRedirectURI,
|
||||
wantScope: geminicli.DefaultCodeAssistScopes,
|
||||
wantProjectID: "",
|
||||
},
|
||||
{
|
||||
name: "code_assist always forces built-in client even when custom client configured",
|
||||
cfg: &config.Config{
|
||||
Gemini: config.GeminiConfig{
|
||||
OAuth: config.GeminiOAuthConfig{
|
||||
ClientID: "custom-client-id",
|
||||
ClientSecret: "custom-client-secret",
|
||||
},
|
||||
},
|
||||
},
|
||||
oauthType: "code_assist",
|
||||
projectID: "my-gcp-project",
|
||||
wantClientID: geminicli.GeminiCLIOAuthClientID,
|
||||
wantRedirect: geminicli.GeminiCLIRedirectURI,
|
||||
wantScope: geminicli.DefaultCodeAssistScopes,
|
||||
wantProjectID: "my-gcp-project",
|
||||
},
|
||||
{
|
||||
name: "ai_studio requires custom client",
|
||||
cfg: &config.Config{
|
||||
Gemini: config.GeminiConfig{
|
||||
OAuth: config.GeminiOAuthConfig{},
|
||||
},
|
||||
},
|
||||
oauthType: "ai_studio",
|
||||
wantErrSubstr: "AI Studio OAuth requires a custom OAuth Client",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
svc := NewGeminiOAuthService(nil, nil, nil, tt.cfg)
|
||||
got, err := svc.GenerateAuthURL(context.Background(), nil, "https://example.com/auth/callback", tt.projectID, tt.oauthType, "")
|
||||
if tt.wantErrSubstr != "" {
|
||||
if err == nil {
|
||||
t.Fatalf("expected error containing %q, got nil", tt.wantErrSubstr)
|
||||
}
|
||||
if !strings.Contains(err.Error(), tt.wantErrSubstr) {
|
||||
t.Fatalf("expected error containing %q, got: %v", tt.wantErrSubstr, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateAuthURL returned error: %v", err)
|
||||
}
|
||||
|
||||
parsed, err := url.Parse(got.AuthURL)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse auth_url: %v", err)
|
||||
}
|
||||
q := parsed.Query()
|
||||
|
||||
if gotState := q.Get("state"); gotState != got.State {
|
||||
t.Fatalf("state mismatch: query=%q result=%q", gotState, got.State)
|
||||
}
|
||||
if gotClientID := q.Get("client_id"); gotClientID != tt.wantClientID {
|
||||
t.Fatalf("client_id mismatch: got=%q want=%q", gotClientID, tt.wantClientID)
|
||||
}
|
||||
if gotRedirect := q.Get("redirect_uri"); gotRedirect != tt.wantRedirect {
|
||||
t.Fatalf("redirect_uri mismatch: got=%q want=%q", gotRedirect, tt.wantRedirect)
|
||||
}
|
||||
if gotScope := q.Get("scope"); gotScope != tt.wantScope {
|
||||
t.Fatalf("scope mismatch: got=%q want=%q", gotScope, tt.wantScope)
|
||||
}
|
||||
if gotProjectID := q.Get("project_id"); gotProjectID != tt.wantProjectID {
|
||||
t.Fatalf("project_id mismatch: got=%q want=%q", gotProjectID, tt.wantProjectID)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
448
backend/internal/service/gemini_quota.go
Normal file
448
backend/internal/service/gemini_quota.go
Normal file
@@ -0,0 +1,448 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
)
|
||||
|
||||
type geminiModelClass string
|
||||
|
||||
const (
|
||||
geminiModelPro geminiModelClass = "pro"
|
||||
geminiModelFlash geminiModelClass = "flash"
|
||||
)
|
||||
|
||||
type GeminiQuota struct {
|
||||
// SharedRPD is a shared requests-per-day pool across models.
|
||||
// When SharedRPD > 0, callers should treat ProRPD/FlashRPD as not applicable for daily quota checks.
|
||||
SharedRPD int64 `json:"shared_rpd,omitempty"`
|
||||
// SharedRPM is a shared requests-per-minute pool across models.
|
||||
// When SharedRPM > 0, callers should treat ProRPM/FlashRPM as not applicable for minute quota checks.
|
||||
SharedRPM int64 `json:"shared_rpm,omitempty"`
|
||||
|
||||
// Per-model quotas (AI Studio / API key).
|
||||
// A value of -1 means "unlimited" (pay-as-you-go).
|
||||
ProRPD int64 `json:"pro_rpd,omitempty"`
|
||||
ProRPM int64 `json:"pro_rpm,omitempty"`
|
||||
FlashRPD int64 `json:"flash_rpd,omitempty"`
|
||||
FlashRPM int64 `json:"flash_rpm,omitempty"`
|
||||
}
|
||||
|
||||
type GeminiTierPolicy struct {
|
||||
Quota GeminiQuota
|
||||
Cooldown time.Duration
|
||||
}
|
||||
|
||||
type GeminiQuotaPolicy struct {
|
||||
tiers map[string]GeminiTierPolicy
|
||||
}
|
||||
|
||||
type GeminiUsageTotals struct {
|
||||
ProRequests int64
|
||||
FlashRequests int64
|
||||
ProTokens int64
|
||||
FlashTokens int64
|
||||
ProCost float64
|
||||
FlashCost float64
|
||||
}
|
||||
|
||||
const geminiQuotaCacheTTL = time.Minute
|
||||
|
||||
type geminiQuotaOverridesV1 struct {
|
||||
Tiers map[string]config.GeminiTierQuotaConfig `json:"tiers"`
|
||||
}
|
||||
|
||||
type geminiQuotaOverridesV2 struct {
|
||||
QuotaRules map[string]geminiQuotaRuleOverride `json:"quota_rules"`
|
||||
}
|
||||
|
||||
type geminiQuotaRuleOverride struct {
|
||||
SharedRPD *int64 `json:"shared_rpd,omitempty"`
|
||||
SharedRPM *int64 `json:"rpm,omitempty"`
|
||||
GeminiPro *geminiModelQuotaOverride `json:"gemini_pro,omitempty"`
|
||||
GeminiFlash *geminiModelQuotaOverride `json:"gemini_flash,omitempty"`
|
||||
Desc *string `json:"desc,omitempty"`
|
||||
}
|
||||
|
||||
type geminiModelQuotaOverride struct {
|
||||
RPD *int64 `json:"rpd,omitempty"`
|
||||
RPM *int64 `json:"rpm,omitempty"`
|
||||
}
|
||||
|
||||
type GeminiQuotaService struct {
|
||||
cfg *config.Config
|
||||
settingRepo SettingRepository
|
||||
mu sync.Mutex
|
||||
cachedAt time.Time
|
||||
policy *GeminiQuotaPolicy
|
||||
}
|
||||
|
||||
func NewGeminiQuotaService(cfg *config.Config, settingRepo SettingRepository) *GeminiQuotaService {
|
||||
return &GeminiQuotaService{
|
||||
cfg: cfg,
|
||||
settingRepo: settingRepo,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *GeminiQuotaService) Policy(ctx context.Context) *GeminiQuotaPolicy {
|
||||
if s == nil {
|
||||
return newGeminiQuotaPolicy()
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
s.mu.Lock()
|
||||
if s.policy != nil && now.Sub(s.cachedAt) < geminiQuotaCacheTTL {
|
||||
policy := s.policy
|
||||
s.mu.Unlock()
|
||||
return policy
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
policy := newGeminiQuotaPolicy()
|
||||
if s.cfg != nil {
|
||||
policy.ApplyOverrides(s.cfg.Gemini.Quota.Tiers)
|
||||
if strings.TrimSpace(s.cfg.Gemini.Quota.Policy) != "" {
|
||||
raw := []byte(s.cfg.Gemini.Quota.Policy)
|
||||
var overridesV2 geminiQuotaOverridesV2
|
||||
if err := json.Unmarshal(raw, &overridesV2); err == nil && len(overridesV2.QuotaRules) > 0 {
|
||||
policy.ApplyQuotaRulesOverrides(overridesV2.QuotaRules)
|
||||
} else {
|
||||
var overridesV1 geminiQuotaOverridesV1
|
||||
if err := json.Unmarshal(raw, &overridesV1); err != nil {
|
||||
log.Printf("gemini quota: parse config policy failed: %v", err)
|
||||
} else {
|
||||
policy.ApplyOverrides(overridesV1.Tiers)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if s.settingRepo != nil {
|
||||
value, err := s.settingRepo.GetValue(ctx, SettingKeyGeminiQuotaPolicy)
|
||||
if err != nil && !errors.Is(err, ErrSettingNotFound) {
|
||||
log.Printf("gemini quota: load setting failed: %v", err)
|
||||
} else if strings.TrimSpace(value) != "" {
|
||||
raw := []byte(value)
|
||||
var overridesV2 geminiQuotaOverridesV2
|
||||
if err := json.Unmarshal(raw, &overridesV2); err == nil && len(overridesV2.QuotaRules) > 0 {
|
||||
policy.ApplyQuotaRulesOverrides(overridesV2.QuotaRules)
|
||||
} else {
|
||||
var overridesV1 geminiQuotaOverridesV1
|
||||
if err := json.Unmarshal(raw, &overridesV1); err != nil {
|
||||
log.Printf("gemini quota: parse setting failed: %v", err)
|
||||
} else {
|
||||
policy.ApplyOverrides(overridesV1.Tiers)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.policy = policy
|
||||
s.cachedAt = now
|
||||
s.mu.Unlock()
|
||||
|
||||
return policy
|
||||
}
|
||||
|
||||
func (s *GeminiQuotaService) QuotaForAccount(ctx context.Context, account *Account) (GeminiQuota, bool) {
|
||||
if account == nil || account.Platform != PlatformGemini {
|
||||
return GeminiQuota{}, false
|
||||
}
|
||||
|
||||
// Map (oauth_type + tier_id) to a canonical policy tier key.
|
||||
// This keeps the policy table stable even if upstream tier_id strings vary.
|
||||
tierKey := geminiQuotaTierKeyForAccount(account)
|
||||
if tierKey == "" {
|
||||
return GeminiQuota{}, false
|
||||
}
|
||||
|
||||
policy := s.Policy(ctx)
|
||||
return policy.QuotaForTier(tierKey)
|
||||
}
|
||||
|
||||
func (s *GeminiQuotaService) CooldownForTier(ctx context.Context, tierID string) time.Duration {
|
||||
policy := s.Policy(ctx)
|
||||
return policy.CooldownForTier(tierID)
|
||||
}
|
||||
|
||||
func (s *GeminiQuotaService) CooldownForAccount(ctx context.Context, account *Account) time.Duration {
|
||||
if s == nil || account == nil || account.Platform != PlatformGemini {
|
||||
return 5 * time.Minute
|
||||
}
|
||||
tierKey := geminiQuotaTierKeyForAccount(account)
|
||||
if strings.TrimSpace(tierKey) == "" {
|
||||
return 5 * time.Minute
|
||||
}
|
||||
return s.CooldownForTier(ctx, tierKey)
|
||||
}
|
||||
|
||||
func newGeminiQuotaPolicy() *GeminiQuotaPolicy {
|
||||
return &GeminiQuotaPolicy{
|
||||
tiers: map[string]GeminiTierPolicy{
|
||||
// --- AI Studio / API Key (per-model) ---
|
||||
// aistudio_free:
|
||||
// - gemini_pro: 50 RPD / 2 RPM
|
||||
// - gemini_flash: 1500 RPD / 15 RPM
|
||||
GeminiTierAIStudioFree: {Quota: GeminiQuota{ProRPD: 50, ProRPM: 2, FlashRPD: 1500, FlashRPM: 15}, Cooldown: 30 * time.Minute},
|
||||
// aistudio_paid: -1 means "unlimited/pay-as-you-go" for RPD.
|
||||
GeminiTierAIStudioPaid: {Quota: GeminiQuota{ProRPD: -1, ProRPM: 1000, FlashRPD: -1, FlashRPM: 2000}, Cooldown: 5 * time.Minute},
|
||||
|
||||
// --- Google One (shared pool) ---
|
||||
GeminiTierGoogleOneFree: {Quota: GeminiQuota{SharedRPD: 1000, SharedRPM: 60}, Cooldown: 30 * time.Minute},
|
||||
GeminiTierGoogleAIPro: {Quota: GeminiQuota{SharedRPD: 1500, SharedRPM: 120}, Cooldown: 5 * time.Minute},
|
||||
GeminiTierGoogleAIUltra: {Quota: GeminiQuota{SharedRPD: 2000, SharedRPM: 120}, Cooldown: 5 * time.Minute},
|
||||
|
||||
// --- GCP Code Assist (shared pool) ---
|
||||
GeminiTierGCPStandard: {Quota: GeminiQuota{SharedRPD: 1500, SharedRPM: 120}, Cooldown: 5 * time.Minute},
|
||||
GeminiTierGCPEnterprise: {Quota: GeminiQuota{SharedRPD: 2000, SharedRPM: 120}, Cooldown: 5 * time.Minute},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (p *GeminiQuotaPolicy) ApplyOverrides(tiers map[string]config.GeminiTierQuotaConfig) {
|
||||
if p == nil || len(tiers) == 0 {
|
||||
return
|
||||
}
|
||||
for rawID, override := range tiers {
|
||||
tierID := normalizeGeminiTierID(rawID)
|
||||
if tierID == "" {
|
||||
continue
|
||||
}
|
||||
policy, ok := p.tiers[tierID]
|
||||
if !ok {
|
||||
policy = GeminiTierPolicy{Cooldown: 5 * time.Minute}
|
||||
}
|
||||
// Backward-compatible overrides:
|
||||
// - If the tier uses shared quota, interpret pro_rpd as shared_rpd.
|
||||
// - Otherwise apply per-model overrides.
|
||||
if override.ProRPD != nil {
|
||||
if policy.Quota.SharedRPD > 0 {
|
||||
policy.Quota.SharedRPD = clampGeminiQuotaInt64WithUnlimited(*override.ProRPD)
|
||||
} else {
|
||||
policy.Quota.ProRPD = clampGeminiQuotaInt64WithUnlimited(*override.ProRPD)
|
||||
}
|
||||
}
|
||||
if override.FlashRPD != nil {
|
||||
if policy.Quota.SharedRPD > 0 {
|
||||
// No separate flash RPD for shared tiers.
|
||||
} else {
|
||||
policy.Quota.FlashRPD = clampGeminiQuotaInt64WithUnlimited(*override.FlashRPD)
|
||||
}
|
||||
}
|
||||
if override.CooldownMinutes != nil {
|
||||
minutes := clampGeminiQuotaInt(*override.CooldownMinutes)
|
||||
policy.Cooldown = time.Duration(minutes) * time.Minute
|
||||
}
|
||||
p.tiers[tierID] = policy
|
||||
}
|
||||
}
|
||||
|
||||
func (p *GeminiQuotaPolicy) ApplyQuotaRulesOverrides(rules map[string]geminiQuotaRuleOverride) {
|
||||
if p == nil || len(rules) == 0 {
|
||||
return
|
||||
}
|
||||
for rawID, override := range rules {
|
||||
tierID := normalizeGeminiTierID(rawID)
|
||||
if tierID == "" {
|
||||
continue
|
||||
}
|
||||
policy, ok := p.tiers[tierID]
|
||||
if !ok {
|
||||
policy = GeminiTierPolicy{Cooldown: 5 * time.Minute}
|
||||
}
|
||||
|
||||
if override.SharedRPD != nil {
|
||||
policy.Quota.SharedRPD = clampGeminiQuotaInt64WithUnlimited(*override.SharedRPD)
|
||||
}
|
||||
if override.SharedRPM != nil {
|
||||
policy.Quota.SharedRPM = clampGeminiQuotaRPM(*override.SharedRPM)
|
||||
}
|
||||
if override.GeminiPro != nil {
|
||||
if override.GeminiPro.RPD != nil {
|
||||
policy.Quota.ProRPD = clampGeminiQuotaInt64WithUnlimited(*override.GeminiPro.RPD)
|
||||
}
|
||||
if override.GeminiPro.RPM != nil {
|
||||
policy.Quota.ProRPM = clampGeminiQuotaRPM(*override.GeminiPro.RPM)
|
||||
}
|
||||
}
|
||||
if override.GeminiFlash != nil {
|
||||
if override.GeminiFlash.RPD != nil {
|
||||
policy.Quota.FlashRPD = clampGeminiQuotaInt64WithUnlimited(*override.GeminiFlash.RPD)
|
||||
}
|
||||
if override.GeminiFlash.RPM != nil {
|
||||
policy.Quota.FlashRPM = clampGeminiQuotaRPM(*override.GeminiFlash.RPM)
|
||||
}
|
||||
}
|
||||
|
||||
p.tiers[tierID] = policy
|
||||
}
|
||||
}
|
||||
|
||||
func (p *GeminiQuotaPolicy) QuotaForTier(tierID string) (GeminiQuota, bool) {
|
||||
policy, ok := p.policyForTier(tierID)
|
||||
if !ok {
|
||||
return GeminiQuota{}, false
|
||||
}
|
||||
return policy.Quota, true
|
||||
}
|
||||
|
||||
func (p *GeminiQuotaPolicy) CooldownForTier(tierID string) time.Duration {
|
||||
policy, ok := p.policyForTier(tierID)
|
||||
if ok && policy.Cooldown > 0 {
|
||||
return policy.Cooldown
|
||||
}
|
||||
return 5 * time.Minute
|
||||
}
|
||||
|
||||
func (p *GeminiQuotaPolicy) policyForTier(tierID string) (GeminiTierPolicy, bool) {
|
||||
if p == nil {
|
||||
return GeminiTierPolicy{}, false
|
||||
}
|
||||
normalized := normalizeGeminiTierID(tierID)
|
||||
if policy, ok := p.tiers[normalized]; ok {
|
||||
return policy, true
|
||||
}
|
||||
return GeminiTierPolicy{}, false
|
||||
}
|
||||
|
||||
func normalizeGeminiTierID(tierID string) string {
|
||||
tierID = strings.TrimSpace(tierID)
|
||||
if tierID == "" {
|
||||
return ""
|
||||
}
|
||||
// Prefer canonical mapping (handles legacy tier strings).
|
||||
if canonical := canonicalGeminiTierID(tierID); canonical != "" {
|
||||
return canonical
|
||||
}
|
||||
// Accept older policy keys that used uppercase names.
|
||||
switch strings.ToUpper(tierID) {
|
||||
case "AISTUDIO_FREE":
|
||||
return GeminiTierAIStudioFree
|
||||
case "AISTUDIO_PAID":
|
||||
return GeminiTierAIStudioPaid
|
||||
case "GOOGLE_ONE_FREE":
|
||||
return GeminiTierGoogleOneFree
|
||||
case "GOOGLE_AI_PRO":
|
||||
return GeminiTierGoogleAIPro
|
||||
case "GOOGLE_AI_ULTRA":
|
||||
return GeminiTierGoogleAIUltra
|
||||
case "GCP_STANDARD":
|
||||
return GeminiTierGCPStandard
|
||||
case "GCP_ENTERPRISE":
|
||||
return GeminiTierGCPEnterprise
|
||||
}
|
||||
return strings.ToLower(tierID)
|
||||
}
|
||||
|
||||
func clampGeminiQuotaInt64WithUnlimited(value int64) int64 {
|
||||
if value < -1 {
|
||||
return 0
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func clampGeminiQuotaInt(value int) int {
|
||||
if value < 0 {
|
||||
return 0
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func clampGeminiQuotaRPM(value int64) int64 {
|
||||
if value < 0 {
|
||||
return 0
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func geminiCooldownForTier(tierID string) time.Duration {
|
||||
policy := newGeminiQuotaPolicy()
|
||||
return policy.CooldownForTier(tierID)
|
||||
}
|
||||
|
||||
func geminiQuotaTierKeyForAccount(account *Account) string {
|
||||
if account == nil || account.Platform != PlatformGemini {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Note: GeminiOAuthType() already defaults legacy (project_id present) to code_assist.
|
||||
oauthType := strings.ToLower(strings.TrimSpace(account.GeminiOAuthType()))
|
||||
rawTier := strings.TrimSpace(account.GeminiTierID())
|
||||
|
||||
// Prefer the canonical tier stored in credentials.
|
||||
if tierID := canonicalGeminiTierIDForOAuthType(oauthType, rawTier); tierID != "" && tierID != GeminiTierGoogleOneUnknown {
|
||||
return tierID
|
||||
}
|
||||
|
||||
// Fallback defaults when tier_id is missing or unknown.
|
||||
switch oauthType {
|
||||
case "google_one":
|
||||
return GeminiTierGoogleOneFree
|
||||
case "code_assist":
|
||||
return GeminiTierGCPStandard
|
||||
case "ai_studio":
|
||||
return GeminiTierAIStudioFree
|
||||
default:
|
||||
// API Key accounts (type=apikey) have empty oauth_type and are treated as AI Studio.
|
||||
return GeminiTierAIStudioFree
|
||||
}
|
||||
}
|
||||
|
||||
func geminiModelClassFromName(model string) geminiModelClass {
|
||||
name := strings.ToLower(strings.TrimSpace(model))
|
||||
if strings.Contains(name, "flash") || strings.Contains(name, "lite") {
|
||||
return geminiModelFlash
|
||||
}
|
||||
return geminiModelPro
|
||||
}
|
||||
|
||||
func geminiAggregateUsage(stats []usagestats.ModelStat) GeminiUsageTotals {
|
||||
var totals GeminiUsageTotals
|
||||
for _, stat := range stats {
|
||||
switch geminiModelClassFromName(stat.Model) {
|
||||
case geminiModelFlash:
|
||||
totals.FlashRequests += stat.Requests
|
||||
totals.FlashTokens += stat.TotalTokens
|
||||
totals.FlashCost += stat.ActualCost
|
||||
default:
|
||||
totals.ProRequests += stat.Requests
|
||||
totals.ProTokens += stat.TotalTokens
|
||||
totals.ProCost += stat.ActualCost
|
||||
}
|
||||
}
|
||||
return totals
|
||||
}
|
||||
|
||||
func geminiQuotaLocation() *time.Location {
|
||||
loc, err := time.LoadLocation("America/Los_Angeles")
|
||||
if err != nil {
|
||||
return time.FixedZone("PST", -8*3600)
|
||||
}
|
||||
return loc
|
||||
}
|
||||
|
||||
func geminiDailyWindowStart(now time.Time) time.Time {
|
||||
loc := geminiQuotaLocation()
|
||||
localNow := now.In(loc)
|
||||
return time.Date(localNow.Year(), localNow.Month(), localNow.Day(), 0, 0, 0, 0, loc)
|
||||
}
|
||||
|
||||
func geminiDailyResetTime(now time.Time) time.Time {
|
||||
loc := geminiQuotaLocation()
|
||||
localNow := now.In(loc)
|
||||
start := time.Date(localNow.Year(), localNow.Month(), localNow.Day(), 0, 0, 0, 0, loc)
|
||||
reset := start.Add(24 * time.Hour)
|
||||
if !reset.After(localNow) {
|
||||
reset = reset.Add(24 * time.Hour)
|
||||
}
|
||||
return reset
|
||||
}
|
||||
17
backend/internal/service/gemini_token_cache.go
Normal file
17
backend/internal/service/gemini_token_cache.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// GeminiTokenCache stores short-lived access tokens and coordinates refresh to avoid stampedes.
|
||||
type GeminiTokenCache interface {
|
||||
// cacheKey should be stable for the token scope; for GeminiCli OAuth we primarily use project_id.
|
||||
GetAccessToken(ctx context.Context, cacheKey string) (string, error)
|
||||
SetAccessToken(ctx context.Context, cacheKey string, token string, ttl time.Duration) error
|
||||
DeleteAccessToken(ctx context.Context, cacheKey string) error
|
||||
|
||||
AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error)
|
||||
ReleaseRefreshLock(ctx context.Context, cacheKey string) error
|
||||
}
|
||||
160
backend/internal/service/gemini_token_provider.go
Normal file
160
backend/internal/service/gemini_token_provider.go
Normal file
@@ -0,0 +1,160 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
geminiTokenRefreshSkew = 3 * time.Minute
|
||||
geminiTokenCacheSkew = 5 * time.Minute
|
||||
)
|
||||
|
||||
type GeminiTokenProvider struct {
|
||||
accountRepo AccountRepository
|
||||
tokenCache GeminiTokenCache
|
||||
geminiOAuthService *GeminiOAuthService
|
||||
}
|
||||
|
||||
func NewGeminiTokenProvider(
|
||||
accountRepo AccountRepository,
|
||||
tokenCache GeminiTokenCache,
|
||||
geminiOAuthService *GeminiOAuthService,
|
||||
) *GeminiTokenProvider {
|
||||
return &GeminiTokenProvider{
|
||||
accountRepo: accountRepo,
|
||||
tokenCache: tokenCache,
|
||||
geminiOAuthService: geminiOAuthService,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
|
||||
if account == nil {
|
||||
return "", errors.New("account is nil")
|
||||
}
|
||||
if account.Platform != PlatformGemini || account.Type != AccountTypeOAuth {
|
||||
return "", errors.New("not a gemini oauth account")
|
||||
}
|
||||
|
||||
cacheKey := GeminiTokenCacheKey(account)
|
||||
|
||||
// 1) Try cache first.
|
||||
if p.tokenCache != nil {
|
||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
||||
return token, nil
|
||||
}
|
||||
}
|
||||
|
||||
// 2) Refresh if needed (pre-expiry skew).
|
||||
expiresAt := account.GetCredentialAsTime("expires_at")
|
||||
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= geminiTokenRefreshSkew
|
||||
if needsRefresh && p.tokenCache != nil {
|
||||
locked, err := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
|
||||
if err == nil && locked {
|
||||
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
|
||||
|
||||
// Re-check after lock (another worker may have refreshed).
|
||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
||||
return token, nil
|
||||
}
|
||||
|
||||
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
|
||||
if err == nil && fresh != nil {
|
||||
account = fresh
|
||||
}
|
||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||
if expiresAt == nil || time.Until(*expiresAt) <= geminiTokenRefreshSkew {
|
||||
if p.geminiOAuthService == nil {
|
||||
return "", errors.New("gemini oauth service not configured")
|
||||
}
|
||||
tokenInfo, err := p.geminiOAuthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
newCredentials := p.geminiOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
for k, v := range account.Credentials {
|
||||
if _, exists := newCredentials[k]; !exists {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
}
|
||||
account.Credentials = newCredentials
|
||||
_ = p.accountRepo.Update(ctx, account)
|
||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
accessToken := account.GetCredential("access_token")
|
||||
if strings.TrimSpace(accessToken) == "" {
|
||||
return "", errors.New("access_token not found in credentials")
|
||||
}
|
||||
|
||||
// project_id is optional now:
|
||||
// - If present: will use Code Assist API (requires project_id)
|
||||
// - If absent: will use AI Studio API with OAuth token (like regular API key mode)
|
||||
// Auto-detect project_id only if explicitly enabled via a credential flag
|
||||
projectID := strings.TrimSpace(account.GetCredential("project_id"))
|
||||
autoDetectProjectID := account.GetCredential("auto_detect_project_id") == "true"
|
||||
|
||||
if projectID == "" && autoDetectProjectID {
|
||||
if p.geminiOAuthService == nil {
|
||||
return accessToken, nil // Fallback to AI Studio API mode
|
||||
}
|
||||
|
||||
var proxyURL string
|
||||
if account.ProxyID != nil && p.geminiOAuthService.proxyRepo != nil {
|
||||
if proxy, err := p.geminiOAuthService.proxyRepo.GetByID(ctx, *account.ProxyID); err == nil && proxy != nil {
|
||||
proxyURL = proxy.URL()
|
||||
}
|
||||
}
|
||||
|
||||
detected, tierID, err := p.geminiOAuthService.fetchProjectID(ctx, accessToken, proxyURL)
|
||||
if err != nil {
|
||||
log.Printf("[GeminiTokenProvider] Auto-detect project_id failed: %v, fallback to AI Studio API mode", err)
|
||||
return accessToken, nil
|
||||
}
|
||||
detected = strings.TrimSpace(detected)
|
||||
tierID = strings.TrimSpace(tierID)
|
||||
if detected != "" {
|
||||
if account.Credentials == nil {
|
||||
account.Credentials = make(map[string]any)
|
||||
}
|
||||
account.Credentials["project_id"] = detected
|
||||
if tierID != "" {
|
||||
account.Credentials["tier_id"] = tierID
|
||||
}
|
||||
_ = p.accountRepo.Update(ctx, account)
|
||||
}
|
||||
}
|
||||
|
||||
// 3) Populate cache with TTL.
|
||||
if p.tokenCache != nil {
|
||||
ttl := 30 * time.Minute
|
||||
if expiresAt != nil {
|
||||
until := time.Until(*expiresAt)
|
||||
switch {
|
||||
case until > geminiTokenCacheSkew:
|
||||
ttl = until - geminiTokenCacheSkew
|
||||
case until > 0:
|
||||
ttl = until
|
||||
default:
|
||||
ttl = time.Minute
|
||||
}
|
||||
}
|
||||
_ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
|
||||
}
|
||||
|
||||
return accessToken, nil
|
||||
}
|
||||
|
||||
func GeminiTokenCacheKey(account *Account) string {
|
||||
projectID := strings.TrimSpace(account.GetCredential("project_id"))
|
||||
if projectID != "" {
|
||||
return projectID
|
||||
}
|
||||
return "account:" + strconv.FormatInt(account.ID, 10)
|
||||
}
|
||||
45
backend/internal/service/gemini_token_refresher.go
Normal file
45
backend/internal/service/gemini_token_refresher.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
type GeminiTokenRefresher struct {
|
||||
geminiOAuthService *GeminiOAuthService
|
||||
}
|
||||
|
||||
func NewGeminiTokenRefresher(geminiOAuthService *GeminiOAuthService) *GeminiTokenRefresher {
|
||||
return &GeminiTokenRefresher{geminiOAuthService: geminiOAuthService}
|
||||
}
|
||||
|
||||
func (r *GeminiTokenRefresher) CanRefresh(account *Account) bool {
|
||||
return account.Platform == PlatformGemini && account.Type == AccountTypeOAuth
|
||||
}
|
||||
|
||||
func (r *GeminiTokenRefresher) NeedsRefresh(account *Account, refreshWindow time.Duration) bool {
|
||||
if !r.CanRefresh(account) {
|
||||
return false
|
||||
}
|
||||
expiresAt := account.GetCredentialAsTime("expires_at")
|
||||
if expiresAt == nil {
|
||||
return false
|
||||
}
|
||||
return time.Until(*expiresAt) < refreshWindow
|
||||
}
|
||||
|
||||
func (r *GeminiTokenRefresher) Refresh(ctx context.Context, account *Account) (map[string]any, error) {
|
||||
tokenInfo, err := r.geminiOAuthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
newCredentials := r.geminiOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
for k, v := range account.Credentials {
|
||||
if _, exists := newCredentials[k]; !exists {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
return newCredentials, nil
|
||||
}
|
||||
13
backend/internal/service/geminicli_codeassist.go
Normal file
13
backend/internal/service/geminicli_codeassist.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
|
||||
)
|
||||
|
||||
// GeminiCliCodeAssistClient calls GeminiCli internal Code Assist endpoints.
|
||||
type GeminiCliCodeAssistClient interface {
|
||||
LoadCodeAssist(ctx context.Context, accessToken, proxyURL string, req *geminicli.LoadCodeAssistRequest) (*geminicli.LoadCodeAssistResponse, error)
|
||||
OnboardUser(ctx context.Context, accessToken, proxyURL string, req *geminicli.OnboardUserRequest) (*geminicli.OnboardUserResponse, error)
|
||||
}
|
||||
92
backend/internal/service/group.go
Normal file
92
backend/internal/service/group.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package service
|
||||
|
||||
import "time"
|
||||
|
||||
type Group struct {
|
||||
ID int64
|
||||
Name string
|
||||
Description string
|
||||
Platform string
|
||||
RateMultiplier float64
|
||||
IsExclusive bool
|
||||
Status string
|
||||
Hydrated bool // indicates the group was loaded from a trusted repository source
|
||||
|
||||
SubscriptionType string
|
||||
DailyLimitUSD *float64
|
||||
WeeklyLimitUSD *float64
|
||||
MonthlyLimitUSD *float64
|
||||
DefaultValidityDays int
|
||||
|
||||
// 图片生成计费配置(antigravity 和 gemini 平台使用)
|
||||
ImagePrice1K *float64
|
||||
ImagePrice2K *float64
|
||||
ImagePrice4K *float64
|
||||
|
||||
// Claude Code 客户端限制
|
||||
ClaudeCodeOnly bool
|
||||
FallbackGroupID *int64
|
||||
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
|
||||
AccountGroups []AccountGroup
|
||||
AccountCount int64
|
||||
}
|
||||
|
||||
func (g *Group) IsActive() bool {
|
||||
return g.Status == StatusActive
|
||||
}
|
||||
|
||||
func (g *Group) IsSubscriptionType() bool {
|
||||
return g.SubscriptionType == SubscriptionTypeSubscription
|
||||
}
|
||||
|
||||
func (g *Group) IsFreeSubscription() bool {
|
||||
return g.IsSubscriptionType() && g.RateMultiplier == 0
|
||||
}
|
||||
|
||||
func (g *Group) HasDailyLimit() bool {
|
||||
return g.DailyLimitUSD != nil && *g.DailyLimitUSD > 0
|
||||
}
|
||||
|
||||
func (g *Group) HasWeeklyLimit() bool {
|
||||
return g.WeeklyLimitUSD != nil && *g.WeeklyLimitUSD > 0
|
||||
}
|
||||
|
||||
func (g *Group) HasMonthlyLimit() bool {
|
||||
return g.MonthlyLimitUSD != nil && *g.MonthlyLimitUSD > 0
|
||||
}
|
||||
|
||||
// GetImagePrice 根据 image_size 返回对应的图片生成价格
|
||||
// 如果分组未配置价格,返回 nil(调用方应使用默认值)
|
||||
func (g *Group) GetImagePrice(imageSize string) *float64 {
|
||||
switch imageSize {
|
||||
case "1K":
|
||||
return g.ImagePrice1K
|
||||
case "2K":
|
||||
return g.ImagePrice2K
|
||||
case "4K":
|
||||
return g.ImagePrice4K
|
||||
default:
|
||||
// 未知尺寸默认按 2K 计费
|
||||
return g.ImagePrice2K
|
||||
}
|
||||
}
|
||||
|
||||
// IsGroupContextValid reports whether a group from context has the fields required for routing decisions.
|
||||
func IsGroupContextValid(group *Group) bool {
|
||||
if group == nil {
|
||||
return false
|
||||
}
|
||||
if group.ID <= 0 {
|
||||
return false
|
||||
}
|
||||
if !group.Hydrated {
|
||||
return false
|
||||
}
|
||||
if group.Platform == "" || group.Status == "" {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
208
backend/internal/service/group_service.go
Normal file
208
backend/internal/service/group_service.go
Normal file
@@ -0,0 +1,208 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrGroupNotFound = infraerrors.NotFound("GROUP_NOT_FOUND", "group not found")
|
||||
ErrGroupExists = infraerrors.Conflict("GROUP_EXISTS", "group name already exists")
|
||||
)
|
||||
|
||||
type GroupRepository interface {
|
||||
Create(ctx context.Context, group *Group) error
|
||||
GetByID(ctx context.Context, id int64) (*Group, error)
|
||||
GetByIDLite(ctx context.Context, id int64) (*Group, error)
|
||||
Update(ctx context.Context, group *Group) error
|
||||
Delete(ctx context.Context, id int64) error
|
||||
DeleteCascade(ctx context.Context, id int64) ([]int64, error)
|
||||
|
||||
List(ctx context.Context, params pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error)
|
||||
ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error)
|
||||
ListActive(ctx context.Context) ([]Group, error)
|
||||
ListActiveByPlatform(ctx context.Context, platform string) ([]Group, error)
|
||||
|
||||
ExistsByName(ctx context.Context, name string) (bool, error)
|
||||
GetAccountCount(ctx context.Context, groupID int64) (int64, error)
|
||||
DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error)
|
||||
}
|
||||
|
||||
// CreateGroupRequest 创建分组请求
|
||||
type CreateGroupRequest struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
RateMultiplier float64 `json:"rate_multiplier"`
|
||||
IsExclusive bool `json:"is_exclusive"`
|
||||
}
|
||||
|
||||
// UpdateGroupRequest 更新分组请求
|
||||
type UpdateGroupRequest struct {
|
||||
Name *string `json:"name"`
|
||||
Description *string `json:"description"`
|
||||
RateMultiplier *float64 `json:"rate_multiplier"`
|
||||
IsExclusive *bool `json:"is_exclusive"`
|
||||
Status *string `json:"status"`
|
||||
}
|
||||
|
||||
// GroupService 分组管理服务
|
||||
type GroupService struct {
|
||||
groupRepo GroupRepository
|
||||
authCacheInvalidator APIKeyAuthCacheInvalidator
|
||||
}
|
||||
|
||||
// NewGroupService 创建分组服务实例
|
||||
func NewGroupService(groupRepo GroupRepository, authCacheInvalidator APIKeyAuthCacheInvalidator) *GroupService {
|
||||
return &GroupService{
|
||||
groupRepo: groupRepo,
|
||||
authCacheInvalidator: authCacheInvalidator,
|
||||
}
|
||||
}
|
||||
|
||||
// Create 创建分组
|
||||
func (s *GroupService) Create(ctx context.Context, req CreateGroupRequest) (*Group, error) {
|
||||
// 检查名称是否已存在
|
||||
exists, err := s.groupRepo.ExistsByName(ctx, req.Name)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("check group exists: %w", err)
|
||||
}
|
||||
if exists {
|
||||
return nil, ErrGroupExists
|
||||
}
|
||||
|
||||
// 创建分组
|
||||
group := &Group{
|
||||
Name: req.Name,
|
||||
Description: req.Description,
|
||||
Platform: PlatformAnthropic,
|
||||
RateMultiplier: req.RateMultiplier,
|
||||
IsExclusive: req.IsExclusive,
|
||||
Status: StatusActive,
|
||||
SubscriptionType: SubscriptionTypeStandard,
|
||||
}
|
||||
|
||||
if err := s.groupRepo.Create(ctx, group); err != nil {
|
||||
return nil, fmt.Errorf("create group: %w", err)
|
||||
}
|
||||
|
||||
return group, nil
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取分组
|
||||
func (s *GroupService) GetByID(ctx context.Context, id int64) (*Group, error) {
|
||||
group, err := s.groupRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get group: %w", err)
|
||||
}
|
||||
return group, nil
|
||||
}
|
||||
|
||||
// List 获取分组列表
|
||||
func (s *GroupService) List(ctx context.Context, params pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
|
||||
groups, pagination, err := s.groupRepo.List(ctx, params)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("list groups: %w", err)
|
||||
}
|
||||
return groups, pagination, nil
|
||||
}
|
||||
|
||||
// ListActive 获取活跃分组列表
|
||||
func (s *GroupService) ListActive(ctx context.Context) ([]Group, error) {
|
||||
groups, err := s.groupRepo.ListActive(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list active groups: %w", err)
|
||||
}
|
||||
return groups, nil
|
||||
}
|
||||
|
||||
// Update 更新分组
|
||||
func (s *GroupService) Update(ctx context.Context, id int64, req UpdateGroupRequest) (*Group, error) {
|
||||
group, err := s.groupRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get group: %w", err)
|
||||
}
|
||||
|
||||
// 更新字段
|
||||
if req.Name != nil && *req.Name != group.Name {
|
||||
// 检查新名称是否已存在
|
||||
exists, err := s.groupRepo.ExistsByName(ctx, *req.Name)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("check group exists: %w", err)
|
||||
}
|
||||
if exists {
|
||||
return nil, ErrGroupExists
|
||||
}
|
||||
group.Name = *req.Name
|
||||
}
|
||||
|
||||
if req.Description != nil {
|
||||
group.Description = *req.Description
|
||||
}
|
||||
|
||||
if req.RateMultiplier != nil {
|
||||
group.RateMultiplier = *req.RateMultiplier
|
||||
}
|
||||
|
||||
if req.IsExclusive != nil {
|
||||
group.IsExclusive = *req.IsExclusive
|
||||
}
|
||||
|
||||
if req.Status != nil {
|
||||
group.Status = *req.Status
|
||||
}
|
||||
|
||||
if err := s.groupRepo.Update(ctx, group); err != nil {
|
||||
return nil, fmt.Errorf("update group: %w", err)
|
||||
}
|
||||
if s.authCacheInvalidator != nil {
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id)
|
||||
}
|
||||
|
||||
return group, nil
|
||||
}
|
||||
|
||||
// Delete 删除分组
|
||||
func (s *GroupService) Delete(ctx context.Context, id int64) error {
|
||||
// 检查分组是否存在
|
||||
_, err := s.groupRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get group: %w", err)
|
||||
}
|
||||
|
||||
if s.authCacheInvalidator != nil {
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id)
|
||||
}
|
||||
if err := s.groupRepo.Delete(ctx, id); err != nil {
|
||||
return fmt.Errorf("delete group: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetStats 获取分组统计信息
|
||||
func (s *GroupService) GetStats(ctx context.Context, id int64) (map[string]any, error) {
|
||||
group, err := s.groupRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get group: %w", err)
|
||||
}
|
||||
|
||||
// 获取账号数量
|
||||
accountCount, err := s.groupRepo.GetAccountCount(ctx, id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get account count: %w", err)
|
||||
}
|
||||
|
||||
stats := map[string]any{
|
||||
"id": group.ID,
|
||||
"name": group.Name,
|
||||
"rate_multiplier": group.RateMultiplier,
|
||||
"is_exclusive": group.IsExclusive,
|
||||
"status": group.Status,
|
||||
"account_count": accountCount,
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
92
backend/internal/service/group_test.go
Normal file
92
backend/internal/service/group_test.go
Normal file
@@ -0,0 +1,92 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestGroup_GetImagePrice_1K 测试 1K 尺寸返回正确价格
|
||||
func TestGroup_GetImagePrice_1K(t *testing.T) {
|
||||
price := 0.10
|
||||
group := &Group{
|
||||
ImagePrice1K: &price,
|
||||
}
|
||||
|
||||
result := group.GetImagePrice("1K")
|
||||
require.NotNil(t, result)
|
||||
require.InDelta(t, 0.10, *result, 0.0001)
|
||||
}
|
||||
|
||||
// TestGroup_GetImagePrice_2K 测试 2K 尺寸返回正确价格
|
||||
func TestGroup_GetImagePrice_2K(t *testing.T) {
|
||||
price := 0.15
|
||||
group := &Group{
|
||||
ImagePrice2K: &price,
|
||||
}
|
||||
|
||||
result := group.GetImagePrice("2K")
|
||||
require.NotNil(t, result)
|
||||
require.InDelta(t, 0.15, *result, 0.0001)
|
||||
}
|
||||
|
||||
// TestGroup_GetImagePrice_4K 测试 4K 尺寸返回正确价格
|
||||
func TestGroup_GetImagePrice_4K(t *testing.T) {
|
||||
price := 0.30
|
||||
group := &Group{
|
||||
ImagePrice4K: &price,
|
||||
}
|
||||
|
||||
result := group.GetImagePrice("4K")
|
||||
require.NotNil(t, result)
|
||||
require.InDelta(t, 0.30, *result, 0.0001)
|
||||
}
|
||||
|
||||
// TestGroup_GetImagePrice_UnknownSize 测试未知尺寸回退 2K
|
||||
func TestGroup_GetImagePrice_UnknownSize(t *testing.T) {
|
||||
price2K := 0.15
|
||||
group := &Group{
|
||||
ImagePrice2K: &price2K,
|
||||
}
|
||||
|
||||
// 未知尺寸 "3K" 应该回退到 2K
|
||||
result := group.GetImagePrice("3K")
|
||||
require.NotNil(t, result)
|
||||
require.InDelta(t, 0.15, *result, 0.0001)
|
||||
|
||||
// 空字符串也回退到 2K
|
||||
result = group.GetImagePrice("")
|
||||
require.NotNil(t, result)
|
||||
require.InDelta(t, 0.15, *result, 0.0001)
|
||||
}
|
||||
|
||||
// TestGroup_GetImagePrice_NilValues 测试未配置时返回 nil
|
||||
func TestGroup_GetImagePrice_NilValues(t *testing.T) {
|
||||
group := &Group{
|
||||
// 所有 ImagePrice 字段都是 nil
|
||||
}
|
||||
|
||||
require.Nil(t, group.GetImagePrice("1K"))
|
||||
require.Nil(t, group.GetImagePrice("2K"))
|
||||
require.Nil(t, group.GetImagePrice("4K"))
|
||||
require.Nil(t, group.GetImagePrice("unknown"))
|
||||
}
|
||||
|
||||
// TestGroup_GetImagePrice_PartialConfig 测试部分配置
|
||||
func TestGroup_GetImagePrice_PartialConfig(t *testing.T) {
|
||||
price1K := 0.10
|
||||
group := &Group{
|
||||
ImagePrice1K: &price1K,
|
||||
// ImagePrice2K 和 ImagePrice4K 未配置
|
||||
}
|
||||
|
||||
result := group.GetImagePrice("1K")
|
||||
require.NotNil(t, result)
|
||||
require.InDelta(t, 0.10, *result, 0.0001)
|
||||
|
||||
// 2K 和 4K 返回 nil
|
||||
require.Nil(t, group.GetImagePrice("2K"))
|
||||
require.Nil(t, group.GetImagePrice("4K"))
|
||||
}
|
||||
30
backend/internal/service/http_upstream_port.go
Normal file
30
backend/internal/service/http_upstream_port.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package service
|
||||
|
||||
import "net/http"
|
||||
|
||||
// HTTPUpstream 上游 HTTP 请求接口
|
||||
// 用于向上游 API(Claude、OpenAI、Gemini 等)发送请求
|
||||
// 这是一个通用接口,可用于任何基于 HTTP 的上游服务
|
||||
//
|
||||
// 设计说明:
|
||||
// - 支持可选代理配置
|
||||
// - 支持账户级连接池隔离
|
||||
// - 实现类负责连接池管理和复用
|
||||
type HTTPUpstream interface {
|
||||
// Do 执行 HTTP 请求
|
||||
//
|
||||
// 参数:
|
||||
// - req: HTTP 请求对象,由调用方构建
|
||||
// - proxyURL: 代理服务器地址,空字符串表示直连
|
||||
// - accountID: 账户 ID,用于连接池隔离(隔离策略为 account 或 account_proxy 时生效)
|
||||
// - accountConcurrency: 账户并发限制,用于动态调整连接池大小
|
||||
//
|
||||
// 返回:
|
||||
// - *http.Response: HTTP 响应,调用方必须关闭 Body
|
||||
// - error: 请求错误(网络错误、超时等)
|
||||
//
|
||||
// 注意:
|
||||
// - 调用方必须关闭 resp.Body,否则会导致连接泄漏
|
||||
// - 响应体可能已被包装以跟踪请求生命周期
|
||||
Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error)
|
||||
}
|
||||
271
backend/internal/service/identity_service.go
Normal file
271
backend/internal/service/identity_service.go
Normal file
@@ -0,0 +1,271 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
// 预编译正则表达式(避免每次调用重新编译)
|
||||
var (
|
||||
// 匹配 user_id 格式: user_{64位hex}_account__session_{uuid}
|
||||
userIDRegex = regexp.MustCompile(`^user_[a-f0-9]{64}_account__session_([a-f0-9-]{36})$`)
|
||||
// 匹配 User-Agent 版本号: xxx/x.y.z
|
||||
userAgentVersionRegex = regexp.MustCompile(`/(\d+)\.(\d+)\.(\d+)`)
|
||||
)
|
||||
|
||||
// 默认指纹值(当客户端未提供时使用)
|
||||
var defaultFingerprint = Fingerprint{
|
||||
UserAgent: "claude-cli/2.0.62 (external, cli)",
|
||||
StainlessLang: "js",
|
||||
StainlessPackageVersion: "0.52.0",
|
||||
StainlessOS: "Linux",
|
||||
StainlessArch: "x64",
|
||||
StainlessRuntime: "node",
|
||||
StainlessRuntimeVersion: "v22.14.0",
|
||||
}
|
||||
|
||||
// Fingerprint represents account fingerprint data
|
||||
type Fingerprint struct {
|
||||
ClientID string
|
||||
UserAgent string
|
||||
StainlessLang string
|
||||
StainlessPackageVersion string
|
||||
StainlessOS string
|
||||
StainlessArch string
|
||||
StainlessRuntime string
|
||||
StainlessRuntimeVersion string
|
||||
}
|
||||
|
||||
// IdentityCache defines cache operations for identity service
|
||||
type IdentityCache interface {
|
||||
GetFingerprint(ctx context.Context, accountID int64) (*Fingerprint, error)
|
||||
SetFingerprint(ctx context.Context, accountID int64, fp *Fingerprint) error
|
||||
}
|
||||
|
||||
// IdentityService 管理OAuth账号的请求身份指纹
|
||||
type IdentityService struct {
|
||||
cache IdentityCache
|
||||
}
|
||||
|
||||
// NewIdentityService 创建新的IdentityService
|
||||
func NewIdentityService(cache IdentityCache) *IdentityService {
|
||||
return &IdentityService{cache: cache}
|
||||
}
|
||||
|
||||
// GetOrCreateFingerprint 获取或创建账号的指纹
|
||||
// 如果缓存存在,检测user-agent版本,新版本则更新
|
||||
// 如果缓存不存在,生成随机ClientID并从请求头创建指纹,然后缓存
|
||||
func (s *IdentityService) GetOrCreateFingerprint(ctx context.Context, accountID int64, headers http.Header) (*Fingerprint, error) {
|
||||
// 尝试从缓存获取指纹
|
||||
cached, err := s.cache.GetFingerprint(ctx, accountID)
|
||||
if err == nil && cached != nil {
|
||||
// 检查客户端的user-agent是否是更新版本
|
||||
clientUA := headers.Get("User-Agent")
|
||||
if clientUA != "" && isNewerVersion(clientUA, cached.UserAgent) {
|
||||
// 更新user-agent
|
||||
cached.UserAgent = clientUA
|
||||
// 保存更新后的指纹
|
||||
_ = s.cache.SetFingerprint(ctx, accountID, cached)
|
||||
log.Printf("Updated fingerprint user-agent for account %d: %s", accountID, clientUA)
|
||||
}
|
||||
return cached, nil
|
||||
}
|
||||
|
||||
// 缓存不存在或解析失败,创建新指纹
|
||||
fp := s.createFingerprintFromHeaders(headers)
|
||||
|
||||
// 生成随机ClientID
|
||||
fp.ClientID = generateClientID()
|
||||
|
||||
// 保存到缓存(永不过期)
|
||||
if err := s.cache.SetFingerprint(ctx, accountID, fp); err != nil {
|
||||
log.Printf("Warning: failed to cache fingerprint for account %d: %v", accountID, err)
|
||||
}
|
||||
|
||||
log.Printf("Created new fingerprint for account %d with client_id: %s", accountID, fp.ClientID)
|
||||
return fp, nil
|
||||
}
|
||||
|
||||
// createFingerprintFromHeaders 从请求头创建指纹
|
||||
func (s *IdentityService) createFingerprintFromHeaders(headers http.Header) *Fingerprint {
|
||||
fp := &Fingerprint{}
|
||||
|
||||
// 获取User-Agent
|
||||
if ua := headers.Get("User-Agent"); ua != "" {
|
||||
fp.UserAgent = ua
|
||||
} else {
|
||||
fp.UserAgent = defaultFingerprint.UserAgent
|
||||
}
|
||||
|
||||
// 获取x-stainless-*头,如果没有则使用默认值
|
||||
fp.StainlessLang = getHeaderOrDefault(headers, "X-Stainless-Lang", defaultFingerprint.StainlessLang)
|
||||
fp.StainlessPackageVersion = getHeaderOrDefault(headers, "X-Stainless-Package-Version", defaultFingerprint.StainlessPackageVersion)
|
||||
fp.StainlessOS = getHeaderOrDefault(headers, "X-Stainless-OS", defaultFingerprint.StainlessOS)
|
||||
fp.StainlessArch = getHeaderOrDefault(headers, "X-Stainless-Arch", defaultFingerprint.StainlessArch)
|
||||
fp.StainlessRuntime = getHeaderOrDefault(headers, "X-Stainless-Runtime", defaultFingerprint.StainlessRuntime)
|
||||
fp.StainlessRuntimeVersion = getHeaderOrDefault(headers, "X-Stainless-Runtime-Version", defaultFingerprint.StainlessRuntimeVersion)
|
||||
|
||||
return fp
|
||||
}
|
||||
|
||||
// getHeaderOrDefault 获取header值,如果不存在则返回默认值
|
||||
func getHeaderOrDefault(headers http.Header, key, defaultValue string) string {
|
||||
if v := headers.Get(key); v != "" {
|
||||
return v
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
// ApplyFingerprint 将指纹应用到请求头(覆盖原有的x-stainless-*头)
|
||||
func (s *IdentityService) ApplyFingerprint(req *http.Request, fp *Fingerprint) {
|
||||
if fp == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 设置user-agent
|
||||
if fp.UserAgent != "" {
|
||||
req.Header.Set("user-agent", fp.UserAgent)
|
||||
}
|
||||
|
||||
// 设置x-stainless-*头
|
||||
if fp.StainlessLang != "" {
|
||||
req.Header.Set("X-Stainless-Lang", fp.StainlessLang)
|
||||
}
|
||||
if fp.StainlessPackageVersion != "" {
|
||||
req.Header.Set("X-Stainless-Package-Version", fp.StainlessPackageVersion)
|
||||
}
|
||||
if fp.StainlessOS != "" {
|
||||
req.Header.Set("X-Stainless-OS", fp.StainlessOS)
|
||||
}
|
||||
if fp.StainlessArch != "" {
|
||||
req.Header.Set("X-Stainless-Arch", fp.StainlessArch)
|
||||
}
|
||||
if fp.StainlessRuntime != "" {
|
||||
req.Header.Set("X-Stainless-Runtime", fp.StainlessRuntime)
|
||||
}
|
||||
if fp.StainlessRuntimeVersion != "" {
|
||||
req.Header.Set("X-Stainless-Runtime-Version", fp.StainlessRuntimeVersion)
|
||||
}
|
||||
}
|
||||
|
||||
// RewriteUserID 重写body中的metadata.user_id
|
||||
// 输入格式:user_{clientId}_account__session_{sessionUUID}
|
||||
// 输出格式:user_{cachedClientID}_account_{accountUUID}_session_{newHash}
|
||||
func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUID, cachedClientID string) ([]byte, error) {
|
||||
if len(body) == 0 || accountUUID == "" || cachedClientID == "" {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
// 解析JSON
|
||||
var reqMap map[string]any
|
||||
if err := json.Unmarshal(body, &reqMap); err != nil {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
metadata, ok := reqMap["metadata"].(map[string]any)
|
||||
if !ok {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
userID, ok := metadata["user_id"].(string)
|
||||
if !ok || userID == "" {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
// 匹配格式: user_{64位hex}_account__session_{uuid}
|
||||
matches := userIDRegex.FindStringSubmatch(userID)
|
||||
if matches == nil {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
sessionTail := matches[1] // 原始session UUID
|
||||
|
||||
// 生成新的session hash: SHA256(accountID::sessionTail) -> UUID格式
|
||||
seed := fmt.Sprintf("%d::%s", accountID, sessionTail)
|
||||
newSessionHash := generateUUIDFromSeed(seed)
|
||||
|
||||
// 构建新的user_id
|
||||
// 格式: user_{cachedClientID}_account_{account_uuid}_session_{newSessionHash}
|
||||
newUserID := fmt.Sprintf("user_%s_account_%s_session_%s", cachedClientID, accountUUID, newSessionHash)
|
||||
|
||||
metadata["user_id"] = newUserID
|
||||
reqMap["metadata"] = metadata
|
||||
|
||||
return json.Marshal(reqMap)
|
||||
}
|
||||
|
||||
// generateClientID 生成64位十六进制客户端ID(32字节随机数)
|
||||
func generateClientID() string {
|
||||
b := make([]byte, 32)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
// 极罕见的情况,使用时间戳+固定值作为fallback
|
||||
log.Printf("Warning: crypto/rand.Read failed: %v, using fallback", err)
|
||||
// 使用SHA256(当前纳秒时间)作为fallback
|
||||
h := sha256.Sum256([]byte(fmt.Sprintf("%d", time.Now().UnixNano())))
|
||||
return hex.EncodeToString(h[:])
|
||||
}
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
|
||||
// generateUUIDFromSeed 从种子生成确定性UUID v4格式字符串
|
||||
func generateUUIDFromSeed(seed string) string {
|
||||
hash := sha256.Sum256([]byte(seed))
|
||||
bytes := hash[:16]
|
||||
|
||||
// 设置UUID v4版本和变体位
|
||||
bytes[6] = (bytes[6] & 0x0f) | 0x40
|
||||
bytes[8] = (bytes[8] & 0x3f) | 0x80
|
||||
|
||||
return fmt.Sprintf("%x-%x-%x-%x-%x",
|
||||
bytes[0:4], bytes[4:6], bytes[6:8], bytes[8:10], bytes[10:16])
|
||||
}
|
||||
|
||||
// parseUserAgentVersion 解析user-agent版本号
|
||||
// 例如:claude-cli/2.0.62 -> (2, 0, 62)
|
||||
func parseUserAgentVersion(ua string) (major, minor, patch int, ok bool) {
|
||||
// 匹配 xxx/x.y.z 格式
|
||||
matches := userAgentVersionRegex.FindStringSubmatch(ua)
|
||||
if len(matches) != 4 {
|
||||
return 0, 0, 0, false
|
||||
}
|
||||
major, _ = strconv.Atoi(matches[1])
|
||||
minor, _ = strconv.Atoi(matches[2])
|
||||
patch, _ = strconv.Atoi(matches[3])
|
||||
return major, minor, patch, true
|
||||
}
|
||||
|
||||
// isNewerVersion 比较版本号,判断newUA是否比cachedUA更新
|
||||
func isNewerVersion(newUA, cachedUA string) bool {
|
||||
newMajor, newMinor, newPatch, newOk := parseUserAgentVersion(newUA)
|
||||
cachedMajor, cachedMinor, cachedPatch, cachedOk := parseUserAgentVersion(cachedUA)
|
||||
|
||||
if !newOk || !cachedOk {
|
||||
return false
|
||||
}
|
||||
|
||||
// 比较版本号
|
||||
if newMajor > cachedMajor {
|
||||
return true
|
||||
}
|
||||
if newMajor < cachedMajor {
|
||||
return false
|
||||
}
|
||||
|
||||
if newMinor > cachedMinor {
|
||||
return true
|
||||
}
|
||||
if newMinor < cachedMinor {
|
||||
return false
|
||||
}
|
||||
|
||||
return newPatch > cachedPatch
|
||||
}
|
||||
301
backend/internal/service/oauth_service.go
Normal file
301
backend/internal/service/oauth_service.go
Normal file
@@ -0,0 +1,301 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
)
|
||||
|
||||
// OpenAIOAuthClient interface for OpenAI OAuth operations
|
||||
type OpenAIOAuthClient interface {
|
||||
ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error)
|
||||
RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error)
|
||||
}
|
||||
|
||||
// ClaudeOAuthClient handles HTTP requests for Claude OAuth flows
|
||||
type ClaudeOAuthClient interface {
|
||||
GetOrganizationUUID(ctx context.Context, sessionKey, proxyURL string) (string, error)
|
||||
GetAuthorizationCode(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error)
|
||||
ExchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error)
|
||||
RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error)
|
||||
}
|
||||
|
||||
// OAuthService handles OAuth authentication flows
|
||||
type OAuthService struct {
|
||||
sessionStore *oauth.SessionStore
|
||||
proxyRepo ProxyRepository
|
||||
oauthClient ClaudeOAuthClient
|
||||
}
|
||||
|
||||
// NewOAuthService creates a new OAuth service
|
||||
func NewOAuthService(proxyRepo ProxyRepository, oauthClient ClaudeOAuthClient) *OAuthService {
|
||||
return &OAuthService{
|
||||
sessionStore: oauth.NewSessionStore(),
|
||||
proxyRepo: proxyRepo,
|
||||
oauthClient: oauthClient,
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateAuthURLResult contains the authorization URL and session info
|
||||
type GenerateAuthURLResult struct {
|
||||
AuthURL string `json:"auth_url"`
|
||||
SessionID string `json:"session_id"`
|
||||
}
|
||||
|
||||
// GenerateAuthURL generates an OAuth authorization URL with full scope
|
||||
func (s *OAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64) (*GenerateAuthURLResult, error) {
|
||||
scope := fmt.Sprintf("%s %s", oauth.ScopeProfile, oauth.ScopeInference)
|
||||
return s.generateAuthURLWithScope(ctx, scope, proxyID)
|
||||
}
|
||||
|
||||
// GenerateSetupTokenURL generates an OAuth authorization URL for setup token (inference only)
|
||||
func (s *OAuthService) GenerateSetupTokenURL(ctx context.Context, proxyID *int64) (*GenerateAuthURLResult, error) {
|
||||
scope := oauth.ScopeInference
|
||||
return s.generateAuthURLWithScope(ctx, scope, proxyID)
|
||||
}
|
||||
|
||||
func (s *OAuthService) generateAuthURLWithScope(ctx context.Context, scope string, proxyID *int64) (*GenerateAuthURLResult, error) {
|
||||
// Generate PKCE values
|
||||
state, err := oauth.GenerateState()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate state: %w", err)
|
||||
}
|
||||
|
||||
codeVerifier, err := oauth.GenerateCodeVerifier()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate code verifier: %w", err)
|
||||
}
|
||||
|
||||
codeChallenge := oauth.GenerateCodeChallenge(codeVerifier)
|
||||
|
||||
// Generate session ID
|
||||
sessionID, err := oauth.GenerateSessionID()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate session ID: %w", err)
|
||||
}
|
||||
|
||||
// Get proxy URL if specified
|
||||
var proxyURL string
|
||||
if proxyID != nil {
|
||||
proxy, err := s.proxyRepo.GetByID(ctx, *proxyID)
|
||||
if err == nil && proxy != nil {
|
||||
proxyURL = proxy.URL()
|
||||
}
|
||||
}
|
||||
|
||||
// Store session
|
||||
session := &oauth.OAuthSession{
|
||||
State: state,
|
||||
CodeVerifier: codeVerifier,
|
||||
Scope: scope,
|
||||
ProxyURL: proxyURL,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
s.sessionStore.Set(sessionID, session)
|
||||
|
||||
// Build authorization URL
|
||||
authURL := oauth.BuildAuthorizationURL(state, codeChallenge, scope)
|
||||
|
||||
return &GenerateAuthURLResult{
|
||||
AuthURL: authURL,
|
||||
SessionID: sessionID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ExchangeCodeInput represents the input for code exchange
|
||||
type ExchangeCodeInput struct {
|
||||
SessionID string
|
||||
Code string
|
||||
ProxyID *int64
|
||||
}
|
||||
|
||||
// TokenInfo represents the token information stored in credentials
|
||||
type TokenInfo struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
ExpiresAt int64 `json:"expires_at"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
OrgUUID string `json:"org_uuid,omitempty"`
|
||||
AccountUUID string `json:"account_uuid,omitempty"`
|
||||
}
|
||||
|
||||
// ExchangeCode exchanges authorization code for tokens
|
||||
func (s *OAuthService) ExchangeCode(ctx context.Context, input *ExchangeCodeInput) (*TokenInfo, error) {
|
||||
// Get session
|
||||
session, ok := s.sessionStore.Get(input.SessionID)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("session not found or expired")
|
||||
}
|
||||
|
||||
// Get proxy URL
|
||||
proxyURL := session.ProxyURL
|
||||
if input.ProxyID != nil {
|
||||
proxy, err := s.proxyRepo.GetByID(ctx, *input.ProxyID)
|
||||
if err == nil && proxy != nil {
|
||||
proxyURL = proxy.URL()
|
||||
}
|
||||
}
|
||||
|
||||
// Determine if this is a setup token (scope is inference only)
|
||||
isSetupToken := session.Scope == oauth.ScopeInference
|
||||
|
||||
// Exchange code for token
|
||||
tokenInfo, err := s.exchangeCodeForToken(ctx, input.Code, session.CodeVerifier, session.State, proxyURL, isSetupToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Delete session after successful exchange
|
||||
s.sessionStore.Delete(input.SessionID)
|
||||
|
||||
return tokenInfo, nil
|
||||
}
|
||||
|
||||
// CookieAuthInput represents the input for cookie-based authentication
|
||||
type CookieAuthInput struct {
|
||||
SessionKey string
|
||||
ProxyID *int64
|
||||
Scope string // "full" or "inference"
|
||||
}
|
||||
|
||||
// CookieAuth performs OAuth using sessionKey (cookie-based auto-auth)
|
||||
func (s *OAuthService) CookieAuth(ctx context.Context, input *CookieAuthInput) (*TokenInfo, error) {
|
||||
// Get proxy URL if specified
|
||||
var proxyURL string
|
||||
if input.ProxyID != nil {
|
||||
proxy, err := s.proxyRepo.GetByID(ctx, *input.ProxyID)
|
||||
if err == nil && proxy != nil {
|
||||
proxyURL = proxy.URL()
|
||||
}
|
||||
}
|
||||
|
||||
// Determine scope and if this is a setup token
|
||||
scope := fmt.Sprintf("%s %s", oauth.ScopeProfile, oauth.ScopeInference)
|
||||
isSetupToken := false
|
||||
if input.Scope == "inference" {
|
||||
scope = oauth.ScopeInference
|
||||
isSetupToken = true
|
||||
}
|
||||
|
||||
// Step 1: Get organization info using sessionKey
|
||||
orgUUID, err := s.getOrganizationUUID(ctx, input.SessionKey, proxyURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get organization info: %w", err)
|
||||
}
|
||||
|
||||
// Step 2: Generate PKCE values
|
||||
codeVerifier, err := oauth.GenerateCodeVerifier()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate code verifier: %w", err)
|
||||
}
|
||||
codeChallenge := oauth.GenerateCodeChallenge(codeVerifier)
|
||||
|
||||
state, err := oauth.GenerateState()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate state: %w", err)
|
||||
}
|
||||
|
||||
// Step 3: Get authorization code using cookie
|
||||
authCode, err := s.getAuthorizationCode(ctx, input.SessionKey, orgUUID, scope, codeChallenge, state, proxyURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get authorization code: %w", err)
|
||||
}
|
||||
|
||||
// Step 4: Exchange code for token
|
||||
tokenInfo, err := s.exchangeCodeForToken(ctx, authCode, codeVerifier, state, proxyURL, isSetupToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to exchange code: %w", err)
|
||||
}
|
||||
|
||||
// Ensure org_uuid is set (from step 1 if not from token response)
|
||||
if tokenInfo.OrgUUID == "" && orgUUID != "" {
|
||||
tokenInfo.OrgUUID = orgUUID
|
||||
log.Printf("[OAuth] Set org_uuid from cookie auth: %s", orgUUID)
|
||||
}
|
||||
|
||||
return tokenInfo, nil
|
||||
}
|
||||
|
||||
// getOrganizationUUID gets the organization UUID from claude.ai using sessionKey
|
||||
func (s *OAuthService) getOrganizationUUID(ctx context.Context, sessionKey, proxyURL string) (string, error) {
|
||||
return s.oauthClient.GetOrganizationUUID(ctx, sessionKey, proxyURL)
|
||||
}
|
||||
|
||||
// getAuthorizationCode gets the authorization code using sessionKey
|
||||
func (s *OAuthService) getAuthorizationCode(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error) {
|
||||
return s.oauthClient.GetAuthorizationCode(ctx, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL)
|
||||
}
|
||||
|
||||
// exchangeCodeForToken exchanges authorization code for tokens
|
||||
func (s *OAuthService) exchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*TokenInfo, error) {
|
||||
tokenResp, err := s.oauthClient.ExchangeCodeForToken(ctx, code, codeVerifier, state, proxyURL, isSetupToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tokenInfo := &TokenInfo{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
TokenType: tokenResp.TokenType,
|
||||
ExpiresIn: tokenResp.ExpiresIn,
|
||||
ExpiresAt: time.Now().Unix() + tokenResp.ExpiresIn,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
Scope: tokenResp.Scope,
|
||||
}
|
||||
|
||||
if tokenResp.Organization != nil && tokenResp.Organization.UUID != "" {
|
||||
tokenInfo.OrgUUID = tokenResp.Organization.UUID
|
||||
log.Printf("[OAuth] Got org_uuid: %s", tokenInfo.OrgUUID)
|
||||
}
|
||||
if tokenResp.Account != nil && tokenResp.Account.UUID != "" {
|
||||
tokenInfo.AccountUUID = tokenResp.Account.UUID
|
||||
log.Printf("[OAuth] Got account_uuid: %s", tokenInfo.AccountUUID)
|
||||
}
|
||||
|
||||
return tokenInfo, nil
|
||||
}
|
||||
|
||||
// RefreshToken refreshes an OAuth token
|
||||
func (s *OAuthService) RefreshToken(ctx context.Context, refreshToken string, proxyURL string) (*TokenInfo, error) {
|
||||
tokenResp, err := s.oauthClient.RefreshToken(ctx, refreshToken, proxyURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &TokenInfo{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
TokenType: tokenResp.TokenType,
|
||||
ExpiresIn: tokenResp.ExpiresIn,
|
||||
ExpiresAt: time.Now().Unix() + tokenResp.ExpiresIn,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
Scope: tokenResp.Scope,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RefreshAccountToken refreshes token for an account
|
||||
func (s *OAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*TokenInfo, error) {
|
||||
refreshToken := account.GetCredential("refresh_token")
|
||||
if refreshToken == "" {
|
||||
return nil, fmt.Errorf("no refresh token available")
|
||||
}
|
||||
|
||||
var proxyURL string
|
||||
if account.ProxyID != nil {
|
||||
proxy, err := s.proxyRepo.GetByID(ctx, *account.ProxyID)
|
||||
if err == nil && proxy != nil {
|
||||
proxyURL = proxy.URL()
|
||||
}
|
||||
}
|
||||
|
||||
return s.RefreshToken(ctx, refreshToken, proxyURL)
|
||||
}
|
||||
|
||||
// Stop stops the session store cleanup goroutine
|
||||
func (s *OAuthService) Stop() {
|
||||
s.sessionStore.Stop()
|
||||
}
|
||||
528
backend/internal/service/openai_codex_transform.go
Normal file
528
backend/internal/service/openai_codex_transform.go
Normal file
@@ -0,0 +1,528 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
opencodeCodexHeaderURL = "https://raw.githubusercontent.com/anomalyco/opencode/dev/packages/opencode/src/session/prompt/codex_header.txt"
|
||||
codexCacheTTL = 15 * time.Minute
|
||||
)
|
||||
|
||||
//go:embed prompts/codex_cli_instructions.md
|
||||
var codexCLIInstructions string
|
||||
|
||||
var codexModelMap = map[string]string{
|
||||
"gpt-5.1-codex": "gpt-5.1-codex",
|
||||
"gpt-5.1-codex-low": "gpt-5.1-codex",
|
||||
"gpt-5.1-codex-medium": "gpt-5.1-codex",
|
||||
"gpt-5.1-codex-high": "gpt-5.1-codex",
|
||||
"gpt-5.1-codex-max": "gpt-5.1-codex-max",
|
||||
"gpt-5.1-codex-max-low": "gpt-5.1-codex-max",
|
||||
"gpt-5.1-codex-max-medium": "gpt-5.1-codex-max",
|
||||
"gpt-5.1-codex-max-high": "gpt-5.1-codex-max",
|
||||
"gpt-5.1-codex-max-xhigh": "gpt-5.1-codex-max",
|
||||
"gpt-5.2": "gpt-5.2",
|
||||
"gpt-5.2-none": "gpt-5.2",
|
||||
"gpt-5.2-low": "gpt-5.2",
|
||||
"gpt-5.2-medium": "gpt-5.2",
|
||||
"gpt-5.2-high": "gpt-5.2",
|
||||
"gpt-5.2-xhigh": "gpt-5.2",
|
||||
"gpt-5.2-codex": "gpt-5.2-codex",
|
||||
"gpt-5.2-codex-low": "gpt-5.2-codex",
|
||||
"gpt-5.2-codex-medium": "gpt-5.2-codex",
|
||||
"gpt-5.2-codex-high": "gpt-5.2-codex",
|
||||
"gpt-5.2-codex-xhigh": "gpt-5.2-codex",
|
||||
"gpt-5.1-codex-mini": "gpt-5.1-codex-mini",
|
||||
"gpt-5.1-codex-mini-medium": "gpt-5.1-codex-mini",
|
||||
"gpt-5.1-codex-mini-high": "gpt-5.1-codex-mini",
|
||||
"gpt-5.1": "gpt-5.1",
|
||||
"gpt-5.1-none": "gpt-5.1",
|
||||
"gpt-5.1-low": "gpt-5.1",
|
||||
"gpt-5.1-medium": "gpt-5.1",
|
||||
"gpt-5.1-high": "gpt-5.1",
|
||||
"gpt-5.1-chat-latest": "gpt-5.1",
|
||||
"gpt-5-codex": "gpt-5.1-codex",
|
||||
"codex-mini-latest": "gpt-5.1-codex-mini",
|
||||
"gpt-5-codex-mini": "gpt-5.1-codex-mini",
|
||||
"gpt-5-codex-mini-medium": "gpt-5.1-codex-mini",
|
||||
"gpt-5-codex-mini-high": "gpt-5.1-codex-mini",
|
||||
"gpt-5": "gpt-5.1",
|
||||
"gpt-5-mini": "gpt-5.1",
|
||||
"gpt-5-nano": "gpt-5.1",
|
||||
}
|
||||
|
||||
type codexTransformResult struct {
|
||||
Modified bool
|
||||
NormalizedModel string
|
||||
PromptCacheKey string
|
||||
}
|
||||
|
||||
type opencodeCacheMetadata struct {
|
||||
ETag string `json:"etag"`
|
||||
LastFetch string `json:"lastFetch,omitempty"`
|
||||
LastChecked int64 `json:"lastChecked"`
|
||||
}
|
||||
|
||||
func applyCodexOAuthTransform(reqBody map[string]any) codexTransformResult {
|
||||
result := codexTransformResult{}
|
||||
// 工具续链需求会影响存储策略与 input 过滤逻辑。
|
||||
needsToolContinuation := NeedsToolContinuation(reqBody)
|
||||
|
||||
model := ""
|
||||
if v, ok := reqBody["model"].(string); ok {
|
||||
model = v
|
||||
}
|
||||
normalizedModel := normalizeCodexModel(model)
|
||||
if normalizedModel != "" {
|
||||
if model != normalizedModel {
|
||||
reqBody["model"] = normalizedModel
|
||||
result.Modified = true
|
||||
}
|
||||
result.NormalizedModel = normalizedModel
|
||||
}
|
||||
|
||||
// OAuth 走 ChatGPT internal API 时,store 必须为 false;显式 true 也会强制覆盖。
|
||||
// 避免上游返回 "Store must be set to false"。
|
||||
if v, ok := reqBody["store"].(bool); !ok || v {
|
||||
reqBody["store"] = false
|
||||
result.Modified = true
|
||||
}
|
||||
if v, ok := reqBody["stream"].(bool); !ok || !v {
|
||||
reqBody["stream"] = true
|
||||
result.Modified = true
|
||||
}
|
||||
|
||||
if _, ok := reqBody["max_output_tokens"]; ok {
|
||||
delete(reqBody, "max_output_tokens")
|
||||
result.Modified = true
|
||||
}
|
||||
if _, ok := reqBody["max_completion_tokens"]; ok {
|
||||
delete(reqBody, "max_completion_tokens")
|
||||
result.Modified = true
|
||||
}
|
||||
|
||||
if normalizeCodexTools(reqBody) {
|
||||
result.Modified = true
|
||||
}
|
||||
|
||||
if v, ok := reqBody["prompt_cache_key"].(string); ok {
|
||||
result.PromptCacheKey = strings.TrimSpace(v)
|
||||
}
|
||||
|
||||
instructions := strings.TrimSpace(getOpenCodeCodexHeader())
|
||||
existingInstructions, _ := reqBody["instructions"].(string)
|
||||
existingInstructions = strings.TrimSpace(existingInstructions)
|
||||
|
||||
if instructions != "" {
|
||||
if existingInstructions != instructions {
|
||||
reqBody["instructions"] = instructions
|
||||
result.Modified = true
|
||||
}
|
||||
} else if existingInstructions == "" {
|
||||
// 未获取到 opencode 指令时,回退使用 Codex CLI 指令。
|
||||
codexInstructions := strings.TrimSpace(getCodexCLIInstructions())
|
||||
if codexInstructions != "" {
|
||||
reqBody["instructions"] = codexInstructions
|
||||
result.Modified = true
|
||||
}
|
||||
}
|
||||
|
||||
// 续链场景保留 item_reference 与 id,避免 call_id 上下文丢失。
|
||||
if input, ok := reqBody["input"].([]any); ok {
|
||||
input = filterCodexInput(input, needsToolContinuation)
|
||||
reqBody["input"] = input
|
||||
result.Modified = true
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func normalizeCodexModel(model string) string {
|
||||
if model == "" {
|
||||
return "gpt-5.1"
|
||||
}
|
||||
|
||||
modelID := model
|
||||
if strings.Contains(modelID, "/") {
|
||||
parts := strings.Split(modelID, "/")
|
||||
modelID = parts[len(parts)-1]
|
||||
}
|
||||
|
||||
if mapped := getNormalizedCodexModel(modelID); mapped != "" {
|
||||
return mapped
|
||||
}
|
||||
|
||||
normalized := strings.ToLower(modelID)
|
||||
|
||||
if strings.Contains(normalized, "gpt-5.2-codex") || strings.Contains(normalized, "gpt 5.2 codex") {
|
||||
return "gpt-5.2-codex"
|
||||
}
|
||||
if strings.Contains(normalized, "gpt-5.2") || strings.Contains(normalized, "gpt 5.2") {
|
||||
return "gpt-5.2"
|
||||
}
|
||||
if strings.Contains(normalized, "gpt-5.1-codex-max") || strings.Contains(normalized, "gpt 5.1 codex max") {
|
||||
return "gpt-5.1-codex-max"
|
||||
}
|
||||
if strings.Contains(normalized, "gpt-5.1-codex-mini") || strings.Contains(normalized, "gpt 5.1 codex mini") {
|
||||
return "gpt-5.1-codex-mini"
|
||||
}
|
||||
if strings.Contains(normalized, "codex-mini-latest") ||
|
||||
strings.Contains(normalized, "gpt-5-codex-mini") ||
|
||||
strings.Contains(normalized, "gpt 5 codex mini") {
|
||||
return "codex-mini-latest"
|
||||
}
|
||||
if strings.Contains(normalized, "gpt-5.1-codex") || strings.Contains(normalized, "gpt 5.1 codex") {
|
||||
return "gpt-5.1-codex"
|
||||
}
|
||||
if strings.Contains(normalized, "gpt-5.1") || strings.Contains(normalized, "gpt 5.1") {
|
||||
return "gpt-5.1"
|
||||
}
|
||||
if strings.Contains(normalized, "codex") {
|
||||
return "gpt-5.1-codex"
|
||||
}
|
||||
if strings.Contains(normalized, "gpt-5") || strings.Contains(normalized, "gpt 5") {
|
||||
return "gpt-5.1"
|
||||
}
|
||||
|
||||
return "gpt-5.1"
|
||||
}
|
||||
|
||||
func getNormalizedCodexModel(modelID string) string {
|
||||
if modelID == "" {
|
||||
return ""
|
||||
}
|
||||
if mapped, ok := codexModelMap[modelID]; ok {
|
||||
return mapped
|
||||
}
|
||||
lower := strings.ToLower(modelID)
|
||||
for key, value := range codexModelMap {
|
||||
if strings.ToLower(key) == lower {
|
||||
return value
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func getOpenCodeCachedPrompt(url, cacheFileName, metaFileName string) string {
|
||||
cacheDir := codexCachePath("")
|
||||
if cacheDir == "" {
|
||||
return ""
|
||||
}
|
||||
cacheFile := filepath.Join(cacheDir, cacheFileName)
|
||||
metaFile := filepath.Join(cacheDir, metaFileName)
|
||||
|
||||
var cachedContent string
|
||||
if content, ok := readFile(cacheFile); ok {
|
||||
cachedContent = content
|
||||
}
|
||||
|
||||
var meta opencodeCacheMetadata
|
||||
if loadJSON(metaFile, &meta) && meta.LastChecked > 0 && cachedContent != "" {
|
||||
if time.Since(time.UnixMilli(meta.LastChecked)) < codexCacheTTL {
|
||||
return cachedContent
|
||||
}
|
||||
}
|
||||
|
||||
content, etag, status, err := fetchWithETag(url, meta.ETag)
|
||||
if err == nil && status == http.StatusNotModified && cachedContent != "" {
|
||||
return cachedContent
|
||||
}
|
||||
if err == nil && status >= 200 && status < 300 && content != "" {
|
||||
_ = writeFile(cacheFile, content)
|
||||
meta = opencodeCacheMetadata{
|
||||
ETag: etag,
|
||||
LastFetch: time.Now().UTC().Format(time.RFC3339),
|
||||
LastChecked: time.Now().UnixMilli(),
|
||||
}
|
||||
_ = writeJSON(metaFile, meta)
|
||||
return content
|
||||
}
|
||||
|
||||
return cachedContent
|
||||
}
|
||||
|
||||
func getOpenCodeCodexHeader() string {
|
||||
// 优先从 opencode 仓库缓存获取指令。
|
||||
opencodeInstructions := getOpenCodeCachedPrompt(opencodeCodexHeaderURL, "opencode-codex-header.txt", "opencode-codex-header-meta.json")
|
||||
|
||||
// 若 opencode 指令可用,直接返回。
|
||||
if opencodeInstructions != "" {
|
||||
return opencodeInstructions
|
||||
}
|
||||
|
||||
// 否则回退使用本地 Codex CLI 指令。
|
||||
return getCodexCLIInstructions()
|
||||
}
|
||||
|
||||
func getCodexCLIInstructions() string {
|
||||
return codexCLIInstructions
|
||||
}
|
||||
|
||||
func GetOpenCodeInstructions() string {
|
||||
return getOpenCodeCodexHeader()
|
||||
}
|
||||
|
||||
// GetCodexCLIInstructions 返回内置的 Codex CLI 指令内容。
|
||||
func GetCodexCLIInstructions() string {
|
||||
return getCodexCLIInstructions()
|
||||
}
|
||||
|
||||
// ReplaceWithCodexInstructions 将请求 instructions 替换为内置 Codex 指令(必要时)。
|
||||
func ReplaceWithCodexInstructions(reqBody map[string]any) bool {
|
||||
codexInstructions := strings.TrimSpace(getCodexCLIInstructions())
|
||||
if codexInstructions == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
existingInstructions, _ := reqBody["instructions"].(string)
|
||||
if strings.TrimSpace(existingInstructions) != codexInstructions {
|
||||
reqBody["instructions"] = codexInstructions
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// IsInstructionError 判断错误信息是否与指令格式/系统提示相关。
|
||||
func IsInstructionError(errorMessage string) bool {
|
||||
if errorMessage == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
lowerMsg := strings.ToLower(errorMessage)
|
||||
instructionKeywords := []string{
|
||||
"instruction",
|
||||
"instructions",
|
||||
"system prompt",
|
||||
"system message",
|
||||
"invalid prompt",
|
||||
"prompt format",
|
||||
}
|
||||
|
||||
for _, keyword := range instructionKeywords {
|
||||
if strings.Contains(lowerMsg, keyword) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// filterCodexInput 按需过滤 item_reference 与 id。
|
||||
// preserveReferences 为 true 时保持引用与 id,以满足续链请求对上下文的依赖。
|
||||
func filterCodexInput(input []any, preserveReferences bool) []any {
|
||||
filtered := make([]any, 0, len(input))
|
||||
for _, item := range input {
|
||||
m, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
filtered = append(filtered, item)
|
||||
continue
|
||||
}
|
||||
typ, _ := m["type"].(string)
|
||||
if typ == "item_reference" {
|
||||
if !preserveReferences {
|
||||
continue
|
||||
}
|
||||
newItem := make(map[string]any, len(m))
|
||||
for key, value := range m {
|
||||
newItem[key] = value
|
||||
}
|
||||
filtered = append(filtered, newItem)
|
||||
continue
|
||||
}
|
||||
|
||||
newItem := m
|
||||
copied := false
|
||||
// 仅在需要修改字段时创建副本,避免直接改写原始输入。
|
||||
ensureCopy := func() {
|
||||
if copied {
|
||||
return
|
||||
}
|
||||
newItem = make(map[string]any, len(m))
|
||||
for key, value := range m {
|
||||
newItem[key] = value
|
||||
}
|
||||
copied = true
|
||||
}
|
||||
|
||||
if isCodexToolCallItemType(typ) {
|
||||
if callID, ok := m["call_id"].(string); !ok || strings.TrimSpace(callID) == "" {
|
||||
if id, ok := m["id"].(string); ok && strings.TrimSpace(id) != "" {
|
||||
ensureCopy()
|
||||
newItem["call_id"] = id
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !preserveReferences {
|
||||
ensureCopy()
|
||||
delete(newItem, "id")
|
||||
if !isCodexToolCallItemType(typ) {
|
||||
delete(newItem, "call_id")
|
||||
}
|
||||
}
|
||||
|
||||
filtered = append(filtered, newItem)
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
func isCodexToolCallItemType(typ string) bool {
|
||||
if typ == "" {
|
||||
return false
|
||||
}
|
||||
return strings.HasSuffix(typ, "_call") || strings.HasSuffix(typ, "_call_output")
|
||||
}
|
||||
|
||||
func normalizeCodexTools(reqBody map[string]any) bool {
|
||||
rawTools, ok := reqBody["tools"]
|
||||
if !ok || rawTools == nil {
|
||||
return false
|
||||
}
|
||||
tools, ok := rawTools.([]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
modified := false
|
||||
for idx, tool := range tools {
|
||||
toolMap, ok := tool.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
toolType, _ := toolMap["type"].(string)
|
||||
if strings.TrimSpace(toolType) != "function" {
|
||||
continue
|
||||
}
|
||||
|
||||
function, ok := toolMap["function"].(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if _, ok := toolMap["name"]; !ok {
|
||||
if name, ok := function["name"].(string); ok && strings.TrimSpace(name) != "" {
|
||||
toolMap["name"] = name
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
if _, ok := toolMap["description"]; !ok {
|
||||
if desc, ok := function["description"].(string); ok && strings.TrimSpace(desc) != "" {
|
||||
toolMap["description"] = desc
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
if _, ok := toolMap["parameters"]; !ok {
|
||||
if params, ok := function["parameters"]; ok {
|
||||
toolMap["parameters"] = params
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
if _, ok := toolMap["strict"]; !ok {
|
||||
if strict, ok := function["strict"]; ok {
|
||||
toolMap["strict"] = strict
|
||||
modified = true
|
||||
}
|
||||
}
|
||||
|
||||
tools[idx] = toolMap
|
||||
}
|
||||
|
||||
if modified {
|
||||
reqBody["tools"] = tools
|
||||
}
|
||||
|
||||
return modified
|
||||
}
|
||||
|
||||
func codexCachePath(filename string) string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
cacheDir := filepath.Join(home, ".opencode", "cache")
|
||||
if filename == "" {
|
||||
return cacheDir
|
||||
}
|
||||
return filepath.Join(cacheDir, filename)
|
||||
}
|
||||
|
||||
func readFile(path string) (string, bool) {
|
||||
if path == "" {
|
||||
return "", false
|
||||
}
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return "", false
|
||||
}
|
||||
return string(data), true
|
||||
}
|
||||
|
||||
func writeFile(path, content string) error {
|
||||
if path == "" {
|
||||
return fmt.Errorf("empty cache path")
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(path, []byte(content), 0o644)
|
||||
}
|
||||
|
||||
func loadJSON(path string, target any) bool {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if err := json.Unmarshal(data, target); err != nil {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func writeJSON(path string, value any) error {
|
||||
if path == "" {
|
||||
return fmt.Errorf("empty json path")
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
data, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(path, data, 0o644)
|
||||
}
|
||||
|
||||
func fetchWithETag(url, etag string) (string, string, int, error) {
|
||||
req, err := http.NewRequest(http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return "", "", 0, err
|
||||
}
|
||||
req.Header.Set("User-Agent", "sub2api-codex")
|
||||
if etag != "" {
|
||||
req.Header.Set("If-None-Match", etag)
|
||||
}
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return "", "", 0, err
|
||||
}
|
||||
defer func() {
|
||||
_ = resp.Body.Close()
|
||||
}()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", "", resp.StatusCode, err
|
||||
}
|
||||
return string(body), resp.Header.Get("etag"), resp.StatusCode, nil
|
||||
}
|
||||
167
backend/internal/service/openai_codex_transform_test.go
Normal file
167
backend/internal/service/openai_codex_transform_test.go
Normal file
@@ -0,0 +1,167 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestApplyCodexOAuthTransform_ToolContinuationPreservesInput(t *testing.T) {
|
||||
// 续链场景:保留 item_reference 与 id,但不再强制 store=true。
|
||||
setupCodexCache(t)
|
||||
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-5.2",
|
||||
"input": []any{
|
||||
map[string]any{"type": "item_reference", "id": "ref1", "text": "x"},
|
||||
map[string]any{"type": "function_call_output", "call_id": "call_1", "output": "ok", "id": "o1"},
|
||||
},
|
||||
"tool_choice": "auto",
|
||||
}
|
||||
|
||||
applyCodexOAuthTransform(reqBody)
|
||||
|
||||
// 未显式设置 store=true,默认为 false。
|
||||
store, ok := reqBody["store"].(bool)
|
||||
require.True(t, ok)
|
||||
require.False(t, store)
|
||||
|
||||
input, ok := reqBody["input"].([]any)
|
||||
require.True(t, ok)
|
||||
require.Len(t, input, 2)
|
||||
|
||||
// 校验 input[0] 为 map,避免断言失败导致测试中断。
|
||||
first, ok := input[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "item_reference", first["type"])
|
||||
require.Equal(t, "ref1", first["id"])
|
||||
|
||||
// 校验 input[1] 为 map,确保后续字段断言安全。
|
||||
second, ok := input[1].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "o1", second["id"])
|
||||
}
|
||||
|
||||
func TestApplyCodexOAuthTransform_ExplicitStoreFalsePreserved(t *testing.T) {
|
||||
// 续链场景:显式 store=false 不再强制为 true,保持 false。
|
||||
setupCodexCache(t)
|
||||
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-5.1",
|
||||
"store": false,
|
||||
"input": []any{
|
||||
map[string]any{"type": "function_call_output", "call_id": "call_1"},
|
||||
},
|
||||
"tool_choice": "auto",
|
||||
}
|
||||
|
||||
applyCodexOAuthTransform(reqBody)
|
||||
|
||||
store, ok := reqBody["store"].(bool)
|
||||
require.True(t, ok)
|
||||
require.False(t, store)
|
||||
}
|
||||
|
||||
func TestApplyCodexOAuthTransform_ExplicitStoreTrueForcedFalse(t *testing.T) {
|
||||
// 显式 store=true 也会强制为 false。
|
||||
setupCodexCache(t)
|
||||
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-5.1",
|
||||
"store": true,
|
||||
"input": []any{
|
||||
map[string]any{"type": "function_call_output", "call_id": "call_1"},
|
||||
},
|
||||
"tool_choice": "auto",
|
||||
}
|
||||
|
||||
applyCodexOAuthTransform(reqBody)
|
||||
|
||||
store, ok := reqBody["store"].(bool)
|
||||
require.True(t, ok)
|
||||
require.False(t, store)
|
||||
}
|
||||
|
||||
func TestApplyCodexOAuthTransform_NonContinuationDefaultsStoreFalseAndStripsIDs(t *testing.T) {
|
||||
// 非续链场景:未设置 store 时默认 false,并移除 input 中的 id。
|
||||
setupCodexCache(t)
|
||||
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-5.1",
|
||||
"input": []any{
|
||||
map[string]any{"type": "text", "id": "t1", "text": "hi"},
|
||||
},
|
||||
}
|
||||
|
||||
applyCodexOAuthTransform(reqBody)
|
||||
|
||||
store, ok := reqBody["store"].(bool)
|
||||
require.True(t, ok)
|
||||
require.False(t, store)
|
||||
|
||||
input, ok := reqBody["input"].([]any)
|
||||
require.True(t, ok)
|
||||
require.Len(t, input, 1)
|
||||
// 校验 input[0] 为 map,避免类型不匹配触发 errcheck。
|
||||
item, ok := input[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
_, hasID := item["id"]
|
||||
require.False(t, hasID)
|
||||
}
|
||||
|
||||
func TestFilterCodexInput_RemovesItemReferenceWhenNotPreserved(t *testing.T) {
|
||||
input := []any{
|
||||
map[string]any{"type": "item_reference", "id": "ref1"},
|
||||
map[string]any{"type": "text", "id": "t1", "text": "hi"},
|
||||
}
|
||||
|
||||
filtered := filterCodexInput(input, false)
|
||||
require.Len(t, filtered, 1)
|
||||
// 校验 filtered[0] 为 map,确保字段检查可靠。
|
||||
item, ok := filtered[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "text", item["type"])
|
||||
_, hasID := item["id"]
|
||||
require.False(t, hasID)
|
||||
}
|
||||
|
||||
func TestApplyCodexOAuthTransform_EmptyInput(t *testing.T) {
|
||||
// 空 input 应保持为空且不触发异常。
|
||||
setupCodexCache(t)
|
||||
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-5.1",
|
||||
"input": []any{},
|
||||
}
|
||||
|
||||
applyCodexOAuthTransform(reqBody)
|
||||
|
||||
input, ok := reqBody["input"].([]any)
|
||||
require.True(t, ok)
|
||||
require.Len(t, input, 0)
|
||||
}
|
||||
|
||||
func setupCodexCache(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
// 使用临时 HOME 避免触发网络拉取 header。
|
||||
tempDir := t.TempDir()
|
||||
t.Setenv("HOME", tempDir)
|
||||
|
||||
cacheDir := filepath.Join(tempDir, ".opencode", "cache")
|
||||
require.NoError(t, os.MkdirAll(cacheDir, 0o755))
|
||||
require.NoError(t, os.WriteFile(filepath.Join(cacheDir, "opencode-codex-header.txt"), []byte("header"), 0o644))
|
||||
|
||||
meta := map[string]any{
|
||||
"etag": "",
|
||||
"lastFetch": time.Now().UTC().Format(time.RFC3339),
|
||||
"lastChecked": time.Now().UnixMilli(),
|
||||
}
|
||||
data, err := json.Marshal(meta)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, os.WriteFile(filepath.Join(cacheDir, "opencode-codex-header-meta.json"), data, 0o644))
|
||||
}
|
||||
1736
backend/internal/service/openai_gateway_service.go
Normal file
1736
backend/internal/service/openai_gateway_service.go
Normal file
File diff suppressed because it is too large
Load Diff
410
backend/internal/service/openai_gateway_service_test.go
Normal file
410
backend/internal/service/openai_gateway_service_test.go
Normal file
@@ -0,0 +1,410 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type stubOpenAIAccountRepo struct {
|
||||
AccountRepository
|
||||
accounts []Account
|
||||
}
|
||||
|
||||
func (r stubOpenAIAccountRepo) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error) {
|
||||
return append([]Account(nil), r.accounts...), nil
|
||||
}
|
||||
|
||||
func (r stubOpenAIAccountRepo) ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error) {
|
||||
return append([]Account(nil), r.accounts...), nil
|
||||
}
|
||||
|
||||
type stubConcurrencyCache struct {
|
||||
ConcurrencyCache
|
||||
}
|
||||
|
||||
func (c stubConcurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (c stubConcurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c stubConcurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
|
||||
out := make(map[int64]*AccountLoadInfo, len(accounts))
|
||||
for _, acc := range accounts {
|
||||
out[acc.ID] = &AccountLoadInfo{AccountID: acc.ID, LoadRate: 0}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulable(t *testing.T) {
|
||||
now := time.Now()
|
||||
resetAt := now.Add(10 * time.Minute)
|
||||
groupID := int64(1)
|
||||
|
||||
rateLimited := Account{
|
||||
ID: 1,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Priority: 0,
|
||||
RateLimitResetAt: &resetAt,
|
||||
}
|
||||
available := Account{
|
||||
ID: 2,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Priority: 1,
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: stubOpenAIAccountRepo{accounts: []Account{rateLimited, available}},
|
||||
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
|
||||
}
|
||||
|
||||
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-5.2", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
|
||||
}
|
||||
if selection == nil || selection.Account == nil {
|
||||
t.Fatalf("expected selection with account")
|
||||
}
|
||||
if selection.Account.ID != available.ID {
|
||||
t.Fatalf("expected account %d, got %d", available.ID, selection.Account.ID)
|
||||
}
|
||||
if selection.ReleaseFunc != nil {
|
||||
selection.ReleaseFunc()
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulableWhenNoConcurrencyService(t *testing.T) {
|
||||
now := time.Now()
|
||||
resetAt := now.Add(10 * time.Minute)
|
||||
groupID := int64(1)
|
||||
|
||||
rateLimited := Account{
|
||||
ID: 1,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Priority: 0,
|
||||
RateLimitResetAt: &resetAt,
|
||||
}
|
||||
available := Account{
|
||||
ID: 2,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Priority: 1,
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: stubOpenAIAccountRepo{accounts: []Account{rateLimited, available}},
|
||||
// concurrencyService is nil, forcing the non-load-batch selection path.
|
||||
}
|
||||
|
||||
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-5.2", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
|
||||
}
|
||||
if selection == nil || selection.Account == nil {
|
||||
t.Fatalf("expected selection with account")
|
||||
}
|
||||
if selection.Account.ID != available.ID {
|
||||
t.Fatalf("expected account %d, got %d", available.ID, selection.Account.ID)
|
||||
}
|
||||
if selection.ReleaseFunc != nil {
|
||||
selection.ReleaseFunc()
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIStreamingTimeout(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
cfg := &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
StreamDataIntervalTimeout: 1,
|
||||
StreamKeepaliveInterval: 0,
|
||||
MaxLineSize: defaultMaxLineSize,
|
||||
},
|
||||
}
|
||||
svc := &OpenAIGatewayService{cfg: cfg}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: pr,
|
||||
Header: http.Header{},
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, start, "model", "model")
|
||||
_ = pw.Close()
|
||||
_ = pr.Close()
|
||||
|
||||
if err == nil || !strings.Contains(err.Error(), "stream data interval timeout") {
|
||||
t.Fatalf("expected stream timeout error, got %v", err)
|
||||
}
|
||||
if !strings.Contains(rec.Body.String(), "stream_timeout") {
|
||||
t.Fatalf("expected stream_timeout SSE error, got %q", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIStreamingTooLong(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
cfg := &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
StreamDataIntervalTimeout: 0,
|
||||
StreamKeepaliveInterval: 0,
|
||||
MaxLineSize: 64 * 1024,
|
||||
},
|
||||
}
|
||||
svc := &OpenAIGatewayService{cfg: cfg}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: pr,
|
||||
Header: http.Header{},
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer func() { _ = pw.Close() }()
|
||||
// 写入超过 MaxLineSize 的单行数据,触发 ErrTooLong
|
||||
payload := "data: " + strings.Repeat("a", 128*1024) + "\n"
|
||||
_, _ = pw.Write([]byte(payload))
|
||||
}()
|
||||
|
||||
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 2}, time.Now(), "model", "model")
|
||||
_ = pr.Close()
|
||||
|
||||
if !errors.Is(err, bufio.ErrTooLong) {
|
||||
t.Fatalf("expected ErrTooLong, got %v", err)
|
||||
}
|
||||
if !strings.Contains(rec.Body.String(), "response_too_large") {
|
||||
t.Fatalf("expected response_too_large SSE error, got %q", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAINonStreamingContentTypePassThrough(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{
|
||||
ResponseHeaders: config.ResponseHeaderConfig{Enabled: false},
|
||||
},
|
||||
}
|
||||
svc := &OpenAIGatewayService{cfg: cfg}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
|
||||
body := []byte(`{"usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":0}}}`)
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(bytes.NewReader(body)),
|
||||
Header: http.Header{"Content-Type": []string{"application/vnd.test+json"}},
|
||||
}
|
||||
|
||||
_, err := svc.handleNonStreamingResponse(c.Request.Context(), resp, c, &Account{}, "model", "model")
|
||||
if err != nil {
|
||||
t.Fatalf("handleNonStreamingResponse error: %v", err)
|
||||
}
|
||||
|
||||
if !strings.Contains(rec.Header().Get("Content-Type"), "application/vnd.test+json") {
|
||||
t.Fatalf("expected Content-Type passthrough, got %q", rec.Header().Get("Content-Type"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAINonStreamingContentTypeDefault(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{
|
||||
ResponseHeaders: config.ResponseHeaderConfig{Enabled: false},
|
||||
},
|
||||
}
|
||||
svc := &OpenAIGatewayService{cfg: cfg}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
|
||||
body := []byte(`{"usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":0}}}`)
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(bytes.NewReader(body)),
|
||||
Header: http.Header{},
|
||||
}
|
||||
|
||||
_, err := svc.handleNonStreamingResponse(c.Request.Context(), resp, c, &Account{}, "model", "model")
|
||||
if err != nil {
|
||||
t.Fatalf("handleNonStreamingResponse error: %v", err)
|
||||
}
|
||||
|
||||
if !strings.Contains(rec.Header().Get("Content-Type"), "application/json") {
|
||||
t.Fatalf("expected default Content-Type, got %q", rec.Header().Get("Content-Type"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIStreamingHeadersOverride(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{
|
||||
ResponseHeaders: config.ResponseHeaderConfig{Enabled: false},
|
||||
},
|
||||
Gateway: config.GatewayConfig{
|
||||
StreamDataIntervalTimeout: 0,
|
||||
StreamKeepaliveInterval: 0,
|
||||
MaxLineSize: defaultMaxLineSize,
|
||||
},
|
||||
}
|
||||
svc := &OpenAIGatewayService{cfg: cfg}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: pr,
|
||||
Header: http.Header{
|
||||
"Cache-Control": []string{"upstream"},
|
||||
"X-Request-Id": []string{"req-123"},
|
||||
"Content-Type": []string{"application/custom"},
|
||||
},
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer func() { _ = pw.Close() }()
|
||||
_, _ = pw.Write([]byte("data: {}\n\n"))
|
||||
}()
|
||||
|
||||
_, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model")
|
||||
_ = pr.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("handleStreamingResponse error: %v", err)
|
||||
}
|
||||
|
||||
if rec.Header().Get("Cache-Control") != "no-cache" {
|
||||
t.Fatalf("expected Cache-Control override, got %q", rec.Header().Get("Cache-Control"))
|
||||
}
|
||||
if rec.Header().Get("Content-Type") != "text/event-stream" {
|
||||
t.Fatalf("expected Content-Type override, got %q", rec.Header().Get("Content-Type"))
|
||||
}
|
||||
if rec.Header().Get("X-Request-Id") != "req-123" {
|
||||
t.Fatalf("expected X-Request-Id passthrough, got %q", rec.Header().Get("X-Request-Id"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIInvalidBaseURLWhenAllowlistDisabled(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{
|
||||
URLAllowlist: config.URLAllowlistConfig{Enabled: false},
|
||||
},
|
||||
}
|
||||
svc := &OpenAIGatewayService{cfg: cfg}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
|
||||
account := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Credentials: map[string]any{"base_url": "://invalid-url"},
|
||||
}
|
||||
|
||||
_, err := svc.buildUpstreamRequest(c.Request.Context(), c, account, []byte("{}"), "token", false, "", false)
|
||||
if err == nil {
|
||||
t.Fatalf("expected error for invalid base_url when allowlist disabled")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIValidateUpstreamBaseURLDisabledRequiresHTTPS(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{
|
||||
URLAllowlist: config.URLAllowlistConfig{Enabled: false},
|
||||
},
|
||||
}
|
||||
svc := &OpenAIGatewayService{cfg: cfg}
|
||||
|
||||
if _, err := svc.validateUpstreamBaseURL("http://not-https.example.com"); err == nil {
|
||||
t.Fatalf("expected http to be rejected when allow_insecure_http is false")
|
||||
}
|
||||
normalized, err := svc.validateUpstreamBaseURL("https://example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("expected https to be allowed when allowlist disabled, got %v", err)
|
||||
}
|
||||
if normalized != "https://example.com" {
|
||||
t.Fatalf("expected raw url passthrough, got %q", normalized)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIValidateUpstreamBaseURLDisabledAllowsHTTP(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{
|
||||
URLAllowlist: config.URLAllowlistConfig{
|
||||
Enabled: false,
|
||||
AllowInsecureHTTP: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
svc := &OpenAIGatewayService{cfg: cfg}
|
||||
|
||||
normalized, err := svc.validateUpstreamBaseURL("http://not-https.example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("expected http allowed when allow_insecure_http is true, got %v", err)
|
||||
}
|
||||
if normalized != "http://not-https.example.com" {
|
||||
t.Fatalf("expected raw url passthrough, got %q", normalized)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIValidateUpstreamBaseURLEnabledEnforcesAllowlist(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Security: config.SecurityConfig{
|
||||
URLAllowlist: config.URLAllowlistConfig{
|
||||
Enabled: true,
|
||||
UpstreamHosts: []string{"example.com"},
|
||||
},
|
||||
},
|
||||
}
|
||||
svc := &OpenAIGatewayService{cfg: cfg}
|
||||
|
||||
if _, err := svc.validateUpstreamBaseURL("https://example.com"); err != nil {
|
||||
t.Fatalf("expected allowlisted host to pass, got %v", err)
|
||||
}
|
||||
if _, err := svc.validateUpstreamBaseURL("https://evil.com"); err == nil {
|
||||
t.Fatalf("expected non-allowlisted host to fail")
|
||||
}
|
||||
}
|
||||
255
backend/internal/service/openai_oauth_service.go
Normal file
255
backend/internal/service/openai_oauth_service.go
Normal file
@@ -0,0 +1,255 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
)
|
||||
|
||||
// OpenAIOAuthService handles OpenAI OAuth authentication flows
|
||||
type OpenAIOAuthService struct {
|
||||
sessionStore *openai.SessionStore
|
||||
proxyRepo ProxyRepository
|
||||
oauthClient OpenAIOAuthClient
|
||||
}
|
||||
|
||||
// NewOpenAIOAuthService creates a new OpenAI OAuth service
|
||||
func NewOpenAIOAuthService(proxyRepo ProxyRepository, oauthClient OpenAIOAuthClient) *OpenAIOAuthService {
|
||||
return &OpenAIOAuthService{
|
||||
sessionStore: openai.NewSessionStore(),
|
||||
proxyRepo: proxyRepo,
|
||||
oauthClient: oauthClient,
|
||||
}
|
||||
}
|
||||
|
||||
// OpenAIAuthURLResult contains the authorization URL and session info
|
||||
type OpenAIAuthURLResult struct {
|
||||
AuthURL string `json:"auth_url"`
|
||||
SessionID string `json:"session_id"`
|
||||
}
|
||||
|
||||
// GenerateAuthURL generates an OpenAI OAuth authorization URL
|
||||
func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64, redirectURI string) (*OpenAIAuthURLResult, error) {
|
||||
// Generate PKCE values
|
||||
state, err := openai.GenerateState()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate state: %w", err)
|
||||
}
|
||||
|
||||
codeVerifier, err := openai.GenerateCodeVerifier()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate code verifier: %w", err)
|
||||
}
|
||||
|
||||
codeChallenge := openai.GenerateCodeChallenge(codeVerifier)
|
||||
|
||||
// Generate session ID
|
||||
sessionID, err := openai.GenerateSessionID()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate session ID: %w", err)
|
||||
}
|
||||
|
||||
// Get proxy URL if specified
|
||||
var proxyURL string
|
||||
if proxyID != nil {
|
||||
proxy, err := s.proxyRepo.GetByID(ctx, *proxyID)
|
||||
if err == nil && proxy != nil {
|
||||
proxyURL = proxy.URL()
|
||||
}
|
||||
}
|
||||
|
||||
// Use default redirect URI if not specified
|
||||
if redirectURI == "" {
|
||||
redirectURI = openai.DefaultRedirectURI
|
||||
}
|
||||
|
||||
// Store session
|
||||
session := &openai.OAuthSession{
|
||||
State: state,
|
||||
CodeVerifier: codeVerifier,
|
||||
RedirectURI: redirectURI,
|
||||
ProxyURL: proxyURL,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
s.sessionStore.Set(sessionID, session)
|
||||
|
||||
// Build authorization URL
|
||||
authURL := openai.BuildAuthorizationURL(state, codeChallenge, redirectURI)
|
||||
|
||||
return &OpenAIAuthURLResult{
|
||||
AuthURL: authURL,
|
||||
SessionID: sessionID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// OpenAIExchangeCodeInput represents the input for code exchange
|
||||
type OpenAIExchangeCodeInput struct {
|
||||
SessionID string
|
||||
Code string
|
||||
RedirectURI string
|
||||
ProxyID *int64
|
||||
}
|
||||
|
||||
// OpenAITokenInfo represents the token information for OpenAI
|
||||
type OpenAITokenInfo struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
IDToken string `json:"id_token,omitempty"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
ExpiresAt int64 `json:"expires_at"`
|
||||
Email string `json:"email,omitempty"`
|
||||
ChatGPTAccountID string `json:"chatgpt_account_id,omitempty"`
|
||||
ChatGPTUserID string `json:"chatgpt_user_id,omitempty"`
|
||||
OrganizationID string `json:"organization_id,omitempty"`
|
||||
}
|
||||
|
||||
// ExchangeCode exchanges authorization code for tokens
|
||||
func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExchangeCodeInput) (*OpenAITokenInfo, error) {
|
||||
// Get session
|
||||
session, ok := s.sessionStore.Get(input.SessionID)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("session not found or expired")
|
||||
}
|
||||
|
||||
// Get proxy URL
|
||||
proxyURL := session.ProxyURL
|
||||
if input.ProxyID != nil {
|
||||
proxy, err := s.proxyRepo.GetByID(ctx, *input.ProxyID)
|
||||
if err == nil && proxy != nil {
|
||||
proxyURL = proxy.URL()
|
||||
}
|
||||
}
|
||||
|
||||
// Use redirect URI from session or input
|
||||
redirectURI := session.RedirectURI
|
||||
if input.RedirectURI != "" {
|
||||
redirectURI = input.RedirectURI
|
||||
}
|
||||
|
||||
// Exchange code for token
|
||||
tokenResp, err := s.oauthClient.ExchangeCode(ctx, input.Code, session.CodeVerifier, redirectURI, proxyURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to exchange code: %w", err)
|
||||
}
|
||||
|
||||
// Parse ID token to get user info
|
||||
var userInfo *openai.UserInfo
|
||||
if tokenResp.IDToken != "" {
|
||||
claims, err := openai.ParseIDToken(tokenResp.IDToken)
|
||||
if err == nil {
|
||||
userInfo = claims.GetUserInfo()
|
||||
}
|
||||
}
|
||||
|
||||
// Delete session after successful exchange
|
||||
s.sessionStore.Delete(input.SessionID)
|
||||
|
||||
tokenInfo := &OpenAITokenInfo{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
IDToken: tokenResp.IDToken,
|
||||
ExpiresIn: int64(tokenResp.ExpiresIn),
|
||||
ExpiresAt: time.Now().Unix() + int64(tokenResp.ExpiresIn),
|
||||
}
|
||||
|
||||
if userInfo != nil {
|
||||
tokenInfo.Email = userInfo.Email
|
||||
tokenInfo.ChatGPTAccountID = userInfo.ChatGPTAccountID
|
||||
tokenInfo.ChatGPTUserID = userInfo.ChatGPTUserID
|
||||
tokenInfo.OrganizationID = userInfo.OrganizationID
|
||||
}
|
||||
|
||||
return tokenInfo, nil
|
||||
}
|
||||
|
||||
// RefreshToken refreshes an OpenAI OAuth token
|
||||
func (s *OpenAIOAuthService) RefreshToken(ctx context.Context, refreshToken string, proxyURL string) (*OpenAITokenInfo, error) {
|
||||
tokenResp, err := s.oauthClient.RefreshToken(ctx, refreshToken, proxyURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse ID token to get user info
|
||||
var userInfo *openai.UserInfo
|
||||
if tokenResp.IDToken != "" {
|
||||
claims, err := openai.ParseIDToken(tokenResp.IDToken)
|
||||
if err == nil {
|
||||
userInfo = claims.GetUserInfo()
|
||||
}
|
||||
}
|
||||
|
||||
tokenInfo := &OpenAITokenInfo{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
IDToken: tokenResp.IDToken,
|
||||
ExpiresIn: int64(tokenResp.ExpiresIn),
|
||||
ExpiresAt: time.Now().Unix() + int64(tokenResp.ExpiresIn),
|
||||
}
|
||||
|
||||
if userInfo != nil {
|
||||
tokenInfo.Email = userInfo.Email
|
||||
tokenInfo.ChatGPTAccountID = userInfo.ChatGPTAccountID
|
||||
tokenInfo.ChatGPTUserID = userInfo.ChatGPTUserID
|
||||
tokenInfo.OrganizationID = userInfo.OrganizationID
|
||||
}
|
||||
|
||||
return tokenInfo, nil
|
||||
}
|
||||
|
||||
// RefreshAccountToken refreshes token for an OpenAI account
|
||||
func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*OpenAITokenInfo, error) {
|
||||
if !account.IsOpenAI() {
|
||||
return nil, fmt.Errorf("account is not an OpenAI account")
|
||||
}
|
||||
|
||||
refreshToken := account.GetOpenAIRefreshToken()
|
||||
if refreshToken == "" {
|
||||
return nil, fmt.Errorf("no refresh token available")
|
||||
}
|
||||
|
||||
var proxyURL string
|
||||
if account.ProxyID != nil {
|
||||
proxy, err := s.proxyRepo.GetByID(ctx, *account.ProxyID)
|
||||
if err == nil && proxy != nil {
|
||||
proxyURL = proxy.URL()
|
||||
}
|
||||
}
|
||||
|
||||
return s.RefreshToken(ctx, refreshToken, proxyURL)
|
||||
}
|
||||
|
||||
// BuildAccountCredentials builds credentials map from token info
|
||||
func (s *OpenAIOAuthService) BuildAccountCredentials(tokenInfo *OpenAITokenInfo) map[string]any {
|
||||
expiresAt := time.Unix(tokenInfo.ExpiresAt, 0).Format(time.RFC3339)
|
||||
|
||||
creds := map[string]any{
|
||||
"access_token": tokenInfo.AccessToken,
|
||||
"refresh_token": tokenInfo.RefreshToken,
|
||||
"expires_at": expiresAt,
|
||||
}
|
||||
|
||||
if tokenInfo.IDToken != "" {
|
||||
creds["id_token"] = tokenInfo.IDToken
|
||||
}
|
||||
if tokenInfo.Email != "" {
|
||||
creds["email"] = tokenInfo.Email
|
||||
}
|
||||
if tokenInfo.ChatGPTAccountID != "" {
|
||||
creds["chatgpt_account_id"] = tokenInfo.ChatGPTAccountID
|
||||
}
|
||||
if tokenInfo.ChatGPTUserID != "" {
|
||||
creds["chatgpt_user_id"] = tokenInfo.ChatGPTUserID
|
||||
}
|
||||
if tokenInfo.OrganizationID != "" {
|
||||
creds["organization_id"] = tokenInfo.OrganizationID
|
||||
}
|
||||
|
||||
return creds
|
||||
}
|
||||
|
||||
// Stop stops the session store cleanup goroutine
|
||||
func (s *OpenAIOAuthService) Stop() {
|
||||
s.sessionStore.Stop()
|
||||
}
|
||||
213
backend/internal/service/openai_tool_continuation.go
Normal file
213
backend/internal/service/openai_tool_continuation.go
Normal file
@@ -0,0 +1,213 @@
|
||||
package service
|
||||
|
||||
import "strings"
|
||||
|
||||
// NeedsToolContinuation 判定请求是否需要工具调用续链处理。
|
||||
// 满足以下任一信号即视为续链:previous_response_id、input 内包含 function_call_output/item_reference、
|
||||
// 或显式声明 tools/tool_choice。
|
||||
func NeedsToolContinuation(reqBody map[string]any) bool {
|
||||
if reqBody == nil {
|
||||
return false
|
||||
}
|
||||
if hasNonEmptyString(reqBody["previous_response_id"]) {
|
||||
return true
|
||||
}
|
||||
if hasToolsSignal(reqBody) {
|
||||
return true
|
||||
}
|
||||
if hasToolChoiceSignal(reqBody) {
|
||||
return true
|
||||
}
|
||||
if inputHasType(reqBody, "function_call_output") {
|
||||
return true
|
||||
}
|
||||
if inputHasType(reqBody, "item_reference") {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// HasFunctionCallOutput 判断 input 是否包含 function_call_output,用于触发续链校验。
|
||||
func HasFunctionCallOutput(reqBody map[string]any) bool {
|
||||
if reqBody == nil {
|
||||
return false
|
||||
}
|
||||
return inputHasType(reqBody, "function_call_output")
|
||||
}
|
||||
|
||||
// HasToolCallContext 判断 input 是否包含带 call_id 的 tool_call/function_call,
|
||||
// 用于判断 function_call_output 是否具备可关联的上下文。
|
||||
func HasToolCallContext(reqBody map[string]any) bool {
|
||||
if reqBody == nil {
|
||||
return false
|
||||
}
|
||||
input, ok := reqBody["input"].([]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
for _, item := range input {
|
||||
itemMap, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
itemType, _ := itemMap["type"].(string)
|
||||
if itemType != "tool_call" && itemType != "function_call" {
|
||||
continue
|
||||
}
|
||||
if callID, ok := itemMap["call_id"].(string); ok && strings.TrimSpace(callID) != "" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// FunctionCallOutputCallIDs 提取 input 中 function_call_output 的 call_id 集合。
|
||||
// 仅返回非空 call_id,用于与 item_reference.id 做匹配校验。
|
||||
func FunctionCallOutputCallIDs(reqBody map[string]any) []string {
|
||||
if reqBody == nil {
|
||||
return nil
|
||||
}
|
||||
input, ok := reqBody["input"].([]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
ids := make(map[string]struct{})
|
||||
for _, item := range input {
|
||||
itemMap, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
itemType, _ := itemMap["type"].(string)
|
||||
if itemType != "function_call_output" {
|
||||
continue
|
||||
}
|
||||
if callID, ok := itemMap["call_id"].(string); ok && strings.TrimSpace(callID) != "" {
|
||||
ids[callID] = struct{}{}
|
||||
}
|
||||
}
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
result := make([]string, 0, len(ids))
|
||||
for id := range ids {
|
||||
result = append(result, id)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// HasFunctionCallOutputMissingCallID 判断是否存在缺少 call_id 的 function_call_output。
|
||||
func HasFunctionCallOutputMissingCallID(reqBody map[string]any) bool {
|
||||
if reqBody == nil {
|
||||
return false
|
||||
}
|
||||
input, ok := reqBody["input"].([]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
for _, item := range input {
|
||||
itemMap, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
itemType, _ := itemMap["type"].(string)
|
||||
if itemType != "function_call_output" {
|
||||
continue
|
||||
}
|
||||
callID, _ := itemMap["call_id"].(string)
|
||||
if strings.TrimSpace(callID) == "" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// HasItemReferenceForCallIDs 判断 item_reference.id 是否覆盖所有 call_id。
|
||||
// 用于仅依赖引用项完成续链场景的校验。
|
||||
func HasItemReferenceForCallIDs(reqBody map[string]any, callIDs []string) bool {
|
||||
if reqBody == nil || len(callIDs) == 0 {
|
||||
return false
|
||||
}
|
||||
input, ok := reqBody["input"].([]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
referenceIDs := make(map[string]struct{})
|
||||
for _, item := range input {
|
||||
itemMap, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
itemType, _ := itemMap["type"].(string)
|
||||
if itemType != "item_reference" {
|
||||
continue
|
||||
}
|
||||
idValue, _ := itemMap["id"].(string)
|
||||
idValue = strings.TrimSpace(idValue)
|
||||
if idValue == "" {
|
||||
continue
|
||||
}
|
||||
referenceIDs[idValue] = struct{}{}
|
||||
}
|
||||
if len(referenceIDs) == 0 {
|
||||
return false
|
||||
}
|
||||
for _, callID := range callIDs {
|
||||
if _, ok := referenceIDs[callID]; !ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// inputHasType 判断 input 中是否存在指定类型的 item。
|
||||
func inputHasType(reqBody map[string]any, want string) bool {
|
||||
input, ok := reqBody["input"].([]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
for _, item := range input {
|
||||
itemMap, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
itemType, _ := itemMap["type"].(string)
|
||||
if itemType == want {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// hasNonEmptyString 判断字段是否为非空字符串。
|
||||
func hasNonEmptyString(value any) bool {
|
||||
stringValue, ok := value.(string)
|
||||
return ok && strings.TrimSpace(stringValue) != ""
|
||||
}
|
||||
|
||||
// hasToolsSignal 判断 tools 字段是否显式声明(存在且不为空)。
|
||||
func hasToolsSignal(reqBody map[string]any) bool {
|
||||
raw, exists := reqBody["tools"]
|
||||
if !exists || raw == nil {
|
||||
return false
|
||||
}
|
||||
if tools, ok := raw.([]any); ok {
|
||||
return len(tools) > 0
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// hasToolChoiceSignal 判断 tool_choice 是否显式声明(非空或非 nil)。
|
||||
func hasToolChoiceSignal(reqBody map[string]any) bool {
|
||||
raw, exists := reqBody["tool_choice"]
|
||||
if !exists || raw == nil {
|
||||
return false
|
||||
}
|
||||
switch value := raw.(type) {
|
||||
case string:
|
||||
return strings.TrimSpace(value) != ""
|
||||
case map[string]any:
|
||||
return len(value) > 0
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
98
backend/internal/service/openai_tool_continuation_test.go
Normal file
98
backend/internal/service/openai_tool_continuation_test.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNeedsToolContinuationSignals(t *testing.T) {
|
||||
// 覆盖所有触发续链的信号来源,确保判定逻辑完整。
|
||||
cases := []struct {
|
||||
name string
|
||||
body map[string]any
|
||||
want bool
|
||||
}{
|
||||
{name: "nil", body: nil, want: false},
|
||||
{name: "previous_response_id", body: map[string]any{"previous_response_id": "resp_1"}, want: true},
|
||||
{name: "previous_response_id_blank", body: map[string]any{"previous_response_id": " "}, want: false},
|
||||
{name: "function_call_output", body: map[string]any{"input": []any{map[string]any{"type": "function_call_output"}}}, want: true},
|
||||
{name: "item_reference", body: map[string]any{"input": []any{map[string]any{"type": "item_reference"}}}, want: true},
|
||||
{name: "tools", body: map[string]any{"tools": []any{map[string]any{"type": "function"}}}, want: true},
|
||||
{name: "tools_empty", body: map[string]any{"tools": []any{}}, want: false},
|
||||
{name: "tools_invalid", body: map[string]any{"tools": "bad"}, want: false},
|
||||
{name: "tool_choice", body: map[string]any{"tool_choice": "auto"}, want: true},
|
||||
{name: "tool_choice_object", body: map[string]any{"tool_choice": map[string]any{"type": "function"}}, want: true},
|
||||
{name: "tool_choice_empty_object", body: map[string]any{"tool_choice": map[string]any{}}, want: false},
|
||||
{name: "none", body: map[string]any{"input": []any{map[string]any{"type": "text", "text": "hi"}}}, want: false},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
require.Equal(t, tt.want, NeedsToolContinuation(tt.body))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasFunctionCallOutput(t *testing.T) {
|
||||
// 仅当 input 中存在 function_call_output 才视为续链输出。
|
||||
require.False(t, HasFunctionCallOutput(nil))
|
||||
require.True(t, HasFunctionCallOutput(map[string]any{
|
||||
"input": []any{map[string]any{"type": "function_call_output"}},
|
||||
}))
|
||||
require.False(t, HasFunctionCallOutput(map[string]any{
|
||||
"input": "text",
|
||||
}))
|
||||
}
|
||||
|
||||
func TestHasToolCallContext(t *testing.T) {
|
||||
// tool_call/function_call 必须包含 call_id,才能作为可关联上下文。
|
||||
require.False(t, HasToolCallContext(nil))
|
||||
require.True(t, HasToolCallContext(map[string]any{
|
||||
"input": []any{map[string]any{"type": "tool_call", "call_id": "call_1"}},
|
||||
}))
|
||||
require.True(t, HasToolCallContext(map[string]any{
|
||||
"input": []any{map[string]any{"type": "function_call", "call_id": "call_2"}},
|
||||
}))
|
||||
require.False(t, HasToolCallContext(map[string]any{
|
||||
"input": []any{map[string]any{"type": "tool_call"}},
|
||||
}))
|
||||
}
|
||||
|
||||
func TestFunctionCallOutputCallIDs(t *testing.T) {
|
||||
// 仅提取非空 call_id,去重后返回。
|
||||
require.Empty(t, FunctionCallOutputCallIDs(nil))
|
||||
callIDs := FunctionCallOutputCallIDs(map[string]any{
|
||||
"input": []any{
|
||||
map[string]any{"type": "function_call_output", "call_id": "call_1"},
|
||||
map[string]any{"type": "function_call_output", "call_id": ""},
|
||||
map[string]any{"type": "function_call_output", "call_id": "call_1"},
|
||||
},
|
||||
})
|
||||
require.ElementsMatch(t, []string{"call_1"}, callIDs)
|
||||
}
|
||||
|
||||
func TestHasFunctionCallOutputMissingCallID(t *testing.T) {
|
||||
require.False(t, HasFunctionCallOutputMissingCallID(nil))
|
||||
require.True(t, HasFunctionCallOutputMissingCallID(map[string]any{
|
||||
"input": []any{map[string]any{"type": "function_call_output"}},
|
||||
}))
|
||||
require.False(t, HasFunctionCallOutputMissingCallID(map[string]any{
|
||||
"input": []any{map[string]any{"type": "function_call_output", "call_id": "call_1"}},
|
||||
}))
|
||||
}
|
||||
|
||||
func TestHasItemReferenceForCallIDs(t *testing.T) {
|
||||
// item_reference 需要覆盖所有 call_id 才视为可关联上下文。
|
||||
require.False(t, HasItemReferenceForCallIDs(nil, []string{"call_1"}))
|
||||
require.False(t, HasItemReferenceForCallIDs(map[string]any{}, []string{"call_1"}))
|
||||
req := map[string]any{
|
||||
"input": []any{
|
||||
map[string]any{"type": "item_reference", "id": "call_1"},
|
||||
map[string]any{"type": "item_reference", "id": "call_2"},
|
||||
},
|
||||
}
|
||||
require.True(t, HasItemReferenceForCallIDs(req, []string{"call_1"}))
|
||||
require.True(t, HasItemReferenceForCallIDs(req, []string{"call_1", "call_2"}))
|
||||
require.False(t, HasItemReferenceForCallIDs(req, []string{"call_1", "call_3"}))
|
||||
}
|
||||
194
backend/internal/service/ops_account_availability.go
Normal file
194
backend/internal/service/ops_account_availability.go
Normal file
@@ -0,0 +1,194 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
)
|
||||
|
||||
// GetAccountAvailabilityStats returns current account availability stats.
|
||||
//
|
||||
// Query-level filtering is intentionally limited to platform/group to match the dashboard scope.
|
||||
func (s *OpsService) GetAccountAvailabilityStats(ctx context.Context, platformFilter string, groupIDFilter *int64) (
|
||||
map[string]*PlatformAvailability,
|
||||
map[int64]*GroupAvailability,
|
||||
map[int64]*AccountAvailability,
|
||||
*time.Time,
|
||||
error,
|
||||
) {
|
||||
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
||||
return nil, nil, nil, nil, err
|
||||
}
|
||||
|
||||
accounts, err := s.listAllAccountsForOps(ctx, platformFilter)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, err
|
||||
}
|
||||
|
||||
if groupIDFilter != nil && *groupIDFilter > 0 {
|
||||
filtered := make([]Account, 0, len(accounts))
|
||||
for _, acc := range accounts {
|
||||
for _, grp := range acc.Groups {
|
||||
if grp != nil && grp.ID == *groupIDFilter {
|
||||
filtered = append(filtered, acc)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
accounts = filtered
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
collectedAt := now
|
||||
|
||||
platform := make(map[string]*PlatformAvailability)
|
||||
group := make(map[int64]*GroupAvailability)
|
||||
account := make(map[int64]*AccountAvailability)
|
||||
|
||||
for _, acc := range accounts {
|
||||
if acc.ID <= 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
isTempUnsched := false
|
||||
if acc.TempUnschedulableUntil != nil && now.Before(*acc.TempUnschedulableUntil) {
|
||||
isTempUnsched = true
|
||||
}
|
||||
|
||||
isRateLimited := acc.RateLimitResetAt != nil && now.Before(*acc.RateLimitResetAt)
|
||||
isOverloaded := acc.OverloadUntil != nil && now.Before(*acc.OverloadUntil)
|
||||
hasError := acc.Status == StatusError
|
||||
|
||||
// Normalize exclusive status flags so the UI doesn't show conflicting badges.
|
||||
if hasError {
|
||||
isRateLimited = false
|
||||
isOverloaded = false
|
||||
}
|
||||
|
||||
isAvailable := acc.Status == StatusActive && acc.Schedulable && !isRateLimited && !isOverloaded && !isTempUnsched
|
||||
|
||||
if acc.Platform != "" {
|
||||
if _, ok := platform[acc.Platform]; !ok {
|
||||
platform[acc.Platform] = &PlatformAvailability{
|
||||
Platform: acc.Platform,
|
||||
}
|
||||
}
|
||||
p := platform[acc.Platform]
|
||||
p.TotalAccounts++
|
||||
if isAvailable {
|
||||
p.AvailableCount++
|
||||
}
|
||||
if isRateLimited {
|
||||
p.RateLimitCount++
|
||||
}
|
||||
if hasError {
|
||||
p.ErrorCount++
|
||||
}
|
||||
}
|
||||
|
||||
for _, grp := range acc.Groups {
|
||||
if grp == nil || grp.ID <= 0 {
|
||||
continue
|
||||
}
|
||||
if _, ok := group[grp.ID]; !ok {
|
||||
group[grp.ID] = &GroupAvailability{
|
||||
GroupID: grp.ID,
|
||||
GroupName: grp.Name,
|
||||
Platform: grp.Platform,
|
||||
}
|
||||
}
|
||||
g := group[grp.ID]
|
||||
g.TotalAccounts++
|
||||
if isAvailable {
|
||||
g.AvailableCount++
|
||||
}
|
||||
if isRateLimited {
|
||||
g.RateLimitCount++
|
||||
}
|
||||
if hasError {
|
||||
g.ErrorCount++
|
||||
}
|
||||
}
|
||||
|
||||
displayGroupID := int64(0)
|
||||
displayGroupName := ""
|
||||
if len(acc.Groups) > 0 && acc.Groups[0] != nil {
|
||||
displayGroupID = acc.Groups[0].ID
|
||||
displayGroupName = acc.Groups[0].Name
|
||||
}
|
||||
|
||||
item := &AccountAvailability{
|
||||
AccountID: acc.ID,
|
||||
AccountName: acc.Name,
|
||||
Platform: acc.Platform,
|
||||
GroupID: displayGroupID,
|
||||
GroupName: displayGroupName,
|
||||
Status: acc.Status,
|
||||
|
||||
IsAvailable: isAvailable,
|
||||
IsRateLimited: isRateLimited,
|
||||
IsOverloaded: isOverloaded,
|
||||
HasError: hasError,
|
||||
|
||||
ErrorMessage: acc.ErrorMessage,
|
||||
}
|
||||
|
||||
if isRateLimited && acc.RateLimitResetAt != nil {
|
||||
item.RateLimitResetAt = acc.RateLimitResetAt
|
||||
remainingSec := int64(time.Until(*acc.RateLimitResetAt).Seconds())
|
||||
if remainingSec > 0 {
|
||||
item.RateLimitRemainingSec = &remainingSec
|
||||
}
|
||||
}
|
||||
if isOverloaded && acc.OverloadUntil != nil {
|
||||
item.OverloadUntil = acc.OverloadUntil
|
||||
remainingSec := int64(time.Until(*acc.OverloadUntil).Seconds())
|
||||
if remainingSec > 0 {
|
||||
item.OverloadRemainingSec = &remainingSec
|
||||
}
|
||||
}
|
||||
if isTempUnsched && acc.TempUnschedulableUntil != nil {
|
||||
item.TempUnschedulableUntil = acc.TempUnschedulableUntil
|
||||
}
|
||||
|
||||
account[acc.ID] = item
|
||||
}
|
||||
|
||||
return platform, group, account, &collectedAt, nil
|
||||
}
|
||||
|
||||
type OpsAccountAvailability struct {
|
||||
Group *GroupAvailability
|
||||
Accounts map[int64]*AccountAvailability
|
||||
CollectedAt *time.Time
|
||||
}
|
||||
|
||||
func (s *OpsService) GetAccountAvailability(ctx context.Context, platformFilter string, groupIDFilter *int64) (*OpsAccountAvailability, error) {
|
||||
if s == nil {
|
||||
return nil, errors.New("ops service is nil")
|
||||
}
|
||||
|
||||
if s.getAccountAvailability != nil {
|
||||
return s.getAccountAvailability(ctx, platformFilter, groupIDFilter)
|
||||
}
|
||||
|
||||
_, groupStats, accountStats, collectedAt, err := s.GetAccountAvailabilityStats(ctx, platformFilter, groupIDFilter)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var group *GroupAvailability
|
||||
if groupIDFilter != nil && *groupIDFilter > 0 {
|
||||
group = groupStats[*groupIDFilter]
|
||||
}
|
||||
|
||||
if accountStats == nil {
|
||||
accountStats = map[int64]*AccountAvailability{}
|
||||
}
|
||||
|
||||
return &OpsAccountAvailability{
|
||||
Group: group,
|
||||
Accounts: accountStats,
|
||||
CollectedAt: collectedAt,
|
||||
}, nil
|
||||
}
|
||||
46
backend/internal/service/ops_advisory_lock.go
Normal file
46
backend/internal/service/ops_advisory_lock.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"hash/fnv"
|
||||
"time"
|
||||
)
|
||||
|
||||
func hashAdvisoryLockID(key string) int64 {
|
||||
h := fnv.New64a()
|
||||
_, _ = h.Write([]byte(key))
|
||||
return int64(h.Sum64())
|
||||
}
|
||||
|
||||
func tryAcquireDBAdvisoryLock(ctx context.Context, db *sql.DB, lockID int64) (func(), bool) {
|
||||
if db == nil {
|
||||
return nil, false
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
conn, err := db.Conn(ctx)
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
acquired := false
|
||||
if err := conn.QueryRowContext(ctx, "SELECT pg_try_advisory_lock($1)", lockID).Scan(&acquired); err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, false
|
||||
}
|
||||
if !acquired {
|
||||
_ = conn.Close()
|
||||
return nil, false
|
||||
}
|
||||
|
||||
release := func() {
|
||||
unlockCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
_, _ = conn.ExecContext(unlockCtx, "SELECT pg_advisory_unlock($1)", lockID)
|
||||
_ = conn.Close()
|
||||
}
|
||||
return release, true
|
||||
}
|
||||
443
backend/internal/service/ops_aggregation_service.go
Normal file
443
backend/internal/service/ops_aggregation_service.go
Normal file
@@ -0,0 +1,443 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"log"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/google/uuid"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const (
|
||||
opsAggHourlyJobName = "ops_preaggregation_hourly"
|
||||
opsAggDailyJobName = "ops_preaggregation_daily"
|
||||
|
||||
opsAggHourlyInterval = 10 * time.Minute
|
||||
opsAggDailyInterval = 1 * time.Hour
|
||||
|
||||
// Keep in sync with ops retention target (vNext default 30d).
|
||||
opsAggBackfillWindow = 30 * 24 * time.Hour
|
||||
|
||||
// Recompute overlap to absorb late-arriving rows near boundaries.
|
||||
opsAggHourlyOverlap = 2 * time.Hour
|
||||
opsAggDailyOverlap = 48 * time.Hour
|
||||
|
||||
opsAggHourlyChunk = 24 * time.Hour
|
||||
opsAggDailyChunk = 7 * 24 * time.Hour
|
||||
|
||||
// Delay around boundaries (e.g. 10:00..10:05) to avoid aggregating buckets
|
||||
// that may still receive late inserts.
|
||||
opsAggSafeDelay = 5 * time.Minute
|
||||
|
||||
opsAggMaxQueryTimeout = 3 * time.Second
|
||||
opsAggHourlyTimeout = 5 * time.Minute
|
||||
opsAggDailyTimeout = 2 * time.Minute
|
||||
|
||||
opsAggHourlyLeaderLockKey = "ops:aggregation:hourly:leader"
|
||||
opsAggDailyLeaderLockKey = "ops:aggregation:daily:leader"
|
||||
|
||||
opsAggHourlyLeaderLockTTL = 15 * time.Minute
|
||||
opsAggDailyLeaderLockTTL = 10 * time.Minute
|
||||
)
|
||||
|
||||
// OpsAggregationService periodically backfills ops_metrics_hourly / ops_metrics_daily
|
||||
// for stable long-window dashboard queries.
|
||||
//
|
||||
// It is safe to run in multi-replica deployments when Redis is available (leader lock).
|
||||
type OpsAggregationService struct {
|
||||
opsRepo OpsRepository
|
||||
settingRepo SettingRepository
|
||||
cfg *config.Config
|
||||
|
||||
db *sql.DB
|
||||
redisClient *redis.Client
|
||||
instanceID string
|
||||
|
||||
stopCh chan struct{}
|
||||
startOnce sync.Once
|
||||
stopOnce sync.Once
|
||||
|
||||
hourlyMu sync.Mutex
|
||||
dailyMu sync.Mutex
|
||||
|
||||
skipLogMu sync.Mutex
|
||||
skipLogAt time.Time
|
||||
}
|
||||
|
||||
func NewOpsAggregationService(
|
||||
opsRepo OpsRepository,
|
||||
settingRepo SettingRepository,
|
||||
db *sql.DB,
|
||||
redisClient *redis.Client,
|
||||
cfg *config.Config,
|
||||
) *OpsAggregationService {
|
||||
return &OpsAggregationService{
|
||||
opsRepo: opsRepo,
|
||||
settingRepo: settingRepo,
|
||||
cfg: cfg,
|
||||
db: db,
|
||||
redisClient: redisClient,
|
||||
instanceID: uuid.NewString(),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OpsAggregationService) Start() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.startOnce.Do(func() {
|
||||
if s.stopCh == nil {
|
||||
s.stopCh = make(chan struct{})
|
||||
}
|
||||
go s.hourlyLoop()
|
||||
go s.dailyLoop()
|
||||
})
|
||||
}
|
||||
|
||||
func (s *OpsAggregationService) Stop() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.stopOnce.Do(func() {
|
||||
if s.stopCh != nil {
|
||||
close(s.stopCh)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (s *OpsAggregationService) hourlyLoop() {
|
||||
// First run immediately.
|
||||
s.aggregateHourly()
|
||||
|
||||
ticker := time.NewTicker(opsAggHourlyInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
s.aggregateHourly()
|
||||
case <-s.stopCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OpsAggregationService) dailyLoop() {
|
||||
// First run immediately.
|
||||
s.aggregateDaily()
|
||||
|
||||
ticker := time.NewTicker(opsAggDailyInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
s.aggregateDaily()
|
||||
case <-s.stopCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OpsAggregationService) aggregateHourly() {
|
||||
if s == nil || s.opsRepo == nil {
|
||||
return
|
||||
}
|
||||
if s.cfg != nil {
|
||||
if !s.cfg.Ops.Enabled {
|
||||
return
|
||||
}
|
||||
if !s.cfg.Ops.Aggregation.Enabled {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), opsAggHourlyTimeout)
|
||||
defer cancel()
|
||||
|
||||
if !s.isMonitoringEnabled(ctx) {
|
||||
return
|
||||
}
|
||||
|
||||
release, ok := s.tryAcquireLeaderLock(ctx, opsAggHourlyLeaderLockKey, opsAggHourlyLeaderLockTTL, "[OpsAggregation][hourly]")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if release != nil {
|
||||
defer release()
|
||||
}
|
||||
|
||||
s.hourlyMu.Lock()
|
||||
defer s.hourlyMu.Unlock()
|
||||
|
||||
startedAt := time.Now().UTC()
|
||||
runAt := startedAt
|
||||
|
||||
// Aggregate stable full hours only.
|
||||
end := utcFloorToHour(time.Now().UTC().Add(-opsAggSafeDelay))
|
||||
start := end.Add(-opsAggBackfillWindow)
|
||||
|
||||
// Resume from the latest bucket with overlap.
|
||||
{
|
||||
ctxMax, cancelMax := context.WithTimeout(context.Background(), opsAggMaxQueryTimeout)
|
||||
latest, ok, err := s.opsRepo.GetLatestHourlyBucketStart(ctxMax)
|
||||
cancelMax()
|
||||
if err != nil {
|
||||
log.Printf("[OpsAggregation][hourly] failed to read latest bucket: %v", err)
|
||||
} else if ok {
|
||||
candidate := latest.Add(-opsAggHourlyOverlap)
|
||||
if candidate.After(start) {
|
||||
start = candidate
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
start = utcFloorToHour(start)
|
||||
if !start.Before(end) {
|
||||
return
|
||||
}
|
||||
|
||||
var aggErr error
|
||||
for cursor := start; cursor.Before(end); cursor = cursor.Add(opsAggHourlyChunk) {
|
||||
chunkEnd := minTime(cursor.Add(opsAggHourlyChunk), end)
|
||||
if err := s.opsRepo.UpsertHourlyMetrics(ctx, cursor, chunkEnd); err != nil {
|
||||
aggErr = err
|
||||
log.Printf("[OpsAggregation][hourly] upsert failed (%s..%s): %v", cursor.Format(time.RFC3339), chunkEnd.Format(time.RFC3339), err)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
finishedAt := time.Now().UTC()
|
||||
durationMs := finishedAt.Sub(startedAt).Milliseconds()
|
||||
dur := durationMs
|
||||
|
||||
if aggErr != nil {
|
||||
msg := truncateString(aggErr.Error(), 2048)
|
||||
errAt := finishedAt
|
||||
hbCtx, hbCancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer hbCancel()
|
||||
_ = s.opsRepo.UpsertJobHeartbeat(hbCtx, &OpsUpsertJobHeartbeatInput{
|
||||
JobName: opsAggHourlyJobName,
|
||||
LastRunAt: &runAt,
|
||||
LastErrorAt: &errAt,
|
||||
LastError: &msg,
|
||||
LastDurationMs: &dur,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
successAt := finishedAt
|
||||
hbCtx, hbCancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer hbCancel()
|
||||
_ = s.opsRepo.UpsertJobHeartbeat(hbCtx, &OpsUpsertJobHeartbeatInput{
|
||||
JobName: opsAggHourlyJobName,
|
||||
LastRunAt: &runAt,
|
||||
LastSuccessAt: &successAt,
|
||||
LastDurationMs: &dur,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *OpsAggregationService) aggregateDaily() {
|
||||
if s == nil || s.opsRepo == nil {
|
||||
return
|
||||
}
|
||||
if s.cfg != nil {
|
||||
if !s.cfg.Ops.Enabled {
|
||||
return
|
||||
}
|
||||
if !s.cfg.Ops.Aggregation.Enabled {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), opsAggDailyTimeout)
|
||||
defer cancel()
|
||||
|
||||
if !s.isMonitoringEnabled(ctx) {
|
||||
return
|
||||
}
|
||||
|
||||
release, ok := s.tryAcquireLeaderLock(ctx, opsAggDailyLeaderLockKey, opsAggDailyLeaderLockTTL, "[OpsAggregation][daily]")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if release != nil {
|
||||
defer release()
|
||||
}
|
||||
|
||||
s.dailyMu.Lock()
|
||||
defer s.dailyMu.Unlock()
|
||||
|
||||
startedAt := time.Now().UTC()
|
||||
runAt := startedAt
|
||||
|
||||
end := utcFloorToDay(time.Now().UTC())
|
||||
start := end.Add(-opsAggBackfillWindow)
|
||||
|
||||
{
|
||||
ctxMax, cancelMax := context.WithTimeout(context.Background(), opsAggMaxQueryTimeout)
|
||||
latest, ok, err := s.opsRepo.GetLatestDailyBucketDate(ctxMax)
|
||||
cancelMax()
|
||||
if err != nil {
|
||||
log.Printf("[OpsAggregation][daily] failed to read latest bucket: %v", err)
|
||||
} else if ok {
|
||||
candidate := latest.Add(-opsAggDailyOverlap)
|
||||
if candidate.After(start) {
|
||||
start = candidate
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
start = utcFloorToDay(start)
|
||||
if !start.Before(end) {
|
||||
return
|
||||
}
|
||||
|
||||
var aggErr error
|
||||
for cursor := start; cursor.Before(end); cursor = cursor.Add(opsAggDailyChunk) {
|
||||
chunkEnd := minTime(cursor.Add(opsAggDailyChunk), end)
|
||||
if err := s.opsRepo.UpsertDailyMetrics(ctx, cursor, chunkEnd); err != nil {
|
||||
aggErr = err
|
||||
log.Printf("[OpsAggregation][daily] upsert failed (%s..%s): %v", cursor.Format("2006-01-02"), chunkEnd.Format("2006-01-02"), err)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
finishedAt := time.Now().UTC()
|
||||
durationMs := finishedAt.Sub(startedAt).Milliseconds()
|
||||
dur := durationMs
|
||||
|
||||
if aggErr != nil {
|
||||
msg := truncateString(aggErr.Error(), 2048)
|
||||
errAt := finishedAt
|
||||
hbCtx, hbCancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer hbCancel()
|
||||
_ = s.opsRepo.UpsertJobHeartbeat(hbCtx, &OpsUpsertJobHeartbeatInput{
|
||||
JobName: opsAggDailyJobName,
|
||||
LastRunAt: &runAt,
|
||||
LastErrorAt: &errAt,
|
||||
LastError: &msg,
|
||||
LastDurationMs: &dur,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
successAt := finishedAt
|
||||
hbCtx, hbCancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer hbCancel()
|
||||
_ = s.opsRepo.UpsertJobHeartbeat(hbCtx, &OpsUpsertJobHeartbeatInput{
|
||||
JobName: opsAggDailyJobName,
|
||||
LastRunAt: &runAt,
|
||||
LastSuccessAt: &successAt,
|
||||
LastDurationMs: &dur,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *OpsAggregationService) isMonitoringEnabled(ctx context.Context) bool {
|
||||
if s == nil {
|
||||
return false
|
||||
}
|
||||
if s.cfg != nil && !s.cfg.Ops.Enabled {
|
||||
return false
|
||||
}
|
||||
if s.settingRepo == nil {
|
||||
return true
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
value, err := s.settingRepo.GetValue(ctx, SettingKeyOpsMonitoringEnabled)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrSettingNotFound) {
|
||||
return true
|
||||
}
|
||||
return true
|
||||
}
|
||||
switch strings.ToLower(strings.TrimSpace(value)) {
|
||||
case "false", "0", "off", "disabled":
|
||||
return false
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
var opsAggReleaseScript = redis.NewScript(`
|
||||
if redis.call("GET", KEYS[1]) == ARGV[1] then
|
||||
return redis.call("DEL", KEYS[1])
|
||||
end
|
||||
return 0
|
||||
`)
|
||||
|
||||
func (s *OpsAggregationService) tryAcquireLeaderLock(ctx context.Context, key string, ttl time.Duration, logPrefix string) (func(), bool) {
|
||||
if s == nil {
|
||||
return nil, false
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
// Prefer Redis leader lock when available (multi-instance), but avoid stampeding
|
||||
// the DB when Redis is flaky by falling back to a DB advisory lock.
|
||||
if s.redisClient != nil {
|
||||
ok, err := s.redisClient.SetNX(ctx, key, s.instanceID, ttl).Result()
|
||||
if err == nil {
|
||||
if !ok {
|
||||
s.maybeLogSkip(logPrefix)
|
||||
return nil, false
|
||||
}
|
||||
release := func() {
|
||||
ctx2, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
_, _ = opsAggReleaseScript.Run(ctx2, s.redisClient, []string{key}, s.instanceID).Result()
|
||||
}
|
||||
return release, true
|
||||
}
|
||||
// Redis error: fall through to DB advisory lock.
|
||||
}
|
||||
|
||||
release, ok := tryAcquireDBAdvisoryLock(ctx, s.db, hashAdvisoryLockID(key))
|
||||
if !ok {
|
||||
s.maybeLogSkip(logPrefix)
|
||||
return nil, false
|
||||
}
|
||||
return release, true
|
||||
}
|
||||
|
||||
func (s *OpsAggregationService) maybeLogSkip(prefix string) {
|
||||
s.skipLogMu.Lock()
|
||||
defer s.skipLogMu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
if !s.skipLogAt.IsZero() && now.Sub(s.skipLogAt) < time.Minute {
|
||||
return
|
||||
}
|
||||
s.skipLogAt = now
|
||||
if prefix == "" {
|
||||
prefix = "[OpsAggregation]"
|
||||
}
|
||||
log.Printf("%s leader lock held by another instance; skipping", prefix)
|
||||
}
|
||||
|
||||
func utcFloorToHour(t time.Time) time.Time {
|
||||
return t.UTC().Truncate(time.Hour)
|
||||
}
|
||||
|
||||
func utcFloorToDay(t time.Time) time.Time {
|
||||
u := t.UTC()
|
||||
y, m, d := u.Date()
|
||||
return time.Date(y, m, d, 0, 0, 0, 0, time.UTC)
|
||||
}
|
||||
|
||||
func minTime(a, b time.Time) time.Time {
|
||||
if a.Before(b) {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
922
backend/internal/service/ops_alert_evaluator_service.go
Normal file
922
backend/internal/service/ops_alert_evaluator_service.go
Normal file
@@ -0,0 +1,922 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"math"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/google/uuid"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const (
|
||||
opsAlertEvaluatorJobName = "ops_alert_evaluator"
|
||||
|
||||
opsAlertEvaluatorTimeout = 45 * time.Second
|
||||
opsAlertEvaluatorLeaderLockKey = "ops:alert:evaluator:leader"
|
||||
opsAlertEvaluatorLeaderLockTTL = 90 * time.Second
|
||||
opsAlertEvaluatorSkipLogInterval = 1 * time.Minute
|
||||
)
|
||||
|
||||
var opsAlertEvaluatorReleaseScript = redis.NewScript(`
|
||||
if redis.call("GET", KEYS[1]) == ARGV[1] then
|
||||
return redis.call("DEL", KEYS[1])
|
||||
end
|
||||
return 0
|
||||
`)
|
||||
|
||||
type OpsAlertEvaluatorService struct {
|
||||
opsService *OpsService
|
||||
opsRepo OpsRepository
|
||||
emailService *EmailService
|
||||
|
||||
redisClient *redis.Client
|
||||
cfg *config.Config
|
||||
instanceID string
|
||||
|
||||
stopCh chan struct{}
|
||||
startOnce sync.Once
|
||||
stopOnce sync.Once
|
||||
wg sync.WaitGroup
|
||||
|
||||
mu sync.Mutex
|
||||
ruleStates map[int64]*opsAlertRuleState
|
||||
|
||||
emailLimiter *slidingWindowLimiter
|
||||
|
||||
skipLogMu sync.Mutex
|
||||
skipLogAt time.Time
|
||||
|
||||
warnNoRedisOnce sync.Once
|
||||
}
|
||||
|
||||
type opsAlertRuleState struct {
|
||||
LastEvaluatedAt time.Time
|
||||
ConsecutiveBreaches int
|
||||
}
|
||||
|
||||
func NewOpsAlertEvaluatorService(
|
||||
opsService *OpsService,
|
||||
opsRepo OpsRepository,
|
||||
emailService *EmailService,
|
||||
redisClient *redis.Client,
|
||||
cfg *config.Config,
|
||||
) *OpsAlertEvaluatorService {
|
||||
return &OpsAlertEvaluatorService{
|
||||
opsService: opsService,
|
||||
opsRepo: opsRepo,
|
||||
emailService: emailService,
|
||||
redisClient: redisClient,
|
||||
cfg: cfg,
|
||||
instanceID: uuid.NewString(),
|
||||
ruleStates: map[int64]*opsAlertRuleState{},
|
||||
emailLimiter: newSlidingWindowLimiter(0, time.Hour),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OpsAlertEvaluatorService) Start() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.startOnce.Do(func() {
|
||||
if s.stopCh == nil {
|
||||
s.stopCh = make(chan struct{})
|
||||
}
|
||||
go s.run()
|
||||
})
|
||||
}
|
||||
|
||||
func (s *OpsAlertEvaluatorService) Stop() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.stopOnce.Do(func() {
|
||||
if s.stopCh != nil {
|
||||
close(s.stopCh)
|
||||
}
|
||||
})
|
||||
s.wg.Wait()
|
||||
}
|
||||
|
||||
func (s *OpsAlertEvaluatorService) run() {
|
||||
s.wg.Add(1)
|
||||
defer s.wg.Done()
|
||||
|
||||
// Start immediately to produce early feedback in ops dashboard.
|
||||
timer := time.NewTimer(0)
|
||||
defer timer.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-timer.C:
|
||||
interval := s.getInterval()
|
||||
s.evaluateOnce(interval)
|
||||
timer.Reset(interval)
|
||||
case <-s.stopCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OpsAlertEvaluatorService) getInterval() time.Duration {
|
||||
// Default.
|
||||
interval := 60 * time.Second
|
||||
|
||||
if s == nil || s.opsService == nil {
|
||||
return interval
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cfg, err := s.opsService.GetOpsAlertRuntimeSettings(ctx)
|
||||
if err != nil || cfg == nil {
|
||||
return interval
|
||||
}
|
||||
if cfg.EvaluationIntervalSeconds <= 0 {
|
||||
return interval
|
||||
}
|
||||
if cfg.EvaluationIntervalSeconds < 1 {
|
||||
return interval
|
||||
}
|
||||
if cfg.EvaluationIntervalSeconds > int((24 * time.Hour).Seconds()) {
|
||||
return interval
|
||||
}
|
||||
return time.Duration(cfg.EvaluationIntervalSeconds) * time.Second
|
||||
}
|
||||
|
||||
func (s *OpsAlertEvaluatorService) evaluateOnce(interval time.Duration) {
|
||||
if s == nil || s.opsRepo == nil {
|
||||
return
|
||||
}
|
||||
if s.cfg != nil && !s.cfg.Ops.Enabled {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), opsAlertEvaluatorTimeout)
|
||||
defer cancel()
|
||||
|
||||
if s.opsService != nil && !s.opsService.IsMonitoringEnabled(ctx) {
|
||||
return
|
||||
}
|
||||
|
||||
runtimeCfg := defaultOpsAlertRuntimeSettings()
|
||||
if s.opsService != nil {
|
||||
if loaded, err := s.opsService.GetOpsAlertRuntimeSettings(ctx); err == nil && loaded != nil {
|
||||
runtimeCfg = loaded
|
||||
}
|
||||
}
|
||||
|
||||
release, ok := s.tryAcquireLeaderLock(ctx, runtimeCfg.DistributedLock)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if release != nil {
|
||||
defer release()
|
||||
}
|
||||
|
||||
startedAt := time.Now().UTC()
|
||||
runAt := startedAt
|
||||
|
||||
rules, err := s.opsRepo.ListAlertRules(ctx)
|
||||
if err != nil {
|
||||
s.recordHeartbeatError(runAt, time.Since(startedAt), err)
|
||||
log.Printf("[OpsAlertEvaluator] list rules failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
safeEnd := now.Truncate(time.Minute)
|
||||
if safeEnd.IsZero() {
|
||||
safeEnd = now
|
||||
}
|
||||
|
||||
systemMetrics, _ := s.opsRepo.GetLatestSystemMetrics(ctx, 1)
|
||||
|
||||
// Cleanup stale state for removed rules.
|
||||
s.pruneRuleStates(rules)
|
||||
|
||||
for _, rule := range rules {
|
||||
if rule == nil || !rule.Enabled || rule.ID <= 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
scopePlatform, scopeGroupID, scopeRegion := parseOpsAlertRuleScope(rule.Filters)
|
||||
|
||||
windowMinutes := rule.WindowMinutes
|
||||
if windowMinutes <= 0 {
|
||||
windowMinutes = 1
|
||||
}
|
||||
windowStart := safeEnd.Add(-time.Duration(windowMinutes) * time.Minute)
|
||||
windowEnd := safeEnd
|
||||
|
||||
metricValue, ok := s.computeRuleMetric(ctx, rule, systemMetrics, windowStart, windowEnd, scopePlatform, scopeGroupID)
|
||||
if !ok {
|
||||
s.resetRuleState(rule.ID, now)
|
||||
continue
|
||||
}
|
||||
|
||||
breachedNow := compareMetric(metricValue, rule.Operator, rule.Threshold)
|
||||
required := requiredSustainedBreaches(rule.SustainedMinutes, interval)
|
||||
consecutive := s.updateRuleBreaches(rule.ID, now, interval, breachedNow)
|
||||
|
||||
activeEvent, err := s.opsRepo.GetActiveAlertEvent(ctx, rule.ID)
|
||||
if err != nil {
|
||||
log.Printf("[OpsAlertEvaluator] get active event failed (rule=%d): %v", rule.ID, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if breachedNow && consecutive >= required {
|
||||
if activeEvent != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Scoped silencing: if a matching silence exists, skip creating a firing event.
|
||||
if s.opsService != nil {
|
||||
platform := strings.TrimSpace(scopePlatform)
|
||||
region := scopeRegion
|
||||
if platform != "" {
|
||||
if ok, err := s.opsService.IsAlertSilenced(ctx, rule.ID, platform, scopeGroupID, region, now); err == nil && ok {
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
latestEvent, err := s.opsRepo.GetLatestAlertEvent(ctx, rule.ID)
|
||||
if err != nil {
|
||||
log.Printf("[OpsAlertEvaluator] get latest event failed (rule=%d): %v", rule.ID, err)
|
||||
continue
|
||||
}
|
||||
if latestEvent != nil && rule.CooldownMinutes > 0 {
|
||||
cooldown := time.Duration(rule.CooldownMinutes) * time.Minute
|
||||
if now.Sub(latestEvent.FiredAt) < cooldown {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
firedEvent := &OpsAlertEvent{
|
||||
RuleID: rule.ID,
|
||||
Severity: strings.TrimSpace(rule.Severity),
|
||||
Status: OpsAlertStatusFiring,
|
||||
Title: fmt.Sprintf("%s: %s", strings.TrimSpace(rule.Severity), strings.TrimSpace(rule.Name)),
|
||||
Description: buildOpsAlertDescription(rule, metricValue, windowMinutes, scopePlatform, scopeGroupID),
|
||||
MetricValue: float64Ptr(metricValue),
|
||||
ThresholdValue: float64Ptr(rule.Threshold),
|
||||
Dimensions: buildOpsAlertDimensions(scopePlatform, scopeGroupID),
|
||||
FiredAt: now,
|
||||
CreatedAt: now,
|
||||
}
|
||||
|
||||
created, err := s.opsRepo.CreateAlertEvent(ctx, firedEvent)
|
||||
if err != nil {
|
||||
log.Printf("[OpsAlertEvaluator] create event failed (rule=%d): %v", rule.ID, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if created != nil && created.ID > 0 {
|
||||
s.maybeSendAlertEmail(ctx, runtimeCfg, rule, created)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Not breached: resolve active event if present.
|
||||
if activeEvent != nil {
|
||||
resolvedAt := now
|
||||
if err := s.opsRepo.UpdateAlertEventStatus(ctx, activeEvent.ID, OpsAlertStatusResolved, &resolvedAt); err != nil {
|
||||
log.Printf("[OpsAlertEvaluator] resolve event failed (event=%d): %v", activeEvent.ID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
s.recordHeartbeatSuccess(runAt, time.Since(startedAt))
|
||||
}
|
||||
|
||||
func (s *OpsAlertEvaluatorService) pruneRuleStates(rules []*OpsAlertRule) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
live := map[int64]struct{}{}
|
||||
for _, r := range rules {
|
||||
if r != nil && r.ID > 0 {
|
||||
live[r.ID] = struct{}{}
|
||||
}
|
||||
}
|
||||
for id := range s.ruleStates {
|
||||
if _, ok := live[id]; !ok {
|
||||
delete(s.ruleStates, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OpsAlertEvaluatorService) resetRuleState(ruleID int64, now time.Time) {
|
||||
if ruleID <= 0 {
|
||||
return
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
state, ok := s.ruleStates[ruleID]
|
||||
if !ok {
|
||||
state = &opsAlertRuleState{}
|
||||
s.ruleStates[ruleID] = state
|
||||
}
|
||||
state.LastEvaluatedAt = now
|
||||
state.ConsecutiveBreaches = 0
|
||||
}
|
||||
|
||||
func (s *OpsAlertEvaluatorService) updateRuleBreaches(ruleID int64, now time.Time, interval time.Duration, breached bool) int {
|
||||
if ruleID <= 0 {
|
||||
return 0
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
state, ok := s.ruleStates[ruleID]
|
||||
if !ok {
|
||||
state = &opsAlertRuleState{}
|
||||
s.ruleStates[ruleID] = state
|
||||
}
|
||||
|
||||
if !state.LastEvaluatedAt.IsZero() && interval > 0 {
|
||||
if now.Sub(state.LastEvaluatedAt) > interval*2 {
|
||||
state.ConsecutiveBreaches = 0
|
||||
}
|
||||
}
|
||||
|
||||
state.LastEvaluatedAt = now
|
||||
if breached {
|
||||
state.ConsecutiveBreaches++
|
||||
} else {
|
||||
state.ConsecutiveBreaches = 0
|
||||
}
|
||||
return state.ConsecutiveBreaches
|
||||
}
|
||||
|
||||
func requiredSustainedBreaches(sustainedMinutes int, interval time.Duration) int {
|
||||
if sustainedMinutes <= 0 {
|
||||
return 1
|
||||
}
|
||||
if interval <= 0 {
|
||||
return sustainedMinutes
|
||||
}
|
||||
required := int(math.Ceil(float64(sustainedMinutes*60) / interval.Seconds()))
|
||||
if required < 1 {
|
||||
return 1
|
||||
}
|
||||
return required
|
||||
}
|
||||
|
||||
func parseOpsAlertRuleScope(filters map[string]any) (platform string, groupID *int64, region *string) {
|
||||
if filters == nil {
|
||||
return "", nil, nil
|
||||
}
|
||||
if v, ok := filters["platform"]; ok {
|
||||
if s, ok := v.(string); ok {
|
||||
platform = strings.TrimSpace(s)
|
||||
}
|
||||
}
|
||||
if v, ok := filters["group_id"]; ok {
|
||||
switch t := v.(type) {
|
||||
case float64:
|
||||
if t > 0 {
|
||||
id := int64(t)
|
||||
groupID = &id
|
||||
}
|
||||
case int64:
|
||||
if t > 0 {
|
||||
id := t
|
||||
groupID = &id
|
||||
}
|
||||
case int:
|
||||
if t > 0 {
|
||||
id := int64(t)
|
||||
groupID = &id
|
||||
}
|
||||
case string:
|
||||
n, err := strconv.ParseInt(strings.TrimSpace(t), 10, 64)
|
||||
if err == nil && n > 0 {
|
||||
groupID = &n
|
||||
}
|
||||
}
|
||||
}
|
||||
if v, ok := filters["region"]; ok {
|
||||
if s, ok := v.(string); ok {
|
||||
vv := strings.TrimSpace(s)
|
||||
if vv != "" {
|
||||
region = &vv
|
||||
}
|
||||
}
|
||||
}
|
||||
return platform, groupID, region
|
||||
}
|
||||
|
||||
func (s *OpsAlertEvaluatorService) computeRuleMetric(
|
||||
ctx context.Context,
|
||||
rule *OpsAlertRule,
|
||||
systemMetrics *OpsSystemMetricsSnapshot,
|
||||
start time.Time,
|
||||
end time.Time,
|
||||
platform string,
|
||||
groupID *int64,
|
||||
) (float64, bool) {
|
||||
if rule == nil {
|
||||
return 0, false
|
||||
}
|
||||
switch strings.TrimSpace(rule.MetricType) {
|
||||
case "cpu_usage_percent":
|
||||
if systemMetrics != nil && systemMetrics.CPUUsagePercent != nil {
|
||||
return *systemMetrics.CPUUsagePercent, true
|
||||
}
|
||||
return 0, false
|
||||
case "memory_usage_percent":
|
||||
if systemMetrics != nil && systemMetrics.MemoryUsagePercent != nil {
|
||||
return *systemMetrics.MemoryUsagePercent, true
|
||||
}
|
||||
return 0, false
|
||||
case "concurrency_queue_depth":
|
||||
if systemMetrics != nil && systemMetrics.ConcurrencyQueueDepth != nil {
|
||||
return float64(*systemMetrics.ConcurrencyQueueDepth), true
|
||||
}
|
||||
return 0, false
|
||||
case "group_available_accounts":
|
||||
if groupID == nil || *groupID <= 0 {
|
||||
return 0, false
|
||||
}
|
||||
if s == nil || s.opsService == nil {
|
||||
return 0, false
|
||||
}
|
||||
availability, err := s.opsService.GetAccountAvailability(ctx, platform, groupID)
|
||||
if err != nil || availability == nil {
|
||||
return 0, false
|
||||
}
|
||||
if availability.Group == nil {
|
||||
return 0, true
|
||||
}
|
||||
return float64(availability.Group.AvailableCount), true
|
||||
case "group_available_ratio":
|
||||
if groupID == nil || *groupID <= 0 {
|
||||
return 0, false
|
||||
}
|
||||
if s == nil || s.opsService == nil {
|
||||
return 0, false
|
||||
}
|
||||
availability, err := s.opsService.GetAccountAvailability(ctx, platform, groupID)
|
||||
if err != nil || availability == nil {
|
||||
return 0, false
|
||||
}
|
||||
return computeGroupAvailableRatio(availability.Group), true
|
||||
case "account_rate_limited_count":
|
||||
if s == nil || s.opsService == nil {
|
||||
return 0, false
|
||||
}
|
||||
availability, err := s.opsService.GetAccountAvailability(ctx, platform, groupID)
|
||||
if err != nil || availability == nil {
|
||||
return 0, false
|
||||
}
|
||||
return float64(countAccountsByCondition(availability.Accounts, func(acc *AccountAvailability) bool {
|
||||
return acc.IsRateLimited
|
||||
})), true
|
||||
case "account_error_count":
|
||||
if s == nil || s.opsService == nil {
|
||||
return 0, false
|
||||
}
|
||||
availability, err := s.opsService.GetAccountAvailability(ctx, platform, groupID)
|
||||
if err != nil || availability == nil {
|
||||
return 0, false
|
||||
}
|
||||
return float64(countAccountsByCondition(availability.Accounts, func(acc *AccountAvailability) bool {
|
||||
return acc.HasError && acc.TempUnschedulableUntil == nil
|
||||
})), true
|
||||
}
|
||||
|
||||
overview, err := s.opsRepo.GetDashboardOverview(ctx, &OpsDashboardFilter{
|
||||
StartTime: start,
|
||||
EndTime: end,
|
||||
Platform: platform,
|
||||
GroupID: groupID,
|
||||
QueryMode: OpsQueryModeRaw,
|
||||
})
|
||||
if err != nil {
|
||||
return 0, false
|
||||
}
|
||||
if overview == nil {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
switch strings.TrimSpace(rule.MetricType) {
|
||||
case "success_rate":
|
||||
if overview.RequestCountSLA <= 0 {
|
||||
return 0, false
|
||||
}
|
||||
return overview.SLA * 100, true
|
||||
case "error_rate":
|
||||
if overview.RequestCountSLA <= 0 {
|
||||
return 0, false
|
||||
}
|
||||
return overview.ErrorRate * 100, true
|
||||
case "upstream_error_rate":
|
||||
if overview.RequestCountSLA <= 0 {
|
||||
return 0, false
|
||||
}
|
||||
return overview.UpstreamErrorRate * 100, true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
func compareMetric(value float64, operator string, threshold float64) bool {
|
||||
switch strings.TrimSpace(operator) {
|
||||
case ">":
|
||||
return value > threshold
|
||||
case ">=":
|
||||
return value >= threshold
|
||||
case "<":
|
||||
return value < threshold
|
||||
case "<=":
|
||||
return value <= threshold
|
||||
case "==":
|
||||
return value == threshold
|
||||
case "!=":
|
||||
return value != threshold
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func buildOpsAlertDimensions(platform string, groupID *int64) map[string]any {
|
||||
dims := map[string]any{}
|
||||
if strings.TrimSpace(platform) != "" {
|
||||
dims["platform"] = strings.TrimSpace(platform)
|
||||
}
|
||||
if groupID != nil && *groupID > 0 {
|
||||
dims["group_id"] = *groupID
|
||||
}
|
||||
if len(dims) == 0 {
|
||||
return nil
|
||||
}
|
||||
return dims
|
||||
}
|
||||
|
||||
func buildOpsAlertDescription(rule *OpsAlertRule, value float64, windowMinutes int, platform string, groupID *int64) string {
|
||||
if rule == nil {
|
||||
return ""
|
||||
}
|
||||
scope := "overall"
|
||||
if strings.TrimSpace(platform) != "" {
|
||||
scope = fmt.Sprintf("platform=%s", strings.TrimSpace(platform))
|
||||
}
|
||||
if groupID != nil && *groupID > 0 {
|
||||
scope = fmt.Sprintf("%s group_id=%d", scope, *groupID)
|
||||
}
|
||||
if windowMinutes <= 0 {
|
||||
windowMinutes = 1
|
||||
}
|
||||
return fmt.Sprintf("%s %s %.2f (current %.2f) over last %dm (%s)",
|
||||
strings.TrimSpace(rule.MetricType),
|
||||
strings.TrimSpace(rule.Operator),
|
||||
rule.Threshold,
|
||||
value,
|
||||
windowMinutes,
|
||||
strings.TrimSpace(scope),
|
||||
)
|
||||
}
|
||||
|
||||
func (s *OpsAlertEvaluatorService) maybeSendAlertEmail(ctx context.Context, runtimeCfg *OpsAlertRuntimeSettings, rule *OpsAlertRule, event *OpsAlertEvent) {
|
||||
if s == nil || s.emailService == nil || s.opsService == nil || event == nil || rule == nil {
|
||||
return
|
||||
}
|
||||
if event.EmailSent {
|
||||
return
|
||||
}
|
||||
if !rule.NotifyEmail {
|
||||
return
|
||||
}
|
||||
|
||||
emailCfg, err := s.opsService.GetEmailNotificationConfig(ctx)
|
||||
if err != nil || emailCfg == nil || !emailCfg.Alert.Enabled {
|
||||
return
|
||||
}
|
||||
|
||||
if len(emailCfg.Alert.Recipients) == 0 {
|
||||
return
|
||||
}
|
||||
if !shouldSendOpsAlertEmailByMinSeverity(strings.TrimSpace(emailCfg.Alert.MinSeverity), strings.TrimSpace(rule.Severity)) {
|
||||
return
|
||||
}
|
||||
|
||||
if runtimeCfg != nil && runtimeCfg.Silencing.Enabled {
|
||||
if isOpsAlertSilenced(time.Now().UTC(), rule, event, runtimeCfg.Silencing) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Apply/update rate limiter.
|
||||
s.emailLimiter.SetLimit(emailCfg.Alert.RateLimitPerHour)
|
||||
|
||||
subject := fmt.Sprintf("[Ops Alert][%s] %s", strings.TrimSpace(rule.Severity), strings.TrimSpace(rule.Name))
|
||||
body := buildOpsAlertEmailBody(rule, event)
|
||||
|
||||
anySent := false
|
||||
for _, to := range emailCfg.Alert.Recipients {
|
||||
addr := strings.TrimSpace(to)
|
||||
if addr == "" {
|
||||
continue
|
||||
}
|
||||
if !s.emailLimiter.Allow(time.Now().UTC()) {
|
||||
continue
|
||||
}
|
||||
if err := s.emailService.SendEmail(ctx, addr, subject, body); err != nil {
|
||||
// Ignore per-recipient failures; continue best-effort.
|
||||
continue
|
||||
}
|
||||
anySent = true
|
||||
}
|
||||
|
||||
if anySent {
|
||||
_ = s.opsRepo.UpdateAlertEventEmailSent(context.Background(), event.ID, true)
|
||||
}
|
||||
}
|
||||
|
||||
func buildOpsAlertEmailBody(rule *OpsAlertRule, event *OpsAlertEvent) string {
|
||||
if rule == nil || event == nil {
|
||||
return ""
|
||||
}
|
||||
metric := strings.TrimSpace(rule.MetricType)
|
||||
value := "-"
|
||||
threshold := fmt.Sprintf("%.2f", rule.Threshold)
|
||||
if event.MetricValue != nil {
|
||||
value = fmt.Sprintf("%.2f", *event.MetricValue)
|
||||
}
|
||||
if event.ThresholdValue != nil {
|
||||
threshold = fmt.Sprintf("%.2f", *event.ThresholdValue)
|
||||
}
|
||||
return fmt.Sprintf(`
|
||||
<h2>Ops Alert</h2>
|
||||
<p><b>Rule</b>: %s</p>
|
||||
<p><b>Severity</b>: %s</p>
|
||||
<p><b>Status</b>: %s</p>
|
||||
<p><b>Metric</b>: %s %s %s</p>
|
||||
<p><b>Fired at</b>: %s</p>
|
||||
<p><b>Description</b>: %s</p>
|
||||
`,
|
||||
htmlEscape(rule.Name),
|
||||
htmlEscape(rule.Severity),
|
||||
htmlEscape(event.Status),
|
||||
htmlEscape(metric),
|
||||
htmlEscape(rule.Operator),
|
||||
htmlEscape(fmt.Sprintf("%s (threshold %s)", value, threshold)),
|
||||
event.FiredAt.Format(time.RFC3339),
|
||||
htmlEscape(event.Description),
|
||||
)
|
||||
}
|
||||
|
||||
func shouldSendOpsAlertEmailByMinSeverity(minSeverity string, ruleSeverity string) bool {
|
||||
minSeverity = strings.ToLower(strings.TrimSpace(minSeverity))
|
||||
if minSeverity == "" {
|
||||
return true
|
||||
}
|
||||
|
||||
eventLevel := opsEmailSeverityForOps(ruleSeverity)
|
||||
minLevel := strings.ToLower(minSeverity)
|
||||
|
||||
rank := func(level string) int {
|
||||
switch level {
|
||||
case "critical":
|
||||
return 3
|
||||
case "warning":
|
||||
return 2
|
||||
case "info":
|
||||
return 1
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
return rank(eventLevel) >= rank(minLevel)
|
||||
}
|
||||
|
||||
func opsEmailSeverityForOps(severity string) string {
|
||||
switch strings.ToUpper(strings.TrimSpace(severity)) {
|
||||
case "P0":
|
||||
return "critical"
|
||||
case "P1":
|
||||
return "warning"
|
||||
default:
|
||||
return "info"
|
||||
}
|
||||
}
|
||||
|
||||
func isOpsAlertSilenced(now time.Time, rule *OpsAlertRule, event *OpsAlertEvent, silencing OpsAlertSilencingSettings) bool {
|
||||
if !silencing.Enabled {
|
||||
return false
|
||||
}
|
||||
if now.IsZero() {
|
||||
now = time.Now().UTC()
|
||||
}
|
||||
if strings.TrimSpace(silencing.GlobalUntilRFC3339) != "" {
|
||||
if t, err := time.Parse(time.RFC3339, strings.TrimSpace(silencing.GlobalUntilRFC3339)); err == nil {
|
||||
if now.Before(t) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, entry := range silencing.Entries {
|
||||
untilRaw := strings.TrimSpace(entry.UntilRFC3339)
|
||||
if untilRaw == "" {
|
||||
continue
|
||||
}
|
||||
until, err := time.Parse(time.RFC3339, untilRaw)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if now.After(until) {
|
||||
continue
|
||||
}
|
||||
if entry.RuleID != nil && rule != nil && rule.ID > 0 && *entry.RuleID != rule.ID {
|
||||
continue
|
||||
}
|
||||
if len(entry.Severities) > 0 {
|
||||
match := false
|
||||
for _, s := range entry.Severities {
|
||||
if strings.EqualFold(strings.TrimSpace(s), strings.TrimSpace(event.Severity)) || strings.EqualFold(strings.TrimSpace(s), strings.TrimSpace(rule.Severity)) {
|
||||
match = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !match {
|
||||
continue
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *OpsAlertEvaluatorService) tryAcquireLeaderLock(ctx context.Context, lock OpsDistributedLockSettings) (func(), bool) {
|
||||
if !lock.Enabled {
|
||||
return nil, true
|
||||
}
|
||||
if s.redisClient == nil {
|
||||
s.warnNoRedisOnce.Do(func() {
|
||||
log.Printf("[OpsAlertEvaluator] redis not configured; running without distributed lock")
|
||||
})
|
||||
return nil, true
|
||||
}
|
||||
key := strings.TrimSpace(lock.Key)
|
||||
if key == "" {
|
||||
key = opsAlertEvaluatorLeaderLockKey
|
||||
}
|
||||
ttl := time.Duration(lock.TTLSeconds) * time.Second
|
||||
if ttl <= 0 {
|
||||
ttl = opsAlertEvaluatorLeaderLockTTL
|
||||
}
|
||||
|
||||
ok, err := s.redisClient.SetNX(ctx, key, s.instanceID, ttl).Result()
|
||||
if err != nil {
|
||||
// Prefer fail-closed to avoid duplicate evaluators stampeding the DB when Redis is flaky.
|
||||
// Single-node deployments can disable the distributed lock via runtime settings.
|
||||
s.warnNoRedisOnce.Do(func() {
|
||||
log.Printf("[OpsAlertEvaluator] leader lock SetNX failed; skipping this cycle: %v", err)
|
||||
})
|
||||
return nil, false
|
||||
}
|
||||
if !ok {
|
||||
s.maybeLogSkip(key)
|
||||
return nil, false
|
||||
}
|
||||
return func() {
|
||||
_, _ = opsAlertEvaluatorReleaseScript.Run(ctx, s.redisClient, []string{key}, s.instanceID).Result()
|
||||
}, true
|
||||
}
|
||||
|
||||
func (s *OpsAlertEvaluatorService) maybeLogSkip(key string) {
|
||||
s.skipLogMu.Lock()
|
||||
defer s.skipLogMu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
if !s.skipLogAt.IsZero() && now.Sub(s.skipLogAt) < opsAlertEvaluatorSkipLogInterval {
|
||||
return
|
||||
}
|
||||
s.skipLogAt = now
|
||||
log.Printf("[OpsAlertEvaluator] leader lock held by another instance; skipping (key=%q)", key)
|
||||
}
|
||||
|
||||
func (s *OpsAlertEvaluatorService) recordHeartbeatSuccess(runAt time.Time, duration time.Duration) {
|
||||
if s == nil || s.opsRepo == nil {
|
||||
return
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
durMs := duration.Milliseconds()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
_ = s.opsRepo.UpsertJobHeartbeat(ctx, &OpsUpsertJobHeartbeatInput{
|
||||
JobName: opsAlertEvaluatorJobName,
|
||||
LastRunAt: &runAt,
|
||||
LastSuccessAt: &now,
|
||||
LastDurationMs: &durMs,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *OpsAlertEvaluatorService) recordHeartbeatError(runAt time.Time, duration time.Duration, err error) {
|
||||
if s == nil || s.opsRepo == nil || err == nil {
|
||||
return
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
durMs := duration.Milliseconds()
|
||||
msg := truncateString(err.Error(), 2048)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
_ = s.opsRepo.UpsertJobHeartbeat(ctx, &OpsUpsertJobHeartbeatInput{
|
||||
JobName: opsAlertEvaluatorJobName,
|
||||
LastRunAt: &runAt,
|
||||
LastErrorAt: &now,
|
||||
LastError: &msg,
|
||||
LastDurationMs: &durMs,
|
||||
})
|
||||
}
|
||||
|
||||
func htmlEscape(s string) string {
|
||||
replacer := strings.NewReplacer(
|
||||
"&", "&",
|
||||
"<", "<",
|
||||
">", ">",
|
||||
`"`, """,
|
||||
"'", "'",
|
||||
)
|
||||
return replacer.Replace(s)
|
||||
}
|
||||
|
||||
type slidingWindowLimiter struct {
|
||||
mu sync.Mutex
|
||||
limit int
|
||||
window time.Duration
|
||||
sent []time.Time
|
||||
}
|
||||
|
||||
func newSlidingWindowLimiter(limit int, window time.Duration) *slidingWindowLimiter {
|
||||
if window <= 0 {
|
||||
window = time.Hour
|
||||
}
|
||||
return &slidingWindowLimiter{
|
||||
limit: limit,
|
||||
window: window,
|
||||
sent: []time.Time{},
|
||||
}
|
||||
}
|
||||
|
||||
func (l *slidingWindowLimiter) SetLimit(limit int) {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
l.limit = limit
|
||||
}
|
||||
|
||||
func (l *slidingWindowLimiter) Allow(now time.Time) bool {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
if l.limit <= 0 {
|
||||
return true
|
||||
}
|
||||
cutoff := now.Add(-l.window)
|
||||
keep := l.sent[:0]
|
||||
for _, t := range l.sent {
|
||||
if t.After(cutoff) {
|
||||
keep = append(keep, t)
|
||||
}
|
||||
}
|
||||
l.sent = keep
|
||||
if len(l.sent) >= l.limit {
|
||||
return false
|
||||
}
|
||||
l.sent = append(l.sent, now)
|
||||
return true
|
||||
}
|
||||
|
||||
// computeGroupAvailableRatio returns the available percentage for a group.
|
||||
// Formula: (AvailableCount / TotalAccounts) * 100.
|
||||
// Returns 0 when TotalAccounts is 0.
|
||||
func computeGroupAvailableRatio(group *GroupAvailability) float64 {
|
||||
if group == nil || group.TotalAccounts <= 0 {
|
||||
return 0
|
||||
}
|
||||
return (float64(group.AvailableCount) / float64(group.TotalAccounts)) * 100
|
||||
}
|
||||
|
||||
// countAccountsByCondition counts accounts that satisfy the given condition.
|
||||
func countAccountsByCondition(accounts map[int64]*AccountAvailability, condition func(*AccountAvailability) bool) int64 {
|
||||
if len(accounts) == 0 || condition == nil {
|
||||
return 0
|
||||
}
|
||||
var count int64
|
||||
for _, account := range accounts {
|
||||
if account != nil && condition(account) {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
210
backend/internal/service/ops_alert_evaluator_service_test.go
Normal file
210
backend/internal/service/ops_alert_evaluator_service_test.go
Normal file
@@ -0,0 +1,210 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type stubOpsRepo struct {
|
||||
OpsRepository
|
||||
overview *OpsDashboardOverview
|
||||
err error
|
||||
}
|
||||
|
||||
func (s *stubOpsRepo) GetDashboardOverview(ctx context.Context, filter *OpsDashboardFilter) (*OpsDashboardOverview, error) {
|
||||
if s.err != nil {
|
||||
return nil, s.err
|
||||
}
|
||||
if s.overview != nil {
|
||||
return s.overview, nil
|
||||
}
|
||||
return &OpsDashboardOverview{}, nil
|
||||
}
|
||||
|
||||
func TestComputeGroupAvailableRatio(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("正常情况: 10个账号, 8个可用 = 80%", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := computeGroupAvailableRatio(&GroupAvailability{
|
||||
TotalAccounts: 10,
|
||||
AvailableCount: 8,
|
||||
})
|
||||
require.InDelta(t, 80.0, got, 0.0001)
|
||||
})
|
||||
|
||||
t.Run("边界情况: TotalAccounts = 0 应返回 0", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := computeGroupAvailableRatio(&GroupAvailability{
|
||||
TotalAccounts: 0,
|
||||
AvailableCount: 8,
|
||||
})
|
||||
require.Equal(t, 0.0, got)
|
||||
})
|
||||
|
||||
t.Run("边界情况: AvailableCount = 0 应返回 0%", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := computeGroupAvailableRatio(&GroupAvailability{
|
||||
TotalAccounts: 10,
|
||||
AvailableCount: 0,
|
||||
})
|
||||
require.Equal(t, 0.0, got)
|
||||
})
|
||||
}
|
||||
|
||||
func TestCountAccountsByCondition(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("测试限流账号统计: acc.IsRateLimited", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
accounts := map[int64]*AccountAvailability{
|
||||
1: {IsRateLimited: true},
|
||||
2: {IsRateLimited: false},
|
||||
3: {IsRateLimited: true},
|
||||
}
|
||||
|
||||
got := countAccountsByCondition(accounts, func(acc *AccountAvailability) bool {
|
||||
return acc.IsRateLimited
|
||||
})
|
||||
require.Equal(t, int64(2), got)
|
||||
})
|
||||
|
||||
t.Run("测试错误账号统计(排除临时不可调度): acc.HasError && acc.TempUnschedulableUntil == nil", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
until := time.Now().UTC().Add(5 * time.Minute)
|
||||
accounts := map[int64]*AccountAvailability{
|
||||
1: {HasError: true},
|
||||
2: {HasError: true, TempUnschedulableUntil: &until},
|
||||
3: {HasError: false},
|
||||
}
|
||||
|
||||
got := countAccountsByCondition(accounts, func(acc *AccountAvailability) bool {
|
||||
return acc.HasError && acc.TempUnschedulableUntil == nil
|
||||
})
|
||||
require.Equal(t, int64(1), got)
|
||||
})
|
||||
|
||||
t.Run("边界情况: 空 map 应返回 0", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := countAccountsByCondition(map[int64]*AccountAvailability{}, func(acc *AccountAvailability) bool {
|
||||
return acc.IsRateLimited
|
||||
})
|
||||
require.Equal(t, int64(0), got)
|
||||
})
|
||||
}
|
||||
|
||||
func TestComputeRuleMetricNewIndicators(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
groupID := int64(101)
|
||||
platform := "openai"
|
||||
|
||||
availability := &OpsAccountAvailability{
|
||||
Group: &GroupAvailability{
|
||||
GroupID: groupID,
|
||||
TotalAccounts: 10,
|
||||
AvailableCount: 8,
|
||||
},
|
||||
Accounts: map[int64]*AccountAvailability{
|
||||
1: {IsRateLimited: true},
|
||||
2: {IsRateLimited: true},
|
||||
3: {HasError: true},
|
||||
4: {HasError: true, TempUnschedulableUntil: timePtr(time.Now().UTC().Add(2 * time.Minute))},
|
||||
5: {HasError: false, IsRateLimited: false},
|
||||
},
|
||||
}
|
||||
|
||||
opsService := &OpsService{
|
||||
getAccountAvailability: func(_ context.Context, _ string, _ *int64) (*OpsAccountAvailability, error) {
|
||||
return availability, nil
|
||||
},
|
||||
}
|
||||
|
||||
svc := &OpsAlertEvaluatorService{
|
||||
opsService: opsService,
|
||||
opsRepo: &stubOpsRepo{overview: &OpsDashboardOverview{}},
|
||||
}
|
||||
|
||||
start := time.Now().UTC().Add(-5 * time.Minute)
|
||||
end := time.Now().UTC()
|
||||
ctx := context.Background()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
metricType string
|
||||
groupID *int64
|
||||
wantValue float64
|
||||
wantOK bool
|
||||
}{
|
||||
{
|
||||
name: "group_available_accounts",
|
||||
metricType: "group_available_accounts",
|
||||
groupID: &groupID,
|
||||
wantValue: 8,
|
||||
wantOK: true,
|
||||
},
|
||||
{
|
||||
name: "group_available_ratio",
|
||||
metricType: "group_available_ratio",
|
||||
groupID: &groupID,
|
||||
wantValue: 80.0,
|
||||
wantOK: true,
|
||||
},
|
||||
{
|
||||
name: "account_rate_limited_count",
|
||||
metricType: "account_rate_limited_count",
|
||||
groupID: nil,
|
||||
wantValue: 2,
|
||||
wantOK: true,
|
||||
},
|
||||
{
|
||||
name: "account_error_count",
|
||||
metricType: "account_error_count",
|
||||
groupID: nil,
|
||||
wantValue: 1,
|
||||
wantOK: true,
|
||||
},
|
||||
{
|
||||
name: "group_available_accounts without group_id returns false",
|
||||
metricType: "group_available_accounts",
|
||||
groupID: nil,
|
||||
wantValue: 0,
|
||||
wantOK: false,
|
||||
},
|
||||
{
|
||||
name: "group_available_ratio without group_id returns false",
|
||||
metricType: "group_available_ratio",
|
||||
groupID: nil,
|
||||
wantValue: 0,
|
||||
wantOK: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rule := &OpsAlertRule{
|
||||
MetricType: tt.metricType,
|
||||
}
|
||||
gotValue, gotOK := svc.computeRuleMetric(ctx, rule, nil, start, end, platform, tt.groupID)
|
||||
require.Equal(t, tt.wantOK, gotOK)
|
||||
if !tt.wantOK {
|
||||
return
|
||||
}
|
||||
require.InDelta(t, tt.wantValue, gotValue, 0.0001)
|
||||
})
|
||||
}
|
||||
}
|
||||
95
backend/internal/service/ops_alert_models.go
Normal file
95
backend/internal/service/ops_alert_models.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package service
|
||||
|
||||
import "time"
|
||||
|
||||
// Ops alert rule/event models.
|
||||
//
|
||||
// NOTE: These are admin-facing DTOs and intentionally keep JSON naming aligned
|
||||
// with the existing ops dashboard frontend (backup style).
|
||||
|
||||
const (
|
||||
OpsAlertStatusFiring = "firing"
|
||||
OpsAlertStatusResolved = "resolved"
|
||||
OpsAlertStatusManualResolved = "manual_resolved"
|
||||
)
|
||||
|
||||
type OpsAlertRule struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
|
||||
Enabled bool `json:"enabled"`
|
||||
Severity string `json:"severity"`
|
||||
|
||||
MetricType string `json:"metric_type"`
|
||||
Operator string `json:"operator"`
|
||||
Threshold float64 `json:"threshold"`
|
||||
|
||||
WindowMinutes int `json:"window_minutes"`
|
||||
SustainedMinutes int `json:"sustained_minutes"`
|
||||
CooldownMinutes int `json:"cooldown_minutes"`
|
||||
|
||||
NotifyEmail bool `json:"notify_email"`
|
||||
|
||||
Filters map[string]any `json:"filters,omitempty"`
|
||||
|
||||
LastTriggeredAt *time.Time `json:"last_triggered_at,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
type OpsAlertEvent struct {
|
||||
ID int64 `json:"id"`
|
||||
RuleID int64 `json:"rule_id"`
|
||||
Severity string `json:"severity"`
|
||||
Status string `json:"status"`
|
||||
|
||||
Title string `json:"title"`
|
||||
Description string `json:"description"`
|
||||
|
||||
MetricValue *float64 `json:"metric_value,omitempty"`
|
||||
ThresholdValue *float64 `json:"threshold_value,omitempty"`
|
||||
|
||||
Dimensions map[string]any `json:"dimensions,omitempty"`
|
||||
|
||||
FiredAt time.Time `json:"fired_at"`
|
||||
ResolvedAt *time.Time `json:"resolved_at,omitempty"`
|
||||
|
||||
EmailSent bool `json:"email_sent"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
type OpsAlertSilence struct {
|
||||
ID int64 `json:"id"`
|
||||
|
||||
RuleID int64 `json:"rule_id"`
|
||||
Platform string `json:"platform"`
|
||||
GroupID *int64 `json:"group_id,omitempty"`
|
||||
Region *string `json:"region,omitempty"`
|
||||
|
||||
Until time.Time `json:"until"`
|
||||
Reason string `json:"reason"`
|
||||
|
||||
CreatedBy *int64 `json:"created_by,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
type OpsAlertEventFilter struct {
|
||||
Limit int
|
||||
|
||||
// Cursor pagination (descending by fired_at, then id).
|
||||
BeforeFiredAt *time.Time
|
||||
BeforeID *int64
|
||||
|
||||
// Optional filters.
|
||||
Status string
|
||||
Severity string
|
||||
EmailSent *bool
|
||||
|
||||
StartTime *time.Time
|
||||
EndTime *time.Time
|
||||
|
||||
// Dimensions filters (best-effort).
|
||||
Platform string
|
||||
GroupID *int64
|
||||
}
|
||||
232
backend/internal/service/ops_alerts.go
Normal file
232
backend/internal/service/ops_alerts.go
Normal file
@@ -0,0 +1,232 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
)
|
||||
|
||||
func (s *OpsService) ListAlertRules(ctx context.Context) ([]*OpsAlertRule, error) {
|
||||
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.opsRepo == nil {
|
||||
return []*OpsAlertRule{}, nil
|
||||
}
|
||||
return s.opsRepo.ListAlertRules(ctx)
|
||||
}
|
||||
|
||||
func (s *OpsService) CreateAlertRule(ctx context.Context, rule *OpsAlertRule) (*OpsAlertRule, error) {
|
||||
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.opsRepo == nil {
|
||||
return nil, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
|
||||
}
|
||||
if rule == nil {
|
||||
return nil, infraerrors.BadRequest("INVALID_RULE", "invalid rule")
|
||||
}
|
||||
|
||||
created, err := s.opsRepo.CreateAlertRule(ctx, rule)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return created, nil
|
||||
}
|
||||
|
||||
func (s *OpsService) UpdateAlertRule(ctx context.Context, rule *OpsAlertRule) (*OpsAlertRule, error) {
|
||||
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.opsRepo == nil {
|
||||
return nil, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
|
||||
}
|
||||
if rule == nil || rule.ID <= 0 {
|
||||
return nil, infraerrors.BadRequest("INVALID_RULE", "invalid rule")
|
||||
}
|
||||
|
||||
updated, err := s.opsRepo.UpdateAlertRule(ctx, rule)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, infraerrors.NotFound("OPS_ALERT_RULE_NOT_FOUND", "alert rule not found")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
func (s *OpsService) DeleteAlertRule(ctx context.Context, id int64) error {
|
||||
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
if s.opsRepo == nil {
|
||||
return infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
|
||||
}
|
||||
if id <= 0 {
|
||||
return infraerrors.BadRequest("INVALID_RULE_ID", "invalid rule id")
|
||||
}
|
||||
if err := s.opsRepo.DeleteAlertRule(ctx, id); err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return infraerrors.NotFound("OPS_ALERT_RULE_NOT_FOUND", "alert rule not found")
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *OpsService) ListAlertEvents(ctx context.Context, filter *OpsAlertEventFilter) ([]*OpsAlertEvent, error) {
|
||||
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.opsRepo == nil {
|
||||
return []*OpsAlertEvent{}, nil
|
||||
}
|
||||
return s.opsRepo.ListAlertEvents(ctx, filter)
|
||||
}
|
||||
|
||||
func (s *OpsService) GetAlertEventByID(ctx context.Context, eventID int64) (*OpsAlertEvent, error) {
|
||||
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.opsRepo == nil {
|
||||
return nil, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
|
||||
}
|
||||
if eventID <= 0 {
|
||||
return nil, infraerrors.BadRequest("INVALID_EVENT_ID", "invalid event id")
|
||||
}
|
||||
ev, err := s.opsRepo.GetAlertEventByID(ctx, eventID)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, infraerrors.NotFound("OPS_ALERT_EVENT_NOT_FOUND", "alert event not found")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
if ev == nil {
|
||||
return nil, infraerrors.NotFound("OPS_ALERT_EVENT_NOT_FOUND", "alert event not found")
|
||||
}
|
||||
return ev, nil
|
||||
}
|
||||
|
||||
func (s *OpsService) GetActiveAlertEvent(ctx context.Context, ruleID int64) (*OpsAlertEvent, error) {
|
||||
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.opsRepo == nil {
|
||||
return nil, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
|
||||
}
|
||||
if ruleID <= 0 {
|
||||
return nil, infraerrors.BadRequest("INVALID_RULE_ID", "invalid rule id")
|
||||
}
|
||||
return s.opsRepo.GetActiveAlertEvent(ctx, ruleID)
|
||||
}
|
||||
|
||||
func (s *OpsService) CreateAlertSilence(ctx context.Context, input *OpsAlertSilence) (*OpsAlertSilence, error) {
|
||||
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.opsRepo == nil {
|
||||
return nil, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
|
||||
}
|
||||
if input == nil {
|
||||
return nil, infraerrors.BadRequest("INVALID_SILENCE", "invalid silence")
|
||||
}
|
||||
if input.RuleID <= 0 {
|
||||
return nil, infraerrors.BadRequest("INVALID_RULE_ID", "invalid rule id")
|
||||
}
|
||||
if strings.TrimSpace(input.Platform) == "" {
|
||||
return nil, infraerrors.BadRequest("INVALID_PLATFORM", "invalid platform")
|
||||
}
|
||||
if input.Until.IsZero() {
|
||||
return nil, infraerrors.BadRequest("INVALID_UNTIL", "invalid until")
|
||||
}
|
||||
|
||||
created, err := s.opsRepo.CreateAlertSilence(ctx, input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return created, nil
|
||||
}
|
||||
|
||||
func (s *OpsService) IsAlertSilenced(ctx context.Context, ruleID int64, platform string, groupID *int64, region *string, now time.Time) (bool, error) {
|
||||
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
||||
return false, err
|
||||
}
|
||||
if s.opsRepo == nil {
|
||||
return false, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
|
||||
}
|
||||
if ruleID <= 0 {
|
||||
return false, infraerrors.BadRequest("INVALID_RULE_ID", "invalid rule id")
|
||||
}
|
||||
if strings.TrimSpace(platform) == "" {
|
||||
return false, nil
|
||||
}
|
||||
return s.opsRepo.IsAlertSilenced(ctx, ruleID, platform, groupID, region, now)
|
||||
}
|
||||
|
||||
func (s *OpsService) GetLatestAlertEvent(ctx context.Context, ruleID int64) (*OpsAlertEvent, error) {
|
||||
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.opsRepo == nil {
|
||||
return nil, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
|
||||
}
|
||||
if ruleID <= 0 {
|
||||
return nil, infraerrors.BadRequest("INVALID_RULE_ID", "invalid rule id")
|
||||
}
|
||||
return s.opsRepo.GetLatestAlertEvent(ctx, ruleID)
|
||||
}
|
||||
|
||||
func (s *OpsService) CreateAlertEvent(ctx context.Context, event *OpsAlertEvent) (*OpsAlertEvent, error) {
|
||||
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.opsRepo == nil {
|
||||
return nil, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
|
||||
}
|
||||
if event == nil {
|
||||
return nil, infraerrors.BadRequest("INVALID_EVENT", "invalid event")
|
||||
}
|
||||
|
||||
created, err := s.opsRepo.CreateAlertEvent(ctx, event)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return created, nil
|
||||
}
|
||||
|
||||
func (s *OpsService) UpdateAlertEventStatus(ctx context.Context, eventID int64, status string, resolvedAt *time.Time) error {
|
||||
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
if s.opsRepo == nil {
|
||||
return infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
|
||||
}
|
||||
if eventID <= 0 {
|
||||
return infraerrors.BadRequest("INVALID_EVENT_ID", "invalid event id")
|
||||
}
|
||||
status = strings.TrimSpace(status)
|
||||
if status == "" {
|
||||
return infraerrors.BadRequest("INVALID_STATUS", "invalid status")
|
||||
}
|
||||
if status != OpsAlertStatusResolved && status != OpsAlertStatusManualResolved {
|
||||
return infraerrors.BadRequest("INVALID_STATUS", "invalid status")
|
||||
}
|
||||
return s.opsRepo.UpdateAlertEventStatus(ctx, eventID, status, resolvedAt)
|
||||
}
|
||||
|
||||
func (s *OpsService) UpdateAlertEventEmailSent(ctx context.Context, eventID int64, emailSent bool) error {
|
||||
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
if s.opsRepo == nil {
|
||||
return infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
|
||||
}
|
||||
if eventID <= 0 {
|
||||
return infraerrors.BadRequest("INVALID_EVENT_ID", "invalid event id")
|
||||
}
|
||||
return s.opsRepo.UpdateAlertEventEmailSent(ctx, eventID, emailSent)
|
||||
}
|
||||
365
backend/internal/service/ops_cleanup_service.go
Normal file
365
backend/internal/service/ops_cleanup_service.go
Normal file
@@ -0,0 +1,365 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/google/uuid"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/robfig/cron/v3"
|
||||
)
|
||||
|
||||
const (
|
||||
opsCleanupJobName = "ops_cleanup"
|
||||
|
||||
opsCleanupLeaderLockKeyDefault = "ops:cleanup:leader"
|
||||
opsCleanupLeaderLockTTLDefault = 30 * time.Minute
|
||||
)
|
||||
|
||||
var opsCleanupCronParser = cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow)
|
||||
|
||||
var opsCleanupReleaseScript = redis.NewScript(`
|
||||
if redis.call("GET", KEYS[1]) == ARGV[1] then
|
||||
return redis.call("DEL", KEYS[1])
|
||||
end
|
||||
return 0
|
||||
`)
|
||||
|
||||
// OpsCleanupService periodically deletes old ops data to prevent unbounded DB growth.
|
||||
//
|
||||
// - Scheduling: 5-field cron spec (minute hour dom month dow).
|
||||
// - Multi-instance: best-effort Redis leader lock so only one node runs cleanup.
|
||||
// - Safety: deletes in batches to avoid long transactions.
|
||||
type OpsCleanupService struct {
|
||||
opsRepo OpsRepository
|
||||
db *sql.DB
|
||||
redisClient *redis.Client
|
||||
cfg *config.Config
|
||||
|
||||
instanceID string
|
||||
|
||||
cron *cron.Cron
|
||||
|
||||
startOnce sync.Once
|
||||
stopOnce sync.Once
|
||||
|
||||
warnNoRedisOnce sync.Once
|
||||
}
|
||||
|
||||
func NewOpsCleanupService(
|
||||
opsRepo OpsRepository,
|
||||
db *sql.DB,
|
||||
redisClient *redis.Client,
|
||||
cfg *config.Config,
|
||||
) *OpsCleanupService {
|
||||
return &OpsCleanupService{
|
||||
opsRepo: opsRepo,
|
||||
db: db,
|
||||
redisClient: redisClient,
|
||||
cfg: cfg,
|
||||
instanceID: uuid.NewString(),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OpsCleanupService) Start() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
if s.cfg != nil && !s.cfg.Ops.Enabled {
|
||||
return
|
||||
}
|
||||
if s.cfg != nil && !s.cfg.Ops.Cleanup.Enabled {
|
||||
log.Printf("[OpsCleanup] not started (disabled)")
|
||||
return
|
||||
}
|
||||
if s.opsRepo == nil || s.db == nil {
|
||||
log.Printf("[OpsCleanup] not started (missing deps)")
|
||||
return
|
||||
}
|
||||
|
||||
s.startOnce.Do(func() {
|
||||
schedule := "0 2 * * *"
|
||||
if s.cfg != nil && strings.TrimSpace(s.cfg.Ops.Cleanup.Schedule) != "" {
|
||||
schedule = strings.TrimSpace(s.cfg.Ops.Cleanup.Schedule)
|
||||
}
|
||||
|
||||
loc := time.Local
|
||||
if s.cfg != nil && strings.TrimSpace(s.cfg.Timezone) != "" {
|
||||
if parsed, err := time.LoadLocation(strings.TrimSpace(s.cfg.Timezone)); err == nil && parsed != nil {
|
||||
loc = parsed
|
||||
}
|
||||
}
|
||||
|
||||
c := cron.New(cron.WithParser(opsCleanupCronParser), cron.WithLocation(loc))
|
||||
_, err := c.AddFunc(schedule, func() { s.runScheduled() })
|
||||
if err != nil {
|
||||
log.Printf("[OpsCleanup] not started (invalid schedule=%q): %v", schedule, err)
|
||||
return
|
||||
}
|
||||
s.cron = c
|
||||
s.cron.Start()
|
||||
log.Printf("[OpsCleanup] started (schedule=%q tz=%s)", schedule, loc.String())
|
||||
})
|
||||
}
|
||||
|
||||
func (s *OpsCleanupService) Stop() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.stopOnce.Do(func() {
|
||||
if s.cron != nil {
|
||||
ctx := s.cron.Stop()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-time.After(3 * time.Second):
|
||||
log.Printf("[OpsCleanup] cron stop timed out")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (s *OpsCleanupService) runScheduled() {
|
||||
if s == nil || s.db == nil || s.opsRepo == nil {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
release, ok := s.tryAcquireLeaderLock(ctx)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if release != nil {
|
||||
defer release()
|
||||
}
|
||||
|
||||
startedAt := time.Now().UTC()
|
||||
runAt := startedAt
|
||||
|
||||
counts, err := s.runCleanupOnce(ctx)
|
||||
if err != nil {
|
||||
s.recordHeartbeatError(runAt, time.Since(startedAt), err)
|
||||
log.Printf("[OpsCleanup] cleanup failed: %v", err)
|
||||
return
|
||||
}
|
||||
s.recordHeartbeatSuccess(runAt, time.Since(startedAt))
|
||||
log.Printf("[OpsCleanup] cleanup complete: %s", counts)
|
||||
}
|
||||
|
||||
type opsCleanupDeletedCounts struct {
|
||||
errorLogs int64
|
||||
retryAttempts int64
|
||||
alertEvents int64
|
||||
systemMetrics int64
|
||||
hourlyPreagg int64
|
||||
dailyPreagg int64
|
||||
}
|
||||
|
||||
func (c opsCleanupDeletedCounts) String() string {
|
||||
return fmt.Sprintf(
|
||||
"error_logs=%d retry_attempts=%d alert_events=%d system_metrics=%d hourly_preagg=%d daily_preagg=%d",
|
||||
c.errorLogs,
|
||||
c.retryAttempts,
|
||||
c.alertEvents,
|
||||
c.systemMetrics,
|
||||
c.hourlyPreagg,
|
||||
c.dailyPreagg,
|
||||
)
|
||||
}
|
||||
|
||||
func (s *OpsCleanupService) runCleanupOnce(ctx context.Context) (opsCleanupDeletedCounts, error) {
|
||||
out := opsCleanupDeletedCounts{}
|
||||
if s == nil || s.db == nil || s.cfg == nil {
|
||||
return out, nil
|
||||
}
|
||||
|
||||
batchSize := 5000
|
||||
|
||||
now := time.Now().UTC()
|
||||
|
||||
// Error-like tables: error logs / retry attempts / alert events.
|
||||
if days := s.cfg.Ops.Cleanup.ErrorLogRetentionDays; days > 0 {
|
||||
cutoff := now.AddDate(0, 0, -days)
|
||||
n, err := deleteOldRowsByID(ctx, s.db, "ops_error_logs", "created_at", cutoff, batchSize, false)
|
||||
if err != nil {
|
||||
return out, err
|
||||
}
|
||||
out.errorLogs = n
|
||||
|
||||
n, err = deleteOldRowsByID(ctx, s.db, "ops_retry_attempts", "created_at", cutoff, batchSize, false)
|
||||
if err != nil {
|
||||
return out, err
|
||||
}
|
||||
out.retryAttempts = n
|
||||
|
||||
n, err = deleteOldRowsByID(ctx, s.db, "ops_alert_events", "created_at", cutoff, batchSize, false)
|
||||
if err != nil {
|
||||
return out, err
|
||||
}
|
||||
out.alertEvents = n
|
||||
}
|
||||
|
||||
// Minute-level metrics snapshots.
|
||||
if days := s.cfg.Ops.Cleanup.MinuteMetricsRetentionDays; days > 0 {
|
||||
cutoff := now.AddDate(0, 0, -days)
|
||||
n, err := deleteOldRowsByID(ctx, s.db, "ops_system_metrics", "created_at", cutoff, batchSize, false)
|
||||
if err != nil {
|
||||
return out, err
|
||||
}
|
||||
out.systemMetrics = n
|
||||
}
|
||||
|
||||
// Pre-aggregation tables (hourly/daily).
|
||||
if days := s.cfg.Ops.Cleanup.HourlyMetricsRetentionDays; days > 0 {
|
||||
cutoff := now.AddDate(0, 0, -days)
|
||||
n, err := deleteOldRowsByID(ctx, s.db, "ops_metrics_hourly", "bucket_start", cutoff, batchSize, false)
|
||||
if err != nil {
|
||||
return out, err
|
||||
}
|
||||
out.hourlyPreagg = n
|
||||
|
||||
n, err = deleteOldRowsByID(ctx, s.db, "ops_metrics_daily", "bucket_date", cutoff, batchSize, true)
|
||||
if err != nil {
|
||||
return out, err
|
||||
}
|
||||
out.dailyPreagg = n
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func deleteOldRowsByID(
|
||||
ctx context.Context,
|
||||
db *sql.DB,
|
||||
table string,
|
||||
timeColumn string,
|
||||
cutoff time.Time,
|
||||
batchSize int,
|
||||
castCutoffToDate bool,
|
||||
) (int64, error) {
|
||||
if db == nil {
|
||||
return 0, nil
|
||||
}
|
||||
if batchSize <= 0 {
|
||||
batchSize = 5000
|
||||
}
|
||||
|
||||
where := fmt.Sprintf("%s < $1", timeColumn)
|
||||
if castCutoffToDate {
|
||||
where = fmt.Sprintf("%s < $1::date", timeColumn)
|
||||
}
|
||||
|
||||
q := fmt.Sprintf(`
|
||||
WITH batch AS (
|
||||
SELECT id FROM %s
|
||||
WHERE %s
|
||||
ORDER BY id
|
||||
LIMIT $2
|
||||
)
|
||||
DELETE FROM %s
|
||||
WHERE id IN (SELECT id FROM batch)
|
||||
`, table, where, table)
|
||||
|
||||
var total int64
|
||||
for {
|
||||
res, err := db.ExecContext(ctx, q, cutoff, batchSize)
|
||||
if err != nil {
|
||||
// If ops tables aren't present yet (partial deployments), treat as no-op.
|
||||
if strings.Contains(strings.ToLower(err.Error()), "does not exist") && strings.Contains(strings.ToLower(err.Error()), "relation") {
|
||||
return total, nil
|
||||
}
|
||||
return total, err
|
||||
}
|
||||
affected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return total, err
|
||||
}
|
||||
total += affected
|
||||
if affected == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
return total, nil
|
||||
}
|
||||
|
||||
func (s *OpsCleanupService) tryAcquireLeaderLock(ctx context.Context) (func(), bool) {
|
||||
if s == nil {
|
||||
return nil, false
|
||||
}
|
||||
// In simple run mode, assume single instance.
|
||||
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
|
||||
return nil, true
|
||||
}
|
||||
|
||||
key := opsCleanupLeaderLockKeyDefault
|
||||
ttl := opsCleanupLeaderLockTTLDefault
|
||||
|
||||
// Prefer Redis leader lock when available, but avoid stampeding the DB when Redis is flaky by
|
||||
// falling back to a DB advisory lock.
|
||||
if s.redisClient != nil {
|
||||
ok, err := s.redisClient.SetNX(ctx, key, s.instanceID, ttl).Result()
|
||||
if err == nil {
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
return func() {
|
||||
_, _ = opsCleanupReleaseScript.Run(ctx, s.redisClient, []string{key}, s.instanceID).Result()
|
||||
}, true
|
||||
}
|
||||
// Redis error: fall back to DB advisory lock.
|
||||
s.warnNoRedisOnce.Do(func() {
|
||||
log.Printf("[OpsCleanup] leader lock SetNX failed; falling back to DB advisory lock: %v", err)
|
||||
})
|
||||
} else {
|
||||
s.warnNoRedisOnce.Do(func() {
|
||||
log.Printf("[OpsCleanup] redis not configured; using DB advisory lock")
|
||||
})
|
||||
}
|
||||
|
||||
release, ok := tryAcquireDBAdvisoryLock(ctx, s.db, hashAdvisoryLockID(key))
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
return release, true
|
||||
}
|
||||
|
||||
func (s *OpsCleanupService) recordHeartbeatSuccess(runAt time.Time, duration time.Duration) {
|
||||
if s == nil || s.opsRepo == nil {
|
||||
return
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
durMs := duration.Milliseconds()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
_ = s.opsRepo.UpsertJobHeartbeat(ctx, &OpsUpsertJobHeartbeatInput{
|
||||
JobName: opsCleanupJobName,
|
||||
LastRunAt: &runAt,
|
||||
LastSuccessAt: &now,
|
||||
LastDurationMs: &durMs,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *OpsCleanupService) recordHeartbeatError(runAt time.Time, duration time.Duration, err error) {
|
||||
if s == nil || s.opsRepo == nil || err == nil {
|
||||
return
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
durMs := duration.Milliseconds()
|
||||
msg := truncateString(err.Error(), 2048)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
_ = s.opsRepo.UpsertJobHeartbeat(ctx, &OpsUpsertJobHeartbeatInput{
|
||||
JobName: opsCleanupJobName,
|
||||
LastRunAt: &runAt,
|
||||
LastErrorAt: &now,
|
||||
LastError: &msg,
|
||||
LastDurationMs: &durMs,
|
||||
})
|
||||
}
|
||||
257
backend/internal/service/ops_concurrency.go
Normal file
257
backend/internal/service/ops_concurrency.go
Normal file
@@ -0,0 +1,257 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
const (
|
||||
opsAccountsPageSize = 100
|
||||
opsConcurrencyBatchChunkSize = 200
|
||||
)
|
||||
|
||||
func (s *OpsService) listAllAccountsForOps(ctx context.Context, platformFilter string) ([]Account, error) {
|
||||
if s == nil || s.accountRepo == nil {
|
||||
return []Account{}, nil
|
||||
}
|
||||
|
||||
out := make([]Account, 0, 128)
|
||||
page := 1
|
||||
for {
|
||||
accounts, pageInfo, err := s.accountRepo.ListWithFilters(ctx, pagination.PaginationParams{
|
||||
Page: page,
|
||||
PageSize: opsAccountsPageSize,
|
||||
}, platformFilter, "", "", "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(accounts) == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
out = append(out, accounts...)
|
||||
if pageInfo != nil && int64(len(out)) >= pageInfo.Total {
|
||||
break
|
||||
}
|
||||
if len(accounts) < opsAccountsPageSize {
|
||||
break
|
||||
}
|
||||
|
||||
page++
|
||||
if page > 10_000 {
|
||||
log.Printf("[Ops] listAllAccountsForOps: aborting after too many pages (platform=%q)", platformFilter)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *OpsService) getAccountsLoadMapBestEffort(ctx context.Context, accounts []Account) map[int64]*AccountLoadInfo {
|
||||
if s == nil || s.concurrencyService == nil {
|
||||
return map[int64]*AccountLoadInfo{}
|
||||
}
|
||||
if len(accounts) == 0 {
|
||||
return map[int64]*AccountLoadInfo{}
|
||||
}
|
||||
|
||||
// De-duplicate IDs (and keep the max concurrency to avoid under-reporting).
|
||||
unique := make(map[int64]int, len(accounts))
|
||||
for _, acc := range accounts {
|
||||
if acc.ID <= 0 {
|
||||
continue
|
||||
}
|
||||
if prev, ok := unique[acc.ID]; !ok || acc.Concurrency > prev {
|
||||
unique[acc.ID] = acc.Concurrency
|
||||
}
|
||||
}
|
||||
|
||||
batch := make([]AccountWithConcurrency, 0, len(unique))
|
||||
for id, maxConc := range unique {
|
||||
batch = append(batch, AccountWithConcurrency{
|
||||
ID: id,
|
||||
MaxConcurrency: maxConc,
|
||||
})
|
||||
}
|
||||
|
||||
out := make(map[int64]*AccountLoadInfo, len(batch))
|
||||
for i := 0; i < len(batch); i += opsConcurrencyBatchChunkSize {
|
||||
end := i + opsConcurrencyBatchChunkSize
|
||||
if end > len(batch) {
|
||||
end = len(batch)
|
||||
}
|
||||
part, err := s.concurrencyService.GetAccountsLoadBatch(ctx, batch[i:end])
|
||||
if err != nil {
|
||||
// Best-effort: return zeros rather than failing the ops UI.
|
||||
log.Printf("[Ops] GetAccountsLoadBatch failed: %v", err)
|
||||
continue
|
||||
}
|
||||
for k, v := range part {
|
||||
out[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
// GetConcurrencyStats returns real-time concurrency usage aggregated by platform/group/account.
|
||||
//
|
||||
// Optional filters:
|
||||
// - platformFilter: only include accounts in that platform (best-effort reduces DB load)
|
||||
// - groupIDFilter: only include accounts that belong to that group
|
||||
func (s *OpsService) GetConcurrencyStats(
|
||||
ctx context.Context,
|
||||
platformFilter string,
|
||||
groupIDFilter *int64,
|
||||
) (map[string]*PlatformConcurrencyInfo, map[int64]*GroupConcurrencyInfo, map[int64]*AccountConcurrencyInfo, *time.Time, error) {
|
||||
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
||||
return nil, nil, nil, nil, err
|
||||
}
|
||||
|
||||
accounts, err := s.listAllAccountsForOps(ctx, platformFilter)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, err
|
||||
}
|
||||
|
||||
collectedAt := time.Now()
|
||||
loadMap := s.getAccountsLoadMapBestEffort(ctx, accounts)
|
||||
|
||||
platform := make(map[string]*PlatformConcurrencyInfo)
|
||||
group := make(map[int64]*GroupConcurrencyInfo)
|
||||
account := make(map[int64]*AccountConcurrencyInfo)
|
||||
|
||||
for _, acc := range accounts {
|
||||
if acc.ID <= 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var matchedGroup *Group
|
||||
if groupIDFilter != nil && *groupIDFilter > 0 {
|
||||
for _, grp := range acc.Groups {
|
||||
if grp == nil || grp.ID <= 0 {
|
||||
continue
|
||||
}
|
||||
if grp.ID == *groupIDFilter {
|
||||
matchedGroup = grp
|
||||
break
|
||||
}
|
||||
}
|
||||
// Group filter provided: skip accounts not in that group.
|
||||
if matchedGroup == nil {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
load := loadMap[acc.ID]
|
||||
currentInUse := int64(0)
|
||||
waiting := int64(0)
|
||||
if load != nil {
|
||||
currentInUse = int64(load.CurrentConcurrency)
|
||||
waiting = int64(load.WaitingCount)
|
||||
}
|
||||
|
||||
// Account-level view picks one display group (the first group).
|
||||
displayGroupID := int64(0)
|
||||
displayGroupName := ""
|
||||
if matchedGroup != nil {
|
||||
displayGroupID = matchedGroup.ID
|
||||
displayGroupName = matchedGroup.Name
|
||||
} else if len(acc.Groups) > 0 && acc.Groups[0] != nil {
|
||||
displayGroupID = acc.Groups[0].ID
|
||||
displayGroupName = acc.Groups[0].Name
|
||||
}
|
||||
|
||||
if _, ok := account[acc.ID]; !ok {
|
||||
info := &AccountConcurrencyInfo{
|
||||
AccountID: acc.ID,
|
||||
AccountName: acc.Name,
|
||||
Platform: acc.Platform,
|
||||
GroupID: displayGroupID,
|
||||
GroupName: displayGroupName,
|
||||
CurrentInUse: currentInUse,
|
||||
MaxCapacity: int64(acc.Concurrency),
|
||||
WaitingInQueue: waiting,
|
||||
}
|
||||
if info.MaxCapacity > 0 {
|
||||
info.LoadPercentage = float64(info.CurrentInUse) / float64(info.MaxCapacity) * 100
|
||||
}
|
||||
account[acc.ID] = info
|
||||
}
|
||||
|
||||
// Platform aggregation.
|
||||
if acc.Platform != "" {
|
||||
if _, ok := platform[acc.Platform]; !ok {
|
||||
platform[acc.Platform] = &PlatformConcurrencyInfo{
|
||||
Platform: acc.Platform,
|
||||
}
|
||||
}
|
||||
p := platform[acc.Platform]
|
||||
p.MaxCapacity += int64(acc.Concurrency)
|
||||
p.CurrentInUse += currentInUse
|
||||
p.WaitingInQueue += waiting
|
||||
}
|
||||
|
||||
// Group aggregation (one account may contribute to multiple groups).
|
||||
if matchedGroup != nil {
|
||||
grp := matchedGroup
|
||||
if _, ok := group[grp.ID]; !ok {
|
||||
group[grp.ID] = &GroupConcurrencyInfo{
|
||||
GroupID: grp.ID,
|
||||
GroupName: grp.Name,
|
||||
Platform: grp.Platform,
|
||||
}
|
||||
}
|
||||
g := group[grp.ID]
|
||||
if g.GroupName == "" && grp.Name != "" {
|
||||
g.GroupName = grp.Name
|
||||
}
|
||||
if g.Platform != "" && grp.Platform != "" && g.Platform != grp.Platform {
|
||||
// Groups are expected to be platform-scoped. If mismatch is observed, avoid misleading labels.
|
||||
g.Platform = ""
|
||||
}
|
||||
g.MaxCapacity += int64(acc.Concurrency)
|
||||
g.CurrentInUse += currentInUse
|
||||
g.WaitingInQueue += waiting
|
||||
} else {
|
||||
for _, grp := range acc.Groups {
|
||||
if grp == nil || grp.ID <= 0 {
|
||||
continue
|
||||
}
|
||||
if _, ok := group[grp.ID]; !ok {
|
||||
group[grp.ID] = &GroupConcurrencyInfo{
|
||||
GroupID: grp.ID,
|
||||
GroupName: grp.Name,
|
||||
Platform: grp.Platform,
|
||||
}
|
||||
}
|
||||
g := group[grp.ID]
|
||||
if g.GroupName == "" && grp.Name != "" {
|
||||
g.GroupName = grp.Name
|
||||
}
|
||||
if g.Platform != "" && grp.Platform != "" && g.Platform != grp.Platform {
|
||||
// Groups are expected to be platform-scoped. If mismatch is observed, avoid misleading labels.
|
||||
g.Platform = ""
|
||||
}
|
||||
g.MaxCapacity += int64(acc.Concurrency)
|
||||
g.CurrentInUse += currentInUse
|
||||
g.WaitingInQueue += waiting
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, info := range platform {
|
||||
if info.MaxCapacity > 0 {
|
||||
info.LoadPercentage = float64(info.CurrentInUse) / float64(info.MaxCapacity) * 100
|
||||
}
|
||||
}
|
||||
for _, info := range group {
|
||||
if info.MaxCapacity > 0 {
|
||||
info.LoadPercentage = float64(info.CurrentInUse) / float64(info.MaxCapacity) * 100
|
||||
}
|
||||
}
|
||||
|
||||
return platform, group, account, &collectedAt, nil
|
||||
}
|
||||
90
backend/internal/service/ops_dashboard.go
Normal file
90
backend/internal/service/ops_dashboard.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
)
|
||||
|
||||
func (s *OpsService) GetDashboardOverview(ctx context.Context, filter *OpsDashboardFilter) (*OpsDashboardOverview, error) {
|
||||
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.opsRepo == nil {
|
||||
return nil, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
|
||||
}
|
||||
if filter == nil {
|
||||
return nil, infraerrors.BadRequest("OPS_FILTER_REQUIRED", "filter is required")
|
||||
}
|
||||
if filter.StartTime.IsZero() || filter.EndTime.IsZero() {
|
||||
return nil, infraerrors.BadRequest("OPS_TIME_RANGE_REQUIRED", "start_time/end_time are required")
|
||||
}
|
||||
if filter.StartTime.After(filter.EndTime) {
|
||||
return nil, infraerrors.BadRequest("OPS_TIME_RANGE_INVALID", "start_time must be <= end_time")
|
||||
}
|
||||
|
||||
// Resolve query mode (requested via query param, or DB default).
|
||||
filter.QueryMode = s.resolveOpsQueryMode(ctx, filter.QueryMode)
|
||||
|
||||
overview, err := s.opsRepo.GetDashboardOverview(ctx, filter)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrOpsPreaggregatedNotPopulated) {
|
||||
return nil, infraerrors.Conflict("OPS_PREAGG_NOT_READY", "Pre-aggregated ops metrics are not populated yet")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Best-effort system health + jobs; dashboard metrics should still render if these are missing.
|
||||
if metrics, err := s.opsRepo.GetLatestSystemMetrics(ctx, 1); err == nil {
|
||||
// Attach config-derived limits so the UI can show "current / max" for connection pools.
|
||||
// These are best-effort and should never block the dashboard rendering.
|
||||
if s != nil && s.cfg != nil {
|
||||
if s.cfg.Database.MaxOpenConns > 0 {
|
||||
metrics.DBMaxOpenConns = intPtr(s.cfg.Database.MaxOpenConns)
|
||||
}
|
||||
if s.cfg.Redis.PoolSize > 0 {
|
||||
metrics.RedisPoolSize = intPtr(s.cfg.Redis.PoolSize)
|
||||
}
|
||||
}
|
||||
overview.SystemMetrics = metrics
|
||||
} else if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
log.Printf("[Ops] GetLatestSystemMetrics failed: %v", err)
|
||||
}
|
||||
|
||||
if heartbeats, err := s.opsRepo.ListJobHeartbeats(ctx); err == nil {
|
||||
overview.JobHeartbeats = heartbeats
|
||||
} else {
|
||||
log.Printf("[Ops] ListJobHeartbeats failed: %v", err)
|
||||
}
|
||||
|
||||
overview.HealthScore = computeDashboardHealthScore(time.Now().UTC(), overview)
|
||||
|
||||
return overview, nil
|
||||
}
|
||||
|
||||
func (s *OpsService) resolveOpsQueryMode(ctx context.Context, requested OpsQueryMode) OpsQueryMode {
|
||||
if requested.IsValid() {
|
||||
// Allow "auto" to be disabled via config until preagg is proven stable in production.
|
||||
// Forced `preagg` via query param still works.
|
||||
if requested == OpsQueryModeAuto && s != nil && s.cfg != nil && !s.cfg.Ops.UsePreaggregatedTables {
|
||||
return OpsQueryModeRaw
|
||||
}
|
||||
return requested
|
||||
}
|
||||
|
||||
mode := OpsQueryModeAuto
|
||||
if s != nil && s.settingRepo != nil {
|
||||
if raw, err := s.settingRepo.GetValue(ctx, SettingKeyOpsQueryModeDefault); err == nil {
|
||||
mode = ParseOpsQueryMode(raw)
|
||||
}
|
||||
}
|
||||
|
||||
if mode == OpsQueryModeAuto && s != nil && s.cfg != nil && !s.cfg.Ops.UsePreaggregatedTables {
|
||||
return OpsQueryModeRaw
|
||||
}
|
||||
return mode
|
||||
}
|
||||
87
backend/internal/service/ops_dashboard_models.go
Normal file
87
backend/internal/service/ops_dashboard_models.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package service
|
||||
|
||||
import "time"
|
||||
|
||||
type OpsDashboardFilter struct {
|
||||
StartTime time.Time
|
||||
EndTime time.Time
|
||||
|
||||
Platform string
|
||||
GroupID *int64
|
||||
|
||||
// QueryMode controls whether dashboard queries should use raw logs or pre-aggregated tables.
|
||||
// Expected values: auto/raw/preagg (see OpsQueryMode).
|
||||
QueryMode OpsQueryMode
|
||||
}
|
||||
|
||||
type OpsRateSummary struct {
|
||||
Current float64 `json:"current"`
|
||||
Peak float64 `json:"peak"`
|
||||
Avg float64 `json:"avg"`
|
||||
}
|
||||
|
||||
type OpsPercentiles struct {
|
||||
P50 *int `json:"p50_ms"`
|
||||
P90 *int `json:"p90_ms"`
|
||||
P95 *int `json:"p95_ms"`
|
||||
P99 *int `json:"p99_ms"`
|
||||
Avg *int `json:"avg_ms"`
|
||||
Max *int `json:"max_ms"`
|
||||
}
|
||||
|
||||
type OpsDashboardOverview struct {
|
||||
StartTime time.Time `json:"start_time"`
|
||||
EndTime time.Time `json:"end_time"`
|
||||
Platform string `json:"platform"`
|
||||
GroupID *int64 `json:"group_id"`
|
||||
|
||||
// HealthScore is a backend-computed overall health score (0-100).
|
||||
// It is derived from the monitored metrics in this overview, plus best-effort system metrics/job heartbeats.
|
||||
HealthScore int `json:"health_score"`
|
||||
|
||||
// Latest system-level snapshot (window=1m, global).
|
||||
SystemMetrics *OpsSystemMetricsSnapshot `json:"system_metrics"`
|
||||
|
||||
// Background jobs health (heartbeats).
|
||||
JobHeartbeats []*OpsJobHeartbeat `json:"job_heartbeats"`
|
||||
|
||||
SuccessCount int64 `json:"success_count"`
|
||||
ErrorCountTotal int64 `json:"error_count_total"`
|
||||
BusinessLimitedCount int64 `json:"business_limited_count"`
|
||||
|
||||
ErrorCountSLA int64 `json:"error_count_sla"`
|
||||
RequestCountTotal int64 `json:"request_count_total"`
|
||||
RequestCountSLA int64 `json:"request_count_sla"`
|
||||
|
||||
TokenConsumed int64 `json:"token_consumed"`
|
||||
|
||||
SLA float64 `json:"sla"`
|
||||
ErrorRate float64 `json:"error_rate"`
|
||||
UpstreamErrorRate float64 `json:"upstream_error_rate"`
|
||||
UpstreamErrorCountExcl429529 int64 `json:"upstream_error_count_excl_429_529"`
|
||||
Upstream429Count int64 `json:"upstream_429_count"`
|
||||
Upstream529Count int64 `json:"upstream_529_count"`
|
||||
|
||||
QPS OpsRateSummary `json:"qps"`
|
||||
TPS OpsRateSummary `json:"tps"`
|
||||
|
||||
Duration OpsPercentiles `json:"duration"`
|
||||
TTFT OpsPercentiles `json:"ttft"`
|
||||
}
|
||||
|
||||
type OpsLatencyHistogramBucket struct {
|
||||
Range string `json:"range"`
|
||||
Count int64 `json:"count"`
|
||||
}
|
||||
|
||||
// OpsLatencyHistogramResponse is a coarse latency distribution histogram (success requests only).
|
||||
// It is used by the Ops dashboard to quickly identify tail latency regressions.
|
||||
type OpsLatencyHistogramResponse struct {
|
||||
StartTime time.Time `json:"start_time"`
|
||||
EndTime time.Time `json:"end_time"`
|
||||
Platform string `json:"platform"`
|
||||
GroupID *int64 `json:"group_id"`
|
||||
|
||||
TotalRequests int64 `json:"total_requests"`
|
||||
Buckets []*OpsLatencyHistogramBucket `json:"buckets"`
|
||||
}
|
||||
45
backend/internal/service/ops_errors.go
Normal file
45
backend/internal/service/ops_errors.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
)
|
||||
|
||||
func (s *OpsService) GetErrorTrend(ctx context.Context, filter *OpsDashboardFilter, bucketSeconds int) (*OpsErrorTrendResponse, error) {
|
||||
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.opsRepo == nil {
|
||||
return nil, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
|
||||
}
|
||||
if filter == nil {
|
||||
return nil, infraerrors.BadRequest("OPS_FILTER_REQUIRED", "filter is required")
|
||||
}
|
||||
if filter.StartTime.IsZero() || filter.EndTime.IsZero() {
|
||||
return nil, infraerrors.BadRequest("OPS_TIME_RANGE_REQUIRED", "start_time/end_time are required")
|
||||
}
|
||||
if filter.StartTime.After(filter.EndTime) {
|
||||
return nil, infraerrors.BadRequest("OPS_TIME_RANGE_INVALID", "start_time must be <= end_time")
|
||||
}
|
||||
return s.opsRepo.GetErrorTrend(ctx, filter, bucketSeconds)
|
||||
}
|
||||
|
||||
func (s *OpsService) GetErrorDistribution(ctx context.Context, filter *OpsDashboardFilter) (*OpsErrorDistributionResponse, error) {
|
||||
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.opsRepo == nil {
|
||||
return nil, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
|
||||
}
|
||||
if filter == nil {
|
||||
return nil, infraerrors.BadRequest("OPS_FILTER_REQUIRED", "filter is required")
|
||||
}
|
||||
if filter.StartTime.IsZero() || filter.EndTime.IsZero() {
|
||||
return nil, infraerrors.BadRequest("OPS_TIME_RANGE_REQUIRED", "start_time/end_time are required")
|
||||
}
|
||||
if filter.StartTime.After(filter.EndTime) {
|
||||
return nil, infraerrors.BadRequest("OPS_TIME_RANGE_INVALID", "start_time must be <= end_time")
|
||||
}
|
||||
return s.opsRepo.GetErrorDistribution(ctx, filter)
|
||||
}
|
||||
143
backend/internal/service/ops_health_score.go
Normal file
143
backend/internal/service/ops_health_score.go
Normal file
@@ -0,0 +1,143 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"math"
|
||||
"time"
|
||||
)
|
||||
|
||||
// computeDashboardHealthScore computes a 0-100 health score from the metrics returned by the dashboard overview.
|
||||
//
|
||||
// Design goals:
|
||||
// - Backend-owned scoring (UI only displays).
|
||||
// - Layered scoring: Business Health (70%) + Infrastructure Health (30%)
|
||||
// - Avoids double-counting (e.g., DB failure affects both infra and business metrics)
|
||||
// - Conservative + stable: penalize clear degradations; avoid overreacting to missing/idle data.
|
||||
func computeDashboardHealthScore(now time.Time, overview *OpsDashboardOverview) int {
|
||||
if overview == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
// Idle/no-data: avoid showing a "bad" score when there is no traffic.
|
||||
// UI can still render a gray/idle state based on QPS + error rate.
|
||||
if overview.RequestCountSLA <= 0 && overview.RequestCountTotal <= 0 && overview.ErrorCountTotal <= 0 {
|
||||
return 100
|
||||
}
|
||||
|
||||
businessHealth := computeBusinessHealth(overview)
|
||||
infraHealth := computeInfraHealth(now, overview)
|
||||
|
||||
// Weighted combination: 70% business + 30% infrastructure
|
||||
score := businessHealth*0.7 + infraHealth*0.3
|
||||
return int(math.Round(clampFloat64(score, 0, 100)))
|
||||
}
|
||||
|
||||
// computeBusinessHealth calculates business health score (0-100)
|
||||
// Components: Error Rate (50%) + TTFT (50%)
|
||||
func computeBusinessHealth(overview *OpsDashboardOverview) float64 {
|
||||
// Error rate score: 1% → 100, 10% → 0 (linear)
|
||||
// Combines request errors and upstream errors
|
||||
errorScore := 100.0
|
||||
errorPct := clampFloat64(overview.ErrorRate*100, 0, 100)
|
||||
upstreamPct := clampFloat64(overview.UpstreamErrorRate*100, 0, 100)
|
||||
combinedErrorPct := math.Max(errorPct, upstreamPct) // Use worst case
|
||||
if combinedErrorPct > 1.0 {
|
||||
if combinedErrorPct <= 10.0 {
|
||||
errorScore = (10.0 - combinedErrorPct) / 9.0 * 100
|
||||
} else {
|
||||
errorScore = 0
|
||||
}
|
||||
}
|
||||
|
||||
// TTFT score: 1s → 100, 3s → 0 (linear)
|
||||
// Time to first token is critical for user experience
|
||||
ttftScore := 100.0
|
||||
if overview.TTFT.P99 != nil {
|
||||
p99 := float64(*overview.TTFT.P99)
|
||||
if p99 > 1000 {
|
||||
if p99 <= 3000 {
|
||||
ttftScore = (3000 - p99) / 2000 * 100
|
||||
} else {
|
||||
ttftScore = 0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Weighted combination: 50% error rate + 50% TTFT
|
||||
return errorScore*0.5 + ttftScore*0.5
|
||||
}
|
||||
|
||||
// computeInfraHealth calculates infrastructure health score (0-100)
|
||||
// Components: Storage (40%) + Compute Resources (30%) + Background Jobs (30%)
|
||||
func computeInfraHealth(now time.Time, overview *OpsDashboardOverview) float64 {
|
||||
// Storage score: DB critical, Redis less critical
|
||||
storageScore := 100.0
|
||||
if overview.SystemMetrics != nil {
|
||||
if overview.SystemMetrics.DBOK != nil && !*overview.SystemMetrics.DBOK {
|
||||
storageScore = 0 // DB failure is critical
|
||||
} else if overview.SystemMetrics.RedisOK != nil && !*overview.SystemMetrics.RedisOK {
|
||||
storageScore = 50 // Redis failure is degraded but not critical
|
||||
}
|
||||
}
|
||||
|
||||
// Compute resources score: CPU + Memory
|
||||
computeScore := 100.0
|
||||
if overview.SystemMetrics != nil {
|
||||
cpuScore := 100.0
|
||||
if overview.SystemMetrics.CPUUsagePercent != nil {
|
||||
cpuPct := clampFloat64(*overview.SystemMetrics.CPUUsagePercent, 0, 100)
|
||||
if cpuPct > 80 {
|
||||
if cpuPct <= 100 {
|
||||
cpuScore = (100 - cpuPct) / 20 * 100
|
||||
} else {
|
||||
cpuScore = 0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
memScore := 100.0
|
||||
if overview.SystemMetrics.MemoryUsagePercent != nil {
|
||||
memPct := clampFloat64(*overview.SystemMetrics.MemoryUsagePercent, 0, 100)
|
||||
if memPct > 85 {
|
||||
if memPct <= 100 {
|
||||
memScore = (100 - memPct) / 15 * 100
|
||||
} else {
|
||||
memScore = 0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
computeScore = (cpuScore + memScore) / 2
|
||||
}
|
||||
|
||||
// Background jobs score
|
||||
jobScore := 100.0
|
||||
failedJobs := 0
|
||||
totalJobs := 0
|
||||
for _, hb := range overview.JobHeartbeats {
|
||||
if hb == nil {
|
||||
continue
|
||||
}
|
||||
totalJobs++
|
||||
if hb.LastErrorAt != nil && (hb.LastSuccessAt == nil || hb.LastErrorAt.After(*hb.LastSuccessAt)) {
|
||||
failedJobs++
|
||||
} else if hb.LastSuccessAt != nil && now.Sub(*hb.LastSuccessAt) > 15*time.Minute {
|
||||
failedJobs++
|
||||
}
|
||||
}
|
||||
if totalJobs > 0 && failedJobs > 0 {
|
||||
jobScore = (1 - float64(failedJobs)/float64(totalJobs)) * 100
|
||||
}
|
||||
|
||||
// Weighted combination
|
||||
return storageScore*0.4 + computeScore*0.3 + jobScore*0.3
|
||||
}
|
||||
|
||||
func clampFloat64(v float64, min float64, max float64) float64 {
|
||||
if v < min {
|
||||
return min
|
||||
}
|
||||
if v > max {
|
||||
return max
|
||||
}
|
||||
return v
|
||||
}
|
||||
442
backend/internal/service/ops_health_score_test.go
Normal file
442
backend/internal/service/ops_health_score_test.go
Normal file
@@ -0,0 +1,442 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestComputeDashboardHealthScore_IdleReturns100(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
score := computeDashboardHealthScore(time.Now().UTC(), &OpsDashboardOverview{})
|
||||
require.Equal(t, 100, score)
|
||||
}
|
||||
|
||||
func TestComputeDashboardHealthScore_DegradesOnBadSignals(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ov := &OpsDashboardOverview{
|
||||
RequestCountTotal: 100,
|
||||
RequestCountSLA: 100,
|
||||
SuccessCount: 90,
|
||||
ErrorCountTotal: 10,
|
||||
ErrorCountSLA: 10,
|
||||
|
||||
SLA: 0.90,
|
||||
ErrorRate: 0.10,
|
||||
UpstreamErrorRate: 0.08,
|
||||
|
||||
Duration: OpsPercentiles{P99: intPtr(20_000)},
|
||||
TTFT: OpsPercentiles{P99: intPtr(2_000)},
|
||||
|
||||
SystemMetrics: &OpsSystemMetricsSnapshot{
|
||||
DBOK: boolPtr(false),
|
||||
RedisOK: boolPtr(false),
|
||||
CPUUsagePercent: float64Ptr(98.0),
|
||||
MemoryUsagePercent: float64Ptr(97.0),
|
||||
DBConnWaiting: intPtr(3),
|
||||
ConcurrencyQueueDepth: intPtr(10),
|
||||
},
|
||||
JobHeartbeats: []*OpsJobHeartbeat{
|
||||
{
|
||||
JobName: "job-a",
|
||||
LastErrorAt: timePtr(time.Now().UTC().Add(-1 * time.Minute)),
|
||||
LastError: stringPtr("boom"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
score := computeDashboardHealthScore(time.Now().UTC(), ov)
|
||||
require.Less(t, score, 80)
|
||||
require.GreaterOrEqual(t, score, 0)
|
||||
}
|
||||
|
||||
func TestComputeDashboardHealthScore_Comprehensive(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
overview *OpsDashboardOverview
|
||||
wantMin int
|
||||
wantMax int
|
||||
}{
|
||||
{
|
||||
name: "nil overview returns 0",
|
||||
overview: nil,
|
||||
wantMin: 0,
|
||||
wantMax: 0,
|
||||
},
|
||||
{
|
||||
name: "perfect health",
|
||||
overview: &OpsDashboardOverview{
|
||||
RequestCountTotal: 1000,
|
||||
RequestCountSLA: 1000,
|
||||
SLA: 1.0,
|
||||
ErrorRate: 0,
|
||||
UpstreamErrorRate: 0,
|
||||
Duration: OpsPercentiles{P99: intPtr(500)},
|
||||
TTFT: OpsPercentiles{P99: intPtr(100)},
|
||||
SystemMetrics: &OpsSystemMetricsSnapshot{
|
||||
DBOK: boolPtr(true),
|
||||
RedisOK: boolPtr(true),
|
||||
CPUUsagePercent: float64Ptr(30),
|
||||
MemoryUsagePercent: float64Ptr(40),
|
||||
},
|
||||
},
|
||||
wantMin: 100,
|
||||
wantMax: 100,
|
||||
},
|
||||
{
|
||||
name: "good health - SLA 99.8%",
|
||||
overview: &OpsDashboardOverview{
|
||||
RequestCountTotal: 1000,
|
||||
RequestCountSLA: 1000,
|
||||
SLA: 0.998,
|
||||
ErrorRate: 0.003,
|
||||
UpstreamErrorRate: 0.001,
|
||||
Duration: OpsPercentiles{P99: intPtr(800)},
|
||||
TTFT: OpsPercentiles{P99: intPtr(200)},
|
||||
SystemMetrics: &OpsSystemMetricsSnapshot{
|
||||
DBOK: boolPtr(true),
|
||||
RedisOK: boolPtr(true),
|
||||
CPUUsagePercent: float64Ptr(50),
|
||||
MemoryUsagePercent: float64Ptr(60),
|
||||
},
|
||||
},
|
||||
wantMin: 95,
|
||||
wantMax: 100,
|
||||
},
|
||||
{
|
||||
name: "medium health - SLA 96%",
|
||||
overview: &OpsDashboardOverview{
|
||||
RequestCountTotal: 1000,
|
||||
RequestCountSLA: 1000,
|
||||
SLA: 0.96,
|
||||
ErrorRate: 0.02,
|
||||
UpstreamErrorRate: 0.01,
|
||||
Duration: OpsPercentiles{P99: intPtr(3000)},
|
||||
TTFT: OpsPercentiles{P99: intPtr(600)},
|
||||
SystemMetrics: &OpsSystemMetricsSnapshot{
|
||||
DBOK: boolPtr(true),
|
||||
RedisOK: boolPtr(true),
|
||||
CPUUsagePercent: float64Ptr(70),
|
||||
MemoryUsagePercent: float64Ptr(75),
|
||||
},
|
||||
},
|
||||
wantMin: 96,
|
||||
wantMax: 97,
|
||||
},
|
||||
{
|
||||
name: "DB failure",
|
||||
overview: &OpsDashboardOverview{
|
||||
RequestCountTotal: 1000,
|
||||
RequestCountSLA: 1000,
|
||||
SLA: 0.995,
|
||||
ErrorRate: 0,
|
||||
UpstreamErrorRate: 0,
|
||||
Duration: OpsPercentiles{P99: intPtr(500)},
|
||||
SystemMetrics: &OpsSystemMetricsSnapshot{
|
||||
DBOK: boolPtr(false),
|
||||
RedisOK: boolPtr(true),
|
||||
CPUUsagePercent: float64Ptr(30),
|
||||
MemoryUsagePercent: float64Ptr(40),
|
||||
},
|
||||
},
|
||||
wantMin: 70,
|
||||
wantMax: 90,
|
||||
},
|
||||
{
|
||||
name: "Redis failure",
|
||||
overview: &OpsDashboardOverview{
|
||||
RequestCountTotal: 1000,
|
||||
RequestCountSLA: 1000,
|
||||
SLA: 0.995,
|
||||
ErrorRate: 0,
|
||||
UpstreamErrorRate: 0,
|
||||
Duration: OpsPercentiles{P99: intPtr(500)},
|
||||
SystemMetrics: &OpsSystemMetricsSnapshot{
|
||||
DBOK: boolPtr(true),
|
||||
RedisOK: boolPtr(false),
|
||||
CPUUsagePercent: float64Ptr(30),
|
||||
MemoryUsagePercent: float64Ptr(40),
|
||||
},
|
||||
},
|
||||
wantMin: 85,
|
||||
wantMax: 95,
|
||||
},
|
||||
{
|
||||
name: "high CPU usage",
|
||||
overview: &OpsDashboardOverview{
|
||||
RequestCountTotal: 1000,
|
||||
RequestCountSLA: 1000,
|
||||
SLA: 0.995,
|
||||
ErrorRate: 0,
|
||||
UpstreamErrorRate: 0,
|
||||
Duration: OpsPercentiles{P99: intPtr(500)},
|
||||
SystemMetrics: &OpsSystemMetricsSnapshot{
|
||||
DBOK: boolPtr(true),
|
||||
RedisOK: boolPtr(true),
|
||||
CPUUsagePercent: float64Ptr(95),
|
||||
MemoryUsagePercent: float64Ptr(40),
|
||||
},
|
||||
},
|
||||
wantMin: 85,
|
||||
wantMax: 100,
|
||||
},
|
||||
{
|
||||
name: "combined failures - business degraded + infra healthy",
|
||||
overview: &OpsDashboardOverview{
|
||||
RequestCountTotal: 1000,
|
||||
RequestCountSLA: 1000,
|
||||
SLA: 0.90,
|
||||
ErrorRate: 0.05,
|
||||
UpstreamErrorRate: 0.02,
|
||||
Duration: OpsPercentiles{P99: intPtr(10000)},
|
||||
SystemMetrics: &OpsSystemMetricsSnapshot{
|
||||
DBOK: boolPtr(true),
|
||||
RedisOK: boolPtr(true),
|
||||
CPUUsagePercent: float64Ptr(20),
|
||||
MemoryUsagePercent: float64Ptr(30),
|
||||
},
|
||||
},
|
||||
wantMin: 84,
|
||||
wantMax: 85,
|
||||
},
|
||||
{
|
||||
name: "combined failures - business healthy + infra degraded",
|
||||
overview: &OpsDashboardOverview{
|
||||
RequestCountTotal: 1000,
|
||||
RequestCountSLA: 1000,
|
||||
SLA: 0.998,
|
||||
ErrorRate: 0.001,
|
||||
UpstreamErrorRate: 0,
|
||||
Duration: OpsPercentiles{P99: intPtr(600)},
|
||||
SystemMetrics: &OpsSystemMetricsSnapshot{
|
||||
DBOK: boolPtr(false),
|
||||
RedisOK: boolPtr(false),
|
||||
CPUUsagePercent: float64Ptr(95),
|
||||
MemoryUsagePercent: float64Ptr(95),
|
||||
},
|
||||
},
|
||||
wantMin: 70,
|
||||
wantMax: 90,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
score := computeDashboardHealthScore(time.Now().UTC(), tt.overview)
|
||||
require.GreaterOrEqual(t, score, tt.wantMin, "score should be >= %d", tt.wantMin)
|
||||
require.LessOrEqual(t, score, tt.wantMax, "score should be <= %d", tt.wantMax)
|
||||
require.GreaterOrEqual(t, score, 0, "score must be >= 0")
|
||||
require.LessOrEqual(t, score, 100, "score must be <= 100")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeBusinessHealth(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
overview *OpsDashboardOverview
|
||||
wantMin float64
|
||||
wantMax float64
|
||||
}{
|
||||
{
|
||||
name: "perfect metrics",
|
||||
overview: &OpsDashboardOverview{
|
||||
SLA: 1.0,
|
||||
ErrorRate: 0,
|
||||
UpstreamErrorRate: 0,
|
||||
Duration: OpsPercentiles{P99: intPtr(500)},
|
||||
},
|
||||
wantMin: 100,
|
||||
wantMax: 100,
|
||||
},
|
||||
{
|
||||
name: "SLA boundary 99.5%",
|
||||
overview: &OpsDashboardOverview{
|
||||
SLA: 0.995,
|
||||
ErrorRate: 0,
|
||||
UpstreamErrorRate: 0,
|
||||
Duration: OpsPercentiles{P99: intPtr(500)},
|
||||
},
|
||||
wantMin: 100,
|
||||
wantMax: 100,
|
||||
},
|
||||
{
|
||||
name: "SLA boundary 95%",
|
||||
overview: &OpsDashboardOverview{
|
||||
SLA: 0.95,
|
||||
ErrorRate: 0,
|
||||
UpstreamErrorRate: 0,
|
||||
Duration: OpsPercentiles{P99: intPtr(500)},
|
||||
},
|
||||
wantMin: 100,
|
||||
wantMax: 100,
|
||||
},
|
||||
{
|
||||
name: "error rate boundary 1%",
|
||||
overview: &OpsDashboardOverview{
|
||||
SLA: 0.99,
|
||||
ErrorRate: 0.01,
|
||||
UpstreamErrorRate: 0,
|
||||
Duration: OpsPercentiles{P99: intPtr(500)},
|
||||
},
|
||||
wantMin: 100,
|
||||
wantMax: 100,
|
||||
},
|
||||
{
|
||||
name: "error rate 5%",
|
||||
overview: &OpsDashboardOverview{
|
||||
SLA: 0.95,
|
||||
ErrorRate: 0.05,
|
||||
UpstreamErrorRate: 0,
|
||||
Duration: OpsPercentiles{P99: intPtr(500)},
|
||||
},
|
||||
wantMin: 77,
|
||||
wantMax: 78,
|
||||
},
|
||||
{
|
||||
name: "TTFT boundary 2s",
|
||||
overview: &OpsDashboardOverview{
|
||||
SLA: 0.99,
|
||||
ErrorRate: 0,
|
||||
UpstreamErrorRate: 0,
|
||||
TTFT: OpsPercentiles{P99: intPtr(2000)},
|
||||
},
|
||||
wantMin: 75,
|
||||
wantMax: 75,
|
||||
},
|
||||
{
|
||||
name: "upstream error dominates",
|
||||
overview: &OpsDashboardOverview{
|
||||
SLA: 0.995,
|
||||
ErrorRate: 0.001,
|
||||
UpstreamErrorRate: 0.03,
|
||||
Duration: OpsPercentiles{P99: intPtr(500)},
|
||||
},
|
||||
wantMin: 88,
|
||||
wantMax: 90,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
score := computeBusinessHealth(tt.overview)
|
||||
require.GreaterOrEqual(t, score, tt.wantMin, "score should be >= %.1f", tt.wantMin)
|
||||
require.LessOrEqual(t, score, tt.wantMax, "score should be <= %.1f", tt.wantMax)
|
||||
require.GreaterOrEqual(t, score, 0.0, "score must be >= 0")
|
||||
require.LessOrEqual(t, score, 100.0, "score must be <= 100")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeInfraHealth(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
now := time.Now().UTC()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
overview *OpsDashboardOverview
|
||||
wantMin float64
|
||||
wantMax float64
|
||||
}{
|
||||
{
|
||||
name: "all infrastructure healthy",
|
||||
overview: &OpsDashboardOverview{
|
||||
RequestCountTotal: 1000,
|
||||
SystemMetrics: &OpsSystemMetricsSnapshot{
|
||||
DBOK: boolPtr(true),
|
||||
RedisOK: boolPtr(true),
|
||||
CPUUsagePercent: float64Ptr(30),
|
||||
MemoryUsagePercent: float64Ptr(40),
|
||||
},
|
||||
},
|
||||
wantMin: 100,
|
||||
wantMax: 100,
|
||||
},
|
||||
{
|
||||
name: "DB down",
|
||||
overview: &OpsDashboardOverview{
|
||||
RequestCountTotal: 1000,
|
||||
SystemMetrics: &OpsSystemMetricsSnapshot{
|
||||
DBOK: boolPtr(false),
|
||||
RedisOK: boolPtr(true),
|
||||
CPUUsagePercent: float64Ptr(30),
|
||||
MemoryUsagePercent: float64Ptr(40),
|
||||
},
|
||||
},
|
||||
wantMin: 50,
|
||||
wantMax: 70,
|
||||
},
|
||||
{
|
||||
name: "Redis down",
|
||||
overview: &OpsDashboardOverview{
|
||||
RequestCountTotal: 1000,
|
||||
SystemMetrics: &OpsSystemMetricsSnapshot{
|
||||
DBOK: boolPtr(true),
|
||||
RedisOK: boolPtr(false),
|
||||
CPUUsagePercent: float64Ptr(30),
|
||||
MemoryUsagePercent: float64Ptr(40),
|
||||
},
|
||||
},
|
||||
wantMin: 80,
|
||||
wantMax: 95,
|
||||
},
|
||||
{
|
||||
name: "CPU at 90%",
|
||||
overview: &OpsDashboardOverview{
|
||||
RequestCountTotal: 1000,
|
||||
SystemMetrics: &OpsSystemMetricsSnapshot{
|
||||
DBOK: boolPtr(true),
|
||||
RedisOK: boolPtr(true),
|
||||
CPUUsagePercent: float64Ptr(90),
|
||||
MemoryUsagePercent: float64Ptr(40),
|
||||
},
|
||||
},
|
||||
wantMin: 85,
|
||||
wantMax: 95,
|
||||
},
|
||||
{
|
||||
name: "failed background job",
|
||||
overview: &OpsDashboardOverview{
|
||||
RequestCountTotal: 1000,
|
||||
SystemMetrics: &OpsSystemMetricsSnapshot{
|
||||
DBOK: boolPtr(true),
|
||||
RedisOK: boolPtr(true),
|
||||
CPUUsagePercent: float64Ptr(30),
|
||||
MemoryUsagePercent: float64Ptr(40),
|
||||
},
|
||||
JobHeartbeats: []*OpsJobHeartbeat{
|
||||
{
|
||||
JobName: "test-job",
|
||||
LastErrorAt: &now,
|
||||
},
|
||||
},
|
||||
},
|
||||
wantMin: 70,
|
||||
wantMax: 90,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
score := computeInfraHealth(now, tt.overview)
|
||||
require.GreaterOrEqual(t, score, tt.wantMin, "score should be >= %.1f", tt.wantMin)
|
||||
require.LessOrEqual(t, score, tt.wantMax, "score should be <= %.1f", tt.wantMax)
|
||||
require.GreaterOrEqual(t, score, 0.0, "score must be >= 0")
|
||||
require.LessOrEqual(t, score, 100.0, "score must be <= 100")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func timePtr(v time.Time) *time.Time { return &v }
|
||||
|
||||
func stringPtr(v string) *string { return &v }
|
||||
26
backend/internal/service/ops_histograms.go
Normal file
26
backend/internal/service/ops_histograms.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
)
|
||||
|
||||
func (s *OpsService) GetLatencyHistogram(ctx context.Context, filter *OpsDashboardFilter) (*OpsLatencyHistogramResponse, error) {
|
||||
if err := s.RequireMonitoringEnabled(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.opsRepo == nil {
|
||||
return nil, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
|
||||
}
|
||||
if filter == nil {
|
||||
return nil, infraerrors.BadRequest("OPS_FILTER_REQUIRED", "filter is required")
|
||||
}
|
||||
if filter.StartTime.IsZero() || filter.EndTime.IsZero() {
|
||||
return nil, infraerrors.BadRequest("OPS_TIME_RANGE_REQUIRED", "start_time/end_time are required")
|
||||
}
|
||||
if filter.StartTime.After(filter.EndTime) {
|
||||
return nil, infraerrors.BadRequest("OPS_TIME_RANGE_INVALID", "start_time must be <= end_time")
|
||||
}
|
||||
return s.opsRepo.GetLatencyHistogram(ctx, filter)
|
||||
}
|
||||
920
backend/internal/service/ops_metrics_collector.go
Normal file
920
backend/internal/service/ops_metrics_collector.go
Normal file
@@ -0,0 +1,920 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"math"
|
||||
"os"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/google/uuid"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/shirou/gopsutil/v4/cpu"
|
||||
"github.com/shirou/gopsutil/v4/mem"
|
||||
)
|
||||
|
||||
const (
|
||||
opsMetricsCollectorJobName = "ops_metrics_collector"
|
||||
opsMetricsCollectorMinInterval = 60 * time.Second
|
||||
opsMetricsCollectorMaxInterval = 1 * time.Hour
|
||||
|
||||
opsMetricsCollectorTimeout = 10 * time.Second
|
||||
|
||||
opsMetricsCollectorLeaderLockKey = "ops:metrics:collector:leader"
|
||||
opsMetricsCollectorLeaderLockTTL = 90 * time.Second
|
||||
|
||||
opsMetricsCollectorHeartbeatTimeout = 2 * time.Second
|
||||
|
||||
bytesPerMB = 1024 * 1024
|
||||
)
|
||||
|
||||
var opsMetricsCollectorAdvisoryLockID = hashAdvisoryLockID(opsMetricsCollectorLeaderLockKey)
|
||||
|
||||
type OpsMetricsCollector struct {
|
||||
opsRepo OpsRepository
|
||||
settingRepo SettingRepository
|
||||
cfg *config.Config
|
||||
|
||||
accountRepo AccountRepository
|
||||
concurrencyService *ConcurrencyService
|
||||
|
||||
db *sql.DB
|
||||
redisClient *redis.Client
|
||||
instanceID string
|
||||
|
||||
lastCgroupCPUUsageNanos uint64
|
||||
lastCgroupCPUSampleAt time.Time
|
||||
|
||||
stopCh chan struct{}
|
||||
startOnce sync.Once
|
||||
stopOnce sync.Once
|
||||
|
||||
skipLogMu sync.Mutex
|
||||
skipLogAt time.Time
|
||||
}
|
||||
|
||||
func NewOpsMetricsCollector(
|
||||
opsRepo OpsRepository,
|
||||
settingRepo SettingRepository,
|
||||
accountRepo AccountRepository,
|
||||
concurrencyService *ConcurrencyService,
|
||||
db *sql.DB,
|
||||
redisClient *redis.Client,
|
||||
cfg *config.Config,
|
||||
) *OpsMetricsCollector {
|
||||
return &OpsMetricsCollector{
|
||||
opsRepo: opsRepo,
|
||||
settingRepo: settingRepo,
|
||||
cfg: cfg,
|
||||
accountRepo: accountRepo,
|
||||
concurrencyService: concurrencyService,
|
||||
db: db,
|
||||
redisClient: redisClient,
|
||||
instanceID: uuid.NewString(),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *OpsMetricsCollector) Start() {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
c.startOnce.Do(func() {
|
||||
if c.stopCh == nil {
|
||||
c.stopCh = make(chan struct{})
|
||||
}
|
||||
go c.run()
|
||||
})
|
||||
}
|
||||
|
||||
func (c *OpsMetricsCollector) Stop() {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
c.stopOnce.Do(func() {
|
||||
if c.stopCh != nil {
|
||||
close(c.stopCh)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (c *OpsMetricsCollector) run() {
|
||||
// First run immediately so the dashboard has data soon after startup.
|
||||
c.collectOnce()
|
||||
|
||||
for {
|
||||
interval := c.getInterval()
|
||||
timer := time.NewTimer(interval)
|
||||
select {
|
||||
case <-timer.C:
|
||||
c.collectOnce()
|
||||
case <-c.stopCh:
|
||||
timer.Stop()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *OpsMetricsCollector) getInterval() time.Duration {
|
||||
interval := opsMetricsCollectorMinInterval
|
||||
|
||||
if c.settingRepo == nil {
|
||||
return interval
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
raw, err := c.settingRepo.GetValue(ctx, SettingKeyOpsMetricsIntervalSeconds)
|
||||
if err != nil {
|
||||
return interval
|
||||
}
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return interval
|
||||
}
|
||||
|
||||
seconds, err := strconv.Atoi(raw)
|
||||
if err != nil {
|
||||
return interval
|
||||
}
|
||||
if seconds < int(opsMetricsCollectorMinInterval.Seconds()) {
|
||||
seconds = int(opsMetricsCollectorMinInterval.Seconds())
|
||||
}
|
||||
if seconds > int(opsMetricsCollectorMaxInterval.Seconds()) {
|
||||
seconds = int(opsMetricsCollectorMaxInterval.Seconds())
|
||||
}
|
||||
return time.Duration(seconds) * time.Second
|
||||
}
|
||||
|
||||
func (c *OpsMetricsCollector) collectOnce() {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
if c.cfg != nil && !c.cfg.Ops.Enabled {
|
||||
return
|
||||
}
|
||||
if c.opsRepo == nil {
|
||||
return
|
||||
}
|
||||
if c.db == nil {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), opsMetricsCollectorTimeout)
|
||||
defer cancel()
|
||||
|
||||
if !c.isMonitoringEnabled(ctx) {
|
||||
return
|
||||
}
|
||||
|
||||
release, ok := c.tryAcquireLeaderLock(ctx)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if release != nil {
|
||||
defer release()
|
||||
}
|
||||
|
||||
startedAt := time.Now().UTC()
|
||||
err := c.collectAndPersist(ctx)
|
||||
finishedAt := time.Now().UTC()
|
||||
|
||||
durationMs := finishedAt.Sub(startedAt).Milliseconds()
|
||||
dur := durationMs
|
||||
runAt := startedAt
|
||||
|
||||
if err != nil {
|
||||
msg := truncateString(err.Error(), 2048)
|
||||
errAt := finishedAt
|
||||
hbCtx, hbCancel := context.WithTimeout(context.Background(), opsMetricsCollectorHeartbeatTimeout)
|
||||
defer hbCancel()
|
||||
_ = c.opsRepo.UpsertJobHeartbeat(hbCtx, &OpsUpsertJobHeartbeatInput{
|
||||
JobName: opsMetricsCollectorJobName,
|
||||
LastRunAt: &runAt,
|
||||
LastErrorAt: &errAt,
|
||||
LastError: &msg,
|
||||
LastDurationMs: &dur,
|
||||
})
|
||||
log.Printf("[OpsMetricsCollector] collect failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
successAt := finishedAt
|
||||
hbCtx, hbCancel := context.WithTimeout(context.Background(), opsMetricsCollectorHeartbeatTimeout)
|
||||
defer hbCancel()
|
||||
_ = c.opsRepo.UpsertJobHeartbeat(hbCtx, &OpsUpsertJobHeartbeatInput{
|
||||
JobName: opsMetricsCollectorJobName,
|
||||
LastRunAt: &runAt,
|
||||
LastSuccessAt: &successAt,
|
||||
LastDurationMs: &dur,
|
||||
})
|
||||
}
|
||||
|
||||
func (c *OpsMetricsCollector) isMonitoringEnabled(ctx context.Context) bool {
|
||||
if c == nil {
|
||||
return false
|
||||
}
|
||||
if c.cfg != nil && !c.cfg.Ops.Enabled {
|
||||
return false
|
||||
}
|
||||
if c.settingRepo == nil {
|
||||
return true
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
value, err := c.settingRepo.GetValue(ctx, SettingKeyOpsMonitoringEnabled)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrSettingNotFound) {
|
||||
return true
|
||||
}
|
||||
// Fail-open: collector should not become a hard dependency.
|
||||
return true
|
||||
}
|
||||
switch strings.ToLower(strings.TrimSpace(value)) {
|
||||
case "false", "0", "off", "disabled":
|
||||
return false
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
func (c *OpsMetricsCollector) collectAndPersist(ctx context.Context) error {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
// Align to stable minute boundaries to avoid partial buckets and to maximize cache hits.
|
||||
now := time.Now().UTC()
|
||||
windowEnd := now.Truncate(time.Minute)
|
||||
windowStart := windowEnd.Add(-1 * time.Minute)
|
||||
|
||||
sys, err := c.collectSystemStats(ctx)
|
||||
if err != nil {
|
||||
// Continue; system stats are best-effort.
|
||||
log.Printf("[OpsMetricsCollector] system stats error: %v", err)
|
||||
}
|
||||
|
||||
dbOK := c.checkDB(ctx)
|
||||
redisOK := c.checkRedis(ctx)
|
||||
active, idle := c.dbPoolStats()
|
||||
redisTotal, redisIdle, redisStatsOK := c.redisPoolStats()
|
||||
|
||||
successCount, tokenConsumed, err := c.queryUsageCounts(ctx, windowStart, windowEnd)
|
||||
if err != nil {
|
||||
return fmt.Errorf("query usage counts: %w", err)
|
||||
}
|
||||
|
||||
duration, ttft, err := c.queryUsageLatency(ctx, windowStart, windowEnd)
|
||||
if err != nil {
|
||||
return fmt.Errorf("query usage latency: %w", err)
|
||||
}
|
||||
|
||||
errorTotal, businessLimited, errorSLA, upstreamExcl, upstream429, upstream529, err := c.queryErrorCounts(ctx, windowStart, windowEnd)
|
||||
if err != nil {
|
||||
return fmt.Errorf("query error counts: %w", err)
|
||||
}
|
||||
|
||||
windowSeconds := windowEnd.Sub(windowStart).Seconds()
|
||||
if windowSeconds <= 0 {
|
||||
windowSeconds = 60
|
||||
}
|
||||
requestTotal := successCount + errorTotal
|
||||
qps := float64(requestTotal) / windowSeconds
|
||||
tps := float64(tokenConsumed) / windowSeconds
|
||||
|
||||
goroutines := runtime.NumGoroutine()
|
||||
concurrencyQueueDepth := c.collectConcurrencyQueueDepth(ctx)
|
||||
|
||||
input := &OpsInsertSystemMetricsInput{
|
||||
CreatedAt: windowEnd,
|
||||
WindowMinutes: 1,
|
||||
|
||||
SuccessCount: successCount,
|
||||
ErrorCountTotal: errorTotal,
|
||||
BusinessLimitedCount: businessLimited,
|
||||
ErrorCountSLA: errorSLA,
|
||||
|
||||
UpstreamErrorCountExcl429529: upstreamExcl,
|
||||
Upstream429Count: upstream429,
|
||||
Upstream529Count: upstream529,
|
||||
|
||||
TokenConsumed: tokenConsumed,
|
||||
QPS: float64Ptr(roundTo1DP(qps)),
|
||||
TPS: float64Ptr(roundTo1DP(tps)),
|
||||
|
||||
DurationP50Ms: duration.p50,
|
||||
DurationP90Ms: duration.p90,
|
||||
DurationP95Ms: duration.p95,
|
||||
DurationP99Ms: duration.p99,
|
||||
DurationAvgMs: duration.avg,
|
||||
DurationMaxMs: duration.max,
|
||||
|
||||
TTFTP50Ms: ttft.p50,
|
||||
TTFTP90Ms: ttft.p90,
|
||||
TTFTP95Ms: ttft.p95,
|
||||
TTFTP99Ms: ttft.p99,
|
||||
TTFTAvgMs: ttft.avg,
|
||||
TTFTMaxMs: ttft.max,
|
||||
|
||||
CPUUsagePercent: sys.cpuUsagePercent,
|
||||
MemoryUsedMB: sys.memoryUsedMB,
|
||||
MemoryTotalMB: sys.memoryTotalMB,
|
||||
MemoryUsagePercent: sys.memoryUsagePercent,
|
||||
|
||||
DBOK: boolPtr(dbOK),
|
||||
RedisOK: boolPtr(redisOK),
|
||||
|
||||
RedisConnTotal: func() *int {
|
||||
if !redisStatsOK {
|
||||
return nil
|
||||
}
|
||||
return intPtr(redisTotal)
|
||||
}(),
|
||||
RedisConnIdle: func() *int {
|
||||
if !redisStatsOK {
|
||||
return nil
|
||||
}
|
||||
return intPtr(redisIdle)
|
||||
}(),
|
||||
|
||||
DBConnActive: intPtr(active),
|
||||
DBConnIdle: intPtr(idle),
|
||||
GoroutineCount: intPtr(goroutines),
|
||||
ConcurrencyQueueDepth: concurrencyQueueDepth,
|
||||
}
|
||||
|
||||
return c.opsRepo.InsertSystemMetrics(ctx, input)
|
||||
}
|
||||
|
||||
func (c *OpsMetricsCollector) collectConcurrencyQueueDepth(parentCtx context.Context) *int {
|
||||
if c == nil || c.accountRepo == nil || c.concurrencyService == nil {
|
||||
return nil
|
||||
}
|
||||
if parentCtx == nil {
|
||||
parentCtx = context.Background()
|
||||
}
|
||||
|
||||
// Best-effort: never let concurrency sampling break the metrics collector.
|
||||
ctx, cancel := context.WithTimeout(parentCtx, 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
accounts, err := c.accountRepo.ListSchedulable(ctx)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
if len(accounts) == 0 {
|
||||
zero := 0
|
||||
return &zero
|
||||
}
|
||||
|
||||
batch := make([]AccountWithConcurrency, 0, len(accounts))
|
||||
for _, acc := range accounts {
|
||||
if acc.ID <= 0 {
|
||||
continue
|
||||
}
|
||||
maxConc := acc.Concurrency
|
||||
if maxConc < 0 {
|
||||
maxConc = 0
|
||||
}
|
||||
batch = append(batch, AccountWithConcurrency{
|
||||
ID: acc.ID,
|
||||
MaxConcurrency: maxConc,
|
||||
})
|
||||
}
|
||||
if len(batch) == 0 {
|
||||
zero := 0
|
||||
return &zero
|
||||
}
|
||||
|
||||
loadMap, err := c.concurrencyService.GetAccountsLoadBatch(ctx, batch)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var total int64
|
||||
for _, info := range loadMap {
|
||||
if info == nil || info.WaitingCount <= 0 {
|
||||
continue
|
||||
}
|
||||
total += int64(info.WaitingCount)
|
||||
}
|
||||
if total < 0 {
|
||||
total = 0
|
||||
}
|
||||
|
||||
maxInt := int64(^uint(0) >> 1)
|
||||
if total > maxInt {
|
||||
total = maxInt
|
||||
}
|
||||
v := int(total)
|
||||
return &v
|
||||
}
|
||||
|
||||
type opsCollectedPercentiles struct {
|
||||
p50 *int
|
||||
p90 *int
|
||||
p95 *int
|
||||
p99 *int
|
||||
avg *float64
|
||||
max *int
|
||||
}
|
||||
|
||||
func (c *OpsMetricsCollector) queryUsageCounts(ctx context.Context, start, end time.Time) (successCount int64, tokenConsumed int64, err error) {
|
||||
q := `
|
||||
SELECT
|
||||
COALESCE(COUNT(*), 0) AS success_count,
|
||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) AS token_consumed
|
||||
FROM usage_logs
|
||||
WHERE created_at >= $1 AND created_at < $2`
|
||||
|
||||
var tokens sql.NullInt64
|
||||
if err := c.db.QueryRowContext(ctx, q, start, end).Scan(&successCount, &tokens); err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
if tokens.Valid {
|
||||
tokenConsumed = tokens.Int64
|
||||
}
|
||||
return successCount, tokenConsumed, nil
|
||||
}
|
||||
|
||||
func (c *OpsMetricsCollector) queryUsageLatency(ctx context.Context, start, end time.Time) (duration opsCollectedPercentiles, ttft opsCollectedPercentiles, err error) {
|
||||
{
|
||||
q := `
|
||||
SELECT
|
||||
percentile_cont(0.50) WITHIN GROUP (ORDER BY duration_ms) AS p50,
|
||||
percentile_cont(0.90) WITHIN GROUP (ORDER BY duration_ms) AS p90,
|
||||
percentile_cont(0.95) WITHIN GROUP (ORDER BY duration_ms) AS p95,
|
||||
percentile_cont(0.99) WITHIN GROUP (ORDER BY duration_ms) AS p99,
|
||||
AVG(duration_ms) AS avg_ms,
|
||||
MAX(duration_ms) AS max_ms
|
||||
FROM usage_logs
|
||||
WHERE created_at >= $1 AND created_at < $2
|
||||
AND duration_ms IS NOT NULL`
|
||||
|
||||
var p50, p90, p95, p99 sql.NullFloat64
|
||||
var avg sql.NullFloat64
|
||||
var max sql.NullInt64
|
||||
if err := c.db.QueryRowContext(ctx, q, start, end).Scan(&p50, &p90, &p95, &p99, &avg, &max); err != nil {
|
||||
return opsCollectedPercentiles{}, opsCollectedPercentiles{}, err
|
||||
}
|
||||
duration.p50 = floatToIntPtr(p50)
|
||||
duration.p90 = floatToIntPtr(p90)
|
||||
duration.p95 = floatToIntPtr(p95)
|
||||
duration.p99 = floatToIntPtr(p99)
|
||||
if avg.Valid {
|
||||
v := roundTo1DP(avg.Float64)
|
||||
duration.avg = &v
|
||||
}
|
||||
if max.Valid {
|
||||
v := int(max.Int64)
|
||||
duration.max = &v
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
q := `
|
||||
SELECT
|
||||
percentile_cont(0.50) WITHIN GROUP (ORDER BY first_token_ms) AS p50,
|
||||
percentile_cont(0.90) WITHIN GROUP (ORDER BY first_token_ms) AS p90,
|
||||
percentile_cont(0.95) WITHIN GROUP (ORDER BY first_token_ms) AS p95,
|
||||
percentile_cont(0.99) WITHIN GROUP (ORDER BY first_token_ms) AS p99,
|
||||
AVG(first_token_ms) AS avg_ms,
|
||||
MAX(first_token_ms) AS max_ms
|
||||
FROM usage_logs
|
||||
WHERE created_at >= $1 AND created_at < $2
|
||||
AND first_token_ms IS NOT NULL`
|
||||
|
||||
var p50, p90, p95, p99 sql.NullFloat64
|
||||
var avg sql.NullFloat64
|
||||
var max sql.NullInt64
|
||||
if err := c.db.QueryRowContext(ctx, q, start, end).Scan(&p50, &p90, &p95, &p99, &avg, &max); err != nil {
|
||||
return opsCollectedPercentiles{}, opsCollectedPercentiles{}, err
|
||||
}
|
||||
ttft.p50 = floatToIntPtr(p50)
|
||||
ttft.p90 = floatToIntPtr(p90)
|
||||
ttft.p95 = floatToIntPtr(p95)
|
||||
ttft.p99 = floatToIntPtr(p99)
|
||||
if avg.Valid {
|
||||
v := roundTo1DP(avg.Float64)
|
||||
ttft.avg = &v
|
||||
}
|
||||
if max.Valid {
|
||||
v := int(max.Int64)
|
||||
ttft.max = &v
|
||||
}
|
||||
}
|
||||
|
||||
return duration, ttft, nil
|
||||
}
|
||||
|
||||
func (c *OpsMetricsCollector) queryErrorCounts(ctx context.Context, start, end time.Time) (
|
||||
errorTotal int64,
|
||||
businessLimited int64,
|
||||
errorSLA int64,
|
||||
upstreamExcl429529 int64,
|
||||
upstream429 int64,
|
||||
upstream529 int64,
|
||||
err error,
|
||||
) {
|
||||
q := `
|
||||
SELECT
|
||||
COALESCE(COUNT(*) FILTER (WHERE COALESCE(status_code, 0) >= 400), 0) AS error_total,
|
||||
COALESCE(COUNT(*) FILTER (WHERE COALESCE(status_code, 0) >= 400 AND is_business_limited), 0) AS business_limited,
|
||||
COALESCE(COUNT(*) FILTER (WHERE COALESCE(status_code, 0) >= 400 AND NOT is_business_limited), 0) AS error_sla,
|
||||
COALESCE(COUNT(*) FILTER (WHERE error_owner = 'provider' AND NOT is_business_limited AND COALESCE(upstream_status_code, status_code, 0) NOT IN (429, 529)), 0) AS upstream_excl,
|
||||
COALESCE(COUNT(*) FILTER (WHERE error_owner = 'provider' AND NOT is_business_limited AND COALESCE(upstream_status_code, status_code, 0) = 429), 0) AS upstream_429,
|
||||
COALESCE(COUNT(*) FILTER (WHERE error_owner = 'provider' AND NOT is_business_limited AND COALESCE(upstream_status_code, status_code, 0) = 529), 0) AS upstream_529
|
||||
FROM ops_error_logs
|
||||
WHERE created_at >= $1 AND created_at < $2`
|
||||
|
||||
if err := c.db.QueryRowContext(ctx, q, start, end).Scan(
|
||||
&errorTotal,
|
||||
&businessLimited,
|
||||
&errorSLA,
|
||||
&upstreamExcl429529,
|
||||
&upstream429,
|
||||
&upstream529,
|
||||
); err != nil {
|
||||
return 0, 0, 0, 0, 0, 0, err
|
||||
}
|
||||
return errorTotal, businessLimited, errorSLA, upstreamExcl429529, upstream429, upstream529, nil
|
||||
}
|
||||
|
||||
type opsCollectedSystemStats struct {
|
||||
cpuUsagePercent *float64
|
||||
memoryUsedMB *int64
|
||||
memoryTotalMB *int64
|
||||
memoryUsagePercent *float64
|
||||
}
|
||||
|
||||
func (c *OpsMetricsCollector) collectSystemStats(ctx context.Context) (*opsCollectedSystemStats, error) {
|
||||
out := &opsCollectedSystemStats{}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
sampleAt := time.Now().UTC()
|
||||
|
||||
// Prefer cgroup (container) metrics when available.
|
||||
if cpuPct := c.tryCgroupCPUPercent(sampleAt); cpuPct != nil {
|
||||
out.cpuUsagePercent = cpuPct
|
||||
}
|
||||
|
||||
cgroupUsed, cgroupTotal, cgroupOK := readCgroupMemoryBytes()
|
||||
if cgroupOK {
|
||||
usedMB := int64(cgroupUsed / bytesPerMB)
|
||||
out.memoryUsedMB = &usedMB
|
||||
if cgroupTotal > 0 {
|
||||
totalMB := int64(cgroupTotal / bytesPerMB)
|
||||
out.memoryTotalMB = &totalMB
|
||||
pct := roundTo1DP(float64(cgroupUsed) / float64(cgroupTotal) * 100)
|
||||
out.memoryUsagePercent = &pct
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to host metrics if cgroup metrics are unavailable (or incomplete).
|
||||
if out.cpuUsagePercent == nil {
|
||||
if cpuPercents, err := cpu.PercentWithContext(ctx, 0, false); err == nil && len(cpuPercents) > 0 {
|
||||
v := roundTo1DP(cpuPercents[0])
|
||||
out.cpuUsagePercent = &v
|
||||
}
|
||||
}
|
||||
|
||||
// If total memory isn't available from cgroup (e.g. memory.max = "max"), fill total from host.
|
||||
if out.memoryUsedMB == nil || out.memoryTotalMB == nil || out.memoryUsagePercent == nil {
|
||||
if vm, err := mem.VirtualMemoryWithContext(ctx); err == nil && vm != nil {
|
||||
if out.memoryUsedMB == nil {
|
||||
usedMB := int64(vm.Used / bytesPerMB)
|
||||
out.memoryUsedMB = &usedMB
|
||||
}
|
||||
if out.memoryTotalMB == nil {
|
||||
totalMB := int64(vm.Total / bytesPerMB)
|
||||
out.memoryTotalMB = &totalMB
|
||||
}
|
||||
if out.memoryUsagePercent == nil {
|
||||
if out.memoryUsedMB != nil && out.memoryTotalMB != nil && *out.memoryTotalMB > 0 {
|
||||
pct := roundTo1DP(float64(*out.memoryUsedMB) / float64(*out.memoryTotalMB) * 100)
|
||||
out.memoryUsagePercent = &pct
|
||||
} else {
|
||||
pct := roundTo1DP(vm.UsedPercent)
|
||||
out.memoryUsagePercent = &pct
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *OpsMetricsCollector) tryCgroupCPUPercent(now time.Time) *float64 {
|
||||
usageNanos, ok := readCgroupCPUUsageNanos()
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Initialize baseline sample.
|
||||
if c.lastCgroupCPUSampleAt.IsZero() {
|
||||
c.lastCgroupCPUUsageNanos = usageNanos
|
||||
c.lastCgroupCPUSampleAt = now
|
||||
return nil
|
||||
}
|
||||
|
||||
elapsed := now.Sub(c.lastCgroupCPUSampleAt)
|
||||
if elapsed <= 0 {
|
||||
c.lastCgroupCPUUsageNanos = usageNanos
|
||||
c.lastCgroupCPUSampleAt = now
|
||||
return nil
|
||||
}
|
||||
|
||||
prev := c.lastCgroupCPUUsageNanos
|
||||
c.lastCgroupCPUUsageNanos = usageNanos
|
||||
c.lastCgroupCPUSampleAt = now
|
||||
|
||||
if usageNanos < prev {
|
||||
// Counter reset (container restarted).
|
||||
return nil
|
||||
}
|
||||
|
||||
deltaUsageSec := float64(usageNanos-prev) / 1e9
|
||||
elapsedSec := elapsed.Seconds()
|
||||
if elapsedSec <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
cores := readCgroupCPULimitCores()
|
||||
if cores <= 0 {
|
||||
// Can't reliably normalize; skip and fall back to gopsutil.
|
||||
return nil
|
||||
}
|
||||
|
||||
pct := (deltaUsageSec / (elapsedSec * cores)) * 100
|
||||
if pct < 0 {
|
||||
pct = 0
|
||||
}
|
||||
// Clamp to avoid noise/jitter showing impossible values.
|
||||
if pct > 100 {
|
||||
pct = 100
|
||||
}
|
||||
v := roundTo1DP(pct)
|
||||
return &v
|
||||
}
|
||||
|
||||
func readCgroupMemoryBytes() (usedBytes uint64, totalBytes uint64, ok bool) {
|
||||
// cgroup v2 (most common in modern containers)
|
||||
if used, ok1 := readUintFile("/sys/fs/cgroup/memory.current"); ok1 {
|
||||
usedBytes = used
|
||||
rawMax, err := os.ReadFile("/sys/fs/cgroup/memory.max")
|
||||
if err == nil {
|
||||
s := strings.TrimSpace(string(rawMax))
|
||||
if s != "" && s != "max" {
|
||||
if v, err := strconv.ParseUint(s, 10, 64); err == nil {
|
||||
totalBytes = v
|
||||
}
|
||||
}
|
||||
}
|
||||
return usedBytes, totalBytes, true
|
||||
}
|
||||
|
||||
// cgroup v1 fallback
|
||||
if used, ok1 := readUintFile("/sys/fs/cgroup/memory/memory.usage_in_bytes"); ok1 {
|
||||
usedBytes = used
|
||||
if limit, ok2 := readUintFile("/sys/fs/cgroup/memory/memory.limit_in_bytes"); ok2 {
|
||||
// Some environments report a very large number when unlimited.
|
||||
if limit > 0 && limit < (1<<60) {
|
||||
totalBytes = limit
|
||||
}
|
||||
}
|
||||
return usedBytes, totalBytes, true
|
||||
}
|
||||
|
||||
return 0, 0, false
|
||||
}
|
||||
|
||||
func readCgroupCPUUsageNanos() (usageNanos uint64, ok bool) {
|
||||
// cgroup v2: cpu.stat has usage_usec
|
||||
if raw, err := os.ReadFile("/sys/fs/cgroup/cpu.stat"); err == nil {
|
||||
lines := strings.Split(string(raw), "\n")
|
||||
for _, line := range lines {
|
||||
fields := strings.Fields(line)
|
||||
if len(fields) != 2 {
|
||||
continue
|
||||
}
|
||||
if fields[0] != "usage_usec" {
|
||||
continue
|
||||
}
|
||||
v, err := strconv.ParseUint(fields[1], 10, 64)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
return v * 1000, true
|
||||
}
|
||||
}
|
||||
|
||||
// cgroup v1: cpuacct.usage is in nanoseconds
|
||||
if v, ok := readUintFile("/sys/fs/cgroup/cpuacct/cpuacct.usage"); ok {
|
||||
return v, true
|
||||
}
|
||||
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func readCgroupCPULimitCores() float64 {
|
||||
// cgroup v2: cpu.max => "<quota> <period>" or "max <period>"
|
||||
if raw, err := os.ReadFile("/sys/fs/cgroup/cpu.max"); err == nil {
|
||||
fields := strings.Fields(string(raw))
|
||||
if len(fields) >= 2 && fields[0] != "max" {
|
||||
quota, err1 := strconv.ParseFloat(fields[0], 64)
|
||||
period, err2 := strconv.ParseFloat(fields[1], 64)
|
||||
if err1 == nil && err2 == nil && quota > 0 && period > 0 {
|
||||
return quota / period
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// cgroup v1: cpu.cfs_quota_us / cpu.cfs_period_us
|
||||
quota, okQuota := readIntFile("/sys/fs/cgroup/cpu/cpu.cfs_quota_us")
|
||||
period, okPeriod := readIntFile("/sys/fs/cgroup/cpu/cpu.cfs_period_us")
|
||||
if okQuota && okPeriod && quota > 0 && period > 0 {
|
||||
return float64(quota) / float64(period)
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
func readUintFile(path string) (uint64, bool) {
|
||||
raw, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return 0, false
|
||||
}
|
||||
s := strings.TrimSpace(string(raw))
|
||||
if s == "" {
|
||||
return 0, false
|
||||
}
|
||||
v, err := strconv.ParseUint(s, 10, 64)
|
||||
if err != nil {
|
||||
return 0, false
|
||||
}
|
||||
return v, true
|
||||
}
|
||||
|
||||
func readIntFile(path string) (int64, bool) {
|
||||
raw, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return 0, false
|
||||
}
|
||||
s := strings.TrimSpace(string(raw))
|
||||
if s == "" {
|
||||
return 0, false
|
||||
}
|
||||
v, err := strconv.ParseInt(s, 10, 64)
|
||||
if err != nil {
|
||||
return 0, false
|
||||
}
|
||||
return v, true
|
||||
}
|
||||
|
||||
func (c *OpsMetricsCollector) checkDB(ctx context.Context) bool {
|
||||
if c == nil || c.db == nil {
|
||||
return false
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
var one int
|
||||
if err := c.db.QueryRowContext(ctx, "SELECT 1").Scan(&one); err != nil {
|
||||
return false
|
||||
}
|
||||
return one == 1
|
||||
}
|
||||
|
||||
func (c *OpsMetricsCollector) checkRedis(ctx context.Context) bool {
|
||||
if c == nil || c.redisClient == nil {
|
||||
return false
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
return c.redisClient.Ping(ctx).Err() == nil
|
||||
}
|
||||
|
||||
func (c *OpsMetricsCollector) redisPoolStats() (total int, idle int, ok bool) {
|
||||
if c == nil || c.redisClient == nil {
|
||||
return 0, 0, false
|
||||
}
|
||||
stats := c.redisClient.PoolStats()
|
||||
if stats == nil {
|
||||
return 0, 0, false
|
||||
}
|
||||
return int(stats.TotalConns), int(stats.IdleConns), true
|
||||
}
|
||||
|
||||
func (c *OpsMetricsCollector) dbPoolStats() (active int, idle int) {
|
||||
if c == nil || c.db == nil {
|
||||
return 0, 0
|
||||
}
|
||||
stats := c.db.Stats()
|
||||
return stats.InUse, stats.Idle
|
||||
}
|
||||
|
||||
var opsMetricsCollectorReleaseScript = redis.NewScript(`
|
||||
if redis.call("GET", KEYS[1]) == ARGV[1] then
|
||||
return redis.call("DEL", KEYS[1])
|
||||
end
|
||||
return 0
|
||||
`)
|
||||
|
||||
func (c *OpsMetricsCollector) tryAcquireLeaderLock(ctx context.Context) (func(), bool) {
|
||||
if c == nil || c.redisClient == nil {
|
||||
return nil, true
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
ok, err := c.redisClient.SetNX(ctx, opsMetricsCollectorLeaderLockKey, c.instanceID, opsMetricsCollectorLeaderLockTTL).Result()
|
||||
if err != nil {
|
||||
// Prefer fail-closed to avoid stampeding the database when Redis is flaky.
|
||||
// Fallback to a DB advisory lock when Redis is present but unavailable.
|
||||
release, ok := tryAcquireDBAdvisoryLock(ctx, c.db, opsMetricsCollectorAdvisoryLockID)
|
||||
if !ok {
|
||||
c.maybeLogSkip()
|
||||
return nil, false
|
||||
}
|
||||
return release, true
|
||||
}
|
||||
if !ok {
|
||||
c.maybeLogSkip()
|
||||
return nil, false
|
||||
}
|
||||
|
||||
release := func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
_, _ = opsMetricsCollectorReleaseScript.Run(ctx, c.redisClient, []string{opsMetricsCollectorLeaderLockKey}, c.instanceID).Result()
|
||||
}
|
||||
return release, true
|
||||
}
|
||||
|
||||
func (c *OpsMetricsCollector) maybeLogSkip() {
|
||||
c.skipLogMu.Lock()
|
||||
defer c.skipLogMu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
if !c.skipLogAt.IsZero() && now.Sub(c.skipLogAt) < time.Minute {
|
||||
return
|
||||
}
|
||||
c.skipLogAt = now
|
||||
log.Printf("[OpsMetricsCollector] leader lock held by another instance; skipping")
|
||||
}
|
||||
|
||||
func floatToIntPtr(v sql.NullFloat64) *int {
|
||||
if !v.Valid {
|
||||
return nil
|
||||
}
|
||||
n := int(math.Round(v.Float64))
|
||||
return &n
|
||||
}
|
||||
|
||||
func roundTo1DP(v float64) float64 {
|
||||
return math.Round(v*10) / 10
|
||||
}
|
||||
|
||||
func truncateString(s string, max int) string {
|
||||
if max <= 0 {
|
||||
return ""
|
||||
}
|
||||
if len(s) <= max {
|
||||
return s
|
||||
}
|
||||
cut := s[:max]
|
||||
for len(cut) > 0 && !utf8.ValidString(cut) {
|
||||
cut = cut[:len(cut)-1]
|
||||
}
|
||||
return cut
|
||||
}
|
||||
|
||||
func boolPtr(v bool) *bool {
|
||||
out := v
|
||||
return &out
|
||||
}
|
||||
|
||||
func intPtr(v int) *int {
|
||||
out := v
|
||||
return &out
|
||||
}
|
||||
|
||||
func float64Ptr(v float64) *float64 {
|
||||
out := v
|
||||
return &out
|
||||
}
|
||||
169
backend/internal/service/ops_models.go
Normal file
169
backend/internal/service/ops_models.go
Normal file
@@ -0,0 +1,169 @@
|
||||
package service
|
||||
|
||||
import "time"
|
||||
|
||||
type OpsErrorLog struct {
|
||||
ID int64 `json:"id"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
|
||||
// Standardized classification
|
||||
// - phase: request|auth|routing|upstream|network|internal
|
||||
// - owner: client|provider|platform
|
||||
// - source: client_request|upstream_http|gateway
|
||||
Phase string `json:"phase"`
|
||||
Type string `json:"type"`
|
||||
|
||||
Owner string `json:"error_owner"`
|
||||
Source string `json:"error_source"`
|
||||
|
||||
Severity string `json:"severity"`
|
||||
|
||||
StatusCode int `json:"status_code"`
|
||||
Platform string `json:"platform"`
|
||||
Model string `json:"model"`
|
||||
|
||||
IsRetryable bool `json:"is_retryable"`
|
||||
RetryCount int `json:"retry_count"`
|
||||
|
||||
Resolved bool `json:"resolved"`
|
||||
ResolvedAt *time.Time `json:"resolved_at"`
|
||||
ResolvedByUserID *int64 `json:"resolved_by_user_id"`
|
||||
ResolvedByUserName string `json:"resolved_by_user_name"`
|
||||
ResolvedRetryID *int64 `json:"resolved_retry_id"`
|
||||
ResolvedStatusRaw string `json:"-"`
|
||||
|
||||
ClientRequestID string `json:"client_request_id"`
|
||||
RequestID string `json:"request_id"`
|
||||
Message string `json:"message"`
|
||||
|
||||
UserID *int64 `json:"user_id"`
|
||||
UserEmail string `json:"user_email"`
|
||||
APIKeyID *int64 `json:"api_key_id"`
|
||||
AccountID *int64 `json:"account_id"`
|
||||
AccountName string `json:"account_name"`
|
||||
GroupID *int64 `json:"group_id"`
|
||||
GroupName string `json:"group_name"`
|
||||
|
||||
ClientIP *string `json:"client_ip"`
|
||||
RequestPath string `json:"request_path"`
|
||||
Stream bool `json:"stream"`
|
||||
}
|
||||
|
||||
type OpsErrorLogDetail struct {
|
||||
OpsErrorLog
|
||||
|
||||
ErrorBody string `json:"error_body"`
|
||||
UserAgent string `json:"user_agent"`
|
||||
|
||||
// Upstream context (optional)
|
||||
UpstreamStatusCode *int `json:"upstream_status_code,omitempty"`
|
||||
UpstreamErrorMessage string `json:"upstream_error_message,omitempty"`
|
||||
UpstreamErrorDetail string `json:"upstream_error_detail,omitempty"`
|
||||
UpstreamErrors string `json:"upstream_errors,omitempty"` // JSON array (string) for display/parsing
|
||||
|
||||
// Timings (optional)
|
||||
AuthLatencyMs *int64 `json:"auth_latency_ms"`
|
||||
RoutingLatencyMs *int64 `json:"routing_latency_ms"`
|
||||
UpstreamLatencyMs *int64 `json:"upstream_latency_ms"`
|
||||
ResponseLatencyMs *int64 `json:"response_latency_ms"`
|
||||
TimeToFirstTokenMs *int64 `json:"time_to_first_token_ms"`
|
||||
|
||||
// Retry context
|
||||
RequestBody string `json:"request_body"`
|
||||
RequestBodyTruncated bool `json:"request_body_truncated"`
|
||||
RequestBodyBytes *int `json:"request_body_bytes"`
|
||||
RequestHeaders string `json:"request_headers,omitempty"`
|
||||
|
||||
// vNext metric semantics
|
||||
IsBusinessLimited bool `json:"is_business_limited"`
|
||||
}
|
||||
|
||||
type OpsErrorLogFilter struct {
|
||||
StartTime *time.Time
|
||||
EndTime *time.Time
|
||||
|
||||
Platform string
|
||||
GroupID *int64
|
||||
AccountID *int64
|
||||
|
||||
StatusCodes []int
|
||||
StatusCodesOther bool
|
||||
Phase string
|
||||
Owner string
|
||||
Source string
|
||||
Resolved *bool
|
||||
Query string
|
||||
UserQuery string // Search by user email
|
||||
|
||||
// Optional correlation keys for exact matching.
|
||||
RequestID string
|
||||
ClientRequestID string
|
||||
|
||||
// View controls error categorization for list endpoints.
|
||||
// - errors: show actionable errors (exclude business-limited / 429 / 529)
|
||||
// - excluded: only show excluded errors
|
||||
// - all: show everything
|
||||
View string
|
||||
|
||||
Page int
|
||||
PageSize int
|
||||
}
|
||||
|
||||
type OpsErrorLogList struct {
|
||||
Errors []*OpsErrorLog `json:"errors"`
|
||||
Total int `json:"total"`
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
}
|
||||
|
||||
type OpsRetryAttempt struct {
|
||||
ID int64 `json:"id"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
|
||||
RequestedByUserID int64 `json:"requested_by_user_id"`
|
||||
SourceErrorID int64 `json:"source_error_id"`
|
||||
Mode string `json:"mode"`
|
||||
PinnedAccountID *int64 `json:"pinned_account_id"`
|
||||
PinnedAccountName string `json:"pinned_account_name"`
|
||||
|
||||
Status string `json:"status"`
|
||||
StartedAt *time.Time `json:"started_at"`
|
||||
FinishedAt *time.Time `json:"finished_at"`
|
||||
DurationMs *int64 `json:"duration_ms"`
|
||||
|
||||
// Persisted execution results (best-effort)
|
||||
Success *bool `json:"success"`
|
||||
HTTPStatusCode *int `json:"http_status_code"`
|
||||
UpstreamRequestID *string `json:"upstream_request_id"`
|
||||
UsedAccountID *int64 `json:"used_account_id"`
|
||||
UsedAccountName string `json:"used_account_name"`
|
||||
ResponsePreview *string `json:"response_preview"`
|
||||
ResponseTruncated *bool `json:"response_truncated"`
|
||||
|
||||
// Optional correlation
|
||||
ResultRequestID *string `json:"result_request_id"`
|
||||
ResultErrorID *int64 `json:"result_error_id"`
|
||||
|
||||
ErrorMessage *string `json:"error_message"`
|
||||
}
|
||||
|
||||
type OpsRetryResult struct {
|
||||
AttemptID int64 `json:"attempt_id"`
|
||||
Mode string `json:"mode"`
|
||||
Status string `json:"status"`
|
||||
|
||||
PinnedAccountID *int64 `json:"pinned_account_id"`
|
||||
UsedAccountID *int64 `json:"used_account_id"`
|
||||
|
||||
HTTPStatusCode int `json:"http_status_code"`
|
||||
UpstreamRequestID string `json:"upstream_request_id"`
|
||||
|
||||
ResponsePreview string `json:"response_preview"`
|
||||
ResponseTruncated bool `json:"response_truncated"`
|
||||
|
||||
ErrorMessage string `json:"error_message"`
|
||||
|
||||
StartedAt time.Time `json:"started_at"`
|
||||
FinishedAt time.Time `json:"finished_at"`
|
||||
DurationMs int64 `json:"duration_ms"`
|
||||
}
|
||||
259
backend/internal/service/ops_port.go
Normal file
259
backend/internal/service/ops_port.go
Normal file
@@ -0,0 +1,259 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
type OpsRepository interface {
|
||||
InsertErrorLog(ctx context.Context, input *OpsInsertErrorLogInput) (int64, error)
|
||||
ListErrorLogs(ctx context.Context, filter *OpsErrorLogFilter) (*OpsErrorLogList, error)
|
||||
GetErrorLogByID(ctx context.Context, id int64) (*OpsErrorLogDetail, error)
|
||||
ListRequestDetails(ctx context.Context, filter *OpsRequestDetailFilter) ([]*OpsRequestDetail, int64, error)
|
||||
|
||||
InsertRetryAttempt(ctx context.Context, input *OpsInsertRetryAttemptInput) (int64, error)
|
||||
UpdateRetryAttempt(ctx context.Context, input *OpsUpdateRetryAttemptInput) error
|
||||
GetLatestRetryAttemptForError(ctx context.Context, sourceErrorID int64) (*OpsRetryAttempt, error)
|
||||
ListRetryAttemptsByErrorID(ctx context.Context, sourceErrorID int64, limit int) ([]*OpsRetryAttempt, error)
|
||||
UpdateErrorResolution(ctx context.Context, errorID int64, resolved bool, resolvedByUserID *int64, resolvedRetryID *int64, resolvedAt *time.Time) error
|
||||
|
||||
// Lightweight window stats (for realtime WS / quick sampling).
|
||||
GetWindowStats(ctx context.Context, filter *OpsDashboardFilter) (*OpsWindowStats, error)
|
||||
// Lightweight realtime traffic summary (for the Ops dashboard header card).
|
||||
GetRealtimeTrafficSummary(ctx context.Context, filter *OpsDashboardFilter) (*OpsRealtimeTrafficSummary, error)
|
||||
|
||||
GetDashboardOverview(ctx context.Context, filter *OpsDashboardFilter) (*OpsDashboardOverview, error)
|
||||
GetThroughputTrend(ctx context.Context, filter *OpsDashboardFilter, bucketSeconds int) (*OpsThroughputTrendResponse, error)
|
||||
GetLatencyHistogram(ctx context.Context, filter *OpsDashboardFilter) (*OpsLatencyHistogramResponse, error)
|
||||
GetErrorTrend(ctx context.Context, filter *OpsDashboardFilter, bucketSeconds int) (*OpsErrorTrendResponse, error)
|
||||
GetErrorDistribution(ctx context.Context, filter *OpsDashboardFilter) (*OpsErrorDistributionResponse, error)
|
||||
|
||||
InsertSystemMetrics(ctx context.Context, input *OpsInsertSystemMetricsInput) error
|
||||
GetLatestSystemMetrics(ctx context.Context, windowMinutes int) (*OpsSystemMetricsSnapshot, error)
|
||||
|
||||
UpsertJobHeartbeat(ctx context.Context, input *OpsUpsertJobHeartbeatInput) error
|
||||
ListJobHeartbeats(ctx context.Context) ([]*OpsJobHeartbeat, error)
|
||||
|
||||
// Alerts (rules + events)
|
||||
ListAlertRules(ctx context.Context) ([]*OpsAlertRule, error)
|
||||
CreateAlertRule(ctx context.Context, input *OpsAlertRule) (*OpsAlertRule, error)
|
||||
UpdateAlertRule(ctx context.Context, input *OpsAlertRule) (*OpsAlertRule, error)
|
||||
DeleteAlertRule(ctx context.Context, id int64) error
|
||||
|
||||
ListAlertEvents(ctx context.Context, filter *OpsAlertEventFilter) ([]*OpsAlertEvent, error)
|
||||
GetAlertEventByID(ctx context.Context, eventID int64) (*OpsAlertEvent, error)
|
||||
GetActiveAlertEvent(ctx context.Context, ruleID int64) (*OpsAlertEvent, error)
|
||||
GetLatestAlertEvent(ctx context.Context, ruleID int64) (*OpsAlertEvent, error)
|
||||
CreateAlertEvent(ctx context.Context, event *OpsAlertEvent) (*OpsAlertEvent, error)
|
||||
UpdateAlertEventStatus(ctx context.Context, eventID int64, status string, resolvedAt *time.Time) error
|
||||
UpdateAlertEventEmailSent(ctx context.Context, eventID int64, emailSent bool) error
|
||||
|
||||
// Alert silences
|
||||
CreateAlertSilence(ctx context.Context, input *OpsAlertSilence) (*OpsAlertSilence, error)
|
||||
IsAlertSilenced(ctx context.Context, ruleID int64, platform string, groupID *int64, region *string, now time.Time) (bool, error)
|
||||
|
||||
// Pre-aggregation (hourly/daily) used for long-window dashboard performance.
|
||||
UpsertHourlyMetrics(ctx context.Context, startTime, endTime time.Time) error
|
||||
UpsertDailyMetrics(ctx context.Context, startTime, endTime time.Time) error
|
||||
GetLatestHourlyBucketStart(ctx context.Context) (time.Time, bool, error)
|
||||
GetLatestDailyBucketDate(ctx context.Context) (time.Time, bool, error)
|
||||
}
|
||||
|
||||
type OpsInsertErrorLogInput struct {
|
||||
RequestID string
|
||||
ClientRequestID string
|
||||
|
||||
UserID *int64
|
||||
APIKeyID *int64
|
||||
AccountID *int64
|
||||
GroupID *int64
|
||||
ClientIP *string
|
||||
|
||||
Platform string
|
||||
Model string
|
||||
RequestPath string
|
||||
Stream bool
|
||||
UserAgent string
|
||||
|
||||
ErrorPhase string
|
||||
ErrorType string
|
||||
Severity string
|
||||
StatusCode int
|
||||
IsBusinessLimited bool
|
||||
IsCountTokens bool // 是否为 count_tokens 请求
|
||||
|
||||
ErrorMessage string
|
||||
ErrorBody string
|
||||
|
||||
ErrorSource string
|
||||
ErrorOwner string
|
||||
|
||||
UpstreamStatusCode *int
|
||||
UpstreamErrorMessage *string
|
||||
UpstreamErrorDetail *string
|
||||
// UpstreamErrors captures all upstream error attempts observed during handling this request.
|
||||
// It is populated during request processing (gin context) and sanitized+serialized by OpsService.
|
||||
UpstreamErrors []*OpsUpstreamErrorEvent
|
||||
// UpstreamErrorsJSON is the sanitized JSON string stored into ops_error_logs.upstream_errors.
|
||||
// It is set by OpsService.RecordError before persisting.
|
||||
UpstreamErrorsJSON *string
|
||||
|
||||
TimeToFirstTokenMs *int64
|
||||
|
||||
RequestBodyJSON *string // sanitized json string (not raw bytes)
|
||||
RequestBodyTruncated bool
|
||||
RequestBodyBytes *int
|
||||
RequestHeadersJSON *string // optional json string
|
||||
|
||||
IsRetryable bool
|
||||
RetryCount int
|
||||
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
type OpsInsertRetryAttemptInput struct {
|
||||
RequestedByUserID int64
|
||||
SourceErrorID int64
|
||||
Mode string
|
||||
PinnedAccountID *int64
|
||||
|
||||
// running|queued etc.
|
||||
Status string
|
||||
StartedAt time.Time
|
||||
}
|
||||
|
||||
type OpsUpdateRetryAttemptInput struct {
|
||||
ID int64
|
||||
|
||||
// succeeded|failed
|
||||
Status string
|
||||
FinishedAt time.Time
|
||||
DurationMs int64
|
||||
|
||||
// Persisted execution results (best-effort)
|
||||
Success *bool
|
||||
HTTPStatusCode *int
|
||||
UpstreamRequestID *string
|
||||
UsedAccountID *int64
|
||||
ResponsePreview *string
|
||||
ResponseTruncated *bool
|
||||
|
||||
// Optional correlation (legacy fields kept)
|
||||
ResultRequestID *string
|
||||
ResultErrorID *int64
|
||||
|
||||
ErrorMessage *string
|
||||
}
|
||||
|
||||
type OpsInsertSystemMetricsInput struct {
|
||||
CreatedAt time.Time
|
||||
WindowMinutes int
|
||||
|
||||
Platform *string
|
||||
GroupID *int64
|
||||
|
||||
SuccessCount int64
|
||||
ErrorCountTotal int64
|
||||
BusinessLimitedCount int64
|
||||
ErrorCountSLA int64
|
||||
|
||||
UpstreamErrorCountExcl429529 int64
|
||||
Upstream429Count int64
|
||||
Upstream529Count int64
|
||||
|
||||
TokenConsumed int64
|
||||
|
||||
QPS *float64
|
||||
TPS *float64
|
||||
|
||||
DurationP50Ms *int
|
||||
DurationP90Ms *int
|
||||
DurationP95Ms *int
|
||||
DurationP99Ms *int
|
||||
DurationAvgMs *float64
|
||||
DurationMaxMs *int
|
||||
|
||||
TTFTP50Ms *int
|
||||
TTFTP90Ms *int
|
||||
TTFTP95Ms *int
|
||||
TTFTP99Ms *int
|
||||
TTFTAvgMs *float64
|
||||
TTFTMaxMs *int
|
||||
|
||||
CPUUsagePercent *float64
|
||||
MemoryUsedMB *int64
|
||||
MemoryTotalMB *int64
|
||||
MemoryUsagePercent *float64
|
||||
|
||||
DBOK *bool
|
||||
RedisOK *bool
|
||||
|
||||
RedisConnTotal *int
|
||||
RedisConnIdle *int
|
||||
|
||||
DBConnActive *int
|
||||
DBConnIdle *int
|
||||
DBConnWaiting *int
|
||||
|
||||
GoroutineCount *int
|
||||
ConcurrencyQueueDepth *int
|
||||
}
|
||||
|
||||
type OpsSystemMetricsSnapshot struct {
|
||||
ID int64 `json:"id"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
WindowMinutes int `json:"window_minutes"`
|
||||
|
||||
CPUUsagePercent *float64 `json:"cpu_usage_percent"`
|
||||
MemoryUsedMB *int64 `json:"memory_used_mb"`
|
||||
MemoryTotalMB *int64 `json:"memory_total_mb"`
|
||||
MemoryUsagePercent *float64 `json:"memory_usage_percent"`
|
||||
|
||||
DBOK *bool `json:"db_ok"`
|
||||
RedisOK *bool `json:"redis_ok"`
|
||||
|
||||
// Config-derived limits (best-effort). These are not historical metrics; they help UI render "current vs max".
|
||||
DBMaxOpenConns *int `json:"db_max_open_conns"`
|
||||
RedisPoolSize *int `json:"redis_pool_size"`
|
||||
|
||||
RedisConnTotal *int `json:"redis_conn_total"`
|
||||
RedisConnIdle *int `json:"redis_conn_idle"`
|
||||
|
||||
DBConnActive *int `json:"db_conn_active"`
|
||||
DBConnIdle *int `json:"db_conn_idle"`
|
||||
DBConnWaiting *int `json:"db_conn_waiting"`
|
||||
|
||||
GoroutineCount *int `json:"goroutine_count"`
|
||||
ConcurrencyQueueDepth *int `json:"concurrency_queue_depth"`
|
||||
}
|
||||
|
||||
type OpsUpsertJobHeartbeatInput struct {
|
||||
JobName string
|
||||
|
||||
LastRunAt *time.Time
|
||||
LastSuccessAt *time.Time
|
||||
LastErrorAt *time.Time
|
||||
LastError *string
|
||||
LastDurationMs *int64
|
||||
}
|
||||
|
||||
type OpsJobHeartbeat struct {
|
||||
JobName string `json:"job_name"`
|
||||
|
||||
LastRunAt *time.Time `json:"last_run_at"`
|
||||
LastSuccessAt *time.Time `json:"last_success_at"`
|
||||
LastErrorAt *time.Time `json:"last_error_at"`
|
||||
LastError *string `json:"last_error"`
|
||||
LastDurationMs *int64 `json:"last_duration_ms"`
|
||||
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
type OpsWindowStats struct {
|
||||
StartTime time.Time `json:"start_time"`
|
||||
EndTime time.Time `json:"end_time"`
|
||||
|
||||
SuccessCount int64 `json:"success_count"`
|
||||
ErrorCountTotal int64 `json:"error_count_total"`
|
||||
TokenConsumed int64 `json:"token_consumed"`
|
||||
}
|
||||
40
backend/internal/service/ops_query_mode.go
Normal file
40
backend/internal/service/ops_query_mode.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type OpsQueryMode string
|
||||
|
||||
const (
|
||||
OpsQueryModeAuto OpsQueryMode = "auto"
|
||||
OpsQueryModeRaw OpsQueryMode = "raw"
|
||||
OpsQueryModePreagg OpsQueryMode = "preagg"
|
||||
)
|
||||
|
||||
// ErrOpsPreaggregatedNotPopulated indicates that raw logs exist for a window, but the
|
||||
// pre-aggregation tables are not populated yet. This is primarily used to implement
|
||||
// the forced `preagg` mode UX.
|
||||
var ErrOpsPreaggregatedNotPopulated = errors.New("ops pre-aggregated tables not populated")
|
||||
|
||||
func ParseOpsQueryMode(raw string) OpsQueryMode {
|
||||
v := strings.ToLower(strings.TrimSpace(raw))
|
||||
switch v {
|
||||
case string(OpsQueryModeRaw):
|
||||
return OpsQueryModeRaw
|
||||
case string(OpsQueryModePreagg):
|
||||
return OpsQueryModePreagg
|
||||
default:
|
||||
return OpsQueryModeAuto
|
||||
}
|
||||
}
|
||||
|
||||
func (m OpsQueryMode) IsValid() bool {
|
||||
switch m {
|
||||
case OpsQueryModeAuto, OpsQueryModeRaw, OpsQueryModePreagg:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
36
backend/internal/service/ops_realtime.go
Normal file
36
backend/internal/service/ops_realtime.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// IsRealtimeMonitoringEnabled returns true when realtime ops features are enabled.
|
||||
//
|
||||
// This is a soft switch controlled by the DB setting `ops_realtime_monitoring_enabled`,
|
||||
// and it is also gated by the hard switch/soft switch of overall ops monitoring.
|
||||
func (s *OpsService) IsRealtimeMonitoringEnabled(ctx context.Context) bool {
|
||||
if !s.IsMonitoringEnabled(ctx) {
|
||||
return false
|
||||
}
|
||||
if s.settingRepo == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
value, err := s.settingRepo.GetValue(ctx, SettingKeyOpsRealtimeMonitoringEnabled)
|
||||
if err != nil {
|
||||
// Default enabled when key is missing; fail-open on transient errors.
|
||||
if errors.Is(err, ErrSettingNotFound) {
|
||||
return true
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
switch strings.ToLower(strings.TrimSpace(value)) {
|
||||
case "false", "0", "off", "disabled":
|
||||
return false
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user