feat: rebuild auth identity foundation flow
This commit is contained in:
@@ -44,6 +44,15 @@ func (s *userRepoStubForGroupUpdate) GetFirstAdmin(context.Context) (*User, erro
|
||||
}
|
||||
func (s *userRepoStubForGroupUpdate) Update(context.Context, *User) error { panic("unexpected") }
|
||||
func (s *userRepoStubForGroupUpdate) Delete(context.Context, int64) error { panic("unexpected") }
|
||||
func (s *userRepoStubForGroupUpdate) GetUserAvatar(context.Context, int64) (*UserAvatar, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *userRepoStubForGroupUpdate) UpsertUserAvatar(context.Context, int64, UpsertUserAvatarInput) (*UserAvatar, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *userRepoStubForGroupUpdate) DeleteUserAvatar(context.Context, int64) error {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *userRepoStubForGroupUpdate) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
|
||||
@@ -62,6 +62,18 @@ func (s *userRepoStub) Delete(ctx context.Context, id int64) error {
|
||||
return s.deleteErr
|
||||
}
|
||||
|
||||
func (s *userRepoStub) GetUserAvatar(ctx context.Context, userID int64) (*UserAvatar, error) {
|
||||
panic("unexpected GetUserAvatar call")
|
||||
}
|
||||
|
||||
func (s *userRepoStub) UpsertUserAvatar(ctx context.Context, userID int64, input UpsertUserAvatarInput) (*UserAvatar, error) {
|
||||
panic("unexpected UpsertUserAvatar call")
|
||||
}
|
||||
|
||||
func (s *userRepoStub) DeleteUserAvatar(ctx context.Context, userID int64) error {
|
||||
panic("unexpected DeleteUserAvatar call")
|
||||
}
|
||||
|
||||
func (s *userRepoStub) List(ctx context.Context, params pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
|
||||
panic("unexpected List call")
|
||||
}
|
||||
|
||||
326
backend/internal/service/auth_pending_identity_service.go
Normal file
326
backend/internal/service/auth_pending_identity_service.go
Normal file
@@ -0,0 +1,326 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/identityadoptiondecision"
|
||||
"github.com/Wei-Shaw/sub2api/ent/pendingauthsession"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrPendingAuthSessionNotFound = infraerrors.NotFound("PENDING_AUTH_SESSION_NOT_FOUND", "pending auth session not found")
|
||||
ErrPendingAuthSessionExpired = infraerrors.Unauthorized("PENDING_AUTH_SESSION_EXPIRED", "pending auth session has expired")
|
||||
ErrPendingAuthSessionConsumed = infraerrors.Unauthorized("PENDING_AUTH_SESSION_CONSUMED", "pending auth session has already been used")
|
||||
ErrPendingAuthCodeInvalid = infraerrors.Unauthorized("PENDING_AUTH_CODE_INVALID", "pending auth completion code is invalid")
|
||||
ErrPendingAuthCodeExpired = infraerrors.Unauthorized("PENDING_AUTH_CODE_EXPIRED", "pending auth completion code has expired")
|
||||
ErrPendingAuthCodeConsumed = infraerrors.Unauthorized("PENDING_AUTH_CODE_CONSUMED", "pending auth completion code has already been used")
|
||||
ErrPendingAuthBrowserMismatch = infraerrors.Unauthorized("PENDING_AUTH_BROWSER_MISMATCH", "pending auth completion code does not match this browser session")
|
||||
)
|
||||
|
||||
const (
|
||||
defaultPendingAuthTTL = 15 * time.Minute
|
||||
defaultPendingAuthCompletionTTL = 5 * time.Minute
|
||||
)
|
||||
|
||||
type PendingAuthIdentityKey struct {
|
||||
ProviderType string
|
||||
ProviderKey string
|
||||
ProviderSubject string
|
||||
}
|
||||
|
||||
type CreatePendingAuthSessionInput struct {
|
||||
SessionToken string
|
||||
Intent string
|
||||
Identity PendingAuthIdentityKey
|
||||
TargetUserID *int64
|
||||
RedirectTo string
|
||||
ResolvedEmail string
|
||||
RegistrationPasswordHash string
|
||||
BrowserSessionKey string
|
||||
UpstreamIdentityClaims map[string]any
|
||||
LocalFlowState map[string]any
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
type IssuePendingAuthCompletionCodeInput struct {
|
||||
PendingAuthSessionID int64
|
||||
BrowserSessionKey string
|
||||
TTL time.Duration
|
||||
}
|
||||
|
||||
type IssuePendingAuthCompletionCodeResult struct {
|
||||
Code string
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
type PendingIdentityAdoptionDecisionInput struct {
|
||||
PendingAuthSessionID int64
|
||||
IdentityID *int64
|
||||
AdoptDisplayName bool
|
||||
AdoptAvatar bool
|
||||
}
|
||||
|
||||
type AuthPendingIdentityService struct {
|
||||
entClient *dbent.Client
|
||||
}
|
||||
|
||||
func NewAuthPendingIdentityService(entClient *dbent.Client) *AuthPendingIdentityService {
|
||||
return &AuthPendingIdentityService{entClient: entClient}
|
||||
}
|
||||
|
||||
func (s *AuthPendingIdentityService) CreatePendingSession(ctx context.Context, input CreatePendingAuthSessionInput) (*dbent.PendingAuthSession, error) {
|
||||
if s == nil || s.entClient == nil {
|
||||
return nil, fmt.Errorf("pending auth ent client is not configured")
|
||||
}
|
||||
|
||||
sessionToken := strings.TrimSpace(input.SessionToken)
|
||||
if sessionToken == "" {
|
||||
var err error
|
||||
sessionToken, err = randomOpaqueToken(24)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
expiresAt := input.ExpiresAt.UTC()
|
||||
if expiresAt.IsZero() {
|
||||
expiresAt = time.Now().UTC().Add(defaultPendingAuthTTL)
|
||||
}
|
||||
|
||||
create := s.entClient.PendingAuthSession.Create().
|
||||
SetSessionToken(sessionToken).
|
||||
SetIntent(strings.TrimSpace(input.Intent)).
|
||||
SetProviderType(strings.TrimSpace(input.Identity.ProviderType)).
|
||||
SetProviderKey(strings.TrimSpace(input.Identity.ProviderKey)).
|
||||
SetProviderSubject(strings.TrimSpace(input.Identity.ProviderSubject)).
|
||||
SetRedirectTo(strings.TrimSpace(input.RedirectTo)).
|
||||
SetResolvedEmail(strings.TrimSpace(input.ResolvedEmail)).
|
||||
SetRegistrationPasswordHash(strings.TrimSpace(input.RegistrationPasswordHash)).
|
||||
SetBrowserSessionKey(strings.TrimSpace(input.BrowserSessionKey)).
|
||||
SetUpstreamIdentityClaims(copyPendingMap(input.UpstreamIdentityClaims)).
|
||||
SetLocalFlowState(copyPendingMap(input.LocalFlowState)).
|
||||
SetExpiresAt(expiresAt)
|
||||
if input.TargetUserID != nil {
|
||||
create = create.SetTargetUserID(*input.TargetUserID)
|
||||
}
|
||||
return create.Save(ctx)
|
||||
}
|
||||
|
||||
func (s *AuthPendingIdentityService) IssueCompletionCode(ctx context.Context, input IssuePendingAuthCompletionCodeInput) (*IssuePendingAuthCompletionCodeResult, error) {
|
||||
if s == nil || s.entClient == nil {
|
||||
return nil, fmt.Errorf("pending auth ent client is not configured")
|
||||
}
|
||||
|
||||
session, err := s.entClient.PendingAuthSession.Get(ctx, input.PendingAuthSessionID)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil, ErrPendingAuthSessionNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
code, err := randomOpaqueToken(24)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ttl := input.TTL
|
||||
if ttl <= 0 {
|
||||
ttl = defaultPendingAuthCompletionTTL
|
||||
}
|
||||
expiresAt := time.Now().UTC().Add(ttl)
|
||||
|
||||
update := s.entClient.PendingAuthSession.UpdateOneID(session.ID).
|
||||
SetCompletionCodeHash(hashPendingAuthCode(code)).
|
||||
SetCompletionCodeExpiresAt(expiresAt)
|
||||
if strings.TrimSpace(input.BrowserSessionKey) != "" {
|
||||
update = update.SetBrowserSessionKey(strings.TrimSpace(input.BrowserSessionKey))
|
||||
}
|
||||
if _, err := update.Save(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &IssuePendingAuthCompletionCodeResult{
|
||||
Code: code,
|
||||
ExpiresAt: expiresAt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *AuthPendingIdentityService) ConsumeCompletionCode(ctx context.Context, rawCode, browserSessionKey string) (*dbent.PendingAuthSession, error) {
|
||||
if s == nil || s.entClient == nil {
|
||||
return nil, fmt.Errorf("pending auth ent client is not configured")
|
||||
}
|
||||
|
||||
codeHash := hashPendingAuthCode(strings.TrimSpace(rawCode))
|
||||
session, err := s.entClient.PendingAuthSession.Query().
|
||||
Where(pendingauthsession.CompletionCodeHashEQ(codeHash)).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil, ErrPendingAuthCodeInvalid
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s.consumeSession(ctx, session, browserSessionKey, ErrPendingAuthCodeExpired, ErrPendingAuthCodeConsumed)
|
||||
}
|
||||
|
||||
func (s *AuthPendingIdentityService) ConsumeBrowserSession(ctx context.Context, sessionToken, browserSessionKey string) (*dbent.PendingAuthSession, error) {
|
||||
if s == nil || s.entClient == nil {
|
||||
return nil, fmt.Errorf("pending auth ent client is not configured")
|
||||
}
|
||||
|
||||
session, err := s.getBrowserSession(ctx, sessionToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s.consumeSession(ctx, session, browserSessionKey, ErrPendingAuthSessionExpired, ErrPendingAuthSessionConsumed)
|
||||
}
|
||||
|
||||
func (s *AuthPendingIdentityService) GetBrowserSession(ctx context.Context, sessionToken, browserSessionKey string) (*dbent.PendingAuthSession, error) {
|
||||
if s == nil || s.entClient == nil {
|
||||
return nil, fmt.Errorf("pending auth ent client is not configured")
|
||||
}
|
||||
|
||||
session, err := s.getBrowserSession(ctx, sessionToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := validatePendingSessionState(session, browserSessionKey, ErrPendingAuthSessionExpired, ErrPendingAuthSessionConsumed); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func (s *AuthPendingIdentityService) getBrowserSession(ctx context.Context, sessionToken string) (*dbent.PendingAuthSession, error) {
|
||||
if s == nil || s.entClient == nil {
|
||||
return nil, fmt.Errorf("pending auth ent client is not configured")
|
||||
}
|
||||
|
||||
sessionToken = strings.TrimSpace(sessionToken)
|
||||
if sessionToken == "" {
|
||||
return nil, ErrPendingAuthSessionNotFound
|
||||
}
|
||||
|
||||
session, err := s.entClient.PendingAuthSession.Query().
|
||||
Where(pendingauthsession.SessionTokenEQ(sessionToken)).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil, ErrPendingAuthSessionNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func (s *AuthPendingIdentityService) consumeSession(
|
||||
ctx context.Context,
|
||||
session *dbent.PendingAuthSession,
|
||||
browserSessionKey string,
|
||||
expiredErr error,
|
||||
consumedErr error,
|
||||
) (*dbent.PendingAuthSession, error) {
|
||||
if err := validatePendingSessionState(session, browserSessionKey, expiredErr, consumedErr); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
updated, err := s.entClient.PendingAuthSession.UpdateOneID(session.ID).
|
||||
SetConsumedAt(now).
|
||||
SetCompletionCodeHash("").
|
||||
ClearCompletionCodeExpiresAt().
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
func validatePendingSessionState(session *dbent.PendingAuthSession, browserSessionKey string, expiredErr error, consumedErr error) error {
|
||||
if session == nil {
|
||||
return ErrPendingAuthSessionNotFound
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
if session.ConsumedAt != nil {
|
||||
return consumedErr
|
||||
}
|
||||
if !session.ExpiresAt.IsZero() && now.After(session.ExpiresAt) {
|
||||
return expiredErr
|
||||
}
|
||||
if session.CompletionCodeExpiresAt != nil && now.After(*session.CompletionCodeExpiresAt) {
|
||||
return expiredErr
|
||||
}
|
||||
if strings.TrimSpace(session.BrowserSessionKey) != "" && strings.TrimSpace(browserSessionKey) != strings.TrimSpace(session.BrowserSessionKey) {
|
||||
return ErrPendingAuthBrowserMismatch
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *AuthPendingIdentityService) UpsertAdoptionDecision(ctx context.Context, input PendingIdentityAdoptionDecisionInput) (*dbent.IdentityAdoptionDecision, error) {
|
||||
if s == nil || s.entClient == nil {
|
||||
return nil, fmt.Errorf("pending auth ent client is not configured")
|
||||
}
|
||||
|
||||
existing, err := s.entClient.IdentityAdoptionDecision.Query().
|
||||
Where(identityadoptiondecision.PendingAuthSessionIDEQ(input.PendingAuthSessionID)).
|
||||
Only(ctx)
|
||||
if err != nil && !dbent.IsNotFound(err) {
|
||||
return nil, err
|
||||
}
|
||||
if existing == nil {
|
||||
create := s.entClient.IdentityAdoptionDecision.Create().
|
||||
SetPendingAuthSessionID(input.PendingAuthSessionID).
|
||||
SetAdoptDisplayName(input.AdoptDisplayName).
|
||||
SetAdoptAvatar(input.AdoptAvatar).
|
||||
SetDecidedAt(time.Now().UTC())
|
||||
if input.IdentityID != nil {
|
||||
create = create.SetIdentityID(*input.IdentityID)
|
||||
}
|
||||
return create.Save(ctx)
|
||||
}
|
||||
|
||||
update := s.entClient.IdentityAdoptionDecision.UpdateOneID(existing.ID).
|
||||
SetAdoptDisplayName(input.AdoptDisplayName).
|
||||
SetAdoptAvatar(input.AdoptAvatar)
|
||||
if input.IdentityID != nil {
|
||||
update = update.SetIdentityID(*input.IdentityID)
|
||||
}
|
||||
return update.Save(ctx)
|
||||
}
|
||||
|
||||
func copyPendingMap(in map[string]any) map[string]any {
|
||||
if len(in) == 0 {
|
||||
return map[string]any{}
|
||||
}
|
||||
out := make(map[string]any, len(in))
|
||||
for k, v := range in {
|
||||
out[k] = v
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func randomOpaqueToken(byteLen int) (string, error) {
|
||||
if byteLen <= 0 {
|
||||
byteLen = 16
|
||||
}
|
||||
buf := make([]byte, byteLen)
|
||||
if _, err := rand.Read(buf); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(buf), nil
|
||||
}
|
||||
|
||||
func hashPendingAuthCode(code string) string {
|
||||
sum := sha256.Sum256([]byte(code))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
224
backend/internal/service/auth_pending_identity_service_test.go
Normal file
224
backend/internal/service/auth_pending_identity_service_test.go
Normal file
@@ -0,0 +1,224 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/enttest"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"entgo.io/ent/dialect"
|
||||
entsql "entgo.io/ent/dialect/sql"
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
func newAuthPendingIdentityServiceTestClient(t *testing.T) (*AuthPendingIdentityService, *dbent.Client) {
|
||||
t.Helper()
|
||||
|
||||
db, err := sql.Open("sqlite", "file:auth_pending_identity_service?mode=memory&cache=shared")
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { _ = db.Close() })
|
||||
|
||||
_, err = db.Exec("PRAGMA foreign_keys = ON")
|
||||
require.NoError(t, err)
|
||||
|
||||
drv := entsql.OpenDB(dialect.SQLite, db)
|
||||
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
|
||||
t.Cleanup(func() { _ = client.Close() })
|
||||
|
||||
return NewAuthPendingIdentityService(client), client
|
||||
}
|
||||
|
||||
func TestAuthPendingIdentityService_CreatePendingSessionStoresSeparatedState(t *testing.T) {
|
||||
svc, client := newAuthPendingIdentityServiceTestClient(t)
|
||||
ctx := context.Background()
|
||||
|
||||
targetUser, err := client.User.Create().
|
||||
SetEmail("pending-target@example.com").
|
||||
SetPasswordHash("hash").
|
||||
SetRole(RoleUser).
|
||||
SetStatus(StatusActive).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
|
||||
Intent: "bind_current_user",
|
||||
Identity: PendingAuthIdentityKey{
|
||||
ProviderType: "wechat",
|
||||
ProviderKey: "wechat-open",
|
||||
ProviderSubject: "union-123",
|
||||
},
|
||||
TargetUserID: &targetUser.ID,
|
||||
RedirectTo: "/profile",
|
||||
ResolvedEmail: "user@example.com",
|
||||
BrowserSessionKey: "browser-1",
|
||||
UpstreamIdentityClaims: map[string]any{"nickname": "wx-user", "avatar_url": "https://cdn.example/avatar.png"},
|
||||
LocalFlowState: map[string]any{"step": "email_required"},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, session.SessionToken)
|
||||
require.Equal(t, "bind_current_user", session.Intent)
|
||||
require.Equal(t, "wechat", session.ProviderType)
|
||||
require.NotNil(t, session.TargetUserID)
|
||||
require.Equal(t, targetUser.ID, *session.TargetUserID)
|
||||
require.Equal(t, "wx-user", session.UpstreamIdentityClaims["nickname"])
|
||||
require.Equal(t, "email_required", session.LocalFlowState["step"])
|
||||
}
|
||||
|
||||
func TestAuthPendingIdentityService_CompletionCodeIsBrowserBoundAndOneTime(t *testing.T) {
|
||||
svc, _ := newAuthPendingIdentityServiceTestClient(t)
|
||||
ctx := context.Background()
|
||||
|
||||
session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
|
||||
Intent: "login",
|
||||
Identity: PendingAuthIdentityKey{
|
||||
ProviderType: "linuxdo",
|
||||
ProviderKey: "linuxdo-main",
|
||||
ProviderSubject: "subject-1",
|
||||
},
|
||||
BrowserSessionKey: "browser-expected",
|
||||
UpstreamIdentityClaims: map[string]any{"nickname": "linux-user"},
|
||||
LocalFlowState: map[string]any{"step": "pending"},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
issued, err := svc.IssueCompletionCode(ctx, IssuePendingAuthCompletionCodeInput{
|
||||
PendingAuthSessionID: session.ID,
|
||||
BrowserSessionKey: "browser-expected",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, issued.Code)
|
||||
|
||||
_, err = svc.ConsumeCompletionCode(ctx, issued.Code, "browser-other")
|
||||
require.ErrorIs(t, err, ErrPendingAuthBrowserMismatch)
|
||||
|
||||
consumed, err := svc.ConsumeCompletionCode(ctx, issued.Code, "browser-expected")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, consumed.ConsumedAt)
|
||||
require.Empty(t, consumed.CompletionCodeHash)
|
||||
require.Nil(t, consumed.CompletionCodeExpiresAt)
|
||||
|
||||
_, err = svc.ConsumeCompletionCode(ctx, issued.Code, "browser-expected")
|
||||
require.ErrorIs(t, err, ErrPendingAuthCodeInvalid)
|
||||
}
|
||||
|
||||
func TestAuthPendingIdentityService_CompletionCodeExpires(t *testing.T) {
|
||||
svc, client := newAuthPendingIdentityServiceTestClient(t)
|
||||
ctx := context.Background()
|
||||
|
||||
session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
|
||||
Intent: "login",
|
||||
Identity: PendingAuthIdentityKey{
|
||||
ProviderType: "oidc",
|
||||
ProviderKey: "https://issuer.example",
|
||||
ProviderSubject: "subject-1",
|
||||
},
|
||||
BrowserSessionKey: "browser-expired",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
issued, err := svc.IssueCompletionCode(ctx, IssuePendingAuthCompletionCodeInput{
|
||||
PendingAuthSessionID: session.ID,
|
||||
BrowserSessionKey: "browser-expired",
|
||||
TTL: time.Second,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.PendingAuthSession.UpdateOneID(session.ID).
|
||||
SetCompletionCodeExpiresAt(time.Now().UTC().Add(-time.Minute)).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = svc.ConsumeCompletionCode(ctx, issued.Code, "browser-expired")
|
||||
require.ErrorIs(t, err, ErrPendingAuthCodeExpired)
|
||||
}
|
||||
|
||||
func TestAuthPendingIdentityService_UpsertAdoptionDecision(t *testing.T) {
|
||||
svc, client := newAuthPendingIdentityServiceTestClient(t)
|
||||
ctx := context.Background()
|
||||
|
||||
user, err := client.User.Create().
|
||||
SetEmail("adoption@example.com").
|
||||
SetPasswordHash("hash").
|
||||
SetRole(RoleUser).
|
||||
SetStatus(StatusActive).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
identity, err := client.AuthIdentity.Create().
|
||||
SetUserID(user.ID).
|
||||
SetProviderType("wechat").
|
||||
SetProviderKey("wechat-open").
|
||||
SetProviderSubject("union-adoption").
|
||||
SetMetadata(map[string]any{}).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
|
||||
Intent: "bind_current_user",
|
||||
Identity: PendingAuthIdentityKey{
|
||||
ProviderType: "wechat",
|
||||
ProviderKey: "wechat-open",
|
||||
ProviderSubject: "union-adoption",
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
first, err := svc.UpsertAdoptionDecision(ctx, PendingIdentityAdoptionDecisionInput{
|
||||
PendingAuthSessionID: session.ID,
|
||||
AdoptDisplayName: true,
|
||||
AdoptAvatar: false,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.True(t, first.AdoptDisplayName)
|
||||
require.False(t, first.AdoptAvatar)
|
||||
require.Nil(t, first.IdentityID)
|
||||
|
||||
second, err := svc.UpsertAdoptionDecision(ctx, PendingIdentityAdoptionDecisionInput{
|
||||
PendingAuthSessionID: session.ID,
|
||||
IdentityID: &identity.ID,
|
||||
AdoptDisplayName: true,
|
||||
AdoptAvatar: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, first.ID, second.ID)
|
||||
require.NotNil(t, second.IdentityID)
|
||||
require.Equal(t, identity.ID, *second.IdentityID)
|
||||
require.True(t, second.AdoptAvatar)
|
||||
}
|
||||
|
||||
func TestAuthPendingIdentityService_ConsumeBrowserSession(t *testing.T) {
|
||||
svc, _ := newAuthPendingIdentityServiceTestClient(t)
|
||||
ctx := context.Background()
|
||||
|
||||
session, err := svc.CreatePendingSession(ctx, CreatePendingAuthSessionInput{
|
||||
Intent: "login",
|
||||
Identity: PendingAuthIdentityKey{
|
||||
ProviderType: "linuxdo",
|
||||
ProviderKey: "linuxdo",
|
||||
ProviderSubject: "subject-session-token",
|
||||
},
|
||||
BrowserSessionKey: "browser-session",
|
||||
LocalFlowState: map[string]any{
|
||||
"completion_response": map[string]any{
|
||||
"access_token": "token",
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = svc.ConsumeBrowserSession(ctx, session.SessionToken, "browser-other")
|
||||
require.ErrorIs(t, err, ErrPendingAuthBrowserMismatch)
|
||||
|
||||
consumed, err := svc.ConsumeBrowserSession(ctx, session.SessionToken, "browser-session")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, consumed.ConsumedAt)
|
||||
|
||||
_, err = svc.ConsumeBrowserSession(ctx, session.SessionToken, "browser-session")
|
||||
require.ErrorIs(t, err, ErrPendingAuthSessionConsumed)
|
||||
}
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/authidentity"
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
@@ -106,6 +107,13 @@ func NewAuthService(
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AuthService) EntClient() *dbent.Client {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
return s.entClient
|
||||
}
|
||||
|
||||
// Register 用户注册,返回token和用户
|
||||
func (s *AuthService) Register(ctx context.Context, email, password string) (string, *User, error) {
|
||||
return s.RegisterWithVerification(ctx, email, password, "", "", "")
|
||||
@@ -205,6 +213,7 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Database error creating user: %v", err)
|
||||
return "", nil, ErrServiceUnavailable
|
||||
}
|
||||
s.postAuthUserBootstrap(ctx, user, "email", true)
|
||||
s.assignDefaultSubscriptions(ctx, user.ID)
|
||||
|
||||
// 标记邀请码为已使用(如果使用了邀请码)
|
||||
@@ -421,6 +430,7 @@ func (s *AuthService) Login(ctx context.Context, email, password string) (string
|
||||
if !user.IsActive() {
|
||||
return "", nil, ErrUserNotActive
|
||||
}
|
||||
s.touchUserLogin(ctx, user.ID)
|
||||
|
||||
// 生成JWT token
|
||||
token, err := s.GenerateToken(user)
|
||||
@@ -501,6 +511,7 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
|
||||
}
|
||||
} else {
|
||||
user = newUser
|
||||
s.postAuthUserBootstrap(ctx, user, inferLegacySignupSource(email), true)
|
||||
s.assignDefaultSubscriptions(ctx, user.ID)
|
||||
}
|
||||
} else {
|
||||
@@ -520,6 +531,7 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to update username after oauth login: %v", err)
|
||||
}
|
||||
}
|
||||
s.touchUserLogin(ctx, user.ID)
|
||||
|
||||
token, err := s.GenerateToken(user)
|
||||
if err != nil {
|
||||
@@ -630,6 +642,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
|
||||
return nil, nil, ErrServiceUnavailable
|
||||
}
|
||||
user = newUser
|
||||
s.postAuthUserBootstrap(ctx, user, inferLegacySignupSource(email), true)
|
||||
s.assignDefaultSubscriptions(ctx, user.ID)
|
||||
}
|
||||
} else {
|
||||
@@ -646,6 +659,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
|
||||
}
|
||||
} else {
|
||||
user = newUser
|
||||
s.postAuthUserBootstrap(ctx, user, inferLegacySignupSource(email), true)
|
||||
s.assignDefaultSubscriptions(ctx, user.ID)
|
||||
if invitationRedeemCode != nil {
|
||||
if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil {
|
||||
@@ -670,6 +684,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to update username after oauth login: %v", err)
|
||||
}
|
||||
}
|
||||
s.touchUserLogin(ctx, user.ID)
|
||||
|
||||
tokenPair, err := s.GenerateTokenPair(ctx, user, "")
|
||||
if err != nil {
|
||||
@@ -678,63 +693,6 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
|
||||
return tokenPair, user, nil
|
||||
}
|
||||
|
||||
// pendingOAuthTokenTTL is the validity period for pending OAuth tokens.
|
||||
const pendingOAuthTokenTTL = 10 * time.Minute
|
||||
|
||||
// pendingOAuthPurpose is the purpose claim value for pending OAuth registration tokens.
|
||||
const pendingOAuthPurpose = "pending_oauth_registration"
|
||||
|
||||
type pendingOAuthClaims struct {
|
||||
Email string `json:"email"`
|
||||
Username string `json:"username"`
|
||||
Purpose string `json:"purpose"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
// CreatePendingOAuthToken generates a short-lived JWT that carries the OAuth identity
|
||||
// while waiting for the user to supply an invitation code.
|
||||
func (s *AuthService) CreatePendingOAuthToken(email, username string) (string, error) {
|
||||
now := time.Now()
|
||||
claims := &pendingOAuthClaims{
|
||||
Email: email,
|
||||
Username: username,
|
||||
Purpose: pendingOAuthPurpose,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(pendingOAuthTokenTTL)),
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
NotBefore: jwt.NewNumericDate(now),
|
||||
},
|
||||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
return token.SignedString([]byte(s.cfg.JWT.Secret))
|
||||
}
|
||||
|
||||
// VerifyPendingOAuthToken validates a pending OAuth token and returns the embedded identity.
|
||||
// Returns ErrInvalidToken when the token is invalid or expired.
|
||||
func (s *AuthService) VerifyPendingOAuthToken(tokenStr string) (email, username string, err error) {
|
||||
if len(tokenStr) > maxTokenLength {
|
||||
return "", "", ErrInvalidToken
|
||||
}
|
||||
parser := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name}))
|
||||
token, parseErr := parser.ParseWithClaims(tokenStr, &pendingOAuthClaims{}, func(t *jwt.Token) (any, error) {
|
||||
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
|
||||
}
|
||||
return []byte(s.cfg.JWT.Secret), nil
|
||||
})
|
||||
if parseErr != nil {
|
||||
return "", "", ErrInvalidToken
|
||||
}
|
||||
claims, ok := token.Claims.(*pendingOAuthClaims)
|
||||
if !ok || !token.Valid {
|
||||
return "", "", ErrInvalidToken
|
||||
}
|
||||
if claims.Purpose != pendingOAuthPurpose {
|
||||
return "", "", ErrInvalidToken
|
||||
}
|
||||
return claims.Email, claims.Username, nil
|
||||
}
|
||||
|
||||
func (s *AuthService) assignDefaultSubscriptions(ctx context.Context, userID int64) {
|
||||
if s.settingService == nil || s.defaultSubAssigner == nil || userID <= 0 {
|
||||
return
|
||||
@@ -752,6 +710,95 @@ func (s *AuthService) assignDefaultSubscriptions(ctx context.Context, userID int
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AuthService) postAuthUserBootstrap(ctx context.Context, user *User, signupSource string, touchLogin bool) {
|
||||
if user == nil || user.ID <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
if strings.TrimSpace(signupSource) == "" {
|
||||
signupSource = "email"
|
||||
}
|
||||
s.updateUserSignupSource(ctx, user.ID, signupSource)
|
||||
|
||||
if signupSource == "email" {
|
||||
s.ensureEmailAuthIdentity(ctx, user)
|
||||
}
|
||||
if touchLogin {
|
||||
s.touchUserLogin(ctx, user.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AuthService) updateUserSignupSource(ctx context.Context, userID int64, signupSource string) {
|
||||
if s == nil || s.entClient == nil || userID <= 0 {
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(signupSource) == "" {
|
||||
return
|
||||
}
|
||||
if err := s.entClient.User.UpdateOneID(userID).
|
||||
SetSignupSource(signupSource).
|
||||
Exec(ctx); err != nil {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to update signup source: user_id=%d source=%s err=%v", userID, signupSource, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AuthService) touchUserLogin(ctx context.Context, userID int64) {
|
||||
if s == nil || s.entClient == nil || userID <= 0 {
|
||||
return
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
if err := s.entClient.User.UpdateOneID(userID).
|
||||
SetLastLoginAt(now).
|
||||
SetLastActiveAt(now).
|
||||
Exec(ctx); err != nil {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to touch login timestamps: user_id=%d err=%v", userID, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AuthService) ensureEmailAuthIdentity(ctx context.Context, user *User) {
|
||||
if s == nil || s.entClient == nil || user == nil || user.ID <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
email := strings.ToLower(strings.TrimSpace(user.Email))
|
||||
if email == "" || isReservedEmail(email) {
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.entClient.AuthIdentity.Create().
|
||||
SetUserID(user.ID).
|
||||
SetProviderType("email").
|
||||
SetProviderKey("email").
|
||||
SetProviderSubject(email).
|
||||
SetVerifiedAt(time.Now().UTC()).
|
||||
SetMetadata(map[string]any{
|
||||
"source": "auth_service_dual_write",
|
||||
}).
|
||||
OnConflictColumns(
|
||||
authidentity.FieldProviderType,
|
||||
authidentity.FieldProviderKey,
|
||||
authidentity.FieldProviderSubject,
|
||||
).
|
||||
DoNothing().
|
||||
Exec(ctx); err != nil {
|
||||
logger.LegacyPrintf("service.auth", "[Auth] Failed to ensure email auth identity: user_id=%d email=%s err=%v", user.ID, email, err)
|
||||
}
|
||||
}
|
||||
|
||||
func inferLegacySignupSource(email string) string {
|
||||
normalized := strings.ToLower(strings.TrimSpace(email))
|
||||
switch {
|
||||
case strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain):
|
||||
return "linuxdo"
|
||||
case strings.HasSuffix(normalized, OIDCConnectSyntheticEmailDomain):
|
||||
return "oidc"
|
||||
case strings.HasSuffix(normalized, WeChatConnectSyntheticEmailDomain):
|
||||
return "wechat"
|
||||
default:
|
||||
return "email"
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AuthService) validateRegistrationEmailPolicy(ctx context.Context, email string) error {
|
||||
if s.settingService == nil {
|
||||
return nil
|
||||
@@ -834,7 +881,8 @@ func randomHexString(byteLength int) (string, error) {
|
||||
func isReservedEmail(email string) bool {
|
||||
normalized := strings.ToLower(strings.TrimSpace(email))
|
||||
return strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain) ||
|
||||
strings.HasSuffix(normalized, OIDCConnectSyntheticEmailDomain)
|
||||
strings.HasSuffix(normalized, OIDCConnectSyntheticEmailDomain) ||
|
||||
strings.HasSuffix(normalized, WeChatConnectSyntheticEmailDomain)
|
||||
}
|
||||
|
||||
// GenerateToken 生成JWT access token
|
||||
|
||||
153
backend/internal/service/auth_service_identity_sync_test.go
Normal file
153
backend/internal/service/auth_service_identity_sync_test.go
Normal file
@@ -0,0 +1,153 @@
|
||||
//go:build unit
|
||||
|
||||
package service_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/authidentity"
|
||||
"github.com/Wei-Shaw/sub2api/ent/enttest"
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/repository"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"entgo.io/ent/dialect"
|
||||
entsql "entgo.io/ent/dialect/sql"
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
type authIdentitySettingRepoStub struct {
|
||||
values map[string]string
|
||||
}
|
||||
|
||||
func (s *authIdentitySettingRepoStub) Get(context.Context, string) (*service.Setting, error) {
|
||||
panic("unexpected Get call")
|
||||
}
|
||||
|
||||
func (s *authIdentitySettingRepoStub) GetValue(_ context.Context, key string) (string, error) {
|
||||
if v, ok := s.values[key]; ok {
|
||||
return v, nil
|
||||
}
|
||||
return "", service.ErrSettingNotFound
|
||||
}
|
||||
|
||||
func (s *authIdentitySettingRepoStub) Set(context.Context, string, string) error {
|
||||
panic("unexpected Set call")
|
||||
}
|
||||
|
||||
func (s *authIdentitySettingRepoStub) GetMultiple(context.Context, []string) (map[string]string, error) {
|
||||
panic("unexpected GetMultiple call")
|
||||
}
|
||||
|
||||
func (s *authIdentitySettingRepoStub) SetMultiple(context.Context, map[string]string) error {
|
||||
panic("unexpected SetMultiple call")
|
||||
}
|
||||
|
||||
func (s *authIdentitySettingRepoStub) GetAll(context.Context) (map[string]string, error) {
|
||||
panic("unexpected GetAll call")
|
||||
}
|
||||
|
||||
func (s *authIdentitySettingRepoStub) Delete(context.Context, string) error {
|
||||
panic("unexpected Delete call")
|
||||
}
|
||||
|
||||
func newAuthServiceWithEnt(t *testing.T) (*service.AuthService, service.UserRepository, *dbent.Client) {
|
||||
t.Helper()
|
||||
|
||||
db, err := sql.Open("sqlite", "file:auth_service_identity_sync?mode=memory&cache=shared")
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { _ = db.Close() })
|
||||
|
||||
_, err = db.Exec("PRAGMA foreign_keys = ON")
|
||||
require.NoError(t, err)
|
||||
|
||||
drv := entsql.OpenDB(dialect.SQLite, db)
|
||||
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
|
||||
t.Cleanup(func() { _ = client.Close() })
|
||||
|
||||
repo := repository.NewUserRepository(client, db)
|
||||
cfg := &config.Config{
|
||||
JWT: config.JWTConfig{
|
||||
Secret: "test-auth-identity-secret",
|
||||
ExpireHour: 1,
|
||||
},
|
||||
Default: config.DefaultConfig{
|
||||
UserBalance: 3.5,
|
||||
UserConcurrency: 2,
|
||||
},
|
||||
}
|
||||
settingSvc := service.NewSettingService(&authIdentitySettingRepoStub{
|
||||
values: map[string]string{
|
||||
service.SettingKeyRegistrationEnabled: "true",
|
||||
},
|
||||
}, cfg)
|
||||
|
||||
svc := service.NewAuthService(client, repo, nil, nil, cfg, settingSvc, nil, nil, nil, nil, nil)
|
||||
return svc, repo, client
|
||||
}
|
||||
|
||||
func TestAuthServiceRegisterDualWritesEmailIdentity(t *testing.T) {
|
||||
svc, _, client := newAuthServiceWithEnt(t)
|
||||
ctx := context.Background()
|
||||
|
||||
token, user, err := svc.Register(ctx, "user@example.com", "password")
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, token)
|
||||
require.NotNil(t, user)
|
||||
|
||||
storedUser, err := client.User.Get(ctx, user.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "email", storedUser.SignupSource)
|
||||
require.NotNil(t, storedUser.LastLoginAt)
|
||||
require.NotNil(t, storedUser.LastActiveAt)
|
||||
|
||||
identity, err := client.AuthIdentity.Query().
|
||||
Where(
|
||||
authidentity.ProviderTypeEQ("email"),
|
||||
authidentity.ProviderKeyEQ("email"),
|
||||
authidentity.ProviderSubjectEQ("user@example.com"),
|
||||
).
|
||||
Only(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, user.ID, identity.UserID)
|
||||
require.NotNil(t, identity.VerifiedAt)
|
||||
}
|
||||
|
||||
func TestAuthServiceLoginTouchesLastLoginAt(t *testing.T) {
|
||||
svc, repo, client := newAuthServiceWithEnt(t)
|
||||
ctx := context.Background()
|
||||
|
||||
user := &service.User{
|
||||
Email: "login@example.com",
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Balance: 1,
|
||||
Concurrency: 1,
|
||||
}
|
||||
require.NoError(t, user.SetPassword("password"))
|
||||
require.NoError(t, repo.Create(ctx, user))
|
||||
|
||||
old := time.Now().Add(-2 * time.Hour).UTC().Round(time.Second)
|
||||
_, err := client.User.UpdateOneID(user.ID).
|
||||
SetLastLoginAt(old).
|
||||
SetLastActiveAt(old).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
token, gotUser, err := svc.Login(ctx, user.Email, "password")
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, token)
|
||||
require.NotNil(t, gotUser)
|
||||
|
||||
storedUser, err := client.User.Get(ctx, user.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, storedUser.LastLoginAt)
|
||||
require.NotNil(t, storedUser.LastActiveAt)
|
||||
require.True(t, storedUser.LastLoginAt.After(old))
|
||||
require.True(t, storedUser.LastActiveAt.After(old))
|
||||
}
|
||||
@@ -1,146 +0,0 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func newAuthServiceForPendingOAuthTest() *AuthService {
|
||||
cfg := &config.Config{
|
||||
JWT: config.JWTConfig{
|
||||
Secret: "test-secret-pending-oauth",
|
||||
ExpireHour: 1,
|
||||
},
|
||||
}
|
||||
return NewAuthService(nil, nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
|
||||
}
|
||||
|
||||
// TestVerifyPendingOAuthToken_ValidToken 验证正常签发的 pending token 可以被成功解析。
|
||||
func TestVerifyPendingOAuthToken_ValidToken(t *testing.T) {
|
||||
svc := newAuthServiceForPendingOAuthTest()
|
||||
|
||||
token, err := svc.CreatePendingOAuthToken("user@example.com", "alice")
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, token)
|
||||
|
||||
email, username, err := svc.VerifyPendingOAuthToken(token)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "user@example.com", email)
|
||||
require.Equal(t, "alice", username)
|
||||
}
|
||||
|
||||
// TestVerifyPendingOAuthToken_RegularJWTRejected 用普通 access token 尝试验证,应返回 ErrInvalidToken。
|
||||
func TestVerifyPendingOAuthToken_RegularJWTRejected(t *testing.T) {
|
||||
svc := newAuthServiceForPendingOAuthTest()
|
||||
|
||||
// 签发一个普通 access token(JWTClaims,无 Purpose 字段)
|
||||
accessToken, err := svc.GenerateToken(&User{
|
||||
ID: 1,
|
||||
Email: "user@example.com",
|
||||
Role: RoleUser,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, _, err = svc.VerifyPendingOAuthToken(accessToken)
|
||||
require.ErrorIs(t, err, ErrInvalidToken)
|
||||
}
|
||||
|
||||
// TestVerifyPendingOAuthToken_WrongPurpose 手动构造 purpose 字段不匹配的 JWT,应返回 ErrInvalidToken。
|
||||
func TestVerifyPendingOAuthToken_WrongPurpose(t *testing.T) {
|
||||
svc := newAuthServiceForPendingOAuthTest()
|
||||
|
||||
now := time.Now()
|
||||
claims := &pendingOAuthClaims{
|
||||
Email: "user@example.com",
|
||||
Username: "alice",
|
||||
Purpose: "some_other_purpose",
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(10 * time.Minute)),
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
NotBefore: jwt.NewNumericDate(now),
|
||||
},
|
||||
}
|
||||
tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret))
|
||||
require.NoError(t, err)
|
||||
|
||||
_, _, err = svc.VerifyPendingOAuthToken(tokenStr)
|
||||
require.ErrorIs(t, err, ErrInvalidToken)
|
||||
}
|
||||
|
||||
// TestVerifyPendingOAuthToken_MissingPurpose 手动构造无 purpose 字段的 JWT(模拟旧 token),应返回 ErrInvalidToken。
|
||||
func TestVerifyPendingOAuthToken_MissingPurpose(t *testing.T) {
|
||||
svc := newAuthServiceForPendingOAuthTest()
|
||||
|
||||
now := time.Now()
|
||||
claims := &pendingOAuthClaims{
|
||||
Email: "user@example.com",
|
||||
Username: "alice",
|
||||
Purpose: "", // 旧 token 无此字段,反序列化后为零值
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(10 * time.Minute)),
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
NotBefore: jwt.NewNumericDate(now),
|
||||
},
|
||||
}
|
||||
tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret))
|
||||
require.NoError(t, err)
|
||||
|
||||
_, _, err = svc.VerifyPendingOAuthToken(tokenStr)
|
||||
require.ErrorIs(t, err, ErrInvalidToken)
|
||||
}
|
||||
|
||||
// TestVerifyPendingOAuthToken_ExpiredToken 过期 token 应返回 ErrInvalidToken。
|
||||
func TestVerifyPendingOAuthToken_ExpiredToken(t *testing.T) {
|
||||
svc := newAuthServiceForPendingOAuthTest()
|
||||
|
||||
past := time.Now().Add(-1 * time.Hour)
|
||||
claims := &pendingOAuthClaims{
|
||||
Email: "user@example.com",
|
||||
Username: "alice",
|
||||
Purpose: pendingOAuthPurpose,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(past),
|
||||
IssuedAt: jwt.NewNumericDate(past.Add(-10 * time.Minute)),
|
||||
NotBefore: jwt.NewNumericDate(past.Add(-10 * time.Minute)),
|
||||
},
|
||||
}
|
||||
tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret))
|
||||
require.NoError(t, err)
|
||||
|
||||
_, _, err = svc.VerifyPendingOAuthToken(tokenStr)
|
||||
require.ErrorIs(t, err, ErrInvalidToken)
|
||||
}
|
||||
|
||||
// TestVerifyPendingOAuthToken_WrongSecret 不同密钥签发的 token 应返回 ErrInvalidToken。
|
||||
func TestVerifyPendingOAuthToken_WrongSecret(t *testing.T) {
|
||||
other := NewAuthService(nil, nil, nil, nil, &config.Config{
|
||||
JWT: config.JWTConfig{Secret: "other-secret"},
|
||||
}, nil, nil, nil, nil, nil, nil)
|
||||
|
||||
token, err := other.CreatePendingOAuthToken("user@example.com", "alice")
|
||||
require.NoError(t, err)
|
||||
|
||||
svc := newAuthServiceForPendingOAuthTest()
|
||||
_, _, err = svc.VerifyPendingOAuthToken(token)
|
||||
require.ErrorIs(t, err, ErrInvalidToken)
|
||||
}
|
||||
|
||||
// TestVerifyPendingOAuthToken_TooLong 超长 token 应返回 ErrInvalidToken。
|
||||
func TestVerifyPendingOAuthToken_TooLong(t *testing.T) {
|
||||
svc := newAuthServiceForPendingOAuthTest()
|
||||
giant := make([]byte, maxTokenLength+1)
|
||||
for i := range giant {
|
||||
giant[i] = 'a'
|
||||
}
|
||||
_, _, err := svc.VerifyPendingOAuthToken(string(giant))
|
||||
require.ErrorIs(t, err, ErrInvalidToken)
|
||||
}
|
||||
@@ -74,6 +74,9 @@ const LinuxDoConnectSyntheticEmailDomain = "@linuxdo-connect.invalid"
|
||||
// OIDCConnectSyntheticEmailDomain 是 OIDC 用户的合成邮箱后缀(RFC 保留域名)。
|
||||
const OIDCConnectSyntheticEmailDomain = "@oidc-connect.invalid"
|
||||
|
||||
// WeChatConnectSyntheticEmailDomain 是 WeChat Connect 用户的合成邮箱后缀(RFC 保留域名)。
|
||||
const WeChatConnectSyntheticEmailDomain = "@wechat-connect.invalid"
|
||||
|
||||
// Setting keys
|
||||
const (
|
||||
// 注册设置
|
||||
@@ -153,6 +156,29 @@ const (
|
||||
SettingKeyDefaultBalance = "default_balance" // 新用户默认余额
|
||||
SettingKeyDefaultSubscriptions = "default_subscriptions" // 新用户默认订阅列表(JSON)
|
||||
|
||||
// 第三方认证来源默认授予配置
|
||||
SettingKeyAuthSourceDefaultEmailBalance = "auth_source_default_email_balance"
|
||||
SettingKeyAuthSourceDefaultEmailConcurrency = "auth_source_default_email_concurrency"
|
||||
SettingKeyAuthSourceDefaultEmailSubscriptions = "auth_source_default_email_subscriptions"
|
||||
SettingKeyAuthSourceDefaultEmailGrantOnSignup = "auth_source_default_email_grant_on_signup"
|
||||
SettingKeyAuthSourceDefaultEmailGrantOnFirstBind = "auth_source_default_email_grant_on_first_bind"
|
||||
SettingKeyAuthSourceDefaultLinuxDoBalance = "auth_source_default_linuxdo_balance"
|
||||
SettingKeyAuthSourceDefaultLinuxDoConcurrency = "auth_source_default_linuxdo_concurrency"
|
||||
SettingKeyAuthSourceDefaultLinuxDoSubscriptions = "auth_source_default_linuxdo_subscriptions"
|
||||
SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup = "auth_source_default_linuxdo_grant_on_signup"
|
||||
SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind = "auth_source_default_linuxdo_grant_on_first_bind"
|
||||
SettingKeyAuthSourceDefaultOIDCBalance = "auth_source_default_oidc_balance"
|
||||
SettingKeyAuthSourceDefaultOIDCConcurrency = "auth_source_default_oidc_concurrency"
|
||||
SettingKeyAuthSourceDefaultOIDCSubscriptions = "auth_source_default_oidc_subscriptions"
|
||||
SettingKeyAuthSourceDefaultOIDCGrantOnSignup = "auth_source_default_oidc_grant_on_signup"
|
||||
SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind = "auth_source_default_oidc_grant_on_first_bind"
|
||||
SettingKeyAuthSourceDefaultWeChatBalance = "auth_source_default_wechat_balance"
|
||||
SettingKeyAuthSourceDefaultWeChatConcurrency = "auth_source_default_wechat_concurrency"
|
||||
SettingKeyAuthSourceDefaultWeChatSubscriptions = "auth_source_default_wechat_subscriptions"
|
||||
SettingKeyAuthSourceDefaultWeChatGrantOnSignup = "auth_source_default_wechat_grant_on_signup"
|
||||
SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind = "auth_source_default_wechat_grant_on_first_bind"
|
||||
SettingKeyForceEmailOnThirdPartySignup = "force_email_on_third_party_signup"
|
||||
|
||||
// 管理员 API Key
|
||||
SettingKeyAdminAPIKey = "admin_api_key" // 全局管理员 API Key(用于外部系统集成)
|
||||
|
||||
|
||||
@@ -13,14 +13,30 @@ import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sync/singleflight"
|
||||
)
|
||||
|
||||
const (
|
||||
openAIAccountScheduleLayerPreviousResponse = "previous_response_id"
|
||||
openAIAccountScheduleLayerSessionSticky = "session_hash"
|
||||
openAIAccountScheduleLayerLoadBalance = "load_balance"
|
||||
openAIAdvancedSchedulerSettingKey = "openai_advanced_scheduler_enabled"
|
||||
)
|
||||
|
||||
const (
|
||||
openAIAdvancedSchedulerSettingCacheTTL = 5 * time.Second
|
||||
openAIAdvancedSchedulerSettingDBTimeout = 2 * time.Second
|
||||
)
|
||||
|
||||
type cachedOpenAIAdvancedSchedulerSetting struct {
|
||||
enabled bool
|
||||
expiresAt int64
|
||||
}
|
||||
|
||||
var openAIAdvancedSchedulerSettingCache atomic.Value // *cachedOpenAIAdvancedSchedulerSetting
|
||||
var openAIAdvancedSchedulerSettingSF singleflight.Group
|
||||
|
||||
type OpenAIAccountScheduleRequest struct {
|
||||
GroupID *int64
|
||||
SessionHash string
|
||||
@@ -805,10 +821,56 @@ func (s *defaultOpenAIAccountScheduler) SnapshotMetrics() OpenAIAccountScheduler
|
||||
return snapshot
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) getOpenAIAccountScheduler() OpenAIAccountScheduler {
|
||||
func (s *OpenAIGatewayService) openAIAdvancedSchedulerSettingRepo() SettingRepository {
|
||||
if s == nil || s.rateLimitService == nil || s.rateLimitService.settingService == nil {
|
||||
return nil
|
||||
}
|
||||
return s.rateLimitService.settingService.settingRepo
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) isOpenAIAdvancedSchedulerEnabled(ctx context.Context) bool {
|
||||
if cached, ok := openAIAdvancedSchedulerSettingCache.Load().(*cachedOpenAIAdvancedSchedulerSetting); ok && cached != nil {
|
||||
if time.Now().UnixNano() < cached.expiresAt {
|
||||
return cached.enabled
|
||||
}
|
||||
}
|
||||
|
||||
result, _, _ := openAIAdvancedSchedulerSettingSF.Do(openAIAdvancedSchedulerSettingKey, func() (any, error) {
|
||||
if cached, ok := openAIAdvancedSchedulerSettingCache.Load().(*cachedOpenAIAdvancedSchedulerSetting); ok && cached != nil {
|
||||
if time.Now().UnixNano() < cached.expiresAt {
|
||||
return cached.enabled, nil
|
||||
}
|
||||
}
|
||||
|
||||
enabled := false
|
||||
if repo := s.openAIAdvancedSchedulerSettingRepo(); repo != nil {
|
||||
dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), openAIAdvancedSchedulerSettingDBTimeout)
|
||||
defer cancel()
|
||||
|
||||
value, err := repo.GetValue(dbCtx, openAIAdvancedSchedulerSettingKey)
|
||||
if err == nil {
|
||||
enabled = strings.EqualFold(strings.TrimSpace(value), "true")
|
||||
}
|
||||
}
|
||||
|
||||
openAIAdvancedSchedulerSettingCache.Store(&cachedOpenAIAdvancedSchedulerSetting{
|
||||
enabled: enabled,
|
||||
expiresAt: time.Now().Add(openAIAdvancedSchedulerSettingCacheTTL).UnixNano(),
|
||||
})
|
||||
return enabled, nil
|
||||
})
|
||||
|
||||
enabled, _ := result.(bool)
|
||||
return enabled
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) getOpenAIAccountScheduler(ctx context.Context) OpenAIAccountScheduler {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
if !s.isOpenAIAdvancedSchedulerEnabled(ctx) {
|
||||
return nil
|
||||
}
|
||||
s.openaiSchedulerOnce.Do(func() {
|
||||
if s.openaiAccountStats == nil {
|
||||
s.openaiAccountStats = newOpenAIAccountRuntimeStats()
|
||||
@@ -820,6 +882,11 @@ func (s *OpenAIGatewayService) getOpenAIAccountScheduler() OpenAIAccountSchedule
|
||||
return s.openaiScheduler
|
||||
}
|
||||
|
||||
func resetOpenAIAdvancedSchedulerSettingCacheForTest() {
|
||||
openAIAdvancedSchedulerSettingCache = atomic.Value{}
|
||||
openAIAdvancedSchedulerSettingSF = singleflight.Group{}
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) SelectAccountWithScheduler(
|
||||
ctx context.Context,
|
||||
groupID *int64,
|
||||
@@ -830,7 +897,7 @@ func (s *OpenAIGatewayService) SelectAccountWithScheduler(
|
||||
requiredTransport OpenAIUpstreamTransport,
|
||||
) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) {
|
||||
decision := OpenAIAccountScheduleDecision{}
|
||||
scheduler := s.getOpenAIAccountScheduler()
|
||||
scheduler := s.getOpenAIAccountScheduler(ctx)
|
||||
if scheduler == nil {
|
||||
selection, err := s.SelectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, excludedIDs)
|
||||
decision.Layer = openAIAccountScheduleLayerLoadBalance
|
||||
@@ -856,7 +923,7 @@ func (s *OpenAIGatewayService) SelectAccountWithScheduler(
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) ReportOpenAIAccountScheduleResult(accountID int64, success bool, firstTokenMs *int) {
|
||||
scheduler := s.getOpenAIAccountScheduler()
|
||||
scheduler := s.getOpenAIAccountScheduler(context.Background())
|
||||
if scheduler == nil {
|
||||
return
|
||||
}
|
||||
@@ -864,7 +931,7 @@ func (s *OpenAIGatewayService) ReportOpenAIAccountScheduleResult(accountID int64
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) RecordOpenAIAccountSwitch() {
|
||||
scheduler := s.getOpenAIAccountScheduler()
|
||||
scheduler := s.getOpenAIAccountScheduler(context.Background())
|
||||
if scheduler == nil {
|
||||
return
|
||||
}
|
||||
@@ -872,7 +939,7 @@ func (s *OpenAIGatewayService) RecordOpenAIAccountSwitch() {
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) SnapshotOpenAIAccountSchedulerMetrics() OpenAIAccountSchedulerMetricsSnapshot {
|
||||
scheduler := s.getOpenAIAccountScheduler()
|
||||
scheduler := s.getOpenAIAccountScheduler(context.Background())
|
||||
if scheduler == nil {
|
||||
return OpenAIAccountSchedulerMetricsSnapshot{}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"sync"
|
||||
@@ -18,6 +19,202 @@ type openAISnapshotCacheStub struct {
|
||||
accountsByID map[int64]*Account
|
||||
}
|
||||
|
||||
type schedulerTestOpenAIAccountRepo struct {
|
||||
AccountRepository
|
||||
accounts []Account
|
||||
}
|
||||
|
||||
func (r schedulerTestOpenAIAccountRepo) GetByID(ctx context.Context, id int64) (*Account, error) {
|
||||
for i := range r.accounts {
|
||||
if r.accounts[i].ID == id {
|
||||
return &r.accounts[i], nil
|
||||
}
|
||||
}
|
||||
return nil, errors.New("account not found")
|
||||
}
|
||||
|
||||
func (r schedulerTestOpenAIAccountRepo) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error) {
|
||||
var result []Account
|
||||
for _, acc := range r.accounts {
|
||||
if acc.Platform == platform {
|
||||
result = append(result, acc)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r schedulerTestOpenAIAccountRepo) ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error) {
|
||||
var result []Account
|
||||
for _, acc := range r.accounts {
|
||||
if acc.Platform == platform {
|
||||
result = append(result, acc)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r schedulerTestOpenAIAccountRepo) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error) {
|
||||
return r.ListSchedulableByPlatform(ctx, platform)
|
||||
}
|
||||
|
||||
type schedulerTestConcurrencyCache struct {
|
||||
ConcurrencyCache
|
||||
loadBatchErr error
|
||||
loadMap map[int64]*AccountLoadInfo
|
||||
acquireResults map[int64]bool
|
||||
waitCounts map[int64]int
|
||||
skipDefaultLoad bool
|
||||
}
|
||||
|
||||
func (c schedulerTestConcurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
if c.acquireResults != nil {
|
||||
if result, ok := c.acquireResults[accountID]; ok {
|
||||
return result, nil
|
||||
}
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (c schedulerTestConcurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c schedulerTestConcurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
|
||||
if c.loadBatchErr != nil {
|
||||
return nil, c.loadBatchErr
|
||||
}
|
||||
out := make(map[int64]*AccountLoadInfo, len(accounts))
|
||||
if c.skipDefaultLoad && c.loadMap != nil {
|
||||
for _, acc := range accounts {
|
||||
if load, ok := c.loadMap[acc.ID]; ok {
|
||||
out[acc.ID] = load
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
for _, acc := range accounts {
|
||||
if c.loadMap != nil {
|
||||
if load, ok := c.loadMap[acc.ID]; ok {
|
||||
out[acc.ID] = load
|
||||
continue
|
||||
}
|
||||
}
|
||||
out[acc.ID] = &AccountLoadInfo{AccountID: acc.ID, LoadRate: 0}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c schedulerTestConcurrencyCache) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
|
||||
if c.waitCounts != nil {
|
||||
if count, ok := c.waitCounts[accountID]; ok {
|
||||
return count, nil
|
||||
}
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
type schedulerTestGatewayCache struct {
|
||||
sessionBindings map[string]int64
|
||||
deletedSessions map[string]int
|
||||
}
|
||||
|
||||
func (c *schedulerTestGatewayCache) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) {
|
||||
if id, ok := c.sessionBindings[sessionHash]; ok {
|
||||
return id, nil
|
||||
}
|
||||
return 0, errors.New("not found")
|
||||
}
|
||||
|
||||
func (c *schedulerTestGatewayCache) SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error {
|
||||
if c.sessionBindings == nil {
|
||||
c.sessionBindings = make(map[string]int64)
|
||||
}
|
||||
c.sessionBindings[sessionHash] = accountID
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *schedulerTestGatewayCache) RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *schedulerTestGatewayCache) DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error {
|
||||
if c.sessionBindings == nil {
|
||||
return nil
|
||||
}
|
||||
if c.deletedSessions == nil {
|
||||
c.deletedSessions = make(map[string]int)
|
||||
}
|
||||
c.deletedSessions[sessionHash]++
|
||||
delete(c.sessionBindings, sessionHash)
|
||||
return nil
|
||||
}
|
||||
|
||||
func newSchedulerTestOpenAIWSV2Config() *config.Config {
|
||||
cfg := &config.Config{}
|
||||
cfg.Gateway.OpenAIWS.Enabled = true
|
||||
cfg.Gateway.OpenAIWS.OAuthEnabled = true
|
||||
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
||||
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
||||
cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600
|
||||
return cfg
|
||||
}
|
||||
|
||||
type openAIAdvancedSchedulerSettingRepoStub struct {
|
||||
values map[string]string
|
||||
}
|
||||
|
||||
func (s *openAIAdvancedSchedulerSettingRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
|
||||
value, err := s.GetValue(ctx, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Setting{Key: key, Value: value}, nil
|
||||
}
|
||||
|
||||
func (s *openAIAdvancedSchedulerSettingRepoStub) GetValue(_ context.Context, key string) (string, error) {
|
||||
if s == nil || s.values == nil {
|
||||
return "", ErrSettingNotFound
|
||||
}
|
||||
value, ok := s.values[key]
|
||||
if !ok {
|
||||
return "", ErrSettingNotFound
|
||||
}
|
||||
return value, nil
|
||||
}
|
||||
|
||||
func (s *openAIAdvancedSchedulerSettingRepoStub) Set(context.Context, string, string) error {
|
||||
panic("unexpected call to Set")
|
||||
}
|
||||
|
||||
func (s *openAIAdvancedSchedulerSettingRepoStub) GetMultiple(context.Context, []string) (map[string]string, error) {
|
||||
panic("unexpected call to GetMultiple")
|
||||
}
|
||||
|
||||
func (s *openAIAdvancedSchedulerSettingRepoStub) SetMultiple(context.Context, map[string]string) error {
|
||||
panic("unexpected call to SetMultiple")
|
||||
}
|
||||
|
||||
func (s *openAIAdvancedSchedulerSettingRepoStub) GetAll(context.Context) (map[string]string, error) {
|
||||
panic("unexpected call to GetAll")
|
||||
}
|
||||
|
||||
func (s *openAIAdvancedSchedulerSettingRepoStub) Delete(context.Context, string) error {
|
||||
panic("unexpected call to Delete")
|
||||
}
|
||||
|
||||
func newOpenAIAdvancedSchedulerRateLimitService(enabled string) *RateLimitService {
|
||||
resetOpenAIAdvancedSchedulerSettingCacheForTest()
|
||||
repo := &openAIAdvancedSchedulerSettingRepoStub{
|
||||
values: map[string]string{},
|
||||
}
|
||||
if enabled != "" {
|
||||
repo.values[openAIAdvancedSchedulerSettingKey] = enabled
|
||||
}
|
||||
return &RateLimitService{
|
||||
settingService: NewSettingService(repo, &config.Config{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *openAISnapshotCacheStub) GetSnapshot(ctx context.Context, bucket SchedulerBucket) ([]*Account, bool, error) {
|
||||
if len(s.snapshotAccounts) == 0 {
|
||||
return nil, false, nil
|
||||
@@ -45,6 +242,138 @@ func (s *openAISnapshotCacheStub) GetAccount(ctx context.Context, accountID int6
|
||||
return &cloned, nil
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_SelectAccountWithScheduler_DefaultDisabledUsesLegacyLoadAwareness(t *testing.T) {
|
||||
resetOpenAIAdvancedSchedulerSettingCacheForTest()
|
||||
|
||||
ctx := context.Background()
|
||||
groupID := int64(10106)
|
||||
accounts := []Account{
|
||||
{
|
||||
ID: 36001,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Priority: 5,
|
||||
},
|
||||
{
|
||||
ID: 36002,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Priority: 0,
|
||||
},
|
||||
}
|
||||
cfg := &config.Config{}
|
||||
cfg.Gateway.Scheduling.LoadBatchEnabled = false
|
||||
cache := &schedulerTestGatewayCache{}
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
|
||||
}
|
||||
|
||||
store := svc.getOpenAIWSStateStore()
|
||||
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_disabled_001", 36001, time.Hour))
|
||||
require.False(t, svc.isOpenAIAdvancedSchedulerEnabled(ctx))
|
||||
|
||||
selection, decision, err := svc.SelectAccountWithScheduler(
|
||||
ctx,
|
||||
&groupID,
|
||||
"resp_disabled_001",
|
||||
"",
|
||||
"gpt-5.1",
|
||||
nil,
|
||||
OpenAIUpstreamTransportAny,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, selection)
|
||||
require.NotNil(t, selection.Account)
|
||||
require.Equal(t, int64(36002), selection.Account.ID)
|
||||
require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
|
||||
require.False(t, decision.StickyPreviousHit)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_SelectAccountWithScheduler_EnabledUsesAdvancedPreviousResponseRouting(t *testing.T) {
|
||||
resetOpenAIAdvancedSchedulerSettingCacheForTest()
|
||||
|
||||
ctx := context.Background()
|
||||
groupID := int64(10107)
|
||||
accounts := []Account{
|
||||
{
|
||||
ID: 37001,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Priority: 5,
|
||||
Extra: map[string]any{
|
||||
"openai_apikey_responses_websockets_v2_enabled": true,
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: 37002,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
Priority: 0,
|
||||
},
|
||||
}
|
||||
cfg := &config.Config{}
|
||||
cfg.Gateway.Scheduling.LoadBatchEnabled = false
|
||||
cfg.Gateway.OpenAIWS.Enabled = true
|
||||
cfg.Gateway.OpenAIWS.OAuthEnabled = true
|
||||
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
|
||||
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
||||
cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
|
||||
cache: &schedulerTestGatewayCache{},
|
||||
cfg: cfg,
|
||||
rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
|
||||
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
|
||||
}
|
||||
|
||||
store := svc.getOpenAIWSStateStore()
|
||||
require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_enabled_001", 37001, time.Hour))
|
||||
require.True(t, svc.isOpenAIAdvancedSchedulerEnabled(ctx))
|
||||
|
||||
selection, decision, err := svc.SelectAccountWithScheduler(
|
||||
ctx,
|
||||
&groupID,
|
||||
"resp_enabled_001",
|
||||
"",
|
||||
"gpt-5.1",
|
||||
nil,
|
||||
OpenAIUpstreamTransportAny,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, selection)
|
||||
require.NotNil(t, selection.Account)
|
||||
require.Equal(t, int64(37001), selection.Account.ID)
|
||||
require.Equal(t, openAIAccountScheduleLayerPreviousResponse, decision.Layer)
|
||||
require.True(t, decision.StickyPreviousHit)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_OpenAIAccountSchedulerMetrics_DisabledNoOp(t *testing.T) {
|
||||
resetOpenAIAdvancedSchedulerSettingCacheForTest()
|
||||
|
||||
svc := &OpenAIGatewayService{}
|
||||
ttft := 120
|
||||
svc.ReportOpenAIAccountScheduleResult(10, true, &ttft)
|
||||
svc.RecordOpenAIAccountSwitch()
|
||||
|
||||
snapshot := svc.SnapshotOpenAIAccountSchedulerMetrics()
|
||||
require.Equal(t, OpenAIAccountSchedulerMetricsSnapshot{}, snapshot)
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyRateLimitedAccountFallsBackToFreshCandidate(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(10101)
|
||||
@@ -53,10 +382,17 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyRateLimite
|
||||
staleBackup := &Account{ID: 31002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
|
||||
freshSticky := &Account{ID: 31001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0, RateLimitResetAt: &rateLimitedUntil}
|
||||
freshBackup := &Account{ID: 31002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
|
||||
cache := &stubGatewayCache{sessionBindings: map[string]int64{"openai:session_hash_rate_limited": 31001}}
|
||||
cache := &schedulerTestGatewayCache{sessionBindings: map[string]int64{"openai:session_hash_rate_limited": 31001}}
|
||||
snapshotCache := &openAISnapshotCacheStub{snapshotAccounts: []*Account{staleSticky, staleBackup}, accountsByID: map[int64]*Account{31001: freshSticky, 31002: freshBackup}}
|
||||
snapshotService := &SchedulerSnapshotService{cache: snapshotCache}
|
||||
svc := &OpenAIGatewayService{accountRepo: stubOpenAIAccountRepo{accounts: []Account{*freshSticky, *freshBackup}}, cache: cache, cfg: &config.Config{}, schedulerSnapshot: snapshotService, concurrencyService: NewConcurrencyService(stubConcurrencyCache{})}
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{*freshSticky, *freshBackup}},
|
||||
cache: cache,
|
||||
cfg: &config.Config{},
|
||||
rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
|
||||
schedulerSnapshot: snapshotService,
|
||||
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
|
||||
}
|
||||
|
||||
selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_rate_limited", "gpt-5.1", nil, OpenAIUpstreamTransportAny)
|
||||
require.NoError(t, err)
|
||||
@@ -76,7 +412,12 @@ func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_SkipsFreshlyRa
|
||||
freshSecondary := &Account{ID: 32002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
|
||||
snapshotCache := &openAISnapshotCacheStub{snapshotAccounts: []*Account{stalePrimary, staleSecondary}, accountsByID: map[int64]*Account{32001: freshPrimary, 32002: freshSecondary}}
|
||||
snapshotService := &SchedulerSnapshotService{cache: snapshotCache}
|
||||
svc := &OpenAIGatewayService{accountRepo: stubOpenAIAccountRepo{accounts: []Account{*freshPrimary, *freshSecondary}}, cfg: &config.Config{}, schedulerSnapshot: snapshotService}
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{*freshPrimary, *freshSecondary}},
|
||||
cfg: &config.Config{},
|
||||
rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
|
||||
schedulerSnapshot: snapshotService,
|
||||
}
|
||||
|
||||
account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "gpt-5.1", nil)
|
||||
require.NoError(t, err)
|
||||
@@ -92,18 +433,19 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyDBRuntimeR
|
||||
staleBackup := &Account{ID: 33002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
|
||||
dbSticky := Account{ID: 33001, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0, RateLimitResetAt: &rateLimitedUntil}
|
||||
dbBackup := Account{ID: 33002, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 5}
|
||||
cache := &stubGatewayCache{sessionBindings: map[string]int64{"openai:session_hash_db_runtime_recheck": 33001}}
|
||||
cache := &schedulerTestGatewayCache{sessionBindings: map[string]int64{"openai:session_hash_db_runtime_recheck": 33001}}
|
||||
snapshotCache := &openAISnapshotCacheStub{
|
||||
snapshotAccounts: []*Account{staleSticky, staleBackup},
|
||||
accountsByID: map[int64]*Account{33001: staleSticky, 33002: staleBackup},
|
||||
}
|
||||
snapshotService := &SchedulerSnapshotService{cache: snapshotCache}
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: stubOpenAIAccountRepo{accounts: []Account{dbSticky, dbBackup}},
|
||||
accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{dbSticky, dbBackup}},
|
||||
cache: cache,
|
||||
cfg: &config.Config{},
|
||||
rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
|
||||
schedulerSnapshot: snapshotService,
|
||||
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
|
||||
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
|
||||
}
|
||||
|
||||
selection, decision, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_db_runtime_recheck", "gpt-5.1", nil, OpenAIUpstreamTransportAny)
|
||||
@@ -128,8 +470,9 @@ func TestOpenAIGatewayService_SelectAccountForModelWithExclusions_DBRuntimeReche
|
||||
}
|
||||
snapshotService := &SchedulerSnapshotService{cache: snapshotCache}
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: stubOpenAIAccountRepo{accounts: []Account{dbPrimary, dbSecondary}},
|
||||
accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{dbPrimary, dbSecondary}},
|
||||
cfg: &config.Config{},
|
||||
rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
|
||||
schedulerSnapshot: snapshotService,
|
||||
}
|
||||
|
||||
@@ -153,7 +496,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_PreviousResponseSticky(
|
||||
"openai_apikey_responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{}
|
||||
cache := &schedulerTestGatewayCache{}
|
||||
cfg := &config.Config{}
|
||||
cfg.Gateway.OpenAIWS.Enabled = true
|
||||
cfg.Gateway.OpenAIWS.OAuthEnabled = true
|
||||
@@ -163,10 +506,11 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_PreviousResponseSticky(
|
||||
cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
|
||||
accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{account}},
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
|
||||
rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
|
||||
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
|
||||
}
|
||||
|
||||
store := svc.getOpenAIWSStateStore()
|
||||
@@ -204,17 +548,18 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionSticky(t *testin
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
}
|
||||
cache := &stubGatewayCache{
|
||||
cache := &schedulerTestGatewayCache{
|
||||
sessionBindings: map[string]int64{
|
||||
"openai:session_hash_abc": account.ID,
|
||||
},
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
|
||||
accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{account}},
|
||||
cache: cache,
|
||||
cfg: &config.Config{},
|
||||
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
|
||||
rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
|
||||
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
|
||||
}
|
||||
|
||||
selection, decision, err := svc.SelectAccountWithScheduler(
|
||||
@@ -260,7 +605,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyBusyKeepsS
|
||||
Priority: 9,
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{
|
||||
cache := &schedulerTestGatewayCache{
|
||||
sessionBindings: map[string]int64{
|
||||
"openai:session_hash_sticky_busy": 21001,
|
||||
},
|
||||
@@ -273,7 +618,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyBusyKeepsS
|
||||
cfg.Gateway.OpenAIWS.OAuthEnabled = true
|
||||
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
|
||||
|
||||
concurrencyCache := stubConcurrencyCache{
|
||||
concurrencyCache := schedulerTestConcurrencyCache{
|
||||
acquireResults: map[int64]bool{
|
||||
21001: false, // sticky 账号已满
|
||||
21002: true, // 若回退负载均衡会命中该账号(本测试要求不能切换)
|
||||
@@ -288,9 +633,10 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyBusyKeepsS
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: stubOpenAIAccountRepo{accounts: accounts},
|
||||
accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
|
||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||
}
|
||||
|
||||
@@ -328,17 +674,18 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionSticky_ForceHTTP
|
||||
"openai_ws_force_http": true,
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{
|
||||
cache := &schedulerTestGatewayCache{
|
||||
sessionBindings: map[string]int64{
|
||||
"openai:session_hash_force_http": account.ID,
|
||||
},
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
|
||||
accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{account}},
|
||||
cache: cache,
|
||||
cfg: &config.Config{},
|
||||
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
|
||||
rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
|
||||
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
|
||||
}
|
||||
|
||||
selection, decision, err := svc.SelectAccountWithScheduler(
|
||||
@@ -387,15 +734,15 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_SkipsStick
|
||||
},
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{
|
||||
cache := &schedulerTestGatewayCache{
|
||||
sessionBindings: map[string]int64{
|
||||
"openai:session_hash_ws_only": 2201,
|
||||
},
|
||||
}
|
||||
cfg := newOpenAIWSV2TestConfig()
|
||||
cfg := newSchedulerTestOpenAIWSV2Config()
|
||||
|
||||
// 构造“HTTP-only 账号负载更低”的场景,验证 required transport 会强制过滤。
|
||||
concurrencyCache := stubConcurrencyCache{
|
||||
concurrencyCache := schedulerTestConcurrencyCache{
|
||||
loadMap: map[int64]*AccountLoadInfo{
|
||||
2201: {AccountID: 2201, LoadRate: 0, WaitingCount: 0},
|
||||
2202: {AccountID: 2202, LoadRate: 90, WaitingCount: 5},
|
||||
@@ -403,9 +750,10 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_SkipsStick
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: stubOpenAIAccountRepo{accounts: accounts},
|
||||
accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
|
||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||
}
|
||||
|
||||
@@ -445,10 +793,11 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_NoAvailabl
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: stubOpenAIAccountRepo{accounts: accounts},
|
||||
cache: &stubGatewayCache{},
|
||||
cfg: newOpenAIWSV2TestConfig(),
|
||||
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
|
||||
accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
|
||||
cache: &schedulerTestGatewayCache{},
|
||||
cfg: newSchedulerTestOpenAIWSV2Config(),
|
||||
rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
|
||||
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
|
||||
}
|
||||
|
||||
selection, decision, err := svc.SelectAccountWithScheduler(
|
||||
@@ -507,7 +856,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceTopKFallback
|
||||
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 0.2
|
||||
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 0.1
|
||||
|
||||
concurrencyCache := stubConcurrencyCache{
|
||||
concurrencyCache := schedulerTestConcurrencyCache{
|
||||
loadMap: map[int64]*AccountLoadInfo{
|
||||
3001: {AccountID: 3001, LoadRate: 95, WaitingCount: 8},
|
||||
3002: {AccountID: 3002, LoadRate: 20, WaitingCount: 1},
|
||||
@@ -520,9 +869,10 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceTopKFallback
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: stubOpenAIAccountRepo{accounts: accounts},
|
||||
cache: &stubGatewayCache{},
|
||||
accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
|
||||
cache: &schedulerTestGatewayCache{},
|
||||
cfg: cfg,
|
||||
rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
|
||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||
}
|
||||
|
||||
@@ -559,16 +909,17 @@ func TestOpenAIGatewayService_OpenAIAccountSchedulerMetrics(t *testing.T) {
|
||||
Schedulable: true,
|
||||
Concurrency: 1,
|
||||
}
|
||||
cache := &stubGatewayCache{
|
||||
cache := &schedulerTestGatewayCache{
|
||||
sessionBindings: map[string]int64{
|
||||
"openai:session_hash_metrics": account.ID,
|
||||
},
|
||||
}
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}},
|
||||
accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{account}},
|
||||
cache: cache,
|
||||
cfg: &config.Config{},
|
||||
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
|
||||
rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
|
||||
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
|
||||
}
|
||||
|
||||
selection, _, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_metrics", "gpt-5.1", nil, OpenAIUpstreamTransportAny)
|
||||
@@ -749,7 +1100,7 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceDistributesA
|
||||
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 1
|
||||
cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 1
|
||||
|
||||
concurrencyCache := stubConcurrencyCache{
|
||||
concurrencyCache := schedulerTestConcurrencyCache{
|
||||
loadMap: map[int64]*AccountLoadInfo{
|
||||
5101: {AccountID: 5101, LoadRate: 20, WaitingCount: 1},
|
||||
5102: {AccountID: 5102, LoadRate: 20, WaitingCount: 1},
|
||||
@@ -757,9 +1108,10 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceDistributesA
|
||||
},
|
||||
}
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: stubOpenAIAccountRepo{accounts: accounts},
|
||||
cache: &stubGatewayCache{sessionBindings: map[string]int64{}},
|
||||
accountRepo: schedulerTestOpenAIAccountRepo{accounts: accounts},
|
||||
cache: &schedulerTestGatewayCache{sessionBindings: map[string]int64{}},
|
||||
cfg: cfg,
|
||||
rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
|
||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||
}
|
||||
|
||||
@@ -905,12 +1257,14 @@ func TestDefaultOpenAIAccountScheduler_ReportSwitchAndSnapshot(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_SchedulerWrappersAndDefaults(t *testing.T) {
|
||||
resetOpenAIAdvancedSchedulerSettingCacheForTest()
|
||||
|
||||
svc := &OpenAIGatewayService{}
|
||||
ttft := 120
|
||||
svc.ReportOpenAIAccountScheduleResult(10, true, &ttft)
|
||||
svc.RecordOpenAIAccountSwitch()
|
||||
snapshot := svc.SnapshotOpenAIAccountSchedulerMetrics()
|
||||
require.GreaterOrEqual(t, snapshot.AccountSwitchTotal, int64(1))
|
||||
require.Equal(t, OpenAIAccountSchedulerMetricsSnapshot{}, snapshot)
|
||||
require.Equal(t, 7, svc.openAIWSLBTopK())
|
||||
require.Equal(t, openaiStickySessionTTL, svc.openAIWSSessionStickyTTL())
|
||||
|
||||
@@ -947,7 +1301,7 @@ func TestDefaultOpenAIAccountScheduler_IsAccountTransportCompatible_Branches(t *
|
||||
require.True(t, scheduler.isAccountTransportCompatible(nil, OpenAIUpstreamTransportHTTPSSE))
|
||||
require.False(t, scheduler.isAccountTransportCompatible(nil, OpenAIUpstreamTransportResponsesWebsocketV2))
|
||||
|
||||
cfg := newOpenAIWSV2TestConfig()
|
||||
cfg := newSchedulerTestOpenAIWSV2Config()
|
||||
scheduler.service = &OpenAIGatewayService{cfg: cfg}
|
||||
account := &Account{
|
||||
ID: 8801,
|
||||
|
||||
@@ -38,11 +38,12 @@ func TestOpenAIGatewayService_SelectAccountWithScheduler_UsesWSPassthroughSnapsh
|
||||
cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeCtxPool
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: stubOpenAIAccountRepo{accounts: []Account{*account}},
|
||||
cache: &stubGatewayCache{},
|
||||
accountRepo: schedulerTestOpenAIAccountRepo{accounts: []Account{*account}},
|
||||
cache: &schedulerTestGatewayCache{},
|
||||
cfg: cfg,
|
||||
rateLimitService: newOpenAIAdvancedSchedulerRateLimitService("true"),
|
||||
schedulerSnapshot: &SchedulerSnapshotService{cache: snapshotCache},
|
||||
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
|
||||
concurrencyService: NewConcurrencyService(schedulerTestConcurrencyCache{}),
|
||||
}
|
||||
|
||||
selection, decision, err := svc.SelectAccountWithScheduler(
|
||||
|
||||
@@ -196,12 +196,25 @@ func (s *PaymentConfigService) GetPaymentConfig(ctx context.Context) (*PaymentCo
|
||||
SettingHelpImageURL, SettingHelpText,
|
||||
SettingCancelRateLimitOn, SettingCancelRateLimitMax,
|
||||
SettingCancelWindowSize, SettingCancelWindowUnit, SettingCancelWindowMode,
|
||||
SettingPaymentVisibleMethodAlipayEnabled, SettingPaymentVisibleMethodAlipaySource,
|
||||
SettingPaymentVisibleMethodWxpayEnabled, SettingPaymentVisibleMethodWxpaySource,
|
||||
}
|
||||
vals, err := s.settingRepo.GetMultiple(ctx, keys)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get payment config settings: %w", err)
|
||||
}
|
||||
cfg := s.parsePaymentConfig(vals)
|
||||
if s.entClient != nil {
|
||||
instances, err := s.entClient.PaymentProviderInstance.Query().
|
||||
Where(paymentproviderinstance.EnabledEQ(true)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list enabled provider instances: %w", err)
|
||||
}
|
||||
cfg.EnabledTypes = applyVisibleMethodRoutingToEnabledTypes(cfg.EnabledTypes, vals, buildVisibleMethodSourceAvailability(instances))
|
||||
} else {
|
||||
cfg.EnabledTypes = applyVisibleMethodRoutingToEnabledTypes(cfg.EnabledTypes, vals, nil)
|
||||
}
|
||||
// Load Stripe publishable key from the first enabled Stripe provider instance
|
||||
cfg.StripePublishableKey = s.getStripePublishableKey(ctx)
|
||||
return cfg, nil
|
||||
@@ -234,18 +247,23 @@ func (s *PaymentConfigService) parsePaymentConfig(vals map[string]string) *Payme
|
||||
cfg.LoadBalanceStrategy = payment.DefaultLoadBalanceStrategy
|
||||
}
|
||||
if raw := vals[SettingEnabledPaymentTypes]; raw != "" {
|
||||
types := make([]string, 0, len(strings.Split(raw, ",")))
|
||||
for _, t := range strings.Split(raw, ",") {
|
||||
t = strings.TrimSpace(t)
|
||||
if t != "" {
|
||||
cfg.EnabledTypes = append(cfg.EnabledTypes, t)
|
||||
types = append(types, t)
|
||||
}
|
||||
}
|
||||
cfg.EnabledTypes = NormalizeVisibleMethods(types)
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
// getStripePublishableKey finds the publishable key from the first enabled Stripe provider instance.
|
||||
func (s *PaymentConfigService) getStripePublishableKey(ctx context.Context) string {
|
||||
if s.entClient == nil {
|
||||
return ""
|
||||
}
|
||||
instances, err := s.entClient.PaymentProviderInstance.Query().
|
||||
Where(
|
||||
paymentproviderinstance.EnabledEQ(true),
|
||||
@@ -385,3 +403,79 @@ func pcParseInt(s string, defaultVal int) int {
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func buildVisibleMethodSourceAvailability(instances []*dbent.PaymentProviderInstance) map[string]bool {
|
||||
available := make(map[string]bool, 4)
|
||||
for _, inst := range instances {
|
||||
switch inst.ProviderKey {
|
||||
case payment.TypeAlipay:
|
||||
if inst.SupportedTypes == "" || payment.InstanceSupportsType(inst.SupportedTypes, payment.TypeAlipay) || payment.InstanceSupportsType(inst.SupportedTypes, payment.TypeAlipayDirect) {
|
||||
available[VisibleMethodSourceOfficialAlipay] = true
|
||||
}
|
||||
case payment.TypeWxpay:
|
||||
if inst.SupportedTypes == "" || payment.InstanceSupportsType(inst.SupportedTypes, payment.TypeWxpay) || payment.InstanceSupportsType(inst.SupportedTypes, payment.TypeWxpayDirect) {
|
||||
available[VisibleMethodSourceOfficialWechat] = true
|
||||
}
|
||||
case payment.TypeEasyPay:
|
||||
for _, supportedType := range splitTypes(inst.SupportedTypes) {
|
||||
switch NormalizeVisibleMethod(supportedType) {
|
||||
case payment.TypeAlipay:
|
||||
available[VisibleMethodSourceEasyPayAlipay] = true
|
||||
case payment.TypeWxpay:
|
||||
available[VisibleMethodSourceEasyPayWechat] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return available
|
||||
}
|
||||
|
||||
func applyVisibleMethodRoutingToEnabledTypes(base []string, vals map[string]string, available map[string]bool) []string {
|
||||
shouldExpose := map[string]bool{
|
||||
payment.TypeAlipay: visibleMethodShouldBeExposed(payment.TypeAlipay, vals, available),
|
||||
payment.TypeWxpay: visibleMethodShouldBeExposed(payment.TypeWxpay, vals, available),
|
||||
}
|
||||
|
||||
seen := make(map[string]struct{}, len(base)+2)
|
||||
out := make([]string, 0, len(base)+2)
|
||||
appendType := func(paymentType string) {
|
||||
paymentType = NormalizeVisibleMethod(paymentType)
|
||||
if paymentType == "" {
|
||||
return
|
||||
}
|
||||
if _, ok := seen[paymentType]; ok {
|
||||
return
|
||||
}
|
||||
seen[paymentType] = struct{}{}
|
||||
out = append(out, paymentType)
|
||||
}
|
||||
|
||||
for _, paymentType := range base {
|
||||
visibleMethod := NormalizeVisibleMethod(paymentType)
|
||||
switch visibleMethod {
|
||||
case payment.TypeAlipay, payment.TypeWxpay:
|
||||
if shouldExpose[visibleMethod] {
|
||||
appendType(visibleMethod)
|
||||
}
|
||||
default:
|
||||
appendType(visibleMethod)
|
||||
}
|
||||
}
|
||||
|
||||
for _, visibleMethod := range []string{payment.TypeAlipay, payment.TypeWxpay} {
|
||||
if shouldExpose[visibleMethod] {
|
||||
appendType(visibleMethod)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func visibleMethodShouldBeExposed(method string, vals map[string]string, available map[string]bool) bool {
|
||||
enabledKey := visibleMethodEnabledSettingKey(method)
|
||||
sourceKey := visibleMethodSourceSettingKey(method)
|
||||
if enabledKey == "" || sourceKey == "" || vals[enabledKey] != "true" {
|
||||
return false
|
||||
}
|
||||
source := NormalizeVisibleMethodSource(method, vals[sourceKey])
|
||||
return source != "" && available[source]
|
||||
}
|
||||
|
||||
@@ -1,9 +1,17 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/enttest"
|
||||
"github.com/Wei-Shaw/sub2api/internal/payment"
|
||||
|
||||
"entgo.io/ent/dialect"
|
||||
entsql "entgo.io/ent/dialect/sql"
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
func TestPcParseFloat(t *testing.T) {
|
||||
@@ -163,6 +171,20 @@ func TestParsePaymentConfig(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("enabled types are normalized to visible methods and deduplicated", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
vals := map[string]string{
|
||||
SettingEnabledPaymentTypes: "alipay_direct, alipay, wxpay_direct, wxpay",
|
||||
}
|
||||
cfg := svc.parsePaymentConfig(vals)
|
||||
if len(cfg.EnabledTypes) != 2 {
|
||||
t.Fatalf("EnabledTypes len = %d, want 2", len(cfg.EnabledTypes))
|
||||
}
|
||||
if cfg.EnabledTypes[0] != "alipay" || cfg.EnabledTypes[1] != "wxpay" {
|
||||
t.Fatalf("EnabledTypes = %v, want [alipay wxpay]", cfg.EnabledTypes)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty enabled types string", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
vals := map[string]string{
|
||||
@@ -204,3 +226,167 @@ func TestGetBasePaymentType(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyVisibleMethodRoutingToEnabledTypes(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
base := []string{"alipay", "wxpay", "stripe"}
|
||||
vals := map[string]string{
|
||||
SettingPaymentVisibleMethodAlipayEnabled: "true",
|
||||
SettingPaymentVisibleMethodAlipaySource: VisibleMethodSourceOfficialAlipay,
|
||||
SettingPaymentVisibleMethodWxpayEnabled: "true",
|
||||
SettingPaymentVisibleMethodWxpaySource: VisibleMethodSourceOfficialWechat,
|
||||
}
|
||||
available := map[string]bool{
|
||||
VisibleMethodSourceOfficialAlipay: true,
|
||||
VisibleMethodSourceOfficialWechat: false,
|
||||
}
|
||||
|
||||
got := applyVisibleMethodRoutingToEnabledTypes(base, vals, available)
|
||||
want := []string{"alipay", "stripe"}
|
||||
if len(got) != len(want) {
|
||||
t.Fatalf("applyVisibleMethodRoutingToEnabledTypes len = %d, want %d (%v)", len(got), len(want), got)
|
||||
}
|
||||
for i := range want {
|
||||
if got[i] != want[i] {
|
||||
t.Fatalf("applyVisibleMethodRoutingToEnabledTypes[%d] = %q, want %q (full=%v)", i, got[i], want[i], got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyVisibleMethodRoutingAddsConfiguredVisibleMethod(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
base := []string{"stripe"}
|
||||
vals := map[string]string{
|
||||
SettingPaymentVisibleMethodAlipayEnabled: "true",
|
||||
SettingPaymentVisibleMethodAlipaySource: VisibleMethodSourceEasyPayAlipay,
|
||||
}
|
||||
available := map[string]bool{
|
||||
VisibleMethodSourceEasyPayAlipay: true,
|
||||
}
|
||||
|
||||
got := applyVisibleMethodRoutingToEnabledTypes(base, vals, available)
|
||||
want := []string{"stripe", "alipay"}
|
||||
if len(got) != len(want) {
|
||||
t.Fatalf("applyVisibleMethodRoutingToEnabledTypes len = %d, want %d (%v)", len(got), len(want), got)
|
||||
}
|
||||
for i := range want {
|
||||
if got[i] != want[i] {
|
||||
t.Fatalf("applyVisibleMethodRoutingToEnabledTypes[%d] = %q, want %q (full=%v)", i, got[i], want[i], got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildVisibleMethodSourceAvailability(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
instances := []*dbent.PaymentProviderInstance{
|
||||
{ProviderKey: payment.TypeAlipay, SupportedTypes: "alipay"},
|
||||
{ProviderKey: payment.TypeEasyPay, SupportedTypes: "wxpay_direct, alipay"},
|
||||
{ProviderKey: payment.TypeWxpay, SupportedTypes: "wxpay_direct"},
|
||||
}
|
||||
|
||||
got := buildVisibleMethodSourceAvailability(instances)
|
||||
if !got[VisibleMethodSourceOfficialAlipay] {
|
||||
t.Fatalf("expected %q to be available", VisibleMethodSourceOfficialAlipay)
|
||||
}
|
||||
if !got[VisibleMethodSourceEasyPayAlipay] {
|
||||
t.Fatalf("expected %q to be available", VisibleMethodSourceEasyPayAlipay)
|
||||
}
|
||||
if !got[VisibleMethodSourceOfficialWechat] {
|
||||
t.Fatalf("expected %q to be available", VisibleMethodSourceOfficialWechat)
|
||||
}
|
||||
if !got[VisibleMethodSourceEasyPayWechat] {
|
||||
t.Fatalf("expected %q to be available", VisibleMethodSourceEasyPayWechat)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetPaymentConfigAppliesVisibleMethodRouting(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := newPaymentConfigServiceTestClient(t)
|
||||
|
||||
_, err := client.PaymentProviderInstance.Create().
|
||||
SetProviderKey(payment.TypeEasyPay).
|
||||
SetName("EasyPay Alipay").
|
||||
SetConfig("{}").
|
||||
SetSupportedTypes("alipay").
|
||||
SetEnabled(true).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("create easypay instance: %v", err)
|
||||
}
|
||||
|
||||
svc := &PaymentConfigService{
|
||||
entClient: client,
|
||||
settingRepo: &paymentConfigSettingRepoStub{
|
||||
values: map[string]string{
|
||||
SettingEnabledPaymentTypes: "alipay,wxpay,stripe",
|
||||
SettingPaymentVisibleMethodAlipayEnabled: "true",
|
||||
SettingPaymentVisibleMethodAlipaySource: "easypay",
|
||||
SettingPaymentVisibleMethodWxpayEnabled: "true",
|
||||
SettingPaymentVisibleMethodWxpaySource: "wxpay",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
cfg, err := svc.GetPaymentConfig(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("GetPaymentConfig returned error: %v", err)
|
||||
}
|
||||
|
||||
want := []string{payment.TypeAlipay, payment.TypeStripe}
|
||||
if len(cfg.EnabledTypes) != len(want) {
|
||||
t.Fatalf("EnabledTypes len = %d, want %d (%v)", len(cfg.EnabledTypes), len(want), cfg.EnabledTypes)
|
||||
}
|
||||
for i := range want {
|
||||
if cfg.EnabledTypes[i] != want[i] {
|
||||
t.Fatalf("EnabledTypes[%d] = %q, want %q (full=%v)", i, cfg.EnabledTypes[i], want[i], cfg.EnabledTypes)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func newPaymentConfigServiceTestClient(t *testing.T) *dbent.Client {
|
||||
t.Helper()
|
||||
|
||||
db, err := sql.Open("sqlite", "file:payment_config_service?mode=memory&cache=shared")
|
||||
if err != nil {
|
||||
t.Fatalf("open sqlite: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = db.Close() })
|
||||
|
||||
if _, err := db.Exec("PRAGMA foreign_keys = ON"); err != nil {
|
||||
t.Fatalf("enable foreign keys: %v", err)
|
||||
}
|
||||
|
||||
drv := entsql.OpenDB(dialect.SQLite, db)
|
||||
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
|
||||
t.Cleanup(func() { _ = client.Close() })
|
||||
return client
|
||||
}
|
||||
|
||||
type paymentConfigSettingRepoStub struct {
|
||||
values map[string]string
|
||||
}
|
||||
|
||||
func (s *paymentConfigSettingRepoStub) Get(context.Context, string) (*Setting, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (s *paymentConfigSettingRepoStub) GetValue(_ context.Context, key string) (string, error) {
|
||||
return s.values[key], nil
|
||||
}
|
||||
func (s *paymentConfigSettingRepoStub) Set(context.Context, string, string) error { return nil }
|
||||
func (s *paymentConfigSettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
|
||||
out := make(map[string]string, len(keys))
|
||||
for _, key := range keys {
|
||||
out[key] = s.values[key]
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
func (s *paymentConfigSettingRepoStub) SetMultiple(context.Context, map[string]string) error {
|
||||
return nil
|
||||
}
|
||||
func (s *paymentConfigSettingRepoStub) GetAll(context.Context) (map[string]string, error) {
|
||||
return s.values, nil
|
||||
}
|
||||
func (s *paymentConfigSettingRepoStub) Delete(context.Context, string) error { return nil }
|
||||
|
||||
248
backend/internal/service/payment_resume_service.go
Normal file
248
backend/internal/service/payment_resume_service.go
Normal file
@@ -0,0 +1,248 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/payment"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
)
|
||||
|
||||
const (
|
||||
PaymentSourceHostedRedirect = "hosted_redirect"
|
||||
PaymentSourceWechatInAppResume = "wechat_in_app_resume"
|
||||
|
||||
paymentResumeFallbackSigningKey = "sub2api-payment-resume"
|
||||
|
||||
SettingPaymentVisibleMethodAlipaySource = "payment_visible_method_alipay_source"
|
||||
SettingPaymentVisibleMethodWxpaySource = "payment_visible_method_wxpay_source"
|
||||
SettingPaymentVisibleMethodAlipayEnabled = "payment_visible_method_alipay_enabled"
|
||||
SettingPaymentVisibleMethodWxpayEnabled = "payment_visible_method_wxpay_enabled"
|
||||
|
||||
VisibleMethodSourceOfficialAlipay = "official_alipay"
|
||||
VisibleMethodSourceEasyPayAlipay = "easypay_alipay"
|
||||
VisibleMethodSourceOfficialWechat = "official_wxpay"
|
||||
VisibleMethodSourceEasyPayWechat = "easypay_wxpay"
|
||||
)
|
||||
|
||||
type ResumeTokenClaims struct {
|
||||
OrderID int64 `json:"oid"`
|
||||
UserID int64 `json:"uid,omitempty"`
|
||||
ProviderInstanceID string `json:"pi,omitempty"`
|
||||
ProviderKey string `json:"pk,omitempty"`
|
||||
PaymentType string `json:"pt,omitempty"`
|
||||
CanonicalReturnURL string `json:"ru,omitempty"`
|
||||
IssuedAt int64 `json:"iat"`
|
||||
}
|
||||
|
||||
type PaymentResumeService struct {
|
||||
signingKey []byte
|
||||
}
|
||||
|
||||
type visibleMethodLoadBalancer struct {
|
||||
inner payment.LoadBalancer
|
||||
configService *PaymentConfigService
|
||||
}
|
||||
|
||||
func NewPaymentResumeService(signingKey []byte) *PaymentResumeService {
|
||||
return &PaymentResumeService{signingKey: signingKey}
|
||||
}
|
||||
|
||||
func NormalizeVisibleMethod(method string) string {
|
||||
return payment.GetBasePaymentType(strings.TrimSpace(method))
|
||||
}
|
||||
|
||||
func NormalizeVisibleMethods(methods []string) []string {
|
||||
if len(methods) == 0 {
|
||||
return nil
|
||||
}
|
||||
seen := make(map[string]struct{}, len(methods))
|
||||
out := make([]string, 0, len(methods))
|
||||
for _, method := range methods {
|
||||
normalized := NormalizeVisibleMethod(method)
|
||||
if normalized == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[normalized]; ok {
|
||||
continue
|
||||
}
|
||||
seen[normalized] = struct{}{}
|
||||
out = append(out, normalized)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func NormalizePaymentSource(source string) string {
|
||||
switch strings.TrimSpace(strings.ToLower(source)) {
|
||||
case "", PaymentSourceHostedRedirect:
|
||||
return PaymentSourceHostedRedirect
|
||||
case "wechat_in_app", "wxpay_resume", PaymentSourceWechatInAppResume:
|
||||
return PaymentSourceWechatInAppResume
|
||||
default:
|
||||
return strings.TrimSpace(strings.ToLower(source))
|
||||
}
|
||||
}
|
||||
|
||||
func NormalizeVisibleMethodSource(method, source string) string {
|
||||
switch NormalizeVisibleMethod(method) {
|
||||
case payment.TypeAlipay:
|
||||
switch strings.TrimSpace(strings.ToLower(source)) {
|
||||
case VisibleMethodSourceOfficialAlipay, payment.TypeAlipay, payment.TypeAlipayDirect, "official":
|
||||
return VisibleMethodSourceOfficialAlipay
|
||||
case VisibleMethodSourceEasyPayAlipay, payment.TypeEasyPay:
|
||||
return VisibleMethodSourceEasyPayAlipay
|
||||
}
|
||||
case payment.TypeWxpay:
|
||||
switch strings.TrimSpace(strings.ToLower(source)) {
|
||||
case VisibleMethodSourceOfficialWechat, payment.TypeWxpay, payment.TypeWxpayDirect, "wechat", "official":
|
||||
return VisibleMethodSourceOfficialWechat
|
||||
case VisibleMethodSourceEasyPayWechat, payment.TypeEasyPay:
|
||||
return VisibleMethodSourceEasyPayWechat
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func VisibleMethodProviderKeyForSource(method, source string) (string, bool) {
|
||||
switch NormalizeVisibleMethodSource(method, source) {
|
||||
case VisibleMethodSourceOfficialAlipay:
|
||||
return payment.TypeAlipay, NormalizeVisibleMethod(method) == payment.TypeAlipay
|
||||
case VisibleMethodSourceEasyPayAlipay:
|
||||
return payment.TypeEasyPay, NormalizeVisibleMethod(method) == payment.TypeAlipay
|
||||
case VisibleMethodSourceOfficialWechat:
|
||||
return payment.TypeWxpay, NormalizeVisibleMethod(method) == payment.TypeWxpay
|
||||
case VisibleMethodSourceEasyPayWechat:
|
||||
return payment.TypeEasyPay, NormalizeVisibleMethod(method) == payment.TypeWxpay
|
||||
default:
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
func newVisibleMethodLoadBalancer(inner payment.LoadBalancer, configService *PaymentConfigService) payment.LoadBalancer {
|
||||
if inner == nil || configService == nil || configService.settingRepo == nil {
|
||||
return inner
|
||||
}
|
||||
return &visibleMethodLoadBalancer{inner: inner, configService: configService}
|
||||
}
|
||||
|
||||
func (lb *visibleMethodLoadBalancer) GetInstanceConfig(ctx context.Context, instanceID int64) (map[string]string, error) {
|
||||
return lb.inner.GetInstanceConfig(ctx, instanceID)
|
||||
}
|
||||
|
||||
func (lb *visibleMethodLoadBalancer) SelectInstance(ctx context.Context, providerKey string, paymentType payment.PaymentType, strategy payment.Strategy, orderAmount float64) (*payment.InstanceSelection, error) {
|
||||
visibleMethod := NormalizeVisibleMethod(paymentType)
|
||||
if providerKey != "" || (visibleMethod != payment.TypeAlipay && visibleMethod != payment.TypeWxpay) {
|
||||
return lb.inner.SelectInstance(ctx, providerKey, paymentType, strategy, orderAmount)
|
||||
}
|
||||
|
||||
enabledKey := visibleMethodEnabledSettingKey(visibleMethod)
|
||||
sourceKey := visibleMethodSourceSettingKey(visibleMethod)
|
||||
vals, err := lb.configService.settingRepo.GetMultiple(ctx, []string{enabledKey, sourceKey})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load visible method routing for %s: %w", visibleMethod, err)
|
||||
}
|
||||
if vals[enabledKey] != "true" {
|
||||
return nil, fmt.Errorf("visible payment method %s is disabled", visibleMethod)
|
||||
}
|
||||
|
||||
targetProviderKey, ok := VisibleMethodProviderKeyForSource(visibleMethod, vals[sourceKey])
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("visible payment method %s has no valid source", visibleMethod)
|
||||
}
|
||||
return lb.inner.SelectInstance(ctx, targetProviderKey, paymentType, strategy, orderAmount)
|
||||
}
|
||||
|
||||
func visibleMethodEnabledSettingKey(method string) string {
|
||||
switch NormalizeVisibleMethod(method) {
|
||||
case payment.TypeAlipay:
|
||||
return SettingPaymentVisibleMethodAlipayEnabled
|
||||
case payment.TypeWxpay:
|
||||
return SettingPaymentVisibleMethodWxpayEnabled
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func visibleMethodSourceSettingKey(method string) string {
|
||||
switch NormalizeVisibleMethod(method) {
|
||||
case payment.TypeAlipay:
|
||||
return SettingPaymentVisibleMethodAlipaySource
|
||||
case payment.TypeWxpay:
|
||||
return SettingPaymentVisibleMethodWxpaySource
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func CanonicalizeReturnURL(raw string) (string, error) {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return "", nil
|
||||
}
|
||||
parsed, err := url.Parse(raw)
|
||||
if err != nil || !parsed.IsAbs() || parsed.Host == "" {
|
||||
return "", infraerrors.BadRequest("INVALID_RETURN_URL", "return_url must be an absolute http/https URL")
|
||||
}
|
||||
if parsed.Scheme != "http" && parsed.Scheme != "https" {
|
||||
return "", infraerrors.BadRequest("INVALID_RETURN_URL", "return_url must use http or https")
|
||||
}
|
||||
parsed.Fragment = ""
|
||||
if parsed.Path == "" {
|
||||
parsed.Path = "/"
|
||||
}
|
||||
return parsed.String(), nil
|
||||
}
|
||||
|
||||
func (s *PaymentResumeService) CreateToken(claims ResumeTokenClaims) (string, error) {
|
||||
if claims.OrderID <= 0 {
|
||||
return "", fmt.Errorf("resume token requires order id")
|
||||
}
|
||||
if claims.IssuedAt == 0 {
|
||||
claims.IssuedAt = time.Now().Unix()
|
||||
}
|
||||
payload, err := json.Marshal(claims)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("marshal resume claims: %w", err)
|
||||
}
|
||||
encodedPayload := base64.RawURLEncoding.EncodeToString(payload)
|
||||
return encodedPayload + "." + s.sign(encodedPayload), nil
|
||||
}
|
||||
|
||||
func (s *PaymentResumeService) ParseToken(token string) (*ResumeTokenClaims, error) {
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
|
||||
return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token is malformed")
|
||||
}
|
||||
if !hmac.Equal([]byte(parts[1]), []byte(s.sign(parts[0]))) {
|
||||
return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token signature mismatch")
|
||||
}
|
||||
payload, err := base64.RawURLEncoding.DecodeString(parts[0])
|
||||
if err != nil {
|
||||
return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token payload is malformed")
|
||||
}
|
||||
var claims ResumeTokenClaims
|
||||
if err := json.Unmarshal(payload, &claims); err != nil {
|
||||
return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token payload is invalid")
|
||||
}
|
||||
if claims.OrderID <= 0 {
|
||||
return nil, infraerrors.BadRequest("INVALID_RESUME_TOKEN", "resume token missing order id")
|
||||
}
|
||||
return &claims, nil
|
||||
}
|
||||
|
||||
func (s *PaymentResumeService) sign(payload string) string {
|
||||
key := s.signingKey
|
||||
if len(key) == 0 {
|
||||
key = []byte(paymentResumeFallbackSigningKey)
|
||||
}
|
||||
mac := hmac.New(sha256.New, key)
|
||||
_, _ = mac.Write([]byte(payload))
|
||||
return base64.RawURLEncoding.EncodeToString(mac.Sum(nil))
|
||||
}
|
||||
240
backend/internal/service/payment_resume_service_test.go
Normal file
240
backend/internal/service/payment_resume_service_test.go
Normal file
@@ -0,0 +1,240 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/payment"
|
||||
)
|
||||
|
||||
func TestNormalizeVisibleMethods(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := NormalizeVisibleMethods([]string{
|
||||
"alipay_direct",
|
||||
"alipay",
|
||||
" wxpay_direct ",
|
||||
"wxpay",
|
||||
"stripe",
|
||||
})
|
||||
|
||||
want := []string{"alipay", "wxpay", "stripe"}
|
||||
if len(got) != len(want) {
|
||||
t.Fatalf("NormalizeVisibleMethods len = %d, want %d (%v)", len(got), len(want), got)
|
||||
}
|
||||
for i := range want {
|
||||
if got[i] != want[i] {
|
||||
t.Fatalf("NormalizeVisibleMethods[%d] = %q, want %q (full=%v)", i, got[i], want[i], got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizePaymentSource(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expect string
|
||||
}{
|
||||
{name: "empty uses default", input: "", expect: PaymentSourceHostedRedirect},
|
||||
{name: "wechat alias normalized", input: "wechat_in_app", expect: PaymentSourceWechatInAppResume},
|
||||
{name: "canonical value preserved", input: PaymentSourceWechatInAppResume, expect: PaymentSourceWechatInAppResume},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if got := NormalizePaymentSource(tt.input); got != tt.expect {
|
||||
t.Fatalf("NormalizePaymentSource(%q) = %q, want %q", tt.input, got, tt.expect)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCanonicalizeReturnURL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got, err := CanonicalizeReturnURL("https://example.com/pay/result?b=2#a")
|
||||
if err != nil {
|
||||
t.Fatalf("CanonicalizeReturnURL returned error: %v", err)
|
||||
}
|
||||
if got != "https://example.com/pay/result?b=2" {
|
||||
t.Fatalf("CanonicalizeReturnURL = %q, want %q", got, "https://example.com/pay/result?b=2")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCanonicalizeReturnURLRejectsRelativeURL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if _, err := CanonicalizeReturnURL("/payment/result"); err == nil {
|
||||
t.Fatal("CanonicalizeReturnURL should reject relative URLs")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPaymentResumeTokenRoundTrip(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
svc := NewPaymentResumeService([]byte("0123456789abcdef0123456789abcdef"))
|
||||
token, err := svc.CreateToken(ResumeTokenClaims{
|
||||
OrderID: 42,
|
||||
UserID: 7,
|
||||
ProviderInstanceID: "19",
|
||||
ProviderKey: "easypay",
|
||||
PaymentType: "wxpay",
|
||||
CanonicalReturnURL: "https://example.com/payment/result",
|
||||
IssuedAt: 1234567890,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("CreateToken returned error: %v", err)
|
||||
}
|
||||
|
||||
claims, err := svc.ParseToken(token)
|
||||
if err != nil {
|
||||
t.Fatalf("ParseToken returned error: %v", err)
|
||||
}
|
||||
if claims.OrderID != 42 || claims.UserID != 7 {
|
||||
t.Fatalf("claims mismatch: %+v", claims)
|
||||
}
|
||||
if claims.ProviderInstanceID != "19" || claims.ProviderKey != "easypay" || claims.PaymentType != "wxpay" {
|
||||
t.Fatalf("claims provider snapshot mismatch: %+v", claims)
|
||||
}
|
||||
if claims.CanonicalReturnURL != "https://example.com/payment/result" {
|
||||
t.Fatalf("claims return URL = %q", claims.CanonicalReturnURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeVisibleMethodSource(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{name: "alipay official alias", method: payment.TypeAlipay, input: "alipay", want: VisibleMethodSourceOfficialAlipay},
|
||||
{name: "alipay easypay alias", method: payment.TypeAlipay, input: "easypay", want: VisibleMethodSourceEasyPayAlipay},
|
||||
{name: "wxpay official alias", method: payment.TypeWxpay, input: "wxpay", want: VisibleMethodSourceOfficialWechat},
|
||||
{name: "wxpay easypay alias", method: payment.TypeWxpay, input: "easypay", want: VisibleMethodSourceEasyPayWechat},
|
||||
{name: "unsupported source", method: payment.TypeWxpay, input: "stripe", want: ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if got := NormalizeVisibleMethodSource(tt.method, tt.input); got != tt.want {
|
||||
t.Fatalf("NormalizeVisibleMethodSource(%q, %q) = %q, want %q", tt.method, tt.input, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVisibleMethodProviderKeyForSource(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
source string
|
||||
want string
|
||||
ok bool
|
||||
}{
|
||||
{name: "official alipay", method: payment.TypeAlipay, source: VisibleMethodSourceOfficialAlipay, want: payment.TypeAlipay, ok: true},
|
||||
{name: "easypay alipay", method: payment.TypeAlipay, source: VisibleMethodSourceEasyPayAlipay, want: payment.TypeEasyPay, ok: true},
|
||||
{name: "official wechat", method: payment.TypeWxpay, source: VisibleMethodSourceOfficialWechat, want: payment.TypeWxpay, ok: true},
|
||||
{name: "easypay wechat", method: payment.TypeWxpay, source: VisibleMethodSourceEasyPayWechat, want: payment.TypeEasyPay, ok: true},
|
||||
{name: "mismatched method and source", method: payment.TypeAlipay, source: VisibleMethodSourceOfficialWechat, want: "", ok: false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got, ok := VisibleMethodProviderKeyForSource(tt.method, tt.source)
|
||||
if got != tt.want || ok != tt.ok {
|
||||
t.Fatalf("VisibleMethodProviderKeyForSource(%q, %q) = (%q, %v), want (%q, %v)", tt.method, tt.source, got, ok, tt.want, tt.ok)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVisibleMethodLoadBalancerUsesConfiguredSource(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
inner := &captureLoadBalancer{}
|
||||
configService := &PaymentConfigService{
|
||||
settingRepo: &paymentSettingRepoStub{
|
||||
values: map[string]string{
|
||||
SettingPaymentVisibleMethodAlipayEnabled: "true",
|
||||
SettingPaymentVisibleMethodAlipaySource: VisibleMethodSourceOfficialAlipay,
|
||||
},
|
||||
},
|
||||
}
|
||||
lb := newVisibleMethodLoadBalancer(inner, configService)
|
||||
|
||||
_, err := lb.SelectInstance(context.Background(), "", payment.TypeAlipay, payment.StrategyRoundRobin, 12.5)
|
||||
if err != nil {
|
||||
t.Fatalf("SelectInstance returned error: %v", err)
|
||||
}
|
||||
if inner.lastProviderKey != payment.TypeAlipay {
|
||||
t.Fatalf("lastProviderKey = %q, want %q", inner.lastProviderKey, payment.TypeAlipay)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVisibleMethodLoadBalancerRejectsDisabledVisibleMethod(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
inner := &captureLoadBalancer{}
|
||||
configService := &PaymentConfigService{
|
||||
settingRepo: &paymentSettingRepoStub{
|
||||
values: map[string]string{
|
||||
SettingPaymentVisibleMethodWxpayEnabled: "false",
|
||||
SettingPaymentVisibleMethodWxpaySource: VisibleMethodSourceOfficialWechat,
|
||||
},
|
||||
},
|
||||
}
|
||||
lb := newVisibleMethodLoadBalancer(inner, configService)
|
||||
|
||||
if _, err := lb.SelectInstance(context.Background(), "", payment.TypeWxpay, payment.StrategyRoundRobin, 9.9); err == nil {
|
||||
t.Fatal("SelectInstance should reject disabled visible method")
|
||||
}
|
||||
}
|
||||
|
||||
type paymentSettingRepoStub struct {
|
||||
values map[string]string
|
||||
}
|
||||
|
||||
func (s *paymentSettingRepoStub) Get(context.Context, string) (*Setting, error) { return nil, nil }
|
||||
func (s *paymentSettingRepoStub) GetValue(_ context.Context, key string) (string, error) {
|
||||
return s.values[key], nil
|
||||
}
|
||||
func (s *paymentSettingRepoStub) Set(context.Context, string, string) error { return nil }
|
||||
func (s *paymentSettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
|
||||
out := make(map[string]string, len(keys))
|
||||
for _, key := range keys {
|
||||
out[key] = s.values[key]
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
func (s *paymentSettingRepoStub) SetMultiple(context.Context, map[string]string) error { return nil }
|
||||
func (s *paymentSettingRepoStub) GetAll(context.Context) (map[string]string, error) {
|
||||
return s.values, nil
|
||||
}
|
||||
func (s *paymentSettingRepoStub) Delete(context.Context, string) error { return nil }
|
||||
|
||||
type captureLoadBalancer struct {
|
||||
lastProviderKey string
|
||||
lastPaymentType string
|
||||
}
|
||||
|
||||
func (c *captureLoadBalancer) GetInstanceConfig(context.Context, int64) (map[string]string, error) {
|
||||
return map[string]string{}, nil
|
||||
}
|
||||
|
||||
func (c *captureLoadBalancer) SelectInstance(_ context.Context, providerKey string, paymentType payment.PaymentType, _ payment.Strategy, _ float64) (*payment.InstanceSelection, error) {
|
||||
c.lastProviderKey = providerKey
|
||||
c.lastPaymentType = paymentType
|
||||
return &payment.InstanceSelection{ProviderKey: providerKey, SupportedTypes: paymentType}, nil
|
||||
}
|
||||
@@ -65,15 +65,17 @@ func generateRandomString(n int) string {
|
||||
}
|
||||
|
||||
type CreateOrderRequest struct {
|
||||
UserID int64
|
||||
Amount float64
|
||||
PaymentType string
|
||||
ClientIP string
|
||||
IsMobile bool
|
||||
SrcHost string
|
||||
SrcURL string
|
||||
OrderType string
|
||||
PlanID int64
|
||||
UserID int64
|
||||
Amount float64
|
||||
PaymentType string
|
||||
ClientIP string
|
||||
IsMobile bool
|
||||
SrcHost string
|
||||
SrcURL string
|
||||
ReturnURL string
|
||||
PaymentSource string
|
||||
OrderType string
|
||||
PlanID int64
|
||||
}
|
||||
|
||||
type CreateOrderResponse struct {
|
||||
@@ -88,6 +90,7 @@ type CreateOrderResponse struct {
|
||||
ClientSecret string `json:"client_secret,omitempty"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
PaymentMode string `json:"payment_mode,omitempty"`
|
||||
ResumeToken string `json:"resume_token,omitempty"`
|
||||
}
|
||||
|
||||
type OrderListParams struct {
|
||||
@@ -165,10 +168,13 @@ type PaymentService struct {
|
||||
configService *PaymentConfigService
|
||||
userRepo UserRepository
|
||||
groupRepo GroupRepository
|
||||
resumeService *PaymentResumeService
|
||||
}
|
||||
|
||||
func NewPaymentService(entClient *dbent.Client, registry *payment.Registry, loadBalancer payment.LoadBalancer, redeemService *RedeemService, subscriptionSvc *SubscriptionService, configService *PaymentConfigService, userRepo UserRepository, groupRepo GroupRepository) *PaymentService {
|
||||
return &PaymentService{entClient: entClient, registry: registry, loadBalancer: loadBalancer, redeemService: redeemService, subscriptionSvc: subscriptionSvc, configService: configService, userRepo: userRepo, groupRepo: groupRepo}
|
||||
svc := &PaymentService{entClient: entClient, registry: registry, loadBalancer: newVisibleMethodLoadBalancer(loadBalancer, configService), redeemService: redeemService, subscriptionSvc: subscriptionSvc, configService: configService, userRepo: userRepo, groupRepo: groupRepo}
|
||||
svc.resumeService = NewPaymentResumeService(psResumeSigningKey(configService))
|
||||
return svc
|
||||
}
|
||||
|
||||
// --- Provider Registry ---
|
||||
@@ -262,6 +268,20 @@ func psNilIfEmpty(s string) *string {
|
||||
return &s
|
||||
}
|
||||
|
||||
func (s *PaymentService) paymentResume() *PaymentResumeService {
|
||||
if s.resumeService != nil {
|
||||
return s.resumeService
|
||||
}
|
||||
return NewPaymentResumeService(psResumeSigningKey(s.configService))
|
||||
}
|
||||
|
||||
func psResumeSigningKey(configService *PaymentConfigService) []byte {
|
||||
if configService == nil {
|
||||
return nil
|
||||
}
|
||||
return configService.encryptionKey
|
||||
}
|
||||
|
||||
func psSliceContains(sl []string, s string) bool {
|
||||
for _, v := range sl {
|
||||
if v == s {
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/url"
|
||||
"os"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -114,6 +115,66 @@ type SettingService struct {
|
||||
webSearchManagerBuilder WebSearchManagerBuilder
|
||||
}
|
||||
|
||||
type ProviderDefaultGrantSettings struct {
|
||||
Balance float64
|
||||
Concurrency int
|
||||
Subscriptions []DefaultSubscriptionSetting
|
||||
GrantOnSignup bool
|
||||
GrantOnFirstBind bool
|
||||
}
|
||||
|
||||
type AuthSourceDefaultSettings struct {
|
||||
Email ProviderDefaultGrantSettings
|
||||
LinuxDo ProviderDefaultGrantSettings
|
||||
OIDC ProviderDefaultGrantSettings
|
||||
WeChat ProviderDefaultGrantSettings
|
||||
ForceEmailOnThirdPartySignup bool
|
||||
}
|
||||
|
||||
type authSourceDefaultKeySet struct {
|
||||
balance string
|
||||
concurrency string
|
||||
subscriptions string
|
||||
grantOnSignup string
|
||||
grantOnFirstBind string
|
||||
}
|
||||
|
||||
var (
|
||||
emailAuthSourceDefaultKeys = authSourceDefaultKeySet{
|
||||
balance: SettingKeyAuthSourceDefaultEmailBalance,
|
||||
concurrency: SettingKeyAuthSourceDefaultEmailConcurrency,
|
||||
subscriptions: SettingKeyAuthSourceDefaultEmailSubscriptions,
|
||||
grantOnSignup: SettingKeyAuthSourceDefaultEmailGrantOnSignup,
|
||||
grantOnFirstBind: SettingKeyAuthSourceDefaultEmailGrantOnFirstBind,
|
||||
}
|
||||
linuxDoAuthSourceDefaultKeys = authSourceDefaultKeySet{
|
||||
balance: SettingKeyAuthSourceDefaultLinuxDoBalance,
|
||||
concurrency: SettingKeyAuthSourceDefaultLinuxDoConcurrency,
|
||||
subscriptions: SettingKeyAuthSourceDefaultLinuxDoSubscriptions,
|
||||
grantOnSignup: SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup,
|
||||
grantOnFirstBind: SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind,
|
||||
}
|
||||
oidcAuthSourceDefaultKeys = authSourceDefaultKeySet{
|
||||
balance: SettingKeyAuthSourceDefaultOIDCBalance,
|
||||
concurrency: SettingKeyAuthSourceDefaultOIDCConcurrency,
|
||||
subscriptions: SettingKeyAuthSourceDefaultOIDCSubscriptions,
|
||||
grantOnSignup: SettingKeyAuthSourceDefaultOIDCGrantOnSignup,
|
||||
grantOnFirstBind: SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind,
|
||||
}
|
||||
weChatAuthSourceDefaultKeys = authSourceDefaultKeySet{
|
||||
balance: SettingKeyAuthSourceDefaultWeChatBalance,
|
||||
concurrency: SettingKeyAuthSourceDefaultWeChatConcurrency,
|
||||
subscriptions: SettingKeyAuthSourceDefaultWeChatSubscriptions,
|
||||
grantOnSignup: SettingKeyAuthSourceDefaultWeChatGrantOnSignup,
|
||||
grantOnFirstBind: SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind,
|
||||
}
|
||||
)
|
||||
|
||||
const (
|
||||
defaultAuthSourceBalance = 0
|
||||
defaultAuthSourceConcurrency = 5
|
||||
)
|
||||
|
||||
// NewSettingService 创建系统设置服务实例
|
||||
func NewSettingService(settingRepo SettingRepository, cfg *config.Config) *SettingService {
|
||||
return &SettingService{
|
||||
@@ -212,6 +273,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
||||
if oidcProviderName == "" {
|
||||
oidcProviderName = "OIDC"
|
||||
}
|
||||
weChatEnabled := isWeChatOAuthConfigured()
|
||||
|
||||
// Password reset requires email verification to be enabled
|
||||
emailVerifyEnabled := settings[SettingKeyEmailVerifyEnabled] == "true"
|
||||
@@ -254,6 +316,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
||||
CustomMenuItems: settings[SettingKeyCustomMenuItems],
|
||||
CustomEndpoints: settings[SettingKeyCustomEndpoints],
|
||||
LinuxDoOAuthEnabled: linuxDoEnabled,
|
||||
WeChatOAuthEnabled: weChatEnabled,
|
||||
BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true",
|
||||
PaymentEnabled: settings[SettingPaymentEnabled] == "true",
|
||||
OIDCOAuthEnabled: oidcEnabled,
|
||||
@@ -310,6 +373,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
|
||||
CustomMenuItems json.RawMessage `json:"custom_menu_items"`
|
||||
CustomEndpoints json.RawMessage `json:"custom_endpoints"`
|
||||
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
||||
WeChatOAuthEnabled bool `json:"wechat_oauth_enabled"`
|
||||
BackendModeEnabled bool `json:"backend_mode_enabled"`
|
||||
PaymentEnabled bool `json:"payment_enabled"`
|
||||
OIDCOAuthEnabled bool `json:"oidc_oauth_enabled"`
|
||||
@@ -344,6 +408,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
|
||||
CustomMenuItems: filterUserVisibleMenuItems(settings.CustomMenuItems),
|
||||
CustomEndpoints: safeRawJSONArray(settings.CustomEndpoints),
|
||||
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
|
||||
WeChatOAuthEnabled: settings.WeChatOAuthEnabled,
|
||||
BackendModeEnabled: settings.BackendModeEnabled,
|
||||
PaymentEnabled: settings.PaymentEnabled,
|
||||
OIDCOAuthEnabled: settings.OIDCOAuthEnabled,
|
||||
@@ -392,6 +457,14 @@ func filterUserVisibleMenuItems(raw string) json.RawMessage {
|
||||
return result
|
||||
}
|
||||
|
||||
func isWeChatOAuthConfigured() bool {
|
||||
openConfigured := strings.TrimSpace(os.Getenv("WECHAT_OAUTH_OPEN_APP_ID")) != "" &&
|
||||
strings.TrimSpace(os.Getenv("WECHAT_OAUTH_OPEN_APP_SECRET")) != ""
|
||||
mpConfigured := strings.TrimSpace(os.Getenv("WECHAT_OAUTH_MP_APP_ID")) != "" &&
|
||||
strings.TrimSpace(os.Getenv("WECHAT_OAUTH_MP_APP_SECRET")) != ""
|
||||
return openConfigured || mpConfigured
|
||||
}
|
||||
|
||||
// safeRawJSONArray returns raw as json.RawMessage if it's valid JSON, otherwise "[]".
|
||||
func safeRawJSONArray(raw string) json.RawMessage {
|
||||
raw = strings.TrimSpace(raw)
|
||||
@@ -919,6 +992,74 @@ func (s *SettingService) GetDefaultSubscriptions(ctx context.Context) []DefaultS
|
||||
return parseDefaultSubscriptions(value)
|
||||
}
|
||||
|
||||
func (s *SettingService) GetAuthSourceDefaultSettings(ctx context.Context) (*AuthSourceDefaultSettings, error) {
|
||||
keys := []string{
|
||||
SettingKeyAuthSourceDefaultEmailBalance,
|
||||
SettingKeyAuthSourceDefaultEmailConcurrency,
|
||||
SettingKeyAuthSourceDefaultEmailSubscriptions,
|
||||
SettingKeyAuthSourceDefaultEmailGrantOnSignup,
|
||||
SettingKeyAuthSourceDefaultEmailGrantOnFirstBind,
|
||||
SettingKeyAuthSourceDefaultLinuxDoBalance,
|
||||
SettingKeyAuthSourceDefaultLinuxDoConcurrency,
|
||||
SettingKeyAuthSourceDefaultLinuxDoSubscriptions,
|
||||
SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup,
|
||||
SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind,
|
||||
SettingKeyAuthSourceDefaultOIDCBalance,
|
||||
SettingKeyAuthSourceDefaultOIDCConcurrency,
|
||||
SettingKeyAuthSourceDefaultOIDCSubscriptions,
|
||||
SettingKeyAuthSourceDefaultOIDCGrantOnSignup,
|
||||
SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind,
|
||||
SettingKeyAuthSourceDefaultWeChatBalance,
|
||||
SettingKeyAuthSourceDefaultWeChatConcurrency,
|
||||
SettingKeyAuthSourceDefaultWeChatSubscriptions,
|
||||
SettingKeyAuthSourceDefaultWeChatGrantOnSignup,
|
||||
SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind,
|
||||
SettingKeyForceEmailOnThirdPartySignup,
|
||||
}
|
||||
|
||||
settings, err := s.settingRepo.GetMultiple(ctx, keys)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get auth source default settings: %w", err)
|
||||
}
|
||||
|
||||
return &AuthSourceDefaultSettings{
|
||||
Email: parseProviderDefaultGrantSettings(settings, emailAuthSourceDefaultKeys),
|
||||
LinuxDo: parseProviderDefaultGrantSettings(settings, linuxDoAuthSourceDefaultKeys),
|
||||
OIDC: parseProviderDefaultGrantSettings(settings, oidcAuthSourceDefaultKeys),
|
||||
WeChat: parseProviderDefaultGrantSettings(settings, weChatAuthSourceDefaultKeys),
|
||||
ForceEmailOnThirdPartySignup: settings[SettingKeyForceEmailOnThirdPartySignup] == "true",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *SettingService) UpdateAuthSourceDefaultSettings(ctx context.Context, settings *AuthSourceDefaultSettings) error {
|
||||
if settings == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, subscriptions := range [][]DefaultSubscriptionSetting{
|
||||
settings.Email.Subscriptions,
|
||||
settings.LinuxDo.Subscriptions,
|
||||
settings.OIDC.Subscriptions,
|
||||
settings.WeChat.Subscriptions,
|
||||
} {
|
||||
if err := s.validateDefaultSubscriptionGroups(ctx, subscriptions); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
updates := make(map[string]string, 21)
|
||||
writeProviderDefaultGrantUpdates(updates, emailAuthSourceDefaultKeys, settings.Email)
|
||||
writeProviderDefaultGrantUpdates(updates, linuxDoAuthSourceDefaultKeys, settings.LinuxDo)
|
||||
writeProviderDefaultGrantUpdates(updates, oidcAuthSourceDefaultKeys, settings.OIDC)
|
||||
writeProviderDefaultGrantUpdates(updates, weChatAuthSourceDefaultKeys, settings.WeChat)
|
||||
updates[SettingKeyForceEmailOnThirdPartySignup] = strconv.FormatBool(settings.ForceEmailOnThirdPartySignup)
|
||||
|
||||
if err := s.settingRepo.SetMultiple(ctx, updates); err != nil {
|
||||
return fmt.Errorf("update auth source default settings: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// InitializeDefaultSettings 初始化默认设置
|
||||
func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
|
||||
// 检查是否已有设置
|
||||
@@ -933,25 +1074,46 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
|
||||
|
||||
// 初始化默认设置
|
||||
defaults := map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
SettingKeyEmailVerifyEnabled: "false",
|
||||
SettingKeyRegistrationEmailSuffixWhitelist: "[]",
|
||||
SettingKeyPromoCodeEnabled: "true", // 默认启用优惠码功能
|
||||
SettingKeySiteName: "Sub2API",
|
||||
SettingKeySiteLogo: "",
|
||||
SettingKeyPurchaseSubscriptionEnabled: "false",
|
||||
SettingKeyPurchaseSubscriptionURL: "",
|
||||
SettingKeyTableDefaultPageSize: "20",
|
||||
SettingKeyTablePageSizeOptions: "[10,20,50,100]",
|
||||
SettingKeyCustomMenuItems: "[]",
|
||||
SettingKeyCustomEndpoints: "[]",
|
||||
SettingKeyOIDCConnectEnabled: "false",
|
||||
SettingKeyOIDCConnectProviderName: "OIDC",
|
||||
SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
|
||||
SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
|
||||
SettingKeyDefaultSubscriptions: "[]",
|
||||
SettingKeySMTPPort: "587",
|
||||
SettingKeySMTPUseTLS: "false",
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
SettingKeyEmailVerifyEnabled: "false",
|
||||
SettingKeyRegistrationEmailSuffixWhitelist: "[]",
|
||||
SettingKeyPromoCodeEnabled: "true", // 默认启用优惠码功能
|
||||
SettingKeySiteName: "Sub2API",
|
||||
SettingKeySiteLogo: "",
|
||||
SettingKeyPurchaseSubscriptionEnabled: "false",
|
||||
SettingKeyPurchaseSubscriptionURL: "",
|
||||
SettingKeyTableDefaultPageSize: "20",
|
||||
SettingKeyTablePageSizeOptions: "[10,20,50,100]",
|
||||
SettingKeyCustomMenuItems: "[]",
|
||||
SettingKeyCustomEndpoints: "[]",
|
||||
SettingKeyOIDCConnectEnabled: "false",
|
||||
SettingKeyOIDCConnectProviderName: "OIDC",
|
||||
SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
|
||||
SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
|
||||
SettingKeyDefaultSubscriptions: "[]",
|
||||
SettingKeyAuthSourceDefaultEmailBalance: "0",
|
||||
SettingKeyAuthSourceDefaultEmailConcurrency: "5",
|
||||
SettingKeyAuthSourceDefaultEmailSubscriptions: "[]",
|
||||
SettingKeyAuthSourceDefaultEmailGrantOnSignup: "true",
|
||||
SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "false",
|
||||
SettingKeyAuthSourceDefaultLinuxDoBalance: "0",
|
||||
SettingKeyAuthSourceDefaultLinuxDoConcurrency: "5",
|
||||
SettingKeyAuthSourceDefaultLinuxDoSubscriptions: "[]",
|
||||
SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup: "true",
|
||||
SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind: "false",
|
||||
SettingKeyAuthSourceDefaultOIDCBalance: "0",
|
||||
SettingKeyAuthSourceDefaultOIDCConcurrency: "5",
|
||||
SettingKeyAuthSourceDefaultOIDCSubscriptions: "[]",
|
||||
SettingKeyAuthSourceDefaultOIDCGrantOnSignup: "true",
|
||||
SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind: "false",
|
||||
SettingKeyAuthSourceDefaultWeChatBalance: "0",
|
||||
SettingKeyAuthSourceDefaultWeChatConcurrency: "5",
|
||||
SettingKeyAuthSourceDefaultWeChatSubscriptions: "[]",
|
||||
SettingKeyAuthSourceDefaultWeChatGrantOnSignup: "true",
|
||||
SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind: "false",
|
||||
SettingKeyForceEmailOnThirdPartySignup: "false",
|
||||
SettingKeySMTPPort: "587",
|
||||
SettingKeySMTPUseTLS: "false",
|
||||
// Model fallback defaults
|
||||
SettingKeyEnableModelFallback: "false",
|
||||
SettingKeyFallbackModelAnthropic: "claude-3-5-sonnet-20241022",
|
||||
@@ -1164,6 +1326,8 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
|
||||
} else {
|
||||
result.OIDCConnectValidateIDToken = oidcBase.ValidateIDToken
|
||||
}
|
||||
result.OIDCConnectUsePKCE = true
|
||||
result.OIDCConnectValidateIDToken = true
|
||||
if v, ok := settings[SettingKeyOIDCConnectAllowedSigningAlgs]; ok && strings.TrimSpace(v) != "" {
|
||||
result.OIDCConnectAllowedSigningAlgs = strings.TrimSpace(v)
|
||||
} else {
|
||||
@@ -1317,6 +1481,51 @@ func parseDefaultSubscriptions(raw string) []DefaultSubscriptionSetting {
|
||||
return normalized
|
||||
}
|
||||
|
||||
func parseProviderDefaultGrantSettings(settings map[string]string, keys authSourceDefaultKeySet) ProviderDefaultGrantSettings {
|
||||
result := ProviderDefaultGrantSettings{
|
||||
Balance: defaultAuthSourceBalance,
|
||||
Concurrency: defaultAuthSourceConcurrency,
|
||||
Subscriptions: []DefaultSubscriptionSetting{},
|
||||
GrantOnSignup: true,
|
||||
GrantOnFirstBind: false,
|
||||
}
|
||||
|
||||
if v, err := strconv.ParseFloat(strings.TrimSpace(settings[keys.balance]), 64); err == nil {
|
||||
result.Balance = v
|
||||
}
|
||||
if v, err := strconv.Atoi(strings.TrimSpace(settings[keys.concurrency])); err == nil {
|
||||
result.Concurrency = v
|
||||
}
|
||||
if items := parseDefaultSubscriptions(settings[keys.subscriptions]); items != nil {
|
||||
result.Subscriptions = items
|
||||
}
|
||||
if raw, ok := settings[keys.grantOnSignup]; ok {
|
||||
result.GrantOnSignup = raw == "true"
|
||||
}
|
||||
if raw, ok := settings[keys.grantOnFirstBind]; ok {
|
||||
result.GrantOnFirstBind = raw == "true"
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func writeProviderDefaultGrantUpdates(updates map[string]string, keys authSourceDefaultKeySet, settings ProviderDefaultGrantSettings) {
|
||||
updates[keys.balance] = strconv.FormatFloat(settings.Balance, 'f', 8, 64)
|
||||
updates[keys.concurrency] = strconv.Itoa(settings.Concurrency)
|
||||
|
||||
subscriptions := settings.Subscriptions
|
||||
if subscriptions == nil {
|
||||
subscriptions = []DefaultSubscriptionSetting{}
|
||||
}
|
||||
raw, err := json.Marshal(subscriptions)
|
||||
if err != nil {
|
||||
raw = []byte("[]")
|
||||
}
|
||||
updates[keys.subscriptions] = string(raw)
|
||||
updates[keys.grantOnSignup] = strconv.FormatBool(settings.GrantOnSignup)
|
||||
updates[keys.grantOnFirstBind] = strconv.FormatBool(settings.GrantOnFirstBind)
|
||||
}
|
||||
|
||||
func parseTablePreferences(defaultPageSizeRaw, optionsRaw string) (int, []int) {
|
||||
defaultPageSize := 20
|
||||
if v, err := strconv.Atoi(strings.TrimSpace(defaultPageSizeRaw)); err == nil {
|
||||
@@ -1539,6 +1748,7 @@ func (s *SettingService) GetLinuxDoConnectOAuthConfig(ctx context.Context) (conf
|
||||
if v, ok := settings[SettingKeyLinuxDoConnectRedirectURL]; ok && strings.TrimSpace(v) != "" {
|
||||
effective.RedirectURL = strings.TrimSpace(v)
|
||||
}
|
||||
effective.UsePKCE = true
|
||||
|
||||
if !effective.Enabled {
|
||||
return config.LinuxDoConnectConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "oauth login is disabled")
|
||||
@@ -1587,9 +1797,6 @@ func (s *SettingService) GetLinuxDoConnectOAuthConfig(ctx context.Context) (conf
|
||||
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth client secret not configured")
|
||||
}
|
||||
case "none":
|
||||
if !effective.UsePKCE {
|
||||
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth pkce must be enabled when token_auth_method=none")
|
||||
}
|
||||
default:
|
||||
return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth token_auth_method invalid")
|
||||
}
|
||||
@@ -1737,6 +1944,8 @@ func (s *SettingService) GetOIDCConnectOAuthConfig(ctx context.Context) (config.
|
||||
if raw, ok := settings[SettingKeyOIDCConnectValidateIDToken]; ok {
|
||||
effective.ValidateIDToken = raw == "true"
|
||||
}
|
||||
effective.UsePKCE = true
|
||||
effective.ValidateIDToken = true
|
||||
if v, ok := settings[SettingKeyOIDCConnectAllowedSigningAlgs]; ok && strings.TrimSpace(v) != "" {
|
||||
effective.AllowedSigningAlgs = strings.TrimSpace(v)
|
||||
}
|
||||
@@ -1864,9 +2073,6 @@ func (s *SettingService) GetOIDCConnectOAuthConfig(ctx context.Context) (config.
|
||||
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth client secret not configured")
|
||||
}
|
||||
case "none":
|
||||
if !effective.UsePKCE {
|
||||
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth pkce must be enabled when token_auth_method=none")
|
||||
}
|
||||
default:
|
||||
return config.OIDCConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth token_auth_method invalid")
|
||||
}
|
||||
|
||||
@@ -0,0 +1,136 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type authSourceDefaultsRepoStub struct {
|
||||
values map[string]string
|
||||
updates map[string]string
|
||||
}
|
||||
|
||||
func (s *authSourceDefaultsRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
|
||||
panic("unexpected Get call")
|
||||
}
|
||||
|
||||
func (s *authSourceDefaultsRepoStub) GetValue(ctx context.Context, key string) (string, error) {
|
||||
panic("unexpected GetValue call")
|
||||
}
|
||||
|
||||
func (s *authSourceDefaultsRepoStub) Set(ctx context.Context, key, value string) error {
|
||||
panic("unexpected Set call")
|
||||
}
|
||||
|
||||
func (s *authSourceDefaultsRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
|
||||
out := make(map[string]string, len(keys))
|
||||
for _, key := range keys {
|
||||
if value, ok := s.values[key]; ok {
|
||||
out[key] = value
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *authSourceDefaultsRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
|
||||
s.updates = make(map[string]string, len(settings))
|
||||
for key, value := range settings {
|
||||
s.updates[key] = value
|
||||
if s.values == nil {
|
||||
s.values = map[string]string{}
|
||||
}
|
||||
s.values[key] = value
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *authSourceDefaultsRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
|
||||
panic("unexpected GetAll call")
|
||||
}
|
||||
|
||||
func (s *authSourceDefaultsRepoStub) Delete(ctx context.Context, key string) error {
|
||||
panic("unexpected Delete call")
|
||||
}
|
||||
|
||||
func TestSettingService_GetAuthSourceDefaultSettings_ParsesValuesAndDefaults(t *testing.T) {
|
||||
repo := &authSourceDefaultsRepoStub{
|
||||
values: map[string]string{
|
||||
SettingKeyAuthSourceDefaultEmailBalance: "12.5",
|
||||
SettingKeyAuthSourceDefaultEmailConcurrency: "7",
|
||||
SettingKeyAuthSourceDefaultEmailSubscriptions: `[{"group_id":11,"validity_days":30}]`,
|
||||
SettingKeyAuthSourceDefaultEmailGrantOnSignup: "false",
|
||||
SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind: "true",
|
||||
SettingKeyForceEmailOnThirdPartySignup: "true",
|
||||
},
|
||||
}
|
||||
svc := NewSettingService(repo, &config.Config{})
|
||||
|
||||
got, err := svc.GetAuthSourceDefaultSettings(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 12.5, got.Email.Balance)
|
||||
require.Equal(t, 7, got.Email.Concurrency)
|
||||
require.Equal(t, []DefaultSubscriptionSetting{{GroupID: 11, ValidityDays: 30}}, got.Email.Subscriptions)
|
||||
require.False(t, got.Email.GrantOnSignup)
|
||||
require.False(t, got.Email.GrantOnFirstBind)
|
||||
require.Equal(t, 0.0, got.LinuxDo.Balance)
|
||||
require.Equal(t, 5, got.LinuxDo.Concurrency)
|
||||
require.Equal(t, []DefaultSubscriptionSetting{}, got.LinuxDo.Subscriptions)
|
||||
require.True(t, got.LinuxDo.GrantOnSignup)
|
||||
require.True(t, got.LinuxDo.GrantOnFirstBind)
|
||||
require.Equal(t, 5, got.OIDC.Concurrency)
|
||||
require.Equal(t, 5, got.WeChat.Concurrency)
|
||||
require.True(t, got.ForceEmailOnThirdPartySignup)
|
||||
}
|
||||
|
||||
func TestSettingService_UpdateAuthSourceDefaultSettings_PersistsAllKeys(t *testing.T) {
|
||||
repo := &authSourceDefaultsRepoStub{}
|
||||
svc := NewSettingService(repo, &config.Config{})
|
||||
|
||||
err := svc.UpdateAuthSourceDefaultSettings(context.Background(), &AuthSourceDefaultSettings{
|
||||
Email: ProviderDefaultGrantSettings{
|
||||
Balance: 1.25,
|
||||
Concurrency: 3,
|
||||
Subscriptions: []DefaultSubscriptionSetting{{GroupID: 21, ValidityDays: 14}},
|
||||
GrantOnSignup: false,
|
||||
GrantOnFirstBind: true,
|
||||
},
|
||||
LinuxDo: ProviderDefaultGrantSettings{
|
||||
Balance: 2,
|
||||
Concurrency: 4,
|
||||
Subscriptions: []DefaultSubscriptionSetting{{GroupID: 22, ValidityDays: 30}},
|
||||
GrantOnSignup: true,
|
||||
GrantOnFirstBind: false,
|
||||
},
|
||||
OIDC: ProviderDefaultGrantSettings{
|
||||
Balance: 3,
|
||||
Concurrency: 5,
|
||||
Subscriptions: []DefaultSubscriptionSetting{{GroupID: 23, ValidityDays: 60}},
|
||||
GrantOnSignup: true,
|
||||
GrantOnFirstBind: true,
|
||||
},
|
||||
WeChat: ProviderDefaultGrantSettings{
|
||||
Balance: 4,
|
||||
Concurrency: 6,
|
||||
Subscriptions: []DefaultSubscriptionSetting{{GroupID: 24, ValidityDays: 90}},
|
||||
GrantOnSignup: false,
|
||||
GrantOnFirstBind: false,
|
||||
},
|
||||
ForceEmailOnThirdPartySignup: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "1.25000000", repo.updates[SettingKeyAuthSourceDefaultEmailBalance])
|
||||
require.Equal(t, "3", repo.updates[SettingKeyAuthSourceDefaultEmailConcurrency])
|
||||
require.Equal(t, "false", repo.updates[SettingKeyAuthSourceDefaultEmailGrantOnSignup])
|
||||
require.Equal(t, "true", repo.updates[SettingKeyAuthSourceDefaultEmailGrantOnFirstBind])
|
||||
require.Equal(t, "true", repo.updates[SettingKeyForceEmailOnThirdPartySignup])
|
||||
|
||||
var got []DefaultSubscriptionSetting
|
||||
require.NoError(t, json.Unmarshal([]byte(repo.updates[SettingKeyAuthSourceDefaultWeChatSubscriptions]), &got))
|
||||
require.Equal(t, []DefaultSubscriptionSetting{{GroupID: 24, ValidityDays: 90}}, got)
|
||||
}
|
||||
@@ -152,6 +152,7 @@ type PublicSettings struct {
|
||||
CustomEndpoints string // JSON array of custom endpoints
|
||||
|
||||
LinuxDoOAuthEnabled bool
|
||||
WeChatOAuthEnabled bool
|
||||
BackendModeEnabled bool
|
||||
PaymentEnabled bool
|
||||
OIDCOAuthEnabled bool
|
||||
|
||||
@@ -7,19 +7,27 @@ import (
|
||||
)
|
||||
|
||||
type User struct {
|
||||
ID int64
|
||||
Email string
|
||||
Username string
|
||||
Notes string
|
||||
PasswordHash string
|
||||
Role string
|
||||
Balance float64
|
||||
Concurrency int
|
||||
Status string
|
||||
AllowedGroups []int64
|
||||
TokenVersion int64 // Incremented on password change to invalidate existing tokens
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
ID int64
|
||||
Email string
|
||||
Username string
|
||||
Notes string
|
||||
AvatarURL string
|
||||
AvatarSource string
|
||||
AvatarMIME string
|
||||
AvatarByteSize int
|
||||
AvatarSHA256 string
|
||||
PasswordHash string
|
||||
Role string
|
||||
Balance float64
|
||||
Concurrency int
|
||||
Status string
|
||||
AllowedGroups []int64
|
||||
TokenVersion int64 // Incremented on password change to invalidate existing tokens
|
||||
SignupSource string
|
||||
LastLoginAt *time.Time
|
||||
LastActiveAt *time.Time
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
|
||||
// GroupRates 用户专属分组倍率配置
|
||||
// map[groupID]rateMultiplier
|
||||
|
||||
@@ -2,9 +2,13 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"crypto/subtle"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -17,10 +21,14 @@ var (
|
||||
ErrPasswordIncorrect = infraerrors.BadRequest("PASSWORD_INCORRECT", "current password is incorrect")
|
||||
ErrInsufficientPerms = infraerrors.Forbidden("INSUFFICIENT_PERMISSIONS", "insufficient permissions")
|
||||
ErrNotifyCodeUserRateLimit = infraerrors.TooManyRequests("NOTIFY_CODE_USER_RATE_LIMIT", "too many verification codes requested, please try again later")
|
||||
ErrAvatarInvalid = infraerrors.BadRequest("AVATAR_INVALID", "avatar must be a valid image data URL or http(s) URL")
|
||||
ErrAvatarTooLarge = infraerrors.BadRequest("AVATAR_TOO_LARGE", "avatar image must be 100KB or smaller")
|
||||
ErrAvatarNotImage = infraerrors.BadRequest("AVATAR_NOT_IMAGE", "avatar content must be an image")
|
||||
)
|
||||
|
||||
const (
|
||||
maxNotifyEmails = 3 // Maximum number of notification emails per user
|
||||
maxNotifyEmails = 3 // Maximum number of notification emails per user
|
||||
maxInlineAvatarBytes = 100 * 1024
|
||||
|
||||
// User-level rate limiting for notify email verification codes
|
||||
notifyCodeUserRateLimit = 5
|
||||
@@ -47,6 +55,9 @@ type UserRepository interface {
|
||||
GetFirstAdmin(ctx context.Context) (*User, error)
|
||||
Update(ctx context.Context, user *User) error
|
||||
Delete(ctx context.Context, id int64) error
|
||||
GetUserAvatar(ctx context.Context, userID int64) (*UserAvatar, error)
|
||||
UpsertUserAvatar(ctx context.Context, userID int64, input UpsertUserAvatarInput) (*UserAvatar, error)
|
||||
DeleteUserAvatar(ctx context.Context, userID int64) error
|
||||
|
||||
List(ctx context.Context, params pagination.PaginationParams) ([]User, *pagination.PaginationResult, error)
|
||||
ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UserListFilters) ([]User, *pagination.PaginationResult, error)
|
||||
@@ -71,11 +82,30 @@ type UserRepository interface {
|
||||
type UpdateProfileRequest struct {
|
||||
Email *string `json:"email"`
|
||||
Username *string `json:"username"`
|
||||
AvatarURL *string `json:"avatar_url"`
|
||||
Concurrency *int `json:"concurrency"`
|
||||
BalanceNotifyEnabled *bool `json:"balance_notify_enabled"`
|
||||
BalanceNotifyThreshold *float64 `json:"balance_notify_threshold"`
|
||||
}
|
||||
|
||||
type UserAvatar struct {
|
||||
StorageProvider string
|
||||
StorageKey string
|
||||
URL string
|
||||
ContentType string
|
||||
ByteSize int
|
||||
SHA256 string
|
||||
}
|
||||
|
||||
type UpsertUserAvatarInput struct {
|
||||
StorageProvider string
|
||||
StorageKey string
|
||||
URL string
|
||||
ContentType string
|
||||
ByteSize int
|
||||
SHA256 string
|
||||
}
|
||||
|
||||
// ChangePasswordRequest 修改密码请求
|
||||
type ChangePasswordRequest struct {
|
||||
CurrentPassword string `json:"current_password"`
|
||||
@@ -115,6 +145,9 @@ func (s *UserService) GetProfile(ctx context.Context, userID int64) (*User, erro
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get user: %w", err)
|
||||
}
|
||||
if err := s.hydrateUserAvatar(ctx, user); err != nil {
|
||||
return nil, fmt.Errorf("get user avatar: %w", err)
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
@@ -143,6 +176,27 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat
|
||||
user.Username = *req.Username
|
||||
}
|
||||
|
||||
if req.AvatarURL != nil {
|
||||
avatarValue := strings.TrimSpace(*req.AvatarURL)
|
||||
switch {
|
||||
case avatarValue == "":
|
||||
if err := s.userRepo.DeleteUserAvatar(ctx, userID); err != nil {
|
||||
return nil, fmt.Errorf("delete avatar: %w", err)
|
||||
}
|
||||
applyUserAvatar(user, nil)
|
||||
default:
|
||||
avatarInput, err := normalizeUserAvatarInput(avatarValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
avatar, err := s.userRepo.UpsertUserAvatar(ctx, userID, avatarInput)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("upsert avatar: %w", err)
|
||||
}
|
||||
applyUserAvatar(user, avatar)
|
||||
}
|
||||
}
|
||||
|
||||
if req.Concurrency != nil {
|
||||
user.Concurrency = *req.Concurrency
|
||||
}
|
||||
@@ -168,6 +222,87 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func applyUserAvatar(user *User, avatar *UserAvatar) {
|
||||
if user == nil {
|
||||
return
|
||||
}
|
||||
if avatar == nil {
|
||||
user.AvatarURL = ""
|
||||
user.AvatarSource = ""
|
||||
user.AvatarMIME = ""
|
||||
user.AvatarByteSize = 0
|
||||
user.AvatarSHA256 = ""
|
||||
return
|
||||
}
|
||||
|
||||
user.AvatarURL = avatar.URL
|
||||
user.AvatarSource = avatar.StorageProvider
|
||||
user.AvatarMIME = avatar.ContentType
|
||||
user.AvatarByteSize = avatar.ByteSize
|
||||
user.AvatarSHA256 = avatar.SHA256
|
||||
}
|
||||
|
||||
func normalizeUserAvatarInput(raw string) (UpsertUserAvatarInput, error) {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return UpsertUserAvatarInput{}, ErrAvatarInvalid
|
||||
}
|
||||
if strings.HasPrefix(raw, "data:") {
|
||||
return normalizeInlineUserAvatarInput(raw)
|
||||
}
|
||||
|
||||
parsed, err := url.Parse(raw)
|
||||
if err != nil || parsed == nil {
|
||||
return UpsertUserAvatarInput{}, ErrAvatarInvalid
|
||||
}
|
||||
if !strings.EqualFold(parsed.Scheme, "http") && !strings.EqualFold(parsed.Scheme, "https") {
|
||||
return UpsertUserAvatarInput{}, ErrAvatarInvalid
|
||||
}
|
||||
if strings.TrimSpace(parsed.Host) == "" {
|
||||
return UpsertUserAvatarInput{}, ErrAvatarInvalid
|
||||
}
|
||||
|
||||
return UpsertUserAvatarInput{
|
||||
StorageProvider: "remote_url",
|
||||
URL: raw,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func normalizeInlineUserAvatarInput(raw string) (UpsertUserAvatarInput, error) {
|
||||
body := strings.TrimPrefix(raw, "data:")
|
||||
meta, encoded, ok := strings.Cut(body, ",")
|
||||
if !ok {
|
||||
return UpsertUserAvatarInput{}, ErrAvatarInvalid
|
||||
}
|
||||
meta = strings.TrimSpace(meta)
|
||||
encoded = strings.TrimSpace(encoded)
|
||||
if !strings.HasSuffix(strings.ToLower(meta), ";base64") {
|
||||
return UpsertUserAvatarInput{}, ErrAvatarInvalid
|
||||
}
|
||||
|
||||
contentType := strings.TrimSpace(meta[:len(meta)-len(";base64")])
|
||||
if contentType == "" || !strings.HasPrefix(strings.ToLower(contentType), "image/") {
|
||||
return UpsertUserAvatarInput{}, ErrAvatarNotImage
|
||||
}
|
||||
|
||||
decoded, err := base64.StdEncoding.DecodeString(encoded)
|
||||
if err != nil {
|
||||
return UpsertUserAvatarInput{}, ErrAvatarInvalid
|
||||
}
|
||||
if len(decoded) > maxInlineAvatarBytes {
|
||||
return UpsertUserAvatarInput{}, ErrAvatarTooLarge
|
||||
}
|
||||
|
||||
sum := sha256.Sum256(decoded)
|
||||
return UpsertUserAvatarInput{
|
||||
StorageProvider: "inline",
|
||||
URL: raw,
|
||||
ContentType: contentType,
|
||||
ByteSize: len(decoded),
|
||||
SHA256: hex.EncodeToString(sum[:]),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ChangePassword 修改密码
|
||||
// Security: Increments TokenVersion to invalidate all existing JWT tokens
|
||||
func (s *UserService) ChangePassword(ctx context.Context, userID int64, req ChangePasswordRequest) error {
|
||||
@@ -202,9 +337,25 @@ func (s *UserService) GetByID(ctx context.Context, id int64) (*User, error) {
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get user: %w", err)
|
||||
}
|
||||
if err := s.hydrateUserAvatar(ctx, user); err != nil {
|
||||
return nil, fmt.Errorf("get user avatar: %w", err)
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (s *UserService) hydrateUserAvatar(ctx context.Context, user *User) error {
|
||||
if s == nil || s.userRepo == nil || user == nil || user.ID == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
avatar, err := s.userRepo.GetUserAvatar(ctx, user.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
applyUserAvatar(user, avatar)
|
||||
return nil
|
||||
}
|
||||
|
||||
// List 获取用户列表(管理员功能)
|
||||
func (s *UserService) List(ctx context.Context, params pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
|
||||
users, pagination, err := s.userRepo.List(ctx, params)
|
||||
|
||||
@@ -4,6 +4,9 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
@@ -19,14 +22,65 @@ import (
|
||||
type mockUserRepo struct {
|
||||
updateBalanceErr error
|
||||
updateBalanceFn func(ctx context.Context, id int64, amount float64) error
|
||||
getByIDUser *User
|
||||
getByIDErr error
|
||||
updateFn func(ctx context.Context, user *User) error
|
||||
updateCalls int
|
||||
upsertAvatarFn func(ctx context.Context, userID int64, input UpsertUserAvatarInput) (*UserAvatar, error)
|
||||
upsertAvatarArgs []UpsertUserAvatarInput
|
||||
deleteAvatarFn func(ctx context.Context, userID int64) error
|
||||
deleteAvatarIDs []int64
|
||||
getAvatarFn func(ctx context.Context, userID int64) (*UserAvatar, error)
|
||||
}
|
||||
|
||||
func (m *mockUserRepo) Create(context.Context, *User) error { return nil }
|
||||
func (m *mockUserRepo) GetByID(context.Context, int64) (*User, error) { return &User{}, nil }
|
||||
func (m *mockUserRepo) Create(context.Context, *User) error { return nil }
|
||||
func (m *mockUserRepo) GetByID(context.Context, int64) (*User, error) {
|
||||
if m.getByIDErr != nil {
|
||||
return nil, m.getByIDErr
|
||||
}
|
||||
if m.getByIDUser != nil {
|
||||
cloned := *m.getByIDUser
|
||||
return &cloned, nil
|
||||
}
|
||||
return &User{}, nil
|
||||
}
|
||||
func (m *mockUserRepo) GetByEmail(context.Context, string) (*User, error) { return &User{}, nil }
|
||||
func (m *mockUserRepo) GetFirstAdmin(context.Context) (*User, error) { return &User{}, nil }
|
||||
func (m *mockUserRepo) Update(context.Context, *User) error { return nil }
|
||||
func (m *mockUserRepo) Delete(context.Context, int64) error { return nil }
|
||||
func (m *mockUserRepo) Update(ctx context.Context, user *User) error {
|
||||
m.updateCalls++
|
||||
if m.updateFn != nil {
|
||||
return m.updateFn(ctx, user)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (m *mockUserRepo) Delete(context.Context, int64) error { return nil }
|
||||
func (m *mockUserRepo) GetUserAvatar(ctx context.Context, userID int64) (*UserAvatar, error) {
|
||||
if m.getAvatarFn != nil {
|
||||
return m.getAvatarFn(ctx, userID)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockUserRepo) UpsertUserAvatar(ctx context.Context, userID int64, input UpsertUserAvatarInput) (*UserAvatar, error) {
|
||||
m.upsertAvatarArgs = append(m.upsertAvatarArgs, input)
|
||||
if m.upsertAvatarFn != nil {
|
||||
return m.upsertAvatarFn(ctx, userID, input)
|
||||
}
|
||||
return &UserAvatar{
|
||||
StorageProvider: input.StorageProvider,
|
||||
StorageKey: input.StorageKey,
|
||||
URL: input.URL,
|
||||
ContentType: input.ContentType,
|
||||
ByteSize: input.ByteSize,
|
||||
SHA256: input.SHA256,
|
||||
}, nil
|
||||
}
|
||||
func (m *mockUserRepo) DeleteUserAvatar(ctx context.Context, userID int64) error {
|
||||
m.deleteAvatarIDs = append(m.deleteAvatarIDs, userID)
|
||||
if m.deleteAvatarFn != nil {
|
||||
return m.deleteAvatarFn(ctx, userID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (m *mockUserRepo) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
@@ -200,3 +254,121 @@ func TestNewUserService_FieldsAssignment(t *testing.T) {
|
||||
require.Equal(t, auth, svc.authCacheInvalidator)
|
||||
require.Equal(t, cache, svc.billingCache)
|
||||
}
|
||||
|
||||
func TestUpdateProfile_StoresInlineAvatarWithinLimit(t *testing.T) {
|
||||
raw := []byte("small-avatar")
|
||||
dataURL := "data:image/png;base64," + base64.StdEncoding.EncodeToString(raw)
|
||||
expectedSum := sha256.Sum256(raw)
|
||||
repo := &mockUserRepo{
|
||||
getByIDUser: &User{
|
||||
ID: 7,
|
||||
Email: "avatar@example.com",
|
||||
Username: "avatar-user",
|
||||
},
|
||||
}
|
||||
svc := NewUserService(repo, nil, nil, nil)
|
||||
|
||||
updated, err := svc.UpdateProfile(context.Background(), 7, UpdateProfileRequest{
|
||||
AvatarURL: &dataURL,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, repo.upsertAvatarArgs, 1)
|
||||
require.Equal(t, "inline", repo.upsertAvatarArgs[0].StorageProvider)
|
||||
require.Equal(t, "image/png", repo.upsertAvatarArgs[0].ContentType)
|
||||
require.Equal(t, len(raw), repo.upsertAvatarArgs[0].ByteSize)
|
||||
require.Equal(t, hex.EncodeToString(expectedSum[:]), repo.upsertAvatarArgs[0].SHA256)
|
||||
require.Equal(t, dataURL, updated.AvatarURL)
|
||||
require.Equal(t, "inline", updated.AvatarSource)
|
||||
require.Equal(t, "image/png", updated.AvatarMIME)
|
||||
require.Equal(t, len(raw), updated.AvatarByteSize)
|
||||
require.Equal(t, hex.EncodeToString(expectedSum[:]), updated.AvatarSHA256)
|
||||
}
|
||||
|
||||
func TestUpdateProfile_RejectsInlineAvatarOverLimit(t *testing.T) {
|
||||
raw := make([]byte, maxInlineAvatarBytes+1)
|
||||
dataURL := "data:image/png;base64," + base64.StdEncoding.EncodeToString(raw)
|
||||
repo := &mockUserRepo{
|
||||
getByIDUser: &User{
|
||||
ID: 8,
|
||||
Email: "large-avatar@example.com",
|
||||
Username: "too-large",
|
||||
},
|
||||
}
|
||||
svc := NewUserService(repo, nil, nil, nil)
|
||||
|
||||
_, err := svc.UpdateProfile(context.Background(), 8, UpdateProfileRequest{
|
||||
AvatarURL: &dataURL,
|
||||
})
|
||||
require.ErrorIs(t, err, ErrAvatarTooLarge)
|
||||
require.Empty(t, repo.upsertAvatarArgs)
|
||||
require.Empty(t, repo.deleteAvatarIDs)
|
||||
require.Zero(t, repo.updateCalls)
|
||||
}
|
||||
|
||||
func TestUpdateProfile_StoresRemoteAvatarURL(t *testing.T) {
|
||||
remoteURL := "https://cdn.example.com/avatar.png"
|
||||
repo := &mockUserRepo{
|
||||
getByIDUser: &User{
|
||||
ID: 9,
|
||||
Email: "remote-avatar@example.com",
|
||||
Username: "remote-avatar",
|
||||
},
|
||||
}
|
||||
svc := NewUserService(repo, nil, nil, nil)
|
||||
|
||||
updated, err := svc.UpdateProfile(context.Background(), 9, UpdateProfileRequest{
|
||||
AvatarURL: &remoteURL,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, repo.upsertAvatarArgs, 1)
|
||||
require.Equal(t, "remote_url", repo.upsertAvatarArgs[0].StorageProvider)
|
||||
require.Equal(t, remoteURL, repo.upsertAvatarArgs[0].URL)
|
||||
require.Equal(t, remoteURL, updated.AvatarURL)
|
||||
require.Equal(t, "remote_url", updated.AvatarSource)
|
||||
require.Zero(t, updated.AvatarByteSize)
|
||||
}
|
||||
|
||||
func TestUpdateProfile_DeletesAvatarOnEmptyString(t *testing.T) {
|
||||
empty := ""
|
||||
repo := &mockUserRepo{
|
||||
getByIDUser: &User{
|
||||
ID: 10,
|
||||
Email: "delete-avatar@example.com",
|
||||
Username: "delete-avatar",
|
||||
AvatarURL: "https://cdn.example.com/old.png",
|
||||
AvatarSource: "remote_url",
|
||||
},
|
||||
}
|
||||
svc := NewUserService(repo, nil, nil, nil)
|
||||
|
||||
updated, err := svc.UpdateProfile(context.Background(), 10, UpdateProfileRequest{
|
||||
AvatarURL: &empty,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []int64{10}, repo.deleteAvatarIDs)
|
||||
require.Empty(t, repo.upsertAvatarArgs)
|
||||
require.Empty(t, updated.AvatarURL)
|
||||
require.Empty(t, updated.AvatarSource)
|
||||
}
|
||||
|
||||
func TestGetProfile_HydratesAvatarFromRepository(t *testing.T) {
|
||||
repo := &mockUserRepo{
|
||||
getByIDUser: &User{
|
||||
ID: 12,
|
||||
Email: "profile-avatar@example.com",
|
||||
Username: "profile-avatar",
|
||||
},
|
||||
getAvatarFn: func(context.Context, int64) (*UserAvatar, error) {
|
||||
return &UserAvatar{
|
||||
StorageProvider: "remote_url",
|
||||
URL: "https://cdn.example.com/profile.png",
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
svc := NewUserService(repo, nil, nil, nil)
|
||||
|
||||
user, err := svc.GetProfile(context.Background(), 12)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "https://cdn.example.com/profile.png", user.AvatarURL)
|
||||
require.Equal(t, "remote_url", user.AvatarSource)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user