feat: rebuild auth identity foundation flow

This commit is contained in:
IanShaw027
2026-04-20 17:39:57 +08:00
parent fbd0a2e3c4
commit e9de839d87
123 changed files with 33599 additions and 772 deletions

View File

@@ -44,6 +44,15 @@ func (s *userRepoStubForGroupUpdate) GetFirstAdmin(context.Context) (*User, erro
}
func (s *userRepoStubForGroupUpdate) Update(context.Context, *User) error { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) Delete(context.Context, int64) error { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) GetUserAvatar(context.Context, int64) (*UserAvatar, error) {
panic("unexpected")
}
func (s *userRepoStubForGroupUpdate) UpsertUserAvatar(context.Context, int64, UpsertUserAvatarInput) (*UserAvatar, error) {
panic("unexpected")
}
func (s *userRepoStubForGroupUpdate) DeleteUserAvatar(context.Context, int64) error {
panic("unexpected")
}
func (s *userRepoStubForGroupUpdate) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
panic("unexpected")
}

View File

@@ -62,6 +62,18 @@ func (s *userRepoStub) Delete(ctx context.Context, id int64) error {
return s.deleteErr
}
func (s *userRepoStub) GetUserAvatar(ctx context.Context, userID int64) (*UserAvatar, error) {
panic("unexpected GetUserAvatar call")
}
func (s *userRepoStub) UpsertUserAvatar(ctx context.Context, userID int64, input UpsertUserAvatarInput) (*UserAvatar, error) {
panic("unexpected UpsertUserAvatar call")
}
func (s *userRepoStub) DeleteUserAvatar(ctx context.Context, userID int64) error {
panic("unexpected DeleteUserAvatar call")
}
func (s *userRepoStub) List(ctx context.Context, params pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
panic("unexpected List call")
}

View File

@@ -0,0 +1,326 @@
package service
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"fmt"
"strings"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
"github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
var (
ErrPendingAuthSessionNotFound = infraerrors.NotFound("PENDING_AUTH_SESSION_NOT_FOUND", "pending auth session not found")
ErrPendingAuthSessionExpired = infraerrors.Unauthorized("PENDING_AUTH_SESSION_EXPIRED", "pending auth session has expired")
ErrPendingAuthSessionConsumed = infraerrors.Unauthorized("PENDING_AUTH_SESSION_CONSUMED", "pending auth session has already been used")
ErrPendingAuthCodeInvalid = infraerrors.Unauthorized("PENDING_AUTH_CODE_INVALID", "pending auth completion code is invalid")
ErrPendingAuthCodeExpired = infraerrors.Unauthorized("PENDING_AUTH_CODE_EXPIRED", "pending auth completion code has expired")
ErrPendingAuthCodeConsumed = infraerrors.Unauthorized("PENDING_AUTH_CODE_CONSUMED", "pending auth completion code has already been used")
ErrPendingAuthBrowserMismatch = infraerrors.Unauthorized("PENDING_AUTH_BROWSER_MISMATCH", "pending auth completion code does not match this browser session")
)
const (
defaultPendingAuthTTL = 15 * time.Minute
defaultPendingAuthCompletionTTL = 5 * time.Minute
)
type PendingAuthIdentityKey struct {
ProviderType string
ProviderKey string
ProviderSubject string
}
type CreatePendingAuthSessionInput struct {
SessionToken string
Intent string
Identity PendingAuthIdentityKey
TargetUserID *int64
RedirectTo string
ResolvedEmail string
RegistrationPasswordHash string
BrowserSessionKey string
UpstreamIdentityClaims map[string]any
LocalFlowState map[string]any
ExpiresAt time.Time
}
type IssuePendingAuthCompletionCodeInput struct {
PendingAuthSessionID int64
BrowserSessionKey string
TTL time.Duration
}
type IssuePendingAuthCompletionCodeResult struct {
Code string
ExpiresAt time.Time
}
type PendingIdentityAdoptionDecisionInput struct {
PendingAuthSessionID int64
IdentityID *int64
AdoptDisplayName bool
AdoptAvatar bool
}
type AuthPendingIdentityService struct {
entClient *dbent.Client
}
func NewAuthPendingIdentityService(entClient *dbent.Client) *AuthPendingIdentityService {
return &AuthPendingIdentityService{entClient: entClient}
}
func (s *AuthPendingIdentityService) CreatePendingSession(ctx context.Context, input CreatePendingAuthSessionInput) (*dbent.PendingAuthSession, error) {
if s == nil || s.entClient == nil {
return nil, fmt.Errorf("pending auth ent client is not configured")
}
sessionToken := strings.TrimSpace(input.SessionToken)
if sessionToken == "" {
var err error
sessionToken, err = randomOpaqueToken(24)
if err != nil {
return nil, err
}
}
expiresAt := input.ExpiresAt.UTC()
if expiresAt.IsZero() {
expiresAt = time.Now().UTC().Add(defaultPendingAuthTTL)
}
create := s.entClient.PendingAuthSession.Create().
SetSessionToken(sessionToken).
SetIntent(strings.TrimSpace(input.Intent)).
SetProviderType(strings.TrimSpace(input.Identity.ProviderType)).
SetProviderKey(strings.TrimSpace(input.Identity.ProviderKey)).
SetProviderSubject(strings.TrimSpace(input.Identity.ProviderSubject)).
SetRedirectTo(strings.TrimSpace(input.RedirectTo)).
SetResolvedEmail(strings.TrimSpace(input.ResolvedEmail)).
SetRegistrationPasswordHash(strings.TrimSpace(input.RegistrationPasswordHash)).
SetBrowserSessionKey(strings.TrimSpace(input.BrowserSessionKey)).
SetUpstreamIdentityClaims(copyPendingMap(input.UpstreamIdentityClaims)).
SetLocalFlowState(copyPendingMap(input.LocalFlowState)).
SetExpiresAt(expiresAt)
if input.TargetUserID != nil {
create = create.SetTargetUserID(*input.TargetUserID)
}
return create.Save(ctx)
}
func (s *AuthPendingIdentityService) IssueCompletionCode(ctx context.Context, input IssuePendingAuthCompletionCodeInput) (*IssuePendingAuthCompletionCodeResult, error) {
if s == nil || s.entClient == nil {
return nil, fmt.Errorf("pending auth ent client is not configured")
}
session, err := s.entClient.PendingAuthSession.Get(ctx, input.PendingAuthSessionID)
if err != nil {
if dbent.IsNotFound(err) {
return nil, ErrPendingAuthSessionNotFound
}
return nil, err
}
code, err := randomOpaqueToken(24)
if err != nil {
return nil, err
}
ttl := input.TTL
if ttl <= 0 {
ttl = defaultPendingAuthCompletionTTL
}
expiresAt := time.Now().UTC().Add(ttl)
update := s.entClient.PendingAuthSession.UpdateOneID(session.ID).
SetCompletionCodeHash(hashPendingAuthCode(code)).
SetCompletionCodeExpiresAt(expiresAt)
if strings.TrimSpace(input.BrowserSessionKey) != "" {
update = update.SetBrowserSessionKey(strings.TrimSpace(input.BrowserSessionKey))
}
if _, err := update.Save(ctx); err != nil {
return nil, err
}
return &IssuePendingAuthCompletionCodeResult{
Code: code,
ExpiresAt: expiresAt,
}, nil
}
func (s *AuthPendingIdentityService) ConsumeCompletionCode(ctx context.Context, rawCode, browserSessionKey string) (*dbent.PendingAuthSession, error) {
if s == nil || s.entClient == nil {
return nil, fmt.Errorf("pending auth ent client is not configured")
}
codeHash := hashPendingAuthCode(strings.TrimSpace(rawCode))
session, err := s.entClient.PendingAuthSession.Query().
Where(pendingauthsession.CompletionCodeHashEQ(codeHash)).
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return nil, ErrPendingAuthCodeInvalid
}
return nil, err
}
return s.consumeSession(ctx, session, browserSessionKey, ErrPendingAuthCodeExpired, ErrPendingAuthCodeConsumed)
}
func (s *AuthPendingIdentityService) ConsumeBrowserSession(ctx context.Context, sessionToken, browserSessionKey string) (*dbent.PendingAuthSession, error) {
if s == nil || s.entClient == nil {
return nil, fmt.Errorf("pending auth ent client is not configured")
}
session, err := s.getBrowserSession(ctx, sessionToken)
if err != nil {
return nil, err
}
return s.consumeSession(ctx, session, browserSessionKey, ErrPendingAuthSessionExpired, ErrPendingAuthSessionConsumed)
}
func (s *AuthPendingIdentityService) GetBrowserSession(ctx context.Context, sessionToken, browserSessionKey string) (*dbent.PendingAuthSession, error) {
if s == nil || s.entClient == nil {
return nil, fmt.Errorf("pending auth ent client is not configured")
}
session, err := s.getBrowserSession(ctx, sessionToken)
if err != nil {
return nil, err
}
if err := validatePendingSessionState(session, browserSessionKey, ErrPendingAuthSessionExpired, ErrPendingAuthSessionConsumed); err != nil {
return nil, err
}
return session, nil
}
func (s *AuthPendingIdentityService) getBrowserSession(ctx context.Context, sessionToken string) (*dbent.PendingAuthSession, error) {
if s == nil || s.entClient == nil {
return nil, fmt.Errorf("pending auth ent client is not configured")
}
sessionToken = strings.TrimSpace(sessionToken)
if sessionToken == "" {
return nil, ErrPendingAuthSessionNotFound
}
session, err := s.entClient.PendingAuthSession.Query().
Where(pendingauthsession.SessionTokenEQ(sessionToken)).
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return nil, ErrPendingAuthSessionNotFound
}
return nil, err
}
return session, nil
}
func (s *AuthPendingIdentityService) consumeSession(
ctx context.Context,
session *dbent.PendingAuthSession,
browserSessionKey string,
expiredErr error,
consumedErr error,
) (*dbent.PendingAuthSession, error) {
if err := validatePendingSessionState(session, browserSessionKey, expiredErr, consumedErr); err != nil {
return nil, err
}
now := time.Now().UTC()
updated, err := s.entClient.PendingAuthSession.UpdateOneID(session.ID).
SetConsumedAt(now).
SetCompletionCodeHash("").
ClearCompletionCodeExpiresAt().
Save(ctx)
if err != nil {
return nil, err
}
return updated, nil
}
func validatePendingSessionState(session *dbent.PendingAuthSession, browserSessionKey string, expiredErr error, consumedErr error) error {
if session == nil {
return ErrPendingAuthSessionNotFound
}
now := time.Now().UTC()
if session.ConsumedAt != nil {
return consumedErr
}
if !session.ExpiresAt.IsZero() && now.After(session.ExpiresAt) {
return expiredErr
}
if session.CompletionCodeExpiresAt != nil && now.After(*session.CompletionCodeExpiresAt) {
return expiredErr
}
if strings.TrimSpace(session.BrowserSessionKey) != "" && strings.TrimSpace(browserSessionKey) != strings.TrimSpace(session.BrowserSessionKey) {
return ErrPendingAuthBrowserMismatch
}
return nil
}
func (s *AuthPendingIdentityService) UpsertAdoptionDecision(ctx context.Context, input PendingIdentityAdoptionDecisionInput) (*dbent.IdentityAdoptionDecision, error) {
if s == nil || s.entClient == nil {
return nil, fmt.Errorf("pending auth ent client is not configured")
}
existing, err := s.entClient.IdentityAdoptionDecision.Query().
Where(identityadoptiondecision.PendingAuthSessionIDEQ(input.PendingAuthSessionID)).
Only(ctx)
if err != nil && !dbent.IsNotFound(err) {
return nil, err
}
if existing == nil {
create := s.entClient.IdentityAdoptionDecision.Create().
SetPendingAuthSessionID(input.PendingAuthSessionID).
SetAdoptDisplayName(input.AdoptDisplayName).
SetAdoptAvatar(input.AdoptAvatar).
SetDecidedAt(time.Now().UTC())
if input.IdentityID != nil {
create = create.SetIdentityID(*input.IdentityID)
}
return create.Save(ctx)
}
update := s.entClient.IdentityAdoptionDecision.UpdateOneID(existing.ID).
SetAdoptDisplayName(input.AdoptDisplayName).
SetAdoptAvatar(input.AdoptAvatar)
if input.IdentityID != nil {
update = update.SetIdentityID(*input.IdentityID)
}
return update.Save(ctx)
}
func copyPendingMap(in map[string]any) map[string]any {
if len(in) == 0 {
return map[string]any{}
}
out := make(map[string]any, len(in))
for k, v := range in {
out[k] = v
}
return out
}
func randomOpaqueToken(byteLen int) (string, error) {
if byteLen <= 0 {
byteLen = 16
}
buf := make([]byte, byteLen)
if _, err := rand.Read(buf); err != nil {
return "", err
}
return hex.EncodeToString(buf), nil
}
func hashPendingAuthCode(code string) string {
sum := sha256.Sum256([]byte(code))
return hex.EncodeToString(sum[:])
}

View File

@@ -0,0 +1,224 @@
//go:build unit
package service
import (
"context"
"database/sql"
"testing"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/enttest"
"github.com/stretchr/testify/require"
"entgo.io/ent/dialect"
entsql "entgo.io/ent/dialect/sql"
_ "modernc.org/sqlite"
)
func newAuthPendingIdentityServiceTestClient(t *testing.T) (*AuthPendingIdentityService, *dbent.Client) {
t.Helper()
db, err := sql.Open("sqlite", "file:auth_pending_identity_service?mode=memory&cache=shared")
require.NoError(t, err)
t.Cleanup(func() { _ = db.Close() })
_, err = db.Exec("PRAGMA foreign_keys = ON")
require.NoError(t, err)
drv := entsql.OpenDB(dialect.SQLite, db)
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
t.Cleanup(func() { _ = client.Close() })
return NewAuthPendingIdentityService(client), client
}
func TestAuthPendingIdentityService_CreatePendingSessionStoresSeparatedState(t *testing.T) {
svc, client := newAuthPendingIdentityServiceTestClient(t)
ctx := context.Background()
targetUser, err := client.User.Create().
SetEmail("pending-target@example.com").
SetPasswordHash("hash").
SetRole(RoleUser).
SetStatus(StatusActive).
Save(ctx)
require.NoError(t, err)
session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
Intent: "bind_current_user",
Identity: PendingAuthIdentityKey{
ProviderType: "wechat",
ProviderKey: "wechat-open",
ProviderSubject: "union-123",
},
TargetUserID: &targetUser.ID,
RedirectTo: "/profile",
ResolvedEmail: "user@example.com",
BrowserSessionKey: "browser-1",
UpstreamIdentityClaims: map[string]any{"nickname": "wx-user", "avatar_url": "https://cdn.example/avatar.png"},
LocalFlowState: map[string]any{"step": "email_required"},
})
require.NoError(t, err)
require.NotEmpty(t, session.SessionToken)
require.Equal(t, "bind_current_user", session.Intent)
require.Equal(t, "wechat", session.ProviderType)
require.NotNil(t, session.TargetUserID)
require.Equal(t, targetUser.ID, *session.TargetUserID)
require.Equal(t, "wx-user", session.UpstreamIdentityClaims["nickname"])
require.Equal(t, "email_required", session.LocalFlowState["step"])
}
func TestAuthPendingIdentityService_CompletionCodeIsBrowserBoundAndOneTime(t *testing.T) {
svc, _ := newAuthPendingIdentityServiceTestClient(t)
ctx := context.Background()
session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
Intent: "login",
Identity: PendingAuthIdentityKey{
ProviderType: "linuxdo",
ProviderKey: "linuxdo-main",
ProviderSubject: "subject-1",
},
BrowserSessionKey: "browser-expected",
UpstreamIdentityClaims: map[string]any{"nickname": "linux-user"},
LocalFlowState: map[string]any{"step": "pending"},
})
require.NoError(t, err)
issued, err := svc.IssueCompletionCode(ctx, IssuePendingAuthCompletionCodeInput{
PendingAuthSessionID: session.ID,
BrowserSessionKey: "browser-expected",
})
require.NoError(t, err)
require.NotEmpty(t, issued.Code)
_, err = svc.ConsumeCompletionCode(ctx, issued.Code, "browser-other")
require.ErrorIs(t, err, ErrPendingAuthBrowserMismatch)
consumed, err := svc.ConsumeCompletionCode(ctx, issued.Code, "browser-expected")
require.NoError(t, err)
require.NotNil(t, consumed.ConsumedAt)
require.Empty(t, consumed.CompletionCodeHash)
require.Nil(t, consumed.CompletionCodeExpiresAt)
_, err = svc.ConsumeCompletionCode(ctx, issued.Code, "browser-expected")
require.ErrorIs(t, err, ErrPendingAuthCodeInvalid)
}
func TestAuthPendingIdentityService_CompletionCodeExpires(t *testing.T) {
svc, client := newAuthPendingIdentityServiceTestClient(t)
ctx := context.Background()
session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
Intent: "login",
Identity: PendingAuthIdentityKey{
ProviderType: "oidc",
ProviderKey: "https://issuer.example",
ProviderSubject: "subject-1",
},
BrowserSessionKey: "browser-expired",
})
require.NoError(t, err)
issued, err := svc.IssueCompletionCode(ctx, IssuePendingAuthCompletionCodeInput{
PendingAuthSessionID: session.ID,
BrowserSessionKey: "browser-expired",
TTL: time.Second,
})
require.NoError(t, err)
_, err = client.PendingAuthSession.UpdateOneID(session.ID).
SetCompletionCodeExpiresAt(time.Now().UTC().Add(-time.Minute)).
Save(ctx)
require.NoError(t, err)
_, err = svc.ConsumeCompletionCode(ctx, issued.Code, "browser-expired")
require.ErrorIs(t, err, ErrPendingAuthCodeExpired)
}
func TestAuthPendingIdentityService_UpsertAdoptionDecision(t *testing.T) {
svc, client := newAuthPendingIdentityServiceTestClient(t)
ctx := context.Background()
user, err := client.User.Create().
SetEmail("adoption@example.com").
SetPasswordHash("hash").
SetRole(RoleUser).
SetStatus(StatusActive).
Save(ctx)
require.NoError(t, err)
identity, err := client.AuthIdentity.Create().
SetUserID(user.ID).
SetProviderType("wechat").
SetProviderKey("wechat-open").
SetProviderSubject("union-adoption").
SetMetadata(map[string]any{}).
Save(ctx)
require.NoError(t, err)
session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
Intent: "bind_current_user",
Identity: PendingAuthIdentityKey{
ProviderType: "wechat",
ProviderKey: "wechat-open",
ProviderSubject: "union-adoption",
},
})
require.NoError(t, err)
first, err := svc.UpsertAdoptionDecision(ctx, PendingIdentityAdoptionDecisionInput{
PendingAuthSessionID: session.ID,
AdoptDisplayName: true,
AdoptAvatar: false,
})
require.NoError(t, err)
require.True(t, first.AdoptDisplayName)
require.False(t, first.AdoptAvatar)
require.Nil(t, first.IdentityID)
second, err := svc.UpsertAdoptionDecision(ctx, PendingIdentityAdoptionDecisionInput{
PendingAuthSessionID: session.ID,
IdentityID: &identity.ID,
AdoptDisplayName: true,
AdoptAvatar: true,
})
require.NoError(t, err)
require.Equal(t, first.ID, second.ID)
require.NotNil(t, second.IdentityID)
require.Equal(t, identity.ID, *second.IdentityID)
require.True(t, second.AdoptAvatar)
}
func TestAuthPendingIdentityService_ConsumeBrowserSession(t *testing.T) {
svc, _ := newAuthPendingIdentityServiceTestClient(t)
ctx := context.Background()
session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
Intent: "login",
Identity: PendingAuthIdentityKey{
ProviderType: "linuxdo",
ProviderKey: "linuxdo",
ProviderSubject: "subject-session-token",
},
BrowserSessionKey: "browser-session",
LocalFlowState: map[string]any{
"completion_response": map[string]any{
"access_token": "token",
},
},
})
require.NoError(t, err)
_, err = svc.ConsumeBrowserSession(ctx, session.SessionToken, "browser-other")
require.ErrorIs(t, err, ErrPendingAuthBrowserMismatch)
consumed, err := svc.ConsumeBrowserSession(ctx, session.SessionToken, "browser-session")
require.NoError(t, err)
require.NotNil(t, consumed.ConsumedAt)
_, err = svc.ConsumeBrowserSession(ctx, session.SessionToken, "browser-session")
require.ErrorIs(t, err, ErrPendingAuthSessionConsumed)
}

View File

@@ -13,6 +13,7 @@ import (
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
@@ -106,6 +107,13 @@ func NewAuthService(
}
}
func (s *AuthService) EntClient() *dbent.Client {
if s == nil {
return nil
}
return s.entClient
}
// Register 用户注册返回token和用户
func (s *AuthService) Register(ctx context.Context, email, password string) (string, *User, error) {
return s.RegisterWithVerification(ctx, email, password, "", "", "")
@@ -205,6 +213,7 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
logger.LegacyPrintf("service.auth", "[Auth] Database error creating user: %v", err)
return "", nil, ErrServiceUnavailable
}
s.postAuthUserBootstrap(ctx, user, "email", true)
s.assignDefaultSubscriptions(ctx, user.ID)
// 标记邀请码为已使用(如果使用了邀请码)
@@ -421,6 +430,7 @@ func (s *AuthService) Login(ctx context.Context, email, password string) (string
if !user.IsActive() {
return "", nil, ErrUserNotActive
}
s.touchUserLogin(ctx, user.ID)
// 生成JWT token
token, err := s.GenerateToken(user)
@@ -501,6 +511,7 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
}
} else {
user = newUser
s.postAuthUserBootstrap(ctx, user, inferLegacySignupSource(email), true)
s.assignDefaultSubscriptions(ctx, user.ID)
}
} else {
@@ -520,6 +531,7 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
logger.LegacyPrintf("service.auth", "[Auth] Failed to update username after oauth login: %v", err)
}
}
s.touchUserLogin(ctx, user.ID)
token, err := s.GenerateToken(user)
if err != nil {
@@ -630,6 +642,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
return nil, nil, ErrServiceUnavailable
}
user = newUser
s.postAuthUserBootstrap(ctx, user, inferLegacySignupSource(email), true)
s.assignDefaultSubscriptions(ctx, user.ID)
}
} else {
@@ -646,6 +659,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
}
} else {
user = newUser
s.postAuthUserBootstrap(ctx, user, inferLegacySignupSource(email), true)
s.assignDefaultSubscriptions(ctx, user.ID)
if invitationRedeemCode != nil {
if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil {
@@ -670,6 +684,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
logger.LegacyPrintf("service.auth", "[Auth] Failed to update username after oauth login: %v", err)
}
}
s.touchUserLogin(ctx, user.ID)
tokenPair, err := s.GenerateTokenPair(ctx, user, "")
if err != nil {
@@ -678,63 +693,6 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
return tokenPair, user, nil
}
// pendingOAuthTokenTTL is the validity period for pending OAuth tokens.
const pendingOAuthTokenTTL = 10 * time.Minute
// pendingOAuthPurpose is the purpose claim value for pending OAuth registration tokens.
const pendingOAuthPurpose = "pending_oauth_registration"
type pendingOAuthClaims struct {
Email string `json:"email"`
Username string `json:"username"`
Purpose string `json:"purpose"`
jwt.RegisteredClaims
}
// CreatePendingOAuthToken generates a short-lived JWT that carries the OAuth identity
// while waiting for the user to supply an invitation code.
func (s *AuthService) CreatePendingOAuthToken(email, username string) (string, error) {
now := time.Now()
claims := &pendingOAuthClaims{
Email: email,
Username: username,
Purpose: pendingOAuthPurpose,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(now.Add(pendingOAuthTokenTTL)),
IssuedAt: jwt.NewNumericDate(now),
NotBefore: jwt.NewNumericDate(now),
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString([]byte(s.cfg.JWT.Secret))
}
// VerifyPendingOAuthToken validates a pending OAuth token and returns the embedded identity.
// Returns ErrInvalidToken when the token is invalid or expired.
func (s *AuthService) VerifyPendingOAuthToken(tokenStr string) (email, username string, err error) {
if len(tokenStr) > maxTokenLength {
return "", "", ErrInvalidToken
}
parser := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name}))
token, parseErr := parser.ParseWithClaims(tokenStr, &pendingOAuthClaims{}, func(t *jwt.Token) (any, error) {
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
}
return []byte(s.cfg.JWT.Secret), nil
})
if parseErr != nil {
return "", "", ErrInvalidToken
}
claims, ok := token.Claims.(*pendingOAuthClaims)
if !ok || !token.Valid {
return "", "", ErrInvalidToken
}
if claims.Purpose != pendingOAuthPurpose {
return "", "", ErrInvalidToken
}
return claims.Email, claims.Username, nil
}
func (s *AuthService) assignDefaultSubscriptions(ctx context.Context, userID int64) {
if s.settingService == nil || s.defaultSubAssigner == nil || userID <= 0 {
return
@@ -752,6 +710,95 @@ func (s *AuthService) assignDefaultSubscriptions(ctx context.Context, userID int
}
}
func (s *AuthService) postAuthUserBootstrap(ctx context.Context, user *User, signupSource string, touchLogin bool) {
if user == nil || user.ID <= 0 {
return
}
if strings.TrimSpace(signupSource) == "" {
signupSource = "email"
}
s.updateUserSignupSource(ctx, user.ID, signupSource)
if signupSource == "email" {
s.ensureEmailAuthIdentity(ctx, user)
}
if touchLogin {
s.touchUserLogin(ctx, user.ID)
}
}
func (s *AuthService) updateUserSignupSource(ctx context.Context, userID int64, signupSource string) {
if s == nil || s.entClient == nil || userID <= 0 {
return
}
if strings.TrimSpace(signupSource) == "" {
return
}
if err := s.entClient.User.UpdateOneID(userID).
SetSignupSource(signupSource).
Exec(ctx); err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to update signup source: user_id=%d source=%s err=%v", userID, signupSource, err)
}
}
func (s *AuthService) touchUserLogin(ctx context.Context, userID int64) {
if s == nil || s.entClient == nil || userID <= 0 {
return
}
now := time.Now().UTC()
if err := s.entClient.User.UpdateOneID(userID).
SetLastLoginAt(now).
SetLastActiveAt(now).
Exec(ctx); err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to touch login timestamps: user_id=%d err=%v", userID, err)
}
}
func (s *AuthService) ensureEmailAuthIdentity(ctx context.Context, user *User) {
if s == nil || s.entClient == nil || user == nil || user.ID <= 0 {
return
}
email := strings.ToLower(strings.TrimSpace(user.Email))
if email == "" || isReservedEmail(email) {
return
}
if err := s.entClient.AuthIdentity.Create().
SetUserID(user.ID).
SetProviderType("email").
SetProviderKey("email").
SetProviderSubject(email).
SetVerifiedAt(time.Now().UTC()).
SetMetadata(map[string]any{
"source": "auth_service_dual_write",
}).
OnConflictColumns(
authidentity.FieldProviderType,
authidentity.FieldProviderKey,
authidentity.FieldProviderSubject,
).
DoNothing().
Exec(ctx); err != nil {
logger.LegacyPrintf("service.auth", "[Auth] Failed to ensure email auth identity: user_id=%d email=%s err=%v", user.ID, email, err)
}
}
func inferLegacySignupSource(email string) string {
normalized := strings.ToLower(strings.TrimSpace(email))
switch {
case strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain):
return "linuxdo"
case strings.HasSuffix(normalized, OIDCConnectSyntheticEmailDomain):
return "oidc"
case strings.HasSuffix(normalized, WeChatConnectSyntheticEmailDomain):
return "wechat"
default:
return "email"
}
}
func (s *AuthService) validateRegistrationEmailPolicy(ctx context.Context, email string) error {
if s.settingService == nil {
return nil
@@ -834,7 +881,8 @@ func randomHexString(byteLength int) (string, error) {
func isReservedEmail(email string) bool {
normalized := strings.ToLower(strings.TrimSpace(email))
return strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain) ||
strings.HasSuffix(normalized, OIDCConnectSyntheticEmailDomain)
strings.HasSuffix(normalized, OIDCConnectSyntheticEmailDomain) ||
strings.HasSuffix(normalized, WeChatConnectSyntheticEmailDomain)
}
// GenerateToken 生成JWT access token

View File

@@ -0,0 +1,153 @@
//go:build unit
package service_test
import (
"context"
"database/sql"
"testing"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/authidentity"
"github.com/Wei-Shaw/sub2api/ent/enttest"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/repository"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
"entgo.io/ent/dialect"
entsql "entgo.io/ent/dialect/sql"
_ "modernc.org/sqlite"
)
type authIdentitySettingRepoStub struct {
values map[string]string
}
func (s *authIdentitySettingRepoStub) Get(context.Context, string) (*service.Setting, error) {
panic("unexpected Get call")
}
func (s *authIdentitySettingRepoStub) GetValue(_ context.Context, key string) (string, error) {
if v, ok := s.values[key]; ok {
return v, nil
}
return "", service.ErrSettingNotFound
}
func (s *authIdentitySettingRepoStub) Set(context.Context, string, string) error {
panic("unexpected Set call")
}
func (s *authIdentitySettingRepoStub) GetMultiple(context.Context, []string) (map[string]string, error) {
panic("unexpected GetMultiple call")
}
func (s *authIdentitySettingRepoStub) SetMultiple(context.Context, map[string]string) error {
panic("unexpected SetMultiple call")
}
func (s *authIdentitySettingRepoStub) GetAll(context.Context) (map[string]string, error) {
panic("unexpected GetAll call")
}
func (s *authIdentitySettingRepoStub) Delete(context.Context, string) error {
panic("unexpected Delete call")
}
func newAuthServiceWithEnt(t *testing.T) (*service.AuthService, service.UserRepository, *dbent.Client) {
t.Helper()
db, err := sql.Open("sqlite", "file:auth_service_identity_sync?mode=memory&cache=shared")
require.NoError(t, err)
t.Cleanup(func() { _ = db.Close() })
_, err = db.Exec("PRAGMA foreign_keys = ON")
require.NoError(t, err)
drv := entsql.OpenDB(dialect.SQLite, db)
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
t.Cleanup(func() { _ = client.Close() })
repo := repository.NewUserRepository(client, db)
cfg := &config.Config{
JWT: config.JWTConfig{
Secret: "test-auth-identity-secret",
ExpireHour: 1,
},
Default: config.DefaultConfig{
UserBalance: 3.5,
UserConcurrency: 2,
},
}
settingSvc := service.NewSettingService(&authIdentitySettingRepoStub{
values: map[string]string{
service.SettingKeyRegistrationEnabled: "true",
},
}, cfg)
svc := service.NewAuthService(client, repo, nil, nil, cfg, settingSvc, nil, nil, nil, nil, nil)
return svc, repo, client
}
func TestAuthServiceRegisterDualWritesEmailIdentity(t *testing.T) {
svc, _, client := newAuthServiceWithEnt(t)
ctx := context.Background()
token, user, err := svc.Register(ctx, "user@example.com", "password")
require.NoError(t, err)
require.NotEmpty(t, token)
require.NotNil(t, user)
storedUser, err := client.User.Get(ctx, user.ID)
require.NoError(t, err)
require.Equal(t, "email", storedUser.SignupSource)
require.NotNil(t, storedUser.LastLoginAt)
require.NotNil(t, storedUser.LastActiveAt)
identity, err := client.AuthIdentity.Query().
Where(
authidentity.ProviderTypeEQ("email"),
authidentity.ProviderKeyEQ("email"),
authidentity.ProviderSubjectEQ("user@example.com"),
).
Only(ctx)
require.NoError(t, err)
require.Equal(t, user.ID, identity.UserID)
require.NotNil(t, identity.VerifiedAt)
}
func TestAuthServiceLoginTouchesLastLoginAt(t *testing.T) {
svc, repo, client := newAuthServiceWithEnt(t)
ctx := context.Background()
user := &service.User{
Email: "login@example.com",
Role: service.RoleUser,
Status: service.StatusActive,
Balance: 1,
Concurrency: 1,
}
require.NoError(t, user.SetPassword("password"))
require.NoError(t, repo.Create(ctx, user))
old := time.Now().Add(-2 * time.Hour).UTC().Round(time.Second)
_, err := client.User.UpdateOneID(user.ID).
SetLastLoginAt(old).
SetLastActiveAt(old).
Save(ctx)
require.NoError(t, err)
token, gotUser, err := svc.Login(ctx, user.Email, "password")
require.NoError(t, err)
require.NotEmpty(t, token)
require.NotNil(t, gotUser)
storedUser, err := client.User.Get(ctx, user.ID)
require.NoError(t, err)
require.NotNil(t, storedUser.LastLoginAt)
require.NotNil(t, storedUser.LastActiveAt)
require.True(t, storedUser.LastLoginAt.After(old))
require.True(t, storedUser.LastActiveAt.After(old))
}

View File

@@ -1,146 +0,0 @@
//go:build unit
package service
import (
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/require"
)
func newAuthServiceForPendingOAuthTest() *AuthService {
cfg := &config.Config{
JWT: config.JWTConfig{
Secret: "test-secret-pending-oauth",
ExpireHour: 1,
},
}
return NewAuthService(nil, nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
}
// TestVerifyPendingOAuthToken_ValidToken 验证正常签发的 pending token 可以被成功解析。
func TestVerifyPendingOAuthToken_ValidToken(t *testing.T) {
svc := newAuthServiceForPendingOAuthTest()
token, err := svc.CreatePendingOAuthToken("user@example.com", "alice")
require.NoError(t, err)
require.NotEmpty(t, token)
email, username, err := svc.VerifyPendingOAuthToken(token)
require.NoError(t, err)
require.Equal(t, "user@example.com", email)
require.Equal(t, "alice", username)
}
// TestVerifyPendingOAuthToken_RegularJWTRejected 用普通 access token 尝试验证,应返回 ErrInvalidToken。
func TestVerifyPendingOAuthToken_RegularJWTRejected(t *testing.T) {
svc := newAuthServiceForPendingOAuthTest()
// 签发一个普通 access tokenJWTClaims无 Purpose 字段)
accessToken, err := svc.GenerateToken(&User{
ID: 1,
Email: "user@example.com",
Role: RoleUser,
})
require.NoError(t, err)
_, _, err = svc.VerifyPendingOAuthToken(accessToken)
require.ErrorIs(t, err, ErrInvalidToken)
}
// TestVerifyPendingOAuthToken_WrongPurpose 手动构造 purpose 字段不匹配的 JWT应返回 ErrInvalidToken。
func TestVerifyPendingOAuthToken_WrongPurpose(t *testing.T) {
svc := newAuthServiceForPendingOAuthTest()
now := time.Now()
claims := &pendingOAuthClaims{
Email: "user@example.com",
Username: "alice",
Purpose: "some_other_purpose",
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(now.Add(10 * time.Minute)),
IssuedAt: jwt.NewNumericDate(now),
NotBefore: jwt.NewNumericDate(now),
},
}
tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret))
require.NoError(t, err)
_, _, err = svc.VerifyPendingOAuthToken(tokenStr)
require.ErrorIs(t, err, ErrInvalidToken)
}
// TestVerifyPendingOAuthToken_MissingPurpose 手动构造无 purpose 字段的 JWT模拟旧 token应返回 ErrInvalidToken。
func TestVerifyPendingOAuthToken_MissingPurpose(t *testing.T) {
svc := newAuthServiceForPendingOAuthTest()
now := time.Now()
claims := &pendingOAuthClaims{
Email: "user@example.com",
Username: "alice",
Purpose: "", // 旧 token 无此字段,反序列化后为零值
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(now.Add(10 * time.Minute)),
IssuedAt: jwt.NewNumericDate(now),
NotBefore: jwt.NewNumericDate(now),
},
}
tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret))
require.NoError(t, err)
_, _, err = svc.VerifyPendingOAuthToken(tokenStr)
require.ErrorIs(t, err, ErrInvalidToken)
}
// TestVerifyPendingOAuthToken_ExpiredToken 过期 token 应返回 ErrInvalidToken。
func TestVerifyPendingOAuthToken_ExpiredToken(t *testing.T) {
svc := newAuthServiceForPendingOAuthTest()
past := time.Now().Add(-1 * time.Hour)
claims := &pendingOAuthClaims{
Email: "user@example.com",
Username: "alice",
Purpose: pendingOAuthPurpose,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(past),
IssuedAt: jwt.NewNumericDate(past.Add(-10 * time.Minute)),
NotBefore: jwt.NewNumericDate(past.Add(-10 * time.Minute)),
},
}
tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret))
require.NoError(t, err)
_, _, err = svc.VerifyPendingOAuthToken(tokenStr)
require.ErrorIs(t, err, ErrInvalidToken)
}
// TestVerifyPendingOAuthToken_WrongSecret 不同密钥签发的 token 应返回 ErrInvalidToken。
func TestVerifyPendingOAuthToken_WrongSecret(t *testing.T) {
other := NewAuthService(nil, nil, nil, nil, &config.Config{
JWT: config.JWTConfig{Secret: "other-secret"},
}, nil, nil, nil, nil, nil, nil)
token, err := other.CreatePendingOAuthToken("user@example.com", "alice")
require.NoError(t, err)
svc := newAuthServiceForPendingOAuthTest()
_, _, err = svc.VerifyPendingOAuthToken(token)
require.ErrorIs(t, err, ErrInvalidToken)
}
// TestVerifyPendingOAuthToken_TooLong 超长 token 应返回 ErrInvalidToken。
func TestVerifyPendingOAuthToken_TooLong(t *testing.T) {
svc := newAuthServiceForPendingOAuthTest()
giant := make([]byte, maxTokenLength+1)
for i := range giant {
giant[i] = 'a'
}
_, _, err := svc.VerifyPendingOAuthToken(string(giant))
require.ErrorIs(t, err, ErrInvalidToken)
}

View File

@@ -74,6 +74,9 @@ const LinuxDoConnectSyntheticEmailDomain = "@linuxdo-connect.invalid"
// OIDCConnectSyntheticEmailDomain 是 OIDC 用户的合成邮箱后缀RFC 保留域名)。
const OIDCConnectSyntheticEmailDomain = "@oidc-connect.invalid"
// WeChatConnectSyntheticEmailDomain 是 WeChat Connect 用户的合成邮箱后缀RFC 保留域名)。
const WeChatConnectSyntheticEmailDomain = "@wechat-connect.invalid"
// Setting keys
const (
// 注册设置
@@ -153,6 +156,29 @@ const (
SettingKeyDefaultBalance = "default_balance" // 新用户默认余额
SettingKeyDefaultSubscriptions = "default_subscriptions" // 新用户默认订阅列表JSON
// 第三方认证来源默认授予配置
SettingKeyAuthSourceDefaultEmailBalance = "auth_source_default_email_balance"
SettingKeyAuthSourceDefaultEmailConcurrency = "auth_source_default_email_concurrency"
SettingKeyAuthSourceDefaultEmailSubscriptions = "auth_source_default_email_subscriptions"
SettingKeyAuthSourceDefaultEmailGrantOnSignup = "auth_source_default_email_grant_on_signup"
SettingKeyAuthSourceDefaultEmailGrantOnFirstBind = "auth_source_default_email_grant_on_first_bind"
SettingKeyAuthSourceDefaultLinuxDoBalance = "auth_source_default_linuxdo_balance"
SettingKeyAuthSourceDefaultLinuxDoConcurrency = "auth_source_default_linuxdo_concurrency"
SettingKeyAuthSourceDefaultLinuxDoSubscriptions = "auth_source_default_linuxdo_subscriptions"
SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup = "auth_source_default_linuxdo_grant_on_signup"
SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind = "auth_source_default_linuxdo_grant_on_first_bind"
SettingKeyAuthSourceDefaultOIDCBalance = "auth_source_default_oidc_balance"
SettingKeyAuthSourceDefaultOIDCConcurrency = "auth_source_default_oidc_concurrency"
SettingKeyAuthSourceDefaultOIDCSubscriptions = "auth_source_default_oidc_subscriptions"
SettingKeyAuthSourceDefaultOIDCGrantOnSignup = "auth_source_default_oidc_grant_on_signup"
SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind = "auth_source_default_oidc_grant_on_first_bind"
SettingKeyAuthSourceDefaultWeChatBalance = "auth_source_default_wechat_balance"
SettingKeyAuthSourceDefaultWeChatConcurrency = "auth_source_default_wechat_concurrency"
SettingKeyAuthSourceDefaultWeChatSubscriptions = "auth_source_default_wechat_subscriptions"
SettingKeyAuthSourceDefaultWeChatGrantOnSignup = "auth_source_default_wechat_grant_on_signup"
SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind = "auth_source_default_wechat_grant_on_first_bind"
SettingKeyForceEmailOnThirdPartySignup = "force_email_on_third_party_signup"
// 管理员 API Key
SettingKeyAdminAPIKey = "admin_api_key" // 全局管理员 API Key用于外部系统集成

View File

@@ -13,14 +13,30 @@ import (
"sync"
"sync/atomic"
"time"
"golang.org/x/sync/singleflight"
)
const (
openAIAccountScheduleLayerPreviousResponse = "previous_response_id"
openAIAccountScheduleLayerSessionSticky = "session_hash"
openAIAccountScheduleLayerLoadBalance = "load_balance"
openAIAdvancedSchedulerSettingKey = "openai_advanced_scheduler_enabled"
)
const (
openAIAdvancedSchedulerSettingCacheTTL = 5 * time.Second
openAIAdvancedSchedulerSettingDBTimeout = 2 * time.Second
)
type cachedOpenAIAdvancedSchedulerSetting struct {
enabled bool
expiresAt int64
}
var openAIAdvancedSchedulerSettingCache atomic.Value // *cachedOpenAIAdvancedSchedulerSetting
var openAIAdvancedSchedulerSettingSF singleflight.Group
type OpenAIAccountScheduleRequest struct {
GroupID *int64
SessionHash string
@@ -805,10 +821,56 @@ func (s *defaultOpenAIAccountScheduler) SnapshotMetrics() OpenAIAccountScheduler
return snapshot
}
func (s *OpenAIGatewayService) getOpenAIAccountScheduler() OpenAIAccountScheduler {
func (s *OpenAIGatewayService) openAIAdvancedSchedulerSettingRepo() SettingRepository {
if s == nil || s.rateLimitService == nil || s.rateLimitService.settingService == nil {
return nil
}
return s.rateLimitService.settingService.settingRepo
}
func (s *OpenAIGatewayService) isOpenAIAdvancedSchedulerEnabled(ctx context.Context) bool {
if cached, ok := openAIAdvancedSchedulerSettingCache.Load().(*cachedOpenAIAdvancedSchedulerSetting); ok && cached != nil {
if time.Now().UnixNano() < cached.expiresAt {
return cached.enabled
}
}
result, _, _ := openAIAdvancedSchedulerSettingSF.Do(openAIAdvancedSchedulerSettingKey, func() (any, error) {
if cached, ok := openAIAdvancedSchedulerSettingCache.Load().(*cachedOpenAIAdvancedSchedulerSetting); ok && cached != nil {
if time.Now().UnixNano() < cached.expiresAt {
return cached.enabled, nil
}
}
enabled := false
if repo := s.openAIAdvancedSchedulerSettingRepo(); repo != nil {
dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), openAIAdvancedSchedulerSettingDBTimeout)
defer cancel()
value, err := repo.GetValue(dbCtx, openAIAdvancedSchedulerSettingKey)
if err == nil {
enabled = strings.EqualFold(strings.TrimSpace(value), "true")
}
}
openAIAdvancedSchedulerSettingCache.Store(&cachedOpenAIAdvancedSchedulerSetting{
enabled: enabled,
expiresAt: time.Now().Add(openAIAdvancedSchedulerSettingCacheTTL).UnixNano(),
})
return enabled, nil
})
enabled, _ := result.(bool)
return enabled
}
func (s *OpenAIGatewayService) getOpenAIAccountScheduler(ctx context.Context) OpenAIAccountScheduler {
if s == nil {
return nil
}
if !s.isOpenAIAdvancedSchedulerEnabled(ctx) {
return nil
}
s.openaiSchedulerOnce.Do(func() {
if s.openaiAccountStats == nil {
s.openaiAccountStats = newOpenAIAccountRuntimeStats()
@@ -820,6 +882,11 @@ func (s *OpenAIGatewayService) getOpenAIAccountScheduler() OpenAIAccountSchedule
return s.openaiScheduler
}
func resetOpenAIAdvancedSchedulerSettingCacheForTest() {
openAIAdvancedSchedulerSettingCache = atomic.Value{}
openAIAdvancedSchedulerSettingSF = singleflight.Group{}
}
func (s *OpenAIGatewayService) SelectAccountWithScheduler(
ctx context.Context,
groupID *int64,
@@ -830,7 +897,7 @@ func (s *OpenAIGatewayService) SelectAccountWithScheduler(
requiredTransport OpenAIUpstreamTransport,
) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) {
decision := OpenAIAccountScheduleDecision{}
scheduler := s.getOpenAIAccountScheduler()
scheduler := s.getOpenAIAccountScheduler(ctx)
if scheduler == nil {
selection, err := s.SelectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, excludedIDs)
decision.Layer = openAIAccountScheduleLayerLoadBalance
@@ -856,7 +923,7 @@ func (s *OpenAIGatewayService) SelectAccountWithScheduler(
}
func (s *OpenAIGatewayService) ReportOpenAIAccountScheduleResult(accountID int64, success bool, firstTokenMs *int) {
scheduler := s.getOpenAIAccountScheduler()
scheduler := s.getOpenAIAccountScheduler(context.Background())
if scheduler == nil {
return
}
@@ -864,7 +931,7 @@ func (s *OpenAIGatewayService) ReportOpenAIAccountScheduleResult(accountID int64
}
func (s *OpenAIGatewayService) RecordOpenAIAccountSwitch() {
scheduler := s.getOpenAIAccountScheduler()
scheduler := s.getOpenAIAccountScheduler(context.Background())
if scheduler == nil {
return
}
@@ -872,7 +939,7 @@ func (s *OpenAIGatewayService) RecordOpenAIAccountSwitch() {
}
func (s *OpenAIGatewayService) SnapshotOpenAIAccountSchedulerMetrics() OpenAIAccountSchedulerMetricsSnapshot {
scheduler := s.getOpenAIAccountScheduler()
scheduler := s.getOpenAIAccountScheduler(context.Background())
if scheduler == nil {
return OpenAIAccountSchedulerMetricsSnapshot{}
}

View File

@@ -2,6 +2,7 @@ package service
import (
"context"
"errors"
"fmt"
"math"
"sync"
@@ -18,6 +19,202 @@ type openAISnapshotCacheStub struct {
accountsByID map[int64]*Account
}
type schedulerTestOpenAIAccountRepo struct {
AccountRepository
accounts []Account
}
func (r schedulerTestOpenAIAccountRepo) GetByID(ctx context.Context, id int64) (*Account, error) {
for i := range r.accounts {
if r.accounts[i].ID == id {
return &r.accounts[i], nil
}
}
return nil, errors.New("account not found")
}
func (r schedulerTestOpenAIAccountRepo) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error) {
var result []Account
for _, acc := range r.accounts {
if acc.Platform == platform {
result = append(result, acc)
}
}
return result, nil
}
func (r schedulerTestOpenAIAccountRepo) ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error) {
var result []Account
for _, acc := range r.accounts {
if acc.Platform == platform {
result = append(result, acc)
}
}
return result, nil
}
func (r schedulerTestOpenAIAccountRepo) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error) {
return r.ListSchedulableByPlatform(ctx, platform)
}
type schedulerTestConcurrencyCache struct {
ConcurrencyCache
loadBatchErr error
loadMap map[int64]*AccountLoadInfo
acquireResults map[int64]bool
waitCounts map[int64]int
skipDefaultLoad bool
}
func (c schedulerTestConcurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
if c.acquireResults != nil {
if result, ok := c.acquireResults[accountID]; ok {
return result, nil
}
}
return true, nil
}
func (c schedulerTestConcurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error {
return nil
}
func (c schedulerTestConcurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
if c.loadBatchErr != nil {
return nil, c.loadBatchErr
}
out := make(map[int64]*AccountLoadInfo, len(accounts))
if c.skipDefaultLoad && c.loadMap != nil {
for _, acc := range accounts {
if load, ok := c.loadMap[acc.ID]; ok {
out[acc.ID] = load
}
}
return out, nil
}
for _, acc := range accounts {
if c.loadMap != nil {
if load, ok := c.loadMap[acc.ID]; ok {
out[acc.ID] = load
continue
}
}
out[acc.ID] = &AccountLoadInfo{AccountID: acc.ID, LoadRate: 0}
}
return out, nil
}
func (c schedulerTestConcurrencyCache) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
if c.waitCounts != nil {
if count, ok := c.waitCounts[accountID]; ok {
return count, nil
}
}
return 0, nil
}
type schedulerTestGatewayCache struct {
sessionBindings map[string]int64
deletedSessions map[string]int
}
func (c *schedulerTestGatewayCache) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) {
if id, ok := c.sessionBindings[sessionHash]; ok {
return id, nil
}
return 0, errors.New("not found")
}
func (c *schedulerTestGatewayCache) SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error {
if c.sessionBindings == nil {
c.sessionBindings = make(map[string]int64)
}
c.sessionBindings[sessionHash] = accountID
return nil
}
func (c *schedulerTestGatewayCache) RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error {
return nil
}
func (c *schedulerTestGatewayCache) DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error {
if c.sessionBindings == nil {
return nil
}
if c.deletedSessions == nil {
c.deletedSessions = make(map[string]int)
}
c.deletedSessions[sessionHash]++
delete(c.sessionBindings, sessionHash)
return nil
}
func newSchedulerTestOpenAIWSV2Config() *config.Config {
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600
return cfg
}
type openAIAdvancedSchedulerSettingRepoStub struct {
values map[string]string
}
func (s *openAIAdvancedSchedulerSettingRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
value, err := s.GetValue(ctx, key)
if err != nil {
return nil, err
}
return &Setting{Key: key, Value: value}, nil
}
func (s *openAIAdvancedSchedulerSettingRepoStub) GetValue(_ context.Context, key string) (string, error) {
if s == nil || s.values == nil {
return "", ErrSettingNotFound
}
value, ok := s.values[key]
if !ok {
return "", ErrSettingNotFound
}
return value, nil
}
func (s *openAIAdvancedSchedulerSettingRepoStub) Set(context.Context, string, string) error {
panic("unexpected call to Set")
}
func (s *openAIAdvancedSchedulerSettingRepoStub) GetMultiple(context.Context, []string) (map[string]string, error) {
panic("unexpected call to GetMultiple")
}
func (s *openAIAdvancedSchedulerSettingRepoStub) SetMultiple(context.Context, map[string]string) error {
panic("unexpected call to SetMultiple")
}
func (s *openAIAdvancedSchedulerSettingRepoStub) GetAll(context.Context) (map[string]string, error) {
panic("unexpected call to GetAll")
}
func (s *openAIAdvancedSchedulerSettingRepoStub) Delete(context.Context, string) error {
panic("unexpected call to Delete")
}
func newOpenAIAdvancedSchedulerRateLimitService(enabled string) *RateLimitService {
resetOpenAIAdvancedSchedulerSettingCacheForTest()
repo := &openAIAdvancedSchedulerSettingRepoStub{
values: map[string]string{},
}
if enabled != "" {
repo.values[openAIAdvancedSchedulerSettingKey] = enabled
}
return &RateLimitService{
settingService: NewSettingService(repo, &config.Config{}),
}
}
func (s *openAISnapshotCacheStub) GetSnapshot(ctx context.Context, bucket SchedulerBucket) ([]*Account, bool, error) {
if len(s.snapshotAccounts) == 0 {
return nil, false, nil
@@ -45,6 +242,138 @@ func (s *openAISnapshotCacheStub) GetAccount(ctx context.Context, accountID int6
return &cloned, nil
}
func TestOpenAIGatewayService_SelectAccountWithScheduler_DefaultDisabledUsesLegacyLoadAwareness(t *testing.T) {
resetOpenAIAdvancedSchedulerSettingCacheForTest()
ctx := context.Background()
groupID := int64(10106)
accounts := []Account{
{
ID: 36001,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 5,
},
{
ID: 36002,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 0,
},
}
cfg := &config.Config{}
cfg.Gateway.Scheduling.LoadBatchEnabled = false
cache := &schedulerTestGatewayCache{}
svc := &OpenAIGatewayService{
accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
cache: cache,
cfg: cfg,
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
store := svc.getOpenAIWSStateStore()
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_disabled_001", 36001, time.Hour))
require.False(t, svc.isOpenAIAdvancedSchedulerEnabled(ctx))
selection, decision, err := svc.SelectAccountWithScheduler(
ctx,
&groupID,
"resp_disabled_001",
"",
"gpt-5.1",
nil,
OpenAIUpstreamTransportAny,
)
require.NoError(t, err)
require.NotNil(t, selection)
require.NotNil(t, selection.Account)
require.Equal(t, int64(36002), selection.Account.ID)
require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
require.False(t, decision.StickyPreviousHit)
}
func TestOpenAIGatewayService_SelectAccountWithScheduler_EnabledUsesAdvancedPreviousResponseRouting(t *testing.T) {
resetOpenAIAdvancedSchedulerSettingCacheForTest()
ctx := context.Background()
groupID := int64(10107)
accounts := []Account{
{
ID: 37001,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 5,
Extra: map[string]any{
"openai_apikey_responses_websockets_v2_enabled": true,
},
},
{
ID: 37002,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Status: StatusActive,
Schedulable: true,
Concurrency: 1,
Priority: 0,
},
}
cfg := &config.Config{}
cfg.Gateway.Scheduling.LoadBatchEnabled = false
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600
svc := &OpenAIGatewayService{
accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
cache: &schedulerTestGatewayCache{},
cfg: cfg,
rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
store := svc.getOpenAIWSStateStore()
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_enabled_001", 37001, time.Hour))
require.True(t, svc.isOpenAIAdvancedSchedulerEnabled(ctx))
selection, decision, err := svc.SelectAccountWithScheduler(
ctx,
&groupID,
"resp_enabled_001",
"",
"gpt-5.1",
nil,
OpenAIUpstreamTransportAny,
)
require.NoError(t, err)
require.NotNil(t, selection)
require.NotNil(t, selection.Account)
require.Equal(t, int64(37001), selection.Account.ID)
require.Equal(t, openAIAccountScheduleLayerPreviousResponse, decision.Layer)
require.True(t, decision.StickyPreviousHit)
}
func TestOpenAIGatewayService_OpenAIAccountSchedulerMetrics_DisabledNoOp(t *testing.T) {
resetOpenAIAdvancedSchedulerSettingCacheForTest()
svc := &OpenAIGatewayService{}
ttft := 120
svc.ReportOpenAIAccountScheduleResult(10, true, &ttft)
svc.RecordOpenAIAccountSwitch()
snapshot := svc.SnapshotOpenAIAccountSchedulerMetrics()
require.Equal(t, OpenAIAccountSchedulerMetricsSnapshot{}, snapshot)
}
func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyRateLimitedAccountFallsBackToFreshCandidate(t *testing.T) {
ctx := context.Background()
groupID := int64(10101)
@@ -53,10 +382,17 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyRateLimite
staleBackup := &Account{ID: 31002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
freshSticky := &Account{ID: 31001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0, RateLimitResetAt: &rateLimitedUntil}
freshBackup := &Account{ID: 31002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
cache := &stubGatewayCache{sessionBindings: map[string]int64{"openai:session_hash_rate_limited": 31001}}
cache := &schedulerTestGatewayCache{sessionBindings: map[string]int64{"openai:session_hash_rate_limited": 31001}}
snapshotCache := &openAISnapshotCacheStub{snapshotAccounts: []*Account{staleSticky, staleBackup}, accountsByID: map[int64]*Account{31001: freshSticky, 31002: freshBackup}}
snapshotService := &SchedulerSnapshotService{cache: snapshotCache}
svc := &OpenAIGatewayService{accountRepo: stubOpenAIAccountRepo{accounts: []Account{*freshSticky, *freshBackup}}, cache: cache, cfg: &config.Config{}, schedulerSnapshot: snapshotService, concurrencyService: NewConcurrencyService(stubConcurrencyCache{})}
svc := &OpenAIGatewayService{
accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{*freshSticky, *freshBackup}},
cache: cache,
cfg: &config.Config{},
rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
schedulerSnapshot: snapshotService,
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_rate_limited", "gpt-5.1", nil, OpenAIUpstreamTransportAny)
require.NoError(t, err)
@@ -76,7 +412,12 @@ func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_SkipsFreshlyRa
freshSecondary := &Account{ID: 32002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
snapshotCache := &openAISnapshotCacheStub{snapshotAccounts: []*Account{stalePrimary, staleSecondary}, accountsByID: map[int64]*Account{32001: freshPrimary, 32002: freshSecondary}}
snapshotService := &SchedulerSnapshotService{cache: snapshotCache}
svc := &OpenAIGatewayService{accountRepo: stubOpenAIAccountRepo{accounts: []Account{*freshPrimary, *freshSecondary}}, cfg: &config.Config{}, schedulerSnapshot: snapshotService}
svc := &OpenAIGatewayService{
accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{*freshPrimary, *freshSecondary}},
cfg: &config.Config{},
rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
schedulerSnapshot: snapshotService,
}
account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "gpt-5.1", nil)
require.NoError(t, err)
@@ -92,18 +433,19 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyDBRuntimeR
staleBackup := &Account{ID: 33002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
dbSticky := Account{ID: 33001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0, RateLimitResetAt: &rateLimitedUntil}
dbBackup := Account{ID: 33002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
cache := &stubGatewayCache{sessionBindings: map[string]int64{"openai:session_hash_db_runtime_recheck": 33001}}
cache := &schedulerTestGatewayCache{sessionBindings: map[string]int64{"openai:session_hash_db_runtime_recheck": 33001}}
snapshotCache := &openAISnapshotCacheStub{
snapshotAccounts: []*Account{staleSticky, staleBackup},
accountsByID: map[int64]*Account{33001: staleSticky, 33002: staleBackup},
}
snapshotService := &SchedulerSnapshotService{cache: snapshotCache}
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{dbSticky, dbBackup}},
accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{dbSticky, dbBackup}},
cache: cache,
cfg: &config.Config{},
rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
schedulerSnapshot: snapshotService,
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_db_runtime_recheck", "gpt-5.1", nil, OpenAIUpstreamTransportAny)
@@ -128,8 +470,9 @@ func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_DBRuntimeReche
}
snapshotService := &SchedulerSnapshotService{cache: snapshotCache}
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{dbPrimary, dbSecondary}},
accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{dbPrimary, dbSecondary}},
cfg: &config.Config{},
rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
schedulerSnapshot: snapshotService,
}
@@ -153,7 +496,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_PreviousResponseSticky(
"openai_apikey_responses_websockets_v2_enabled": true,
},
}
cache := &stubGatewayCache{}
cache := &schedulerTestGatewayCache{}
cfg := &config.Config{}
cfg.Gateway.OpenAIWS.Enabled = true
cfg.Gateway.OpenAIWS.OAuthEnabled = true
@@ -163,10 +506,11 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_PreviousResponseSticky(
cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{account}},
cache: cache,
cfg: cfg,
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
store := svc.getOpenAIWSStateStore()
@@ -204,17 +548,18 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionSticky(t *testin
Schedulable: true,
Concurrency: 1,
}
cache := &stubGatewayCache{
cache := &schedulerTestGatewayCache{
sessionBindings: map[string]int64{
"openai:session_hash_abc": account.ID,
},
}
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{account}},
cache: cache,
cfg: &config.Config{},
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
selection, decision, err := svc.SelectAccountWithScheduler(
@@ -260,7 +605,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyBusyKeepsS
Priority: 9,
},
}
cache := &stubGatewayCache{
cache := &schedulerTestGatewayCache{
sessionBindings: map[string]int64{
"openai:session_hash_sticky_busy": 21001,
},
@@ -273,7 +618,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyBusyKeepsS
cfg.Gateway.OpenAIWS.OAuthEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
concurrencyCache := stubConcurrencyCache{
concurrencyCache := schedulerTestConcurrencyCache{
acquireResults: map[int64]bool{
21001: false, // sticky 账号已满
21002: true, // 若回退负载均衡会命中该账号(本测试要求不能切换)
@@ -288,9 +633,10 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyBusyKeepsS
}
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: accounts},
accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
cache: cache,
cfg: cfg,
rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
concurrencyService: NewConcurrencyService(concurrencyCache),
}
@@ -328,17 +674,18 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionSticky_ForceHTTP
"openai_ws_force_http": true,
},
}
cache := &stubGatewayCache{
cache := &schedulerTestGatewayCache{
sessionBindings: map[string]int64{
"openai:session_hash_force_http": account.ID,
},
}
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{account}},
cache: cache,
cfg: &config.Config{},
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
selection, decision, err := svc.SelectAccountWithScheduler(
@@ -387,15 +734,15 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_SkipsStick
},
},
}
cache := &stubGatewayCache{
cache := &schedulerTestGatewayCache{
sessionBindings: map[string]int64{
"openai:session_hash_ws_only": 2201,
},
}
cfg := newOpenAIWSV2TestConfig()
cfg := newSchedulerTestOpenAIWSV2Config()
// 构造“HTTP-only 账号负载更低”的场景,验证 required transport 会强制过滤。
concurrencyCache := stubConcurrencyCache{
concurrencyCache := schedulerTestConcurrencyCache{
loadMap: map[int64]*AccountLoadInfo{
2201: {AccountID: 2201, LoadRate: 0, WaitingCount: 0},
2202: {AccountID: 2202, LoadRate: 90, WaitingCount: 5},
@@ -403,9 +750,10 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_SkipsStick
}
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: accounts},
accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
cache: cache,
cfg: cfg,
rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
concurrencyService: NewConcurrencyService(concurrencyCache),
}
@@ -445,10 +793,11 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_NoAvailabl
}
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: accounts},
cache: &stubGatewayCache{},
cfg: newOpenAIWSV2TestConfig(),
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
cache: &schedulerTestGatewayCache{},
cfg: newSchedulerTestOpenAIWSV2Config(),
rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
selection, decision, err := svc.SelectAccountWithScheduler(
@@ -507,7 +856,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceTopKFallback
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 0.2
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 0.1
concurrencyCache := stubConcurrencyCache{
concurrencyCache := schedulerTestConcurrencyCache{
loadMap: map[int64]*AccountLoadInfo{
3001: {AccountID: 3001, LoadRate: 95, WaitingCount: 8},
3002: {AccountID: 3002, LoadRate: 20, WaitingCount: 1},
@@ -520,9 +869,10 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceTopKFallback
}
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: accounts},
cache: &stubGatewayCache{},
accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
cache: &schedulerTestGatewayCache{},
cfg: cfg,
rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
concurrencyService: NewConcurrencyService(concurrencyCache),
}
@@ -559,16 +909,17 @@ func TestOpenAIGatewayService_OpenAIAccountSchedulerMetrics(t *testing.T) {
Schedulable: true,
Concurrency: 1,
}
cache := &stubGatewayCache{
cache := &schedulerTestGatewayCache{
sessionBindings: map[string]int64{
"openai:session_hash_metrics": account.ID,
},
}
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{account}},
cache: cache,
cfg: &config.Config{},
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
selection, _, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_metrics", "gpt-5.1", nil, OpenAIUpstreamTransportAny)
@@ -749,7 +1100,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceDistributesA
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 1
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 1
concurrencyCache := stubConcurrencyCache{
concurrencyCache := schedulerTestConcurrencyCache{
loadMap: map[int64]*AccountLoadInfo{
5101: {AccountID: 5101, LoadRate: 20, WaitingCount: 1},
5102: {AccountID: 5102, LoadRate: 20, WaitingCount: 1},
@@ -757,9 +1108,10 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceDistributesA
},
}
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: accounts},
cache: &stubGatewayCache{sessionBindings: map[string]int64{}},
accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
cache: &schedulerTestGatewayCache{sessionBindings: map[string]int64{}},
cfg: cfg,
rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
concurrencyService: NewConcurrencyService(concurrencyCache),
}
@@ -905,12 +1257,14 @@ func TestDefaultOpenAIAccountScheduler_ReportSwitchAndSnapshot(t *testing.T) {
}
func TestOpenAIGatewayService_SchedulerWrappersAndDefaults(t *testing.T) {
resetOpenAIAdvancedSchedulerSettingCacheForTest()
svc := &OpenAIGatewayService{}
ttft := 120
svc.ReportOpenAIAccountScheduleResult(10, true, &ttft)
svc.RecordOpenAIAccountSwitch()
snapshot := svc.SnapshotOpenAIAccountSchedulerMetrics()
require.GreaterOrEqual(t, snapshot.AccountSwitchTotal, int64(1))
require.Equal(t, OpenAIAccountSchedulerMetricsSnapshot{}, snapshot)
require.Equal(t, 7, svc.openAIWSLBTopK())
require.Equal(t, openaiStickySessionTTL, svc.openAIWSSessionStickyTTL())
@@ -947,7 +1301,7 @@ func TestDefaultOpenAIAccountScheduler_IsAccountTransportCompatible_Branches(t *
require.True(t, scheduler.isAccountTransportCompatible(nil, OpenAIUpstreamTransportHTTPSSE))
require.False(t, scheduler.isAccountTransportCompatible(nil, OpenAIUpstreamTransportResponsesWebsocketV2))
cfg := newOpenAIWSV2TestConfig()
cfg := newSchedulerTestOpenAIWSV2Config()
scheduler.service = &OpenAIGatewayService{cfg: cfg}
account := &Account{
ID: 8801,

View File

@@ -38,11 +38,12 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_UsesWSPassthroughSnapsh
cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeCtxPool
svc := &OpenAIGatewayService{
accountRepo: stubOpenAIAccountRepo{accounts: []Account{*account}},
cache: &stubGatewayCache{},
accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{*account}},
cache: &schedulerTestGatewayCache{},
cfg: cfg,
rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
schedulerSnapshot: &SchedulerSnapshotService{cache: snapshotCache},
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
}
selection, decision, err := svc.SelectAccountWithScheduler(

View File

@@ -196,12 +196,25 @@ func (s *PaymentConfigService) GetPaymentConfig(ctx context.Context) (*PaymentCo
SettingHelpImageURL, SettingHelpText,
SettingCancelRateLimitOn, SettingCancelRateLimitMax,
SettingCancelWindowSize, SettingCancelWindowUnit, SettingCancelWindowMode,
SettingPaymentVisibleMethodAlipayEnabled, SettingPaymentVisibleMethodAlipaySource,
SettingPaymentVisibleMethodWxpayEnabled, SettingPaymentVisibleMethodWxpaySource,
}
vals, err := s.settingRepo.GetMultiple(ctx, keys)
if err != nil {
return nil, fmt.Errorf("get payment config settings: %w", err)
}
cfg := s.parsePaymentConfig(vals)
if s.entClient != nil {
instances, err := s.entClient.PaymentProviderInstance.Query().
Where(paymentproviderinstance.EnabledEQ(true)).
All(ctx)
if err != nil {
return nil, fmt.Errorf("list enabled provider instances: %w", err)
}
cfg.EnabledTypes = applyVisibleMethodRoutingToEnabledTypes(cfg.EnabledTypes, vals, buildVisibleMethodSourceAvailability(instances))
} else {
cfg.EnabledTypes = applyVisibleMethodRoutingToEnabledTypes(cfg.EnabledTypes, vals, nil)
}
// Load Stripe publishable key from the first enabled Stripe provider instance
cfg.StripePublishableKey = s.getStripePublishableKey(ctx)
return cfg, nil
@@ -234,18 +247,23 @@ func (s *PaymentConfigService) parsePaymentConfig(vals map[string]string) *Payme
cfg.LoadBalanceStrategy = payment.DefaultLoadBalanceStrategy
}
if raw := vals[SettingEnabledPaymentTypes]; raw != "" {
types := make([]string, 0, len(strings.Split(raw, ",")))
for _, t := range strings.Split(raw, ",") {
t = strings.TrimSpace(t)
if t != "" {
cfg.EnabledTypes = append(cfg.EnabledTypes, t)
types = append(types, t)
}
}
cfg.EnabledTypes = NormalizeVisibleMethods(types)
}
return cfg
}
// getStripePublishableKey finds the publishable key from the first enabled Stripe provider instance.
func (s *PaymentConfigService) getStripePublishableKey(ctx context.Context) string {
if s.entClient == nil {
return ""
}
instances, err := s.entClient.PaymentProviderInstance.Query().
Where(
paymentproviderinstance.EnabledEQ(true),
@@ -385,3 +403,79 @@ func pcParseInt(s string, defaultVal int) int {
}
return v
}
func buildVisibleMethodSourceAvailability(instances []*dbent.PaymentProviderInstance) map[string]bool {
available := make(map[string]bool, 4)
for _, inst := range instances {
switch inst.ProviderKey {
case payment.TypeAlipay:
if inst.SupportedTypes == "" || payment.InstanceSupportsType(inst.SupportedTypes, payment.TypeAlipay) || payment.InstanceSupportsType(inst.SupportedTypes, payment.TypeAlipayDirect) {
available[VisibleMethodSourceOfficialAlipay] = true
}
case payment.TypeWxpay:
if inst.SupportedTypes == "" || payment.InstanceSupportsType(inst.SupportedTypes, payment.TypeWxpay) || payment.InstanceSupportsType(inst.SupportedTypes, payment.TypeWxpayDirect) {
available[VisibleMethodSourceOfficialWechat] = true
}
case payment.TypeEasyPay:
for _, supportedType := range splitTypes(inst.SupportedTypes) {
switch NormalizeVisibleMethod(supportedType) {
case payment.TypeAlipay:
available[VisibleMethodSourceEasyPayAlipay] = true
case payment.TypeWxpay:
available[VisibleMethodSourceEasyPayWechat] = true
}
}
}
}
return available
}
func applyVisibleMethodRoutingToEnabledTypes(base []string, vals map[string]string, available map[string]bool) []string {
shouldExpose := map[string]bool{
payment.TypeAlipay: visibleMethodShouldBeExposed(payment.TypeAlipay, vals, available),
payment.TypeWxpay: visibleMethodShouldBeExposed(payment.TypeWxpay, vals, available),
}
seen := make(map[string]struct{}, len(base)+2)
out := make([]string, 0, len(base)+2)
appendType := func(paymentType string) {
paymentType = NormalizeVisibleMethod(paymentType)
if paymentType == "" {
return
}
if _, ok := seen[paymentType]; ok {
return
}
seen[paymentType] = struct{}{}
out = append(out, paymentType)
}
for _, paymentType := range base {
visibleMethod := NormalizeVisibleMethod(paymentType)
switch visibleMethod {
case payment.TypeAlipay, payment.TypeWxpay:
if shouldExpose[visibleMethod] {
appendType(visibleMethod)
}
default:
appendType(visibleMethod)
}
}
for _, visibleMethod := range []string{payment.TypeAlipay, payment.TypeWxpay} {
if shouldExpose[visibleMethod] {
appendType(visibleMethod)
}
}
return out
}
func visibleMethodShouldBeExposed(method string, vals map[string]string, available map[string]bool) bool {
enabledKey := visibleMethodEnabledSettingKey(method)
sourceKey := visibleMethodSourceSettingKey(method)
if enabledKey == "" || sourceKey == "" || vals[enabledKey] != "true" {
return false
}
source := NormalizeVisibleMethodSource(method, vals[sourceKey])
return source != "" && available[source]
}

View File

@@ -1,9 +1,17 @@
package service
import (
"context"
"database/sql"
"testing"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/enttest"
"github.com/Wei-Shaw/sub2api/internal/payment"
"entgo.io/ent/dialect"
entsql "entgo.io/ent/dialect/sql"
_ "modernc.org/sqlite"
)
func TestPcParseFloat(t *testing.T) {
@@ -163,6 +171,20 @@ func TestParsePaymentConfig(t *testing.T) {
}
})
t.Run("enabled types are normalized to visible methods and deduplicated", func(t *testing.T) {
t.Parallel()
vals := map[string]string{
SettingEnabledPaymentTypes: "alipay_direct, alipay, wxpay_direct, wxpay",
}
cfg := svc.parsePaymentConfig(vals)
if len(cfg.EnabledTypes) != 2 {
t.Fatalf("EnabledTypes len = %d, want 2", len(cfg.EnabledTypes))
}
if cfg.EnabledTypes[0] != "alipay" || cfg.EnabledTypes[1] != "wxpay" {
t.Fatalf("EnabledTypes = %v, want [alipay wxpay]", cfg.EnabledTypes)
}
})
t.Run("empty enabled types string", func(t *testing.T) {
t.Parallel()
vals := map[string]string{
@@ -204,3 +226,167 @@ func TestGetBasePaymentType(t *testing.T) {
})
}
}
func TestApplyVisibleMethodRoutingToEnabledTypes(t *testing.T) {
t.Parallel()
base := []string{"alipay", "wxpay", "stripe"}
vals := map[string]string{
SettingPaymentVisibleMethodAlipayEnabled: "true",
SettingPaymentVisibleMethodAlipaySource: VisibleMethodSourceOfficialAlipay,
SettingPaymentVisibleMethodWxpayEnabled: "true",
SettingPaymentVisibleMethodWxpaySource: VisibleMethodSourceOfficialWechat,
}
available := map[string]bool{
VisibleMethodSourceOfficialAlipay: true,
VisibleMethodSourceOfficialWechat: false,
}
got := applyVisibleMethodRoutingToEnabledTypes(base, vals, available)
want := []string{"alipay", "stripe"}
if len(got) != len(want) {
t.Fatalf("applyVisibleMethodRoutingToEnabledTypes len = %d, want %d (%v)", len(got), len(want), got)
}
for i := range want {
if got[i] != want[i] {
t.Fatalf("applyVisibleMethodRoutingToEnabledTypes[%d] = %q, want %q (full=%v)", i, got[i], want[i], got)
}
}
}
func TestApplyVisibleMethodRoutingAddsConfiguredVisibleMethod(t *testing.T) {
t.Parallel()
base := []string{"stripe"}
vals := map[string]string{
SettingPaymentVisibleMethodAlipayEnabled: "true",
SettingPaymentVisibleMethodAlipaySource: VisibleMethodSourceEasyPayAlipay,
}
available := map[string]bool{
VisibleMethodSourceEasyPayAlipay: true,
}
got := applyVisibleMethodRoutingToEnabledTypes(base, vals, available)
want := []string{"stripe", "alipay"}
if len(got) != len(want) {
t.Fatalf("applyVisibleMethodRoutingToEnabledTypes len = %d, want %d (%v)", len(got), len(want), got)
}
for i := range want {
if got[i] != want[i] {
t.Fatalf("applyVisibleMethodRoutingToEnabledTypes[%d] = %q, want %q (full=%v)", i, got[i], want[i], got)
}
}
}
func TestBuildVisibleMethodSourceAvailability(t *testing.T) {
t.Parallel()
instances := []*dbent.PaymentProviderInstance{
{ProviderKey: payment.TypeAlipay, SupportedTypes: "alipay"},
{ProviderKey: payment.TypeEasyPay, SupportedTypes: "wxpay_direct, alipay"},
{ProviderKey: payment.TypeWxpay, SupportedTypes: "wxpay_direct"},
}
got := buildVisibleMethodSourceAvailability(instances)
if !got[VisibleMethodSourceOfficialAlipay] {
t.Fatalf("expected %q to be available", VisibleMethodSourceOfficialAlipay)
}
if !got[VisibleMethodSourceEasyPayAlipay] {
t.Fatalf("expected %q to be available", VisibleMethodSourceEasyPayAlipay)
}
if !got[VisibleMethodSourceOfficialWechat] {
t.Fatalf("expected %q to be available", VisibleMethodSourceOfficialWechat)
}
if !got[VisibleMethodSourceEasyPayWechat] {
t.Fatalf("expected %q to be available", VisibleMethodSourceEasyPayWechat)
}
}
func TestGetPaymentConfigAppliesVisibleMethodRouting(t *testing.T) {
ctx := context.Background()
client := newPaymentConfigServiceTestClient(t)
_, err := client.PaymentProviderInstance.Create().
SetProviderKey(payment.TypeEasyPay).
SetName("EasyPay Alipay").
SetConfig("{}").
SetSupportedTypes("alipay").
SetEnabled(true).
Save(ctx)
if err != nil {
t.Fatalf("create easypay instance: %v", err)
}
svc := &PaymentConfigService{
entClient: client,
settingRepo: &paymentConfigSettingRepoStub{
values: map[string]string{
SettingEnabledPaymentTypes: "alipay,wxpay,stripe",
SettingPaymentVisibleMethodAlipayEnabled: "true",
SettingPaymentVisibleMethodAlipaySource: "easypay",
SettingPaymentVisibleMethodWxpayEnabled: "true",
SettingPaymentVisibleMethodWxpaySource: "wxpay",
},
},
}
cfg, err := svc.GetPaymentConfig(ctx)
if err != nil {
t.Fatalf("GetPaymentConfig returned error: %v", err)
}
want := []string{payment.TypeAlipay, payment.TypeStripe}
if len(cfg.EnabledTypes) != len(want) {
t.Fatalf("EnabledTypes len = %d, want %d (%v)", len(cfg.EnabledTypes), len(want), cfg.EnabledTypes)
}
for i := range want {
if cfg.EnabledTypes[i] != want[i] {
t.Fatalf("EnabledTypes[%d] = %q, want %q (full=%v)", i, cfg.EnabledTypes[i], want[i], cfg.EnabledTypes)
}
}
}
func newPaymentConfigServiceTestClient(t *testing.T) *dbent.Client {
t.Helper()
db, err := sql.Open("sqlite", "file:payment_config_service?mode=memory&cache=shared")
if err != nil {
t.Fatalf("open sqlite: %v", err)
}
t.Cleanup(func() { _ = db.Close() })
if _, err := db.Exec("PRAGMA foreign_keys = ON"); err != nil {
t.Fatalf("enable foreign keys: %v", err)
}
drv := entsql.OpenDB(dialect.SQLite, db)
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
t.Cleanup(func() { _ = client.Close() })
return client
}
type paymentConfigSettingRepoStub struct {
values map[string]string
}
func (s *paymentConfigSettingRepoStub) Get(context.Context, string) (*Setting, error) {
return nil, nil
}
func (s *paymentConfigSettingRepoStub) GetValue(_ context.Context, key string) (string, error) {
return s.values[key], nil
}
func (s *paymentConfigSettingRepoStub) Set(context.Context, string, string) error { return nil }
func (s *paymentConfigSettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
out := make(map[string]string, len(keys))
for _, key := range keys {
out[key] = s.values[key]
}
return out, nil
}
func (s *paymentConfigSettingRepoStub) SetMultiple(context.Context, map[string]string) error {
return nil
}
func (s *paymentConfigSettingRepoStub) GetAll(context.Context) (map[string]string, error) {
return s.values, nil
}
func (s *paymentConfigSettingRepoStub) Delete(context.Context, string) error { return nil }

View File

@@ -0,0 +1,248 @@
package service
import (
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"net/url"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/payment"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
const (
PaymentSourceHostedRedirect = "hosted_redirect"
PaymentSourceWechatInAppResume = "wechat_in_app_resume"
paymentResumeFallbackSigningKey = "sub2api-payment-resume"
SettingPaymentVisibleMethodAlipaySource = "payment_visible_method_alipay_source"
SettingPaymentVisibleMethodWxpaySource = "payment_visible_method_wxpay_source"
SettingPaymentVisibleMethodAlipayEnabled = "payment_visible_method_alipay_enabled"
SettingPaymentVisibleMethodWxpayEnabled = "payment_visible_method_wxpay_enabled"
VisibleMethodSourceOfficialAlipay = "official_alipay"
VisibleMethodSourceEasyPayAlipay = "easypay_alipay"
VisibleMethodSourceOfficialWechat = "official_wxpay"
VisibleMethodSourceEasyPayWechat = "easypay_wxpay"
)
type ResumeTokenClaims struct {
OrderID int64 `json:"oid"`
UserID int64 `json:"uid,omitempty"`
ProviderInstanceID string `json:"pi,omitempty"`
ProviderKey string `json:"pk,omitempty"`
PaymentType string `json:"pt,omitempty"`
CanonicalReturnURL string `json:"ru,omitempty"`
IssuedAt int64 `json:"iat"`
}
type PaymentResumeService struct {
signingKey []byte
}
type visibleMethodLoadBalancer struct {
inner payment.LoadBalancer
configService *PaymentConfigService
}
func NewPaymentResumeService(signingKey []byte) *PaymentResumeService {
return &PaymentResumeService{signingKey: signingKey}
}
func NormalizeVisibleMethod(method string) string {
return payment.GetBasePaymentType(strings.TrimSpace(method))
}
func NormalizeVisibleMethods(methods []string) []string {
if len(methods) == 0 {
return nil
}
seen := make(map[string]struct{}, len(methods))
out := make([]string, 0, len(methods))
for _, method := range methods {
normalized := NormalizeVisibleMethod(method)
if normalized == "" {
continue
}
if _, ok := seen[normalized]; ok {
continue
}
seen[normalized] = struct{}{}
out = append(out, normalized)
}
return out
}
func NormalizePaymentSource(source string) string {
switch strings.TrimSpace(strings.ToLower(source)) {
case "", PaymentSourceHostedRedirect:
return PaymentSourceHostedRedirect
case "wechat_in_app", "wxpay_resume", PaymentSourceWechatInAppResume:
return PaymentSourceWechatInAppResume
default:
return strings.TrimSpace(strings.ToLower(source))
}
}
func NormalizeVisibleMethodSource(method, source string) string {
switch NormalizeVisibleMethod(method) {
case payment.TypeAlipay:
switch strings.TrimSpace(strings.ToLower(source)) {
case VisibleMethodSourceOfficialAlipay, payment.TypeAlipay, payment.TypeAlipayDirect, "official":
return VisibleMethodSourceOfficialAlipay
case VisibleMethodSourceEasyPayAlipay, payment.TypeEasyPay:
return VisibleMethodSourceEasyPayAlipay
}
case payment.TypeWxpay:
switch strings.TrimSpace(strings.ToLower(source)) {
case VisibleMethodSourceOfficialWechat, payment.TypeWxpay, payment.TypeWxpayDirect, "wechat", "official":
return VisibleMethodSourceOfficialWechat
case VisibleMethodSourceEasyPayWechat, payment.TypeEasyPay:
return VisibleMethodSourceEasyPayWechat
}
}
return ""
}
func VisibleMethodProviderKeyForSource(method, source string) (string, bool) {
switch NormalizeVisibleMethodSource(method, source) {
case VisibleMethodSourceOfficialAlipay:
return payment.TypeAlipay, NormalizeVisibleMethod(method) == payment.TypeAlipay
case VisibleMethodSourceEasyPayAlipay:
return payment.TypeEasyPay, NormalizeVisibleMethod(method) == payment.TypeAlipay
case VisibleMethodSourceOfficialWechat:
return payment.TypeWxpay, NormalizeVisibleMethod(method) == payment.TypeWxpay
case VisibleMethodSourceEasyPayWechat:
return payment.TypeEasyPay, NormalizeVisibleMethod(method) == payment.TypeWxpay
default:
return "", false
}
}
func newVisibleMethodLoadBalancer(inner payment.LoadBalancer, configService *PaymentConfigService) payment.LoadBalancer {
if inner == nil || configService == nil || configService.settingRepo == nil {
return inner
}
return &visibleMethodLoadBalancer{inner: inner, configService: configService}
}
func (lb *visibleMethodLoadBalancer) GetInstanceConfig(ctx context.Context, instanceID int64) (map[string]string, error) {
return lb.inner.GetInstanceConfig(ctx, instanceID)
}
func (lb *visibleMethodLoadBalancer) SelectInstance(ctx context.Context, providerKey string, paymentType payment.PaymentType, strategy payment.Strategy, orderAmount float64) (*payment.InstanceSelection, error) {
visibleMethod := NormalizeVisibleMethod(paymentType)
if providerKey != "" || (visibleMethod != payment.TypeAlipay && visibleMethod != payment.TypeWxpay) {
return lb.inner.SelectInstance(ctx, providerKey, paymentType, strategy, orderAmount)
}
enabledKey := visibleMethodEnabledSettingKey(visibleMethod)
sourceKey := visibleMethodSourceSettingKey(visibleMethod)
vals, err := lb.configService.settingRepo.GetMultiple(ctx, []string{enabledKey, sourceKey})
if err != nil {
return nil, fmt.Errorf("load visible method routing for %s: %w", visibleMethod, err)
}
if vals[enabledKey] != "true" {
return nil, fmt.Errorf("visible payment method %s is disabled", visibleMethod)
}
targetProviderKey, ok := VisibleMethodProviderKeyForSource(visibleMethod, vals[sourceKey])
if !ok {
return nil, fmt.Errorf("visible payment method %s has no valid source", visibleMethod)
}
return lb.inner.SelectInstance(ctx, targetProviderKey, paymentType, strategy, orderAmount)
}
func visibleMethodEnabledSettingKey(method string) string {
switch NormalizeVisibleMethod(method) {
case payment.TypeAlipay:
return SettingPaymentVisibleMethodAlipayEnabled
case payment.TypeWxpay:
return SettingPaymentVisibleMethodWxpayEnabled
default:
return ""
}
}
func visibleMethodSourceSettingKey(method string) string {
switch NormalizeVisibleMethod(method) {
case payment.TypeAlipay:
return SettingPaymentVisibleMethodAlipaySource
case payment.TypeWxpay:
return SettingPaymentVisibleMethodWxpaySource
default:
return ""
}
}
func CanonicalizeReturnURL(raw string) (string, error) {
raw = strings.TrimSpace(raw)
if raw == "" {
return "", nil
}
parsed, err := url.Parse(raw)
if err != nil || !parsed.IsAbs() || parsed.Host == "" {
return "", infraerrors.BadRequest("INVALID_RETURN_URL", "return_url must be an absolute http/https URL")
}
if parsed.Scheme != "http" && parsed.Scheme != "https" {
return "", infraerrors.BadRequest("INVALID_RETURN_URL", "return_url must use http or https")
}
parsed.Fragment = ""
if parsed.Path == "" {
parsed.Path = "/"
}
return parsed.String(), nil
}
func (s *PaymentResumeService) CreateToken(claims ResumeTokenClaims) (string, error) {
if claims.OrderID <= 0 {
return "", fmt.Errorf("resume token requires order id")
}
if claims.IssuedAt == 0 {
claims.IssuedAt = time.Now().Unix()
}
payload, err := json.Marshal(claims)
if err != nil {
return "", fmt.Errorf("marshal resume claims: %w", err)
}
encodedPayload := base64.RawURLEncoding.EncodeToString(payload)
return encodedPayload + "." + s.sign(encodedPayload), nil
}
func (s *PaymentResumeService) ParseToken(token string) (*ResumeTokenClaims, error) {
parts := strings.Split(token, ".")
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token is malformed")
}
if !hmac.Equal([]byte(parts[1]), []byte(s.sign(parts[0]))) {
return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token signature mismatch")
}
payload, err := base64.RawURLEncoding.DecodeString(parts[0])
if err != nil {
return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token payload is malformed")
}
var claims ResumeTokenClaims
if err := json.Unmarshal(payload, &claims); err != nil {
return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token payload is invalid")
}
if claims.OrderID <= 0 {
return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token missing order id")
}
return &claims, nil
}
func (s *PaymentResumeService) sign(payload string) string {
key := s.signingKey
if len(key) == 0 {
key = []byte(paymentResumeFallbackSigningKey)
}
mac := hmac.New(sha256.New, key)
_, _ = mac.Write([]byte(payload))
return base64.RawURLEncoding.EncodeToString(mac.Sum(nil))
}

View File

@@ -0,0 +1,240 @@
//go:build unit
package service
import (
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/payment"
)
func TestNormalizeVisibleMethods(t *testing.T) {
t.Parallel()
got := NormalizeVisibleMethods([]string{
"alipay_direct",
"alipay",
" wxpay_direct ",
"wxpay",
"stripe",
})
want := []string{"alipay", "wxpay", "stripe"}
if len(got) != len(want) {
t.Fatalf("NormalizeVisibleMethods len = %d, want %d (%v)", len(got), len(want), got)
}
for i := range want {
if got[i] != want[i] {
t.Fatalf("NormalizeVisibleMethods[%d] = %q, want %q (full=%v)", i, got[i], want[i], got)
}
}
}
func TestNormalizePaymentSource(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
expect string
}{
{name: "empty uses default", input: "", expect: PaymentSourceHostedRedirect},
{name: "wechat alias normalized", input: "wechat_in_app", expect: PaymentSourceWechatInAppResume},
{name: "canonical value preserved", input: PaymentSourceWechatInAppResume, expect: PaymentSourceWechatInAppResume},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
if got := NormalizePaymentSource(tt.input); got != tt.expect {
t.Fatalf("NormalizePaymentSource(%q) = %q, want %q", tt.input, got, tt.expect)
}
})
}
}
func TestCanonicalizeReturnURL(t *testing.T) {
t.Parallel()
got, err := CanonicalizeReturnURL("https://example.com/pay/result?b=2#a")
if err != nil {
t.Fatalf("CanonicalizeReturnURL returned error: %v", err)
}
if got != "https://example.com/pay/result?b=2" {
t.Fatalf("CanonicalizeReturnURL = %q, want %q", got, "https://example.com/pay/result?b=2")
}
}
func TestCanonicalizeReturnURLRejectsRelativeURL(t *testing.T) {
t.Parallel()
if _, err := CanonicalizeReturnURL("/payment/result"); err == nil {
t.Fatal("CanonicalizeReturnURL should reject relative URLs")
}
}
func TestPaymentResumeTokenRoundTrip(t *testing.T) {
t.Parallel()
svc := NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef"))
token, err := svc.CreateToken(ResumeTokenClaims{
OrderID: 42,
UserID: 7,
ProviderInstanceID: "19",
ProviderKey: "easypay",
PaymentType: "wxpay",
CanonicalReturnURL: "https://example.com/payment/result",
IssuedAt: 1234567890,
})
if err != nil {
t.Fatalf("CreateToken returned error: %v", err)
}
claims, err := svc.ParseToken(token)
if err != nil {
t.Fatalf("ParseToken returned error: %v", err)
}
if claims.OrderID != 42 || claims.UserID != 7 {
t.Fatalf("claims mismatch: %+v", claims)
}
if claims.ProviderInstanceID != "19" || claims.ProviderKey != "easypay" || claims.PaymentType != "wxpay" {
t.Fatalf("claims provider snapshot mismatch: %+v", claims)
}
if claims.CanonicalReturnURL != "https://example.com/payment/result" {
t.Fatalf("claims return URL = %q", claims.CanonicalReturnURL)
}
}
func TestNormalizeVisibleMethodSource(t *testing.T) {
t.Parallel()
tests := []struct {
name string
method string
input string
want string
}{
{name: "alipay official alias", method: payment.TypeAlipay, input: "alipay", want: VisibleMethodSourceOfficialAlipay},
{name: "alipay easypay alias", method: payment.TypeAlipay, input: "easypay", want: VisibleMethodSourceEasyPayAlipay},
{name: "wxpay official alias", method: payment.TypeWxpay, input: "wxpay", want: VisibleMethodSourceOfficialWechat},
{name: "wxpay easypay alias", method: payment.TypeWxpay, input: "easypay", want: VisibleMethodSourceEasyPayWechat},
{name: "unsupported source", method: payment.TypeWxpay, input: "stripe", want: ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
if got := NormalizeVisibleMethodSource(tt.method, tt.input); got != tt.want {
t.Fatalf("NormalizeVisibleMethodSource(%q, %q) = %q, want %q", tt.method, tt.input, got, tt.want)
}
})
}
}
func TestVisibleMethodProviderKeyForSource(t *testing.T) {
t.Parallel()
tests := []struct {
name string
method string
source string
want string
ok bool
}{
{name: "official alipay", method: payment.TypeAlipay, source: VisibleMethodSourceOfficialAlipay, want: payment.TypeAlipay, ok: true},
{name: "easypay alipay", method: payment.TypeAlipay, source: VisibleMethodSourceEasyPayAlipay, want: payment.TypeEasyPay, ok: true},
{name: "official wechat", method: payment.TypeWxpay, source: VisibleMethodSourceOfficialWechat, want: payment.TypeWxpay, ok: true},
{name: "easypay wechat", method: payment.TypeWxpay, source: VisibleMethodSourceEasyPayWechat, want: payment.TypeEasyPay, ok: true},
{name: "mismatched method and source", method: payment.TypeAlipay, source: VisibleMethodSourceOfficialWechat, want: "", ok: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got, ok := VisibleMethodProviderKeyForSource(tt.method, tt.source)
if got != tt.want || ok != tt.ok {
t.Fatalf("VisibleMethodProviderKeyForSource(%q, %q) = (%q, %v), want (%q, %v)", tt.method, tt.source, got, ok, tt.want, tt.ok)
}
})
}
}
func TestVisibleMethodLoadBalancerUsesConfiguredSource(t *testing.T) {
t.Parallel()
inner := &captureLoadBalancer{}
configService := &PaymentConfigService{
settingRepo: &paymentSettingRepoStub{
values: map[string]string{
SettingPaymentVisibleMethodAlipayEnabled: "true",
SettingPaymentVisibleMethodAlipaySource: VisibleMethodSourceOfficialAlipay,
},
},
}
lb := newVisibleMethodLoadBalancer(inner, configService)
_, err := lb.SelectInstance(context.Background(), "", payment.TypeAlipay, payment.StrategyRoundRobin, 12.5)
if err != nil {
t.Fatalf("SelectInstance returned error: %v", err)
}
if inner.lastProviderKey != payment.TypeAlipay {
t.Fatalf("lastProviderKey = %q, want %q", inner.lastProviderKey, payment.TypeAlipay)
}
}
func TestVisibleMethodLoadBalancerRejectsDisabledVisibleMethod(t *testing.T) {
t.Parallel()
inner := &captureLoadBalancer{}
configService := &PaymentConfigService{
settingRepo: &paymentSettingRepoStub{
values: map[string]string{
SettingPaymentVisibleMethodWxpayEnabled: "false",
SettingPaymentVisibleMethodWxpaySource: VisibleMethodSourceOfficialWechat,
},
},
}
lb := newVisibleMethodLoadBalancer(inner, configService)
if _, err := lb.SelectInstance(context.Background(), "", payment.TypeWxpay, payment.StrategyRoundRobin, 9.9); err == nil {
t.Fatal("SelectInstance should reject disabled visible method")
}
}
type paymentSettingRepoStub struct {
values map[string]string
}
func (s *paymentSettingRepoStub) Get(context.Context, string) (*Setting, error) { return nil, nil }
func (s *paymentSettingRepoStub) GetValue(_ context.Context, key string) (string, error) {
return s.values[key], nil
}
func (s *paymentSettingRepoStub) Set(context.Context, string, string) error { return nil }
func (s *paymentSettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
out := make(map[string]string, len(keys))
for _, key := range keys {
out[key] = s.values[key]
}
return out, nil
}
func (s *paymentSettingRepoStub) SetMultiple(context.Context, map[string]string) error { return nil }
func (s *paymentSettingRepoStub) GetAll(context.Context) (map[string]string, error) {
return s.values, nil
}
func (s *paymentSettingRepoStub) Delete(context.Context, string) error { return nil }
type captureLoadBalancer struct {
lastProviderKey string
lastPaymentType string
}
func (c *captureLoadBalancer) GetInstanceConfig(context.Context, int64) (map[string]string, error) {
return map[string]string{}, nil
}
func (c *captureLoadBalancer) SelectInstance(_ context.Context, providerKey string, paymentType payment.PaymentType, _ payment.Strategy, _ float64) (*payment.InstanceSelection, error) {
c.lastProviderKey = providerKey
c.lastPaymentType = paymentType
return &payment.InstanceSelection{ProviderKey: providerKey, SupportedTypes: paymentType}, nil
}

View File

@@ -65,15 +65,17 @@ func generateRandomString(n int) string {
}
type CreateOrderRequest struct {
UserID int64
Amount float64
PaymentType string
ClientIP string
IsMobile bool
SrcHost string
SrcURL string
OrderType string
PlanID int64
UserID int64
Amount float64
PaymentType string
ClientIP string
IsMobile bool
SrcHost string
SrcURL string
ReturnURL string
PaymentSource string
OrderType string
PlanID int64
}
type CreateOrderResponse struct {
@@ -88,6 +90,7 @@ type CreateOrderResponse struct {
ClientSecret string `json:"client_secret,omitempty"`
ExpiresAt time.Time `json:"expires_at"`
PaymentMode string `json:"payment_mode,omitempty"`
ResumeToken string `json:"resume_token,omitempty"`
}
type OrderListParams struct {
@@ -165,10 +168,13 @@ type PaymentService struct {
configService *PaymentConfigService
userRepo UserRepository
groupRepo GroupRepository
resumeService *PaymentResumeService
}
func NewPaymentService(entClient *dbent.Client, registry *payment.Registry, loadBalancer payment.LoadBalancer, redeemService *RedeemService, subscriptionSvc *SubscriptionService, configService *PaymentConfigService, userRepo UserRepository, groupRepo GroupRepository) *PaymentService {
return &PaymentService{entClient: entClient, registry: registry, loadBalancer: loadBalancer, redeemService: redeemService, subscriptionSvc: subscriptionSvc, configService: configService, userRepo: userRepo, groupRepo: groupRepo}
svc := &PaymentService{entClient: entClient, registry: registry, loadBalancer: newVisibleMethodLoadBalancer(loadBalancer, configService), redeemService: redeemService, subscriptionSvc: subscriptionSvc, configService: configService, userRepo: userRepo, groupRepo: groupRepo}
svc.resumeService = NewPaymentResumeService(psResumeSigningKey(configService))
return svc
}
// --- Provider Registry ---
@@ -262,6 +268,20 @@ func psNilIfEmpty(s string) *string {
return &s
}
func (s *PaymentService) paymentResume() *PaymentResumeService {
if s.resumeService != nil {
return s.resumeService
}
return NewPaymentResumeService(psResumeSigningKey(s.configService))
}
func psResumeSigningKey(configService *PaymentConfigService) []byte {
if configService == nil {
return nil
}
return configService.encryptionKey
}
func psSliceContains(sl []string, s string) bool {
for _, v := range sl {
if v == s {

View File

@@ -9,6 +9,7 @@ import (
"fmt"
"log/slog"
"net/url"
"os"
"sort"
"strconv"
"strings"
@@ -114,6 +115,66 @@ type SettingService struct {
webSearchManagerBuilder WebSearchManagerBuilder
}
type ProviderDefaultGrantSettings struct {
Balance float64
Concurrency int
Subscriptions []DefaultSubscriptionSetting
GrantOnSignup bool
GrantOnFirstBind bool
}
type AuthSourceDefaultSettings struct {
Email ProviderDefaultGrantSettings
LinuxDo ProviderDefaultGrantSettings
OIDC ProviderDefaultGrantSettings
WeChat ProviderDefaultGrantSettings
ForceEmailOnThirdPartySignup bool
}
type authSourceDefaultKeySet struct {
balance string
concurrency string
subscriptions string
grantOnSignup string
grantOnFirstBind string
}
var (
emailAuthSourceDefaultKeys = authSourceDefaultKeySet{
balance: SettingKeyAuthSourceDefaultEmailBalance,
concurrency: SettingKeyAuthSourceDefaultEmailConcurrency,
subscriptions: SettingKeyAuthSourceDefaultEmailSubscriptions,
grantOnSignup: SettingKeyAuthSourceDefaultEmailGrantOnSignup,
grantOnFirstBind: SettingKeyAuthSourceDefaultEmailGrantOnFirstBind,
}
linuxDoAuthSourceDefaultKeys = authSourceDefaultKeySet{
balance: SettingKeyAuthSourceDefaultLinuxDoBalance,
concurrency: SettingKeyAuthSourceDefaultLinuxDoConcurrency,
subscriptions: SettingKeyAuthSourceDefaultLinuxDoSubscriptions,
grantOnSignup: SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup,
grantOnFirstBind: SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind,
}
oidcAuthSourceDefaultKeys = authSourceDefaultKeySet{
balance: SettingKeyAuthSourceDefaultOIDCBalance,
concurrency: SettingKeyAuthSourceDefaultOIDCConcurrency,
subscriptions: SettingKeyAuthSourceDefaultOIDCSubscriptions,
grantOnSignup: SettingKeyAuthSourceDefaultOIDCGrantOnSignup,
grantOnFirstBind: SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind,
}
weChatAuthSourceDefaultKeys = authSourceDefaultKeySet{
balance: SettingKeyAuthSourceDefaultWeChatBalance,
concurrency: SettingKeyAuthSourceDefaultWeChatConcurrency,
subscriptions: SettingKeyAuthSourceDefaultWeChatSubscriptions,
grantOnSignup: SettingKeyAuthSourceDefaultWeChatGrantOnSignup,
grantOnFirstBind: SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind,
}
)
const (
defaultAuthSourceBalance = 0
defaultAuthSourceConcurrency = 5
)
// NewSettingService 创建系统设置服务实例
func NewSettingService(settingRepo SettingRepository, cfg *config.Config) *SettingService {
return &SettingService{
@@ -212,6 +273,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
if oidcProviderName == "" {
oidcProviderName = "OIDC"
}
weChatEnabled := isWeChatOAuthConfigured()
// Password reset requires email verification to be enabled
emailVerifyEnabled := settings[SettingKeyEmailVerifyEnabled] == "true"
@@ -254,6 +316,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
CustomMenuItems: settings[SettingKeyCustomMenuItems],
CustomEndpoints: settings[SettingKeyCustomEndpoints],
LinuxDoOAuthEnabled: linuxDoEnabled,
WeChatOAuthEnabled: weChatEnabled,
BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true",
PaymentEnabled: settings[SettingPaymentEnabled] == "true",
OIDCOAuthEnabled: oidcEnabled,
@@ -310,6 +373,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
CustomMenuItems json.RawMessage `json:"custom_menu_items"`
CustomEndpoints json.RawMessage `json:"custom_endpoints"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
WeChatOAuthEnabled bool `json:"wechat_oauth_enabled"`
BackendModeEnabled bool `json:"backend_mode_enabled"`
PaymentEnabled bool `json:"payment_enabled"`
OIDCOAuthEnabled bool `json:"oidc_oauth_enabled"`
@@ -344,6 +408,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
CustomMenuItems: filterUserVisibleMenuItems(settings.CustomMenuItems),
CustomEndpoints: safeRawJSONArray(settings.CustomEndpoints),
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
WeChatOAuthEnabled: settings.WeChatOAuthEnabled,
BackendModeEnabled: settings.BackendModeEnabled,
PaymentEnabled: settings.PaymentEnabled,
OIDCOAuthEnabled: settings.OIDCOAuthEnabled,
@@ -392,6 +457,14 @@ func filterUserVisibleMenuItems(raw string) json.RawMessage {
return result
}
func isWeChatOAuthConfigured() bool {
openConfigured := strings.TrimSpace(os.Getenv("WECHAT_OAUTH_OPEN_APP_ID")) != "" &&
strings.TrimSpace(os.Getenv("WECHAT_OAUTH_OPEN_APP_SECRET")) != ""
mpConfigured := strings.TrimSpace(os.Getenv("WECHAT_OAUTH_MP_APP_ID")) != "" &&
strings.TrimSpace(os.Getenv("WECHAT_OAUTH_MP_APP_SECRET")) != ""
return openConfigured || mpConfigured
}
// safeRawJSONArray returns raw as json.RawMessage if it's valid JSON, otherwise "[]".
func safeRawJSONArray(raw string) json.RawMessage {
raw = strings.TrimSpace(raw)
@@ -919,6 +992,74 @@ func (s *SettingService) GetDefaultSubscriptions(ctx context.Context) []DefaultS
return parseDefaultSubscriptions(value)
}
func (s *SettingService) GetAuthSourceDefaultSettings(ctx context.Context) (*AuthSourceDefaultSettings, error) {
keys := []string{
SettingKeyAuthSourceDefaultEmailBalance,
SettingKeyAuthSourceDefaultEmailConcurrency,
SettingKeyAuthSourceDefaultEmailSubscriptions,
SettingKeyAuthSourceDefaultEmailGrantOnSignup,
SettingKeyAuthSourceDefaultEmailGrantOnFirstBind,
SettingKeyAuthSourceDefaultLinuxDoBalance,
SettingKeyAuthSourceDefaultLinuxDoConcurrency,
SettingKeyAuthSourceDefaultLinuxDoSubscriptions,
SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup,
SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind,
SettingKeyAuthSourceDefaultOIDCBalance,
SettingKeyAuthSourceDefaultOIDCConcurrency,
SettingKeyAuthSourceDefaultOIDCSubscriptions,
SettingKeyAuthSourceDefaultOIDCGrantOnSignup,
SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind,
SettingKeyAuthSourceDefaultWeChatBalance,
SettingKeyAuthSourceDefaultWeChatConcurrency,
SettingKeyAuthSourceDefaultWeChatSubscriptions,
SettingKeyAuthSourceDefaultWeChatGrantOnSignup,
SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind,
SettingKeyForceEmailOnThirdPartySignup,
}
settings, err := s.settingRepo.GetMultiple(ctx, keys)
if err != nil {
return nil, fmt.Errorf("get auth source default settings: %w", err)
}
return &AuthSourceDefaultSettings{
Email: parseProviderDefaultGrantSettings(settings, emailAuthSourceDefaultKeys),
LinuxDo: parseProviderDefaultGrantSettings(settings, linuxDoAuthSourceDefaultKeys),
OIDC: parseProviderDefaultGrantSettings(settings, oidcAuthSourceDefaultKeys),
WeChat: parseProviderDefaultGrantSettings(settings, weChatAuthSourceDefaultKeys),
ForceEmailOnThirdPartySignup: settings[SettingKeyForceEmailOnThirdPartySignup] == "true",
}, nil
}
func (s *SettingService) UpdateAuthSourceDefaultSettings(ctx context.Context, settings *AuthSourceDefaultSettings) error {
if settings == nil {
return nil
}
for _, subscriptions := range [][]DefaultSubscriptionSetting{
settings.Email.Subscriptions,
settings.LinuxDo.Subscriptions,
settings.OIDC.Subscriptions,
settings.WeChat.Subscriptions,
} {
if err := s.validateDefaultSubscriptionGroups(ctx, subscriptions); err != nil {
return err
}
}
updates := make(map[string]string, 21)
writeProviderDefaultGrantUpdates(updates, emailAuthSourceDefaultKeys, settings.Email)
writeProviderDefaultGrantUpdates(updates, linuxDoAuthSourceDefaultKeys, settings.LinuxDo)
writeProviderDefaultGrantUpdates(updates, oidcAuthSourceDefaultKeys, settings.OIDC)
writeProviderDefaultGrantUpdates(updates, weChatAuthSourceDefaultKeys, settings.WeChat)
updates[SettingKeyForceEmailOnThirdPartySignup] = strconv.FormatBool(settings.ForceEmailOnThirdPartySignup)
if err := s.settingRepo.SetMultiple(ctx, updates); err != nil {
return fmt.Errorf("update auth source default settings: %w", err)
}
return nil
}
// InitializeDefaultSettings 初始化默认设置
func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
// 检查是否已有设置
@@ -933,25 +1074,46 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
// 初始化默认设置
defaults := map[string]string{
SettingKeyRegistrationEnabled: "true",
SettingKeyEmailVerifyEnabled: "false",
SettingKeyRegistrationEmailSuffixWhitelist: "[]",
SettingKeyPromoCodeEnabled: "true", // 默认启用优惠码功能
SettingKeySiteName: "Sub2API",
SettingKeySiteLogo: "",
SettingKeyPurchaseSubscriptionEnabled: "false",
SettingKeyPurchaseSubscriptionURL: "",
SettingKeyTableDefaultPageSize: "20",
SettingKeyTablePageSizeOptions: "[10,20,50,100]",
SettingKeyCustomMenuItems: "[]",
SettingKeyCustomEndpoints: "[]",
SettingKeyOIDCConnectEnabled: "false",
SettingKeyOIDCConnectProviderName: "OIDC",
SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
SettingKeyDefaultSubscriptions: "[]",
SettingKeySMTPPort: "587",
SettingKeySMTPUseTLS: "false",
SettingKeyRegistrationEnabled: "true",
SettingKeyEmailVerifyEnabled: "false",
SettingKeyRegistrationEmailSuffixWhitelist: "[]",
SettingKeyPromoCodeEnabled: "true", // 默认启用优惠码功能
SettingKeySiteName: "Sub2API",
SettingKeySiteLogo: "",
SettingKeyPurchaseSubscriptionEnabled: "false",
SettingKeyPurchaseSubscriptionURL: "",
SettingKeyTableDefaultPageSize: "20",
SettingKeyTablePageSizeOptions: "[10,20,50,100]",
SettingKeyCustomMenuItems: "[]",
SettingKeyCustomEndpoints: "[]",
SettingKeyOIDCConnectEnabled: "false",
SettingKeyOIDCConnectProviderName: "OIDC",
SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
SettingKeyDefaultSubscriptions: "[]",
SettingKeyAuthSourceDefaultEmailBalance: "0",
SettingKeyAuthSourceDefaultEmailConcurrency: "5",
SettingKeyAuthSourceDefaultEmailSubscriptions: "[]",
SettingKeyAuthSourceDefaultEmailGrantOnSignup: "true",
SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "false",
SettingKeyAuthSourceDefaultLinuxDoBalance: "0",
SettingKeyAuthSourceDefaultLinuxDoConcurrency: "5",
SettingKeyAuthSourceDefaultLinuxDoSubscriptions: "[]",
SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup: "true",
SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind: "false",
SettingKeyAuthSourceDefaultOIDCBalance: "0",
SettingKeyAuthSourceDefaultOIDCConcurrency: "5",
SettingKeyAuthSourceDefaultOIDCSubscriptions: "[]",
SettingKeyAuthSourceDefaultOIDCGrantOnSignup: "true",
SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind: "false",
SettingKeyAuthSourceDefaultWeChatBalance: "0",
SettingKeyAuthSourceDefaultWeChatConcurrency: "5",
SettingKeyAuthSourceDefaultWeChatSubscriptions: "[]",
SettingKeyAuthSourceDefaultWeChatGrantOnSignup: "true",
SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind: "false",
SettingKeyForceEmailOnThirdPartySignup: "false",
SettingKeySMTPPort: "587",
SettingKeySMTPUseTLS: "false",
// Model fallback defaults
SettingKeyEnableModelFallback: "false",
SettingKeyFallbackModelAnthropic: "claude-3-5-sonnet-20241022",
@@ -1164,6 +1326,8 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
} else {
result.OIDCConnectValidateIDToken = oidcBase.ValidateIDToken
}
result.OIDCConnectUsePKCE = true
result.OIDCConnectValidateIDToken = true
if v, ok := settings[SettingKeyOIDCConnectAllowedSigningAlgs]; ok && strings.TrimSpace(v) != "" {
result.OIDCConnectAllowedSigningAlgs = strings.TrimSpace(v)
} else {
@@ -1317,6 +1481,51 @@ func parseDefaultSubscriptions(raw string) []DefaultSubscriptionSetting {
return normalized
}
func parseProviderDefaultGrantSettings(settings map[string]string, keys authSourceDefaultKeySet) ProviderDefaultGrantSettings {
result := ProviderDefaultGrantSettings{
Balance: defaultAuthSourceBalance,
Concurrency: defaultAuthSourceConcurrency,
Subscriptions: []DefaultSubscriptionSetting{},
GrantOnSignup: true,
GrantOnFirstBind: false,
}
if v, err := strconv.ParseFloat(strings.TrimSpace(settings[keys.balance]), 64); err == nil {
result.Balance = v
}
if v, err := strconv.Atoi(strings.TrimSpace(settings[keys.concurrency])); err == nil {
result.Concurrency = v
}
if items := parseDefaultSubscriptions(settings[keys.subscriptions]); items != nil {
result.Subscriptions = items
}
if raw, ok := settings[keys.grantOnSignup]; ok {
result.GrantOnSignup = raw == "true"
}
if raw, ok := settings[keys.grantOnFirstBind]; ok {
result.GrantOnFirstBind = raw == "true"
}
return result
}
func writeProviderDefaultGrantUpdates(updates map[string]string, keys authSourceDefaultKeySet, settings ProviderDefaultGrantSettings) {
updates[keys.balance] = strconv.FormatFloat(settings.Balance, 'f', 8, 64)
updates[keys.concurrency] = strconv.Itoa(settings.Concurrency)
subscriptions := settings.Subscriptions
if subscriptions == nil {
subscriptions = []DefaultSubscriptionSetting{}
}
raw, err := json.Marshal(subscriptions)
if err != nil {
raw = []byte("[]")
}
updates[keys.subscriptions] = string(raw)
updates[keys.grantOnSignup] = strconv.FormatBool(settings.GrantOnSignup)
updates[keys.grantOnFirstBind] = strconv.FormatBool(settings.GrantOnFirstBind)
}
func parseTablePreferences(defaultPageSizeRaw, optionsRaw string) (int, []int) {
defaultPageSize := 20
if v, err := strconv.Atoi(strings.TrimSpace(defaultPageSizeRaw)); err == nil {
@@ -1539,6 +1748,7 @@ func (s *SettingService) GetLinuxDoConnectOAuthConfig(ctx context.Context) (conf
if v, ok := settings[SettingKeyLinuxDoConnectRedirectURL]; ok && strings.TrimSpace(v) != "" {
effective.RedirectURL = strings.TrimSpace(v)
}
effective.UsePKCE = true
if !effective.Enabled {
return config.LinuxDoConnectConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "oauth login is disabled")
@@ -1587,9 +1797,6 @@ func (s *SettingService) GetLinuxDoConnectOAuthConfig(ctx context.Context) (conf
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth client secret not configured")
}
case "none":
if !effective.UsePKCE {
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth pkce must be enabled when token_auth_method=none")
}
default:
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth token_auth_method invalid")
}
@@ -1737,6 +1944,8 @@ func (s *SettingService) GetOIDCConnectOAuthConfig(ctx context.Context) (config.
if raw, ok := settings[SettingKeyOIDCConnectValidateIDToken]; ok {
effective.ValidateIDToken = raw == "true"
}
effective.UsePKCE = true
effective.ValidateIDToken = true
if v, ok := settings[SettingKeyOIDCConnectAllowedSigningAlgs]; ok && strings.TrimSpace(v) != "" {
effective.AllowedSigningAlgs = strings.TrimSpace(v)
}
@@ -1864,9 +2073,6 @@ func (s *SettingService) GetOIDCConnectOAuthConfig(ctx context.Context) (config.
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth client secret not configured")
}
case "none":
if !effective.UsePKCE {
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth pkce must be enabled when token_auth_method=none")
}
default:
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth token_auth_method invalid")
}

View File

@@ -0,0 +1,136 @@
//go:build unit
package service
import (
"context"
"encoding/json"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
type authSourceDefaultsRepoStub struct {
values map[string]string
updates map[string]string
}
func (s *authSourceDefaultsRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
panic("unexpected Get call")
}
func (s *authSourceDefaultsRepoStub) GetValue(ctx context.Context, key string) (string, error) {
panic("unexpected GetValue call")
}
func (s *authSourceDefaultsRepoStub) Set(ctx context.Context, key, value string) error {
panic("unexpected Set call")
}
func (s *authSourceDefaultsRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
out := make(map[string]string, len(keys))
for _, key := range keys {
if value, ok := s.values[key]; ok {
out[key] = value
}
}
return out, nil
}
func (s *authSourceDefaultsRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
s.updates = make(map[string]string, len(settings))
for key, value := range settings {
s.updates[key] = value
if s.values == nil {
s.values = map[string]string{}
}
s.values[key] = value
}
return nil
}
func (s *authSourceDefaultsRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
panic("unexpected GetAll call")
}
func (s *authSourceDefaultsRepoStub) Delete(ctx context.Context, key string) error {
panic("unexpected Delete call")
}
func TestSettingService_GetAuthSourceDefaultSettings_ParsesValuesAndDefaults(t *testing.T) {
repo := &authSourceDefaultsRepoStub{
values: map[string]string{
SettingKeyAuthSourceDefaultEmailBalance: "12.5",
SettingKeyAuthSourceDefaultEmailConcurrency: "7",
SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`,
SettingKeyAuthSourceDefaultEmailGrantOnSignup: "false",
SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind: "true",
SettingKeyForceEmailOnThirdPartySignup: "true",
},
}
svc := NewSettingService(repo, &config.Config{})
got, err := svc.GetAuthSourceDefaultSettings(context.Background())
require.NoError(t, err)
require.Equal(t, 12.5, got.Email.Balance)
require.Equal(t, 7, got.Email.Concurrency)
require.Equal(t, []DefaultSubscriptionSetting{{GroupID: 11, ValidityDays: 30}}, got.Email.Subscriptions)
require.False(t, got.Email.GrantOnSignup)
require.False(t, got.Email.GrantOnFirstBind)
require.Equal(t, 0.0, got.LinuxDo.Balance)
require.Equal(t, 5, got.LinuxDo.Concurrency)
require.Equal(t, []DefaultSubscriptionSetting{}, got.LinuxDo.Subscriptions)
require.True(t, got.LinuxDo.GrantOnSignup)
require.True(t, got.LinuxDo.GrantOnFirstBind)
require.Equal(t, 5, got.OIDC.Concurrency)
require.Equal(t, 5, got.WeChat.Concurrency)
require.True(t, got.ForceEmailOnThirdPartySignup)
}
func TestSettingService_UpdateAuthSourceDefaultSettings_PersistsAllKeys(t *testing.T) {
repo := &authSourceDefaultsRepoStub{}
svc := NewSettingService(repo, &config.Config{})
err := svc.UpdateAuthSourceDefaultSettings(context.Background(), &AuthSourceDefaultSettings{
Email: ProviderDefaultGrantSettings{
Balance: 1.25,
Concurrency: 3,
Subscriptions: []DefaultSubscriptionSetting{{GroupID: 21, ValidityDays: 14}},
GrantOnSignup: false,
GrantOnFirstBind: true,
},
LinuxDo: ProviderDefaultGrantSettings{
Balance: 2,
Concurrency: 4,
Subscriptions: []DefaultSubscriptionSetting{{GroupID: 22, ValidityDays: 30}},
GrantOnSignup: true,
GrantOnFirstBind: false,
},
OIDC: ProviderDefaultGrantSettings{
Balance: 3,
Concurrency: 5,
Subscriptions: []DefaultSubscriptionSetting{{GroupID: 23, ValidityDays: 60}},
GrantOnSignup: true,
GrantOnFirstBind: true,
},
WeChat: ProviderDefaultGrantSettings{
Balance: 4,
Concurrency: 6,
Subscriptions: []DefaultSubscriptionSetting{{GroupID: 24, ValidityDays: 90}},
GrantOnSignup: false,
GrantOnFirstBind: false,
},
ForceEmailOnThirdPartySignup: true,
})
require.NoError(t, err)
require.Equal(t, "1.25000000", repo.updates[SettingKeyAuthSourceDefaultEmailBalance])
require.Equal(t, "3", repo.updates[SettingKeyAuthSourceDefaultEmailConcurrency])
require.Equal(t, "false", repo.updates[SettingKeyAuthSourceDefaultEmailGrantOnSignup])
require.Equal(t, "true", repo.updates[SettingKeyAuthSourceDefaultEmailGrantOnFirstBind])
require.Equal(t, "true", repo.updates[SettingKeyForceEmailOnThirdPartySignup])
var got []DefaultSubscriptionSetting
require.NoError(t, json.Unmarshal([]byte(repo.updates[SettingKeyAuthSourceDefaultWeChatSubscriptions]), &got))
require.Equal(t, []DefaultSubscriptionSetting{{GroupID: 24, ValidityDays: 90}}, got)
}

View File

@@ -152,6 +152,7 @@ type PublicSettings struct {
CustomEndpoints string // JSON array of custom endpoints
LinuxDoOAuthEnabled bool
WeChatOAuthEnabled bool
BackendModeEnabled bool
PaymentEnabled bool
OIDCOAuthEnabled bool

View File

@@ -7,19 +7,27 @@ import (
)
type User struct {
ID int64
Email string
Username string
Notes string
PasswordHash string
Role string
Balance float64
Concurrency int
Status string
AllowedGroups []int64
TokenVersion int64 // Incremented on password change to invalidate existing tokens
CreatedAt time.Time
UpdatedAt time.Time
ID int64
Email string
Username string
Notes string
AvatarURL string
AvatarSource string
AvatarMIME string
AvatarByteSize int
AvatarSHA256 string
PasswordHash string
Role string
Balance float64
Concurrency int
Status string
AllowedGroups []int64
TokenVersion int64 // Incremented on password change to invalidate existing tokens
SignupSource string
LastLoginAt *time.Time
LastActiveAt *time.Time
CreatedAt time.Time
UpdatedAt time.Time
// GroupRates 用户专属分组倍率配置
// map[groupID]rateMultiplier

View File

@@ -2,9 +2,13 @@ package service
import (
"context"
"crypto/sha256"
"crypto/subtle"
"encoding/base64"
"encoding/hex"
"fmt"
"log/slog"
"net/url"
"strings"
"time"
@@ -17,10 +21,14 @@ var (
ErrPasswordIncorrect = infraerrors.BadRequest("PASSWORD_INCORRECT", "current password is incorrect")
ErrInsufficientPerms = infraerrors.Forbidden("INSUFFICIENT_PERMISSIONS", "insufficient permissions")
ErrNotifyCodeUserRateLimit = infraerrors.TooManyRequests("NOTIFY_CODE_USER_RATE_LIMIT", "too many verification codes requested, please try again later")
ErrAvatarInvalid = infraerrors.BadRequest("AVATAR_INVALID", "avatar must be a valid image data URL or http(s) URL")
ErrAvatarTooLarge = infraerrors.BadRequest("AVATAR_TOO_LARGE", "avatar image must be 100KB or smaller")
ErrAvatarNotImage = infraerrors.BadRequest("AVATAR_NOT_IMAGE", "avatar content must be an image")
)
const (
maxNotifyEmails = 3 // Maximum number of notification emails per user
maxNotifyEmails = 3 // Maximum number of notification emails per user
maxInlineAvatarBytes = 100 * 1024
// User-level rate limiting for notify email verification codes
notifyCodeUserRateLimit = 5
@@ -47,6 +55,9 @@ type UserRepository interface {
GetFirstAdmin(ctx context.Context) (*User, error)
Update(ctx context.Context, user *User) error
Delete(ctx context.Context, id int64) error
GetUserAvatar(ctx context.Context, userID int64) (*UserAvatar, error)
UpsertUserAvatar(ctx context.Context, userID int64, input UpsertUserAvatarInput) (*UserAvatar, error)
DeleteUserAvatar(ctx context.Context, userID int64) error
List(ctx context.Context, params pagination.PaginationParams) ([]User, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UserListFilters) ([]User, *pagination.PaginationResult, error)
@@ -71,11 +82,30 @@ type UserRepository interface {
type UpdateProfileRequest struct {
Email *string `json:"email"`
Username *string `json:"username"`
AvatarURL *string `json:"avatar_url"`
Concurrency *int `json:"concurrency"`
BalanceNotifyEnabled *bool `json:"balance_notify_enabled"`
BalanceNotifyThreshold *float64 `json:"balance_notify_threshold"`
}
type UserAvatar struct {
StorageProvider string
StorageKey string
URL string
ContentType string
ByteSize int
SHA256 string
}
type UpsertUserAvatarInput struct {
StorageProvider string
StorageKey string
URL string
ContentType string
ByteSize int
SHA256 string
}
// ChangePasswordRequest 修改密码请求
type ChangePasswordRequest struct {
CurrentPassword string `json:"current_password"`
@@ -115,6 +145,9 @@ func (s *UserService) GetProfile(ctx context.Context, userID int64) (*User, erro
if err != nil {
return nil, fmt.Errorf("get user: %w", err)
}
if err := s.hydrateUserAvatar(ctx, user); err != nil {
return nil, fmt.Errorf("get user avatar: %w", err)
}
return user, nil
}
@@ -143,6 +176,27 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat
user.Username = *req.Username
}
if req.AvatarURL != nil {
avatarValue := strings.TrimSpace(*req.AvatarURL)
switch {
case avatarValue == "":
if err := s.userRepo.DeleteUserAvatar(ctx, userID); err != nil {
return nil, fmt.Errorf("delete avatar: %w", err)
}
applyUserAvatar(user, nil)
default:
avatarInput, err := normalizeUserAvatarInput(avatarValue)
if err != nil {
return nil, err
}
avatar, err := s.userRepo.UpsertUserAvatar(ctx, userID, avatarInput)
if err != nil {
return nil, fmt.Errorf("upsert avatar: %w", err)
}
applyUserAvatar(user, avatar)
}
}
if req.Concurrency != nil {
user.Concurrency = *req.Concurrency
}
@@ -168,6 +222,87 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat
return user, nil
}
func applyUserAvatar(user *User, avatar *UserAvatar) {
if user == nil {
return
}
if avatar == nil {
user.AvatarURL = ""
user.AvatarSource = ""
user.AvatarMIME = ""
user.AvatarByteSize = 0
user.AvatarSHA256 = ""
return
}
user.AvatarURL = avatar.URL
user.AvatarSource = avatar.StorageProvider
user.AvatarMIME = avatar.ContentType
user.AvatarByteSize = avatar.ByteSize
user.AvatarSHA256 = avatar.SHA256
}
func normalizeUserAvatarInput(raw string) (UpsertUserAvatarInput, error) {
raw = strings.TrimSpace(raw)
if raw == "" {
return UpsertUserAvatarInput{}, ErrAvatarInvalid
}
if strings.HasPrefix(raw, "data:") {
return normalizeInlineUserAvatarInput(raw)
}
parsed, err := url.Parse(raw)
if err != nil || parsed == nil {
return UpsertUserAvatarInput{}, ErrAvatarInvalid
}
if !strings.EqualFold(parsed.Scheme, "http") && !strings.EqualFold(parsed.Scheme, "https") {
return UpsertUserAvatarInput{}, ErrAvatarInvalid
}
if strings.TrimSpace(parsed.Host) == "" {
return UpsertUserAvatarInput{}, ErrAvatarInvalid
}
return UpsertUserAvatarInput{
StorageProvider: "remote_url",
URL: raw,
}, nil
}
func normalizeInlineUserAvatarInput(raw string) (UpsertUserAvatarInput, error) {
body := strings.TrimPrefix(raw, "data:")
meta, encoded, ok := strings.Cut(body, ",")
if !ok {
return UpsertUserAvatarInput{}, ErrAvatarInvalid
}
meta = strings.TrimSpace(meta)
encoded = strings.TrimSpace(encoded)
if !strings.HasSuffix(strings.ToLower(meta), ";base64") {
return UpsertUserAvatarInput{}, ErrAvatarInvalid
}
contentType := strings.TrimSpace(meta[:len(meta)-len(";base64")])
if contentType == "" || !strings.HasPrefix(strings.ToLower(contentType), "image/") {
return UpsertUserAvatarInput{}, ErrAvatarNotImage
}
decoded, err := base64.StdEncoding.DecodeString(encoded)
if err != nil {
return UpsertUserAvatarInput{}, ErrAvatarInvalid
}
if len(decoded) > maxInlineAvatarBytes {
return UpsertUserAvatarInput{}, ErrAvatarTooLarge
}
sum := sha256.Sum256(decoded)
return UpsertUserAvatarInput{
StorageProvider: "inline",
URL: raw,
ContentType: contentType,
ByteSize: len(decoded),
SHA256: hex.EncodeToString(sum[:]),
}, nil
}
// ChangePassword 修改密码
// Security: Increments TokenVersion to invalidate all existing JWT tokens
func (s *UserService) ChangePassword(ctx context.Context, userID int64, req ChangePasswordRequest) error {
@@ -202,9 +337,25 @@ func (s *UserService) GetByID(ctx context.Context, id int64) (*User, error) {
if err != nil {
return nil, fmt.Errorf("get user: %w", err)
}
if err := s.hydrateUserAvatar(ctx, user); err != nil {
return nil, fmt.Errorf("get user avatar: %w", err)
}
return user, nil
}
func (s *UserService) hydrateUserAvatar(ctx context.Context, user *User) error {
if s == nil || s.userRepo == nil || user == nil || user.ID == 0 {
return nil
}
avatar, err := s.userRepo.GetUserAvatar(ctx, user.ID)
if err != nil {
return err
}
applyUserAvatar(user, avatar)
return nil
}
// List 获取用户列表(管理员功能)
func (s *UserService) List(ctx context.Context, params pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
users, pagination, err := s.userRepo.List(ctx, params)

View File

@@ -4,6 +4,9 @@ package service
import (
"context"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"errors"
"sync"
"sync/atomic"
@@ -19,14 +22,65 @@ import (
type mockUserRepo struct {
updateBalanceErr error
updateBalanceFn func(ctx context.Context, id int64, amount float64) error
getByIDUser *User
getByIDErr error
updateFn func(ctx context.Context, user *User) error
updateCalls int
upsertAvatarFn func(ctx context.Context, userID int64, input UpsertUserAvatarInput) (*UserAvatar, error)
upsertAvatarArgs []UpsertUserAvatarInput
deleteAvatarFn func(ctx context.Context, userID int64) error
deleteAvatarIDs []int64
getAvatarFn func(ctx context.Context, userID int64) (*UserAvatar, error)
}
func (m *mockUserRepo) Create(context.Context, *User) error { return nil }
func (m *mockUserRepo) GetByID(context.Context, int64) (*User, error) { return &User{}, nil }
func (m *mockUserRepo) Create(context.Context, *User) error { return nil }
func (m *mockUserRepo) GetByID(context.Context, int64) (*User, error) {
if m.getByIDErr != nil {
return nil, m.getByIDErr
}
if m.getByIDUser != nil {
cloned := *m.getByIDUser
return &cloned, nil
}
return &User{}, nil
}
func (m *mockUserRepo) GetByEmail(context.Context, string) (*User, error) { return &User{}, nil }
func (m *mockUserRepo) GetFirstAdmin(context.Context) (*User, error) { return &User{}, nil }
func (m *mockUserRepo) Update(context.Context, *User) error { return nil }
func (m *mockUserRepo) Delete(context.Context, int64) error { return nil }
func (m *mockUserRepo) Update(ctx context.Context, user *User) error {
m.updateCalls++
if m.updateFn != nil {
return m.updateFn(ctx, user)
}
return nil
}
func (m *mockUserRepo) Delete(context.Context, int64) error { return nil }
func (m *mockUserRepo) GetUserAvatar(ctx context.Context, userID int64) (*UserAvatar, error) {
if m.getAvatarFn != nil {
return m.getAvatarFn(ctx, userID)
}
return nil, nil
}
func (m *mockUserRepo) UpsertUserAvatar(ctx context.Context, userID int64, input UpsertUserAvatarInput) (*UserAvatar, error) {
m.upsertAvatarArgs = append(m.upsertAvatarArgs, input)
if m.upsertAvatarFn != nil {
return m.upsertAvatarFn(ctx, userID, input)
}
return &UserAvatar{
StorageProvider: input.StorageProvider,
StorageKey: input.StorageKey,
URL: input.URL,
ContentType: input.ContentType,
ByteSize: input.ByteSize,
SHA256: input.SHA256,
}, nil
}
func (m *mockUserRepo) DeleteUserAvatar(ctx context.Context, userID int64) error {
m.deleteAvatarIDs = append(m.deleteAvatarIDs, userID)
if m.deleteAvatarFn != nil {
return m.deleteAvatarFn(ctx, userID)
}
return nil
}
func (m *mockUserRepo) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
return nil, nil, nil
}
@@ -200,3 +254,121 @@ func TestNewUserService_FieldsAssignment(t *testing.T) {
require.Equal(t, auth, svc.authCacheInvalidator)
require.Equal(t, cache, svc.billingCache)
}
func TestUpdateProfile_StoresInlineAvatarWithinLimit(t *testing.T) {
raw := []byte("small-avatar")
dataURL := "data:image/png;base64," + base64.StdEncoding.EncodeToString(raw)
expectedSum := sha256.Sum256(raw)
repo := &mockUserRepo{
getByIDUser: &User{
ID: 7,
Email: "avatar@example.com",
Username: "avatar-user",
},
}
svc := NewUserService(repo, nil, nil, nil)
updated, err := svc.UpdateProfile(context.Background(), 7, UpdateProfileRequest{
AvatarURL: &dataURL,
})
require.NoError(t, err)
require.Len(t, repo.upsertAvatarArgs, 1)
require.Equal(t, "inline", repo.upsertAvatarArgs[0].StorageProvider)
require.Equal(t, "image/png", repo.upsertAvatarArgs[0].ContentType)
require.Equal(t, len(raw), repo.upsertAvatarArgs[0].ByteSize)
require.Equal(t, hex.EncodeToString(expectedSum[:]), repo.upsertAvatarArgs[0].SHA256)
require.Equal(t, dataURL, updated.AvatarURL)
require.Equal(t, "inline", updated.AvatarSource)
require.Equal(t, "image/png", updated.AvatarMIME)
require.Equal(t, len(raw), updated.AvatarByteSize)
require.Equal(t, hex.EncodeToString(expectedSum[:]), updated.AvatarSHA256)
}
func TestUpdateProfile_RejectsInlineAvatarOverLimit(t *testing.T) {
raw := make([]byte, maxInlineAvatarBytes+1)
dataURL := "data:image/png;base64," + base64.StdEncoding.EncodeToString(raw)
repo := &mockUserRepo{
getByIDUser: &User{
ID: 8,
Email: "large-avatar@example.com",
Username: "too-large",
},
}
svc := NewUserService(repo, nil, nil, nil)
_, err := svc.UpdateProfile(context.Background(), 8, UpdateProfileRequest{
AvatarURL: &dataURL,
})
require.ErrorIs(t, err, ErrAvatarTooLarge)
require.Empty(t, repo.upsertAvatarArgs)
require.Empty(t, repo.deleteAvatarIDs)
require.Zero(t, repo.updateCalls)
}
func TestUpdateProfile_StoresRemoteAvatarURL(t *testing.T) {
remoteURL := "https://cdn.example.com/avatar.png"
repo := &mockUserRepo{
getByIDUser: &User{
ID: 9,
Email: "remote-avatar@example.com",
Username: "remote-avatar",
},
}
svc := NewUserService(repo, nil, nil, nil)
updated, err := svc.UpdateProfile(context.Background(), 9, UpdateProfileRequest{
AvatarURL: &remoteURL,
})
require.NoError(t, err)
require.Len(t, repo.upsertAvatarArgs, 1)
require.Equal(t, "remote_url", repo.upsertAvatarArgs[0].StorageProvider)
require.Equal(t, remoteURL, repo.upsertAvatarArgs[0].URL)
require.Equal(t, remoteURL, updated.AvatarURL)
require.Equal(t, "remote_url", updated.AvatarSource)
require.Zero(t, updated.AvatarByteSize)
}
func TestUpdateProfile_DeletesAvatarOnEmptyString(t *testing.T) {
empty := ""
repo := &mockUserRepo{
getByIDUser: &User{
ID: 10,
Email: "delete-avatar@example.com",
Username: "delete-avatar",
AvatarURL: "https://cdn.example.com/old.png",
AvatarSource: "remote_url",
},
}
svc := NewUserService(repo, nil, nil, nil)
updated, err := svc.UpdateProfile(context.Background(), 10, UpdateProfileRequest{
AvatarURL: &empty,
})
require.NoError(t, err)
require.Equal(t, []int64{10}, repo.deleteAvatarIDs)
require.Empty(t, repo.upsertAvatarArgs)
require.Empty(t, updated.AvatarURL)
require.Empty(t, updated.AvatarSource)
}
func TestGetProfile_HydratesAvatarFromRepository(t *testing.T) {
repo := &mockUserRepo{
getByIDUser: &User{
ID: 12,
Email: "profile-avatar@example.com",
Username: "profile-avatar",
},
getAvatarFn: func(context.Context, int64) (*UserAvatar, error) {
return &UserAvatar{
StorageProvider: "remote_url",
URL: "https://cdn.example.com/profile.png",
}, nil
},
}
svc := NewUserService(repo, nil, nil, nil)
user, err := svc.GetProfile(context.Background(), 12)
require.NoError(t, err)
require.Equal(t, "https://cdn.example.com/profile.png", user.AvatarURL)
require.Equal(t, "remote_url", user.AvatarSource)
}