merge: 合并 upstream/main 解决 PR #37 冲突
- 删除 backend/internal/model/account.go 符合重构方向 - 合并最新的项目结构重构 - 包含 SSE 格式解析修复 - 更新依赖和配置文件
This commit is contained in:
324
backend/internal/service/account.go
Normal file
324
backend/internal/service/account.go
Normal file
@@ -0,0 +1,324 @@
|
||||
package service
|
||||
|
||||
import "time"
|
||||
|
||||
type Account struct {
|
||||
ID int64
|
||||
Name string
|
||||
Platform string
|
||||
Type string
|
||||
Credentials map[string]any
|
||||
Extra map[string]any
|
||||
ProxyID *int64
|
||||
Concurrency int
|
||||
Priority int
|
||||
Status string
|
||||
ErrorMessage string
|
||||
LastUsedAt *time.Time
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
|
||||
Schedulable bool
|
||||
|
||||
RateLimitedAt *time.Time
|
||||
RateLimitResetAt *time.Time
|
||||
OverloadUntil *time.Time
|
||||
|
||||
SessionWindowStart *time.Time
|
||||
SessionWindowEnd *time.Time
|
||||
SessionWindowStatus string
|
||||
|
||||
Proxy *Proxy
|
||||
AccountGroups []AccountGroup
|
||||
GroupIDs []int64
|
||||
Groups []*Group
|
||||
}
|
||||
|
||||
func (a *Account) IsActive() bool {
|
||||
return a.Status == StatusActive
|
||||
}
|
||||
|
||||
func (a *Account) IsSchedulable() bool {
|
||||
if !a.IsActive() || !a.Schedulable {
|
||||
return false
|
||||
}
|
||||
now := time.Now()
|
||||
if a.OverloadUntil != nil && now.Before(*a.OverloadUntil) {
|
||||
return false
|
||||
}
|
||||
if a.RateLimitResetAt != nil && now.Before(*a.RateLimitResetAt) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (a *Account) IsRateLimited() bool {
|
||||
if a.RateLimitResetAt == nil {
|
||||
return false
|
||||
}
|
||||
return time.Now().Before(*a.RateLimitResetAt)
|
||||
}
|
||||
|
||||
func (a *Account) IsOverloaded() bool {
|
||||
if a.OverloadUntil == nil {
|
||||
return false
|
||||
}
|
||||
return time.Now().Before(*a.OverloadUntil)
|
||||
}
|
||||
|
||||
func (a *Account) IsOAuth() bool {
|
||||
return a.Type == AccountTypeOAuth || a.Type == AccountTypeSetupToken
|
||||
}
|
||||
|
||||
func (a *Account) CanGetUsage() bool {
|
||||
return a.Type == AccountTypeOAuth
|
||||
}
|
||||
|
||||
func (a *Account) GetCredential(key string) string {
|
||||
if a.Credentials == nil {
|
||||
return ""
|
||||
}
|
||||
if v, ok := a.Credentials[key]; ok {
|
||||
if s, ok := v.(string); ok {
|
||||
return s
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (a *Account) GetModelMapping() map[string]string {
|
||||
if a.Credentials == nil {
|
||||
return nil
|
||||
}
|
||||
raw, ok := a.Credentials["model_mapping"]
|
||||
if !ok || raw == nil {
|
||||
return nil
|
||||
}
|
||||
if m, ok := raw.(map[string]any); ok {
|
||||
result := make(map[string]string)
|
||||
for k, v := range m {
|
||||
if s, ok := v.(string); ok {
|
||||
result[k] = s
|
||||
}
|
||||
}
|
||||
if len(result) > 0 {
|
||||
return result
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Account) IsModelSupported(requestedModel string) bool {
|
||||
mapping := a.GetModelMapping()
|
||||
if len(mapping) == 0 {
|
||||
return true
|
||||
}
|
||||
_, exists := mapping[requestedModel]
|
||||
return exists
|
||||
}
|
||||
|
||||
func (a *Account) GetMappedModel(requestedModel string) string {
|
||||
mapping := a.GetModelMapping()
|
||||
if len(mapping) == 0 {
|
||||
return requestedModel
|
||||
}
|
||||
if mappedModel, exists := mapping[requestedModel]; exists {
|
||||
return mappedModel
|
||||
}
|
||||
return requestedModel
|
||||
}
|
||||
|
||||
func (a *Account) GetBaseURL() string {
|
||||
if a.Type != AccountTypeApiKey {
|
||||
return ""
|
||||
}
|
||||
baseURL := a.GetCredential("base_url")
|
||||
if baseURL == "" {
|
||||
return "https://api.anthropic.com"
|
||||
}
|
||||
return baseURL
|
||||
}
|
||||
|
||||
func (a *Account) GetExtraString(key string) string {
|
||||
if a.Extra == nil {
|
||||
return ""
|
||||
}
|
||||
if v, ok := a.Extra[key]; ok {
|
||||
if s, ok := v.(string); ok {
|
||||
return s
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (a *Account) IsCustomErrorCodesEnabled() bool {
|
||||
if a.Type != AccountTypeApiKey || a.Credentials == nil {
|
||||
return false
|
||||
}
|
||||
if v, ok := a.Credentials["custom_error_codes_enabled"]; ok {
|
||||
if enabled, ok := v.(bool); ok {
|
||||
return enabled
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (a *Account) GetCustomErrorCodes() []int {
|
||||
if a.Credentials == nil {
|
||||
return nil
|
||||
}
|
||||
raw, ok := a.Credentials["custom_error_codes"]
|
||||
if !ok || raw == nil {
|
||||
return nil
|
||||
}
|
||||
if arr, ok := raw.([]any); ok {
|
||||
result := make([]int, 0, len(arr))
|
||||
for _, v := range arr {
|
||||
if f, ok := v.(float64); ok {
|
||||
result = append(result, int(f))
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Account) ShouldHandleErrorCode(statusCode int) bool {
|
||||
if !a.IsCustomErrorCodesEnabled() {
|
||||
return true
|
||||
}
|
||||
codes := a.GetCustomErrorCodes()
|
||||
if len(codes) == 0 {
|
||||
return true
|
||||
}
|
||||
for _, code := range codes {
|
||||
if code == statusCode {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (a *Account) IsInterceptWarmupEnabled() bool {
|
||||
if a.Credentials == nil {
|
||||
return false
|
||||
}
|
||||
if v, ok := a.Credentials["intercept_warmup_requests"]; ok {
|
||||
if enabled, ok := v.(bool); ok {
|
||||
return enabled
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (a *Account) IsOpenAI() bool {
|
||||
return a.Platform == PlatformOpenAI
|
||||
}
|
||||
|
||||
func (a *Account) IsAnthropic() bool {
|
||||
return a.Platform == PlatformAnthropic
|
||||
}
|
||||
|
||||
func (a *Account) IsOpenAIOAuth() bool {
|
||||
return a.IsOpenAI() && a.Type == AccountTypeOAuth
|
||||
}
|
||||
|
||||
func (a *Account) IsOpenAIApiKey() bool {
|
||||
return a.IsOpenAI() && a.Type == AccountTypeApiKey
|
||||
}
|
||||
|
||||
func (a *Account) GetOpenAIBaseURL() string {
|
||||
if !a.IsOpenAI() {
|
||||
return ""
|
||||
}
|
||||
if a.Type == AccountTypeApiKey {
|
||||
baseURL := a.GetCredential("base_url")
|
||||
if baseURL != "" {
|
||||
return baseURL
|
||||
}
|
||||
}
|
||||
return "https://api.openai.com"
|
||||
}
|
||||
|
||||
func (a *Account) GetOpenAIAccessToken() string {
|
||||
if !a.IsOpenAI() {
|
||||
return ""
|
||||
}
|
||||
return a.GetCredential("access_token")
|
||||
}
|
||||
|
||||
func (a *Account) GetOpenAIRefreshToken() string {
|
||||
if !a.IsOpenAIOAuth() {
|
||||
return ""
|
||||
}
|
||||
return a.GetCredential("refresh_token")
|
||||
}
|
||||
|
||||
func (a *Account) GetOpenAIIDToken() string {
|
||||
if !a.IsOpenAIOAuth() {
|
||||
return ""
|
||||
}
|
||||
return a.GetCredential("id_token")
|
||||
}
|
||||
|
||||
func (a *Account) GetOpenAIApiKey() string {
|
||||
if !a.IsOpenAIApiKey() {
|
||||
return ""
|
||||
}
|
||||
return a.GetCredential("api_key")
|
||||
}
|
||||
|
||||
func (a *Account) GetOpenAIUserAgent() string {
|
||||
if !a.IsOpenAI() {
|
||||
return ""
|
||||
}
|
||||
return a.GetCredential("user_agent")
|
||||
}
|
||||
|
||||
func (a *Account) GetChatGPTAccountID() string {
|
||||
if !a.IsOpenAIOAuth() {
|
||||
return ""
|
||||
}
|
||||
return a.GetCredential("chatgpt_account_id")
|
||||
}
|
||||
|
||||
func (a *Account) GetChatGPTUserID() string {
|
||||
if !a.IsOpenAIOAuth() {
|
||||
return ""
|
||||
}
|
||||
return a.GetCredential("chatgpt_user_id")
|
||||
}
|
||||
|
||||
func (a *Account) GetOpenAIOrganizationID() string {
|
||||
if !a.IsOpenAIOAuth() {
|
||||
return ""
|
||||
}
|
||||
return a.GetCredential("organization_id")
|
||||
}
|
||||
|
||||
func (a *Account) GetOpenAITokenExpiresAt() *time.Time {
|
||||
if !a.IsOpenAIOAuth() {
|
||||
return nil
|
||||
}
|
||||
expiresAtStr := a.GetCredential("expires_at")
|
||||
if expiresAtStr == "" {
|
||||
return nil
|
||||
}
|
||||
t, err := time.Parse(time.RFC3339, expiresAtStr)
|
||||
if err != nil {
|
||||
if v, ok := a.Credentials["expires_at"].(float64); ok {
|
||||
tt := time.Unix(int64(v), 0)
|
||||
return &tt
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return &t
|
||||
}
|
||||
|
||||
func (a *Account) IsOpenAITokenExpired() bool {
|
||||
expiresAt := a.GetOpenAITokenExpiresAt()
|
||||
if expiresAt == nil {
|
||||
return false
|
||||
}
|
||||
return time.Now().Add(60 * time.Second).After(*expiresAt)
|
||||
}
|
||||
13
backend/internal/service/account_group.go
Normal file
13
backend/internal/service/account_group.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package service
|
||||
|
||||
import "time"
|
||||
|
||||
type AccountGroup struct {
|
||||
AccountID int64
|
||||
GroupID int64
|
||||
Priority int
|
||||
CreatedAt time.Time
|
||||
|
||||
Account *Account
|
||||
Group *Group
|
||||
}
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
@@ -15,29 +14,29 @@ var (
|
||||
)
|
||||
|
||||
type AccountRepository interface {
|
||||
Create(ctx context.Context, account *model.Account) error
|
||||
GetByID(ctx context.Context, id int64) (*model.Account, error)
|
||||
Create(ctx context.Context, account *Account) error
|
||||
GetByID(ctx context.Context, id int64) (*Account, error)
|
||||
// GetByCRSAccountID finds an account previously synced from CRS.
|
||||
// Returns (nil, nil) if not found.
|
||||
GetByCRSAccountID(ctx context.Context, crsAccountID string) (*model.Account, error)
|
||||
Update(ctx context.Context, account *model.Account) error
|
||||
GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error)
|
||||
Update(ctx context.Context, account *Account) error
|
||||
Delete(ctx context.Context, id int64) error
|
||||
|
||||
List(ctx context.Context, params pagination.PaginationParams) ([]model.Account, *pagination.PaginationResult, error)
|
||||
ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]model.Account, *pagination.PaginationResult, error)
|
||||
ListByGroup(ctx context.Context, groupID int64) ([]model.Account, error)
|
||||
ListActive(ctx context.Context) ([]model.Account, error)
|
||||
ListByPlatform(ctx context.Context, platform string) ([]model.Account, error)
|
||||
List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error)
|
||||
ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error)
|
||||
ListByGroup(ctx context.Context, groupID int64) ([]Account, error)
|
||||
ListActive(ctx context.Context) ([]Account, error)
|
||||
ListByPlatform(ctx context.Context, platform string) ([]Account, error)
|
||||
|
||||
UpdateLastUsed(ctx context.Context, id int64) error
|
||||
SetError(ctx context.Context, id int64, errorMsg string) error
|
||||
SetSchedulable(ctx context.Context, id int64, schedulable bool) error
|
||||
BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error
|
||||
|
||||
ListSchedulable(ctx context.Context) ([]model.Account, error)
|
||||
ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]model.Account, error)
|
||||
ListSchedulableByPlatform(ctx context.Context, platform string) ([]model.Account, error)
|
||||
ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]model.Account, error)
|
||||
ListSchedulable(ctx context.Context) ([]Account, error)
|
||||
ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]Account, error)
|
||||
ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error)
|
||||
ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error)
|
||||
|
||||
SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error
|
||||
SetOverloaded(ctx context.Context, id int64, until time.Time) error
|
||||
@@ -99,7 +98,7 @@ func NewAccountService(accountRepo AccountRepository, groupRepo GroupRepository)
|
||||
}
|
||||
|
||||
// Create 创建账号
|
||||
func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (*model.Account, error) {
|
||||
func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (*Account, error) {
|
||||
// 验证分组是否存在(如果指定了分组)
|
||||
if len(req.GroupIDs) > 0 {
|
||||
for _, groupID := range req.GroupIDs {
|
||||
@@ -111,7 +110,7 @@ func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (
|
||||
}
|
||||
|
||||
// 创建账号
|
||||
account := &model.Account{
|
||||
account := &Account{
|
||||
Name: req.Name,
|
||||
Platform: req.Platform,
|
||||
Type: req.Type,
|
||||
@@ -120,7 +119,7 @@ func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (
|
||||
ProxyID: req.ProxyID,
|
||||
Concurrency: req.Concurrency,
|
||||
Priority: req.Priority,
|
||||
Status: model.StatusActive,
|
||||
Status: StatusActive,
|
||||
}
|
||||
|
||||
if err := s.accountRepo.Create(ctx, account); err != nil {
|
||||
@@ -138,7 +137,7 @@ func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取账号
|
||||
func (s *AccountService) GetByID(ctx context.Context, id int64) (*model.Account, error) {
|
||||
func (s *AccountService) GetByID(ctx context.Context, id int64) (*Account, error) {
|
||||
account, err := s.accountRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get account: %w", err)
|
||||
@@ -147,7 +146,7 @@ func (s *AccountService) GetByID(ctx context.Context, id int64) (*model.Account,
|
||||
}
|
||||
|
||||
// List 获取账号列表
|
||||
func (s *AccountService) List(ctx context.Context, params pagination.PaginationParams) ([]model.Account, *pagination.PaginationResult, error) {
|
||||
func (s *AccountService) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) {
|
||||
accounts, pagination, err := s.accountRepo.List(ctx, params)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("list accounts: %w", err)
|
||||
@@ -156,7 +155,7 @@ func (s *AccountService) List(ctx context.Context, params pagination.PaginationP
|
||||
}
|
||||
|
||||
// ListByPlatform 根据平台获取账号列表
|
||||
func (s *AccountService) ListByPlatform(ctx context.Context, platform string) ([]model.Account, error) {
|
||||
func (s *AccountService) ListByPlatform(ctx context.Context, platform string) ([]Account, error) {
|
||||
accounts, err := s.accountRepo.ListByPlatform(ctx, platform)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list accounts by platform: %w", err)
|
||||
@@ -165,7 +164,7 @@ func (s *AccountService) ListByPlatform(ctx context.Context, platform string) ([
|
||||
}
|
||||
|
||||
// ListByGroup 根据分组获取账号列表
|
||||
func (s *AccountService) ListByGroup(ctx context.Context, groupID int64) ([]model.Account, error) {
|
||||
func (s *AccountService) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) {
|
||||
accounts, err := s.accountRepo.ListByGroup(ctx, groupID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list accounts by group: %w", err)
|
||||
@@ -174,7 +173,7 @@ func (s *AccountService) ListByGroup(ctx context.Context, groupID int64) ([]mode
|
||||
}
|
||||
|
||||
// Update 更新账号
|
||||
func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccountRequest) (*model.Account, error) {
|
||||
func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccountRequest) (*Account, error) {
|
||||
account, err := s.accountRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get account: %w", err)
|
||||
@@ -290,13 +289,13 @@ func (s *AccountService) TestCredentials(ctx context.Context, id int64) error {
|
||||
|
||||
// 根据平台执行不同的测试逻辑
|
||||
switch account.Platform {
|
||||
case model.PlatformAnthropic:
|
||||
case PlatformAnthropic:
|
||||
// TODO: 测试Anthropic API凭证
|
||||
return nil
|
||||
case model.PlatformOpenAI:
|
||||
case PlatformOpenAI:
|
||||
// TODO: 测试OpenAI API凭证
|
||||
return nil
|
||||
case model.PlatformGemini:
|
||||
case PlatformGemini:
|
||||
// TODO: 测试Gemini API凭证
|
||||
return nil
|
||||
default:
|
||||
|
||||
@@ -11,11 +11,11 @@ import (
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
@@ -23,6 +23,10 @@ import (
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// sseDataPrefix matches SSE data lines with optional whitespace after colon.
|
||||
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
|
||||
var sseDataPrefix = regexp.MustCompile(`^data:\s*`)
|
||||
|
||||
const (
|
||||
testClaudeAPIURL = "https://api.anthropic.com/v1/messages"
|
||||
testOpenAIAPIURL = "https://api.openai.com/v1/responses"
|
||||
@@ -141,7 +145,7 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
|
||||
}
|
||||
|
||||
// testClaudeAccountConnection tests an Anthropic Claude account's connection
|
||||
func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account *model.Account, modelID string) error {
|
||||
func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account *Account, modelID string) error {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Determine the model to use
|
||||
@@ -268,7 +272,7 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
|
||||
}
|
||||
|
||||
// testOpenAIAccountConnection tests an OpenAI account's connection
|
||||
func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account *model.Account, modelID string) error {
|
||||
func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account *Account, modelID string) error {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Default to openai.DefaultTestModel for OpenAI testing
|
||||
@@ -667,11 +671,11 @@ func (s *AccountTestService) processClaudeStream(c *gin.Context, body io.Reader)
|
||||
}
|
||||
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" || !strings.HasPrefix(line, "data: ") {
|
||||
if line == "" || !sseDataPrefix.MatchString(line) {
|
||||
continue
|
||||
}
|
||||
|
||||
jsonStr := strings.TrimPrefix(line, "data: ")
|
||||
jsonStr := sseDataPrefix.ReplaceAllString(line, "")
|
||||
if jsonStr == "[DONE]" {
|
||||
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
|
||||
return nil
|
||||
@@ -721,11 +725,11 @@ func (s *AccountTestService) processOpenAIStream(c *gin.Context, body io.Reader)
|
||||
}
|
||||
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" || !strings.HasPrefix(line, "data: ") {
|
||||
if line == "" || !sseDataPrefix.MatchString(line) {
|
||||
continue
|
||||
}
|
||||
|
||||
jsonStr := strings.TrimPrefix(line, "data: ")
|
||||
jsonStr := sseDataPrefix.ReplaceAllString(line, "")
|
||||
if jsonStr == "[DONE]" {
|
||||
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
|
||||
return nil
|
||||
|
||||
@@ -7,24 +7,23 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
)
|
||||
|
||||
type UsageLogRepository interface {
|
||||
Create(ctx context.Context, log *model.UsageLog) error
|
||||
GetByID(ctx context.Context, id int64) (*model.UsageLog, error)
|
||||
Create(ctx context.Context, log *UsageLog) error
|
||||
GetByID(ctx context.Context, id int64) (*UsageLog, error)
|
||||
Delete(ctx context.Context, id int64) error
|
||||
|
||||
ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error)
|
||||
ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error)
|
||||
ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error)
|
||||
ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error)
|
||||
ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error)
|
||||
ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error)
|
||||
|
||||
ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error)
|
||||
ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error)
|
||||
ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error)
|
||||
ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error)
|
||||
ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
|
||||
ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
|
||||
ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
|
||||
ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
|
||||
|
||||
GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error)
|
||||
GetAccountTodayStats(ctx context.Context, accountID int64) (*usagestats.AccountStats, error)
|
||||
@@ -44,7 +43,7 @@ type UsageLogRepository interface {
|
||||
GetUserModelStats(ctx context.Context, userID int64, startTime, endTime time.Time) ([]usagestats.ModelStat, error)
|
||||
|
||||
// Admin usage listing/stats
|
||||
ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]model.UsageLog, *pagination.PaginationResult, error)
|
||||
ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]UsageLog, *pagination.PaginationResult, error)
|
||||
GetGlobalStats(ctx context.Context, startTime, endTime time.Time) (*usagestats.UsageStats, error)
|
||||
|
||||
// Account stats
|
||||
@@ -163,7 +162,7 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
|
||||
}
|
||||
|
||||
// Setup Token账号:根据session_window推算(没有profile scope,无法调用usage API)
|
||||
if account.Type == model.AccountTypeSetupToken {
|
||||
if account.Type == AccountTypeSetupToken {
|
||||
usage := s.estimateSetupTokenUsage(account)
|
||||
// 添加窗口统计
|
||||
s.addWindowStats(ctx, account, usage)
|
||||
@@ -175,7 +174,7 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
|
||||
}
|
||||
|
||||
// addWindowStats 为usage数据添加窗口期统计
|
||||
func (s *AccountUsageService) addWindowStats(ctx context.Context, account *model.Account, usage *UsageInfo) {
|
||||
func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Account, usage *UsageInfo) {
|
||||
if usage.FiveHour == nil {
|
||||
return
|
||||
}
|
||||
@@ -225,7 +224,7 @@ func (s *AccountUsageService) GetAccountUsageStats(ctx context.Context, accountI
|
||||
}
|
||||
|
||||
// fetchOAuthUsage 从Anthropic API获取OAuth账号的使用量
|
||||
func (s *AccountUsageService) fetchOAuthUsage(ctx context.Context, account *model.Account) (*UsageInfo, error) {
|
||||
func (s *AccountUsageService) fetchOAuthUsage(ctx context.Context, account *Account) (*UsageInfo, error) {
|
||||
accessToken := account.GetCredential("access_token")
|
||||
if accessToken == "" {
|
||||
return nil, fmt.Errorf("no access token available")
|
||||
@@ -320,7 +319,7 @@ func (s *AccountUsageService) buildUsageInfo(resp *ClaudeUsageResponse, updatedA
|
||||
}
|
||||
|
||||
// estimateSetupTokenUsage 根据session_window推算Setup Token账号的使用量
|
||||
func (s *AccountUsageService) estimateSetupTokenUsage(account *model.Account) *UsageInfo {
|
||||
func (s *AccountUsageService) estimateSetupTokenUsage(account *Account) *UsageInfo {
|
||||
info := &UsageInfo{}
|
||||
|
||||
// 如果有session_window信息
|
||||
|
||||
@@ -7,62 +7,61 @@ import (
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
// AdminService interface defines admin management operations
|
||||
type AdminService interface {
|
||||
// User management
|
||||
ListUsers(ctx context.Context, page, pageSize int, status, role, search string) ([]model.User, int64, error)
|
||||
GetUser(ctx context.Context, id int64) (*model.User, error)
|
||||
CreateUser(ctx context.Context, input *CreateUserInput) (*model.User, error)
|
||||
UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*model.User, error)
|
||||
ListUsers(ctx context.Context, page, pageSize int, status, role, search string) ([]User, int64, error)
|
||||
GetUser(ctx context.Context, id int64) (*User, error)
|
||||
CreateUser(ctx context.Context, input *CreateUserInput) (*User, error)
|
||||
UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error)
|
||||
DeleteUser(ctx context.Context, id int64) error
|
||||
UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*model.User, error)
|
||||
GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]model.ApiKey, int64, error)
|
||||
UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error)
|
||||
GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]ApiKey, int64, error)
|
||||
GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error)
|
||||
|
||||
// Group management
|
||||
ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]model.Group, int64, error)
|
||||
GetAllGroups(ctx context.Context) ([]model.Group, error)
|
||||
GetAllGroupsByPlatform(ctx context.Context, platform string) ([]model.Group, error)
|
||||
GetGroup(ctx context.Context, id int64) (*model.Group, error)
|
||||
CreateGroup(ctx context.Context, input *CreateGroupInput) (*model.Group, error)
|
||||
UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*model.Group, error)
|
||||
ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]Group, int64, error)
|
||||
GetAllGroups(ctx context.Context) ([]Group, error)
|
||||
GetAllGroupsByPlatform(ctx context.Context, platform string) ([]Group, error)
|
||||
GetGroup(ctx context.Context, id int64) (*Group, error)
|
||||
CreateGroup(ctx context.Context, input *CreateGroupInput) (*Group, error)
|
||||
UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error)
|
||||
DeleteGroup(ctx context.Context, id int64) error
|
||||
GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]model.ApiKey, int64, error)
|
||||
GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]ApiKey, int64, error)
|
||||
|
||||
// Account management
|
||||
ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]model.Account, int64, error)
|
||||
GetAccount(ctx context.Context, id int64) (*model.Account, error)
|
||||
CreateAccount(ctx context.Context, input *CreateAccountInput) (*model.Account, error)
|
||||
UpdateAccount(ctx context.Context, id int64, input *UpdateAccountInput) (*model.Account, error)
|
||||
ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error)
|
||||
GetAccount(ctx context.Context, id int64) (*Account, error)
|
||||
CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error)
|
||||
UpdateAccount(ctx context.Context, id int64, input *UpdateAccountInput) (*Account, error)
|
||||
DeleteAccount(ctx context.Context, id int64) error
|
||||
RefreshAccountCredentials(ctx context.Context, id int64) (*model.Account, error)
|
||||
ClearAccountError(ctx context.Context, id int64) (*model.Account, error)
|
||||
SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*model.Account, error)
|
||||
RefreshAccountCredentials(ctx context.Context, id int64) (*Account, error)
|
||||
ClearAccountError(ctx context.Context, id int64) (*Account, error)
|
||||
SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error)
|
||||
BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error)
|
||||
|
||||
// Proxy management
|
||||
ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]model.Proxy, int64, error)
|
||||
GetAllProxies(ctx context.Context) ([]model.Proxy, error)
|
||||
GetAllProxiesWithAccountCount(ctx context.Context) ([]model.ProxyWithAccountCount, error)
|
||||
GetProxy(ctx context.Context, id int64) (*model.Proxy, error)
|
||||
CreateProxy(ctx context.Context, input *CreateProxyInput) (*model.Proxy, error)
|
||||
UpdateProxy(ctx context.Context, id int64, input *UpdateProxyInput) (*model.Proxy, error)
|
||||
ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]Proxy, int64, error)
|
||||
GetAllProxies(ctx context.Context) ([]Proxy, error)
|
||||
GetAllProxiesWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error)
|
||||
GetProxy(ctx context.Context, id int64) (*Proxy, error)
|
||||
CreateProxy(ctx context.Context, input *CreateProxyInput) (*Proxy, error)
|
||||
UpdateProxy(ctx context.Context, id int64, input *UpdateProxyInput) (*Proxy, error)
|
||||
DeleteProxy(ctx context.Context, id int64) error
|
||||
GetProxyAccounts(ctx context.Context, proxyID int64, page, pageSize int) ([]model.Account, int64, error)
|
||||
GetProxyAccounts(ctx context.Context, proxyID int64, page, pageSize int) ([]Account, int64, error)
|
||||
CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error)
|
||||
TestProxy(ctx context.Context, id int64) (*ProxyTestResult, error)
|
||||
|
||||
// Redeem code management
|
||||
ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]model.RedeemCode, int64, error)
|
||||
GetRedeemCode(ctx context.Context, id int64) (*model.RedeemCode, error)
|
||||
GenerateRedeemCodes(ctx context.Context, input *GenerateRedeemCodesInput) ([]model.RedeemCode, error)
|
||||
ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]RedeemCode, int64, error)
|
||||
GetRedeemCode(ctx context.Context, id int64) (*RedeemCode, error)
|
||||
GenerateRedeemCodes(ctx context.Context, input *GenerateRedeemCodesInput) ([]RedeemCode, error)
|
||||
DeleteRedeemCode(ctx context.Context, id int64) error
|
||||
BatchDeleteRedeemCodes(ctx context.Context, ids []int64) (int64, error)
|
||||
ExpireRedeemCode(ctx context.Context, id int64) (*model.RedeemCode, error)
|
||||
ExpireRedeemCode(ctx context.Context, id int64) (*RedeemCode, error)
|
||||
}
|
||||
|
||||
// Input types for admin operations
|
||||
@@ -252,7 +251,7 @@ func NewAdminService(
|
||||
}
|
||||
|
||||
// User management implementations
|
||||
func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, status, role, search string) ([]model.User, int64, error) {
|
||||
func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, status, role, search string) ([]User, int64, error) {
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
users, result, err := s.userRepo.ListWithFilters(ctx, params, status, role, search)
|
||||
if err != nil {
|
||||
@@ -261,20 +260,21 @@ func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, st
|
||||
return users, result.Total, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*model.User, error) {
|
||||
func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error) {
|
||||
return s.userRepo.GetByID(ctx, id)
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInput) (*model.User, error) {
|
||||
user := &model.User{
|
||||
Email: input.Email,
|
||||
Username: input.Username,
|
||||
Wechat: input.Wechat,
|
||||
Notes: input.Notes,
|
||||
Role: "user", // Always create as regular user, never admin
|
||||
Balance: input.Balance,
|
||||
Concurrency: input.Concurrency,
|
||||
Status: model.StatusActive,
|
||||
func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInput) (*User, error) {
|
||||
user := &User{
|
||||
Email: input.Email,
|
||||
Username: input.Username,
|
||||
Wechat: input.Wechat,
|
||||
Notes: input.Notes,
|
||||
Role: RoleUser, // Always create as regular user, never admin
|
||||
Balance: input.Balance,
|
||||
Concurrency: input.Concurrency,
|
||||
Status: StatusActive,
|
||||
AllowedGroups: input.AllowedGroups,
|
||||
}
|
||||
if err := user.SetPassword(input.Password); err != nil {
|
||||
return nil, err
|
||||
@@ -285,7 +285,7 @@ func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInpu
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*model.User, error) {
|
||||
func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error) {
|
||||
user, err := s.userRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -335,16 +335,16 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
|
||||
|
||||
concurrencyDiff := user.Concurrency - oldConcurrency
|
||||
if concurrencyDiff != 0 {
|
||||
code, err := model.GenerateRedeemCode()
|
||||
code, err := GenerateRedeemCode()
|
||||
if err != nil {
|
||||
log.Printf("failed to generate adjustment redeem code: %v", err)
|
||||
return user, nil
|
||||
}
|
||||
adjustmentRecord := &model.RedeemCode{
|
||||
adjustmentRecord := &RedeemCode{
|
||||
Code: code,
|
||||
Type: model.AdjustmentTypeAdminConcurrency,
|
||||
Type: AdjustmentTypeAdminConcurrency,
|
||||
Value: float64(concurrencyDiff),
|
||||
Status: model.StatusUsed,
|
||||
Status: StatusUsed,
|
||||
UsedBy: &user.ID,
|
||||
}
|
||||
now := time.Now()
|
||||
@@ -369,7 +369,7 @@ func (s *adminServiceImpl) DeleteUser(ctx context.Context, id int64) error {
|
||||
return s.userRepo.Delete(ctx, id)
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*model.User, error) {
|
||||
func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error) {
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -406,17 +406,17 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
|
||||
|
||||
balanceDiff := user.Balance - oldBalance
|
||||
if balanceDiff != 0 {
|
||||
code, err := model.GenerateRedeemCode()
|
||||
code, err := GenerateRedeemCode()
|
||||
if err != nil {
|
||||
log.Printf("failed to generate adjustment redeem code: %v", err)
|
||||
return user, nil
|
||||
}
|
||||
|
||||
adjustmentRecord := &model.RedeemCode{
|
||||
adjustmentRecord := &RedeemCode{
|
||||
Code: code,
|
||||
Type: model.AdjustmentTypeAdminBalance,
|
||||
Type: AdjustmentTypeAdminBalance,
|
||||
Value: balanceDiff,
|
||||
Status: model.StatusUsed,
|
||||
Status: StatusUsed,
|
||||
UsedBy: &user.ID,
|
||||
Notes: notes,
|
||||
}
|
||||
@@ -431,7 +431,7 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]model.ApiKey, int64, error) {
|
||||
func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]ApiKey, int64, error) {
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
keys, result, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
|
||||
if err != nil {
|
||||
@@ -452,7 +452,7 @@ func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64,
|
||||
}
|
||||
|
||||
// Group management implementations
|
||||
func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]model.Group, int64, error) {
|
||||
func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]Group, int64, error) {
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
groups, result, err := s.groupRepo.ListWithFilters(ctx, params, platform, status, isExclusive)
|
||||
if err != nil {
|
||||
@@ -461,36 +461,36 @@ func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, p
|
||||
return groups, result.Total, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) GetAllGroups(ctx context.Context) ([]model.Group, error) {
|
||||
func (s *adminServiceImpl) GetAllGroups(ctx context.Context) ([]Group, error) {
|
||||
return s.groupRepo.ListActive(ctx)
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) GetAllGroupsByPlatform(ctx context.Context, platform string) ([]model.Group, error) {
|
||||
func (s *adminServiceImpl) GetAllGroupsByPlatform(ctx context.Context, platform string) ([]Group, error) {
|
||||
return s.groupRepo.ListActiveByPlatform(ctx, platform)
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) GetGroup(ctx context.Context, id int64) (*model.Group, error) {
|
||||
func (s *adminServiceImpl) GetGroup(ctx context.Context, id int64) (*Group, error) {
|
||||
return s.groupRepo.GetByID(ctx, id)
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupInput) (*model.Group, error) {
|
||||
func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupInput) (*Group, error) {
|
||||
platform := input.Platform
|
||||
if platform == "" {
|
||||
platform = model.PlatformAnthropic
|
||||
platform = PlatformAnthropic
|
||||
}
|
||||
|
||||
subscriptionType := input.SubscriptionType
|
||||
if subscriptionType == "" {
|
||||
subscriptionType = model.SubscriptionTypeStandard
|
||||
subscriptionType = SubscriptionTypeStandard
|
||||
}
|
||||
|
||||
group := &model.Group{
|
||||
group := &Group{
|
||||
Name: input.Name,
|
||||
Description: input.Description,
|
||||
Platform: platform,
|
||||
RateMultiplier: input.RateMultiplier,
|
||||
IsExclusive: input.IsExclusive,
|
||||
Status: model.StatusActive,
|
||||
Status: StatusActive,
|
||||
SubscriptionType: subscriptionType,
|
||||
DailyLimitUSD: input.DailyLimitUSD,
|
||||
WeeklyLimitUSD: input.WeeklyLimitUSD,
|
||||
@@ -502,7 +502,7 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
|
||||
return group, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*model.Group, error) {
|
||||
func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) {
|
||||
group, err := s.groupRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -571,7 +571,7 @@ func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]model.ApiKey, int64, error) {
|
||||
func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]ApiKey, int64, error) {
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
keys, result, err := s.apiKeyRepo.ListByGroupID(ctx, groupID, params)
|
||||
if err != nil {
|
||||
@@ -581,7 +581,7 @@ func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, p
|
||||
}
|
||||
|
||||
// Account management implementations
|
||||
func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]model.Account, int64, error) {
|
||||
func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error) {
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search)
|
||||
if err != nil {
|
||||
@@ -590,21 +590,21 @@ func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int,
|
||||
return accounts, result.Total, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) GetAccount(ctx context.Context, id int64) (*model.Account, error) {
|
||||
func (s *adminServiceImpl) GetAccount(ctx context.Context, id int64) (*Account, error) {
|
||||
return s.accountRepo.GetByID(ctx, id)
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccountInput) (*model.Account, error) {
|
||||
account := &model.Account{
|
||||
func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error) {
|
||||
account := &Account{
|
||||
Name: input.Name,
|
||||
Platform: input.Platform,
|
||||
Type: input.Type,
|
||||
Credentials: model.JSONB(input.Credentials),
|
||||
Extra: model.JSONB(input.Extra),
|
||||
Credentials: input.Credentials,
|
||||
Extra: input.Extra,
|
||||
ProxyID: input.ProxyID,
|
||||
Concurrency: input.Concurrency,
|
||||
Priority: input.Priority,
|
||||
Status: model.StatusActive,
|
||||
Status: StatusActive,
|
||||
}
|
||||
if err := s.accountRepo.Create(ctx, account); err != nil {
|
||||
return nil, err
|
||||
@@ -618,7 +618,7 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
|
||||
return account, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *UpdateAccountInput) (*model.Account, error) {
|
||||
func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *UpdateAccountInput) (*Account, error) {
|
||||
account, err := s.accountRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -631,10 +631,10 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
|
||||
account.Type = input.Type
|
||||
}
|
||||
if len(input.Credentials) > 0 {
|
||||
account.Credentials = model.JSONB(input.Credentials)
|
||||
account.Credentials = input.Credentials
|
||||
}
|
||||
if len(input.Extra) > 0 {
|
||||
account.Extra = model.JSONB(input.Extra)
|
||||
account.Extra = input.Extra
|
||||
}
|
||||
if input.ProxyID != nil {
|
||||
account.ProxyID = input.ProxyID
|
||||
@@ -730,7 +730,7 @@ func (s *adminServiceImpl) DeleteAccount(ctx context.Context, id int64) error {
|
||||
return s.accountRepo.Delete(ctx, id)
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) RefreshAccountCredentials(ctx context.Context, id int64) (*model.Account, error) {
|
||||
func (s *adminServiceImpl) RefreshAccountCredentials(ctx context.Context, id int64) (*Account, error) {
|
||||
account, err := s.accountRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -739,12 +739,12 @@ func (s *adminServiceImpl) RefreshAccountCredentials(ctx context.Context, id int
|
||||
return account, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) ClearAccountError(ctx context.Context, id int64) (*model.Account, error) {
|
||||
func (s *adminServiceImpl) ClearAccountError(ctx context.Context, id int64) (*Account, error) {
|
||||
account, err := s.accountRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
account.Status = model.StatusActive
|
||||
account.Status = StatusActive
|
||||
account.ErrorMessage = ""
|
||||
if err := s.accountRepo.Update(ctx, account); err != nil {
|
||||
return nil, err
|
||||
@@ -752,7 +752,7 @@ func (s *adminServiceImpl) ClearAccountError(ctx context.Context, id int64) (*mo
|
||||
return account, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*model.Account, error) {
|
||||
func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error) {
|
||||
if err := s.accountRepo.SetSchedulable(ctx, id, schedulable); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -760,7 +760,7 @@ func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64,
|
||||
}
|
||||
|
||||
// Proxy management implementations
|
||||
func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]model.Proxy, int64, error) {
|
||||
func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]Proxy, int64, error) {
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
proxies, result, err := s.proxyRepo.ListWithFilters(ctx, params, protocol, status, search)
|
||||
if err != nil {
|
||||
@@ -769,27 +769,27 @@ func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int,
|
||||
return proxies, result.Total, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) GetAllProxies(ctx context.Context) ([]model.Proxy, error) {
|
||||
func (s *adminServiceImpl) GetAllProxies(ctx context.Context) ([]Proxy, error) {
|
||||
return s.proxyRepo.ListActive(ctx)
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) GetAllProxiesWithAccountCount(ctx context.Context) ([]model.ProxyWithAccountCount, error) {
|
||||
func (s *adminServiceImpl) GetAllProxiesWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error) {
|
||||
return s.proxyRepo.ListActiveWithAccountCount(ctx)
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) GetProxy(ctx context.Context, id int64) (*model.Proxy, error) {
|
||||
func (s *adminServiceImpl) GetProxy(ctx context.Context, id int64) (*Proxy, error) {
|
||||
return s.proxyRepo.GetByID(ctx, id)
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) CreateProxy(ctx context.Context, input *CreateProxyInput) (*model.Proxy, error) {
|
||||
proxy := &model.Proxy{
|
||||
func (s *adminServiceImpl) CreateProxy(ctx context.Context, input *CreateProxyInput) (*Proxy, error) {
|
||||
proxy := &Proxy{
|
||||
Name: input.Name,
|
||||
Protocol: input.Protocol,
|
||||
Host: input.Host,
|
||||
Port: input.Port,
|
||||
Username: input.Username,
|
||||
Password: input.Password,
|
||||
Status: model.StatusActive,
|
||||
Status: StatusActive,
|
||||
}
|
||||
if err := s.proxyRepo.Create(ctx, proxy); err != nil {
|
||||
return nil, err
|
||||
@@ -797,7 +797,7 @@ func (s *adminServiceImpl) CreateProxy(ctx context.Context, input *CreateProxyIn
|
||||
return proxy, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) UpdateProxy(ctx context.Context, id int64, input *UpdateProxyInput) (*model.Proxy, error) {
|
||||
func (s *adminServiceImpl) UpdateProxy(ctx context.Context, id int64, input *UpdateProxyInput) (*Proxy, error) {
|
||||
proxy, err := s.proxyRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -835,9 +835,9 @@ func (s *adminServiceImpl) DeleteProxy(ctx context.Context, id int64) error {
|
||||
return s.proxyRepo.Delete(ctx, id)
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) GetProxyAccounts(ctx context.Context, proxyID int64, page, pageSize int) ([]model.Account, int64, error) {
|
||||
func (s *adminServiceImpl) GetProxyAccounts(ctx context.Context, proxyID int64, page, pageSize int) ([]Account, int64, error) {
|
||||
// Return mock data for now - would need a dedicated repository method
|
||||
return []model.Account{}, 0, nil
|
||||
return []Account{}, 0, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error) {
|
||||
@@ -845,7 +845,7 @@ func (s *adminServiceImpl) CheckProxyExists(ctx context.Context, host string, po
|
||||
}
|
||||
|
||||
// Redeem code management implementations
|
||||
func (s *adminServiceImpl) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]model.RedeemCode, int64, error) {
|
||||
func (s *adminServiceImpl) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]RedeemCode, int64, error) {
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
codes, result, err := s.redeemCodeRepo.ListWithFilters(ctx, params, codeType, status, search)
|
||||
if err != nil {
|
||||
@@ -854,13 +854,13 @@ func (s *adminServiceImpl) ListRedeemCodes(ctx context.Context, page, pageSize i
|
||||
return codes, result.Total, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) GetRedeemCode(ctx context.Context, id int64) (*model.RedeemCode, error) {
|
||||
func (s *adminServiceImpl) GetRedeemCode(ctx context.Context, id int64) (*RedeemCode, error) {
|
||||
return s.redeemCodeRepo.GetByID(ctx, id)
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) GenerateRedeemCodes(ctx context.Context, input *GenerateRedeemCodesInput) ([]model.RedeemCode, error) {
|
||||
func (s *adminServiceImpl) GenerateRedeemCodes(ctx context.Context, input *GenerateRedeemCodesInput) ([]RedeemCode, error) {
|
||||
// 如果是订阅类型,验证必须有 GroupID
|
||||
if input.Type == model.RedeemTypeSubscription {
|
||||
if input.Type == RedeemTypeSubscription {
|
||||
if input.GroupID == nil {
|
||||
return nil, errors.New("group_id is required for subscription type")
|
||||
}
|
||||
@@ -874,20 +874,20 @@ func (s *adminServiceImpl) GenerateRedeemCodes(ctx context.Context, input *Gener
|
||||
}
|
||||
}
|
||||
|
||||
codes := make([]model.RedeemCode, 0, input.Count)
|
||||
codes := make([]RedeemCode, 0, input.Count)
|
||||
for i := 0; i < input.Count; i++ {
|
||||
codeValue, err := model.GenerateRedeemCode()
|
||||
codeValue, err := GenerateRedeemCode()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
code := model.RedeemCode{
|
||||
code := RedeemCode{
|
||||
Code: codeValue,
|
||||
Type: input.Type,
|
||||
Value: input.Value,
|
||||
Status: model.StatusUnused,
|
||||
Status: StatusUnused,
|
||||
}
|
||||
// 订阅类型专用字段
|
||||
if input.Type == model.RedeemTypeSubscription {
|
||||
if input.Type == RedeemTypeSubscription {
|
||||
code.GroupID = input.GroupID
|
||||
code.ValidityDays = input.ValidityDays
|
||||
if code.ValidityDays <= 0 {
|
||||
@@ -916,12 +916,12 @@ func (s *adminServiceImpl) BatchDeleteRedeemCodes(ctx context.Context, ids []int
|
||||
return deleted, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) ExpireRedeemCode(ctx context.Context, id int64) (*model.RedeemCode, error) {
|
||||
func (s *adminServiceImpl) ExpireRedeemCode(ctx context.Context, id int64) (*RedeemCode, error) {
|
||||
code, err := s.redeemCodeRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
code.Status = model.StatusExpired
|
||||
code.Status = StatusExpired
|
||||
if err := s.redeemCodeRepo.Update(ctx, code); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
20
backend/internal/service/api_key.go
Normal file
20
backend/internal/service/api_key.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package service
|
||||
|
||||
import "time"
|
||||
|
||||
type ApiKey struct {
|
||||
ID int64
|
||||
UserID int64
|
||||
Key string
|
||||
Name string
|
||||
GroupID *int64
|
||||
Status string
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
User *User
|
||||
Group *Group
|
||||
}
|
||||
|
||||
func (k *ApiKey) IsActive() bool {
|
||||
return k.Status == StatusActive
|
||||
}
|
||||
@@ -4,16 +4,13 @@ import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -30,17 +27,17 @@ const (
|
||||
)
|
||||
|
||||
type ApiKeyRepository interface {
|
||||
Create(ctx context.Context, key *model.ApiKey) error
|
||||
GetByID(ctx context.Context, id int64) (*model.ApiKey, error)
|
||||
GetByKey(ctx context.Context, key string) (*model.ApiKey, error)
|
||||
Update(ctx context.Context, key *model.ApiKey) error
|
||||
Create(ctx context.Context, key *ApiKey) error
|
||||
GetByID(ctx context.Context, id int64) (*ApiKey, error)
|
||||
GetByKey(ctx context.Context, key string) (*ApiKey, error)
|
||||
Update(ctx context.Context, key *ApiKey) error
|
||||
Delete(ctx context.Context, id int64) error
|
||||
|
||||
ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error)
|
||||
ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error)
|
||||
CountByUserID(ctx context.Context, userID int64) (int64, error)
|
||||
ExistsByKey(ctx context.Context, key string) (bool, error)
|
||||
ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error)
|
||||
SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]model.ApiKey, error)
|
||||
ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error)
|
||||
SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]ApiKey, error)
|
||||
ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error)
|
||||
CountByGroupID(ctx context.Context, groupID int64) (int64, error)
|
||||
}
|
||||
@@ -144,7 +141,7 @@ func (s *ApiKeyService) checkApiKeyRateLimit(ctx context.Context, userID int64)
|
||||
}
|
||||
|
||||
count, err := s.cache.GetCreateAttemptCount(ctx, userID)
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
if err != nil {
|
||||
// Redis 出错时不阻止用户操作
|
||||
return nil
|
||||
}
|
||||
@@ -168,7 +165,7 @@ func (s *ApiKeyService) incrementApiKeyErrorCount(ctx context.Context, userID in
|
||||
// canUserBindGroup 检查用户是否可以绑定指定分组
|
||||
// 对于订阅类型分组:检查用户是否有有效订阅
|
||||
// 对于标准类型分组:使用原有的 AllowedGroups 和 IsExclusive 逻辑
|
||||
func (s *ApiKeyService) canUserBindGroup(ctx context.Context, user *model.User, group *model.Group) bool {
|
||||
func (s *ApiKeyService) canUserBindGroup(ctx context.Context, user *User, group *Group) bool {
|
||||
// 订阅类型分组:需要有效订阅
|
||||
if group.IsSubscriptionType() {
|
||||
_, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, user.ID, group.ID)
|
||||
@@ -179,7 +176,7 @@ func (s *ApiKeyService) canUserBindGroup(ctx context.Context, user *model.User,
|
||||
}
|
||||
|
||||
// Create 创建API Key
|
||||
func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiKeyRequest) (*model.ApiKey, error) {
|
||||
func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiKeyRequest) (*ApiKey, error) {
|
||||
// 验证用户存在
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
@@ -235,12 +232,12 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK
|
||||
}
|
||||
|
||||
// 创建API Key记录
|
||||
apiKey := &model.ApiKey{
|
||||
apiKey := &ApiKey{
|
||||
UserID: userID,
|
||||
Key: key,
|
||||
Name: req.Name,
|
||||
GroupID: req.GroupID,
|
||||
Status: model.StatusActive,
|
||||
Status: StatusActive,
|
||||
}
|
||||
|
||||
if err := s.apiKeyRepo.Create(ctx, apiKey); err != nil {
|
||||
@@ -251,7 +248,7 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK
|
||||
}
|
||||
|
||||
// List 获取用户的API Key列表
|
||||
func (s *ApiKeyService) List(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) {
|
||||
func (s *ApiKeyService) List(ctx context.Context, userID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error) {
|
||||
keys, pagination, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("list api keys: %w", err)
|
||||
@@ -260,7 +257,7 @@ func (s *ApiKeyService) List(ctx context.Context, userID int64, params paginatio
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取API Key
|
||||
func (s *ApiKeyService) GetByID(ctx context.Context, id int64) (*model.ApiKey, error) {
|
||||
func (s *ApiKeyService) GetByID(ctx context.Context, id int64) (*ApiKey, error) {
|
||||
apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get api key: %w", err)
|
||||
@@ -269,7 +266,7 @@ func (s *ApiKeyService) GetByID(ctx context.Context, id int64) (*model.ApiKey, e
|
||||
}
|
||||
|
||||
// GetByKey 根据Key字符串获取API Key(用于认证)
|
||||
func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*model.ApiKey, error) {
|
||||
func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*ApiKey, error) {
|
||||
// 尝试从Redis缓存获取
|
||||
cacheKey := fmt.Sprintf("apikey:%s", key)
|
||||
|
||||
@@ -289,7 +286,7 @@ func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*model.ApiKey
|
||||
}
|
||||
|
||||
// Update 更新API Key
|
||||
func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req UpdateApiKeyRequest) (*model.ApiKey, error) {
|
||||
func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req UpdateApiKeyRequest) (*ApiKey, error) {
|
||||
apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get api key: %w", err)
|
||||
@@ -364,7 +361,7 @@ func (s *ApiKeyService) Delete(ctx context.Context, id int64, userID int64) erro
|
||||
}
|
||||
|
||||
// ValidateKey 验证API Key是否有效(用于认证中间件)
|
||||
func (s *ApiKeyService) ValidateKey(ctx context.Context, key string) (*model.ApiKey, *model.User, error) {
|
||||
func (s *ApiKeyService) ValidateKey(ctx context.Context, key string) (*ApiKey, *User, error) {
|
||||
// 获取API Key
|
||||
apiKey, err := s.GetByKey(ctx, key)
|
||||
if err != nil {
|
||||
@@ -408,7 +405,7 @@ func (s *ApiKeyService) IncrementUsage(ctx context.Context, keyID int64) error {
|
||||
// 返回用户可以选择的分组:
|
||||
// - 标准类型分组:公开的(非专属)或用户被明确允许的
|
||||
// - 订阅类型分组:用户有有效订阅的
|
||||
func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([]model.Group, error) {
|
||||
func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([]Group, error) {
|
||||
// 获取用户信息
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
@@ -434,7 +431,7 @@ func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([
|
||||
}
|
||||
|
||||
// 过滤出用户有权限的分组
|
||||
availableGroups := make([]model.Group, 0)
|
||||
availableGroups := make([]Group, 0)
|
||||
for _, group := range allGroups {
|
||||
if s.canUserBindGroupInternal(user, &group, subscribedGroupIDs) {
|
||||
availableGroups = append(availableGroups, group)
|
||||
@@ -445,7 +442,7 @@ func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([
|
||||
}
|
||||
|
||||
// canUserBindGroupInternal 内部方法,检查用户是否可以绑定分组(使用预加载的订阅数据)
|
||||
func (s *ApiKeyService) canUserBindGroupInternal(user *model.User, group *model.Group, subscribedGroupIDs map[int64]bool) bool {
|
||||
func (s *ApiKeyService) canUserBindGroupInternal(user *User, group *Group, subscribedGroupIDs map[int64]bool) bool {
|
||||
// 订阅类型分组:需要有效订阅
|
||||
if group.IsSubscriptionType() {
|
||||
return subscribedGroupIDs[group.ID]
|
||||
@@ -454,7 +451,7 @@ func (s *ApiKeyService) canUserBindGroupInternal(user *model.User, group *model.
|
||||
return user.CanBindGroup(group.ID, group.IsExclusive)
|
||||
}
|
||||
|
||||
func (s *ApiKeyService) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]model.ApiKey, error) {
|
||||
func (s *ApiKeyService) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]ApiKey, error) {
|
||||
keys, err := s.apiKeyRepo.SearchApiKeys(ctx, userID, keyword, limit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("search api keys: %w", err)
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
@@ -64,12 +63,12 @@ func NewAuthService(
|
||||
}
|
||||
|
||||
// Register 用户注册,返回token和用户
|
||||
func (s *AuthService) Register(ctx context.Context, email, password string) (string, *model.User, error) {
|
||||
func (s *AuthService) Register(ctx context.Context, email, password string) (string, *User, error) {
|
||||
return s.RegisterWithVerification(ctx, email, password, "")
|
||||
}
|
||||
|
||||
// RegisterWithVerification 用户注册(支持邮件验证),返回token和用户
|
||||
func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode string) (string, *model.User, error) {
|
||||
func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode string) (string, *User, error) {
|
||||
// 检查是否开放注册
|
||||
if s.settingService != nil && !s.settingService.IsRegistrationEnabled(ctx) {
|
||||
return "", nil, ErrRegDisabled
|
||||
@@ -113,13 +112,13 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
|
||||
}
|
||||
|
||||
// 创建用户
|
||||
user := &model.User{
|
||||
user := &User{
|
||||
Email: email,
|
||||
PasswordHash: hashedPassword,
|
||||
Role: model.RoleUser,
|
||||
Role: RoleUser,
|
||||
Balance: defaultBalance,
|
||||
Concurrency: defaultConcurrency,
|
||||
Status: model.StatusActive,
|
||||
Status: StatusActive,
|
||||
}
|
||||
|
||||
if err := s.userRepo.Create(ctx, user); err != nil {
|
||||
@@ -251,7 +250,7 @@ func (s *AuthService) IsEmailVerifyEnabled(ctx context.Context) bool {
|
||||
}
|
||||
|
||||
// Login 用户登录,返回JWT token
|
||||
func (s *AuthService) Login(ctx context.Context, email, password string) (string, *model.User, error) {
|
||||
func (s *AuthService) Login(ctx context.Context, email, password string) (string, *User, error) {
|
||||
// 查找用户
|
||||
user, err := s.userRepo.GetByEmail(ctx, email)
|
||||
if err != nil {
|
||||
@@ -307,7 +306,7 @@ func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) {
|
||||
}
|
||||
|
||||
// GenerateToken 生成JWT token
|
||||
func (s *AuthService) GenerateToken(user *model.User) (string, error) {
|
||||
func (s *AuthService) GenerateToken(user *User) (string, error) {
|
||||
now := time.Now()
|
||||
expiresAt := now.Add(time.Duration(s.cfg.JWT.ExpireHour) * time.Hour)
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
)
|
||||
|
||||
// 错误定义
|
||||
@@ -224,7 +223,7 @@ func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID
|
||||
// CheckBillingEligibility 检查用户是否有资格发起请求
|
||||
// 余额模式:检查缓存余额 > 0
|
||||
// 订阅模式:检查缓存用量未超过限额(Group限额从参数传入)
|
||||
func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user *model.User, apiKey *model.ApiKey, group *model.Group, subscription *model.UserSubscription) error {
|
||||
func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user *User, apiKey *ApiKey, group *Group, subscription *UserSubscription) error {
|
||||
// 判断计费模式
|
||||
isSubscriptionMode := group != nil && group.IsSubscriptionType() && subscription != nil
|
||||
|
||||
@@ -252,7 +251,7 @@ func (s *BillingCacheService) checkBalanceEligibility(ctx context.Context, userI
|
||||
}
|
||||
|
||||
// checkSubscriptionEligibility 检查订阅模式资格
|
||||
func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context, userID int64, group *model.Group, subscription *model.UserSubscription) error {
|
||||
func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context, userID int64, group *Group, subscription *UserSubscription) error {
|
||||
// 获取订阅缓存数据
|
||||
subData, err := s.GetSubscriptionStatus(ctx, userID, group.ID)
|
||||
if err != nil {
|
||||
@@ -262,7 +261,7 @@ func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context,
|
||||
}
|
||||
|
||||
// 检查订阅状态
|
||||
if subData.Status != model.SubscriptionStatusActive {
|
||||
if subData.Status != SubscriptionStatusActive {
|
||||
return ErrSubscriptionInvalid
|
||||
}
|
||||
|
||||
@@ -288,7 +287,7 @@ func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context,
|
||||
}
|
||||
|
||||
// checkSubscriptionLimitsFallback 降级检查订阅限额
|
||||
func (s *BillingCacheService) checkSubscriptionLimitsFallback(subscription *model.UserSubscription, group *model.Group) error {
|
||||
func (s *BillingCacheService) checkSubscriptionLimitsFallback(subscription *UserSubscription, group *Group) error {
|
||||
if subscription == nil {
|
||||
return ErrSubscriptionInvalid
|
||||
}
|
||||
|
||||
@@ -12,8 +12,6 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
)
|
||||
|
||||
type CRSSyncService struct {
|
||||
@@ -217,7 +215,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
||||
),
|
||||
}
|
||||
|
||||
var proxies []model.Proxy
|
||||
var proxies []Proxy
|
||||
if input.SyncProxies {
|
||||
proxies, _ = s.proxyRepo.ListActive(ctx)
|
||||
}
|
||||
@@ -234,7 +232,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
||||
if targetType == "" {
|
||||
targetType = "oauth"
|
||||
}
|
||||
if targetType != model.AccountTypeOAuth && targetType != model.AccountTypeSetupToken {
|
||||
if targetType != AccountTypeOAuth && targetType != AccountTypeSetupToken {
|
||||
item.Action = "skipped"
|
||||
item.Error = "unsupported authType: " + targetType
|
||||
result.Skipped++
|
||||
@@ -305,12 +303,12 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
||||
}
|
||||
|
||||
if existing == nil {
|
||||
account := &model.Account{
|
||||
account := &Account{
|
||||
Name: defaultName(src.Name, src.ID),
|
||||
Platform: model.PlatformAnthropic,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: targetType,
|
||||
Credentials: model.JSONB(credentials),
|
||||
Extra: model.JSONB(extra),
|
||||
Credentials: credentials,
|
||||
Extra: extra,
|
||||
ProxyID: proxyID,
|
||||
Concurrency: concurrency,
|
||||
Priority: priority,
|
||||
@@ -325,7 +323,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
||||
continue
|
||||
}
|
||||
// 🔄 Refresh OAuth token after creation
|
||||
if targetType == model.AccountTypeOAuth {
|
||||
if targetType == AccountTypeOAuth {
|
||||
if refreshedCreds := s.refreshOAuthToken(ctx, account); refreshedCreds != nil {
|
||||
account.Credentials = refreshedCreds
|
||||
_ = s.accountRepo.Update(ctx, account)
|
||||
@@ -338,11 +336,11 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
||||
}
|
||||
|
||||
// Update existing
|
||||
existing.Extra = mergeJSONB(existing.Extra, extra)
|
||||
existing.Extra = mergeMap(existing.Extra, extra)
|
||||
existing.Name = defaultName(src.Name, src.ID)
|
||||
existing.Platform = model.PlatformAnthropic
|
||||
existing.Platform = PlatformAnthropic
|
||||
existing.Type = targetType
|
||||
existing.Credentials = mergeJSONB(existing.Credentials, credentials)
|
||||
existing.Credentials = mergeMap(existing.Credentials, credentials)
|
||||
if proxyID != nil {
|
||||
existing.ProxyID = proxyID
|
||||
}
|
||||
@@ -360,7 +358,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
||||
}
|
||||
|
||||
// 🔄 Refresh OAuth token after update
|
||||
if targetType == model.AccountTypeOAuth {
|
||||
if targetType == AccountTypeOAuth {
|
||||
if refreshedCreds := s.refreshOAuthToken(ctx, existing); refreshedCreds != nil {
|
||||
existing.Credentials = refreshedCreds
|
||||
_ = s.accountRepo.Update(ctx, existing)
|
||||
@@ -422,12 +420,12 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
||||
}
|
||||
|
||||
if existing == nil {
|
||||
account := &model.Account{
|
||||
account := &Account{
|
||||
Name: defaultName(src.Name, src.ID),
|
||||
Platform: model.PlatformAnthropic,
|
||||
Type: model.AccountTypeApiKey,
|
||||
Credentials: model.JSONB(credentials),
|
||||
Extra: model.JSONB(extra),
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeApiKey,
|
||||
Credentials: credentials,
|
||||
Extra: extra,
|
||||
ProxyID: proxyID,
|
||||
Concurrency: concurrency,
|
||||
Priority: priority,
|
||||
@@ -447,11 +445,11 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
||||
continue
|
||||
}
|
||||
|
||||
existing.Extra = mergeJSONB(existing.Extra, extra)
|
||||
existing.Extra = mergeMap(existing.Extra, extra)
|
||||
existing.Name = defaultName(src.Name, src.ID)
|
||||
existing.Platform = model.PlatformAnthropic
|
||||
existing.Type = model.AccountTypeApiKey
|
||||
existing.Credentials = mergeJSONB(existing.Credentials, credentials)
|
||||
existing.Platform = PlatformAnthropic
|
||||
existing.Type = AccountTypeApiKey
|
||||
existing.Credentials = mergeMap(existing.Credentials, credentials)
|
||||
if proxyID != nil {
|
||||
existing.ProxyID = proxyID
|
||||
}
|
||||
@@ -545,12 +543,12 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
||||
}
|
||||
|
||||
if existing == nil {
|
||||
account := &model.Account{
|
||||
account := &Account{
|
||||
Name: defaultName(src.Name, src.ID),
|
||||
Platform: model.PlatformOpenAI,
|
||||
Type: model.AccountTypeOAuth,
|
||||
Credentials: model.JSONB(credentials),
|
||||
Extra: model.JSONB(extra),
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: credentials,
|
||||
Extra: extra,
|
||||
ProxyID: proxyID,
|
||||
Concurrency: concurrency,
|
||||
Priority: priority,
|
||||
@@ -575,11 +573,11 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
||||
continue
|
||||
}
|
||||
|
||||
existing.Extra = mergeJSONB(existing.Extra, extra)
|
||||
existing.Extra = mergeMap(existing.Extra, extra)
|
||||
existing.Name = defaultName(src.Name, src.ID)
|
||||
existing.Platform = model.PlatformOpenAI
|
||||
existing.Type = model.AccountTypeOAuth
|
||||
existing.Credentials = mergeJSONB(existing.Credentials, credentials)
|
||||
existing.Platform = PlatformOpenAI
|
||||
existing.Type = AccountTypeOAuth
|
||||
existing.Credentials = mergeMap(existing.Credentials, credentials)
|
||||
if proxyID != nil {
|
||||
existing.ProxyID = proxyID
|
||||
}
|
||||
@@ -666,12 +664,12 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
||||
}
|
||||
|
||||
if existing == nil {
|
||||
account := &model.Account{
|
||||
account := &Account{
|
||||
Name: defaultName(src.Name, src.ID),
|
||||
Platform: model.PlatformOpenAI,
|
||||
Type: model.AccountTypeApiKey,
|
||||
Credentials: model.JSONB(credentials),
|
||||
Extra: model.JSONB(extra),
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeApiKey,
|
||||
Credentials: credentials,
|
||||
Extra: extra,
|
||||
ProxyID: proxyID,
|
||||
Concurrency: concurrency,
|
||||
Priority: priority,
|
||||
@@ -691,11 +689,11 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
||||
continue
|
||||
}
|
||||
|
||||
existing.Extra = mergeJSONB(existing.Extra, extra)
|
||||
existing.Extra = mergeMap(existing.Extra, extra)
|
||||
existing.Name = defaultName(src.Name, src.ID)
|
||||
existing.Platform = model.PlatformOpenAI
|
||||
existing.Type = model.AccountTypeApiKey
|
||||
existing.Credentials = mergeJSONB(existing.Credentials, credentials)
|
||||
existing.Platform = PlatformOpenAI
|
||||
existing.Type = AccountTypeApiKey
|
||||
existing.Credentials = mergeMap(existing.Credentials, credentials)
|
||||
if proxyID != nil {
|
||||
existing.ProxyID = proxyID
|
||||
}
|
||||
@@ -939,9 +937,8 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// mergeJSONB merges two JSONB maps without removing keys that are absent in updates.
|
||||
func mergeJSONB(existing model.JSONB, updates map[string]any) model.JSONB {
|
||||
out := make(model.JSONB)
|
||||
func mergeMap(existing map[string]any, updates map[string]any) map[string]any {
|
||||
out := make(map[string]any, len(existing)+len(updates))
|
||||
for k, v := range existing {
|
||||
out[k] = v
|
||||
}
|
||||
@@ -951,7 +948,7 @@ func mergeJSONB(existing model.JSONB, updates map[string]any) model.JSONB {
|
||||
return out
|
||||
}
|
||||
|
||||
func (s *CRSSyncService) mapOrCreateProxy(ctx context.Context, enabled bool, cached *[]model.Proxy, src *crsProxy, defaultName string) (*int64, error) {
|
||||
func (s *CRSSyncService) mapOrCreateProxy(ctx context.Context, enabled bool, cached *[]Proxy, src *crsProxy, defaultName string) (*int64, error) {
|
||||
if !enabled || src == nil {
|
||||
return nil, nil
|
||||
}
|
||||
@@ -987,14 +984,14 @@ func (s *CRSSyncService) mapOrCreateProxy(ctx context.Context, enabled bool, cac
|
||||
}
|
||||
|
||||
// Create new proxy
|
||||
proxy := &model.Proxy{
|
||||
proxy := &Proxy{
|
||||
Name: defaultProxyName(defaultName, protocol, host, port),
|
||||
Protocol: protocol,
|
||||
Host: host,
|
||||
Port: port,
|
||||
Username: username,
|
||||
Password: password,
|
||||
Status: model.StatusActive,
|
||||
Status: StatusActive,
|
||||
}
|
||||
if err := s.proxyRepo.Create(ctx, proxy); err != nil {
|
||||
return nil, err
|
||||
@@ -1153,8 +1150,8 @@ func crsExportAccounts(ctx context.Context, client *http.Client, baseURL, adminT
|
||||
|
||||
// refreshOAuthToken attempts to refresh OAuth token for a synced account
|
||||
// Returns updated credentials or nil if refresh failed/not applicable
|
||||
func (s *CRSSyncService) refreshOAuthToken(ctx context.Context, account *model.Account) model.JSONB {
|
||||
if account.Type != model.AccountTypeOAuth {
|
||||
func (s *CRSSyncService) refreshOAuthToken(ctx context.Context, account *Account) map[string]any {
|
||||
if account.Type != AccountTypeOAuth {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1162,7 +1159,7 @@ func (s *CRSSyncService) refreshOAuthToken(ctx context.Context, account *model.A
|
||||
var err error
|
||||
|
||||
switch account.Platform {
|
||||
case model.PlatformAnthropic:
|
||||
case PlatformAnthropic:
|
||||
if s.oauthService == nil {
|
||||
return nil
|
||||
}
|
||||
@@ -1187,7 +1184,7 @@ func (s *CRSSyncService) refreshOAuthToken(ctx context.Context, account *model.A
|
||||
newCredentials["scope"] = tokenInfo.Scope
|
||||
}
|
||||
}
|
||||
case model.PlatformOpenAI:
|
||||
case PlatformOpenAI:
|
||||
if s.openaiOAuthService == nil {
|
||||
return nil
|
||||
}
|
||||
@@ -1227,5 +1224,5 @@ func (s *CRSSyncService) refreshOAuthToken(ctx context.Context, account *model.A
|
||||
return nil
|
||||
}
|
||||
|
||||
return model.JSONB(newCredentials)
|
||||
return newCredentials
|
||||
}
|
||||
|
||||
96
backend/internal/service/domain_constants.go
Normal file
96
backend/internal/service/domain_constants.go
Normal file
@@ -0,0 +1,96 @@
|
||||
package service
|
||||
|
||||
// Status constants
|
||||
const (
|
||||
StatusActive = "active"
|
||||
StatusDisabled = "disabled"
|
||||
StatusError = "error"
|
||||
StatusUnused = "unused"
|
||||
StatusUsed = "used"
|
||||
StatusExpired = "expired"
|
||||
)
|
||||
|
||||
// Role constants
|
||||
const (
|
||||
RoleAdmin = "admin"
|
||||
RoleUser = "user"
|
||||
)
|
||||
|
||||
// Platform constants
|
||||
const (
|
||||
PlatformAnthropic = "anthropic"
|
||||
PlatformOpenAI = "openai"
|
||||
PlatformGemini = "gemini"
|
||||
)
|
||||
|
||||
// Account type constants
|
||||
const (
|
||||
AccountTypeOAuth = "oauth" // OAuth类型账号(full scope: profile + inference)
|
||||
AccountTypeSetupToken = "setup-token" // Setup Token类型账号(inference only scope)
|
||||
AccountTypeApiKey = "apikey" // API Key类型账号
|
||||
)
|
||||
|
||||
// Redeem type constants
|
||||
const (
|
||||
RedeemTypeBalance = "balance"
|
||||
RedeemTypeConcurrency = "concurrency"
|
||||
RedeemTypeSubscription = "subscription"
|
||||
)
|
||||
|
||||
// Admin adjustment type constants
|
||||
const (
|
||||
AdjustmentTypeAdminBalance = "admin_balance" // 管理员调整余额
|
||||
AdjustmentTypeAdminConcurrency = "admin_concurrency" // 管理员调整并发数
|
||||
)
|
||||
|
||||
// Group subscription type constants
|
||||
const (
|
||||
SubscriptionTypeStandard = "standard" // 标准计费模式(按余额扣费)
|
||||
SubscriptionTypeSubscription = "subscription" // 订阅模式(按限额控制)
|
||||
)
|
||||
|
||||
// Subscription status constants
|
||||
const (
|
||||
SubscriptionStatusActive = "active"
|
||||
SubscriptionStatusExpired = "expired"
|
||||
SubscriptionStatusSuspended = "suspended"
|
||||
)
|
||||
|
||||
// Setting keys
|
||||
const (
|
||||
// 注册设置
|
||||
SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册
|
||||
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
|
||||
|
||||
// 邮件服务设置
|
||||
SettingKeySmtpHost = "smtp_host" // SMTP服务器地址
|
||||
SettingKeySmtpPort = "smtp_port" // SMTP端口
|
||||
SettingKeySmtpUsername = "smtp_username" // SMTP用户名
|
||||
SettingKeySmtpPassword = "smtp_password" // SMTP密码(加密存储)
|
||||
SettingKeySmtpFrom = "smtp_from" // 发件人地址
|
||||
SettingKeySmtpFromName = "smtp_from_name" // 发件人名称
|
||||
SettingKeySmtpUseTLS = "smtp_use_tls" // 是否使用TLS
|
||||
|
||||
// Cloudflare Turnstile 设置
|
||||
SettingKeyTurnstileEnabled = "turnstile_enabled" // 是否启用 Turnstile 验证
|
||||
SettingKeyTurnstileSiteKey = "turnstile_site_key" // Turnstile Site Key
|
||||
SettingKeyTurnstileSecretKey = "turnstile_secret_key" // Turnstile Secret Key
|
||||
|
||||
// OEM设置
|
||||
SettingKeySiteName = "site_name" // 网站名称
|
||||
SettingKeySiteLogo = "site_logo" // 网站Logo (base64)
|
||||
SettingKeySiteSubtitle = "site_subtitle" // 网站副标题
|
||||
SettingKeyApiBaseUrl = "api_base_url" // API端点地址(用于客户端配置和导入)
|
||||
SettingKeyContactInfo = "contact_info" // 客服联系方式
|
||||
SettingKeyDocUrl = "doc_url" // 文档链接
|
||||
|
||||
// 默认配置
|
||||
SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量
|
||||
SettingKeyDefaultBalance = "default_balance" // 新用户默认余额
|
||||
|
||||
// 管理员 API Key
|
||||
SettingKeyAdminApiKey = "admin_api_key" // 全局管理员 API Key(用于外部系统集成)
|
||||
)
|
||||
|
||||
// Admin API Key prefix (distinct from user "sk-" keys)
|
||||
const AdminApiKeyPrefix = "admin-"
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -69,13 +68,13 @@ func NewEmailService(settingRepo SettingRepository, cache EmailCache) *EmailServ
|
||||
// GetSmtpConfig 从数据库获取SMTP配置
|
||||
func (s *EmailService) GetSmtpConfig(ctx context.Context) (*SmtpConfig, error) {
|
||||
keys := []string{
|
||||
model.SettingKeySmtpHost,
|
||||
model.SettingKeySmtpPort,
|
||||
model.SettingKeySmtpUsername,
|
||||
model.SettingKeySmtpPassword,
|
||||
model.SettingKeySmtpFrom,
|
||||
model.SettingKeySmtpFromName,
|
||||
model.SettingKeySmtpUseTLS,
|
||||
SettingKeySmtpHost,
|
||||
SettingKeySmtpPort,
|
||||
SettingKeySmtpUsername,
|
||||
SettingKeySmtpPassword,
|
||||
SettingKeySmtpFrom,
|
||||
SettingKeySmtpFromName,
|
||||
SettingKeySmtpUseTLS,
|
||||
}
|
||||
|
||||
settings, err := s.settingRepo.GetMultiple(ctx, keys)
|
||||
@@ -83,27 +82,27 @@ func (s *EmailService) GetSmtpConfig(ctx context.Context) (*SmtpConfig, error) {
|
||||
return nil, fmt.Errorf("get smtp settings: %w", err)
|
||||
}
|
||||
|
||||
host := settings[model.SettingKeySmtpHost]
|
||||
host := settings[SettingKeySmtpHost]
|
||||
if host == "" {
|
||||
return nil, ErrEmailNotConfigured
|
||||
}
|
||||
|
||||
port := 587 // 默认端口
|
||||
if portStr := settings[model.SettingKeySmtpPort]; portStr != "" {
|
||||
if portStr := settings[SettingKeySmtpPort]; portStr != "" {
|
||||
if p, err := strconv.Atoi(portStr); err == nil {
|
||||
port = p
|
||||
}
|
||||
}
|
||||
|
||||
useTLS := settings[model.SettingKeySmtpUseTLS] == "true"
|
||||
useTLS := settings[SettingKeySmtpUseTLS] == "true"
|
||||
|
||||
return &SmtpConfig{
|
||||
Host: host,
|
||||
Port: port,
|
||||
Username: settings[model.SettingKeySmtpUsername],
|
||||
Password: settings[model.SettingKeySmtpPassword],
|
||||
From: settings[model.SettingKeySmtpFrom],
|
||||
FromName: settings[model.SettingKeySmtpFromName],
|
||||
Username: settings[SettingKeySmtpUsername],
|
||||
Password: settings[SettingKeySmtpPassword],
|
||||
From: settings[SettingKeySmtpFrom],
|
||||
FromName: settings[SettingKeySmtpFromName],
|
||||
UseTLS: useTLS,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -17,7 +17,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
@@ -31,6 +30,10 @@ const (
|
||||
stickySessionTTL = time.Hour // 粘性会话TTL
|
||||
)
|
||||
|
||||
// sseDataRe matches SSE data lines with optional whitespace after colon.
|
||||
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
|
||||
var sseDataRe = regexp.MustCompile(`^data:\s*`)
|
||||
|
||||
// allowedHeaders 白名单headers(参考CRS项目)
|
||||
var allowedHeaders = map[string]bool{
|
||||
"accept": true,
|
||||
@@ -265,12 +268,12 @@ func (s *GatewayService) replaceModelInBody(body []byte, newModel string) []byte
|
||||
}
|
||||
|
||||
// SelectAccount 选择账号(粘性会话+优先级)
|
||||
func (s *GatewayService) SelectAccount(ctx context.Context, groupID *int64, sessionHash string) (*model.Account, error) {
|
||||
func (s *GatewayService) SelectAccount(ctx context.Context, groupID *int64, sessionHash string) (*Account, error) {
|
||||
return s.SelectAccountForModel(ctx, groupID, sessionHash, "")
|
||||
}
|
||||
|
||||
// SelectAccountForModel 选择支持指定模型的账号(粘性会话+优先级+模型映射)
|
||||
func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*model.Account, error) {
|
||||
func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*Account, error) {
|
||||
// 1. 查询粘性会话
|
||||
if sessionHash != "" {
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
|
||||
@@ -289,19 +292,19 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int
|
||||
}
|
||||
|
||||
// 2. 获取可调度账号列表(排除限流和过载的账号,仅限 Anthropic 平台)
|
||||
var accounts []model.Account
|
||||
var accounts []Account
|
||||
var err error
|
||||
if groupID != nil {
|
||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, model.PlatformAnthropic)
|
||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformAnthropic)
|
||||
} else {
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, model.PlatformAnthropic)
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformAnthropic)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||
}
|
||||
|
||||
// 3. 按优先级+最久未用选择(考虑模型支持)
|
||||
var selected *model.Account
|
||||
var selected *Account
|
||||
for i := range accounts {
|
||||
acc := &accounts[i]
|
||||
// 检查模型支持
|
||||
@@ -350,12 +353,12 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int
|
||||
}
|
||||
|
||||
// GetAccessToken 获取账号凭证
|
||||
func (s *GatewayService) GetAccessToken(ctx context.Context, account *model.Account) (string, string, error) {
|
||||
func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) {
|
||||
switch account.Type {
|
||||
case model.AccountTypeOAuth, model.AccountTypeSetupToken:
|
||||
case AccountTypeOAuth, AccountTypeSetupToken:
|
||||
// Both oauth and setup-token use OAuth token flow
|
||||
return s.getOAuthToken(ctx, account)
|
||||
case model.AccountTypeApiKey:
|
||||
case AccountTypeApiKey:
|
||||
apiKey := account.GetCredential("api_key")
|
||||
if apiKey == "" {
|
||||
return "", "", errors.New("api_key not found in credentials")
|
||||
@@ -366,7 +369,7 @@ func (s *GatewayService) GetAccessToken(ctx context.Context, account *model.Acco
|
||||
}
|
||||
}
|
||||
|
||||
func (s *GatewayService) getOAuthToken(ctx context.Context, account *model.Account) (string, string, error) {
|
||||
func (s *GatewayService) getOAuthToken(ctx context.Context, account *Account) (string, string, error) {
|
||||
accessToken := account.GetCredential("access_token")
|
||||
if accessToken == "" {
|
||||
return "", "", errors.New("access_token not found in credentials")
|
||||
@@ -381,10 +384,7 @@ const (
|
||||
retryDelay = 3 * time.Second // 重试等待时间
|
||||
)
|
||||
|
||||
// shouldRetryUpstreamError 判断是否应该重试上游错误
|
||||
// OAuth/Setup Token 账号:仅 403 重试
|
||||
// API Key 账号:未配置的错误码重试
|
||||
func (s *GatewayService) shouldRetryUpstreamError(account *model.Account, statusCode int) bool {
|
||||
func (s *GatewayService) shouldRetryUpstreamError(account *Account, statusCode int) bool {
|
||||
// OAuth/Setup Token 账号:仅 403 重试
|
||||
if account.IsOAuth() {
|
||||
return statusCode == 403
|
||||
@@ -395,7 +395,7 @@ func (s *GatewayService) shouldRetryUpstreamError(account *model.Account, status
|
||||
}
|
||||
|
||||
// Forward 转发请求到Claude API
|
||||
func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *model.Account, body []byte) (*ForwardResult, error) {
|
||||
func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
// 解析请求获取model和stream
|
||||
@@ -421,7 +421,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *m
|
||||
|
||||
// 应用模型映射(仅对apikey类型账号)
|
||||
originalModel := req.Model
|
||||
if account.Type == model.AccountTypeApiKey {
|
||||
if account.Type == AccountTypeApiKey {
|
||||
mappedModel := account.GetMappedModel(req.Model)
|
||||
if mappedModel != req.Model {
|
||||
// 替换请求体中的模型名
|
||||
@@ -513,10 +513,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *m
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *model.Account, body []byte, token, tokenType string) (*http.Request, error) {
|
||||
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType string) (*http.Request, error) {
|
||||
// 确定目标URL
|
||||
targetURL := claudeAPIURL
|
||||
if account.Type == model.AccountTypeApiKey {
|
||||
if account.Type == AccountTypeApiKey {
|
||||
baseURL := account.GetBaseURL()
|
||||
targetURL = baseURL + "/v1/messages"
|
||||
}
|
||||
@@ -640,7 +640,7 @@ func (s *GatewayService) getBetaHeader(body []byte, clientBetaHeader string) str
|
||||
return claude.DefaultBetaHeader
|
||||
}
|
||||
|
||||
func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account) (*ForwardResult, error) {
|
||||
func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
// 处理上游错误,标记账号状态
|
||||
@@ -695,7 +695,7 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res
|
||||
// handleRetryExhaustedError 处理重试耗尽后的错误
|
||||
// OAuth 403:标记账号异常
|
||||
// API Key 未配置错误码:仅返回错误,不标记账号
|
||||
func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account) (*ForwardResult, error) {
|
||||
func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
statusCode := resp.StatusCode
|
||||
|
||||
@@ -726,7 +726,7 @@ type streamingResult struct {
|
||||
firstTokenMs *int
|
||||
}
|
||||
|
||||
func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account, startTime time.Time, originalModel, mappedModel string) (*streamingResult, error) {
|
||||
func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string) (*streamingResult, error) {
|
||||
// 更新5h窗口状态
|
||||
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
|
||||
|
||||
@@ -758,26 +758,33 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
|
||||
// 如果有模型映射,替换响应中的model字段
|
||||
if needModelReplace && strings.HasPrefix(line, "data: ") {
|
||||
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
|
||||
}
|
||||
// Extract data from SSE line (supports both "data: " and "data:" formats)
|
||||
if sseDataRe.MatchString(line) {
|
||||
data := sseDataRe.ReplaceAllString(line, "")
|
||||
|
||||
// 转发行
|
||||
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
||||
}
|
||||
flusher.Flush()
|
||||
// 如果有模型映射,替换响应中的model字段
|
||||
if needModelReplace {
|
||||
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
|
||||
}
|
||||
|
||||
// 转发行
|
||||
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
||||
}
|
||||
flusher.Flush()
|
||||
|
||||
// 解析usage数据
|
||||
if strings.HasPrefix(line, "data: ") {
|
||||
data := line[6:]
|
||||
// 记录首字时间:第一个有效的 content_block_delta 或 message_start
|
||||
if firstTokenMs == nil && data != "" && data != "[DONE]" {
|
||||
ms := int(time.Since(startTime).Milliseconds())
|
||||
firstTokenMs = &ms
|
||||
}
|
||||
s.parseSSEUsage(data, usage)
|
||||
} else {
|
||||
// 非 data 行直接转发
|
||||
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -790,7 +797,10 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
||||
|
||||
// replaceModelInSSELine 替换SSE数据行中的model字段
|
||||
func (s *GatewayService) replaceModelInSSELine(line, fromModel, toModel string) string {
|
||||
data := line[6:] // 去掉 "data: " 前缀
|
||||
if !sseDataRe.MatchString(line) {
|
||||
return line
|
||||
}
|
||||
data := sseDataRe.ReplaceAllString(line, "")
|
||||
if data == "" || data == "[DONE]" {
|
||||
return line
|
||||
}
|
||||
@@ -865,7 +875,7 @@ func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account, originalModel, mappedModel string) (*ClaudeUsage, error) {
|
||||
func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*ClaudeUsage, error) {
|
||||
// 更新5h窗口状态
|
||||
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
|
||||
|
||||
@@ -924,10 +934,10 @@ func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toMo
|
||||
// RecordUsageInput 记录使用量的输入参数
|
||||
type RecordUsageInput struct {
|
||||
Result *ForwardResult
|
||||
ApiKey *model.ApiKey
|
||||
User *model.User
|
||||
Account *model.Account
|
||||
Subscription *model.UserSubscription // 可选:订阅信息
|
||||
ApiKey *ApiKey
|
||||
User *User
|
||||
Account *Account
|
||||
Subscription *UserSubscription // 可选:订阅信息
|
||||
}
|
||||
|
||||
// RecordUsage 记录使用量并扣费(或更新订阅用量)
|
||||
@@ -961,14 +971,14 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
|
||||
// 判断计费方式:订阅模式 vs 余额模式
|
||||
isSubscriptionBilling := subscription != nil && apiKey.Group != nil && apiKey.Group.IsSubscriptionType()
|
||||
billingType := model.BillingTypeBalance
|
||||
billingType := BillingTypeBalance
|
||||
if isSubscriptionBilling {
|
||||
billingType = model.BillingTypeSubscription
|
||||
billingType = BillingTypeSubscription
|
||||
}
|
||||
|
||||
// 创建使用日志
|
||||
durationMs := int(result.Duration.Milliseconds())
|
||||
usageLog := &model.UsageLog{
|
||||
usageLog := &UsageLog{
|
||||
UserID: user.ID,
|
||||
ApiKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
@@ -1047,9 +1057,9 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
|
||||
// ForwardCountTokens 转发 count_tokens 请求到上游 API
|
||||
// 特点:不记录使用量、仅支持非流式响应
|
||||
func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *model.Account, body []byte) error {
|
||||
func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *Account, body []byte) error {
|
||||
// 应用模型映射(仅对 apikey 类型账号)
|
||||
if account.Type == model.AccountTypeApiKey {
|
||||
if account.Type == AccountTypeApiKey {
|
||||
var req struct {
|
||||
Model string `json:"model"`
|
||||
}
|
||||
@@ -1122,10 +1132,10 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
}
|
||||
|
||||
// buildCountTokensRequest 构建 count_tokens 上游请求
|
||||
func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *model.Account, body []byte, token, tokenType string) (*http.Request, error) {
|
||||
func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType string) (*http.Request, error) {
|
||||
// 确定目标 URL
|
||||
targetURL := claudeAPICountTokensURL
|
||||
if account.Type == model.AccountTypeApiKey {
|
||||
if account.Type == AccountTypeApiKey {
|
||||
baseURL := account.GetBaseURL()
|
||||
targetURL = baseURL + "/v1/messages/count_tokens"
|
||||
}
|
||||
|
||||
48
backend/internal/service/group.go
Normal file
48
backend/internal/service/group.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package service
|
||||
|
||||
import "time"
|
||||
|
||||
type Group struct {
|
||||
ID int64
|
||||
Name string
|
||||
Description string
|
||||
Platform string
|
||||
RateMultiplier float64
|
||||
IsExclusive bool
|
||||
Status string
|
||||
|
||||
SubscriptionType string
|
||||
DailyLimitUSD *float64
|
||||
WeeklyLimitUSD *float64
|
||||
MonthlyLimitUSD *float64
|
||||
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
|
||||
AccountGroups []AccountGroup
|
||||
AccountCount int64
|
||||
}
|
||||
|
||||
func (g *Group) IsActive() bool {
|
||||
return g.Status == StatusActive
|
||||
}
|
||||
|
||||
func (g *Group) IsSubscriptionType() bool {
|
||||
return g.SubscriptionType == SubscriptionTypeSubscription
|
||||
}
|
||||
|
||||
func (g *Group) IsFreeSubscription() bool {
|
||||
return g.IsSubscriptionType() && g.RateMultiplier == 0
|
||||
}
|
||||
|
||||
func (g *Group) HasDailyLimit() bool {
|
||||
return g.DailyLimitUSD != nil && *g.DailyLimitUSD > 0
|
||||
}
|
||||
|
||||
func (g *Group) HasWeeklyLimit() bool {
|
||||
return g.WeeklyLimitUSD != nil && *g.WeeklyLimitUSD > 0
|
||||
}
|
||||
|
||||
func (g *Group) HasMonthlyLimit() bool {
|
||||
return g.MonthlyLimitUSD != nil && *g.MonthlyLimitUSD > 0
|
||||
}
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"fmt"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
@@ -15,16 +14,16 @@ var (
|
||||
)
|
||||
|
||||
type GroupRepository interface {
|
||||
Create(ctx context.Context, group *model.Group) error
|
||||
GetByID(ctx context.Context, id int64) (*model.Group, error)
|
||||
Update(ctx context.Context, group *model.Group) error
|
||||
Create(ctx context.Context, group *Group) error
|
||||
GetByID(ctx context.Context, id int64) (*Group, error)
|
||||
Update(ctx context.Context, group *Group) error
|
||||
Delete(ctx context.Context, id int64) error
|
||||
DeleteCascade(ctx context.Context, id int64) ([]int64, error)
|
||||
|
||||
List(ctx context.Context, params pagination.PaginationParams) ([]model.Group, *pagination.PaginationResult, error)
|
||||
ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]model.Group, *pagination.PaginationResult, error)
|
||||
ListActive(ctx context.Context) ([]model.Group, error)
|
||||
ListActiveByPlatform(ctx context.Context, platform string) ([]model.Group, error)
|
||||
List(ctx context.Context, params pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error)
|
||||
ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error)
|
||||
ListActive(ctx context.Context) ([]Group, error)
|
||||
ListActiveByPlatform(ctx context.Context, platform string) ([]Group, error)
|
||||
|
||||
ExistsByName(ctx context.Context, name string) (bool, error)
|
||||
GetAccountCount(ctx context.Context, groupID int64) (int64, error)
|
||||
@@ -61,7 +60,7 @@ func NewGroupService(groupRepo GroupRepository) *GroupService {
|
||||
}
|
||||
|
||||
// Create 创建分组
|
||||
func (s *GroupService) Create(ctx context.Context, req CreateGroupRequest) (*model.Group, error) {
|
||||
func (s *GroupService) Create(ctx context.Context, req CreateGroupRequest) (*Group, error) {
|
||||
// 检查名称是否已存在
|
||||
exists, err := s.groupRepo.ExistsByName(ctx, req.Name)
|
||||
if err != nil {
|
||||
@@ -72,12 +71,14 @@ func (s *GroupService) Create(ctx context.Context, req CreateGroupRequest) (*mod
|
||||
}
|
||||
|
||||
// 创建分组
|
||||
group := &model.Group{
|
||||
Name: req.Name,
|
||||
Description: req.Description,
|
||||
RateMultiplier: req.RateMultiplier,
|
||||
IsExclusive: req.IsExclusive,
|
||||
Status: model.StatusActive,
|
||||
group := &Group{
|
||||
Name: req.Name,
|
||||
Description: req.Description,
|
||||
Platform: PlatformAnthropic,
|
||||
RateMultiplier: req.RateMultiplier,
|
||||
IsExclusive: req.IsExclusive,
|
||||
Status: StatusActive,
|
||||
SubscriptionType: SubscriptionTypeStandard,
|
||||
}
|
||||
|
||||
if err := s.groupRepo.Create(ctx, group); err != nil {
|
||||
@@ -88,7 +89,7 @@ func (s *GroupService) Create(ctx context.Context, req CreateGroupRequest) (*mod
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取分组
|
||||
func (s *GroupService) GetByID(ctx context.Context, id int64) (*model.Group, error) {
|
||||
func (s *GroupService) GetByID(ctx context.Context, id int64) (*Group, error) {
|
||||
group, err := s.groupRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get group: %w", err)
|
||||
@@ -97,7 +98,7 @@ func (s *GroupService) GetByID(ctx context.Context, id int64) (*model.Group, err
|
||||
}
|
||||
|
||||
// List 获取分组列表
|
||||
func (s *GroupService) List(ctx context.Context, params pagination.PaginationParams) ([]model.Group, *pagination.PaginationResult, error) {
|
||||
func (s *GroupService) List(ctx context.Context, params pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
|
||||
groups, pagination, err := s.groupRepo.List(ctx, params)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("list groups: %w", err)
|
||||
@@ -106,7 +107,7 @@ func (s *GroupService) List(ctx context.Context, params pagination.PaginationPar
|
||||
}
|
||||
|
||||
// ListActive 获取活跃分组列表
|
||||
func (s *GroupService) ListActive(ctx context.Context) ([]model.Group, error) {
|
||||
func (s *GroupService) ListActive(ctx context.Context) ([]Group, error) {
|
||||
groups, err := s.groupRepo.ListActive(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list active groups: %w", err)
|
||||
@@ -115,7 +116,7 @@ func (s *GroupService) ListActive(ctx context.Context) ([]model.Group, error) {
|
||||
}
|
||||
|
||||
// Update 更新分组
|
||||
func (s *GroupService) Update(ctx context.Context, id int64, req UpdateGroupRequest) (*model.Group, error) {
|
||||
func (s *GroupService) Update(ctx context.Context, id int64, req UpdateGroupRequest) (*Group, error) {
|
||||
group, err := s.groupRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get group: %w", err)
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
)
|
||||
@@ -274,7 +273,7 @@ func (s *OAuthService) RefreshToken(ctx context.Context, refreshToken string, pr
|
||||
}
|
||||
|
||||
// RefreshAccountToken refreshes token for an account
|
||||
func (s *OAuthService) RefreshAccountToken(ctx context.Context, account *model.Account) (*TokenInfo, error) {
|
||||
func (s *OAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*TokenInfo, error) {
|
||||
refreshToken := account.GetCredential("refresh_token")
|
||||
if refreshToken == "" {
|
||||
return nil, fmt.Errorf("no refresh token available")
|
||||
|
||||
@@ -11,12 +11,12 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
@@ -28,6 +28,10 @@ const (
|
||||
openaiStickySessionTTL = time.Hour // 粘性会话TTL
|
||||
)
|
||||
|
||||
// openaiSSEDataRe matches SSE data lines with optional whitespace after colon.
|
||||
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
|
||||
var openaiSSEDataRe = regexp.MustCompile(`^data:\s*`)
|
||||
|
||||
// OpenAI allowed headers whitelist (for non-OAuth accounts)
|
||||
var openaiAllowedHeaders = map[string]bool{
|
||||
"accept-language": true,
|
||||
@@ -119,12 +123,12 @@ func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context) string {
|
||||
}
|
||||
|
||||
// SelectAccount selects an OpenAI account with sticky session support
|
||||
func (s *OpenAIGatewayService) SelectAccount(ctx context.Context, groupID *int64, sessionHash string) (*model.Account, error) {
|
||||
func (s *OpenAIGatewayService) SelectAccount(ctx context.Context, groupID *int64, sessionHash string) (*Account, error) {
|
||||
return s.SelectAccountForModel(ctx, groupID, sessionHash, "")
|
||||
}
|
||||
|
||||
// SelectAccountForModel selects an account supporting the requested model
|
||||
func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*model.Account, error) {
|
||||
func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*Account, error) {
|
||||
// 1. Check sticky session
|
||||
if sessionHash != "" {
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, "openai:"+sessionHash)
|
||||
@@ -139,19 +143,19 @@ func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupI
|
||||
}
|
||||
|
||||
// 2. Get schedulable OpenAI accounts
|
||||
var accounts []model.Account
|
||||
var accounts []Account
|
||||
var err error
|
||||
if groupID != nil {
|
||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, model.PlatformOpenAI)
|
||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformOpenAI)
|
||||
} else {
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, model.PlatformOpenAI)
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformOpenAI)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||
}
|
||||
|
||||
// 3. Select by priority + LRU
|
||||
var selected *model.Account
|
||||
var selected *Account
|
||||
for i := range accounts {
|
||||
acc := &accounts[i]
|
||||
// Check model support
|
||||
@@ -198,15 +202,15 @@ func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupI
|
||||
}
|
||||
|
||||
// GetAccessToken gets the access token for an OpenAI account
|
||||
func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *model.Account) (string, string, error) {
|
||||
func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) {
|
||||
switch account.Type {
|
||||
case model.AccountTypeOAuth:
|
||||
case AccountTypeOAuth:
|
||||
accessToken := account.GetOpenAIAccessToken()
|
||||
if accessToken == "" {
|
||||
return "", "", errors.New("access_token not found in credentials")
|
||||
}
|
||||
return accessToken, "oauth", nil
|
||||
case model.AccountTypeApiKey:
|
||||
case AccountTypeApiKey:
|
||||
apiKey := account.GetOpenAIApiKey()
|
||||
if apiKey == "" {
|
||||
return "", "", errors.New("api_key not found in credentials")
|
||||
@@ -218,7 +222,7 @@ func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *mode
|
||||
}
|
||||
|
||||
// Forward forwards request to OpenAI API
|
||||
func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, account *model.Account, body []byte) (*OpenAIForwardResult, error) {
|
||||
func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*OpenAIForwardResult, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
// Parse request body once (avoid multiple parse/serialize cycles)
|
||||
@@ -243,7 +247,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
}
|
||||
|
||||
// For OAuth accounts using ChatGPT internal API, add store: false
|
||||
if account.Type == model.AccountTypeOAuth {
|
||||
if account.Type == AccountTypeOAuth {
|
||||
reqBody["store"] = false
|
||||
bodyModified = true
|
||||
}
|
||||
@@ -305,7 +309,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
}
|
||||
|
||||
// Extract and save Codex usage snapshot from response headers (for OAuth accounts)
|
||||
if account.Type == model.AccountTypeOAuth {
|
||||
if account.Type == AccountTypeOAuth {
|
||||
if snapshot := extractCodexUsageHeaders(resp.Header); snapshot != nil {
|
||||
s.updateCodexUsageSnapshot(ctx, account.ID, snapshot)
|
||||
}
|
||||
@@ -321,14 +325,14 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *model.Account, body []byte, token string, isStream bool) (*http.Request, error) {
|
||||
func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token string, isStream bool) (*http.Request, error) {
|
||||
// Determine target URL based on account type
|
||||
var targetURL string
|
||||
switch account.Type {
|
||||
case model.AccountTypeOAuth:
|
||||
case AccountTypeOAuth:
|
||||
// OAuth accounts use ChatGPT internal API
|
||||
targetURL = chatgptCodexURL
|
||||
case model.AccountTypeApiKey:
|
||||
case AccountTypeApiKey:
|
||||
// API Key accounts use Platform API or custom base URL
|
||||
baseURL := account.GetOpenAIBaseURL()
|
||||
if baseURL != "" {
|
||||
@@ -349,7 +353,7 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
|
||||
req.Header.Set("authorization", "Bearer "+token)
|
||||
|
||||
// Set headers specific to OAuth accounts (ChatGPT internal API)
|
||||
if account.Type == model.AccountTypeOAuth {
|
||||
if account.Type == AccountTypeOAuth {
|
||||
// Required: set Host for ChatGPT API (must use req.Host, not Header.Set)
|
||||
req.Host = "chatgpt.com"
|
||||
// Required: set chatgpt-account-id header
|
||||
@@ -389,7 +393,7 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account) (*OpenAIForwardResult, error) {
|
||||
func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*OpenAIForwardResult, error) {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
// Check custom error codes
|
||||
@@ -445,7 +449,7 @@ type openaiStreamingResult struct {
|
||||
firstTokenMs *int
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account, startTime time.Time, originalModel, mappedModel string) (*openaiStreamingResult, error) {
|
||||
func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string) (*openaiStreamingResult, error) {
|
||||
// Set SSE response headers
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
@@ -473,26 +477,33 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
|
||||
// Replace model in response if needed
|
||||
if needModelReplace && strings.HasPrefix(line, "data: ") {
|
||||
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
|
||||
}
|
||||
// Extract data from SSE line (supports both "data: " and "data:" formats)
|
||||
if openaiSSEDataRe.MatchString(line) {
|
||||
data := openaiSSEDataRe.ReplaceAllString(line, "")
|
||||
|
||||
// Forward line
|
||||
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
||||
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
||||
}
|
||||
flusher.Flush()
|
||||
// Replace model in response if needed
|
||||
if needModelReplace {
|
||||
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
|
||||
}
|
||||
|
||||
// Forward line
|
||||
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
||||
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
||||
}
|
||||
flusher.Flush()
|
||||
|
||||
// Parse usage data
|
||||
if strings.HasPrefix(line, "data: ") {
|
||||
data := line[6:]
|
||||
// Record first token time
|
||||
if firstTokenMs == nil && data != "" && data != "[DONE]" {
|
||||
ms := int(time.Since(startTime).Milliseconds())
|
||||
firstTokenMs = &ms
|
||||
}
|
||||
s.parseSSEUsage(data, usage)
|
||||
} else {
|
||||
// Forward non-data lines as-is
|
||||
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
|
||||
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -504,7 +515,10 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) replaceModelInSSELine(line, fromModel, toModel string) string {
|
||||
data := line[6:]
|
||||
if !openaiSSEDataRe.MatchString(line) {
|
||||
return line
|
||||
}
|
||||
data := openaiSSEDataRe.ReplaceAllString(line, "")
|
||||
if data == "" || data == "[DONE]" {
|
||||
return line
|
||||
}
|
||||
@@ -561,7 +575,7 @@ func (s *OpenAIGatewayService) parseSSEUsage(data string, usage *OpenAIUsage) {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *model.Account, originalModel, mappedModel string) (*OpenAIUsage, error) {
|
||||
func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*OpenAIUsage, error) {
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -627,10 +641,10 @@ func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel
|
||||
// OpenAIRecordUsageInput input for recording usage
|
||||
type OpenAIRecordUsageInput struct {
|
||||
Result *OpenAIForwardResult
|
||||
ApiKey *model.ApiKey
|
||||
User *model.User
|
||||
Account *model.Account
|
||||
Subscription *model.UserSubscription
|
||||
ApiKey *ApiKey
|
||||
User *User
|
||||
Account *Account
|
||||
Subscription *UserSubscription
|
||||
}
|
||||
|
||||
// RecordUsage records usage and deducts balance
|
||||
@@ -669,14 +683,14 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
|
||||
// Determine billing type
|
||||
isSubscriptionBilling := subscription != nil && apiKey.Group != nil && apiKey.Group.IsSubscriptionType()
|
||||
billingType := model.BillingTypeBalance
|
||||
billingType := BillingTypeBalance
|
||||
if isSubscriptionBilling {
|
||||
billingType = model.BillingTypeSubscription
|
||||
billingType = BillingTypeSubscription
|
||||
}
|
||||
|
||||
// Create usage log
|
||||
durationMs := int(result.Duration.Milliseconds())
|
||||
usageLog := &model.UsageLog{
|
||||
usageLog := &UsageLog{
|
||||
UserID: user.ID,
|
||||
ApiKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
)
|
||||
|
||||
@@ -200,7 +199,7 @@ func (s *OpenAIOAuthService) RefreshToken(ctx context.Context, refreshToken stri
|
||||
}
|
||||
|
||||
// RefreshAccountToken refreshes token for an OpenAI account
|
||||
func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *model.Account) (*OpenAITokenInfo, error) {
|
||||
func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*OpenAITokenInfo, error) {
|
||||
if !account.IsOpenAI() {
|
||||
return nil, fmt.Errorf("account is not an OpenAI account")
|
||||
}
|
||||
|
||||
35
backend/internal/service/proxy.go
Normal file
35
backend/internal/service/proxy.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Proxy struct {
|
||||
ID int64
|
||||
Name string
|
||||
Protocol string
|
||||
Host string
|
||||
Port int
|
||||
Username string
|
||||
Password string
|
||||
Status string
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
func (p *Proxy) IsActive() bool {
|
||||
return p.Status == StatusActive
|
||||
}
|
||||
|
||||
func (p *Proxy) URL() string {
|
||||
if p.Username != "" && p.Password != "" {
|
||||
return fmt.Sprintf("%s://%s:%s@%s:%d", p.Protocol, p.Username, p.Password, p.Host, p.Port)
|
||||
}
|
||||
return fmt.Sprintf("%s://%s:%d", p.Protocol, p.Host, p.Port)
|
||||
}
|
||||
|
||||
type ProxyWithAccountCount struct {
|
||||
Proxy
|
||||
AccountCount int64
|
||||
}
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"fmt"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
@@ -14,15 +13,15 @@ var (
|
||||
)
|
||||
|
||||
type ProxyRepository interface {
|
||||
Create(ctx context.Context, proxy *model.Proxy) error
|
||||
GetByID(ctx context.Context, id int64) (*model.Proxy, error)
|
||||
Update(ctx context.Context, proxy *model.Proxy) error
|
||||
Create(ctx context.Context, proxy *Proxy) error
|
||||
GetByID(ctx context.Context, id int64) (*Proxy, error)
|
||||
Update(ctx context.Context, proxy *Proxy) error
|
||||
Delete(ctx context.Context, id int64) error
|
||||
|
||||
List(ctx context.Context, params pagination.PaginationParams) ([]model.Proxy, *pagination.PaginationResult, error)
|
||||
ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]model.Proxy, *pagination.PaginationResult, error)
|
||||
ListActive(ctx context.Context) ([]model.Proxy, error)
|
||||
ListActiveWithAccountCount(ctx context.Context) ([]model.ProxyWithAccountCount, error)
|
||||
List(ctx context.Context, params pagination.PaginationParams) ([]Proxy, *pagination.PaginationResult, error)
|
||||
ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]Proxy, *pagination.PaginationResult, error)
|
||||
ListActive(ctx context.Context) ([]Proxy, error)
|
||||
ListActiveWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error)
|
||||
|
||||
ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error)
|
||||
CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error)
|
||||
@@ -62,16 +61,16 @@ func NewProxyService(proxyRepo ProxyRepository) *ProxyService {
|
||||
}
|
||||
|
||||
// Create 创建代理
|
||||
func (s *ProxyService) Create(ctx context.Context, req CreateProxyRequest) (*model.Proxy, error) {
|
||||
func (s *ProxyService) Create(ctx context.Context, req CreateProxyRequest) (*Proxy, error) {
|
||||
// 创建代理
|
||||
proxy := &model.Proxy{
|
||||
proxy := &Proxy{
|
||||
Name: req.Name,
|
||||
Protocol: req.Protocol,
|
||||
Host: req.Host,
|
||||
Port: req.Port,
|
||||
Username: req.Username,
|
||||
Password: req.Password,
|
||||
Status: model.StatusActive,
|
||||
Status: StatusActive,
|
||||
}
|
||||
|
||||
if err := s.proxyRepo.Create(ctx, proxy); err != nil {
|
||||
@@ -82,7 +81,7 @@ func (s *ProxyService) Create(ctx context.Context, req CreateProxyRequest) (*mod
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取代理
|
||||
func (s *ProxyService) GetByID(ctx context.Context, id int64) (*model.Proxy, error) {
|
||||
func (s *ProxyService) GetByID(ctx context.Context, id int64) (*Proxy, error) {
|
||||
proxy, err := s.proxyRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get proxy: %w", err)
|
||||
@@ -91,7 +90,7 @@ func (s *ProxyService) GetByID(ctx context.Context, id int64) (*model.Proxy, err
|
||||
}
|
||||
|
||||
// List 获取代理列表
|
||||
func (s *ProxyService) List(ctx context.Context, params pagination.PaginationParams) ([]model.Proxy, *pagination.PaginationResult, error) {
|
||||
func (s *ProxyService) List(ctx context.Context, params pagination.PaginationParams) ([]Proxy, *pagination.PaginationResult, error) {
|
||||
proxies, pagination, err := s.proxyRepo.List(ctx, params)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("list proxies: %w", err)
|
||||
@@ -100,7 +99,7 @@ func (s *ProxyService) List(ctx context.Context, params pagination.PaginationPar
|
||||
}
|
||||
|
||||
// ListActive 获取活跃代理列表
|
||||
func (s *ProxyService) ListActive(ctx context.Context) ([]model.Proxy, error) {
|
||||
func (s *ProxyService) ListActive(ctx context.Context) ([]Proxy, error) {
|
||||
proxies, err := s.proxyRepo.ListActive(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list active proxies: %w", err)
|
||||
@@ -109,7 +108,7 @@ func (s *ProxyService) ListActive(ctx context.Context) ([]model.Proxy, error) {
|
||||
}
|
||||
|
||||
// Update 更新代理
|
||||
func (s *ProxyService) Update(ctx context.Context, id int64, req UpdateProxyRequest) (*model.Proxy, error) {
|
||||
func (s *ProxyService) Update(ctx context.Context, id int64, req UpdateProxyRequest) (*Proxy, error) {
|
||||
proxy, err := s.proxyRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get proxy: %w", err)
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
)
|
||||
|
||||
// RateLimitService 处理限流和过载状态管理
|
||||
@@ -27,7 +26,7 @@ func NewRateLimitService(accountRepo AccountRepository, cfg *config.Config) *Rat
|
||||
|
||||
// HandleUpstreamError 处理上游错误响应,标记账号状态
|
||||
// 返回是否应该停止该账号的调度
|
||||
func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *model.Account, statusCode int, headers http.Header, responseBody []byte) (shouldDisable bool) {
|
||||
func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, responseBody []byte) (shouldDisable bool) {
|
||||
// apikey 类型账号:检查自定义错误码配置
|
||||
// 如果启用且错误码不在列表中,则不处理(不停止调度、不标记限流/过载)
|
||||
if !account.ShouldHandleErrorCode(statusCode) {
|
||||
@@ -60,7 +59,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *mod
|
||||
}
|
||||
|
||||
// handleAuthError 处理认证类错误(401/403),停止账号调度
|
||||
func (s *RateLimitService) handleAuthError(ctx context.Context, account *model.Account, errorMsg string) {
|
||||
func (s *RateLimitService) handleAuthError(ctx context.Context, account *Account, errorMsg string) {
|
||||
if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil {
|
||||
log.Printf("SetError failed for account %d: %v", account.ID, err)
|
||||
return
|
||||
@@ -70,7 +69,7 @@ func (s *RateLimitService) handleAuthError(ctx context.Context, account *model.A
|
||||
|
||||
// handle429 处理429限流错误
|
||||
// 解析响应头获取重置时间,标记账号为限流状态
|
||||
func (s *RateLimitService) handle429(ctx context.Context, account *model.Account, headers http.Header) {
|
||||
func (s *RateLimitService) handle429(ctx context.Context, account *Account, headers http.Header) {
|
||||
// 解析重置时间戳
|
||||
resetTimestamp := headers.Get("anthropic-ratelimit-unified-reset")
|
||||
if resetTimestamp == "" {
|
||||
@@ -113,7 +112,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *model.Account
|
||||
|
||||
// handle529 处理529过载错误
|
||||
// 根据配置设置过载冷却时间
|
||||
func (s *RateLimitService) handle529(ctx context.Context, account *model.Account) {
|
||||
func (s *RateLimitService) handle529(ctx context.Context, account *Account) {
|
||||
cooldownMinutes := s.cfg.RateLimit.OverloadCooldownMinutes
|
||||
if cooldownMinutes <= 0 {
|
||||
cooldownMinutes = 10 // 默认10分钟
|
||||
@@ -129,7 +128,7 @@ func (s *RateLimitService) handle529(ctx context.Context, account *model.Account
|
||||
}
|
||||
|
||||
// UpdateSessionWindow 从成功响应更新5h窗口状态
|
||||
func (s *RateLimitService) UpdateSessionWindow(ctx context.Context, account *model.Account, headers http.Header) {
|
||||
func (s *RateLimitService) UpdateSessionWindow(ctx context.Context, account *Account, headers http.Header) {
|
||||
status := headers.Get("anthropic-ratelimit-unified-5h-status")
|
||||
if status == "" {
|
||||
return
|
||||
|
||||
41
backend/internal/service/redeem_code.go
Normal file
41
backend/internal/service/redeem_code.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"time"
|
||||
)
|
||||
|
||||
type RedeemCode struct {
|
||||
ID int64
|
||||
Code string
|
||||
Type string
|
||||
Value float64
|
||||
Status string
|
||||
UsedBy *int64
|
||||
UsedAt *time.Time
|
||||
Notes string
|
||||
CreatedAt time.Time
|
||||
|
||||
GroupID *int64
|
||||
ValidityDays int
|
||||
|
||||
User *User
|
||||
Group *Group
|
||||
}
|
||||
|
||||
func (r *RedeemCode) IsUsed() bool {
|
||||
return r.Status == StatusUsed
|
||||
}
|
||||
|
||||
func (r *RedeemCode) CanUse() bool {
|
||||
return r.Status == StatusUnused
|
||||
}
|
||||
|
||||
func GenerateRedeemCode() (string, error) {
|
||||
b := make([]byte, 16)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(b), nil
|
||||
}
|
||||
@@ -10,9 +10,7 @@ import (
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -39,17 +37,17 @@ type RedeemCache interface {
|
||||
}
|
||||
|
||||
type RedeemCodeRepository interface {
|
||||
Create(ctx context.Context, code *model.RedeemCode) error
|
||||
CreateBatch(ctx context.Context, codes []model.RedeemCode) error
|
||||
GetByID(ctx context.Context, id int64) (*model.RedeemCode, error)
|
||||
GetByCode(ctx context.Context, code string) (*model.RedeemCode, error)
|
||||
Update(ctx context.Context, code *model.RedeemCode) error
|
||||
Create(ctx context.Context, code *RedeemCode) error
|
||||
CreateBatch(ctx context.Context, codes []RedeemCode) error
|
||||
GetByID(ctx context.Context, id int64) (*RedeemCode, error)
|
||||
GetByCode(ctx context.Context, code string) (*RedeemCode, error)
|
||||
Update(ctx context.Context, code *RedeemCode) error
|
||||
Delete(ctx context.Context, id int64) error
|
||||
Use(ctx context.Context, id, userID int64) error
|
||||
|
||||
List(ctx context.Context, params pagination.PaginationParams) ([]model.RedeemCode, *pagination.PaginationResult, error)
|
||||
ListWithFilters(ctx context.Context, params pagination.PaginationParams, codeType, status, search string) ([]model.RedeemCode, *pagination.PaginationResult, error)
|
||||
ListByUser(ctx context.Context, userID int64, limit int) ([]model.RedeemCode, error)
|
||||
List(ctx context.Context, params pagination.PaginationParams) ([]RedeemCode, *pagination.PaginationResult, error)
|
||||
ListWithFilters(ctx context.Context, params pagination.PaginationParams, codeType, status, search string) ([]RedeemCode, *pagination.PaginationResult, error)
|
||||
ListByUser(ctx context.Context, userID int64, limit int) ([]RedeemCode, error)
|
||||
}
|
||||
|
||||
// GenerateCodesRequest 生成兑换码请求
|
||||
@@ -116,7 +114,7 @@ func (s *RedeemService) GenerateRandomCode() (string, error) {
|
||||
}
|
||||
|
||||
// GenerateCodes 批量生成兑换码
|
||||
func (s *RedeemService) GenerateCodes(ctx context.Context, req GenerateCodesRequest) ([]model.RedeemCode, error) {
|
||||
func (s *RedeemService) GenerateCodes(ctx context.Context, req GenerateCodesRequest) ([]RedeemCode, error) {
|
||||
if req.Count <= 0 {
|
||||
return nil, errors.New("count must be greater than 0")
|
||||
}
|
||||
@@ -131,21 +129,21 @@ func (s *RedeemService) GenerateCodes(ctx context.Context, req GenerateCodesRequ
|
||||
|
||||
codeType := req.Type
|
||||
if codeType == "" {
|
||||
codeType = model.RedeemTypeBalance
|
||||
codeType = RedeemTypeBalance
|
||||
}
|
||||
|
||||
codes := make([]model.RedeemCode, 0, req.Count)
|
||||
codes := make([]RedeemCode, 0, req.Count)
|
||||
for i := 0; i < req.Count; i++ {
|
||||
code, err := s.GenerateRandomCode()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate code: %w", err)
|
||||
}
|
||||
|
||||
codes = append(codes, model.RedeemCode{
|
||||
codes = append(codes, RedeemCode{
|
||||
Code: code,
|
||||
Type: codeType,
|
||||
Value: req.Value,
|
||||
Status: model.StatusUnused,
|
||||
Status: StatusUnused,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -164,7 +162,7 @@ func (s *RedeemService) checkRedeemRateLimit(ctx context.Context, userID int64)
|
||||
}
|
||||
|
||||
count, err := s.cache.GetRedeemAttemptCount(ctx, userID)
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
if err != nil {
|
||||
// Redis 出错时不阻止用户操作
|
||||
return nil
|
||||
}
|
||||
@@ -210,7 +208,7 @@ func (s *RedeemService) releaseRedeemLock(ctx context.Context, code string) {
|
||||
}
|
||||
|
||||
// Redeem 使用兑换码
|
||||
func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (*model.RedeemCode, error) {
|
||||
func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (*RedeemCode, error) {
|
||||
// 检查限流
|
||||
if err := s.checkRedeemRateLimit(ctx, userID); err != nil {
|
||||
return nil, err
|
||||
@@ -239,7 +237,7 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
|
||||
}
|
||||
|
||||
// 验证兑换码类型的前置条件
|
||||
if redeemCode.Type == model.RedeemTypeSubscription && redeemCode.GroupID == nil {
|
||||
if redeemCode.Type == RedeemTypeSubscription && redeemCode.GroupID == nil {
|
||||
return nil, infraerrors.BadRequest("REDEEM_CODE_INVALID", "invalid subscription redeem code: missing group_id")
|
||||
}
|
||||
|
||||
@@ -261,7 +259,7 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
|
||||
|
||||
// 执行兑换逻辑(兑换码已被锁定,此时可安全操作)
|
||||
switch redeemCode.Type {
|
||||
case model.RedeemTypeBalance:
|
||||
case RedeemTypeBalance:
|
||||
// 增加用户余额
|
||||
if err := s.userRepo.UpdateBalance(ctx, userID, redeemCode.Value); err != nil {
|
||||
return nil, fmt.Errorf("update user balance: %w", err)
|
||||
@@ -275,13 +273,13 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
|
||||
}()
|
||||
}
|
||||
|
||||
case model.RedeemTypeConcurrency:
|
||||
case RedeemTypeConcurrency:
|
||||
// 增加用户并发数
|
||||
if err := s.userRepo.UpdateConcurrency(ctx, userID, int(redeemCode.Value)); err != nil {
|
||||
return nil, fmt.Errorf("update user concurrency: %w", err)
|
||||
}
|
||||
|
||||
case model.RedeemTypeSubscription:
|
||||
case RedeemTypeSubscription:
|
||||
validityDays := redeemCode.ValidityDays
|
||||
if validityDays <= 0 {
|
||||
validityDays = 30
|
||||
@@ -320,7 +318,7 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取兑换码
|
||||
func (s *RedeemService) GetByID(ctx context.Context, id int64) (*model.RedeemCode, error) {
|
||||
func (s *RedeemService) GetByID(ctx context.Context, id int64) (*RedeemCode, error) {
|
||||
code, err := s.redeemRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get redeem code: %w", err)
|
||||
@@ -329,7 +327,7 @@ func (s *RedeemService) GetByID(ctx context.Context, id int64) (*model.RedeemCod
|
||||
}
|
||||
|
||||
// GetByCode 根据Code获取兑换码
|
||||
func (s *RedeemService) GetByCode(ctx context.Context, code string) (*model.RedeemCode, error) {
|
||||
func (s *RedeemService) GetByCode(ctx context.Context, code string) (*RedeemCode, error) {
|
||||
redeemCode, err := s.redeemRepo.GetByCode(ctx, code)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get redeem code: %w", err)
|
||||
@@ -338,7 +336,7 @@ func (s *RedeemService) GetByCode(ctx context.Context, code string) (*model.Rede
|
||||
}
|
||||
|
||||
// List 获取兑换码列表(管理员功能)
|
||||
func (s *RedeemService) List(ctx context.Context, params pagination.PaginationParams) ([]model.RedeemCode, *pagination.PaginationResult, error) {
|
||||
func (s *RedeemService) List(ctx context.Context, params pagination.PaginationParams) ([]RedeemCode, *pagination.PaginationResult, error) {
|
||||
codes, pagination, err := s.redeemRepo.List(ctx, params)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("list redeem codes: %w", err)
|
||||
@@ -383,7 +381,7 @@ func (s *RedeemService) GetStats(ctx context.Context) (map[string]any, error) {
|
||||
}
|
||||
|
||||
// GetUserHistory 获取用户的兑换历史
|
||||
func (s *RedeemService) GetUserHistory(ctx context.Context, userID int64, limit int) ([]model.RedeemCode, error) {
|
||||
func (s *RedeemService) GetUserHistory(ctx context.Context, userID int64, limit int) ([]RedeemCode, error) {
|
||||
codes, err := s.redeemRepo.ListByUser(ctx, userID, limit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get user redeem history: %w", err)
|
||||
|
||||
10
backend/internal/service/setting.go
Normal file
10
backend/internal/service/setting.go
Normal file
@@ -0,0 +1,10 @@
|
||||
package service
|
||||
|
||||
import "time"
|
||||
|
||||
type Setting struct {
|
||||
ID int64
|
||||
Key string
|
||||
Value string
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -19,7 +18,7 @@ var (
|
||||
)
|
||||
|
||||
type SettingRepository interface {
|
||||
Get(ctx context.Context, key string) (*model.Setting, error)
|
||||
Get(ctx context.Context, key string) (*Setting, error)
|
||||
GetValue(ctx context.Context, key string) (string, error)
|
||||
Set(ctx context.Context, key, value string) error
|
||||
GetMultiple(ctx context.Context, keys []string) (map[string]string, error)
|
||||
@@ -43,7 +42,7 @@ func NewSettingService(settingRepo SettingRepository, cfg *config.Config) *Setti
|
||||
}
|
||||
|
||||
// GetAllSettings 获取所有系统设置
|
||||
func (s *SettingService) GetAllSettings(ctx context.Context) (*model.SystemSettings, error) {
|
||||
func (s *SettingService) GetAllSettings(ctx context.Context) (*SystemSettings, error) {
|
||||
settings, err := s.settingRepo.GetAll(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get all settings: %w", err)
|
||||
@@ -53,18 +52,18 @@ func (s *SettingService) GetAllSettings(ctx context.Context) (*model.SystemSetti
|
||||
}
|
||||
|
||||
// GetPublicSettings 获取公开设置(无需登录)
|
||||
func (s *SettingService) GetPublicSettings(ctx context.Context) (*model.PublicSettings, error) {
|
||||
func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings, error) {
|
||||
keys := []string{
|
||||
model.SettingKeyRegistrationEnabled,
|
||||
model.SettingKeyEmailVerifyEnabled,
|
||||
model.SettingKeyTurnstileEnabled,
|
||||
model.SettingKeyTurnstileSiteKey,
|
||||
model.SettingKeySiteName,
|
||||
model.SettingKeySiteLogo,
|
||||
model.SettingKeySiteSubtitle,
|
||||
model.SettingKeyApiBaseUrl,
|
||||
model.SettingKeyContactInfo,
|
||||
model.SettingKeyDocUrl,
|
||||
SettingKeyRegistrationEnabled,
|
||||
SettingKeyEmailVerifyEnabled,
|
||||
SettingKeyTurnstileEnabled,
|
||||
SettingKeyTurnstileSiteKey,
|
||||
SettingKeySiteName,
|
||||
SettingKeySiteLogo,
|
||||
SettingKeySiteSubtitle,
|
||||
SettingKeyApiBaseUrl,
|
||||
SettingKeyContactInfo,
|
||||
SettingKeyDocUrl,
|
||||
}
|
||||
|
||||
settings, err := s.settingRepo.GetMultiple(ctx, keys)
|
||||
@@ -72,64 +71,64 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*model.PublicSe
|
||||
return nil, fmt.Errorf("get public settings: %w", err)
|
||||
}
|
||||
|
||||
return &model.PublicSettings{
|
||||
RegistrationEnabled: settings[model.SettingKeyRegistrationEnabled] == "true",
|
||||
EmailVerifyEnabled: settings[model.SettingKeyEmailVerifyEnabled] == "true",
|
||||
TurnstileEnabled: settings[model.SettingKeyTurnstileEnabled] == "true",
|
||||
TurnstileSiteKey: settings[model.SettingKeyTurnstileSiteKey],
|
||||
SiteName: s.getStringOrDefault(settings, model.SettingKeySiteName, "Sub2API"),
|
||||
SiteLogo: settings[model.SettingKeySiteLogo],
|
||||
SiteSubtitle: s.getStringOrDefault(settings, model.SettingKeySiteSubtitle, "Subscription to API Conversion Platform"),
|
||||
ApiBaseUrl: settings[model.SettingKeyApiBaseUrl],
|
||||
ContactInfo: settings[model.SettingKeyContactInfo],
|
||||
DocUrl: settings[model.SettingKeyDocUrl],
|
||||
return &PublicSettings{
|
||||
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
|
||||
EmailVerifyEnabled: settings[SettingKeyEmailVerifyEnabled] == "true",
|
||||
TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true",
|
||||
TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey],
|
||||
SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"),
|
||||
SiteLogo: settings[SettingKeySiteLogo],
|
||||
SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"),
|
||||
ApiBaseUrl: settings[SettingKeyApiBaseUrl],
|
||||
ContactInfo: settings[SettingKeyContactInfo],
|
||||
DocUrl: settings[SettingKeyDocUrl],
|
||||
}, nil
|
||||
}
|
||||
|
||||
// UpdateSettings 更新系统设置
|
||||
func (s *SettingService) UpdateSettings(ctx context.Context, settings *model.SystemSettings) error {
|
||||
func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSettings) error {
|
||||
updates := make(map[string]string)
|
||||
|
||||
// 注册设置
|
||||
updates[model.SettingKeyRegistrationEnabled] = strconv.FormatBool(settings.RegistrationEnabled)
|
||||
updates[model.SettingKeyEmailVerifyEnabled] = strconv.FormatBool(settings.EmailVerifyEnabled)
|
||||
updates[SettingKeyRegistrationEnabled] = strconv.FormatBool(settings.RegistrationEnabled)
|
||||
updates[SettingKeyEmailVerifyEnabled] = strconv.FormatBool(settings.EmailVerifyEnabled)
|
||||
|
||||
// 邮件服务设置(只有非空才更新密码)
|
||||
updates[model.SettingKeySmtpHost] = settings.SmtpHost
|
||||
updates[model.SettingKeySmtpPort] = strconv.Itoa(settings.SmtpPort)
|
||||
updates[model.SettingKeySmtpUsername] = settings.SmtpUsername
|
||||
updates[SettingKeySmtpHost] = settings.SmtpHost
|
||||
updates[SettingKeySmtpPort] = strconv.Itoa(settings.SmtpPort)
|
||||
updates[SettingKeySmtpUsername] = settings.SmtpUsername
|
||||
if settings.SmtpPassword != "" {
|
||||
updates[model.SettingKeySmtpPassword] = settings.SmtpPassword
|
||||
updates[SettingKeySmtpPassword] = settings.SmtpPassword
|
||||
}
|
||||
updates[model.SettingKeySmtpFrom] = settings.SmtpFrom
|
||||
updates[model.SettingKeySmtpFromName] = settings.SmtpFromName
|
||||
updates[model.SettingKeySmtpUseTLS] = strconv.FormatBool(settings.SmtpUseTLS)
|
||||
updates[SettingKeySmtpFrom] = settings.SmtpFrom
|
||||
updates[SettingKeySmtpFromName] = settings.SmtpFromName
|
||||
updates[SettingKeySmtpUseTLS] = strconv.FormatBool(settings.SmtpUseTLS)
|
||||
|
||||
// Cloudflare Turnstile 设置(只有非空才更新密钥)
|
||||
updates[model.SettingKeyTurnstileEnabled] = strconv.FormatBool(settings.TurnstileEnabled)
|
||||
updates[model.SettingKeyTurnstileSiteKey] = settings.TurnstileSiteKey
|
||||
updates[SettingKeyTurnstileEnabled] = strconv.FormatBool(settings.TurnstileEnabled)
|
||||
updates[SettingKeyTurnstileSiteKey] = settings.TurnstileSiteKey
|
||||
if settings.TurnstileSecretKey != "" {
|
||||
updates[model.SettingKeyTurnstileSecretKey] = settings.TurnstileSecretKey
|
||||
updates[SettingKeyTurnstileSecretKey] = settings.TurnstileSecretKey
|
||||
}
|
||||
|
||||
// OEM设置
|
||||
updates[model.SettingKeySiteName] = settings.SiteName
|
||||
updates[model.SettingKeySiteLogo] = settings.SiteLogo
|
||||
updates[model.SettingKeySiteSubtitle] = settings.SiteSubtitle
|
||||
updates[model.SettingKeyApiBaseUrl] = settings.ApiBaseUrl
|
||||
updates[model.SettingKeyContactInfo] = settings.ContactInfo
|
||||
updates[model.SettingKeyDocUrl] = settings.DocUrl
|
||||
updates[SettingKeySiteName] = settings.SiteName
|
||||
updates[SettingKeySiteLogo] = settings.SiteLogo
|
||||
updates[SettingKeySiteSubtitle] = settings.SiteSubtitle
|
||||
updates[SettingKeyApiBaseUrl] = settings.ApiBaseUrl
|
||||
updates[SettingKeyContactInfo] = settings.ContactInfo
|
||||
updates[SettingKeyDocUrl] = settings.DocUrl
|
||||
|
||||
// 默认配置
|
||||
updates[model.SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency)
|
||||
updates[model.SettingKeyDefaultBalance] = strconv.FormatFloat(settings.DefaultBalance, 'f', 8, 64)
|
||||
updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency)
|
||||
updates[SettingKeyDefaultBalance] = strconv.FormatFloat(settings.DefaultBalance, 'f', 8, 64)
|
||||
|
||||
return s.settingRepo.SetMultiple(ctx, updates)
|
||||
}
|
||||
|
||||
// IsRegistrationEnabled 检查是否开放注册
|
||||
func (s *SettingService) IsRegistrationEnabled(ctx context.Context) bool {
|
||||
value, err := s.settingRepo.GetValue(ctx, model.SettingKeyRegistrationEnabled)
|
||||
value, err := s.settingRepo.GetValue(ctx, SettingKeyRegistrationEnabled)
|
||||
if err != nil {
|
||||
// 默认开放注册
|
||||
return true
|
||||
@@ -139,7 +138,7 @@ func (s *SettingService) IsRegistrationEnabled(ctx context.Context) bool {
|
||||
|
||||
// IsEmailVerifyEnabled 检查是否开启邮件验证
|
||||
func (s *SettingService) IsEmailVerifyEnabled(ctx context.Context) bool {
|
||||
value, err := s.settingRepo.GetValue(ctx, model.SettingKeyEmailVerifyEnabled)
|
||||
value, err := s.settingRepo.GetValue(ctx, SettingKeyEmailVerifyEnabled)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
@@ -148,7 +147,7 @@ func (s *SettingService) IsEmailVerifyEnabled(ctx context.Context) bool {
|
||||
|
||||
// GetSiteName 获取网站名称
|
||||
func (s *SettingService) GetSiteName(ctx context.Context) string {
|
||||
value, err := s.settingRepo.GetValue(ctx, model.SettingKeySiteName)
|
||||
value, err := s.settingRepo.GetValue(ctx, SettingKeySiteName)
|
||||
if err != nil || value == "" {
|
||||
return "Sub2API"
|
||||
}
|
||||
@@ -157,7 +156,7 @@ func (s *SettingService) GetSiteName(ctx context.Context) string {
|
||||
|
||||
// GetDefaultConcurrency 获取默认并发量
|
||||
func (s *SettingService) GetDefaultConcurrency(ctx context.Context) int {
|
||||
value, err := s.settingRepo.GetValue(ctx, model.SettingKeyDefaultConcurrency)
|
||||
value, err := s.settingRepo.GetValue(ctx, SettingKeyDefaultConcurrency)
|
||||
if err != nil {
|
||||
return s.cfg.Default.UserConcurrency
|
||||
}
|
||||
@@ -169,7 +168,7 @@ func (s *SettingService) GetDefaultConcurrency(ctx context.Context) int {
|
||||
|
||||
// GetDefaultBalance 获取默认余额
|
||||
func (s *SettingService) GetDefaultBalance(ctx context.Context) float64 {
|
||||
value, err := s.settingRepo.GetValue(ctx, model.SettingKeyDefaultBalance)
|
||||
value, err := s.settingRepo.GetValue(ctx, SettingKeyDefaultBalance)
|
||||
if err != nil {
|
||||
return s.cfg.Default.UserBalance
|
||||
}
|
||||
@@ -182,7 +181,7 @@ func (s *SettingService) GetDefaultBalance(ctx context.Context) float64 {
|
||||
// InitializeDefaultSettings 初始化默认设置
|
||||
func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
|
||||
// 检查是否已有设置
|
||||
_, err := s.settingRepo.GetValue(ctx, model.SettingKeyRegistrationEnabled)
|
||||
_, err := s.settingRepo.GetValue(ctx, SettingKeyRegistrationEnabled)
|
||||
if err == nil {
|
||||
// 已有设置,不需要初始化
|
||||
return nil
|
||||
@@ -193,62 +192,62 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
|
||||
|
||||
// 初始化默认设置
|
||||
defaults := map[string]string{
|
||||
model.SettingKeyRegistrationEnabled: "true",
|
||||
model.SettingKeyEmailVerifyEnabled: "false",
|
||||
model.SettingKeySiteName: "Sub2API",
|
||||
model.SettingKeySiteLogo: "",
|
||||
model.SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
|
||||
model.SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
|
||||
model.SettingKeySmtpPort: "587",
|
||||
model.SettingKeySmtpUseTLS: "false",
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
SettingKeyEmailVerifyEnabled: "false",
|
||||
SettingKeySiteName: "Sub2API",
|
||||
SettingKeySiteLogo: "",
|
||||
SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
|
||||
SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
|
||||
SettingKeySmtpPort: "587",
|
||||
SettingKeySmtpUseTLS: "false",
|
||||
}
|
||||
|
||||
return s.settingRepo.SetMultiple(ctx, defaults)
|
||||
}
|
||||
|
||||
// parseSettings 解析设置到结构体
|
||||
func (s *SettingService) parseSettings(settings map[string]string) *model.SystemSettings {
|
||||
result := &model.SystemSettings{
|
||||
RegistrationEnabled: settings[model.SettingKeyRegistrationEnabled] == "true",
|
||||
EmailVerifyEnabled: settings[model.SettingKeyEmailVerifyEnabled] == "true",
|
||||
SmtpHost: settings[model.SettingKeySmtpHost],
|
||||
SmtpUsername: settings[model.SettingKeySmtpUsername],
|
||||
SmtpFrom: settings[model.SettingKeySmtpFrom],
|
||||
SmtpFromName: settings[model.SettingKeySmtpFromName],
|
||||
SmtpUseTLS: settings[model.SettingKeySmtpUseTLS] == "true",
|
||||
TurnstileEnabled: settings[model.SettingKeyTurnstileEnabled] == "true",
|
||||
TurnstileSiteKey: settings[model.SettingKeyTurnstileSiteKey],
|
||||
SiteName: s.getStringOrDefault(settings, model.SettingKeySiteName, "Sub2API"),
|
||||
SiteLogo: settings[model.SettingKeySiteLogo],
|
||||
SiteSubtitle: s.getStringOrDefault(settings, model.SettingKeySiteSubtitle, "Subscription to API Conversion Platform"),
|
||||
ApiBaseUrl: settings[model.SettingKeyApiBaseUrl],
|
||||
ContactInfo: settings[model.SettingKeyContactInfo],
|
||||
DocUrl: settings[model.SettingKeyDocUrl],
|
||||
func (s *SettingService) parseSettings(settings map[string]string) *SystemSettings {
|
||||
result := &SystemSettings{
|
||||
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
|
||||
EmailVerifyEnabled: settings[SettingKeyEmailVerifyEnabled] == "true",
|
||||
SmtpHost: settings[SettingKeySmtpHost],
|
||||
SmtpUsername: settings[SettingKeySmtpUsername],
|
||||
SmtpFrom: settings[SettingKeySmtpFrom],
|
||||
SmtpFromName: settings[SettingKeySmtpFromName],
|
||||
SmtpUseTLS: settings[SettingKeySmtpUseTLS] == "true",
|
||||
TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true",
|
||||
TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey],
|
||||
SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"),
|
||||
SiteLogo: settings[SettingKeySiteLogo],
|
||||
SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"),
|
||||
ApiBaseUrl: settings[SettingKeyApiBaseUrl],
|
||||
ContactInfo: settings[SettingKeyContactInfo],
|
||||
DocUrl: settings[SettingKeyDocUrl],
|
||||
}
|
||||
|
||||
// 解析整数类型
|
||||
if port, err := strconv.Atoi(settings[model.SettingKeySmtpPort]); err == nil {
|
||||
if port, err := strconv.Atoi(settings[SettingKeySmtpPort]); err == nil {
|
||||
result.SmtpPort = port
|
||||
} else {
|
||||
result.SmtpPort = 587
|
||||
}
|
||||
|
||||
if concurrency, err := strconv.Atoi(settings[model.SettingKeyDefaultConcurrency]); err == nil {
|
||||
if concurrency, err := strconv.Atoi(settings[SettingKeyDefaultConcurrency]); err == nil {
|
||||
result.DefaultConcurrency = concurrency
|
||||
} else {
|
||||
result.DefaultConcurrency = s.cfg.Default.UserConcurrency
|
||||
}
|
||||
|
||||
// 解析浮点数类型
|
||||
if balance, err := strconv.ParseFloat(settings[model.SettingKeyDefaultBalance], 64); err == nil {
|
||||
if balance, err := strconv.ParseFloat(settings[SettingKeyDefaultBalance], 64); err == nil {
|
||||
result.DefaultBalance = balance
|
||||
} else {
|
||||
result.DefaultBalance = s.cfg.Default.UserBalance
|
||||
}
|
||||
|
||||
// 敏感信息直接返回,方便测试连接时使用
|
||||
result.SmtpPassword = settings[model.SettingKeySmtpPassword]
|
||||
result.TurnstileSecretKey = settings[model.SettingKeyTurnstileSecretKey]
|
||||
result.SmtpPassword = settings[SettingKeySmtpPassword]
|
||||
result.TurnstileSecretKey = settings[SettingKeyTurnstileSecretKey]
|
||||
|
||||
return result
|
||||
}
|
||||
@@ -263,7 +262,7 @@ func (s *SettingService) getStringOrDefault(settings map[string]string, key, def
|
||||
|
||||
// IsTurnstileEnabled 检查是否启用 Turnstile 验证
|
||||
func (s *SettingService) IsTurnstileEnabled(ctx context.Context) bool {
|
||||
value, err := s.settingRepo.GetValue(ctx, model.SettingKeyTurnstileEnabled)
|
||||
value, err := s.settingRepo.GetValue(ctx, SettingKeyTurnstileEnabled)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
@@ -272,7 +271,7 @@ func (s *SettingService) IsTurnstileEnabled(ctx context.Context) bool {
|
||||
|
||||
// GetTurnstileSecretKey 获取 Turnstile Secret Key
|
||||
func (s *SettingService) GetTurnstileSecretKey(ctx context.Context) string {
|
||||
value, err := s.settingRepo.GetValue(ctx, model.SettingKeyTurnstileSecretKey)
|
||||
value, err := s.settingRepo.GetValue(ctx, SettingKeyTurnstileSecretKey)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
@@ -287,10 +286,10 @@ func (s *SettingService) GenerateAdminApiKey(ctx context.Context) (string, error
|
||||
return "", fmt.Errorf("generate random bytes: %w", err)
|
||||
}
|
||||
|
||||
key := model.AdminApiKeyPrefix + hex.EncodeToString(bytes)
|
||||
key := AdminApiKeyPrefix + hex.EncodeToString(bytes)
|
||||
|
||||
// 存储到 settings 表
|
||||
if err := s.settingRepo.Set(ctx, model.SettingKeyAdminApiKey, key); err != nil {
|
||||
if err := s.settingRepo.Set(ctx, SettingKeyAdminApiKey, key); err != nil {
|
||||
return "", fmt.Errorf("save admin api key: %w", err)
|
||||
}
|
||||
|
||||
@@ -300,7 +299,7 @@ func (s *SettingService) GenerateAdminApiKey(ctx context.Context) (string, error
|
||||
// GetAdminApiKeyStatus 获取管理员 API Key 状态
|
||||
// 返回脱敏的 key、是否存在、错误
|
||||
func (s *SettingService) GetAdminApiKeyStatus(ctx context.Context) (maskedKey string, exists bool, err error) {
|
||||
key, err := s.settingRepo.GetValue(ctx, model.SettingKeyAdminApiKey)
|
||||
key, err := s.settingRepo.GetValue(ctx, SettingKeyAdminApiKey)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrSettingNotFound) {
|
||||
return "", false, nil
|
||||
@@ -324,7 +323,7 @@ func (s *SettingService) GetAdminApiKeyStatus(ctx context.Context) (maskedKey st
|
||||
// GetAdminApiKey 获取完整的管理员 API Key(仅供内部验证使用)
|
||||
// 如果未配置返回空字符串和 nil 错误,只有数据库错误时才返回 error
|
||||
func (s *SettingService) GetAdminApiKey(ctx context.Context) (string, error) {
|
||||
key, err := s.settingRepo.GetValue(ctx, model.SettingKeyAdminApiKey)
|
||||
key, err := s.settingRepo.GetValue(ctx, SettingKeyAdminApiKey)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrSettingNotFound) {
|
||||
return "", nil // 未配置,返回空字符串
|
||||
@@ -336,5 +335,5 @@ func (s *SettingService) GetAdminApiKey(ctx context.Context) (string, error) {
|
||||
|
||||
// DeleteAdminApiKey 删除管理员 API Key
|
||||
func (s *SettingService) DeleteAdminApiKey(ctx context.Context) error {
|
||||
return s.settingRepo.Delete(ctx, model.SettingKeyAdminApiKey)
|
||||
return s.settingRepo.Delete(ctx, SettingKeyAdminApiKey)
|
||||
}
|
||||
|
||||
42
backend/internal/service/settings_view.go
Normal file
42
backend/internal/service/settings_view.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package service
|
||||
|
||||
type SystemSettings struct {
|
||||
RegistrationEnabled bool
|
||||
EmailVerifyEnabled bool
|
||||
|
||||
SmtpHost string
|
||||
SmtpPort int
|
||||
SmtpUsername string
|
||||
SmtpPassword string
|
||||
SmtpFrom string
|
||||
SmtpFromName string
|
||||
SmtpUseTLS bool
|
||||
|
||||
TurnstileEnabled bool
|
||||
TurnstileSiteKey string
|
||||
TurnstileSecretKey string
|
||||
|
||||
SiteName string
|
||||
SiteLogo string
|
||||
SiteSubtitle string
|
||||
ApiBaseUrl string
|
||||
ContactInfo string
|
||||
DocUrl string
|
||||
|
||||
DefaultConcurrency int
|
||||
DefaultBalance float64
|
||||
}
|
||||
|
||||
type PublicSettings struct {
|
||||
RegistrationEnabled bool
|
||||
EmailVerifyEnabled bool
|
||||
TurnstileEnabled bool
|
||||
TurnstileSiteKey string
|
||||
SiteName string
|
||||
SiteLogo string
|
||||
SiteSubtitle string
|
||||
ApiBaseUrl string
|
||||
ContactInfo string
|
||||
DocUrl string
|
||||
Version string
|
||||
}
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
@@ -48,7 +47,7 @@ type AssignSubscriptionInput struct {
|
||||
}
|
||||
|
||||
// AssignSubscription 分配订阅给用户(不允许重复分配)
|
||||
func (s *SubscriptionService) AssignSubscription(ctx context.Context, input *AssignSubscriptionInput) (*model.UserSubscription, error) {
|
||||
func (s *SubscriptionService) AssignSubscription(ctx context.Context, input *AssignSubscriptionInput) (*UserSubscription, error) {
|
||||
// 检查分组是否存在且为订阅类型
|
||||
group, err := s.groupRepo.GetByID(ctx, input.GroupID)
|
||||
if err != nil {
|
||||
@@ -91,7 +90,7 @@ func (s *SubscriptionService) AssignSubscription(ctx context.Context, input *Ass
|
||||
// - 已过期:从当前时间开始计算新的过期时间,并激活订阅
|
||||
//
|
||||
// 如果没有订阅:创建新订阅
|
||||
func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, input *AssignSubscriptionInput) (*model.UserSubscription, bool, error) {
|
||||
func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, input *AssignSubscriptionInput) (*UserSubscription, bool, error) {
|
||||
// 检查分组是否存在且为订阅类型
|
||||
group, err := s.groupRepo.GetByID(ctx, input.GroupID)
|
||||
if err != nil {
|
||||
@@ -132,8 +131,8 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
|
||||
}
|
||||
|
||||
// 如果订阅已过期或被暂停,恢复为active状态
|
||||
if existingSub.Status != model.SubscriptionStatusActive {
|
||||
if err := s.userSubRepo.UpdateStatus(ctx, existingSub.ID, model.SubscriptionStatusActive); err != nil {
|
||||
if existingSub.Status != SubscriptionStatusActive {
|
||||
if err := s.userSubRepo.UpdateStatus(ctx, existingSub.ID, SubscriptionStatusActive); err != nil {
|
||||
return nil, false, fmt.Errorf("update subscription status: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -185,19 +184,19 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
|
||||
}
|
||||
|
||||
// createSubscription 创建新订阅(内部方法)
|
||||
func (s *SubscriptionService) createSubscription(ctx context.Context, input *AssignSubscriptionInput) (*model.UserSubscription, error) {
|
||||
func (s *SubscriptionService) createSubscription(ctx context.Context, input *AssignSubscriptionInput) (*UserSubscription, error) {
|
||||
validityDays := input.ValidityDays
|
||||
if validityDays <= 0 {
|
||||
validityDays = 30
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
sub := &model.UserSubscription{
|
||||
sub := &UserSubscription{
|
||||
UserID: input.UserID,
|
||||
GroupID: input.GroupID,
|
||||
StartsAt: now,
|
||||
ExpiresAt: now.AddDate(0, 0, validityDays),
|
||||
Status: model.SubscriptionStatusActive,
|
||||
Status: SubscriptionStatusActive,
|
||||
AssignedAt: now,
|
||||
Notes: input.Notes,
|
||||
CreatedAt: now,
|
||||
@@ -229,14 +228,14 @@ type BulkAssignSubscriptionInput struct {
|
||||
type BulkAssignResult struct {
|
||||
SuccessCount int
|
||||
FailedCount int
|
||||
Subscriptions []model.UserSubscription
|
||||
Subscriptions []UserSubscription
|
||||
Errors []string
|
||||
}
|
||||
|
||||
// BulkAssignSubscription 批量分配订阅
|
||||
func (s *SubscriptionService) BulkAssignSubscription(ctx context.Context, input *BulkAssignSubscriptionInput) (*BulkAssignResult, error) {
|
||||
result := &BulkAssignResult{
|
||||
Subscriptions: make([]model.UserSubscription, 0),
|
||||
Subscriptions: make([]UserSubscription, 0),
|
||||
Errors: make([]string, 0),
|
||||
}
|
||||
|
||||
@@ -286,7 +285,7 @@ func (s *SubscriptionService) RevokeSubscription(ctx context.Context, subscripti
|
||||
}
|
||||
|
||||
// ExtendSubscription 延长订阅
|
||||
func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscriptionID int64, days int) (*model.UserSubscription, error) {
|
||||
func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscriptionID int64, days int) (*UserSubscription, error) {
|
||||
sub, err := s.userSubRepo.GetByID(ctx, subscriptionID)
|
||||
if err != nil {
|
||||
return nil, ErrSubscriptionNotFound
|
||||
@@ -299,8 +298,8 @@ func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscripti
|
||||
}
|
||||
|
||||
// 如果订阅已过期,恢复为active状态
|
||||
if sub.Status == model.SubscriptionStatusExpired {
|
||||
if err := s.userSubRepo.UpdateStatus(ctx, subscriptionID, model.SubscriptionStatusActive); err != nil {
|
||||
if sub.Status == SubscriptionStatusExpired {
|
||||
if err := s.userSubRepo.UpdateStatus(ctx, subscriptionID, SubscriptionStatusActive); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
@@ -319,12 +318,12 @@ func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscripti
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取订阅
|
||||
func (s *SubscriptionService) GetByID(ctx context.Context, id int64) (*model.UserSubscription, error) {
|
||||
func (s *SubscriptionService) GetByID(ctx context.Context, id int64) (*UserSubscription, error) {
|
||||
return s.userSubRepo.GetByID(ctx, id)
|
||||
}
|
||||
|
||||
// GetActiveSubscription 获取用户对特定分组的有效订阅
|
||||
func (s *SubscriptionService) GetActiveSubscription(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error) {
|
||||
func (s *SubscriptionService) GetActiveSubscription(ctx context.Context, userID, groupID int64) (*UserSubscription, error) {
|
||||
sub, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, userID, groupID)
|
||||
if err != nil {
|
||||
return nil, ErrSubscriptionNotFound
|
||||
@@ -333,7 +332,7 @@ func (s *SubscriptionService) GetActiveSubscription(ctx context.Context, userID,
|
||||
}
|
||||
|
||||
// ListUserSubscriptions 获取用户的所有订阅
|
||||
func (s *SubscriptionService) ListUserSubscriptions(ctx context.Context, userID int64) ([]model.UserSubscription, error) {
|
||||
func (s *SubscriptionService) ListUserSubscriptions(ctx context.Context, userID int64) ([]UserSubscription, error) {
|
||||
subs, err := s.userSubRepo.ListByUserID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -343,7 +342,7 @@ func (s *SubscriptionService) ListUserSubscriptions(ctx context.Context, userID
|
||||
}
|
||||
|
||||
// ListActiveUserSubscriptions 获取用户的所有有效订阅
|
||||
func (s *SubscriptionService) ListActiveUserSubscriptions(ctx context.Context, userID int64) ([]model.UserSubscription, error) {
|
||||
func (s *SubscriptionService) ListActiveUserSubscriptions(ctx context.Context, userID int64) ([]UserSubscription, error) {
|
||||
subs, err := s.userSubRepo.ListActiveByUserID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -353,7 +352,7 @@ func (s *SubscriptionService) ListActiveUserSubscriptions(ctx context.Context, u
|
||||
}
|
||||
|
||||
// ListGroupSubscriptions 获取分组的所有订阅
|
||||
func (s *SubscriptionService) ListGroupSubscriptions(ctx context.Context, groupID int64, page, pageSize int) ([]model.UserSubscription, *pagination.PaginationResult, error) {
|
||||
func (s *SubscriptionService) ListGroupSubscriptions(ctx context.Context, groupID int64, page, pageSize int) ([]UserSubscription, *pagination.PaginationResult, error) {
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
subs, pag, err := s.userSubRepo.ListByGroupID(ctx, groupID, params)
|
||||
if err != nil {
|
||||
@@ -364,7 +363,7 @@ func (s *SubscriptionService) ListGroupSubscriptions(ctx context.Context, groupI
|
||||
}
|
||||
|
||||
// List 获取所有订阅(分页,支持筛选)
|
||||
func (s *SubscriptionService) List(ctx context.Context, page, pageSize int, userID, groupID *int64, status string) ([]model.UserSubscription, *pagination.PaginationResult, error) {
|
||||
func (s *SubscriptionService) List(ctx context.Context, page, pageSize int, userID, groupID *int64, status string) ([]UserSubscription, *pagination.PaginationResult, error) {
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
subs, pag, err := s.userSubRepo.List(ctx, params, userID, groupID, status)
|
||||
if err != nil {
|
||||
@@ -376,7 +375,7 @@ func (s *SubscriptionService) List(ctx context.Context, page, pageSize int, user
|
||||
|
||||
// normalizeExpiredWindows 将已过期窗口的数据清零(仅影响返回数据,不影响数据库)
|
||||
// 这确保前端显示正确的当前窗口状态,而不是过期窗口的历史数据
|
||||
func normalizeExpiredWindows(subs []model.UserSubscription) {
|
||||
func normalizeExpiredWindows(subs []UserSubscription) {
|
||||
for i := range subs {
|
||||
sub := &subs[i]
|
||||
// 日窗口过期:清零展示数据
|
||||
@@ -403,7 +402,7 @@ func startOfDay(t time.Time) time.Time {
|
||||
}
|
||||
|
||||
// CheckAndActivateWindow 检查并激活窗口(首次使用时)
|
||||
func (s *SubscriptionService) CheckAndActivateWindow(ctx context.Context, sub *model.UserSubscription) error {
|
||||
func (s *SubscriptionService) CheckAndActivateWindow(ctx context.Context, sub *UserSubscription) error {
|
||||
if sub.IsWindowActivated() {
|
||||
return nil
|
||||
}
|
||||
@@ -414,7 +413,7 @@ func (s *SubscriptionService) CheckAndActivateWindow(ctx context.Context, sub *m
|
||||
}
|
||||
|
||||
// CheckAndResetWindows 检查并重置过期的窗口
|
||||
func (s *SubscriptionService) CheckAndResetWindows(ctx context.Context, sub *model.UserSubscription) error {
|
||||
func (s *SubscriptionService) CheckAndResetWindows(ctx context.Context, sub *UserSubscription) error {
|
||||
// 使用当天零点作为新窗口起始时间
|
||||
windowStart := startOfDay(time.Now())
|
||||
needsInvalidateCache := false
|
||||
@@ -458,7 +457,7 @@ func (s *SubscriptionService) CheckAndResetWindows(ctx context.Context, sub *mod
|
||||
}
|
||||
|
||||
// CheckUsageLimits 检查使用限额(返回错误如果超限)
|
||||
func (s *SubscriptionService) CheckUsageLimits(ctx context.Context, sub *model.UserSubscription, group *model.Group, additionalCost float64) error {
|
||||
func (s *SubscriptionService) CheckUsageLimits(ctx context.Context, sub *UserSubscription, group *Group, additionalCost float64) error {
|
||||
if !sub.CheckDailyLimit(group, additionalCost) {
|
||||
return ErrDailyLimitExceeded
|
||||
}
|
||||
@@ -620,16 +619,16 @@ func (s *SubscriptionService) UpdateExpiredSubscriptions(ctx context.Context) (i
|
||||
}
|
||||
|
||||
// ValidateSubscription 验证订阅是否有效
|
||||
func (s *SubscriptionService) ValidateSubscription(ctx context.Context, sub *model.UserSubscription) error {
|
||||
if sub.Status == model.SubscriptionStatusExpired {
|
||||
func (s *SubscriptionService) ValidateSubscription(ctx context.Context, sub *UserSubscription) error {
|
||||
if sub.Status == SubscriptionStatusExpired {
|
||||
return ErrSubscriptionExpired
|
||||
}
|
||||
if sub.Status == model.SubscriptionStatusSuspended {
|
||||
if sub.Status == SubscriptionStatusSuspended {
|
||||
return ErrSubscriptionSuspended
|
||||
}
|
||||
if sub.IsExpired() {
|
||||
// 更新状态
|
||||
_ = s.userSubRepo.UpdateStatus(ctx, sub.ID, model.SubscriptionStatusExpired)
|
||||
_ = s.userSubRepo.UpdateStatus(ctx, sub.ID, SubscriptionStatusExpired)
|
||||
return ErrSubscriptionExpired
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
)
|
||||
|
||||
// TokenRefreshService OAuth token自动刷新服务
|
||||
@@ -144,19 +143,19 @@ func (s *TokenRefreshService) processRefresh() {
|
||||
|
||||
// listActiveAccounts 获取所有active状态的账号
|
||||
// 使用ListActive确保刷新所有活跃账号的token(包括临时禁用的)
|
||||
func (s *TokenRefreshService) listActiveAccounts(ctx context.Context) ([]model.Account, error) {
|
||||
func (s *TokenRefreshService) listActiveAccounts(ctx context.Context) ([]Account, error) {
|
||||
return s.accountRepo.ListActive(ctx)
|
||||
}
|
||||
|
||||
// refreshWithRetry 带重试的刷新
|
||||
func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *model.Account, refresher TokenRefresher) error {
|
||||
func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Account, refresher TokenRefresher) error {
|
||||
var lastErr error
|
||||
|
||||
for attempt := 1; attempt <= s.cfg.MaxRetries; attempt++ {
|
||||
newCredentials, err := refresher.Refresh(ctx, account)
|
||||
if err == nil {
|
||||
// 刷新成功,更新账号credentials
|
||||
account.Credentials = model.JSONB(newCredentials)
|
||||
account.Credentials = newCredentials
|
||||
if err := s.accountRepo.Update(ctx, account); err != nil {
|
||||
return fmt.Errorf("failed to save credentials: %w", err)
|
||||
}
|
||||
|
||||
@@ -4,22 +4,20 @@ import (
|
||||
"context"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
)
|
||||
|
||||
// TokenRefresher 定义平台特定的token刷新策略接口
|
||||
// 通过此接口可以扩展支持不同平台(Anthropic/OpenAI/Gemini)
|
||||
type TokenRefresher interface {
|
||||
// CanRefresh 检查此刷新器是否能处理指定账号
|
||||
CanRefresh(account *model.Account) bool
|
||||
CanRefresh(account *Account) bool
|
||||
|
||||
// NeedsRefresh 检查账号的token是否需要刷新
|
||||
NeedsRefresh(account *model.Account, refreshWindow time.Duration) bool
|
||||
NeedsRefresh(account *Account, refreshWindow time.Duration) bool
|
||||
|
||||
// Refresh 执行token刷新,返回更新后的credentials
|
||||
// 注意:返回的map应该保留原有credentials中的所有字段,只更新token相关字段
|
||||
Refresh(ctx context.Context, account *model.Account) (map[string]any, error)
|
||||
Refresh(ctx context.Context, account *Account) (map[string]any, error)
|
||||
}
|
||||
|
||||
// ClaudeTokenRefresher 处理Anthropic/Claude OAuth token刷新
|
||||
@@ -37,14 +35,14 @@ func NewClaudeTokenRefresher(oauthService *OAuthService) *ClaudeTokenRefresher {
|
||||
// CanRefresh 检查是否能处理此账号
|
||||
// 只处理 anthropic 平台的 oauth 类型账号
|
||||
// setup-token 虽然也是OAuth,但有效期1年,不需要频繁刷新
|
||||
func (r *ClaudeTokenRefresher) CanRefresh(account *model.Account) bool {
|
||||
return account.Platform == model.PlatformAnthropic &&
|
||||
account.Type == model.AccountTypeOAuth
|
||||
func (r *ClaudeTokenRefresher) CanRefresh(account *Account) bool {
|
||||
return account.Platform == PlatformAnthropic &&
|
||||
account.Type == AccountTypeOAuth
|
||||
}
|
||||
|
||||
// NeedsRefresh 检查token是否需要刷新
|
||||
// 基于 expires_at 字段判断是否在刷新窗口内
|
||||
func (r *ClaudeTokenRefresher) NeedsRefresh(account *model.Account, refreshWindow time.Duration) bool {
|
||||
func (r *ClaudeTokenRefresher) NeedsRefresh(account *Account, refreshWindow time.Duration) bool {
|
||||
expiresAtStr := account.GetCredential("expires_at")
|
||||
if expiresAtStr == "" {
|
||||
return false
|
||||
@@ -61,7 +59,7 @@ func (r *ClaudeTokenRefresher) NeedsRefresh(account *model.Account, refreshWindo
|
||||
|
||||
// Refresh 执行token刷新
|
||||
// 保留原有credentials中的所有字段,只更新token相关字段
|
||||
func (r *ClaudeTokenRefresher) Refresh(ctx context.Context, account *model.Account) (map[string]any, error) {
|
||||
func (r *ClaudeTokenRefresher) Refresh(ctx context.Context, account *Account) (map[string]any, error) {
|
||||
tokenInfo, err := r.oauthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -103,14 +101,14 @@ func NewOpenAITokenRefresher(openaiOAuthService *OpenAIOAuthService) *OpenAIToke
|
||||
|
||||
// CanRefresh 检查是否能处理此账号
|
||||
// 只处理 openai 平台的 oauth 类型账号
|
||||
func (r *OpenAITokenRefresher) CanRefresh(account *model.Account) bool {
|
||||
return account.Platform == model.PlatformOpenAI &&
|
||||
account.Type == model.AccountTypeOAuth
|
||||
func (r *OpenAITokenRefresher) CanRefresh(account *Account) bool {
|
||||
return account.Platform == PlatformOpenAI &&
|
||||
account.Type == AccountTypeOAuth
|
||||
}
|
||||
|
||||
// NeedsRefresh 检查token是否需要刷新
|
||||
// 基于 expires_at 字段判断是否在刷新窗口内
|
||||
func (r *OpenAITokenRefresher) NeedsRefresh(account *model.Account, refreshWindow time.Duration) bool {
|
||||
func (r *OpenAITokenRefresher) NeedsRefresh(account *Account, refreshWindow time.Duration) bool {
|
||||
expiresAt := account.GetOpenAITokenExpiresAt()
|
||||
if expiresAt == nil {
|
||||
return false
|
||||
@@ -121,7 +119,7 @@ func (r *OpenAITokenRefresher) NeedsRefresh(account *model.Account, refreshWindo
|
||||
|
||||
// Refresh 执行token刷新
|
||||
// 保留原有credentials中的所有字段,只更新token相关字段
|
||||
func (r *OpenAITokenRefresher) Refresh(ctx context.Context, account *model.Account) (map[string]any, error) {
|
||||
func (r *OpenAITokenRefresher) Refresh(ctx context.Context, account *Account) (map[string]any, error) {
|
||||
tokenInfo, err := r.openaiOAuthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
53
backend/internal/service/usage_log.go
Normal file
53
backend/internal/service/usage_log.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package service
|
||||
|
||||
import "time"
|
||||
|
||||
const (
|
||||
BillingTypeBalance int8 = 0 // 钱包余额
|
||||
BillingTypeSubscription int8 = 1 // 订阅套餐
|
||||
)
|
||||
|
||||
type UsageLog struct {
|
||||
ID int64
|
||||
UserID int64
|
||||
ApiKeyID int64
|
||||
AccountID int64
|
||||
RequestID string
|
||||
Model string
|
||||
|
||||
GroupID *int64
|
||||
SubscriptionID *int64
|
||||
|
||||
InputTokens int
|
||||
OutputTokens int
|
||||
CacheCreationTokens int
|
||||
CacheReadTokens int
|
||||
|
||||
CacheCreation5mTokens int
|
||||
CacheCreation1hTokens int
|
||||
|
||||
InputCost float64
|
||||
OutputCost float64
|
||||
CacheCreationCost float64
|
||||
CacheReadCost float64
|
||||
TotalCost float64
|
||||
ActualCost float64
|
||||
RateMultiplier float64
|
||||
|
||||
BillingType int8
|
||||
Stream bool
|
||||
DurationMs *int
|
||||
FirstTokenMs *int
|
||||
|
||||
CreatedAt time.Time
|
||||
|
||||
User *User
|
||||
ApiKey *ApiKey
|
||||
Account *Account
|
||||
Group *Group
|
||||
Subscription *UserSubscription
|
||||
}
|
||||
|
||||
func (u *UsageLog) TotalTokens() int {
|
||||
return u.InputTokens + u.OutputTokens + u.CacheCreationTokens + u.CacheReadTokens
|
||||
}
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
)
|
||||
@@ -66,7 +65,7 @@ func NewUsageService(usageRepo UsageLogRepository, userRepo UserRepository) *Usa
|
||||
}
|
||||
|
||||
// Create 创建使用日志
|
||||
func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (*model.UsageLog, error) {
|
||||
func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (*UsageLog, error) {
|
||||
// 验证用户存在
|
||||
_, err := s.userRepo.GetByID(ctx, req.UserID)
|
||||
if err != nil {
|
||||
@@ -74,7 +73,7 @@ func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (*
|
||||
}
|
||||
|
||||
// 创建使用日志
|
||||
usageLog := &model.UsageLog{
|
||||
usageLog := &UsageLog{
|
||||
UserID: req.UserID,
|
||||
ApiKeyID: req.ApiKeyID,
|
||||
AccountID: req.AccountID,
|
||||
@@ -112,7 +111,7 @@ func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (*
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取使用日志
|
||||
func (s *UsageService) GetByID(ctx context.Context, id int64) (*model.UsageLog, error) {
|
||||
func (s *UsageService) GetByID(ctx context.Context, id int64) (*UsageLog, error) {
|
||||
log, err := s.usageRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get usage log: %w", err)
|
||||
@@ -121,7 +120,7 @@ func (s *UsageService) GetByID(ctx context.Context, id int64) (*model.UsageLog,
|
||||
}
|
||||
|
||||
// ListByUser 获取用户的使用日志列表
|
||||
func (s *UsageService) ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) {
|
||||
func (s *UsageService) ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error) {
|
||||
logs, pagination, err := s.usageRepo.ListByUser(ctx, userID, params)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("list usage logs: %w", err)
|
||||
@@ -130,7 +129,7 @@ func (s *UsageService) ListByUser(ctx context.Context, userID int64, params pagi
|
||||
}
|
||||
|
||||
// ListByApiKey 获取API Key的使用日志列表
|
||||
func (s *UsageService) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) {
|
||||
func (s *UsageService) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error) {
|
||||
logs, pagination, err := s.usageRepo.ListByApiKey(ctx, apiKeyID, params)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("list usage logs: %w", err)
|
||||
@@ -139,7 +138,7 @@ func (s *UsageService) ListByApiKey(ctx context.Context, apiKeyID int64, params
|
||||
}
|
||||
|
||||
// ListByAccount 获取账号的使用日志列表
|
||||
func (s *UsageService) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) {
|
||||
func (s *UsageService) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error) {
|
||||
logs, pagination, err := s.usageRepo.ListByAccount(ctx, accountID, params)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("list usage logs: %w", err)
|
||||
@@ -243,7 +242,7 @@ func (s *UsageService) GetDailyStats(ctx context.Context, userID int64, days int
|
||||
}
|
||||
|
||||
// calculateStats 计算统计数据
|
||||
func (s *UsageService) calculateStats(logs []model.UsageLog) *UsageStats {
|
||||
func (s *UsageService) calculateStats(logs []UsageLog) *UsageStats {
|
||||
stats := &UsageStats{}
|
||||
|
||||
for _, log := range logs {
|
||||
@@ -313,7 +312,7 @@ func (s *UsageService) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs [
|
||||
}
|
||||
|
||||
// ListWithFilters lists usage logs with admin filters.
|
||||
func (s *UsageService) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]model.UsageLog, *pagination.PaginationResult, error) {
|
||||
func (s *UsageService) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]UsageLog, *pagination.PaginationResult, error) {
|
||||
logs, result, err := s.usageRepo.ListWithFilters(ctx, params, filters)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("list usage logs with filters: %w", err)
|
||||
|
||||
63
backend/internal/service/user.go
Normal file
63
backend/internal/service/user.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
type User struct {
|
||||
ID int64
|
||||
Email string
|
||||
Username string
|
||||
Wechat string
|
||||
Notes string
|
||||
PasswordHash string
|
||||
Role string
|
||||
Balance float64
|
||||
Concurrency int
|
||||
Status string
|
||||
AllowedGroups []int64
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
|
||||
ApiKeys []ApiKey
|
||||
Subscriptions []UserSubscription
|
||||
}
|
||||
|
||||
func (u *User) IsAdmin() bool {
|
||||
return u.Role == RoleAdmin
|
||||
}
|
||||
|
||||
func (u *User) IsActive() bool {
|
||||
return u.Status == StatusActive
|
||||
}
|
||||
|
||||
// CanBindGroup checks whether a user can bind to a given group.
|
||||
// For standard groups:
|
||||
// - If AllowedGroups is non-empty, only allow binding to IDs in that list.
|
||||
// - If AllowedGroups is empty (nil or length 0), allow binding to any non-exclusive group.
|
||||
func (u *User) CanBindGroup(groupID int64, isExclusive bool) bool {
|
||||
if len(u.AllowedGroups) > 0 {
|
||||
for _, id := range u.AllowedGroups {
|
||||
if id == groupID {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
return !isExclusive
|
||||
}
|
||||
|
||||
func (u *User) SetPassword(password string) error {
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
u.PasswordHash = string(hash)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *User) CheckPassword(password string) bool {
|
||||
return bcrypt.CompareHashAndPassword([]byte(u.PasswordHash), []byte(password)) == nil
|
||||
}
|
||||
@@ -5,9 +5,7 @@ import (
|
||||
"fmt"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -17,15 +15,15 @@ var (
|
||||
)
|
||||
|
||||
type UserRepository interface {
|
||||
Create(ctx context.Context, user *model.User) error
|
||||
GetByID(ctx context.Context, id int64) (*model.User, error)
|
||||
GetByEmail(ctx context.Context, email string) (*model.User, error)
|
||||
GetFirstAdmin(ctx context.Context) (*model.User, error)
|
||||
Update(ctx context.Context, user *model.User) error
|
||||
Create(ctx context.Context, user *User) error
|
||||
GetByID(ctx context.Context, id int64) (*User, error)
|
||||
GetByEmail(ctx context.Context, email string) (*User, error)
|
||||
GetFirstAdmin(ctx context.Context) (*User, error)
|
||||
Update(ctx context.Context, user *User) error
|
||||
Delete(ctx context.Context, id int64) error
|
||||
|
||||
List(ctx context.Context, params pagination.PaginationParams) ([]model.User, *pagination.PaginationResult, error)
|
||||
ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, role, search string) ([]model.User, *pagination.PaginationResult, error)
|
||||
List(ctx context.Context, params pagination.PaginationParams) ([]User, *pagination.PaginationResult, error)
|
||||
ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, role, search string) ([]User, *pagination.PaginationResult, error)
|
||||
|
||||
UpdateBalance(ctx context.Context, id int64, amount float64) error
|
||||
DeductBalance(ctx context.Context, id int64, amount float64) error
|
||||
@@ -61,7 +59,7 @@ func NewUserService(userRepo UserRepository) *UserService {
|
||||
}
|
||||
|
||||
// GetFirstAdmin 获取首个管理员用户(用于 Admin API Key 认证)
|
||||
func (s *UserService) GetFirstAdmin(ctx context.Context) (*model.User, error) {
|
||||
func (s *UserService) GetFirstAdmin(ctx context.Context) (*User, error) {
|
||||
admin, err := s.userRepo.GetFirstAdmin(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get first admin: %w", err)
|
||||
@@ -70,7 +68,7 @@ func (s *UserService) GetFirstAdmin(ctx context.Context) (*model.User, error) {
|
||||
}
|
||||
|
||||
// GetProfile 获取用户资料
|
||||
func (s *UserService) GetProfile(ctx context.Context, userID int64) (*model.User, error) {
|
||||
func (s *UserService) GetProfile(ctx context.Context, userID int64) (*User, error) {
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get user: %w", err)
|
||||
@@ -79,7 +77,7 @@ func (s *UserService) GetProfile(ctx context.Context, userID int64) (*model.User
|
||||
}
|
||||
|
||||
// UpdateProfile 更新用户资料
|
||||
func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req UpdateProfileRequest) (*model.User, error) {
|
||||
func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req UpdateProfileRequest) (*User, error) {
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get user: %w", err)
|
||||
@@ -125,18 +123,14 @@ func (s *UserService) ChangePassword(ctx context.Context, userID int64, req Chan
|
||||
}
|
||||
|
||||
// 验证当前密码
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.CurrentPassword)); err != nil {
|
||||
if !user.CheckPassword(req.CurrentPassword) {
|
||||
return ErrPasswordIncorrect
|
||||
}
|
||||
|
||||
// 生成新密码哈希
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.NewPassword), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return fmt.Errorf("hash password: %w", err)
|
||||
if err := user.SetPassword(req.NewPassword); err != nil {
|
||||
return fmt.Errorf("set password: %w", err)
|
||||
}
|
||||
|
||||
user.PasswordHash = string(hashedPassword)
|
||||
|
||||
if err := s.userRepo.Update(ctx, user); err != nil {
|
||||
return fmt.Errorf("update user: %w", err)
|
||||
}
|
||||
@@ -145,7 +139,7 @@ func (s *UserService) ChangePassword(ctx context.Context, userID int64, req Chan
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取用户(管理员功能)
|
||||
func (s *UserService) GetByID(ctx context.Context, id int64) (*model.User, error) {
|
||||
func (s *UserService) GetByID(ctx context.Context, id int64) (*User, error) {
|
||||
user, err := s.userRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get user: %w", err)
|
||||
@@ -154,7 +148,7 @@ func (s *UserService) GetByID(ctx context.Context, id int64) (*model.User, error
|
||||
}
|
||||
|
||||
// List 获取用户列表(管理员功能)
|
||||
func (s *UserService) List(ctx context.Context, params pagination.PaginationParams) ([]model.User, *pagination.PaginationResult, error) {
|
||||
func (s *UserService) List(ctx context.Context, params pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
|
||||
users, pagination, err := s.userRepo.List(ctx, params)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("list users: %w", err)
|
||||
|
||||
124
backend/internal/service/user_subscription.go
Normal file
124
backend/internal/service/user_subscription.go
Normal file
@@ -0,0 +1,124 @@
|
||||
package service
|
||||
|
||||
import "time"
|
||||
|
||||
type UserSubscription struct {
|
||||
ID int64
|
||||
UserID int64
|
||||
GroupID int64
|
||||
|
||||
StartsAt time.Time
|
||||
ExpiresAt time.Time
|
||||
Status string
|
||||
|
||||
DailyWindowStart *time.Time
|
||||
WeeklyWindowStart *time.Time
|
||||
MonthlyWindowStart *time.Time
|
||||
|
||||
DailyUsageUSD float64
|
||||
WeeklyUsageUSD float64
|
||||
MonthlyUsageUSD float64
|
||||
|
||||
AssignedBy *int64
|
||||
AssignedAt time.Time
|
||||
Notes string
|
||||
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
|
||||
User *User
|
||||
Group *Group
|
||||
AssignedByUser *User
|
||||
}
|
||||
|
||||
func (s *UserSubscription) IsActive() bool {
|
||||
return s.Status == SubscriptionStatusActive && time.Now().Before(s.ExpiresAt)
|
||||
}
|
||||
|
||||
func (s *UserSubscription) IsExpired() bool {
|
||||
return time.Now().After(s.ExpiresAt)
|
||||
}
|
||||
|
||||
func (s *UserSubscription) DaysRemaining() int {
|
||||
if s.IsExpired() {
|
||||
return 0
|
||||
}
|
||||
return int(time.Until(s.ExpiresAt).Hours() / 24)
|
||||
}
|
||||
|
||||
func (s *UserSubscription) IsWindowActivated() bool {
|
||||
return s.DailyWindowStart != nil || s.WeeklyWindowStart != nil || s.MonthlyWindowStart != nil
|
||||
}
|
||||
|
||||
func (s *UserSubscription) NeedsDailyReset() bool {
|
||||
if s.DailyWindowStart == nil {
|
||||
return false
|
||||
}
|
||||
return time.Since(*s.DailyWindowStart) >= 24*time.Hour
|
||||
}
|
||||
|
||||
func (s *UserSubscription) NeedsWeeklyReset() bool {
|
||||
if s.WeeklyWindowStart == nil {
|
||||
return false
|
||||
}
|
||||
return time.Since(*s.WeeklyWindowStart) >= 7*24*time.Hour
|
||||
}
|
||||
|
||||
func (s *UserSubscription) NeedsMonthlyReset() bool {
|
||||
if s.MonthlyWindowStart == nil {
|
||||
return false
|
||||
}
|
||||
return time.Since(*s.MonthlyWindowStart) >= 30*24*time.Hour
|
||||
}
|
||||
|
||||
func (s *UserSubscription) DailyResetTime() *time.Time {
|
||||
if s.DailyWindowStart == nil {
|
||||
return nil
|
||||
}
|
||||
t := s.DailyWindowStart.Add(24 * time.Hour)
|
||||
return &t
|
||||
}
|
||||
|
||||
func (s *UserSubscription) WeeklyResetTime() *time.Time {
|
||||
if s.WeeklyWindowStart == nil {
|
||||
return nil
|
||||
}
|
||||
t := s.WeeklyWindowStart.Add(7 * 24 * time.Hour)
|
||||
return &t
|
||||
}
|
||||
|
||||
func (s *UserSubscription) MonthlyResetTime() *time.Time {
|
||||
if s.MonthlyWindowStart == nil {
|
||||
return nil
|
||||
}
|
||||
t := s.MonthlyWindowStart.Add(30 * 24 * time.Hour)
|
||||
return &t
|
||||
}
|
||||
|
||||
func (s *UserSubscription) CheckDailyLimit(group *Group, additionalCost float64) bool {
|
||||
if !group.HasDailyLimit() {
|
||||
return true
|
||||
}
|
||||
return s.DailyUsageUSD+additionalCost <= *group.DailyLimitUSD
|
||||
}
|
||||
|
||||
func (s *UserSubscription) CheckWeeklyLimit(group *Group, additionalCost float64) bool {
|
||||
if !group.HasWeeklyLimit() {
|
||||
return true
|
||||
}
|
||||
return s.WeeklyUsageUSD+additionalCost <= *group.WeeklyLimitUSD
|
||||
}
|
||||
|
||||
func (s *UserSubscription) CheckMonthlyLimit(group *Group, additionalCost float64) bool {
|
||||
if !group.HasMonthlyLimit() {
|
||||
return true
|
||||
}
|
||||
return s.MonthlyUsageUSD+additionalCost <= *group.MonthlyLimitUSD
|
||||
}
|
||||
|
||||
func (s *UserSubscription) CheckAllLimits(group *Group, additionalCost float64) (daily, weekly, monthly bool) {
|
||||
daily = s.CheckDailyLimit(group, additionalCost)
|
||||
weekly = s.CheckWeeklyLimit(group, additionalCost)
|
||||
monthly = s.CheckMonthlyLimit(group, additionalCost)
|
||||
return
|
||||
}
|
||||
@@ -4,22 +4,21 @@ import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
type UserSubscriptionRepository interface {
|
||||
Create(ctx context.Context, sub *model.UserSubscription) error
|
||||
GetByID(ctx context.Context, id int64) (*model.UserSubscription, error)
|
||||
GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error)
|
||||
GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error)
|
||||
Update(ctx context.Context, sub *model.UserSubscription) error
|
||||
Create(ctx context.Context, sub *UserSubscription) error
|
||||
GetByID(ctx context.Context, id int64) (*UserSubscription, error)
|
||||
GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*UserSubscription, error)
|
||||
GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*UserSubscription, error)
|
||||
Update(ctx context.Context, sub *UserSubscription) error
|
||||
Delete(ctx context.Context, id int64) error
|
||||
|
||||
ListByUserID(ctx context.Context, userID int64) ([]model.UserSubscription, error)
|
||||
ListActiveByUserID(ctx context.Context, userID int64) ([]model.UserSubscription, error)
|
||||
ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.UserSubscription, *pagination.PaginationResult, error)
|
||||
List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]model.UserSubscription, *pagination.PaginationResult, error)
|
||||
ListByUserID(ctx context.Context, userID int64) ([]UserSubscription, error)
|
||||
ListActiveByUserID(ctx context.Context, userID int64) ([]UserSubscription, error)
|
||||
ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]UserSubscription, *pagination.PaginationResult, error)
|
||||
List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]UserSubscription, *pagination.PaginationResult, error)
|
||||
|
||||
ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error)
|
||||
ExtendExpiry(ctx context.Context, subscriptionID int64, newExpiresAt time.Time) error
|
||||
|
||||
Reference in New Issue
Block a user