First commit

This commit is contained in:
shaw
2025-12-18 13:50:39 +08:00
parent 569f4882e5
commit 642842c29e
218 changed files with 86902 additions and 0 deletions

View File

@@ -0,0 +1,284 @@
package service
import (
"context"
"errors"
"fmt"
"sub2api/internal/model"
"sub2api/internal/repository"
"gorm.io/gorm"
)
var (
ErrAccountNotFound = errors.New("account not found")
)
// CreateAccountRequest 创建账号请求
type CreateAccountRequest struct {
Name string `json:"name"`
Platform string `json:"platform"`
Type string `json:"type"`
Credentials map[string]interface{} `json:"credentials"`
Extra map[string]interface{} `json:"extra"`
ProxyID *int64 `json:"proxy_id"`
Concurrency int `json:"concurrency"`
Priority int `json:"priority"`
GroupIDs []int64 `json:"group_ids"`
}
// UpdateAccountRequest 更新账号请求
type UpdateAccountRequest struct {
Name *string `json:"name"`
Credentials *map[string]interface{} `json:"credentials"`
Extra *map[string]interface{} `json:"extra"`
ProxyID *int64 `json:"proxy_id"`
Concurrency *int `json:"concurrency"`
Priority *int `json:"priority"`
Status *string `json:"status"`
GroupIDs *[]int64 `json:"group_ids"`
}
// AccountService 账号管理服务
type AccountService struct {
accountRepo *repository.AccountRepository
groupRepo *repository.GroupRepository
}
// NewAccountService 创建账号服务实例
func NewAccountService(accountRepo *repository.AccountRepository, groupRepo *repository.GroupRepository) *AccountService {
return &AccountService{
accountRepo: accountRepo,
groupRepo: groupRepo,
}
}
// Create 创建账号
func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (*model.Account, error) {
// 验证分组是否存在(如果指定了分组)
if len(req.GroupIDs) > 0 {
for _, groupID := range req.GroupIDs {
_, err := s.groupRepo.GetByID(ctx, groupID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, fmt.Errorf("group %d not found", groupID)
}
return nil, fmt.Errorf("get group: %w", err)
}
}
}
// 创建账号
account := &model.Account{
Name: req.Name,
Platform: req.Platform,
Type: req.Type,
Credentials: req.Credentials,
Extra: req.Extra,
ProxyID: req.ProxyID,
Concurrency: req.Concurrency,
Priority: req.Priority,
Status: model.StatusActive,
}
if err := s.accountRepo.Create(ctx, account); err != nil {
return nil, fmt.Errorf("create account: %w", err)
}
// 绑定分组
if len(req.GroupIDs) > 0 {
if err := s.accountRepo.BindGroups(ctx, account.ID, req.GroupIDs); err != nil {
return nil, fmt.Errorf("bind groups: %w", err)
}
}
return account, nil
}
// GetByID 根据ID获取账号
func (s *AccountService) GetByID(ctx context.Context, id int64) (*model.Account, error) {
account, err := s.accountRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrAccountNotFound
}
return nil, fmt.Errorf("get account: %w", err)
}
return account, nil
}
// List 获取账号列表
func (s *AccountService) List(ctx context.Context, params repository.PaginationParams) ([]model.Account, *repository.PaginationResult, error) {
accounts, pagination, err := s.accountRepo.List(ctx, params)
if err != nil {
return nil, nil, fmt.Errorf("list accounts: %w", err)
}
return accounts, pagination, nil
}
// ListByPlatform 根据平台获取账号列表
func (s *AccountService) ListByPlatform(ctx context.Context, platform string) ([]model.Account, error) {
accounts, err := s.accountRepo.ListByPlatform(ctx, platform)
if err != nil {
return nil, fmt.Errorf("list accounts by platform: %w", err)
}
return accounts, nil
}
// ListByGroup 根据分组获取账号列表
func (s *AccountService) ListByGroup(ctx context.Context, groupID int64) ([]model.Account, error) {
accounts, err := s.accountRepo.ListByGroup(ctx, groupID)
if err != nil {
return nil, fmt.Errorf("list accounts by group: %w", err)
}
return accounts, nil
}
// Update 更新账号
func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccountRequest) (*model.Account, error) {
account, err := s.accountRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrAccountNotFound
}
return nil, fmt.Errorf("get account: %w", err)
}
// 更新字段
if req.Name != nil {
account.Name = *req.Name
}
if req.Credentials != nil {
account.Credentials = *req.Credentials
}
if req.Extra != nil {
account.Extra = *req.Extra
}
if req.ProxyID != nil {
account.ProxyID = req.ProxyID
}
if req.Concurrency != nil {
account.Concurrency = *req.Concurrency
}
if req.Priority != nil {
account.Priority = *req.Priority
}
if req.Status != nil {
account.Status = *req.Status
}
if err := s.accountRepo.Update(ctx, account); err != nil {
return nil, fmt.Errorf("update account: %w", err)
}
// 更新分组绑定
if req.GroupIDs != nil {
// 验证分组是否存在
for _, groupID := range *req.GroupIDs {
_, err := s.groupRepo.GetByID(ctx, groupID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, fmt.Errorf("group %d not found", groupID)
}
return nil, fmt.Errorf("get group: %w", err)
}
}
if err := s.accountRepo.BindGroups(ctx, account.ID, *req.GroupIDs); err != nil {
return nil, fmt.Errorf("bind groups: %w", err)
}
}
return account, nil
}
// Delete 删除账号
func (s *AccountService) Delete(ctx context.Context, id int64) error {
// 检查账号是否存在
_, err := s.accountRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrAccountNotFound
}
return fmt.Errorf("get account: %w", err)
}
if err := s.accountRepo.Delete(ctx, id); err != nil {
return fmt.Errorf("delete account: %w", err)
}
return nil
}
// UpdateStatus 更新账号状态
func (s *AccountService) UpdateStatus(ctx context.Context, id int64, status string, errorMessage string) error {
account, err := s.accountRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrAccountNotFound
}
return fmt.Errorf("get account: %w", err)
}
account.Status = status
account.ErrorMessage = errorMessage
if err := s.accountRepo.Update(ctx, account); err != nil {
return fmt.Errorf("update account: %w", err)
}
return nil
}
// UpdateLastUsed 更新最后使用时间
func (s *AccountService) UpdateLastUsed(ctx context.Context, id int64) error {
if err := s.accountRepo.UpdateLastUsed(ctx, id); err != nil {
return fmt.Errorf("update last used: %w", err)
}
return nil
}
// GetCredential 获取账号凭证(安全访问)
func (s *AccountService) GetCredential(ctx context.Context, id int64, key string) (string, error) {
account, err := s.accountRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return "", ErrAccountNotFound
}
return "", fmt.Errorf("get account: %w", err)
}
return account.GetCredential(key), nil
}
// TestCredentials 测试账号凭证是否有效(需要实现具体平台的测试逻辑)
func (s *AccountService) TestCredentials(ctx context.Context, id int64) error {
account, err := s.accountRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrAccountNotFound
}
return fmt.Errorf("get account: %w", err)
}
// 根据平台执行不同的测试逻辑
switch account.Platform {
case model.PlatformAnthropic:
// TODO: 测试Anthropic API凭证
return nil
case model.PlatformOpenAI:
// TODO: 测试OpenAI API凭证
return nil
case model.PlatformGemini:
// TODO: 测试Gemini API凭证
return nil
default:
return fmt.Errorf("unsupported platform: %s", account.Platform)
}
}

View File

@@ -0,0 +1,314 @@
package service
import (
"bufio"
"bytes"
"crypto/rand"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"sub2api/internal/repository"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
const (
testClaudeAPIURL = "https://api.anthropic.com/v1/messages"
testModel = "claude-sonnet-4-5-20250929"
)
// TestEvent represents a SSE event for account testing
type TestEvent struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
Model string `json:"model,omitempty"`
Success bool `json:"success,omitempty"`
Error string `json:"error,omitempty"`
}
// AccountTestService handles account testing operations
type AccountTestService struct {
repos *repository.Repositories
oauthService *OAuthService
httpClient *http.Client
}
// NewAccountTestService creates a new AccountTestService
func NewAccountTestService(repos *repository.Repositories, oauthService *OAuthService) *AccountTestService {
return &AccountTestService{
repos: repos,
oauthService: oauthService,
httpClient: &http.Client{
Timeout: 60 * time.Second,
},
}
}
// generateSessionString generates a Claude Code style session string
func generateSessionString() string {
bytes := make([]byte, 32)
rand.Read(bytes)
hex64 := hex.EncodeToString(bytes)
sessionUUID := uuid.New().String()
return fmt.Sprintf("user_%s_account__session_%s", hex64, sessionUUID)
}
// createTestPayload creates a minimal test request payload for OAuth/Setup Token accounts
func createTestPayload() map[string]interface{} {
return map[string]interface{}{
"model": testModel,
"messages": []map[string]interface{}{
{
"role": "user",
"content": []map[string]interface{}{
{
"type": "text",
"text": "hi",
"cache_control": map[string]string{
"type": "ephemeral",
},
},
},
},
},
"system": []map[string]interface{}{
{
"type": "text",
"text": "You are Claude Code, Anthropic's official CLI for Claude.",
"cache_control": map[string]string{
"type": "ephemeral",
},
},
},
"metadata": map[string]string{
"user_id": generateSessionString(),
},
"max_tokens": 1024,
"temperature": 1,
"stream": true,
}
}
// createApiKeyTestPayload creates a simpler test request payload for API Key accounts
func createApiKeyTestPayload(model string) map[string]interface{} {
return map[string]interface{}{
"model": model,
"messages": []map[string]interface{}{
{
"role": "user",
"content": "hi",
},
},
"max_tokens": 1024,
"stream": true,
}
}
// TestAccountConnection tests an account's connection by sending a test request
func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int64) error {
ctx := c.Request.Context()
// Get account
account, err := s.repos.Account.GetByID(ctx, accountID)
if err != nil {
return s.sendErrorAndEnd(c, "Account not found")
}
// Determine authentication method based on account type
var authToken string
var authType string // "bearer" for OAuth, "apikey" for API Key
var apiURL string
if account.IsOAuth() {
// OAuth or Setup Token account
authType = "bearer"
apiURL = testClaudeAPIURL
authToken = account.GetCredential("access_token")
if authToken == "" {
return s.sendErrorAndEnd(c, "No access token available")
}
// Check if token needs refresh
needRefresh := false
if expiresAtStr := account.GetCredential("expires_at"); expiresAtStr != "" {
expiresAt, err := strconv.ParseInt(expiresAtStr, 10, 64)
if err == nil && time.Now().Unix()+300 > expiresAt { // 5 minute buffer
needRefresh = true
}
}
if needRefresh && s.oauthService != nil {
tokenInfo, err := s.oauthService.RefreshAccountToken(ctx, account)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to refresh token: %s", err.Error()))
}
authToken = tokenInfo.AccessToken
}
} else if account.Type == "apikey" {
// API Key account
authType = "apikey"
authToken = account.GetCredential("api_key")
if authToken == "" {
return s.sendErrorAndEnd(c, "No API key available")
}
// Get base URL (use default if not set)
apiURL = account.GetBaseURL()
if apiURL == "" {
apiURL = "https://api.anthropic.com"
}
// Append /v1/messages endpoint
apiURL = strings.TrimSuffix(apiURL, "/") + "/v1/messages"
} else {
return s.sendErrorAndEnd(c, fmt.Sprintf("Unsupported account type: %s", account.Type))
}
// Set SSE headers
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Writer.Flush()
// Create test request payload
var payload map[string]interface{}
var actualModel string
if authType == "apikey" {
// Use simpler payload for API Key (without Claude Code specific fields)
// Apply model mapping if configured
actualModel = account.GetMappedModel(testModel)
payload = createApiKeyTestPayload(actualModel)
} else {
actualModel = testModel
payload = createTestPayload()
}
payloadBytes, _ := json.Marshal(payload)
// Send test_start event with model info
s.sendEvent(c, TestEvent{Type: "test_start", Model: actualModel})
req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewReader(payloadBytes))
if err != nil {
return s.sendErrorAndEnd(c, "Failed to create request")
}
// Set headers based on auth type
req.Header.Set("Content-Type", "application/json")
req.Header.Set("anthropic-version", "2023-06-01")
if authType == "bearer" {
req.Header.Set("Authorization", "Bearer "+authToken)
req.Header.Set("anthropic-beta", "prompt-caching-2024-07-31,interleaved-thinking-2025-05-14,output-128k-2025-02-19")
} else {
// API Key uses x-api-key header
req.Header.Set("x-api-key", authToken)
}
// Configure proxy if account has one
transport := http.DefaultTransport.(*http.Transport).Clone()
if account.ProxyID != nil && account.Proxy != nil {
proxyURL := account.Proxy.URL()
if proxyURL != "" {
if parsedURL, err := url.Parse(proxyURL); err == nil {
transport.Proxy = http.ProxyURL(parsedURL)
}
}
}
client := &http.Client{
Transport: transport,
Timeout: 60 * time.Second,
}
resp, err := client.Do(req)
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return s.sendErrorAndEnd(c, fmt.Sprintf("API returned %d: %s", resp.StatusCode, string(body)))
}
// Process SSE stream
return s.processStream(c, resp.Body)
}
// processStream processes the SSE stream from Claude API
func (s *AccountTestService) processStream(c *gin.Context, body io.Reader) error {
reader := bufio.NewReader(body)
for {
line, err := reader.ReadString('\n')
if err != nil {
if err == io.EOF {
// Stream ended, send complete event
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
}
return s.sendErrorAndEnd(c, fmt.Sprintf("Stream read error: %s", err.Error()))
}
line = strings.TrimSpace(line)
if line == "" || !strings.HasPrefix(line, "data: ") {
continue
}
jsonStr := strings.TrimPrefix(line, "data: ")
if jsonStr == "[DONE]" {
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
}
var data map[string]interface{}
if err := json.Unmarshal([]byte(jsonStr), &data); err != nil {
continue
}
eventType, _ := data["type"].(string)
switch eventType {
case "content_block_delta":
if delta, ok := data["delta"].(map[string]interface{}); ok {
if text, ok := delta["text"].(string); ok {
s.sendEvent(c, TestEvent{Type: "content", Text: text})
}
}
case "message_stop":
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
case "error":
errorMsg := "Unknown error"
if errData, ok := data["error"].(map[string]interface{}); ok {
if msg, ok := errData["message"].(string); ok {
errorMsg = msg
}
}
return s.sendErrorAndEnd(c, errorMsg)
}
}
}
// sendEvent sends a SSE event to the client
func (s *AccountTestService) sendEvent(c *gin.Context, event TestEvent) {
eventJSON, _ := json.Marshal(event)
fmt.Fprintf(c.Writer, "data: %s\n\n", eventJSON)
c.Writer.Flush()
}
// sendErrorAndEnd sends an error event and ends the stream
func (s *AccountTestService) sendErrorAndEnd(c *gin.Context, errorMsg string) error {
log.Printf("Account test error: %s", errorMsg)
s.sendEvent(c, TestEvent{Type: "error", Error: errorMsg})
return fmt.Errorf(errorMsg)
}

View File

@@ -0,0 +1,345 @@
package service
import (
"context"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"net/url"
"sync"
"time"
"sub2api/internal/model"
"sub2api/internal/repository"
)
// usageCache 用于缓存usage数据
type usageCache struct {
data *UsageInfo
timestamp time.Time
}
var (
usageCacheMap = sync.Map{}
cacheTTL = 10 * time.Minute
)
// WindowStats 窗口期统计
type WindowStats struct {
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
Cost float64 `json:"cost"`
}
// UsageProgress 使用量进度
type UsageProgress struct {
Utilization float64 `json:"utilization"` // 使用率百分比 (0-100+100表示100%)
ResetsAt *time.Time `json:"resets_at"` // 重置时间
RemainingSeconds int `json:"remaining_seconds"` // 距重置剩余秒数
WindowStats *WindowStats `json:"window_stats,omitempty"` // 窗口期统计(从窗口开始到当前的使用量)
}
// UsageInfo 账号使用量信息
type UsageInfo struct {
UpdatedAt *time.Time `json:"updated_at,omitempty"` // 更新时间
FiveHour *UsageProgress `json:"five_hour"` // 5小时窗口
SevenDay *UsageProgress `json:"seven_day,omitempty"` // 7天窗口
SevenDaySonnet *UsageProgress `json:"seven_day_sonnet,omitempty"` // 7天Sonnet窗口
}
// ClaudeUsageResponse Anthropic API返回的usage结构
type ClaudeUsageResponse struct {
FiveHour struct {
Utilization float64 `json:"utilization"`
ResetsAt string `json:"resets_at"`
} `json:"five_hour"`
SevenDay struct {
Utilization float64 `json:"utilization"`
ResetsAt string `json:"resets_at"`
} `json:"seven_day"`
SevenDaySonnet struct {
Utilization float64 `json:"utilization"`
ResetsAt string `json:"resets_at"`
} `json:"seven_day_sonnet"`
}
// AccountUsageService 账号使用量查询服务
type AccountUsageService struct {
repos *repository.Repositories
oauthService *OAuthService
httpClient *http.Client
}
// NewAccountUsageService 创建AccountUsageService实例
func NewAccountUsageService(repos *repository.Repositories, oauthService *OAuthService) *AccountUsageService {
return &AccountUsageService{
repos: repos,
oauthService: oauthService,
httpClient: &http.Client{
Timeout: 30 * time.Second,
},
}
}
// GetUsage 获取账号使用量
// OAuth账号: 调用Anthropic API获取真实数据需要profile scope缓存10分钟
// Setup Token账号: 根据session_window推算5h窗口7d数据不可用没有profile scope
// API Key账号: 不支持usage查询
func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*UsageInfo, error) {
account, err := s.repos.Account.GetByID(ctx, accountID)
if err != nil {
return nil, fmt.Errorf("get account failed: %w", err)
}
// 只有oauth类型账号可以通过API获取usage有profile scope
if account.CanGetUsage() {
// 检查缓存
if cached, ok := usageCacheMap.Load(accountID); ok {
cache := cached.(*usageCache)
if time.Since(cache.timestamp) < cacheTTL {
return cache.data, nil
}
}
// 从API获取数据
usage, err := s.fetchOAuthUsage(ctx, account)
if err != nil {
return nil, err
}
// 添加5h窗口统计数据
s.addWindowStats(ctx, account, usage)
// 缓存结果
usageCacheMap.Store(accountID, &usageCache{
data: usage,
timestamp: time.Now(),
})
return usage, nil
}
// Setup Token账号根据session_window推算没有profile scope无法调用usage API
if account.Type == model.AccountTypeSetupToken {
usage := s.estimateSetupTokenUsage(account)
// 添加窗口统计
s.addWindowStats(ctx, account, usage)
return usage, nil
}
// API Key账号不支持usage查询
return nil, fmt.Errorf("account type %s does not support usage query", account.Type)
}
// addWindowStats 为usage数据添加窗口期统计
func (s *AccountUsageService) addWindowStats(ctx context.Context, account *model.Account, usage *UsageInfo) {
if usage.FiveHour == nil {
return
}
// 使用session_window_start作为统计起始时间
var startTime time.Time
if account.SessionWindowStart != nil {
startTime = *account.SessionWindowStart
} else {
// 如果没有窗口信息使用5小时前作为默认
startTime = time.Now().Add(-5 * time.Hour)
}
stats, err := s.repos.UsageLog.GetAccountWindowStats(ctx, account.ID, startTime)
if err != nil {
log.Printf("Failed to get window stats for account %d: %v", account.ID, err)
return
}
usage.FiveHour.WindowStats = &WindowStats{
Requests: stats.Requests,
Tokens: stats.Tokens,
Cost: stats.Cost,
}
}
// GetTodayStats 获取账号今日统计
func (s *AccountUsageService) GetTodayStats(ctx context.Context, accountID int64) (*WindowStats, error) {
stats, err := s.repos.UsageLog.GetAccountTodayStats(ctx, accountID)
if err != nil {
return nil, fmt.Errorf("get today stats failed: %w", err)
}
return &WindowStats{
Requests: stats.Requests,
Tokens: stats.Tokens,
Cost: stats.Cost,
}, nil
}
// fetchOAuthUsage 从Anthropic API获取OAuth账号的使用量
func (s *AccountUsageService) fetchOAuthUsage(ctx context.Context, account *model.Account) (*UsageInfo, error) {
// 获取access token从credentials中获取
accessToken := account.GetCredential("access_token")
if accessToken == "" {
return nil, fmt.Errorf("no access token available")
}
// 获取代理配置
transport := http.DefaultTransport.(*http.Transport).Clone()
if account.ProxyID != nil && account.Proxy != nil {
proxyURL := account.Proxy.URL()
if proxyURL != "" {
if parsedURL, err := url.Parse(proxyURL); err == nil {
transport.Proxy = http.ProxyURL(parsedURL)
}
}
}
client := &http.Client{
Transport: transport,
Timeout: 30 * time.Second,
}
// 构建请求
req, err := http.NewRequestWithContext(ctx, "GET", "https://api.anthropic.com/api/oauth/usage", nil)
if err != nil {
return nil, fmt.Errorf("create request failed: %w", err)
}
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("anthropic-beta", "oauth-2025-04-20")
// 发送请求
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("API returned status %d: %s", resp.StatusCode, string(body))
}
// 解析响应
var usageResp ClaudeUsageResponse
if err := json.NewDecoder(resp.Body).Decode(&usageResp); err != nil {
return nil, fmt.Errorf("decode response failed: %w", err)
}
// 转换为UsageInfo
now := time.Now()
return s.buildUsageInfo(&usageResp, &now), nil
}
// parseTime 尝试多种格式解析时间
func parseTime(s string) (time.Time, error) {
formats := []string{
time.RFC3339,
time.RFC3339Nano,
"2006-01-02T15:04:05Z",
"2006-01-02T15:04:05.000Z",
}
for _, format := range formats {
if t, err := time.Parse(format, s); err == nil {
return t, nil
}
}
return time.Time{}, fmt.Errorf("unable to parse time: %s", s)
}
// buildUsageInfo 构建UsageInfo
func (s *AccountUsageService) buildUsageInfo(resp *ClaudeUsageResponse, updatedAt *time.Time) *UsageInfo {
info := &UsageInfo{
UpdatedAt: updatedAt,
}
// 5小时窗口
if resp.FiveHour.ResetsAt != "" {
if fiveHourReset, err := parseTime(resp.FiveHour.ResetsAt); err == nil {
info.FiveHour = &UsageProgress{
Utilization: resp.FiveHour.Utilization,
ResetsAt: &fiveHourReset,
RemainingSeconds: int(time.Until(fiveHourReset).Seconds()),
}
} else {
log.Printf("Failed to parse FiveHour.ResetsAt: %s, error: %v", resp.FiveHour.ResetsAt, err)
// 即使解析失败也返回utilization
info.FiveHour = &UsageProgress{
Utilization: resp.FiveHour.Utilization,
}
}
}
// 7天窗口
if resp.SevenDay.ResetsAt != "" {
if sevenDayReset, err := parseTime(resp.SevenDay.ResetsAt); err == nil {
info.SevenDay = &UsageProgress{
Utilization: resp.SevenDay.Utilization,
ResetsAt: &sevenDayReset,
RemainingSeconds: int(time.Until(sevenDayReset).Seconds()),
}
} else {
log.Printf("Failed to parse SevenDay.ResetsAt: %s, error: %v", resp.SevenDay.ResetsAt, err)
info.SevenDay = &UsageProgress{
Utilization: resp.SevenDay.Utilization,
}
}
}
// 7天Sonnet窗口
if resp.SevenDaySonnet.ResetsAt != "" {
if sonnetReset, err := parseTime(resp.SevenDaySonnet.ResetsAt); err == nil {
info.SevenDaySonnet = &UsageProgress{
Utilization: resp.SevenDaySonnet.Utilization,
ResetsAt: &sonnetReset,
RemainingSeconds: int(time.Until(sonnetReset).Seconds()),
}
} else {
log.Printf("Failed to parse SevenDaySonnet.ResetsAt: %s, error: %v", resp.SevenDaySonnet.ResetsAt, err)
info.SevenDaySonnet = &UsageProgress{
Utilization: resp.SevenDaySonnet.Utilization,
}
}
}
return info
}
// estimateSetupTokenUsage 根据session_window推算Setup Token账号的使用量
func (s *AccountUsageService) estimateSetupTokenUsage(account *model.Account) *UsageInfo {
info := &UsageInfo{}
// 如果有session_window信息
if account.SessionWindowEnd != nil {
remaining := int(time.Until(*account.SessionWindowEnd).Seconds())
if remaining < 0 {
remaining = 0
}
// 根据状态估算使用率 (百分比形式100 = 100%)
var utilization float64
switch account.SessionWindowStatus {
case "rejected":
utilization = 100.0
case "allowed_warning":
utilization = 80.0
default:
utilization = 0.0
}
info.FiveHour = &UsageProgress{
Utilization: utilization,
ResetsAt: account.SessionWindowEnd,
RemainingSeconds: remaining,
}
} else {
// 没有窗口信息,返回空数据
info.FiveHour = &UsageProgress{
Utilization: 0,
RemainingSeconds: 0,
}
}
// Setup Token无法获取7d数据
return info
}

View File

@@ -0,0 +1,989 @@
package service
import (
"context"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"time"
"sub2api/internal/model"
"sub2api/internal/repository"
"golang.org/x/net/proxy"
"gorm.io/gorm"
)
// AdminService interface defines admin management operations
type AdminService interface {
// User management
ListUsers(ctx context.Context, page, pageSize int, status, role, search string) ([]model.User, int64, error)
GetUser(ctx context.Context, id int64) (*model.User, error)
CreateUser(ctx context.Context, input *CreateUserInput) (*model.User, error)
UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*model.User, error)
DeleteUser(ctx context.Context, id int64) error
UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string) (*model.User, error)
GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]model.ApiKey, int64, error)
GetUserUsageStats(ctx context.Context, userID int64, period string) (interface{}, error)
// Group management
ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]model.Group, int64, error)
GetAllGroups(ctx context.Context) ([]model.Group, error)
GetAllGroupsByPlatform(ctx context.Context, platform string) ([]model.Group, error)
GetGroup(ctx context.Context, id int64) (*model.Group, error)
CreateGroup(ctx context.Context, input *CreateGroupInput) (*model.Group, error)
UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*model.Group, error)
DeleteGroup(ctx context.Context, id int64) error
GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]model.ApiKey, int64, error)
// Account management
ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]model.Account, int64, error)
GetAccount(ctx context.Context, id int64) (*model.Account, error)
CreateAccount(ctx context.Context, input *CreateAccountInput) (*model.Account, error)
UpdateAccount(ctx context.Context, id int64, input *UpdateAccountInput) (*model.Account, error)
DeleteAccount(ctx context.Context, id int64) error
RefreshAccountCredentials(ctx context.Context, id int64) (*model.Account, error)
ClearAccountError(ctx context.Context, id int64) (*model.Account, error)
SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*model.Account, error)
// Proxy management
ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]model.Proxy, int64, error)
GetAllProxies(ctx context.Context) ([]model.Proxy, error)
GetAllProxiesWithAccountCount(ctx context.Context) ([]model.ProxyWithAccountCount, error)
GetProxy(ctx context.Context, id int64) (*model.Proxy, error)
CreateProxy(ctx context.Context, input *CreateProxyInput) (*model.Proxy, error)
UpdateProxy(ctx context.Context, id int64, input *UpdateProxyInput) (*model.Proxy, error)
DeleteProxy(ctx context.Context, id int64) error
GetProxyAccounts(ctx context.Context, proxyID int64, page, pageSize int) ([]model.Account, int64, error)
CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error)
TestProxy(ctx context.Context, id int64) (*ProxyTestResult, error)
// Redeem code management
ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]model.RedeemCode, int64, error)
GetRedeemCode(ctx context.Context, id int64) (*model.RedeemCode, error)
GenerateRedeemCodes(ctx context.Context, input *GenerateRedeemCodesInput) ([]model.RedeemCode, error)
DeleteRedeemCode(ctx context.Context, id int64) error
BatchDeleteRedeemCodes(ctx context.Context, ids []int64) (int64, error)
ExpireRedeemCode(ctx context.Context, id int64) (*model.RedeemCode, error)
}
// Input types for admin operations
type CreateUserInput struct {
Email string
Password string
Balance float64
Concurrency int
AllowedGroups []int64
}
type UpdateUserInput struct {
Email string
Password string
Balance *float64 // 使用指针区分"未提供"和"设置为0"
Concurrency *int // 使用指针区分"未提供"和"设置为0"
Status string
AllowedGroups *[]int64 // 使用指针区分"未提供"和"设置为空数组"
}
type CreateGroupInput struct {
Name string
Description string
Platform string
RateMultiplier float64
IsExclusive bool
SubscriptionType string // standard/subscription
DailyLimitUSD *float64 // 日限额 (USD)
WeeklyLimitUSD *float64 // 周限额 (USD)
MonthlyLimitUSD *float64 // 月限额 (USD)
}
type UpdateGroupInput struct {
Name string
Description string
Platform string
RateMultiplier *float64 // 使用指针以支持设置为0
IsExclusive *bool
Status string
SubscriptionType string // standard/subscription
DailyLimitUSD *float64 // 日限额 (USD)
WeeklyLimitUSD *float64 // 周限额 (USD)
MonthlyLimitUSD *float64 // 月限额 (USD)
}
type CreateAccountInput struct {
Name string
Platform string
Type string
Credentials map[string]interface{}
Extra map[string]interface{}
ProxyID *int64
Concurrency int
Priority int
GroupIDs []int64
}
type UpdateAccountInput struct {
Name string
Type string // Account type: oauth, setup-token, apikey
Credentials map[string]interface{}
Extra map[string]interface{}
ProxyID *int64
Concurrency *int // 使用指针区分"未提供"和"设置为0"
Priority *int // 使用指针区分"未提供"和"设置为0"
Status string
GroupIDs *[]int64
}
type CreateProxyInput struct {
Name string
Protocol string
Host string
Port int
Username string
Password string
}
type UpdateProxyInput struct {
Name string
Protocol string
Host string
Port int
Username string
Password string
Status string
}
type GenerateRedeemCodesInput struct {
Count int
Type string
Value float64
GroupID *int64 // 订阅类型专用关联的分组ID
ValidityDays int // 订阅类型专用:有效天数
}
// ProxyTestResult represents the result of testing a proxy
type ProxyTestResult struct {
Success bool `json:"success"`
Message string `json:"message"`
LatencyMs int64 `json:"latency_ms,omitempty"`
IPAddress string `json:"ip_address,omitempty"`
City string `json:"city,omitempty"`
Region string `json:"region,omitempty"`
Country string `json:"country,omitempty"`
}
// adminServiceImpl implements AdminService
type adminServiceImpl struct {
userRepo *repository.UserRepository
groupRepo *repository.GroupRepository
accountRepo *repository.AccountRepository
proxyRepo *repository.ProxyRepository
apiKeyRepo *repository.ApiKeyRepository
redeemCodeRepo *repository.RedeemCodeRepository
usageLogRepo *repository.UsageLogRepository
userSubRepo *repository.UserSubscriptionRepository
billingCacheService *BillingCacheService
}
// NewAdminService creates a new AdminService
func NewAdminService(repos *repository.Repositories) AdminService {
return &adminServiceImpl{
userRepo: repos.User,
groupRepo: repos.Group,
accountRepo: repos.Account,
proxyRepo: repos.Proxy,
apiKeyRepo: repos.ApiKey,
redeemCodeRepo: repos.RedeemCode,
usageLogRepo: repos.UsageLog,
userSubRepo: repos.UserSubscription,
}
}
// SetBillingCacheService 设置计费缓存服务(用于缓存失效)
// 注意AdminService是接口需要类型断言
func SetAdminServiceBillingCache(adminService AdminService, billingCacheService *BillingCacheService) {
if impl, ok := adminService.(*adminServiceImpl); ok {
impl.billingCacheService = billingCacheService
}
}
// User management implementations
func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, status, role, search string) ([]model.User, int64, error) {
params := repository.PaginationParams{Page: page, PageSize: pageSize}
users, result, err := s.userRepo.ListWithFilters(ctx, params, status, role, search)
if err != nil {
return nil, 0, err
}
return users, result.Total, nil
}
func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*model.User, error) {
return s.userRepo.GetByID(ctx, id)
}
func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInput) (*model.User, error) {
user := &model.User{
Email: input.Email,
Role: "user", // Always create as regular user, never admin
Balance: input.Balance,
Concurrency: input.Concurrency,
Status: model.StatusActive,
}
if err := user.SetPassword(input.Password); err != nil {
return nil, err
}
if err := s.userRepo.Create(ctx, user); err != nil {
return nil, err
}
return user, nil
}
func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*model.User, error) {
user, err := s.userRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
// Protect admin users: cannot disable admin accounts
if user.Role == "admin" && input.Status == "disabled" {
return nil, errors.New("cannot disable admin user")
}
// Track balance and concurrency changes for logging
oldBalance := user.Balance
oldConcurrency := user.Concurrency
if input.Email != "" {
user.Email = input.Email
}
if input.Password != "" {
if err := user.SetPassword(input.Password); err != nil {
return nil, err
}
}
// Role is not allowed to be changed via API to prevent privilege escalation
if input.Status != "" {
user.Status = input.Status
}
// 只在指针非 nil 时更新 Balance支持设置为 0
if input.Balance != nil {
user.Balance = *input.Balance
}
// 只在指针非 nil 时更新 Concurrency支持设置为任意值
if input.Concurrency != nil {
user.Concurrency = *input.Concurrency
}
// 只在指针非 nil 时更新 AllowedGroups
if input.AllowedGroups != nil {
user.AllowedGroups = *input.AllowedGroups
}
if err := s.userRepo.Update(ctx, user); err != nil {
return nil, err
}
// 余额变化时失效缓存
if input.Balance != nil && *input.Balance != oldBalance {
if s.billingCacheService != nil {
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
s.billingCacheService.InvalidateUserBalance(cacheCtx, id)
}()
}
}
// Create adjustment records for balance/concurrency changes
balanceDiff := user.Balance - oldBalance
if balanceDiff != 0 {
adjustmentRecord := &model.RedeemCode{
Code: model.GenerateRedeemCode(),
Type: model.AdjustmentTypeAdminBalance,
Value: balanceDiff,
Status: model.StatusUsed,
UsedBy: &user.ID,
}
now := time.Now()
adjustmentRecord.UsedAt = &now
if err := s.redeemCodeRepo.Create(ctx, adjustmentRecord); err != nil {
// Log error but don't fail the update
// The user update has already succeeded
}
}
concurrencyDiff := user.Concurrency - oldConcurrency
if concurrencyDiff != 0 {
adjustmentRecord := &model.RedeemCode{
Code: model.GenerateRedeemCode(),
Type: model.AdjustmentTypeAdminConcurrency,
Value: float64(concurrencyDiff),
Status: model.StatusUsed,
UsedBy: &user.ID,
}
now := time.Now()
adjustmentRecord.UsedAt = &now
if err := s.redeemCodeRepo.Create(ctx, adjustmentRecord); err != nil {
// Log error but don't fail the update
// The user update has already succeeded
}
}
return user, nil
}
func (s *adminServiceImpl) DeleteUser(ctx context.Context, id int64) error {
// Protect admin users: cannot delete admin accounts
user, err := s.userRepo.GetByID(ctx, id)
if err != nil {
return err
}
if user.Role == "admin" {
return errors.New("cannot delete admin user")
}
return s.userRepo.Delete(ctx, id)
}
func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string) (*model.User, error) {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return nil, err
}
switch operation {
case "set":
user.Balance = balance
case "add":
user.Balance += balance
case "subtract":
user.Balance -= balance
}
if err := s.userRepo.Update(ctx, user); err != nil {
return nil, err
}
// 失效余额缓存
if s.billingCacheService != nil {
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
s.billingCacheService.InvalidateUserBalance(cacheCtx, userID)
}()
}
return user, nil
}
func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]model.ApiKey, int64, error) {
params := repository.PaginationParams{Page: page, PageSize: pageSize}
keys, result, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
if err != nil {
return nil, 0, err
}
return keys, result.Total, nil
}
func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64, period string) (interface{}, error) {
// Return mock data for now
return map[string]interface{}{
"period": period,
"total_requests": 0,
"total_cost": 0.0,
"total_tokens": 0,
"avg_duration_ms": 0,
}, nil
}
// Group management implementations
func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]model.Group, int64, error) {
params := repository.PaginationParams{Page: page, PageSize: pageSize}
groups, result, err := s.groupRepo.ListWithFilters(ctx, params, platform, status, isExclusive)
if err != nil {
return nil, 0, err
}
return groups, result.Total, nil
}
func (s *adminServiceImpl) GetAllGroups(ctx context.Context) ([]model.Group, error) {
return s.groupRepo.ListActive(ctx)
}
func (s *adminServiceImpl) GetAllGroupsByPlatform(ctx context.Context, platform string) ([]model.Group, error) {
return s.groupRepo.ListActiveByPlatform(ctx, platform)
}
func (s *adminServiceImpl) GetGroup(ctx context.Context, id int64) (*model.Group, error) {
return s.groupRepo.GetByID(ctx, id)
}
func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupInput) (*model.Group, error) {
platform := input.Platform
if platform == "" {
platform = model.PlatformAnthropic
}
subscriptionType := input.SubscriptionType
if subscriptionType == "" {
subscriptionType = model.SubscriptionTypeStandard
}
group := &model.Group{
Name: input.Name,
Description: input.Description,
Platform: platform,
RateMultiplier: input.RateMultiplier,
IsExclusive: input.IsExclusive,
Status: model.StatusActive,
SubscriptionType: subscriptionType,
DailyLimitUSD: input.DailyLimitUSD,
WeeklyLimitUSD: input.WeeklyLimitUSD,
MonthlyLimitUSD: input.MonthlyLimitUSD,
}
if err := s.groupRepo.Create(ctx, group); err != nil {
return nil, err
}
return group, nil
}
func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*model.Group, error) {
group, err := s.groupRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
if input.Name != "" {
group.Name = input.Name
}
if input.Description != "" {
group.Description = input.Description
}
if input.Platform != "" {
group.Platform = input.Platform
}
if input.RateMultiplier != nil {
group.RateMultiplier = *input.RateMultiplier
}
if input.IsExclusive != nil {
group.IsExclusive = *input.IsExclusive
}
if input.Status != "" {
group.Status = input.Status
}
// 订阅相关字段
if input.SubscriptionType != "" {
group.SubscriptionType = input.SubscriptionType
}
// 限额字段支持设置为nil清除限额或具体值
if input.DailyLimitUSD != nil {
group.DailyLimitUSD = input.DailyLimitUSD
}
if input.WeeklyLimitUSD != nil {
group.WeeklyLimitUSD = input.WeeklyLimitUSD
}
if input.MonthlyLimitUSD != nil {
group.MonthlyLimitUSD = input.MonthlyLimitUSD
}
if err := s.groupRepo.Update(ctx, group); err != nil {
return nil, err
}
return group, nil
}
func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
// 先获取分组信息,检查是否存在
group, err := s.groupRepo.GetByID(ctx, id)
if err != nil {
return fmt.Errorf("group not found: %w", err)
}
// 订阅类型分组先获取受影响的用户ID列表用于事务后失效缓存
var affectedUserIDs []int64
if group.IsSubscriptionType() && s.billingCacheService != nil {
var subscriptions []model.UserSubscription
if err := s.groupRepo.DB().WithContext(ctx).
Where("group_id = ?", id).
Select("user_id").
Find(&subscriptions).Error; err == nil {
for _, sub := range subscriptions {
affectedUserIDs = append(affectedUserIDs, sub.UserID)
}
}
}
// 使用事务处理所有级联删除
db := s.groupRepo.DB()
err = db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
// 1. 如果是订阅类型分组,删除 user_subscriptions 中的相关记录
if group.IsSubscriptionType() {
if err := tx.Where("group_id = ?", id).Delete(&model.UserSubscription{}).Error; err != nil {
return fmt.Errorf("delete user subscriptions: %w", err)
}
}
// 2. 将 api_keys 中绑定该分组的 group_id 设为 nil任何类型的分组都需要
if err := tx.Model(&model.ApiKey{}).Where("group_id = ?", id).Update("group_id", nil).Error; err != nil {
return fmt.Errorf("clear api key group_id: %w", err)
}
// 3. 从 users.allowed_groups 数组中移除该分组 ID
if err := tx.Model(&model.User{}).
Where("? = ANY(allowed_groups)", id).
Update("allowed_groups", gorm.Expr("array_remove(allowed_groups, ?)", id)).Error; err != nil {
return fmt.Errorf("remove from allowed_groups: %w", err)
}
// 4. 删除 account_groups 中间表的数据
if err := tx.Where("group_id = ?", id).Delete(&model.AccountGroup{}).Error; err != nil {
return fmt.Errorf("delete account groups: %w", err)
}
// 5. 删除分组本身
if err := tx.Delete(&model.Group{}, id).Error; err != nil {
return fmt.Errorf("delete group: %w", err)
}
return nil
})
if err != nil {
return err
}
// 事务成功后,异步失效受影响用户的订阅缓存
if len(affectedUserIDs) > 0 && s.billingCacheService != nil {
groupID := id
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
for _, userID := range affectedUserIDs {
s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
}
}()
}
return nil
}
func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]model.ApiKey, int64, error) {
params := repository.PaginationParams{Page: page, PageSize: pageSize}
keys, result, err := s.apiKeyRepo.ListByGroupID(ctx, groupID, params)
if err != nil {
return nil, 0, err
}
return keys, result.Total, nil
}
// Account management implementations
func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]model.Account, int64, error) {
params := repository.PaginationParams{Page: page, PageSize: pageSize}
accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search)
if err != nil {
return nil, 0, err
}
return accounts, result.Total, nil
}
func (s *adminServiceImpl) GetAccount(ctx context.Context, id int64) (*model.Account, error) {
return s.accountRepo.GetByID(ctx, id)
}
func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccountInput) (*model.Account, error) {
account := &model.Account{
Name: input.Name,
Platform: input.Platform,
Type: input.Type,
Credentials: model.JSONB(input.Credentials),
Extra: model.JSONB(input.Extra),
ProxyID: input.ProxyID,
Concurrency: input.Concurrency,
Priority: input.Priority,
Status: model.StatusActive,
}
if err := s.accountRepo.Create(ctx, account); err != nil {
return nil, err
}
// 绑定分组
if len(input.GroupIDs) > 0 {
if err := s.accountRepo.BindGroups(ctx, account.ID, input.GroupIDs); err != nil {
return nil, err
}
}
return account, nil
}
func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *UpdateAccountInput) (*model.Account, error) {
account, err := s.accountRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
if input.Name != "" {
account.Name = input.Name
}
if input.Type != "" {
account.Type = input.Type
}
if input.Credentials != nil && len(input.Credentials) > 0 {
account.Credentials = model.JSONB(input.Credentials)
}
if input.Extra != nil && len(input.Extra) > 0 {
account.Extra = model.JSONB(input.Extra)
}
if input.ProxyID != nil {
account.ProxyID = input.ProxyID
}
// 只在指针非 nil 时更新 Concurrency支持设置为 0
if input.Concurrency != nil {
account.Concurrency = *input.Concurrency
}
// 只在指针非 nil 时更新 Priority支持设置为 0
if input.Priority != nil {
account.Priority = *input.Priority
}
if input.Status != "" {
account.Status = input.Status
}
if err := s.accountRepo.Update(ctx, account); err != nil {
return nil, err
}
// 更新分组绑定
if input.GroupIDs != nil {
if err := s.accountRepo.BindGroups(ctx, account.ID, *input.GroupIDs); err != nil {
return nil, err
}
}
return account, nil
}
func (s *adminServiceImpl) DeleteAccount(ctx context.Context, id int64) error {
return s.accountRepo.Delete(ctx, id)
}
func (s *adminServiceImpl) RefreshAccountCredentials(ctx context.Context, id int64) (*model.Account, error) {
account, err := s.accountRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
// TODO: Implement refresh logic
return account, nil
}
func (s *adminServiceImpl) ClearAccountError(ctx context.Context, id int64) (*model.Account, error) {
account, err := s.accountRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
account.Status = model.StatusActive
account.ErrorMessage = ""
if err := s.accountRepo.Update(ctx, account); err != nil {
return nil, err
}
return account, nil
}
func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*model.Account, error) {
if err := s.accountRepo.SetSchedulable(ctx, id, schedulable); err != nil {
return nil, err
}
return s.accountRepo.GetByID(ctx, id)
}
// Proxy management implementations
func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]model.Proxy, int64, error) {
params := repository.PaginationParams{Page: page, PageSize: pageSize}
proxies, result, err := s.proxyRepo.ListWithFilters(ctx, params, protocol, status, search)
if err != nil {
return nil, 0, err
}
return proxies, result.Total, nil
}
func (s *adminServiceImpl) GetAllProxies(ctx context.Context) ([]model.Proxy, error) {
return s.proxyRepo.ListActive(ctx)
}
func (s *adminServiceImpl) GetAllProxiesWithAccountCount(ctx context.Context) ([]model.ProxyWithAccountCount, error) {
return s.proxyRepo.ListActiveWithAccountCount(ctx)
}
func (s *adminServiceImpl) GetProxy(ctx context.Context, id int64) (*model.Proxy, error) {
return s.proxyRepo.GetByID(ctx, id)
}
func (s *adminServiceImpl) CreateProxy(ctx context.Context, input *CreateProxyInput) (*model.Proxy, error) {
proxy := &model.Proxy{
Name: input.Name,
Protocol: input.Protocol,
Host: input.Host,
Port: input.Port,
Username: input.Username,
Password: input.Password,
Status: model.StatusActive,
}
if err := s.proxyRepo.Create(ctx, proxy); err != nil {
return nil, err
}
return proxy, nil
}
func (s *adminServiceImpl) UpdateProxy(ctx context.Context, id int64, input *UpdateProxyInput) (*model.Proxy, error) {
proxy, err := s.proxyRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
if input.Name != "" {
proxy.Name = input.Name
}
if input.Protocol != "" {
proxy.Protocol = input.Protocol
}
if input.Host != "" {
proxy.Host = input.Host
}
if input.Port != 0 {
proxy.Port = input.Port
}
if input.Username != "" {
proxy.Username = input.Username
}
if input.Password != "" {
proxy.Password = input.Password
}
if input.Status != "" {
proxy.Status = input.Status
}
if err := s.proxyRepo.Update(ctx, proxy); err != nil {
return nil, err
}
return proxy, nil
}
func (s *adminServiceImpl) DeleteProxy(ctx context.Context, id int64) error {
return s.proxyRepo.Delete(ctx, id)
}
func (s *adminServiceImpl) GetProxyAccounts(ctx context.Context, proxyID int64, page, pageSize int) ([]model.Account, int64, error) {
// Return mock data for now - would need a dedicated repository method
return []model.Account{}, 0, nil
}
func (s *adminServiceImpl) CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error) {
return s.proxyRepo.ExistsByHostPortAuth(ctx, host, port, username, password)
}
// Redeem code management implementations
func (s *adminServiceImpl) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]model.RedeemCode, int64, error) {
params := repository.PaginationParams{Page: page, PageSize: pageSize}
codes, result, err := s.redeemCodeRepo.ListWithFilters(ctx, params, codeType, status, search)
if err != nil {
return nil, 0, err
}
return codes, result.Total, nil
}
func (s *adminServiceImpl) GetRedeemCode(ctx context.Context, id int64) (*model.RedeemCode, error) {
return s.redeemCodeRepo.GetByID(ctx, id)
}
func (s *adminServiceImpl) GenerateRedeemCodes(ctx context.Context, input *GenerateRedeemCodesInput) ([]model.RedeemCode, error) {
// 如果是订阅类型,验证必须有 GroupID
if input.Type == model.RedeemTypeSubscription {
if input.GroupID == nil {
return nil, errors.New("group_id is required for subscription type")
}
// 验证分组存在且为订阅类型
group, err := s.groupRepo.GetByID(ctx, *input.GroupID)
if err != nil {
return nil, fmt.Errorf("group not found: %w", err)
}
if !group.IsSubscriptionType() {
return nil, errors.New("group must be subscription type")
}
}
codes := make([]model.RedeemCode, 0, input.Count)
for i := 0; i < input.Count; i++ {
code := model.RedeemCode{
Code: model.GenerateRedeemCode(),
Type: input.Type,
Value: input.Value,
Status: model.StatusUnused,
}
// 订阅类型专用字段
if input.Type == model.RedeemTypeSubscription {
code.GroupID = input.GroupID
code.ValidityDays = input.ValidityDays
if code.ValidityDays <= 0 {
code.ValidityDays = 30 // 默认30天
}
}
if err := s.redeemCodeRepo.Create(ctx, &code); err != nil {
return nil, err
}
codes = append(codes, code)
}
return codes, nil
}
func (s *adminServiceImpl) DeleteRedeemCode(ctx context.Context, id int64) error {
return s.redeemCodeRepo.Delete(ctx, id)
}
func (s *adminServiceImpl) BatchDeleteRedeemCodes(ctx context.Context, ids []int64) (int64, error) {
var deleted int64
for _, id := range ids {
if err := s.redeemCodeRepo.Delete(ctx, id); err == nil {
deleted++
}
}
return deleted, nil
}
func (s *adminServiceImpl) ExpireRedeemCode(ctx context.Context, id int64) (*model.RedeemCode, error) {
code, err := s.redeemCodeRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
code.Status = model.StatusExpired
if err := s.redeemCodeRepo.Update(ctx, code); err != nil {
return nil, err
}
return code, nil
}
func (s *adminServiceImpl) TestProxy(ctx context.Context, id int64) (*ProxyTestResult, error) {
proxy, err := s.proxyRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
return testProxyConnection(ctx, proxy)
}
// testProxyConnection tests proxy connectivity by requesting ipinfo.io/json
func testProxyConnection(ctx context.Context, proxy *model.Proxy) (*ProxyTestResult, error) {
proxyURL := proxy.URL()
// Create HTTP client with proxy
transport, err := createProxyTransport(proxyURL)
if err != nil {
return &ProxyTestResult{
Success: false,
Message: fmt.Sprintf("Failed to create proxy transport: %v", err),
}, nil
}
client := &http.Client{
Transport: transport,
Timeout: 15 * time.Second,
}
// Measure latency
startTime := time.Now()
req, err := http.NewRequestWithContext(ctx, "GET", "https://ipinfo.io/json", nil)
if err != nil {
return &ProxyTestResult{
Success: false,
Message: fmt.Sprintf("Failed to create request: %v", err),
}, nil
}
resp, err := client.Do(req)
if err != nil {
return &ProxyTestResult{
Success: false,
Message: fmt.Sprintf("Proxy connection failed: %v", err),
}, nil
}
defer resp.Body.Close()
latencyMs := time.Since(startTime).Milliseconds()
if resp.StatusCode != http.StatusOK {
return &ProxyTestResult{
Success: false,
Message: fmt.Sprintf("Request failed with status: %d", resp.StatusCode),
LatencyMs: latencyMs,
}, nil
}
// Parse ipinfo.io response
var ipInfo struct {
IP string `json:"ip"`
City string `json:"city"`
Region string `json:"region"`
Country string `json:"country"`
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return &ProxyTestResult{
Success: true,
Message: "Proxy is accessible but failed to read response",
LatencyMs: latencyMs,
}, nil
}
if err := json.Unmarshal(body, &ipInfo); err != nil {
return &ProxyTestResult{
Success: true,
Message: "Proxy is accessible but failed to parse response",
LatencyMs: latencyMs,
}, nil
}
return &ProxyTestResult{
Success: true,
Message: "Proxy is accessible",
LatencyMs: latencyMs,
IPAddress: ipInfo.IP,
City: ipInfo.City,
Region: ipInfo.Region,
Country: ipInfo.Country,
}, nil
}
// createProxyTransport creates an HTTP transport with the given proxy URL
func createProxyTransport(proxyURL string) (*http.Transport, error) {
parsedURL, err := url.Parse(proxyURL)
if err != nil {
return nil, fmt.Errorf("invalid proxy URL: %w", err)
}
transport := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
switch parsedURL.Scheme {
case "http", "https":
transport.Proxy = http.ProxyURL(parsedURL)
case "socks5":
dialer, err := proxy.FromURL(parsedURL, proxy.Direct)
if err != nil {
return nil, fmt.Errorf("failed to create socks5 dialer: %w", err)
}
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.Dial(network, addr)
}
default:
return nil, fmt.Errorf("unsupported proxy protocol: %s", parsedURL.Scheme)
}
return transport, nil
}

View File

@@ -0,0 +1,464 @@
package service
import (
"context"
"crypto/rand"
"encoding/hex"
"errors"
"fmt"
"sub2api/internal/config"
"sub2api/internal/model"
"sub2api/internal/pkg/timezone"
"sub2api/internal/repository"
"time"
"github.com/redis/go-redis/v9"
"gorm.io/gorm"
)
var (
ErrApiKeyNotFound = errors.New("api key not found")
ErrGroupNotAllowed = errors.New("user is not allowed to bind this group")
ErrApiKeyExists = errors.New("api key already exists")
ErrApiKeyTooShort = errors.New("api key must be at least 16 characters")
ErrApiKeyInvalidChars = errors.New("api key can only contain letters, numbers, underscores, and hyphens")
ErrApiKeyRateLimited = errors.New("too many failed attempts, please try again later")
)
const (
apiKeyRateLimitKeyPrefix = "apikey:create_rate_limit:"
apiKeyMaxErrorsPerHour = 20
apiKeyRateLimitDuration = time.Hour
)
// CreateApiKeyRequest 创建API Key请求
type CreateApiKeyRequest struct {
Name string `json:"name"`
GroupID *int64 `json:"group_id"`
CustomKey *string `json:"custom_key"` // 可选的自定义key
}
// UpdateApiKeyRequest 更新API Key请求
type UpdateApiKeyRequest struct {
Name *string `json:"name"`
GroupID *int64 `json:"group_id"`
Status *string `json:"status"`
}
// ApiKeyService API Key服务
type ApiKeyService struct {
apiKeyRepo *repository.ApiKeyRepository
userRepo *repository.UserRepository
groupRepo *repository.GroupRepository
userSubRepo *repository.UserSubscriptionRepository
rdb *redis.Client
cfg *config.Config
}
// NewApiKeyService 创建API Key服务实例
func NewApiKeyService(
apiKeyRepo *repository.ApiKeyRepository,
userRepo *repository.UserRepository,
groupRepo *repository.GroupRepository,
userSubRepo *repository.UserSubscriptionRepository,
rdb *redis.Client,
cfg *config.Config,
) *ApiKeyService {
return &ApiKeyService{
apiKeyRepo: apiKeyRepo,
userRepo: userRepo,
groupRepo: groupRepo,
userSubRepo: userSubRepo,
rdb: rdb,
cfg: cfg,
}
}
// GenerateKey 生成随机API Key
func (s *ApiKeyService) GenerateKey() (string, error) {
// 生成32字节随机数据
bytes := make([]byte, 32)
if _, err := rand.Read(bytes); err != nil {
return "", fmt.Errorf("generate random bytes: %w", err)
}
// 转换为十六进制字符串并添加前缀
prefix := s.cfg.Default.ApiKeyPrefix
if prefix == "" {
prefix = "sk-"
}
key := prefix + hex.EncodeToString(bytes)
return key, nil
}
// ValidateCustomKey 验证自定义API Key格式
func (s *ApiKeyService) ValidateCustomKey(key string) error {
// 检查长度
if len(key) < 16 {
return ErrApiKeyTooShort
}
// 检查字符:只允许字母、数字、下划线、连字符
for _, c := range key {
if !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') ||
(c >= '0' && c <= '9') || c == '_' || c == '-') {
return ErrApiKeyInvalidChars
}
}
return nil
}
// checkApiKeyRateLimit 检查用户创建自定义Key的错误次数是否超限
func (s *ApiKeyService) checkApiKeyRateLimit(ctx context.Context, userID int64) error {
if s.rdb == nil {
return nil
}
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
count, err := s.rdb.Get(ctx, key).Int()
if err != nil && !errors.Is(err, redis.Nil) {
// Redis 出错时不阻止用户操作
return nil
}
if count >= apiKeyMaxErrorsPerHour {
return ErrApiKeyRateLimited
}
return nil
}
// incrementApiKeyErrorCount 增加用户创建自定义Key的错误计数
func (s *ApiKeyService) incrementApiKeyErrorCount(ctx context.Context, userID int64) {
if s.rdb == nil {
return
}
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
pipe := s.rdb.Pipeline()
pipe.Incr(ctx, key)
pipe.Expire(ctx, key, apiKeyRateLimitDuration)
_, _ = pipe.Exec(ctx)
}
// canUserBindGroup 检查用户是否可以绑定指定分组
// 对于订阅类型分组:检查用户是否有有效订阅
// 对于标准类型分组:使用原有的 AllowedGroups 和 IsExclusive 逻辑
func (s *ApiKeyService) canUserBindGroup(ctx context.Context, user *model.User, group *model.Group) bool {
// 订阅类型分组:需要有效订阅
if group.IsSubscriptionType() {
_, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, user.ID, group.ID)
return err == nil // 有有效订阅则允许
}
// 标准类型分组:使用原有逻辑
return user.CanBindGroup(group.ID, group.IsExclusive)
}
// Create 创建API Key
func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiKeyRequest) (*model.ApiKey, error) {
// 验证用户存在
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
return nil, fmt.Errorf("get user: %w", err)
}
// 验证分组权限(如果指定了分组)
if req.GroupID != nil {
group, err := s.groupRepo.GetByID(ctx, *req.GroupID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("group not found")
}
return nil, fmt.Errorf("get group: %w", err)
}
// 检查用户是否可以绑定该分组
if !s.canUserBindGroup(ctx, user, group) {
return nil, ErrGroupNotAllowed
}
}
var key string
// 判断是否使用自定义Key
if req.CustomKey != nil && *req.CustomKey != "" {
// 检查限流仅对自定义key进行限流
if err := s.checkApiKeyRateLimit(ctx, userID); err != nil {
return nil, err
}
// 验证自定义Key格式
if err := s.ValidateCustomKey(*req.CustomKey); err != nil {
return nil, err
}
// 检查Key是否已存在
exists, err := s.apiKeyRepo.ExistsByKey(ctx, *req.CustomKey)
if err != nil {
return nil, fmt.Errorf("check key exists: %w", err)
}
if exists {
// Key已存在增加错误计数
s.incrementApiKeyErrorCount(ctx, userID)
return nil, ErrApiKeyExists
}
key = *req.CustomKey
} else {
// 生成随机API Key
var err error
key, err = s.GenerateKey()
if err != nil {
return nil, fmt.Errorf("generate key: %w", err)
}
}
// 创建API Key记录
apiKey := &model.ApiKey{
UserID: userID,
Key: key,
Name: req.Name,
GroupID: req.GroupID,
Status: model.StatusActive,
}
if err := s.apiKeyRepo.Create(ctx, apiKey); err != nil {
return nil, fmt.Errorf("create api key: %w", err)
}
return apiKey, nil
}
// List 获取用户的API Key列表
func (s *ApiKeyService) List(ctx context.Context, userID int64, params repository.PaginationParams) ([]model.ApiKey, *repository.PaginationResult, error) {
keys, pagination, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
if err != nil {
return nil, nil, fmt.Errorf("list api keys: %w", err)
}
return keys, pagination, nil
}
// GetByID 根据ID获取API Key
func (s *ApiKeyService) GetByID(ctx context.Context, id int64) (*model.ApiKey, error) {
apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrApiKeyNotFound
}
return nil, fmt.Errorf("get api key: %w", err)
}
return apiKey, nil
}
// GetByKey 根据Key字符串获取API Key用于认证
func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*model.ApiKey, error) {
// 尝试从Redis缓存获取
cacheKey := fmt.Sprintf("apikey:%s", key)
// 这里可以添加Redis缓存逻辑暂时直接查询数据库
apiKey, err := s.apiKeyRepo.GetByKey(ctx, key)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrApiKeyNotFound
}
return nil, fmt.Errorf("get api key: %w", err)
}
// 缓存到Redis可选TTL设置为5分钟
if s.rdb != nil {
// 这里可以序列化并缓存API Key
_ = cacheKey // 使用变量避免未使用错误
}
return apiKey, nil
}
// Update 更新API Key
func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req UpdateApiKeyRequest) (*model.ApiKey, error) {
apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrApiKeyNotFound
}
return nil, fmt.Errorf("get api key: %w", err)
}
// 验证所有权
if apiKey.UserID != userID {
return nil, ErrInsufficientPerms
}
// 更新字段
if req.Name != nil {
apiKey.Name = *req.Name
}
if req.GroupID != nil {
// 验证分组权限
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return nil, fmt.Errorf("get user: %w", err)
}
group, err := s.groupRepo.GetByID(ctx, *req.GroupID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("group not found")
}
return nil, fmt.Errorf("get group: %w", err)
}
if !s.canUserBindGroup(ctx, user, group) {
return nil, ErrGroupNotAllowed
}
apiKey.GroupID = req.GroupID
}
if req.Status != nil {
apiKey.Status = *req.Status
// 如果状态改变清除Redis缓存
if s.rdb != nil {
cacheKey := fmt.Sprintf("apikey:%s", apiKey.Key)
_ = s.rdb.Del(ctx, cacheKey)
}
}
if err := s.apiKeyRepo.Update(ctx, apiKey); err != nil {
return nil, fmt.Errorf("update api key: %w", err)
}
return apiKey, nil
}
// Delete 删除API Key
func (s *ApiKeyService) Delete(ctx context.Context, id int64, userID int64) error {
apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrApiKeyNotFound
}
return fmt.Errorf("get api key: %w", err)
}
// 验证所有权
if apiKey.UserID != userID {
return ErrInsufficientPerms
}
// 清除Redis缓存
if s.rdb != nil {
cacheKey := fmt.Sprintf("apikey:%s", apiKey.Key)
_ = s.rdb.Del(ctx, cacheKey)
}
if err := s.apiKeyRepo.Delete(ctx, id); err != nil {
return fmt.Errorf("delete api key: %w", err)
}
return nil
}
// ValidateKey 验证API Key是否有效用于认证中间件
func (s *ApiKeyService) ValidateKey(ctx context.Context, key string) (*model.ApiKey, *model.User, error) {
// 获取API Key
apiKey, err := s.GetByKey(ctx, key)
if err != nil {
return nil, nil, err
}
// 检查API Key状态
if !apiKey.IsActive() {
return nil, nil, errors.New("api key is not active")
}
// 获取用户信息
user, err := s.userRepo.GetByID(ctx, apiKey.UserID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil, ErrUserNotFound
}
return nil, nil, fmt.Errorf("get user: %w", err)
}
// 检查用户状态
if !user.IsActive() {
return nil, nil, ErrUserNotActive
}
return apiKey, user, nil
}
// IncrementUsage 增加API Key使用次数可选用于统计
func (s *ApiKeyService) IncrementUsage(ctx context.Context, keyID int64) error {
// 使用Redis计数器
if s.rdb != nil {
cacheKey := fmt.Sprintf("apikey:usage:%d:%s", keyID, timezone.Now().Format("2006-01-02"))
if err := s.rdb.Incr(ctx, cacheKey).Err(); err != nil {
return fmt.Errorf("increment usage: %w", err)
}
// 设置24小时过期
_ = s.rdb.Expire(ctx, cacheKey, 24*time.Hour)
}
return nil
}
// GetAvailableGroups 获取用户有权限绑定的分组列表
// 返回用户可以选择的分组:
// - 标准类型分组:公开的(非专属)或用户被明确允许的
// - 订阅类型分组:用户有有效订阅的
func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([]model.Group, error) {
// 获取用户信息
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
return nil, fmt.Errorf("get user: %w", err)
}
// 获取所有活跃分组
allGroups, err := s.groupRepo.ListActive(ctx)
if err != nil {
return nil, fmt.Errorf("list active groups: %w", err)
}
// 获取用户的所有有效订阅
activeSubscriptions, err := s.userSubRepo.ListActiveByUserID(ctx, userID)
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, fmt.Errorf("list active subscriptions: %w", err)
}
// 构建订阅分组 ID 集合
subscribedGroupIDs := make(map[int64]bool)
for _, sub := range activeSubscriptions {
subscribedGroupIDs[sub.GroupID] = true
}
// 过滤出用户有权限的分组
availableGroups := make([]model.Group, 0)
for _, group := range allGroups {
if s.canUserBindGroupInternal(user, &group, subscribedGroupIDs) {
availableGroups = append(availableGroups, group)
}
}
return availableGroups, nil
}
// canUserBindGroupInternal 内部方法,检查用户是否可以绑定分组(使用预加载的订阅数据)
func (s *ApiKeyService) canUserBindGroupInternal(user *model.User, group *model.Group, subscribedGroupIDs map[int64]bool) bool {
// 订阅类型分组:需要有效订阅
if group.IsSubscriptionType() {
return subscribedGroupIDs[group.ID]
}
// 标准类型分组:使用原有逻辑
return user.CanBindGroup(group.ID, group.IsExclusive)
}

View File

@@ -0,0 +1,376 @@
package service
import (
"context"
"errors"
"fmt"
"log"
"sub2api/internal/config"
"sub2api/internal/model"
"sub2api/internal/repository"
"time"
"github.com/golang-jwt/jwt/v5"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
)
var (
ErrInvalidCredentials = errors.New("invalid email or password")
ErrUserNotActive = errors.New("user is not active")
ErrEmailExists = errors.New("email already exists")
ErrInvalidToken = errors.New("invalid token")
ErrTokenExpired = errors.New("token has expired")
ErrEmailVerifyRequired = errors.New("email verification is required")
ErrRegDisabled = errors.New("registration is currently disabled")
)
// JWTClaims JWT载荷数据
type JWTClaims struct {
UserID int64 `json:"user_id"`
Email string `json:"email"`
Role string `json:"role"`
jwt.RegisteredClaims
}
// AuthService 认证服务
type AuthService struct {
userRepo *repository.UserRepository
cfg *config.Config
settingService *SettingService
emailService *EmailService
turnstileService *TurnstileService
emailQueueService *EmailQueueService
}
// NewAuthService 创建认证服务实例
func NewAuthService(userRepo *repository.UserRepository, cfg *config.Config) *AuthService {
return &AuthService{
userRepo: userRepo,
cfg: cfg,
}
}
// SetSettingService 设置系统设置服务(用于检查注册开关和邮件验证)
func (s *AuthService) SetSettingService(settingService *SettingService) {
s.settingService = settingService
}
// SetEmailService 设置邮件服务(用于邮件验证)
func (s *AuthService) SetEmailService(emailService *EmailService) {
s.emailService = emailService
}
// SetTurnstileService 设置Turnstile服务用于验证码校验
func (s *AuthService) SetTurnstileService(turnstileService *TurnstileService) {
s.turnstileService = turnstileService
}
// SetEmailQueueService 设置邮件队列服务(用于异步发送邮件)
func (s *AuthService) SetEmailQueueService(emailQueueService *EmailQueueService) {
s.emailQueueService = emailQueueService
}
// Register 用户注册返回token和用户
func (s *AuthService) Register(ctx context.Context, email, password string) (string, *model.User, error) {
return s.RegisterWithVerification(ctx, email, password, "")
}
// RegisterWithVerification 用户注册支持邮件验证返回token和用户
func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode string) (string, *model.User, error) {
// 检查是否开放注册
if s.settingService != nil && !s.settingService.IsRegistrationEnabled(ctx) {
return "", nil, ErrRegDisabled
}
// 检查是否需要邮件验证
if s.settingService != nil && s.settingService.IsEmailVerifyEnabled(ctx) {
if verifyCode == "" {
return "", nil, ErrEmailVerifyRequired
}
// 验证邮箱验证码
if s.emailService != nil {
if err := s.emailService.VerifyCode(ctx, email, verifyCode); err != nil {
return "", nil, fmt.Errorf("verify code: %w", err)
}
}
}
// 检查邮箱是否已存在
existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
if err != nil {
return "", nil, fmt.Errorf("check email exists: %w", err)
}
if existsEmail {
return "", nil, ErrEmailExists
}
// 密码哈希
hashedPassword, err := s.HashPassword(password)
if err != nil {
return "", nil, fmt.Errorf("hash password: %w", err)
}
// 获取默认配置
defaultBalance := s.cfg.Default.UserBalance
defaultConcurrency := s.cfg.Default.UserConcurrency
if s.settingService != nil {
defaultBalance = s.settingService.GetDefaultBalance(ctx)
defaultConcurrency = s.settingService.GetDefaultConcurrency(ctx)
}
// 创建用户
user := &model.User{
Email: email,
PasswordHash: hashedPassword,
Role: model.RoleUser,
Balance: defaultBalance,
Concurrency: defaultConcurrency,
Status: model.StatusActive,
}
if err := s.userRepo.Create(ctx, user); err != nil {
return "", nil, fmt.Errorf("create user: %w", err)
}
// 生成token
token, err := s.GenerateToken(user)
if err != nil {
return "", nil, fmt.Errorf("generate token: %w", err)
}
return token, user, nil
}
// SendVerifyCodeResult 发送验证码返回结果
type SendVerifyCodeResult struct {
Countdown int `json:"countdown"` // 倒计时秒数
}
// SendVerifyCode 发送邮箱验证码(同步方式)
func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error {
// 检查是否开放注册
if s.settingService != nil && !s.settingService.IsRegistrationEnabled(ctx) {
return ErrRegDisabled
}
// 检查邮箱是否已存在
existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
if err != nil {
return fmt.Errorf("check email exists: %w", err)
}
if existsEmail {
return ErrEmailExists
}
// 发送验证码
if s.emailService == nil {
return errors.New("email service not configured")
}
// 获取网站名称
siteName := "Sub2API"
if s.settingService != nil {
siteName = s.settingService.GetSiteName(ctx)
}
return s.emailService.SendVerifyCode(ctx, email, siteName)
}
// SendVerifyCodeAsync 异步发送邮箱验证码并返回倒计时
func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*SendVerifyCodeResult, error) {
log.Printf("[Auth] SendVerifyCodeAsync called for email: %s", email)
// 检查是否开放注册
if s.settingService != nil && !s.settingService.IsRegistrationEnabled(ctx) {
log.Println("[Auth] Registration is disabled")
return nil, ErrRegDisabled
}
// 检查邮箱是否已存在
existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
if err != nil {
log.Printf("[Auth] Error checking email exists: %v", err)
return nil, fmt.Errorf("check email exists: %w", err)
}
if existsEmail {
log.Printf("[Auth] Email already exists: %s", email)
return nil, ErrEmailExists
}
// 检查邮件队列服务是否配置
if s.emailQueueService == nil {
log.Println("[Auth] Email queue service not configured")
return nil, errors.New("email queue service not configured")
}
// 获取网站名称
siteName := "Sub2API"
if s.settingService != nil {
siteName = s.settingService.GetSiteName(ctx)
}
// 异步发送
log.Printf("[Auth] Enqueueing verify code for: %s", email)
if err := s.emailQueueService.EnqueueVerifyCode(email, siteName); err != nil {
log.Printf("[Auth] Failed to enqueue: %v", err)
return nil, fmt.Errorf("enqueue verify code: %w", err)
}
log.Printf("[Auth] Verify code enqueued successfully for: %s", email)
return &SendVerifyCodeResult{
Countdown: 60, // 60秒倒计时
}, nil
}
// VerifyTurnstile 验证Turnstile token
func (s *AuthService) VerifyTurnstile(ctx context.Context, token string, remoteIP string) error {
if s.turnstileService == nil {
return nil // 服务未配置则跳过验证
}
return s.turnstileService.VerifyToken(ctx, token, remoteIP)
}
// IsTurnstileEnabled 检查是否启用Turnstile验证
func (s *AuthService) IsTurnstileEnabled(ctx context.Context) bool {
if s.turnstileService == nil {
return false
}
return s.turnstileService.IsEnabled(ctx)
}
// IsRegistrationEnabled 检查是否开放注册
func (s *AuthService) IsRegistrationEnabled(ctx context.Context) bool {
if s.settingService == nil {
return true
}
return s.settingService.IsRegistrationEnabled(ctx)
}
// IsEmailVerifyEnabled 检查是否开启邮件验证
func (s *AuthService) IsEmailVerifyEnabled(ctx context.Context) bool {
if s.settingService == nil {
return false
}
return s.settingService.IsEmailVerifyEnabled(ctx)
}
// Login 用户登录返回JWT token
func (s *AuthService) Login(ctx context.Context, email, password string) (string, *model.User, error) {
// 查找用户
user, err := s.userRepo.GetByEmail(ctx, email)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return "", nil, ErrInvalidCredentials
}
return "", nil, fmt.Errorf("get user by email: %w", err)
}
// 验证密码
if !s.CheckPassword(password, user.PasswordHash) {
return "", nil, ErrInvalidCredentials
}
// 检查用户状态
if !user.IsActive() {
return "", nil, ErrUserNotActive
}
// 生成JWT token
token, err := s.GenerateToken(user)
if err != nil {
return "", nil, fmt.Errorf("generate token: %w", err)
}
return token, user, nil
}
// ValidateToken 验证JWT token并返回用户声明
func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) {
token, err := jwt.ParseWithClaims(tokenString, &JWTClaims{}, func(token *jwt.Token) (interface{}, error) {
// 验证签名方法
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return []byte(s.cfg.JWT.Secret), nil
})
if err != nil {
if errors.Is(err, jwt.ErrTokenExpired) {
return nil, ErrTokenExpired
}
return nil, ErrInvalidToken
}
if claims, ok := token.Claims.(*JWTClaims); ok && token.Valid {
return claims, nil
}
return nil, ErrInvalidToken
}
// GenerateToken 生成JWT token
func (s *AuthService) GenerateToken(user *model.User) (string, error) {
now := time.Now()
expiresAt := now.Add(time.Duration(s.cfg.JWT.ExpireHour) * time.Hour)
claims := &JWTClaims{
UserID: user.ID,
Email: user.Email,
Role: user.Role,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(expiresAt),
IssuedAt: jwt.NewNumericDate(now),
NotBefore: jwt.NewNumericDate(now),
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenString, err := token.SignedString([]byte(s.cfg.JWT.Secret))
if err != nil {
return "", fmt.Errorf("sign token: %w", err)
}
return tokenString, nil
}
// HashPassword 使用bcrypt加密密码
func (s *AuthService) HashPassword(password string) (string, error) {
hashedBytes, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return "", err
}
return string(hashedBytes), nil
}
// CheckPassword 验证密码是否匹配
func (s *AuthService) CheckPassword(password, hashedPassword string) bool {
err := bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(password))
return err == nil
}
// RefreshToken 刷新token
func (s *AuthService) RefreshToken(ctx context.Context, oldTokenString string) (string, error) {
// 验证旧token即使过期也允许用于刷新
claims, err := s.ValidateToken(oldTokenString)
if err != nil && !errors.Is(err, ErrTokenExpired) {
return "", err
}
// 获取最新的用户信息
user, err := s.userRepo.GetByID(ctx, claims.UserID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return "", ErrInvalidToken
}
return "", fmt.Errorf("get user: %w", err)
}
// 检查用户状态
if !user.IsActive() {
return "", ErrUserNotActive
}
// 生成新token
return s.GenerateToken(user)
}

View File

@@ -0,0 +1,422 @@
package service
import (
"context"
"errors"
"fmt"
"log"
"strconv"
"time"
"sub2api/internal/model"
"sub2api/internal/repository"
"github.com/redis/go-redis/v9"
)
// 缓存Key前缀和TTL
const (
billingBalanceKeyPrefix = "billing:balance:"
billingSubKeyPrefix = "billing:sub:"
billingCacheTTL = 5 * time.Minute
)
// 订阅缓存Hash字段
const (
subFieldStatus = "status"
subFieldExpiresAt = "expires_at"
subFieldDailyUsage = "daily_usage"
subFieldWeeklyUsage = "weekly_usage"
subFieldMonthlyUsage = "monthly_usage"
subFieldVersion = "version"
)
// 错误定义
// 注ErrInsufficientBalance在redeem_service.go中定义
// 注ErrDailyLimitExceeded/ErrWeeklyLimitExceeded/ErrMonthlyLimitExceeded在subscription_service.go中定义
var (
ErrSubscriptionInvalid = errors.New("subscription is invalid or expired")
)
// 预编译的Lua脚本
var (
// deductBalanceScript: 扣减余额缓存key不存在则忽略
deductBalanceScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1])
if current == false then
return 0
end
local newVal = tonumber(current) - tonumber(ARGV[1])
redis.call('SET', KEYS[1], newVal)
redis.call('EXPIRE', KEYS[1], ARGV[2])
return 1
`)
// updateSubUsageScript: 更新订阅用量缓存key不存在则忽略
updateSubUsageScript = redis.NewScript(`
local exists = redis.call('EXISTS', KEYS[1])
if exists == 0 then
return 0
end
local cost = tonumber(ARGV[1])
redis.call('HINCRBYFLOAT', KEYS[1], 'daily_usage', cost)
redis.call('HINCRBYFLOAT', KEYS[1], 'weekly_usage', cost)
redis.call('HINCRBYFLOAT', KEYS[1], 'monthly_usage', cost)
redis.call('EXPIRE', KEYS[1], ARGV[2])
return 1
`)
)
// subscriptionCacheData 订阅缓存数据结构(内部使用)
type subscriptionCacheData struct {
Status string
ExpiresAt time.Time
DailyUsage float64
WeeklyUsage float64
MonthlyUsage float64
Version int64
}
// BillingCacheService 计费缓存服务
// 负责余额和订阅数据的缓存管理,提供高性能的计费资格检查
type BillingCacheService struct {
rdb *redis.Client
userRepo *repository.UserRepository
subRepo *repository.UserSubscriptionRepository
}
// NewBillingCacheService 创建计费缓存服务
func NewBillingCacheService(rdb *redis.Client, userRepo *repository.UserRepository, subRepo *repository.UserSubscriptionRepository) *BillingCacheService {
return &BillingCacheService{
rdb: rdb,
userRepo: userRepo,
subRepo: subRepo,
}
}
// ============================================
// 余额缓存方法
// ============================================
// GetUserBalance 获取用户余额(优先从缓存读取)
func (s *BillingCacheService) GetUserBalance(ctx context.Context, userID int64) (float64, error) {
if s.rdb == nil {
// Redis不可用直接查询数据库
return s.getUserBalanceFromDB(ctx, userID)
}
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
// 尝试从缓存读取
val, err := s.rdb.Get(ctx, key).Result()
if err == nil {
balance, parseErr := strconv.ParseFloat(val, 64)
if parseErr == nil {
return balance, nil
}
}
// 缓存未命中或解析错误,从数据库读取
balance, err := s.getUserBalanceFromDB(ctx, userID)
if err != nil {
return 0, err
}
// 异步建立缓存
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
s.setBalanceCache(cacheCtx, userID, balance)
}()
return balance, nil
}
// getUserBalanceFromDB 从数据库获取用户余额
func (s *BillingCacheService) getUserBalanceFromDB(ctx context.Context, userID int64) (float64, error) {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return 0, fmt.Errorf("get user balance: %w", err)
}
return user.Balance, nil
}
// setBalanceCache 设置余额缓存
func (s *BillingCacheService) setBalanceCache(ctx context.Context, userID int64, balance float64) {
if s.rdb == nil {
return
}
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
if err := s.rdb.Set(ctx, key, balance, billingCacheTTL).Err(); err != nil {
log.Printf("Warning: set balance cache failed for user %d: %v", userID, err)
}
}
// DeductBalanceCache 扣减余额缓存(异步调用,用于扣费后更新缓存)
func (s *BillingCacheService) DeductBalanceCache(ctx context.Context, userID int64, amount float64) error {
if s.rdb == nil {
return nil
}
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
// 使用预编译的Lua脚本原子性扣减如果key不存在则忽略
_, err := deductBalanceScript.Run(ctx, s.rdb, []string{key}, amount, int(billingCacheTTL.Seconds())).Result()
if err != nil && !errors.Is(err, redis.Nil) {
log.Printf("Warning: deduct balance cache failed for user %d: %v", userID, err)
}
return nil
}
// InvalidateUserBalance 失效用户余额缓存
func (s *BillingCacheService) InvalidateUserBalance(ctx context.Context, userID int64) error {
if s.rdb == nil {
return nil
}
key := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
if err := s.rdb.Del(ctx, key).Err(); err != nil {
log.Printf("Warning: invalidate balance cache failed for user %d: %v", userID, err)
return err
}
return nil
}
// ============================================
// 订阅缓存方法
// ============================================
// GetSubscriptionStatus 获取订阅状态(优先从缓存读取)
func (s *BillingCacheService) GetSubscriptionStatus(ctx context.Context, userID, groupID int64) (*subscriptionCacheData, error) {
if s.rdb == nil {
return s.getSubscriptionFromDB(ctx, userID, groupID)
}
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
// 尝试从缓存读取
result, err := s.rdb.HGetAll(ctx, key).Result()
if err == nil && len(result) > 0 {
data, parseErr := s.parseSubscriptionCache(result)
if parseErr == nil {
return data, nil
}
}
// 缓存未命中,从数据库读取
data, err := s.getSubscriptionFromDB(ctx, userID, groupID)
if err != nil {
return nil, err
}
// 异步建立缓存
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
s.setSubscriptionCache(cacheCtx, userID, groupID, data)
}()
return data, nil
}
// getSubscriptionFromDB 从数据库获取订阅数据
func (s *BillingCacheService) getSubscriptionFromDB(ctx context.Context, userID, groupID int64) (*subscriptionCacheData, error) {
sub, err := s.subRepo.GetActiveByUserIDAndGroupID(ctx, userID, groupID)
if err != nil {
return nil, fmt.Errorf("get subscription: %w", err)
}
return &subscriptionCacheData{
Status: sub.Status,
ExpiresAt: sub.ExpiresAt,
DailyUsage: sub.DailyUsageUSD,
WeeklyUsage: sub.WeeklyUsageUSD,
MonthlyUsage: sub.MonthlyUsageUSD,
Version: sub.UpdatedAt.Unix(),
}, nil
}
// parseSubscriptionCache 解析订阅缓存数据
func (s *BillingCacheService) parseSubscriptionCache(data map[string]string) (*subscriptionCacheData, error) {
result := &subscriptionCacheData{}
result.Status = data[subFieldStatus]
if result.Status == "" {
return nil, errors.New("invalid cache: missing status")
}
if expiresStr, ok := data[subFieldExpiresAt]; ok {
expiresAt, err := strconv.ParseInt(expiresStr, 10, 64)
if err == nil {
result.ExpiresAt = time.Unix(expiresAt, 0)
}
}
if dailyStr, ok := data[subFieldDailyUsage]; ok {
result.DailyUsage, _ = strconv.ParseFloat(dailyStr, 64)
}
if weeklyStr, ok := data[subFieldWeeklyUsage]; ok {
result.WeeklyUsage, _ = strconv.ParseFloat(weeklyStr, 64)
}
if monthlyStr, ok := data[subFieldMonthlyUsage]; ok {
result.MonthlyUsage, _ = strconv.ParseFloat(monthlyStr, 64)
}
if versionStr, ok := data[subFieldVersion]; ok {
result.Version, _ = strconv.ParseInt(versionStr, 10, 64)
}
return result, nil
}
// setSubscriptionCache 设置订阅缓存
func (s *BillingCacheService) setSubscriptionCache(ctx context.Context, userID, groupID int64, data *subscriptionCacheData) {
if s.rdb == nil || data == nil {
return
}
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
fields := map[string]interface{}{
subFieldStatus: data.Status,
subFieldExpiresAt: data.ExpiresAt.Unix(),
subFieldDailyUsage: data.DailyUsage,
subFieldWeeklyUsage: data.WeeklyUsage,
subFieldMonthlyUsage: data.MonthlyUsage,
subFieldVersion: data.Version,
}
pipe := s.rdb.Pipeline()
pipe.HSet(ctx, key, fields)
pipe.Expire(ctx, key, billingCacheTTL)
if _, err := pipe.Exec(ctx); err != nil {
log.Printf("Warning: set subscription cache failed for user %d group %d: %v", userID, groupID, err)
}
}
// UpdateSubscriptionUsage 更新订阅用量缓存(异步调用,用于扣费后更新缓存)
func (s *BillingCacheService) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, costUSD float64) error {
if s.rdb == nil {
return nil
}
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
// 使用预编译的Lua脚本原子性增加用量如果key不存在则忽略
_, err := updateSubUsageScript.Run(ctx, s.rdb, []string{key}, costUSD, int(billingCacheTTL.Seconds())).Result()
if err != nil && !errors.Is(err, redis.Nil) {
log.Printf("Warning: update subscription usage cache failed for user %d group %d: %v", userID, groupID, err)
}
return nil
}
// InvalidateSubscription 失效指定订阅缓存
func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID, groupID int64) error {
if s.rdb == nil {
return nil
}
key := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
if err := s.rdb.Del(ctx, key).Err(); err != nil {
log.Printf("Warning: invalidate subscription cache failed for user %d group %d: %v", userID, groupID, err)
return err
}
return nil
}
// ============================================
// 统一检查方法
// ============================================
// CheckBillingEligibility 检查用户是否有资格发起请求
// 余额模式:检查缓存余额 > 0
// 订阅模式检查缓存用量未超过限额Group限额从参数传入
func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user *model.User, apiKey *model.ApiKey, group *model.Group, subscription *model.UserSubscription) error {
// 判断计费模式
isSubscriptionMode := group != nil && group.IsSubscriptionType() && subscription != nil
if isSubscriptionMode {
return s.checkSubscriptionEligibility(ctx, user.ID, group, subscription)
}
return s.checkBalanceEligibility(ctx, user.ID)
}
// checkBalanceEligibility 检查余额模式资格
func (s *BillingCacheService) checkBalanceEligibility(ctx context.Context, userID int64) error {
balance, err := s.GetUserBalance(ctx, userID)
if err != nil {
// 缓存/数据库错误,允许通过(降级处理)
log.Printf("Warning: get user balance failed, allowing request: %v", err)
return nil
}
if balance <= 0 {
return ErrInsufficientBalance
}
return nil
}
// checkSubscriptionEligibility 检查订阅模式资格
func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context, userID int64, group *model.Group, subscription *model.UserSubscription) error {
// 获取订阅缓存数据
subData, err := s.GetSubscriptionStatus(ctx, userID, group.ID)
if err != nil {
// 缓存/数据库错误降级使用传入的subscription进行检查
log.Printf("Warning: get subscription cache failed, using fallback: %v", err)
return s.checkSubscriptionLimitsFallback(subscription, group)
}
// 检查订阅状态
if subData.Status != model.SubscriptionStatusActive {
return ErrSubscriptionInvalid
}
// 检查是否过期
if time.Now().After(subData.ExpiresAt) {
return ErrSubscriptionInvalid
}
// 检查限额使用传入的Group限额配置
if group.HasDailyLimit() && subData.DailyUsage >= *group.DailyLimitUSD {
return ErrDailyLimitExceeded
}
if group.HasWeeklyLimit() && subData.WeeklyUsage >= *group.WeeklyLimitUSD {
return ErrWeeklyLimitExceeded
}
if group.HasMonthlyLimit() && subData.MonthlyUsage >= *group.MonthlyLimitUSD {
return ErrMonthlyLimitExceeded
}
return nil
}
// checkSubscriptionLimitsFallback 降级检查订阅限额
func (s *BillingCacheService) checkSubscriptionLimitsFallback(subscription *model.UserSubscription, group *model.Group) error {
if subscription == nil {
return ErrSubscriptionInvalid
}
if !subscription.IsActive() {
return ErrSubscriptionInvalid
}
if !subscription.CheckDailyLimit(group, 0) {
return ErrDailyLimitExceeded
}
if !subscription.CheckWeeklyLimit(group, 0) {
return ErrWeeklyLimitExceeded
}
if !subscription.CheckMonthlyLimit(group, 0) {
return ErrMonthlyLimitExceeded
}
return nil
}

View File

@@ -0,0 +1,279 @@
package service
import (
"fmt"
"log"
"strings"
"sub2api/internal/config"
)
// ModelPricing 模型价格配置per-token价格与LiteLLM格式一致
type ModelPricing struct {
InputPricePerToken float64 // 每token输入价格 (USD)
OutputPricePerToken float64 // 每token输出价格 (USD)
CacheCreationPricePerToken float64 // 缓存创建每token价格 (USD)
CacheReadPricePerToken float64 // 缓存读取每token价格 (USD)
CacheCreation5mPrice float64 // 5分钟缓存创建价格每百万token- 仅用于硬编码回退
CacheCreation1hPrice float64 // 1小时缓存创建价格每百万token- 仅用于硬编码回退
SupportsCacheBreakdown bool // 是否支持详细的缓存分类
}
// UsageTokens 使用的token数量
type UsageTokens struct {
InputTokens int
OutputTokens int
CacheCreationTokens int
CacheReadTokens int
CacheCreation5mTokens int
CacheCreation1hTokens int
}
// CostBreakdown 费用明细
type CostBreakdown struct {
InputCost float64
OutputCost float64
CacheCreationCost float64
CacheReadCost float64
TotalCost float64
ActualCost float64 // 应用倍率后的实际费用
}
// BillingService 计费服务
type BillingService struct {
cfg *config.Config
pricingService *PricingService
fallbackPrices map[string]*ModelPricing // 硬编码回退价格
}
// NewBillingService 创建计费服务实例
func NewBillingService(cfg *config.Config, pricingService *PricingService) *BillingService {
s := &BillingService{
cfg: cfg,
pricingService: pricingService,
fallbackPrices: make(map[string]*ModelPricing),
}
// 初始化硬编码回退价格(当动态价格不可用时使用)
s.initFallbackPricing()
return s
}
// initFallbackPricing 初始化硬编码回退价格(当动态价格不可用时使用)
// 价格单位USD per token与LiteLLM格式一致
func (s *BillingService) initFallbackPricing() {
// Claude 4.5 Opus
s.fallbackPrices["claude-opus-4.5"] = &ModelPricing{
InputPricePerToken: 5e-6, // $5 per MTok
OutputPricePerToken: 25e-6, // $25 per MTok
CacheCreationPricePerToken: 6.25e-6, // $6.25 per MTok
CacheReadPricePerToken: 0.5e-6, // $0.50 per MTok
SupportsCacheBreakdown: false,
}
// Claude 4 Sonnet
s.fallbackPrices["claude-sonnet-4"] = &ModelPricing{
InputPricePerToken: 3e-6, // $3 per MTok
OutputPricePerToken: 15e-6, // $15 per MTok
CacheCreationPricePerToken: 3.75e-6, // $3.75 per MTok
CacheReadPricePerToken: 0.3e-6, // $0.30 per MTok
SupportsCacheBreakdown: false,
}
// Claude 3.5 Sonnet
s.fallbackPrices["claude-3-5-sonnet"] = &ModelPricing{
InputPricePerToken: 3e-6, // $3 per MTok
OutputPricePerToken: 15e-6, // $15 per MTok
CacheCreationPricePerToken: 3.75e-6, // $3.75 per MTok
CacheReadPricePerToken: 0.3e-6, // $0.30 per MTok
SupportsCacheBreakdown: false,
}
// Claude 3.5 Haiku
s.fallbackPrices["claude-3-5-haiku"] = &ModelPricing{
InputPricePerToken: 1e-6, // $1 per MTok
OutputPricePerToken: 5e-6, // $5 per MTok
CacheCreationPricePerToken: 1.25e-6, // $1.25 per MTok
CacheReadPricePerToken: 0.1e-6, // $0.10 per MTok
SupportsCacheBreakdown: false,
}
// Claude 3 Opus
s.fallbackPrices["claude-3-opus"] = &ModelPricing{
InputPricePerToken: 15e-6, // $15 per MTok
OutputPricePerToken: 75e-6, // $75 per MTok
CacheCreationPricePerToken: 18.75e-6, // $18.75 per MTok
CacheReadPricePerToken: 1.5e-6, // $1.50 per MTok
SupportsCacheBreakdown: false,
}
// Claude 3 Haiku
s.fallbackPrices["claude-3-haiku"] = &ModelPricing{
InputPricePerToken: 0.25e-6, // $0.25 per MTok
OutputPricePerToken: 1.25e-6, // $1.25 per MTok
CacheCreationPricePerToken: 0.3e-6, // $0.30 per MTok
CacheReadPricePerToken: 0.03e-6, // $0.03 per MTok
SupportsCacheBreakdown: false,
}
}
// getFallbackPricing 根据模型系列获取回退价格
func (s *BillingService) getFallbackPricing(model string) *ModelPricing {
modelLower := strings.ToLower(model)
// 按模型系列匹配
if strings.Contains(modelLower, "opus") {
if strings.Contains(modelLower, "4.5") || strings.Contains(modelLower, "4-5") {
return s.fallbackPrices["claude-opus-4.5"]
}
return s.fallbackPrices["claude-3-opus"]
}
if strings.Contains(modelLower, "sonnet") {
if strings.Contains(modelLower, "4") && !strings.Contains(modelLower, "3") {
return s.fallbackPrices["claude-sonnet-4"]
}
return s.fallbackPrices["claude-3-5-sonnet"]
}
if strings.Contains(modelLower, "haiku") {
if strings.Contains(modelLower, "3-5") || strings.Contains(modelLower, "3.5") {
return s.fallbackPrices["claude-3-5-haiku"]
}
return s.fallbackPrices["claude-3-haiku"]
}
// 默认使用Sonnet价格
return s.fallbackPrices["claude-sonnet-4"]
}
// GetModelPricing 获取模型价格配置
func (s *BillingService) GetModelPricing(model string) (*ModelPricing, error) {
// 标准化模型名称(转小写)
model = strings.ToLower(model)
// 1. 优先从动态价格服务获取
if s.pricingService != nil {
litellmPricing := s.pricingService.GetModelPricing(model)
if litellmPricing != nil {
return &ModelPricing{
InputPricePerToken: litellmPricing.InputCostPerToken,
OutputPricePerToken: litellmPricing.OutputCostPerToken,
CacheCreationPricePerToken: litellmPricing.CacheCreationInputTokenCost,
CacheReadPricePerToken: litellmPricing.CacheReadInputTokenCost,
SupportsCacheBreakdown: false,
}, nil
}
}
// 2. 使用硬编码回退价格
fallback := s.getFallbackPricing(model)
if fallback != nil {
log.Printf("[Billing] Using fallback pricing for model: %s", model)
return fallback, nil
}
return nil, fmt.Errorf("pricing not found for model: %s", model)
}
// CalculateCost 计算使用费用
func (s *BillingService) CalculateCost(model string, tokens UsageTokens, rateMultiplier float64) (*CostBreakdown, error) {
pricing, err := s.GetModelPricing(model)
if err != nil {
return nil, err
}
breakdown := &CostBreakdown{}
// 计算输入token费用使用per-token价格
breakdown.InputCost = float64(tokens.InputTokens) * pricing.InputPricePerToken
// 计算输出token费用
breakdown.OutputCost = float64(tokens.OutputTokens) * pricing.OutputPricePerToken
// 计算缓存费用
if pricing.SupportsCacheBreakdown && (pricing.CacheCreation5mPrice > 0 || pricing.CacheCreation1hPrice > 0) {
// 支持详细缓存分类的模型5分钟/1小时缓存
breakdown.CacheCreationCost = float64(tokens.CacheCreation5mTokens)/1_000_000*pricing.CacheCreation5mPrice +
float64(tokens.CacheCreation1hTokens)/1_000_000*pricing.CacheCreation1hPrice
} else {
// 标准缓存创建价格per-token
breakdown.CacheCreationCost = float64(tokens.CacheCreationTokens) * pricing.CacheCreationPricePerToken
}
breakdown.CacheReadCost = float64(tokens.CacheReadTokens) * pricing.CacheReadPricePerToken
// 计算总费用
breakdown.TotalCost = breakdown.InputCost + breakdown.OutputCost +
breakdown.CacheCreationCost + breakdown.CacheReadCost
// 应用倍率计算实际费用
if rateMultiplier <= 0 {
rateMultiplier = 1.0
}
breakdown.ActualCost = breakdown.TotalCost * rateMultiplier
return breakdown, nil
}
// CalculateCostWithConfig 使用配置中的默认倍率计算费用
func (s *BillingService) CalculateCostWithConfig(model string, tokens UsageTokens) (*CostBreakdown, error) {
multiplier := s.cfg.Default.RateMultiplier
if multiplier <= 0 {
multiplier = 1.0
}
return s.CalculateCost(model, tokens, multiplier)
}
// ListSupportedModels 列出所有支持的模型现在总是返回true因为有模糊匹配
func (s *BillingService) ListSupportedModels() []string {
models := make([]string, 0)
// 返回回退价格支持的模型系列
for model := range s.fallbackPrices {
models = append(models, model)
}
return models
}
// IsModelSupported 检查模型是否支持现在总是返回true因为有模糊匹配回退
func (s *BillingService) IsModelSupported(model string) bool {
// 所有Claude模型都有回退价格支持
modelLower := strings.ToLower(model)
return strings.Contains(modelLower, "claude") ||
strings.Contains(modelLower, "opus") ||
strings.Contains(modelLower, "sonnet") ||
strings.Contains(modelLower, "haiku")
}
// GetEstimatedCost 估算费用(用于前端展示)
func (s *BillingService) GetEstimatedCost(model string, estimatedInputTokens, estimatedOutputTokens int) (float64, error) {
tokens := UsageTokens{
InputTokens: estimatedInputTokens,
OutputTokens: estimatedOutputTokens,
}
breakdown, err := s.CalculateCostWithConfig(model, tokens)
if err != nil {
return 0, err
}
return breakdown.ActualCost, nil
}
// GetPricingServiceStatus 获取价格服务状态
func (s *BillingService) GetPricingServiceStatus() map[string]interface{} {
if s.pricingService != nil {
return s.pricingService.GetStatus()
}
return map[string]interface{}{
"model_count": len(s.fallbackPrices),
"last_updated": "using fallback",
"local_hash": "N/A",
}
}
// ForceUpdatePricing 强制更新价格数据
func (s *BillingService) ForceUpdatePricing() error {
if s.pricingService != nil {
return s.pricingService.ForceUpdate()
}
return fmt.Errorf("pricing service not initialized")
}

View File

@@ -0,0 +1,251 @@
package service
import (
"context"
"fmt"
"log"
"time"
"github.com/redis/go-redis/v9"
)
const (
// Redis key prefixes
accountConcurrencyKey = "concurrency:account:"
userConcurrencyKey = "concurrency:user:"
userWaitCountKey = "concurrency:wait:"
// TTL for concurrency keys (auto-release safety net)
concurrencyKeyTTL = 10 * time.Minute
// Wait polling interval
waitPollInterval = 100 * time.Millisecond
// Default max wait time
defaultMaxWait = 60 * time.Second
// Default extra wait slots beyond concurrency limit
defaultExtraWaitSlots = 20
)
// Pre-compiled Lua scripts for better performance
var (
// acquireScript: increment counter if below max, return 1 if successful
acquireScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1])
if current == false then
current = 0
else
current = tonumber(current)
end
if current < tonumber(ARGV[1]) then
redis.call('INCR', KEYS[1])
redis.call('EXPIRE', KEYS[1], ARGV[2])
return 1
end
return 0
`)
// releaseScript: decrement counter, but don't go below 0
releaseScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1])
if current ~= false and tonumber(current) > 0 then
redis.call('DECR', KEYS[1])
end
return 1
`)
// incrementWaitScript: increment wait counter if below max, return 1 if successful
incrementWaitScript = redis.NewScript(`
local waitKey = KEYS[1]
local maxWait = tonumber(ARGV[1])
local ttl = tonumber(ARGV[2])
local current = redis.call('GET', waitKey)
if current == false then
current = 0
else
current = tonumber(current)
end
if current >= maxWait then
return 0
end
redis.call('INCR', waitKey)
redis.call('EXPIRE', waitKey, ttl)
return 1
`)
// decrementWaitScript: decrement wait counter, but don't go below 0
decrementWaitScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1])
if current ~= false and tonumber(current) > 0 then
redis.call('DECR', KEYS[1])
end
return 1
`)
)
// ConcurrencyService manages concurrent request limiting for accounts and users
type ConcurrencyService struct {
rdb *redis.Client
}
// NewConcurrencyService creates a new ConcurrencyService
func NewConcurrencyService(rdb *redis.Client) *ConcurrencyService {
return &ConcurrencyService{rdb: rdb}
}
// AcquireResult represents the result of acquiring a concurrency slot
type AcquireResult struct {
Acquired bool
ReleaseFunc func() // Must be called when done (typically via defer)
}
// AcquireAccountSlot attempts to acquire a concurrency slot for an account.
// If the account is at max concurrency, it waits until a slot is available or timeout.
// Returns a release function that MUST be called when the request completes.
func (s *ConcurrencyService) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (*AcquireResult, error) {
key := fmt.Sprintf("%s%d", accountConcurrencyKey, accountID)
return s.acquireSlot(ctx, key, maxConcurrency)
}
// AcquireUserSlot attempts to acquire a concurrency slot for a user.
// If the user is at max concurrency, it waits until a slot is available or timeout.
// Returns a release function that MUST be called when the request completes.
func (s *ConcurrencyService) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int) (*AcquireResult, error) {
key := fmt.Sprintf("%s%d", userConcurrencyKey, userID)
return s.acquireSlot(ctx, key, maxConcurrency)
}
// acquireSlot is the core implementation for acquiring a concurrency slot
func (s *ConcurrencyService) acquireSlot(ctx context.Context, key string, maxConcurrency int) (*AcquireResult, error) {
// If maxConcurrency is 0 or negative, no limit
if maxConcurrency <= 0 {
return &AcquireResult{
Acquired: true,
ReleaseFunc: func() {}, // no-op
}, nil
}
// Try to acquire immediately
acquired, err := s.tryAcquire(ctx, key, maxConcurrency)
if err != nil {
return nil, err
}
if acquired {
return &AcquireResult{
Acquired: true,
ReleaseFunc: s.makeReleaseFunc(key),
}, nil
}
// Not acquired, return with Acquired=false
// The caller (gateway handler) will handle waiting with ping support
return &AcquireResult{
Acquired: false,
ReleaseFunc: nil,
}, nil
}
// tryAcquire attempts to increment the counter if below max
// Uses pre-compiled Lua script for atomicity and performance
func (s *ConcurrencyService) tryAcquire(ctx context.Context, key string, maxConcurrency int) (bool, error) {
result, err := acquireScript.Run(ctx, s.rdb, []string{key}, maxConcurrency, int(concurrencyKeyTTL.Seconds())).Int()
if err != nil {
return false, fmt.Errorf("acquire slot failed: %w", err)
}
return result == 1, nil
}
// makeReleaseFunc creates a function to release a concurrency slot
func (s *ConcurrencyService) makeReleaseFunc(key string) func() {
return func() {
// Use background context to ensure release even if original context is cancelled
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := releaseScript.Run(ctx, s.rdb, []string{key}).Err(); err != nil {
// Log error but don't panic - TTL will eventually clean up
log.Printf("Warning: failed to release concurrency slot for %s: %v", key, err)
}
}
}
// GetCurrentCount returns the current concurrency count for debugging/monitoring
func (s *ConcurrencyService) GetCurrentCount(ctx context.Context, key string) (int, error) {
val, err := s.rdb.Get(ctx, key).Int()
if err == redis.Nil {
return 0, nil
}
if err != nil {
return 0, err
}
return val, nil
}
// GetAccountCurrentCount returns current concurrency count for an account
func (s *ConcurrencyService) GetAccountCurrentCount(ctx context.Context, accountID int64) (int, error) {
key := fmt.Sprintf("%s%d", accountConcurrencyKey, accountID)
return s.GetCurrentCount(ctx, key)
}
// GetUserCurrentCount returns current concurrency count for a user
func (s *ConcurrencyService) GetUserCurrentCount(ctx context.Context, userID int64) (int, error) {
key := fmt.Sprintf("%s%d", userConcurrencyKey, userID)
return s.GetCurrentCount(ctx, key)
}
// ============================================
// Wait Queue Count Methods
// ============================================
// IncrementWaitCount attempts to increment the wait queue counter for a user.
// Returns true if successful, false if the wait queue is full.
// maxWait should be user.Concurrency + defaultExtraWaitSlots
func (s *ConcurrencyService) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
if s.rdb == nil {
// Redis not available, allow request
return true, nil
}
key := fmt.Sprintf("%s%d", userWaitCountKey, userID)
result, err := incrementWaitScript.Run(ctx, s.rdb, []string{key}, maxWait, int(concurrencyKeyTTL.Seconds())).Int()
if err != nil {
// On error, allow the request to proceed (fail open)
log.Printf("Warning: increment wait count failed for user %d: %v", userID, err)
return true, nil
}
return result == 1, nil
}
// DecrementWaitCount decrements the wait queue counter for a user.
// Should be called when a request completes or exits the wait queue.
func (s *ConcurrencyService) DecrementWaitCount(ctx context.Context, userID int64) {
if s.rdb == nil {
return
}
key := fmt.Sprintf("%s%d", userWaitCountKey, userID)
// Use background context to ensure decrement even if original context is cancelled
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := decrementWaitScript.Run(bgCtx, s.rdb, []string{key}).Err(); err != nil {
log.Printf("Warning: decrement wait count failed for user %d: %v", userID, err)
}
}
// GetUserWaitCount returns current wait queue count for a user
func (s *ConcurrencyService) GetUserWaitCount(ctx context.Context, userID int64) (int, error) {
key := fmt.Sprintf("%s%d", userWaitCountKey, userID)
return s.GetCurrentCount(ctx, key)
}
// CalculateMaxWait calculates the maximum wait queue size for a user
// maxWait = userConcurrency + defaultExtraWaitSlots
func CalculateMaxWait(userConcurrency int) int {
if userConcurrency <= 0 {
userConcurrency = 1
}
return userConcurrency + defaultExtraWaitSlots
}

View File

@@ -0,0 +1,109 @@
package service
import (
"context"
"fmt"
"log"
"sync"
"time"
)
// EmailTask 邮件发送任务
type EmailTask struct {
Email string
SiteName string
TaskType string // "verify_code"
}
// EmailQueueService 异步邮件队列服务
type EmailQueueService struct {
emailService *EmailService
taskChan chan EmailTask
wg sync.WaitGroup
stopChan chan struct{}
workers int
}
// NewEmailQueueService 创建邮件队列服务
func NewEmailQueueService(emailService *EmailService, workers int) *EmailQueueService {
if workers <= 0 {
workers = 3 // 默认3个工作协程
}
service := &EmailQueueService{
emailService: emailService,
taskChan: make(chan EmailTask, 100), // 缓冲100个任务
stopChan: make(chan struct{}),
workers: workers,
}
// 启动工作协程
service.start()
return service
}
// start 启动工作协程
func (s *EmailQueueService) start() {
for i := 0; i < s.workers; i++ {
s.wg.Add(1)
go s.worker(i)
}
log.Printf("[EmailQueue] Started %d workers", s.workers)
}
// worker 工作协程
func (s *EmailQueueService) worker(id int) {
defer s.wg.Done()
for {
select {
case task := <-s.taskChan:
s.processTask(id, task)
case <-s.stopChan:
log.Printf("[EmailQueue] Worker %d stopping", id)
return
}
}
}
// processTask 处理任务
func (s *EmailQueueService) processTask(workerID int, task EmailTask) {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
switch task.TaskType {
case "verify_code":
if err := s.emailService.SendVerifyCode(ctx, task.Email, task.SiteName); err != nil {
log.Printf("[EmailQueue] Worker %d failed to send verify code to %s: %v", workerID, task.Email, err)
} else {
log.Printf("[EmailQueue] Worker %d sent verify code to %s", workerID, task.Email)
}
default:
log.Printf("[EmailQueue] Worker %d unknown task type: %s", workerID, task.TaskType)
}
}
// EnqueueVerifyCode 将验证码发送任务加入队列
func (s *EmailQueueService) EnqueueVerifyCode(email, siteName string) error {
task := EmailTask{
Email: email,
SiteName: siteName,
TaskType: "verify_code",
}
select {
case s.taskChan <- task:
log.Printf("[EmailQueue] Enqueued verify code task for %s", email)
return nil
default:
return fmt.Errorf("email queue is full")
}
}
// Stop 停止队列服务
func (s *EmailQueueService) Stop() {
close(s.stopChan)
s.wg.Wait()
log.Println("[EmailQueue] All workers stopped")
}

View File

@@ -0,0 +1,372 @@
package service
import (
"context"
"crypto/rand"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"math/big"
"net/smtp"
"strconv"
"sub2api/internal/model"
"sub2api/internal/repository"
"time"
"github.com/redis/go-redis/v9"
)
var (
ErrEmailNotConfigured = errors.New("email service not configured")
ErrInvalidVerifyCode = errors.New("invalid or expired verification code")
ErrVerifyCodeTooFrequent = errors.New("please wait before requesting a new code")
ErrVerifyCodeMaxAttempts = errors.New("too many failed attempts, please request a new code")
)
const (
verifyCodeKeyPrefix = "email_verify:"
verifyCodeTTL = 15 * time.Minute
verifyCodeCooldown = 1 * time.Minute
maxVerifyCodeAttempts = 5
)
// verifyCodeData Redis 中存储的验证码数据
type verifyCodeData struct {
Code string `json:"code"`
Attempts int `json:"attempts"`
CreatedAt time.Time `json:"created_at"`
}
// SmtpConfig SMTP配置
type SmtpConfig struct {
Host string
Port int
Username string
Password string
From string
FromName string
UseTLS bool
}
// EmailService 邮件服务
type EmailService struct {
settingRepo *repository.SettingRepository
rdb *redis.Client
}
// NewEmailService 创建邮件服务实例
func NewEmailService(settingRepo *repository.SettingRepository, rdb *redis.Client) *EmailService {
return &EmailService{
settingRepo: settingRepo,
rdb: rdb,
}
}
// GetSmtpConfig 从数据库获取SMTP配置
func (s *EmailService) GetSmtpConfig(ctx context.Context) (*SmtpConfig, error) {
keys := []string{
model.SettingKeySmtpHost,
model.SettingKeySmtpPort,
model.SettingKeySmtpUsername,
model.SettingKeySmtpPassword,
model.SettingKeySmtpFrom,
model.SettingKeySmtpFromName,
model.SettingKeySmtpUseTLS,
}
settings, err := s.settingRepo.GetMultiple(ctx, keys)
if err != nil {
return nil, fmt.Errorf("get smtp settings: %w", err)
}
host := settings[model.SettingKeySmtpHost]
if host == "" {
return nil, ErrEmailNotConfigured
}
port := 587 // 默认端口
if portStr := settings[model.SettingKeySmtpPort]; portStr != "" {
if p, err := strconv.Atoi(portStr); err == nil {
port = p
}
}
useTLS := settings[model.SettingKeySmtpUseTLS] == "true"
return &SmtpConfig{
Host: host,
Port: port,
Username: settings[model.SettingKeySmtpUsername],
Password: settings[model.SettingKeySmtpPassword],
From: settings[model.SettingKeySmtpFrom],
FromName: settings[model.SettingKeySmtpFromName],
UseTLS: useTLS,
}, nil
}
// SendEmail 发送邮件(使用数据库中保存的配置)
func (s *EmailService) SendEmail(ctx context.Context, to, subject, body string) error {
config, err := s.GetSmtpConfig(ctx)
if err != nil {
return err
}
return s.SendEmailWithConfig(config, to, subject, body)
}
// SendEmailWithConfig 使用指定配置发送邮件
func (s *EmailService) SendEmailWithConfig(config *SmtpConfig, to, subject, body string) error {
from := config.From
if config.FromName != "" {
from = fmt.Sprintf("%s <%s>", config.FromName, config.From)
}
msg := fmt.Sprintf("From: %s\r\nTo: %s\r\nSubject: %s\r\nMIME-Version: 1.0\r\nContent-Type: text/html; charset=UTF-8\r\n\r\n%s",
from, to, subject, body)
addr := fmt.Sprintf("%s:%d", config.Host, config.Port)
auth := smtp.PlainAuth("", config.Username, config.Password, config.Host)
if config.UseTLS {
return s.sendMailTLS(addr, auth, config.From, to, []byte(msg), config.Host)
}
return smtp.SendMail(addr, auth, config.From, []string{to}, []byte(msg))
}
// sendMailTLS 使用TLS发送邮件
func (s *EmailService) sendMailTLS(addr string, auth smtp.Auth, from, to string, msg []byte, host string) error {
tlsConfig := &tls.Config{
ServerName: host,
}
conn, err := tls.Dial("tcp", addr, tlsConfig)
if err != nil {
return fmt.Errorf("tls dial: %w", err)
}
defer conn.Close()
client, err := smtp.NewClient(conn, host)
if err != nil {
return fmt.Errorf("new smtp client: %w", err)
}
defer client.Close()
if err = client.Auth(auth); err != nil {
return fmt.Errorf("smtp auth: %w", err)
}
if err = client.Mail(from); err != nil {
return fmt.Errorf("smtp mail: %w", err)
}
if err = client.Rcpt(to); err != nil {
return fmt.Errorf("smtp rcpt: %w", err)
}
w, err := client.Data()
if err != nil {
return fmt.Errorf("smtp data: %w", err)
}
_, err = w.Write(msg)
if err != nil {
return fmt.Errorf("write msg: %w", err)
}
err = w.Close()
if err != nil {
return fmt.Errorf("close writer: %w", err)
}
// Email is sent successfully after w.Close(), ignore Quit errors
// Some SMTP servers return non-standard responses on QUIT
_ = client.Quit()
return nil
}
// GenerateVerifyCode 生成6位数字验证码
func (s *EmailService) GenerateVerifyCode() (string, error) {
const digits = "0123456789"
code := make([]byte, 6)
for i := range code {
num, err := rand.Int(rand.Reader, big.NewInt(int64(len(digits))))
if err != nil {
return "", err
}
code[i] = digits[num.Int64()]
}
return string(code), nil
}
// SendVerifyCode 发送验证码邮件
func (s *EmailService) SendVerifyCode(ctx context.Context, email, siteName string) error {
key := verifyCodeKeyPrefix + email
// 检查是否在冷却期内
existing, err := s.getVerifyCodeData(ctx, key)
if err == nil && existing != nil {
if time.Since(existing.CreatedAt) < verifyCodeCooldown {
return ErrVerifyCodeTooFrequent
}
}
// 生成验证码
code, err := s.GenerateVerifyCode()
if err != nil {
return fmt.Errorf("generate code: %w", err)
}
// 保存验证码到 Redis
data := &verifyCodeData{
Code: code,
Attempts: 0,
CreatedAt: time.Now(),
}
if err := s.setVerifyCodeData(ctx, key, data); err != nil {
return fmt.Errorf("save verify code: %w", err)
}
// 构建邮件内容
subject := fmt.Sprintf("[%s] Email Verification Code", siteName)
body := s.buildVerifyCodeEmailBody(code, siteName)
// 发送邮件
if err := s.SendEmail(ctx, email, subject, body); err != nil {
return fmt.Errorf("send email: %w", err)
}
return nil
}
// VerifyCode 验证验证码
func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error {
key := verifyCodeKeyPrefix + email
data, err := s.getVerifyCodeData(ctx, key)
if err != nil || data == nil {
return ErrInvalidVerifyCode
}
// 检查是否已达到最大尝试次数
if data.Attempts >= maxVerifyCodeAttempts {
return ErrVerifyCodeMaxAttempts
}
// 验证码不匹配
if data.Code != code {
data.Attempts++
_ = s.setVerifyCodeData(ctx, key, data)
if data.Attempts >= maxVerifyCodeAttempts {
return ErrVerifyCodeMaxAttempts
}
return ErrInvalidVerifyCode
}
// 验证成功,删除验证码
s.rdb.Del(ctx, key)
return nil
}
// getVerifyCodeData 从 Redis 获取验证码数据
func (s *EmailService) getVerifyCodeData(ctx context.Context, key string) (*verifyCodeData, error) {
val, err := s.rdb.Get(ctx, key).Result()
if err != nil {
return nil, err
}
var data verifyCodeData
if err := json.Unmarshal([]byte(val), &data); err != nil {
return nil, err
}
return &data, nil
}
// setVerifyCodeData 保存验证码数据到 Redis
func (s *EmailService) setVerifyCodeData(ctx context.Context, key string, data *verifyCodeData) error {
val, err := json.Marshal(data)
if err != nil {
return err
}
return s.rdb.Set(ctx, key, val, verifyCodeTTL).Err()
}
// buildVerifyCodeEmailBody 构建验证码邮件HTML内容
func (s *EmailService) buildVerifyCodeEmailBody(code, siteName string) string {
return fmt.Sprintf(`
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<style>
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, sans-serif; background-color: #f5f5f5; margin: 0; padding: 20px; }
.container { max-width: 600px; margin: 0 auto; background-color: #ffffff; border-radius: 8px; overflow: hidden; box-shadow: 0 2px 8px rgba(0,0,0,0.1); }
.header { background: linear-gradient(135deg, #667eea 0%%, #764ba2 100%%); color: white; padding: 30px; text-align: center; }
.header h1 { margin: 0; font-size: 24px; }
.content { padding: 40px 30px; text-align: center; }
.code { font-size: 36px; font-weight: bold; letter-spacing: 8px; color: #333; background-color: #f8f9fa; padding: 20px 30px; border-radius: 8px; display: inline-block; margin: 20px 0; font-family: monospace; }
.info { color: #666; font-size: 14px; line-height: 1.6; margin-top: 20px; }
.footer { background-color: #f8f9fa; padding: 20px; text-align: center; color: #999; font-size: 12px; }
</style>
</head>
<body>
<div class="container">
<div class="header">
<h1>%s</h1>
</div>
<div class="content">
<p style="font-size: 18px; color: #333;">Your verification code is:</p>
<div class="code">%s</div>
<div class="info">
<p>This code will expire in <strong>15 minutes</strong>.</p>
<p>If you did not request this code, please ignore this email.</p>
</div>
</div>
<div class="footer">
<p>This is an automated message, please do not reply.</p>
</div>
</div>
</body>
</html>
`, siteName, code)
}
// TestSmtpConnectionWithConfig 使用指定配置测试SMTP连接
func (s *EmailService) TestSmtpConnectionWithConfig(config *SmtpConfig) error {
addr := fmt.Sprintf("%s:%d", config.Host, config.Port)
if config.UseTLS {
tlsConfig := &tls.Config{ServerName: config.Host}
conn, err := tls.Dial("tcp", addr, tlsConfig)
if err != nil {
return fmt.Errorf("tls connection failed: %w", err)
}
defer conn.Close()
client, err := smtp.NewClient(conn, config.Host)
if err != nil {
return fmt.Errorf("smtp client creation failed: %w", err)
}
defer client.Close()
auth := smtp.PlainAuth("", config.Username, config.Password, config.Host)
if err = client.Auth(auth); err != nil {
return fmt.Errorf("smtp authentication failed: %w", err)
}
return client.Quit()
}
// 非TLS连接测试
client, err := smtp.Dial(addr)
if err != nil {
return fmt.Errorf("smtp connection failed: %w", err)
}
defer client.Close()
auth := smtp.PlainAuth("", config.Username, config.Password, config.Host)
if err = client.Auth(auth); err != nil {
return fmt.Errorf("smtp authentication failed: %w", err)
}
return client.Quit()
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,194 @@
package service
import (
"context"
"errors"
"fmt"
"sub2api/internal/model"
"sub2api/internal/repository"
"gorm.io/gorm"
)
var (
ErrGroupNotFound = errors.New("group not found")
ErrGroupExists = errors.New("group name already exists")
)
// CreateGroupRequest 创建分组请求
type CreateGroupRequest struct {
Name string `json:"name"`
Description string `json:"description"`
RateMultiplier float64 `json:"rate_multiplier"`
IsExclusive bool `json:"is_exclusive"`
}
// UpdateGroupRequest 更新分组请求
type UpdateGroupRequest struct {
Name *string `json:"name"`
Description *string `json:"description"`
RateMultiplier *float64 `json:"rate_multiplier"`
IsExclusive *bool `json:"is_exclusive"`
Status *string `json:"status"`
}
// GroupService 分组管理服务
type GroupService struct {
groupRepo *repository.GroupRepository
}
// NewGroupService 创建分组服务实例
func NewGroupService(groupRepo *repository.GroupRepository) *GroupService {
return &GroupService{
groupRepo: groupRepo,
}
}
// Create 创建分组
func (s *GroupService) Create(ctx context.Context, req CreateGroupRequest) (*model.Group, error) {
// 检查名称是否已存在
exists, err := s.groupRepo.ExistsByName(ctx, req.Name)
if err != nil {
return nil, fmt.Errorf("check group exists: %w", err)
}
if exists {
return nil, ErrGroupExists
}
// 创建分组
group := &model.Group{
Name: req.Name,
Description: req.Description,
RateMultiplier: req.RateMultiplier,
IsExclusive: req.IsExclusive,
Status: model.StatusActive,
}
if err := s.groupRepo.Create(ctx, group); err != nil {
return nil, fmt.Errorf("create group: %w", err)
}
return group, nil
}
// GetByID 根据ID获取分组
func (s *GroupService) GetByID(ctx context.Context, id int64) (*model.Group, error) {
group, err := s.groupRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrGroupNotFound
}
return nil, fmt.Errorf("get group: %w", err)
}
return group, nil
}
// List 获取分组列表
func (s *GroupService) List(ctx context.Context, params repository.PaginationParams) ([]model.Group, *repository.PaginationResult, error) {
groups, pagination, err := s.groupRepo.List(ctx, params)
if err != nil {
return nil, nil, fmt.Errorf("list groups: %w", err)
}
return groups, pagination, nil
}
// ListActive 获取活跃分组列表
func (s *GroupService) ListActive(ctx context.Context) ([]model.Group, error) {
groups, err := s.groupRepo.ListActive(ctx)
if err != nil {
return nil, fmt.Errorf("list active groups: %w", err)
}
return groups, nil
}
// Update 更新分组
func (s *GroupService) Update(ctx context.Context, id int64, req UpdateGroupRequest) (*model.Group, error) {
group, err := s.groupRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrGroupNotFound
}
return nil, fmt.Errorf("get group: %w", err)
}
// 更新字段
if req.Name != nil && *req.Name != group.Name {
// 检查新名称是否已存在
exists, err := s.groupRepo.ExistsByName(ctx, *req.Name)
if err != nil {
return nil, fmt.Errorf("check group exists: %w", err)
}
if exists {
return nil, ErrGroupExists
}
group.Name = *req.Name
}
if req.Description != nil {
group.Description = *req.Description
}
if req.RateMultiplier != nil {
group.RateMultiplier = *req.RateMultiplier
}
if req.IsExclusive != nil {
group.IsExclusive = *req.IsExclusive
}
if req.Status != nil {
group.Status = *req.Status
}
if err := s.groupRepo.Update(ctx, group); err != nil {
return nil, fmt.Errorf("update group: %w", err)
}
return group, nil
}
// Delete 删除分组
func (s *GroupService) Delete(ctx context.Context, id int64) error {
// 检查分组是否存在
_, err := s.groupRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrGroupNotFound
}
return fmt.Errorf("get group: %w", err)
}
if err := s.groupRepo.Delete(ctx, id); err != nil {
return fmt.Errorf("delete group: %w", err)
}
return nil
}
// GetStats 获取分组统计信息
func (s *GroupService) GetStats(ctx context.Context, id int64) (map[string]interface{}, error) {
group, err := s.groupRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrGroupNotFound
}
return nil, fmt.Errorf("get group: %w", err)
}
// 获取账号数量
accountCount, err := s.groupRepo.GetAccountCount(ctx, id)
if err != nil {
return nil, fmt.Errorf("get account count: %w", err)
}
stats := map[string]interface{}{
"id": group.ID,
"name": group.Name,
"rate_multiplier": group.RateMultiplier,
"is_exclusive": group.IsExclusive,
"status": group.Status,
"account_count": accountCount,
}
return stats, nil
}

View File

@@ -0,0 +1,282 @@
package service
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"log"
"net/http"
"regexp"
"strconv"
"time"
"github.com/redis/go-redis/v9"
)
const (
// Redis key prefix
identityFingerprintKey = "identity:fingerprint:"
)
// 预编译正则表达式(避免每次调用重新编译)
var (
// 匹配 user_id 格式: user_{64位hex}_account__session_{uuid}
userIDRegex = regexp.MustCompile(`^user_[a-f0-9]{64}_account__session_([a-f0-9-]{36})$`)
// 匹配 User-Agent 版本号: xxx/x.y.z
userAgentVersionRegex = regexp.MustCompile(`/(\d+)\.(\d+)\.(\d+)`)
)
// Fingerprint 存储的指纹数据结构
type Fingerprint struct {
ClientID string `json:"client_id"` // 64位hex客户端ID首次随机生成
UserAgent string `json:"user_agent"` // User-Agent
StainlessLang string `json:"x_stainless_lang"` // x-stainless-lang
StainlessPackageVersion string `json:"x_stainless_package_version"` // x-stainless-package-version
StainlessOS string `json:"x_stainless_os"` // x-stainless-os
StainlessArch string `json:"x_stainless_arch"` // x-stainless-arch
StainlessRuntime string `json:"x_stainless_runtime"` // x-stainless-runtime
StainlessRuntimeVersion string `json:"x_stainless_runtime_version"` // x-stainless-runtime-version
}
// 默认指纹值(当客户端未提供时使用)
var defaultFingerprint = Fingerprint{
UserAgent: "claude-cli/2.0.62 (external, cli)",
StainlessLang: "js",
StainlessPackageVersion: "0.52.0",
StainlessOS: "Linux",
StainlessArch: "x64",
StainlessRuntime: "node",
StainlessRuntimeVersion: "v22.14.0",
}
// IdentityService 管理OAuth账号的请求身份指纹
type IdentityService struct {
rdb *redis.Client
}
// NewIdentityService 创建新的IdentityService
func NewIdentityService(rdb *redis.Client) *IdentityService {
return &IdentityService{rdb: rdb}
}
// GetOrCreateFingerprint 获取或创建账号的指纹
// 如果缓存存在检测user-agent版本新版本则更新
// 如果缓存不存在生成随机ClientID并从请求头创建指纹然后缓存
func (s *IdentityService) GetOrCreateFingerprint(ctx context.Context, accountID int64, headers http.Header) (*Fingerprint, error) {
key := identityFingerprintKey + strconv.FormatInt(accountID, 10)
// 尝试从Redis获取缓存的指纹
data, err := s.rdb.Get(ctx, key).Bytes()
if err == nil && len(data) > 0 {
// 缓存存在,解析指纹
var cached Fingerprint
if err := json.Unmarshal(data, &cached); err == nil {
// 检查客户端的user-agent是否是更新版本
clientUA := headers.Get("User-Agent")
if clientUA != "" && isNewerVersion(clientUA, cached.UserAgent) {
// 更新user-agent
cached.UserAgent = clientUA
// 保存更新后的指纹
if newData, err := json.Marshal(cached); err == nil {
s.rdb.Set(ctx, key, newData, 0) // 永不过期
}
log.Printf("Updated fingerprint user-agent for account %d: %s", accountID, clientUA)
}
return &cached, nil
}
}
// 缓存不存在或解析失败,创建新指纹
fp := s.createFingerprintFromHeaders(headers)
// 生成随机ClientID
fp.ClientID = generateClientID()
// 保存到Redis永不过期
if data, err := json.Marshal(fp); err == nil {
if err := s.rdb.Set(ctx, key, data, 0).Err(); err != nil {
log.Printf("Warning: failed to cache fingerprint for account %d: %v", accountID, err)
}
}
log.Printf("Created new fingerprint for account %d with client_id: %s", accountID, fp.ClientID)
return fp, nil
}
// createFingerprintFromHeaders 从请求头创建指纹
func (s *IdentityService) createFingerprintFromHeaders(headers http.Header) *Fingerprint {
fp := &Fingerprint{}
// 获取User-Agent
if ua := headers.Get("User-Agent"); ua != "" {
fp.UserAgent = ua
} else {
fp.UserAgent = defaultFingerprint.UserAgent
}
// 获取x-stainless-*头,如果没有则使用默认值
fp.StainlessLang = getHeaderOrDefault(headers, "X-Stainless-Lang", defaultFingerprint.StainlessLang)
fp.StainlessPackageVersion = getHeaderOrDefault(headers, "X-Stainless-Package-Version", defaultFingerprint.StainlessPackageVersion)
fp.StainlessOS = getHeaderOrDefault(headers, "X-Stainless-OS", defaultFingerprint.StainlessOS)
fp.StainlessArch = getHeaderOrDefault(headers, "X-Stainless-Arch", defaultFingerprint.StainlessArch)
fp.StainlessRuntime = getHeaderOrDefault(headers, "X-Stainless-Runtime", defaultFingerprint.StainlessRuntime)
fp.StainlessRuntimeVersion = getHeaderOrDefault(headers, "X-Stainless-Runtime-Version", defaultFingerprint.StainlessRuntimeVersion)
return fp
}
// getHeaderOrDefault 获取header值如果不存在则返回默认值
func getHeaderOrDefault(headers http.Header, key, defaultValue string) string {
if v := headers.Get(key); v != "" {
return v
}
return defaultValue
}
// ApplyFingerprint 将指纹应用到请求头覆盖原有的x-stainless-*头)
func (s *IdentityService) ApplyFingerprint(req *http.Request, fp *Fingerprint) {
if fp == nil {
return
}
// 设置User-Agent
if fp.UserAgent != "" {
req.Header.Set("User-Agent", fp.UserAgent)
}
// 设置x-stainless-*头(使用正确的大小写)
if fp.StainlessLang != "" {
req.Header.Set("X-Stainless-Lang", fp.StainlessLang)
}
if fp.StainlessPackageVersion != "" {
req.Header.Set("X-Stainless-Package-Version", fp.StainlessPackageVersion)
}
if fp.StainlessOS != "" {
req.Header.Set("X-Stainless-OS", fp.StainlessOS)
}
if fp.StainlessArch != "" {
req.Header.Set("X-Stainless-Arch", fp.StainlessArch)
}
if fp.StainlessRuntime != "" {
req.Header.Set("X-Stainless-Runtime", fp.StainlessRuntime)
}
if fp.StainlessRuntimeVersion != "" {
req.Header.Set("X-Stainless-Runtime-Version", fp.StainlessRuntimeVersion)
}
}
// RewriteUserID 重写body中的metadata.user_id
// 输入格式user_{clientId}_account__session_{sessionUUID}
// 输出格式user_{cachedClientID}_account_{accountUUID}_session_{newHash}
func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUID, cachedClientID string) ([]byte, error) {
if len(body) == 0 || accountUUID == "" || cachedClientID == "" {
return body, nil
}
// 解析JSON
var reqMap map[string]interface{}
if err := json.Unmarshal(body, &reqMap); err != nil {
return body, nil
}
metadata, ok := reqMap["metadata"].(map[string]interface{})
if !ok {
return body, nil
}
userID, ok := metadata["user_id"].(string)
if !ok || userID == "" {
return body, nil
}
// 匹配格式: user_{64位hex}_account__session_{uuid}
matches := userIDRegex.FindStringSubmatch(userID)
if matches == nil {
return body, nil
}
sessionTail := matches[1] // 原始session UUID
// 生成新的session hash: SHA256(accountID::sessionTail) -> UUID格式
seed := fmt.Sprintf("%d::%s", accountID, sessionTail)
newSessionHash := generateUUIDFromSeed(seed)
// 构建新的user_id
// 格式: user_{cachedClientID}_account_{account_uuid}_session_{newSessionHash}
newUserID := fmt.Sprintf("user_%s_account_%s_session_%s", cachedClientID, accountUUID, newSessionHash)
metadata["user_id"] = newUserID
reqMap["metadata"] = metadata
return json.Marshal(reqMap)
}
// generateClientID 生成64位十六进制客户端ID32字节随机数
func generateClientID() string {
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
// 极罕见的情况,使用时间戳+固定值作为fallback
log.Printf("Warning: crypto/rand.Read failed: %v, using fallback", err)
// 使用SHA256(当前纳秒时间)作为fallback
h := sha256.Sum256([]byte(fmt.Sprintf("%d", time.Now().UnixNano())))
return hex.EncodeToString(h[:])
}
return hex.EncodeToString(b)
}
// generateUUIDFromSeed 从种子生成确定性UUID v4格式字符串
func generateUUIDFromSeed(seed string) string {
hash := sha256.Sum256([]byte(seed))
bytes := hash[:16]
// 设置UUID v4版本和变体位
bytes[6] = (bytes[6] & 0x0f) | 0x40
bytes[8] = (bytes[8] & 0x3f) | 0x80
return fmt.Sprintf("%x-%x-%x-%x-%x",
bytes[0:4], bytes[4:6], bytes[6:8], bytes[8:10], bytes[10:16])
}
// parseUserAgentVersion 解析user-agent版本号
// 例如claude-cli/2.0.62 -> (2, 0, 62)
func parseUserAgentVersion(ua string) (major, minor, patch int, ok bool) {
// 匹配 xxx/x.y.z 格式
matches := userAgentVersionRegex.FindStringSubmatch(ua)
if len(matches) != 4 {
return 0, 0, 0, false
}
major, _ = strconv.Atoi(matches[1])
minor, _ = strconv.Atoi(matches[2])
patch, _ = strconv.Atoi(matches[3])
return major, minor, patch, true
}
// isNewerVersion 比较版本号判断newUA是否比cachedUA更新
func isNewerVersion(newUA, cachedUA string) bool {
newMajor, newMinor, newPatch, newOk := parseUserAgentVersion(newUA)
cachedMajor, cachedMinor, cachedPatch, cachedOk := parseUserAgentVersion(cachedUA)
if !newOk || !cachedOk {
return false
}
// 比较版本号
if newMajor > cachedMajor {
return true
}
if newMajor < cachedMajor {
return false
}
if newMinor > cachedMinor {
return true
}
if newMinor < cachedMinor {
return false
}
return newPatch > cachedPatch
}

View File

@@ -0,0 +1,471 @@
package service
import (
"context"
"encoding/json"
"fmt"
"log"
"net/http"
"net/url"
"strings"
"time"
"sub2api/internal/model"
"sub2api/internal/pkg/oauth"
"sub2api/internal/repository"
"github.com/imroc/req/v3"
)
// OAuthService handles OAuth authentication flows
type OAuthService struct {
sessionStore *oauth.SessionStore
proxyRepo *repository.ProxyRepository
}
// NewOAuthService creates a new OAuth service
func NewOAuthService(proxyRepo *repository.ProxyRepository) *OAuthService {
return &OAuthService{
sessionStore: oauth.NewSessionStore(),
proxyRepo: proxyRepo,
}
}
// GenerateAuthURLResult contains the authorization URL and session info
type GenerateAuthURLResult struct {
AuthURL string `json:"auth_url"`
SessionID string `json:"session_id"`
}
// GenerateAuthURL generates an OAuth authorization URL with full scope
func (s *OAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64) (*GenerateAuthURLResult, error) {
scope := fmt.Sprintf("%s %s", oauth.ScopeProfile, oauth.ScopeInference)
return s.generateAuthURLWithScope(ctx, scope, proxyID)
}
// GenerateSetupTokenURL generates an OAuth authorization URL for setup token (inference only)
func (s *OAuthService) GenerateSetupTokenURL(ctx context.Context, proxyID *int64) (*GenerateAuthURLResult, error) {
scope := oauth.ScopeInference
return s.generateAuthURLWithScope(ctx, scope, proxyID)
}
func (s *OAuthService) generateAuthURLWithScope(ctx context.Context, scope string, proxyID *int64) (*GenerateAuthURLResult, error) {
// Generate PKCE values
state, err := oauth.GenerateState()
if err != nil {
return nil, fmt.Errorf("failed to generate state: %w", err)
}
codeVerifier, err := oauth.GenerateCodeVerifier()
if err != nil {
return nil, fmt.Errorf("failed to generate code verifier: %w", err)
}
codeChallenge := oauth.GenerateCodeChallenge(codeVerifier)
// Generate session ID
sessionID, err := oauth.GenerateSessionID()
if err != nil {
return nil, fmt.Errorf("failed to generate session ID: %w", err)
}
// Get proxy URL if specified
var proxyURL string
if proxyID != nil {
proxy, err := s.proxyRepo.GetByID(ctx, *proxyID)
if err == nil && proxy != nil {
proxyURL = proxy.URL()
}
}
// Store session
session := &oauth.OAuthSession{
State: state,
CodeVerifier: codeVerifier,
Scope: scope,
ProxyURL: proxyURL,
CreatedAt: time.Now(),
}
s.sessionStore.Set(sessionID, session)
// Build authorization URL
authURL := oauth.BuildAuthorizationURL(state, codeChallenge, scope)
return &GenerateAuthURLResult{
AuthURL: authURL,
SessionID: sessionID,
}, nil
}
// ExchangeCodeInput represents the input for code exchange
type ExchangeCodeInput struct {
SessionID string
Code string
ProxyID *int64
}
// TokenInfo represents the token information stored in credentials
type TokenInfo struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int64 `json:"expires_in"`
ExpiresAt int64 `json:"expires_at"`
RefreshToken string `json:"refresh_token,omitempty"`
Scope string `json:"scope,omitempty"`
OrgUUID string `json:"org_uuid,omitempty"`
AccountUUID string `json:"account_uuid,omitempty"`
}
// ExchangeCode exchanges authorization code for tokens
func (s *OAuthService) ExchangeCode(ctx context.Context, input *ExchangeCodeInput) (*TokenInfo, error) {
// Get session
session, ok := s.sessionStore.Get(input.SessionID)
if !ok {
return nil, fmt.Errorf("session not found or expired")
}
// Get proxy URL
proxyURL := session.ProxyURL
if input.ProxyID != nil {
proxy, err := s.proxyRepo.GetByID(ctx, *input.ProxyID)
if err == nil && proxy != nil {
proxyURL = proxy.URL()
}
}
// Exchange code for token
tokenInfo, err := s.exchangeCodeForToken(ctx, input.Code, session.CodeVerifier, session.State, proxyURL)
if err != nil {
return nil, err
}
// Delete session after successful exchange
s.sessionStore.Delete(input.SessionID)
return tokenInfo, nil
}
// CookieAuthInput represents the input for cookie-based authentication
type CookieAuthInput struct {
SessionKey string
ProxyID *int64
Scope string // "full" or "inference"
}
// CookieAuth performs OAuth using sessionKey (cookie-based auto-auth)
func (s *OAuthService) CookieAuth(ctx context.Context, input *CookieAuthInput) (*TokenInfo, error) {
// Get proxy URL if specified
var proxyURL string
if input.ProxyID != nil {
proxy, err := s.proxyRepo.GetByID(ctx, *input.ProxyID)
if err == nil && proxy != nil {
proxyURL = proxy.URL()
}
}
// Determine scope
scope := fmt.Sprintf("%s %s", oauth.ScopeProfile, oauth.ScopeInference)
if input.Scope == "inference" {
scope = oauth.ScopeInference
}
// Step 1: Get organization info using sessionKey
orgUUID, err := s.getOrganizationUUID(ctx, input.SessionKey, proxyURL)
if err != nil {
return nil, fmt.Errorf("failed to get organization info: %w", err)
}
// Step 2: Generate PKCE values
codeVerifier, err := oauth.GenerateCodeVerifier()
if err != nil {
return nil, fmt.Errorf("failed to generate code verifier: %w", err)
}
codeChallenge := oauth.GenerateCodeChallenge(codeVerifier)
state, err := oauth.GenerateState()
if err != nil {
return nil, fmt.Errorf("failed to generate state: %w", err)
}
// Step 3: Get authorization code using cookie
authCode, err := s.getAuthorizationCode(ctx, input.SessionKey, orgUUID, scope, codeChallenge, state, proxyURL)
if err != nil {
return nil, fmt.Errorf("failed to get authorization code: %w", err)
}
// Step 4: Exchange code for token
tokenInfo, err := s.exchangeCodeForToken(ctx, authCode, codeVerifier, state, proxyURL)
if err != nil {
return nil, fmt.Errorf("failed to exchange code: %w", err)
}
// Ensure org_uuid is set (from step 1 if not from token response)
if tokenInfo.OrgUUID == "" && orgUUID != "" {
tokenInfo.OrgUUID = orgUUID
log.Printf("[OAuth] Set org_uuid from cookie auth: %s", orgUUID)
}
return tokenInfo, nil
}
// getOrganizationUUID gets the organization UUID from claude.ai using sessionKey
func (s *OAuthService) getOrganizationUUID(ctx context.Context, sessionKey, proxyURL string) (string, error) {
client := s.createReqClient(proxyURL)
var orgs []struct {
UUID string `json:"uuid"`
}
targetURL := "https://claude.ai/api/organizations"
log.Printf("[OAuth] Step 1: Getting organization UUID from %s", targetURL)
resp, err := client.R().
SetContext(ctx).
SetCookies(&http.Cookie{
Name: "sessionKey",
Value: sessionKey,
}).
SetSuccessResult(&orgs).
Get(targetURL)
if err != nil {
log.Printf("[OAuth] Step 1 FAILED - Request error: %v", err)
return "", fmt.Errorf("request failed: %w", err)
}
log.Printf("[OAuth] Step 1 Response - Status: %d, Body: %s", resp.StatusCode, resp.String())
if !resp.IsSuccessState() {
return "", fmt.Errorf("failed to get organizations: status %d, body: %s", resp.StatusCode, resp.String())
}
if len(orgs) == 0 {
return "", fmt.Errorf("no organizations found")
}
log.Printf("[OAuth] Step 1 SUCCESS - Got org UUID: %s", orgs[0].UUID)
return orgs[0].UUID, nil
}
// getAuthorizationCode gets the authorization code using sessionKey
func (s *OAuthService) getAuthorizationCode(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error) {
client := s.createReqClient(proxyURL)
authURL := fmt.Sprintf("https://claude.ai/v1/oauth/%s/authorize", orgUUID)
// Build request body - must include organization_uuid as per CRS
reqBody := map[string]interface{}{
"response_type": "code",
"client_id": oauth.ClientID,
"organization_uuid": orgUUID, // Required field!
"redirect_uri": oauth.RedirectURI,
"scope": scope,
"state": state,
"code_challenge": codeChallenge,
"code_challenge_method": "S256",
}
reqBodyJSON, _ := json.Marshal(reqBody)
log.Printf("[OAuth] Step 2: Getting authorization code from %s", authURL)
log.Printf("[OAuth] Step 2 Request Body: %s", string(reqBodyJSON))
// Response contains redirect_uri with code, not direct code field
var result struct {
RedirectURI string `json:"redirect_uri"`
}
resp, err := client.R().
SetContext(ctx).
SetCookies(&http.Cookie{
Name: "sessionKey",
Value: sessionKey,
}).
SetHeader("Accept", "application/json").
SetHeader("Accept-Language", "en-US,en;q=0.9").
SetHeader("Cache-Control", "no-cache").
SetHeader("Origin", "https://claude.ai").
SetHeader("Referer", "https://claude.ai/new").
SetHeader("Content-Type", "application/json").
SetBody(reqBody).
SetSuccessResult(&result).
Post(authURL)
if err != nil {
log.Printf("[OAuth] Step 2 FAILED - Request error: %v", err)
return "", fmt.Errorf("request failed: %w", err)
}
log.Printf("[OAuth] Step 2 Response - Status: %d, Body: %s", resp.StatusCode, resp.String())
if !resp.IsSuccessState() {
return "", fmt.Errorf("failed to get authorization code: status %d, body: %s", resp.StatusCode, resp.String())
}
if result.RedirectURI == "" {
return "", fmt.Errorf("no redirect_uri in response")
}
// Parse redirect_uri to extract code and state
parsedURL, err := url.Parse(result.RedirectURI)
if err != nil {
return "", fmt.Errorf("failed to parse redirect_uri: %w", err)
}
queryParams := parsedURL.Query()
authCode := queryParams.Get("code")
responseState := queryParams.Get("state")
if authCode == "" {
return "", fmt.Errorf("no authorization code in redirect_uri")
}
// Combine code with state if present (as CRS does)
fullCode := authCode
if responseState != "" {
fullCode = authCode + "#" + responseState
}
log.Printf("[OAuth] Step 2 SUCCESS - Got authorization code: %s...", authCode[:20])
return fullCode, nil
}
// exchangeCodeForToken exchanges authorization code for tokens
func (s *OAuthService) exchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string) (*TokenInfo, error) {
client := s.createReqClient(proxyURL)
// Parse code#state format if present
authCode := code
codeState := ""
if parts := strings.Split(code, "#"); len(parts) > 1 {
authCode = parts[0]
codeState = parts[1]
}
// Build JSON body as CRS does (not form data!)
reqBody := map[string]interface{}{
"code": authCode,
"grant_type": "authorization_code",
"client_id": oauth.ClientID,
"redirect_uri": oauth.RedirectURI,
"code_verifier": codeVerifier,
}
// Add state if present
if codeState != "" {
reqBody["state"] = codeState
}
reqBodyJSON, _ := json.Marshal(reqBody)
log.Printf("[OAuth] Step 3: Exchanging code for token at %s", oauth.TokenURL)
log.Printf("[OAuth] Step 3 Request Body: %s", string(reqBodyJSON))
var tokenResp oauth.TokenResponse
resp, err := client.R().
SetContext(ctx).
SetHeader("Content-Type", "application/json").
SetBody(reqBody).
SetSuccessResult(&tokenResp).
Post(oauth.TokenURL)
if err != nil {
log.Printf("[OAuth] Step 3 FAILED - Request error: %v", err)
return nil, fmt.Errorf("request failed: %w", err)
}
log.Printf("[OAuth] Step 3 Response - Status: %d, Body: %s", resp.StatusCode, resp.String())
if !resp.IsSuccessState() {
return nil, fmt.Errorf("token exchange failed: status %d, body: %s", resp.StatusCode, resp.String())
}
log.Printf("[OAuth] Step 3 SUCCESS - Got access token")
tokenInfo := &TokenInfo{
AccessToken: tokenResp.AccessToken,
TokenType: tokenResp.TokenType,
ExpiresIn: tokenResp.ExpiresIn,
ExpiresAt: time.Now().Unix() + tokenResp.ExpiresIn,
RefreshToken: tokenResp.RefreshToken,
Scope: tokenResp.Scope,
}
// Extract org_uuid and account_uuid from response
if tokenResp.Organization != nil && tokenResp.Organization.UUID != "" {
tokenInfo.OrgUUID = tokenResp.Organization.UUID
log.Printf("[OAuth] Got org_uuid: %s", tokenInfo.OrgUUID)
}
if tokenResp.Account != nil && tokenResp.Account.UUID != "" {
tokenInfo.AccountUUID = tokenResp.Account.UUID
log.Printf("[OAuth] Got account_uuid: %s", tokenInfo.AccountUUID)
}
return tokenInfo, nil
}
// RefreshToken refreshes an OAuth token
func (s *OAuthService) RefreshToken(ctx context.Context, refreshToken string, proxyURL string) (*TokenInfo, error) {
client := s.createReqClient(proxyURL)
formData := url.Values{}
formData.Set("grant_type", "refresh_token")
formData.Set("refresh_token", refreshToken)
formData.Set("client_id", oauth.ClientID)
var tokenResp oauth.TokenResponse
resp, err := client.R().
SetContext(ctx).
SetFormDataFromValues(formData).
SetSuccessResult(&tokenResp).
Post(oauth.TokenURL)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
if !resp.IsSuccessState() {
return nil, fmt.Errorf("token refresh failed: status %d, body: %s", resp.StatusCode, resp.String())
}
return &TokenInfo{
AccessToken: tokenResp.AccessToken,
TokenType: tokenResp.TokenType,
ExpiresIn: tokenResp.ExpiresIn,
ExpiresAt: time.Now().Unix() + tokenResp.ExpiresIn,
RefreshToken: tokenResp.RefreshToken,
Scope: tokenResp.Scope,
}, nil
}
// RefreshAccountToken refreshes token for an account
func (s *OAuthService) RefreshAccountToken(ctx context.Context, account *model.Account) (*TokenInfo, error) {
refreshToken := account.GetCredential("refresh_token")
if refreshToken == "" {
return nil, fmt.Errorf("no refresh token available")
}
var proxyURL string
if account.ProxyID != nil {
proxy, err := s.proxyRepo.GetByID(ctx, *account.ProxyID)
if err == nil && proxy != nil {
proxyURL = proxy.URL()
}
}
return s.RefreshToken(ctx, refreshToken, proxyURL)
}
// createReqClient creates a req client with Chrome impersonation and optional proxy
func (s *OAuthService) createReqClient(proxyURL string) *req.Client {
client := req.C().
ImpersonateChrome(). // Impersonate Chrome browser to bypass Cloudflare
SetTimeout(60 * time.Second)
// Set proxy if specified
if proxyURL != "" {
client.SetProxyURL(proxyURL)
}
return client
}

View File

@@ -0,0 +1,572 @@
package service
import (
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"os"
"path/filepath"
"strings"
"sync"
"time"
"sub2api/internal/config"
)
// LiteLLMModelPricing LiteLLM价格数据结构
// 只保留我们需要的字段,使用指针来处理可能缺失的值
type LiteLLMModelPricing struct {
InputCostPerToken float64 `json:"input_cost_per_token"`
OutputCostPerToken float64 `json:"output_cost_per_token"`
CacheCreationInputTokenCost float64 `json:"cache_creation_input_token_cost"`
CacheReadInputTokenCost float64 `json:"cache_read_input_token_cost"`
LiteLLMProvider string `json:"litellm_provider"`
Mode string `json:"mode"`
SupportsPromptCaching bool `json:"supports_prompt_caching"`
}
// LiteLLMRawEntry 用于解析原始JSON数据
type LiteLLMRawEntry struct {
InputCostPerToken *float64 `json:"input_cost_per_token"`
OutputCostPerToken *float64 `json:"output_cost_per_token"`
CacheCreationInputTokenCost *float64 `json:"cache_creation_input_token_cost"`
CacheReadInputTokenCost *float64 `json:"cache_read_input_token_cost"`
LiteLLMProvider string `json:"litellm_provider"`
Mode string `json:"mode"`
SupportsPromptCaching bool `json:"supports_prompt_caching"`
}
// PricingService 动态价格服务
type PricingService struct {
cfg *config.Config
mu sync.RWMutex
pricingData map[string]*LiteLLMModelPricing
lastUpdated time.Time
localHash string
// 停止信号
stopCh chan struct{}
wg sync.WaitGroup
}
// NewPricingService 创建价格服务
func NewPricingService(cfg *config.Config) *PricingService {
s := &PricingService{
cfg: cfg,
pricingData: make(map[string]*LiteLLMModelPricing),
stopCh: make(chan struct{}),
}
return s
}
// Initialize 初始化价格服务
func (s *PricingService) Initialize() error {
// 确保数据目录存在
if err := os.MkdirAll(s.cfg.Pricing.DataDir, 0755); err != nil {
log.Printf("[Pricing] Failed to create data directory: %v", err)
}
// 首次加载价格数据
if err := s.checkAndUpdatePricing(); err != nil {
log.Printf("[Pricing] Initial load failed, using fallback: %v", err)
if err := s.useFallbackPricing(); err != nil {
return fmt.Errorf("failed to load pricing data: %w", err)
}
}
// 启动定时更新
s.startUpdateScheduler()
log.Printf("[Pricing] Service initialized with %d models", len(s.pricingData))
return nil
}
// Stop 停止价格服务
func (s *PricingService) Stop() {
close(s.stopCh)
s.wg.Wait()
log.Println("[Pricing] Service stopped")
}
// startUpdateScheduler 启动定时更新调度器
func (s *PricingService) startUpdateScheduler() {
// 定期检查哈希更新
hashInterval := time.Duration(s.cfg.Pricing.HashCheckIntervalMinutes) * time.Minute
if hashInterval < time.Minute {
hashInterval = 10 * time.Minute
}
s.wg.Add(1)
go func() {
defer s.wg.Done()
ticker := time.NewTicker(hashInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if err := s.syncWithRemote(); err != nil {
log.Printf("[Pricing] Sync failed: %v", err)
}
case <-s.stopCh:
return
}
}
}()
log.Printf("[Pricing] Update scheduler started (check every %v)", hashInterval)
}
// checkAndUpdatePricing 检查并更新价格数据
func (s *PricingService) checkAndUpdatePricing() error {
pricingFile := s.getPricingFilePath()
// 检查本地文件是否存在
if _, err := os.Stat(pricingFile); os.IsNotExist(err) {
log.Println("[Pricing] Local pricing file not found, downloading...")
return s.downloadPricingData()
}
// 检查文件是否过期
info, err := os.Stat(pricingFile)
if err != nil {
return s.downloadPricingData()
}
fileAge := time.Since(info.ModTime())
maxAge := time.Duration(s.cfg.Pricing.UpdateIntervalHours) * time.Hour
if fileAge > maxAge {
log.Printf("[Pricing] Local file is %v old, updating...", fileAge.Round(time.Hour))
if err := s.downloadPricingData(); err != nil {
log.Printf("[Pricing] Download failed, using existing file: %v", err)
}
}
// 加载本地文件
return s.loadPricingData(pricingFile)
}
// syncWithRemote 与远程同步(基于哈希校验)
func (s *PricingService) syncWithRemote() error {
pricingFile := s.getPricingFilePath()
// 计算本地文件哈希
localHash, err := s.computeFileHash(pricingFile)
if err != nil {
log.Printf("[Pricing] Failed to compute local hash: %v", err)
return s.downloadPricingData()
}
// 如果配置了哈希URL从远程获取哈希进行比对
if s.cfg.Pricing.HashURL != "" {
remoteHash, err := s.fetchRemoteHash()
if err != nil {
log.Printf("[Pricing] Failed to fetch remote hash: %v", err)
return nil // 哈希获取失败不影响正常使用
}
if remoteHash != localHash {
log.Println("[Pricing] Remote hash differs, downloading new version...")
return s.downloadPricingData()
}
log.Println("[Pricing] Hash check passed, no update needed")
return nil
}
// 没有哈希URL时基于时间检查
info, err := os.Stat(pricingFile)
if err != nil {
return s.downloadPricingData()
}
fileAge := time.Since(info.ModTime())
maxAge := time.Duration(s.cfg.Pricing.UpdateIntervalHours) * time.Hour
if fileAge > maxAge {
log.Printf("[Pricing] File is %v old, downloading...", fileAge.Round(time.Hour))
return s.downloadPricingData()
}
return nil
}
// downloadPricingData 从远程下载价格数据
func (s *PricingService) downloadPricingData() error {
log.Printf("[Pricing] Downloading from %s", s.cfg.Pricing.RemoteURL)
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Get(s.cfg.Pricing.RemoteURL)
if err != nil {
return fmt.Errorf("download failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("download failed: HTTP %d", resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("read response failed: %w", err)
}
// 解析JSON数据使用灵活的解析方式
data, err := s.parsePricingData(body)
if err != nil {
return fmt.Errorf("parse pricing data: %w", err)
}
// 保存到本地文件
pricingFile := s.getPricingFilePath()
if err := os.WriteFile(pricingFile, body, 0644); err != nil {
log.Printf("[Pricing] Failed to save file: %v", err)
}
// 保存哈希
hash := sha256.Sum256(body)
hashStr := hex.EncodeToString(hash[:])
hashFile := s.getHashFilePath()
if err := os.WriteFile(hashFile, []byte(hashStr+"\n"), 0644); err != nil {
log.Printf("[Pricing] Failed to save hash: %v", err)
}
// 更新内存数据
s.mu.Lock()
s.pricingData = data
s.lastUpdated = time.Now()
s.localHash = hashStr
s.mu.Unlock()
log.Printf("[Pricing] Downloaded %d models successfully", len(data))
return nil
}
// parsePricingData 解析价格数据(处理各种格式)
func (s *PricingService) parsePricingData(body []byte) (map[string]*LiteLLMModelPricing, error) {
// 首先解析为 map[string]json.RawMessage
var rawData map[string]json.RawMessage
if err := json.Unmarshal(body, &rawData); err != nil {
return nil, fmt.Errorf("parse raw JSON: %w", err)
}
result := make(map[string]*LiteLLMModelPricing)
skipped := 0
for modelName, rawEntry := range rawData {
// 跳过 sample_spec 等文档条目
if modelName == "sample_spec" {
continue
}
// 尝试解析每个条目
var entry LiteLLMRawEntry
if err := json.Unmarshal(rawEntry, &entry); err != nil {
skipped++
continue
}
// 只保留有有效价格的条目
if entry.InputCostPerToken == nil && entry.OutputCostPerToken == nil {
continue
}
pricing := &LiteLLMModelPricing{
LiteLLMProvider: entry.LiteLLMProvider,
Mode: entry.Mode,
SupportsPromptCaching: entry.SupportsPromptCaching,
}
if entry.InputCostPerToken != nil {
pricing.InputCostPerToken = *entry.InputCostPerToken
}
if entry.OutputCostPerToken != nil {
pricing.OutputCostPerToken = *entry.OutputCostPerToken
}
if entry.CacheCreationInputTokenCost != nil {
pricing.CacheCreationInputTokenCost = *entry.CacheCreationInputTokenCost
}
if entry.CacheReadInputTokenCost != nil {
pricing.CacheReadInputTokenCost = *entry.CacheReadInputTokenCost
}
result[modelName] = pricing
}
if skipped > 0 {
log.Printf("[Pricing] Skipped %d invalid entries", skipped)
}
if len(result) == 0 {
return nil, fmt.Errorf("no valid pricing entries found")
}
return result, nil
}
// loadPricingData 从本地文件加载价格数据
func (s *PricingService) loadPricingData(filePath string) error {
data, err := os.ReadFile(filePath)
if err != nil {
return fmt.Errorf("read file failed: %w", err)
}
// 使用灵活的解析方式
pricingData, err := s.parsePricingData(data)
if err != nil {
return fmt.Errorf("parse pricing data: %w", err)
}
// 计算哈希
hash := sha256.Sum256(data)
hashStr := hex.EncodeToString(hash[:])
s.mu.Lock()
s.pricingData = pricingData
s.localHash = hashStr
info, _ := os.Stat(filePath)
if info != nil {
s.lastUpdated = info.ModTime()
} else {
s.lastUpdated = time.Now()
}
s.mu.Unlock()
log.Printf("[Pricing] Loaded %d models from %s", len(pricingData), filePath)
return nil
}
// useFallbackPricing 使用回退价格文件
func (s *PricingService) useFallbackPricing() error {
fallbackFile := s.cfg.Pricing.FallbackFile
if _, err := os.Stat(fallbackFile); os.IsNotExist(err) {
return fmt.Errorf("fallback file not found: %s", fallbackFile)
}
log.Printf("[Pricing] Using fallback file: %s", fallbackFile)
// 复制到数据目录
data, err := os.ReadFile(fallbackFile)
if err != nil {
return fmt.Errorf("read fallback failed: %w", err)
}
pricingFile := s.getPricingFilePath()
if err := os.WriteFile(pricingFile, data, 0644); err != nil {
log.Printf("[Pricing] Failed to copy fallback: %v", err)
}
return s.loadPricingData(fallbackFile)
}
// fetchRemoteHash 从远程获取哈希值
func (s *PricingService) fetchRemoteHash() (string, error) {
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Get(s.cfg.Pricing.HashURL)
if err != nil {
return "", err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("HTTP %d", resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", err
}
// 哈希文件格式hash filename 或者纯 hash
hash := strings.TrimSpace(string(body))
parts := strings.Fields(hash)
if len(parts) > 0 {
return parts[0], nil
}
return hash, nil
}
// computeFileHash 计算文件哈希
func (s *PricingService) computeFileHash(filePath string) (string, error) {
data, err := os.ReadFile(filePath)
if err != nil {
return "", err
}
hash := sha256.Sum256(data)
return hex.EncodeToString(hash[:]), nil
}
// GetModelPricing 获取模型价格(带模糊匹配)
func (s *PricingService) GetModelPricing(modelName string) *LiteLLMModelPricing {
s.mu.RLock()
defer s.mu.RUnlock()
if modelName == "" {
return nil
}
// 标准化模型名称
modelLower := strings.ToLower(modelName)
// 1. 精确匹配
if pricing, ok := s.pricingData[modelLower]; ok {
return pricing
}
if pricing, ok := s.pricingData[modelName]; ok {
return pricing
}
// 2. 处理常见的模型名称变体
// claude-opus-4-5-20251101 -> claude-opus-4.5-20251101
normalized := strings.ReplaceAll(modelLower, "-4-5-", "-4.5-")
if pricing, ok := s.pricingData[normalized]; ok {
return pricing
}
// 3. 尝试模糊匹配(去掉版本号后缀)
// claude-opus-4-5-20251101 -> claude-opus-4.5
baseName := s.extractBaseName(modelLower)
for key, pricing := range s.pricingData {
keyBase := s.extractBaseName(strings.ToLower(key))
if keyBase == baseName {
return pricing
}
}
// 4. 基于模型系列匹配
return s.matchByModelFamily(modelLower)
}
// extractBaseName 提取基础模型名称(去掉日期版本号)
func (s *PricingService) extractBaseName(model string) string {
// 移除日期后缀 (如 -20251101, -20241022)
parts := strings.Split(model, "-")
result := make([]string, 0, len(parts))
for _, part := range parts {
// 跳过看起来像日期的部分8位数字
if len(part) == 8 && isNumeric(part) {
continue
}
// 跳过版本号(如 v1:0
if strings.Contains(part, ":") {
continue
}
result = append(result, part)
}
return strings.Join(result, "-")
}
// matchByModelFamily 基于模型系列匹配
func (s *PricingService) matchByModelFamily(model string) *LiteLLMModelPricing {
// Claude模型系列匹配规则
familyPatterns := map[string][]string{
"opus-4.5": {"claude-opus-4.5", "claude-opus-4-5"},
"opus-4": {"claude-opus-4", "claude-3-opus"},
"sonnet-4.5": {"claude-sonnet-4.5", "claude-sonnet-4-5"},
"sonnet-4": {"claude-sonnet-4", "claude-3-5-sonnet"},
"sonnet-3.5": {"claude-3-5-sonnet", "claude-3.5-sonnet"},
"sonnet-3": {"claude-3-sonnet"},
"haiku-3.5": {"claude-3-5-haiku", "claude-3.5-haiku"},
"haiku-3": {"claude-3-haiku"},
}
// 确定模型属于哪个系列
var matchedFamily string
for family, patterns := range familyPatterns {
for _, pattern := range patterns {
if strings.Contains(model, pattern) || strings.Contains(model, strings.ReplaceAll(pattern, "-", "")) {
matchedFamily = family
break
}
}
if matchedFamily != "" {
break
}
}
if matchedFamily == "" {
// 简单的系列匹配
if strings.Contains(model, "opus") {
if strings.Contains(model, "4.5") || strings.Contains(model, "4-5") {
matchedFamily = "opus-4.5"
} else {
matchedFamily = "opus-4"
}
} else if strings.Contains(model, "sonnet") {
if strings.Contains(model, "4.5") || strings.Contains(model, "4-5") {
matchedFamily = "sonnet-4.5"
} else if strings.Contains(model, "3-5") || strings.Contains(model, "3.5") {
matchedFamily = "sonnet-3.5"
} else {
matchedFamily = "sonnet-4"
}
} else if strings.Contains(model, "haiku") {
if strings.Contains(model, "3-5") || strings.Contains(model, "3.5") {
matchedFamily = "haiku-3.5"
} else {
matchedFamily = "haiku-3"
}
}
}
if matchedFamily == "" {
return nil
}
// 在价格数据中查找该系列的模型
patterns := familyPatterns[matchedFamily]
for _, pattern := range patterns {
for key, pricing := range s.pricingData {
keyLower := strings.ToLower(key)
if strings.Contains(keyLower, pattern) {
log.Printf("[Pricing] Fuzzy matched %s -> %s", model, key)
return pricing
}
}
}
return nil
}
// GetStatus 获取服务状态
func (s *PricingService) GetStatus() map[string]interface{} {
s.mu.RLock()
defer s.mu.RUnlock()
return map[string]interface{}{
"model_count": len(s.pricingData),
"last_updated": s.lastUpdated,
"local_hash": s.localHash[:min(8, len(s.localHash))],
}
}
// ForceUpdate 强制更新
func (s *PricingService) ForceUpdate() error {
return s.downloadPricingData()
}
// getPricingFilePath 获取价格文件路径
func (s *PricingService) getPricingFilePath() string {
return filepath.Join(s.cfg.Pricing.DataDir, "model_pricing.json")
}
// getHashFilePath 获取哈希文件路径
func (s *PricingService) getHashFilePath() string {
return filepath.Join(s.cfg.Pricing.DataDir, "model_pricing.sha256")
}
// isNumeric 检查字符串是否为纯数字
func isNumeric(s string) bool {
for _, c := range s {
if c < '0' || c > '9' {
return false
}
}
return true
}

View File

@@ -0,0 +1,192 @@
package service
import (
"context"
"errors"
"fmt"
"sub2api/internal/model"
"sub2api/internal/repository"
"gorm.io/gorm"
)
var (
ErrProxyNotFound = errors.New("proxy not found")
)
// CreateProxyRequest 创建代理请求
type CreateProxyRequest struct {
Name string `json:"name"`
Protocol string `json:"protocol"`
Host string `json:"host"`
Port int `json:"port"`
Username string `json:"username"`
Password string `json:"password"`
}
// UpdateProxyRequest 更新代理请求
type UpdateProxyRequest struct {
Name *string `json:"name"`
Protocol *string `json:"protocol"`
Host *string `json:"host"`
Port *int `json:"port"`
Username *string `json:"username"`
Password *string `json:"password"`
Status *string `json:"status"`
}
// ProxyService 代理管理服务
type ProxyService struct {
proxyRepo *repository.ProxyRepository
}
// NewProxyService 创建代理服务实例
func NewProxyService(proxyRepo *repository.ProxyRepository) *ProxyService {
return &ProxyService{
proxyRepo: proxyRepo,
}
}
// Create 创建代理
func (s *ProxyService) Create(ctx context.Context, req CreateProxyRequest) (*model.Proxy, error) {
// 创建代理
proxy := &model.Proxy{
Name: req.Name,
Protocol: req.Protocol,
Host: req.Host,
Port: req.Port,
Username: req.Username,
Password: req.Password,
Status: model.StatusActive,
}
if err := s.proxyRepo.Create(ctx, proxy); err != nil {
return nil, fmt.Errorf("create proxy: %w", err)
}
return proxy, nil
}
// GetByID 根据ID获取代理
func (s *ProxyService) GetByID(ctx context.Context, id int64) (*model.Proxy, error) {
proxy, err := s.proxyRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrProxyNotFound
}
return nil, fmt.Errorf("get proxy: %w", err)
}
return proxy, nil
}
// List 获取代理列表
func (s *ProxyService) List(ctx context.Context, params repository.PaginationParams) ([]model.Proxy, *repository.PaginationResult, error) {
proxies, pagination, err := s.proxyRepo.List(ctx, params)
if err != nil {
return nil, nil, fmt.Errorf("list proxies: %w", err)
}
return proxies, pagination, nil
}
// ListActive 获取活跃代理列表
func (s *ProxyService) ListActive(ctx context.Context) ([]model.Proxy, error) {
proxies, err := s.proxyRepo.ListActive(ctx)
if err != nil {
return nil, fmt.Errorf("list active proxies: %w", err)
}
return proxies, nil
}
// Update 更新代理
func (s *ProxyService) Update(ctx context.Context, id int64, req UpdateProxyRequest) (*model.Proxy, error) {
proxy, err := s.proxyRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrProxyNotFound
}
return nil, fmt.Errorf("get proxy: %w", err)
}
// 更新字段
if req.Name != nil {
proxy.Name = *req.Name
}
if req.Protocol != nil {
proxy.Protocol = *req.Protocol
}
if req.Host != nil {
proxy.Host = *req.Host
}
if req.Port != nil {
proxy.Port = *req.Port
}
if req.Username != nil {
proxy.Username = *req.Username
}
if req.Password != nil {
proxy.Password = *req.Password
}
if req.Status != nil {
proxy.Status = *req.Status
}
if err := s.proxyRepo.Update(ctx, proxy); err != nil {
return nil, fmt.Errorf("update proxy: %w", err)
}
return proxy, nil
}
// Delete 删除代理
func (s *ProxyService) Delete(ctx context.Context, id int64) error {
// 检查代理是否存在
_, err := s.proxyRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrProxyNotFound
}
return fmt.Errorf("get proxy: %w", err)
}
if err := s.proxyRepo.Delete(ctx, id); err != nil {
return fmt.Errorf("delete proxy: %w", err)
}
return nil
}
// TestConnection 测试代理连接(需要实现具体测试逻辑)
func (s *ProxyService) TestConnection(ctx context.Context, id int64) error {
proxy, err := s.proxyRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrProxyNotFound
}
return fmt.Errorf("get proxy: %w", err)
}
// TODO: 实现代理连接测试逻辑
// 可以尝试通过代理发送测试请求
_ = proxy
return nil
}
// GetURL 获取代理URL
func (s *ProxyService) GetURL(ctx context.Context, id int64) (string, error) {
proxy, err := s.proxyRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return "", ErrProxyNotFound
}
return "", fmt.Errorf("get proxy: %w", err)
}
return proxy.URL(), nil
}

View File

@@ -0,0 +1,170 @@
package service
import (
"context"
"log"
"net/http"
"strconv"
"time"
"sub2api/internal/config"
"sub2api/internal/model"
"sub2api/internal/repository"
)
// RateLimitService 处理限流和过载状态管理
type RateLimitService struct {
repos *repository.Repositories
cfg *config.Config
}
// NewRateLimitService 创建RateLimitService实例
func NewRateLimitService(repos *repository.Repositories, cfg *config.Config) *RateLimitService {
return &RateLimitService{
repos: repos,
cfg: cfg,
}
}
// HandleUpstreamError 处理上游错误响应,标记账号状态
// 返回是否应该停止该账号的调度
func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *model.Account, statusCode int, headers http.Header, responseBody []byte) (shouldDisable bool) {
// apikey 类型账号:检查自定义错误码配置
// 如果启用且错误码不在列表中,则不处理(不停止调度、不标记限流/过载)
if !account.ShouldHandleErrorCode(statusCode) {
log.Printf("Account %d: error %d skipped (not in custom error codes)", account.ID, statusCode)
return false
}
switch statusCode {
case 401:
// 认证失败:停止调度,记录错误
s.handleAuthError(ctx, account, "Authentication failed (401): invalid or expired credentials")
return true
case 403:
// 禁止访问:停止调度,记录错误
s.handleAuthError(ctx, account, "Access forbidden (403): account may be suspended or lack permissions")
return true
case 429:
s.handle429(ctx, account, headers)
return false
case 529:
s.handle529(ctx, account)
return false
default:
// 其他5xx错误记录但不停止调度
if statusCode >= 500 {
log.Printf("Account %d received upstream error %d", account.ID, statusCode)
}
return false
}
}
// handleAuthError 处理认证类错误(401/403),停止账号调度
func (s *RateLimitService) handleAuthError(ctx context.Context, account *model.Account, errorMsg string) {
if err := s.repos.Account.SetError(ctx, account.ID, errorMsg); err != nil {
log.Printf("SetError failed for account %d: %v", account.ID, err)
return
}
log.Printf("Account %d disabled due to auth error: %s", account.ID, errorMsg)
}
// handle429 处理429限流错误
// 解析响应头获取重置时间,标记账号为限流状态
func (s *RateLimitService) handle429(ctx context.Context, account *model.Account, headers http.Header) {
// 解析重置时间戳
resetTimestamp := headers.Get("anthropic-ratelimit-unified-reset")
if resetTimestamp == "" {
// 没有重置时间使用默认5分钟
resetAt := time.Now().Add(5 * time.Minute)
if err := s.repos.Account.SetRateLimited(ctx, account.ID, resetAt); err != nil {
log.Printf("SetRateLimited failed for account %d: %v", account.ID, err)
}
return
}
// 解析Unix时间戳
ts, err := strconv.ParseInt(resetTimestamp, 10, 64)
if err != nil {
log.Printf("Parse reset timestamp failed: %v", err)
resetAt := time.Now().Add(5 * time.Minute)
if err := s.repos.Account.SetRateLimited(ctx, account.ID, resetAt); err != nil {
log.Printf("SetRateLimited failed for account %d: %v", account.ID, err)
}
return
}
resetAt := time.Unix(ts, 0)
// 标记限流状态
if err := s.repos.Account.SetRateLimited(ctx, account.ID, resetAt); err != nil {
log.Printf("SetRateLimited failed for account %d: %v", account.ID, err)
return
}
// 根据重置时间反推5h窗口
windowEnd := resetAt
windowStart := resetAt.Add(-5 * time.Hour)
if err := s.repos.Account.UpdateSessionWindow(ctx, account.ID, &windowStart, &windowEnd, "rejected"); err != nil {
log.Printf("UpdateSessionWindow failed for account %d: %v", account.ID, err)
}
log.Printf("Account %d rate limited until %v", account.ID, resetAt)
}
// handle529 处理529过载错误
// 根据配置设置过载冷却时间
func (s *RateLimitService) handle529(ctx context.Context, account *model.Account) {
cooldownMinutes := s.cfg.RateLimit.OverloadCooldownMinutes
if cooldownMinutes <= 0 {
cooldownMinutes = 10 // 默认10分钟
}
until := time.Now().Add(time.Duration(cooldownMinutes) * time.Minute)
if err := s.repos.Account.SetOverloaded(ctx, account.ID, until); err != nil {
log.Printf("SetOverloaded failed for account %d: %v", account.ID, err)
return
}
log.Printf("Account %d overloaded until %v", account.ID, until)
}
// UpdateSessionWindow 从成功响应更新5h窗口状态
func (s *RateLimitService) UpdateSessionWindow(ctx context.Context, account *model.Account, headers http.Header) {
status := headers.Get("anthropic-ratelimit-unified-5h-status")
if status == "" {
return
}
// 检查是否需要初始化时间窗口
// 对于 Setup Token 账号,首次成功请求时需要预测时间窗口
var windowStart, windowEnd *time.Time
needInitWindow := account.SessionWindowEnd == nil || time.Now().After(*account.SessionWindowEnd)
if needInitWindow && (status == "allowed" || status == "allowed_warning") {
// 预测时间窗口:从当前时间的整点开始,+5小时为结束
// 例如:现在是 14:30窗口为 14:00 ~ 19:00
now := time.Now()
start := time.Date(now.Year(), now.Month(), now.Day(), now.Hour(), 0, 0, 0, now.Location())
end := start.Add(5 * time.Hour)
windowStart = &start
windowEnd = &end
log.Printf("Account %d: initializing 5h window from %v to %v (status: %s)", account.ID, start, end, status)
}
if err := s.repos.Account.UpdateSessionWindow(ctx, account.ID, windowStart, windowEnd, status); err != nil {
log.Printf("UpdateSessionWindow failed for account %d: %v", account.ID, err)
}
// 如果状态为allowed且之前有限流说明窗口已重置清除限流状态
if status == "allowed" && account.IsRateLimited() {
if err := s.repos.Account.ClearRateLimit(ctx, account.ID); err != nil {
log.Printf("ClearRateLimit failed for account %d: %v", account.ID, err)
}
}
}
// ClearRateLimit 清除账号的限流状态
func (s *RateLimitService) ClearRateLimit(ctx context.Context, accountID int64) error {
return s.repos.Account.ClearRateLimit(ctx, accountID)
}

View File

@@ -0,0 +1,392 @@
package service
import (
"context"
"crypto/rand"
"encoding/hex"
"errors"
"fmt"
"strings"
"sub2api/internal/model"
"sub2api/internal/repository"
"time"
"github.com/redis/go-redis/v9"
"gorm.io/gorm"
)
var (
ErrRedeemCodeNotFound = errors.New("redeem code not found")
ErrRedeemCodeUsed = errors.New("redeem code already used")
ErrRedeemCodeInvalid = errors.New("invalid redeem code")
ErrInsufficientBalance = errors.New("insufficient balance")
ErrRedeemRateLimited = errors.New("too many failed attempts, please try again later")
ErrRedeemCodeLocked = errors.New("redeem code is being processed, please try again")
)
const (
redeemRateLimitKeyPrefix = "redeem:rate_limit:"
redeemLockKeyPrefix = "redeem:lock:"
redeemMaxErrorsPerHour = 20
redeemRateLimitDuration = time.Hour
redeemLockDuration = 10 * time.Second // 锁超时时间,防止死锁
)
// GenerateCodesRequest 生成兑换码请求
type GenerateCodesRequest struct {
Count int `json:"count"`
Value float64 `json:"value"`
Type string `json:"type"`
}
// RedeemCodeResponse 兑换码响应
type RedeemCodeResponse struct {
Code string `json:"code"`
Value float64 `json:"value"`
Status string `json:"status"`
CreatedAt time.Time `json:"created_at"`
}
// RedeemService 兑换码服务
type RedeemService struct {
redeemRepo *repository.RedeemCodeRepository
userRepo *repository.UserRepository
subscriptionService *SubscriptionService
rdb *redis.Client
billingCacheService *BillingCacheService
}
// NewRedeemService 创建兑换码服务实例
func NewRedeemService(redeemRepo *repository.RedeemCodeRepository, userRepo *repository.UserRepository, subscriptionService *SubscriptionService, rdb *redis.Client) *RedeemService {
return &RedeemService{
redeemRepo: redeemRepo,
userRepo: userRepo,
subscriptionService: subscriptionService,
rdb: rdb,
}
}
// SetBillingCacheService 设置计费缓存服务(用于缓存失效)
func (s *RedeemService) SetBillingCacheService(billingCacheService *BillingCacheService) {
s.billingCacheService = billingCacheService
}
// GenerateRandomCode 生成随机兑换码
func (s *RedeemService) GenerateRandomCode() (string, error) {
// 生成16字节随机数据
bytes := make([]byte, 16)
if _, err := rand.Read(bytes); err != nil {
return "", fmt.Errorf("generate random bytes: %w", err)
}
// 转换为十六进制字符串
code := hex.EncodeToString(bytes)
// 格式化为 XXXX-XXXX-XXXX-XXXX 格式
parts := []string{
strings.ToUpper(code[0:8]),
strings.ToUpper(code[8:16]),
strings.ToUpper(code[16:24]),
strings.ToUpper(code[24:32]),
}
return strings.Join(parts, "-"), nil
}
// GenerateCodes 批量生成兑换码
func (s *RedeemService) GenerateCodes(ctx context.Context, req GenerateCodesRequest) ([]model.RedeemCode, error) {
if req.Count <= 0 {
return nil, errors.New("count must be greater than 0")
}
if req.Value <= 0 {
return nil, errors.New("value must be greater than 0")
}
if req.Count > 1000 {
return nil, errors.New("cannot generate more than 1000 codes at once")
}
codeType := req.Type
if codeType == "" {
codeType = model.RedeemTypeBalance
}
codes := make([]model.RedeemCode, 0, req.Count)
for i := 0; i < req.Count; i++ {
code, err := s.GenerateRandomCode()
if err != nil {
return nil, fmt.Errorf("generate code: %w", err)
}
codes = append(codes, model.RedeemCode{
Code: code,
Type: codeType,
Value: req.Value,
Status: model.StatusUnused,
})
}
// 批量插入
if err := s.redeemRepo.CreateBatch(ctx, codes); err != nil {
return nil, fmt.Errorf("create batch codes: %w", err)
}
return codes, nil
}
// checkRedeemRateLimit 检查用户兑换错误次数是否超限
func (s *RedeemService) checkRedeemRateLimit(ctx context.Context, userID int64) error {
if s.rdb == nil {
return nil
}
key := fmt.Sprintf("%s%d", redeemRateLimitKeyPrefix, userID)
count, err := s.rdb.Get(ctx, key).Int()
if err != nil && !errors.Is(err, redis.Nil) {
// Redis 出错时不阻止用户操作
return nil
}
if count >= redeemMaxErrorsPerHour {
return ErrRedeemRateLimited
}
return nil
}
// incrementRedeemErrorCount 增加用户兑换错误计数
func (s *RedeemService) incrementRedeemErrorCount(ctx context.Context, userID int64) {
if s.rdb == nil {
return
}
key := fmt.Sprintf("%s%d", redeemRateLimitKeyPrefix, userID)
pipe := s.rdb.Pipeline()
pipe.Incr(ctx, key)
pipe.Expire(ctx, key, redeemRateLimitDuration)
_, _ = pipe.Exec(ctx)
}
// acquireRedeemLock 尝试获取兑换码的分布式锁
// 返回 true 表示获取成功false 表示锁已被占用
func (s *RedeemService) acquireRedeemLock(ctx context.Context, code string) bool {
if s.rdb == nil {
return true // 无 Redis 时降级为不加锁
}
key := redeemLockKeyPrefix + code
ok, err := s.rdb.SetNX(ctx, key, "1", redeemLockDuration).Result()
if err != nil {
// Redis 出错时不阻止操作,依赖数据库层面的状态检查
return true
}
return ok
}
// releaseRedeemLock 释放兑换码的分布式锁
func (s *RedeemService) releaseRedeemLock(ctx context.Context, code string) {
if s.rdb == nil {
return
}
key := redeemLockKeyPrefix + code
s.rdb.Del(ctx, key)
}
// Redeem 使用兑换码
func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (*model.RedeemCode, error) {
// 检查限流
if err := s.checkRedeemRateLimit(ctx, userID); err != nil {
return nil, err
}
// 获取分布式锁,防止同一兑换码并发使用
if !s.acquireRedeemLock(ctx, code) {
return nil, ErrRedeemCodeLocked
}
defer s.releaseRedeemLock(ctx, code)
// 查找兑换码
redeemCode, err := s.redeemRepo.GetByCode(ctx, code)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
s.incrementRedeemErrorCount(ctx, userID)
return nil, ErrRedeemCodeNotFound
}
return nil, fmt.Errorf("get redeem code: %w", err)
}
// 检查兑换码状态
if !redeemCode.CanUse() {
s.incrementRedeemErrorCount(ctx, userID)
return nil, ErrRedeemCodeUsed
}
// 验证兑换码类型的前置条件
if redeemCode.Type == model.RedeemTypeSubscription && redeemCode.GroupID == nil {
return nil, errors.New("invalid subscription redeem code: missing group_id")
}
// 获取用户信息
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
return nil, fmt.Errorf("get user: %w", err)
}
_ = user // 使用变量避免未使用错误
// 【关键】先标记兑换码为已使用,确保并发安全
// 利用数据库乐观锁WHERE status = 'unused')保证原子性
if err := s.redeemRepo.Use(ctx, redeemCode.ID, userID); err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
// 兑换码已被其他请求使用
return nil, ErrRedeemCodeUsed
}
return nil, fmt.Errorf("mark code as used: %w", err)
}
// 执行兑换逻辑(兑换码已被锁定,此时可安全操作)
switch redeemCode.Type {
case model.RedeemTypeBalance:
// 增加用户余额
if err := s.userRepo.UpdateBalance(ctx, userID, redeemCode.Value); err != nil {
return nil, fmt.Errorf("update user balance: %w", err)
}
// 失效余额缓存
if s.billingCacheService != nil {
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
s.billingCacheService.InvalidateUserBalance(cacheCtx, userID)
}()
}
case model.RedeemTypeConcurrency:
// 增加用户并发数
if err := s.userRepo.UpdateConcurrency(ctx, userID, int(redeemCode.Value)); err != nil {
return nil, fmt.Errorf("update user concurrency: %w", err)
}
case model.RedeemTypeSubscription:
validityDays := redeemCode.ValidityDays
if validityDays <= 0 {
validityDays = 30
}
_, _, err := s.subscriptionService.AssignOrExtendSubscription(ctx, &AssignSubscriptionInput{
UserID: userID,
GroupID: *redeemCode.GroupID,
ValidityDays: validityDays,
AssignedBy: 0, // 系统分配
Notes: fmt.Sprintf("通过兑换码 %s 兑换", redeemCode.Code),
})
if err != nil {
return nil, fmt.Errorf("assign or extend subscription: %w", err)
}
// 失效订阅缓存
if s.billingCacheService != nil {
groupID := *redeemCode.GroupID
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
}()
}
default:
return nil, fmt.Errorf("unsupported redeem type: %s", redeemCode.Type)
}
// 重新获取更新后的兑换码
redeemCode, err = s.redeemRepo.GetByID(ctx, redeemCode.ID)
if err != nil {
return nil, fmt.Errorf("get updated redeem code: %w", err)
}
return redeemCode, nil
}
// GetByID 根据ID获取兑换码
func (s *RedeemService) GetByID(ctx context.Context, id int64) (*model.RedeemCode, error) {
code, err := s.redeemRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrRedeemCodeNotFound
}
return nil, fmt.Errorf("get redeem code: %w", err)
}
return code, nil
}
// GetByCode 根据Code获取兑换码
func (s *RedeemService) GetByCode(ctx context.Context, code string) (*model.RedeemCode, error) {
redeemCode, err := s.redeemRepo.GetByCode(ctx, code)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrRedeemCodeNotFound
}
return nil, fmt.Errorf("get redeem code: %w", err)
}
return redeemCode, nil
}
// List 获取兑换码列表(管理员功能)
func (s *RedeemService) List(ctx context.Context, params repository.PaginationParams) ([]model.RedeemCode, *repository.PaginationResult, error) {
codes, pagination, err := s.redeemRepo.List(ctx, params)
if err != nil {
return nil, nil, fmt.Errorf("list redeem codes: %w", err)
}
return codes, pagination, nil
}
// Delete 删除兑换码(管理员功能)
func (s *RedeemService) Delete(ctx context.Context, id int64) error {
// 检查兑换码是否存在
code, err := s.redeemRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrRedeemCodeNotFound
}
return fmt.Errorf("get redeem code: %w", err)
}
// 不允许删除已使用的兑换码
if code.IsUsed() {
return errors.New("cannot delete used redeem code")
}
if err := s.redeemRepo.Delete(ctx, id); err != nil {
return fmt.Errorf("delete redeem code: %w", err)
}
return nil
}
// GetStats 获取兑换码统计信息
func (s *RedeemService) GetStats(ctx context.Context) (map[string]interface{}, error) {
// TODO: 实现统计逻辑
// 统计未使用、已使用的兑换码数量
// 统计总面值等
stats := map[string]interface{}{
"total_codes": 0,
"unused_codes": 0,
"used_codes": 0,
"total_value": 0.0,
}
return stats, nil
}
// GetUserHistory 获取用户的兑换历史
func (s *RedeemService) GetUserHistory(ctx context.Context, userID int64, limit int) ([]model.RedeemCode, error) {
codes, err := s.redeemRepo.ListByUser(ctx, userID, limit)
if err != nil {
return nil, fmt.Errorf("get user redeem history: %w", err)
}
return codes, nil
}

View File

@@ -0,0 +1,139 @@
package service
import (
"sub2api/internal/config"
"sub2api/internal/repository"
"github.com/redis/go-redis/v9"
)
// Services 服务集合容器
type Services struct {
Auth *AuthService
User *UserService
ApiKey *ApiKeyService
Group *GroupService
Account *AccountService
Proxy *ProxyService
Redeem *RedeemService
Usage *UsageService
Pricing *PricingService
Billing *BillingService
BillingCache *BillingCacheService
Admin AdminService
Gateway *GatewayService
OAuth *OAuthService
RateLimit *RateLimitService
AccountUsage *AccountUsageService
AccountTest *AccountTestService
Setting *SettingService
Email *EmailService
EmailQueue *EmailQueueService
Turnstile *TurnstileService
Subscription *SubscriptionService
Concurrency *ConcurrencyService
Identity *IdentityService
}
// NewServices 创建所有服务实例
func NewServices(repos *repository.Repositories, rdb *redis.Client, cfg *config.Config) *Services {
// 初始化价格服务
pricingService := NewPricingService(cfg)
if err := pricingService.Initialize(); err != nil {
// 价格服务初始化失败不应阻止启动,使用回退价格
println("[Service] Warning: Pricing service initialization failed:", err.Error())
}
// 初始化计费服务(依赖价格服务)
billingService := NewBillingService(cfg, pricingService)
// 初始化其他服务
authService := NewAuthService(repos.User, cfg)
userService := NewUserService(repos.User, cfg)
apiKeyService := NewApiKeyService(repos.ApiKey, repos.User, repos.Group, repos.UserSubscription, rdb, cfg)
groupService := NewGroupService(repos.Group)
accountService := NewAccountService(repos.Account, repos.Group)
proxyService := NewProxyService(repos.Proxy)
usageService := NewUsageService(repos.UsageLog, repos.User)
// 初始化订阅服务 (RedeemService 依赖)
subscriptionService := NewSubscriptionService(repos)
// 初始化兑换服务 (依赖订阅服务)
redeemService := NewRedeemService(repos.RedeemCode, repos.User, subscriptionService, rdb)
// 初始化Admin服务
adminService := NewAdminService(repos)
// 初始化OAuth服务GatewayService依赖
oauthService := NewOAuthService(repos.Proxy)
// 初始化限流服务
rateLimitService := NewRateLimitService(repos, cfg)
// 初始化计费缓存服务
billingCacheService := NewBillingCacheService(rdb, repos.User, repos.UserSubscription)
// 初始化账号使用量服务
accountUsageService := NewAccountUsageService(repos, oauthService)
// 初始化账号测试服务
accountTestService := NewAccountTestService(repos, oauthService)
// 初始化身份指纹服务
identityService := NewIdentityService(rdb)
// 初始化Gateway服务
gatewayService := NewGatewayService(repos, rdb, cfg, oauthService, billingService, rateLimitService, billingCacheService, identityService)
// 初始化设置服务
settingService := NewSettingService(repos.Setting, cfg)
emailService := NewEmailService(repos.Setting, rdb)
// 初始化邮件队列服务
emailQueueService := NewEmailQueueService(emailService, 3)
// 初始化Turnstile服务
turnstileService := NewTurnstileService(settingService)
// 设置Auth服务的依赖用于注册开关和邮件验证
authService.SetSettingService(settingService)
authService.SetEmailService(emailService)
authService.SetTurnstileService(turnstileService)
authService.SetEmailQueueService(emailQueueService)
// 初始化并发控制服务
concurrencyService := NewConcurrencyService(rdb)
// 注入计费缓存服务到需要失效缓存的服务
redeemService.SetBillingCacheService(billingCacheService)
subscriptionService.SetBillingCacheService(billingCacheService)
SetAdminServiceBillingCache(adminService, billingCacheService)
return &Services{
Auth: authService,
User: userService,
ApiKey: apiKeyService,
Group: groupService,
Account: accountService,
Proxy: proxyService,
Redeem: redeemService,
Usage: usageService,
Pricing: pricingService,
Billing: billingService,
BillingCache: billingCacheService,
Admin: adminService,
Gateway: gatewayService,
OAuth: oauthService,
RateLimit: rateLimitService,
AccountUsage: accountUsageService,
AccountTest: accountTestService,
Setting: settingService,
Email: emailService,
EmailQueue: emailQueueService,
Turnstile: turnstileService,
Subscription: subscriptionService,
Concurrency: concurrencyService,
Identity: identityService,
}
}

View File

@@ -0,0 +1,264 @@
package service
import (
"context"
"errors"
"fmt"
"strconv"
"sub2api/internal/config"
"sub2api/internal/model"
"sub2api/internal/repository"
"gorm.io/gorm"
)
var (
ErrRegistrationDisabled = errors.New("registration is currently disabled")
)
// SettingService 系统设置服务
type SettingService struct {
settingRepo *repository.SettingRepository
cfg *config.Config
}
// NewSettingService 创建系统设置服务实例
func NewSettingService(settingRepo *repository.SettingRepository, cfg *config.Config) *SettingService {
return &SettingService{
settingRepo: settingRepo,
cfg: cfg,
}
}
// GetAllSettings 获取所有系统设置
func (s *SettingService) GetAllSettings(ctx context.Context) (*model.SystemSettings, error) {
settings, err := s.settingRepo.GetAll(ctx)
if err != nil {
return nil, fmt.Errorf("get all settings: %w", err)
}
return s.parseSettings(settings), nil
}
// GetPublicSettings 获取公开设置(无需登录)
func (s *SettingService) GetPublicSettings(ctx context.Context) (*model.PublicSettings, error) {
keys := []string{
model.SettingKeyRegistrationEnabled,
model.SettingKeyEmailVerifyEnabled,
model.SettingKeyTurnstileEnabled,
model.SettingKeyTurnstileSiteKey,
model.SettingKeySiteName,
model.SettingKeySiteLogo,
model.SettingKeySiteSubtitle,
model.SettingKeyApiBaseUrl,
model.SettingKeyContactInfo,
}
settings, err := s.settingRepo.GetMultiple(ctx, keys)
if err != nil {
return nil, fmt.Errorf("get public settings: %w", err)
}
return &model.PublicSettings{
RegistrationEnabled: settings[model.SettingKeyRegistrationEnabled] == "true",
EmailVerifyEnabled: settings[model.SettingKeyEmailVerifyEnabled] == "true",
TurnstileEnabled: settings[model.SettingKeyTurnstileEnabled] == "true",
TurnstileSiteKey: settings[model.SettingKeyTurnstileSiteKey],
SiteName: s.getStringOrDefault(settings, model.SettingKeySiteName, "Sub2API"),
SiteLogo: settings[model.SettingKeySiteLogo],
SiteSubtitle: s.getStringOrDefault(settings, model.SettingKeySiteSubtitle, "Subscription to API Conversion Platform"),
ApiBaseUrl: settings[model.SettingKeyApiBaseUrl],
ContactInfo: settings[model.SettingKeyContactInfo],
}, nil
}
// UpdateSettings 更新系统设置
func (s *SettingService) UpdateSettings(ctx context.Context, settings *model.SystemSettings) error {
updates := make(map[string]string)
// 注册设置
updates[model.SettingKeyRegistrationEnabled] = strconv.FormatBool(settings.RegistrationEnabled)
updates[model.SettingKeyEmailVerifyEnabled] = strconv.FormatBool(settings.EmailVerifyEnabled)
// 邮件服务设置(只有非空才更新密码)
updates[model.SettingKeySmtpHost] = settings.SmtpHost
updates[model.SettingKeySmtpPort] = strconv.Itoa(settings.SmtpPort)
updates[model.SettingKeySmtpUsername] = settings.SmtpUsername
if settings.SmtpPassword != "" {
updates[model.SettingKeySmtpPassword] = settings.SmtpPassword
}
updates[model.SettingKeySmtpFrom] = settings.SmtpFrom
updates[model.SettingKeySmtpFromName] = settings.SmtpFromName
updates[model.SettingKeySmtpUseTLS] = strconv.FormatBool(settings.SmtpUseTLS)
// Cloudflare Turnstile 设置(只有非空才更新密钥)
updates[model.SettingKeyTurnstileEnabled] = strconv.FormatBool(settings.TurnstileEnabled)
updates[model.SettingKeyTurnstileSiteKey] = settings.TurnstileSiteKey
if settings.TurnstileSecretKey != "" {
updates[model.SettingKeyTurnstileSecretKey] = settings.TurnstileSecretKey
}
// OEM设置
updates[model.SettingKeySiteName] = settings.SiteName
updates[model.SettingKeySiteLogo] = settings.SiteLogo
updates[model.SettingKeySiteSubtitle] = settings.SiteSubtitle
updates[model.SettingKeyApiBaseUrl] = settings.ApiBaseUrl
updates[model.SettingKeyContactInfo] = settings.ContactInfo
// 默认配置
updates[model.SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency)
updates[model.SettingKeyDefaultBalance] = strconv.FormatFloat(settings.DefaultBalance, 'f', 8, 64)
return s.settingRepo.SetMultiple(ctx, updates)
}
// IsRegistrationEnabled 检查是否开放注册
func (s *SettingService) IsRegistrationEnabled(ctx context.Context) bool {
value, err := s.settingRepo.GetValue(ctx, model.SettingKeyRegistrationEnabled)
if err != nil {
// 默认开放注册
return true
}
return value == "true"
}
// IsEmailVerifyEnabled 检查是否开启邮件验证
func (s *SettingService) IsEmailVerifyEnabled(ctx context.Context) bool {
value, err := s.settingRepo.GetValue(ctx, model.SettingKeyEmailVerifyEnabled)
if err != nil {
return false
}
return value == "true"
}
// GetSiteName 获取网站名称
func (s *SettingService) GetSiteName(ctx context.Context) string {
value, err := s.settingRepo.GetValue(ctx, model.SettingKeySiteName)
if err != nil || value == "" {
return "Sub2API"
}
return value
}
// GetDefaultConcurrency 获取默认并发量
func (s *SettingService) GetDefaultConcurrency(ctx context.Context) int {
value, err := s.settingRepo.GetValue(ctx, model.SettingKeyDefaultConcurrency)
if err != nil {
return s.cfg.Default.UserConcurrency
}
if v, err := strconv.Atoi(value); err == nil && v > 0 {
return v
}
return s.cfg.Default.UserConcurrency
}
// GetDefaultBalance 获取默认余额
func (s *SettingService) GetDefaultBalance(ctx context.Context) float64 {
value, err := s.settingRepo.GetValue(ctx, model.SettingKeyDefaultBalance)
if err != nil {
return s.cfg.Default.UserBalance
}
if v, err := strconv.ParseFloat(value, 64); err == nil && v >= 0 {
return v
}
return s.cfg.Default.UserBalance
}
// InitializeDefaultSettings 初始化默认设置
func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
// 检查是否已有设置
_, err := s.settingRepo.GetValue(ctx, model.SettingKeyRegistrationEnabled)
if err == nil {
// 已有设置,不需要初始化
return nil
}
if !errors.Is(err, gorm.ErrRecordNotFound) {
return fmt.Errorf("check existing settings: %w", err)
}
// 初始化默认设置
defaults := map[string]string{
model.SettingKeyRegistrationEnabled: "true",
model.SettingKeyEmailVerifyEnabled: "false",
model.SettingKeySiteName: "Sub2API",
model.SettingKeySiteLogo: "",
model.SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
model.SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
model.SettingKeySmtpPort: "587",
model.SettingKeySmtpUseTLS: "false",
}
return s.settingRepo.SetMultiple(ctx, defaults)
}
// parseSettings 解析设置到结构体
func (s *SettingService) parseSettings(settings map[string]string) *model.SystemSettings {
result := &model.SystemSettings{
RegistrationEnabled: settings[model.SettingKeyRegistrationEnabled] == "true",
EmailVerifyEnabled: settings[model.SettingKeyEmailVerifyEnabled] == "true",
SmtpHost: settings[model.SettingKeySmtpHost],
SmtpUsername: settings[model.SettingKeySmtpUsername],
SmtpFrom: settings[model.SettingKeySmtpFrom],
SmtpFromName: settings[model.SettingKeySmtpFromName],
SmtpUseTLS: settings[model.SettingKeySmtpUseTLS] == "true",
TurnstileEnabled: settings[model.SettingKeyTurnstileEnabled] == "true",
TurnstileSiteKey: settings[model.SettingKeyTurnstileSiteKey],
SiteName: s.getStringOrDefault(settings, model.SettingKeySiteName, "Sub2API"),
SiteLogo: settings[model.SettingKeySiteLogo],
SiteSubtitle: s.getStringOrDefault(settings, model.SettingKeySiteSubtitle, "Subscription to API Conversion Platform"),
ApiBaseUrl: settings[model.SettingKeyApiBaseUrl],
ContactInfo: settings[model.SettingKeyContactInfo],
}
// 解析整数类型
if port, err := strconv.Atoi(settings[model.SettingKeySmtpPort]); err == nil {
result.SmtpPort = port
} else {
result.SmtpPort = 587
}
if concurrency, err := strconv.Atoi(settings[model.SettingKeyDefaultConcurrency]); err == nil {
result.DefaultConcurrency = concurrency
} else {
result.DefaultConcurrency = s.cfg.Default.UserConcurrency
}
// 解析浮点数类型
if balance, err := strconv.ParseFloat(settings[model.SettingKeyDefaultBalance], 64); err == nil {
result.DefaultBalance = balance
} else {
result.DefaultBalance = s.cfg.Default.UserBalance
}
// 敏感信息直接返回,方便测试连接时使用
result.SmtpPassword = settings[model.SettingKeySmtpPassword]
result.TurnstileSecretKey = settings[model.SettingKeyTurnstileSecretKey]
return result
}
// getStringOrDefault 获取字符串值或默认值
func (s *SettingService) getStringOrDefault(settings map[string]string, key, defaultValue string) string {
if value, ok := settings[key]; ok && value != "" {
return value
}
return defaultValue
}
// IsTurnstileEnabled 检查是否启用 Turnstile 验证
func (s *SettingService) IsTurnstileEnabled(ctx context.Context) bool {
value, err := s.settingRepo.GetValue(ctx, model.SettingKeyTurnstileEnabled)
if err != nil {
return false
}
return value == "true"
}
// GetTurnstileSecretKey 获取 Turnstile Secret Key
func (s *SettingService) GetTurnstileSecretKey(ctx context.Context) string {
value, err := s.settingRepo.GetValue(ctx, model.SettingKeyTurnstileSecretKey)
if err != nil {
return ""
}
return value
}

View File

@@ -0,0 +1,575 @@
package service
import (
"context"
"errors"
"fmt"
"time"
"sub2api/internal/model"
"sub2api/internal/repository"
)
var (
ErrSubscriptionNotFound = errors.New("subscription not found")
ErrSubscriptionExpired = errors.New("subscription has expired")
ErrSubscriptionSuspended = errors.New("subscription is suspended")
ErrSubscriptionAlreadyExists = errors.New("subscription already exists for this user and group")
ErrGroupNotSubscriptionType = errors.New("group is not a subscription type")
ErrDailyLimitExceeded = errors.New("daily usage limit exceeded")
ErrWeeklyLimitExceeded = errors.New("weekly usage limit exceeded")
ErrMonthlyLimitExceeded = errors.New("monthly usage limit exceeded")
)
// SubscriptionService 订阅服务
type SubscriptionService struct {
repos *repository.Repositories
billingCacheService *BillingCacheService
}
// NewSubscriptionService 创建订阅服务
func NewSubscriptionService(repos *repository.Repositories) *SubscriptionService {
return &SubscriptionService{repos: repos}
}
// SetBillingCacheService 设置计费缓存服务(用于缓存失效)
func (s *SubscriptionService) SetBillingCacheService(billingCacheService *BillingCacheService) {
s.billingCacheService = billingCacheService
}
// AssignSubscriptionInput 分配订阅输入
type AssignSubscriptionInput struct {
UserID int64
GroupID int64
ValidityDays int
AssignedBy int64
Notes string
}
// AssignSubscription 分配订阅给用户(不允许重复分配)
func (s *SubscriptionService) AssignSubscription(ctx context.Context, input *AssignSubscriptionInput) (*model.UserSubscription, error) {
// 检查分组是否存在且为订阅类型
group, err := s.repos.Group.GetByID(ctx, input.GroupID)
if err != nil {
return nil, fmt.Errorf("group not found: %w", err)
}
if !group.IsSubscriptionType() {
return nil, ErrGroupNotSubscriptionType
}
// 检查是否已存在订阅
exists, err := s.repos.UserSubscription.ExistsByUserIDAndGroupID(ctx, input.UserID, input.GroupID)
if err != nil {
return nil, err
}
if exists {
return nil, ErrSubscriptionAlreadyExists
}
sub, err := s.createSubscription(ctx, input)
if err != nil {
return nil, err
}
// 失效订阅缓存
if s.billingCacheService != nil {
userID, groupID := input.UserID, input.GroupID
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
}()
}
return sub, nil
}
// AssignOrExtendSubscription 分配或续期订阅(用于兑换码等场景)
// 如果用户已有同分组的订阅:
// - 未过期:从当前过期时间累加天数
// - 已过期:从当前时间开始计算新的过期时间,并激活订阅
// 如果没有订阅:创建新订阅
func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, input *AssignSubscriptionInput) (*model.UserSubscription, bool, error) {
// 检查分组是否存在且为订阅类型
group, err := s.repos.Group.GetByID(ctx, input.GroupID)
if err != nil {
return nil, false, fmt.Errorf("group not found: %w", err)
}
if !group.IsSubscriptionType() {
return nil, false, ErrGroupNotSubscriptionType
}
// 查询是否已有订阅
existingSub, err := s.repos.UserSubscription.GetByUserIDAndGroupID(ctx, input.UserID, input.GroupID)
if err != nil {
// 不存在记录是正常情况,其他错误需要返回
existingSub = nil
}
validityDays := input.ValidityDays
if validityDays <= 0 {
validityDays = 30
}
// 已有订阅,执行续期
if existingSub != nil {
now := time.Now()
var newExpiresAt time.Time
if existingSub.ExpiresAt.After(now) {
// 未过期:从当前过期时间累加
newExpiresAt = existingSub.ExpiresAt.AddDate(0, 0, validityDays)
} else {
// 已过期:从当前时间开始计算
newExpiresAt = now.AddDate(0, 0, validityDays)
}
// 更新过期时间
if err := s.repos.UserSubscription.ExtendExpiry(ctx, existingSub.ID, newExpiresAt); err != nil {
return nil, false, fmt.Errorf("extend subscription: %w", err)
}
// 如果订阅已过期或被暂停恢复为active状态
if existingSub.Status != model.SubscriptionStatusActive {
if err := s.repos.UserSubscription.UpdateStatus(ctx, existingSub.ID, model.SubscriptionStatusActive); err != nil {
return nil, false, fmt.Errorf("update subscription status: %w", err)
}
}
// 追加备注
if input.Notes != "" {
newNotes := existingSub.Notes
if newNotes != "" {
newNotes += "\n"
}
newNotes += input.Notes
if err := s.repos.UserSubscription.UpdateNotes(ctx, existingSub.ID, newNotes); err != nil {
// 备注更新失败不影响主流程
}
}
// 失效订阅缓存
if s.billingCacheService != nil {
userID, groupID := input.UserID, input.GroupID
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
}()
}
// 返回更新后的订阅
sub, err := s.repos.UserSubscription.GetByID(ctx, existingSub.ID)
return sub, true, err // true 表示是续期
}
// 没有订阅,创建新订阅
sub, err := s.createSubscription(ctx, input)
if err != nil {
return nil, false, err
}
// 失效订阅缓存
if s.billingCacheService != nil {
userID, groupID := input.UserID, input.GroupID
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
}()
}
return sub, false, nil // false 表示是新建
}
// createSubscription 创建新订阅(内部方法)
func (s *SubscriptionService) createSubscription(ctx context.Context, input *AssignSubscriptionInput) (*model.UserSubscription, error) {
validityDays := input.ValidityDays
if validityDays <= 0 {
validityDays = 30
}
now := time.Now()
sub := &model.UserSubscription{
UserID: input.UserID,
GroupID: input.GroupID,
StartsAt: now,
ExpiresAt: now.AddDate(0, 0, validityDays),
Status: model.SubscriptionStatusActive,
AssignedAt: now,
Notes: input.Notes,
CreatedAt: now,
UpdatedAt: now,
}
// 只有当 AssignedBy > 0 时才设置0 表示系统分配,如兑换码)
if input.AssignedBy > 0 {
sub.AssignedBy = &input.AssignedBy
}
if err := s.repos.UserSubscription.Create(ctx, sub); err != nil {
return nil, err
}
// 重新获取完整订阅信息(包含关联)
return s.repos.UserSubscription.GetByID(ctx, sub.ID)
}
// BulkAssignSubscriptionInput 批量分配订阅输入
type BulkAssignSubscriptionInput struct {
UserIDs []int64
GroupID int64
ValidityDays int
AssignedBy int64
Notes string
}
// BulkAssignResult 批量分配结果
type BulkAssignResult struct {
SuccessCount int
FailedCount int
Subscriptions []model.UserSubscription
Errors []string
}
// BulkAssignSubscription 批量分配订阅
func (s *SubscriptionService) BulkAssignSubscription(ctx context.Context, input *BulkAssignSubscriptionInput) (*BulkAssignResult, error) {
result := &BulkAssignResult{
Subscriptions: make([]model.UserSubscription, 0),
Errors: make([]string, 0),
}
for _, userID := range input.UserIDs {
sub, err := s.AssignSubscription(ctx, &AssignSubscriptionInput{
UserID: userID,
GroupID: input.GroupID,
ValidityDays: input.ValidityDays,
AssignedBy: input.AssignedBy,
Notes: input.Notes,
})
if err != nil {
result.FailedCount++
result.Errors = append(result.Errors, fmt.Sprintf("user %d: %v", userID, err))
} else {
result.SuccessCount++
result.Subscriptions = append(result.Subscriptions, *sub)
}
}
return result, nil
}
// RevokeSubscription 撤销订阅
func (s *SubscriptionService) RevokeSubscription(ctx context.Context, subscriptionID int64) error {
// 先获取订阅信息用于失效缓存
sub, err := s.repos.UserSubscription.GetByID(ctx, subscriptionID)
if err != nil {
return err
}
if err := s.repos.UserSubscription.Delete(ctx, subscriptionID); err != nil {
return err
}
// 失效订阅缓存
if s.billingCacheService != nil {
userID, groupID := sub.UserID, sub.GroupID
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
}()
}
return nil
}
// ExtendSubscription 延长订阅
func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscriptionID int64, days int) (*model.UserSubscription, error) {
sub, err := s.repos.UserSubscription.GetByID(ctx, subscriptionID)
if err != nil {
return nil, ErrSubscriptionNotFound
}
// 计算新的过期时间
newExpiresAt := sub.ExpiresAt.AddDate(0, 0, days)
if err := s.repos.UserSubscription.ExtendExpiry(ctx, subscriptionID, newExpiresAt); err != nil {
return nil, err
}
// 如果订阅已过期恢复为active状态
if sub.Status == model.SubscriptionStatusExpired {
if err := s.repos.UserSubscription.UpdateStatus(ctx, subscriptionID, model.SubscriptionStatusActive); err != nil {
return nil, err
}
}
// 失效订阅缓存
if s.billingCacheService != nil {
userID, groupID := sub.UserID, sub.GroupID
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
}()
}
return s.repos.UserSubscription.GetByID(ctx, subscriptionID)
}
// GetByID 根据ID获取订阅
func (s *SubscriptionService) GetByID(ctx context.Context, id int64) (*model.UserSubscription, error) {
return s.repos.UserSubscription.GetByID(ctx, id)
}
// GetActiveSubscription 获取用户对特定分组的有效订阅
func (s *SubscriptionService) GetActiveSubscription(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error) {
sub, err := s.repos.UserSubscription.GetActiveByUserIDAndGroupID(ctx, userID, groupID)
if err != nil {
return nil, ErrSubscriptionNotFound
}
return sub, nil
}
// ListUserSubscriptions 获取用户的所有订阅
func (s *SubscriptionService) ListUserSubscriptions(ctx context.Context, userID int64) ([]model.UserSubscription, error) {
return s.repos.UserSubscription.ListByUserID(ctx, userID)
}
// ListActiveUserSubscriptions 获取用户的所有有效订阅
func (s *SubscriptionService) ListActiveUserSubscriptions(ctx context.Context, userID int64) ([]model.UserSubscription, error) {
return s.repos.UserSubscription.ListActiveByUserID(ctx, userID)
}
// ListGroupSubscriptions 获取分组的所有订阅
func (s *SubscriptionService) ListGroupSubscriptions(ctx context.Context, groupID int64, page, pageSize int) ([]model.UserSubscription, *repository.PaginationResult, error) {
params := repository.PaginationParams{Page: page, PageSize: pageSize}
return s.repos.UserSubscription.ListByGroupID(ctx, groupID, params)
}
// List 获取所有订阅(分页,支持筛选)
func (s *SubscriptionService) List(ctx context.Context, page, pageSize int, userID, groupID *int64, status string) ([]model.UserSubscription, *repository.PaginationResult, error) {
params := repository.PaginationParams{Page: page, PageSize: pageSize}
return s.repos.UserSubscription.List(ctx, params, userID, groupID, status)
}
// CheckAndActivateWindow 检查并激活窗口(首次使用时)
func (s *SubscriptionService) CheckAndActivateWindow(ctx context.Context, sub *model.UserSubscription) error {
if sub.IsWindowActivated() {
return nil
}
now := time.Now()
return s.repos.UserSubscription.ActivateWindows(ctx, sub.ID, now)
}
// CheckAndResetWindows 检查并重置过期的窗口
func (s *SubscriptionService) CheckAndResetWindows(ctx context.Context, sub *model.UserSubscription) error {
now := time.Now()
// 日窗口重置24小时
if sub.NeedsDailyReset() {
if err := s.repos.UserSubscription.ResetDailyUsage(ctx, sub.ID, now); err != nil {
return err
}
sub.DailyWindowStart = &now
sub.DailyUsageUSD = 0
}
// 周窗口重置7天
if sub.NeedsWeeklyReset() {
if err := s.repos.UserSubscription.ResetWeeklyUsage(ctx, sub.ID, now); err != nil {
return err
}
sub.WeeklyWindowStart = &now
sub.WeeklyUsageUSD = 0
}
// 月窗口重置30天
if sub.NeedsMonthlyReset() {
if err := s.repos.UserSubscription.ResetMonthlyUsage(ctx, sub.ID, now); err != nil {
return err
}
sub.MonthlyWindowStart = &now
sub.MonthlyUsageUSD = 0
}
return nil
}
// CheckUsageLimits 检查使用限额(返回错误如果超限)
func (s *SubscriptionService) CheckUsageLimits(ctx context.Context, sub *model.UserSubscription, group *model.Group, additionalCost float64) error {
if !sub.CheckDailyLimit(group, additionalCost) {
return ErrDailyLimitExceeded
}
if !sub.CheckWeeklyLimit(group, additionalCost) {
return ErrWeeklyLimitExceeded
}
if !sub.CheckMonthlyLimit(group, additionalCost) {
return ErrMonthlyLimitExceeded
}
return nil
}
// RecordUsage 记录使用量到订阅
func (s *SubscriptionService) RecordUsage(ctx context.Context, subscriptionID int64, costUSD float64) error {
return s.repos.UserSubscription.IncrementUsage(ctx, subscriptionID, costUSD)
}
// SubscriptionProgress 订阅进度
type SubscriptionProgress struct {
ID int64 `json:"id"`
GroupName string `json:"group_name"`
ExpiresAt time.Time `json:"expires_at"`
ExpiresInDays int `json:"expires_in_days"`
Daily *UsageWindowProgress `json:"daily,omitempty"`
Weekly *UsageWindowProgress `json:"weekly,omitempty"`
Monthly *UsageWindowProgress `json:"monthly,omitempty"`
}
// UsageWindowProgress 使用窗口进度
type UsageWindowProgress struct {
LimitUSD float64 `json:"limit_usd"`
UsedUSD float64 `json:"used_usd"`
RemainingUSD float64 `json:"remaining_usd"`
Percentage float64 `json:"percentage"`
WindowStart time.Time `json:"window_start"`
ResetsAt time.Time `json:"resets_at"`
ResetsInSeconds int64 `json:"resets_in_seconds"`
}
// GetSubscriptionProgress 获取订阅使用进度
func (s *SubscriptionService) GetSubscriptionProgress(ctx context.Context, subscriptionID int64) (*SubscriptionProgress, error) {
sub, err := s.repos.UserSubscription.GetByID(ctx, subscriptionID)
if err != nil {
return nil, ErrSubscriptionNotFound
}
group := sub.Group
if group == nil {
group, err = s.repos.Group.GetByID(ctx, sub.GroupID)
if err != nil {
return nil, err
}
}
progress := &SubscriptionProgress{
ID: sub.ID,
GroupName: group.Name,
ExpiresAt: sub.ExpiresAt,
ExpiresInDays: sub.DaysRemaining(),
}
// 日进度
if group.HasDailyLimit() && sub.DailyWindowStart != nil {
limit := *group.DailyLimitUSD
resetsAt := sub.DailyWindowStart.Add(24 * time.Hour)
progress.Daily = &UsageWindowProgress{
LimitUSD: limit,
UsedUSD: sub.DailyUsageUSD,
RemainingUSD: limit - sub.DailyUsageUSD,
Percentage: (sub.DailyUsageUSD / limit) * 100,
WindowStart: *sub.DailyWindowStart,
ResetsAt: resetsAt,
ResetsInSeconds: int64(time.Until(resetsAt).Seconds()),
}
if progress.Daily.RemainingUSD < 0 {
progress.Daily.RemainingUSD = 0
}
if progress.Daily.Percentage > 100 {
progress.Daily.Percentage = 100
}
if progress.Daily.ResetsInSeconds < 0 {
progress.Daily.ResetsInSeconds = 0
}
}
// 周进度
if group.HasWeeklyLimit() && sub.WeeklyWindowStart != nil {
limit := *group.WeeklyLimitUSD
resetsAt := sub.WeeklyWindowStart.Add(7 * 24 * time.Hour)
progress.Weekly = &UsageWindowProgress{
LimitUSD: limit,
UsedUSD: sub.WeeklyUsageUSD,
RemainingUSD: limit - sub.WeeklyUsageUSD,
Percentage: (sub.WeeklyUsageUSD / limit) * 100,
WindowStart: *sub.WeeklyWindowStart,
ResetsAt: resetsAt,
ResetsInSeconds: int64(time.Until(resetsAt).Seconds()),
}
if progress.Weekly.RemainingUSD < 0 {
progress.Weekly.RemainingUSD = 0
}
if progress.Weekly.Percentage > 100 {
progress.Weekly.Percentage = 100
}
if progress.Weekly.ResetsInSeconds < 0 {
progress.Weekly.ResetsInSeconds = 0
}
}
// 月进度
if group.HasMonthlyLimit() && sub.MonthlyWindowStart != nil {
limit := *group.MonthlyLimitUSD
resetsAt := sub.MonthlyWindowStart.Add(30 * 24 * time.Hour)
progress.Monthly = &UsageWindowProgress{
LimitUSD: limit,
UsedUSD: sub.MonthlyUsageUSD,
RemainingUSD: limit - sub.MonthlyUsageUSD,
Percentage: (sub.MonthlyUsageUSD / limit) * 100,
WindowStart: *sub.MonthlyWindowStart,
ResetsAt: resetsAt,
ResetsInSeconds: int64(time.Until(resetsAt).Seconds()),
}
if progress.Monthly.RemainingUSD < 0 {
progress.Monthly.RemainingUSD = 0
}
if progress.Monthly.Percentage > 100 {
progress.Monthly.Percentage = 100
}
if progress.Monthly.ResetsInSeconds < 0 {
progress.Monthly.ResetsInSeconds = 0
}
}
return progress, nil
}
// GetUserSubscriptionsWithProgress 获取用户所有订阅及进度
func (s *SubscriptionService) GetUserSubscriptionsWithProgress(ctx context.Context, userID int64) ([]SubscriptionProgress, error) {
subs, err := s.repos.UserSubscription.ListActiveByUserID(ctx, userID)
if err != nil {
return nil, err
}
progresses := make([]SubscriptionProgress, 0, len(subs))
for _, sub := range subs {
progress, err := s.GetSubscriptionProgress(ctx, sub.ID)
if err != nil {
continue
}
progresses = append(progresses, *progress)
}
return progresses, nil
}
// UpdateExpiredSubscriptions 更新过期订阅状态(定时任务调用)
func (s *SubscriptionService) UpdateExpiredSubscriptions(ctx context.Context) (int64, error) {
return s.repos.UserSubscription.BatchUpdateExpiredStatus(ctx)
}
// ValidateSubscription 验证订阅是否有效
func (s *SubscriptionService) ValidateSubscription(ctx context.Context, sub *model.UserSubscription) error {
if sub.Status == model.SubscriptionStatusExpired {
return ErrSubscriptionExpired
}
if sub.Status == model.SubscriptionStatusSuspended {
return ErrSubscriptionSuspended
}
if sub.IsExpired() {
// 更新状态
_ = s.repos.UserSubscription.UpdateStatus(ctx, sub.ID, model.SubscriptionStatusExpired)
return ErrSubscriptionExpired
}
return nil
}

View File

@@ -0,0 +1,111 @@
package service
import (
"context"
"encoding/json"
"errors"
"fmt"
"log"
"net/http"
"net/url"
"strings"
"time"
)
var (
ErrTurnstileVerificationFailed = errors.New("turnstile verification failed")
ErrTurnstileNotConfigured = errors.New("turnstile not configured")
)
const turnstileVerifyURL = "https://challenges.cloudflare.com/turnstile/v0/siteverify"
// TurnstileService Turnstile 验证服务
type TurnstileService struct {
settingService *SettingService
httpClient *http.Client
}
// TurnstileVerifyResponse Cloudflare Turnstile 验证响应
type TurnstileVerifyResponse struct {
Success bool `json:"success"`
ChallengeTS string `json:"challenge_ts"`
Hostname string `json:"hostname"`
ErrorCodes []string `json:"error-codes"`
Action string `json:"action"`
CData string `json:"cdata"`
}
// NewTurnstileService 创建 Turnstile 服务实例
func NewTurnstileService(settingService *SettingService) *TurnstileService {
return &TurnstileService{
settingService: settingService,
httpClient: &http.Client{
Timeout: 10 * time.Second,
},
}
}
// VerifyToken 验证 Turnstile token
func (s *TurnstileService) VerifyToken(ctx context.Context, token string, remoteIP string) error {
// 检查是否启用 Turnstile
if !s.settingService.IsTurnstileEnabled(ctx) {
log.Println("[Turnstile] Disabled, skipping verification")
return nil
}
// 获取 Secret Key
secretKey := s.settingService.GetTurnstileSecretKey(ctx)
if secretKey == "" {
log.Println("[Turnstile] Secret key not configured")
return ErrTurnstileNotConfigured
}
// 如果 token 为空,返回错误
if token == "" {
log.Println("[Turnstile] Token is empty")
return ErrTurnstileVerificationFailed
}
// 构建请求
formData := url.Values{}
formData.Set("secret", secretKey)
formData.Set("response", token)
if remoteIP != "" {
formData.Set("remoteip", remoteIP)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, turnstileVerifyURL, strings.NewReader(formData.Encode()))
if err != nil {
return fmt.Errorf("create request: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
// 发送请求
log.Printf("[Turnstile] Verifying token for IP: %s", remoteIP)
resp, err := s.httpClient.Do(req)
if err != nil {
log.Printf("[Turnstile] Request failed: %v", err)
return fmt.Errorf("send request: %w", err)
}
defer resp.Body.Close()
// 解析响应
var result TurnstileVerifyResponse
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
log.Printf("[Turnstile] Failed to decode response: %v", err)
return fmt.Errorf("decode response: %w", err)
}
if !result.Success {
log.Printf("[Turnstile] Verification failed, error codes: %v", result.ErrorCodes)
return ErrTurnstileVerificationFailed
}
log.Println("[Turnstile] Verification successful")
return nil
}
// IsEnabled 检查 Turnstile 是否启用
func (s *TurnstileService) IsEnabled(ctx context.Context) bool {
return s.settingService.IsTurnstileEnabled(ctx)
}

View File

@@ -0,0 +1,621 @@
package service
import (
"archive/tar"
"bufio"
"compress/gzip"
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"time"
"github.com/redis/go-redis/v9"
)
const (
updateCacheKey = "update_check_cache"
updateCacheTTL = 1200 // 20 minutes
githubRepo = "Wei-Shaw/sub2api"
// Security: allowed download domains for updates
allowedDownloadHost = "github.com"
allowedAssetHost = "objects.githubusercontent.com"
// Security: max download size (500MB)
maxDownloadSize = 500 * 1024 * 1024
)
// UpdateService handles software updates
type UpdateService struct {
rdb *redis.Client
currentVersion string
buildType string // "source" for manual builds, "release" for CI builds
}
// NewUpdateService creates a new UpdateService
func NewUpdateService(rdb *redis.Client, version, buildType string) *UpdateService {
return &UpdateService{
rdb: rdb,
currentVersion: version,
buildType: buildType,
}
}
// UpdateInfo contains update information
type UpdateInfo struct {
CurrentVersion string `json:"current_version"`
LatestVersion string `json:"latest_version"`
HasUpdate bool `json:"has_update"`
ReleaseInfo *ReleaseInfo `json:"release_info,omitempty"`
Cached bool `json:"cached"`
Warning string `json:"warning,omitempty"`
BuildType string `json:"build_type"` // "source" or "release"
}
// ReleaseInfo contains GitHub release details
type ReleaseInfo struct {
Name string `json:"name"`
Body string `json:"body"`
PublishedAt string `json:"published_at"`
HtmlURL string `json:"html_url"`
Assets []Asset `json:"assets,omitempty"`
}
// Asset represents a release asset
type Asset struct {
Name string `json:"name"`
DownloadURL string `json:"download_url"`
Size int64 `json:"size"`
}
// GitHubRelease represents GitHub API response
type GitHubRelease struct {
TagName string `json:"tag_name"`
Name string `json:"name"`
Body string `json:"body"`
PublishedAt string `json:"published_at"`
HtmlUrl string `json:"html_url"`
Assets []GitHubAsset `json:"assets"`
}
type GitHubAsset struct {
Name string `json:"name"`
BrowserDownloadUrl string `json:"browser_download_url"`
Size int64 `json:"size"`
}
// CheckUpdate checks for available updates
func (s *UpdateService) CheckUpdate(ctx context.Context, force bool) (*UpdateInfo, error) {
// Try cache first
if !force {
if cached, err := s.getFromCache(ctx); err == nil && cached != nil {
return cached, nil
}
}
// Fetch from GitHub
info, err := s.fetchLatestRelease(ctx)
if err != nil {
// Return cached on error
if cached, cacheErr := s.getFromCache(ctx); cacheErr == nil && cached != nil {
cached.Warning = "Using cached data: " + err.Error()
return cached, nil
}
return &UpdateInfo{
CurrentVersion: s.currentVersion,
LatestVersion: s.currentVersion,
HasUpdate: false,
Warning: err.Error(),
BuildType: s.buildType,
}, nil
}
// Cache result
s.saveToCache(ctx, info)
return info, nil
}
// PerformUpdate downloads and applies the update
func (s *UpdateService) PerformUpdate(ctx context.Context) error {
info, err := s.CheckUpdate(ctx, true)
if err != nil {
return err
}
if !info.HasUpdate {
return fmt.Errorf("no update available")
}
// Find matching archive and checksum for current platform
archiveName := s.getArchiveName()
var downloadURL string
var checksumURL string
for _, asset := range info.ReleaseInfo.Assets {
if strings.Contains(asset.Name, archiveName) && !strings.HasSuffix(asset.Name, ".txt") {
downloadURL = asset.DownloadURL
}
if asset.Name == "checksums.txt" {
checksumURL = asset.DownloadURL
}
}
if downloadURL == "" {
return fmt.Errorf("no compatible release found for %s/%s", runtime.GOOS, runtime.GOARCH)
}
// SECURITY: Validate download URL is from trusted domain
if err := validateDownloadURL(downloadURL); err != nil {
return fmt.Errorf("invalid download URL: %w", err)
}
if checksumURL != "" {
if err := validateDownloadURL(checksumURL); err != nil {
return fmt.Errorf("invalid checksum URL: %w", err)
}
}
// Get current executable path
exePath, err := os.Executable()
if err != nil {
return fmt.Errorf("failed to get executable path: %w", err)
}
exePath, err = filepath.EvalSymlinks(exePath)
if err != nil {
return fmt.Errorf("failed to resolve symlinks: %w", err)
}
// Create temp directory for extraction
tempDir, err := os.MkdirTemp("", "sub2api-update-*")
if err != nil {
return fmt.Errorf("failed to create temp dir: %w", err)
}
defer os.RemoveAll(tempDir)
// Download archive
archivePath := filepath.Join(tempDir, filepath.Base(downloadURL))
if err := s.downloadFile(ctx, downloadURL, archivePath); err != nil {
return fmt.Errorf("download failed: %w", err)
}
// Verify checksum if available
if checksumURL != "" {
if err := s.verifyChecksum(ctx, archivePath, checksumURL); err != nil {
return fmt.Errorf("checksum verification failed: %w", err)
}
}
// Extract binary from archive
newBinaryPath := filepath.Join(tempDir, "sub2api")
if err := s.extractBinary(archivePath, newBinaryPath); err != nil {
return fmt.Errorf("extraction failed: %w", err)
}
// Backup current binary
backupFile := exePath + ".backup"
if err := os.Rename(exePath, backupFile); err != nil {
return fmt.Errorf("backup failed: %w", err)
}
// Replace with new binary
if err := copyFile(newBinaryPath, exePath); err != nil {
os.Rename(backupFile, exePath)
return fmt.Errorf("replace failed: %w", err)
}
// Make executable
if err := os.Chmod(exePath, 0755); err != nil {
return fmt.Errorf("chmod failed: %w", err)
}
return nil
}
// Rollback restores the previous version
func (s *UpdateService) Rollback() error {
exePath, err := os.Executable()
if err != nil {
return fmt.Errorf("failed to get executable path: %w", err)
}
exePath, err = filepath.EvalSymlinks(exePath)
if err != nil {
return fmt.Errorf("failed to resolve symlinks: %w", err)
}
backupFile := exePath + ".backup"
if _, err := os.Stat(backupFile); os.IsNotExist(err) {
return fmt.Errorf("no backup found")
}
// Replace current with backup
if err := os.Rename(backupFile, exePath); err != nil {
return fmt.Errorf("rollback failed: %w", err)
}
return nil
}
// RestartService triggers a service restart via systemd
func (s *UpdateService) RestartService() error {
if runtime.GOOS != "linux" {
return fmt.Errorf("systemd restart only available on Linux")
}
// Try direct systemctl first (works if running as root or with proper permissions)
cmd := exec.Command("systemctl", "restart", "sub2api")
if err := cmd.Run(); err != nil {
// Try with sudo (requires NOPASSWD sudoers entry)
sudoCmd := exec.Command("sudo", "systemctl", "restart", "sub2api")
if sudoErr := sudoCmd.Run(); sudoErr != nil {
return fmt.Errorf("systemctl restart failed: %w (sudo also failed: %v)", err, sudoErr)
}
}
return nil
}
func (s *UpdateService) fetchLatestRelease(ctx context.Context) (*UpdateInfo, error) {
url := fmt.Sprintf("https://api.github.com/repos/%s/releases/latest", githubRepo)
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return nil, err
}
req.Header.Set("Accept", "application/vnd.github.v3+json")
req.Header.Set("User-Agent", "Sub2API-Updater")
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusNotFound {
return &UpdateInfo{
CurrentVersion: s.currentVersion,
LatestVersion: s.currentVersion,
HasUpdate: false,
Warning: "No releases found",
BuildType: s.buildType,
}, nil
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("GitHub API returned %d", resp.StatusCode)
}
var release GitHubRelease
if err := json.NewDecoder(resp.Body).Decode(&release); err != nil {
return nil, err
}
latestVersion := strings.TrimPrefix(release.TagName, "v")
assets := make([]Asset, len(release.Assets))
for i, a := range release.Assets {
assets[i] = Asset{
Name: a.Name,
DownloadURL: a.BrowserDownloadUrl,
Size: a.Size,
}
}
return &UpdateInfo{
CurrentVersion: s.currentVersion,
LatestVersion: latestVersion,
HasUpdate: compareVersions(s.currentVersion, latestVersion) < 0,
ReleaseInfo: &ReleaseInfo{
Name: release.Name,
Body: release.Body,
PublishedAt: release.PublishedAt,
HtmlURL: release.HtmlUrl,
Assets: assets,
},
Cached: false,
BuildType: s.buildType,
}, nil
}
func (s *UpdateService) downloadFile(ctx context.Context, downloadURL, dest string) error {
req, err := http.NewRequestWithContext(ctx, "GET", downloadURL, nil)
if err != nil {
return err
}
client := &http.Client{Timeout: 10 * time.Minute}
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("download returned %d", resp.StatusCode)
}
// SECURITY: Check Content-Length if available
if resp.ContentLength > maxDownloadSize {
return fmt.Errorf("file too large: %d bytes (max %d)", resp.ContentLength, maxDownloadSize)
}
out, err := os.Create(dest)
if err != nil {
return err
}
defer out.Close()
// SECURITY: Use LimitReader to enforce max download size even if Content-Length is missing/wrong
limited := io.LimitReader(resp.Body, maxDownloadSize+1)
written, err := io.Copy(out, limited)
if err != nil {
return err
}
// Check if we hit the limit (downloaded more than maxDownloadSize)
if written > maxDownloadSize {
os.Remove(dest) // Clean up partial file
return fmt.Errorf("download exceeded maximum size of %d bytes", maxDownloadSize)
}
return nil
}
func (s *UpdateService) getArchiveName() string {
osName := runtime.GOOS
arch := runtime.GOARCH
return fmt.Sprintf("%s_%s", osName, arch)
}
// validateDownloadURL checks if the URL is from an allowed domain
// SECURITY: This prevents SSRF and ensures downloads only come from trusted GitHub domains
func validateDownloadURL(rawURL string) error {
parsedURL, err := url.Parse(rawURL)
if err != nil {
return fmt.Errorf("invalid URL: %w", err)
}
// Must be HTTPS
if parsedURL.Scheme != "https" {
return fmt.Errorf("only HTTPS URLs are allowed")
}
// Check against allowed hosts
host := parsedURL.Host
// GitHub release URLs can be from github.com or objects.githubusercontent.com
if host != allowedDownloadHost &&
!strings.HasSuffix(host, "."+allowedDownloadHost) &&
host != allowedAssetHost &&
!strings.HasSuffix(host, "."+allowedAssetHost) {
return fmt.Errorf("download from untrusted host: %s", host)
}
return nil
}
func (s *UpdateService) verifyChecksum(ctx context.Context, filePath, checksumURL string) error {
// Download checksums file
req, err := http.NewRequestWithContext(ctx, "GET", checksumURL, nil)
if err != nil {
return err
}
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("failed to download checksums: %d", resp.StatusCode)
}
// Calculate file hash
f, err := os.Open(filePath)
if err != nil {
return err
}
defer f.Close()
h := sha256.New()
if _, err := io.Copy(h, f); err != nil {
return err
}
actualHash := hex.EncodeToString(h.Sum(nil))
// Find expected hash in checksums file
fileName := filepath.Base(filePath)
scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() {
line := scanner.Text()
parts := strings.Fields(line)
if len(parts) == 2 && parts[1] == fileName {
if parts[0] == actualHash {
return nil
}
return fmt.Errorf("checksum mismatch: expected %s, got %s", parts[0], actualHash)
}
}
return fmt.Errorf("checksum not found for %s", fileName)
}
func (s *UpdateService) extractBinary(archivePath, destPath string) error {
f, err := os.Open(archivePath)
if err != nil {
return err
}
defer f.Close()
var reader io.Reader = f
// Handle gzip compression
if strings.HasSuffix(archivePath, ".gz") || strings.HasSuffix(archivePath, ".tar.gz") || strings.HasSuffix(archivePath, ".tgz") {
gzr, err := gzip.NewReader(f)
if err != nil {
return err
}
defer gzr.Close()
reader = gzr
}
// Handle tar archive
if strings.Contains(archivePath, ".tar") {
tr := tar.NewReader(reader)
for {
hdr, err := tr.Next()
if err == io.EOF {
break
}
if err != nil {
return err
}
// SECURITY: Prevent Zip Slip / Path Traversal attack
// Only allow files with safe base names, no directory traversal
baseName := filepath.Base(hdr.Name)
// Check for path traversal attempts
if strings.Contains(hdr.Name, "..") {
return fmt.Errorf("path traversal attempt detected: %s", hdr.Name)
}
// Validate the entry is a regular file
if hdr.Typeflag != tar.TypeReg {
continue // Skip directories and special files
}
// Only extract the specific binary we need
if baseName == "sub2api" || baseName == "sub2api.exe" {
// Additional security: limit file size (max 500MB)
const maxBinarySize = 500 * 1024 * 1024
if hdr.Size > maxBinarySize {
return fmt.Errorf("binary too large: %d bytes (max %d)", hdr.Size, maxBinarySize)
}
out, err := os.Create(destPath)
if err != nil {
return err
}
// Use LimitReader to prevent decompression bombs
limited := io.LimitReader(tr, maxBinarySize)
if _, err := io.Copy(out, limited); err != nil {
out.Close()
return err
}
out.Close()
return nil
}
}
return fmt.Errorf("binary not found in archive")
}
// Direct copy for non-tar files (with size limit)
const maxBinarySize = 500 * 1024 * 1024
out, err := os.Create(destPath)
if err != nil {
return err
}
defer out.Close()
limited := io.LimitReader(reader, maxBinarySize)
_, err = io.Copy(out, limited)
return err
}
func copyFile(src, dst string) error {
in, err := os.Open(src)
if err != nil {
return err
}
defer in.Close()
out, err := os.Create(dst)
if err != nil {
return err
}
defer out.Close()
_, err = io.Copy(out, in)
return err
}
func (s *UpdateService) getFromCache(ctx context.Context) (*UpdateInfo, error) {
data, err := s.rdb.Get(ctx, updateCacheKey).Result()
if err != nil {
return nil, err
}
var cached struct {
Latest string `json:"latest"`
ReleaseInfo *ReleaseInfo `json:"release_info"`
Timestamp int64 `json:"timestamp"`
}
if err := json.Unmarshal([]byte(data), &cached); err != nil {
return nil, err
}
if time.Now().Unix()-cached.Timestamp > updateCacheTTL {
return nil, fmt.Errorf("cache expired")
}
return &UpdateInfo{
CurrentVersion: s.currentVersion,
LatestVersion: cached.Latest,
HasUpdate: compareVersions(s.currentVersion, cached.Latest) < 0,
ReleaseInfo: cached.ReleaseInfo,
Cached: true,
BuildType: s.buildType,
}, nil
}
func (s *UpdateService) saveToCache(ctx context.Context, info *UpdateInfo) {
cacheData := struct {
Latest string `json:"latest"`
ReleaseInfo *ReleaseInfo `json:"release_info"`
Timestamp int64 `json:"timestamp"`
}{
Latest: info.LatestVersion,
ReleaseInfo: info.ReleaseInfo,
Timestamp: time.Now().Unix(),
}
data, _ := json.Marshal(cacheData)
s.rdb.Set(ctx, updateCacheKey, data, time.Duration(updateCacheTTL)*time.Second)
}
// compareVersions compares two semantic versions
func compareVersions(current, latest string) int {
currentParts := parseVersion(current)
latestParts := parseVersion(latest)
for i := 0; i < 3; i++ {
if currentParts[i] < latestParts[i] {
return -1
}
if currentParts[i] > latestParts[i] {
return 1
}
}
return 0
}
func parseVersion(v string) [3]int {
v = strings.TrimPrefix(v, "v")
parts := strings.Split(v, ".")
result := [3]int{0, 0, 0}
for i := 0; i < len(parts) && i < 3; i++ {
fmt.Sscanf(parts[i], "%d", &result[i])
}
return result
}

View File

@@ -0,0 +1,283 @@
package service
import (
"context"
"errors"
"fmt"
"sub2api/internal/model"
"sub2api/internal/repository"
"time"
"gorm.io/gorm"
)
var (
ErrUsageLogNotFound = errors.New("usage log not found")
)
// CreateUsageLogRequest 创建使用日志请求
type CreateUsageLogRequest struct {
UserID int64 `json:"user_id"`
ApiKeyID int64 `json:"api_key_id"`
AccountID int64 `json:"account_id"`
RequestID string `json:"request_id"`
Model string `json:"model"`
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
CacheCreationTokens int `json:"cache_creation_tokens"`
CacheReadTokens int `json:"cache_read_tokens"`
CacheCreation5mTokens int `json:"cache_creation_5m_tokens"`
CacheCreation1hTokens int `json:"cache_creation_1h_tokens"`
InputCost float64 `json:"input_cost"`
OutputCost float64 `json:"output_cost"`
CacheCreationCost float64 `json:"cache_creation_cost"`
CacheReadCost float64 `json:"cache_read_cost"`
TotalCost float64 `json:"total_cost"`
ActualCost float64 `json:"actual_cost"`
RateMultiplier float64 `json:"rate_multiplier"`
Stream bool `json:"stream"`
DurationMs *int `json:"duration_ms"`
}
// UsageStats 使用统计
type UsageStats struct {
TotalRequests int64 `json:"total_requests"`
TotalInputTokens int64 `json:"total_input_tokens"`
TotalOutputTokens int64 `json:"total_output_tokens"`
TotalCacheTokens int64 `json:"total_cache_tokens"`
TotalTokens int64 `json:"total_tokens"`
TotalCost float64 `json:"total_cost"`
TotalActualCost float64 `json:"total_actual_cost"`
AverageDurationMs float64 `json:"average_duration_ms"`
}
// UsageService 使用统计服务
type UsageService struct {
usageRepo *repository.UsageLogRepository
userRepo *repository.UserRepository
}
// NewUsageService 创建使用统计服务实例
func NewUsageService(usageRepo *repository.UsageLogRepository, userRepo *repository.UserRepository) *UsageService {
return &UsageService{
usageRepo: usageRepo,
userRepo: userRepo,
}
}
// Create 创建使用日志
func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (*model.UsageLog, error) {
// 验证用户存在
_, err := s.userRepo.GetByID(ctx, req.UserID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
return nil, fmt.Errorf("get user: %w", err)
}
// 创建使用日志
usageLog := &model.UsageLog{
UserID: req.UserID,
ApiKeyID: req.ApiKeyID,
AccountID: req.AccountID,
RequestID: req.RequestID,
Model: req.Model,
InputTokens: req.InputTokens,
OutputTokens: req.OutputTokens,
CacheCreationTokens: req.CacheCreationTokens,
CacheReadTokens: req.CacheReadTokens,
CacheCreation5mTokens: req.CacheCreation5mTokens,
CacheCreation1hTokens: req.CacheCreation1hTokens,
InputCost: req.InputCost,
OutputCost: req.OutputCost,
CacheCreationCost: req.CacheCreationCost,
CacheReadCost: req.CacheReadCost,
TotalCost: req.TotalCost,
ActualCost: req.ActualCost,
RateMultiplier: req.RateMultiplier,
Stream: req.Stream,
DurationMs: req.DurationMs,
}
if err := s.usageRepo.Create(ctx, usageLog); err != nil {
return nil, fmt.Errorf("create usage log: %w", err)
}
// 扣除用户余额
if req.ActualCost > 0 {
if err := s.userRepo.UpdateBalance(ctx, req.UserID, -req.ActualCost); err != nil {
return nil, fmt.Errorf("update user balance: %w", err)
}
}
return usageLog, nil
}
// GetByID 根据ID获取使用日志
func (s *UsageService) GetByID(ctx context.Context, id int64) (*model.UsageLog, error) {
log, err := s.usageRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUsageLogNotFound
}
return nil, fmt.Errorf("get usage log: %w", err)
}
return log, nil
}
// ListByUser 获取用户的使用日志列表
func (s *UsageService) ListByUser(ctx context.Context, userID int64, params repository.PaginationParams) ([]model.UsageLog, *repository.PaginationResult, error) {
logs, pagination, err := s.usageRepo.ListByUser(ctx, userID, params)
if err != nil {
return nil, nil, fmt.Errorf("list usage logs: %w", err)
}
return logs, pagination, nil
}
// ListByApiKey 获取API Key的使用日志列表
func (s *UsageService) ListByApiKey(ctx context.Context, apiKeyID int64, params repository.PaginationParams) ([]model.UsageLog, *repository.PaginationResult, error) {
logs, pagination, err := s.usageRepo.ListByApiKey(ctx, apiKeyID, params)
if err != nil {
return nil, nil, fmt.Errorf("list usage logs: %w", err)
}
return logs, pagination, nil
}
// ListByAccount 获取账号的使用日志列表
func (s *UsageService) ListByAccount(ctx context.Context, accountID int64, params repository.PaginationParams) ([]model.UsageLog, *repository.PaginationResult, error) {
logs, pagination, err := s.usageRepo.ListByAccount(ctx, accountID, params)
if err != nil {
return nil, nil, fmt.Errorf("list usage logs: %w", err)
}
return logs, pagination, nil
}
// GetStatsByUser 获取用户的使用统计
func (s *UsageService) GetStatsByUser(ctx context.Context, userID int64, startTime, endTime time.Time) (*UsageStats, error) {
logs, _, err := s.usageRepo.ListByUserAndTimeRange(ctx, userID, startTime, endTime)
if err != nil {
return nil, fmt.Errorf("list usage logs: %w", err)
}
return s.calculateStats(logs), nil
}
// GetStatsByApiKey 获取API Key的使用统计
func (s *UsageService) GetStatsByApiKey(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*UsageStats, error) {
logs, _, err := s.usageRepo.ListByApiKeyAndTimeRange(ctx, apiKeyID, startTime, endTime)
if err != nil {
return nil, fmt.Errorf("list usage logs: %w", err)
}
return s.calculateStats(logs), nil
}
// GetStatsByAccount 获取账号的使用统计
func (s *UsageService) GetStatsByAccount(ctx context.Context, accountID int64, startTime, endTime time.Time) (*UsageStats, error) {
logs, _, err := s.usageRepo.ListByAccountAndTimeRange(ctx, accountID, startTime, endTime)
if err != nil {
return nil, fmt.Errorf("list usage logs: %w", err)
}
return s.calculateStats(logs), nil
}
// GetStatsByModel 获取模型的使用统计
func (s *UsageService) GetStatsByModel(ctx context.Context, modelName string, startTime, endTime time.Time) (*UsageStats, error) {
logs, _, err := s.usageRepo.ListByModelAndTimeRange(ctx, modelName, startTime, endTime)
if err != nil {
return nil, fmt.Errorf("list usage logs: %w", err)
}
return s.calculateStats(logs), nil
}
// GetDailyStats 获取每日使用统计最近N天
func (s *UsageService) GetDailyStats(ctx context.Context, userID int64, days int) ([]map[string]interface{}, error) {
endTime := time.Now()
startTime := endTime.AddDate(0, 0, -days)
logs, _, err := s.usageRepo.ListByUserAndTimeRange(ctx, userID, startTime, endTime)
if err != nil {
return nil, fmt.Errorf("list usage logs: %w", err)
}
// 按日期分组统计
dailyStats := make(map[string]*UsageStats)
for _, log := range logs {
dateKey := log.CreatedAt.Format("2006-01-02")
if _, exists := dailyStats[dateKey]; !exists {
dailyStats[dateKey] = &UsageStats{}
}
stats := dailyStats[dateKey]
stats.TotalRequests++
stats.TotalInputTokens += int64(log.InputTokens)
stats.TotalOutputTokens += int64(log.OutputTokens)
stats.TotalCacheTokens += int64(log.CacheCreationTokens + log.CacheReadTokens)
stats.TotalTokens += int64(log.TotalTokens())
stats.TotalCost += log.TotalCost
stats.TotalActualCost += log.ActualCost
if log.DurationMs != nil {
stats.AverageDurationMs += float64(*log.DurationMs)
}
}
// 计算平均值并转换为数组
result := make([]map[string]interface{}, 0, len(dailyStats))
for date, stats := range dailyStats {
if stats.TotalRequests > 0 {
stats.AverageDurationMs /= float64(stats.TotalRequests)
}
result = append(result, map[string]interface{}{
"date": date,
"total_requests": stats.TotalRequests,
"total_input_tokens": stats.TotalInputTokens,
"total_output_tokens": stats.TotalOutputTokens,
"total_cache_tokens": stats.TotalCacheTokens,
"total_tokens": stats.TotalTokens,
"total_cost": stats.TotalCost,
"total_actual_cost": stats.TotalActualCost,
"average_duration_ms": stats.AverageDurationMs,
})
}
return result, nil
}
// calculateStats 计算统计数据
func (s *UsageService) calculateStats(logs []model.UsageLog) *UsageStats {
stats := &UsageStats{}
for _, log := range logs {
stats.TotalRequests++
stats.TotalInputTokens += int64(log.InputTokens)
stats.TotalOutputTokens += int64(log.OutputTokens)
stats.TotalCacheTokens += int64(log.CacheCreationTokens + log.CacheReadTokens)
stats.TotalTokens += int64(log.TotalTokens())
stats.TotalCost += log.TotalCost
stats.TotalActualCost += log.ActualCost
if log.DurationMs != nil {
stats.AverageDurationMs += float64(*log.DurationMs)
}
}
// 计算平均持续时间
if stats.TotalRequests > 0 {
stats.AverageDurationMs /= float64(stats.TotalRequests)
}
return stats
}
// Delete 删除使用日志(管理员功能,谨慎使用)
func (s *UsageService) Delete(ctx context.Context, id int64) error {
if err := s.usageRepo.Delete(ctx, id); err != nil {
return fmt.Errorf("delete usage log: %w", err)
}
return nil
}

View File

@@ -0,0 +1,177 @@
package service
import (
"context"
"errors"
"fmt"
"sub2api/internal/config"
"sub2api/internal/model"
"sub2api/internal/repository"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
)
var (
ErrUserNotFound = errors.New("user not found")
ErrPasswordIncorrect = errors.New("current password is incorrect")
ErrInsufficientPerms = errors.New("insufficient permissions")
)
// UpdateProfileRequest 更新用户资料请求
type UpdateProfileRequest struct {
Email *string `json:"email"`
Concurrency *int `json:"concurrency"`
}
// ChangePasswordRequest 修改密码请求
type ChangePasswordRequest struct {
CurrentPassword string `json:"current_password"`
NewPassword string `json:"new_password"`
}
// UserService 用户服务
type UserService struct {
userRepo *repository.UserRepository
cfg *config.Config
}
// NewUserService 创建用户服务实例
func NewUserService(userRepo *repository.UserRepository, cfg *config.Config) *UserService {
return &UserService{
userRepo: userRepo,
cfg: cfg,
}
}
// GetProfile 获取用户资料
func (s *UserService) GetProfile(ctx context.Context, userID int64) (*model.User, error) {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
return nil, fmt.Errorf("get user: %w", err)
}
return user, nil
}
// UpdateProfile 更新用户资料
func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req UpdateProfileRequest) (*model.User, error) {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
return nil, fmt.Errorf("get user: %w", err)
}
// 更新字段
if req.Email != nil {
// 检查新邮箱是否已被使用
exists, err := s.userRepo.ExistsByEmail(ctx, *req.Email)
if err != nil {
return nil, fmt.Errorf("check email exists: %w", err)
}
if exists && *req.Email != user.Email {
return nil, ErrEmailExists
}
user.Email = *req.Email
}
if req.Concurrency != nil {
user.Concurrency = *req.Concurrency
}
if err := s.userRepo.Update(ctx, user); err != nil {
return nil, fmt.Errorf("update user: %w", err)
}
return user, nil
}
// ChangePassword 修改密码
func (s *UserService) ChangePassword(ctx context.Context, userID int64, req ChangePasswordRequest) error {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrUserNotFound
}
return fmt.Errorf("get user: %w", err)
}
// 验证当前密码
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.CurrentPassword)); err != nil {
return ErrPasswordIncorrect
}
// 生成新密码哈希
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.NewPassword), bcrypt.DefaultCost)
if err != nil {
return fmt.Errorf("hash password: %w", err)
}
user.PasswordHash = string(hashedPassword)
if err := s.userRepo.Update(ctx, user); err != nil {
return fmt.Errorf("update user: %w", err)
}
return nil
}
// GetByID 根据ID获取用户管理员功能
func (s *UserService) GetByID(ctx context.Context, id int64) (*model.User, error) {
user, err := s.userRepo.GetByID(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
return nil, fmt.Errorf("get user: %w", err)
}
return user, nil
}
// List 获取用户列表(管理员功能)
func (s *UserService) List(ctx context.Context, params repository.PaginationParams) ([]model.User, *repository.PaginationResult, error) {
users, pagination, err := s.userRepo.List(ctx, params)
if err != nil {
return nil, nil, fmt.Errorf("list users: %w", err)
}
return users, pagination, nil
}
// UpdateBalance 更新用户余额(管理员功能)
func (s *UserService) UpdateBalance(ctx context.Context, userID int64, amount float64) error {
if err := s.userRepo.UpdateBalance(ctx, userID, amount); err != nil {
return fmt.Errorf("update balance: %w", err)
}
return nil
}
// UpdateStatus 更新用户状态(管理员功能)
func (s *UserService) UpdateStatus(ctx context.Context, userID int64, status string) error {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrUserNotFound
}
return fmt.Errorf("get user: %w", err)
}
user.Status = status
if err := s.userRepo.Update(ctx, user); err != nil {
return fmt.Errorf("update user: %w", err)
}
return nil
}
// Delete 删除用户(管理员功能)
func (s *UserService) Delete(ctx context.Context, userID int64) error {
if err := s.userRepo.Delete(ctx, userID); err != nil {
return fmt.Errorf("delete user: %w", err)
}
return nil
}