密码重置请求
+您已请求重置密码。请点击下方按钮设置新密码:
+ 重置密码 +此链接将在 30 分钟后失效。
+如果您没有请求重置密码,请忽略此邮件。您的密码将保持不变。
+如果按钮无法点击,请复制以下链接到浏览器中打开:
+%s
+diff --git a/README.md b/README.md
index fa965e6f..e8e9c5a5 100644
--- a/README.md
+++ b/README.md
@@ -18,7 +18,7 @@ English | [中文](README_CN.md)
## Demo
-Try Sub2API online: **https://v2.pincc.ai/**
+Try Sub2API online: **https://demo.sub2api.org/**
Demo credentials (shared demo environment; **not** created automatically for self-hosted installs):
diff --git a/backend/cmd/server/wire.go b/backend/cmd/server/wire.go
index 5ef04a66..d9ff788e 100644
--- a/backend/cmd/server/wire.go
+++ b/backend/cmd/server/wire.go
@@ -70,6 +70,7 @@ func provideCleanup(
schedulerSnapshot *service.SchedulerSnapshotService,
tokenRefresh *service.TokenRefreshService,
accountExpiry *service.AccountExpiryService,
+ subscriptionExpiry *service.SubscriptionExpiryService,
usageCleanup *service.UsageCleanupService,
pricing *service.PricingService,
emailQueue *service.EmailQueueService,
@@ -138,6 +139,10 @@ func provideCleanup(
accountExpiry.Stop()
return nil
}},
+ {"SubscriptionExpiryService", func() error {
+ subscriptionExpiry.Stop()
+ return nil
+ }},
{"PricingService", func() error {
pricing.Stop()
return nil
diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go
index 7b22a31e..71624091 100644
--- a/backend/cmd/server/wire_gen.go
+++ b/backend/cmd/server/wire_gen.go
@@ -63,7 +63,13 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator)
authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService)
userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator)
- authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService)
+ secretEncryptor, err := repository.NewAESEncryptor(configConfig)
+ if err != nil {
+ return nil, err
+ }
+ totpCache := repository.NewTotpCache(redisClient)
+ totpService := service.NewTotpService(userRepository, secretEncryptor, totpCache, settingService, emailService, emailQueueService)
+ authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService, totpService)
userHandler := handler.NewUserHandler(userService)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageLogRepository := repository.NewUsageLogRepository(client, db)
@@ -165,7 +171,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, configConfig)
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, configConfig)
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
- handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler)
+ totpHandler := handler.NewTotpHandler(totpService)
+ handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler, totpHandler)
jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService)
adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService)
apiKeyAuthMiddleware := middleware.NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig)
@@ -178,7 +185,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, configConfig)
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
- v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, usageCleanupService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService)
+ subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
+ v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService)
application := &Application{
Server: httpServer,
Cleanup: v,
@@ -211,6 +219,7 @@ func provideCleanup(
schedulerSnapshot *service.SchedulerSnapshotService,
tokenRefresh *service.TokenRefreshService,
accountExpiry *service.AccountExpiryService,
+ subscriptionExpiry *service.SubscriptionExpiryService,
usageCleanup *service.UsageCleanupService,
pricing *service.PricingService,
emailQueue *service.EmailQueueService,
@@ -278,6 +287,10 @@ func provideCleanup(
accountExpiry.Stop()
return nil
}},
+ {"SubscriptionExpiryService", func() error {
+ subscriptionExpiry.Stop()
+ return nil
+ }},
{"PricingService", func() error {
pricing.Stop()
return nil
diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go
index d1f05186..d2a39331 100644
--- a/backend/ent/migrate/schema.go
+++ b/backend/ent/migrate/schema.go
@@ -610,6 +610,9 @@ var (
{Name: "status", Type: field.TypeString, Size: 20, Default: "active"},
{Name: "username", Type: field.TypeString, Size: 100, Default: ""},
{Name: "notes", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "totp_secret_encrypted", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}},
+ {Name: "totp_enabled", Type: field.TypeBool, Default: false},
+ {Name: "totp_enabled_at", Type: field.TypeTime, Nullable: true},
}
// UsersTable holds the schema information for the "users" table.
UsersTable = &schema.Table{
diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go
index 9b330616..7f3071c2 100644
--- a/backend/ent/mutation.go
+++ b/backend/ent/mutation.go
@@ -14360,6 +14360,9 @@ type UserMutation struct {
status *string
username *string
notes *string
+ totp_secret_encrypted *string
+ totp_enabled *bool
+ totp_enabled_at *time.Time
clearedFields map[string]struct{}
api_keys map[int64]struct{}
removedapi_keys map[int64]struct{}
@@ -14937,6 +14940,140 @@ func (m *UserMutation) ResetNotes() {
m.notes = nil
}
+// SetTotpSecretEncrypted sets the "totp_secret_encrypted" field.
+func (m *UserMutation) SetTotpSecretEncrypted(s string) {
+ m.totp_secret_encrypted = &s
+}
+
+// TotpSecretEncrypted returns the value of the "totp_secret_encrypted" field in the mutation.
+func (m *UserMutation) TotpSecretEncrypted() (r string, exists bool) {
+ v := m.totp_secret_encrypted
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldTotpSecretEncrypted returns the old "totp_secret_encrypted" field's value of the User entity.
+// If the User object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *UserMutation) OldTotpSecretEncrypted(ctx context.Context) (v *string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldTotpSecretEncrypted is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldTotpSecretEncrypted requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldTotpSecretEncrypted: %w", err)
+ }
+ return oldValue.TotpSecretEncrypted, nil
+}
+
+// ClearTotpSecretEncrypted clears the value of the "totp_secret_encrypted" field.
+func (m *UserMutation) ClearTotpSecretEncrypted() {
+ m.totp_secret_encrypted = nil
+ m.clearedFields[user.FieldTotpSecretEncrypted] = struct{}{}
+}
+
+// TotpSecretEncryptedCleared returns if the "totp_secret_encrypted" field was cleared in this mutation.
+func (m *UserMutation) TotpSecretEncryptedCleared() bool {
+ _, ok := m.clearedFields[user.FieldTotpSecretEncrypted]
+ return ok
+}
+
+// ResetTotpSecretEncrypted resets all changes to the "totp_secret_encrypted" field.
+func (m *UserMutation) ResetTotpSecretEncrypted() {
+ m.totp_secret_encrypted = nil
+ delete(m.clearedFields, user.FieldTotpSecretEncrypted)
+}
+
+// SetTotpEnabled sets the "totp_enabled" field.
+func (m *UserMutation) SetTotpEnabled(b bool) {
+ m.totp_enabled = &b
+}
+
+// TotpEnabled returns the value of the "totp_enabled" field in the mutation.
+func (m *UserMutation) TotpEnabled() (r bool, exists bool) {
+ v := m.totp_enabled
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldTotpEnabled returns the old "totp_enabled" field's value of the User entity.
+// If the User object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *UserMutation) OldTotpEnabled(ctx context.Context) (v bool, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldTotpEnabled is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldTotpEnabled requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldTotpEnabled: %w", err)
+ }
+ return oldValue.TotpEnabled, nil
+}
+
+// ResetTotpEnabled resets all changes to the "totp_enabled" field.
+func (m *UserMutation) ResetTotpEnabled() {
+ m.totp_enabled = nil
+}
+
+// SetTotpEnabledAt sets the "totp_enabled_at" field.
+func (m *UserMutation) SetTotpEnabledAt(t time.Time) {
+ m.totp_enabled_at = &t
+}
+
+// TotpEnabledAt returns the value of the "totp_enabled_at" field in the mutation.
+func (m *UserMutation) TotpEnabledAt() (r time.Time, exists bool) {
+ v := m.totp_enabled_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldTotpEnabledAt returns the old "totp_enabled_at" field's value of the User entity.
+// If the User object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *UserMutation) OldTotpEnabledAt(ctx context.Context) (v *time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldTotpEnabledAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldTotpEnabledAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldTotpEnabledAt: %w", err)
+ }
+ return oldValue.TotpEnabledAt, nil
+}
+
+// ClearTotpEnabledAt clears the value of the "totp_enabled_at" field.
+func (m *UserMutation) ClearTotpEnabledAt() {
+ m.totp_enabled_at = nil
+ m.clearedFields[user.FieldTotpEnabledAt] = struct{}{}
+}
+
+// TotpEnabledAtCleared returns if the "totp_enabled_at" field was cleared in this mutation.
+func (m *UserMutation) TotpEnabledAtCleared() bool {
+ _, ok := m.clearedFields[user.FieldTotpEnabledAt]
+ return ok
+}
+
+// ResetTotpEnabledAt resets all changes to the "totp_enabled_at" field.
+func (m *UserMutation) ResetTotpEnabledAt() {
+ m.totp_enabled_at = nil
+ delete(m.clearedFields, user.FieldTotpEnabledAt)
+}
+
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids.
func (m *UserMutation) AddAPIKeyIDs(ids ...int64) {
if m.api_keys == nil {
@@ -15403,7 +15540,7 @@ func (m *UserMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *UserMutation) Fields() []string {
- fields := make([]string, 0, 11)
+ fields := make([]string, 0, 14)
if m.created_at != nil {
fields = append(fields, user.FieldCreatedAt)
}
@@ -15437,6 +15574,15 @@ func (m *UserMutation) Fields() []string {
if m.notes != nil {
fields = append(fields, user.FieldNotes)
}
+ if m.totp_secret_encrypted != nil {
+ fields = append(fields, user.FieldTotpSecretEncrypted)
+ }
+ if m.totp_enabled != nil {
+ fields = append(fields, user.FieldTotpEnabled)
+ }
+ if m.totp_enabled_at != nil {
+ fields = append(fields, user.FieldTotpEnabledAt)
+ }
return fields
}
@@ -15467,6 +15613,12 @@ func (m *UserMutation) Field(name string) (ent.Value, bool) {
return m.Username()
case user.FieldNotes:
return m.Notes()
+ case user.FieldTotpSecretEncrypted:
+ return m.TotpSecretEncrypted()
+ case user.FieldTotpEnabled:
+ return m.TotpEnabled()
+ case user.FieldTotpEnabledAt:
+ return m.TotpEnabledAt()
}
return nil, false
}
@@ -15498,6 +15650,12 @@ func (m *UserMutation) OldField(ctx context.Context, name string) (ent.Value, er
return m.OldUsername(ctx)
case user.FieldNotes:
return m.OldNotes(ctx)
+ case user.FieldTotpSecretEncrypted:
+ return m.OldTotpSecretEncrypted(ctx)
+ case user.FieldTotpEnabled:
+ return m.OldTotpEnabled(ctx)
+ case user.FieldTotpEnabledAt:
+ return m.OldTotpEnabledAt(ctx)
}
return nil, fmt.Errorf("unknown User field %s", name)
}
@@ -15584,6 +15742,27 @@ func (m *UserMutation) SetField(name string, value ent.Value) error {
}
m.SetNotes(v)
return nil
+ case user.FieldTotpSecretEncrypted:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetTotpSecretEncrypted(v)
+ return nil
+ case user.FieldTotpEnabled:
+ v, ok := value.(bool)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetTotpEnabled(v)
+ return nil
+ case user.FieldTotpEnabledAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetTotpEnabledAt(v)
+ return nil
}
return fmt.Errorf("unknown User field %s", name)
}
@@ -15644,6 +15823,12 @@ func (m *UserMutation) ClearedFields() []string {
if m.FieldCleared(user.FieldDeletedAt) {
fields = append(fields, user.FieldDeletedAt)
}
+ if m.FieldCleared(user.FieldTotpSecretEncrypted) {
+ fields = append(fields, user.FieldTotpSecretEncrypted)
+ }
+ if m.FieldCleared(user.FieldTotpEnabledAt) {
+ fields = append(fields, user.FieldTotpEnabledAt)
+ }
return fields
}
@@ -15661,6 +15846,12 @@ func (m *UserMutation) ClearField(name string) error {
case user.FieldDeletedAt:
m.ClearDeletedAt()
return nil
+ case user.FieldTotpSecretEncrypted:
+ m.ClearTotpSecretEncrypted()
+ return nil
+ case user.FieldTotpEnabledAt:
+ m.ClearTotpEnabledAt()
+ return nil
}
return fmt.Errorf("unknown User nullable field %s", name)
}
@@ -15702,6 +15893,15 @@ func (m *UserMutation) ResetField(name string) error {
case user.FieldNotes:
m.ResetNotes()
return nil
+ case user.FieldTotpSecretEncrypted:
+ m.ResetTotpSecretEncrypted()
+ return nil
+ case user.FieldTotpEnabled:
+ m.ResetTotpEnabled()
+ return nil
+ case user.FieldTotpEnabledAt:
+ m.ResetTotpEnabledAt()
+ return nil
}
return fmt.Errorf("unknown User field %s", name)
}
diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go
index 1e3f4cbe..14323f8c 100644
--- a/backend/ent/runtime/runtime.go
+++ b/backend/ent/runtime/runtime.go
@@ -736,6 +736,10 @@ func init() {
userDescNotes := userFields[7].Descriptor()
// user.DefaultNotes holds the default value on creation for the notes field.
user.DefaultNotes = userDescNotes.Default.(string)
+ // userDescTotpEnabled is the schema descriptor for totp_enabled field.
+ userDescTotpEnabled := userFields[9].Descriptor()
+ // user.DefaultTotpEnabled holds the default value on creation for the totp_enabled field.
+ user.DefaultTotpEnabled = userDescTotpEnabled.Default.(bool)
userallowedgroupFields := schema.UserAllowedGroup{}.Fields()
_ = userallowedgroupFields
// userallowedgroupDescCreatedAt is the schema descriptor for created_at field.
diff --git a/backend/ent/schema/user.go b/backend/ent/schema/user.go
index 79dc2286..335c1cc8 100644
--- a/backend/ent/schema/user.go
+++ b/backend/ent/schema/user.go
@@ -61,6 +61,17 @@ func (User) Fields() []ent.Field {
field.String("notes").
SchemaType(map[string]string{dialect.Postgres: "text"}).
Default(""),
+
+ // TOTP 双因素认证字段
+ field.String("totp_secret_encrypted").
+ SchemaType(map[string]string{dialect.Postgres: "text"}).
+ Optional().
+ Nillable(),
+ field.Bool("totp_enabled").
+ Default(false),
+ field.Time("totp_enabled_at").
+ Optional().
+ Nillable(),
}
}
diff --git a/backend/ent/user.go b/backend/ent/user.go
index 0b9a48cc..82830a95 100644
--- a/backend/ent/user.go
+++ b/backend/ent/user.go
@@ -39,6 +39,12 @@ type User struct {
Username string `json:"username,omitempty"`
// Notes holds the value of the "notes" field.
Notes string `json:"notes,omitempty"`
+ // TotpSecretEncrypted holds the value of the "totp_secret_encrypted" field.
+ TotpSecretEncrypted *string `json:"totp_secret_encrypted,omitempty"`
+ // TotpEnabled holds the value of the "totp_enabled" field.
+ TotpEnabled bool `json:"totp_enabled,omitempty"`
+ // TotpEnabledAt holds the value of the "totp_enabled_at" field.
+ TotpEnabledAt *time.Time `json:"totp_enabled_at,omitempty"`
// Edges holds the relations/edges for other nodes in the graph.
// The values are being populated by the UserQuery when eager-loading is set.
Edges UserEdges `json:"edges"`
@@ -156,13 +162,15 @@ func (*User) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns))
for i := range columns {
switch columns[i] {
+ case user.FieldTotpEnabled:
+ values[i] = new(sql.NullBool)
case user.FieldBalance:
values[i] = new(sql.NullFloat64)
case user.FieldID, user.FieldConcurrency:
values[i] = new(sql.NullInt64)
- case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes:
+ case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes, user.FieldTotpSecretEncrypted:
values[i] = new(sql.NullString)
- case user.FieldCreatedAt, user.FieldUpdatedAt, user.FieldDeletedAt:
+ case user.FieldCreatedAt, user.FieldUpdatedAt, user.FieldDeletedAt, user.FieldTotpEnabledAt:
values[i] = new(sql.NullTime)
default:
values[i] = new(sql.UnknownType)
@@ -252,6 +260,26 @@ func (_m *User) assignValues(columns []string, values []any) error {
} else if value.Valid {
_m.Notes = value.String
}
+ case user.FieldTotpSecretEncrypted:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field totp_secret_encrypted", values[i])
+ } else if value.Valid {
+ _m.TotpSecretEncrypted = new(string)
+ *_m.TotpSecretEncrypted = value.String
+ }
+ case user.FieldTotpEnabled:
+ if value, ok := values[i].(*sql.NullBool); !ok {
+ return fmt.Errorf("unexpected type %T for field totp_enabled", values[i])
+ } else if value.Valid {
+ _m.TotpEnabled = value.Bool
+ }
+ case user.FieldTotpEnabledAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field totp_enabled_at", values[i])
+ } else if value.Valid {
+ _m.TotpEnabledAt = new(time.Time)
+ *_m.TotpEnabledAt = value.Time
+ }
default:
_m.selectValues.Set(columns[i], values[i])
}
@@ -367,6 +395,19 @@ func (_m *User) String() string {
builder.WriteString(", ")
builder.WriteString("notes=")
builder.WriteString(_m.Notes)
+ builder.WriteString(", ")
+ if v := _m.TotpSecretEncrypted; v != nil {
+ builder.WriteString("totp_secret_encrypted=")
+ builder.WriteString(*v)
+ }
+ builder.WriteString(", ")
+ builder.WriteString("totp_enabled=")
+ builder.WriteString(fmt.Sprintf("%v", _m.TotpEnabled))
+ builder.WriteString(", ")
+ if v := _m.TotpEnabledAt; v != nil {
+ builder.WriteString("totp_enabled_at=")
+ builder.WriteString(v.Format(time.ANSIC))
+ }
builder.WriteByte(')')
return builder.String()
}
diff --git a/backend/ent/user/user.go b/backend/ent/user/user.go
index 1be1d871..0685ed72 100644
--- a/backend/ent/user/user.go
+++ b/backend/ent/user/user.go
@@ -37,6 +37,12 @@ const (
FieldUsername = "username"
// FieldNotes holds the string denoting the notes field in the database.
FieldNotes = "notes"
+ // FieldTotpSecretEncrypted holds the string denoting the totp_secret_encrypted field in the database.
+ FieldTotpSecretEncrypted = "totp_secret_encrypted"
+ // FieldTotpEnabled holds the string denoting the totp_enabled field in the database.
+ FieldTotpEnabled = "totp_enabled"
+ // FieldTotpEnabledAt holds the string denoting the totp_enabled_at field in the database.
+ FieldTotpEnabledAt = "totp_enabled_at"
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
EdgeAPIKeys = "api_keys"
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
@@ -134,6 +140,9 @@ var Columns = []string{
FieldStatus,
FieldUsername,
FieldNotes,
+ FieldTotpSecretEncrypted,
+ FieldTotpEnabled,
+ FieldTotpEnabledAt,
}
var (
@@ -188,6 +197,8 @@ var (
UsernameValidator func(string) error
// DefaultNotes holds the default value on creation for the "notes" field.
DefaultNotes string
+ // DefaultTotpEnabled holds the default value on creation for the "totp_enabled" field.
+ DefaultTotpEnabled bool
)
// OrderOption defines the ordering options for the User queries.
@@ -253,6 +264,21 @@ func ByNotes(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldNotes, opts...).ToFunc()
}
+// ByTotpSecretEncrypted orders the results by the totp_secret_encrypted field.
+func ByTotpSecretEncrypted(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldTotpSecretEncrypted, opts...).ToFunc()
+}
+
+// ByTotpEnabled orders the results by the totp_enabled field.
+func ByTotpEnabled(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldTotpEnabled, opts...).ToFunc()
+}
+
+// ByTotpEnabledAt orders the results by the totp_enabled_at field.
+func ByTotpEnabledAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldTotpEnabledAt, opts...).ToFunc()
+}
+
// ByAPIKeysCount orders the results by api_keys count.
func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
diff --git a/backend/ent/user/where.go b/backend/ent/user/where.go
index 6a460f10..3dc4fec7 100644
--- a/backend/ent/user/where.go
+++ b/backend/ent/user/where.go
@@ -110,6 +110,21 @@ func Notes(v string) predicate.User {
return predicate.User(sql.FieldEQ(FieldNotes, v))
}
+// TotpSecretEncrypted applies equality check predicate on the "totp_secret_encrypted" field. It's identical to TotpSecretEncryptedEQ.
+func TotpSecretEncrypted(v string) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldTotpSecretEncrypted, v))
+}
+
+// TotpEnabled applies equality check predicate on the "totp_enabled" field. It's identical to TotpEnabledEQ.
+func TotpEnabled(v bool) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldTotpEnabled, v))
+}
+
+// TotpEnabledAt applies equality check predicate on the "totp_enabled_at" field. It's identical to TotpEnabledAtEQ.
+func TotpEnabledAt(v time.Time) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldTotpEnabledAt, v))
+}
+
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
func CreatedAtEQ(v time.Time) predicate.User {
return predicate.User(sql.FieldEQ(FieldCreatedAt, v))
@@ -710,6 +725,141 @@ func NotesContainsFold(v string) predicate.User {
return predicate.User(sql.FieldContainsFold(FieldNotes, v))
}
+// TotpSecretEncryptedEQ applies the EQ predicate on the "totp_secret_encrypted" field.
+func TotpSecretEncryptedEQ(v string) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldTotpSecretEncrypted, v))
+}
+
+// TotpSecretEncryptedNEQ applies the NEQ predicate on the "totp_secret_encrypted" field.
+func TotpSecretEncryptedNEQ(v string) predicate.User {
+ return predicate.User(sql.FieldNEQ(FieldTotpSecretEncrypted, v))
+}
+
+// TotpSecretEncryptedIn applies the In predicate on the "totp_secret_encrypted" field.
+func TotpSecretEncryptedIn(vs ...string) predicate.User {
+ return predicate.User(sql.FieldIn(FieldTotpSecretEncrypted, vs...))
+}
+
+// TotpSecretEncryptedNotIn applies the NotIn predicate on the "totp_secret_encrypted" field.
+func TotpSecretEncryptedNotIn(vs ...string) predicate.User {
+ return predicate.User(sql.FieldNotIn(FieldTotpSecretEncrypted, vs...))
+}
+
+// TotpSecretEncryptedGT applies the GT predicate on the "totp_secret_encrypted" field.
+func TotpSecretEncryptedGT(v string) predicate.User {
+ return predicate.User(sql.FieldGT(FieldTotpSecretEncrypted, v))
+}
+
+// TotpSecretEncryptedGTE applies the GTE predicate on the "totp_secret_encrypted" field.
+func TotpSecretEncryptedGTE(v string) predicate.User {
+ return predicate.User(sql.FieldGTE(FieldTotpSecretEncrypted, v))
+}
+
+// TotpSecretEncryptedLT applies the LT predicate on the "totp_secret_encrypted" field.
+func TotpSecretEncryptedLT(v string) predicate.User {
+ return predicate.User(sql.FieldLT(FieldTotpSecretEncrypted, v))
+}
+
+// TotpSecretEncryptedLTE applies the LTE predicate on the "totp_secret_encrypted" field.
+func TotpSecretEncryptedLTE(v string) predicate.User {
+ return predicate.User(sql.FieldLTE(FieldTotpSecretEncrypted, v))
+}
+
+// TotpSecretEncryptedContains applies the Contains predicate on the "totp_secret_encrypted" field.
+func TotpSecretEncryptedContains(v string) predicate.User {
+ return predicate.User(sql.FieldContains(FieldTotpSecretEncrypted, v))
+}
+
+// TotpSecretEncryptedHasPrefix applies the HasPrefix predicate on the "totp_secret_encrypted" field.
+func TotpSecretEncryptedHasPrefix(v string) predicate.User {
+ return predicate.User(sql.FieldHasPrefix(FieldTotpSecretEncrypted, v))
+}
+
+// TotpSecretEncryptedHasSuffix applies the HasSuffix predicate on the "totp_secret_encrypted" field.
+func TotpSecretEncryptedHasSuffix(v string) predicate.User {
+ return predicate.User(sql.FieldHasSuffix(FieldTotpSecretEncrypted, v))
+}
+
+// TotpSecretEncryptedIsNil applies the IsNil predicate on the "totp_secret_encrypted" field.
+func TotpSecretEncryptedIsNil() predicate.User {
+ return predicate.User(sql.FieldIsNull(FieldTotpSecretEncrypted))
+}
+
+// TotpSecretEncryptedNotNil applies the NotNil predicate on the "totp_secret_encrypted" field.
+func TotpSecretEncryptedNotNil() predicate.User {
+ return predicate.User(sql.FieldNotNull(FieldTotpSecretEncrypted))
+}
+
+// TotpSecretEncryptedEqualFold applies the EqualFold predicate on the "totp_secret_encrypted" field.
+func TotpSecretEncryptedEqualFold(v string) predicate.User {
+ return predicate.User(sql.FieldEqualFold(FieldTotpSecretEncrypted, v))
+}
+
+// TotpSecretEncryptedContainsFold applies the ContainsFold predicate on the "totp_secret_encrypted" field.
+func TotpSecretEncryptedContainsFold(v string) predicate.User {
+ return predicate.User(sql.FieldContainsFold(FieldTotpSecretEncrypted, v))
+}
+
+// TotpEnabledEQ applies the EQ predicate on the "totp_enabled" field.
+func TotpEnabledEQ(v bool) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldTotpEnabled, v))
+}
+
+// TotpEnabledNEQ applies the NEQ predicate on the "totp_enabled" field.
+func TotpEnabledNEQ(v bool) predicate.User {
+ return predicate.User(sql.FieldNEQ(FieldTotpEnabled, v))
+}
+
+// TotpEnabledAtEQ applies the EQ predicate on the "totp_enabled_at" field.
+func TotpEnabledAtEQ(v time.Time) predicate.User {
+ return predicate.User(sql.FieldEQ(FieldTotpEnabledAt, v))
+}
+
+// TotpEnabledAtNEQ applies the NEQ predicate on the "totp_enabled_at" field.
+func TotpEnabledAtNEQ(v time.Time) predicate.User {
+ return predicate.User(sql.FieldNEQ(FieldTotpEnabledAt, v))
+}
+
+// TotpEnabledAtIn applies the In predicate on the "totp_enabled_at" field.
+func TotpEnabledAtIn(vs ...time.Time) predicate.User {
+ return predicate.User(sql.FieldIn(FieldTotpEnabledAt, vs...))
+}
+
+// TotpEnabledAtNotIn applies the NotIn predicate on the "totp_enabled_at" field.
+func TotpEnabledAtNotIn(vs ...time.Time) predicate.User {
+ return predicate.User(sql.FieldNotIn(FieldTotpEnabledAt, vs...))
+}
+
+// TotpEnabledAtGT applies the GT predicate on the "totp_enabled_at" field.
+func TotpEnabledAtGT(v time.Time) predicate.User {
+ return predicate.User(sql.FieldGT(FieldTotpEnabledAt, v))
+}
+
+// TotpEnabledAtGTE applies the GTE predicate on the "totp_enabled_at" field.
+func TotpEnabledAtGTE(v time.Time) predicate.User {
+ return predicate.User(sql.FieldGTE(FieldTotpEnabledAt, v))
+}
+
+// TotpEnabledAtLT applies the LT predicate on the "totp_enabled_at" field.
+func TotpEnabledAtLT(v time.Time) predicate.User {
+ return predicate.User(sql.FieldLT(FieldTotpEnabledAt, v))
+}
+
+// TotpEnabledAtLTE applies the LTE predicate on the "totp_enabled_at" field.
+func TotpEnabledAtLTE(v time.Time) predicate.User {
+ return predicate.User(sql.FieldLTE(FieldTotpEnabledAt, v))
+}
+
+// TotpEnabledAtIsNil applies the IsNil predicate on the "totp_enabled_at" field.
+func TotpEnabledAtIsNil() predicate.User {
+ return predicate.User(sql.FieldIsNull(FieldTotpEnabledAt))
+}
+
+// TotpEnabledAtNotNil applies the NotNil predicate on the "totp_enabled_at" field.
+func TotpEnabledAtNotNil() predicate.User {
+ return predicate.User(sql.FieldNotNull(FieldTotpEnabledAt))
+}
+
// HasAPIKeys applies the HasEdge predicate on the "api_keys" edge.
func HasAPIKeys() predicate.User {
return predicate.User(func(s *sql.Selector) {
diff --git a/backend/ent/user_create.go b/backend/ent/user_create.go
index e12e476c..6b4ebc59 100644
--- a/backend/ent/user_create.go
+++ b/backend/ent/user_create.go
@@ -167,6 +167,48 @@ func (_c *UserCreate) SetNillableNotes(v *string) *UserCreate {
return _c
}
+// SetTotpSecretEncrypted sets the "totp_secret_encrypted" field.
+func (_c *UserCreate) SetTotpSecretEncrypted(v string) *UserCreate {
+ _c.mutation.SetTotpSecretEncrypted(v)
+ return _c
+}
+
+// SetNillableTotpSecretEncrypted sets the "totp_secret_encrypted" field if the given value is not nil.
+func (_c *UserCreate) SetNillableTotpSecretEncrypted(v *string) *UserCreate {
+ if v != nil {
+ _c.SetTotpSecretEncrypted(*v)
+ }
+ return _c
+}
+
+// SetTotpEnabled sets the "totp_enabled" field.
+func (_c *UserCreate) SetTotpEnabled(v bool) *UserCreate {
+ _c.mutation.SetTotpEnabled(v)
+ return _c
+}
+
+// SetNillableTotpEnabled sets the "totp_enabled" field if the given value is not nil.
+func (_c *UserCreate) SetNillableTotpEnabled(v *bool) *UserCreate {
+ if v != nil {
+ _c.SetTotpEnabled(*v)
+ }
+ return _c
+}
+
+// SetTotpEnabledAt sets the "totp_enabled_at" field.
+func (_c *UserCreate) SetTotpEnabledAt(v time.Time) *UserCreate {
+ _c.mutation.SetTotpEnabledAt(v)
+ return _c
+}
+
+// SetNillableTotpEnabledAt sets the "totp_enabled_at" field if the given value is not nil.
+func (_c *UserCreate) SetNillableTotpEnabledAt(v *time.Time) *UserCreate {
+ if v != nil {
+ _c.SetTotpEnabledAt(*v)
+ }
+ return _c
+}
+
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_c *UserCreate) AddAPIKeyIDs(ids ...int64) *UserCreate {
_c.mutation.AddAPIKeyIDs(ids...)
@@ -362,6 +404,10 @@ func (_c *UserCreate) defaults() error {
v := user.DefaultNotes
_c.mutation.SetNotes(v)
}
+ if _, ok := _c.mutation.TotpEnabled(); !ok {
+ v := user.DefaultTotpEnabled
+ _c.mutation.SetTotpEnabled(v)
+ }
return nil
}
@@ -422,6 +468,9 @@ func (_c *UserCreate) check() error {
if _, ok := _c.mutation.Notes(); !ok {
return &ValidationError{Name: "notes", err: errors.New(`ent: missing required field "User.notes"`)}
}
+ if _, ok := _c.mutation.TotpEnabled(); !ok {
+ return &ValidationError{Name: "totp_enabled", err: errors.New(`ent: missing required field "User.totp_enabled"`)}
+ }
return nil
}
@@ -493,6 +542,18 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) {
_spec.SetField(user.FieldNotes, field.TypeString, value)
_node.Notes = value
}
+ if value, ok := _c.mutation.TotpSecretEncrypted(); ok {
+ _spec.SetField(user.FieldTotpSecretEncrypted, field.TypeString, value)
+ _node.TotpSecretEncrypted = &value
+ }
+ if value, ok := _c.mutation.TotpEnabled(); ok {
+ _spec.SetField(user.FieldTotpEnabled, field.TypeBool, value)
+ _node.TotpEnabled = value
+ }
+ if value, ok := _c.mutation.TotpEnabledAt(); ok {
+ _spec.SetField(user.FieldTotpEnabledAt, field.TypeTime, value)
+ _node.TotpEnabledAt = &value
+ }
if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
@@ -815,6 +876,54 @@ func (u *UserUpsert) UpdateNotes() *UserUpsert {
return u
}
+// SetTotpSecretEncrypted sets the "totp_secret_encrypted" field.
+func (u *UserUpsert) SetTotpSecretEncrypted(v string) *UserUpsert {
+ u.Set(user.FieldTotpSecretEncrypted, v)
+ return u
+}
+
+// UpdateTotpSecretEncrypted sets the "totp_secret_encrypted" field to the value that was provided on create.
+func (u *UserUpsert) UpdateTotpSecretEncrypted() *UserUpsert {
+ u.SetExcluded(user.FieldTotpSecretEncrypted)
+ return u
+}
+
+// ClearTotpSecretEncrypted clears the value of the "totp_secret_encrypted" field.
+func (u *UserUpsert) ClearTotpSecretEncrypted() *UserUpsert {
+ u.SetNull(user.FieldTotpSecretEncrypted)
+ return u
+}
+
+// SetTotpEnabled sets the "totp_enabled" field.
+func (u *UserUpsert) SetTotpEnabled(v bool) *UserUpsert {
+ u.Set(user.FieldTotpEnabled, v)
+ return u
+}
+
+// UpdateTotpEnabled sets the "totp_enabled" field to the value that was provided on create.
+func (u *UserUpsert) UpdateTotpEnabled() *UserUpsert {
+ u.SetExcluded(user.FieldTotpEnabled)
+ return u
+}
+
+// SetTotpEnabledAt sets the "totp_enabled_at" field.
+func (u *UserUpsert) SetTotpEnabledAt(v time.Time) *UserUpsert {
+ u.Set(user.FieldTotpEnabledAt, v)
+ return u
+}
+
+// UpdateTotpEnabledAt sets the "totp_enabled_at" field to the value that was provided on create.
+func (u *UserUpsert) UpdateTotpEnabledAt() *UserUpsert {
+ u.SetExcluded(user.FieldTotpEnabledAt)
+ return u
+}
+
+// ClearTotpEnabledAt clears the value of the "totp_enabled_at" field.
+func (u *UserUpsert) ClearTotpEnabledAt() *UserUpsert {
+ u.SetNull(user.FieldTotpEnabledAt)
+ return u
+}
+
// UpdateNewValues updates the mutable fields using the new values that were set on create.
// Using this option is equivalent to using:
//
@@ -1021,6 +1130,62 @@ func (u *UserUpsertOne) UpdateNotes() *UserUpsertOne {
})
}
+// SetTotpSecretEncrypted sets the "totp_secret_encrypted" field.
+func (u *UserUpsertOne) SetTotpSecretEncrypted(v string) *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.SetTotpSecretEncrypted(v)
+ })
+}
+
+// UpdateTotpSecretEncrypted sets the "totp_secret_encrypted" field to the value that was provided on create.
+func (u *UserUpsertOne) UpdateTotpSecretEncrypted() *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateTotpSecretEncrypted()
+ })
+}
+
+// ClearTotpSecretEncrypted clears the value of the "totp_secret_encrypted" field.
+func (u *UserUpsertOne) ClearTotpSecretEncrypted() *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.ClearTotpSecretEncrypted()
+ })
+}
+
+// SetTotpEnabled sets the "totp_enabled" field.
+func (u *UserUpsertOne) SetTotpEnabled(v bool) *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.SetTotpEnabled(v)
+ })
+}
+
+// UpdateTotpEnabled sets the "totp_enabled" field to the value that was provided on create.
+func (u *UserUpsertOne) UpdateTotpEnabled() *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateTotpEnabled()
+ })
+}
+
+// SetTotpEnabledAt sets the "totp_enabled_at" field.
+func (u *UserUpsertOne) SetTotpEnabledAt(v time.Time) *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.SetTotpEnabledAt(v)
+ })
+}
+
+// UpdateTotpEnabledAt sets the "totp_enabled_at" field to the value that was provided on create.
+func (u *UserUpsertOne) UpdateTotpEnabledAt() *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateTotpEnabledAt()
+ })
+}
+
+// ClearTotpEnabledAt clears the value of the "totp_enabled_at" field.
+func (u *UserUpsertOne) ClearTotpEnabledAt() *UserUpsertOne {
+ return u.Update(func(s *UserUpsert) {
+ s.ClearTotpEnabledAt()
+ })
+}
+
// Exec executes the query.
func (u *UserUpsertOne) Exec(ctx context.Context) error {
if len(u.create.conflict) == 0 {
@@ -1393,6 +1558,62 @@ func (u *UserUpsertBulk) UpdateNotes() *UserUpsertBulk {
})
}
+// SetTotpSecretEncrypted sets the "totp_secret_encrypted" field.
+func (u *UserUpsertBulk) SetTotpSecretEncrypted(v string) *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.SetTotpSecretEncrypted(v)
+ })
+}
+
+// UpdateTotpSecretEncrypted sets the "totp_secret_encrypted" field to the value that was provided on create.
+func (u *UserUpsertBulk) UpdateTotpSecretEncrypted() *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateTotpSecretEncrypted()
+ })
+}
+
+// ClearTotpSecretEncrypted clears the value of the "totp_secret_encrypted" field.
+func (u *UserUpsertBulk) ClearTotpSecretEncrypted() *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.ClearTotpSecretEncrypted()
+ })
+}
+
+// SetTotpEnabled sets the "totp_enabled" field.
+func (u *UserUpsertBulk) SetTotpEnabled(v bool) *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.SetTotpEnabled(v)
+ })
+}
+
+// UpdateTotpEnabled sets the "totp_enabled" field to the value that was provided on create.
+func (u *UserUpsertBulk) UpdateTotpEnabled() *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateTotpEnabled()
+ })
+}
+
+// SetTotpEnabledAt sets the "totp_enabled_at" field.
+func (u *UserUpsertBulk) SetTotpEnabledAt(v time.Time) *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.SetTotpEnabledAt(v)
+ })
+}
+
+// UpdateTotpEnabledAt sets the "totp_enabled_at" field to the value that was provided on create.
+func (u *UserUpsertBulk) UpdateTotpEnabledAt() *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.UpdateTotpEnabledAt()
+ })
+}
+
+// ClearTotpEnabledAt clears the value of the "totp_enabled_at" field.
+func (u *UserUpsertBulk) ClearTotpEnabledAt() *UserUpsertBulk {
+ return u.Update(func(s *UserUpsert) {
+ s.ClearTotpEnabledAt()
+ })
+}
+
// Exec executes the query.
func (u *UserUpsertBulk) Exec(ctx context.Context) error {
if u.create.err != nil {
diff --git a/backend/ent/user_update.go b/backend/ent/user_update.go
index cf189fea..b98a41c6 100644
--- a/backend/ent/user_update.go
+++ b/backend/ent/user_update.go
@@ -187,6 +187,60 @@ func (_u *UserUpdate) SetNillableNotes(v *string) *UserUpdate {
return _u
}
+// SetTotpSecretEncrypted sets the "totp_secret_encrypted" field.
+func (_u *UserUpdate) SetTotpSecretEncrypted(v string) *UserUpdate {
+ _u.mutation.SetTotpSecretEncrypted(v)
+ return _u
+}
+
+// SetNillableTotpSecretEncrypted sets the "totp_secret_encrypted" field if the given value is not nil.
+func (_u *UserUpdate) SetNillableTotpSecretEncrypted(v *string) *UserUpdate {
+ if v != nil {
+ _u.SetTotpSecretEncrypted(*v)
+ }
+ return _u
+}
+
+// ClearTotpSecretEncrypted clears the value of the "totp_secret_encrypted" field.
+func (_u *UserUpdate) ClearTotpSecretEncrypted() *UserUpdate {
+ _u.mutation.ClearTotpSecretEncrypted()
+ return _u
+}
+
+// SetTotpEnabled sets the "totp_enabled" field.
+func (_u *UserUpdate) SetTotpEnabled(v bool) *UserUpdate {
+ _u.mutation.SetTotpEnabled(v)
+ return _u
+}
+
+// SetNillableTotpEnabled sets the "totp_enabled" field if the given value is not nil.
+func (_u *UserUpdate) SetNillableTotpEnabled(v *bool) *UserUpdate {
+ if v != nil {
+ _u.SetTotpEnabled(*v)
+ }
+ return _u
+}
+
+// SetTotpEnabledAt sets the "totp_enabled_at" field.
+func (_u *UserUpdate) SetTotpEnabledAt(v time.Time) *UserUpdate {
+ _u.mutation.SetTotpEnabledAt(v)
+ return _u
+}
+
+// SetNillableTotpEnabledAt sets the "totp_enabled_at" field if the given value is not nil.
+func (_u *UserUpdate) SetNillableTotpEnabledAt(v *time.Time) *UserUpdate {
+ if v != nil {
+ _u.SetTotpEnabledAt(*v)
+ }
+ return _u
+}
+
+// ClearTotpEnabledAt clears the value of the "totp_enabled_at" field.
+func (_u *UserUpdate) ClearTotpEnabledAt() *UserUpdate {
+ _u.mutation.ClearTotpEnabledAt()
+ return _u
+}
+
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_u *UserUpdate) AddAPIKeyIDs(ids ...int64) *UserUpdate {
_u.mutation.AddAPIKeyIDs(ids...)
@@ -603,6 +657,21 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if value, ok := _u.mutation.Notes(); ok {
_spec.SetField(user.FieldNotes, field.TypeString, value)
}
+ if value, ok := _u.mutation.TotpSecretEncrypted(); ok {
+ _spec.SetField(user.FieldTotpSecretEncrypted, field.TypeString, value)
+ }
+ if _u.mutation.TotpSecretEncryptedCleared() {
+ _spec.ClearField(user.FieldTotpSecretEncrypted, field.TypeString)
+ }
+ if value, ok := _u.mutation.TotpEnabled(); ok {
+ _spec.SetField(user.FieldTotpEnabled, field.TypeBool, value)
+ }
+ if value, ok := _u.mutation.TotpEnabledAt(); ok {
+ _spec.SetField(user.FieldTotpEnabledAt, field.TypeTime, value)
+ }
+ if _u.mutation.TotpEnabledAtCleared() {
+ _spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime)
+ }
if _u.mutation.APIKeysCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
@@ -1147,6 +1216,60 @@ func (_u *UserUpdateOne) SetNillableNotes(v *string) *UserUpdateOne {
return _u
}
+// SetTotpSecretEncrypted sets the "totp_secret_encrypted" field.
+func (_u *UserUpdateOne) SetTotpSecretEncrypted(v string) *UserUpdateOne {
+ _u.mutation.SetTotpSecretEncrypted(v)
+ return _u
+}
+
+// SetNillableTotpSecretEncrypted sets the "totp_secret_encrypted" field if the given value is not nil.
+func (_u *UserUpdateOne) SetNillableTotpSecretEncrypted(v *string) *UserUpdateOne {
+ if v != nil {
+ _u.SetTotpSecretEncrypted(*v)
+ }
+ return _u
+}
+
+// ClearTotpSecretEncrypted clears the value of the "totp_secret_encrypted" field.
+func (_u *UserUpdateOne) ClearTotpSecretEncrypted() *UserUpdateOne {
+ _u.mutation.ClearTotpSecretEncrypted()
+ return _u
+}
+
+// SetTotpEnabled sets the "totp_enabled" field.
+func (_u *UserUpdateOne) SetTotpEnabled(v bool) *UserUpdateOne {
+ _u.mutation.SetTotpEnabled(v)
+ return _u
+}
+
+// SetNillableTotpEnabled sets the "totp_enabled" field if the given value is not nil.
+func (_u *UserUpdateOne) SetNillableTotpEnabled(v *bool) *UserUpdateOne {
+ if v != nil {
+ _u.SetTotpEnabled(*v)
+ }
+ return _u
+}
+
+// SetTotpEnabledAt sets the "totp_enabled_at" field.
+func (_u *UserUpdateOne) SetTotpEnabledAt(v time.Time) *UserUpdateOne {
+ _u.mutation.SetTotpEnabledAt(v)
+ return _u
+}
+
+// SetNillableTotpEnabledAt sets the "totp_enabled_at" field if the given value is not nil.
+func (_u *UserUpdateOne) SetNillableTotpEnabledAt(v *time.Time) *UserUpdateOne {
+ if v != nil {
+ _u.SetTotpEnabledAt(*v)
+ }
+ return _u
+}
+
+// ClearTotpEnabledAt clears the value of the "totp_enabled_at" field.
+func (_u *UserUpdateOne) ClearTotpEnabledAt() *UserUpdateOne {
+ _u.mutation.ClearTotpEnabledAt()
+ return _u
+}
+
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_u *UserUpdateOne) AddAPIKeyIDs(ids ...int64) *UserUpdateOne {
_u.mutation.AddAPIKeyIDs(ids...)
@@ -1593,6 +1716,21 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) {
if value, ok := _u.mutation.Notes(); ok {
_spec.SetField(user.FieldNotes, field.TypeString, value)
}
+ if value, ok := _u.mutation.TotpSecretEncrypted(); ok {
+ _spec.SetField(user.FieldTotpSecretEncrypted, field.TypeString, value)
+ }
+ if _u.mutation.TotpSecretEncryptedCleared() {
+ _spec.ClearField(user.FieldTotpSecretEncrypted, field.TypeString)
+ }
+ if value, ok := _u.mutation.TotpEnabled(); ok {
+ _spec.SetField(user.FieldTotpEnabled, field.TypeBool, value)
+ }
+ if value, ok := _u.mutation.TotpEnabledAt(); ok {
+ _spec.SetField(user.FieldTotpEnabledAt, field.TypeTime, value)
+ }
+ if _u.mutation.TotpEnabledAtCleared() {
+ _spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime)
+ }
if _u.mutation.APIKeysCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
diff --git a/backend/go.mod b/backend/go.mod
index fd429b07..ad7d76b6 100644
--- a/backend/go.mod
+++ b/backend/go.mod
@@ -37,6 +37,7 @@ require (
github.com/andybalholm/brotli v1.2.0 // indirect
github.com/apparentlymart/go-textseg/v15 v15.0.0 // indirect
github.com/bmatcuk/doublestar v1.3.4 // indirect
+ github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect
github.com/bytedance/sonic v1.9.1 // indirect
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
@@ -106,6 +107,7 @@ require (
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
+ github.com/pquerna/otp v1.5.0 // indirect
github.com/quic-go/qpack v0.6.0 // indirect
github.com/quic-go/quic-go v0.57.1 // indirect
github.com/refraction-networking/utls v1.8.1 // indirect
diff --git a/backend/go.sum b/backend/go.sum
index aa10718c..0addb5bb 100644
--- a/backend/go.sum
+++ b/backend/go.sum
@@ -20,6 +20,8 @@ github.com/apparentlymart/go-textseg/v15 v15.0.0 h1:uYvfpb3DyLSCGWnctWKGj857c6ew
github.com/apparentlymart/go-textseg/v15 v15.0.0/go.mod h1:K8XmNZdhEBkdlyDdvbmmsvpAG721bKi0joRfFdHIWJ4=
github.com/bmatcuk/doublestar v1.3.4 h1:gPypJ5xD31uhX6Tf54sDPUOBXTqKH4c9aPY66CyQrS0=
github.com/bmatcuk/doublestar v1.3.4/go.mod h1:wiQtGV+rzVYxB7WIlirSN++5HPtPlXEo9MEoZQC/PmE=
+github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc h1:biVzkmvwrH8WK8raXaxBx6fRVTlJILwEwQGL1I/ByEI=
+github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
@@ -217,6 +219,8 @@ github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRI
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw=
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
+github.com/pquerna/otp v1.5.0 h1:NMMR+WrmaqXU4EzdGJEE1aUUI0AMRzsp96fFFWNPwxs=
+github.com/pquerna/otp v1.5.0/go.mod h1:dkJfzwRKNiegxyNb54X/3fLwhCynbMspSyWKnvi1AEg=
github.com/prashantv/gostub v1.1.0 h1:BTyx3RfQjRHnUWaGF9oQos79AlQ5k8WNktv7VGvVH4g=
github.com/prashantv/gostub v1.1.0/go.mod h1:A5zLQHz7ieHGG7is6LLXLz7I8+3LZzsrV0P1IAHhP5U=
github.com/quic-go/qpack v0.6.0 h1:g7W+BMYynC1LbYLSqRt8PBg5Tgwxn214ZZR34VIOjz8=
diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go
index 00a78480..477cb59d 100644
--- a/backend/internal/config/config.go
+++ b/backend/internal/config/config.go
@@ -47,6 +47,7 @@ type Config struct {
Redis RedisConfig `mapstructure:"redis"`
Ops OpsConfig `mapstructure:"ops"`
JWT JWTConfig `mapstructure:"jwt"`
+ Totp TotpConfig `mapstructure:"totp"`
LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"`
Default DefaultConfig `mapstructure:"default"`
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
@@ -466,6 +467,16 @@ type JWTConfig struct {
ExpireHour int `mapstructure:"expire_hour"`
}
+// TotpConfig TOTP 双因素认证配置
+type TotpConfig struct {
+ // EncryptionKey 用于加密 TOTP 密钥的 AES-256 密钥(32 字节 hex 编码)
+ // 如果为空,将自动生成一个随机密钥(仅适用于开发环境)
+ EncryptionKey string `mapstructure:"encryption_key"`
+ // EncryptionKeyConfigured 标记加密密钥是否为手动配置(非自动生成)
+ // 只有手动配置了密钥才允许在管理后台启用 TOTP 功能
+ EncryptionKeyConfigured bool `mapstructure:"-"`
+}
+
type TurnstileConfig struct {
Required bool `mapstructure:"required"`
}
@@ -626,6 +637,20 @@ func Load() (*Config, error) {
log.Println("Warning: JWT secret auto-generated. Consider setting a fixed secret for production.")
}
+ // Auto-generate TOTP encryption key if not set (32 bytes = 64 hex chars for AES-256)
+ cfg.Totp.EncryptionKey = strings.TrimSpace(cfg.Totp.EncryptionKey)
+ if cfg.Totp.EncryptionKey == "" {
+ key, err := generateJWTSecret(32) // Reuse the same random generation function
+ if err != nil {
+ return nil, fmt.Errorf("generate totp encryption key error: %w", err)
+ }
+ cfg.Totp.EncryptionKey = key
+ cfg.Totp.EncryptionKeyConfigured = false
+ log.Println("Warning: TOTP encryption key auto-generated. Consider setting a fixed key for production.")
+ } else {
+ cfg.Totp.EncryptionKeyConfigured = true
+ }
+
if err := cfg.Validate(); err != nil {
return nil, fmt.Errorf("validate config error: %w", err)
}
@@ -756,6 +781,9 @@ func setDefaults() {
viper.SetDefault("jwt.secret", "")
viper.SetDefault("jwt.expire_hour", 24)
+ // TOTP
+ viper.SetDefault("totp.encryption_key", "")
+
// Default
// Admin credentials are created via the setup flow (web wizard / CLI / AUTO_SETUP).
// Do not ship fixed defaults here to avoid insecure "known credentials" in production.
diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go
index 9cc2540d..bbf5d026 100644
--- a/backend/internal/handler/admin/account_handler.go
+++ b/backend/internal/handler/admin/account_handler.go
@@ -547,9 +547,18 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
}
}
- // 如果 project_id 获取失败,先更新凭证,再标记账户为 error
+ // 特殊处理 project_id:如果新值为空但旧值非空,保留旧值
+ // 这确保了即使 LoadCodeAssist 失败,project_id 也不会丢失
+ if newProjectID, _ := newCredentials["project_id"].(string); newProjectID == "" {
+ if oldProjectID := strings.TrimSpace(account.GetCredential("project_id")); oldProjectID != "" {
+ newCredentials["project_id"] = oldProjectID
+ }
+ }
+
+ // 如果 project_id 获取失败,更新凭证但不标记为 error
+ // LoadCodeAssist 失败可能是临时网络问题,给它机会在下次自动刷新时重试
if tokenInfo.ProjectIDMissing {
- // 先更新凭证
+ // 先更新凭证(token 本身刷新成功了)
_, updateErr := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
Credentials: newCredentials,
})
@@ -557,14 +566,10 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
response.InternalError(c, "Failed to update credentials: "+updateErr.Error())
return
}
- // 标记账户为 error
- if setErr := h.adminService.SetAccountError(c.Request.Context(), accountID, "missing_project_id: 账户缺少project id,可能无法使用Antigravity"); setErr != nil {
- response.InternalError(c, "Failed to set account error: "+setErr.Error())
- return
- }
+ // 不标记为 error,只返回警告信息
response.Success(c, gin.H{
- "message": "Token refreshed but project_id is missing, account marked as error",
- "warning": "missing_project_id",
+ "message": "Token refreshed successfully, but project_id could not be retrieved (will retry automatically)",
+ "warning": "missing_project_id_temporary",
})
return
}
@@ -668,6 +673,15 @@ func (h *AccountHandler) ClearError(c *gin.Context) {
return
}
+ // 清除错误后,同时清除 token 缓存,确保下次请求会获取最新的 token(触发刷新或从 DB 读取)
+ // 这解决了管理员重置账号状态后,旧的失效 token 仍在缓存中导致立即再次 401 的问题
+ if h.tokenCacheInvalidator != nil && account.IsOAuth() {
+ if invalidateErr := h.tokenCacheInvalidator.InvalidateToken(c.Request.Context(), account); invalidateErr != nil {
+ // 缓存失效失败只记录日志,不影响主流程
+ _ = c.Error(invalidateErr)
+ }
+ }
+
response.Success(c, dto.AccountFromService(account))
}
diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go
index 0e3e0a2f..4a798fa1 100644
--- a/backend/internal/handler/admin/setting_handler.go
+++ b/backend/internal/handler/admin/setting_handler.go
@@ -48,6 +48,9 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled,
PromoCodeEnabled: settings.PromoCodeEnabled,
+ PasswordResetEnabled: settings.PasswordResetEnabled,
+ TotpEnabled: settings.TotpEnabled,
+ TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
SMTPHost: settings.SMTPHost,
SMTPPort: settings.SMTPPort,
SMTPUsername: settings.SMTPUsername,
@@ -89,9 +92,11 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
// UpdateSettingsRequest 更新设置请求
type UpdateSettingsRequest struct {
// 注册设置
- RegistrationEnabled bool `json:"registration_enabled"`
- EmailVerifyEnabled bool `json:"email_verify_enabled"`
- PromoCodeEnabled bool `json:"promo_code_enabled"`
+ RegistrationEnabled bool `json:"registration_enabled"`
+ EmailVerifyEnabled bool `json:"email_verify_enabled"`
+ PromoCodeEnabled bool `json:"promo_code_enabled"`
+ PasswordResetEnabled bool `json:"password_reset_enabled"`
+ TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
// 邮件服务设置
SMTPHost string `json:"smtp_host"`
@@ -198,6 +203,16 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
}
+ // TOTP 双因素认证参数验证
+ // 只有手动配置了加密密钥才允许启用 TOTP 功能
+ if req.TotpEnabled && !previousSettings.TotpEnabled {
+ // 尝试启用 TOTP,检查加密密钥是否已手动配置
+ if !h.settingService.IsTotpEncryptionKeyConfigured() {
+ response.BadRequest(c, "Cannot enable TOTP: TOTP_ENCRYPTION_KEY environment variable must be configured first. Generate a key with 'openssl rand -hex 32' and set it in your environment.")
+ return
+ }
+ }
+
// LinuxDo Connect 参数验证
if req.LinuxDoConnectEnabled {
req.LinuxDoConnectClientID = strings.TrimSpace(req.LinuxDoConnectClientID)
@@ -243,6 +258,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
RegistrationEnabled: req.RegistrationEnabled,
EmailVerifyEnabled: req.EmailVerifyEnabled,
PromoCodeEnabled: req.PromoCodeEnabled,
+ PasswordResetEnabled: req.PasswordResetEnabled,
+ TotpEnabled: req.TotpEnabled,
SMTPHost: req.SMTPHost,
SMTPPort: req.SMTPPort,
SMTPUsername: req.SMTPUsername,
@@ -318,6 +335,9 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
RegistrationEnabled: updatedSettings.RegistrationEnabled,
EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
PromoCodeEnabled: updatedSettings.PromoCodeEnabled,
+ PasswordResetEnabled: updatedSettings.PasswordResetEnabled,
+ TotpEnabled: updatedSettings.TotpEnabled,
+ TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
SMTPHost: updatedSettings.SMTPHost,
SMTPPort: updatedSettings.SMTPPort,
SMTPUsername: updatedSettings.SMTPUsername,
@@ -384,6 +404,12 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.EmailVerifyEnabled != after.EmailVerifyEnabled {
changed = append(changed, "email_verify_enabled")
}
+ if before.PasswordResetEnabled != after.PasswordResetEnabled {
+ changed = append(changed, "password_reset_enabled")
+ }
+ if before.TotpEnabled != after.TotpEnabled {
+ changed = append(changed, "totp_enabled")
+ }
if before.SMTPHost != after.SMTPHost {
changed = append(changed, "smtp_host")
}
diff --git a/backend/internal/handler/admin/subscription_handler.go b/backend/internal/handler/admin/subscription_handler.go
index a0d1456f..51995ab1 100644
--- a/backend/internal/handler/admin/subscription_handler.go
+++ b/backend/internal/handler/admin/subscription_handler.go
@@ -77,7 +77,11 @@ func (h *SubscriptionHandler) List(c *gin.Context) {
}
status := c.Query("status")
- subscriptions, pagination, err := h.subscriptionService.List(c.Request.Context(), page, pageSize, userID, groupID, status)
+ // Parse sorting parameters
+ sortBy := c.DefaultQuery("sort_by", "created_at")
+ sortOrder := c.DefaultQuery("sort_order", "desc")
+
+ subscriptions, pagination, err := h.subscriptionService.List(c.Request.Context(), page, pageSize, userID, groupID, status, sortBy, sortOrder)
if err != nil {
response.ErrorFrom(c, err)
return
diff --git a/backend/internal/handler/auth_handler.go b/backend/internal/handler/auth_handler.go
index 89f34aae..3522407d 100644
--- a/backend/internal/handler/auth_handler.go
+++ b/backend/internal/handler/auth_handler.go
@@ -1,6 +1,8 @@
package handler
import (
+ "log/slog"
+
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
@@ -18,16 +20,18 @@ type AuthHandler struct {
userService *service.UserService
settingSvc *service.SettingService
promoService *service.PromoService
+ totpService *service.TotpService
}
// NewAuthHandler creates a new AuthHandler
-func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService, settingService *service.SettingService, promoService *service.PromoService) *AuthHandler {
+func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService, settingService *service.SettingService, promoService *service.PromoService, totpService *service.TotpService) *AuthHandler {
return &AuthHandler{
cfg: cfg,
authService: authService,
userService: userService,
settingSvc: settingService,
promoService: promoService,
+ totpService: totpService,
}
}
@@ -144,6 +148,100 @@ func (h *AuthHandler) Login(c *gin.Context) {
return
}
+ // Check if TOTP 2FA is enabled for this user
+ if h.totpService != nil && h.settingSvc.IsTotpEnabled(c.Request.Context()) && user.TotpEnabled {
+ // Create a temporary login session for 2FA
+ tempToken, err := h.totpService.CreateLoginSession(c.Request.Context(), user.ID, user.Email)
+ if err != nil {
+ response.InternalError(c, "Failed to create 2FA session")
+ return
+ }
+
+ response.Success(c, TotpLoginResponse{
+ Requires2FA: true,
+ TempToken: tempToken,
+ UserEmailMasked: service.MaskEmail(user.Email),
+ })
+ return
+ }
+
+ response.Success(c, AuthResponse{
+ AccessToken: token,
+ TokenType: "Bearer",
+ User: dto.UserFromService(user),
+ })
+}
+
+// TotpLoginResponse represents the response when 2FA is required
+type TotpLoginResponse struct {
+ Requires2FA bool `json:"requires_2fa"`
+ TempToken string `json:"temp_token,omitempty"`
+ UserEmailMasked string `json:"user_email_masked,omitempty"`
+}
+
+// Login2FARequest represents the 2FA login request
+type Login2FARequest struct {
+ TempToken string `json:"temp_token" binding:"required"`
+ TotpCode string `json:"totp_code" binding:"required,len=6"`
+}
+
+// Login2FA completes the login with 2FA verification
+// POST /api/v1/auth/login/2fa
+func (h *AuthHandler) Login2FA(c *gin.Context) {
+ var req Login2FARequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ slog.Debug("login_2fa_request",
+ "temp_token_len", len(req.TempToken),
+ "totp_code_len", len(req.TotpCode))
+
+ // Get the login session
+ session, err := h.totpService.GetLoginSession(c.Request.Context(), req.TempToken)
+ if err != nil || session == nil {
+ tokenPrefix := ""
+ if len(req.TempToken) >= 8 {
+ tokenPrefix = req.TempToken[:8]
+ }
+ slog.Debug("login_2fa_session_invalid",
+ "temp_token_prefix", tokenPrefix,
+ "error", err)
+ response.BadRequest(c, "Invalid or expired 2FA session")
+ return
+ }
+
+ slog.Debug("login_2fa_session_found",
+ "user_id", session.UserID,
+ "email", session.Email)
+
+ // Verify the TOTP code
+ if err := h.totpService.VerifyCode(c.Request.Context(), session.UserID, req.TotpCode); err != nil {
+ slog.Debug("login_2fa_verify_failed",
+ "user_id", session.UserID,
+ "error", err)
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ // Delete the login session
+ _ = h.totpService.DeleteLoginSession(c.Request.Context(), req.TempToken)
+
+ // Get the user
+ user, err := h.userService.GetByID(c.Request.Context(), session.UserID)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ // Generate the JWT token
+ token, err := h.authService.GenerateToken(user)
+ if err != nil {
+ response.InternalError(c, "Failed to generate token")
+ return
+ }
+
response.Success(c, AuthResponse{
AccessToken: token,
TokenType: "Bearer",
@@ -247,3 +345,85 @@ func (h *AuthHandler) ValidatePromoCode(c *gin.Context) {
BonusAmount: promoCode.BonusAmount,
})
}
+
+// ForgotPasswordRequest 忘记密码请求
+type ForgotPasswordRequest struct {
+ Email string `json:"email" binding:"required,email"`
+ TurnstileToken string `json:"turnstile_token"`
+}
+
+// ForgotPasswordResponse 忘记密码响应
+type ForgotPasswordResponse struct {
+ Message string `json:"message"`
+}
+
+// ForgotPassword 请求密码重置
+// POST /api/v1/auth/forgot-password
+func (h *AuthHandler) ForgotPassword(c *gin.Context) {
+ var req ForgotPasswordRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ // Turnstile 验证
+ if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, ip.GetClientIP(c)); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ // Build frontend base URL from request
+ scheme := "https"
+ if c.Request.TLS == nil {
+ // Check X-Forwarded-Proto header (common in reverse proxy setups)
+ if proto := c.GetHeader("X-Forwarded-Proto"); proto != "" {
+ scheme = proto
+ } else {
+ scheme = "http"
+ }
+ }
+ frontendBaseURL := scheme + "://" + c.Request.Host
+
+ // Request password reset (async)
+ // Note: This returns success even if email doesn't exist (to prevent enumeration)
+ if err := h.authService.RequestPasswordResetAsync(c.Request.Context(), req.Email, frontendBaseURL); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, ForgotPasswordResponse{
+ Message: "If your email is registered, you will receive a password reset link shortly.",
+ })
+}
+
+// ResetPasswordRequest 重置密码请求
+type ResetPasswordRequest struct {
+ Email string `json:"email" binding:"required,email"`
+ Token string `json:"token" binding:"required"`
+ NewPassword string `json:"new_password" binding:"required,min=6"`
+}
+
+// ResetPasswordResponse 重置密码响应
+type ResetPasswordResponse struct {
+ Message string `json:"message"`
+}
+
+// ResetPassword 重置密码
+// POST /api/v1/auth/reset-password
+func (h *AuthHandler) ResetPassword(c *gin.Context) {
+ var req ResetPasswordRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ // Reset password
+ if err := h.authService.ResetPassword(c.Request.Context(), req.Email, req.Token, req.NewPassword); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, ResetPasswordResponse{
+ Message: "Your password has been reset successfully. You can now log in with your new password.",
+ })
+}
diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go
index 01f39478..fc7b1349 100644
--- a/backend/internal/handler/dto/settings.go
+++ b/backend/internal/handler/dto/settings.go
@@ -2,9 +2,12 @@ package dto
// SystemSettings represents the admin settings API response payload.
type SystemSettings struct {
- RegistrationEnabled bool `json:"registration_enabled"`
- EmailVerifyEnabled bool `json:"email_verify_enabled"`
- PromoCodeEnabled bool `json:"promo_code_enabled"`
+ RegistrationEnabled bool `json:"registration_enabled"`
+ EmailVerifyEnabled bool `json:"email_verify_enabled"`
+ PromoCodeEnabled bool `json:"promo_code_enabled"`
+ PasswordResetEnabled bool `json:"password_reset_enabled"`
+ TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
+ TotpEncryptionKeyConfigured bool `json:"totp_encryption_key_configured"` // TOTP 加密密钥是否已配置
SMTPHost string `json:"smtp_host"`
SMTPPort int `json:"smtp_port"`
@@ -54,21 +57,23 @@ type SystemSettings struct {
}
type PublicSettings struct {
- RegistrationEnabled bool `json:"registration_enabled"`
- EmailVerifyEnabled bool `json:"email_verify_enabled"`
- PromoCodeEnabled bool `json:"promo_code_enabled"`
- TurnstileEnabled bool `json:"turnstile_enabled"`
- TurnstileSiteKey string `json:"turnstile_site_key"`
- SiteName string `json:"site_name"`
- SiteLogo string `json:"site_logo"`
- SiteSubtitle string `json:"site_subtitle"`
- APIBaseURL string `json:"api_base_url"`
- ContactInfo string `json:"contact_info"`
- DocURL string `json:"doc_url"`
- HomeContent string `json:"home_content"`
- HideCcsImportButton bool `json:"hide_ccs_import_button"`
- LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
- Version string `json:"version"`
+ RegistrationEnabled bool `json:"registration_enabled"`
+ EmailVerifyEnabled bool `json:"email_verify_enabled"`
+ PromoCodeEnabled bool `json:"promo_code_enabled"`
+ PasswordResetEnabled bool `json:"password_reset_enabled"`
+ TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
+ TurnstileEnabled bool `json:"turnstile_enabled"`
+ TurnstileSiteKey string `json:"turnstile_site_key"`
+ SiteName string `json:"site_name"`
+ SiteLogo string `json:"site_logo"`
+ SiteSubtitle string `json:"site_subtitle"`
+ APIBaseURL string `json:"api_base_url"`
+ ContactInfo string `json:"contact_info"`
+ DocURL string `json:"doc_url"`
+ HomeContent string `json:"home_content"`
+ HideCcsImportButton bool `json:"hide_ccs_import_button"`
+ LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
+ Version string `json:"version"`
}
// StreamTimeoutSettings 流超时处理配置 DTO
diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go
index f38fea39..8fd6d2b9 100644
--- a/backend/internal/handler/gateway_handler.go
+++ b/backend/internal/handler/gateway_handler.go
@@ -209,17 +209,20 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
account := selection.Account
setOpsSelectedAccount(c, account.ID)
- // 检查预热请求拦截(在账号选择后、转发前检查)
- if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
- if selection.Acquired && selection.ReleaseFunc != nil {
- selection.ReleaseFunc()
+ // 检查请求拦截(预热请求、SUGGESTION MODE等)
+ if account.IsInterceptWarmupEnabled() {
+ interceptType := detectInterceptType(body)
+ if interceptType != InterceptTypeNone {
+ if selection.Acquired && selection.ReleaseFunc != nil {
+ selection.ReleaseFunc()
+ }
+ if reqStream {
+ sendMockInterceptStream(c, reqModel, interceptType)
+ } else {
+ sendMockInterceptResponse(c, reqModel, interceptType)
+ }
+ return
}
- if reqStream {
- sendMockWarmupStream(c, reqModel)
- } else {
- sendMockWarmupResponse(c, reqModel)
- }
- return
}
// 3. 获取账号并发槽位
@@ -344,17 +347,20 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
account := selection.Account
setOpsSelectedAccount(c, account.ID)
- // 检查预热请求拦截(在账号选择后、转发前检查)
- if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
- if selection.Acquired && selection.ReleaseFunc != nil {
- selection.ReleaseFunc()
+ // 检查请求拦截(预热请求、SUGGESTION MODE等)
+ if account.IsInterceptWarmupEnabled() {
+ interceptType := detectInterceptType(body)
+ if interceptType != InterceptTypeNone {
+ if selection.Acquired && selection.ReleaseFunc != nil {
+ selection.ReleaseFunc()
+ }
+ if reqStream {
+ sendMockInterceptStream(c, reqModel, interceptType)
+ } else {
+ sendMockInterceptResponse(c, reqModel, interceptType)
+ }
+ return
}
- if reqStream {
- sendMockWarmupStream(c, reqModel)
- } else {
- sendMockWarmupResponse(c, reqModel)
- }
- return
}
// 3. 获取账号并发槽位
@@ -768,17 +774,30 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
}
}
-// isWarmupRequest 检测是否为预热请求(标题生成、Warmup等)
-func isWarmupRequest(body []byte) bool {
- // 快速检查:如果body不包含关键字,直接返回false
+// InterceptType 表示请求拦截类型
+type InterceptType int
+
+const (
+ InterceptTypeNone InterceptType = iota
+ InterceptTypeWarmup // 预热请求(返回 "New Conversation")
+ InterceptTypeSuggestionMode // SUGGESTION MODE(返回空字符串)
+)
+
+// detectInterceptType 检测请求是否需要拦截,返回拦截类型
+func detectInterceptType(body []byte) InterceptType {
+ // 快速检查:如果不包含任何关键字,直接返回
bodyStr := string(body)
- if !strings.Contains(bodyStr, "title") && !strings.Contains(bodyStr, "Warmup") {
- return false
+ hasSuggestionMode := strings.Contains(bodyStr, "[SUGGESTION MODE:")
+ hasWarmupKeyword := strings.Contains(bodyStr, "title") || strings.Contains(bodyStr, "Warmup")
+
+ if !hasSuggestionMode && !hasWarmupKeyword {
+ return InterceptTypeNone
}
- // 解析完整请求
+ // 解析请求(只解析一次)
var req struct {
Messages []struct {
+ Role string `json:"role"`
Content []struct {
Type string `json:"type"`
Text string `json:"text"`
@@ -789,43 +808,71 @@ func isWarmupRequest(body []byte) bool {
} `json:"system"`
}
if err := json.Unmarshal(body, &req); err != nil {
- return false
+ return InterceptTypeNone
}
- // 检查 messages 中的标题提示模式
- for _, msg := range req.Messages {
- for _, content := range msg.Content {
- if content.Type == "text" {
- if strings.Contains(content.Text, "Please write a 5-10 word title for the following conversation:") ||
- content.Text == "Warmup" {
- return true
+ // 检查 SUGGESTION MODE(最后一条 user 消息)
+ if hasSuggestionMode && len(req.Messages) > 0 {
+ lastMsg := req.Messages[len(req.Messages)-1]
+ if lastMsg.Role == "user" && len(lastMsg.Content) > 0 &&
+ lastMsg.Content[0].Type == "text" &&
+ strings.HasPrefix(lastMsg.Content[0].Text, "[SUGGESTION MODE:") {
+ return InterceptTypeSuggestionMode
+ }
+ }
+
+ // 检查 Warmup 请求
+ if hasWarmupKeyword {
+ // 检查 messages 中的标题提示模式
+ for _, msg := range req.Messages {
+ for _, content := range msg.Content {
+ if content.Type == "text" {
+ if strings.Contains(content.Text, "Please write a 5-10 word title for the following conversation:") ||
+ content.Text == "Warmup" {
+ return InterceptTypeWarmup
+ }
}
}
}
+ // 检查 system 中的标题提取模式
+ for _, sys := range req.System {
+ if strings.Contains(sys.Text, "nalyze if this message indicates a new conversation topic. If it does, extract a 2-3 word title") {
+ return InterceptTypeWarmup
+ }
+ }
}
- // 检查 system 中的标题提取模式
- for _, system := range req.System {
- if strings.Contains(system.Text, "nalyze if this message indicates a new conversation topic. If it does, extract a 2-3 word title") {
- return true
- }
- }
-
- return false
+ return InterceptTypeNone
}
-// sendMockWarmupStream 发送流式 mock 响应(用于预热请求拦截)
-func sendMockWarmupStream(c *gin.Context, model string) {
+// sendMockInterceptStream 发送流式 mock 响应(用于请求拦截)
+func sendMockInterceptStream(c *gin.Context, model string, interceptType InterceptType) {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
+ // 根据拦截类型决定响应内容
+ var msgID string
+ var outputTokens int
+ var textDeltas []string
+
+ switch interceptType {
+ case InterceptTypeSuggestionMode:
+ msgID = "msg_mock_suggestion"
+ outputTokens = 1
+ textDeltas = []string{""} // 空内容
+ default: // InterceptTypeWarmup
+ msgID = "msg_mock_warmup"
+ outputTokens = 2
+ textDeltas = []string{"New", " Conversation"}
+ }
+
// Build message_start event with proper JSON marshaling
messageStart := map[string]any{
"type": "message_start",
"message": map[string]any{
- "id": "msg_mock_warmup",
+ "id": msgID,
"type": "message",
"role": "assistant",
"model": model,
@@ -840,16 +887,46 @@ func sendMockWarmupStream(c *gin.Context, model string) {
}
messageStartJSON, _ := json.Marshal(messageStart)
+ // Build events
events := []string{
`event: message_start` + "\n" + `data: ` + string(messageStartJSON),
`event: content_block_start` + "\n" + `data: {"content_block":{"text":"","type":"text"},"index":0,"type":"content_block_start"}`,
- `event: content_block_delta` + "\n" + `data: {"delta":{"text":"New","type":"text_delta"},"index":0,"type":"content_block_delta"}`,
- `event: content_block_delta` + "\n" + `data: {"delta":{"text":" Conversation","type":"text_delta"},"index":0,"type":"content_block_delta"}`,
- `event: content_block_stop` + "\n" + `data: {"index":0,"type":"content_block_stop"}`,
- `event: message_delta` + "\n" + `data: {"delta":{"stop_reason":"end_turn","stop_sequence":null},"type":"message_delta","usage":{"input_tokens":10,"output_tokens":2}}`,
- `event: message_stop` + "\n" + `data: {"type":"message_stop"}`,
}
+ // Add text deltas
+ for _, text := range textDeltas {
+ delta := map[string]any{
+ "type": "content_block_delta",
+ "index": 0,
+ "delta": map[string]string{
+ "type": "text_delta",
+ "text": text,
+ },
+ }
+ deltaJSON, _ := json.Marshal(delta)
+ events = append(events, `event: content_block_delta`+"\n"+`data: `+string(deltaJSON))
+ }
+
+ // Add final events
+ messageDelta := map[string]any{
+ "type": "message_delta",
+ "delta": map[string]any{
+ "stop_reason": "end_turn",
+ "stop_sequence": nil,
+ },
+ "usage": map[string]int{
+ "input_tokens": 10,
+ "output_tokens": outputTokens,
+ },
+ }
+ messageDeltaJSON, _ := json.Marshal(messageDelta)
+
+ events = append(events,
+ `event: content_block_stop`+"\n"+`data: {"index":0,"type":"content_block_stop"}`,
+ `event: message_delta`+"\n"+`data: `+string(messageDeltaJSON),
+ `event: message_stop`+"\n"+`data: {"type":"message_stop"}`,
+ )
+
for _, event := range events {
_, _ = c.Writer.WriteString(event + "\n\n")
c.Writer.Flush()
@@ -857,18 +934,32 @@ func sendMockWarmupStream(c *gin.Context, model string) {
}
}
-// sendMockWarmupResponse 发送非流式 mock 响应(用于预热请求拦截)
-func sendMockWarmupResponse(c *gin.Context, model string) {
+// sendMockInterceptResponse 发送非流式 mock 响应(用于请求拦截)
+func sendMockInterceptResponse(c *gin.Context, model string, interceptType InterceptType) {
+ var msgID, text string
+ var outputTokens int
+
+ switch interceptType {
+ case InterceptTypeSuggestionMode:
+ msgID = "msg_mock_suggestion"
+ text = ""
+ outputTokens = 1
+ default: // InterceptTypeWarmup
+ msgID = "msg_mock_warmup"
+ text = "New Conversation"
+ outputTokens = 2
+ }
+
c.JSON(http.StatusOK, gin.H{
- "id": "msg_mock_warmup",
+ "id": msgID,
"type": "message",
"role": "assistant",
"model": model,
- "content": []gin.H{{"type": "text", "text": "New Conversation"}},
+ "content": []gin.H{{"type": "text", "text": text}},
"stop_reason": "end_turn",
"usage": gin.H{
"input_tokens": 10,
- "output_tokens": 2,
+ "output_tokens": outputTokens,
},
})
}
diff --git a/backend/internal/handler/gemini_cli_session_test.go b/backend/internal/handler/gemini_cli_session_test.go
new file mode 100644
index 00000000..0b37f5f2
--- /dev/null
+++ b/backend/internal/handler/gemini_cli_session_test.go
@@ -0,0 +1,122 @@
+//go:build unit
+
+package handler
+
+import (
+ "crypto/sha256"
+ "encoding/hex"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+)
+
+func TestExtractGeminiCLISessionHash(t *testing.T) {
+ tests := []struct {
+ name string
+ body string
+ privilegedUserID string
+ wantEmpty bool
+ wantHash string
+ }{
+ {
+ name: "with privileged-user-id and tmp dir",
+ body: `{"contents":[{"parts":[{"text":"The project's temporary directory is: /Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740"}]}]}`,
+ privilegedUserID: "90785f52-8bbe-4b17-b111-a1ddea1636c3",
+ wantEmpty: false,
+ wantHash: func() string {
+ combined := "90785f52-8bbe-4b17-b111-a1ddea1636c3:f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740"
+ hash := sha256.Sum256([]byte(combined))
+ return hex.EncodeToString(hash[:])
+ }(),
+ },
+ {
+ name: "without privileged-user-id but with tmp dir",
+ body: `{"contents":[{"parts":[{"text":"The project's temporary directory is: /Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740"}]}]}`,
+ privilegedUserID: "",
+ wantEmpty: false,
+ wantHash: "f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740",
+ },
+ {
+ name: "without tmp dir",
+ body: `{"contents":[{"parts":[{"text":"Hello world"}]}]}`,
+ privilegedUserID: "90785f52-8bbe-4b17-b111-a1ddea1636c3",
+ wantEmpty: true,
+ },
+ {
+ name: "empty body",
+ body: "",
+ privilegedUserID: "90785f52-8bbe-4b17-b111-a1ddea1636c3",
+ wantEmpty: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ // 创建测试上下文
+ w := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(w)
+ c.Request = httptest.NewRequest("POST", "/test", nil)
+ if tt.privilegedUserID != "" {
+ c.Request.Header.Set("x-gemini-api-privileged-user-id", tt.privilegedUserID)
+ }
+
+ // 调用函数
+ result := extractGeminiCLISessionHash(c, []byte(tt.body))
+
+ // 验证结果
+ if tt.wantEmpty {
+ require.Empty(t, result, "expected empty session hash")
+ } else {
+ require.NotEmpty(t, result, "expected non-empty session hash")
+ require.Equal(t, tt.wantHash, result, "session hash mismatch")
+ }
+ })
+ }
+}
+
+func TestGeminiCLITmpDirRegex(t *testing.T) {
+ tests := []struct {
+ name string
+ input string
+ wantMatch bool
+ wantHash string
+ }{
+ {
+ name: "valid tmp dir path",
+ input: "/Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740",
+ wantMatch: true,
+ wantHash: "f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740",
+ },
+ {
+ name: "valid tmp dir path in text",
+ input: "The project's temporary directory is: /Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740\nOther text",
+ wantMatch: true,
+ wantHash: "f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740",
+ },
+ {
+ name: "invalid hash length",
+ input: "/Users/ianshaw/.gemini/tmp/abc123",
+ wantMatch: false,
+ },
+ {
+ name: "no tmp dir",
+ input: "Hello world",
+ wantMatch: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ match := geminiCLITmpDirRegex.FindStringSubmatch(tt.input)
+ if tt.wantMatch {
+ require.NotNil(t, match, "expected regex to match")
+ require.Len(t, match, 2, "expected 2 capture groups")
+ require.Equal(t, tt.wantHash, match[1], "hash mismatch")
+ } else {
+ require.Nil(t, match, "expected regex not to match")
+ }
+ })
+ }
+}
diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go
index c7646b38..32f83013 100644
--- a/backend/internal/handler/gemini_v1beta_handler.go
+++ b/backend/internal/handler/gemini_v1beta_handler.go
@@ -1,11 +1,15 @@
package handler
import (
+ "bytes"
"context"
+ "crypto/sha256"
+ "encoding/hex"
"errors"
"io"
"log"
"net/http"
+ "regexp"
"strings"
"time"
@@ -19,6 +23,17 @@ import (
"github.com/gin-gonic/gin"
)
+// geminiCLITmpDirRegex 用于从 Gemini CLI 请求体中提取 tmp 目录的哈希值
+// 匹配格式: /Users/xxx/.gemini/tmp/[64位十六进制哈希]
+var geminiCLITmpDirRegex = regexp.MustCompile(`/\.gemini/tmp/([A-Fa-f0-9]{64})`)
+
+func isGeminiCLIRequest(c *gin.Context, body []byte) bool {
+ if strings.TrimSpace(c.GetHeader("x-gemini-api-privileged-user-id")) != "" {
+ return true
+ }
+ return geminiCLITmpDirRegex.Match(body)
+}
+
// GeminiV1BetaListModels proxies:
// GET /v1beta/models
func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) {
@@ -214,12 +229,26 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
}
// 3) select account (sticky session based on request body)
- parsedReq, _ := service.ParseGatewayRequest(body)
- sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
+ // 优先使用 Gemini CLI 的会话标识(privileged-user-id + tmp 目录哈希)
+ sessionHash := extractGeminiCLISessionHash(c, body)
+ if sessionHash == "" {
+ // Fallback: 使用通用的会话哈希生成逻辑(适用于其他客户端)
+ parsedReq, _ := service.ParseGatewayRequest(body)
+ sessionHash = h.gatewayService.GenerateSessionHash(parsedReq)
+ }
sessionKey := sessionHash
if sessionHash != "" {
sessionKey = "gemini:" + sessionHash
}
+
+ // 查询粘性会话绑定的账号 ID(用于检测账号切换)
+ var sessionBoundAccountID int64
+ if sessionKey != "" {
+ sessionBoundAccountID, _ = h.gatewayService.GetCachedSessionAccountID(c.Request.Context(), apiKey.GroupID, sessionKey)
+ }
+ isCLI := isGeminiCLIRequest(c, body)
+ cleanedForUnknownBinding := false
+
maxAccountSwitches := h.maxAccountSwitchesGemini
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
@@ -238,6 +267,24 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
account := selection.Account
setOpsSelectedAccount(c, account.ID)
+ // 检测账号切换:如果粘性会话绑定的账号与当前选择的账号不同,清除 thoughtSignature
+ // 注意:Gemini 原生 API 的 thoughtSignature 与具体上游账号强相关;跨账号透传会导致 400。
+ if sessionBoundAccountID > 0 && sessionBoundAccountID != account.ID {
+ log.Printf("[Gemini] Sticky session account switched: %d -> %d, cleaning thoughtSignature", sessionBoundAccountID, account.ID)
+ body = service.CleanGeminiNativeThoughtSignatures(body)
+ sessionBoundAccountID = account.ID
+ } else if sessionKey != "" && sessionBoundAccountID == 0 && isCLI && !cleanedForUnknownBinding && bytes.Contains(body, []byte(`"thoughtSignature"`)) {
+ // 无缓存绑定但请求里已有 thoughtSignature:常见于缓存丢失/TTL 过期后,CLI 继续携带旧签名。
+ // 为避免第一次转发就 400,这里做一次确定性清理,让新账号重新生成签名链路。
+ log.Printf("[Gemini] Sticky session binding missing for CLI request, cleaning thoughtSignature proactively")
+ body = service.CleanGeminiNativeThoughtSignatures(body)
+ cleanedForUnknownBinding = true
+ sessionBoundAccountID = account.ID
+ } else if sessionBoundAccountID == 0 {
+ // 记录本次请求中首次选择到的账号,便于同一请求内 failover 时检测切换。
+ sessionBoundAccountID = account.ID
+ }
+
// 4) account concurrency slot
accountReleaseFunc := selection.ReleaseFunc
if !selection.Acquired {
@@ -433,3 +480,38 @@ func shouldFallbackGeminiModels(res *service.UpstreamHTTPResult) bool {
}
return false
}
+
+// extractGeminiCLISessionHash 从 Gemini CLI 请求中提取会话标识。
+// 组合 x-gemini-api-privileged-user-id header 和请求体中的 tmp 目录哈希。
+//
+// 会话标识生成策略:
+// 1. 从请求体中提取 tmp 目录哈希(64位十六进制)
+// 2. 从 header 中提取 privileged-user-id(UUID)
+// 3. 组合两者生成 SHA256 哈希作为最终的会话标识
+//
+// 如果找不到 tmp 目录哈希,返回空字符串(不使用粘性会话)。
+//
+// extractGeminiCLISessionHash extracts session identifier from Gemini CLI requests.
+// Combines x-gemini-api-privileged-user-id header with tmp directory hash from request body.
+func extractGeminiCLISessionHash(c *gin.Context, body []byte) string {
+ // 1. 从请求体中提取 tmp 目录哈希
+ match := geminiCLITmpDirRegex.FindSubmatch(body)
+ if len(match) < 2 {
+ return "" // 没有找到 tmp 目录,不使用粘性会话
+ }
+ tmpDirHash := string(match[1])
+
+ // 2. 提取 privileged-user-id
+ privilegedUserID := strings.TrimSpace(c.GetHeader("x-gemini-api-privileged-user-id"))
+
+ // 3. 组合生成最终的 session hash
+ if privilegedUserID != "" {
+ // 组合两个标识符:privileged-user-id + tmp 目录哈希
+ combined := privilegedUserID + ":" + tmpDirHash
+ hash := sha256.Sum256([]byte(combined))
+ return hex.EncodeToString(hash[:])
+ }
+
+ // 如果没有 privileged-user-id,直接使用 tmp 目录哈希
+ return tmpDirHash
+}
diff --git a/backend/internal/handler/handler.go b/backend/internal/handler/handler.go
index 5b1b317d..907c314d 100644
--- a/backend/internal/handler/handler.go
+++ b/backend/internal/handler/handler.go
@@ -37,6 +37,7 @@ type Handlers struct {
Gateway *GatewayHandler
OpenAIGateway *OpenAIGatewayHandler
Setting *SettingHandler
+ Totp *TotpHandler
}
// BuildInfo contains build-time information
diff --git a/backend/internal/handler/setting_handler.go b/backend/internal/handler/setting_handler.go
index 8723c746..9c0bde33 100644
--- a/backend/internal/handler/setting_handler.go
+++ b/backend/internal/handler/setting_handler.go
@@ -32,20 +32,21 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
}
response.Success(c, dto.PublicSettings{
- RegistrationEnabled: settings.RegistrationEnabled,
- EmailVerifyEnabled: settings.EmailVerifyEnabled,
- PromoCodeEnabled: settings.PromoCodeEnabled,
- TurnstileEnabled: settings.TurnstileEnabled,
- TurnstileSiteKey: settings.TurnstileSiteKey,
- SiteName: settings.SiteName,
- SiteLogo: settings.SiteLogo,
- SiteSubtitle: settings.SiteSubtitle,
- APIBaseURL: settings.APIBaseURL,
- ContactInfo: settings.ContactInfo,
- DocURL: settings.DocURL,
- HomeContent: settings.HomeContent,
- HideCcsImportButton: settings.HideCcsImportButton,
- LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
- Version: h.version,
+ RegistrationEnabled: settings.RegistrationEnabled,
+ EmailVerifyEnabled: settings.EmailVerifyEnabled,
+ PromoCodeEnabled: settings.PromoCodeEnabled,
+ PasswordResetEnabled: settings.PasswordResetEnabled,
+ TurnstileEnabled: settings.TurnstileEnabled,
+ TurnstileSiteKey: settings.TurnstileSiteKey,
+ SiteName: settings.SiteName,
+ SiteLogo: settings.SiteLogo,
+ SiteSubtitle: settings.SiteSubtitle,
+ APIBaseURL: settings.APIBaseURL,
+ ContactInfo: settings.ContactInfo,
+ DocURL: settings.DocURL,
+ HomeContent: settings.HomeContent,
+ HideCcsImportButton: settings.HideCcsImportButton,
+ LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
+ Version: h.version,
})
}
diff --git a/backend/internal/handler/totp_handler.go b/backend/internal/handler/totp_handler.go
new file mode 100644
index 00000000..5c5eb567
--- /dev/null
+++ b/backend/internal/handler/totp_handler.go
@@ -0,0 +1,181 @@
+package handler
+
+import (
+ "github.com/gin-gonic/gin"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/response"
+ middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+)
+
+// TotpHandler handles TOTP-related requests
+type TotpHandler struct {
+ totpService *service.TotpService
+}
+
+// NewTotpHandler creates a new TotpHandler
+func NewTotpHandler(totpService *service.TotpService) *TotpHandler {
+ return &TotpHandler{
+ totpService: totpService,
+ }
+}
+
+// TotpStatusResponse represents the TOTP status response
+type TotpStatusResponse struct {
+ Enabled bool `json:"enabled"`
+ EnabledAt *int64 `json:"enabled_at,omitempty"` // Unix timestamp
+ FeatureEnabled bool `json:"feature_enabled"`
+}
+
+// GetStatus returns the TOTP status for the current user
+// GET /api/v1/user/totp/status
+func (h *TotpHandler) GetStatus(c *gin.Context) {
+ subject, ok := middleware2.GetAuthSubjectFromContext(c)
+ if !ok {
+ response.Unauthorized(c, "User not authenticated")
+ return
+ }
+
+ status, err := h.totpService.GetStatus(c.Request.Context(), subject.UserID)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ resp := TotpStatusResponse{
+ Enabled: status.Enabled,
+ FeatureEnabled: status.FeatureEnabled,
+ }
+
+ if status.EnabledAt != nil {
+ ts := status.EnabledAt.Unix()
+ resp.EnabledAt = &ts
+ }
+
+ response.Success(c, resp)
+}
+
+// TotpSetupRequest represents the request to initiate TOTP setup
+type TotpSetupRequest struct {
+ EmailCode string `json:"email_code"`
+ Password string `json:"password"`
+}
+
+// TotpSetupResponse represents the TOTP setup response
+type TotpSetupResponse struct {
+ Secret string `json:"secret"`
+ QRCodeURL string `json:"qr_code_url"`
+ SetupToken string `json:"setup_token"`
+ Countdown int `json:"countdown"`
+}
+
+// InitiateSetup starts the TOTP setup process
+// POST /api/v1/user/totp/setup
+func (h *TotpHandler) InitiateSetup(c *gin.Context) {
+ subject, ok := middleware2.GetAuthSubjectFromContext(c)
+ if !ok {
+ response.Unauthorized(c, "User not authenticated")
+ return
+ }
+
+ var req TotpSetupRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ // Allow empty body (optional params)
+ req = TotpSetupRequest{}
+ }
+
+ result, err := h.totpService.InitiateSetup(c.Request.Context(), subject.UserID, req.EmailCode, req.Password)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, TotpSetupResponse{
+ Secret: result.Secret,
+ QRCodeURL: result.QRCodeURL,
+ SetupToken: result.SetupToken,
+ Countdown: result.Countdown,
+ })
+}
+
+// TotpEnableRequest represents the request to enable TOTP
+type TotpEnableRequest struct {
+ TotpCode string `json:"totp_code" binding:"required,len=6"`
+ SetupToken string `json:"setup_token" binding:"required"`
+}
+
+// Enable completes the TOTP setup
+// POST /api/v1/user/totp/enable
+func (h *TotpHandler) Enable(c *gin.Context) {
+ subject, ok := middleware2.GetAuthSubjectFromContext(c)
+ if !ok {
+ response.Unauthorized(c, "User not authenticated")
+ return
+ }
+
+ var req TotpEnableRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ if err := h.totpService.CompleteSetup(c.Request.Context(), subject.UserID, req.TotpCode, req.SetupToken); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, gin.H{"success": true})
+}
+
+// TotpDisableRequest represents the request to disable TOTP
+type TotpDisableRequest struct {
+ EmailCode string `json:"email_code"`
+ Password string `json:"password"`
+}
+
+// Disable disables TOTP for the current user
+// POST /api/v1/user/totp/disable
+func (h *TotpHandler) Disable(c *gin.Context) {
+ subject, ok := middleware2.GetAuthSubjectFromContext(c)
+ if !ok {
+ response.Unauthorized(c, "User not authenticated")
+ return
+ }
+
+ var req TotpDisableRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ if err := h.totpService.Disable(c.Request.Context(), subject.UserID, req.EmailCode, req.Password); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, gin.H{"success": true})
+}
+
+// GetVerificationMethod returns the verification method for TOTP operations
+// GET /api/v1/user/totp/verification-method
+func (h *TotpHandler) GetVerificationMethod(c *gin.Context) {
+ method := h.totpService.GetVerificationMethod(c.Request.Context())
+ response.Success(c, method)
+}
+
+// SendVerifyCode sends an email verification code for TOTP operations
+// POST /api/v1/user/totp/send-code
+func (h *TotpHandler) SendVerifyCode(c *gin.Context) {
+ subject, ok := middleware2.GetAuthSubjectFromContext(c)
+ if !ok {
+ response.Unauthorized(c, "User not authenticated")
+ return
+ }
+
+ if err := h.totpService.SendVerifyCode(c.Request.Context(), subject.UserID); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, gin.H{"success": true})
+}
diff --git a/backend/internal/handler/wire.go b/backend/internal/handler/wire.go
index 2af7905e..92e8edeb 100644
--- a/backend/internal/handler/wire.go
+++ b/backend/internal/handler/wire.go
@@ -70,6 +70,7 @@ func ProvideHandlers(
gatewayHandler *GatewayHandler,
openaiGatewayHandler *OpenAIGatewayHandler,
settingHandler *SettingHandler,
+ totpHandler *TotpHandler,
) *Handlers {
return &Handlers{
Auth: authHandler,
@@ -82,6 +83,7 @@ func ProvideHandlers(
Gateway: gatewayHandler,
OpenAIGateway: openaiGatewayHandler,
Setting: settingHandler,
+ Totp: totpHandler,
}
}
@@ -96,6 +98,7 @@ var ProviderSet = wire.NewSet(
NewSubscriptionHandler,
NewGatewayHandler,
NewOpenAIGatewayHandler,
+ NewTotpHandler,
ProvideSettingHandler,
// Admin handlers
diff --git a/backend/internal/pkg/antigravity/request_transformer.go b/backend/internal/pkg/antigravity/request_transformer.go
index 637a4ea8..63f6ee7c 100644
--- a/backend/internal/pkg/antigravity/request_transformer.go
+++ b/backend/internal/pkg/antigravity/request_transformer.go
@@ -7,13 +7,11 @@ import (
"fmt"
"log"
"math/rand"
- "os"
"strconv"
"strings"
"sync"
"time"
- "github.com/gin-gonic/gin"
"github.com/google/uuid"
)
@@ -369,8 +367,10 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
Text: block.Thinking,
Thought: true,
}
- // 保留原有 signature(Claude 模型需要有效的 signature)
- if block.Signature != "" {
+ // signature 处理:
+ // - Claude 模型(allowDummyThought=false):必须是上游返回的真实 signature(dummy 视为缺失)
+ // - Gemini 模型(allowDummyThought=true):优先透传真实 signature,缺失时使用 dummy signature
+ if block.Signature != "" && (allowDummyThought || block.Signature != dummyThoughtSignature) {
part.ThoughtSignature = block.Signature
} else if !allowDummyThought {
// Claude 模型需要有效 signature;在缺失时降级为普通文本,并在上层禁用 thinking mode。
@@ -409,12 +409,12 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu
},
}
// tool_use 的 signature 处理:
- // - Gemini 模型:使用 dummy signature(跳过 thought_signature 校验)
- // - Claude 模型:透传上游返回的真实 signature(Vertex/Google 需要完整签名链路)
- if allowDummyThought {
- part.ThoughtSignature = dummyThoughtSignature
- } else if block.Signature != "" && block.Signature != dummyThoughtSignature {
+ // - Claude 模型(allowDummyThought=false):必须是上游返回的真实 signature(dummy 视为缺失)
+ // - Gemini 模型(allowDummyThought=true):优先透传真实 signature,缺失时使用 dummy signature
+ if block.Signature != "" && (allowDummyThought || block.Signature != dummyThoughtSignature) {
part.ThoughtSignature = block.Signature
+ } else if allowDummyThought {
+ part.ThoughtSignature = dummyThoughtSignature
}
parts = append(parts, part)
@@ -594,11 +594,14 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
}
// 清理 JSON Schema
- params := cleanJSONSchema(inputSchema)
+ // 1. 深度清理 [undefined] 值
+ DeepCleanUndefined(inputSchema)
+ // 2. 转换为符合 Gemini v1internal 的 schema
+ params := CleanJSONSchema(inputSchema)
// 为 nil schema 提供默认值
if params == nil {
params = map[string]any{
- "type": "OBJECT",
+ "type": "object", // lowercase type
"properties": map[string]any{},
}
}
@@ -631,236 +634,3 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
FunctionDeclarations: funcDecls,
}}
}
-
-// cleanJSONSchema 清理 JSON Schema,移除 Antigravity/Gemini 不支持的字段
-// 参考 proxycast 的实现,确保 schema 符合 JSON Schema draft 2020-12
-func cleanJSONSchema(schema map[string]any) map[string]any {
- if schema == nil {
- return nil
- }
- cleaned := cleanSchemaValue(schema, "$")
- result, ok := cleaned.(map[string]any)
- if !ok {
- return nil
- }
-
- // 确保有 type 字段(默认 OBJECT)
- if _, hasType := result["type"]; !hasType {
- result["type"] = "OBJECT"
- }
-
- // 确保有 properties 字段(默认空对象)
- if _, hasProps := result["properties"]; !hasProps {
- result["properties"] = make(map[string]any)
- }
-
- // 验证 required 中的字段都存在于 properties 中
- if required, ok := result["required"].([]any); ok {
- if props, ok := result["properties"].(map[string]any); ok {
- validRequired := make([]any, 0, len(required))
- for _, r := range required {
- if reqName, ok := r.(string); ok {
- if _, exists := props[reqName]; exists {
- validRequired = append(validRequired, r)
- }
- }
- }
- if len(validRequired) > 0 {
- result["required"] = validRequired
- } else {
- delete(result, "required")
- }
- }
- }
-
- return result
-}
-
-var schemaValidationKeys = map[string]bool{
- "minLength": true,
- "maxLength": true,
- "pattern": true,
- "minimum": true,
- "maximum": true,
- "exclusiveMinimum": true,
- "exclusiveMaximum": true,
- "multipleOf": true,
- "uniqueItems": true,
- "minItems": true,
- "maxItems": true,
- "minProperties": true,
- "maxProperties": true,
- "patternProperties": true,
- "propertyNames": true,
- "dependencies": true,
- "dependentSchemas": true,
- "dependentRequired": true,
-}
-
-var warnedSchemaKeys sync.Map
-
-func schemaCleaningWarningsEnabled() bool {
- // 可通过环境变量强制开关,方便排查:SUB2API_SCHEMA_CLEAN_WARN=true/false
- if v := strings.TrimSpace(os.Getenv("SUB2API_SCHEMA_CLEAN_WARN")); v != "" {
- switch strings.ToLower(v) {
- case "1", "true", "yes", "on":
- return true
- case "0", "false", "no", "off":
- return false
- }
- }
- // 默认:非 release 模式下输出(debug/test)
- return gin.Mode() != gin.ReleaseMode
-}
-
-func warnSchemaKeyRemovedOnce(key, path string) {
- if !schemaCleaningWarningsEnabled() {
- return
- }
- if !schemaValidationKeys[key] {
- return
- }
- if _, loaded := warnedSchemaKeys.LoadOrStore(key, struct{}{}); loaded {
- return
- }
- log.Printf("[SchemaClean] removed unsupported JSON Schema validation field key=%q path=%q", key, path)
-}
-
-// excludedSchemaKeys 不支持的 schema 字段
-// 基于 Claude API (Vertex AI) 的实际支持情况
-// 支持: type, description, enum, properties, required, additionalProperties, items
-// 不支持: minItems, maxItems, minLength, maxLength, pattern, minimum, maximum 等验证字段
-var excludedSchemaKeys = map[string]bool{
- // 元 schema 字段
- "$schema": true,
- "$id": true,
- "$ref": true,
-
- // 字符串验证(Gemini 不支持)
- "minLength": true,
- "maxLength": true,
- "pattern": true,
-
- // 数字验证(Claude API 通过 Vertex AI 不支持这些字段)
- "minimum": true,
- "maximum": true,
- "exclusiveMinimum": true,
- "exclusiveMaximum": true,
- "multipleOf": true,
-
- // 数组验证(Claude API 通过 Vertex AI 不支持这些字段)
- "uniqueItems": true,
- "minItems": true,
- "maxItems": true,
-
- // 组合 schema(Gemini 不支持)
- "oneOf": true,
- "anyOf": true,
- "allOf": true,
- "not": true,
- "if": true,
- "then": true,
- "else": true,
- "$defs": true,
- "definitions": true,
-
- // 对象验证(仅保留 properties/required/additionalProperties)
- "minProperties": true,
- "maxProperties": true,
- "patternProperties": true,
- "propertyNames": true,
- "dependencies": true,
- "dependentSchemas": true,
- "dependentRequired": true,
-
- // 其他不支持的字段
- "default": true,
- "const": true,
- "examples": true,
- "deprecated": true,
- "readOnly": true,
- "writeOnly": true,
- "contentMediaType": true,
- "contentEncoding": true,
-
- // Claude 特有字段
- "strict": true,
-}
-
-// cleanSchemaValue 递归清理 schema 值
-func cleanSchemaValue(value any, path string) any {
- switch v := value.(type) {
- case map[string]any:
- result := make(map[string]any)
- for k, val := range v {
- // 跳过不支持的字段
- if excludedSchemaKeys[k] {
- warnSchemaKeyRemovedOnce(k, path)
- continue
- }
-
- // 特殊处理 type 字段
- if k == "type" {
- result[k] = cleanTypeValue(val)
- continue
- }
-
- // 特殊处理 format 字段:只保留 Gemini 支持的 format 值
- if k == "format" {
- if formatStr, ok := val.(string); ok {
- // Gemini 只支持 date-time, date, time
- if formatStr == "date-time" || formatStr == "date" || formatStr == "time" {
- result[k] = val
- }
- // 其他 format 值直接跳过
- }
- continue
- }
-
- // 特殊处理 additionalProperties:Claude API 只支持布尔值,不支持 schema 对象
- if k == "additionalProperties" {
- if boolVal, ok := val.(bool); ok {
- result[k] = boolVal
- } else {
- // 如果是 schema 对象,转换为 false(更安全的默认值)
- result[k] = false
- }
- continue
- }
-
- // 递归清理所有值
- result[k] = cleanSchemaValue(val, path+"."+k)
- }
- return result
-
- case []any:
- // 递归处理数组中的每个元素
- cleaned := make([]any, 0, len(v))
- for i, item := range v {
- cleaned = append(cleaned, cleanSchemaValue(item, fmt.Sprintf("%s[%d]", path, i)))
- }
- return cleaned
-
- default:
- return value
- }
-}
-
-// cleanTypeValue 处理 type 字段,转换为大写
-func cleanTypeValue(value any) any {
- switch v := value.(type) {
- case string:
- return strings.ToUpper(v)
- case []any:
- // 联合类型 ["string", "null"] -> 取第一个非 null 类型
- for _, t := range v {
- if ts, ok := t.(string); ok && ts != "null" {
- return strings.ToUpper(ts)
- }
- }
- // 如果只有 null,返回 STRING
- return "STRING"
- default:
- return value
- }
-}
diff --git a/backend/internal/pkg/antigravity/request_transformer_test.go b/backend/internal/pkg/antigravity/request_transformer_test.go
index 60ee6f63..9d62a4a1 100644
--- a/backend/internal/pkg/antigravity/request_transformer_test.go
+++ b/backend/internal/pkg/antigravity/request_transformer_test.go
@@ -100,7 +100,7 @@ func TestBuildParts_ToolUseSignatureHandling(t *testing.T) {
{"type": "tool_use", "id": "t1", "name": "Bash", "input": {"command": "ls"}, "signature": "sig_tool_abc"}
]`
- t.Run("Gemini uses dummy tool_use signature", func(t *testing.T) {
+ t.Run("Gemini preserves provided tool_use signature", func(t *testing.T) {
toolIDToName := make(map[string]string)
parts, _, err := buildParts(json.RawMessage(content), toolIDToName, true)
if err != nil {
@@ -109,6 +109,23 @@ func TestBuildParts_ToolUseSignatureHandling(t *testing.T) {
if len(parts) != 1 || parts[0].FunctionCall == nil {
t.Fatalf("expected 1 functionCall part, got %+v", parts)
}
+ if parts[0].ThoughtSignature != "sig_tool_abc" {
+ t.Fatalf("expected preserved tool signature %q, got %q", "sig_tool_abc", parts[0].ThoughtSignature)
+ }
+ })
+
+ t.Run("Gemini falls back to dummy tool_use signature when missing", func(t *testing.T) {
+ contentNoSig := `[
+ {"type": "tool_use", "id": "t1", "name": "Bash", "input": {"command": "ls"}}
+ ]`
+ toolIDToName := make(map[string]string)
+ parts, _, err := buildParts(json.RawMessage(contentNoSig), toolIDToName, true)
+ if err != nil {
+ t.Fatalf("buildParts() error = %v", err)
+ }
+ if len(parts) != 1 || parts[0].FunctionCall == nil {
+ t.Fatalf("expected 1 functionCall part, got %+v", parts)
+ }
if parts[0].ThoughtSignature != dummyThoughtSignature {
t.Fatalf("expected dummy tool signature %q, got %q", dummyThoughtSignature, parts[0].ThoughtSignature)
}
diff --git a/backend/internal/pkg/antigravity/response_transformer.go b/backend/internal/pkg/antigravity/response_transformer.go
index 04424c03..eb16f09d 100644
--- a/backend/internal/pkg/antigravity/response_transformer.go
+++ b/backend/internal/pkg/antigravity/response_transformer.go
@@ -3,6 +3,7 @@ package antigravity
import (
"encoding/json"
"fmt"
+ "log"
"strings"
)
@@ -19,6 +20,15 @@ func TransformGeminiToClaude(geminiResp []byte, originalModel string) ([]byte, *
v1Resp.Response = directResp
v1Resp.ResponseID = directResp.ResponseID
v1Resp.ModelVersion = directResp.ModelVersion
+ } else if len(v1Resp.Response.Candidates) == 0 {
+ // 第一次解析成功但 candidates 为空,说明是直接的 GeminiResponse 格式
+ var directResp GeminiResponse
+ if err2 := json.Unmarshal(geminiResp, &directResp); err2 != nil {
+ return nil, nil, fmt.Errorf("parse gemini response as direct: %w", err2)
+ }
+ v1Resp.Response = directResp
+ v1Resp.ResponseID = directResp.ResponseID
+ v1Resp.ModelVersion = directResp.ModelVersion
}
// 使用处理器转换
@@ -173,16 +183,20 @@ func (p *NonStreamingProcessor) processPart(part *GeminiPart) {
p.trailingSignature = ""
}
- p.textBuilder += part.Text
-
- // 非空 text 带签名 - 立即刷新并输出空 thinking 块
+ // 非空 text 带签名 - 特殊处理:先输出 text,再输出空 thinking 块
if signature != "" {
- p.flushText()
+ p.contentBlocks = append(p.contentBlocks, ClaudeContentItem{
+ Type: "text",
+ Text: part.Text,
+ })
p.contentBlocks = append(p.contentBlocks, ClaudeContentItem{
Type: "thinking",
Thinking: "",
Signature: signature,
})
+ } else {
+ // 普通 text (无签名) - 累积到 builder
+ p.textBuilder += part.Text
}
}
}
@@ -242,6 +256,14 @@ func (p *NonStreamingProcessor) buildResponse(geminiResp *GeminiResponse, respon
var finishReason string
if len(geminiResp.Candidates) > 0 {
finishReason = geminiResp.Candidates[0].FinishReason
+ if finishReason == "MALFORMED_FUNCTION_CALL" {
+ log.Printf("[Antigravity] MALFORMED_FUNCTION_CALL detected in response for model %s", originalModel)
+ if geminiResp.Candidates[0].Content != nil {
+ if b, err := json.Marshal(geminiResp.Candidates[0].Content); err == nil {
+ log.Printf("[Antigravity] Malformed content: %s", string(b))
+ }
+ }
+ }
}
stopReason := "end_turn"
diff --git a/backend/internal/pkg/antigravity/schema_cleaner.go b/backend/internal/pkg/antigravity/schema_cleaner.go
new file mode 100644
index 00000000..0ee746aa
--- /dev/null
+++ b/backend/internal/pkg/antigravity/schema_cleaner.go
@@ -0,0 +1,519 @@
+package antigravity
+
+import (
+ "fmt"
+ "strings"
+)
+
+// CleanJSONSchema 清理 JSON Schema,移除 Antigravity/Gemini 不支持的字段
+// 参考 Antigravity-Manager/src-tauri/src/proxy/common/json_schema.rs 实现
+// 确保 schema 符合 JSON Schema draft 2020-12 且适配 Gemini v1internal
+func CleanJSONSchema(schema map[string]any) map[string]any {
+ if schema == nil {
+ return nil
+ }
+ // 0. 预处理:展开 $ref (Schema Flattening)
+ // (Go map 是引用的,直接修改 schema)
+ flattenRefs(schema, extractDefs(schema))
+
+ // 递归清理
+ cleaned := cleanJSONSchemaRecursive(schema)
+ result, ok := cleaned.(map[string]any)
+ if !ok {
+ return nil
+ }
+
+ return result
+}
+
+// extractDefs 提取并移除定义的 helper
+func extractDefs(schema map[string]any) map[string]any {
+ defs := make(map[string]any)
+ if d, ok := schema["$defs"].(map[string]any); ok {
+ for k, v := range d {
+ defs[k] = v
+ }
+ delete(schema, "$defs")
+ }
+ if d, ok := schema["definitions"].(map[string]any); ok {
+ for k, v := range d {
+ defs[k] = v
+ }
+ delete(schema, "definitions")
+ }
+ return defs
+}
+
+// flattenRefs 递归展开 $ref
+func flattenRefs(schema map[string]any, defs map[string]any) {
+ if len(defs) == 0 {
+ return // 无需展开
+ }
+
+ // 检查并替换 $ref
+ if ref, ok := schema["$ref"].(string); ok {
+ delete(schema, "$ref")
+ // 解析引用名 (例如 #/$defs/MyType -> MyType)
+ parts := strings.Split(ref, "/")
+ refName := parts[len(parts)-1]
+
+ if defSchema, exists := defs[refName]; exists {
+ if defMap, ok := defSchema.(map[string]any); ok {
+ // 合并定义内容 (不覆盖现有 key)
+ for k, v := range defMap {
+ if _, has := schema[k]; !has {
+ schema[k] = deepCopy(v) // 需深拷贝避免共享引用
+ }
+ }
+ // 递归处理刚刚合并进来的内容
+ flattenRefs(schema, defs)
+ }
+ }
+ }
+
+ // 遍历子节点
+ for _, v := range schema {
+ if subMap, ok := v.(map[string]any); ok {
+ flattenRefs(subMap, defs)
+ } else if subArr, ok := v.([]any); ok {
+ for _, item := range subArr {
+ if itemMap, ok := item.(map[string]any); ok {
+ flattenRefs(itemMap, defs)
+ }
+ }
+ }
+ }
+}
+
+// deepCopy 深拷贝 (简单实现,仅针对 JSON 类型)
+func deepCopy(src any) any {
+ if src == nil {
+ return nil
+ }
+ switch v := src.(type) {
+ case map[string]any:
+ dst := make(map[string]any)
+ for k, val := range v {
+ dst[k] = deepCopy(val)
+ }
+ return dst
+ case []any:
+ dst := make([]any, len(v))
+ for i, val := range v {
+ dst[i] = deepCopy(val)
+ }
+ return dst
+ default:
+ return src
+ }
+}
+
+// cleanJSONSchemaRecursive 递归核心清理逻辑
+// 返回处理后的值 (通常是 input map,但可能修改内部结构)
+func cleanJSONSchemaRecursive(value any) any {
+ schemaMap, ok := value.(map[string]any)
+ if !ok {
+ return value
+ }
+
+ // 0. [NEW] 合并 allOf
+ mergeAllOf(schemaMap)
+
+ // 1. [CRITICAL] 深度递归处理子项
+ if props, ok := schemaMap["properties"].(map[string]any); ok {
+ for _, v := range props {
+ cleanJSONSchemaRecursive(v)
+ }
+ // Go 中不需要像 Rust 那样显式处理 nullable_keys remove required,
+ // 因为我们在子项处理中会正确设置 type 和 description
+ } else if items, ok := schemaMap["items"]; ok {
+ // [FIX] Gemini 期望 "items" 是单个 Schema 对象(列表验证),而不是数组(元组验证)。
+ if itemsArr, ok := items.([]any); ok {
+ // 策略:将元组 [A, B] 视为 A、B 中的最佳匹配项。
+ best := extractBestSchemaFromUnion(itemsArr)
+ if best == nil {
+ // 回退到通用字符串
+ best = map[string]any{"type": "string"}
+ }
+ // 用处理后的对象替换原有数组
+ cleanedBest := cleanJSONSchemaRecursive(best)
+ schemaMap["items"] = cleanedBest
+ } else {
+ cleanJSONSchemaRecursive(items)
+ }
+ } else {
+ // 遍历所有值递归
+ for _, v := range schemaMap {
+ if _, isMap := v.(map[string]any); isMap {
+ cleanJSONSchemaRecursive(v)
+ } else if arr, isArr := v.([]any); isArr {
+ for _, item := range arr {
+ cleanJSONSchemaRecursive(item)
+ }
+ }
+ }
+ }
+
+ // 2. [FIX] 处理 anyOf/oneOf 联合类型: 合并属性而非直接删除
+ var unionArray []any
+ typeStr, _ := schemaMap["type"].(string)
+ if typeStr == "" || typeStr == "object" {
+ if anyOf, ok := schemaMap["anyOf"].([]any); ok {
+ unionArray = anyOf
+ } else if oneOf, ok := schemaMap["oneOf"].([]any); ok {
+ unionArray = oneOf
+ }
+ }
+
+ if len(unionArray) > 0 {
+ if bestBranch := extractBestSchemaFromUnion(unionArray); bestBranch != nil {
+ if bestMap, ok := bestBranch.(map[string]any); ok {
+ // 合并分支内容
+ for k, v := range bestMap {
+ if k == "properties" {
+ targetProps, _ := schemaMap["properties"].(map[string]any)
+ if targetProps == nil {
+ targetProps = make(map[string]any)
+ schemaMap["properties"] = targetProps
+ }
+ if sourceProps, ok := v.(map[string]any); ok {
+ for pk, pv := range sourceProps {
+ if _, exists := targetProps[pk]; !exists {
+ targetProps[pk] = deepCopy(pv)
+ }
+ }
+ }
+ } else if k == "required" {
+ targetReq, _ := schemaMap["required"].([]any)
+ if sourceReq, ok := v.([]any); ok {
+ for _, rv := range sourceReq {
+ // 简单的去重添加
+ exists := false
+ for _, tr := range targetReq {
+ if tr == rv {
+ exists = true
+ break
+ }
+ }
+ if !exists {
+ targetReq = append(targetReq, rv)
+ }
+ }
+ schemaMap["required"] = targetReq
+ }
+ } else if _, exists := schemaMap[k]; !exists {
+ schemaMap[k] = deepCopy(v)
+ }
+ }
+ }
+ }
+ }
+
+ // 3. [SAFETY] 检查当前对象是否为 JSON Schema 节点
+ looksLikeSchema := hasKey(schemaMap, "type") ||
+ hasKey(schemaMap, "properties") ||
+ hasKey(schemaMap, "items") ||
+ hasKey(schemaMap, "enum") ||
+ hasKey(schemaMap, "anyOf") ||
+ hasKey(schemaMap, "oneOf") ||
+ hasKey(schemaMap, "allOf")
+
+ if looksLikeSchema {
+ // 4. [ROBUST] 约束迁移
+ migrateConstraints(schemaMap)
+
+ // 5. [CRITICAL] 白名单过滤
+ allowedFields := map[string]bool{
+ "type": true,
+ "description": true,
+ "properties": true,
+ "required": true,
+ "items": true,
+ "enum": true,
+ "title": true,
+ }
+ for k := range schemaMap {
+ if !allowedFields[k] {
+ delete(schemaMap, k)
+ }
+ }
+
+ // 6. [SAFETY] 处理空 Object
+ if t, _ := schemaMap["type"].(string); t == "object" {
+ hasProps := false
+ if props, ok := schemaMap["properties"].(map[string]any); ok && len(props) > 0 {
+ hasProps = true
+ }
+ if !hasProps {
+ schemaMap["properties"] = map[string]any{
+ "reason": map[string]any{
+ "type": "string",
+ "description": "Reason for calling this tool",
+ },
+ }
+ schemaMap["required"] = []any{"reason"}
+ }
+ }
+
+ // 7. [SAFETY] Required 字段对齐
+ if props, ok := schemaMap["properties"].(map[string]any); ok {
+ if req, ok := schemaMap["required"].([]any); ok {
+ var validReq []any
+ for _, r := range req {
+ if rStr, ok := r.(string); ok {
+ if _, exists := props[rStr]; exists {
+ validReq = append(validReq, r)
+ }
+ }
+ }
+ if len(validReq) > 0 {
+ schemaMap["required"] = validReq
+ } else {
+ delete(schemaMap, "required")
+ }
+ }
+ }
+
+ // 8. 处理 type 字段 (Lowercase + Nullable 提取)
+ isEffectivelyNullable := false
+ if typeVal, exists := schemaMap["type"]; exists {
+ var selectedType string
+ switch v := typeVal.(type) {
+ case string:
+ lower := strings.ToLower(v)
+ if lower == "null" {
+ isEffectivelyNullable = true
+ selectedType = "string" // fallback
+ } else {
+ selectedType = lower
+ }
+ case []any:
+ // ["string", "null"]
+ for _, t := range v {
+ if ts, ok := t.(string); ok {
+ lower := strings.ToLower(ts)
+ if lower == "null" {
+ isEffectivelyNullable = true
+ } else if selectedType == "" {
+ selectedType = lower
+ }
+ }
+ }
+ if selectedType == "" {
+ selectedType = "string"
+ }
+ }
+ schemaMap["type"] = selectedType
+ } else {
+ // 默认 object 如果有 properties (虽然上面白名单过滤可能删了 type 如果它不在... 但 type 必在 allowlist)
+ // 如果没有 type,但有 properties,补一个
+ if hasKey(schemaMap, "properties") {
+ schemaMap["type"] = "object"
+ } else {
+ // 默认为 string ? or object? Gemini 通常需要明确 type
+ schemaMap["type"] = "object"
+ }
+ }
+
+ if isEffectivelyNullable {
+ desc, _ := schemaMap["description"].(string)
+ if !strings.Contains(desc, "nullable") {
+ if desc != "" {
+ desc += " "
+ }
+ desc += "(nullable)"
+ schemaMap["description"] = desc
+ }
+ }
+
+ // 9. Enum 值强制转字符串
+ if enumVals, ok := schemaMap["enum"].([]any); ok {
+ hasNonString := false
+ for i, val := range enumVals {
+ if _, isStr := val.(string); !isStr {
+ hasNonString = true
+ if val == nil {
+ enumVals[i] = "null"
+ } else {
+ enumVals[i] = fmt.Sprintf("%v", val)
+ }
+ }
+ }
+ // If we mandated string values, we must ensure type is string
+ if hasNonString {
+ schemaMap["type"] = "string"
+ }
+ }
+ }
+
+ return schemaMap
+}
+
+func hasKey(m map[string]any, k string) bool {
+ _, ok := m[k]
+ return ok
+}
+
+func migrateConstraints(m map[string]any) {
+ constraints := []struct {
+ key string
+ label string
+ }{
+ {"minLength", "minLen"},
+ {"maxLength", "maxLen"},
+ {"pattern", "pattern"},
+ {"minimum", "min"},
+ {"maximum", "max"},
+ {"multipleOf", "multipleOf"},
+ {"exclusiveMinimum", "exclMin"},
+ {"exclusiveMaximum", "exclMax"},
+ {"minItems", "minItems"},
+ {"maxItems", "maxItems"},
+ {"propertyNames", "propertyNames"},
+ {"format", "format"},
+ }
+
+ var hints []string
+ for _, c := range constraints {
+ if val, ok := m[c.key]; ok && val != nil {
+ hints = append(hints, fmt.Sprintf("%s: %v", c.label, val))
+ }
+ }
+
+ if len(hints) > 0 {
+ suffix := fmt.Sprintf(" [Constraint: %s]", strings.Join(hints, ", "))
+ desc, _ := m["description"].(string)
+ if !strings.Contains(desc, suffix) {
+ m["description"] = desc + suffix
+ }
+ }
+}
+
+// mergeAllOf 合并 allOf
+func mergeAllOf(m map[string]any) {
+ allOf, ok := m["allOf"].([]any)
+ if !ok {
+ return
+ }
+ delete(m, "allOf")
+
+ mergedProps := make(map[string]any)
+ mergedReq := make(map[string]bool)
+ otherFields := make(map[string]any)
+
+ for _, sub := range allOf {
+ if subMap, ok := sub.(map[string]any); ok {
+ // Props
+ if props, ok := subMap["properties"].(map[string]any); ok {
+ for k, v := range props {
+ mergedProps[k] = v
+ }
+ }
+ // Required
+ if reqs, ok := subMap["required"].([]any); ok {
+ for _, r := range reqs {
+ if s, ok := r.(string); ok {
+ mergedReq[s] = true
+ }
+ }
+ }
+ // Others
+ for k, v := range subMap {
+ if k != "properties" && k != "required" && k != "allOf" {
+ if _, exists := otherFields[k]; !exists {
+ otherFields[k] = v
+ }
+ }
+ }
+ }
+ }
+
+ // Apply
+ for k, v := range otherFields {
+ if _, exists := m[k]; !exists {
+ m[k] = v
+ }
+ }
+ if len(mergedProps) > 0 {
+ existProps, _ := m["properties"].(map[string]any)
+ if existProps == nil {
+ existProps = make(map[string]any)
+ m["properties"] = existProps
+ }
+ for k, v := range mergedProps {
+ if _, exists := existProps[k]; !exists {
+ existProps[k] = v
+ }
+ }
+ }
+ if len(mergedReq) > 0 {
+ existReq, _ := m["required"].([]any)
+ var validReqs []any
+ for _, r := range existReq {
+ if s, ok := r.(string); ok {
+ validReqs = append(validReqs, s)
+ delete(mergedReq, s) // already exists
+ }
+ }
+ // append new
+ for r := range mergedReq {
+ validReqs = append(validReqs, r)
+ }
+ m["required"] = validReqs
+ }
+}
+
+// extractBestSchemaFromUnion 从 anyOf/oneOf 中选取最佳分支
+func extractBestSchemaFromUnion(unionArray []any) any {
+ var bestOption any
+ bestScore := -1
+
+ for _, item := range unionArray {
+ score := scoreSchemaOption(item)
+ if score > bestScore {
+ bestScore = score
+ bestOption = item
+ }
+ }
+ return bestOption
+}
+
+func scoreSchemaOption(val any) int {
+ m, ok := val.(map[string]any)
+ if !ok {
+ return 0
+ }
+ typeStr, _ := m["type"].(string)
+
+ if hasKey(m, "properties") || typeStr == "object" {
+ return 3
+ }
+ if hasKey(m, "items") || typeStr == "array" {
+ return 2
+ }
+ if typeStr != "" && typeStr != "null" {
+ return 1
+ }
+ return 0
+}
+
+// DeepCleanUndefined 深度清理值为 "[undefined]" 的字段
+func DeepCleanUndefined(value any) {
+ if value == nil {
+ return
+ }
+ switch v := value.(type) {
+ case map[string]any:
+ for k, val := range v {
+ if s, ok := val.(string); ok && s == "[undefined]" {
+ delete(v, k)
+ continue
+ }
+ DeepCleanUndefined(val)
+ }
+ case []any:
+ for _, val := range v {
+ DeepCleanUndefined(val)
+ }
+ }
+}
diff --git a/backend/internal/pkg/antigravity/stream_transformer.go b/backend/internal/pkg/antigravity/stream_transformer.go
index da0c6f97..b384658a 100644
--- a/backend/internal/pkg/antigravity/stream_transformer.go
+++ b/backend/internal/pkg/antigravity/stream_transformer.go
@@ -4,6 +4,7 @@ import (
"bytes"
"encoding/json"
"fmt"
+ "log"
"strings"
)
@@ -102,6 +103,14 @@ func (p *StreamingProcessor) ProcessLine(line string) []byte {
// 检查是否结束
if len(geminiResp.Candidates) > 0 {
finishReason := geminiResp.Candidates[0].FinishReason
+ if finishReason == "MALFORMED_FUNCTION_CALL" {
+ log.Printf("[Antigravity] MALFORMED_FUNCTION_CALL detected in stream for model %s", p.originalModel)
+ if geminiResp.Candidates[0].Content != nil {
+ if b, err := json.Marshal(geminiResp.Candidates[0].Content); err == nil {
+ log.Printf("[Antigravity] Malformed content: %s", string(b))
+ }
+ }
+ }
if finishReason != "" {
_, _ = result.Write(p.emitFinish(finishReason))
}
diff --git a/backend/internal/pkg/oauth/oauth.go b/backend/internal/pkg/oauth/oauth.go
index 0a607dfb..33caffd7 100644
--- a/backend/internal/pkg/oauth/oauth.go
+++ b/backend/internal/pkg/oauth/oauth.go
@@ -24,9 +24,9 @@ const (
RedirectURI = "https://platform.claude.com/oauth/code/callback"
// Scopes - Browser URL (includes org:create_api_key for user authorization)
- ScopeOAuth = "org:create_api_key user:profile user:inference user:sessions:claude_code"
+ ScopeOAuth = "org:create_api_key user:profile user:inference user:sessions:claude_code user:mcp_servers"
// Scopes - Internal API call (org:create_api_key not supported in API)
- ScopeAPI = "user:profile user:inference user:sessions:claude_code"
+ ScopeAPI = "user:profile user:inference user:sessions:claude_code user:mcp_servers"
// Scopes - Setup token (inference only)
ScopeInference = "user:inference"
@@ -215,5 +215,6 @@ type OrgInfo struct {
// AccountInfo represents account info from OAuth response
type AccountInfo struct {
- UUID string `json:"uuid"`
+ UUID string `json:"uuid"`
+ EmailAddress string `json:"email_address"`
}
diff --git a/backend/internal/pkg/tlsfingerprint/dialer_integration_test.go b/backend/internal/pkg/tlsfingerprint/dialer_integration_test.go
new file mode 100644
index 00000000..eea74fcc
--- /dev/null
+++ b/backend/internal/pkg/tlsfingerprint/dialer_integration_test.go
@@ -0,0 +1,278 @@
+//go:build integration
+
+// Package tlsfingerprint provides TLS fingerprint simulation for HTTP clients.
+//
+// Integration tests for verifying TLS fingerprint correctness.
+// These tests make actual network requests to external services and should be run manually.
+//
+// Run with: go test -v -tags=integration ./internal/pkg/tlsfingerprint/...
+package tlsfingerprint
+
+import (
+ "context"
+ "encoding/json"
+ "io"
+ "net/http"
+ "strings"
+ "testing"
+ "time"
+)
+
+// skipIfExternalServiceUnavailable checks if the external service is available.
+// If not, it skips the test instead of failing.
+func skipIfExternalServiceUnavailable(t *testing.T, err error) {
+ t.Helper()
+ if err != nil {
+ // Check for common network/TLS errors that indicate external service issues
+ errStr := err.Error()
+ if strings.Contains(errStr, "certificate has expired") ||
+ strings.Contains(errStr, "certificate is not yet valid") ||
+ strings.Contains(errStr, "connection refused") ||
+ strings.Contains(errStr, "no such host") ||
+ strings.Contains(errStr, "network is unreachable") ||
+ strings.Contains(errStr, "timeout") {
+ t.Skipf("skipping test: external service unavailable: %v", err)
+ }
+ t.Fatalf("failed to get fingerprint: %v", err)
+ }
+}
+
+// TestJA3Fingerprint verifies the JA3/JA4 fingerprint matches expected value.
+// This test uses tls.peet.ws to verify the fingerprint.
+// Expected JA3 hash: 1a28e69016765d92e3b381168d68922c (Claude CLI / Node.js 20.x)
+// Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4 (d=domain) or t13i5911h1_... (i=IP)
+func TestJA3Fingerprint(t *testing.T) {
+ // Skip if network is unavailable or if running in short mode
+ if testing.Short() {
+ t.Skip("skipping integration test in short mode")
+ }
+
+ profile := &Profile{
+ Name: "Claude CLI Test",
+ EnableGREASE: false,
+ }
+ dialer := NewDialer(profile, nil)
+
+ client := &http.Client{
+ Transport: &http.Transport{
+ DialTLSContext: dialer.DialTLSContext,
+ },
+ Timeout: 30 * time.Second,
+ }
+
+ // Use tls.peet.ws fingerprint detection API
+ ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
+ defer cancel()
+
+ req, err := http.NewRequestWithContext(ctx, "GET", "https://tls.peet.ws/api/all", nil)
+ if err != nil {
+ t.Fatalf("failed to create request: %v", err)
+ }
+ req.Header.Set("User-Agent", "Claude Code/2.0.0 Node.js/20.0.0")
+
+ resp, err := client.Do(req)
+ skipIfExternalServiceUnavailable(t, err)
+ defer func() { _ = resp.Body.Close() }()
+
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ t.Fatalf("failed to read response: %v", err)
+ }
+
+ var fpResp FingerprintResponse
+ if err := json.Unmarshal(body, &fpResp); err != nil {
+ t.Logf("Response body: %s", string(body))
+ t.Fatalf("failed to parse fingerprint response: %v", err)
+ }
+
+ // Log all fingerprint information
+ t.Logf("JA3: %s", fpResp.TLS.JA3)
+ t.Logf("JA3 Hash: %s", fpResp.TLS.JA3Hash)
+ t.Logf("JA4: %s", fpResp.TLS.JA4)
+ t.Logf("PeetPrint: %s", fpResp.TLS.PeetPrint)
+ t.Logf("PeetPrint Hash: %s", fpResp.TLS.PeetPrintHash)
+
+ // Verify JA3 hash matches expected value
+ expectedJA3Hash := "1a28e69016765d92e3b381168d68922c"
+ if fpResp.TLS.JA3Hash == expectedJA3Hash {
+ t.Logf("✓ JA3 hash matches expected value: %s", expectedJA3Hash)
+ } else {
+ t.Errorf("✗ JA3 hash mismatch: got %s, expected %s", fpResp.TLS.JA3Hash, expectedJA3Hash)
+ }
+
+ // Verify JA4 fingerprint
+ // JA4 format: t[version][sni][cipher_count][ext_count][alpn]_[cipher_hash]_[ext_hash]
+ // Expected: t13d5910h1 (d=domain) or t13i5910h1 (i=IP)
+ // The suffix _a33745022dd6_1f22a2ca17c4 should match
+ expectedJA4Suffix := "_a33745022dd6_1f22a2ca17c4"
+ if strings.HasSuffix(fpResp.TLS.JA4, expectedJA4Suffix) {
+ t.Logf("✓ JA4 suffix matches expected value: %s", expectedJA4Suffix)
+ } else {
+ t.Errorf("✗ JA4 suffix mismatch: got %s, expected suffix %s", fpResp.TLS.JA4, expectedJA4Suffix)
+ }
+
+ // Verify JA4 prefix (t13d5911h1 or t13i5911h1)
+ // d = domain (SNI present), i = IP (no SNI)
+ // Since we connect to tls.peet.ws (domain), we expect 'd'
+ expectedJA4Prefix := "t13d5911h1"
+ if strings.HasPrefix(fpResp.TLS.JA4, expectedJA4Prefix) {
+ t.Logf("✓ JA4 prefix matches: %s (t13=TLS1.3, d=domain, 59=ciphers, 11=extensions, h1=HTTP/1.1)", expectedJA4Prefix)
+ } else {
+ // Also accept 'i' variant for IP connections
+ altPrefix := "t13i5911h1"
+ if strings.HasPrefix(fpResp.TLS.JA4, altPrefix) {
+ t.Logf("✓ JA4 prefix matches (IP variant): %s", altPrefix)
+ } else {
+ t.Errorf("✗ JA4 prefix mismatch: got %s, expected %s or %s", fpResp.TLS.JA4, expectedJA4Prefix, altPrefix)
+ }
+ }
+
+ // Verify JA3 contains expected cipher suites (TLS 1.3 ciphers at the beginning)
+ if strings.Contains(fpResp.TLS.JA3, "4866-4867-4865") {
+ t.Logf("✓ JA3 contains expected TLS 1.3 cipher suites")
+ } else {
+ t.Logf("Warning: JA3 does not contain expected TLS 1.3 cipher suites")
+ }
+
+ // Verify extension list (should be 11 extensions including SNI)
+ // Expected: 0-11-10-35-16-22-23-13-43-45-51
+ expectedExtensions := "0-11-10-35-16-22-23-13-43-45-51"
+ if strings.Contains(fpResp.TLS.JA3, expectedExtensions) {
+ t.Logf("✓ JA3 contains expected extension list: %s", expectedExtensions)
+ } else {
+ t.Logf("Warning: JA3 extension list may differ")
+ }
+}
+
+// TestProfileExpectation defines expected fingerprint values for a profile.
+type TestProfileExpectation struct {
+ Profile *Profile
+ ExpectedJA3 string // Expected JA3 hash (empty = don't check)
+ ExpectedJA4 string // Expected full JA4 (empty = don't check)
+ JA4CipherHash string // Expected JA4 cipher hash - the stable middle part (empty = don't check)
+}
+
+// TestAllProfiles tests multiple TLS fingerprint profiles against tls.peet.ws.
+// Run with: go test -v -tags=integration -run TestAllProfiles ./internal/pkg/tlsfingerprint/...
+func TestAllProfiles(t *testing.T) {
+ if testing.Short() {
+ t.Skip("skipping integration test in short mode")
+ }
+
+ // Define all profiles to test with their expected fingerprints
+ // These profiles are from config.yaml gateway.tls_fingerprint.profiles
+ profiles := []TestProfileExpectation{
+ {
+ // Linux x64 Node.js v22.17.1
+ // Expected JA3 Hash: 1a28e69016765d92e3b381168d68922c
+ // Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4
+ Profile: &Profile{
+ Name: "linux_x64_node_v22171",
+ EnableGREASE: false,
+ CipherSuites: []uint16{4866, 4867, 4865, 49199, 49195, 49200, 49196, 158, 49191, 103, 49192, 107, 163, 159, 52393, 52392, 52394, 49327, 49325, 49315, 49311, 49245, 49249, 49239, 49235, 162, 49326, 49324, 49314, 49310, 49244, 49248, 49238, 49234, 49188, 106, 49187, 64, 49162, 49172, 57, 56, 49161, 49171, 51, 50, 157, 49313, 49309, 49233, 156, 49312, 49308, 49232, 61, 60, 53, 47, 255},
+ Curves: []uint16{29, 23, 30, 25, 24, 256, 257, 258, 259, 260},
+ PointFormats: []uint8{0, 1, 2},
+ },
+ JA4CipherHash: "a33745022dd6", // stable part
+ },
+ {
+ // MacOS arm64 Node.js v22.18.0
+ // Expected JA3 Hash: 70cb5ca646080902703ffda87036a5ea
+ // Expected JA4: t13d5912h1_a33745022dd6_dbd39dd1d406
+ Profile: &Profile{
+ Name: "macos_arm64_node_v22180",
+ EnableGREASE: false,
+ CipherSuites: []uint16{4866, 4867, 4865, 49199, 49195, 49200, 49196, 158, 49191, 103, 49192, 107, 163, 159, 52393, 52392, 52394, 49327, 49325, 49315, 49311, 49245, 49249, 49239, 49235, 162, 49326, 49324, 49314, 49310, 49244, 49248, 49238, 49234, 49188, 106, 49187, 64, 49162, 49172, 57, 56, 49161, 49171, 51, 50, 157, 49313, 49309, 49233, 156, 49312, 49308, 49232, 61, 60, 53, 47, 255},
+ Curves: []uint16{29, 23, 30, 25, 24, 256, 257, 258, 259, 260},
+ PointFormats: []uint8{0, 1, 2},
+ },
+ JA4CipherHash: "a33745022dd6", // stable part (same cipher suites)
+ },
+ }
+
+ for _, tc := range profiles {
+ tc := tc // capture range variable
+ t.Run(tc.Profile.Name, func(t *testing.T) {
+ fp := fetchFingerprint(t, tc.Profile)
+ if fp == nil {
+ return // fetchFingerprint already called t.Fatal
+ }
+
+ t.Logf("Profile: %s", tc.Profile.Name)
+ t.Logf(" JA3: %s", fp.JA3)
+ t.Logf(" JA3 Hash: %s", fp.JA3Hash)
+ t.Logf(" JA4: %s", fp.JA4)
+ t.Logf(" PeetPrint: %s", fp.PeetPrint)
+ t.Logf(" PeetPrintHash: %s", fp.PeetPrintHash)
+
+ // Verify expectations
+ if tc.ExpectedJA3 != "" {
+ if fp.JA3Hash == tc.ExpectedJA3 {
+ t.Logf(" ✓ JA3 hash matches: %s", tc.ExpectedJA3)
+ } else {
+ t.Errorf(" ✗ JA3 hash mismatch: got %s, expected %s", fp.JA3Hash, tc.ExpectedJA3)
+ }
+ }
+
+ if tc.ExpectedJA4 != "" {
+ if fp.JA4 == tc.ExpectedJA4 {
+ t.Logf(" ✓ JA4 matches: %s", tc.ExpectedJA4)
+ } else {
+ t.Errorf(" ✗ JA4 mismatch: got %s, expected %s", fp.JA4, tc.ExpectedJA4)
+ }
+ }
+
+ // Check JA4 cipher hash (stable middle part)
+ // JA4 format: prefix_cipherHash_extHash
+ if tc.JA4CipherHash != "" {
+ if strings.Contains(fp.JA4, "_"+tc.JA4CipherHash+"_") {
+ t.Logf(" ✓ JA4 cipher hash matches: %s", tc.JA4CipherHash)
+ } else {
+ t.Errorf(" ✗ JA4 cipher hash mismatch: got %s, expected cipher hash %s", fp.JA4, tc.JA4CipherHash)
+ }
+ }
+ })
+ }
+}
+
+// fetchFingerprint makes a request to tls.peet.ws and returns the TLS fingerprint info.
+func fetchFingerprint(t *testing.T, profile *Profile) *TLSInfo {
+ t.Helper()
+
+ dialer := NewDialer(profile, nil)
+ client := &http.Client{
+ Transport: &http.Transport{
+ DialTLSContext: dialer.DialTLSContext,
+ },
+ Timeout: 30 * time.Second,
+ }
+
+ ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
+ defer cancel()
+
+ req, err := http.NewRequestWithContext(ctx, "GET", "https://tls.peet.ws/api/all", nil)
+ if err != nil {
+ t.Fatalf("failed to create request: %v", err)
+ return nil
+ }
+ req.Header.Set("User-Agent", "Claude Code/2.0.0 Node.js/20.0.0")
+
+ resp, err := client.Do(req)
+ skipIfExternalServiceUnavailable(t, err)
+ defer func() { _ = resp.Body.Close() }()
+
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ t.Fatalf("failed to read response: %v", err)
+ return nil
+ }
+
+ var fpResp FingerprintResponse
+ if err := json.Unmarshal(body, &fpResp); err != nil {
+ t.Logf("Response body: %s", string(body))
+ t.Fatalf("failed to parse fingerprint response: %v", err)
+ return nil
+ }
+
+ return &fpResp.TLS
+}
diff --git a/backend/internal/pkg/tlsfingerprint/dialer_test.go b/backend/internal/pkg/tlsfingerprint/dialer_test.go
index 845d51e5..dff7570f 100644
--- a/backend/internal/pkg/tlsfingerprint/dialer_test.go
+++ b/backend/internal/pkg/tlsfingerprint/dialer_test.go
@@ -1,21 +1,16 @@
// Package tlsfingerprint provides TLS fingerprint simulation for HTTP clients.
//
-// Integration tests for verifying TLS fingerprint correctness.
-// These tests make actual network requests and should be run manually.
+// Unit tests for TLS fingerprint dialer.
+// Integration tests that require external network are in dialer_integration_test.go
+// and require the 'integration' build tag.
//
-// Run with: go test -v ./internal/pkg/tlsfingerprint/...
-// Run integration tests: go test -v -run TestJA3 ./internal/pkg/tlsfingerprint/...
+// Run unit tests: go test -v ./internal/pkg/tlsfingerprint/...
+// Run integration tests: go test -v -tags=integration ./internal/pkg/tlsfingerprint/...
package tlsfingerprint
import (
- "context"
- "encoding/json"
- "io"
- "net/http"
"net/url"
- "strings"
"testing"
- "time"
)
// FingerprintResponse represents the response from tls.peet.ws/api/all.
@@ -36,148 +31,6 @@ type TLSInfo struct {
SessionID string `json:"session_id"`
}
-// TestDialerBasicConnection tests that the dialer can establish TLS connections.
-func TestDialerBasicConnection(t *testing.T) {
- if testing.Short() {
- t.Skip("skipping network test in short mode")
- }
-
- // Create a dialer with default profile
- profile := &Profile{
- Name: "Test Profile",
- EnableGREASE: false,
- }
- dialer := NewDialer(profile, nil)
-
- // Create HTTP client with custom TLS dialer
- client := &http.Client{
- Transport: &http.Transport{
- DialTLSContext: dialer.DialTLSContext,
- },
- Timeout: 30 * time.Second,
- }
-
- // Make a request to a known HTTPS endpoint
- resp, err := client.Get("https://www.google.com")
- if err != nil {
- t.Fatalf("failed to connect: %v", err)
- }
- defer func() { _ = resp.Body.Close() }()
-
- if resp.StatusCode != http.StatusOK {
- t.Errorf("expected status 200, got %d", resp.StatusCode)
- }
-}
-
-// TestJA3Fingerprint verifies the JA3/JA4 fingerprint matches expected value.
-// This test uses tls.peet.ws to verify the fingerprint.
-// Expected JA3 hash: 1a28e69016765d92e3b381168d68922c (Claude CLI / Node.js 20.x)
-// Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4 (d=domain) or t13i5911h1_... (i=IP)
-func TestJA3Fingerprint(t *testing.T) {
- // Skip if network is unavailable or if running in short mode
- if testing.Short() {
- t.Skip("skipping integration test in short mode")
- }
-
- profile := &Profile{
- Name: "Claude CLI Test",
- EnableGREASE: false,
- }
- dialer := NewDialer(profile, nil)
-
- client := &http.Client{
- Transport: &http.Transport{
- DialTLSContext: dialer.DialTLSContext,
- },
- Timeout: 30 * time.Second,
- }
-
- // Use tls.peet.ws fingerprint detection API
- ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
- defer cancel()
-
- req, err := http.NewRequestWithContext(ctx, "GET", "https://tls.peet.ws/api/all", nil)
- if err != nil {
- t.Fatalf("failed to create request: %v", err)
- }
- req.Header.Set("User-Agent", "Claude Code/2.0.0 Node.js/20.0.0")
-
- resp, err := client.Do(req)
- if err != nil {
- t.Fatalf("failed to get fingerprint: %v", err)
- }
- defer func() { _ = resp.Body.Close() }()
-
- body, err := io.ReadAll(resp.Body)
- if err != nil {
- t.Fatalf("failed to read response: %v", err)
- }
-
- var fpResp FingerprintResponse
- if err := json.Unmarshal(body, &fpResp); err != nil {
- t.Logf("Response body: %s", string(body))
- t.Fatalf("failed to parse fingerprint response: %v", err)
- }
-
- // Log all fingerprint information
- t.Logf("JA3: %s", fpResp.TLS.JA3)
- t.Logf("JA3 Hash: %s", fpResp.TLS.JA3Hash)
- t.Logf("JA4: %s", fpResp.TLS.JA4)
- t.Logf("PeetPrint: %s", fpResp.TLS.PeetPrint)
- t.Logf("PeetPrint Hash: %s", fpResp.TLS.PeetPrintHash)
-
- // Verify JA3 hash matches expected value
- expectedJA3Hash := "1a28e69016765d92e3b381168d68922c"
- if fpResp.TLS.JA3Hash == expectedJA3Hash {
- t.Logf("✓ JA3 hash matches expected value: %s", expectedJA3Hash)
- } else {
- t.Errorf("✗ JA3 hash mismatch: got %s, expected %s", fpResp.TLS.JA3Hash, expectedJA3Hash)
- }
-
- // Verify JA4 fingerprint
- // JA4 format: t[version][sni][cipher_count][ext_count][alpn]_[cipher_hash]_[ext_hash]
- // Expected: t13d5910h1 (d=domain) or t13i5910h1 (i=IP)
- // The suffix _a33745022dd6_1f22a2ca17c4 should match
- expectedJA4Suffix := "_a33745022dd6_1f22a2ca17c4"
- if strings.HasSuffix(fpResp.TLS.JA4, expectedJA4Suffix) {
- t.Logf("✓ JA4 suffix matches expected value: %s", expectedJA4Suffix)
- } else {
- t.Errorf("✗ JA4 suffix mismatch: got %s, expected suffix %s", fpResp.TLS.JA4, expectedJA4Suffix)
- }
-
- // Verify JA4 prefix (t13d5911h1 or t13i5911h1)
- // d = domain (SNI present), i = IP (no SNI)
- // Since we connect to tls.peet.ws (domain), we expect 'd'
- expectedJA4Prefix := "t13d5911h1"
- if strings.HasPrefix(fpResp.TLS.JA4, expectedJA4Prefix) {
- t.Logf("✓ JA4 prefix matches: %s (t13=TLS1.3, d=domain, 59=ciphers, 11=extensions, h1=HTTP/1.1)", expectedJA4Prefix)
- } else {
- // Also accept 'i' variant for IP connections
- altPrefix := "t13i5911h1"
- if strings.HasPrefix(fpResp.TLS.JA4, altPrefix) {
- t.Logf("✓ JA4 prefix matches (IP variant): %s", altPrefix)
- } else {
- t.Errorf("✗ JA4 prefix mismatch: got %s, expected %s or %s", fpResp.TLS.JA4, expectedJA4Prefix, altPrefix)
- }
- }
-
- // Verify JA3 contains expected cipher suites (TLS 1.3 ciphers at the beginning)
- if strings.Contains(fpResp.TLS.JA3, "4866-4867-4865") {
- t.Logf("✓ JA3 contains expected TLS 1.3 cipher suites")
- } else {
- t.Logf("Warning: JA3 does not contain expected TLS 1.3 cipher suites")
- }
-
- // Verify extension list (should be 11 extensions including SNI)
- // Expected: 0-11-10-35-16-22-23-13-43-45-51
- expectedExtensions := "0-11-10-35-16-22-23-13-43-45-51"
- if strings.Contains(fpResp.TLS.JA3, expectedExtensions) {
- t.Logf("✓ JA3 contains expected extension list: %s", expectedExtensions)
- } else {
- t.Logf("Warning: JA3 extension list may differ")
- }
-}
-
// TestDialerWithProfile tests that different profiles produce different fingerprints.
func TestDialerWithProfile(t *testing.T) {
// Create two dialers with different profiles
@@ -305,139 +158,3 @@ func mustParseURL(rawURL string) *url.URL {
}
return u
}
-
-// TestProfileExpectation defines expected fingerprint values for a profile.
-type TestProfileExpectation struct {
- Profile *Profile
- ExpectedJA3 string // Expected JA3 hash (empty = don't check)
- ExpectedJA4 string // Expected full JA4 (empty = don't check)
- JA4CipherHash string // Expected JA4 cipher hash - the stable middle part (empty = don't check)
-}
-
-// TestAllProfiles tests multiple TLS fingerprint profiles against tls.peet.ws.
-// Run with: go test -v -run TestAllProfiles ./internal/pkg/tlsfingerprint/...
-func TestAllProfiles(t *testing.T) {
- if testing.Short() {
- t.Skip("skipping integration test in short mode")
- }
-
- // Define all profiles to test with their expected fingerprints
- // These profiles are from config.yaml gateway.tls_fingerprint.profiles
- profiles := []TestProfileExpectation{
- {
- // Linux x64 Node.js v22.17.1
- // Expected JA3 Hash: 1a28e69016765d92e3b381168d68922c
- // Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4
- Profile: &Profile{
- Name: "linux_x64_node_v22171",
- EnableGREASE: false,
- CipherSuites: []uint16{4866, 4867, 4865, 49199, 49195, 49200, 49196, 158, 49191, 103, 49192, 107, 163, 159, 52393, 52392, 52394, 49327, 49325, 49315, 49311, 49245, 49249, 49239, 49235, 162, 49326, 49324, 49314, 49310, 49244, 49248, 49238, 49234, 49188, 106, 49187, 64, 49162, 49172, 57, 56, 49161, 49171, 51, 50, 157, 49313, 49309, 49233, 156, 49312, 49308, 49232, 61, 60, 53, 47, 255},
- Curves: []uint16{29, 23, 30, 25, 24, 256, 257, 258, 259, 260},
- PointFormats: []uint8{0, 1, 2},
- },
- JA4CipherHash: "a33745022dd6", // stable part
- },
- {
- // MacOS arm64 Node.js v22.18.0
- // Expected JA3 Hash: 70cb5ca646080902703ffda87036a5ea
- // Expected JA4: t13d5912h1_a33745022dd6_dbd39dd1d406
- Profile: &Profile{
- Name: "macos_arm64_node_v22180",
- EnableGREASE: false,
- CipherSuites: []uint16{4866, 4867, 4865, 49199, 49195, 49200, 49196, 158, 49191, 103, 49192, 107, 163, 159, 52393, 52392, 52394, 49327, 49325, 49315, 49311, 49245, 49249, 49239, 49235, 162, 49326, 49324, 49314, 49310, 49244, 49248, 49238, 49234, 49188, 106, 49187, 64, 49162, 49172, 57, 56, 49161, 49171, 51, 50, 157, 49313, 49309, 49233, 156, 49312, 49308, 49232, 61, 60, 53, 47, 255},
- Curves: []uint16{29, 23, 30, 25, 24, 256, 257, 258, 259, 260},
- PointFormats: []uint8{0, 1, 2},
- },
- JA4CipherHash: "a33745022dd6", // stable part (same cipher suites)
- },
- }
-
- for _, tc := range profiles {
- tc := tc // capture range variable
- t.Run(tc.Profile.Name, func(t *testing.T) {
- fp := fetchFingerprint(t, tc.Profile)
- if fp == nil {
- return // fetchFingerprint already called t.Fatal
- }
-
- t.Logf("Profile: %s", tc.Profile.Name)
- t.Logf(" JA3: %s", fp.JA3)
- t.Logf(" JA3 Hash: %s", fp.JA3Hash)
- t.Logf(" JA4: %s", fp.JA4)
- t.Logf(" PeetPrint: %s", fp.PeetPrint)
- t.Logf(" PeetPrintHash: %s", fp.PeetPrintHash)
-
- // Verify expectations
- if tc.ExpectedJA3 != "" {
- if fp.JA3Hash == tc.ExpectedJA3 {
- t.Logf(" ✓ JA3 hash matches: %s", tc.ExpectedJA3)
- } else {
- t.Errorf(" ✗ JA3 hash mismatch: got %s, expected %s", fp.JA3Hash, tc.ExpectedJA3)
- }
- }
-
- if tc.ExpectedJA4 != "" {
- if fp.JA4 == tc.ExpectedJA4 {
- t.Logf(" ✓ JA4 matches: %s", tc.ExpectedJA4)
- } else {
- t.Errorf(" ✗ JA4 mismatch: got %s, expected %s", fp.JA4, tc.ExpectedJA4)
- }
- }
-
- // Check JA4 cipher hash (stable middle part)
- // JA4 format: prefix_cipherHash_extHash
- if tc.JA4CipherHash != "" {
- if strings.Contains(fp.JA4, "_"+tc.JA4CipherHash+"_") {
- t.Logf(" ✓ JA4 cipher hash matches: %s", tc.JA4CipherHash)
- } else {
- t.Errorf(" ✗ JA4 cipher hash mismatch: got %s, expected cipher hash %s", fp.JA4, tc.JA4CipherHash)
- }
- }
- })
- }
-}
-
-// fetchFingerprint makes a request to tls.peet.ws and returns the TLS fingerprint info.
-func fetchFingerprint(t *testing.T, profile *Profile) *TLSInfo {
- t.Helper()
-
- dialer := NewDialer(profile, nil)
- client := &http.Client{
- Transport: &http.Transport{
- DialTLSContext: dialer.DialTLSContext,
- },
- Timeout: 30 * time.Second,
- }
-
- ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
- defer cancel()
-
- req, err := http.NewRequestWithContext(ctx, "GET", "https://tls.peet.ws/api/all", nil)
- if err != nil {
- t.Fatalf("failed to create request: %v", err)
- return nil
- }
- req.Header.Set("User-Agent", "Claude Code/2.0.0 Node.js/20.0.0")
-
- resp, err := client.Do(req)
- if err != nil {
- t.Fatalf("failed to get fingerprint: %v", err)
- return nil
- }
- defer func() { _ = resp.Body.Close() }()
-
- body, err := io.ReadAll(resp.Body)
- if err != nil {
- t.Fatalf("failed to read response: %v", err)
- return nil
- }
-
- var fpResp FingerprintResponse
- if err := json.Unmarshal(body, &fpResp); err != nil {
- t.Logf("Response body: %s", string(body))
- t.Fatalf("failed to parse fingerprint response: %v", err)
- return nil
- }
-
- return &fpResp.TLS
-}
diff --git a/backend/internal/repository/aes_encryptor.go b/backend/internal/repository/aes_encryptor.go
new file mode 100644
index 00000000..924e3698
--- /dev/null
+++ b/backend/internal/repository/aes_encryptor.go
@@ -0,0 +1,95 @@
+package repository
+
+import (
+ "crypto/aes"
+ "crypto/cipher"
+ "crypto/rand"
+ "encoding/base64"
+ "encoding/hex"
+ "fmt"
+ "io"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+)
+
+// AESEncryptor implements SecretEncryptor using AES-256-GCM
+type AESEncryptor struct {
+ key []byte
+}
+
+// NewAESEncryptor creates a new AES encryptor
+func NewAESEncryptor(cfg *config.Config) (service.SecretEncryptor, error) {
+ key, err := hex.DecodeString(cfg.Totp.EncryptionKey)
+ if err != nil {
+ return nil, fmt.Errorf("invalid totp encryption key: %w", err)
+ }
+
+ if len(key) != 32 {
+ return nil, fmt.Errorf("totp encryption key must be 32 bytes (64 hex chars), got %d bytes", len(key))
+ }
+
+ return &AESEncryptor{key: key}, nil
+}
+
+// Encrypt encrypts plaintext using AES-256-GCM
+// Output format: base64(nonce + ciphertext + tag)
+func (e *AESEncryptor) Encrypt(plaintext string) (string, error) {
+ block, err := aes.NewCipher(e.key)
+ if err != nil {
+ return "", fmt.Errorf("create cipher: %w", err)
+ }
+
+ gcm, err := cipher.NewGCM(block)
+ if err != nil {
+ return "", fmt.Errorf("create gcm: %w", err)
+ }
+
+ // Generate a random nonce
+ nonce := make([]byte, gcm.NonceSize())
+ if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
+ return "", fmt.Errorf("generate nonce: %w", err)
+ }
+
+ // Encrypt the plaintext
+ // Seal appends the ciphertext and tag to the nonce
+ ciphertext := gcm.Seal(nonce, nonce, []byte(plaintext), nil)
+
+ // Encode as base64
+ return base64.StdEncoding.EncodeToString(ciphertext), nil
+}
+
+// Decrypt decrypts ciphertext using AES-256-GCM
+func (e *AESEncryptor) Decrypt(ciphertext string) (string, error) {
+ // Decode from base64
+ data, err := base64.StdEncoding.DecodeString(ciphertext)
+ if err != nil {
+ return "", fmt.Errorf("decode base64: %w", err)
+ }
+
+ block, err := aes.NewCipher(e.key)
+ if err != nil {
+ return "", fmt.Errorf("create cipher: %w", err)
+ }
+
+ gcm, err := cipher.NewGCM(block)
+ if err != nil {
+ return "", fmt.Errorf("create gcm: %w", err)
+ }
+
+ nonceSize := gcm.NonceSize()
+ if len(data) < nonceSize {
+ return "", fmt.Errorf("ciphertext too short")
+ }
+
+ // Extract nonce and ciphertext
+ nonce, ciphertextData := data[:nonceSize], data[nonceSize:]
+
+ // Decrypt
+ plaintext, err := gcm.Open(nil, nonce, ciphertextData, nil)
+ if err != nil {
+ return "", fmt.Errorf("decrypt: %w", err)
+ }
+
+ return string(plaintext), nil
+}
diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go
index ab890844..1e5a62df 100644
--- a/backend/internal/repository/api_key_repo.go
+++ b/backend/internal/repository/api_key_repo.go
@@ -387,17 +387,20 @@ func userEntityToService(u *dbent.User) *service.User {
return nil
}
return &service.User{
- ID: u.ID,
- Email: u.Email,
- Username: u.Username,
- Notes: u.Notes,
- PasswordHash: u.PasswordHash,
- Role: u.Role,
- Balance: u.Balance,
- Concurrency: u.Concurrency,
- Status: u.Status,
- CreatedAt: u.CreatedAt,
- UpdatedAt: u.UpdatedAt,
+ ID: u.ID,
+ Email: u.Email,
+ Username: u.Username,
+ Notes: u.Notes,
+ PasswordHash: u.PasswordHash,
+ Role: u.Role,
+ Balance: u.Balance,
+ Concurrency: u.Concurrency,
+ Status: u.Status,
+ TotpSecretEncrypted: u.TotpSecretEncrypted,
+ TotpEnabled: u.TotpEnabled,
+ TotpEnabledAt: u.TotpEnabledAt,
+ CreatedAt: u.CreatedAt,
+ UpdatedAt: u.UpdatedAt,
}
}
diff --git a/backend/internal/repository/claude_oauth_service.go b/backend/internal/repository/claude_oauth_service.go
index 1f1db553..fc0d2918 100644
--- a/backend/internal/repository/claude_oauth_service.go
+++ b/backend/internal/repository/claude_oauth_service.go
@@ -35,7 +35,9 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey
client := s.clientFactory(proxyURL)
var orgs []struct {
- UUID string `json:"uuid"`
+ UUID string `json:"uuid"`
+ Name string `json:"name"`
+ RavenType *string `json:"raven_type"` // nil for personal, "team" for team organization
}
targetURL := s.baseURL + "/api/organizations"
@@ -65,7 +67,23 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey
return "", fmt.Errorf("no organizations found")
}
- log.Printf("[OAuth] Step 1 SUCCESS - Got org UUID: %s", orgs[0].UUID)
+ // 如果只有一个组织,直接使用
+ if len(orgs) == 1 {
+ log.Printf("[OAuth] Step 1 SUCCESS - Single org found, UUID: %s, Name: %s", orgs[0].UUID, orgs[0].Name)
+ return orgs[0].UUID, nil
+ }
+
+ // 如果有多个组织,优先选择 raven_type 为 "team" 的组织
+ for _, org := range orgs {
+ if org.RavenType != nil && *org.RavenType == "team" {
+ log.Printf("[OAuth] Step 1 SUCCESS - Selected team org, UUID: %s, Name: %s, RavenType: %s",
+ org.UUID, org.Name, *org.RavenType)
+ return org.UUID, nil
+ }
+ }
+
+ // 如果没有 team 类型的组织,使用第一个
+ log.Printf("[OAuth] Step 1 SUCCESS - No team org found, using first org, UUID: %s, Name: %s", orgs[0].UUID, orgs[0].Name)
return orgs[0].UUID, nil
}
diff --git a/backend/internal/repository/email_cache.go b/backend/internal/repository/email_cache.go
index e00e35dd..8f2b8eca 100644
--- a/backend/internal/repository/email_cache.go
+++ b/backend/internal/repository/email_cache.go
@@ -9,13 +9,27 @@ import (
"github.com/redis/go-redis/v9"
)
-const verifyCodeKeyPrefix = "verify_code:"
+const (
+ verifyCodeKeyPrefix = "verify_code:"
+ passwordResetKeyPrefix = "password_reset:"
+ passwordResetSentAtKeyPrefix = "password_reset_sent:"
+)
// verifyCodeKey generates the Redis key for email verification code.
func verifyCodeKey(email string) string {
return verifyCodeKeyPrefix + email
}
+// passwordResetKey generates the Redis key for password reset token.
+func passwordResetKey(email string) string {
+ return passwordResetKeyPrefix + email
+}
+
+// passwordResetSentAtKey generates the Redis key for password reset email sent timestamp.
+func passwordResetSentAtKey(email string) string {
+ return passwordResetSentAtKeyPrefix + email
+}
+
type emailCache struct {
rdb *redis.Client
}
@@ -50,3 +64,45 @@ func (c *emailCache) DeleteVerificationCode(ctx context.Context, email string) e
key := verifyCodeKey(email)
return c.rdb.Del(ctx, key).Err()
}
+
+// Password reset token methods
+
+func (c *emailCache) GetPasswordResetToken(ctx context.Context, email string) (*service.PasswordResetTokenData, error) {
+ key := passwordResetKey(email)
+ val, err := c.rdb.Get(ctx, key).Result()
+ if err != nil {
+ return nil, err
+ }
+ var data service.PasswordResetTokenData
+ if err := json.Unmarshal([]byte(val), &data); err != nil {
+ return nil, err
+ }
+ return &data, nil
+}
+
+func (c *emailCache) SetPasswordResetToken(ctx context.Context, email string, data *service.PasswordResetTokenData, ttl time.Duration) error {
+ key := passwordResetKey(email)
+ val, err := json.Marshal(data)
+ if err != nil {
+ return err
+ }
+ return c.rdb.Set(ctx, key, val, ttl).Err()
+}
+
+func (c *emailCache) DeletePasswordResetToken(ctx context.Context, email string) error {
+ key := passwordResetKey(email)
+ return c.rdb.Del(ctx, key).Err()
+}
+
+// Password reset email cooldown methods
+
+func (c *emailCache) IsPasswordResetEmailInCooldown(ctx context.Context, email string) bool {
+ key := passwordResetSentAtKey(email)
+ exists, err := c.rdb.Exists(ctx, key).Result()
+ return err == nil && exists > 0
+}
+
+func (c *emailCache) SetPasswordResetEmailCooldown(ctx context.Context, email string, ttl time.Duration) error {
+ key := passwordResetSentAtKey(email)
+ return c.rdb.Set(ctx, key, "1", ttl).Err()
+}
diff --git a/backend/internal/repository/totp_cache.go b/backend/internal/repository/totp_cache.go
new file mode 100644
index 00000000..2f4a8ab2
--- /dev/null
+++ b/backend/internal/repository/totp_cache.go
@@ -0,0 +1,149 @@
+package repository
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "time"
+
+ "github.com/redis/go-redis/v9"
+
+ "github.com/Wei-Shaw/sub2api/internal/service"
+)
+
+const (
+ totpSetupKeyPrefix = "totp:setup:"
+ totpLoginKeyPrefix = "totp:login:"
+ totpAttemptsKeyPrefix = "totp:attempts:"
+ totpAttemptsTTL = 15 * time.Minute
+)
+
+// TotpCache implements service.TotpCache using Redis
+type TotpCache struct {
+ rdb *redis.Client
+}
+
+// NewTotpCache creates a new TOTP cache
+func NewTotpCache(rdb *redis.Client) service.TotpCache {
+ return &TotpCache{rdb: rdb}
+}
+
+// GetSetupSession retrieves a TOTP setup session
+func (c *TotpCache) GetSetupSession(ctx context.Context, userID int64) (*service.TotpSetupSession, error) {
+ key := fmt.Sprintf("%s%d", totpSetupKeyPrefix, userID)
+ data, err := c.rdb.Get(ctx, key).Bytes()
+ if err != nil {
+ if err == redis.Nil {
+ return nil, nil
+ }
+ return nil, fmt.Errorf("get setup session: %w", err)
+ }
+
+ var session service.TotpSetupSession
+ if err := json.Unmarshal(data, &session); err != nil {
+ return nil, fmt.Errorf("unmarshal setup session: %w", err)
+ }
+
+ return &session, nil
+}
+
+// SetSetupSession stores a TOTP setup session
+func (c *TotpCache) SetSetupSession(ctx context.Context, userID int64, session *service.TotpSetupSession, ttl time.Duration) error {
+ key := fmt.Sprintf("%s%d", totpSetupKeyPrefix, userID)
+ data, err := json.Marshal(session)
+ if err != nil {
+ return fmt.Errorf("marshal setup session: %w", err)
+ }
+
+ if err := c.rdb.Set(ctx, key, data, ttl).Err(); err != nil {
+ return fmt.Errorf("set setup session: %w", err)
+ }
+
+ return nil
+}
+
+// DeleteSetupSession deletes a TOTP setup session
+func (c *TotpCache) DeleteSetupSession(ctx context.Context, userID int64) error {
+ key := fmt.Sprintf("%s%d", totpSetupKeyPrefix, userID)
+ return c.rdb.Del(ctx, key).Err()
+}
+
+// GetLoginSession retrieves a TOTP login session
+func (c *TotpCache) GetLoginSession(ctx context.Context, tempToken string) (*service.TotpLoginSession, error) {
+ key := totpLoginKeyPrefix + tempToken
+ data, err := c.rdb.Get(ctx, key).Bytes()
+ if err != nil {
+ if err == redis.Nil {
+ return nil, nil
+ }
+ return nil, fmt.Errorf("get login session: %w", err)
+ }
+
+ var session service.TotpLoginSession
+ if err := json.Unmarshal(data, &session); err != nil {
+ return nil, fmt.Errorf("unmarshal login session: %w", err)
+ }
+
+ return &session, nil
+}
+
+// SetLoginSession stores a TOTP login session
+func (c *TotpCache) SetLoginSession(ctx context.Context, tempToken string, session *service.TotpLoginSession, ttl time.Duration) error {
+ key := totpLoginKeyPrefix + tempToken
+ data, err := json.Marshal(session)
+ if err != nil {
+ return fmt.Errorf("marshal login session: %w", err)
+ }
+
+ if err := c.rdb.Set(ctx, key, data, ttl).Err(); err != nil {
+ return fmt.Errorf("set login session: %w", err)
+ }
+
+ return nil
+}
+
+// DeleteLoginSession deletes a TOTP login session
+func (c *TotpCache) DeleteLoginSession(ctx context.Context, tempToken string) error {
+ key := totpLoginKeyPrefix + tempToken
+ return c.rdb.Del(ctx, key).Err()
+}
+
+// IncrementVerifyAttempts increments the verify attempt counter
+func (c *TotpCache) IncrementVerifyAttempts(ctx context.Context, userID int64) (int, error) {
+ key := fmt.Sprintf("%s%d", totpAttemptsKeyPrefix, userID)
+
+ // Use pipeline for atomic increment and set TTL
+ pipe := c.rdb.Pipeline()
+ incrCmd := pipe.Incr(ctx, key)
+ pipe.Expire(ctx, key, totpAttemptsTTL)
+
+ if _, err := pipe.Exec(ctx); err != nil {
+ return 0, fmt.Errorf("increment verify attempts: %w", err)
+ }
+
+ count, err := incrCmd.Result()
+ if err != nil {
+ return 0, fmt.Errorf("get increment result: %w", err)
+ }
+
+ return int(count), nil
+}
+
+// GetVerifyAttempts gets the current verify attempt count
+func (c *TotpCache) GetVerifyAttempts(ctx context.Context, userID int64) (int, error) {
+ key := fmt.Sprintf("%s%d", totpAttemptsKeyPrefix, userID)
+ count, err := c.rdb.Get(ctx, key).Int()
+ if err != nil {
+ if err == redis.Nil {
+ return 0, nil
+ }
+ return 0, fmt.Errorf("get verify attempts: %w", err)
+ }
+ return count, nil
+}
+
+// ClearVerifyAttempts clears the verify attempt counter
+func (c *TotpCache) ClearVerifyAttempts(ctx context.Context, userID int64) error {
+ key := fmt.Sprintf("%s%d", totpAttemptsKeyPrefix, userID)
+ return c.rdb.Del(ctx, key).Err()
+}
diff --git a/backend/internal/repository/user_repo.go b/backend/internal/repository/user_repo.go
index 006a5464..fe5b645c 100644
--- a/backend/internal/repository/user_repo.go
+++ b/backend/internal/repository/user_repo.go
@@ -7,6 +7,7 @@ import (
"fmt"
"sort"
"strings"
+ "time"
dbent "github.com/Wei-Shaw/sub2api/ent"
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
@@ -466,3 +467,46 @@ func applyUserEntityToService(dst *service.User, src *dbent.User) {
dst.CreatedAt = src.CreatedAt
dst.UpdatedAt = src.UpdatedAt
}
+
+// UpdateTotpSecret 更新用户的 TOTP 加密密钥
+func (r *userRepository) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
+ client := clientFromContext(ctx, r.client)
+ update := client.User.UpdateOneID(userID)
+ if encryptedSecret == nil {
+ update = update.ClearTotpSecretEncrypted()
+ } else {
+ update = update.SetTotpSecretEncrypted(*encryptedSecret)
+ }
+ _, err := update.Save(ctx)
+ if err != nil {
+ return translatePersistenceError(err, service.ErrUserNotFound, nil)
+ }
+ return nil
+}
+
+// EnableTotp 启用用户的 TOTP 双因素认证
+func (r *userRepository) EnableTotp(ctx context.Context, userID int64) error {
+ client := clientFromContext(ctx, r.client)
+ _, err := client.User.UpdateOneID(userID).
+ SetTotpEnabled(true).
+ SetTotpEnabledAt(time.Now()).
+ Save(ctx)
+ if err != nil {
+ return translatePersistenceError(err, service.ErrUserNotFound, nil)
+ }
+ return nil
+}
+
+// DisableTotp 禁用用户的 TOTP 双因素认证
+func (r *userRepository) DisableTotp(ctx context.Context, userID int64) error {
+ client := clientFromContext(ctx, r.client)
+ _, err := client.User.UpdateOneID(userID).
+ SetTotpEnabled(false).
+ ClearTotpEnabledAt().
+ ClearTotpSecretEncrypted().
+ Save(ctx)
+ if err != nil {
+ return translatePersistenceError(err, service.ErrUserNotFound, nil)
+ }
+ return nil
+}
diff --git a/backend/internal/repository/user_subscription_repo.go b/backend/internal/repository/user_subscription_repo.go
index cd3b9db6..5a649846 100644
--- a/backend/internal/repository/user_subscription_repo.go
+++ b/backend/internal/repository/user_subscription_repo.go
@@ -190,7 +190,7 @@ func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID
return userSubscriptionEntitiesToService(subs), paginationResultFromTotal(int64(total), params), nil
}
-func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
+func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
client := clientFromContext(ctx, r.client)
q := client.UserSubscription.Query()
if userID != nil {
@@ -199,7 +199,31 @@ func (r *userSubscriptionRepository) List(ctx context.Context, params pagination
if groupID != nil {
q = q.Where(usersubscription.GroupIDEQ(*groupID))
}
- if status != "" {
+
+ // Status filtering with real-time expiration check
+ now := time.Now()
+ switch status {
+ case service.SubscriptionStatusActive:
+ // Active: status is active AND not yet expired
+ q = q.Where(
+ usersubscription.StatusEQ(service.SubscriptionStatusActive),
+ usersubscription.ExpiresAtGT(now),
+ )
+ case service.SubscriptionStatusExpired:
+ // Expired: status is expired OR (status is active but already expired)
+ q = q.Where(
+ usersubscription.Or(
+ usersubscription.StatusEQ(service.SubscriptionStatusExpired),
+ usersubscription.And(
+ usersubscription.StatusEQ(service.SubscriptionStatusActive),
+ usersubscription.ExpiresAtLTE(now),
+ ),
+ ),
+ )
+ case "":
+ // No filter
+ default:
+ // Other status (e.g., revoked)
q = q.Where(usersubscription.StatusEQ(status))
}
@@ -208,11 +232,28 @@ func (r *userSubscriptionRepository) List(ctx context.Context, params pagination
return nil, nil, err
}
+ // Apply sorting
+ q = q.WithUser().WithGroup().WithAssignedByUser()
+
+ // Determine sort field
+ var field string
+ switch sortBy {
+ case "expires_at":
+ field = usersubscription.FieldExpiresAt
+ case "status":
+ field = usersubscription.FieldStatus
+ default:
+ field = usersubscription.FieldCreatedAt
+ }
+
+ // Determine sort order (default: desc)
+ if sortOrder == "asc" && sortBy != "" {
+ q = q.Order(dbent.Asc(field))
+ } else {
+ q = q.Order(dbent.Desc(field))
+ }
+
subs, err := q.
- WithUser().
- WithGroup().
- WithAssignedByUser().
- Order(dbent.Desc(usersubscription.FieldCreatedAt)).
Offset(params.Offset()).
Limit(params.Limit()).
All(ctx)
diff --git a/backend/internal/repository/user_subscription_repo_integration_test.go b/backend/internal/repository/user_subscription_repo_integration_test.go
index 2099e5d8..60a5a378 100644
--- a/backend/internal/repository/user_subscription_repo_integration_test.go
+++ b/backend/internal/repository/user_subscription_repo_integration_test.go
@@ -271,7 +271,7 @@ func (s *UserSubscriptionRepoSuite) TestList_NoFilters() {
group := s.mustCreateGroup("g-list")
s.mustCreateSubscription(user.ID, group.ID, nil)
- subs, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, "")
+ subs, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, "", "", "")
s.Require().NoError(err, "List")
s.Require().Len(subs, 1)
s.Require().Equal(int64(1), page.Total)
@@ -285,7 +285,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByUserID() {
s.mustCreateSubscription(user1.ID, group.ID, nil)
s.mustCreateSubscription(user2.ID, group.ID, nil)
- subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, &user1.ID, nil, "")
+ subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, &user1.ID, nil, "", "", "")
s.Require().NoError(err)
s.Require().Len(subs, 1)
s.Require().Equal(user1.ID, subs[0].UserID)
@@ -299,7 +299,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByGroupID() {
s.mustCreateSubscription(user.ID, g1.ID, nil)
s.mustCreateSubscription(user.ID, g2.ID, nil)
- subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, &g1.ID, "")
+ subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, &g1.ID, "", "", "")
s.Require().NoError(err)
s.Require().Len(subs, 1)
s.Require().Equal(g1.ID, subs[0].GroupID)
@@ -320,7 +320,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByStatus() {
c.SetExpiresAt(time.Now().Add(-24 * time.Hour))
})
- subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, service.SubscriptionStatusExpired)
+ subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, service.SubscriptionStatusExpired, "", "")
s.Require().NoError(err)
s.Require().Len(subs, 1)
s.Require().Equal(service.SubscriptionStatusExpired, subs[0].Status)
diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go
index 7a8d85f4..3e1c05fc 100644
--- a/backend/internal/repository/wire.go
+++ b/backend/internal/repository/wire.go
@@ -82,6 +82,10 @@ var ProviderSet = wire.NewSet(
NewSchedulerCache,
NewSchedulerOutboxRepository,
NewProxyLatencyCache,
+ NewTotpCache,
+
+ // Encryptors
+ NewAESEncryptor,
// HTTP service ports (DI Strategy A: return interface directly)
NewTurnstileVerifier,
diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go
index 230a3c60..1deab421 100644
--- a/backend/internal/server/api_contract_test.go
+++ b/backend/internal/server/api_contract_test.go
@@ -193,20 +193,20 @@ func TestAPIContracts(t *testing.T) {
// 普通用户订阅接口不应包含 assigned_* / notes 等管理员字段。
deps.userSubRepo.SetByUserID(1, []service.UserSubscription{
{
- ID: 501,
- UserID: 1,
- GroupID: 10,
- StartsAt: deps.now,
- ExpiresAt: deps.now.Add(24 * time.Hour),
- Status: service.SubscriptionStatusActive,
+ ID: 501,
+ UserID: 1,
+ GroupID: 10,
+ StartsAt: deps.now,
+ ExpiresAt: time.Date(2099, 1, 2, 3, 4, 5, 0, time.UTC), // 使用未来日期避免 normalizeSubscriptionStatus 标记为过期
+ Status: service.SubscriptionStatusActive,
DailyUsageUSD: 1.23,
WeeklyUsageUSD: 2.34,
MonthlyUsageUSD: 3.45,
- AssignedBy: ptr(int64(999)),
- AssignedAt: deps.now,
- Notes: "admin-note",
- CreatedAt: deps.now,
- UpdatedAt: deps.now,
+ AssignedBy: ptr(int64(999)),
+ AssignedAt: deps.now,
+ Notes: "admin-note",
+ CreatedAt: deps.now,
+ UpdatedAt: deps.now,
},
})
},
@@ -222,7 +222,7 @@ func TestAPIContracts(t *testing.T) {
"user_id": 1,
"group_id": 10,
"starts_at": "2025-01-02T03:04:05Z",
- "expires_at": "2025-01-03T03:04:05Z",
+ "expires_at": "2099-01-02T03:04:05Z",
"status": "active",
"daily_window_start": null,
"weekly_window_start": null,
@@ -452,6 +452,9 @@ func TestAPIContracts(t *testing.T) {
"registration_enabled": true,
"email_verify_enabled": false,
"promo_code_enabled": true,
+ "password_reset_enabled": false,
+ "totp_enabled": false,
+ "totp_encryption_key_configured": false,
"smtp_host": "smtp.example.com",
"smtp_port": 587,
"smtp_username": "user",
@@ -595,7 +598,7 @@ func newContractDeps(t *testing.T) *contractDeps {
settingService := service.NewSettingService(settingRepo, cfg)
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil)
- authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil)
+ authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, nil)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil)
@@ -754,6 +757,18 @@ func (r *stubUserRepo) RemoveGroupFromAllowedGroups(ctx context.Context, groupID
return 0, errors.New("not implemented")
}
+func (r *stubUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
+ return errors.New("not implemented")
+}
+
+func (r *stubUserRepo) EnableTotp(ctx context.Context, userID int64) error {
+ return errors.New("not implemented")
+}
+
+func (r *stubUserRepo) DisableTotp(ctx context.Context, userID int64) error {
+ return errors.New("not implemented")
+}
+
type stubApiKeyCache struct{}
func (stubApiKeyCache) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) {
@@ -1176,7 +1191,7 @@ func (r *stubUserSubscriptionRepo) ListActiveByUserID(ctx context.Context, userI
func (stubUserSubscriptionRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
-func (stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
+func (stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (stubUserSubscriptionRepo) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {
diff --git a/backend/internal/server/middleware/api_key_auth_test.go b/backend/internal/server/middleware/api_key_auth_test.go
index 84398093..920ff93f 100644
--- a/backend/internal/server/middleware/api_key_auth_test.go
+++ b/backend/internal/server/middleware/api_key_auth_test.go
@@ -367,7 +367,7 @@ func (r *stubUserSubscriptionRepo) ListByGroupID(ctx context.Context, groupID in
return nil, nil, errors.New("not implemented")
}
-func (r *stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
+func (r *stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
diff --git a/backend/internal/server/routes/auth.go b/backend/internal/server/routes/auth.go
index aa691eba..33a88e82 100644
--- a/backend/internal/server/routes/auth.go
+++ b/backend/internal/server/routes/auth.go
@@ -26,11 +26,20 @@ func RegisterAuthRoutes(
{
auth.POST("/register", h.Auth.Register)
auth.POST("/login", h.Auth.Login)
+ auth.POST("/login/2fa", h.Auth.Login2FA)
auth.POST("/send-verify-code", h.Auth.SendVerifyCode)
// 优惠码验证接口添加速率限制:每分钟最多 10 次(Redis 故障时 fail-close)
auth.POST("/validate-promo-code", rateLimiter.LimitWithOptions("validate-promo", 10, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}), h.Auth.ValidatePromoCode)
+ // 忘记密码接口添加速率限制:每分钟最多 5 次(Redis 故障时 fail-close)
+ auth.POST("/forgot-password", rateLimiter.LimitWithOptions("forgot-password", 5, time.Minute, middleware.RateLimitOptions{
+ FailureMode: middleware.RateLimitFailClose,
+ }), h.Auth.ForgotPassword)
+ // 重置密码接口添加速率限制:每分钟最多 10 次(Redis 故障时 fail-close)
+ auth.POST("/reset-password", rateLimiter.LimitWithOptions("reset-password", 10, time.Minute, middleware.RateLimitOptions{
+ FailureMode: middleware.RateLimitFailClose,
+ }), h.Auth.ResetPassword)
auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart)
auth.GET("/oauth/linuxdo/callback", h.Auth.LinuxDoOAuthCallback)
}
diff --git a/backend/internal/server/routes/user.go b/backend/internal/server/routes/user.go
index ad2166fe..83cf31c4 100644
--- a/backend/internal/server/routes/user.go
+++ b/backend/internal/server/routes/user.go
@@ -22,6 +22,17 @@ func RegisterUserRoutes(
user.GET("/profile", h.User.GetProfile)
user.PUT("/password", h.User.ChangePassword)
user.PUT("", h.User.UpdateProfile)
+
+ // TOTP 双因素认证
+ totp := user.Group("/totp")
+ {
+ totp.GET("/status", h.Totp.GetStatus)
+ totp.GET("/verification-method", h.Totp.GetVerificationMethod)
+ totp.POST("/send-code", h.Totp.SendVerifyCode)
+ totp.POST("/setup", h.Totp.InitiateSetup)
+ totp.POST("/enable", h.Totp.Enable)
+ totp.POST("/disable", h.Totp.Disable)
+ }
}
// API Key管理
diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go
index e710560f..7b958838 100644
--- a/backend/internal/service/account.go
+++ b/backend/internal/service/account.go
@@ -197,6 +197,35 @@ func (a *Account) GetCredentialAsTime(key string) *time.Time {
return nil
}
+// GetCredentialAsInt64 解析凭证中的 int64 字段
+// 用于读取 _token_version 等内部字段
+func (a *Account) GetCredentialAsInt64(key string) int64 {
+ if a == nil || a.Credentials == nil {
+ return 0
+ }
+ val, ok := a.Credentials[key]
+ if !ok || val == nil {
+ return 0
+ }
+ switch v := val.(type) {
+ case int64:
+ return v
+ case float64:
+ return int64(v)
+ case int:
+ return int64(v)
+ case json.Number:
+ if i, err := v.Int64(); err == nil {
+ return i
+ }
+ case string:
+ if i, err := strconv.ParseInt(strings.TrimSpace(v), 10, 64); err == nil {
+ return i
+ }
+ }
+ return 0
+}
+
func (a *Account) IsTempUnschedulableEnabled() bool {
if a.Credentials == nil {
return false
diff --git a/backend/internal/service/admin_service_delete_test.go b/backend/internal/service/admin_service_delete_test.go
index afa433af..6472ccbb 100644
--- a/backend/internal/service/admin_service_delete_test.go
+++ b/backend/internal/service/admin_service_delete_test.go
@@ -93,6 +93,18 @@ func (s *userRepoStub) RemoveGroupFromAllowedGroups(ctx context.Context, groupID
panic("unexpected RemoveGroupFromAllowedGroups call")
}
+func (s *userRepoStub) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
+ panic("unexpected UpdateTotpSecret call")
+}
+
+func (s *userRepoStub) EnableTotp(ctx context.Context, userID int64) error {
+ panic("unexpected EnableTotp call")
+}
+
+func (s *userRepoStub) DisableTotp(ctx context.Context, userID int64) error {
+ panic("unexpected DisableTotp call")
+}
+
type groupRepoStub struct {
affectedUserIDs []int64
deleteErr error
diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go
index 043f338d..3b847bcb 100644
--- a/backend/internal/service/antigravity_gateway_service.go
+++ b/backend/internal/service/antigravity_gateway_service.go
@@ -1305,6 +1305,14 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
return nil, err
}
+ // 清理 Schema
+ if cleanedBody, err := cleanGeminiRequest(injectedBody); err == nil {
+ injectedBody = cleanedBody
+ log.Printf("[Antigravity] Cleaned request schema in forwarded request for account %s", account.Name)
+ } else {
+ log.Printf("[Antigravity] Failed to clean schema: %v", err)
+ }
+
// 包装请求
wrappedBody, err := s.wrapV1InternalRequest(projectID, mappedModel, injectedBody)
if err != nil {
@@ -1705,6 +1713,19 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
if u := extractGeminiUsage(parsed); u != nil {
usage = u
}
+ // Check for MALFORMED_FUNCTION_CALL
+ if candidates, ok := parsed["candidates"].([]any); ok && len(candidates) > 0 {
+ if cand, ok := candidates[0].(map[string]any); ok {
+ if fr, ok := cand["finishReason"].(string); ok && fr == "MALFORMED_FUNCTION_CALL" {
+ log.Printf("[Antigravity] MALFORMED_FUNCTION_CALL detected in forward stream")
+ if content, ok := cand["content"]; ok {
+ if b, err := json.Marshal(content); err == nil {
+ log.Printf("[Antigravity] Malformed content: %s", string(b))
+ }
+ }
+ }
+ }
+ }
}
if firstTokenMs == nil {
@@ -1854,6 +1875,20 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont
usage = u
}
+ // Check for MALFORMED_FUNCTION_CALL
+ if candidates, ok := parsed["candidates"].([]any); ok && len(candidates) > 0 {
+ if cand, ok := candidates[0].(map[string]any); ok {
+ if fr, ok := cand["finishReason"].(string); ok && fr == "MALFORMED_FUNCTION_CALL" {
+ log.Printf("[Antigravity] MALFORMED_FUNCTION_CALL detected in forward non-stream collect")
+ if content, ok := cand["content"]; ok {
+ if b, err := json.Marshal(content); err == nil {
+ log.Printf("[Antigravity] Malformed content: %s", string(b))
+ }
+ }
+ }
+ }
+ }
+
// 保留最后一个有 parts 的响应
if parts := extractGeminiParts(parsed); len(parts) > 0 {
lastWithParts = parsed
@@ -1950,6 +1985,58 @@ func getOrCreateGeminiParts(response map[string]any) (result map[string]any, exi
return result, existingParts, setParts
}
+// mergeCollectedPartsToResponse 将收集的所有 parts 合并到 Gemini 响应中
+// 这个函数会合并所有类型的 parts:text、thinking、functionCall、inlineData 等
+// 保持原始顺序,只合并连续的普通 text parts
+func mergeCollectedPartsToResponse(response map[string]any, collectedParts []map[string]any) map[string]any {
+ if len(collectedParts) == 0 {
+ return response
+ }
+
+ result, _, setParts := getOrCreateGeminiParts(response)
+
+ // 合并策略:
+ // 1. 保持原始顺序
+ // 2. 连续的普通 text parts 合并为一个
+ // 3. thinking、functionCall、inlineData 等保持原样
+ var mergedParts []any
+ var textBuffer strings.Builder
+
+ flushTextBuffer := func() {
+ if textBuffer.Len() > 0 {
+ mergedParts = append(mergedParts, map[string]any{
+ "text": textBuffer.String(),
+ })
+ textBuffer.Reset()
+ }
+ }
+
+ for _, part := range collectedParts {
+ // 检查是否是普通 text part
+ if text, ok := part["text"].(string); ok {
+ // 检查是否有 thought 标记
+ if thought, _ := part["thought"].(bool); thought {
+ // thinking part,先刷新 text buffer,然后保留原样
+ flushTextBuffer()
+ mergedParts = append(mergedParts, part)
+ } else {
+ // 普通 text,累积到 buffer
+ _, _ = textBuffer.WriteString(text)
+ }
+ } else {
+ // 非 text part(functionCall、inlineData 等),先刷新 text buffer,然后保留原样
+ flushTextBuffer()
+ mergedParts = append(mergedParts, part)
+ }
+ }
+
+ // 刷新剩余的 text
+ flushTextBuffer()
+
+ setParts(mergedParts)
+ return result
+}
+
// mergeImagePartsToResponse 将收集到的图片 parts 合并到 Gemini 响应中
func mergeImagePartsToResponse(response map[string]any, imageParts []map[string]any) map[string]any {
if len(imageParts) == 0 {
@@ -2133,6 +2220,7 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont
var firstTokenMs *int
var last map[string]any
var lastWithParts map[string]any
+ var collectedParts []map[string]any // 收集所有 parts(包括 text、thinking、functionCall、inlineData 等)
type scanEvent struct {
line string
@@ -2227,9 +2315,12 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont
last = parsed
- // 保留最后一个有 parts 的响应
+ // 保留最后一个有 parts 的响应,并收集所有 parts
if parts := extractGeminiParts(parsed); len(parts) > 0 {
lastWithParts = parsed
+
+ // 收集所有 parts(text、thinking、functionCall、inlineData 等)
+ collectedParts = append(collectedParts, parts...)
}
case <-intervalCh:
@@ -2252,6 +2343,11 @@ returnResponse:
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Empty response from upstream")
}
+ // 将收集的所有 parts 合并到最终响应中
+ if len(collectedParts) > 0 {
+ finalResponse = mergeCollectedPartsToResponse(finalResponse, collectedParts)
+ }
+
// 序列化为 JSON(Gemini 格式)
geminiBody, err := json.Marshal(finalResponse)
if err != nil {
@@ -2459,3 +2555,55 @@ func isImageGenerationModel(model string) bool {
modelLower == "gemini-2.5-flash-image-preview" ||
strings.HasPrefix(modelLower, "gemini-2.5-flash-image-")
}
+
+// cleanGeminiRequest 清理 Gemini 请求体中的 Schema
+func cleanGeminiRequest(body []byte) ([]byte, error) {
+ var payload map[string]any
+ if err := json.Unmarshal(body, &payload); err != nil {
+ return nil, err
+ }
+
+ modified := false
+
+ // 1. 清理 Tools
+ if tools, ok := payload["tools"].([]any); ok && len(tools) > 0 {
+ for _, t := range tools {
+ toolMap, ok := t.(map[string]any)
+ if !ok {
+ continue
+ }
+
+ // function_declarations (snake_case) or functionDeclarations (camelCase)
+ var funcs []any
+ if f, ok := toolMap["functionDeclarations"].([]any); ok {
+ funcs = f
+ } else if f, ok := toolMap["function_declarations"].([]any); ok {
+ funcs = f
+ }
+
+ if len(funcs) == 0 {
+ continue
+ }
+
+ for _, f := range funcs {
+ funcMap, ok := f.(map[string]any)
+ if !ok {
+ continue
+ }
+
+ if params, ok := funcMap["parameters"].(map[string]any); ok {
+ antigravity.DeepCleanUndefined(params)
+ cleaned := antigravity.CleanJSONSchema(params)
+ funcMap["parameters"] = cleaned
+ modified = true
+ }
+ }
+ }
+ }
+
+ if !modified {
+ return body, nil
+ }
+
+ return json.Marshal(payload)
+}
diff --git a/backend/internal/service/antigravity_oauth_service.go b/backend/internal/service/antigravity_oauth_service.go
index 52293cd5..fa8379ed 100644
--- a/backend/internal/service/antigravity_oauth_service.go
+++ b/backend/internal/service/antigravity_oauth_service.go
@@ -142,12 +142,13 @@ func (s *AntigravityOAuthService) ExchangeCode(ctx context.Context, input *Antig
result.Email = userInfo.Email
}
- // 获取 project_id(部分账户类型可能没有)
- loadResp, _, err := client.LoadCodeAssist(ctx, tokenResp.AccessToken)
- if err != nil {
- fmt.Printf("[AntigravityOAuth] 警告: 获取 project_id 失败: %v\n", err)
- } else if loadResp != nil && loadResp.CloudAICompanionProject != "" {
- result.ProjectID = loadResp.CloudAICompanionProject
+ // 获取 project_id(部分账户类型可能没有),失败时重试
+ projectID, loadErr := s.loadProjectIDWithRetry(ctx, tokenResp.AccessToken, proxyURL, 3)
+ if loadErr != nil {
+ fmt.Printf("[AntigravityOAuth] 警告: 获取 project_id 失败(重试后): %v\n", loadErr)
+ result.ProjectIDMissing = true
+ } else {
+ result.ProjectID = projectID
}
return result, nil
@@ -237,21 +238,60 @@ func (s *AntigravityOAuthService) RefreshAccountToken(ctx context.Context, accou
tokenInfo.Email = existingEmail
}
- // 每次刷新都调用 LoadCodeAssist 获取 project_id
- client := antigravity.NewClient(proxyURL)
- loadResp, _, err := client.LoadCodeAssist(ctx, tokenInfo.AccessToken)
- if err != nil || loadResp == nil || loadResp.CloudAICompanionProject == "" {
- // LoadCodeAssist 失败或返回空,保留原有 project_id,标记缺失
- existingProjectID := strings.TrimSpace(account.GetCredential("project_id"))
+ // 每次刷新都调用 LoadCodeAssist 获取 project_id,失败时重试
+ existingProjectID := strings.TrimSpace(account.GetCredential("project_id"))
+ projectID, loadErr := s.loadProjectIDWithRetry(ctx, tokenInfo.AccessToken, proxyURL, 3)
+
+ if loadErr != nil {
+ // LoadCodeAssist 失败,保留原有 project_id
tokenInfo.ProjectID = existingProjectID
- tokenInfo.ProjectIDMissing = true
+ // 只有从未获取过 project_id 且本次也获取失败时,才标记为真正缺失
+ // 如果之前有 project_id,本次只是临时故障,不应标记为错误
+ if existingProjectID == "" {
+ tokenInfo.ProjectIDMissing = true
+ }
} else {
- tokenInfo.ProjectID = loadResp.CloudAICompanionProject
+ tokenInfo.ProjectID = projectID
}
return tokenInfo, nil
}
+// loadProjectIDWithRetry 带重试机制获取 project_id
+// 返回 project_id 和错误,失败时会重试指定次数
+func (s *AntigravityOAuthService) loadProjectIDWithRetry(ctx context.Context, accessToken, proxyURL string, maxRetries int) (string, error) {
+ var lastErr error
+
+ for attempt := 0; attempt <= maxRetries; attempt++ {
+ if attempt > 0 {
+ // 指数退避:1s, 2s, 4s
+ backoff := time.Duration(1< 密码重置请求 您已请求重置密码。请点击下方按钮设置新密码: 此链接将在 30 分钟后失效。 如果您没有请求重置密码,请忽略此邮件。您的密码将保持不变。 如果按钮无法点击,请复制以下链接到浏览器中打开: %s%s
+