diff --git a/README.md b/README.md index fa965e6f..e8e9c5a5 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@ English | [中文](README_CN.md) ## Demo -Try Sub2API online: **https://v2.pincc.ai/** +Try Sub2API online: **https://demo.sub2api.org/** Demo credentials (shared demo environment; **not** created automatically for self-hosted installs): diff --git a/backend/cmd/server/wire.go b/backend/cmd/server/wire.go index 5ef04a66..d9ff788e 100644 --- a/backend/cmd/server/wire.go +++ b/backend/cmd/server/wire.go @@ -70,6 +70,7 @@ func provideCleanup( schedulerSnapshot *service.SchedulerSnapshotService, tokenRefresh *service.TokenRefreshService, accountExpiry *service.AccountExpiryService, + subscriptionExpiry *service.SubscriptionExpiryService, usageCleanup *service.UsageCleanupService, pricing *service.PricingService, emailQueue *service.EmailQueueService, @@ -138,6 +139,10 @@ func provideCleanup( accountExpiry.Stop() return nil }}, + {"SubscriptionExpiryService", func() error { + subscriptionExpiry.Stop() + return nil + }}, {"PricingService", func() error { pricing.Stop() return nil diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 7b22a31e..71624091 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -63,7 +63,13 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator) authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService) userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator) - authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService) + secretEncryptor, err := repository.NewAESEncryptor(configConfig) + if err != nil { + return nil, err + } + totpCache := repository.NewTotpCache(redisClient) + totpService := service.NewTotpService(userRepository, secretEncryptor, totpCache, settingService, emailService, emailQueueService) + authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService, totpService) userHandler := handler.NewUserHandler(userService) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) usageLogRepository := repository.NewUsageLogRepository(client, db) @@ -165,7 +171,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, configConfig) openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, configConfig) handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo) - handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler) + totpHandler := handler.NewTotpHandler(totpService) + handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler, totpHandler) jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService) adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService) apiKeyAuthMiddleware := middleware.NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig) @@ -178,7 +185,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig) tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, configConfig) accountExpiryService := service.ProvideAccountExpiryService(accountRepository) - v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, usageCleanupService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService) + subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository) + v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService) application := &Application{ Server: httpServer, Cleanup: v, @@ -211,6 +219,7 @@ func provideCleanup( schedulerSnapshot *service.SchedulerSnapshotService, tokenRefresh *service.TokenRefreshService, accountExpiry *service.AccountExpiryService, + subscriptionExpiry *service.SubscriptionExpiryService, usageCleanup *service.UsageCleanupService, pricing *service.PricingService, emailQueue *service.EmailQueueService, @@ -278,6 +287,10 @@ func provideCleanup( accountExpiry.Stop() return nil }}, + {"SubscriptionExpiryService", func() error { + subscriptionExpiry.Stop() + return nil + }}, {"PricingService", func() error { pricing.Stop() return nil diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go index d1f05186..d2a39331 100644 --- a/backend/ent/migrate/schema.go +++ b/backend/ent/migrate/schema.go @@ -610,6 +610,9 @@ var ( {Name: "status", Type: field.TypeString, Size: 20, Default: "active"}, {Name: "username", Type: field.TypeString, Size: 100, Default: ""}, {Name: "notes", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}}, + {Name: "totp_secret_encrypted", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}}, + {Name: "totp_enabled", Type: field.TypeBool, Default: false}, + {Name: "totp_enabled_at", Type: field.TypeTime, Nullable: true}, } // UsersTable holds the schema information for the "users" table. UsersTable = &schema.Table{ diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go index 9b330616..7f3071c2 100644 --- a/backend/ent/mutation.go +++ b/backend/ent/mutation.go @@ -14360,6 +14360,9 @@ type UserMutation struct { status *string username *string notes *string + totp_secret_encrypted *string + totp_enabled *bool + totp_enabled_at *time.Time clearedFields map[string]struct{} api_keys map[int64]struct{} removedapi_keys map[int64]struct{} @@ -14937,6 +14940,140 @@ func (m *UserMutation) ResetNotes() { m.notes = nil } +// SetTotpSecretEncrypted sets the "totp_secret_encrypted" field. +func (m *UserMutation) SetTotpSecretEncrypted(s string) { + m.totp_secret_encrypted = &s +} + +// TotpSecretEncrypted returns the value of the "totp_secret_encrypted" field in the mutation. +func (m *UserMutation) TotpSecretEncrypted() (r string, exists bool) { + v := m.totp_secret_encrypted + if v == nil { + return + } + return *v, true +} + +// OldTotpSecretEncrypted returns the old "totp_secret_encrypted" field's value of the User entity. +// If the User object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserMutation) OldTotpSecretEncrypted(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldTotpSecretEncrypted is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldTotpSecretEncrypted requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldTotpSecretEncrypted: %w", err) + } + return oldValue.TotpSecretEncrypted, nil +} + +// ClearTotpSecretEncrypted clears the value of the "totp_secret_encrypted" field. +func (m *UserMutation) ClearTotpSecretEncrypted() { + m.totp_secret_encrypted = nil + m.clearedFields[user.FieldTotpSecretEncrypted] = struct{}{} +} + +// TotpSecretEncryptedCleared returns if the "totp_secret_encrypted" field was cleared in this mutation. +func (m *UserMutation) TotpSecretEncryptedCleared() bool { + _, ok := m.clearedFields[user.FieldTotpSecretEncrypted] + return ok +} + +// ResetTotpSecretEncrypted resets all changes to the "totp_secret_encrypted" field. +func (m *UserMutation) ResetTotpSecretEncrypted() { + m.totp_secret_encrypted = nil + delete(m.clearedFields, user.FieldTotpSecretEncrypted) +} + +// SetTotpEnabled sets the "totp_enabled" field. +func (m *UserMutation) SetTotpEnabled(b bool) { + m.totp_enabled = &b +} + +// TotpEnabled returns the value of the "totp_enabled" field in the mutation. +func (m *UserMutation) TotpEnabled() (r bool, exists bool) { + v := m.totp_enabled + if v == nil { + return + } + return *v, true +} + +// OldTotpEnabled returns the old "totp_enabled" field's value of the User entity. +// If the User object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserMutation) OldTotpEnabled(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldTotpEnabled is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldTotpEnabled requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldTotpEnabled: %w", err) + } + return oldValue.TotpEnabled, nil +} + +// ResetTotpEnabled resets all changes to the "totp_enabled" field. +func (m *UserMutation) ResetTotpEnabled() { + m.totp_enabled = nil +} + +// SetTotpEnabledAt sets the "totp_enabled_at" field. +func (m *UserMutation) SetTotpEnabledAt(t time.Time) { + m.totp_enabled_at = &t +} + +// TotpEnabledAt returns the value of the "totp_enabled_at" field in the mutation. +func (m *UserMutation) TotpEnabledAt() (r time.Time, exists bool) { + v := m.totp_enabled_at + if v == nil { + return + } + return *v, true +} + +// OldTotpEnabledAt returns the old "totp_enabled_at" field's value of the User entity. +// If the User object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserMutation) OldTotpEnabledAt(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldTotpEnabledAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldTotpEnabledAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldTotpEnabledAt: %w", err) + } + return oldValue.TotpEnabledAt, nil +} + +// ClearTotpEnabledAt clears the value of the "totp_enabled_at" field. +func (m *UserMutation) ClearTotpEnabledAt() { + m.totp_enabled_at = nil + m.clearedFields[user.FieldTotpEnabledAt] = struct{}{} +} + +// TotpEnabledAtCleared returns if the "totp_enabled_at" field was cleared in this mutation. +func (m *UserMutation) TotpEnabledAtCleared() bool { + _, ok := m.clearedFields[user.FieldTotpEnabledAt] + return ok +} + +// ResetTotpEnabledAt resets all changes to the "totp_enabled_at" field. +func (m *UserMutation) ResetTotpEnabledAt() { + m.totp_enabled_at = nil + delete(m.clearedFields, user.FieldTotpEnabledAt) +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids. func (m *UserMutation) AddAPIKeyIDs(ids ...int64) { if m.api_keys == nil { @@ -15403,7 +15540,7 @@ func (m *UserMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *UserMutation) Fields() []string { - fields := make([]string, 0, 11) + fields := make([]string, 0, 14) if m.created_at != nil { fields = append(fields, user.FieldCreatedAt) } @@ -15437,6 +15574,15 @@ func (m *UserMutation) Fields() []string { if m.notes != nil { fields = append(fields, user.FieldNotes) } + if m.totp_secret_encrypted != nil { + fields = append(fields, user.FieldTotpSecretEncrypted) + } + if m.totp_enabled != nil { + fields = append(fields, user.FieldTotpEnabled) + } + if m.totp_enabled_at != nil { + fields = append(fields, user.FieldTotpEnabledAt) + } return fields } @@ -15467,6 +15613,12 @@ func (m *UserMutation) Field(name string) (ent.Value, bool) { return m.Username() case user.FieldNotes: return m.Notes() + case user.FieldTotpSecretEncrypted: + return m.TotpSecretEncrypted() + case user.FieldTotpEnabled: + return m.TotpEnabled() + case user.FieldTotpEnabledAt: + return m.TotpEnabledAt() } return nil, false } @@ -15498,6 +15650,12 @@ func (m *UserMutation) OldField(ctx context.Context, name string) (ent.Value, er return m.OldUsername(ctx) case user.FieldNotes: return m.OldNotes(ctx) + case user.FieldTotpSecretEncrypted: + return m.OldTotpSecretEncrypted(ctx) + case user.FieldTotpEnabled: + return m.OldTotpEnabled(ctx) + case user.FieldTotpEnabledAt: + return m.OldTotpEnabledAt(ctx) } return nil, fmt.Errorf("unknown User field %s", name) } @@ -15584,6 +15742,27 @@ func (m *UserMutation) SetField(name string, value ent.Value) error { } m.SetNotes(v) return nil + case user.FieldTotpSecretEncrypted: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetTotpSecretEncrypted(v) + return nil + case user.FieldTotpEnabled: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetTotpEnabled(v) + return nil + case user.FieldTotpEnabledAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetTotpEnabledAt(v) + return nil } return fmt.Errorf("unknown User field %s", name) } @@ -15644,6 +15823,12 @@ func (m *UserMutation) ClearedFields() []string { if m.FieldCleared(user.FieldDeletedAt) { fields = append(fields, user.FieldDeletedAt) } + if m.FieldCleared(user.FieldTotpSecretEncrypted) { + fields = append(fields, user.FieldTotpSecretEncrypted) + } + if m.FieldCleared(user.FieldTotpEnabledAt) { + fields = append(fields, user.FieldTotpEnabledAt) + } return fields } @@ -15661,6 +15846,12 @@ func (m *UserMutation) ClearField(name string) error { case user.FieldDeletedAt: m.ClearDeletedAt() return nil + case user.FieldTotpSecretEncrypted: + m.ClearTotpSecretEncrypted() + return nil + case user.FieldTotpEnabledAt: + m.ClearTotpEnabledAt() + return nil } return fmt.Errorf("unknown User nullable field %s", name) } @@ -15702,6 +15893,15 @@ func (m *UserMutation) ResetField(name string) error { case user.FieldNotes: m.ResetNotes() return nil + case user.FieldTotpSecretEncrypted: + m.ResetTotpSecretEncrypted() + return nil + case user.FieldTotpEnabled: + m.ResetTotpEnabled() + return nil + case user.FieldTotpEnabledAt: + m.ResetTotpEnabledAt() + return nil } return fmt.Errorf("unknown User field %s", name) } diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go index 1e3f4cbe..14323f8c 100644 --- a/backend/ent/runtime/runtime.go +++ b/backend/ent/runtime/runtime.go @@ -736,6 +736,10 @@ func init() { userDescNotes := userFields[7].Descriptor() // user.DefaultNotes holds the default value on creation for the notes field. user.DefaultNotes = userDescNotes.Default.(string) + // userDescTotpEnabled is the schema descriptor for totp_enabled field. + userDescTotpEnabled := userFields[9].Descriptor() + // user.DefaultTotpEnabled holds the default value on creation for the totp_enabled field. + user.DefaultTotpEnabled = userDescTotpEnabled.Default.(bool) userallowedgroupFields := schema.UserAllowedGroup{}.Fields() _ = userallowedgroupFields // userallowedgroupDescCreatedAt is the schema descriptor for created_at field. diff --git a/backend/ent/schema/user.go b/backend/ent/schema/user.go index 79dc2286..335c1cc8 100644 --- a/backend/ent/schema/user.go +++ b/backend/ent/schema/user.go @@ -61,6 +61,17 @@ func (User) Fields() []ent.Field { field.String("notes"). SchemaType(map[string]string{dialect.Postgres: "text"}). Default(""), + + // TOTP 双因素认证字段 + field.String("totp_secret_encrypted"). + SchemaType(map[string]string{dialect.Postgres: "text"}). + Optional(). + Nillable(), + field.Bool("totp_enabled"). + Default(false), + field.Time("totp_enabled_at"). + Optional(). + Nillable(), } } diff --git a/backend/ent/user.go b/backend/ent/user.go index 0b9a48cc..82830a95 100644 --- a/backend/ent/user.go +++ b/backend/ent/user.go @@ -39,6 +39,12 @@ type User struct { Username string `json:"username,omitempty"` // Notes holds the value of the "notes" field. Notes string `json:"notes,omitempty"` + // TotpSecretEncrypted holds the value of the "totp_secret_encrypted" field. + TotpSecretEncrypted *string `json:"totp_secret_encrypted,omitempty"` + // TotpEnabled holds the value of the "totp_enabled" field. + TotpEnabled bool `json:"totp_enabled,omitempty"` + // TotpEnabledAt holds the value of the "totp_enabled_at" field. + TotpEnabledAt *time.Time `json:"totp_enabled_at,omitempty"` // Edges holds the relations/edges for other nodes in the graph. // The values are being populated by the UserQuery when eager-loading is set. Edges UserEdges `json:"edges"` @@ -156,13 +162,15 @@ func (*User) scanValues(columns []string) ([]any, error) { values := make([]any, len(columns)) for i := range columns { switch columns[i] { + case user.FieldTotpEnabled: + values[i] = new(sql.NullBool) case user.FieldBalance: values[i] = new(sql.NullFloat64) case user.FieldID, user.FieldConcurrency: values[i] = new(sql.NullInt64) - case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes: + case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes, user.FieldTotpSecretEncrypted: values[i] = new(sql.NullString) - case user.FieldCreatedAt, user.FieldUpdatedAt, user.FieldDeletedAt: + case user.FieldCreatedAt, user.FieldUpdatedAt, user.FieldDeletedAt, user.FieldTotpEnabledAt: values[i] = new(sql.NullTime) default: values[i] = new(sql.UnknownType) @@ -252,6 +260,26 @@ func (_m *User) assignValues(columns []string, values []any) error { } else if value.Valid { _m.Notes = value.String } + case user.FieldTotpSecretEncrypted: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field totp_secret_encrypted", values[i]) + } else if value.Valid { + _m.TotpSecretEncrypted = new(string) + *_m.TotpSecretEncrypted = value.String + } + case user.FieldTotpEnabled: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field totp_enabled", values[i]) + } else if value.Valid { + _m.TotpEnabled = value.Bool + } + case user.FieldTotpEnabledAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field totp_enabled_at", values[i]) + } else if value.Valid { + _m.TotpEnabledAt = new(time.Time) + *_m.TotpEnabledAt = value.Time + } default: _m.selectValues.Set(columns[i], values[i]) } @@ -367,6 +395,19 @@ func (_m *User) String() string { builder.WriteString(", ") builder.WriteString("notes=") builder.WriteString(_m.Notes) + builder.WriteString(", ") + if v := _m.TotpSecretEncrypted; v != nil { + builder.WriteString("totp_secret_encrypted=") + builder.WriteString(*v) + } + builder.WriteString(", ") + builder.WriteString("totp_enabled=") + builder.WriteString(fmt.Sprintf("%v", _m.TotpEnabled)) + builder.WriteString(", ") + if v := _m.TotpEnabledAt; v != nil { + builder.WriteString("totp_enabled_at=") + builder.WriteString(v.Format(time.ANSIC)) + } builder.WriteByte(')') return builder.String() } diff --git a/backend/ent/user/user.go b/backend/ent/user/user.go index 1be1d871..0685ed72 100644 --- a/backend/ent/user/user.go +++ b/backend/ent/user/user.go @@ -37,6 +37,12 @@ const ( FieldUsername = "username" // FieldNotes holds the string denoting the notes field in the database. FieldNotes = "notes" + // FieldTotpSecretEncrypted holds the string denoting the totp_secret_encrypted field in the database. + FieldTotpSecretEncrypted = "totp_secret_encrypted" + // FieldTotpEnabled holds the string denoting the totp_enabled field in the database. + FieldTotpEnabled = "totp_enabled" + // FieldTotpEnabledAt holds the string denoting the totp_enabled_at field in the database. + FieldTotpEnabledAt = "totp_enabled_at" // EdgeAPIKeys holds the string denoting the api_keys edge name in mutations. EdgeAPIKeys = "api_keys" // EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations. @@ -134,6 +140,9 @@ var Columns = []string{ FieldStatus, FieldUsername, FieldNotes, + FieldTotpSecretEncrypted, + FieldTotpEnabled, + FieldTotpEnabledAt, } var ( @@ -188,6 +197,8 @@ var ( UsernameValidator func(string) error // DefaultNotes holds the default value on creation for the "notes" field. DefaultNotes string + // DefaultTotpEnabled holds the default value on creation for the "totp_enabled" field. + DefaultTotpEnabled bool ) // OrderOption defines the ordering options for the User queries. @@ -253,6 +264,21 @@ func ByNotes(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldNotes, opts...).ToFunc() } +// ByTotpSecretEncrypted orders the results by the totp_secret_encrypted field. +func ByTotpSecretEncrypted(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldTotpSecretEncrypted, opts...).ToFunc() +} + +// ByTotpEnabled orders the results by the totp_enabled field. +func ByTotpEnabled(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldTotpEnabled, opts...).ToFunc() +} + +// ByTotpEnabledAt orders the results by the totp_enabled_at field. +func ByTotpEnabledAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldTotpEnabledAt, opts...).ToFunc() +} + // ByAPIKeysCount orders the results by api_keys count. func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption { return func(s *sql.Selector) { diff --git a/backend/ent/user/where.go b/backend/ent/user/where.go index 6a460f10..3dc4fec7 100644 --- a/backend/ent/user/where.go +++ b/backend/ent/user/where.go @@ -110,6 +110,21 @@ func Notes(v string) predicate.User { return predicate.User(sql.FieldEQ(FieldNotes, v)) } +// TotpSecretEncrypted applies equality check predicate on the "totp_secret_encrypted" field. It's identical to TotpSecretEncryptedEQ. +func TotpSecretEncrypted(v string) predicate.User { + return predicate.User(sql.FieldEQ(FieldTotpSecretEncrypted, v)) +} + +// TotpEnabled applies equality check predicate on the "totp_enabled" field. It's identical to TotpEnabledEQ. +func TotpEnabled(v bool) predicate.User { + return predicate.User(sql.FieldEQ(FieldTotpEnabled, v)) +} + +// TotpEnabledAt applies equality check predicate on the "totp_enabled_at" field. It's identical to TotpEnabledAtEQ. +func TotpEnabledAt(v time.Time) predicate.User { + return predicate.User(sql.FieldEQ(FieldTotpEnabledAt, v)) +} + // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.User { return predicate.User(sql.FieldEQ(FieldCreatedAt, v)) @@ -710,6 +725,141 @@ func NotesContainsFold(v string) predicate.User { return predicate.User(sql.FieldContainsFold(FieldNotes, v)) } +// TotpSecretEncryptedEQ applies the EQ predicate on the "totp_secret_encrypted" field. +func TotpSecretEncryptedEQ(v string) predicate.User { + return predicate.User(sql.FieldEQ(FieldTotpSecretEncrypted, v)) +} + +// TotpSecretEncryptedNEQ applies the NEQ predicate on the "totp_secret_encrypted" field. +func TotpSecretEncryptedNEQ(v string) predicate.User { + return predicate.User(sql.FieldNEQ(FieldTotpSecretEncrypted, v)) +} + +// TotpSecretEncryptedIn applies the In predicate on the "totp_secret_encrypted" field. +func TotpSecretEncryptedIn(vs ...string) predicate.User { + return predicate.User(sql.FieldIn(FieldTotpSecretEncrypted, vs...)) +} + +// TotpSecretEncryptedNotIn applies the NotIn predicate on the "totp_secret_encrypted" field. +func TotpSecretEncryptedNotIn(vs ...string) predicate.User { + return predicate.User(sql.FieldNotIn(FieldTotpSecretEncrypted, vs...)) +} + +// TotpSecretEncryptedGT applies the GT predicate on the "totp_secret_encrypted" field. +func TotpSecretEncryptedGT(v string) predicate.User { + return predicate.User(sql.FieldGT(FieldTotpSecretEncrypted, v)) +} + +// TotpSecretEncryptedGTE applies the GTE predicate on the "totp_secret_encrypted" field. +func TotpSecretEncryptedGTE(v string) predicate.User { + return predicate.User(sql.FieldGTE(FieldTotpSecretEncrypted, v)) +} + +// TotpSecretEncryptedLT applies the LT predicate on the "totp_secret_encrypted" field. +func TotpSecretEncryptedLT(v string) predicate.User { + return predicate.User(sql.FieldLT(FieldTotpSecretEncrypted, v)) +} + +// TotpSecretEncryptedLTE applies the LTE predicate on the "totp_secret_encrypted" field. +func TotpSecretEncryptedLTE(v string) predicate.User { + return predicate.User(sql.FieldLTE(FieldTotpSecretEncrypted, v)) +} + +// TotpSecretEncryptedContains applies the Contains predicate on the "totp_secret_encrypted" field. +func TotpSecretEncryptedContains(v string) predicate.User { + return predicate.User(sql.FieldContains(FieldTotpSecretEncrypted, v)) +} + +// TotpSecretEncryptedHasPrefix applies the HasPrefix predicate on the "totp_secret_encrypted" field. +func TotpSecretEncryptedHasPrefix(v string) predicate.User { + return predicate.User(sql.FieldHasPrefix(FieldTotpSecretEncrypted, v)) +} + +// TotpSecretEncryptedHasSuffix applies the HasSuffix predicate on the "totp_secret_encrypted" field. +func TotpSecretEncryptedHasSuffix(v string) predicate.User { + return predicate.User(sql.FieldHasSuffix(FieldTotpSecretEncrypted, v)) +} + +// TotpSecretEncryptedIsNil applies the IsNil predicate on the "totp_secret_encrypted" field. +func TotpSecretEncryptedIsNil() predicate.User { + return predicate.User(sql.FieldIsNull(FieldTotpSecretEncrypted)) +} + +// TotpSecretEncryptedNotNil applies the NotNil predicate on the "totp_secret_encrypted" field. +func TotpSecretEncryptedNotNil() predicate.User { + return predicate.User(sql.FieldNotNull(FieldTotpSecretEncrypted)) +} + +// TotpSecretEncryptedEqualFold applies the EqualFold predicate on the "totp_secret_encrypted" field. +func TotpSecretEncryptedEqualFold(v string) predicate.User { + return predicate.User(sql.FieldEqualFold(FieldTotpSecretEncrypted, v)) +} + +// TotpSecretEncryptedContainsFold applies the ContainsFold predicate on the "totp_secret_encrypted" field. +func TotpSecretEncryptedContainsFold(v string) predicate.User { + return predicate.User(sql.FieldContainsFold(FieldTotpSecretEncrypted, v)) +} + +// TotpEnabledEQ applies the EQ predicate on the "totp_enabled" field. +func TotpEnabledEQ(v bool) predicate.User { + return predicate.User(sql.FieldEQ(FieldTotpEnabled, v)) +} + +// TotpEnabledNEQ applies the NEQ predicate on the "totp_enabled" field. +func TotpEnabledNEQ(v bool) predicate.User { + return predicate.User(sql.FieldNEQ(FieldTotpEnabled, v)) +} + +// TotpEnabledAtEQ applies the EQ predicate on the "totp_enabled_at" field. +func TotpEnabledAtEQ(v time.Time) predicate.User { + return predicate.User(sql.FieldEQ(FieldTotpEnabledAt, v)) +} + +// TotpEnabledAtNEQ applies the NEQ predicate on the "totp_enabled_at" field. +func TotpEnabledAtNEQ(v time.Time) predicate.User { + return predicate.User(sql.FieldNEQ(FieldTotpEnabledAt, v)) +} + +// TotpEnabledAtIn applies the In predicate on the "totp_enabled_at" field. +func TotpEnabledAtIn(vs ...time.Time) predicate.User { + return predicate.User(sql.FieldIn(FieldTotpEnabledAt, vs...)) +} + +// TotpEnabledAtNotIn applies the NotIn predicate on the "totp_enabled_at" field. +func TotpEnabledAtNotIn(vs ...time.Time) predicate.User { + return predicate.User(sql.FieldNotIn(FieldTotpEnabledAt, vs...)) +} + +// TotpEnabledAtGT applies the GT predicate on the "totp_enabled_at" field. +func TotpEnabledAtGT(v time.Time) predicate.User { + return predicate.User(sql.FieldGT(FieldTotpEnabledAt, v)) +} + +// TotpEnabledAtGTE applies the GTE predicate on the "totp_enabled_at" field. +func TotpEnabledAtGTE(v time.Time) predicate.User { + return predicate.User(sql.FieldGTE(FieldTotpEnabledAt, v)) +} + +// TotpEnabledAtLT applies the LT predicate on the "totp_enabled_at" field. +func TotpEnabledAtLT(v time.Time) predicate.User { + return predicate.User(sql.FieldLT(FieldTotpEnabledAt, v)) +} + +// TotpEnabledAtLTE applies the LTE predicate on the "totp_enabled_at" field. +func TotpEnabledAtLTE(v time.Time) predicate.User { + return predicate.User(sql.FieldLTE(FieldTotpEnabledAt, v)) +} + +// TotpEnabledAtIsNil applies the IsNil predicate on the "totp_enabled_at" field. +func TotpEnabledAtIsNil() predicate.User { + return predicate.User(sql.FieldIsNull(FieldTotpEnabledAt)) +} + +// TotpEnabledAtNotNil applies the NotNil predicate on the "totp_enabled_at" field. +func TotpEnabledAtNotNil() predicate.User { + return predicate.User(sql.FieldNotNull(FieldTotpEnabledAt)) +} + // HasAPIKeys applies the HasEdge predicate on the "api_keys" edge. func HasAPIKeys() predicate.User { return predicate.User(func(s *sql.Selector) { diff --git a/backend/ent/user_create.go b/backend/ent/user_create.go index e12e476c..6b4ebc59 100644 --- a/backend/ent/user_create.go +++ b/backend/ent/user_create.go @@ -167,6 +167,48 @@ func (_c *UserCreate) SetNillableNotes(v *string) *UserCreate { return _c } +// SetTotpSecretEncrypted sets the "totp_secret_encrypted" field. +func (_c *UserCreate) SetTotpSecretEncrypted(v string) *UserCreate { + _c.mutation.SetTotpSecretEncrypted(v) + return _c +} + +// SetNillableTotpSecretEncrypted sets the "totp_secret_encrypted" field if the given value is not nil. +func (_c *UserCreate) SetNillableTotpSecretEncrypted(v *string) *UserCreate { + if v != nil { + _c.SetTotpSecretEncrypted(*v) + } + return _c +} + +// SetTotpEnabled sets the "totp_enabled" field. +func (_c *UserCreate) SetTotpEnabled(v bool) *UserCreate { + _c.mutation.SetTotpEnabled(v) + return _c +} + +// SetNillableTotpEnabled sets the "totp_enabled" field if the given value is not nil. +func (_c *UserCreate) SetNillableTotpEnabled(v *bool) *UserCreate { + if v != nil { + _c.SetTotpEnabled(*v) + } + return _c +} + +// SetTotpEnabledAt sets the "totp_enabled_at" field. +func (_c *UserCreate) SetTotpEnabledAt(v time.Time) *UserCreate { + _c.mutation.SetTotpEnabledAt(v) + return _c +} + +// SetNillableTotpEnabledAt sets the "totp_enabled_at" field if the given value is not nil. +func (_c *UserCreate) SetNillableTotpEnabledAt(v *time.Time) *UserCreate { + if v != nil { + _c.SetTotpEnabledAt(*v) + } + return _c +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_c *UserCreate) AddAPIKeyIDs(ids ...int64) *UserCreate { _c.mutation.AddAPIKeyIDs(ids...) @@ -362,6 +404,10 @@ func (_c *UserCreate) defaults() error { v := user.DefaultNotes _c.mutation.SetNotes(v) } + if _, ok := _c.mutation.TotpEnabled(); !ok { + v := user.DefaultTotpEnabled + _c.mutation.SetTotpEnabled(v) + } return nil } @@ -422,6 +468,9 @@ func (_c *UserCreate) check() error { if _, ok := _c.mutation.Notes(); !ok { return &ValidationError{Name: "notes", err: errors.New(`ent: missing required field "User.notes"`)} } + if _, ok := _c.mutation.TotpEnabled(); !ok { + return &ValidationError{Name: "totp_enabled", err: errors.New(`ent: missing required field "User.totp_enabled"`)} + } return nil } @@ -493,6 +542,18 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) { _spec.SetField(user.FieldNotes, field.TypeString, value) _node.Notes = value } + if value, ok := _c.mutation.TotpSecretEncrypted(); ok { + _spec.SetField(user.FieldTotpSecretEncrypted, field.TypeString, value) + _node.TotpSecretEncrypted = &value + } + if value, ok := _c.mutation.TotpEnabled(); ok { + _spec.SetField(user.FieldTotpEnabled, field.TypeBool, value) + _node.TotpEnabled = value + } + if value, ok := _c.mutation.TotpEnabledAt(); ok { + _spec.SetField(user.FieldTotpEnabledAt, field.TypeTime, value) + _node.TotpEnabledAt = &value + } if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, @@ -815,6 +876,54 @@ func (u *UserUpsert) UpdateNotes() *UserUpsert { return u } +// SetTotpSecretEncrypted sets the "totp_secret_encrypted" field. +func (u *UserUpsert) SetTotpSecretEncrypted(v string) *UserUpsert { + u.Set(user.FieldTotpSecretEncrypted, v) + return u +} + +// UpdateTotpSecretEncrypted sets the "totp_secret_encrypted" field to the value that was provided on create. +func (u *UserUpsert) UpdateTotpSecretEncrypted() *UserUpsert { + u.SetExcluded(user.FieldTotpSecretEncrypted) + return u +} + +// ClearTotpSecretEncrypted clears the value of the "totp_secret_encrypted" field. +func (u *UserUpsert) ClearTotpSecretEncrypted() *UserUpsert { + u.SetNull(user.FieldTotpSecretEncrypted) + return u +} + +// SetTotpEnabled sets the "totp_enabled" field. +func (u *UserUpsert) SetTotpEnabled(v bool) *UserUpsert { + u.Set(user.FieldTotpEnabled, v) + return u +} + +// UpdateTotpEnabled sets the "totp_enabled" field to the value that was provided on create. +func (u *UserUpsert) UpdateTotpEnabled() *UserUpsert { + u.SetExcluded(user.FieldTotpEnabled) + return u +} + +// SetTotpEnabledAt sets the "totp_enabled_at" field. +func (u *UserUpsert) SetTotpEnabledAt(v time.Time) *UserUpsert { + u.Set(user.FieldTotpEnabledAt, v) + return u +} + +// UpdateTotpEnabledAt sets the "totp_enabled_at" field to the value that was provided on create. +func (u *UserUpsert) UpdateTotpEnabledAt() *UserUpsert { + u.SetExcluded(user.FieldTotpEnabledAt) + return u +} + +// ClearTotpEnabledAt clears the value of the "totp_enabled_at" field. +func (u *UserUpsert) ClearTotpEnabledAt() *UserUpsert { + u.SetNull(user.FieldTotpEnabledAt) + return u +} + // UpdateNewValues updates the mutable fields using the new values that were set on create. // Using this option is equivalent to using: // @@ -1021,6 +1130,62 @@ func (u *UserUpsertOne) UpdateNotes() *UserUpsertOne { }) } +// SetTotpSecretEncrypted sets the "totp_secret_encrypted" field. +func (u *UserUpsertOne) SetTotpSecretEncrypted(v string) *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.SetTotpSecretEncrypted(v) + }) +} + +// UpdateTotpSecretEncrypted sets the "totp_secret_encrypted" field to the value that was provided on create. +func (u *UserUpsertOne) UpdateTotpSecretEncrypted() *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.UpdateTotpSecretEncrypted() + }) +} + +// ClearTotpSecretEncrypted clears the value of the "totp_secret_encrypted" field. +func (u *UserUpsertOne) ClearTotpSecretEncrypted() *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.ClearTotpSecretEncrypted() + }) +} + +// SetTotpEnabled sets the "totp_enabled" field. +func (u *UserUpsertOne) SetTotpEnabled(v bool) *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.SetTotpEnabled(v) + }) +} + +// UpdateTotpEnabled sets the "totp_enabled" field to the value that was provided on create. +func (u *UserUpsertOne) UpdateTotpEnabled() *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.UpdateTotpEnabled() + }) +} + +// SetTotpEnabledAt sets the "totp_enabled_at" field. +func (u *UserUpsertOne) SetTotpEnabledAt(v time.Time) *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.SetTotpEnabledAt(v) + }) +} + +// UpdateTotpEnabledAt sets the "totp_enabled_at" field to the value that was provided on create. +func (u *UserUpsertOne) UpdateTotpEnabledAt() *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.UpdateTotpEnabledAt() + }) +} + +// ClearTotpEnabledAt clears the value of the "totp_enabled_at" field. +func (u *UserUpsertOne) ClearTotpEnabledAt() *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.ClearTotpEnabledAt() + }) +} + // Exec executes the query. func (u *UserUpsertOne) Exec(ctx context.Context) error { if len(u.create.conflict) == 0 { @@ -1393,6 +1558,62 @@ func (u *UserUpsertBulk) UpdateNotes() *UserUpsertBulk { }) } +// SetTotpSecretEncrypted sets the "totp_secret_encrypted" field. +func (u *UserUpsertBulk) SetTotpSecretEncrypted(v string) *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.SetTotpSecretEncrypted(v) + }) +} + +// UpdateTotpSecretEncrypted sets the "totp_secret_encrypted" field to the value that was provided on create. +func (u *UserUpsertBulk) UpdateTotpSecretEncrypted() *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.UpdateTotpSecretEncrypted() + }) +} + +// ClearTotpSecretEncrypted clears the value of the "totp_secret_encrypted" field. +func (u *UserUpsertBulk) ClearTotpSecretEncrypted() *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.ClearTotpSecretEncrypted() + }) +} + +// SetTotpEnabled sets the "totp_enabled" field. +func (u *UserUpsertBulk) SetTotpEnabled(v bool) *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.SetTotpEnabled(v) + }) +} + +// UpdateTotpEnabled sets the "totp_enabled" field to the value that was provided on create. +func (u *UserUpsertBulk) UpdateTotpEnabled() *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.UpdateTotpEnabled() + }) +} + +// SetTotpEnabledAt sets the "totp_enabled_at" field. +func (u *UserUpsertBulk) SetTotpEnabledAt(v time.Time) *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.SetTotpEnabledAt(v) + }) +} + +// UpdateTotpEnabledAt sets the "totp_enabled_at" field to the value that was provided on create. +func (u *UserUpsertBulk) UpdateTotpEnabledAt() *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.UpdateTotpEnabledAt() + }) +} + +// ClearTotpEnabledAt clears the value of the "totp_enabled_at" field. +func (u *UserUpsertBulk) ClearTotpEnabledAt() *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.ClearTotpEnabledAt() + }) +} + // Exec executes the query. func (u *UserUpsertBulk) Exec(ctx context.Context) error { if u.create.err != nil { diff --git a/backend/ent/user_update.go b/backend/ent/user_update.go index cf189fea..b98a41c6 100644 --- a/backend/ent/user_update.go +++ b/backend/ent/user_update.go @@ -187,6 +187,60 @@ func (_u *UserUpdate) SetNillableNotes(v *string) *UserUpdate { return _u } +// SetTotpSecretEncrypted sets the "totp_secret_encrypted" field. +func (_u *UserUpdate) SetTotpSecretEncrypted(v string) *UserUpdate { + _u.mutation.SetTotpSecretEncrypted(v) + return _u +} + +// SetNillableTotpSecretEncrypted sets the "totp_secret_encrypted" field if the given value is not nil. +func (_u *UserUpdate) SetNillableTotpSecretEncrypted(v *string) *UserUpdate { + if v != nil { + _u.SetTotpSecretEncrypted(*v) + } + return _u +} + +// ClearTotpSecretEncrypted clears the value of the "totp_secret_encrypted" field. +func (_u *UserUpdate) ClearTotpSecretEncrypted() *UserUpdate { + _u.mutation.ClearTotpSecretEncrypted() + return _u +} + +// SetTotpEnabled sets the "totp_enabled" field. +func (_u *UserUpdate) SetTotpEnabled(v bool) *UserUpdate { + _u.mutation.SetTotpEnabled(v) + return _u +} + +// SetNillableTotpEnabled sets the "totp_enabled" field if the given value is not nil. +func (_u *UserUpdate) SetNillableTotpEnabled(v *bool) *UserUpdate { + if v != nil { + _u.SetTotpEnabled(*v) + } + return _u +} + +// SetTotpEnabledAt sets the "totp_enabled_at" field. +func (_u *UserUpdate) SetTotpEnabledAt(v time.Time) *UserUpdate { + _u.mutation.SetTotpEnabledAt(v) + return _u +} + +// SetNillableTotpEnabledAt sets the "totp_enabled_at" field if the given value is not nil. +func (_u *UserUpdate) SetNillableTotpEnabledAt(v *time.Time) *UserUpdate { + if v != nil { + _u.SetTotpEnabledAt(*v) + } + return _u +} + +// ClearTotpEnabledAt clears the value of the "totp_enabled_at" field. +func (_u *UserUpdate) ClearTotpEnabledAt() *UserUpdate { + _u.mutation.ClearTotpEnabledAt() + return _u +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_u *UserUpdate) AddAPIKeyIDs(ids ...int64) *UserUpdate { _u.mutation.AddAPIKeyIDs(ids...) @@ -603,6 +657,21 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) { if value, ok := _u.mutation.Notes(); ok { _spec.SetField(user.FieldNotes, field.TypeString, value) } + if value, ok := _u.mutation.TotpSecretEncrypted(); ok { + _spec.SetField(user.FieldTotpSecretEncrypted, field.TypeString, value) + } + if _u.mutation.TotpSecretEncryptedCleared() { + _spec.ClearField(user.FieldTotpSecretEncrypted, field.TypeString) + } + if value, ok := _u.mutation.TotpEnabled(); ok { + _spec.SetField(user.FieldTotpEnabled, field.TypeBool, value) + } + if value, ok := _u.mutation.TotpEnabledAt(); ok { + _spec.SetField(user.FieldTotpEnabledAt, field.TypeTime, value) + } + if _u.mutation.TotpEnabledAtCleared() { + _spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime) + } if _u.mutation.APIKeysCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, @@ -1147,6 +1216,60 @@ func (_u *UserUpdateOne) SetNillableNotes(v *string) *UserUpdateOne { return _u } +// SetTotpSecretEncrypted sets the "totp_secret_encrypted" field. +func (_u *UserUpdateOne) SetTotpSecretEncrypted(v string) *UserUpdateOne { + _u.mutation.SetTotpSecretEncrypted(v) + return _u +} + +// SetNillableTotpSecretEncrypted sets the "totp_secret_encrypted" field if the given value is not nil. +func (_u *UserUpdateOne) SetNillableTotpSecretEncrypted(v *string) *UserUpdateOne { + if v != nil { + _u.SetTotpSecretEncrypted(*v) + } + return _u +} + +// ClearTotpSecretEncrypted clears the value of the "totp_secret_encrypted" field. +func (_u *UserUpdateOne) ClearTotpSecretEncrypted() *UserUpdateOne { + _u.mutation.ClearTotpSecretEncrypted() + return _u +} + +// SetTotpEnabled sets the "totp_enabled" field. +func (_u *UserUpdateOne) SetTotpEnabled(v bool) *UserUpdateOne { + _u.mutation.SetTotpEnabled(v) + return _u +} + +// SetNillableTotpEnabled sets the "totp_enabled" field if the given value is not nil. +func (_u *UserUpdateOne) SetNillableTotpEnabled(v *bool) *UserUpdateOne { + if v != nil { + _u.SetTotpEnabled(*v) + } + return _u +} + +// SetTotpEnabledAt sets the "totp_enabled_at" field. +func (_u *UserUpdateOne) SetTotpEnabledAt(v time.Time) *UserUpdateOne { + _u.mutation.SetTotpEnabledAt(v) + return _u +} + +// SetNillableTotpEnabledAt sets the "totp_enabled_at" field if the given value is not nil. +func (_u *UserUpdateOne) SetNillableTotpEnabledAt(v *time.Time) *UserUpdateOne { + if v != nil { + _u.SetTotpEnabledAt(*v) + } + return _u +} + +// ClearTotpEnabledAt clears the value of the "totp_enabled_at" field. +func (_u *UserUpdateOne) ClearTotpEnabledAt() *UserUpdateOne { + _u.mutation.ClearTotpEnabledAt() + return _u +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_u *UserUpdateOne) AddAPIKeyIDs(ids ...int64) *UserUpdateOne { _u.mutation.AddAPIKeyIDs(ids...) @@ -1593,6 +1716,21 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) { if value, ok := _u.mutation.Notes(); ok { _spec.SetField(user.FieldNotes, field.TypeString, value) } + if value, ok := _u.mutation.TotpSecretEncrypted(); ok { + _spec.SetField(user.FieldTotpSecretEncrypted, field.TypeString, value) + } + if _u.mutation.TotpSecretEncryptedCleared() { + _spec.ClearField(user.FieldTotpSecretEncrypted, field.TypeString) + } + if value, ok := _u.mutation.TotpEnabled(); ok { + _spec.SetField(user.FieldTotpEnabled, field.TypeBool, value) + } + if value, ok := _u.mutation.TotpEnabledAt(); ok { + _spec.SetField(user.FieldTotpEnabledAt, field.TypeTime, value) + } + if _u.mutation.TotpEnabledAtCleared() { + _spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime) + } if _u.mutation.APIKeysCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, diff --git a/backend/go.mod b/backend/go.mod index fd429b07..ad7d76b6 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -37,6 +37,7 @@ require ( github.com/andybalholm/brotli v1.2.0 // indirect github.com/apparentlymart/go-textseg/v15 v15.0.0 // indirect github.com/bmatcuk/doublestar v1.3.4 // indirect + github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect github.com/bytedance/sonic v1.9.1 // indirect github.com/cenkalti/backoff/v4 v4.3.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect @@ -106,6 +107,7 @@ require ( github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect + github.com/pquerna/otp v1.5.0 // indirect github.com/quic-go/qpack v0.6.0 // indirect github.com/quic-go/quic-go v0.57.1 // indirect github.com/refraction-networking/utls v1.8.1 // indirect diff --git a/backend/go.sum b/backend/go.sum index aa10718c..0addb5bb 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -20,6 +20,8 @@ github.com/apparentlymart/go-textseg/v15 v15.0.0 h1:uYvfpb3DyLSCGWnctWKGj857c6ew github.com/apparentlymart/go-textseg/v15 v15.0.0/go.mod h1:K8XmNZdhEBkdlyDdvbmmsvpAG721bKi0joRfFdHIWJ4= github.com/bmatcuk/doublestar v1.3.4 h1:gPypJ5xD31uhX6Tf54sDPUOBXTqKH4c9aPY66CyQrS0= github.com/bmatcuk/doublestar v1.3.4/go.mod h1:wiQtGV+rzVYxB7WIlirSN++5HPtPlXEo9MEoZQC/PmE= +github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc h1:biVzkmvwrH8WK8raXaxBx6fRVTlJILwEwQGL1I/ByEI= +github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= @@ -217,6 +219,8 @@ github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRI github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw= github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE= +github.com/pquerna/otp v1.5.0 h1:NMMR+WrmaqXU4EzdGJEE1aUUI0AMRzsp96fFFWNPwxs= +github.com/pquerna/otp v1.5.0/go.mod h1:dkJfzwRKNiegxyNb54X/3fLwhCynbMspSyWKnvi1AEg= github.com/prashantv/gostub v1.1.0 h1:BTyx3RfQjRHnUWaGF9oQos79AlQ5k8WNktv7VGvVH4g= github.com/prashantv/gostub v1.1.0/go.mod h1:A5zLQHz7ieHGG7is6LLXLz7I8+3LZzsrV0P1IAHhP5U= github.com/quic-go/qpack v0.6.0 h1:g7W+BMYynC1LbYLSqRt8PBg5Tgwxn214ZZR34VIOjz8= diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 00a78480..477cb59d 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -47,6 +47,7 @@ type Config struct { Redis RedisConfig `mapstructure:"redis"` Ops OpsConfig `mapstructure:"ops"` JWT JWTConfig `mapstructure:"jwt"` + Totp TotpConfig `mapstructure:"totp"` LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"` Default DefaultConfig `mapstructure:"default"` RateLimit RateLimitConfig `mapstructure:"rate_limit"` @@ -466,6 +467,16 @@ type JWTConfig struct { ExpireHour int `mapstructure:"expire_hour"` } +// TotpConfig TOTP 双因素认证配置 +type TotpConfig struct { + // EncryptionKey 用于加密 TOTP 密钥的 AES-256 密钥(32 字节 hex 编码) + // 如果为空,将自动生成一个随机密钥(仅适用于开发环境) + EncryptionKey string `mapstructure:"encryption_key"` + // EncryptionKeyConfigured 标记加密密钥是否为手动配置(非自动生成) + // 只有手动配置了密钥才允许在管理后台启用 TOTP 功能 + EncryptionKeyConfigured bool `mapstructure:"-"` +} + type TurnstileConfig struct { Required bool `mapstructure:"required"` } @@ -626,6 +637,20 @@ func Load() (*Config, error) { log.Println("Warning: JWT secret auto-generated. Consider setting a fixed secret for production.") } + // Auto-generate TOTP encryption key if not set (32 bytes = 64 hex chars for AES-256) + cfg.Totp.EncryptionKey = strings.TrimSpace(cfg.Totp.EncryptionKey) + if cfg.Totp.EncryptionKey == "" { + key, err := generateJWTSecret(32) // Reuse the same random generation function + if err != nil { + return nil, fmt.Errorf("generate totp encryption key error: %w", err) + } + cfg.Totp.EncryptionKey = key + cfg.Totp.EncryptionKeyConfigured = false + log.Println("Warning: TOTP encryption key auto-generated. Consider setting a fixed key for production.") + } else { + cfg.Totp.EncryptionKeyConfigured = true + } + if err := cfg.Validate(); err != nil { return nil, fmt.Errorf("validate config error: %w", err) } @@ -756,6 +781,9 @@ func setDefaults() { viper.SetDefault("jwt.secret", "") viper.SetDefault("jwt.expire_hour", 24) + // TOTP + viper.SetDefault("totp.encryption_key", "") + // Default // Admin credentials are created via the setup flow (web wizard / CLI / AUTO_SETUP). // Do not ship fixed defaults here to avoid insecure "known credentials" in production. diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go index 188aa0ec..bbf5d026 100644 --- a/backend/internal/handler/admin/account_handler.go +++ b/backend/internal/handler/admin/account_handler.go @@ -547,9 +547,18 @@ func (h *AccountHandler) Refresh(c *gin.Context) { } } - // 如果 project_id 获取失败,先更新凭证,再标记账户为 error + // 特殊处理 project_id:如果新值为空但旧值非空,保留旧值 + // 这确保了即使 LoadCodeAssist 失败,project_id 也不会丢失 + if newProjectID, _ := newCredentials["project_id"].(string); newProjectID == "" { + if oldProjectID := strings.TrimSpace(account.GetCredential("project_id")); oldProjectID != "" { + newCredentials["project_id"] = oldProjectID + } + } + + // 如果 project_id 获取失败,更新凭证但不标记为 error + // LoadCodeAssist 失败可能是临时网络问题,给它机会在下次自动刷新时重试 if tokenInfo.ProjectIDMissing { - // 先更新凭证 + // 先更新凭证(token 本身刷新成功了) _, updateErr := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{ Credentials: newCredentials, }) @@ -557,14 +566,10 @@ func (h *AccountHandler) Refresh(c *gin.Context) { response.InternalError(c, "Failed to update credentials: "+updateErr.Error()) return } - // 标记账户为 error - if setErr := h.adminService.SetAccountError(c.Request.Context(), accountID, "missing_project_id: 账户缺少project id,可能无法使用Antigravity"); setErr != nil { - response.InternalError(c, "Failed to set account error: "+setErr.Error()) - return - } + // 不标记为 error,只返回警告信息 response.Success(c, gin.H{ - "message": "Token refreshed but project_id is missing, account marked as error", - "warning": "missing_project_id", + "message": "Token refreshed successfully, but project_id could not be retrieved (will retry automatically)", + "warning": "missing_project_id_temporary", }) return } diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index 0e3e0a2f..4a798fa1 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -48,6 +48,9 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { RegistrationEnabled: settings.RegistrationEnabled, EmailVerifyEnabled: settings.EmailVerifyEnabled, PromoCodeEnabled: settings.PromoCodeEnabled, + PasswordResetEnabled: settings.PasswordResetEnabled, + TotpEnabled: settings.TotpEnabled, + TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(), SMTPHost: settings.SMTPHost, SMTPPort: settings.SMTPPort, SMTPUsername: settings.SMTPUsername, @@ -89,9 +92,11 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { // UpdateSettingsRequest 更新设置请求 type UpdateSettingsRequest struct { // 注册设置 - RegistrationEnabled bool `json:"registration_enabled"` - EmailVerifyEnabled bool `json:"email_verify_enabled"` - PromoCodeEnabled bool `json:"promo_code_enabled"` + RegistrationEnabled bool `json:"registration_enabled"` + EmailVerifyEnabled bool `json:"email_verify_enabled"` + PromoCodeEnabled bool `json:"promo_code_enabled"` + PasswordResetEnabled bool `json:"password_reset_enabled"` + TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证 // 邮件服务设置 SMTPHost string `json:"smtp_host"` @@ -198,6 +203,16 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { } } + // TOTP 双因素认证参数验证 + // 只有手动配置了加密密钥才允许启用 TOTP 功能 + if req.TotpEnabled && !previousSettings.TotpEnabled { + // 尝试启用 TOTP,检查加密密钥是否已手动配置 + if !h.settingService.IsTotpEncryptionKeyConfigured() { + response.BadRequest(c, "Cannot enable TOTP: TOTP_ENCRYPTION_KEY environment variable must be configured first. Generate a key with 'openssl rand -hex 32' and set it in your environment.") + return + } + } + // LinuxDo Connect 参数验证 if req.LinuxDoConnectEnabled { req.LinuxDoConnectClientID = strings.TrimSpace(req.LinuxDoConnectClientID) @@ -243,6 +258,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { RegistrationEnabled: req.RegistrationEnabled, EmailVerifyEnabled: req.EmailVerifyEnabled, PromoCodeEnabled: req.PromoCodeEnabled, + PasswordResetEnabled: req.PasswordResetEnabled, + TotpEnabled: req.TotpEnabled, SMTPHost: req.SMTPHost, SMTPPort: req.SMTPPort, SMTPUsername: req.SMTPUsername, @@ -318,6 +335,9 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { RegistrationEnabled: updatedSettings.RegistrationEnabled, EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled, PromoCodeEnabled: updatedSettings.PromoCodeEnabled, + PasswordResetEnabled: updatedSettings.PasswordResetEnabled, + TotpEnabled: updatedSettings.TotpEnabled, + TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(), SMTPHost: updatedSettings.SMTPHost, SMTPPort: updatedSettings.SMTPPort, SMTPUsername: updatedSettings.SMTPUsername, @@ -384,6 +404,12 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, if before.EmailVerifyEnabled != after.EmailVerifyEnabled { changed = append(changed, "email_verify_enabled") } + if before.PasswordResetEnabled != after.PasswordResetEnabled { + changed = append(changed, "password_reset_enabled") + } + if before.TotpEnabled != after.TotpEnabled { + changed = append(changed, "totp_enabled") + } if before.SMTPHost != after.SMTPHost { changed = append(changed, "smtp_host") } diff --git a/backend/internal/handler/admin/subscription_handler.go b/backend/internal/handler/admin/subscription_handler.go index a0d1456f..51995ab1 100644 --- a/backend/internal/handler/admin/subscription_handler.go +++ b/backend/internal/handler/admin/subscription_handler.go @@ -77,7 +77,11 @@ func (h *SubscriptionHandler) List(c *gin.Context) { } status := c.Query("status") - subscriptions, pagination, err := h.subscriptionService.List(c.Request.Context(), page, pageSize, userID, groupID, status) + // Parse sorting parameters + sortBy := c.DefaultQuery("sort_by", "created_at") + sortOrder := c.DefaultQuery("sort_order", "desc") + + subscriptions, pagination, err := h.subscriptionService.List(c.Request.Context(), page, pageSize, userID, groupID, status, sortBy, sortOrder) if err != nil { response.ErrorFrom(c, err) return diff --git a/backend/internal/handler/auth_handler.go b/backend/internal/handler/auth_handler.go index 89f34aae..3522407d 100644 --- a/backend/internal/handler/auth_handler.go +++ b/backend/internal/handler/auth_handler.go @@ -1,6 +1,8 @@ package handler import ( + "log/slog" + "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/pkg/ip" @@ -18,16 +20,18 @@ type AuthHandler struct { userService *service.UserService settingSvc *service.SettingService promoService *service.PromoService + totpService *service.TotpService } // NewAuthHandler creates a new AuthHandler -func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService, settingService *service.SettingService, promoService *service.PromoService) *AuthHandler { +func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService, settingService *service.SettingService, promoService *service.PromoService, totpService *service.TotpService) *AuthHandler { return &AuthHandler{ cfg: cfg, authService: authService, userService: userService, settingSvc: settingService, promoService: promoService, + totpService: totpService, } } @@ -144,6 +148,100 @@ func (h *AuthHandler) Login(c *gin.Context) { return } + // Check if TOTP 2FA is enabled for this user + if h.totpService != nil && h.settingSvc.IsTotpEnabled(c.Request.Context()) && user.TotpEnabled { + // Create a temporary login session for 2FA + tempToken, err := h.totpService.CreateLoginSession(c.Request.Context(), user.ID, user.Email) + if err != nil { + response.InternalError(c, "Failed to create 2FA session") + return + } + + response.Success(c, TotpLoginResponse{ + Requires2FA: true, + TempToken: tempToken, + UserEmailMasked: service.MaskEmail(user.Email), + }) + return + } + + response.Success(c, AuthResponse{ + AccessToken: token, + TokenType: "Bearer", + User: dto.UserFromService(user), + }) +} + +// TotpLoginResponse represents the response when 2FA is required +type TotpLoginResponse struct { + Requires2FA bool `json:"requires_2fa"` + TempToken string `json:"temp_token,omitempty"` + UserEmailMasked string `json:"user_email_masked,omitempty"` +} + +// Login2FARequest represents the 2FA login request +type Login2FARequest struct { + TempToken string `json:"temp_token" binding:"required"` + TotpCode string `json:"totp_code" binding:"required,len=6"` +} + +// Login2FA completes the login with 2FA verification +// POST /api/v1/auth/login/2fa +func (h *AuthHandler) Login2FA(c *gin.Context) { + var req Login2FARequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + slog.Debug("login_2fa_request", + "temp_token_len", len(req.TempToken), + "totp_code_len", len(req.TotpCode)) + + // Get the login session + session, err := h.totpService.GetLoginSession(c.Request.Context(), req.TempToken) + if err != nil || session == nil { + tokenPrefix := "" + if len(req.TempToken) >= 8 { + tokenPrefix = req.TempToken[:8] + } + slog.Debug("login_2fa_session_invalid", + "temp_token_prefix", tokenPrefix, + "error", err) + response.BadRequest(c, "Invalid or expired 2FA session") + return + } + + slog.Debug("login_2fa_session_found", + "user_id", session.UserID, + "email", session.Email) + + // Verify the TOTP code + if err := h.totpService.VerifyCode(c.Request.Context(), session.UserID, req.TotpCode); err != nil { + slog.Debug("login_2fa_verify_failed", + "user_id", session.UserID, + "error", err) + response.ErrorFrom(c, err) + return + } + + // Delete the login session + _ = h.totpService.DeleteLoginSession(c.Request.Context(), req.TempToken) + + // Get the user + user, err := h.userService.GetByID(c.Request.Context(), session.UserID) + if err != nil { + response.ErrorFrom(c, err) + return + } + + // Generate the JWT token + token, err := h.authService.GenerateToken(user) + if err != nil { + response.InternalError(c, "Failed to generate token") + return + } + response.Success(c, AuthResponse{ AccessToken: token, TokenType: "Bearer", @@ -247,3 +345,85 @@ func (h *AuthHandler) ValidatePromoCode(c *gin.Context) { BonusAmount: promoCode.BonusAmount, }) } + +// ForgotPasswordRequest 忘记密码请求 +type ForgotPasswordRequest struct { + Email string `json:"email" binding:"required,email"` + TurnstileToken string `json:"turnstile_token"` +} + +// ForgotPasswordResponse 忘记密码响应 +type ForgotPasswordResponse struct { + Message string `json:"message"` +} + +// ForgotPassword 请求密码重置 +// POST /api/v1/auth/forgot-password +func (h *AuthHandler) ForgotPassword(c *gin.Context) { + var req ForgotPasswordRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + // Turnstile 验证 + if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, ip.GetClientIP(c)); err != nil { + response.ErrorFrom(c, err) + return + } + + // Build frontend base URL from request + scheme := "https" + if c.Request.TLS == nil { + // Check X-Forwarded-Proto header (common in reverse proxy setups) + if proto := c.GetHeader("X-Forwarded-Proto"); proto != "" { + scheme = proto + } else { + scheme = "http" + } + } + frontendBaseURL := scheme + "://" + c.Request.Host + + // Request password reset (async) + // Note: This returns success even if email doesn't exist (to prevent enumeration) + if err := h.authService.RequestPasswordResetAsync(c.Request.Context(), req.Email, frontendBaseURL); err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, ForgotPasswordResponse{ + Message: "If your email is registered, you will receive a password reset link shortly.", + }) +} + +// ResetPasswordRequest 重置密码请求 +type ResetPasswordRequest struct { + Email string `json:"email" binding:"required,email"` + Token string `json:"token" binding:"required"` + NewPassword string `json:"new_password" binding:"required,min=6"` +} + +// ResetPasswordResponse 重置密码响应 +type ResetPasswordResponse struct { + Message string `json:"message"` +} + +// ResetPassword 重置密码 +// POST /api/v1/auth/reset-password +func (h *AuthHandler) ResetPassword(c *gin.Context) { + var req ResetPasswordRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + // Reset password + if err := h.authService.ResetPassword(c.Request.Context(), req.Email, req.Token, req.NewPassword); err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, ResetPasswordResponse{ + Message: "Your password has been reset successfully. You can now log in with your new password.", + }) +} diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go index 01f39478..fc7b1349 100644 --- a/backend/internal/handler/dto/settings.go +++ b/backend/internal/handler/dto/settings.go @@ -2,9 +2,12 @@ package dto // SystemSettings represents the admin settings API response payload. type SystemSettings struct { - RegistrationEnabled bool `json:"registration_enabled"` - EmailVerifyEnabled bool `json:"email_verify_enabled"` - PromoCodeEnabled bool `json:"promo_code_enabled"` + RegistrationEnabled bool `json:"registration_enabled"` + EmailVerifyEnabled bool `json:"email_verify_enabled"` + PromoCodeEnabled bool `json:"promo_code_enabled"` + PasswordResetEnabled bool `json:"password_reset_enabled"` + TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证 + TotpEncryptionKeyConfigured bool `json:"totp_encryption_key_configured"` // TOTP 加密密钥是否已配置 SMTPHost string `json:"smtp_host"` SMTPPort int `json:"smtp_port"` @@ -54,21 +57,23 @@ type SystemSettings struct { } type PublicSettings struct { - RegistrationEnabled bool `json:"registration_enabled"` - EmailVerifyEnabled bool `json:"email_verify_enabled"` - PromoCodeEnabled bool `json:"promo_code_enabled"` - TurnstileEnabled bool `json:"turnstile_enabled"` - TurnstileSiteKey string `json:"turnstile_site_key"` - SiteName string `json:"site_name"` - SiteLogo string `json:"site_logo"` - SiteSubtitle string `json:"site_subtitle"` - APIBaseURL string `json:"api_base_url"` - ContactInfo string `json:"contact_info"` - DocURL string `json:"doc_url"` - HomeContent string `json:"home_content"` - HideCcsImportButton bool `json:"hide_ccs_import_button"` - LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` - Version string `json:"version"` + RegistrationEnabled bool `json:"registration_enabled"` + EmailVerifyEnabled bool `json:"email_verify_enabled"` + PromoCodeEnabled bool `json:"promo_code_enabled"` + PasswordResetEnabled bool `json:"password_reset_enabled"` + TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证 + TurnstileEnabled bool `json:"turnstile_enabled"` + TurnstileSiteKey string `json:"turnstile_site_key"` + SiteName string `json:"site_name"` + SiteLogo string `json:"site_logo"` + SiteSubtitle string `json:"site_subtitle"` + APIBaseURL string `json:"api_base_url"` + ContactInfo string `json:"contact_info"` + DocURL string `json:"doc_url"` + HomeContent string `json:"home_content"` + HideCcsImportButton bool `json:"hide_ccs_import_button"` + LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` + Version string `json:"version"` } // StreamTimeoutSettings 流超时处理配置 DTO diff --git a/backend/internal/handler/gemini_cli_session_test.go b/backend/internal/handler/gemini_cli_session_test.go new file mode 100644 index 00000000..0b37f5f2 --- /dev/null +++ b/backend/internal/handler/gemini_cli_session_test.go @@ -0,0 +1,122 @@ +//go:build unit + +package handler + +import ( + "crypto/sha256" + "encoding/hex" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestExtractGeminiCLISessionHash(t *testing.T) { + tests := []struct { + name string + body string + privilegedUserID string + wantEmpty bool + wantHash string + }{ + { + name: "with privileged-user-id and tmp dir", + body: `{"contents":[{"parts":[{"text":"The project's temporary directory is: /Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740"}]}]}`, + privilegedUserID: "90785f52-8bbe-4b17-b111-a1ddea1636c3", + wantEmpty: false, + wantHash: func() string { + combined := "90785f52-8bbe-4b17-b111-a1ddea1636c3:f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740" + hash := sha256.Sum256([]byte(combined)) + return hex.EncodeToString(hash[:]) + }(), + }, + { + name: "without privileged-user-id but with tmp dir", + body: `{"contents":[{"parts":[{"text":"The project's temporary directory is: /Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740"}]}]}`, + privilegedUserID: "", + wantEmpty: false, + wantHash: "f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740", + }, + { + name: "without tmp dir", + body: `{"contents":[{"parts":[{"text":"Hello world"}]}]}`, + privilegedUserID: "90785f52-8bbe-4b17-b111-a1ddea1636c3", + wantEmpty: true, + }, + { + name: "empty body", + body: "", + privilegedUserID: "90785f52-8bbe-4b17-b111-a1ddea1636c3", + wantEmpty: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // 创建测试上下文 + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest("POST", "/test", nil) + if tt.privilegedUserID != "" { + c.Request.Header.Set("x-gemini-api-privileged-user-id", tt.privilegedUserID) + } + + // 调用函数 + result := extractGeminiCLISessionHash(c, []byte(tt.body)) + + // 验证结果 + if tt.wantEmpty { + require.Empty(t, result, "expected empty session hash") + } else { + require.NotEmpty(t, result, "expected non-empty session hash") + require.Equal(t, tt.wantHash, result, "session hash mismatch") + } + }) + } +} + +func TestGeminiCLITmpDirRegex(t *testing.T) { + tests := []struct { + name string + input string + wantMatch bool + wantHash string + }{ + { + name: "valid tmp dir path", + input: "/Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740", + wantMatch: true, + wantHash: "f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740", + }, + { + name: "valid tmp dir path in text", + input: "The project's temporary directory is: /Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740\nOther text", + wantMatch: true, + wantHash: "f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740", + }, + { + name: "invalid hash length", + input: "/Users/ianshaw/.gemini/tmp/abc123", + wantMatch: false, + }, + { + name: "no tmp dir", + input: "Hello world", + wantMatch: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + match := geminiCLITmpDirRegex.FindStringSubmatch(tt.input) + if tt.wantMatch { + require.NotNil(t, match, "expected regex to match") + require.Len(t, match, 2, "expected 2 capture groups") + require.Equal(t, tt.wantHash, match[1], "hash mismatch") + } else { + require.Nil(t, match, "expected regex not to match") + } + }) + } +} diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index c7646b38..32f83013 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -1,11 +1,15 @@ package handler import ( + "bytes" "context" + "crypto/sha256" + "encoding/hex" "errors" "io" "log" "net/http" + "regexp" "strings" "time" @@ -19,6 +23,17 @@ import ( "github.com/gin-gonic/gin" ) +// geminiCLITmpDirRegex 用于从 Gemini CLI 请求体中提取 tmp 目录的哈希值 +// 匹配格式: /Users/xxx/.gemini/tmp/[64位十六进制哈希] +var geminiCLITmpDirRegex = regexp.MustCompile(`/\.gemini/tmp/([A-Fa-f0-9]{64})`) + +func isGeminiCLIRequest(c *gin.Context, body []byte) bool { + if strings.TrimSpace(c.GetHeader("x-gemini-api-privileged-user-id")) != "" { + return true + } + return geminiCLITmpDirRegex.Match(body) +} + // GeminiV1BetaListModels proxies: // GET /v1beta/models func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) { @@ -214,12 +229,26 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { } // 3) select account (sticky session based on request body) - parsedReq, _ := service.ParseGatewayRequest(body) - sessionHash := h.gatewayService.GenerateSessionHash(parsedReq) + // 优先使用 Gemini CLI 的会话标识(privileged-user-id + tmp 目录哈希) + sessionHash := extractGeminiCLISessionHash(c, body) + if sessionHash == "" { + // Fallback: 使用通用的会话哈希生成逻辑(适用于其他客户端) + parsedReq, _ := service.ParseGatewayRequest(body) + sessionHash = h.gatewayService.GenerateSessionHash(parsedReq) + } sessionKey := sessionHash if sessionHash != "" { sessionKey = "gemini:" + sessionHash } + + // 查询粘性会话绑定的账号 ID(用于检测账号切换) + var sessionBoundAccountID int64 + if sessionKey != "" { + sessionBoundAccountID, _ = h.gatewayService.GetCachedSessionAccountID(c.Request.Context(), apiKey.GroupID, sessionKey) + } + isCLI := isGeminiCLIRequest(c, body) + cleanedForUnknownBinding := false + maxAccountSwitches := h.maxAccountSwitchesGemini switchCount := 0 failedAccountIDs := make(map[int64]struct{}) @@ -238,6 +267,24 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { account := selection.Account setOpsSelectedAccount(c, account.ID) + // 检测账号切换:如果粘性会话绑定的账号与当前选择的账号不同,清除 thoughtSignature + // 注意:Gemini 原生 API 的 thoughtSignature 与具体上游账号强相关;跨账号透传会导致 400。 + if sessionBoundAccountID > 0 && sessionBoundAccountID != account.ID { + log.Printf("[Gemini] Sticky session account switched: %d -> %d, cleaning thoughtSignature", sessionBoundAccountID, account.ID) + body = service.CleanGeminiNativeThoughtSignatures(body) + sessionBoundAccountID = account.ID + } else if sessionKey != "" && sessionBoundAccountID == 0 && isCLI && !cleanedForUnknownBinding && bytes.Contains(body, []byte(`"thoughtSignature"`)) { + // 无缓存绑定但请求里已有 thoughtSignature:常见于缓存丢失/TTL 过期后,CLI 继续携带旧签名。 + // 为避免第一次转发就 400,这里做一次确定性清理,让新账号重新生成签名链路。 + log.Printf("[Gemini] Sticky session binding missing for CLI request, cleaning thoughtSignature proactively") + body = service.CleanGeminiNativeThoughtSignatures(body) + cleanedForUnknownBinding = true + sessionBoundAccountID = account.ID + } else if sessionBoundAccountID == 0 { + // 记录本次请求中首次选择到的账号,便于同一请求内 failover 时检测切换。 + sessionBoundAccountID = account.ID + } + // 4) account concurrency slot accountReleaseFunc := selection.ReleaseFunc if !selection.Acquired { @@ -433,3 +480,38 @@ func shouldFallbackGeminiModels(res *service.UpstreamHTTPResult) bool { } return false } + +// extractGeminiCLISessionHash 从 Gemini CLI 请求中提取会话标识。 +// 组合 x-gemini-api-privileged-user-id header 和请求体中的 tmp 目录哈希。 +// +// 会话标识生成策略: +// 1. 从请求体中提取 tmp 目录哈希(64位十六进制) +// 2. 从 header 中提取 privileged-user-id(UUID) +// 3. 组合两者生成 SHA256 哈希作为最终的会话标识 +// +// 如果找不到 tmp 目录哈希,返回空字符串(不使用粘性会话)。 +// +// extractGeminiCLISessionHash extracts session identifier from Gemini CLI requests. +// Combines x-gemini-api-privileged-user-id header with tmp directory hash from request body. +func extractGeminiCLISessionHash(c *gin.Context, body []byte) string { + // 1. 从请求体中提取 tmp 目录哈希 + match := geminiCLITmpDirRegex.FindSubmatch(body) + if len(match) < 2 { + return "" // 没有找到 tmp 目录,不使用粘性会话 + } + tmpDirHash := string(match[1]) + + // 2. 提取 privileged-user-id + privilegedUserID := strings.TrimSpace(c.GetHeader("x-gemini-api-privileged-user-id")) + + // 3. 组合生成最终的 session hash + if privilegedUserID != "" { + // 组合两个标识符:privileged-user-id + tmp 目录哈希 + combined := privilegedUserID + ":" + tmpDirHash + hash := sha256.Sum256([]byte(combined)) + return hex.EncodeToString(hash[:]) + } + + // 如果没有 privileged-user-id,直接使用 tmp 目录哈希 + return tmpDirHash +} diff --git a/backend/internal/handler/handler.go b/backend/internal/handler/handler.go index 5b1b317d..907c314d 100644 --- a/backend/internal/handler/handler.go +++ b/backend/internal/handler/handler.go @@ -37,6 +37,7 @@ type Handlers struct { Gateway *GatewayHandler OpenAIGateway *OpenAIGatewayHandler Setting *SettingHandler + Totp *TotpHandler } // BuildInfo contains build-time information diff --git a/backend/internal/handler/setting_handler.go b/backend/internal/handler/setting_handler.go index 8723c746..9c0bde33 100644 --- a/backend/internal/handler/setting_handler.go +++ b/backend/internal/handler/setting_handler.go @@ -32,20 +32,21 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) { } response.Success(c, dto.PublicSettings{ - RegistrationEnabled: settings.RegistrationEnabled, - EmailVerifyEnabled: settings.EmailVerifyEnabled, - PromoCodeEnabled: settings.PromoCodeEnabled, - TurnstileEnabled: settings.TurnstileEnabled, - TurnstileSiteKey: settings.TurnstileSiteKey, - SiteName: settings.SiteName, - SiteLogo: settings.SiteLogo, - SiteSubtitle: settings.SiteSubtitle, - APIBaseURL: settings.APIBaseURL, - ContactInfo: settings.ContactInfo, - DocURL: settings.DocURL, - HomeContent: settings.HomeContent, - HideCcsImportButton: settings.HideCcsImportButton, - LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, - Version: h.version, + RegistrationEnabled: settings.RegistrationEnabled, + EmailVerifyEnabled: settings.EmailVerifyEnabled, + PromoCodeEnabled: settings.PromoCodeEnabled, + PasswordResetEnabled: settings.PasswordResetEnabled, + TurnstileEnabled: settings.TurnstileEnabled, + TurnstileSiteKey: settings.TurnstileSiteKey, + SiteName: settings.SiteName, + SiteLogo: settings.SiteLogo, + SiteSubtitle: settings.SiteSubtitle, + APIBaseURL: settings.APIBaseURL, + ContactInfo: settings.ContactInfo, + DocURL: settings.DocURL, + HomeContent: settings.HomeContent, + HideCcsImportButton: settings.HideCcsImportButton, + LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, + Version: h.version, }) } diff --git a/backend/internal/handler/totp_handler.go b/backend/internal/handler/totp_handler.go new file mode 100644 index 00000000..5c5eb567 --- /dev/null +++ b/backend/internal/handler/totp_handler.go @@ -0,0 +1,181 @@ +package handler + +import ( + "github.com/gin-gonic/gin" + + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" +) + +// TotpHandler handles TOTP-related requests +type TotpHandler struct { + totpService *service.TotpService +} + +// NewTotpHandler creates a new TotpHandler +func NewTotpHandler(totpService *service.TotpService) *TotpHandler { + return &TotpHandler{ + totpService: totpService, + } +} + +// TotpStatusResponse represents the TOTP status response +type TotpStatusResponse struct { + Enabled bool `json:"enabled"` + EnabledAt *int64 `json:"enabled_at,omitempty"` // Unix timestamp + FeatureEnabled bool `json:"feature_enabled"` +} + +// GetStatus returns the TOTP status for the current user +// GET /api/v1/user/totp/status +func (h *TotpHandler) GetStatus(c *gin.Context) { + subject, ok := middleware2.GetAuthSubjectFromContext(c) + if !ok { + response.Unauthorized(c, "User not authenticated") + return + } + + status, err := h.totpService.GetStatus(c.Request.Context(), subject.UserID) + if err != nil { + response.ErrorFrom(c, err) + return + } + + resp := TotpStatusResponse{ + Enabled: status.Enabled, + FeatureEnabled: status.FeatureEnabled, + } + + if status.EnabledAt != nil { + ts := status.EnabledAt.Unix() + resp.EnabledAt = &ts + } + + response.Success(c, resp) +} + +// TotpSetupRequest represents the request to initiate TOTP setup +type TotpSetupRequest struct { + EmailCode string `json:"email_code"` + Password string `json:"password"` +} + +// TotpSetupResponse represents the TOTP setup response +type TotpSetupResponse struct { + Secret string `json:"secret"` + QRCodeURL string `json:"qr_code_url"` + SetupToken string `json:"setup_token"` + Countdown int `json:"countdown"` +} + +// InitiateSetup starts the TOTP setup process +// POST /api/v1/user/totp/setup +func (h *TotpHandler) InitiateSetup(c *gin.Context) { + subject, ok := middleware2.GetAuthSubjectFromContext(c) + if !ok { + response.Unauthorized(c, "User not authenticated") + return + } + + var req TotpSetupRequest + if err := c.ShouldBindJSON(&req); err != nil { + // Allow empty body (optional params) + req = TotpSetupRequest{} + } + + result, err := h.totpService.InitiateSetup(c.Request.Context(), subject.UserID, req.EmailCode, req.Password) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, TotpSetupResponse{ + Secret: result.Secret, + QRCodeURL: result.QRCodeURL, + SetupToken: result.SetupToken, + Countdown: result.Countdown, + }) +} + +// TotpEnableRequest represents the request to enable TOTP +type TotpEnableRequest struct { + TotpCode string `json:"totp_code" binding:"required,len=6"` + SetupToken string `json:"setup_token" binding:"required"` +} + +// Enable completes the TOTP setup +// POST /api/v1/user/totp/enable +func (h *TotpHandler) Enable(c *gin.Context) { + subject, ok := middleware2.GetAuthSubjectFromContext(c) + if !ok { + response.Unauthorized(c, "User not authenticated") + return + } + + var req TotpEnableRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + if err := h.totpService.CompleteSetup(c.Request.Context(), subject.UserID, req.TotpCode, req.SetupToken); err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, gin.H{"success": true}) +} + +// TotpDisableRequest represents the request to disable TOTP +type TotpDisableRequest struct { + EmailCode string `json:"email_code"` + Password string `json:"password"` +} + +// Disable disables TOTP for the current user +// POST /api/v1/user/totp/disable +func (h *TotpHandler) Disable(c *gin.Context) { + subject, ok := middleware2.GetAuthSubjectFromContext(c) + if !ok { + response.Unauthorized(c, "User not authenticated") + return + } + + var req TotpDisableRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + if err := h.totpService.Disable(c.Request.Context(), subject.UserID, req.EmailCode, req.Password); err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, gin.H{"success": true}) +} + +// GetVerificationMethod returns the verification method for TOTP operations +// GET /api/v1/user/totp/verification-method +func (h *TotpHandler) GetVerificationMethod(c *gin.Context) { + method := h.totpService.GetVerificationMethod(c.Request.Context()) + response.Success(c, method) +} + +// SendVerifyCode sends an email verification code for TOTP operations +// POST /api/v1/user/totp/send-code +func (h *TotpHandler) SendVerifyCode(c *gin.Context) { + subject, ok := middleware2.GetAuthSubjectFromContext(c) + if !ok { + response.Unauthorized(c, "User not authenticated") + return + } + + if err := h.totpService.SendVerifyCode(c.Request.Context(), subject.UserID); err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, gin.H{"success": true}) +} diff --git a/backend/internal/handler/wire.go b/backend/internal/handler/wire.go index 2af7905e..92e8edeb 100644 --- a/backend/internal/handler/wire.go +++ b/backend/internal/handler/wire.go @@ -70,6 +70,7 @@ func ProvideHandlers( gatewayHandler *GatewayHandler, openaiGatewayHandler *OpenAIGatewayHandler, settingHandler *SettingHandler, + totpHandler *TotpHandler, ) *Handlers { return &Handlers{ Auth: authHandler, @@ -82,6 +83,7 @@ func ProvideHandlers( Gateway: gatewayHandler, OpenAIGateway: openaiGatewayHandler, Setting: settingHandler, + Totp: totpHandler, } } @@ -96,6 +98,7 @@ var ProviderSet = wire.NewSet( NewSubscriptionHandler, NewGatewayHandler, NewOpenAIGatewayHandler, + NewTotpHandler, ProvideSettingHandler, // Admin handlers diff --git a/backend/internal/pkg/antigravity/request_transformer.go b/backend/internal/pkg/antigravity/request_transformer.go index 1b21bd58..63f6ee7c 100644 --- a/backend/internal/pkg/antigravity/request_transformer.go +++ b/backend/internal/pkg/antigravity/request_transformer.go @@ -367,8 +367,10 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu Text: block.Thinking, Thought: true, } - // 保留原有 signature(Claude 模型需要有效的 signature) - if block.Signature != "" { + // signature 处理: + // - Claude 模型(allowDummyThought=false):必须是上游返回的真实 signature(dummy 视为缺失) + // - Gemini 模型(allowDummyThought=true):优先透传真实 signature,缺失时使用 dummy signature + if block.Signature != "" && (allowDummyThought || block.Signature != dummyThoughtSignature) { part.ThoughtSignature = block.Signature } else if !allowDummyThought { // Claude 模型需要有效 signature;在缺失时降级为普通文本,并在上层禁用 thinking mode。 @@ -407,12 +409,12 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu }, } // tool_use 的 signature 处理: - // - Gemini 模型:使用 dummy signature(跳过 thought_signature 校验) - // - Claude 模型:透传上游返回的真实 signature(Vertex/Google 需要完整签名链路) - if allowDummyThought { - part.ThoughtSignature = dummyThoughtSignature - } else if block.Signature != "" && block.Signature != dummyThoughtSignature { + // - Claude 模型(allowDummyThought=false):必须是上游返回的真实 signature(dummy 视为缺失) + // - Gemini 模型(allowDummyThought=true):优先透传真实 signature,缺失时使用 dummy signature + if block.Signature != "" && (allowDummyThought || block.Signature != dummyThoughtSignature) { part.ThoughtSignature = block.Signature + } else if allowDummyThought { + part.ThoughtSignature = dummyThoughtSignature } parts = append(parts, part) diff --git a/backend/internal/pkg/antigravity/request_transformer_test.go b/backend/internal/pkg/antigravity/request_transformer_test.go index 60ee6f63..9d62a4a1 100644 --- a/backend/internal/pkg/antigravity/request_transformer_test.go +++ b/backend/internal/pkg/antigravity/request_transformer_test.go @@ -100,7 +100,7 @@ func TestBuildParts_ToolUseSignatureHandling(t *testing.T) { {"type": "tool_use", "id": "t1", "name": "Bash", "input": {"command": "ls"}, "signature": "sig_tool_abc"} ]` - t.Run("Gemini uses dummy tool_use signature", func(t *testing.T) { + t.Run("Gemini preserves provided tool_use signature", func(t *testing.T) { toolIDToName := make(map[string]string) parts, _, err := buildParts(json.RawMessage(content), toolIDToName, true) if err != nil { @@ -109,6 +109,23 @@ func TestBuildParts_ToolUseSignatureHandling(t *testing.T) { if len(parts) != 1 || parts[0].FunctionCall == nil { t.Fatalf("expected 1 functionCall part, got %+v", parts) } + if parts[0].ThoughtSignature != "sig_tool_abc" { + t.Fatalf("expected preserved tool signature %q, got %q", "sig_tool_abc", parts[0].ThoughtSignature) + } + }) + + t.Run("Gemini falls back to dummy tool_use signature when missing", func(t *testing.T) { + contentNoSig := `[ + {"type": "tool_use", "id": "t1", "name": "Bash", "input": {"command": "ls"}} + ]` + toolIDToName := make(map[string]string) + parts, _, err := buildParts(json.RawMessage(contentNoSig), toolIDToName, true) + if err != nil { + t.Fatalf("buildParts() error = %v", err) + } + if len(parts) != 1 || parts[0].FunctionCall == nil { + t.Fatalf("expected 1 functionCall part, got %+v", parts) + } if parts[0].ThoughtSignature != dummyThoughtSignature { t.Fatalf("expected dummy tool signature %q, got %q", dummyThoughtSignature, parts[0].ThoughtSignature) } diff --git a/backend/internal/pkg/tlsfingerprint/dialer_integration_test.go b/backend/internal/pkg/tlsfingerprint/dialer_integration_test.go new file mode 100644 index 00000000..eea74fcc --- /dev/null +++ b/backend/internal/pkg/tlsfingerprint/dialer_integration_test.go @@ -0,0 +1,278 @@ +//go:build integration + +// Package tlsfingerprint provides TLS fingerprint simulation for HTTP clients. +// +// Integration tests for verifying TLS fingerprint correctness. +// These tests make actual network requests to external services and should be run manually. +// +// Run with: go test -v -tags=integration ./internal/pkg/tlsfingerprint/... +package tlsfingerprint + +import ( + "context" + "encoding/json" + "io" + "net/http" + "strings" + "testing" + "time" +) + +// skipIfExternalServiceUnavailable checks if the external service is available. +// If not, it skips the test instead of failing. +func skipIfExternalServiceUnavailable(t *testing.T, err error) { + t.Helper() + if err != nil { + // Check for common network/TLS errors that indicate external service issues + errStr := err.Error() + if strings.Contains(errStr, "certificate has expired") || + strings.Contains(errStr, "certificate is not yet valid") || + strings.Contains(errStr, "connection refused") || + strings.Contains(errStr, "no such host") || + strings.Contains(errStr, "network is unreachable") || + strings.Contains(errStr, "timeout") { + t.Skipf("skipping test: external service unavailable: %v", err) + } + t.Fatalf("failed to get fingerprint: %v", err) + } +} + +// TestJA3Fingerprint verifies the JA3/JA4 fingerprint matches expected value. +// This test uses tls.peet.ws to verify the fingerprint. +// Expected JA3 hash: 1a28e69016765d92e3b381168d68922c (Claude CLI / Node.js 20.x) +// Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4 (d=domain) or t13i5911h1_... (i=IP) +func TestJA3Fingerprint(t *testing.T) { + // Skip if network is unavailable or if running in short mode + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + profile := &Profile{ + Name: "Claude CLI Test", + EnableGREASE: false, + } + dialer := NewDialer(profile, nil) + + client := &http.Client{ + Transport: &http.Transport{ + DialTLSContext: dialer.DialTLSContext, + }, + Timeout: 30 * time.Second, + } + + // Use tls.peet.ws fingerprint detection API + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "GET", "https://tls.peet.ws/api/all", nil) + if err != nil { + t.Fatalf("failed to create request: %v", err) + } + req.Header.Set("User-Agent", "Claude Code/2.0.0 Node.js/20.0.0") + + resp, err := client.Do(req) + skipIfExternalServiceUnavailable(t, err) + defer func() { _ = resp.Body.Close() }() + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("failed to read response: %v", err) + } + + var fpResp FingerprintResponse + if err := json.Unmarshal(body, &fpResp); err != nil { + t.Logf("Response body: %s", string(body)) + t.Fatalf("failed to parse fingerprint response: %v", err) + } + + // Log all fingerprint information + t.Logf("JA3: %s", fpResp.TLS.JA3) + t.Logf("JA3 Hash: %s", fpResp.TLS.JA3Hash) + t.Logf("JA4: %s", fpResp.TLS.JA4) + t.Logf("PeetPrint: %s", fpResp.TLS.PeetPrint) + t.Logf("PeetPrint Hash: %s", fpResp.TLS.PeetPrintHash) + + // Verify JA3 hash matches expected value + expectedJA3Hash := "1a28e69016765d92e3b381168d68922c" + if fpResp.TLS.JA3Hash == expectedJA3Hash { + t.Logf("✓ JA3 hash matches expected value: %s", expectedJA3Hash) + } else { + t.Errorf("✗ JA3 hash mismatch: got %s, expected %s", fpResp.TLS.JA3Hash, expectedJA3Hash) + } + + // Verify JA4 fingerprint + // JA4 format: t[version][sni][cipher_count][ext_count][alpn]_[cipher_hash]_[ext_hash] + // Expected: t13d5910h1 (d=domain) or t13i5910h1 (i=IP) + // The suffix _a33745022dd6_1f22a2ca17c4 should match + expectedJA4Suffix := "_a33745022dd6_1f22a2ca17c4" + if strings.HasSuffix(fpResp.TLS.JA4, expectedJA4Suffix) { + t.Logf("✓ JA4 suffix matches expected value: %s", expectedJA4Suffix) + } else { + t.Errorf("✗ JA4 suffix mismatch: got %s, expected suffix %s", fpResp.TLS.JA4, expectedJA4Suffix) + } + + // Verify JA4 prefix (t13d5911h1 or t13i5911h1) + // d = domain (SNI present), i = IP (no SNI) + // Since we connect to tls.peet.ws (domain), we expect 'd' + expectedJA4Prefix := "t13d5911h1" + if strings.HasPrefix(fpResp.TLS.JA4, expectedJA4Prefix) { + t.Logf("✓ JA4 prefix matches: %s (t13=TLS1.3, d=domain, 59=ciphers, 11=extensions, h1=HTTP/1.1)", expectedJA4Prefix) + } else { + // Also accept 'i' variant for IP connections + altPrefix := "t13i5911h1" + if strings.HasPrefix(fpResp.TLS.JA4, altPrefix) { + t.Logf("✓ JA4 prefix matches (IP variant): %s", altPrefix) + } else { + t.Errorf("✗ JA4 prefix mismatch: got %s, expected %s or %s", fpResp.TLS.JA4, expectedJA4Prefix, altPrefix) + } + } + + // Verify JA3 contains expected cipher suites (TLS 1.3 ciphers at the beginning) + if strings.Contains(fpResp.TLS.JA3, "4866-4867-4865") { + t.Logf("✓ JA3 contains expected TLS 1.3 cipher suites") + } else { + t.Logf("Warning: JA3 does not contain expected TLS 1.3 cipher suites") + } + + // Verify extension list (should be 11 extensions including SNI) + // Expected: 0-11-10-35-16-22-23-13-43-45-51 + expectedExtensions := "0-11-10-35-16-22-23-13-43-45-51" + if strings.Contains(fpResp.TLS.JA3, expectedExtensions) { + t.Logf("✓ JA3 contains expected extension list: %s", expectedExtensions) + } else { + t.Logf("Warning: JA3 extension list may differ") + } +} + +// TestProfileExpectation defines expected fingerprint values for a profile. +type TestProfileExpectation struct { + Profile *Profile + ExpectedJA3 string // Expected JA3 hash (empty = don't check) + ExpectedJA4 string // Expected full JA4 (empty = don't check) + JA4CipherHash string // Expected JA4 cipher hash - the stable middle part (empty = don't check) +} + +// TestAllProfiles tests multiple TLS fingerprint profiles against tls.peet.ws. +// Run with: go test -v -tags=integration -run TestAllProfiles ./internal/pkg/tlsfingerprint/... +func TestAllProfiles(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + // Define all profiles to test with their expected fingerprints + // These profiles are from config.yaml gateway.tls_fingerprint.profiles + profiles := []TestProfileExpectation{ + { + // Linux x64 Node.js v22.17.1 + // Expected JA3 Hash: 1a28e69016765d92e3b381168d68922c + // Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4 + Profile: &Profile{ + Name: "linux_x64_node_v22171", + EnableGREASE: false, + CipherSuites: []uint16{4866, 4867, 4865, 49199, 49195, 49200, 49196, 158, 49191, 103, 49192, 107, 163, 159, 52393, 52392, 52394, 49327, 49325, 49315, 49311, 49245, 49249, 49239, 49235, 162, 49326, 49324, 49314, 49310, 49244, 49248, 49238, 49234, 49188, 106, 49187, 64, 49162, 49172, 57, 56, 49161, 49171, 51, 50, 157, 49313, 49309, 49233, 156, 49312, 49308, 49232, 61, 60, 53, 47, 255}, + Curves: []uint16{29, 23, 30, 25, 24, 256, 257, 258, 259, 260}, + PointFormats: []uint8{0, 1, 2}, + }, + JA4CipherHash: "a33745022dd6", // stable part + }, + { + // MacOS arm64 Node.js v22.18.0 + // Expected JA3 Hash: 70cb5ca646080902703ffda87036a5ea + // Expected JA4: t13d5912h1_a33745022dd6_dbd39dd1d406 + Profile: &Profile{ + Name: "macos_arm64_node_v22180", + EnableGREASE: false, + CipherSuites: []uint16{4866, 4867, 4865, 49199, 49195, 49200, 49196, 158, 49191, 103, 49192, 107, 163, 159, 52393, 52392, 52394, 49327, 49325, 49315, 49311, 49245, 49249, 49239, 49235, 162, 49326, 49324, 49314, 49310, 49244, 49248, 49238, 49234, 49188, 106, 49187, 64, 49162, 49172, 57, 56, 49161, 49171, 51, 50, 157, 49313, 49309, 49233, 156, 49312, 49308, 49232, 61, 60, 53, 47, 255}, + Curves: []uint16{29, 23, 30, 25, 24, 256, 257, 258, 259, 260}, + PointFormats: []uint8{0, 1, 2}, + }, + JA4CipherHash: "a33745022dd6", // stable part (same cipher suites) + }, + } + + for _, tc := range profiles { + tc := tc // capture range variable + t.Run(tc.Profile.Name, func(t *testing.T) { + fp := fetchFingerprint(t, tc.Profile) + if fp == nil { + return // fetchFingerprint already called t.Fatal + } + + t.Logf("Profile: %s", tc.Profile.Name) + t.Logf(" JA3: %s", fp.JA3) + t.Logf(" JA3 Hash: %s", fp.JA3Hash) + t.Logf(" JA4: %s", fp.JA4) + t.Logf(" PeetPrint: %s", fp.PeetPrint) + t.Logf(" PeetPrintHash: %s", fp.PeetPrintHash) + + // Verify expectations + if tc.ExpectedJA3 != "" { + if fp.JA3Hash == tc.ExpectedJA3 { + t.Logf(" ✓ JA3 hash matches: %s", tc.ExpectedJA3) + } else { + t.Errorf(" ✗ JA3 hash mismatch: got %s, expected %s", fp.JA3Hash, tc.ExpectedJA3) + } + } + + if tc.ExpectedJA4 != "" { + if fp.JA4 == tc.ExpectedJA4 { + t.Logf(" ✓ JA4 matches: %s", tc.ExpectedJA4) + } else { + t.Errorf(" ✗ JA4 mismatch: got %s, expected %s", fp.JA4, tc.ExpectedJA4) + } + } + + // Check JA4 cipher hash (stable middle part) + // JA4 format: prefix_cipherHash_extHash + if tc.JA4CipherHash != "" { + if strings.Contains(fp.JA4, "_"+tc.JA4CipherHash+"_") { + t.Logf(" ✓ JA4 cipher hash matches: %s", tc.JA4CipherHash) + } else { + t.Errorf(" ✗ JA4 cipher hash mismatch: got %s, expected cipher hash %s", fp.JA4, tc.JA4CipherHash) + } + } + }) + } +} + +// fetchFingerprint makes a request to tls.peet.ws and returns the TLS fingerprint info. +func fetchFingerprint(t *testing.T, profile *Profile) *TLSInfo { + t.Helper() + + dialer := NewDialer(profile, nil) + client := &http.Client{ + Transport: &http.Transport{ + DialTLSContext: dialer.DialTLSContext, + }, + Timeout: 30 * time.Second, + } + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "GET", "https://tls.peet.ws/api/all", nil) + if err != nil { + t.Fatalf("failed to create request: %v", err) + return nil + } + req.Header.Set("User-Agent", "Claude Code/2.0.0 Node.js/20.0.0") + + resp, err := client.Do(req) + skipIfExternalServiceUnavailable(t, err) + defer func() { _ = resp.Body.Close() }() + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("failed to read response: %v", err) + return nil + } + + var fpResp FingerprintResponse + if err := json.Unmarshal(body, &fpResp); err != nil { + t.Logf("Response body: %s", string(body)) + t.Fatalf("failed to parse fingerprint response: %v", err) + return nil + } + + return &fpResp.TLS +} diff --git a/backend/internal/pkg/tlsfingerprint/dialer_test.go b/backend/internal/pkg/tlsfingerprint/dialer_test.go index 845d51e5..dff7570f 100644 --- a/backend/internal/pkg/tlsfingerprint/dialer_test.go +++ b/backend/internal/pkg/tlsfingerprint/dialer_test.go @@ -1,21 +1,16 @@ // Package tlsfingerprint provides TLS fingerprint simulation for HTTP clients. // -// Integration tests for verifying TLS fingerprint correctness. -// These tests make actual network requests and should be run manually. +// Unit tests for TLS fingerprint dialer. +// Integration tests that require external network are in dialer_integration_test.go +// and require the 'integration' build tag. // -// Run with: go test -v ./internal/pkg/tlsfingerprint/... -// Run integration tests: go test -v -run TestJA3 ./internal/pkg/tlsfingerprint/... +// Run unit tests: go test -v ./internal/pkg/tlsfingerprint/... +// Run integration tests: go test -v -tags=integration ./internal/pkg/tlsfingerprint/... package tlsfingerprint import ( - "context" - "encoding/json" - "io" - "net/http" "net/url" - "strings" "testing" - "time" ) // FingerprintResponse represents the response from tls.peet.ws/api/all. @@ -36,148 +31,6 @@ type TLSInfo struct { SessionID string `json:"session_id"` } -// TestDialerBasicConnection tests that the dialer can establish TLS connections. -func TestDialerBasicConnection(t *testing.T) { - if testing.Short() { - t.Skip("skipping network test in short mode") - } - - // Create a dialer with default profile - profile := &Profile{ - Name: "Test Profile", - EnableGREASE: false, - } - dialer := NewDialer(profile, nil) - - // Create HTTP client with custom TLS dialer - client := &http.Client{ - Transport: &http.Transport{ - DialTLSContext: dialer.DialTLSContext, - }, - Timeout: 30 * time.Second, - } - - // Make a request to a known HTTPS endpoint - resp, err := client.Get("https://www.google.com") - if err != nil { - t.Fatalf("failed to connect: %v", err) - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - t.Errorf("expected status 200, got %d", resp.StatusCode) - } -} - -// TestJA3Fingerprint verifies the JA3/JA4 fingerprint matches expected value. -// This test uses tls.peet.ws to verify the fingerprint. -// Expected JA3 hash: 1a28e69016765d92e3b381168d68922c (Claude CLI / Node.js 20.x) -// Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4 (d=domain) or t13i5911h1_... (i=IP) -func TestJA3Fingerprint(t *testing.T) { - // Skip if network is unavailable or if running in short mode - if testing.Short() { - t.Skip("skipping integration test in short mode") - } - - profile := &Profile{ - Name: "Claude CLI Test", - EnableGREASE: false, - } - dialer := NewDialer(profile, nil) - - client := &http.Client{ - Transport: &http.Transport{ - DialTLSContext: dialer.DialTLSContext, - }, - Timeout: 30 * time.Second, - } - - // Use tls.peet.ws fingerprint detection API - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - - req, err := http.NewRequestWithContext(ctx, "GET", "https://tls.peet.ws/api/all", nil) - if err != nil { - t.Fatalf("failed to create request: %v", err) - } - req.Header.Set("User-Agent", "Claude Code/2.0.0 Node.js/20.0.0") - - resp, err := client.Do(req) - if err != nil { - t.Fatalf("failed to get fingerprint: %v", err) - } - defer func() { _ = resp.Body.Close() }() - - body, err := io.ReadAll(resp.Body) - if err != nil { - t.Fatalf("failed to read response: %v", err) - } - - var fpResp FingerprintResponse - if err := json.Unmarshal(body, &fpResp); err != nil { - t.Logf("Response body: %s", string(body)) - t.Fatalf("failed to parse fingerprint response: %v", err) - } - - // Log all fingerprint information - t.Logf("JA3: %s", fpResp.TLS.JA3) - t.Logf("JA3 Hash: %s", fpResp.TLS.JA3Hash) - t.Logf("JA4: %s", fpResp.TLS.JA4) - t.Logf("PeetPrint: %s", fpResp.TLS.PeetPrint) - t.Logf("PeetPrint Hash: %s", fpResp.TLS.PeetPrintHash) - - // Verify JA3 hash matches expected value - expectedJA3Hash := "1a28e69016765d92e3b381168d68922c" - if fpResp.TLS.JA3Hash == expectedJA3Hash { - t.Logf("✓ JA3 hash matches expected value: %s", expectedJA3Hash) - } else { - t.Errorf("✗ JA3 hash mismatch: got %s, expected %s", fpResp.TLS.JA3Hash, expectedJA3Hash) - } - - // Verify JA4 fingerprint - // JA4 format: t[version][sni][cipher_count][ext_count][alpn]_[cipher_hash]_[ext_hash] - // Expected: t13d5910h1 (d=domain) or t13i5910h1 (i=IP) - // The suffix _a33745022dd6_1f22a2ca17c4 should match - expectedJA4Suffix := "_a33745022dd6_1f22a2ca17c4" - if strings.HasSuffix(fpResp.TLS.JA4, expectedJA4Suffix) { - t.Logf("✓ JA4 suffix matches expected value: %s", expectedJA4Suffix) - } else { - t.Errorf("✗ JA4 suffix mismatch: got %s, expected suffix %s", fpResp.TLS.JA4, expectedJA4Suffix) - } - - // Verify JA4 prefix (t13d5911h1 or t13i5911h1) - // d = domain (SNI present), i = IP (no SNI) - // Since we connect to tls.peet.ws (domain), we expect 'd' - expectedJA4Prefix := "t13d5911h1" - if strings.HasPrefix(fpResp.TLS.JA4, expectedJA4Prefix) { - t.Logf("✓ JA4 prefix matches: %s (t13=TLS1.3, d=domain, 59=ciphers, 11=extensions, h1=HTTP/1.1)", expectedJA4Prefix) - } else { - // Also accept 'i' variant for IP connections - altPrefix := "t13i5911h1" - if strings.HasPrefix(fpResp.TLS.JA4, altPrefix) { - t.Logf("✓ JA4 prefix matches (IP variant): %s", altPrefix) - } else { - t.Errorf("✗ JA4 prefix mismatch: got %s, expected %s or %s", fpResp.TLS.JA4, expectedJA4Prefix, altPrefix) - } - } - - // Verify JA3 contains expected cipher suites (TLS 1.3 ciphers at the beginning) - if strings.Contains(fpResp.TLS.JA3, "4866-4867-4865") { - t.Logf("✓ JA3 contains expected TLS 1.3 cipher suites") - } else { - t.Logf("Warning: JA3 does not contain expected TLS 1.3 cipher suites") - } - - // Verify extension list (should be 11 extensions including SNI) - // Expected: 0-11-10-35-16-22-23-13-43-45-51 - expectedExtensions := "0-11-10-35-16-22-23-13-43-45-51" - if strings.Contains(fpResp.TLS.JA3, expectedExtensions) { - t.Logf("✓ JA3 contains expected extension list: %s", expectedExtensions) - } else { - t.Logf("Warning: JA3 extension list may differ") - } -} - // TestDialerWithProfile tests that different profiles produce different fingerprints. func TestDialerWithProfile(t *testing.T) { // Create two dialers with different profiles @@ -305,139 +158,3 @@ func mustParseURL(rawURL string) *url.URL { } return u } - -// TestProfileExpectation defines expected fingerprint values for a profile. -type TestProfileExpectation struct { - Profile *Profile - ExpectedJA3 string // Expected JA3 hash (empty = don't check) - ExpectedJA4 string // Expected full JA4 (empty = don't check) - JA4CipherHash string // Expected JA4 cipher hash - the stable middle part (empty = don't check) -} - -// TestAllProfiles tests multiple TLS fingerprint profiles against tls.peet.ws. -// Run with: go test -v -run TestAllProfiles ./internal/pkg/tlsfingerprint/... -func TestAllProfiles(t *testing.T) { - if testing.Short() { - t.Skip("skipping integration test in short mode") - } - - // Define all profiles to test with their expected fingerprints - // These profiles are from config.yaml gateway.tls_fingerprint.profiles - profiles := []TestProfileExpectation{ - { - // Linux x64 Node.js v22.17.1 - // Expected JA3 Hash: 1a28e69016765d92e3b381168d68922c - // Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4 - Profile: &Profile{ - Name: "linux_x64_node_v22171", - EnableGREASE: false, - CipherSuites: []uint16{4866, 4867, 4865, 49199, 49195, 49200, 49196, 158, 49191, 103, 49192, 107, 163, 159, 52393, 52392, 52394, 49327, 49325, 49315, 49311, 49245, 49249, 49239, 49235, 162, 49326, 49324, 49314, 49310, 49244, 49248, 49238, 49234, 49188, 106, 49187, 64, 49162, 49172, 57, 56, 49161, 49171, 51, 50, 157, 49313, 49309, 49233, 156, 49312, 49308, 49232, 61, 60, 53, 47, 255}, - Curves: []uint16{29, 23, 30, 25, 24, 256, 257, 258, 259, 260}, - PointFormats: []uint8{0, 1, 2}, - }, - JA4CipherHash: "a33745022dd6", // stable part - }, - { - // MacOS arm64 Node.js v22.18.0 - // Expected JA3 Hash: 70cb5ca646080902703ffda87036a5ea - // Expected JA4: t13d5912h1_a33745022dd6_dbd39dd1d406 - Profile: &Profile{ - Name: "macos_arm64_node_v22180", - EnableGREASE: false, - CipherSuites: []uint16{4866, 4867, 4865, 49199, 49195, 49200, 49196, 158, 49191, 103, 49192, 107, 163, 159, 52393, 52392, 52394, 49327, 49325, 49315, 49311, 49245, 49249, 49239, 49235, 162, 49326, 49324, 49314, 49310, 49244, 49248, 49238, 49234, 49188, 106, 49187, 64, 49162, 49172, 57, 56, 49161, 49171, 51, 50, 157, 49313, 49309, 49233, 156, 49312, 49308, 49232, 61, 60, 53, 47, 255}, - Curves: []uint16{29, 23, 30, 25, 24, 256, 257, 258, 259, 260}, - PointFormats: []uint8{0, 1, 2}, - }, - JA4CipherHash: "a33745022dd6", // stable part (same cipher suites) - }, - } - - for _, tc := range profiles { - tc := tc // capture range variable - t.Run(tc.Profile.Name, func(t *testing.T) { - fp := fetchFingerprint(t, tc.Profile) - if fp == nil { - return // fetchFingerprint already called t.Fatal - } - - t.Logf("Profile: %s", tc.Profile.Name) - t.Logf(" JA3: %s", fp.JA3) - t.Logf(" JA3 Hash: %s", fp.JA3Hash) - t.Logf(" JA4: %s", fp.JA4) - t.Logf(" PeetPrint: %s", fp.PeetPrint) - t.Logf(" PeetPrintHash: %s", fp.PeetPrintHash) - - // Verify expectations - if tc.ExpectedJA3 != "" { - if fp.JA3Hash == tc.ExpectedJA3 { - t.Logf(" ✓ JA3 hash matches: %s", tc.ExpectedJA3) - } else { - t.Errorf(" ✗ JA3 hash mismatch: got %s, expected %s", fp.JA3Hash, tc.ExpectedJA3) - } - } - - if tc.ExpectedJA4 != "" { - if fp.JA4 == tc.ExpectedJA4 { - t.Logf(" ✓ JA4 matches: %s", tc.ExpectedJA4) - } else { - t.Errorf(" ✗ JA4 mismatch: got %s, expected %s", fp.JA4, tc.ExpectedJA4) - } - } - - // Check JA4 cipher hash (stable middle part) - // JA4 format: prefix_cipherHash_extHash - if tc.JA4CipherHash != "" { - if strings.Contains(fp.JA4, "_"+tc.JA4CipherHash+"_") { - t.Logf(" ✓ JA4 cipher hash matches: %s", tc.JA4CipherHash) - } else { - t.Errorf(" ✗ JA4 cipher hash mismatch: got %s, expected cipher hash %s", fp.JA4, tc.JA4CipherHash) - } - } - }) - } -} - -// fetchFingerprint makes a request to tls.peet.ws and returns the TLS fingerprint info. -func fetchFingerprint(t *testing.T, profile *Profile) *TLSInfo { - t.Helper() - - dialer := NewDialer(profile, nil) - client := &http.Client{ - Transport: &http.Transport{ - DialTLSContext: dialer.DialTLSContext, - }, - Timeout: 30 * time.Second, - } - - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - - req, err := http.NewRequestWithContext(ctx, "GET", "https://tls.peet.ws/api/all", nil) - if err != nil { - t.Fatalf("failed to create request: %v", err) - return nil - } - req.Header.Set("User-Agent", "Claude Code/2.0.0 Node.js/20.0.0") - - resp, err := client.Do(req) - if err != nil { - t.Fatalf("failed to get fingerprint: %v", err) - return nil - } - defer func() { _ = resp.Body.Close() }() - - body, err := io.ReadAll(resp.Body) - if err != nil { - t.Fatalf("failed to read response: %v", err) - return nil - } - - var fpResp FingerprintResponse - if err := json.Unmarshal(body, &fpResp); err != nil { - t.Logf("Response body: %s", string(body)) - t.Fatalf("failed to parse fingerprint response: %v", err) - return nil - } - - return &fpResp.TLS -} diff --git a/backend/internal/repository/aes_encryptor.go b/backend/internal/repository/aes_encryptor.go new file mode 100644 index 00000000..924e3698 --- /dev/null +++ b/backend/internal/repository/aes_encryptor.go @@ -0,0 +1,95 @@ +package repository + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "encoding/base64" + "encoding/hex" + "fmt" + "io" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/service" +) + +// AESEncryptor implements SecretEncryptor using AES-256-GCM +type AESEncryptor struct { + key []byte +} + +// NewAESEncryptor creates a new AES encryptor +func NewAESEncryptor(cfg *config.Config) (service.SecretEncryptor, error) { + key, err := hex.DecodeString(cfg.Totp.EncryptionKey) + if err != nil { + return nil, fmt.Errorf("invalid totp encryption key: %w", err) + } + + if len(key) != 32 { + return nil, fmt.Errorf("totp encryption key must be 32 bytes (64 hex chars), got %d bytes", len(key)) + } + + return &AESEncryptor{key: key}, nil +} + +// Encrypt encrypts plaintext using AES-256-GCM +// Output format: base64(nonce + ciphertext + tag) +func (e *AESEncryptor) Encrypt(plaintext string) (string, error) { + block, err := aes.NewCipher(e.key) + if err != nil { + return "", fmt.Errorf("create cipher: %w", err) + } + + gcm, err := cipher.NewGCM(block) + if err != nil { + return "", fmt.Errorf("create gcm: %w", err) + } + + // Generate a random nonce + nonce := make([]byte, gcm.NonceSize()) + if _, err := io.ReadFull(rand.Reader, nonce); err != nil { + return "", fmt.Errorf("generate nonce: %w", err) + } + + // Encrypt the plaintext + // Seal appends the ciphertext and tag to the nonce + ciphertext := gcm.Seal(nonce, nonce, []byte(plaintext), nil) + + // Encode as base64 + return base64.StdEncoding.EncodeToString(ciphertext), nil +} + +// Decrypt decrypts ciphertext using AES-256-GCM +func (e *AESEncryptor) Decrypt(ciphertext string) (string, error) { + // Decode from base64 + data, err := base64.StdEncoding.DecodeString(ciphertext) + if err != nil { + return "", fmt.Errorf("decode base64: %w", err) + } + + block, err := aes.NewCipher(e.key) + if err != nil { + return "", fmt.Errorf("create cipher: %w", err) + } + + gcm, err := cipher.NewGCM(block) + if err != nil { + return "", fmt.Errorf("create gcm: %w", err) + } + + nonceSize := gcm.NonceSize() + if len(data) < nonceSize { + return "", fmt.Errorf("ciphertext too short") + } + + // Extract nonce and ciphertext + nonce, ciphertextData := data[:nonceSize], data[nonceSize:] + + // Decrypt + plaintext, err := gcm.Open(nil, nonce, ciphertextData, nil) + if err != nil { + return "", fmt.Errorf("decrypt: %w", err) + } + + return string(plaintext), nil +} diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go index ab890844..1e5a62df 100644 --- a/backend/internal/repository/api_key_repo.go +++ b/backend/internal/repository/api_key_repo.go @@ -387,17 +387,20 @@ func userEntityToService(u *dbent.User) *service.User { return nil } return &service.User{ - ID: u.ID, - Email: u.Email, - Username: u.Username, - Notes: u.Notes, - PasswordHash: u.PasswordHash, - Role: u.Role, - Balance: u.Balance, - Concurrency: u.Concurrency, - Status: u.Status, - CreatedAt: u.CreatedAt, - UpdatedAt: u.UpdatedAt, + ID: u.ID, + Email: u.Email, + Username: u.Username, + Notes: u.Notes, + PasswordHash: u.PasswordHash, + Role: u.Role, + Balance: u.Balance, + Concurrency: u.Concurrency, + Status: u.Status, + TotpSecretEncrypted: u.TotpSecretEncrypted, + TotpEnabled: u.TotpEnabled, + TotpEnabledAt: u.TotpEnabledAt, + CreatedAt: u.CreatedAt, + UpdatedAt: u.UpdatedAt, } } diff --git a/backend/internal/repository/email_cache.go b/backend/internal/repository/email_cache.go index e00e35dd..8f2b8eca 100644 --- a/backend/internal/repository/email_cache.go +++ b/backend/internal/repository/email_cache.go @@ -9,13 +9,27 @@ import ( "github.com/redis/go-redis/v9" ) -const verifyCodeKeyPrefix = "verify_code:" +const ( + verifyCodeKeyPrefix = "verify_code:" + passwordResetKeyPrefix = "password_reset:" + passwordResetSentAtKeyPrefix = "password_reset_sent:" +) // verifyCodeKey generates the Redis key for email verification code. func verifyCodeKey(email string) string { return verifyCodeKeyPrefix + email } +// passwordResetKey generates the Redis key for password reset token. +func passwordResetKey(email string) string { + return passwordResetKeyPrefix + email +} + +// passwordResetSentAtKey generates the Redis key for password reset email sent timestamp. +func passwordResetSentAtKey(email string) string { + return passwordResetSentAtKeyPrefix + email +} + type emailCache struct { rdb *redis.Client } @@ -50,3 +64,45 @@ func (c *emailCache) DeleteVerificationCode(ctx context.Context, email string) e key := verifyCodeKey(email) return c.rdb.Del(ctx, key).Err() } + +// Password reset token methods + +func (c *emailCache) GetPasswordResetToken(ctx context.Context, email string) (*service.PasswordResetTokenData, error) { + key := passwordResetKey(email) + val, err := c.rdb.Get(ctx, key).Result() + if err != nil { + return nil, err + } + var data service.PasswordResetTokenData + if err := json.Unmarshal([]byte(val), &data); err != nil { + return nil, err + } + return &data, nil +} + +func (c *emailCache) SetPasswordResetToken(ctx context.Context, email string, data *service.PasswordResetTokenData, ttl time.Duration) error { + key := passwordResetKey(email) + val, err := json.Marshal(data) + if err != nil { + return err + } + return c.rdb.Set(ctx, key, val, ttl).Err() +} + +func (c *emailCache) DeletePasswordResetToken(ctx context.Context, email string) error { + key := passwordResetKey(email) + return c.rdb.Del(ctx, key).Err() +} + +// Password reset email cooldown methods + +func (c *emailCache) IsPasswordResetEmailInCooldown(ctx context.Context, email string) bool { + key := passwordResetSentAtKey(email) + exists, err := c.rdb.Exists(ctx, key).Result() + return err == nil && exists > 0 +} + +func (c *emailCache) SetPasswordResetEmailCooldown(ctx context.Context, email string, ttl time.Duration) error { + key := passwordResetSentAtKey(email) + return c.rdb.Set(ctx, key, "1", ttl).Err() +} diff --git a/backend/internal/repository/totp_cache.go b/backend/internal/repository/totp_cache.go new file mode 100644 index 00000000..2f4a8ab2 --- /dev/null +++ b/backend/internal/repository/totp_cache.go @@ -0,0 +1,149 @@ +package repository + +import ( + "context" + "encoding/json" + "fmt" + "time" + + "github.com/redis/go-redis/v9" + + "github.com/Wei-Shaw/sub2api/internal/service" +) + +const ( + totpSetupKeyPrefix = "totp:setup:" + totpLoginKeyPrefix = "totp:login:" + totpAttemptsKeyPrefix = "totp:attempts:" + totpAttemptsTTL = 15 * time.Minute +) + +// TotpCache implements service.TotpCache using Redis +type TotpCache struct { + rdb *redis.Client +} + +// NewTotpCache creates a new TOTP cache +func NewTotpCache(rdb *redis.Client) service.TotpCache { + return &TotpCache{rdb: rdb} +} + +// GetSetupSession retrieves a TOTP setup session +func (c *TotpCache) GetSetupSession(ctx context.Context, userID int64) (*service.TotpSetupSession, error) { + key := fmt.Sprintf("%s%d", totpSetupKeyPrefix, userID) + data, err := c.rdb.Get(ctx, key).Bytes() + if err != nil { + if err == redis.Nil { + return nil, nil + } + return nil, fmt.Errorf("get setup session: %w", err) + } + + var session service.TotpSetupSession + if err := json.Unmarshal(data, &session); err != nil { + return nil, fmt.Errorf("unmarshal setup session: %w", err) + } + + return &session, nil +} + +// SetSetupSession stores a TOTP setup session +func (c *TotpCache) SetSetupSession(ctx context.Context, userID int64, session *service.TotpSetupSession, ttl time.Duration) error { + key := fmt.Sprintf("%s%d", totpSetupKeyPrefix, userID) + data, err := json.Marshal(session) + if err != nil { + return fmt.Errorf("marshal setup session: %w", err) + } + + if err := c.rdb.Set(ctx, key, data, ttl).Err(); err != nil { + return fmt.Errorf("set setup session: %w", err) + } + + return nil +} + +// DeleteSetupSession deletes a TOTP setup session +func (c *TotpCache) DeleteSetupSession(ctx context.Context, userID int64) error { + key := fmt.Sprintf("%s%d", totpSetupKeyPrefix, userID) + return c.rdb.Del(ctx, key).Err() +} + +// GetLoginSession retrieves a TOTP login session +func (c *TotpCache) GetLoginSession(ctx context.Context, tempToken string) (*service.TotpLoginSession, error) { + key := totpLoginKeyPrefix + tempToken + data, err := c.rdb.Get(ctx, key).Bytes() + if err != nil { + if err == redis.Nil { + return nil, nil + } + return nil, fmt.Errorf("get login session: %w", err) + } + + var session service.TotpLoginSession + if err := json.Unmarshal(data, &session); err != nil { + return nil, fmt.Errorf("unmarshal login session: %w", err) + } + + return &session, nil +} + +// SetLoginSession stores a TOTP login session +func (c *TotpCache) SetLoginSession(ctx context.Context, tempToken string, session *service.TotpLoginSession, ttl time.Duration) error { + key := totpLoginKeyPrefix + tempToken + data, err := json.Marshal(session) + if err != nil { + return fmt.Errorf("marshal login session: %w", err) + } + + if err := c.rdb.Set(ctx, key, data, ttl).Err(); err != nil { + return fmt.Errorf("set login session: %w", err) + } + + return nil +} + +// DeleteLoginSession deletes a TOTP login session +func (c *TotpCache) DeleteLoginSession(ctx context.Context, tempToken string) error { + key := totpLoginKeyPrefix + tempToken + return c.rdb.Del(ctx, key).Err() +} + +// IncrementVerifyAttempts increments the verify attempt counter +func (c *TotpCache) IncrementVerifyAttempts(ctx context.Context, userID int64) (int, error) { + key := fmt.Sprintf("%s%d", totpAttemptsKeyPrefix, userID) + + // Use pipeline for atomic increment and set TTL + pipe := c.rdb.Pipeline() + incrCmd := pipe.Incr(ctx, key) + pipe.Expire(ctx, key, totpAttemptsTTL) + + if _, err := pipe.Exec(ctx); err != nil { + return 0, fmt.Errorf("increment verify attempts: %w", err) + } + + count, err := incrCmd.Result() + if err != nil { + return 0, fmt.Errorf("get increment result: %w", err) + } + + return int(count), nil +} + +// GetVerifyAttempts gets the current verify attempt count +func (c *TotpCache) GetVerifyAttempts(ctx context.Context, userID int64) (int, error) { + key := fmt.Sprintf("%s%d", totpAttemptsKeyPrefix, userID) + count, err := c.rdb.Get(ctx, key).Int() + if err != nil { + if err == redis.Nil { + return 0, nil + } + return 0, fmt.Errorf("get verify attempts: %w", err) + } + return count, nil +} + +// ClearVerifyAttempts clears the verify attempt counter +func (c *TotpCache) ClearVerifyAttempts(ctx context.Context, userID int64) error { + key := fmt.Sprintf("%s%d", totpAttemptsKeyPrefix, userID) + return c.rdb.Del(ctx, key).Err() +} diff --git a/backend/internal/repository/user_repo.go b/backend/internal/repository/user_repo.go index 006a5464..fe5b645c 100644 --- a/backend/internal/repository/user_repo.go +++ b/backend/internal/repository/user_repo.go @@ -7,6 +7,7 @@ import ( "fmt" "sort" "strings" + "time" dbent "github.com/Wei-Shaw/sub2api/ent" dbuser "github.com/Wei-Shaw/sub2api/ent/user" @@ -466,3 +467,46 @@ func applyUserEntityToService(dst *service.User, src *dbent.User) { dst.CreatedAt = src.CreatedAt dst.UpdatedAt = src.UpdatedAt } + +// UpdateTotpSecret 更新用户的 TOTP 加密密钥 +func (r *userRepository) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error { + client := clientFromContext(ctx, r.client) + update := client.User.UpdateOneID(userID) + if encryptedSecret == nil { + update = update.ClearTotpSecretEncrypted() + } else { + update = update.SetTotpSecretEncrypted(*encryptedSecret) + } + _, err := update.Save(ctx) + if err != nil { + return translatePersistenceError(err, service.ErrUserNotFound, nil) + } + return nil +} + +// EnableTotp 启用用户的 TOTP 双因素认证 +func (r *userRepository) EnableTotp(ctx context.Context, userID int64) error { + client := clientFromContext(ctx, r.client) + _, err := client.User.UpdateOneID(userID). + SetTotpEnabled(true). + SetTotpEnabledAt(time.Now()). + Save(ctx) + if err != nil { + return translatePersistenceError(err, service.ErrUserNotFound, nil) + } + return nil +} + +// DisableTotp 禁用用户的 TOTP 双因素认证 +func (r *userRepository) DisableTotp(ctx context.Context, userID int64) error { + client := clientFromContext(ctx, r.client) + _, err := client.User.UpdateOneID(userID). + SetTotpEnabled(false). + ClearTotpEnabledAt(). + ClearTotpSecretEncrypted(). + Save(ctx) + if err != nil { + return translatePersistenceError(err, service.ErrUserNotFound, nil) + } + return nil +} diff --git a/backend/internal/repository/user_subscription_repo.go b/backend/internal/repository/user_subscription_repo.go index cd3b9db6..5a649846 100644 --- a/backend/internal/repository/user_subscription_repo.go +++ b/backend/internal/repository/user_subscription_repo.go @@ -190,7 +190,7 @@ func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID return userSubscriptionEntitiesToService(subs), paginationResultFromTotal(int64(total), params), nil } -func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]service.UserSubscription, *pagination.PaginationResult, error) { +func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) { client := clientFromContext(ctx, r.client) q := client.UserSubscription.Query() if userID != nil { @@ -199,7 +199,31 @@ func (r *userSubscriptionRepository) List(ctx context.Context, params pagination if groupID != nil { q = q.Where(usersubscription.GroupIDEQ(*groupID)) } - if status != "" { + + // Status filtering with real-time expiration check + now := time.Now() + switch status { + case service.SubscriptionStatusActive: + // Active: status is active AND not yet expired + q = q.Where( + usersubscription.StatusEQ(service.SubscriptionStatusActive), + usersubscription.ExpiresAtGT(now), + ) + case service.SubscriptionStatusExpired: + // Expired: status is expired OR (status is active but already expired) + q = q.Where( + usersubscription.Or( + usersubscription.StatusEQ(service.SubscriptionStatusExpired), + usersubscription.And( + usersubscription.StatusEQ(service.SubscriptionStatusActive), + usersubscription.ExpiresAtLTE(now), + ), + ), + ) + case "": + // No filter + default: + // Other status (e.g., revoked) q = q.Where(usersubscription.StatusEQ(status)) } @@ -208,11 +232,28 @@ func (r *userSubscriptionRepository) List(ctx context.Context, params pagination return nil, nil, err } + // Apply sorting + q = q.WithUser().WithGroup().WithAssignedByUser() + + // Determine sort field + var field string + switch sortBy { + case "expires_at": + field = usersubscription.FieldExpiresAt + case "status": + field = usersubscription.FieldStatus + default: + field = usersubscription.FieldCreatedAt + } + + // Determine sort order (default: desc) + if sortOrder == "asc" && sortBy != "" { + q = q.Order(dbent.Asc(field)) + } else { + q = q.Order(dbent.Desc(field)) + } + subs, err := q. - WithUser(). - WithGroup(). - WithAssignedByUser(). - Order(dbent.Desc(usersubscription.FieldCreatedAt)). Offset(params.Offset()). Limit(params.Limit()). All(ctx) diff --git a/backend/internal/repository/user_subscription_repo_integration_test.go b/backend/internal/repository/user_subscription_repo_integration_test.go index 2099e5d8..60a5a378 100644 --- a/backend/internal/repository/user_subscription_repo_integration_test.go +++ b/backend/internal/repository/user_subscription_repo_integration_test.go @@ -271,7 +271,7 @@ func (s *UserSubscriptionRepoSuite) TestList_NoFilters() { group := s.mustCreateGroup("g-list") s.mustCreateSubscription(user.ID, group.ID, nil) - subs, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, "") + subs, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, "", "", "") s.Require().NoError(err, "List") s.Require().Len(subs, 1) s.Require().Equal(int64(1), page.Total) @@ -285,7 +285,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByUserID() { s.mustCreateSubscription(user1.ID, group.ID, nil) s.mustCreateSubscription(user2.ID, group.ID, nil) - subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, &user1.ID, nil, "") + subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, &user1.ID, nil, "", "", "") s.Require().NoError(err) s.Require().Len(subs, 1) s.Require().Equal(user1.ID, subs[0].UserID) @@ -299,7 +299,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByGroupID() { s.mustCreateSubscription(user.ID, g1.ID, nil) s.mustCreateSubscription(user.ID, g2.ID, nil) - subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, &g1.ID, "") + subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, &g1.ID, "", "", "") s.Require().NoError(err) s.Require().Len(subs, 1) s.Require().Equal(g1.ID, subs[0].GroupID) @@ -320,7 +320,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByStatus() { c.SetExpiresAt(time.Now().Add(-24 * time.Hour)) }) - subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, service.SubscriptionStatusExpired) + subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, service.SubscriptionStatusExpired, "", "") s.Require().NoError(err) s.Require().Len(subs, 1) s.Require().Equal(service.SubscriptionStatusExpired, subs[0].Status) diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go index 7a8d85f4..3e1c05fc 100644 --- a/backend/internal/repository/wire.go +++ b/backend/internal/repository/wire.go @@ -82,6 +82,10 @@ var ProviderSet = wire.NewSet( NewSchedulerCache, NewSchedulerOutboxRepository, NewProxyLatencyCache, + NewTotpCache, + + // Encryptors + NewAESEncryptor, // HTTP service ports (DI Strategy A: return interface directly) NewTurnstileVerifier, diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index 244dc0b8..014e95e2 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -197,7 +197,7 @@ func TestAPIContracts(t *testing.T) { UserID: 1, GroupID: 10, StartsAt: deps.now, - ExpiresAt: deps.now.Add(24 * time.Hour), + ExpiresAt: time.Date(2099, 1, 2, 3, 4, 5, 0, time.UTC), // 使用未来日期避免 normalizeSubscriptionStatus 标记为过期 Status: service.SubscriptionStatusActive, DailyUsageUSD: 1.23, WeeklyUsageUSD: 2.34, @@ -222,7 +222,7 @@ func TestAPIContracts(t *testing.T) { "user_id": 1, "group_id": 10, "starts_at": "2025-01-02T03:04:05Z", - "expires_at": "2025-01-03T03:04:05Z", + "expires_at": "2099-01-02T03:04:05Z", "status": "active", "daily_window_start": null, "weekly_window_start": null, @@ -452,6 +452,9 @@ func TestAPIContracts(t *testing.T) { "registration_enabled": true, "email_verify_enabled": false, "promo_code_enabled": true, + "password_reset_enabled": false, + "totp_enabled": false, + "totp_encryption_key_configured": false, "smtp_host": "smtp.example.com", "smtp_port": 587, "smtp_username": "user", @@ -595,7 +598,7 @@ func newContractDeps(t *testing.T) *contractDeps { settingService := service.NewSettingService(settingRepo, cfg) adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil) - authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil) + authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, nil) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) usageHandler := handler.NewUsageHandler(usageService, apiKeyService) adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil) @@ -754,6 +757,18 @@ func (r *stubUserRepo) RemoveGroupFromAllowedGroups(ctx context.Context, groupID return 0, errors.New("not implemented") } +func (r *stubUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error { + return errors.New("not implemented") +} + +func (r *stubUserRepo) EnableTotp(ctx context.Context, userID int64) error { + return errors.New("not implemented") +} + +func (r *stubUserRepo) DisableTotp(ctx context.Context, userID int64) error { + return errors.New("not implemented") +} + type stubApiKeyCache struct{} func (stubApiKeyCache) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) { @@ -1176,7 +1191,7 @@ func (r *stubUserSubscriptionRepo) ListActiveByUserID(ctx context.Context, userI func (stubUserSubscriptionRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) { return nil, nil, errors.New("not implemented") } -func (stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]service.UserSubscription, *pagination.PaginationResult, error) { +func (stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) { return nil, nil, errors.New("not implemented") } func (stubUserSubscriptionRepo) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) { diff --git a/backend/internal/server/middleware/api_key_auth_test.go b/backend/internal/server/middleware/api_key_auth_test.go index 84398093..920ff93f 100644 --- a/backend/internal/server/middleware/api_key_auth_test.go +++ b/backend/internal/server/middleware/api_key_auth_test.go @@ -367,7 +367,7 @@ func (r *stubUserSubscriptionRepo) ListByGroupID(ctx context.Context, groupID in return nil, nil, errors.New("not implemented") } -func (r *stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]service.UserSubscription, *pagination.PaginationResult, error) { +func (r *stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) { return nil, nil, errors.New("not implemented") } diff --git a/backend/internal/server/routes/auth.go b/backend/internal/server/routes/auth.go index aa691eba..33a88e82 100644 --- a/backend/internal/server/routes/auth.go +++ b/backend/internal/server/routes/auth.go @@ -26,11 +26,20 @@ func RegisterAuthRoutes( { auth.POST("/register", h.Auth.Register) auth.POST("/login", h.Auth.Login) + auth.POST("/login/2fa", h.Auth.Login2FA) auth.POST("/send-verify-code", h.Auth.SendVerifyCode) // 优惠码验证接口添加速率限制:每分钟最多 10 次(Redis 故障时 fail-close) auth.POST("/validate-promo-code", rateLimiter.LimitWithOptions("validate-promo", 10, time.Minute, middleware.RateLimitOptions{ FailureMode: middleware.RateLimitFailClose, }), h.Auth.ValidatePromoCode) + // 忘记密码接口添加速率限制:每分钟最多 5 次(Redis 故障时 fail-close) + auth.POST("/forgot-password", rateLimiter.LimitWithOptions("forgot-password", 5, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), h.Auth.ForgotPassword) + // 重置密码接口添加速率限制:每分钟最多 10 次(Redis 故障时 fail-close) + auth.POST("/reset-password", rateLimiter.LimitWithOptions("reset-password", 10, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), h.Auth.ResetPassword) auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart) auth.GET("/oauth/linuxdo/callback", h.Auth.LinuxDoOAuthCallback) } diff --git a/backend/internal/server/routes/user.go b/backend/internal/server/routes/user.go index ad2166fe..83cf31c4 100644 --- a/backend/internal/server/routes/user.go +++ b/backend/internal/server/routes/user.go @@ -22,6 +22,17 @@ func RegisterUserRoutes( user.GET("/profile", h.User.GetProfile) user.PUT("/password", h.User.ChangePassword) user.PUT("", h.User.UpdateProfile) + + // TOTP 双因素认证 + totp := user.Group("/totp") + { + totp.GET("/status", h.Totp.GetStatus) + totp.GET("/verification-method", h.Totp.GetVerificationMethod) + totp.POST("/send-code", h.Totp.SendVerifyCode) + totp.POST("/setup", h.Totp.InitiateSetup) + totp.POST("/enable", h.Totp.Enable) + totp.POST("/disable", h.Totp.Disable) + } } // API Key管理 diff --git a/backend/internal/service/admin_service_delete_test.go b/backend/internal/service/admin_service_delete_test.go index afa433af..6472ccbb 100644 --- a/backend/internal/service/admin_service_delete_test.go +++ b/backend/internal/service/admin_service_delete_test.go @@ -93,6 +93,18 @@ func (s *userRepoStub) RemoveGroupFromAllowedGroups(ctx context.Context, groupID panic("unexpected RemoveGroupFromAllowedGroups call") } +func (s *userRepoStub) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error { + panic("unexpected UpdateTotpSecret call") +} + +func (s *userRepoStub) EnableTotp(ctx context.Context, userID int64) error { + panic("unexpected EnableTotp call") +} + +func (s *userRepoStub) DisableTotp(ctx context.Context, userID int64) error { + panic("unexpected DisableTotp call") +} + type groupRepoStub struct { affectedUserIDs []int64 deleteErr error diff --git a/backend/internal/service/antigravity_oauth_service.go b/backend/internal/service/antigravity_oauth_service.go index 52293cd5..fa8379ed 100644 --- a/backend/internal/service/antigravity_oauth_service.go +++ b/backend/internal/service/antigravity_oauth_service.go @@ -142,12 +142,13 @@ func (s *AntigravityOAuthService) ExchangeCode(ctx context.Context, input *Antig result.Email = userInfo.Email } - // 获取 project_id(部分账户类型可能没有) - loadResp, _, err := client.LoadCodeAssist(ctx, tokenResp.AccessToken) - if err != nil { - fmt.Printf("[AntigravityOAuth] 警告: 获取 project_id 失败: %v\n", err) - } else if loadResp != nil && loadResp.CloudAICompanionProject != "" { - result.ProjectID = loadResp.CloudAICompanionProject + // 获取 project_id(部分账户类型可能没有),失败时重试 + projectID, loadErr := s.loadProjectIDWithRetry(ctx, tokenResp.AccessToken, proxyURL, 3) + if loadErr != nil { + fmt.Printf("[AntigravityOAuth] 警告: 获取 project_id 失败(重试后): %v\n", loadErr) + result.ProjectIDMissing = true + } else { + result.ProjectID = projectID } return result, nil @@ -237,21 +238,60 @@ func (s *AntigravityOAuthService) RefreshAccountToken(ctx context.Context, accou tokenInfo.Email = existingEmail } - // 每次刷新都调用 LoadCodeAssist 获取 project_id - client := antigravity.NewClient(proxyURL) - loadResp, _, err := client.LoadCodeAssist(ctx, tokenInfo.AccessToken) - if err != nil || loadResp == nil || loadResp.CloudAICompanionProject == "" { - // LoadCodeAssist 失败或返回空,保留原有 project_id,标记缺失 - existingProjectID := strings.TrimSpace(account.GetCredential("project_id")) + // 每次刷新都调用 LoadCodeAssist 获取 project_id,失败时重试 + existingProjectID := strings.TrimSpace(account.GetCredential("project_id")) + projectID, loadErr := s.loadProjectIDWithRetry(ctx, tokenInfo.AccessToken, proxyURL, 3) + + if loadErr != nil { + // LoadCodeAssist 失败,保留原有 project_id tokenInfo.ProjectID = existingProjectID - tokenInfo.ProjectIDMissing = true + // 只有从未获取过 project_id 且本次也获取失败时,才标记为真正缺失 + // 如果之前有 project_id,本次只是临时故障,不应标记为错误 + if existingProjectID == "" { + tokenInfo.ProjectIDMissing = true + } } else { - tokenInfo.ProjectID = loadResp.CloudAICompanionProject + tokenInfo.ProjectID = projectID } return tokenInfo, nil } +// loadProjectIDWithRetry 带重试机制获取 project_id +// 返回 project_id 和错误,失败时会重试指定次数 +func (s *AntigravityOAuthService) loadProjectIDWithRetry(ctx context.Context, accessToken, proxyURL string, maxRetries int) (string, error) { + var lastErr error + + for attempt := 0; attempt <= maxRetries; attempt++ { + if attempt > 0 { + // 指数退避:1s, 2s, 4s + backoff := time.Duration(1< 8*time.Second { + backoff = 8 * time.Second + } + time.Sleep(backoff) + } + + client := antigravity.NewClient(proxyURL) + loadResp, _, err := client.LoadCodeAssist(ctx, accessToken) + + if err == nil && loadResp != nil && loadResp.CloudAICompanionProject != "" { + return loadResp.CloudAICompanionProject, nil + } + + // 记录错误 + if err != nil { + lastErr = err + } else if loadResp == nil { + lastErr = fmt.Errorf("LoadCodeAssist 返回空响应") + } else { + lastErr = fmt.Errorf("LoadCodeAssist 返回空 project_id") + } + } + + return "", fmt.Errorf("获取 project_id 失败 (重试 %d 次后): %w", maxRetries, lastErr) +} + // BuildAccountCredentials 构建账户凭证 func (s *AntigravityOAuthService) BuildAccountCredentials(tokenInfo *AntigravityTokenInfo) map[string]any { creds := map[string]any{ diff --git a/backend/internal/service/antigravity_token_refresher.go b/backend/internal/service/antigravity_token_refresher.go index a07c86e6..e33f88d0 100644 --- a/backend/internal/service/antigravity_token_refresher.go +++ b/backend/internal/service/antigravity_token_refresher.go @@ -3,6 +3,8 @@ package service import ( "context" "fmt" + "log" + "strings" "time" ) @@ -55,15 +57,32 @@ func (r *AntigravityTokenRefresher) Refresh(ctx context.Context, account *Accoun } newCredentials := r.antigravityOAuthService.BuildAccountCredentials(tokenInfo) + // 合并旧的 credentials,保留新 credentials 中不存在的字段 for k, v := range account.Credentials { if _, exists := newCredentials[k]; !exists { newCredentials[k] = v } } - // 如果 project_id 获取失败,返回 credentials 但同时返回错误让账户被标记 + // 特殊处理 project_id:如果新值为空但旧值非空,保留旧值 + // 这确保了即使 LoadCodeAssist 失败,project_id 也不会丢失 + if newProjectID, _ := newCredentials["project_id"].(string); newProjectID == "" { + if oldProjectID := strings.TrimSpace(account.GetCredential("project_id")); oldProjectID != "" { + newCredentials["project_id"] = oldProjectID + } + } + + // 如果 project_id 获取失败,只记录警告,不返回错误 + // LoadCodeAssist 失败可能是临时网络问题,应该允许重试而不是立即标记为不可重试错误 + // Token 刷新本身是成功的(access_token 和 refresh_token 已更新) if tokenInfo.ProjectIDMissing { - return newCredentials, fmt.Errorf("missing_project_id: 账户缺少project id,可能无法使用Antigravity") + if tokenInfo.ProjectID != "" { + // 有旧的 project_id,本次获取失败,保留旧值 + log.Printf("[AntigravityTokenRefresher] Account %d: LoadCodeAssist 临时失败,保留旧 project_id", account.ID) + } else { + // 从未获取过 project_id,本次也失败,但不返回错误以允许下次重试 + log.Printf("[AntigravityTokenRefresher] Account %d: LoadCodeAssist 失败,project_id 缺失,但 token 已更新,将在下次刷新时重试", account.ID) + } } return newCredentials, nil diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go index ab3ed116..a807b240 100644 --- a/backend/internal/service/auth_service.go +++ b/backend/internal/service/auth_service.go @@ -580,3 +580,149 @@ func (s *AuthService) RefreshToken(ctx context.Context, oldTokenString string) ( // 生成新token return s.GenerateToken(user) } + +// IsPasswordResetEnabled 检查是否启用密码重置功能 +// 要求:必须同时开启邮件验证且 SMTP 配置正确 +func (s *AuthService) IsPasswordResetEnabled(ctx context.Context) bool { + if s.settingService == nil { + return false + } + // Must have email verification enabled and SMTP configured + if !s.settingService.IsEmailVerifyEnabled(ctx) { + return false + } + return s.settingService.IsPasswordResetEnabled(ctx) +} + +// preparePasswordReset validates the password reset request and returns necessary data +// Returns (siteName, resetURL, shouldProceed) +// shouldProceed is false when we should silently return success (to prevent enumeration) +func (s *AuthService) preparePasswordReset(ctx context.Context, email, frontendBaseURL string) (string, string, bool) { + // Check if user exists (but don't reveal this to the caller) + user, err := s.userRepo.GetByEmail(ctx, email) + if err != nil { + if errors.Is(err, ErrUserNotFound) { + // Security: Log but don't reveal that user doesn't exist + log.Printf("[Auth] Password reset requested for non-existent email: %s", email) + return "", "", false + } + log.Printf("[Auth] Database error checking email for password reset: %v", err) + return "", "", false + } + + // Check if user is active + if !user.IsActive() { + log.Printf("[Auth] Password reset requested for inactive user: %s", email) + return "", "", false + } + + // Get site name + siteName := "Sub2API" + if s.settingService != nil { + siteName = s.settingService.GetSiteName(ctx) + } + + // Build reset URL base + resetURL := fmt.Sprintf("%s/reset-password", strings.TrimSuffix(frontendBaseURL, "/")) + + return siteName, resetURL, true +} + +// RequestPasswordReset 请求密码重置(同步发送) +// Security: Returns the same response regardless of whether the email exists (prevent user enumeration) +func (s *AuthService) RequestPasswordReset(ctx context.Context, email, frontendBaseURL string) error { + if !s.IsPasswordResetEnabled(ctx) { + return infraerrors.Forbidden("PASSWORD_RESET_DISABLED", "password reset is not enabled") + } + if s.emailService == nil { + return ErrServiceUnavailable + } + + siteName, resetURL, shouldProceed := s.preparePasswordReset(ctx, email, frontendBaseURL) + if !shouldProceed { + return nil // Silent success to prevent enumeration + } + + if err := s.emailService.SendPasswordResetEmail(ctx, email, siteName, resetURL); err != nil { + log.Printf("[Auth] Failed to send password reset email to %s: %v", email, err) + return nil // Silent success to prevent enumeration + } + + log.Printf("[Auth] Password reset email sent to: %s", email) + return nil +} + +// RequestPasswordResetAsync 异步请求密码重置(队列发送) +// Security: Returns the same response regardless of whether the email exists (prevent user enumeration) +func (s *AuthService) RequestPasswordResetAsync(ctx context.Context, email, frontendBaseURL string) error { + if !s.IsPasswordResetEnabled(ctx) { + return infraerrors.Forbidden("PASSWORD_RESET_DISABLED", "password reset is not enabled") + } + if s.emailQueueService == nil { + return ErrServiceUnavailable + } + + siteName, resetURL, shouldProceed := s.preparePasswordReset(ctx, email, frontendBaseURL) + if !shouldProceed { + return nil // Silent success to prevent enumeration + } + + if err := s.emailQueueService.EnqueuePasswordReset(email, siteName, resetURL); err != nil { + log.Printf("[Auth] Failed to enqueue password reset email for %s: %v", email, err) + return nil // Silent success to prevent enumeration + } + + log.Printf("[Auth] Password reset email enqueued for: %s", email) + return nil +} + +// ResetPassword 重置密码 +// Security: Increments TokenVersion to invalidate all existing JWT tokens +func (s *AuthService) ResetPassword(ctx context.Context, email, token, newPassword string) error { + // Check if password reset is enabled + if !s.IsPasswordResetEnabled(ctx) { + return infraerrors.Forbidden("PASSWORD_RESET_DISABLED", "password reset is not enabled") + } + + if s.emailService == nil { + return ErrServiceUnavailable + } + + // Verify and consume the reset token (one-time use) + if err := s.emailService.ConsumePasswordResetToken(ctx, email, token); err != nil { + return err + } + + // Get user + user, err := s.userRepo.GetByEmail(ctx, email) + if err != nil { + if errors.Is(err, ErrUserNotFound) { + return ErrInvalidResetToken // Token was valid but user was deleted + } + log.Printf("[Auth] Database error getting user for password reset: %v", err) + return ErrServiceUnavailable + } + + // Check if user is active + if !user.IsActive() { + return ErrUserNotActive + } + + // Hash new password + hashedPassword, err := s.HashPassword(newPassword) + if err != nil { + return fmt.Errorf("hash password: %w", err) + } + + // Update password and increment TokenVersion + user.PasswordHash = hashedPassword + user.TokenVersion++ // Invalidate all existing tokens + + if err := s.userRepo.Update(ctx, user); err != nil { + log.Printf("[Auth] Database error updating password for user %d: %v", user.ID, err) + return ErrServiceUnavailable + } + + log.Printf("[Auth] Password reset successful for user: %s", email) + return nil +} diff --git a/backend/internal/service/auth_service_register_test.go b/backend/internal/service/auth_service_register_test.go index bc8f6f68..e31ca561 100644 --- a/backend/internal/service/auth_service_register_test.go +++ b/backend/internal/service/auth_service_register_test.go @@ -71,6 +71,26 @@ func (s *emailCacheStub) DeleteVerificationCode(ctx context.Context, email strin return nil } +func (s *emailCacheStub) GetPasswordResetToken(ctx context.Context, email string) (*PasswordResetTokenData, error) { + return nil, nil +} + +func (s *emailCacheStub) SetPasswordResetToken(ctx context.Context, email string, data *PasswordResetTokenData, ttl time.Duration) error { + return nil +} + +func (s *emailCacheStub) DeletePasswordResetToken(ctx context.Context, email string) error { + return nil +} + +func (s *emailCacheStub) IsPasswordResetEmailInCooldown(ctx context.Context, email string) bool { + return false +} + +func (s *emailCacheStub) SetPasswordResetEmailCooldown(ctx context.Context, email string, ttl time.Duration) error { + return nil +} + func newAuthService(repo *userRepoStub, settings map[string]string, emailCache EmailCache) *AuthService { cfg := &config.Config{ JWT: config.JWTConfig{ diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go index 3bb63ffa..31a34e00 100644 --- a/backend/internal/service/domain_constants.go +++ b/backend/internal/service/domain_constants.go @@ -69,9 +69,10 @@ const LinuxDoConnectSyntheticEmailDomain = "@linuxdo-connect.invalid" // Setting keys const ( // 注册设置 - SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册 - SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证 - SettingKeyPromoCodeEnabled = "promo_code_enabled" // 是否启用优惠码功能 + SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册 + SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证 + SettingKeyPromoCodeEnabled = "promo_code_enabled" // 是否启用优惠码功能 + SettingKeyPasswordResetEnabled = "password_reset_enabled" // 是否启用忘记密码功能(需要先开启邮件验证) // 邮件服务设置 SettingKeySMTPHost = "smtp_host" // SMTP服务器地址 @@ -87,6 +88,9 @@ const ( SettingKeyTurnstileSiteKey = "turnstile_site_key" // Turnstile Site Key SettingKeyTurnstileSecretKey = "turnstile_secret_key" // Turnstile Secret Key + // TOTP 双因素认证设置 + SettingKeyTotpEnabled = "totp_enabled" // 是否启用 TOTP 2FA 功能 + // LinuxDo Connect OAuth 登录设置 SettingKeyLinuxDoConnectEnabled = "linuxdo_connect_enabled" SettingKeyLinuxDoConnectClientID = "linuxdo_connect_client_id" diff --git a/backend/internal/service/email_queue_service.go b/backend/internal/service/email_queue_service.go index 1c22702c..6c975c69 100644 --- a/backend/internal/service/email_queue_service.go +++ b/backend/internal/service/email_queue_service.go @@ -8,11 +8,18 @@ import ( "time" ) +// Task type constants +const ( + TaskTypeVerifyCode = "verify_code" + TaskTypePasswordReset = "password_reset" +) + // EmailTask 邮件发送任务 type EmailTask struct { Email string SiteName string - TaskType string // "verify_code" + TaskType string // "verify_code" or "password_reset" + ResetURL string // Only used for password_reset task type } // EmailQueueService 异步邮件队列服务 @@ -73,12 +80,18 @@ func (s *EmailQueueService) processTask(workerID int, task EmailTask) { defer cancel() switch task.TaskType { - case "verify_code": + case TaskTypeVerifyCode: if err := s.emailService.SendVerifyCode(ctx, task.Email, task.SiteName); err != nil { log.Printf("[EmailQueue] Worker %d failed to send verify code to %s: %v", workerID, task.Email, err) } else { log.Printf("[EmailQueue] Worker %d sent verify code to %s", workerID, task.Email) } + case TaskTypePasswordReset: + if err := s.emailService.SendPasswordResetEmailWithCooldown(ctx, task.Email, task.SiteName, task.ResetURL); err != nil { + log.Printf("[EmailQueue] Worker %d failed to send password reset to %s: %v", workerID, task.Email, err) + } else { + log.Printf("[EmailQueue] Worker %d sent password reset to %s", workerID, task.Email) + } default: log.Printf("[EmailQueue] Worker %d unknown task type: %s", workerID, task.TaskType) } @@ -89,7 +102,7 @@ func (s *EmailQueueService) EnqueueVerifyCode(email, siteName string) error { task := EmailTask{ Email: email, SiteName: siteName, - TaskType: "verify_code", + TaskType: TaskTypeVerifyCode, } select { @@ -101,6 +114,24 @@ func (s *EmailQueueService) EnqueueVerifyCode(email, siteName string) error { } } +// EnqueuePasswordReset 将密码重置邮件任务加入队列 +func (s *EmailQueueService) EnqueuePasswordReset(email, siteName, resetURL string) error { + task := EmailTask{ + Email: email, + SiteName: siteName, + TaskType: TaskTypePasswordReset, + ResetURL: resetURL, + } + + select { + case s.taskChan <- task: + log.Printf("[EmailQueue] Enqueued password reset task for %s", email) + return nil + default: + return fmt.Errorf("email queue is full") + } +} + // Stop 停止队列服务 func (s *EmailQueueService) Stop() { close(s.stopChan) diff --git a/backend/internal/service/email_service.go b/backend/internal/service/email_service.go index 55e137d6..44edf7f7 100644 --- a/backend/internal/service/email_service.go +++ b/backend/internal/service/email_service.go @@ -3,11 +3,14 @@ package service import ( "context" "crypto/rand" + "crypto/subtle" "crypto/tls" + "encoding/hex" "fmt" "log" "math/big" "net/smtp" + "net/url" "strconv" "time" @@ -19,6 +22,9 @@ var ( ErrInvalidVerifyCode = infraerrors.BadRequest("INVALID_VERIFY_CODE", "invalid or expired verification code") ErrVerifyCodeTooFrequent = infraerrors.TooManyRequests("VERIFY_CODE_TOO_FREQUENT", "please wait before requesting a new code") ErrVerifyCodeMaxAttempts = infraerrors.TooManyRequests("VERIFY_CODE_MAX_ATTEMPTS", "too many failed attempts, please request a new code") + + // Password reset errors + ErrInvalidResetToken = infraerrors.BadRequest("INVALID_RESET_TOKEN", "invalid or expired password reset token") ) // EmailCache defines cache operations for email service @@ -26,6 +32,16 @@ type EmailCache interface { GetVerificationCode(ctx context.Context, email string) (*VerificationCodeData, error) SetVerificationCode(ctx context.Context, email string, data *VerificationCodeData, ttl time.Duration) error DeleteVerificationCode(ctx context.Context, email string) error + + // Password reset token methods + GetPasswordResetToken(ctx context.Context, email string) (*PasswordResetTokenData, error) + SetPasswordResetToken(ctx context.Context, email string, data *PasswordResetTokenData, ttl time.Duration) error + DeletePasswordResetToken(ctx context.Context, email string) error + + // Password reset email cooldown methods + // Returns true if in cooldown period (email was sent recently) + IsPasswordResetEmailInCooldown(ctx context.Context, email string) bool + SetPasswordResetEmailCooldown(ctx context.Context, email string, ttl time.Duration) error } // VerificationCodeData represents verification code data @@ -35,10 +51,22 @@ type VerificationCodeData struct { CreatedAt time.Time } +// PasswordResetTokenData represents password reset token data +type PasswordResetTokenData struct { + Token string + CreatedAt time.Time +} + const ( verifyCodeTTL = 15 * time.Minute verifyCodeCooldown = 1 * time.Minute maxVerifyCodeAttempts = 5 + + // Password reset token settings + passwordResetTokenTTL = 30 * time.Minute + + // Password reset email cooldown (prevent email bombing) + passwordResetEmailCooldown = 30 * time.Second ) // SMTPConfig SMTP配置 @@ -254,8 +282,8 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error return ErrVerifyCodeMaxAttempts } - // 验证码不匹配 - if data.Code != code { + // 验证码不匹配 (constant-time comparison to prevent timing attacks) + if subtle.ConstantTimeCompare([]byte(data.Code), []byte(code)) != 1 { data.Attempts++ if err := s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL); err != nil { log.Printf("[Email] Failed to update verification attempt count: %v", err) @@ -357,3 +385,157 @@ func (s *EmailService) TestSMTPConnectionWithConfig(config *SMTPConfig) error { return client.Quit() } + +// GeneratePasswordResetToken generates a secure 32-byte random token (64 hex characters) +func (s *EmailService) GeneratePasswordResetToken() (string, error) { + bytes := make([]byte, 32) + if _, err := rand.Read(bytes); err != nil { + return "", err + } + return hex.EncodeToString(bytes), nil +} + +// SendPasswordResetEmail sends a password reset email with a reset link +func (s *EmailService) SendPasswordResetEmail(ctx context.Context, email, siteName, resetURL string) error { + var token string + var needSaveToken bool + + // Check if token already exists + existing, err := s.cache.GetPasswordResetToken(ctx, email) + if err == nil && existing != nil { + // Token exists, reuse it (allows resending email without generating new token) + token = existing.Token + needSaveToken = false + } else { + // Generate new token + token, err = s.GeneratePasswordResetToken() + if err != nil { + return fmt.Errorf("generate token: %w", err) + } + needSaveToken = true + } + + // Save token to Redis (only if new token generated) + if needSaveToken { + data := &PasswordResetTokenData{ + Token: token, + CreatedAt: time.Now(), + } + if err := s.cache.SetPasswordResetToken(ctx, email, data, passwordResetTokenTTL); err != nil { + return fmt.Errorf("save reset token: %w", err) + } + } + + // Build full reset URL with URL-encoded token and email + fullResetURL := fmt.Sprintf("%s?email=%s&token=%s", resetURL, url.QueryEscape(email), url.QueryEscape(token)) + + // Build email content + subject := fmt.Sprintf("[%s] 密码重置请求", siteName) + body := s.buildPasswordResetEmailBody(fullResetURL, siteName) + + // Send email + if err := s.SendEmail(ctx, email, subject, body); err != nil { + return fmt.Errorf("send email: %w", err) + } + + return nil +} + +// SendPasswordResetEmailWithCooldown sends password reset email with cooldown check (called by queue worker) +// This method wraps SendPasswordResetEmail with email cooldown to prevent email bombing +func (s *EmailService) SendPasswordResetEmailWithCooldown(ctx context.Context, email, siteName, resetURL string) error { + // Check email cooldown to prevent email bombing + if s.cache.IsPasswordResetEmailInCooldown(ctx, email) { + log.Printf("[Email] Password reset email skipped (cooldown): %s", email) + return nil // Silent success to prevent revealing cooldown to attackers + } + + // Send email using core method + if err := s.SendPasswordResetEmail(ctx, email, siteName, resetURL); err != nil { + return err + } + + // Set cooldown marker (Redis TTL handles expiration) + if err := s.cache.SetPasswordResetEmailCooldown(ctx, email, passwordResetEmailCooldown); err != nil { + log.Printf("[Email] Failed to set password reset cooldown for %s: %v", email, err) + } + + return nil +} + +// VerifyPasswordResetToken verifies the password reset token without consuming it +func (s *EmailService) VerifyPasswordResetToken(ctx context.Context, email, token string) error { + data, err := s.cache.GetPasswordResetToken(ctx, email) + if err != nil || data == nil { + return ErrInvalidResetToken + } + + // Use constant-time comparison to prevent timing attacks + if subtle.ConstantTimeCompare([]byte(data.Token), []byte(token)) != 1 { + return ErrInvalidResetToken + } + + return nil +} + +// ConsumePasswordResetToken verifies and deletes the token (one-time use) +func (s *EmailService) ConsumePasswordResetToken(ctx context.Context, email, token string) error { + // Verify first + if err := s.VerifyPasswordResetToken(ctx, email, token); err != nil { + return err + } + + // Delete after verification (one-time use) + if err := s.cache.DeletePasswordResetToken(ctx, email); err != nil { + log.Printf("[Email] Failed to delete password reset token after consumption: %v", err) + } + return nil +} + +// buildPasswordResetEmailBody builds the HTML content for password reset email +func (s *EmailService) buildPasswordResetEmailBody(resetURL, siteName string) string { + return fmt.Sprintf(` + + + + + + + +
+
+

%s

+
+
+

密码重置请求

+

您已请求重置密码。请点击下方按钮设置新密码:

+ 重置密码 +
+

此链接将在 30 分钟后失效。

+

如果您没有请求重置密码,请忽略此邮件。您的密码将保持不变。

+
+ +
+ +
+ + +`, siteName, resetURL, resetURL) +} diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 9565da29..a8f5baeb 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -305,6 +305,19 @@ func (s *GatewayService) BindStickySession(ctx context.Context, groupID *int64, return s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, accountID, stickySessionTTL) } +// GetCachedSessionAccountID retrieves the account ID bound to a sticky session. +// Returns 0 if no binding exists or on error. +func (s *GatewayService) GetCachedSessionAccountID(ctx context.Context, groupID *int64, sessionHash string) (int64, error) { + if sessionHash == "" || s.cache == nil { + return 0, nil + } + accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash) + if err != nil { + return 0, err + } + return accountID, nil +} + func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string { if parsed == nil { return "" diff --git a/backend/internal/service/gemini_native_signature_cleaner.go b/backend/internal/service/gemini_native_signature_cleaner.go new file mode 100644 index 00000000..b3352fb0 --- /dev/null +++ b/backend/internal/service/gemini_native_signature_cleaner.go @@ -0,0 +1,72 @@ +package service + +import ( + "encoding/json" +) + +// CleanGeminiNativeThoughtSignatures 从 Gemini 原生 API 请求中移除 thoughtSignature 字段, +// 以避免跨账号签名验证错误。 +// +// 当粘性会话切换账号时(例如原账号异常、不可调度等),旧账号返回的 thoughtSignature +// 会导致新账号的签名验证失败。通过移除这些签名,让新账号重新生成有效的签名。 +// +// CleanGeminiNativeThoughtSignatures removes thoughtSignature fields from Gemini native API requests +// to avoid cross-account signature validation errors. +// +// When sticky session switches accounts (e.g., original account becomes unavailable), +// thoughtSignatures from the old account will cause validation failures on the new account. +// By removing these signatures, we allow the new account to generate valid signatures. +func CleanGeminiNativeThoughtSignatures(body []byte) []byte { + if len(body) == 0 { + return body + } + + // 解析 JSON + var data any + if err := json.Unmarshal(body, &data); err != nil { + // 如果解析失败,返回原始 body(可能不是 JSON 或格式不正确) + return body + } + + // 递归清理 thoughtSignature + cleaned := cleanThoughtSignaturesRecursive(data) + + // 重新序列化 + result, err := json.Marshal(cleaned) + if err != nil { + // 如果序列化失败,返回原始 body + return body + } + + return result +} + +// cleanThoughtSignaturesRecursive 递归遍历数据结构,移除所有 thoughtSignature 字段 +func cleanThoughtSignaturesRecursive(data any) any { + switch v := data.(type) { + case map[string]any: + // 创建新的 map,移除 thoughtSignature + result := make(map[string]any, len(v)) + for key, value := range v { + // 跳过 thoughtSignature 字段 + if key == "thoughtSignature" { + continue + } + // 递归处理嵌套结构 + result[key] = cleanThoughtSignaturesRecursive(value) + } + return result + + case []any: + // 递归处理数组中的每个元素 + result := make([]any, len(v)) + for i, item := range v { + result[i] = cleanThoughtSignaturesRecursive(item) + } + return result + + default: + // 基本类型(string, number, bool, null)直接返回 + return v + } +} diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 65ba01b3..289a13af 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -60,6 +60,92 @@ type OpenAICodexUsageSnapshot struct { UpdatedAt string `json:"updated_at,omitempty"` } +// NormalizedCodexLimits contains normalized 5h/7d rate limit data +type NormalizedCodexLimits struct { + Used5hPercent *float64 + Reset5hSeconds *int + Window5hMinutes *int + Used7dPercent *float64 + Reset7dSeconds *int + Window7dMinutes *int +} + +// Normalize converts primary/secondary fields to canonical 5h/7d fields. +// Strategy: Compare window_minutes to determine which is 5h vs 7d. +// Returns nil if snapshot is nil or has no useful data. +func (s *OpenAICodexUsageSnapshot) Normalize() *NormalizedCodexLimits { + if s == nil { + return nil + } + + result := &NormalizedCodexLimits{} + + primaryMins := 0 + secondaryMins := 0 + hasPrimaryWindow := false + hasSecondaryWindow := false + + if s.PrimaryWindowMinutes != nil { + primaryMins = *s.PrimaryWindowMinutes + hasPrimaryWindow = true + } + if s.SecondaryWindowMinutes != nil { + secondaryMins = *s.SecondaryWindowMinutes + hasSecondaryWindow = true + } + + // Determine mapping based on window_minutes + use5hFromPrimary := false + use7dFromPrimary := false + + if hasPrimaryWindow && hasSecondaryWindow { + // Both known: smaller window is 5h, larger is 7d + if primaryMins < secondaryMins { + use5hFromPrimary = true + } else { + use7dFromPrimary = true + } + } else if hasPrimaryWindow { + // Only primary known: classify by threshold (<=360 min = 6h -> 5h window) + if primaryMins <= 360 { + use5hFromPrimary = true + } else { + use7dFromPrimary = true + } + } else if hasSecondaryWindow { + // Only secondary known: classify by threshold + if secondaryMins <= 360 { + // 5h from secondary, so primary (if any data) is 7d + use7dFromPrimary = true + } else { + // 7d from secondary, so primary (if any data) is 5h + use5hFromPrimary = true + } + } else { + // No window_minutes: fall back to legacy assumption (primary=7d, secondary=5h) + use7dFromPrimary = true + } + + // Assign values + if use5hFromPrimary { + result.Used5hPercent = s.PrimaryUsedPercent + result.Reset5hSeconds = s.PrimaryResetAfterSeconds + result.Window5hMinutes = s.PrimaryWindowMinutes + result.Used7dPercent = s.SecondaryUsedPercent + result.Reset7dSeconds = s.SecondaryResetAfterSeconds + result.Window7dMinutes = s.SecondaryWindowMinutes + } else if use7dFromPrimary { + result.Used7dPercent = s.PrimaryUsedPercent + result.Reset7dSeconds = s.PrimaryResetAfterSeconds + result.Window7dMinutes = s.PrimaryWindowMinutes + result.Used5hPercent = s.SecondaryUsedPercent + result.Reset5hSeconds = s.SecondaryResetAfterSeconds + result.Window5hMinutes = s.SecondaryWindowMinutes + } + + return result +} + // OpenAIUsage represents OpenAI API response usage type OpenAIUsage struct { InputTokens int `json:"input_tokens"` @@ -867,7 +953,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco // Extract and save Codex usage snapshot from response headers (for OAuth accounts) if account.Type == AccountTypeOAuth { - if snapshot := extractCodexUsageHeaders(resp.Header); snapshot != nil { + if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil { s.updateCodexUsageSnapshot(ctx, account.ID, snapshot) } } @@ -1665,8 +1751,9 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec return nil } -// extractCodexUsageHeaders extracts Codex usage limits from response headers -func extractCodexUsageHeaders(headers http.Header) *OpenAICodexUsageSnapshot { +// ParseCodexRateLimitHeaders extracts Codex usage limits from response headers. +// Exported for use in ratelimit_service when handling OpenAI 429 responses. +func ParseCodexRateLimitHeaders(headers http.Header) *OpenAICodexUsageSnapshot { snapshot := &OpenAICodexUsageSnapshot{} hasData := false @@ -1740,6 +1827,8 @@ func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, acc // Convert snapshot to map for merging into Extra updates := make(map[string]any) + + // Save raw primary/secondary fields for debugging/tracing if snapshot.PrimaryUsedPercent != nil { updates["codex_primary_used_percent"] = *snapshot.PrimaryUsedPercent } @@ -1763,109 +1852,25 @@ func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, acc } updates["codex_usage_updated_at"] = snapshot.UpdatedAt - // Normalize to canonical 5h/7d fields based on window_minutes - // This fixes the issue where OpenAI's primary/secondary naming is reversed - // Strategy: Compare the two windows and assign the smaller one to 5h, larger one to 7d - - // IMPORTANT: We can only reliably determine window type from window_minutes field - // The reset_after_seconds is remaining time, not window size, so it cannot be used for comparison - - var primaryWindowMins, secondaryWindowMins int - var hasPrimaryWindow, hasSecondaryWindow bool - - // Only use window_minutes for reliable window size comparison - if snapshot.PrimaryWindowMinutes != nil { - primaryWindowMins = *snapshot.PrimaryWindowMinutes - hasPrimaryWindow = true - } - - if snapshot.SecondaryWindowMinutes != nil { - secondaryWindowMins = *snapshot.SecondaryWindowMinutes - hasSecondaryWindow = true - } - - // Determine which is 5h and which is 7d - var use5hFromPrimary, use7dFromPrimary bool - var use5hFromSecondary, use7dFromSecondary bool - - if hasPrimaryWindow && hasSecondaryWindow { - // Both window sizes known: compare and assign smaller to 5h, larger to 7d - if primaryWindowMins < secondaryWindowMins { - use5hFromPrimary = true - use7dFromSecondary = true - } else { - use5hFromSecondary = true - use7dFromPrimary = true + // Normalize to canonical 5h/7d fields + if normalized := snapshot.Normalize(); normalized != nil { + if normalized.Used5hPercent != nil { + updates["codex_5h_used_percent"] = *normalized.Used5hPercent } - } else if hasPrimaryWindow { - // Only primary window size known: classify by absolute threshold - if primaryWindowMins <= 360 { - use5hFromPrimary = true - } else { - use7dFromPrimary = true + if normalized.Reset5hSeconds != nil { + updates["codex_5h_reset_after_seconds"] = *normalized.Reset5hSeconds } - } else if hasSecondaryWindow { - // Only secondary window size known: classify by absolute threshold - if secondaryWindowMins <= 360 { - use5hFromSecondary = true - } else { - use7dFromSecondary = true + if normalized.Window5hMinutes != nil { + updates["codex_5h_window_minutes"] = *normalized.Window5hMinutes } - } else { - // No window_minutes available: cannot reliably determine window types - // Fall back to legacy assumption (may be incorrect) - // Assume primary=7d, secondary=5h based on historical observation - if snapshot.SecondaryUsedPercent != nil || snapshot.SecondaryResetAfterSeconds != nil || snapshot.SecondaryWindowMinutes != nil { - use5hFromSecondary = true + if normalized.Used7dPercent != nil { + updates["codex_7d_used_percent"] = *normalized.Used7dPercent } - if snapshot.PrimaryUsedPercent != nil || snapshot.PrimaryResetAfterSeconds != nil || snapshot.PrimaryWindowMinutes != nil { - use7dFromPrimary = true + if normalized.Reset7dSeconds != nil { + updates["codex_7d_reset_after_seconds"] = *normalized.Reset7dSeconds } - } - - // Write canonical 5h fields - if use5hFromPrimary { - if snapshot.PrimaryUsedPercent != nil { - updates["codex_5h_used_percent"] = *snapshot.PrimaryUsedPercent - } - if snapshot.PrimaryResetAfterSeconds != nil { - updates["codex_5h_reset_after_seconds"] = *snapshot.PrimaryResetAfterSeconds - } - if snapshot.PrimaryWindowMinutes != nil { - updates["codex_5h_window_minutes"] = *snapshot.PrimaryWindowMinutes - } - } else if use5hFromSecondary { - if snapshot.SecondaryUsedPercent != nil { - updates["codex_5h_used_percent"] = *snapshot.SecondaryUsedPercent - } - if snapshot.SecondaryResetAfterSeconds != nil { - updates["codex_5h_reset_after_seconds"] = *snapshot.SecondaryResetAfterSeconds - } - if snapshot.SecondaryWindowMinutes != nil { - updates["codex_5h_window_minutes"] = *snapshot.SecondaryWindowMinutes - } - } - - // Write canonical 7d fields - if use7dFromPrimary { - if snapshot.PrimaryUsedPercent != nil { - updates["codex_7d_used_percent"] = *snapshot.PrimaryUsedPercent - } - if snapshot.PrimaryResetAfterSeconds != nil { - updates["codex_7d_reset_after_seconds"] = *snapshot.PrimaryResetAfterSeconds - } - if snapshot.PrimaryWindowMinutes != nil { - updates["codex_7d_window_minutes"] = *snapshot.PrimaryWindowMinutes - } - } else if use7dFromSecondary { - if snapshot.SecondaryUsedPercent != nil { - updates["codex_7d_used_percent"] = *snapshot.SecondaryUsedPercent - } - if snapshot.SecondaryResetAfterSeconds != nil { - updates["codex_7d_reset_after_seconds"] = *snapshot.SecondaryResetAfterSeconds - } - if snapshot.SecondaryWindowMinutes != nil { - updates["codex_7d_window_minutes"] = *snapshot.SecondaryWindowMinutes + if normalized.Window7dMinutes != nil { + updates["codex_7d_window_minutes"] = *normalized.Window7dMinutes } } diff --git a/backend/internal/service/ratelimit_service.go b/backend/internal/service/ratelimit_service.go index 41bd253c..6b7ebb07 100644 --- a/backend/internal/service/ratelimit_service.go +++ b/backend/internal/service/ratelimit_service.go @@ -343,9 +343,48 @@ func (s *RateLimitService) handleCustomErrorCode(ctx context.Context, account *A // handle429 处理429限流错误 // 解析响应头获取重置时间,标记账号为限流状态 func (s *RateLimitService) handle429(ctx context.Context, account *Account, headers http.Header, responseBody []byte) { - // 解析重置时间戳 + // 1. OpenAI 平台:优先尝试解析 x-codex-* 响应头(用于 rate_limit_exceeded) + if account.Platform == PlatformOpenAI { + if resetAt := s.calculateOpenAI429ResetTime(headers); resetAt != nil { + if err := s.accountRepo.SetRateLimited(ctx, account.ID, *resetAt); err != nil { + slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err) + return + } + slog.Info("openai_account_rate_limited", "account_id", account.ID, "reset_at", *resetAt) + return + } + } + + // 2. 尝试从响应头解析重置时间(Anthropic) resetTimestamp := headers.Get("anthropic-ratelimit-unified-reset") + + // 3. 如果响应头没有,尝试从响应体解析(OpenAI usage_limit_reached, Gemini) if resetTimestamp == "" { + switch account.Platform { + case PlatformOpenAI: + // 尝试解析 OpenAI 的 usage_limit_reached 错误 + if resetAt := parseOpenAIRateLimitResetTime(responseBody); resetAt != nil { + resetTime := time.Unix(*resetAt, 0) + if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetTime); err != nil { + slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err) + return + } + slog.Info("account_rate_limited", "account_id", account.ID, "platform", account.Platform, "reset_at", resetTime, "reset_in", time.Until(resetTime).Truncate(time.Second)) + return + } + case PlatformGemini, PlatformAntigravity: + // 尝试解析 Gemini 格式(用于其他平台) + if resetAt := ParseGeminiRateLimitResetTime(responseBody); resetAt != nil { + resetTime := time.Unix(*resetAt, 0) + if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetTime); err != nil { + slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err) + return + } + slog.Info("account_rate_limited", "account_id", account.ID, "platform", account.Platform, "reset_at", resetTime, "reset_in", time.Until(resetTime).Truncate(time.Second)) + return + } + } + // 没有重置时间,使用默认5分钟 resetAt := time.Now().Add(5 * time.Minute) if s.shouldScopeClaudeSonnetRateLimit(account, responseBody) { @@ -356,6 +395,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head } return } + slog.Warn("rate_limit_no_reset_time", "account_id", account.ID, "platform", account.Platform, "using_default", "5m") if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil { slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err) } @@ -419,6 +459,108 @@ func (s *RateLimitService) shouldScopeClaudeSonnetRateLimit(account *Account, re return strings.Contains(msg, "sonnet") } +// calculateOpenAI429ResetTime 从 OpenAI 429 响应头计算正确的重置时间 +// 返回 nil 表示无法从响应头中确定重置时间 +func (s *RateLimitService) calculateOpenAI429ResetTime(headers http.Header) *time.Time { + snapshot := ParseCodexRateLimitHeaders(headers) + if snapshot == nil { + return nil + } + + normalized := snapshot.Normalize() + if normalized == nil { + return nil + } + + now := time.Now() + + // 判断哪个限制被触发(used_percent >= 100) + is7dExhausted := normalized.Used7dPercent != nil && *normalized.Used7dPercent >= 100 + is5hExhausted := normalized.Used5hPercent != nil && *normalized.Used5hPercent >= 100 + + // 优先使用被触发限制的重置时间 + if is7dExhausted && normalized.Reset7dSeconds != nil { + resetAt := now.Add(time.Duration(*normalized.Reset7dSeconds) * time.Second) + slog.Info("openai_429_7d_limit_exhausted", "reset_after_seconds", *normalized.Reset7dSeconds, "reset_at", resetAt) + return &resetAt + } + if is5hExhausted && normalized.Reset5hSeconds != nil { + resetAt := now.Add(time.Duration(*normalized.Reset5hSeconds) * time.Second) + slog.Info("openai_429_5h_limit_exhausted", "reset_after_seconds", *normalized.Reset5hSeconds, "reset_at", resetAt) + return &resetAt + } + + // 都未达到100%但收到429,使用较长的重置时间 + var maxResetSecs int + if normalized.Reset7dSeconds != nil && *normalized.Reset7dSeconds > maxResetSecs { + maxResetSecs = *normalized.Reset7dSeconds + } + if normalized.Reset5hSeconds != nil && *normalized.Reset5hSeconds > maxResetSecs { + maxResetSecs = *normalized.Reset5hSeconds + } + if maxResetSecs > 0 { + resetAt := now.Add(time.Duration(maxResetSecs) * time.Second) + slog.Info("openai_429_using_max_reset", "max_reset_seconds", maxResetSecs, "reset_at", resetAt) + return &resetAt + } + + return nil +} + +// parseOpenAIRateLimitResetTime 解析 OpenAI 格式的 429 响应,返回重置时间的 Unix 时间戳 +// OpenAI 的 usage_limit_reached 错误格式: +// +// { +// "error": { +// "message": "The usage limit has been reached", +// "type": "usage_limit_reached", +// "resets_at": 1769404154, +// "resets_in_seconds": 133107 +// } +// } +func parseOpenAIRateLimitResetTime(body []byte) *int64 { + var parsed map[string]any + if err := json.Unmarshal(body, &parsed); err != nil { + return nil + } + + errObj, ok := parsed["error"].(map[string]any) + if !ok { + return nil + } + + // 检查是否为 usage_limit_reached 或 rate_limit_exceeded 类型 + errType, _ := errObj["type"].(string) + if errType != "usage_limit_reached" && errType != "rate_limit_exceeded" { + return nil + } + + // 优先使用 resets_at(Unix 时间戳) + if resetsAt, ok := errObj["resets_at"].(float64); ok { + ts := int64(resetsAt) + return &ts + } + if resetsAt, ok := errObj["resets_at"].(string); ok { + if ts, err := strconv.ParseInt(resetsAt, 10, 64); err == nil { + return &ts + } + } + + // 如果没有 resets_at,尝试使用 resets_in_seconds + if resetsInSeconds, ok := errObj["resets_in_seconds"].(float64); ok { + ts := time.Now().Unix() + int64(resetsInSeconds) + return &ts + } + if resetsInSeconds, ok := errObj["resets_in_seconds"].(string); ok { + if sec, err := strconv.ParseInt(resetsInSeconds, 10, 64); err == nil { + ts := time.Now().Unix() + sec + return &ts + } + } + + return nil +} + // handle529 处理529过载错误 // 根据配置设置过载冷却时间 func (s *RateLimitService) handle529(ctx context.Context, account *Account) { diff --git a/backend/internal/service/ratelimit_service_openai_test.go b/backend/internal/service/ratelimit_service_openai_test.go new file mode 100644 index 00000000..00902068 --- /dev/null +++ b/backend/internal/service/ratelimit_service_openai_test.go @@ -0,0 +1,364 @@ +package service + +import ( + "net/http" + "testing" + "time" +) + +func TestCalculateOpenAI429ResetTime_7dExhausted(t *testing.T) { + svc := &RateLimitService{} + + // Simulate headers when 7d limit is exhausted (100% used) + // Primary = 7d (10080 minutes), Secondary = 5h (300 minutes) + headers := http.Header{} + headers.Set("x-codex-primary-used-percent", "100") + headers.Set("x-codex-primary-reset-after-seconds", "384607") // ~4.5 days + headers.Set("x-codex-primary-window-minutes", "10080") // 7 days + headers.Set("x-codex-secondary-used-percent", "3") + headers.Set("x-codex-secondary-reset-after-seconds", "17369") // ~4.8 hours + headers.Set("x-codex-secondary-window-minutes", "300") // 5 hours + + before := time.Now() + resetAt := svc.calculateOpenAI429ResetTime(headers) + after := time.Now() + + if resetAt == nil { + t.Fatal("expected non-nil resetAt") + } + + // Should be approximately 384607 seconds from now + expectedDuration := 384607 * time.Second + minExpected := before.Add(expectedDuration) + maxExpected := after.Add(expectedDuration) + + if resetAt.Before(minExpected) || resetAt.After(maxExpected) { + t.Errorf("resetAt %v not in expected range [%v, %v]", resetAt, minExpected, maxExpected) + } +} + +func TestCalculateOpenAI429ResetTime_5hExhausted(t *testing.T) { + svc := &RateLimitService{} + + // Simulate headers when 5h limit is exhausted (100% used) + headers := http.Header{} + headers.Set("x-codex-primary-used-percent", "50") + headers.Set("x-codex-primary-reset-after-seconds", "500000") + headers.Set("x-codex-primary-window-minutes", "10080") // 7 days + headers.Set("x-codex-secondary-used-percent", "100") + headers.Set("x-codex-secondary-reset-after-seconds", "3600") // 1 hour + headers.Set("x-codex-secondary-window-minutes", "300") // 5 hours + + before := time.Now() + resetAt := svc.calculateOpenAI429ResetTime(headers) + after := time.Now() + + if resetAt == nil { + t.Fatal("expected non-nil resetAt") + } + + // Should be approximately 3600 seconds from now + expectedDuration := 3600 * time.Second + minExpected := before.Add(expectedDuration) + maxExpected := after.Add(expectedDuration) + + if resetAt.Before(minExpected) || resetAt.After(maxExpected) { + t.Errorf("resetAt %v not in expected range [%v, %v]", resetAt, minExpected, maxExpected) + } +} + +func TestCalculateOpenAI429ResetTime_NeitherExhausted_UsesMax(t *testing.T) { + svc := &RateLimitService{} + + // Neither limit at 100%, should use the longer reset time + headers := http.Header{} + headers.Set("x-codex-primary-used-percent", "80") + headers.Set("x-codex-primary-reset-after-seconds", "100000") + headers.Set("x-codex-primary-window-minutes", "10080") + headers.Set("x-codex-secondary-used-percent", "90") + headers.Set("x-codex-secondary-reset-after-seconds", "5000") + headers.Set("x-codex-secondary-window-minutes", "300") + + before := time.Now() + resetAt := svc.calculateOpenAI429ResetTime(headers) + after := time.Now() + + if resetAt == nil { + t.Fatal("expected non-nil resetAt") + } + + // Should use the max (100000 seconds from 7d window) + expectedDuration := 100000 * time.Second + minExpected := before.Add(expectedDuration) + maxExpected := after.Add(expectedDuration) + + if resetAt.Before(minExpected) || resetAt.After(maxExpected) { + t.Errorf("resetAt %v not in expected range [%v, %v]", resetAt, minExpected, maxExpected) + } +} + +func TestCalculateOpenAI429ResetTime_NoCodexHeaders(t *testing.T) { + svc := &RateLimitService{} + + // No codex headers at all + headers := http.Header{} + headers.Set("content-type", "application/json") + + resetAt := svc.calculateOpenAI429ResetTime(headers) + + if resetAt != nil { + t.Errorf("expected nil resetAt when no codex headers, got %v", resetAt) + } +} + +func TestCalculateOpenAI429ResetTime_ReversedWindowOrder(t *testing.T) { + svc := &RateLimitService{} + + // Test when OpenAI sends primary as 5h and secondary as 7d (reversed) + headers := http.Header{} + headers.Set("x-codex-primary-used-percent", "100") // This is 5h + headers.Set("x-codex-primary-reset-after-seconds", "3600") // 1 hour + headers.Set("x-codex-primary-window-minutes", "300") // 5 hours - smaller! + headers.Set("x-codex-secondary-used-percent", "50") + headers.Set("x-codex-secondary-reset-after-seconds", "500000") + headers.Set("x-codex-secondary-window-minutes", "10080") // 7 days - larger! + + before := time.Now() + resetAt := svc.calculateOpenAI429ResetTime(headers) + after := time.Now() + + if resetAt == nil { + t.Fatal("expected non-nil resetAt") + } + + // Should correctly identify that primary is 5h (smaller window) and use its reset time + expectedDuration := 3600 * time.Second + minExpected := before.Add(expectedDuration) + maxExpected := after.Add(expectedDuration) + + if resetAt.Before(minExpected) || resetAt.After(maxExpected) { + t.Errorf("resetAt %v not in expected range [%v, %v]", resetAt, minExpected, maxExpected) + } +} + +func TestNormalizedCodexLimits(t *testing.T) { + // Test the Normalize() method directly + pUsed := 100.0 + pReset := 384607 + pWindow := 10080 + sUsed := 3.0 + sReset := 17369 + sWindow := 300 + + snapshot := &OpenAICodexUsageSnapshot{ + PrimaryUsedPercent: &pUsed, + PrimaryResetAfterSeconds: &pReset, + PrimaryWindowMinutes: &pWindow, + SecondaryUsedPercent: &sUsed, + SecondaryResetAfterSeconds: &sReset, + SecondaryWindowMinutes: &sWindow, + } + + normalized := snapshot.Normalize() + if normalized == nil { + t.Fatal("expected non-nil normalized") + } + + // Primary has larger window (10080 > 300), so primary should be 7d + if normalized.Used7dPercent == nil || *normalized.Used7dPercent != 100.0 { + t.Errorf("expected Used7dPercent=100, got %v", normalized.Used7dPercent) + } + if normalized.Reset7dSeconds == nil || *normalized.Reset7dSeconds != 384607 { + t.Errorf("expected Reset7dSeconds=384607, got %v", normalized.Reset7dSeconds) + } + if normalized.Used5hPercent == nil || *normalized.Used5hPercent != 3.0 { + t.Errorf("expected Used5hPercent=3, got %v", normalized.Used5hPercent) + } + if normalized.Reset5hSeconds == nil || *normalized.Reset5hSeconds != 17369 { + t.Errorf("expected Reset5hSeconds=17369, got %v", normalized.Reset5hSeconds) + } +} + +func TestNormalizedCodexLimits_OnlyPrimaryData(t *testing.T) { + // Test when only primary has data, no window_minutes + pUsed := 80.0 + pReset := 50000 + + snapshot := &OpenAICodexUsageSnapshot{ + PrimaryUsedPercent: &pUsed, + PrimaryResetAfterSeconds: &pReset, + // No window_minutes, no secondary data + } + + normalized := snapshot.Normalize() + if normalized == nil { + t.Fatal("expected non-nil normalized") + } + + // Legacy assumption: primary=7d, secondary=5h + if normalized.Used7dPercent == nil || *normalized.Used7dPercent != 80.0 { + t.Errorf("expected Used7dPercent=80, got %v", normalized.Used7dPercent) + } + if normalized.Reset7dSeconds == nil || *normalized.Reset7dSeconds != 50000 { + t.Errorf("expected Reset7dSeconds=50000, got %v", normalized.Reset7dSeconds) + } + // Secondary (5h) should be nil + if normalized.Used5hPercent != nil { + t.Errorf("expected Used5hPercent=nil, got %v", *normalized.Used5hPercent) + } + if normalized.Reset5hSeconds != nil { + t.Errorf("expected Reset5hSeconds=nil, got %v", *normalized.Reset5hSeconds) + } +} + +func TestNormalizedCodexLimits_OnlySecondaryData(t *testing.T) { + // Test when only secondary has data, no window_minutes + sUsed := 60.0 + sReset := 3000 + + snapshot := &OpenAICodexUsageSnapshot{ + SecondaryUsedPercent: &sUsed, + SecondaryResetAfterSeconds: &sReset, + // No window_minutes, no primary data + } + + normalized := snapshot.Normalize() + if normalized == nil { + t.Fatal("expected non-nil normalized") + } + + // Legacy assumption: primary=7d, secondary=5h + // So secondary goes to 5h + if normalized.Used5hPercent == nil || *normalized.Used5hPercent != 60.0 { + t.Errorf("expected Used5hPercent=60, got %v", normalized.Used5hPercent) + } + if normalized.Reset5hSeconds == nil || *normalized.Reset5hSeconds != 3000 { + t.Errorf("expected Reset5hSeconds=3000, got %v", normalized.Reset5hSeconds) + } + // Primary (7d) should be nil + if normalized.Used7dPercent != nil { + t.Errorf("expected Used7dPercent=nil, got %v", *normalized.Used7dPercent) + } +} + +func TestNormalizedCodexLimits_BothDataNoWindowMinutes(t *testing.T) { + // Test when both have data but no window_minutes + pUsed := 100.0 + pReset := 400000 + sUsed := 50.0 + sReset := 10000 + + snapshot := &OpenAICodexUsageSnapshot{ + PrimaryUsedPercent: &pUsed, + PrimaryResetAfterSeconds: &pReset, + SecondaryUsedPercent: &sUsed, + SecondaryResetAfterSeconds: &sReset, + // No window_minutes + } + + normalized := snapshot.Normalize() + if normalized == nil { + t.Fatal("expected non-nil normalized") + } + + // Legacy assumption: primary=7d, secondary=5h + if normalized.Used7dPercent == nil || *normalized.Used7dPercent != 100.0 { + t.Errorf("expected Used7dPercent=100, got %v", normalized.Used7dPercent) + } + if normalized.Reset7dSeconds == nil || *normalized.Reset7dSeconds != 400000 { + t.Errorf("expected Reset7dSeconds=400000, got %v", normalized.Reset7dSeconds) + } + if normalized.Used5hPercent == nil || *normalized.Used5hPercent != 50.0 { + t.Errorf("expected Used5hPercent=50, got %v", normalized.Used5hPercent) + } + if normalized.Reset5hSeconds == nil || *normalized.Reset5hSeconds != 10000 { + t.Errorf("expected Reset5hSeconds=10000, got %v", normalized.Reset5hSeconds) + } +} + +func TestHandle429_AnthropicPlatformUnaffected(t *testing.T) { + // Verify that Anthropic platform accounts still use the original logic + // This test ensures we don't break existing Claude account rate limiting + + svc := &RateLimitService{} + + // Simulate Anthropic 429 headers + headers := http.Header{} + headers.Set("anthropic-ratelimit-unified-reset", "1737820800") // A future Unix timestamp + + // For Anthropic platform, calculateOpenAI429ResetTime should return nil + // because it only handles OpenAI platform + resetAt := svc.calculateOpenAI429ResetTime(headers) + + // Should return nil since there are no x-codex-* headers + if resetAt != nil { + t.Errorf("expected nil for Anthropic headers, got %v", resetAt) + } +} + +func TestCalculateOpenAI429ResetTime_UserProvidedScenario(t *testing.T) { + // This is the exact scenario from the user: + // codex_7d_used_percent: 100 + // codex_7d_reset_after_seconds: 384607 (约4.5天后重置) + // codex_5h_used_percent: 3 + // codex_5h_reset_after_seconds: 17369 (约4.8小时后重置) + + svc := &RateLimitService{} + + // Simulate headers matching user's data + // Note: We need to map the canonical 5h/7d back to primary/secondary + // Based on typical OpenAI behavior: primary=7d (larger window), secondary=5h (smaller window) + headers := http.Header{} + headers.Set("x-codex-primary-used-percent", "100") + headers.Set("x-codex-primary-reset-after-seconds", "384607") + headers.Set("x-codex-primary-window-minutes", "10080") // 7 days = 10080 minutes + headers.Set("x-codex-secondary-used-percent", "3") + headers.Set("x-codex-secondary-reset-after-seconds", "17369") + headers.Set("x-codex-secondary-window-minutes", "300") // 5 hours = 300 minutes + + before := time.Now() + resetAt := svc.calculateOpenAI429ResetTime(headers) + after := time.Now() + + if resetAt == nil { + t.Fatal("expected non-nil resetAt for user scenario") + } + + // Should use the 7d reset time (384607 seconds) since 7d limit is exhausted (100%) + expectedDuration := 384607 * time.Second + minExpected := before.Add(expectedDuration) + maxExpected := after.Add(expectedDuration) + + if resetAt.Before(minExpected) || resetAt.After(maxExpected) { + t.Errorf("resetAt %v not in expected range [%v, %v]", resetAt, minExpected, maxExpected) + } + + // Verify it's approximately 4.45 days (384607 seconds) + duration := resetAt.Sub(before) + actualDays := duration.Hours() / 24.0 + + // 384607 / 86400 = ~4.45 days + if actualDays < 4.4 || actualDays > 4.5 { + t.Errorf("expected ~4.45 days, got %.2f days", actualDays) + } + + t.Logf("User scenario: reset_at=%v, duration=%.2f days", resetAt, actualDays) +} + +func TestCalculateOpenAI429ResetTime_5MinFallbackWhenNoReset(t *testing.T) { + // Test that we return nil when there's used_percent but no reset_after_seconds + // This should cause the caller to use the default 5-minute fallback + + svc := &RateLimitService{} + + headers := http.Header{} + headers.Set("x-codex-primary-used-percent", "100") + // No reset_after_seconds! + + resetAt := svc.calculateOpenAI429ResetTime(headers) + + // Should return nil since there's no reset time available + if resetAt != nil { + t.Errorf("expected nil when no reset_after_seconds, got %v", resetAt) + } +} diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index 2d716a90..68e3ee08 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -61,6 +61,8 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings SettingKeyRegistrationEnabled, SettingKeyEmailVerifyEnabled, SettingKeyPromoCodeEnabled, + SettingKeyPasswordResetEnabled, + SettingKeyTotpEnabled, SettingKeyTurnstileEnabled, SettingKeyTurnstileSiteKey, SettingKeySiteName, @@ -86,21 +88,27 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings linuxDoEnabled = s.cfg != nil && s.cfg.LinuxDo.Enabled } + // Password reset requires email verification to be enabled + emailVerifyEnabled := settings[SettingKeyEmailVerifyEnabled] == "true" + passwordResetEnabled := emailVerifyEnabled && settings[SettingKeyPasswordResetEnabled] == "true" + return &PublicSettings{ - RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true", - EmailVerifyEnabled: settings[SettingKeyEmailVerifyEnabled] == "true", - PromoCodeEnabled: settings[SettingKeyPromoCodeEnabled] != "false", // 默认启用 - TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true", - TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey], - SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "TianShuAPI"), - SiteLogo: settings[SettingKeySiteLogo], - SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"), - APIBaseURL: settings[SettingKeyAPIBaseURL], - ContactInfo: settings[SettingKeyContactInfo], - DocURL: settings[SettingKeyDocURL], - HomeContent: settings[SettingKeyHomeContent], - HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true", - LinuxDoOAuthEnabled: linuxDoEnabled, + RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true", + EmailVerifyEnabled: emailVerifyEnabled, + PromoCodeEnabled: settings[SettingKeyPromoCodeEnabled] != "false", // 默认启用 + PasswordResetEnabled: passwordResetEnabled, + TotpEnabled: settings[SettingKeyTotpEnabled] == "true", + TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true", + TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey], + SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "TianShuAPI"), + SiteLogo: settings[SettingKeySiteLogo], + SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"), + APIBaseURL: settings[SettingKeyAPIBaseURL], + ContactInfo: settings[SettingKeyContactInfo], + DocURL: settings[SettingKeyDocURL], + HomeContent: settings[SettingKeyHomeContent], + HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true", + LinuxDoOAuthEnabled: linuxDoEnabled, }, nil } @@ -125,37 +133,41 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any // Return a struct that matches the frontend's expected format return &struct { - RegistrationEnabled bool `json:"registration_enabled"` - EmailVerifyEnabled bool `json:"email_verify_enabled"` - PromoCodeEnabled bool `json:"promo_code_enabled"` - TurnstileEnabled bool `json:"turnstile_enabled"` - TurnstileSiteKey string `json:"turnstile_site_key,omitempty"` - SiteName string `json:"site_name"` - SiteLogo string `json:"site_logo,omitempty"` - SiteSubtitle string `json:"site_subtitle,omitempty"` - APIBaseURL string `json:"api_base_url,omitempty"` - ContactInfo string `json:"contact_info,omitempty"` - DocURL string `json:"doc_url,omitempty"` - HomeContent string `json:"home_content,omitempty"` - HideCcsImportButton bool `json:"hide_ccs_import_button"` - LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` - Version string `json:"version,omitempty"` + RegistrationEnabled bool `json:"registration_enabled"` + EmailVerifyEnabled bool `json:"email_verify_enabled"` + PromoCodeEnabled bool `json:"promo_code_enabled"` + PasswordResetEnabled bool `json:"password_reset_enabled"` + TotpEnabled bool `json:"totp_enabled"` + TurnstileEnabled bool `json:"turnstile_enabled"` + TurnstileSiteKey string `json:"turnstile_site_key,omitempty"` + SiteName string `json:"site_name"` + SiteLogo string `json:"site_logo,omitempty"` + SiteSubtitle string `json:"site_subtitle,omitempty"` + APIBaseURL string `json:"api_base_url,omitempty"` + ContactInfo string `json:"contact_info,omitempty"` + DocURL string `json:"doc_url,omitempty"` + HomeContent string `json:"home_content,omitempty"` + HideCcsImportButton bool `json:"hide_ccs_import_button"` + LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` + Version string `json:"version,omitempty"` }{ - RegistrationEnabled: settings.RegistrationEnabled, - EmailVerifyEnabled: settings.EmailVerifyEnabled, - PromoCodeEnabled: settings.PromoCodeEnabled, - TurnstileEnabled: settings.TurnstileEnabled, - TurnstileSiteKey: settings.TurnstileSiteKey, - SiteName: settings.SiteName, - SiteLogo: settings.SiteLogo, - SiteSubtitle: settings.SiteSubtitle, - APIBaseURL: settings.APIBaseURL, - ContactInfo: settings.ContactInfo, - DocURL: settings.DocURL, - HomeContent: settings.HomeContent, - HideCcsImportButton: settings.HideCcsImportButton, - LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, - Version: s.version, + RegistrationEnabled: settings.RegistrationEnabled, + EmailVerifyEnabled: settings.EmailVerifyEnabled, + PromoCodeEnabled: settings.PromoCodeEnabled, + PasswordResetEnabled: settings.PasswordResetEnabled, + TotpEnabled: settings.TotpEnabled, + TurnstileEnabled: settings.TurnstileEnabled, + TurnstileSiteKey: settings.TurnstileSiteKey, + SiteName: settings.SiteName, + SiteLogo: settings.SiteLogo, + SiteSubtitle: settings.SiteSubtitle, + APIBaseURL: settings.APIBaseURL, + ContactInfo: settings.ContactInfo, + DocURL: settings.DocURL, + HomeContent: settings.HomeContent, + HideCcsImportButton: settings.HideCcsImportButton, + LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, + Version: s.version, }, nil } @@ -167,6 +179,8 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet updates[SettingKeyRegistrationEnabled] = strconv.FormatBool(settings.RegistrationEnabled) updates[SettingKeyEmailVerifyEnabled] = strconv.FormatBool(settings.EmailVerifyEnabled) updates[SettingKeyPromoCodeEnabled] = strconv.FormatBool(settings.PromoCodeEnabled) + updates[SettingKeyPasswordResetEnabled] = strconv.FormatBool(settings.PasswordResetEnabled) + updates[SettingKeyTotpEnabled] = strconv.FormatBool(settings.TotpEnabled) // 邮件服务设置(只有非空才更新密码) updates[SettingKeySMTPHost] = settings.SMTPHost @@ -262,6 +276,35 @@ func (s *SettingService) IsPromoCodeEnabled(ctx context.Context) bool { return value != "false" } +// IsPasswordResetEnabled 检查是否启用密码重置功能 +// 要求:必须同时开启邮件验证 +func (s *SettingService) IsPasswordResetEnabled(ctx context.Context) bool { + // Password reset requires email verification to be enabled + if !s.IsEmailVerifyEnabled(ctx) { + return false + } + value, err := s.settingRepo.GetValue(ctx, SettingKeyPasswordResetEnabled) + if err != nil { + return false // 默认关闭 + } + return value == "true" +} + +// IsTotpEnabled 检查是否启用 TOTP 双因素认证功能 +func (s *SettingService) IsTotpEnabled(ctx context.Context) bool { + value, err := s.settingRepo.GetValue(ctx, SettingKeyTotpEnabled) + if err != nil { + return false // 默认关闭 + } + return value == "true" +} + +// IsTotpEncryptionKeyConfigured 检查 TOTP 加密密钥是否已手动配置 +// 只有手动配置了密钥才允许在管理后台启用 TOTP 功能 +func (s *SettingService) IsTotpEncryptionKeyConfigured() bool { + return s.cfg.Totp.EncryptionKeyConfigured +} + // GetSiteName 获取网站名称 func (s *SettingService) GetSiteName(ctx context.Context) string { value, err := s.settingRepo.GetValue(ctx, SettingKeySiteName) @@ -340,10 +383,13 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { // parseSettings 解析设置到结构体 func (s *SettingService) parseSettings(settings map[string]string) *SystemSettings { + emailVerifyEnabled := settings[SettingKeyEmailVerifyEnabled] == "true" result := &SystemSettings{ RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true", - EmailVerifyEnabled: settings[SettingKeyEmailVerifyEnabled] == "true", + EmailVerifyEnabled: emailVerifyEnabled, PromoCodeEnabled: settings[SettingKeyPromoCodeEnabled] != "false", // 默认启用 + PasswordResetEnabled: emailVerifyEnabled && settings[SettingKeyPasswordResetEnabled] == "true", + TotpEnabled: settings[SettingKeyTotpEnabled] == "true", SMTPHost: settings[SettingKeySMTPHost], SMTPUsername: settings[SettingKeySMTPUsername], SMTPFrom: settings[SettingKeySMTPFrom], diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go index 919344e5..f10254e5 100644 --- a/backend/internal/service/settings_view.go +++ b/backend/internal/service/settings_view.go @@ -1,9 +1,11 @@ package service type SystemSettings struct { - RegistrationEnabled bool - EmailVerifyEnabled bool - PromoCodeEnabled bool + RegistrationEnabled bool + EmailVerifyEnabled bool + PromoCodeEnabled bool + PasswordResetEnabled bool + TotpEnabled bool // TOTP 双因素认证 SMTPHost string SMTPPort int @@ -57,21 +59,23 @@ type SystemSettings struct { } type PublicSettings struct { - RegistrationEnabled bool - EmailVerifyEnabled bool - PromoCodeEnabled bool - TurnstileEnabled bool - TurnstileSiteKey string - SiteName string - SiteLogo string - SiteSubtitle string - APIBaseURL string - ContactInfo string - DocURL string - HomeContent string - HideCcsImportButton bool - LinuxDoOAuthEnabled bool - Version string + RegistrationEnabled bool + EmailVerifyEnabled bool + PromoCodeEnabled bool + PasswordResetEnabled bool + TotpEnabled bool // TOTP 双因素认证 + TurnstileEnabled bool + TurnstileSiteKey string + SiteName string + SiteLogo string + SiteSubtitle string + APIBaseURL string + ContactInfo string + DocURL string + HomeContent string + HideCcsImportButton bool + LinuxDoOAuthEnabled bool + Version string } // StreamTimeoutSettings 流超时处理配置(仅控制超时后的处理方式,超时判定由网关配置控制) diff --git a/backend/internal/service/subscription_expiry_service.go b/backend/internal/service/subscription_expiry_service.go new file mode 100644 index 00000000..ce6b32b8 --- /dev/null +++ b/backend/internal/service/subscription_expiry_service.go @@ -0,0 +1,71 @@ +package service + +import ( + "context" + "log" + "sync" + "time" +) + +// SubscriptionExpiryService periodically updates expired subscription status. +type SubscriptionExpiryService struct { + userSubRepo UserSubscriptionRepository + interval time.Duration + stopCh chan struct{} + stopOnce sync.Once + wg sync.WaitGroup +} + +func NewSubscriptionExpiryService(userSubRepo UserSubscriptionRepository, interval time.Duration) *SubscriptionExpiryService { + return &SubscriptionExpiryService{ + userSubRepo: userSubRepo, + interval: interval, + stopCh: make(chan struct{}), + } +} + +func (s *SubscriptionExpiryService) Start() { + if s == nil || s.userSubRepo == nil || s.interval <= 0 { + return + } + s.wg.Add(1) + go func() { + defer s.wg.Done() + ticker := time.NewTicker(s.interval) + defer ticker.Stop() + + s.runOnce() + for { + select { + case <-ticker.C: + s.runOnce() + case <-s.stopCh: + return + } + } + }() +} + +func (s *SubscriptionExpiryService) Stop() { + if s == nil { + return + } + s.stopOnce.Do(func() { + close(s.stopCh) + }) + s.wg.Wait() +} + +func (s *SubscriptionExpiryService) runOnce() { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + updated, err := s.userSubRepo.BatchUpdateExpiredStatus(ctx) + if err != nil { + log.Printf("[SubscriptionExpiry] Update expired subscriptions failed: %v", err) + return + } + if updated > 0 { + log.Printf("[SubscriptionExpiry] Updated %d expired subscriptions", updated) + } +} diff --git a/backend/internal/service/subscription_service.go b/backend/internal/service/subscription_service.go index c25c58a2..3c42852e 100644 --- a/backend/internal/service/subscription_service.go +++ b/backend/internal/service/subscription_service.go @@ -324,18 +324,31 @@ func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscripti days = -MaxValidityDays } + now := time.Now() + isExpired := !sub.ExpiresAt.After(now) + + // 如果订阅已过期,不允许负向调整 + if isExpired && days < 0 { + return nil, infraerrors.BadRequest("CANNOT_SHORTEN_EXPIRED", "cannot shorten an expired subscription") + } + // 计算新的过期时间 - newExpiresAt := sub.ExpiresAt.AddDate(0, 0, days) + var newExpiresAt time.Time + if isExpired { + // 已过期:从当前时间开始增加天数 + newExpiresAt = now.AddDate(0, 0, days) + } else { + // 未过期:从原过期时间增加/减少天数 + newExpiresAt = sub.ExpiresAt.AddDate(0, 0, days) + } + if newExpiresAt.After(MaxExpiresAt) { newExpiresAt = MaxExpiresAt } - // 如果是缩短(负数),检查新的过期时间必须大于当前时间 - if days < 0 { - now := time.Now() - if !newExpiresAt.After(now) { - return nil, ErrAdjustWouldExpire - } + // 检查新的过期时间必须大于当前时间 + if !newExpiresAt.After(now) { + return nil, ErrAdjustWouldExpire } if err := s.userSubRepo.ExtendExpiry(ctx, subscriptionID, newExpiresAt); err != nil { @@ -383,6 +396,7 @@ func (s *SubscriptionService) ListUserSubscriptions(ctx context.Context, userID return nil, err } normalizeExpiredWindows(subs) + normalizeSubscriptionStatus(subs) return subs, nil } @@ -404,17 +418,19 @@ func (s *SubscriptionService) ListGroupSubscriptions(ctx context.Context, groupI return nil, nil, err } normalizeExpiredWindows(subs) + normalizeSubscriptionStatus(subs) return subs, pag, nil } -// List 获取所有订阅(分页,支持筛选) -func (s *SubscriptionService) List(ctx context.Context, page, pageSize int, userID, groupID *int64, status string) ([]UserSubscription, *pagination.PaginationResult, error) { +// List 获取所有订阅(分页,支持筛选和排序) +func (s *SubscriptionService) List(ctx context.Context, page, pageSize int, userID, groupID *int64, status, sortBy, sortOrder string) ([]UserSubscription, *pagination.PaginationResult, error) { params := pagination.PaginationParams{Page: page, PageSize: pageSize} - subs, pag, err := s.userSubRepo.List(ctx, params, userID, groupID, status) + subs, pag, err := s.userSubRepo.List(ctx, params, userID, groupID, status, sortBy, sortOrder) if err != nil { return nil, nil, err } normalizeExpiredWindows(subs) + normalizeSubscriptionStatus(subs) return subs, pag, nil } @@ -441,6 +457,18 @@ func normalizeExpiredWindows(subs []UserSubscription) { } } +// normalizeSubscriptionStatus 根据实际过期时间修正状态(仅影响返回数据,不影响数据库) +// 这确保前端显示正确的状态,即使定时任务尚未更新数据库 +func normalizeSubscriptionStatus(subs []UserSubscription) { + now := time.Now() + for i := range subs { + sub := &subs[i] + if sub.Status == SubscriptionStatusActive && !sub.ExpiresAt.After(now) { + sub.Status = SubscriptionStatusExpired + } + } +} + // startOfDay 返回给定时间所在日期的零点(保持原时区) func startOfDay(t time.Time) time.Time { return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, t.Location()) @@ -659,11 +687,6 @@ func (s *SubscriptionService) GetUserSubscriptionsWithProgress(ctx context.Conte return progresses, nil } -// UpdateExpiredSubscriptions 更新过期订阅状态(定时任务调用) -func (s *SubscriptionService) UpdateExpiredSubscriptions(ctx context.Context) (int64, error) { - return s.userSubRepo.BatchUpdateExpiredStatus(ctx) -} - // ValidateSubscription 验证订阅是否有效 func (s *SubscriptionService) ValidateSubscription(ctx context.Context, sub *UserSubscription) error { if sub.Status == SubscriptionStatusExpired { diff --git a/backend/internal/service/token_refresh_service.go b/backend/internal/service/token_refresh_service.go index 7364bd33..6ef92bbf 100644 --- a/backend/internal/service/token_refresh_service.go +++ b/backend/internal/service/token_refresh_service.go @@ -237,7 +237,8 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc } // isNonRetryableRefreshError 判断是否为不可重试的刷新错误 -// 这些错误通常表示凭证已失效,需要用户重新授权 +// 这些错误通常表示凭证已失效或配置确实缺失,需要用户重新授权 +// 注意:missing_project_id 错误只在真正缺失(从未获取过)时返回,临时获取失败不会返回此错误 func isNonRetryableRefreshError(err error) bool { if err == nil { return false diff --git a/backend/internal/service/totp_service.go b/backend/internal/service/totp_service.go new file mode 100644 index 00000000..5192fe3d --- /dev/null +++ b/backend/internal/service/totp_service.go @@ -0,0 +1,506 @@ +package service + +import ( + "context" + "crypto/rand" + "crypto/subtle" + "encoding/hex" + "fmt" + "log/slog" + "time" + + "github.com/pquerna/otp/totp" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" +) + +var ( + ErrTotpNotEnabled = infraerrors.BadRequest("TOTP_NOT_ENABLED", "totp feature is not enabled") + ErrTotpAlreadyEnabled = infraerrors.BadRequest("TOTP_ALREADY_ENABLED", "totp is already enabled for this account") + ErrTotpNotSetup = infraerrors.BadRequest("TOTP_NOT_SETUP", "totp is not set up for this account") + ErrTotpInvalidCode = infraerrors.BadRequest("TOTP_INVALID_CODE", "invalid totp code") + ErrTotpSetupExpired = infraerrors.BadRequest("TOTP_SETUP_EXPIRED", "totp setup session expired") + ErrTotpTooManyAttempts = infraerrors.TooManyRequests("TOTP_TOO_MANY_ATTEMPTS", "too many verification attempts, please try again later") + ErrVerifyCodeRequired = infraerrors.BadRequest("VERIFY_CODE_REQUIRED", "email verification code is required") + ErrPasswordRequired = infraerrors.BadRequest("PASSWORD_REQUIRED", "password is required") +) + +// TotpCache defines cache operations for TOTP service +type TotpCache interface { + // Setup session methods + GetSetupSession(ctx context.Context, userID int64) (*TotpSetupSession, error) + SetSetupSession(ctx context.Context, userID int64, session *TotpSetupSession, ttl time.Duration) error + DeleteSetupSession(ctx context.Context, userID int64) error + + // Login session methods (for 2FA login flow) + GetLoginSession(ctx context.Context, tempToken string) (*TotpLoginSession, error) + SetLoginSession(ctx context.Context, tempToken string, session *TotpLoginSession, ttl time.Duration) error + DeleteLoginSession(ctx context.Context, tempToken string) error + + // Rate limiting + IncrementVerifyAttempts(ctx context.Context, userID int64) (int, error) + GetVerifyAttempts(ctx context.Context, userID int64) (int, error) + ClearVerifyAttempts(ctx context.Context, userID int64) error +} + +// SecretEncryptor defines encryption operations for TOTP secrets +type SecretEncryptor interface { + Encrypt(plaintext string) (string, error) + Decrypt(ciphertext string) (string, error) +} + +// TotpSetupSession represents a TOTP setup session +type TotpSetupSession struct { + Secret string // Plain text TOTP secret (not encrypted yet) + SetupToken string // Random token to verify setup request + CreatedAt time.Time +} + +// TotpLoginSession represents a pending 2FA login session +type TotpLoginSession struct { + UserID int64 + Email string + TokenExpiry time.Time +} + +// TotpStatus represents the TOTP status for a user +type TotpStatus struct { + Enabled bool `json:"enabled"` + EnabledAt *time.Time `json:"enabled_at,omitempty"` + FeatureEnabled bool `json:"feature_enabled"` +} + +// TotpSetupResponse represents the response for initiating TOTP setup +type TotpSetupResponse struct { + Secret string `json:"secret"` + QRCodeURL string `json:"qr_code_url"` + SetupToken string `json:"setup_token"` + Countdown int `json:"countdown"` // seconds until setup expires +} + +const ( + totpSetupTTL = 5 * time.Minute + totpLoginTTL = 5 * time.Minute + totpAttemptsTTL = 15 * time.Minute + maxTotpAttempts = 5 + totpIssuer = "Sub2API" +) + +// TotpService handles TOTP operations +type TotpService struct { + userRepo UserRepository + encryptor SecretEncryptor + cache TotpCache + settingService *SettingService + emailService *EmailService + emailQueueService *EmailQueueService +} + +// NewTotpService creates a new TOTP service +func NewTotpService( + userRepo UserRepository, + encryptor SecretEncryptor, + cache TotpCache, + settingService *SettingService, + emailService *EmailService, + emailQueueService *EmailQueueService, +) *TotpService { + return &TotpService{ + userRepo: userRepo, + encryptor: encryptor, + cache: cache, + settingService: settingService, + emailService: emailService, + emailQueueService: emailQueueService, + } +} + +// GetStatus returns the TOTP status for a user +func (s *TotpService) GetStatus(ctx context.Context, userID int64) (*TotpStatus, error) { + featureEnabled := s.settingService.IsTotpEnabled(ctx) + + user, err := s.userRepo.GetByID(ctx, userID) + if err != nil { + return nil, fmt.Errorf("get user: %w", err) + } + + return &TotpStatus{ + Enabled: user.TotpEnabled, + EnabledAt: user.TotpEnabledAt, + FeatureEnabled: featureEnabled, + }, nil +} + +// InitiateSetup starts the TOTP setup process +// If email verification is enabled, emailCode is required; otherwise password is required +func (s *TotpService) InitiateSetup(ctx context.Context, userID int64, emailCode, password string) (*TotpSetupResponse, error) { + // Check if TOTP feature is enabled globally + if !s.settingService.IsTotpEnabled(ctx) { + return nil, ErrTotpNotEnabled + } + + // Get user and check if TOTP is already enabled + user, err := s.userRepo.GetByID(ctx, userID) + if err != nil { + return nil, fmt.Errorf("get user: %w", err) + } + + if user.TotpEnabled { + return nil, ErrTotpAlreadyEnabled + } + + // Verify identity based on email verification setting + if s.settingService.IsEmailVerifyEnabled(ctx) { + // Email verification enabled - verify email code + if emailCode == "" { + return nil, ErrVerifyCodeRequired + } + if err := s.emailService.VerifyCode(ctx, user.Email, emailCode); err != nil { + return nil, err + } + } else { + // Email verification disabled - verify password + if password == "" { + return nil, ErrPasswordRequired + } + if !user.CheckPassword(password) { + return nil, ErrPasswordIncorrect + } + } + + // Generate a new TOTP key + key, err := totp.Generate(totp.GenerateOpts{ + Issuer: totpIssuer, + AccountName: user.Email, + }) + if err != nil { + return nil, fmt.Errorf("generate totp key: %w", err) + } + + // Generate a random setup token + setupToken, err := generateRandomToken(32) + if err != nil { + return nil, fmt.Errorf("generate setup token: %w", err) + } + + // Store the setup session in cache + session := &TotpSetupSession{ + Secret: key.Secret(), + SetupToken: setupToken, + CreatedAt: time.Now(), + } + + if err := s.cache.SetSetupSession(ctx, userID, session, totpSetupTTL); err != nil { + return nil, fmt.Errorf("store setup session: %w", err) + } + + return &TotpSetupResponse{ + Secret: key.Secret(), + QRCodeURL: key.URL(), + SetupToken: setupToken, + Countdown: int(totpSetupTTL.Seconds()), + }, nil +} + +// CompleteSetup completes the TOTP setup by verifying the code +func (s *TotpService) CompleteSetup(ctx context.Context, userID int64, totpCode, setupToken string) error { + // Check if TOTP feature is enabled globally + if !s.settingService.IsTotpEnabled(ctx) { + return ErrTotpNotEnabled + } + + // Get the setup session + session, err := s.cache.GetSetupSession(ctx, userID) + if err != nil { + return ErrTotpSetupExpired + } + + if session == nil { + return ErrTotpSetupExpired + } + + // Verify the setup token (constant-time comparison) + if subtle.ConstantTimeCompare([]byte(session.SetupToken), []byte(setupToken)) != 1 { + return ErrTotpSetupExpired + } + + // Verify the TOTP code + if !totp.Validate(totpCode, session.Secret) { + return ErrTotpInvalidCode + } + + setupSecretPrefix := "N/A" + if len(session.Secret) >= 4 { + setupSecretPrefix = session.Secret[:4] + } + slog.Debug("totp_complete_setup_before_encrypt", + "user_id", userID, + "secret_len", len(session.Secret), + "secret_prefix", setupSecretPrefix) + + // Encrypt the secret + encryptedSecret, err := s.encryptor.Encrypt(session.Secret) + if err != nil { + return fmt.Errorf("encrypt totp secret: %w", err) + } + + slog.Debug("totp_complete_setup_encrypted", + "user_id", userID, + "encrypted_len", len(encryptedSecret)) + + // Verify encryption by decrypting + decrypted, decErr := s.encryptor.Decrypt(encryptedSecret) + if decErr != nil { + slog.Debug("totp_complete_setup_verify_failed", + "user_id", userID, + "error", decErr) + } else { + decryptedPrefix := "N/A" + if len(decrypted) >= 4 { + decryptedPrefix = decrypted[:4] + } + slog.Debug("totp_complete_setup_verified", + "user_id", userID, + "original_len", len(session.Secret), + "decrypted_len", len(decrypted), + "match", session.Secret == decrypted, + "decrypted_prefix", decryptedPrefix) + } + + // Update user with encrypted TOTP secret + if err := s.userRepo.UpdateTotpSecret(ctx, userID, &encryptedSecret); err != nil { + return fmt.Errorf("update totp secret: %w", err) + } + + // Enable TOTP for the user + if err := s.userRepo.EnableTotp(ctx, userID); err != nil { + return fmt.Errorf("enable totp: %w", err) + } + + // Clean up the setup session + _ = s.cache.DeleteSetupSession(ctx, userID) + + return nil +} + +// Disable disables TOTP for a user +// If email verification is enabled, emailCode is required; otherwise password is required +func (s *TotpService) Disable(ctx context.Context, userID int64, emailCode, password string) error { + // Get user + user, err := s.userRepo.GetByID(ctx, userID) + if err != nil { + return fmt.Errorf("get user: %w", err) + } + + if !user.TotpEnabled { + return ErrTotpNotSetup + } + + // Verify identity based on email verification setting + if s.settingService.IsEmailVerifyEnabled(ctx) { + // Email verification enabled - verify email code + if emailCode == "" { + return ErrVerifyCodeRequired + } + if err := s.emailService.VerifyCode(ctx, user.Email, emailCode); err != nil { + return err + } + } else { + // Email verification disabled - verify password + if password == "" { + return ErrPasswordRequired + } + if !user.CheckPassword(password) { + return ErrPasswordIncorrect + } + } + + // Disable TOTP + if err := s.userRepo.DisableTotp(ctx, userID); err != nil { + return fmt.Errorf("disable totp: %w", err) + } + + return nil +} + +// VerifyCode verifies a TOTP code for a user +func (s *TotpService) VerifyCode(ctx context.Context, userID int64, code string) error { + slog.Debug("totp_verify_code_called", + "user_id", userID, + "code_len", len(code)) + + // Check rate limiting + attempts, err := s.cache.GetVerifyAttempts(ctx, userID) + if err == nil && attempts >= maxTotpAttempts { + return ErrTotpTooManyAttempts + } + + // Get user + user, err := s.userRepo.GetByID(ctx, userID) + if err != nil { + slog.Debug("totp_verify_get_user_failed", + "user_id", userID, + "error", err) + return infraerrors.InternalServer("TOTP_VERIFY_ERROR", "failed to verify totp code") + } + + if !user.TotpEnabled || user.TotpSecretEncrypted == nil { + slog.Debug("totp_verify_not_setup", + "user_id", userID, + "enabled", user.TotpEnabled, + "has_secret", user.TotpSecretEncrypted != nil) + return ErrTotpNotSetup + } + + slog.Debug("totp_verify_encrypted_secret", + "user_id", userID, + "encrypted_len", len(*user.TotpSecretEncrypted)) + + // Decrypt the secret + secret, err := s.encryptor.Decrypt(*user.TotpSecretEncrypted) + if err != nil { + slog.Debug("totp_verify_decrypt_failed", + "user_id", userID, + "error", err) + return infraerrors.InternalServer("TOTP_VERIFY_ERROR", "failed to verify totp code") + } + + secretPrefix := "N/A" + if len(secret) >= 4 { + secretPrefix = secret[:4] + } + slog.Debug("totp_verify_decrypted", + "user_id", userID, + "secret_len", len(secret), + "secret_prefix", secretPrefix) + + // Verify the code + valid := totp.Validate(code, secret) + slog.Debug("totp_verify_result", + "user_id", userID, + "valid", valid, + "secret_len", len(secret), + "secret_prefix", secretPrefix, + "server_time", time.Now().UTC().Format(time.RFC3339)) + + if !valid { + // Increment failed attempts + _, _ = s.cache.IncrementVerifyAttempts(ctx, userID) + return ErrTotpInvalidCode + } + + // Clear attempt counter on success + _ = s.cache.ClearVerifyAttempts(ctx, userID) + + return nil +} + +// CreateLoginSession creates a temporary login session for 2FA +func (s *TotpService) CreateLoginSession(ctx context.Context, userID int64, email string) (string, error) { + // Generate a random temp token + tempToken, err := generateRandomToken(32) + if err != nil { + return "", fmt.Errorf("generate temp token: %w", err) + } + + session := &TotpLoginSession{ + UserID: userID, + Email: email, + TokenExpiry: time.Now().Add(totpLoginTTL), + } + + if err := s.cache.SetLoginSession(ctx, tempToken, session, totpLoginTTL); err != nil { + return "", fmt.Errorf("store login session: %w", err) + } + + return tempToken, nil +} + +// GetLoginSession retrieves a login session +func (s *TotpService) GetLoginSession(ctx context.Context, tempToken string) (*TotpLoginSession, error) { + return s.cache.GetLoginSession(ctx, tempToken) +} + +// DeleteLoginSession deletes a login session +func (s *TotpService) DeleteLoginSession(ctx context.Context, tempToken string) error { + return s.cache.DeleteLoginSession(ctx, tempToken) +} + +// IsTotpEnabledForUser checks if TOTP is enabled for a specific user +func (s *TotpService) IsTotpEnabledForUser(ctx context.Context, userID int64) (bool, error) { + user, err := s.userRepo.GetByID(ctx, userID) + if err != nil { + return false, fmt.Errorf("get user: %w", err) + } + return user.TotpEnabled, nil +} + +// MaskEmail masks an email address for display +func MaskEmail(email string) string { + if len(email) < 3 { + return "***" + } + + atIdx := -1 + for i, c := range email { + if c == '@' { + atIdx = i + break + } + } + + if atIdx == -1 || atIdx < 1 { + return email[:1] + "***" + } + + localPart := email[:atIdx] + domain := email[atIdx:] + + if len(localPart) <= 2 { + return localPart[:1] + "***" + domain + } + + return localPart[:1] + "***" + localPart[len(localPart)-1:] + domain +} + +// generateRandomToken generates a random hex-encoded token +func generateRandomToken(byteLength int) (string, error) { + b := make([]byte, byteLength) + if _, err := rand.Read(b); err != nil { + return "", err + } + return hex.EncodeToString(b), nil +} + +// VerificationMethod represents the method required for TOTP operations +type VerificationMethod struct { + Method string `json:"method"` // "email" or "password" +} + +// GetVerificationMethod returns the verification method for TOTP operations +func (s *TotpService) GetVerificationMethod(ctx context.Context) *VerificationMethod { + if s.settingService.IsEmailVerifyEnabled(ctx) { + return &VerificationMethod{Method: "email"} + } + return &VerificationMethod{Method: "password"} +} + +// SendVerifyCode sends an email verification code for TOTP operations +func (s *TotpService) SendVerifyCode(ctx context.Context, userID int64) error { + // Check if email verification is enabled + if !s.settingService.IsEmailVerifyEnabled(ctx) { + return infraerrors.BadRequest("EMAIL_VERIFY_NOT_ENABLED", "email verification is not enabled") + } + + // Get user email + user, err := s.userRepo.GetByID(ctx, userID) + if err != nil { + return fmt.Errorf("get user: %w", err) + } + + // Get site name for email + siteName := s.settingService.GetSiteName(ctx) + + // Send verification code via queue + return s.emailQueueService.EnqueueVerifyCode(user.Email, siteName) +} diff --git a/backend/internal/service/user.go b/backend/internal/service/user.go index c565607e..0f589eb3 100644 --- a/backend/internal/service/user.go +++ b/backend/internal/service/user.go @@ -21,6 +21,11 @@ type User struct { CreatedAt time.Time UpdatedAt time.Time + // TOTP 双因素认证字段 + TotpSecretEncrypted *string // AES-256-GCM 加密的 TOTP 密钥 + TotpEnabled bool // 是否启用 TOTP + TotpEnabledAt *time.Time // TOTP 启用时间 + APIKeys []APIKey Subscriptions []UserSubscription } diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go index 1734914a..99bf7fd0 100644 --- a/backend/internal/service/user_service.go +++ b/backend/internal/service/user_service.go @@ -38,6 +38,11 @@ type UserRepository interface { UpdateConcurrency(ctx context.Context, id int64, amount int) error ExistsByEmail(ctx context.Context, email string) (bool, error) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) + + // TOTP 相关方法 + UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error + EnableTotp(ctx context.Context, userID int64) error + DisableTotp(ctx context.Context, userID int64) error } // UpdateProfileRequest 更新用户资料请求 diff --git a/backend/internal/service/user_subscription_port.go b/backend/internal/service/user_subscription_port.go index abf4dffd..2dfc8d02 100644 --- a/backend/internal/service/user_subscription_port.go +++ b/backend/internal/service/user_subscription_port.go @@ -18,7 +18,7 @@ type UserSubscriptionRepository interface { ListByUserID(ctx context.Context, userID int64) ([]UserSubscription, error) ListActiveByUserID(ctx context.Context, userID int64) ([]UserSubscription, error) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]UserSubscription, *pagination.PaginationResult, error) - List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]UserSubscription, *pagination.PaginationResult, error) + List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]UserSubscription, *pagination.PaginationResult, error) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) ExtendExpiry(ctx context.Context, subscriptionID int64, newExpiresAt time.Time) error diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index b210286d..df86b2e7 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -72,6 +72,13 @@ func ProvideAccountExpiryService(accountRepo AccountRepository) *AccountExpirySe return svc } +// ProvideSubscriptionExpiryService creates and starts SubscriptionExpiryService. +func ProvideSubscriptionExpiryService(userSubRepo UserSubscriptionRepository) *SubscriptionExpiryService { + svc := NewSubscriptionExpiryService(userSubRepo, time.Minute) + svc.Start() + return svc +} + // ProvideTimingWheelService creates and starts TimingWheelService func ProvideTimingWheelService() (*TimingWheelService, error) { svc, err := NewTimingWheelService() @@ -256,6 +263,7 @@ var ProviderSet = wire.NewSet( ProvideUpdateService, ProvideTokenRefreshService, ProvideAccountExpiryService, + ProvideSubscriptionExpiryService, ProvideTimingWheelService, ProvideDashboardAggregationService, ProvideUsageCleanupService, @@ -263,4 +271,5 @@ var ProviderSet = wire.NewSet( NewAntigravityQuotaFetcher, NewUserAttributeService, NewUsageCache, + NewTotpService, ) diff --git a/backend/internal/util/urlvalidator/validator.go b/backend/internal/util/urlvalidator/validator.go index 56a888b9..49df015b 100644 --- a/backend/internal/util/urlvalidator/validator.go +++ b/backend/internal/util/urlvalidator/validator.go @@ -46,7 +46,7 @@ func ValidateURLFormat(raw string, allowInsecureHTTP bool) (string, error) { } } - return trimmed, nil + return strings.TrimRight(trimmed, "/"), nil } func ValidateHTTPSURL(raw string, opts ValidationOptions) (string, error) { diff --git a/backend/internal/util/urlvalidator/validator_test.go b/backend/internal/util/urlvalidator/validator_test.go index b7f9ffed..f9745da3 100644 --- a/backend/internal/util/urlvalidator/validator_test.go +++ b/backend/internal/util/urlvalidator/validator_test.go @@ -21,4 +21,31 @@ func TestValidateURLFormat(t *testing.T) { if _, err := ValidateURLFormat("https://example.com:bad", true); err == nil { t.Fatalf("expected invalid port to fail") } + + // 验证末尾斜杠被移除 + normalized, err := ValidateURLFormat("https://example.com/", false) + if err != nil { + t.Fatalf("expected trailing slash url to pass, got %v", err) + } + if normalized != "https://example.com" { + t.Fatalf("expected trailing slash to be removed, got %s", normalized) + } + + // 验证多个末尾斜杠被移除 + normalized, err = ValidateURLFormat("https://example.com///", false) + if err != nil { + t.Fatalf("expected multiple trailing slashes to pass, got %v", err) + } + if normalized != "https://example.com" { + t.Fatalf("expected all trailing slashes to be removed, got %s", normalized) + } + + // 验证带路径的 URL 末尾斜杠被移除 + normalized, err = ValidateURLFormat("https://example.com/api/v1/", false) + if err != nil { + t.Fatalf("expected trailing slash url with path to pass, got %v", err) + } + if normalized != "https://example.com/api/v1" { + t.Fatalf("expected trailing slash to be removed from path, got %s", normalized) + } } diff --git a/backend/migrations/044_add_user_totp.sql b/backend/migrations/044_add_user_totp.sql new file mode 100644 index 00000000..6e157a68 --- /dev/null +++ b/backend/migrations/044_add_user_totp.sql @@ -0,0 +1,12 @@ +-- 为 users 表添加 TOTP 双因素认证字段 +ALTER TABLE users + ADD COLUMN IF NOT EXISTS totp_secret_encrypted TEXT DEFAULT NULL, + ADD COLUMN IF NOT EXISTS totp_enabled BOOLEAN NOT NULL DEFAULT FALSE, + ADD COLUMN IF NOT EXISTS totp_enabled_at TIMESTAMPTZ DEFAULT NULL; + +COMMENT ON COLUMN users.totp_secret_encrypted IS 'AES-256-GCM 加密的 TOTP 密钥'; +COMMENT ON COLUMN users.totp_enabled IS '是否启用 TOTP 双因素认证'; +COMMENT ON COLUMN users.totp_enabled_at IS 'TOTP 启用时间'; + +-- 创建索引以支持快速查询启用 2FA 的用户 +CREATE INDEX IF NOT EXISTS idx_users_totp_enabled ON users(totp_enabled) WHERE deleted_at IS NULL AND totp_enabled = true; diff --git a/deploy/.env.example b/deploy/.env.example index f21a3c62..1e9395a0 100644 --- a/deploy/.env.example +++ b/deploy/.env.example @@ -61,6 +61,18 @@ ADMIN_PASSWORD= JWT_SECRET= JWT_EXPIRE_HOUR=24 +# ----------------------------------------------------------------------------- +# TOTP (2FA) Configuration +# TOTP(双因素认证)配置 +# ----------------------------------------------------------------------------- +# IMPORTANT: Set a fixed encryption key for TOTP secrets. If left empty, a +# random key will be generated on each startup, causing all existing TOTP +# configurations to become invalid (users won't be able to login with 2FA). +# Generate a secure key: openssl rand -hex 32 +# 重要:设置固定的 TOTP 加密密钥。如果留空,每次启动将生成随机密钥, +# 导致现有的 TOTP 配置失效(用户无法使用双因素认证登录)。 +TOTP_ENCRYPTION_KEY= + # ----------------------------------------------------------------------------- # Configuration File (Optional) # ----------------------------------------------------------------------------- diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index 558b8ef0..98aba8f5 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -403,6 +403,21 @@ jwt: # 令牌过期时间(小时,最大 24) expire_hour: 24 +# ============================================================================= +# TOTP (2FA) Configuration +# TOTP 双因素认证配置 +# ============================================================================= +totp: + # IMPORTANT: Set a fixed encryption key for TOTP secrets. + # 重要:设置固定的 TOTP 加密密钥。 + # If left empty, a random key will be generated on each startup, causing all + # existing TOTP configurations to become invalid (users won't be able to + # login with 2FA). + # 如果留空,每次启动将生成随机密钥,导致现有的 TOTP 配置失效(用户无法使用 + # 双因素认证登录)。 + # Generate with / 生成命令: openssl rand -hex 32 + encryption_key: "" + # ============================================================================= # LinuxDo Connect OAuth Login (SSO) # LinuxDo Connect OAuth 登录(用于 Sub2API 用户登录) diff --git a/deploy/docker-compose.yml b/deploy/docker-compose.yml index 5d3efe8c..123cc4fd 100644 --- a/deploy/docker-compose.yml +++ b/deploy/docker-compose.yml @@ -61,7 +61,25 @@ services: - JWT_SECRET=${JWT_SECRET:-} - JWT_EXPIRE_HOUR=${JWT_EXPIRE_HOUR:-24} - # Timezone + # ======================================================================= + # TOTP (2FA) Configuration + # ======================================================================= + # IMPORTANT: Set a fixed encryption key for TOTP secrets. If left empty, + # a random key will be generated on each startup, causing all existing + # TOTP configurations to become invalid (users won't be able to login + # with 2FA). + # Generate a secure key: openssl rand -hex 32 + - TOTP_ENCRYPTION_KEY=${TOTP_ENCRYPTION_KEY:-} + + # ======================================================================= + # Timezone Configuration + # This affects ALL time operations in the application: + # - Database timestamps + # - Usage statistics "today" boundary + # - Subscription expiry times + # - Log timestamps + # Common values: Asia/Shanghai, America/New_York, Europe/London, UTC + # ======================================================================= - TZ=${TZ:-Asia/Shanghai} # Gemini OAuth (可选) diff --git a/frontend/package-lock.json b/frontend/package-lock.json index e6c6144e..5c43a6a8 100644 --- a/frontend/package-lock.json +++ b/frontend/package-lock.json @@ -15,6 +15,7 @@ "driver.js": "^1.4.0", "file-saver": "^2.0.5", "pinia": "^2.1.7", + "qrcode": "^1.5.4", "vue": "^3.4.0", "vue-chartjs": "^5.3.0", "vue-i18n": "^9.14.5", @@ -25,6 +26,7 @@ "@types/file-saver": "^2.0.7", "@types/mdx": "^2.0.13", "@types/node": "^20.10.5", + "@types/qrcode": "^1.5.6", "@typescript-eslint/eslint-plugin": "^7.18.0", "@typescript-eslint/parser": "^7.18.0", "@vitejs/plugin-vue": "^5.2.3", @@ -1680,6 +1682,16 @@ "integrity": "sha512-dISoDXWWQwUquiKsyZ4Ng+HX2KsPL7LyHKHQwgGFEA3IaKac4Obd+h2a/a6waisAoepJlBcx9paWqjA8/HVjCw==", "license": "MIT" }, + "node_modules/@types/qrcode": { + "version": "1.5.6", + "resolved": "https://registry.npmmirror.com/@types/qrcode/-/qrcode-1.5.6.tgz", + "integrity": "sha512-te7NQcV2BOvdj2b1hCAHzAoMNuj65kNBMz0KBaxM6c3VGBOhU0dURQKOtH8CFNI/dsKkwlv32p26qYQTWoB5bw==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/node": "*" + } + }, "node_modules/@types/web-bluetooth": { "version": "0.0.20", "resolved": "https://registry.npmjs.org/@types/web-bluetooth/-/web-bluetooth-0.0.20.tgz", @@ -2354,7 +2366,6 @@ "version": "5.0.1", "resolved": "https://registry.npmjs.org/ansi-regex/-/ansi-regex-5.0.1.tgz", "integrity": "sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ==", - "dev": true, "license": "MIT", "engines": { "node": ">=8" @@ -2364,7 +2375,6 @@ "version": "4.3.0", "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-4.3.0.tgz", "integrity": "sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==", - "dev": true, "license": "MIT", "dependencies": { "color-convert": "^2.0.1" @@ -2646,6 +2656,15 @@ "node": ">=6" } }, + "node_modules/camelcase": { + "version": "5.3.1", + "resolved": "https://registry.npmmirror.com/camelcase/-/camelcase-5.3.1.tgz", + "integrity": "sha512-L28STB170nwWS63UjtlEOE3dldQApaJXZkOI1uMFfzf3rRuPegHaHesyee+YxQ+W6SvRDQV6UrdOdRiR153wJg==", + "license": "MIT", + "engines": { + "node": ">=6" + } + }, "node_modules/camelcase-css": { "version": "2.0.1", "resolved": "https://registry.npmjs.org/camelcase-css/-/camelcase-css-2.0.1.tgz", @@ -2784,6 +2803,51 @@ "node": ">= 6" } }, + "node_modules/cliui": { + "version": "6.0.0", + "resolved": "https://registry.npmmirror.com/cliui/-/cliui-6.0.0.tgz", + "integrity": "sha512-t6wbgtoCXvAzst7QgXxJYqPt0usEfbgQdftEPbLL/cvv6HPE5VgvqCuAIDR0NgU52ds6rFwqrgakNLrHEjCbrQ==", + "license": "ISC", + "dependencies": { + "string-width": "^4.2.0", + "strip-ansi": "^6.0.0", + "wrap-ansi": "^6.2.0" + } + }, + "node_modules/cliui/node_modules/emoji-regex": { + "version": "8.0.0", + "resolved": "https://registry.npmmirror.com/emoji-regex/-/emoji-regex-8.0.0.tgz", + "integrity": "sha512-MSjYzcWNOA0ewAHpz0MxpYFvwg6yjy1NG3xteoqz644VCo/RPgnr1/GGt+ic3iJTzQ8Eu3TdM14SawnVUmGE6A==", + "license": "MIT" + }, + "node_modules/cliui/node_modules/string-width": { + "version": "4.2.3", + "resolved": "https://registry.npmmirror.com/string-width/-/string-width-4.2.3.tgz", + "integrity": "sha512-wKyQRQpjJ0sIp62ErSZdGsjMJWsap5oRNihHhu6G7JVO/9jIB6UyevL+tXuOqrng8j/cxKTWyWUwvSTriiZz/g==", + "license": "MIT", + "dependencies": { + "emoji-regex": "^8.0.0", + "is-fullwidth-code-point": "^3.0.0", + "strip-ansi": "^6.0.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/cliui/node_modules/wrap-ansi": { + "version": "6.2.0", + "resolved": "https://registry.npmmirror.com/wrap-ansi/-/wrap-ansi-6.2.0.tgz", + "integrity": "sha512-r6lPcBGxZXlIcymEu7InxDMhdW0KDxpLgoFLcguasxCaJ/SOIZwINatK9KY/tf+ZrlywOKU0UDj3ATXUBfxJXA==", + "license": "MIT", + "dependencies": { + "ansi-styles": "^4.0.0", + "string-width": "^4.1.0", + "strip-ansi": "^6.0.0" + }, + "engines": { + "node": ">=8" + } + }, "node_modules/clsx": { "version": "2.1.1", "resolved": "https://registry.npmjs.org/clsx/-/clsx-2.1.1.tgz", @@ -2806,7 +2870,6 @@ "version": "2.0.1", "resolved": "https://registry.npmjs.org/color-convert/-/color-convert-2.0.1.tgz", "integrity": "sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ==", - "dev": true, "license": "MIT", "dependencies": { "color-name": "~1.1.4" @@ -2819,7 +2882,6 @@ "version": "1.1.4", "resolved": "https://registry.npmjs.org/color-name/-/color-name-1.1.4.tgz", "integrity": "sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==", - "dev": true, "license": "MIT" }, "node_modules/combined-stream": { @@ -2989,6 +3051,15 @@ } } }, + "node_modules/decamelize": { + "version": "1.2.0", + "resolved": "https://registry.npmmirror.com/decamelize/-/decamelize-1.2.0.tgz", + "integrity": "sha512-z2S+W9X73hAUUki+N+9Za2lBlun89zigOyGrsax+KUQ6wKW4ZoWpEYBkGhQjwAjjDCkWxhY0VKEhk8wzY7F5cA==", + "license": "MIT", + "engines": { + "node": ">=0.10.0" + } + }, "node_modules/decimal.js": { "version": "10.6.0", "resolved": "https://registry.npmjs.org/decimal.js/-/decimal.js-10.6.0.tgz", @@ -3029,6 +3100,12 @@ "dev": true, "license": "Apache-2.0" }, + "node_modules/dijkstrajs": { + "version": "1.0.3", + "resolved": "https://registry.npmmirror.com/dijkstrajs/-/dijkstrajs-1.0.3.tgz", + "integrity": "sha512-qiSlmBq9+BCdCA/L46dw8Uy93mloxsPSbwnm5yrKn2vMPiy8KyAskTF6zuV/j5BMsmOGZDPs7KjU+mjb670kfA==", + "license": "MIT" + }, "node_modules/dir-glob": { "version": "3.0.1", "resolved": "https://registry.npmjs.org/dir-glob/-/dir-glob-3.0.1.tgz", @@ -3759,6 +3836,15 @@ "url": "https://github.com/sponsors/ljharb" } }, + "node_modules/get-caller-file": { + "version": "2.0.5", + "resolved": "https://registry.npmmirror.com/get-caller-file/-/get-caller-file-2.0.5.tgz", + "integrity": "sha512-DyFP3BM/3YHTQOCUL/w0OZHR0lpKeGrxotcHWcqNEdnltqFwXVfhEBQ94eIo34AfQpo0rGki4cyIiftY06h2Fg==", + "license": "ISC", + "engines": { + "node": "6.* || 8.* || >= 10.*" + } + }, "node_modules/get-intrinsic": { "version": "1.3.0", "resolved": "https://registry.npmjs.org/get-intrinsic/-/get-intrinsic-1.3.0.tgz", @@ -4156,7 +4242,6 @@ "version": "3.0.0", "resolved": "https://registry.npmjs.org/is-fullwidth-code-point/-/is-fullwidth-code-point-3.0.0.tgz", "integrity": "sha512-zymm5+u+sCsSWyD9qNaejV3DFvhCKclKdizYaJUuHA83RLjb7nSuGnddCHGv0hk+KY7BMAlsWeK4Ueg6EV6XQg==", - "dev": true, "license": "MIT", "engines": { "node": ">=8" @@ -4883,6 +4968,15 @@ "url": "https://github.com/sponsors/sindresorhus" } }, + "node_modules/p-try": { + "version": "2.2.0", + "resolved": "https://registry.npmmirror.com/p-try/-/p-try-2.2.0.tgz", + "integrity": "sha512-R4nPAVTAU0B9D35/Gk3uJf/7XYbQcyohSKdvAxIRSNghFl4e71hVoGnBNQz9cWaXxO2I10KTC+3jMdvvoKw6dQ==", + "license": "MIT", + "engines": { + "node": ">=6" + } + }, "node_modules/package-json-from-dist": { "version": "1.0.1", "resolved": "https://registry.npmjs.org/package-json-from-dist/-/package-json-from-dist-1.0.1.tgz", @@ -4957,7 +5051,6 @@ "version": "4.0.0", "resolved": "https://registry.npmjs.org/path-exists/-/path-exists-4.0.0.tgz", "integrity": "sha512-ak9Qy5Q7jYb2Wwcey5Fpvg2KoAc/ZIhLSLOSBmRmygPsGwkVVt0fZa0qrtMz+m6tJTAHfZQ8FnmB4MG4LWy7/w==", - "dev": true, "license": "MIT", "engines": { "node": ">=8" @@ -5093,6 +5186,15 @@ "node": ">= 6" } }, + "node_modules/pngjs": { + "version": "5.0.0", + "resolved": "https://registry.npmmirror.com/pngjs/-/pngjs-5.0.0.tgz", + "integrity": "sha512-40QW5YalBNfQo5yRYmiw7Yz6TKKVr3h6970B2YE+3fQpsWcrbj1PzJgxeJ19DRQjhMbKPIuMY8rFaXc8moolVw==", + "license": "MIT", + "engines": { + "node": ">=10.13.0" + } + }, "node_modules/polished": { "version": "4.3.1", "resolved": "https://registry.npmjs.org/polished/-/polished-4.3.1.tgz", @@ -5313,6 +5415,23 @@ "node": ">=6" } }, + "node_modules/qrcode": { + "version": "1.5.4", + "resolved": "https://registry.npmmirror.com/qrcode/-/qrcode-1.5.4.tgz", + "integrity": "sha512-1ca71Zgiu6ORjHqFBDpnSMTR2ReToX4l1Au1VFLyVeBTFavzQnv5JxMFr3ukHVKpSrSA2MCk0lNJSykjUfz7Zg==", + "license": "MIT", + "dependencies": { + "dijkstrajs": "^1.0.1", + "pngjs": "^5.0.0", + "yargs": "^15.3.1" + }, + "bin": { + "qrcode": "bin/qrcode" + }, + "engines": { + "node": ">=10.13.0" + } + }, "node_modules/querystringify": { "version": "2.2.0", "resolved": "https://registry.npmjs.org/querystringify/-/querystringify-2.2.0.tgz", @@ -5370,6 +5489,21 @@ "node": ">=8.10.0" } }, + "node_modules/require-directory": { + "version": "2.1.1", + "resolved": "https://registry.npmmirror.com/require-directory/-/require-directory-2.1.1.tgz", + "integrity": "sha512-fGxEI7+wsG9xrvdjsrlmL22OMTTiHRwAMroiEeMgq8gzoLC/PQr7RsRDSTLUg/bZAZtF+TVIkHc6/4RIKrui+Q==", + "license": "MIT", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/require-main-filename": { + "version": "2.0.0", + "resolved": "https://registry.npmmirror.com/require-main-filename/-/require-main-filename-2.0.0.tgz", + "integrity": "sha512-NKN5kMDylKuldxYLSUfrbo5Tuzh4hd+2E8NPPX02mZtn1VuREQToYe/ZdlJy+J3uCpfaiGF05e7B8W0iXbQHmg==", + "license": "ISC" + }, "node_modules/requires-port": { "version": "1.0.0", "resolved": "https://registry.npmjs.org/requires-port/-/requires-port-1.0.0.tgz", @@ -5543,6 +5677,12 @@ "node": ">=10" } }, + "node_modules/set-blocking": { + "version": "2.0.0", + "resolved": "https://registry.npmmirror.com/set-blocking/-/set-blocking-2.0.0.tgz", + "integrity": "sha512-KiKBS8AnWGEyLzofFfmvKwpdPzqiy16LvQfK3yv/fVH7Bj13/wl3JSR1J+rfgRE9q7xUJK4qvgS8raSOeLUehw==", + "license": "ISC" + }, "node_modules/shebang-command": { "version": "2.0.0", "resolved": "https://registry.npmjs.org/shebang-command/-/shebang-command-2.0.0.tgz", @@ -5714,7 +5854,6 @@ "version": "6.0.1", "resolved": "https://registry.npmjs.org/strip-ansi/-/strip-ansi-6.0.1.tgz", "integrity": "sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A==", - "dev": true, "license": "MIT", "dependencies": { "ansi-regex": "^5.0.1" @@ -6715,6 +6854,12 @@ "node": ">= 8" } }, + "node_modules/which-module": { + "version": "2.0.1", + "resolved": "https://registry.npmmirror.com/which-module/-/which-module-2.0.1.tgz", + "integrity": "sha512-iBdZ57RDvnOR9AGBhML2vFZf7h8vmBjhoaZqODJBFWHVtKkDmKuHai3cx5PgVMrX5YDNp27AofYbAwctSS+vhQ==", + "license": "ISC" + }, "node_modules/why-is-node-running": { "version": "2.3.0", "resolved": "https://registry.npmjs.org/why-is-node-running/-/why-is-node-running-2.3.0.tgz", @@ -6928,6 +7073,12 @@ "dev": true, "license": "MIT" }, + "node_modules/y18n": { + "version": "4.0.3", + "resolved": "https://registry.npmmirror.com/y18n/-/y18n-4.0.3.tgz", + "integrity": "sha512-JKhqTOwSrqNA1NY5lSztJ1GrBiUodLMmIZuLiDaMRJ+itFd+ABVE8XBjOvIWL+rSqNDC74LCSFmlb/U4UZ4hJQ==", + "license": "ISC" + }, "node_modules/yaml": { "version": "1.10.2", "resolved": "https://registry.npmjs.org/yaml/-/yaml-1.10.2.tgz", @@ -6937,6 +7088,113 @@ "node": ">= 6" } }, + "node_modules/yargs": { + "version": "15.4.1", + "resolved": "https://registry.npmmirror.com/yargs/-/yargs-15.4.1.tgz", + "integrity": "sha512-aePbxDmcYW++PaqBsJ+HYUFwCdv4LVvdnhBy78E57PIor8/OVvhMrADFFEDh8DHDFRv/O9i3lPhsENjO7QX0+A==", + "license": "MIT", + "dependencies": { + "cliui": "^6.0.0", + "decamelize": "^1.2.0", + "find-up": "^4.1.0", + "get-caller-file": "^2.0.1", + "require-directory": "^2.1.1", + "require-main-filename": "^2.0.0", + "set-blocking": "^2.0.0", + "string-width": "^4.2.0", + "which-module": "^2.0.0", + "y18n": "^4.0.0", + "yargs-parser": "^18.1.2" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/yargs-parser": { + "version": "18.1.3", + "resolved": "https://registry.npmmirror.com/yargs-parser/-/yargs-parser-18.1.3.tgz", + "integrity": "sha512-o50j0JeToy/4K6OZcaQmW6lyXXKhq7csREXcDwk2omFPJEwUNOVtJKvmDr9EI1fAJZUyZcRF7kxGBWmRXudrCQ==", + "license": "ISC", + "dependencies": { + "camelcase": "^5.0.0", + "decamelize": "^1.2.0" + }, + "engines": { + "node": ">=6" + } + }, + "node_modules/yargs/node_modules/emoji-regex": { + "version": "8.0.0", + "resolved": "https://registry.npmmirror.com/emoji-regex/-/emoji-regex-8.0.0.tgz", + "integrity": "sha512-MSjYzcWNOA0ewAHpz0MxpYFvwg6yjy1NG3xteoqz644VCo/RPgnr1/GGt+ic3iJTzQ8Eu3TdM14SawnVUmGE6A==", + "license": "MIT" + }, + "node_modules/yargs/node_modules/find-up": { + "version": "4.1.0", + "resolved": "https://registry.npmmirror.com/find-up/-/find-up-4.1.0.tgz", + "integrity": "sha512-PpOwAdQ/YlXQ2vj8a3h8IipDuYRi3wceVQQGYWxNINccq40Anw7BlsEXCMbt1Zt+OLA6Fq9suIpIWD0OsnISlw==", + "license": "MIT", + "dependencies": { + "locate-path": "^5.0.0", + "path-exists": "^4.0.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/yargs/node_modules/locate-path": { + "version": "5.0.0", + "resolved": "https://registry.npmmirror.com/locate-path/-/locate-path-5.0.0.tgz", + "integrity": "sha512-t7hw9pI+WvuwNJXwk5zVHpyhIqzg2qTlklJOf0mVxGSbe3Fp2VieZcduNYjaLDoy6p9uGpQEGWG87WpMKlNq8g==", + "license": "MIT", + "dependencies": { + "p-locate": "^4.1.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/yargs/node_modules/p-limit": { + "version": "2.3.0", + "resolved": "https://registry.npmmirror.com/p-limit/-/p-limit-2.3.0.tgz", + "integrity": "sha512-//88mFWSJx8lxCzwdAABTJL2MyWB12+eIY7MDL2SqLmAkeKU9qxRvWuSyTjm3FUmpBEMuFfckAIqEaVGUDxb6w==", + "license": "MIT", + "dependencies": { + "p-try": "^2.0.0" + }, + "engines": { + "node": ">=6" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/yargs/node_modules/p-locate": { + "version": "4.1.0", + "resolved": "https://registry.npmmirror.com/p-locate/-/p-locate-4.1.0.tgz", + "integrity": "sha512-R79ZZ/0wAxKGu3oYMlz8jy/kbhsNrS7SKZ7PxEHBgJ5+F2mtFW2fK2cOtBh1cHYkQsbzFV7I+EoRKe6Yt0oK7A==", + "license": "MIT", + "dependencies": { + "p-limit": "^2.2.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/yargs/node_modules/string-width": { + "version": "4.2.3", + "resolved": "https://registry.npmmirror.com/string-width/-/string-width-4.2.3.tgz", + "integrity": "sha512-wKyQRQpjJ0sIp62ErSZdGsjMJWsap5oRNihHhu6G7JVO/9jIB6UyevL+tXuOqrng8j/cxKTWyWUwvSTriiZz/g==", + "license": "MIT", + "dependencies": { + "emoji-regex": "^8.0.0", + "is-fullwidth-code-point": "^3.0.0", + "strip-ansi": "^6.0.1" + }, + "engines": { + "node": ">=8" + } + }, "node_modules/yocto-queue": { "version": "0.1.0", "resolved": "https://registry.npmjs.org/yocto-queue/-/yocto-queue-0.1.0.tgz", diff --git a/frontend/package.json b/frontend/package.json index c984cd96..8e1fdb4b 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -22,6 +22,7 @@ "driver.js": "^1.4.0", "file-saver": "^2.0.5", "pinia": "^2.1.7", + "qrcode": "^1.5.4", "vue": "^3.4.0", "vue-chartjs": "^5.3.0", "vue-i18n": "^9.14.5", @@ -32,6 +33,7 @@ "@types/file-saver": "^2.0.7", "@types/mdx": "^2.0.13", "@types/node": "^20.10.5", + "@types/qrcode": "^1.5.6", "@typescript-eslint/eslint-plugin": "^7.18.0", "@typescript-eslint/parser": "^7.18.0", "@vitejs/plugin-vue": "^5.2.3", diff --git a/frontend/pnpm-lock.yaml b/frontend/pnpm-lock.yaml index 1a808176..df82dcdb 100644 --- a/frontend/pnpm-lock.yaml +++ b/frontend/pnpm-lock.yaml @@ -29,6 +29,9 @@ importers: pinia: specifier: ^2.1.7 version: 2.3.1(typescript@5.6.3)(vue@3.5.26(typescript@5.6.3)) + qrcode: + specifier: ^1.5.4 + version: 1.5.4 vue: specifier: ^3.4.0 version: 3.5.26(typescript@5.6.3) @@ -54,6 +57,9 @@ importers: '@types/node': specifier: ^20.10.5 version: 20.19.27 + '@types/qrcode': + specifier: ^1.5.6 + version: 1.5.6 '@typescript-eslint/eslint-plugin': specifier: ^7.18.0 version: 7.18.0(@typescript-eslint/parser@7.18.0(eslint@8.57.1)(typescript@5.6.3))(eslint@8.57.1)(typescript@5.6.3) @@ -1239,56 +1245,67 @@ packages: resolution: {integrity: sha512-EHMUcDwhtdRGlXZsGSIuXSYwD5kOT9NVnx9sqzYiwAc91wfYOE1g1djOEDseZJKKqtHAHGwnGPQu3kytmfaXLQ==} cpu: [arm] os: [linux] + libc: [glibc] '@rollup/rollup-linux-arm-musleabihf@4.54.0': resolution: {integrity: sha512-+pBrqEjaakN2ySv5RVrj/qLytYhPKEUwk+e3SFU5jTLHIcAtqh2rLrd/OkbNuHJpsBgxsD8ccJt5ga/SeG0JmA==} cpu: [arm] os: [linux] + libc: [musl] '@rollup/rollup-linux-arm64-gnu@4.54.0': resolution: {integrity: sha512-NSqc7rE9wuUaRBsBp5ckQ5CVz5aIRKCwsoa6WMF7G01sX3/qHUw/z4pv+D+ahL1EIKy6Enpcnz1RY8pf7bjwng==} cpu: [arm64] os: [linux] + libc: [glibc] '@rollup/rollup-linux-arm64-musl@4.54.0': resolution: {integrity: sha512-gr5vDbg3Bakga5kbdpqx81m2n9IX8M6gIMlQQIXiLTNeQW6CucvuInJ91EuCJ/JYvc+rcLLsDFcfAD1K7fMofg==} cpu: [arm64] os: [linux] + libc: [musl] '@rollup/rollup-linux-loong64-gnu@4.54.0': resolution: {integrity: sha512-gsrtB1NA3ZYj2vq0Rzkylo9ylCtW/PhpLEivlgWe0bpgtX5+9j9EZa0wtZiCjgu6zmSeZWyI/e2YRX1URozpIw==} cpu: [loong64] os: [linux] + libc: [glibc] '@rollup/rollup-linux-ppc64-gnu@4.54.0': resolution: {integrity: sha512-y3qNOfTBStmFNq+t4s7Tmc9hW2ENtPg8FeUD/VShI7rKxNW7O4fFeaYbMsd3tpFlIg1Q8IapFgy7Q9i2BqeBvA==} cpu: [ppc64] os: [linux] + libc: [glibc] '@rollup/rollup-linux-riscv64-gnu@4.54.0': resolution: {integrity: sha512-89sepv7h2lIVPsFma8iwmccN7Yjjtgz0Rj/Ou6fEqg3HDhpCa+Et+YSufy27i6b0Wav69Qv4WBNl3Rs6pwhebQ==} cpu: [riscv64] os: [linux] + libc: [glibc] '@rollup/rollup-linux-riscv64-musl@4.54.0': resolution: {integrity: sha512-ZcU77ieh0M2Q8Ur7D5X7KvK+UxbXeDHwiOt/CPSBTI1fBmeDMivW0dPkdqkT4rOgDjrDDBUed9x4EgraIKoR2A==} cpu: [riscv64] os: [linux] + libc: [musl] '@rollup/rollup-linux-s390x-gnu@4.54.0': resolution: {integrity: sha512-2AdWy5RdDF5+4YfG/YesGDDtbyJlC9LHmL6rZw6FurBJ5n4vFGupsOBGfwMRjBYH7qRQowT8D/U4LoSvVwOhSQ==} cpu: [s390x] os: [linux] + libc: [glibc] '@rollup/rollup-linux-x64-gnu@4.54.0': resolution: {integrity: sha512-WGt5J8Ij/rvyqpFexxk3ffKqqbLf9AqrTBbWDk7ApGUzaIs6V+s2s84kAxklFwmMF/vBNGrVdYgbblCOFFezMQ==} cpu: [x64] os: [linux] + libc: [glibc] '@rollup/rollup-linux-x64-musl@4.54.0': resolution: {integrity: sha512-JzQmb38ATzHjxlPHuTH6tE7ojnMKM2kYNzt44LO/jJi8BpceEC8QuXYA908n8r3CNuG/B3BV8VR3Hi1rYtmPiw==} cpu: [x64] os: [linux] + libc: [musl] '@rollup/rollup-openharmony-arm64@4.54.0': resolution: {integrity: sha512-huT3fd0iC7jigGh7n3q/+lfPcXxBi+om/Rs3yiFxjvSxbSB6aohDFXbWvlspaqjeOh+hx7DDHS+5Es5qRkWkZg==} @@ -1479,6 +1496,9 @@ packages: '@types/parse-json@4.0.2': resolution: {integrity: sha512-dISoDXWWQwUquiKsyZ4Ng+HX2KsPL7LyHKHQwgGFEA3IaKac4Obd+h2a/a6waisAoepJlBcx9paWqjA8/HVjCw==} + '@types/qrcode@1.5.6': + resolution: {integrity: sha512-te7NQcV2BOvdj2b1hCAHzAoMNuj65kNBMz0KBaxM6c3VGBOhU0dURQKOtH8CFNI/dsKkwlv32p26qYQTWoB5bw==} + '@types/react@19.2.7': resolution: {integrity: sha512-MWtvHrGZLFttgeEj28VXHxpmwYbor/ATPYbBfSFZEIRK0ecCFLl2Qo55z52Hss+UV9CRN7trSeq1zbgx7YDWWg==} @@ -1832,6 +1852,10 @@ packages: resolution: {integrity: sha512-QOSvevhslijgYwRx6Rv7zKdMF8lbRmx+uQGx2+vDc+KI/eBnsy9kit5aj23AgGu3pa4t9AgwbnXWqS+iOY+2aA==} engines: {node: '>= 6'} + camelcase@5.3.1: + resolution: {integrity: sha512-L28STB170nwWS63UjtlEOE3dldQApaJXZkOI1uMFfzf3rRuPegHaHesyee+YxQ+W6SvRDQV6UrdOdRiR153wJg==} + engines: {node: '>=6'} + caniuse-lite@1.0.30001761: resolution: {integrity: sha512-JF9ptu1vP2coz98+5051jZ4PwQgd2ni8A+gYSN7EA7dPKIMf0pDlSUxhdmVOaV3/fYK5uWBkgSXJaRLr4+3A6g==} @@ -1895,6 +1919,9 @@ packages: classnames@2.5.1: resolution: {integrity: sha512-saHYOzhIQs6wy2sVxTM6bUDsQO4F50V9RQ22qBpEdCW+I+/Wmke2HOl6lS6dTpdxVhb88/I6+Hs+438c3lfUow==} + cliui@6.0.0: + resolution: {integrity: sha512-t6wbgtoCXvAzst7QgXxJYqPt0usEfbgQdftEPbLL/cvv6HPE5VgvqCuAIDR0NgU52ds6rFwqrgakNLrHEjCbrQ==} + clsx@1.2.1: resolution: {integrity: sha512-EcR6r5a8bj6pu3ycsa/E/cKVGuTgZJZdsyUYHOksG/UHIiKfjxzRxYJpyVBwYaQeOvghal9fcc4PidlgzugAQg==} engines: {node: '>=6'} @@ -2164,6 +2191,10 @@ packages: supports-color: optional: true + decamelize@1.2.0: + resolution: {integrity: sha512-z2S+W9X73hAUUki+N+9Za2lBlun89zigOyGrsax+KUQ6wKW4ZoWpEYBkGhQjwAjjDCkWxhY0VKEhk8wzY7F5cA==} + engines: {node: '>=0.10.0'} + decimal.js@10.6.0: resolution: {integrity: sha512-YpgQiITW3JXGntzdUmyUR1V812Hn8T1YVXhCu+wO3OpS4eU9l4YdD3qjyiKdV6mvV29zapkMeD390UVEf2lkUg==} @@ -2198,6 +2229,9 @@ packages: didyoumean@1.2.2: resolution: {integrity: sha512-gxtyfqMg7GKyhQmb056K7M3xszy/myH8w+B4RT+QXBQsvAOdc3XymqDDPHx1BgPgsdAA5SIifona89YtRATDzw==} + dijkstrajs@1.0.3: + resolution: {integrity: sha512-qiSlmBq9+BCdCA/L46dw8Uy93mloxsPSbwnm5yrKn2vMPiy8KyAskTF6zuV/j5BMsmOGZDPs7KjU+mjb670kfA==} + dir-glob@3.0.1: resolution: {integrity: sha512-WkrWp9GR4KXfKGYzOLmTuGVi1UWFfws377n9cc55/tb6DuqyF6pcQ5AbiHEshaDpY9v6oaSr2XCDidGmMwdzIA==} engines: {node: '>=8'} @@ -2424,6 +2458,10 @@ packages: find-root@1.1.0: resolution: {integrity: sha512-NKfW6bec6GfKc0SGx1e07QZY9PE99u0Bft/0rzSD5k3sO/vwkVUpDUKVm5Gpp5Ue3YfShPFTX2070tDs5kB9Ng==} + find-up@4.1.0: + resolution: {integrity: sha512-PpOwAdQ/YlXQ2vj8a3h8IipDuYRi3wceVQQGYWxNINccq40Anw7BlsEXCMbt1Zt+OLA6Fq9suIpIWD0OsnISlw==} + engines: {node: '>=8'} + find-up@5.0.0: resolution: {integrity: sha512-78/PXT1wlLLDgTzDs7sjq9hzz0vXD+zn+7wypEe4fXQxCmdmqfGsEPQxmiCSQI3ajFV91bVSsvNtrJRiW6nGng==} engines: {node: '>=10'} @@ -2488,6 +2526,10 @@ packages: function-bind@1.1.2: resolution: {integrity: sha512-7XHNxH7qX9xG5mIwxkhumTox/MIRNcOgDrxWsMt2pAr23WHp6MrRlN7FBSFpCpr+oVO0F744iUgR82nJMfG2SA==} + get-caller-file@2.0.5: + resolution: {integrity: sha512-DyFP3BM/3YHTQOCUL/w0OZHR0lpKeGrxotcHWcqNEdnltqFwXVfhEBQ94eIo34AfQpo0rGki4cyIiftY06h2Fg==} + engines: {node: 6.* || 8.* || >= 10.*} + get-east-asian-width@1.4.0: resolution: {integrity: sha512-QZjmEOC+IT1uk6Rx0sX22V6uHWVwbdbxf1faPqJ1QhLdGgsRGCZoyaQBm/piRdJy/D2um6hM1UP7ZEeQ4EkP+Q==} engines: {node: '>=18'} @@ -2856,6 +2898,10 @@ packages: lit@3.3.2: resolution: {integrity: sha512-NF9zbsP79l4ao2SNrH3NkfmFgN/hBYSQo90saIVI1o5GpjAdCPVstVzO1MrLOakHoEhYkrtRjPK6Ob521aoYWQ==} + locate-path@5.0.0: + resolution: {integrity: sha512-t7hw9pI+WvuwNJXwk5zVHpyhIqzg2qTlklJOf0mVxGSbe3Fp2VieZcduNYjaLDoy6p9uGpQEGWG87WpMKlNq8g==} + engines: {node: '>=8'} + locate-path@6.0.0: resolution: {integrity: sha512-iPZK6eYjbxRu3uB4/WZ3EsEIMJFMqAoopl3R+zuq0UjcAm/MO6KCweDgPfP3elTztoKP3KtnVHxTn2NHBSDVUw==} engines: {node: '>=10'} @@ -3239,14 +3285,26 @@ packages: resolution: {integrity: sha512-6IpQ7mKUxRcZNLIObR0hz7lxsapSSIYNZJwXPGeF0mTVqGKFIXj1DQcMoT22S3ROcLyY/rz0PWaWZ9ayWmad9g==} engines: {node: '>= 0.8.0'} + p-limit@2.3.0: + resolution: {integrity: sha512-//88mFWSJx8lxCzwdAABTJL2MyWB12+eIY7MDL2SqLmAkeKU9qxRvWuSyTjm3FUmpBEMuFfckAIqEaVGUDxb6w==} + engines: {node: '>=6'} + p-limit@3.1.0: resolution: {integrity: sha512-TYOanM3wGwNGsZN2cVTYPArw454xnXj5qmWF1bEoAc4+cU/ol7GVh7odevjp1FNHduHc3KZMcFduxU5Xc6uJRQ==} engines: {node: '>=10'} + p-locate@4.1.0: + resolution: {integrity: sha512-R79ZZ/0wAxKGu3oYMlz8jy/kbhsNrS7SKZ7PxEHBgJ5+F2mtFW2fK2cOtBh1cHYkQsbzFV7I+EoRKe6Yt0oK7A==} + engines: {node: '>=8'} + p-locate@5.0.0: resolution: {integrity: sha512-LaNjtRWUBY++zB5nE/NwcaoMylSPk+S+ZHNB1TzdbMJMny6dynpAGt7X/tl/QYq3TIeE6nxHppbo2LGymrG5Pw==} engines: {node: '>=10'} + p-try@2.2.0: + resolution: {integrity: sha512-R4nPAVTAU0B9D35/Gk3uJf/7XYbQcyohSKdvAxIRSNghFl4e71hVoGnBNQz9cWaXxO2I10KTC+3jMdvvoKw6dQ==} + engines: {node: '>=6'} + package-json-from-dist@1.0.1: resolution: {integrity: sha512-UEZIS3/by4OC8vL3P2dTXRETpebLI2NiI5vIrjaD/5UtrkFX/tNbwjTSRAGC/+7CAo2pIcBaRgWmcBBHcsaCIw==} @@ -3341,6 +3399,10 @@ packages: pkg-types@1.3.1: resolution: {integrity: sha512-/Jm5M4RvtBFVkKWRu2BLUTNP8/M2a+UwuAX+ae4770q1qVGtfjG+WTCupoZixokjmHiry8uI+dlY8KXYV5HVVQ==} + pngjs@5.0.0: + resolution: {integrity: sha512-40QW5YalBNfQo5yRYmiw7Yz6TKKVr3h6970B2YE+3fQpsWcrbj1PzJgxeJ19DRQjhMbKPIuMY8rFaXc8moolVw==} + engines: {node: '>=10.13.0'} + points-on-curve@0.2.0: resolution: {integrity: sha512-0mYKnYYe9ZcqMCWhUjItv/oHjvgEsfKvnUTg8sAtnHr3GVy7rGkXCb6d5cSyqrWqL4k81b9CPg3urd+T7aop3A==} @@ -3421,6 +3483,11 @@ packages: resolution: {integrity: sha512-vYt7UD1U9Wg6138shLtLOvdAu+8DsC/ilFtEVHcH+wydcSpNE20AfSOduf6MkRFahL5FY7X1oU7nKVZFtfq8Fg==} engines: {node: '>=6'} + qrcode@1.5.4: + resolution: {integrity: sha512-1ca71Zgiu6ORjHqFBDpnSMTR2ReToX4l1Au1VFLyVeBTFavzQnv5JxMFr3ukHVKpSrSA2MCk0lNJSykjUfz7Zg==} + engines: {node: '>=10.13.0'} + hasBin: true + query-string@9.3.1: resolution: {integrity: sha512-5fBfMOcDi5SA9qj5jZhWAcTtDfKF5WFdd2uD9nVNlbxVv1baq65aALy6qofpNEGELHvisjjasxQp7BlM9gvMzw==} engines: {node: '>=18'} @@ -3664,6 +3731,13 @@ packages: remark-stringify@11.0.0: resolution: {integrity: sha512-1OSmLd3awB/t8qdoEOMazZkNsfVTeY4fTsgzcQFdXNq8ToTN4ZGwrMnlda4K6smTFKD+GRV6O48i6Z4iKgPPpw==} + require-directory@2.1.1: + resolution: {integrity: sha512-fGxEI7+wsG9xrvdjsrlmL22OMTTiHRwAMroiEeMgq8gzoLC/PQr7RsRDSTLUg/bZAZtF+TVIkHc6/4RIKrui+Q==} + engines: {node: '>=0.10.0'} + + require-main-filename@2.0.0: + resolution: {integrity: sha512-NKN5kMDylKuldxYLSUfrbo5Tuzh4hd+2E8NPPX02mZtn1VuREQToYe/ZdlJy+J3uCpfaiGF05e7B8W0iXbQHmg==} + requires-port@1.0.0: resolution: {integrity: sha512-KigOCHcocU3XODJxsu8i/j8T9tzT4adHiecwORRQ0ZZFcp7ahwXuRU1m+yuO90C5ZUyGeGfocHDI14M3L3yDAQ==} @@ -3739,6 +3813,9 @@ packages: engines: {node: '>=10'} hasBin: true + set-blocking@2.0.0: + resolution: {integrity: sha512-KiKBS8AnWGEyLzofFfmvKwpdPzqiy16LvQfK3yv/fVH7Bj13/wl3JSR1J+rfgRE9q7xUJK4qvgS8raSOeLUehw==} + set-value@2.0.1: resolution: {integrity: sha512-JxHc1weCN68wRY0fhCoXpyK55m/XPHafOmK4UWD7m2CI14GMcFypt4w/0+NV5f/ZMby2F6S2wwA7fgynh9gWSw==} engines: {node: '>=0.10.0'} @@ -4263,6 +4340,9 @@ packages: resolution: {integrity: sha512-De72GdQZzNTUBBChsXueQUnPKDkg/5A5zp7pFDuQAj5UFoENpiACU0wlCvzpAGnTkj++ihpKwKyYewn/XNUbKw==} engines: {node: '>=18'} + which-module@2.0.1: + resolution: {integrity: sha512-iBdZ57RDvnOR9AGBhML2vFZf7h8vmBjhoaZqODJBFWHVtKkDmKuHai3cx5PgVMrX5YDNp27AofYbAwctSS+vhQ==} + which@2.0.2: resolution: {integrity: sha512-BLI3Tl1TW3Pvl70l3yq3Y64i+awpwXqsGBYWkkqMtnbXgrMD+yj7rhW0kuEDxzJaYXGjEW5ogapKNMEKNMjibA==} engines: {node: '>= 8'} @@ -4285,6 +4365,10 @@ packages: resolution: {integrity: sha512-OELeY0Q61OXpdUfTp+oweA/vtLVg5VDOXh+3he3PNzLGG/y0oylSOC1xRVj0+l4vQ3tj/bB1HVHv1ocXkQceFA==} engines: {node: '>=0.8'} + wrap-ansi@6.2.0: + resolution: {integrity: sha512-r6lPcBGxZXlIcymEu7InxDMhdW0KDxpLgoFLcguasxCaJ/SOIZwINatK9KY/tf+ZrlywOKU0UDj3ATXUBfxJXA==} + engines: {node: '>=8'} + wrap-ansi@7.0.0: resolution: {integrity: sha512-YVGIj2kamLSTxw6NsZjoBxfSwsn0ycdesmc4p+Q21c5zPuZ1pl+NfxVdxPtdHvmNVOQ6XSYG4AUtyt/Fi7D16Q==} engines: {node: '>=10'} @@ -4324,10 +4408,21 @@ packages: xmlchars@2.2.0: resolution: {integrity: sha512-JZnDKK8B0RCDw84FNdDAIpZK+JuJw+s7Lz8nksI7SIuU3UXJJslUthsi+uWBUYOwPFwW7W7PRLRfUKpxjtjFCw==} + y18n@4.0.3: + resolution: {integrity: sha512-JKhqTOwSrqNA1NY5lSztJ1GrBiUodLMmIZuLiDaMRJ+itFd+ABVE8XBjOvIWL+rSqNDC74LCSFmlb/U4UZ4hJQ==} + yaml@1.10.2: resolution: {integrity: sha512-r3vXyErRCYJ7wg28yvBY5VSoAF8ZvlcW9/BwUzEtUsjvX/DKs24dIkuwjtuprwJJHsbyUbLApepYTR1BN4uHrg==} engines: {node: '>= 6'} + yargs-parser@18.1.3: + resolution: {integrity: sha512-o50j0JeToy/4K6OZcaQmW6lyXXKhq7csREXcDwk2omFPJEwUNOVtJKvmDr9EI1fAJZUyZcRF7kxGBWmRXudrCQ==} + engines: {node: '>=6'} + + yargs@15.4.1: + resolution: {integrity: sha512-aePbxDmcYW++PaqBsJ+HYUFwCdv4LVvdnhBy78E57PIor8/OVvhMrADFFEDh8DHDFRv/O9i3lPhsENjO7QX0+A==} + engines: {node: '>=8'} + yocto-queue@0.1.0: resolution: {integrity: sha512-rVksvsnNCdJ/ohGc6xgPwyN8eheCxsiLM8mxuE/t/mOVqJewPuO1miLpTHQiRgTKCLexL4MeAFVagts7HmNZ2Q==} engines: {node: '>=10'} @@ -5838,6 +5933,10 @@ snapshots: '@types/parse-json@4.0.2': {} + '@types/qrcode@1.5.6': + dependencies: + '@types/node': 20.19.27 + '@types/react@19.2.7': dependencies: csstype: 3.2.3 @@ -6321,6 +6420,8 @@ snapshots: camelcase-css@2.0.1: {} + camelcase@5.3.1: {} + caniuse-lite@1.0.30001761: {} ccount@2.0.1: {} @@ -6395,6 +6496,12 @@ snapshots: classnames@2.5.1: {} + cliui@6.0.0: + dependencies: + string-width: 4.2.3 + strip-ansi: 6.0.1 + wrap-ansi: 6.2.0 + clsx@1.2.1: {} clsx@2.1.1: {} @@ -6668,6 +6775,8 @@ snapshots: dependencies: ms: 2.1.3 + decamelize@1.2.0: {} + decimal.js@10.6.0: {} decode-named-character-reference@1.2.0: @@ -6694,6 +6803,8 @@ snapshots: didyoumean@1.2.2: {} + dijkstrajs@1.0.3: {} + dir-glob@3.0.1: dependencies: path-type: 4.0.0 @@ -6978,6 +7089,11 @@ snapshots: find-root@1.1.0: {} + find-up@4.1.0: + dependencies: + locate-path: 5.0.0 + path-exists: 4.0.0 + find-up@5.0.0: dependencies: locate-path: 6.0.0 @@ -7029,6 +7145,8 @@ snapshots: function-bind@1.1.2: {} + get-caller-file@2.0.5: {} + get-east-asian-width@1.4.0: {} get-intrinsic@1.3.0: @@ -7521,6 +7639,10 @@ snapshots: lit-element: 4.2.2 lit-html: 3.3.2 + locate-path@5.0.0: + dependencies: + p-locate: 4.1.0 + locate-path@6.0.0: dependencies: p-locate: 5.0.0 @@ -8194,14 +8316,24 @@ snapshots: type-check: 0.4.0 word-wrap: 1.2.5 + p-limit@2.3.0: + dependencies: + p-try: 2.2.0 + p-limit@3.1.0: dependencies: yocto-queue: 0.1.0 + p-locate@4.1.0: + dependencies: + p-limit: 2.3.0 + p-locate@5.0.0: dependencies: p-limit: 3.1.0 + p-try@2.2.0: {} + package-json-from-dist@1.0.1: {} package-manager-detector@1.6.0: {} @@ -8284,6 +8416,8 @@ snapshots: mlly: 1.8.0 pathe: 2.0.3 + pngjs@5.0.0: {} + points-on-curve@0.2.0: {} points-on-path@0.2.1: @@ -8352,6 +8486,12 @@ snapshots: punycode@2.3.1: {} + qrcode@1.5.4: + dependencies: + dijkstrajs: 1.0.3 + pngjs: 5.0.0 + yargs: 15.4.1 + query-string@9.3.1: dependencies: decode-uri-component: 0.4.1 @@ -8703,6 +8843,10 @@ snapshots: mdast-util-to-markdown: 2.1.2 unified: 11.0.5 + require-directory@2.1.1: {} + + require-main-filename@2.0.0: {} + requires-port@1.0.0: {} reselect@5.1.1: {} @@ -8788,6 +8932,8 @@ snapshots: semver@7.7.3: {} + set-blocking@2.0.0: {} + set-value@2.0.1: dependencies: extend-shallow: 2.0.1 @@ -9298,6 +9444,8 @@ snapshots: tr46: 5.1.1 webidl-conversions: 7.0.0 + which-module@2.0.1: {} + which@2.0.2: dependencies: isexe: 2.0.0 @@ -9313,6 +9461,12 @@ snapshots: word@0.3.0: {} + wrap-ansi@6.2.0: + dependencies: + ansi-styles: 4.3.0 + string-width: 4.2.3 + strip-ansi: 6.0.1 + wrap-ansi@7.0.0: dependencies: ansi-styles: 4.3.0 @@ -9345,8 +9499,29 @@ snapshots: xmlchars@2.2.0: {} + y18n@4.0.3: {} + yaml@1.10.2: {} + yargs-parser@18.1.3: + dependencies: + camelcase: 5.3.1 + decamelize: 1.2.0 + + yargs@15.4.1: + dependencies: + cliui: 6.0.0 + decamelize: 1.2.0 + find-up: 4.1.0 + get-caller-file: 2.0.5 + require-directory: 2.1.1 + require-main-filename: 2.0.0 + set-blocking: 2.0.0 + string-width: 4.2.3 + which-module: 2.0.1 + y18n: 4.0.3 + yargs-parser: 18.1.3 + yocto-queue@0.1.0: {} zustand@3.7.2(react@19.2.3): diff --git a/frontend/src/api/admin/settings.ts b/frontend/src/api/admin/settings.ts index 6e2ade00..10ec4d8e 100644 --- a/frontend/src/api/admin/settings.ts +++ b/frontend/src/api/admin/settings.ts @@ -13,6 +13,9 @@ export interface SystemSettings { registration_enabled: boolean email_verify_enabled: boolean promo_code_enabled: boolean + password_reset_enabled: boolean + totp_enabled: boolean // TOTP 双因素认证 + totp_encryption_key_configured: boolean // TOTP 加密密钥是否已配置 // Default settings default_balance: number default_concurrency: number @@ -66,6 +69,8 @@ export interface UpdateSettingsRequest { registration_enabled?: boolean email_verify_enabled?: boolean promo_code_enabled?: boolean + password_reset_enabled?: boolean + totp_enabled?: boolean // TOTP 双因素认证 default_balance?: number default_concurrency?: number site_name?: string diff --git a/frontend/src/api/admin/subscriptions.ts b/frontend/src/api/admin/subscriptions.ts index 54b448e2..9f21056f 100644 --- a/frontend/src/api/admin/subscriptions.ts +++ b/frontend/src/api/admin/subscriptions.ts @@ -17,7 +17,7 @@ import type { * List all subscriptions with pagination * @param page - Page number (default: 1) * @param pageSize - Items per page (default: 20) - * @param filters - Optional filters (status, user_id, group_id) + * @param filters - Optional filters (status, user_id, group_id, sort_by, sort_order) * @returns Paginated list of subscriptions */ export async function list( @@ -27,6 +27,8 @@ export async function list( status?: 'active' | 'expired' | 'revoked' user_id?: number group_id?: number + sort_by?: string + sort_order?: 'asc' | 'desc' }, options?: { signal?: AbortSignal diff --git a/frontend/src/api/auth.ts b/frontend/src/api/auth.ts index fddc23ef..bbd5ed74 100644 --- a/frontend/src/api/auth.ts +++ b/frontend/src/api/auth.ts @@ -11,9 +11,23 @@ import type { CurrentUserResponse, SendVerifyCodeRequest, SendVerifyCodeResponse, - PublicSettings + PublicSettings, + TotpLoginResponse, + TotpLogin2FARequest } from '@/types' +/** + * Login response type - can be either full auth or 2FA required + */ +export type LoginResponse = AuthResponse | TotpLoginResponse + +/** + * Type guard to check if login response requires 2FA + */ +export function isTotp2FARequired(response: LoginResponse): response is TotpLoginResponse { + return 'requires_2fa' in response && response.requires_2fa === true +} + /** * Store authentication token in localStorage */ @@ -38,11 +52,28 @@ export function clearAuthToken(): void { /** * User login - * @param credentials - Username and password + * @param credentials - Email and password + * @returns Authentication response with token and user data, or 2FA required response + */ +export async function login(credentials: LoginRequest): Promise { + const { data } = await apiClient.post('/auth/login', credentials) + + // Only store token if 2FA is not required + if (!isTotp2FARequired(data)) { + setAuthToken(data.access_token) + localStorage.setItem('auth_user', JSON.stringify(data.user)) + } + + return data +} + +/** + * Complete login with 2FA code + * @param request - Temp token and TOTP code * @returns Authentication response with token and user data */ -export async function login(credentials: LoginRequest): Promise { - const { data } = await apiClient.post('/auth/login', credentials) +export async function login2FA(request: TotpLogin2FARequest): Promise { + const { data } = await apiClient.post('/auth/login/2fa', request) // Store token and user data setAuthToken(data.access_token) @@ -133,8 +164,61 @@ export async function validatePromoCode(code: string): Promise { + const { data } = await apiClient.post('/auth/forgot-password', request) + return data +} + +/** + * Reset password request + */ +export interface ResetPasswordRequest { + email: string + token: string + new_password: string +} + +/** + * Reset password response + */ +export interface ResetPasswordResponse { + message: string +} + +/** + * Reset password with token + * @param request - Email, token, and new password + * @returns Response with message + */ +export async function resetPassword(request: ResetPasswordRequest): Promise { + const { data } = await apiClient.post('/auth/reset-password', request) + return data +} + export const authAPI = { login, + login2FA, + isTotp2FARequired, register, getCurrentUser, logout, @@ -144,7 +228,9 @@ export const authAPI = { clearAuthToken, getPublicSettings, sendVerifyCode, - validatePromoCode + validatePromoCode, + forgotPassword, + resetPassword } export default authAPI diff --git a/frontend/src/api/index.ts b/frontend/src/api/index.ts index e80cc753..c144a7e1 100644 --- a/frontend/src/api/index.ts +++ b/frontend/src/api/index.ts @@ -7,7 +7,7 @@ export { apiClient } from './client' // Auth API -export { authAPI } from './auth' +export { authAPI, isTotp2FARequired, type LoginResponse } from './auth' // User APIs export { keysAPI } from './keys' @@ -15,6 +15,7 @@ export { usageAPI } from './usage' export { userAPI } from './user' export { redeemAPI, type RedeemHistoryItem } from './redeem' export { userGroupsAPI } from './groups' +export { totpAPI } from './totp' // Admin APIs export { adminAPI } from './admin' diff --git a/frontend/src/api/totp.ts b/frontend/src/api/totp.ts new file mode 100644 index 00000000..cd658acb --- /dev/null +++ b/frontend/src/api/totp.ts @@ -0,0 +1,83 @@ +/** + * TOTP (2FA) API endpoints + * Handles Two-Factor Authentication with Google Authenticator + */ + +import { apiClient } from './client' +import type { + TotpStatus, + TotpSetupRequest, + TotpSetupResponse, + TotpEnableRequest, + TotpEnableResponse, + TotpDisableRequest, + TotpVerificationMethod +} from '@/types' + +/** + * Get TOTP status for current user + * @returns TOTP status including enabled state and feature availability + */ +export async function getStatus(): Promise { + const { data } = await apiClient.get('/user/totp/status') + return data +} + +/** + * Get verification method for TOTP operations + * @returns Method ('email' or 'password') required for setup/disable + */ +export async function getVerificationMethod(): Promise { + const { data } = await apiClient.get('/user/totp/verification-method') + return data +} + +/** + * Send email verification code for TOTP operations + * @returns Success response + */ +export async function sendVerifyCode(): Promise<{ success: boolean }> { + const { data } = await apiClient.post<{ success: boolean }>('/user/totp/send-code') + return data +} + +/** + * Initiate TOTP setup - generates secret and QR code + * @param request - Email code or password depending on verification method + * @returns Setup response with secret, QR code URL, and setup token + */ +export async function initiateSetup(request?: TotpSetupRequest): Promise { + const { data } = await apiClient.post('/user/totp/setup', request || {}) + return data +} + +/** + * Complete TOTP setup by verifying the code + * @param request - TOTP code and setup token + * @returns Enable response with success status and enabled timestamp + */ +export async function enable(request: TotpEnableRequest): Promise { + const { data } = await apiClient.post('/user/totp/enable', request) + return data +} + +/** + * Disable TOTP for current user + * @param request - Email code or password depending on verification method + * @returns Success response + */ +export async function disable(request: TotpDisableRequest): Promise<{ success: boolean }> { + const { data } = await apiClient.post<{ success: boolean }>('/user/totp/disable', request) + return data +} + +export const totpAPI = { + getStatus, + getVerificationMethod, + sendVerifyCode, + initiateSetup, + enable, + disable +} + +export default totpAPI diff --git a/frontend/src/components/auth/TotpLoginModal.vue b/frontend/src/components/auth/TotpLoginModal.vue new file mode 100644 index 00000000..03fa718d --- /dev/null +++ b/frontend/src/components/auth/TotpLoginModal.vue @@ -0,0 +1,176 @@ + + + diff --git a/frontend/src/components/common/DataTable.vue b/frontend/src/components/common/DataTable.vue index b74f52ee..c1e4333d 100644 --- a/frontend/src/components/common/DataTable.vue +++ b/frontend/src/components/common/DataTable.vue @@ -181,6 +181,10 @@ import Icon from '@/components/icons/Icon.vue' const { t } = useI18n() +const emit = defineEmits<{ + sort: [key: string, order: 'asc' | 'desc'] +}>() + // 表格容器引用 const tableWrapperRef = ref(null) const isScrollable = ref(false) @@ -289,6 +293,11 @@ interface Props { * If provided, DataTable will load the stored sort state on mount. */ sortStorageKey?: string + /** + * Enable server-side sorting mode. When true, clicking sort headers + * will emit 'sort' events instead of performing client-side sorting. + */ + serverSideSort?: boolean } const props = withDefaults(defineProps(), { @@ -296,7 +305,8 @@ const props = withDefaults(defineProps(), { stickyFirstColumn: true, stickyActionsColumn: true, expandableActions: true, - defaultSortOrder: 'asc' + defaultSortOrder: 'asc', + serverSideSort: false }) const sortKey = ref('') @@ -448,16 +458,26 @@ watch(actionsExpanded, async () => { }) const handleSort = (key: string) => { + let newOrder: 'asc' | 'desc' = 'asc' if (sortKey.value === key) { - sortOrder.value = sortOrder.value === 'asc' ? 'desc' : 'asc' - } else { + newOrder = sortOrder.value === 'asc' ? 'desc' : 'asc' + } + + if (props.serverSideSort) { + // Server-side sort mode: emit event and update internal state for UI feedback sortKey.value = key - sortOrder.value = 'asc' + sortOrder.value = newOrder + emit('sort', key, newOrder) + } else { + // Client-side sort mode: just update internal state + sortKey.value = key + sortOrder.value = newOrder } } const sortedData = computed(() => { - if (!sortKey.value || !props.data) return props.data + // Server-side sort mode: return data as-is (server handles sorting) + if (props.serverSideSort || !sortKey.value || !props.data) return props.data const key = sortKey.value const order = sortOrder.value diff --git a/frontend/src/components/user/profile/ProfileTotpCard.vue b/frontend/src/components/user/profile/ProfileTotpCard.vue new file mode 100644 index 00000000..77413e52 --- /dev/null +++ b/frontend/src/components/user/profile/ProfileTotpCard.vue @@ -0,0 +1,154 @@ + + + diff --git a/frontend/src/components/user/profile/TotpDisableDialog.vue b/frontend/src/components/user/profile/TotpDisableDialog.vue new file mode 100644 index 00000000..daca4067 --- /dev/null +++ b/frontend/src/components/user/profile/TotpDisableDialog.vue @@ -0,0 +1,179 @@ + + + diff --git a/frontend/src/components/user/profile/TotpSetupModal.vue b/frontend/src/components/user/profile/TotpSetupModal.vue new file mode 100644 index 00000000..3d9b79ec --- /dev/null +++ b/frontend/src/components/user/profile/TotpSetupModal.vue @@ -0,0 +1,400 @@ + + + diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index 06f80977..04f385d3 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -146,7 +146,10 @@ export default { balance: 'Balance', available: 'Available', copiedToClipboard: 'Copied to clipboard', + copied: 'Copied', copyFailed: 'Failed to copy', + verifying: 'Verifying...', + processing: 'Processing...', contactSupport: 'Contact Support', add: 'Add', invalidEmail: 'Please enter a valid email address', @@ -271,7 +274,36 @@ export default { code: 'Code', state: 'State', fullUrl: 'Full URL' - } + }, + // Forgot password + forgotPassword: 'Forgot password?', + forgotPasswordTitle: 'Reset Your Password', + forgotPasswordHint: 'Enter your email address and we will send you a link to reset your password.', + sendResetLink: 'Send Reset Link', + sendingResetLink: 'Sending...', + sendResetLinkFailed: 'Failed to send reset link. Please try again.', + resetEmailSent: 'Reset Link Sent', + resetEmailSentHint: 'If an account exists with this email, you will receive a password reset link shortly. Please check your inbox and spam folder.', + backToLogin: 'Back to Login', + rememberedPassword: 'Remembered your password?', + // Reset password + resetPasswordTitle: 'Set New Password', + resetPasswordHint: 'Enter your new password below.', + newPassword: 'New Password', + newPasswordPlaceholder: 'Enter your new password', + confirmPassword: 'Confirm Password', + confirmPasswordPlaceholder: 'Confirm your new password', + confirmPasswordRequired: 'Please confirm your password', + passwordsDoNotMatch: 'Passwords do not match', + resetPassword: 'Reset Password', + resettingPassword: 'Resetting...', + resetPasswordFailed: 'Failed to reset password. Please try again.', + passwordResetSuccess: 'Password Reset Successful', + passwordResetSuccessHint: 'Your password has been reset. You can now sign in with your new password.', + invalidResetLink: 'Invalid Reset Link', + invalidResetLinkHint: 'This password reset link is invalid or has expired. Please request a new one.', + requestNewResetLink: 'Request New Reset Link', + invalidOrExpiredToken: 'The password reset link is invalid or has expired. Please request a new one.' }, // Dashboard @@ -554,7 +586,46 @@ export default { passwordsNotMatch: 'New passwords do not match', passwordTooShort: 'Password must be at least 8 characters long', passwordChangeSuccess: 'Password changed successfully', - passwordChangeFailed: 'Failed to change password' + passwordChangeFailed: 'Failed to change password', + // TOTP 2FA + totp: { + title: 'Two-Factor Authentication (2FA)', + description: 'Enhance account security with Google Authenticator or similar apps', + enabled: 'Enabled', + enabledAt: 'Enabled at', + notEnabled: 'Not Enabled', + notEnabledHint: 'Enable two-factor authentication to enhance account security', + enable: 'Enable', + disable: 'Disable', + featureDisabled: 'Feature Unavailable', + featureDisabledHint: 'Two-factor authentication has not been enabled by the administrator', + setupTitle: 'Set Up Two-Factor Authentication', + setupStep1: 'Scan the QR code below with your authenticator app', + setupStep2: 'Enter the 6-digit code from your app', + manualEntry: "Can't scan? Enter the key manually:", + enterCode: 'Enter 6-digit code', + verify: 'Verify', + setupFailed: 'Failed to get setup information', + verifyFailed: 'Invalid code, please try again', + enableSuccess: 'Two-factor authentication enabled', + disableTitle: 'Disable Two-Factor Authentication', + disableWarning: 'After disabling, you will no longer need a verification code to log in. This may reduce your account security.', + enterPassword: 'Enter your current password to confirm', + confirmDisable: 'Confirm Disable', + disableSuccess: 'Two-factor authentication disabled', + disableFailed: 'Failed to disable, please check your password', + loginTitle: 'Two-Factor Authentication', + loginHint: 'Enter the 6-digit code from your authenticator app', + loginFailed: 'Verification failed, please try again', + // New translations for email verification + verifyEmailFirst: 'Please verify your email first', + verifyPasswordFirst: 'Please verify your identity first', + emailCode: 'Email Verification Code', + enterEmailCode: 'Enter 6-digit code', + sendCode: 'Send Code', + codeSent: 'Verification code sent to your email', + sendCodeFailed: 'Failed to send verification code' + } }, // Empty States @@ -2743,7 +2814,13 @@ export default { emailVerification: 'Email Verification', emailVerificationHint: 'Require email verification for new registrations', promoCode: 'Promo Code', - promoCodeHint: 'Allow users to use promo codes during registration' + promoCodeHint: 'Allow users to use promo codes during registration', + passwordReset: 'Password Reset', + passwordResetHint: 'Allow users to reset their password via email', + totp: 'Two-Factor Authentication (2FA)', + totpHint: 'Allow users to use authenticator apps like Google Authenticator', + totpKeyNotConfigured: + 'Please configure TOTP_ENCRYPTION_KEY in environment variables first. Generate a key with: openssl rand -hex 32' }, turnstile: { title: 'Cloudflare Turnstile', diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index 01457859..3ba7f64b 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -143,7 +143,10 @@ export default { balance: '余额', available: '可用', copiedToClipboard: '已复制到剪贴板', + copied: '已复制', copyFailed: '复制失败', + verifying: '验证中...', + processing: '处理中...', contactSupport: '联系客服', add: '添加', invalidEmail: '请输入有效的邮箱地址', @@ -268,7 +271,36 @@ export default { code: '授权码', state: '状态', fullUrl: '完整URL' - } + }, + // 忘记密码 + forgotPassword: '忘记密码?', + forgotPasswordTitle: '重置密码', + forgotPasswordHint: '输入您的邮箱地址,我们将向您发送密码重置链接。', + sendResetLink: '发送重置链接', + sendingResetLink: '发送中...', + sendResetLinkFailed: '发送重置链接失败,请重试。', + resetEmailSent: '重置链接已发送', + resetEmailSentHint: '如果该邮箱已注册,您将很快收到密码重置链接。请检查您的收件箱和垃圾邮件文件夹。', + backToLogin: '返回登录', + rememberedPassword: '想起密码了?', + // 重置密码 + resetPasswordTitle: '设置新密码', + resetPasswordHint: '请在下方输入您的新密码。', + newPassword: '新密码', + newPasswordPlaceholder: '输入新密码', + confirmPassword: '确认密码', + confirmPasswordPlaceholder: '再次输入新密码', + confirmPasswordRequired: '请确认您的密码', + passwordsDoNotMatch: '两次输入的密码不一致', + resetPassword: '重置密码', + resettingPassword: '重置中...', + resetPasswordFailed: '重置密码失败,请重试。', + passwordResetSuccess: '密码重置成功', + passwordResetSuccessHint: '您的密码已重置。现在可以使用新密码登录。', + invalidResetLink: '无效的重置链接', + invalidResetLinkHint: '此密码重置链接无效或已过期。请重新请求一个新链接。', + requestNewResetLink: '请求新的重置链接', + invalidOrExpiredToken: '密码重置链接无效或已过期。请重新请求一个新链接。' }, // Dashboard @@ -550,7 +582,46 @@ export default { passwordsNotMatch: '两次输入的密码不一致', passwordTooShort: '密码至少需要 8 个字符', passwordChangeSuccess: '密码修改成功', - passwordChangeFailed: '密码修改失败' + passwordChangeFailed: '密码修改失败', + // TOTP 2FA + totp: { + title: '双因素认证 (2FA)', + description: '使用 Google Authenticator 等应用增强账户安全', + enabled: '已启用', + enabledAt: '启用时间', + notEnabled: '未启用', + notEnabledHint: '启用双因素认证可以增强账户安全性', + enable: '启用', + disable: '禁用', + featureDisabled: '功能未开放', + featureDisabledHint: '管理员尚未开放双因素认证功能', + setupTitle: '设置双因素认证', + setupStep1: '使用认证器应用扫描下方二维码', + setupStep2: '输入应用显示的 6 位验证码', + manualEntry: '无法扫码?手动输入密钥:', + enterCode: '输入 6 位验证码', + verify: '验证', + setupFailed: '获取设置信息失败', + verifyFailed: '验证码错误,请重试', + enableSuccess: '双因素认证已启用', + disableTitle: '禁用双因素认证', + disableWarning: '禁用后,登录时将不再需要验证码。这可能会降低您的账户安全性。', + enterPassword: '请输入当前密码确认', + confirmDisable: '确认禁用', + disableSuccess: '双因素认证已禁用', + disableFailed: '禁用失败,请检查密码是否正确', + loginTitle: '双因素认证', + loginHint: '请输入您认证器应用显示的 6 位验证码', + loginFailed: '验证失败,请重试', + // New translations for email verification + verifyEmailFirst: '请先验证您的邮箱', + verifyPasswordFirst: '请先验证您的身份', + emailCode: '邮箱验证码', + enterEmailCode: '请输入 6 位验证码', + sendCode: '发送验证码', + codeSent: '验证码已发送到您的邮箱', + sendCodeFailed: '发送验证码失败' + } }, // Empty States @@ -2896,7 +2967,13 @@ export default { emailVerification: '邮箱验证', emailVerificationHint: '新用户注册时需要验证邮箱', promoCode: '优惠码', - promoCodeHint: '允许用户在注册时使用优惠码' + promoCodeHint: '允许用户在注册时使用优惠码', + passwordReset: '忘记密码', + passwordResetHint: '允许用户通过邮箱重置密码', + totp: '双因素认证 (2FA)', + totpHint: '允许用户使用 Google Authenticator 等应用进行二次验证', + totpKeyNotConfigured: + '请先在环境变量中配置 TOTP_ENCRYPTION_KEY。使用命令 openssl rand -hex 32 生成密钥。' }, turnstile: { title: 'Cloudflare Turnstile', diff --git a/frontend/src/router/index.ts b/frontend/src/router/index.ts index 9dce8624..06217228 100644 --- a/frontend/src/router/index.ts +++ b/frontend/src/router/index.ts @@ -79,6 +79,24 @@ const routes: RouteRecordRaw[] = [ title: 'LinuxDo OAuth Callback' } }, + { + path: '/forgot-password', + name: 'ForgotPassword', + component: () => import('@/views/auth/ForgotPasswordView.vue'), + meta: { + requiresAuth: false, + title: 'Forgot Password' + } + }, + { + path: '/reset-password', + name: 'ResetPassword', + component: () => import('@/views/auth/ResetPasswordView.vue'), + meta: { + requiresAuth: false, + title: 'Reset Password' + } + }, // ==================== User Routes ==================== { diff --git a/frontend/src/stores/app.ts b/frontend/src/stores/app.ts index c3bb3275..5a8533d8 100644 --- a/frontend/src/stores/app.ts +++ b/frontend/src/stores/app.ts @@ -313,6 +313,7 @@ export const useAppStore = defineStore('app', () => { registration_enabled: false, email_verify_enabled: false, promo_code_enabled: true, + password_reset_enabled: false, turnstile_enabled: false, turnstile_site_key: '', site_name: siteName.value, diff --git a/frontend/src/stores/auth.ts b/frontend/src/stores/auth.ts index 4076e154..e4612f5e 100644 --- a/frontend/src/stores/auth.ts +++ b/frontend/src/stores/auth.ts @@ -5,8 +5,8 @@ import { defineStore } from 'pinia' import { ref, computed, readonly } from 'vue' -import { authAPI } from '@/api' -import type { User, LoginRequest, RegisterRequest } from '@/types' +import { authAPI, isTotp2FARequired, type LoginResponse } from '@/api' +import type { User, LoginRequest, RegisterRequest, AuthResponse } from '@/types' const AUTH_TOKEN_KEY = 'auth_token' const AUTH_USER_KEY = 'auth_user' @@ -91,32 +91,23 @@ export const useAuthStore = defineStore('auth', () => { /** * User login - * @param credentials - Login credentials (username and password) - * @returns Promise resolving to the authenticated user + * @param credentials - Login credentials (email and password) + * @returns Promise resolving to the login response (may require 2FA) * @throws Error if login fails */ - async function login(credentials: LoginRequest): Promise { + async function login(credentials: LoginRequest): Promise { try { const response = await authAPI.login(credentials) - // Store token and user - token.value = response.access_token - - // Extract run_mode if present - if (response.user.run_mode) { - runMode.value = response.user.run_mode + // If 2FA is required, return the response without setting auth state + if (isTotp2FARequired(response)) { + return response } - const { run_mode: _run_mode, ...userData } = response.user - user.value = userData - // Persist to localStorage - localStorage.setItem(AUTH_TOKEN_KEY, response.access_token) - localStorage.setItem(AUTH_USER_KEY, JSON.stringify(userData)) + // Set auth state from the response + setAuthFromResponse(response) - // Start auto-refresh interval - startAutoRefresh() - - return userData + return response } catch (error) { // Clear any partial state on error clearAuth() @@ -124,6 +115,47 @@ export const useAuthStore = defineStore('auth', () => { } } + /** + * Complete login with 2FA code + * @param tempToken - Temporary token from initial login + * @param totpCode - 6-digit TOTP code + * @returns Promise resolving to the authenticated user + * @throws Error if 2FA verification fails + */ + async function login2FA(tempToken: string, totpCode: string): Promise { + try { + const response = await authAPI.login2FA({ temp_token: tempToken, totp_code: totpCode }) + setAuthFromResponse(response) + return user.value! + } catch (error) { + clearAuth() + throw error + } + } + + /** + * Set auth state from an AuthResponse + * Internal helper function + */ + function setAuthFromResponse(response: AuthResponse): void { + // Store token and user + token.value = response.access_token + + // Extract run_mode if present + if (response.user.run_mode) { + runMode.value = response.user.run_mode + } + const { run_mode: _run_mode, ...userData } = response.user + user.value = userData + + // Persist to localStorage + localStorage.setItem(AUTH_TOKEN_KEY, response.access_token) + localStorage.setItem(AUTH_USER_KEY, JSON.stringify(userData)) + + // Start auto-refresh interval + startAutoRefresh() + } + /** * User registration * @param userData - Registration data (username, email, password) @@ -253,6 +285,7 @@ export const useAuthStore = defineStore('auth', () => { // Actions login, + login2FA, register, setToken, logout, diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index 1e73ea9b..cefb914d 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -71,6 +71,7 @@ export interface PublicSettings { registration_enabled: boolean email_verify_enabled: boolean promo_code_enabled: boolean + password_reset_enabled: boolean turnstile_enabled: boolean turnstile_site_key: string site_name: string @@ -1107,3 +1108,52 @@ export interface UpdatePromoCodeRequest { expires_at?: number | null notes?: string } + +// ==================== TOTP (2FA) Types ==================== + +export interface TotpStatus { + enabled: boolean + enabled_at: number | null // Unix timestamp in seconds + feature_enabled: boolean +} + +export interface TotpSetupRequest { + email_code?: string + password?: string +} + +export interface TotpSetupResponse { + secret: string + qr_code_url: string + setup_token: string + countdown: number +} + +export interface TotpEnableRequest { + totp_code: string + setup_token: string +} + +export interface TotpEnableResponse { + success: boolean +} + +export interface TotpDisableRequest { + email_code?: string + password?: string +} + +export interface TotpVerificationMethod { + method: 'email' | 'password' +} + +export interface TotpLoginResponse { + requires_2fa: boolean + temp_token?: string + user_email_masked?: string +} + +export interface TotpLogin2FARequest { + temp_token: string + totp_code: string +} diff --git a/frontend/src/views/admin/SettingsView.vue b/frontend/src/views/admin/SettingsView.vue index 657013a9..b61dcc09 100644 --- a/frontend/src/views/admin/SettingsView.vue +++ b/frontend/src/views/admin/SettingsView.vue @@ -338,6 +338,47 @@ + + +
+
+ +

+ {{ t('admin.settings.registration.passwordResetHint') }} +

+
+ +
+ + +
+
+ +

+ {{ t('admin.settings.registration.totpHint') }} +

+ +

+ {{ t('admin.settings.registration.totpKeyNotConfigured') }} +

+
+ +
@@ -1029,6 +1070,9 @@ const form = reactive({ registration_enabled: true, email_verify_enabled: false, promo_code_enabled: true, + password_reset_enabled: false, + totp_enabled: false, + totp_encryption_key_configured: false, default_balance: 0, default_concurrency: 1, site_name: 'TianShuAPI', @@ -1152,6 +1196,8 @@ async function saveSettings() { registration_enabled: form.registration_enabled, email_verify_enabled: form.email_verify_enabled, promo_code_enabled: form.promo_code_enabled, + password_reset_enabled: form.password_reset_enabled, + totp_enabled: form.totp_enabled, default_balance: form.default_balance, default_concurrency: form.default_concurrency, site_name: form.site_name, diff --git a/frontend/src/views/admin/SubscriptionsView.vue b/frontend/src/views/admin/SubscriptionsView.vue index 9b0e5ecb..eb2b40d5 100644 --- a/frontend/src/views/admin/SubscriptionsView.vue +++ b/frontend/src/views/admin/SubscriptionsView.vue @@ -154,7 +154,13 @@ diff --git a/frontend/src/views/user/ProfileView.vue b/frontend/src/views/user/ProfileView.vue index 2d811fa5..0967e2b9 100644 --- a/frontend/src/views/user/ProfileView.vue +++ b/frontend/src/views/user/ProfileView.vue @@ -15,6 +15,7 @@ + @@ -27,6 +28,7 @@ import StatCard from '@/components/common/StatCard.vue' import ProfileInfoCard from '@/components/user/profile/ProfileInfoCard.vue' import ProfileEditForm from '@/components/user/profile/ProfileEditForm.vue' import ProfilePasswordForm from '@/components/user/profile/ProfilePasswordForm.vue' +import ProfileTotpCard from '@/components/user/profile/ProfileTotpCard.vue' import { Icon } from '@/components/icons' const { t } = useI18n(); const authStore = useAuthStore(); const user = computed(() => authStore.user)