diff --git a/backend/ent/apikey.go b/backend/ent/apikey.go index fe3ad0cf..95586017 100644 --- a/backend/ent/apikey.go +++ b/backend/ent/apikey.go @@ -3,6 +3,7 @@ package ent import ( + "encoding/json" "fmt" "strings" "time" @@ -35,6 +36,10 @@ type APIKey struct { GroupID *int64 `json:"group_id,omitempty"` // Status holds the value of the "status" field. Status string `json:"status,omitempty"` + // Allowed IPs/CIDRs, e.g. ["192.168.1.100", "10.0.0.0/8"] + IPWhitelist []string `json:"ip_whitelist,omitempty"` + // Blocked IPs/CIDRs + IPBlacklist []string `json:"ip_blacklist,omitempty"` // Edges holds the relations/edges for other nodes in the graph. // The values are being populated by the APIKeyQuery when eager-loading is set. Edges APIKeyEdges `json:"edges"` @@ -90,6 +95,8 @@ func (*APIKey) scanValues(columns []string) ([]any, error) { values := make([]any, len(columns)) for i := range columns { switch columns[i] { + case apikey.FieldIPWhitelist, apikey.FieldIPBlacklist: + values[i] = new([]byte) case apikey.FieldID, apikey.FieldUserID, apikey.FieldGroupID: values[i] = new(sql.NullInt64) case apikey.FieldKey, apikey.FieldName, apikey.FieldStatus: @@ -167,6 +174,22 @@ func (_m *APIKey) assignValues(columns []string, values []any) error { } else if value.Valid { _m.Status = value.String } + case apikey.FieldIPWhitelist: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field ip_whitelist", values[i]) + } else if value != nil && len(*value) > 0 { + if err := json.Unmarshal(*value, &_m.IPWhitelist); err != nil { + return fmt.Errorf("unmarshal field ip_whitelist: %w", err) + } + } + case apikey.FieldIPBlacklist: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field ip_blacklist", values[i]) + } else if value != nil && len(*value) > 0 { + if err := json.Unmarshal(*value, &_m.IPBlacklist); err != nil { + return fmt.Errorf("unmarshal field ip_blacklist: %w", err) + } + } default: _m.selectValues.Set(columns[i], values[i]) } @@ -245,6 +268,12 @@ func (_m *APIKey) String() string { builder.WriteString(", ") builder.WriteString("status=") builder.WriteString(_m.Status) + builder.WriteString(", ") + builder.WriteString("ip_whitelist=") + builder.WriteString(fmt.Sprintf("%v", _m.IPWhitelist)) + builder.WriteString(", ") + builder.WriteString("ip_blacklist=") + builder.WriteString(fmt.Sprintf("%v", _m.IPBlacklist)) builder.WriteByte(')') return builder.String() } diff --git a/backend/ent/apikey/apikey.go b/backend/ent/apikey/apikey.go index 91f7d620..564cddb1 100644 --- a/backend/ent/apikey/apikey.go +++ b/backend/ent/apikey/apikey.go @@ -31,6 +31,10 @@ const ( FieldGroupID = "group_id" // FieldStatus holds the string denoting the status field in the database. FieldStatus = "status" + // FieldIPWhitelist holds the string denoting the ip_whitelist field in the database. + FieldIPWhitelist = "ip_whitelist" + // FieldIPBlacklist holds the string denoting the ip_blacklist field in the database. + FieldIPBlacklist = "ip_blacklist" // EdgeUser holds the string denoting the user edge name in mutations. EdgeUser = "user" // EdgeGroup holds the string denoting the group edge name in mutations. @@ -73,6 +77,8 @@ var Columns = []string{ FieldName, FieldGroupID, FieldStatus, + FieldIPWhitelist, + FieldIPBlacklist, } // ValidColumn reports if the column name is valid (part of the table columns). diff --git a/backend/ent/apikey/where.go b/backend/ent/apikey/where.go index 5e739006..5152867f 100644 --- a/backend/ent/apikey/where.go +++ b/backend/ent/apikey/where.go @@ -470,6 +470,26 @@ func StatusContainsFold(v string) predicate.APIKey { return predicate.APIKey(sql.FieldContainsFold(FieldStatus, v)) } +// IPWhitelistIsNil applies the IsNil predicate on the "ip_whitelist" field. +func IPWhitelistIsNil() predicate.APIKey { + return predicate.APIKey(sql.FieldIsNull(FieldIPWhitelist)) +} + +// IPWhitelistNotNil applies the NotNil predicate on the "ip_whitelist" field. +func IPWhitelistNotNil() predicate.APIKey { + return predicate.APIKey(sql.FieldNotNull(FieldIPWhitelist)) +} + +// IPBlacklistIsNil applies the IsNil predicate on the "ip_blacklist" field. +func IPBlacklistIsNil() predicate.APIKey { + return predicate.APIKey(sql.FieldIsNull(FieldIPBlacklist)) +} + +// IPBlacklistNotNil applies the NotNil predicate on the "ip_blacklist" field. +func IPBlacklistNotNil() predicate.APIKey { + return predicate.APIKey(sql.FieldNotNull(FieldIPBlacklist)) +} + // HasUser applies the HasEdge predicate on the "user" edge. func HasUser() predicate.APIKey { return predicate.APIKey(func(s *sql.Selector) { diff --git a/backend/ent/apikey_create.go b/backend/ent/apikey_create.go index 2098872c..d5363be5 100644 --- a/backend/ent/apikey_create.go +++ b/backend/ent/apikey_create.go @@ -113,6 +113,18 @@ func (_c *APIKeyCreate) SetNillableStatus(v *string) *APIKeyCreate { return _c } +// SetIPWhitelist sets the "ip_whitelist" field. +func (_c *APIKeyCreate) SetIPWhitelist(v []string) *APIKeyCreate { + _c.mutation.SetIPWhitelist(v) + return _c +} + +// SetIPBlacklist sets the "ip_blacklist" field. +func (_c *APIKeyCreate) SetIPBlacklist(v []string) *APIKeyCreate { + _c.mutation.SetIPBlacklist(v) + return _c +} + // SetUser sets the "user" edge to the User entity. func (_c *APIKeyCreate) SetUser(v *User) *APIKeyCreate { return _c.SetUserID(v.ID) @@ -285,6 +297,14 @@ func (_c *APIKeyCreate) createSpec() (*APIKey, *sqlgraph.CreateSpec) { _spec.SetField(apikey.FieldStatus, field.TypeString, value) _node.Status = value } + if value, ok := _c.mutation.IPWhitelist(); ok { + _spec.SetField(apikey.FieldIPWhitelist, field.TypeJSON, value) + _node.IPWhitelist = value + } + if value, ok := _c.mutation.IPBlacklist(); ok { + _spec.SetField(apikey.FieldIPBlacklist, field.TypeJSON, value) + _node.IPBlacklist = value + } if nodes := _c.mutation.UserIDs(); len(nodes) > 0 { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.M2O, @@ -483,6 +503,42 @@ func (u *APIKeyUpsert) UpdateStatus() *APIKeyUpsert { return u } +// SetIPWhitelist sets the "ip_whitelist" field. +func (u *APIKeyUpsert) SetIPWhitelist(v []string) *APIKeyUpsert { + u.Set(apikey.FieldIPWhitelist, v) + return u +} + +// UpdateIPWhitelist sets the "ip_whitelist" field to the value that was provided on create. +func (u *APIKeyUpsert) UpdateIPWhitelist() *APIKeyUpsert { + u.SetExcluded(apikey.FieldIPWhitelist) + return u +} + +// ClearIPWhitelist clears the value of the "ip_whitelist" field. +func (u *APIKeyUpsert) ClearIPWhitelist() *APIKeyUpsert { + u.SetNull(apikey.FieldIPWhitelist) + return u +} + +// SetIPBlacklist sets the "ip_blacklist" field. +func (u *APIKeyUpsert) SetIPBlacklist(v []string) *APIKeyUpsert { + u.Set(apikey.FieldIPBlacklist, v) + return u +} + +// UpdateIPBlacklist sets the "ip_blacklist" field to the value that was provided on create. +func (u *APIKeyUpsert) UpdateIPBlacklist() *APIKeyUpsert { + u.SetExcluded(apikey.FieldIPBlacklist) + return u +} + +// ClearIPBlacklist clears the value of the "ip_blacklist" field. +func (u *APIKeyUpsert) ClearIPBlacklist() *APIKeyUpsert { + u.SetNull(apikey.FieldIPBlacklist) + return u +} + // UpdateNewValues updates the mutable fields using the new values that were set on create. // Using this option is equivalent to using: // @@ -640,6 +696,48 @@ func (u *APIKeyUpsertOne) UpdateStatus() *APIKeyUpsertOne { }) } +// SetIPWhitelist sets the "ip_whitelist" field. +func (u *APIKeyUpsertOne) SetIPWhitelist(v []string) *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.SetIPWhitelist(v) + }) +} + +// UpdateIPWhitelist sets the "ip_whitelist" field to the value that was provided on create. +func (u *APIKeyUpsertOne) UpdateIPWhitelist() *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateIPWhitelist() + }) +} + +// ClearIPWhitelist clears the value of the "ip_whitelist" field. +func (u *APIKeyUpsertOne) ClearIPWhitelist() *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.ClearIPWhitelist() + }) +} + +// SetIPBlacklist sets the "ip_blacklist" field. +func (u *APIKeyUpsertOne) SetIPBlacklist(v []string) *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.SetIPBlacklist(v) + }) +} + +// UpdateIPBlacklist sets the "ip_blacklist" field to the value that was provided on create. +func (u *APIKeyUpsertOne) UpdateIPBlacklist() *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateIPBlacklist() + }) +} + +// ClearIPBlacklist clears the value of the "ip_blacklist" field. +func (u *APIKeyUpsertOne) ClearIPBlacklist() *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.ClearIPBlacklist() + }) +} + // Exec executes the query. func (u *APIKeyUpsertOne) Exec(ctx context.Context) error { if len(u.create.conflict) == 0 { @@ -963,6 +1061,48 @@ func (u *APIKeyUpsertBulk) UpdateStatus() *APIKeyUpsertBulk { }) } +// SetIPWhitelist sets the "ip_whitelist" field. +func (u *APIKeyUpsertBulk) SetIPWhitelist(v []string) *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.SetIPWhitelist(v) + }) +} + +// UpdateIPWhitelist sets the "ip_whitelist" field to the value that was provided on create. +func (u *APIKeyUpsertBulk) UpdateIPWhitelist() *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateIPWhitelist() + }) +} + +// ClearIPWhitelist clears the value of the "ip_whitelist" field. +func (u *APIKeyUpsertBulk) ClearIPWhitelist() *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.ClearIPWhitelist() + }) +} + +// SetIPBlacklist sets the "ip_blacklist" field. +func (u *APIKeyUpsertBulk) SetIPBlacklist(v []string) *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.SetIPBlacklist(v) + }) +} + +// UpdateIPBlacklist sets the "ip_blacklist" field to the value that was provided on create. +func (u *APIKeyUpsertBulk) UpdateIPBlacklist() *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateIPBlacklist() + }) +} + +// ClearIPBlacklist clears the value of the "ip_blacklist" field. +func (u *APIKeyUpsertBulk) ClearIPBlacklist() *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.ClearIPBlacklist() + }) +} + // Exec executes the query. func (u *APIKeyUpsertBulk) Exec(ctx context.Context) error { if u.create.err != nil { diff --git a/backend/ent/apikey_update.go b/backend/ent/apikey_update.go index 4a16369b..9ae332a8 100644 --- a/backend/ent/apikey_update.go +++ b/backend/ent/apikey_update.go @@ -10,6 +10,7 @@ import ( "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/dialect/sql/sqljson" "entgo.io/ent/schema/field" "github.com/Wei-Shaw/sub2api/ent/apikey" "github.com/Wei-Shaw/sub2api/ent/group" @@ -133,6 +134,42 @@ func (_u *APIKeyUpdate) SetNillableStatus(v *string) *APIKeyUpdate { return _u } +// SetIPWhitelist sets the "ip_whitelist" field. +func (_u *APIKeyUpdate) SetIPWhitelist(v []string) *APIKeyUpdate { + _u.mutation.SetIPWhitelist(v) + return _u +} + +// AppendIPWhitelist appends value to the "ip_whitelist" field. +func (_u *APIKeyUpdate) AppendIPWhitelist(v []string) *APIKeyUpdate { + _u.mutation.AppendIPWhitelist(v) + return _u +} + +// ClearIPWhitelist clears the value of the "ip_whitelist" field. +func (_u *APIKeyUpdate) ClearIPWhitelist() *APIKeyUpdate { + _u.mutation.ClearIPWhitelist() + return _u +} + +// SetIPBlacklist sets the "ip_blacklist" field. +func (_u *APIKeyUpdate) SetIPBlacklist(v []string) *APIKeyUpdate { + _u.mutation.SetIPBlacklist(v) + return _u +} + +// AppendIPBlacklist appends value to the "ip_blacklist" field. +func (_u *APIKeyUpdate) AppendIPBlacklist(v []string) *APIKeyUpdate { + _u.mutation.AppendIPBlacklist(v) + return _u +} + +// ClearIPBlacklist clears the value of the "ip_blacklist" field. +func (_u *APIKeyUpdate) ClearIPBlacklist() *APIKeyUpdate { + _u.mutation.ClearIPBlacklist() + return _u +} + // SetUser sets the "user" edge to the User entity. func (_u *APIKeyUpdate) SetUser(v *User) *APIKeyUpdate { return _u.SetUserID(v.ID) @@ -291,6 +328,28 @@ func (_u *APIKeyUpdate) sqlSave(ctx context.Context) (_node int, err error) { if value, ok := _u.mutation.Status(); ok { _spec.SetField(apikey.FieldStatus, field.TypeString, value) } + if value, ok := _u.mutation.IPWhitelist(); ok { + _spec.SetField(apikey.FieldIPWhitelist, field.TypeJSON, value) + } + if value, ok := _u.mutation.AppendedIPWhitelist(); ok { + _spec.AddModifier(func(u *sql.UpdateBuilder) { + sqljson.Append(u, apikey.FieldIPWhitelist, value) + }) + } + if _u.mutation.IPWhitelistCleared() { + _spec.ClearField(apikey.FieldIPWhitelist, field.TypeJSON) + } + if value, ok := _u.mutation.IPBlacklist(); ok { + _spec.SetField(apikey.FieldIPBlacklist, field.TypeJSON, value) + } + if value, ok := _u.mutation.AppendedIPBlacklist(); ok { + _spec.AddModifier(func(u *sql.UpdateBuilder) { + sqljson.Append(u, apikey.FieldIPBlacklist, value) + }) + } + if _u.mutation.IPBlacklistCleared() { + _spec.ClearField(apikey.FieldIPBlacklist, field.TypeJSON) + } if _u.mutation.UserCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.M2O, @@ -516,6 +575,42 @@ func (_u *APIKeyUpdateOne) SetNillableStatus(v *string) *APIKeyUpdateOne { return _u } +// SetIPWhitelist sets the "ip_whitelist" field. +func (_u *APIKeyUpdateOne) SetIPWhitelist(v []string) *APIKeyUpdateOne { + _u.mutation.SetIPWhitelist(v) + return _u +} + +// AppendIPWhitelist appends value to the "ip_whitelist" field. +func (_u *APIKeyUpdateOne) AppendIPWhitelist(v []string) *APIKeyUpdateOne { + _u.mutation.AppendIPWhitelist(v) + return _u +} + +// ClearIPWhitelist clears the value of the "ip_whitelist" field. +func (_u *APIKeyUpdateOne) ClearIPWhitelist() *APIKeyUpdateOne { + _u.mutation.ClearIPWhitelist() + return _u +} + +// SetIPBlacklist sets the "ip_blacklist" field. +func (_u *APIKeyUpdateOne) SetIPBlacklist(v []string) *APIKeyUpdateOne { + _u.mutation.SetIPBlacklist(v) + return _u +} + +// AppendIPBlacklist appends value to the "ip_blacklist" field. +func (_u *APIKeyUpdateOne) AppendIPBlacklist(v []string) *APIKeyUpdateOne { + _u.mutation.AppendIPBlacklist(v) + return _u +} + +// ClearIPBlacklist clears the value of the "ip_blacklist" field. +func (_u *APIKeyUpdateOne) ClearIPBlacklist() *APIKeyUpdateOne { + _u.mutation.ClearIPBlacklist() + return _u +} + // SetUser sets the "user" edge to the User entity. func (_u *APIKeyUpdateOne) SetUser(v *User) *APIKeyUpdateOne { return _u.SetUserID(v.ID) @@ -704,6 +799,28 @@ func (_u *APIKeyUpdateOne) sqlSave(ctx context.Context) (_node *APIKey, err erro if value, ok := _u.mutation.Status(); ok { _spec.SetField(apikey.FieldStatus, field.TypeString, value) } + if value, ok := _u.mutation.IPWhitelist(); ok { + _spec.SetField(apikey.FieldIPWhitelist, field.TypeJSON, value) + } + if value, ok := _u.mutation.AppendedIPWhitelist(); ok { + _spec.AddModifier(func(u *sql.UpdateBuilder) { + sqljson.Append(u, apikey.FieldIPWhitelist, value) + }) + } + if _u.mutation.IPWhitelistCleared() { + _spec.ClearField(apikey.FieldIPWhitelist, field.TypeJSON) + } + if value, ok := _u.mutation.IPBlacklist(); ok { + _spec.SetField(apikey.FieldIPBlacklist, field.TypeJSON, value) + } + if value, ok := _u.mutation.AppendedIPBlacklist(); ok { + _spec.AddModifier(func(u *sql.UpdateBuilder) { + sqljson.Append(u, apikey.FieldIPBlacklist, value) + }) + } + if _u.mutation.IPBlacklistCleared() { + _spec.ClearField(apikey.FieldIPBlacklist, field.TypeJSON) + } if _u.mutation.UserCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.M2O, diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go index 13081e31..fdde0cd1 100644 --- a/backend/ent/migrate/schema.go +++ b/backend/ent/migrate/schema.go @@ -18,6 +18,8 @@ var ( {Name: "key", Type: field.TypeString, Unique: true, Size: 128}, {Name: "name", Type: field.TypeString, Size: 100}, {Name: "status", Type: field.TypeString, Size: 20, Default: "active"}, + {Name: "ip_whitelist", Type: field.TypeJSON, Nullable: true}, + {Name: "ip_blacklist", Type: field.TypeJSON, Nullable: true}, {Name: "group_id", Type: field.TypeInt64, Nullable: true}, {Name: "user_id", Type: field.TypeInt64}, } @@ -29,13 +31,13 @@ var ( ForeignKeys: []*schema.ForeignKey{ { Symbol: "api_keys_groups_api_keys", - Columns: []*schema.Column{APIKeysColumns[7]}, + Columns: []*schema.Column{APIKeysColumns[9]}, RefColumns: []*schema.Column{GroupsColumns[0]}, OnDelete: schema.SetNull, }, { Symbol: "api_keys_users_api_keys", - Columns: []*schema.Column{APIKeysColumns[8]}, + Columns: []*schema.Column{APIKeysColumns[10]}, RefColumns: []*schema.Column{UsersColumns[0]}, OnDelete: schema.NoAction, }, @@ -44,12 +46,12 @@ var ( { Name: "apikey_user_id", Unique: false, - Columns: []*schema.Column{APIKeysColumns[8]}, + Columns: []*schema.Column{APIKeysColumns[10]}, }, { Name: "apikey_group_id", Unique: false, - Columns: []*schema.Column{APIKeysColumns[7]}, + Columns: []*schema.Column{APIKeysColumns[9]}, }, { Name: "apikey_status", @@ -376,6 +378,7 @@ var ( {Name: "duration_ms", Type: field.TypeInt, Nullable: true}, {Name: "first_token_ms", Type: field.TypeInt, Nullable: true}, {Name: "user_agent", Type: field.TypeString, Nullable: true, Size: 512}, + {Name: "ip_address", Type: field.TypeString, Nullable: true, Size: 45}, {Name: "image_count", Type: field.TypeInt, Default: 0}, {Name: "image_size", Type: field.TypeString, Nullable: true, Size: 10}, {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, @@ -393,31 +396,31 @@ var ( ForeignKeys: []*schema.ForeignKey{ { Symbol: "usage_logs_api_keys_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[24]}, + Columns: []*schema.Column{UsageLogsColumns[25]}, RefColumns: []*schema.Column{APIKeysColumns[0]}, OnDelete: schema.NoAction, }, { Symbol: "usage_logs_accounts_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[25]}, + Columns: []*schema.Column{UsageLogsColumns[26]}, RefColumns: []*schema.Column{AccountsColumns[0]}, OnDelete: schema.NoAction, }, { Symbol: "usage_logs_groups_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[26]}, + Columns: []*schema.Column{UsageLogsColumns[27]}, RefColumns: []*schema.Column{GroupsColumns[0]}, OnDelete: schema.SetNull, }, { Symbol: "usage_logs_users_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[27]}, + Columns: []*schema.Column{UsageLogsColumns[28]}, RefColumns: []*schema.Column{UsersColumns[0]}, OnDelete: schema.NoAction, }, { Symbol: "usage_logs_user_subscriptions_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[28]}, + Columns: []*schema.Column{UsageLogsColumns[29]}, RefColumns: []*schema.Column{UserSubscriptionsColumns[0]}, OnDelete: schema.SetNull, }, @@ -426,32 +429,32 @@ var ( { Name: "usagelog_user_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[27]}, + Columns: []*schema.Column{UsageLogsColumns[28]}, }, { Name: "usagelog_api_key_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[24]}, + Columns: []*schema.Column{UsageLogsColumns[25]}, }, { Name: "usagelog_account_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[25]}, + Columns: []*schema.Column{UsageLogsColumns[26]}, }, { Name: "usagelog_group_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[26]}, + Columns: []*schema.Column{UsageLogsColumns[27]}, }, { Name: "usagelog_subscription_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[28]}, + Columns: []*schema.Column{UsageLogsColumns[29]}, }, { Name: "usagelog_created_at", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[23]}, + Columns: []*schema.Column{UsageLogsColumns[24]}, }, { Name: "usagelog_model", @@ -466,12 +469,12 @@ var ( { Name: "usagelog_user_id_created_at", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[27], UsageLogsColumns[23]}, + Columns: []*schema.Column{UsageLogsColumns[28], UsageLogsColumns[24]}, }, { Name: "usagelog_api_key_id_created_at", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[24], UsageLogsColumns[23]}, + Columns: []*schema.Column{UsageLogsColumns[25], UsageLogsColumns[24]}, }, }, } diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go index 4e01e12b..09801d4b 100644 --- a/backend/ent/mutation.go +++ b/backend/ent/mutation.go @@ -54,26 +54,30 @@ const ( // APIKeyMutation represents an operation that mutates the APIKey nodes in the graph. type APIKeyMutation struct { config - op Op - typ string - id *int64 - created_at *time.Time - updated_at *time.Time - deleted_at *time.Time - key *string - name *string - status *string - clearedFields map[string]struct{} - user *int64 - cleareduser bool - group *int64 - clearedgroup bool - usage_logs map[int64]struct{} - removedusage_logs map[int64]struct{} - clearedusage_logs bool - done bool - oldValue func(context.Context) (*APIKey, error) - predicates []predicate.APIKey + op Op + typ string + id *int64 + created_at *time.Time + updated_at *time.Time + deleted_at *time.Time + key *string + name *string + status *string + ip_whitelist *[]string + appendip_whitelist []string + ip_blacklist *[]string + appendip_blacklist []string + clearedFields map[string]struct{} + user *int64 + cleareduser bool + group *int64 + clearedgroup bool + usage_logs map[int64]struct{} + removedusage_logs map[int64]struct{} + clearedusage_logs bool + done bool + oldValue func(context.Context) (*APIKey, error) + predicates []predicate.APIKey } var _ ent.Mutation = (*APIKeyMutation)(nil) @@ -488,6 +492,136 @@ func (m *APIKeyMutation) ResetStatus() { m.status = nil } +// SetIPWhitelist sets the "ip_whitelist" field. +func (m *APIKeyMutation) SetIPWhitelist(s []string) { + m.ip_whitelist = &s + m.appendip_whitelist = nil +} + +// IPWhitelist returns the value of the "ip_whitelist" field in the mutation. +func (m *APIKeyMutation) IPWhitelist() (r []string, exists bool) { + v := m.ip_whitelist + if v == nil { + return + } + return *v, true +} + +// OldIPWhitelist returns the old "ip_whitelist" field's value of the APIKey entity. +// If the APIKey object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *APIKeyMutation) OldIPWhitelist(ctx context.Context) (v []string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldIPWhitelist is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldIPWhitelist requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldIPWhitelist: %w", err) + } + return oldValue.IPWhitelist, nil +} + +// AppendIPWhitelist adds s to the "ip_whitelist" field. +func (m *APIKeyMutation) AppendIPWhitelist(s []string) { + m.appendip_whitelist = append(m.appendip_whitelist, s...) +} + +// AppendedIPWhitelist returns the list of values that were appended to the "ip_whitelist" field in this mutation. +func (m *APIKeyMutation) AppendedIPWhitelist() ([]string, bool) { + if len(m.appendip_whitelist) == 0 { + return nil, false + } + return m.appendip_whitelist, true +} + +// ClearIPWhitelist clears the value of the "ip_whitelist" field. +func (m *APIKeyMutation) ClearIPWhitelist() { + m.ip_whitelist = nil + m.appendip_whitelist = nil + m.clearedFields[apikey.FieldIPWhitelist] = struct{}{} +} + +// IPWhitelistCleared returns if the "ip_whitelist" field was cleared in this mutation. +func (m *APIKeyMutation) IPWhitelistCleared() bool { + _, ok := m.clearedFields[apikey.FieldIPWhitelist] + return ok +} + +// ResetIPWhitelist resets all changes to the "ip_whitelist" field. +func (m *APIKeyMutation) ResetIPWhitelist() { + m.ip_whitelist = nil + m.appendip_whitelist = nil + delete(m.clearedFields, apikey.FieldIPWhitelist) +} + +// SetIPBlacklist sets the "ip_blacklist" field. +func (m *APIKeyMutation) SetIPBlacklist(s []string) { + m.ip_blacklist = &s + m.appendip_blacklist = nil +} + +// IPBlacklist returns the value of the "ip_blacklist" field in the mutation. +func (m *APIKeyMutation) IPBlacklist() (r []string, exists bool) { + v := m.ip_blacklist + if v == nil { + return + } + return *v, true +} + +// OldIPBlacklist returns the old "ip_blacklist" field's value of the APIKey entity. +// If the APIKey object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *APIKeyMutation) OldIPBlacklist(ctx context.Context) (v []string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldIPBlacklist is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldIPBlacklist requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldIPBlacklist: %w", err) + } + return oldValue.IPBlacklist, nil +} + +// AppendIPBlacklist adds s to the "ip_blacklist" field. +func (m *APIKeyMutation) AppendIPBlacklist(s []string) { + m.appendip_blacklist = append(m.appendip_blacklist, s...) +} + +// AppendedIPBlacklist returns the list of values that were appended to the "ip_blacklist" field in this mutation. +func (m *APIKeyMutation) AppendedIPBlacklist() ([]string, bool) { + if len(m.appendip_blacklist) == 0 { + return nil, false + } + return m.appendip_blacklist, true +} + +// ClearIPBlacklist clears the value of the "ip_blacklist" field. +func (m *APIKeyMutation) ClearIPBlacklist() { + m.ip_blacklist = nil + m.appendip_blacklist = nil + m.clearedFields[apikey.FieldIPBlacklist] = struct{}{} +} + +// IPBlacklistCleared returns if the "ip_blacklist" field was cleared in this mutation. +func (m *APIKeyMutation) IPBlacklistCleared() bool { + _, ok := m.clearedFields[apikey.FieldIPBlacklist] + return ok +} + +// ResetIPBlacklist resets all changes to the "ip_blacklist" field. +func (m *APIKeyMutation) ResetIPBlacklist() { + m.ip_blacklist = nil + m.appendip_blacklist = nil + delete(m.clearedFields, apikey.FieldIPBlacklist) +} + // ClearUser clears the "user" edge to the User entity. func (m *APIKeyMutation) ClearUser() { m.cleareduser = true @@ -630,7 +764,7 @@ func (m *APIKeyMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *APIKeyMutation) Fields() []string { - fields := make([]string, 0, 8) + fields := make([]string, 0, 10) if m.created_at != nil { fields = append(fields, apikey.FieldCreatedAt) } @@ -655,6 +789,12 @@ func (m *APIKeyMutation) Fields() []string { if m.status != nil { fields = append(fields, apikey.FieldStatus) } + if m.ip_whitelist != nil { + fields = append(fields, apikey.FieldIPWhitelist) + } + if m.ip_blacklist != nil { + fields = append(fields, apikey.FieldIPBlacklist) + } return fields } @@ -679,6 +819,10 @@ func (m *APIKeyMutation) Field(name string) (ent.Value, bool) { return m.GroupID() case apikey.FieldStatus: return m.Status() + case apikey.FieldIPWhitelist: + return m.IPWhitelist() + case apikey.FieldIPBlacklist: + return m.IPBlacklist() } return nil, false } @@ -704,6 +848,10 @@ func (m *APIKeyMutation) OldField(ctx context.Context, name string) (ent.Value, return m.OldGroupID(ctx) case apikey.FieldStatus: return m.OldStatus(ctx) + case apikey.FieldIPWhitelist: + return m.OldIPWhitelist(ctx) + case apikey.FieldIPBlacklist: + return m.OldIPBlacklist(ctx) } return nil, fmt.Errorf("unknown APIKey field %s", name) } @@ -769,6 +917,20 @@ func (m *APIKeyMutation) SetField(name string, value ent.Value) error { } m.SetStatus(v) return nil + case apikey.FieldIPWhitelist: + v, ok := value.([]string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetIPWhitelist(v) + return nil + case apikey.FieldIPBlacklist: + v, ok := value.([]string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetIPBlacklist(v) + return nil } return fmt.Errorf("unknown APIKey field %s", name) } @@ -808,6 +970,12 @@ func (m *APIKeyMutation) ClearedFields() []string { if m.FieldCleared(apikey.FieldGroupID) { fields = append(fields, apikey.FieldGroupID) } + if m.FieldCleared(apikey.FieldIPWhitelist) { + fields = append(fields, apikey.FieldIPWhitelist) + } + if m.FieldCleared(apikey.FieldIPBlacklist) { + fields = append(fields, apikey.FieldIPBlacklist) + } return fields } @@ -828,6 +996,12 @@ func (m *APIKeyMutation) ClearField(name string) error { case apikey.FieldGroupID: m.ClearGroupID() return nil + case apikey.FieldIPWhitelist: + m.ClearIPWhitelist() + return nil + case apikey.FieldIPBlacklist: + m.ClearIPBlacklist() + return nil } return fmt.Errorf("unknown APIKey nullable field %s", name) } @@ -860,6 +1034,12 @@ func (m *APIKeyMutation) ResetField(name string) error { case apikey.FieldStatus: m.ResetStatus() return nil + case apikey.FieldIPWhitelist: + m.ResetIPWhitelist() + return nil + case apikey.FieldIPBlacklist: + m.ResetIPBlacklist() + return nil } return fmt.Errorf("unknown APIKey field %s", name) } @@ -8396,6 +8576,7 @@ type UsageLogMutation struct { first_token_ms *int addfirst_token_ms *int user_agent *string + ip_address *string image_count *int addimage_count *int image_size *string @@ -9801,6 +9982,55 @@ func (m *UsageLogMutation) ResetUserAgent() { delete(m.clearedFields, usagelog.FieldUserAgent) } +// SetIPAddress sets the "ip_address" field. +func (m *UsageLogMutation) SetIPAddress(s string) { + m.ip_address = &s +} + +// IPAddress returns the value of the "ip_address" field in the mutation. +func (m *UsageLogMutation) IPAddress() (r string, exists bool) { + v := m.ip_address + if v == nil { + return + } + return *v, true +} + +// OldIPAddress returns the old "ip_address" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldIPAddress(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldIPAddress is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldIPAddress requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldIPAddress: %w", err) + } + return oldValue.IPAddress, nil +} + +// ClearIPAddress clears the value of the "ip_address" field. +func (m *UsageLogMutation) ClearIPAddress() { + m.ip_address = nil + m.clearedFields[usagelog.FieldIPAddress] = struct{}{} +} + +// IPAddressCleared returns if the "ip_address" field was cleared in this mutation. +func (m *UsageLogMutation) IPAddressCleared() bool { + _, ok := m.clearedFields[usagelog.FieldIPAddress] + return ok +} + +// ResetIPAddress resets all changes to the "ip_address" field. +func (m *UsageLogMutation) ResetIPAddress() { + m.ip_address = nil + delete(m.clearedFields, usagelog.FieldIPAddress) +} + // SetImageCount sets the "image_count" field. func (m *UsageLogMutation) SetImageCount(i int) { m.image_count = &i @@ -10111,7 +10341,7 @@ func (m *UsageLogMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *UsageLogMutation) Fields() []string { - fields := make([]string, 0, 28) + fields := make([]string, 0, 29) if m.user != nil { fields = append(fields, usagelog.FieldUserID) } @@ -10187,6 +10417,9 @@ func (m *UsageLogMutation) Fields() []string { if m.user_agent != nil { fields = append(fields, usagelog.FieldUserAgent) } + if m.ip_address != nil { + fields = append(fields, usagelog.FieldIPAddress) + } if m.image_count != nil { fields = append(fields, usagelog.FieldImageCount) } @@ -10254,6 +10487,8 @@ func (m *UsageLogMutation) Field(name string) (ent.Value, bool) { return m.FirstTokenMs() case usagelog.FieldUserAgent: return m.UserAgent() + case usagelog.FieldIPAddress: + return m.IPAddress() case usagelog.FieldImageCount: return m.ImageCount() case usagelog.FieldImageSize: @@ -10319,6 +10554,8 @@ func (m *UsageLogMutation) OldField(ctx context.Context, name string) (ent.Value return m.OldFirstTokenMs(ctx) case usagelog.FieldUserAgent: return m.OldUserAgent(ctx) + case usagelog.FieldIPAddress: + return m.OldIPAddress(ctx) case usagelog.FieldImageCount: return m.OldImageCount(ctx) case usagelog.FieldImageSize: @@ -10509,6 +10746,13 @@ func (m *UsageLogMutation) SetField(name string, value ent.Value) error { } m.SetUserAgent(v) return nil + case usagelog.FieldIPAddress: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetIPAddress(v) + return nil case usagelog.FieldImageCount: v, ok := value.(int) if !ok { @@ -10782,6 +11026,9 @@ func (m *UsageLogMutation) ClearedFields() []string { if m.FieldCleared(usagelog.FieldUserAgent) { fields = append(fields, usagelog.FieldUserAgent) } + if m.FieldCleared(usagelog.FieldIPAddress) { + fields = append(fields, usagelog.FieldIPAddress) + } if m.FieldCleared(usagelog.FieldImageSize) { fields = append(fields, usagelog.FieldImageSize) } @@ -10814,6 +11061,9 @@ func (m *UsageLogMutation) ClearField(name string) error { case usagelog.FieldUserAgent: m.ClearUserAgent() return nil + case usagelog.FieldIPAddress: + m.ClearIPAddress() + return nil case usagelog.FieldImageSize: m.ClearImageSize() return nil @@ -10900,6 +11150,9 @@ func (m *UsageLogMutation) ResetField(name string) error { case usagelog.FieldUserAgent: m.ResetUserAgent() return nil + case usagelog.FieldIPAddress: + m.ResetIPAddress() + return nil case usagelog.FieldImageCount: m.ResetImageCount() return nil diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go index fb1c948c..b82f2e6c 100644 --- a/backend/ent/runtime/runtime.go +++ b/backend/ent/runtime/runtime.go @@ -533,16 +533,20 @@ func init() { usagelogDescUserAgent := usagelogFields[24].Descriptor() // usagelog.UserAgentValidator is a validator for the "user_agent" field. It is called by the builders before save. usagelog.UserAgentValidator = usagelogDescUserAgent.Validators[0].(func(string) error) + // usagelogDescIPAddress is the schema descriptor for ip_address field. + usagelogDescIPAddress := usagelogFields[25].Descriptor() + // usagelog.IPAddressValidator is a validator for the "ip_address" field. It is called by the builders before save. + usagelog.IPAddressValidator = usagelogDescIPAddress.Validators[0].(func(string) error) // usagelogDescImageCount is the schema descriptor for image_count field. - usagelogDescImageCount := usagelogFields[25].Descriptor() + usagelogDescImageCount := usagelogFields[26].Descriptor() // usagelog.DefaultImageCount holds the default value on creation for the image_count field. usagelog.DefaultImageCount = usagelogDescImageCount.Default.(int) // usagelogDescImageSize is the schema descriptor for image_size field. - usagelogDescImageSize := usagelogFields[26].Descriptor() + usagelogDescImageSize := usagelogFields[27].Descriptor() // usagelog.ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save. usagelog.ImageSizeValidator = usagelogDescImageSize.Validators[0].(func(string) error) // usagelogDescCreatedAt is the schema descriptor for created_at field. - usagelogDescCreatedAt := usagelogFields[27].Descriptor() + usagelogDescCreatedAt := usagelogFields[28].Descriptor() // usagelog.DefaultCreatedAt holds the default value on creation for the created_at field. usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time) userMixin := schema.User{}.Mixin() diff --git a/backend/ent/schema/api_key.go b/backend/ent/schema/api_key.go index 94e572c5..1b206089 100644 --- a/backend/ent/schema/api_key.go +++ b/backend/ent/schema/api_key.go @@ -46,6 +46,12 @@ func (APIKey) Fields() []ent.Field { field.String("status"). MaxLen(20). Default(service.StatusActive), + field.JSON("ip_whitelist", []string{}). + Optional(). + Comment("Allowed IPs/CIDRs, e.g. [\"192.168.1.100\", \"10.0.0.0/8\"]"), + field.JSON("ip_blacklist", []string{}). + Optional(). + Comment("Blocked IPs/CIDRs"), } } diff --git a/backend/ent/schema/usage_log.go b/backend/ent/schema/usage_log.go index df955181..264a4087 100644 --- a/backend/ent/schema/usage_log.go +++ b/backend/ent/schema/usage_log.go @@ -100,6 +100,10 @@ func (UsageLog) Fields() []ent.Field { MaxLen(512). Optional(). Nillable(), + field.String("ip_address"). + MaxLen(45). // 支持 IPv6 + Optional(). + Nillable(), // 图片生成字段(仅 gemini-3-pro-image 等图片模型使用) field.Int("image_count"). diff --git a/backend/ent/usagelog.go b/backend/ent/usagelog.go index 798f3a9f..cd576466 100644 --- a/backend/ent/usagelog.go +++ b/backend/ent/usagelog.go @@ -72,6 +72,8 @@ type UsageLog struct { FirstTokenMs *int `json:"first_token_ms,omitempty"` // UserAgent holds the value of the "user_agent" field. UserAgent *string `json:"user_agent,omitempty"` + // IPAddress holds the value of the "ip_address" field. + IPAddress *string `json:"ip_address,omitempty"` // ImageCount holds the value of the "image_count" field. ImageCount int `json:"image_count,omitempty"` // ImageSize holds the value of the "image_size" field. @@ -167,7 +169,7 @@ func (*UsageLog) scanValues(columns []string) ([]any, error) { values[i] = new(sql.NullFloat64) case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount: values[i] = new(sql.NullInt64) - case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldUserAgent, usagelog.FieldImageSize: + case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize: values[i] = new(sql.NullString) case usagelog.FieldCreatedAt: values[i] = new(sql.NullTime) @@ -347,6 +349,13 @@ func (_m *UsageLog) assignValues(columns []string, values []any) error { _m.UserAgent = new(string) *_m.UserAgent = value.String } + case usagelog.FieldIPAddress: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field ip_address", values[i]) + } else if value.Valid { + _m.IPAddress = new(string) + *_m.IPAddress = value.String + } case usagelog.FieldImageCount: if value, ok := values[i].(*sql.NullInt64); !ok { return fmt.Errorf("unexpected type %T for field image_count", values[i]) @@ -512,6 +521,11 @@ func (_m *UsageLog) String() string { builder.WriteString(*v) } builder.WriteString(", ") + if v := _m.IPAddress; v != nil { + builder.WriteString("ip_address=") + builder.WriteString(*v) + } + builder.WriteString(", ") builder.WriteString("image_count=") builder.WriteString(fmt.Sprintf("%v", _m.ImageCount)) builder.WriteString(", ") diff --git a/backend/ent/usagelog/usagelog.go b/backend/ent/usagelog/usagelog.go index d3edfb4d..c06925c4 100644 --- a/backend/ent/usagelog/usagelog.go +++ b/backend/ent/usagelog/usagelog.go @@ -64,6 +64,8 @@ const ( FieldFirstTokenMs = "first_token_ms" // FieldUserAgent holds the string denoting the user_agent field in the database. FieldUserAgent = "user_agent" + // FieldIPAddress holds the string denoting the ip_address field in the database. + FieldIPAddress = "ip_address" // FieldImageCount holds the string denoting the image_count field in the database. FieldImageCount = "image_count" // FieldImageSize holds the string denoting the image_size field in the database. @@ -147,6 +149,7 @@ var Columns = []string{ FieldDurationMs, FieldFirstTokenMs, FieldUserAgent, + FieldIPAddress, FieldImageCount, FieldImageSize, FieldCreatedAt, @@ -199,6 +202,8 @@ var ( DefaultStream bool // UserAgentValidator is a validator for the "user_agent" field. It is called by the builders before save. UserAgentValidator func(string) error + // IPAddressValidator is a validator for the "ip_address" field. It is called by the builders before save. + IPAddressValidator func(string) error // DefaultImageCount holds the default value on creation for the "image_count" field. DefaultImageCount int // ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save. @@ -340,6 +345,11 @@ func ByUserAgent(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldUserAgent, opts...).ToFunc() } +// ByIPAddress orders the results by the ip_address field. +func ByIPAddress(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldIPAddress, opts...).ToFunc() +} + // ByImageCount orders the results by the image_count field. func ByImageCount(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldImageCount, opts...).ToFunc() diff --git a/backend/ent/usagelog/where.go b/backend/ent/usagelog/where.go index c7acd59d..96b7a19c 100644 --- a/backend/ent/usagelog/where.go +++ b/backend/ent/usagelog/where.go @@ -180,6 +180,11 @@ func UserAgent(v string) predicate.UsageLog { return predicate.UsageLog(sql.FieldEQ(FieldUserAgent, v)) } +// IPAddress applies equality check predicate on the "ip_address" field. It's identical to IPAddressEQ. +func IPAddress(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldIPAddress, v)) +} + // ImageCount applies equality check predicate on the "image_count" field. It's identical to ImageCountEQ. func ImageCount(v int) predicate.UsageLog { return predicate.UsageLog(sql.FieldEQ(FieldImageCount, v)) @@ -1190,6 +1195,81 @@ func UserAgentContainsFold(v string) predicate.UsageLog { return predicate.UsageLog(sql.FieldContainsFold(FieldUserAgent, v)) } +// IPAddressEQ applies the EQ predicate on the "ip_address" field. +func IPAddressEQ(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldIPAddress, v)) +} + +// IPAddressNEQ applies the NEQ predicate on the "ip_address" field. +func IPAddressNEQ(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldIPAddress, v)) +} + +// IPAddressIn applies the In predicate on the "ip_address" field. +func IPAddressIn(vs ...string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldIn(FieldIPAddress, vs...)) +} + +// IPAddressNotIn applies the NotIn predicate on the "ip_address" field. +func IPAddressNotIn(vs ...string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotIn(FieldIPAddress, vs...)) +} + +// IPAddressGT applies the GT predicate on the "ip_address" field. +func IPAddressGT(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGT(FieldIPAddress, v)) +} + +// IPAddressGTE applies the GTE predicate on the "ip_address" field. +func IPAddressGTE(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGTE(FieldIPAddress, v)) +} + +// IPAddressLT applies the LT predicate on the "ip_address" field. +func IPAddressLT(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLT(FieldIPAddress, v)) +} + +// IPAddressLTE applies the LTE predicate on the "ip_address" field. +func IPAddressLTE(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLTE(FieldIPAddress, v)) +} + +// IPAddressContains applies the Contains predicate on the "ip_address" field. +func IPAddressContains(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldContains(FieldIPAddress, v)) +} + +// IPAddressHasPrefix applies the HasPrefix predicate on the "ip_address" field. +func IPAddressHasPrefix(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldHasPrefix(FieldIPAddress, v)) +} + +// IPAddressHasSuffix applies the HasSuffix predicate on the "ip_address" field. +func IPAddressHasSuffix(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldHasSuffix(FieldIPAddress, v)) +} + +// IPAddressIsNil applies the IsNil predicate on the "ip_address" field. +func IPAddressIsNil() predicate.UsageLog { + return predicate.UsageLog(sql.FieldIsNull(FieldIPAddress)) +} + +// IPAddressNotNil applies the NotNil predicate on the "ip_address" field. +func IPAddressNotNil() predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotNull(FieldIPAddress)) +} + +// IPAddressEqualFold applies the EqualFold predicate on the "ip_address" field. +func IPAddressEqualFold(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEqualFold(FieldIPAddress, v)) +} + +// IPAddressContainsFold applies the ContainsFold predicate on the "ip_address" field. +func IPAddressContainsFold(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldContainsFold(FieldIPAddress, v)) +} + // ImageCountEQ applies the EQ predicate on the "image_count" field. func ImageCountEQ(v int) predicate.UsageLog { return predicate.UsageLog(sql.FieldEQ(FieldImageCount, v)) diff --git a/backend/ent/usagelog_create.go b/backend/ent/usagelog_create.go index f77650ab..e63fab05 100644 --- a/backend/ent/usagelog_create.go +++ b/backend/ent/usagelog_create.go @@ -337,6 +337,20 @@ func (_c *UsageLogCreate) SetNillableUserAgent(v *string) *UsageLogCreate { return _c } +// SetIPAddress sets the "ip_address" field. +func (_c *UsageLogCreate) SetIPAddress(v string) *UsageLogCreate { + _c.mutation.SetIPAddress(v) + return _c +} + +// SetNillableIPAddress sets the "ip_address" field if the given value is not nil. +func (_c *UsageLogCreate) SetNillableIPAddress(v *string) *UsageLogCreate { + if v != nil { + _c.SetIPAddress(*v) + } + return _c +} + // SetImageCount sets the "image_count" field. func (_c *UsageLogCreate) SetImageCount(v int) *UsageLogCreate { _c.mutation.SetImageCount(v) @@ -586,6 +600,11 @@ func (_c *UsageLogCreate) check() error { return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)} } } + if v, ok := _c.mutation.IPAddress(); ok { + if err := usagelog.IPAddressValidator(v); err != nil { + return &ValidationError{Name: "ip_address", err: fmt.Errorf(`ent: validator failed for field "UsageLog.ip_address": %w`, err)} + } + } if _, ok := _c.mutation.ImageCount(); !ok { return &ValidationError{Name: "image_count", err: errors.New(`ent: missing required field "UsageLog.image_count"`)} } @@ -713,6 +732,10 @@ func (_c *UsageLogCreate) createSpec() (*UsageLog, *sqlgraph.CreateSpec) { _spec.SetField(usagelog.FieldUserAgent, field.TypeString, value) _node.UserAgent = &value } + if value, ok := _c.mutation.IPAddress(); ok { + _spec.SetField(usagelog.FieldIPAddress, field.TypeString, value) + _node.IPAddress = &value + } if value, ok := _c.mutation.ImageCount(); ok { _spec.SetField(usagelog.FieldImageCount, field.TypeInt, value) _node.ImageCount = value @@ -1288,6 +1311,24 @@ func (u *UsageLogUpsert) ClearUserAgent() *UsageLogUpsert { return u } +// SetIPAddress sets the "ip_address" field. +func (u *UsageLogUpsert) SetIPAddress(v string) *UsageLogUpsert { + u.Set(usagelog.FieldIPAddress, v) + return u +} + +// UpdateIPAddress sets the "ip_address" field to the value that was provided on create. +func (u *UsageLogUpsert) UpdateIPAddress() *UsageLogUpsert { + u.SetExcluded(usagelog.FieldIPAddress) + return u +} + +// ClearIPAddress clears the value of the "ip_address" field. +func (u *UsageLogUpsert) ClearIPAddress() *UsageLogUpsert { + u.SetNull(usagelog.FieldIPAddress) + return u +} + // SetImageCount sets the "image_count" field. func (u *UsageLogUpsert) SetImageCount(v int) *UsageLogUpsert { u.Set(usagelog.FieldImageCount, v) @@ -1866,6 +1907,27 @@ func (u *UsageLogUpsertOne) ClearUserAgent() *UsageLogUpsertOne { }) } +// SetIPAddress sets the "ip_address" field. +func (u *UsageLogUpsertOne) SetIPAddress(v string) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.SetIPAddress(v) + }) +} + +// UpdateIPAddress sets the "ip_address" field to the value that was provided on create. +func (u *UsageLogUpsertOne) UpdateIPAddress() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateIPAddress() + }) +} + +// ClearIPAddress clears the value of the "ip_address" field. +func (u *UsageLogUpsertOne) ClearIPAddress() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.ClearIPAddress() + }) +} + // SetImageCount sets the "image_count" field. func (u *UsageLogUpsertOne) SetImageCount(v int) *UsageLogUpsertOne { return u.Update(func(s *UsageLogUpsert) { @@ -2616,6 +2678,27 @@ func (u *UsageLogUpsertBulk) ClearUserAgent() *UsageLogUpsertBulk { }) } +// SetIPAddress sets the "ip_address" field. +func (u *UsageLogUpsertBulk) SetIPAddress(v string) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.SetIPAddress(v) + }) +} + +// UpdateIPAddress sets the "ip_address" field to the value that was provided on create. +func (u *UsageLogUpsertBulk) UpdateIPAddress() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateIPAddress() + }) +} + +// ClearIPAddress clears the value of the "ip_address" field. +func (u *UsageLogUpsertBulk) ClearIPAddress() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.ClearIPAddress() + }) +} + // SetImageCount sets the "image_count" field. func (u *UsageLogUpsertBulk) SetImageCount(v int) *UsageLogUpsertBulk { return u.Update(func(s *UsageLogUpsert) { diff --git a/backend/ent/usagelog_update.go b/backend/ent/usagelog_update.go index 2e77eef7..ec2acbbb 100644 --- a/backend/ent/usagelog_update.go +++ b/backend/ent/usagelog_update.go @@ -524,6 +524,26 @@ func (_u *UsageLogUpdate) ClearUserAgent() *UsageLogUpdate { return _u } +// SetIPAddress sets the "ip_address" field. +func (_u *UsageLogUpdate) SetIPAddress(v string) *UsageLogUpdate { + _u.mutation.SetIPAddress(v) + return _u +} + +// SetNillableIPAddress sets the "ip_address" field if the given value is not nil. +func (_u *UsageLogUpdate) SetNillableIPAddress(v *string) *UsageLogUpdate { + if v != nil { + _u.SetIPAddress(*v) + } + return _u +} + +// ClearIPAddress clears the value of the "ip_address" field. +func (_u *UsageLogUpdate) ClearIPAddress() *UsageLogUpdate { + _u.mutation.ClearIPAddress() + return _u +} + // SetImageCount sets the "image_count" field. func (_u *UsageLogUpdate) SetImageCount(v int) *UsageLogUpdate { _u.mutation.ResetImageCount() @@ -669,6 +689,11 @@ func (_u *UsageLogUpdate) check() error { return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)} } } + if v, ok := _u.mutation.IPAddress(); ok { + if err := usagelog.IPAddressValidator(v); err != nil { + return &ValidationError{Name: "ip_address", err: fmt.Errorf(`ent: validator failed for field "UsageLog.ip_address": %w`, err)} + } + } if v, ok := _u.mutation.ImageSize(); ok { if err := usagelog.ImageSizeValidator(v); err != nil { return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)} @@ -815,6 +840,12 @@ func (_u *UsageLogUpdate) sqlSave(ctx context.Context) (_node int, err error) { if _u.mutation.UserAgentCleared() { _spec.ClearField(usagelog.FieldUserAgent, field.TypeString) } + if value, ok := _u.mutation.IPAddress(); ok { + _spec.SetField(usagelog.FieldIPAddress, field.TypeString, value) + } + if _u.mutation.IPAddressCleared() { + _spec.ClearField(usagelog.FieldIPAddress, field.TypeString) + } if value, ok := _u.mutation.ImageCount(); ok { _spec.SetField(usagelog.FieldImageCount, field.TypeInt, value) } @@ -1484,6 +1515,26 @@ func (_u *UsageLogUpdateOne) ClearUserAgent() *UsageLogUpdateOne { return _u } +// SetIPAddress sets the "ip_address" field. +func (_u *UsageLogUpdateOne) SetIPAddress(v string) *UsageLogUpdateOne { + _u.mutation.SetIPAddress(v) + return _u +} + +// SetNillableIPAddress sets the "ip_address" field if the given value is not nil. +func (_u *UsageLogUpdateOne) SetNillableIPAddress(v *string) *UsageLogUpdateOne { + if v != nil { + _u.SetIPAddress(*v) + } + return _u +} + +// ClearIPAddress clears the value of the "ip_address" field. +func (_u *UsageLogUpdateOne) ClearIPAddress() *UsageLogUpdateOne { + _u.mutation.ClearIPAddress() + return _u +} + // SetImageCount sets the "image_count" field. func (_u *UsageLogUpdateOne) SetImageCount(v int) *UsageLogUpdateOne { _u.mutation.ResetImageCount() @@ -1642,6 +1693,11 @@ func (_u *UsageLogUpdateOne) check() error { return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)} } } + if v, ok := _u.mutation.IPAddress(); ok { + if err := usagelog.IPAddressValidator(v); err != nil { + return &ValidationError{Name: "ip_address", err: fmt.Errorf(`ent: validator failed for field "UsageLog.ip_address": %w`, err)} + } + } if v, ok := _u.mutation.ImageSize(); ok { if err := usagelog.ImageSizeValidator(v); err != nil { return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)} @@ -1805,6 +1861,12 @@ func (_u *UsageLogUpdateOne) sqlSave(ctx context.Context) (_node *UsageLog, err if _u.mutation.UserAgentCleared() { _spec.ClearField(usagelog.FieldUserAgent, field.TypeString) } + if value, ok := _u.mutation.IPAddress(); ok { + _spec.SetField(usagelog.FieldIPAddress, field.TypeString, value) + } + if _u.mutation.IPAddressCleared() { + _spec.ClearField(usagelog.FieldIPAddress, field.TypeString) + } if value, ok := _u.mutation.ImageCount(); ok { _spec.SetField(usagelog.FieldImageCount, field.TypeInt, value) } diff --git a/backend/internal/handler/api_key_handler.go b/backend/internal/handler/api_key_handler.go index 09772f22..52dc6911 100644 --- a/backend/internal/handler/api_key_handler.go +++ b/backend/internal/handler/api_key_handler.go @@ -27,16 +27,20 @@ func NewAPIKeyHandler(apiKeyService *service.APIKeyService) *APIKeyHandler { // CreateAPIKeyRequest represents the create API key request payload type CreateAPIKeyRequest struct { - Name string `json:"name" binding:"required"` - GroupID *int64 `json:"group_id"` // nullable - CustomKey *string `json:"custom_key"` // 可选的自定义key + Name string `json:"name" binding:"required"` + GroupID *int64 `json:"group_id"` // nullable + CustomKey *string `json:"custom_key"` // 可选的自定义key + IPWhitelist []string `json:"ip_whitelist"` // IP 白名单 + IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单 } // UpdateAPIKeyRequest represents the update API key request payload type UpdateAPIKeyRequest struct { - Name string `json:"name"` - GroupID *int64 `json:"group_id"` - Status string `json:"status" binding:"omitempty,oneof=active inactive"` + Name string `json:"name"` + GroupID *int64 `json:"group_id"` + Status string `json:"status" binding:"omitempty,oneof=active inactive"` + IPWhitelist []string `json:"ip_whitelist"` // IP 白名单 + IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单 } // List handles listing user's API keys with pagination @@ -110,9 +114,11 @@ func (h *APIKeyHandler) Create(c *gin.Context) { } svcReq := service.CreateAPIKeyRequest{ - Name: req.Name, - GroupID: req.GroupID, - CustomKey: req.CustomKey, + Name: req.Name, + GroupID: req.GroupID, + CustomKey: req.CustomKey, + IPWhitelist: req.IPWhitelist, + IPBlacklist: req.IPBlacklist, } key, err := h.apiKeyService.Create(c.Request.Context(), subject.UserID, svcReq) if err != nil { @@ -144,7 +150,10 @@ func (h *APIKeyHandler) Update(c *gin.Context) { return } - svcReq := service.UpdateAPIKeyRequest{} + svcReq := service.UpdateAPIKeyRequest{ + IPWhitelist: req.IPWhitelist, + IPBlacklist: req.IPBlacklist, + } if req.Name != "" { svcReq.Name = &req.Name } diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go index 9a672064..85dbe6f5 100644 --- a/backend/internal/handler/dto/mappers.go +++ b/backend/internal/handler/dto/mappers.go @@ -53,16 +53,18 @@ func APIKeyFromService(k *service.APIKey) *APIKey { return nil } return &APIKey{ - ID: k.ID, - UserID: k.UserID, - Key: k.Key, - Name: k.Name, - GroupID: k.GroupID, - Status: k.Status, - CreatedAt: k.CreatedAt, - UpdatedAt: k.UpdatedAt, - User: UserFromServiceShallow(k.User), - Group: GroupFromServiceShallow(k.Group), + ID: k.ID, + UserID: k.UserID, + Key: k.Key, + Name: k.Name, + GroupID: k.GroupID, + Status: k.Status, + IPWhitelist: k.IPWhitelist, + IPBlacklist: k.IPBlacklist, + CreatedAt: k.CreatedAt, + UpdatedAt: k.UpdatedAt, + User: UserFromServiceShallow(k.User), + Group: GroupFromServiceShallow(k.Group), } } @@ -250,11 +252,12 @@ func AccountSummaryFromService(a *service.Account) *AccountSummary { // usageLogFromServiceBase is a helper that converts service UsageLog to DTO. // The account parameter allows caller to control what Account info is included. -func usageLogFromServiceBase(l *service.UsageLog, account *AccountSummary) *UsageLog { +// The includeIPAddress parameter controls whether to include the IP address (admin-only). +func usageLogFromServiceBase(l *service.UsageLog, account *AccountSummary, includeIPAddress bool) *UsageLog { if l == nil { return nil } - return &UsageLog{ + result := &UsageLog{ ID: l.ID, UserID: l.UserID, APIKeyID: l.APIKeyID, @@ -290,21 +293,26 @@ func usageLogFromServiceBase(l *service.UsageLog, account *AccountSummary) *Usag Group: GroupFromServiceShallow(l.Group), Subscription: UserSubscriptionFromService(l.Subscription), } + // IP 地址仅对管理员可见 + if includeIPAddress { + result.IPAddress = l.IPAddress + } + return result } // UsageLogFromService converts a service UsageLog to DTO for regular users. -// It excludes Account details - users should not see account information. +// It excludes Account details and IP address - users should not see these. func UsageLogFromService(l *service.UsageLog) *UsageLog { - return usageLogFromServiceBase(l, nil) + return usageLogFromServiceBase(l, nil, false) } // UsageLogFromServiceAdmin converts a service UsageLog to DTO for admin users. -// It includes minimal Account info (ID, Name only). +// It includes minimal Account info (ID, Name only) and IP address. func UsageLogFromServiceAdmin(l *service.UsageLog) *UsageLog { if l == nil { return nil } - return usageLogFromServiceBase(l, AccountSummaryFromService(l.Account)) + return usageLogFromServiceBase(l, AccountSummaryFromService(l.Account), true) } func SettingFromService(s *service.Setting) *Setting { diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go index 03f7080b..ad583ad0 100644 --- a/backend/internal/handler/dto/types.go +++ b/backend/internal/handler/dto/types.go @@ -20,14 +20,16 @@ type User struct { } type APIKey struct { - ID int64 `json:"id"` - UserID int64 `json:"user_id"` - Key string `json:"key"` - Name string `json:"name"` - GroupID *int64 `json:"group_id"` - Status string `json:"status"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` + ID int64 `json:"id"` + UserID int64 `json:"user_id"` + Key string `json:"key"` + Name string `json:"name"` + GroupID *int64 `json:"group_id"` + Status string `json:"status"` + IPWhitelist []string `json:"ip_whitelist"` + IPBlacklist []string `json:"ip_blacklist"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` User *User `json:"user,omitempty"` Group *Group `json:"group,omitempty"` @@ -187,6 +189,9 @@ type UsageLog struct { // User-Agent UserAgent *string `json:"user_agent"` + // IP 地址(仅管理员可见) + IPAddress *string `json:"ip_address,omitempty"` + CreatedAt time.Time `json:"created_at"` User *User `json:"user,omitempty"` diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 48a827f3..0d38db17 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -15,6 +15,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" "github.com/Wei-Shaw/sub2api/internal/pkg/claude" pkgerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/ip" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" @@ -114,6 +115,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) { // 获取 User-Agent userAgent := c.Request.UserAgent() + // 获取客户端 IP + clientIP := ip.GetClientIP(c) + // 0. 检查wait队列是否已满 maxWait := service.CalculateMaxWait(subject.Concurrency) canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait) @@ -273,7 +277,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } // 异步记录使用量(subscription已在函数开头获取) - go func(result *service.ForwardResult, usedAccount *service.Account, ua string) { + go func(result *service.ForwardResult, usedAccount *service.Account, ua string, cip string) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ @@ -283,10 +287,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) { Account: usedAccount, Subscription: subscription, UserAgent: ua, + IPAddress: cip, }); err != nil { log.Printf("Record usage failed: %v", err) } - }(result, account, userAgent) + }(result, account, userAgent, clientIP) return } } @@ -401,7 +406,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } // 异步记录使用量(subscription已在函数开头获取) - go func(result *service.ForwardResult, usedAccount *service.Account, ua string) { + go func(result *service.ForwardResult, usedAccount *service.Account, ua string, cip string) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ @@ -411,10 +416,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) { Account: usedAccount, Subscription: subscription, UserAgent: ua, + IPAddress: cip, }); err != nil { log.Printf("Record usage failed: %v", err) } - }(result, account, userAgent) + }(result, account, userAgent, clientIP) return } } diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index 0cbe44f2..986b174b 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -12,6 +12,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" "github.com/Wei-Shaw/sub2api/internal/pkg/gemini" "github.com/Wei-Shaw/sub2api/internal/pkg/googleapi" + "github.com/Wei-Shaw/sub2api/internal/pkg/ip" "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" @@ -167,6 +168,9 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { // 获取 User-Agent userAgent := c.Request.UserAgent() + // 获取客户端 IP + clientIP := ip.GetClientIP(c) + // For Gemini native API, do not send Claude-style ping frames. geminiConcurrency := NewConcurrencyHelper(h.concurrencyHelper.concurrencyService, SSEPingFormatNone, 0) @@ -307,7 +311,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { } // 6) record usage async - go func(result *service.ForwardResult, usedAccount *service.Account, ua string) { + go func(result *service.ForwardResult, usedAccount *service.Account, ua string, cip string) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ @@ -317,10 +321,11 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { Account: usedAccount, Subscription: subscription, UserAgent: ua, + IPAddress: cip, }); err != nil { log.Printf("Record usage failed: %v", err) } - }(result, account, userAgent) + }(result, account, userAgent, clientIP) return } } diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index 70131417..068e80ea 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -11,6 +11,7 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/ip" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" @@ -94,6 +95,10 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { // For non-Codex CLI requests, set default instructions userAgent := c.GetHeader("User-Agent") + + // 获取客户端 IP + clientIP := ip.GetClientIP(c) + if !openai.IsCodexCLIRequest(userAgent) { reqBody["instructions"] = openai.DefaultInstructions // Re-serialize body @@ -242,7 +247,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { } // Async record usage - go func(result *service.OpenAIForwardResult, usedAccount *service.Account, ua string) { + go func(result *service.OpenAIForwardResult, usedAccount *service.Account, ua string, cip string) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{ @@ -252,10 +257,11 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { Account: usedAccount, Subscription: subscription, UserAgent: ua, + IPAddress: cip, }); err != nil { log.Printf("Record usage failed: %v", err) } - }(result, account, userAgent) + }(result, account, userAgent, clientIP) return } } diff --git a/backend/internal/pkg/ip/ip.go b/backend/internal/pkg/ip/ip.go new file mode 100644 index 00000000..97109c0c --- /dev/null +++ b/backend/internal/pkg/ip/ip.go @@ -0,0 +1,168 @@ +// Package ip 提供客户端 IP 地址提取工具。 +package ip + +import ( + "net" + "strings" + + "github.com/gin-gonic/gin" +) + +// GetClientIP 从 Gin Context 中提取客户端真实 IP 地址。 +// 按以下优先级检查 Header: +// 1. CF-Connecting-IP (Cloudflare) +// 2. X-Real-IP (Nginx) +// 3. X-Forwarded-For (取第一个非私有 IP) +// 4. c.ClientIP() (Gin 内置方法) +func GetClientIP(c *gin.Context) string { + // 1. Cloudflare + if ip := c.GetHeader("CF-Connecting-IP"); ip != "" { + return normalizeIP(ip) + } + + // 2. Nginx X-Real-IP + if ip := c.GetHeader("X-Real-IP"); ip != "" { + return normalizeIP(ip) + } + + // 3. X-Forwarded-For (多个 IP 时取第一个公网 IP) + if xff := c.GetHeader("X-Forwarded-For"); xff != "" { + ips := strings.Split(xff, ",") + for _, ip := range ips { + ip = strings.TrimSpace(ip) + if ip != "" && !isPrivateIP(ip) { + return normalizeIP(ip) + } + } + // 如果都是私有 IP,返回第一个 + if len(ips) > 0 { + return normalizeIP(strings.TrimSpace(ips[0])) + } + } + + // 4. Gin 内置方法 + return normalizeIP(c.ClientIP()) +} + +// normalizeIP 规范化 IP 地址,去除端口号和空格。 +func normalizeIP(ip string) string { + ip = strings.TrimSpace(ip) + // 移除端口号(如 "192.168.1.1:8080" -> "192.168.1.1") + if host, _, err := net.SplitHostPort(ip); err == nil { + return host + } + return ip +} + +// isPrivateIP 检查 IP 是否为私有地址。 +func isPrivateIP(ipStr string) bool { + ip := net.ParseIP(ipStr) + if ip == nil { + return false + } + + // 私有 IP 范围 + privateBlocks := []string{ + "10.0.0.0/8", + "172.16.0.0/12", + "192.168.0.0/16", + "127.0.0.0/8", + "::1/128", + "fc00::/7", + } + + for _, block := range privateBlocks { + _, cidr, err := net.ParseCIDR(block) + if err != nil { + continue + } + if cidr.Contains(ip) { + return true + } + } + return false +} + +// MatchesPattern 检查 IP 是否匹配指定的模式(支持单个 IP 或 CIDR)。 +// pattern 可以是: +// - 单个 IP: "192.168.1.100" +// - CIDR 范围: "192.168.1.0/24" +func MatchesPattern(clientIP, pattern string) bool { + ip := net.ParseIP(clientIP) + if ip == nil { + return false + } + + // 尝试解析为 CIDR + if strings.Contains(pattern, "/") { + _, cidr, err := net.ParseCIDR(pattern) + if err != nil { + return false + } + return cidr.Contains(ip) + } + + // 作为单个 IP 处理 + patternIP := net.ParseIP(pattern) + if patternIP == nil { + return false + } + return ip.Equal(patternIP) +} + +// MatchesAnyPattern 检查 IP 是否匹配任意一个模式。 +func MatchesAnyPattern(clientIP string, patterns []string) bool { + for _, pattern := range patterns { + if MatchesPattern(clientIP, pattern) { + return true + } + } + return false +} + +// CheckIPRestriction 检查 IP 是否被 API Key 的 IP 限制允许。 +// 返回值:(是否允许, 拒绝原因) +// 逻辑: +// 1. 先检查黑名单,如果在黑名单中则直接拒绝 +// 2. 如果白名单不为空,IP 必须在白名单中 +// 3. 如果白名单为空,允许访问(除非被黑名单拒绝) +func CheckIPRestriction(clientIP string, whitelist, blacklist []string) (bool, string) { + // 规范化 IP + clientIP = normalizeIP(clientIP) + if clientIP == "" { + return false, "access denied" + } + + // 1. 检查黑名单 + if len(blacklist) > 0 && MatchesAnyPattern(clientIP, blacklist) { + return false, "access denied" + } + + // 2. 检查白名单(如果设置了白名单,IP 必须在其中) + if len(whitelist) > 0 && !MatchesAnyPattern(clientIP, whitelist) { + return false, "access denied" + } + + return true, "" +} + +// ValidateIPPattern 验证 IP 或 CIDR 格式是否有效。 +func ValidateIPPattern(pattern string) bool { + if strings.Contains(pattern, "/") { + _, _, err := net.ParseCIDR(pattern) + return err == nil + } + return net.ParseIP(pattern) != nil +} + +// ValidateIPPatterns 验证多个 IP 或 CIDR 格式。 +// 返回无效的模式列表。 +func ValidateIPPatterns(patterns []string) []string { + var invalid []string + for _, p := range patterns { + if !ValidateIPPattern(p) { + invalid = append(invalid, p) + } + } + return invalid +} diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go index f3b07616..6da551da 100644 --- a/backend/internal/repository/api_key_repo.go +++ b/backend/internal/repository/api_key_repo.go @@ -26,13 +26,21 @@ func (r *apiKeyRepository) activeQuery() *dbent.APIKeyQuery { } func (r *apiKeyRepository) Create(ctx context.Context, key *service.APIKey) error { - created, err := r.client.APIKey.Create(). + builder := r.client.APIKey.Create(). SetUserID(key.UserID). SetKey(key.Key). SetName(key.Name). SetStatus(key.Status). - SetNillableGroupID(key.GroupID). - Save(ctx) + SetNillableGroupID(key.GroupID) + + if len(key.IPWhitelist) > 0 { + builder.SetIPWhitelist(key.IPWhitelist) + } + if len(key.IPBlacklist) > 0 { + builder.SetIPBlacklist(key.IPBlacklist) + } + + created, err := builder.Save(ctx) if err == nil { key.ID = created.ID key.CreatedAt = created.CreatedAt @@ -108,6 +116,18 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) erro builder.ClearGroupID() } + // IP 限制字段 + if len(key.IPWhitelist) > 0 { + builder.SetIPWhitelist(key.IPWhitelist) + } else { + builder.ClearIPWhitelist() + } + if len(key.IPBlacklist) > 0 { + builder.SetIPBlacklist(key.IPBlacklist) + } else { + builder.ClearIPBlacklist() + } + affected, err := builder.Save(ctx) if err != nil { return err @@ -268,14 +288,16 @@ func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey { return nil } out := &service.APIKey{ - ID: m.ID, - UserID: m.UserID, - Key: m.Key, - Name: m.Name, - Status: m.Status, - CreatedAt: m.CreatedAt, - UpdatedAt: m.UpdatedAt, - GroupID: m.GroupID, + ID: m.ID, + UserID: m.UserID, + Key: m.Key, + Name: m.Name, + Status: m.Status, + IPWhitelist: m.IPWhitelist, + IPBlacklist: m.IPBlacklist, + CreatedAt: m.CreatedAt, + UpdatedAt: m.UpdatedAt, + GroupID: m.GroupID, } if m.Edges.User != nil { out.User = userEntityToService(m.Edges.User) diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go index bd5c8b4f..6ed8910e 100644 --- a/backend/internal/repository/usage_log_repo.go +++ b/backend/internal/repository/usage_log_repo.go @@ -22,7 +22,7 @@ import ( "github.com/lib/pq" ) -const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, image_count, image_size, created_at" +const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, created_at" type usageLogRepository struct { client *dbent.Client @@ -110,6 +110,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) duration_ms, first_token_ms, user_agent, + ip_address, image_count, image_size, created_at @@ -119,7 +120,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, - $20, $21, $22, $23, $24, $25, $26, $27, $28 + $20, $21, $22, $23, $24, $25, $26, $27, $28, $29 ) ON CONFLICT (request_id, api_key_id) DO NOTHING RETURNING id, created_at @@ -130,6 +131,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) duration := nullInt(log.DurationMs) firstToken := nullInt(log.FirstTokenMs) userAgent := nullString(log.UserAgent) + ipAddress := nullString(log.IPAddress) imageSize := nullString(log.ImageSize) var requestIDArg any @@ -163,6 +165,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) duration, firstToken, userAgent, + ipAddress, log.ImageCount, imageSize, createdAt, @@ -1873,6 +1876,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e durationMs sql.NullInt64 firstTokenMs sql.NullInt64 userAgent sql.NullString + ipAddress sql.NullString imageCount int imageSize sql.NullString createdAt time.Time @@ -1905,6 +1909,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e &durationMs, &firstTokenMs, &userAgent, + &ipAddress, &imageCount, &imageSize, &createdAt, @@ -1959,6 +1964,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e if userAgent.Valid { log.UserAgent = &userAgent.String } + if ipAddress.Valid { + log.IPAddress = &ipAddress.String + } if imageSize.Valid { log.ImageSize = &imageSize.String } diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index 20e82be8..6e52c5bc 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -82,6 +82,8 @@ func TestAPIContracts(t *testing.T) { "name": "Key One", "group_id": null, "status": "active", + "ip_whitelist": null, + "ip_blacklist": null, "created_at": "2025-01-02T03:04:05Z", "updated_at": "2025-01-02T03:04:05Z" } @@ -116,6 +118,8 @@ func TestAPIContracts(t *testing.T) { "name": "Key One", "group_id": null, "status": "active", + "ip_whitelist": null, + "ip_blacklist": null, "created_at": "2025-01-02T03:04:05Z", "updated_at": "2025-01-02T03:04:05Z" } diff --git a/backend/internal/server/middleware/api_key_auth.go b/backend/internal/server/middleware/api_key_auth.go index 74ff8af3..d93724f2 100644 --- a/backend/internal/server/middleware/api_key_auth.go +++ b/backend/internal/server/middleware/api_key_auth.go @@ -6,6 +6,7 @@ import ( "strings" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/ip" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" @@ -71,6 +72,17 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti return } + // 检查 IP 限制(白名单/黑名单) + // 注意:错误信息故意模糊,避免暴露具体的 IP 限制机制 + if len(apiKey.IPWhitelist) > 0 || len(apiKey.IPBlacklist) > 0 { + clientIP := ip.GetClientIP(c) + allowed, _ := ip.CheckIPRestriction(clientIP, apiKey.IPWhitelist, apiKey.IPBlacklist) + if !allowed { + AbortWithError(c, 403, "ACCESS_DENIED", "Access denied") + return + } + } + // 检查关联的用户 if apiKey.User == nil { AbortWithError(c, 401, "USER_NOT_FOUND", "User associated with API key not found") diff --git a/backend/internal/service/api_key.go b/backend/internal/service/api_key.go index 0cf0f4f9..8c692d09 100644 --- a/backend/internal/service/api_key.go +++ b/backend/internal/service/api_key.go @@ -3,16 +3,18 @@ package service import "time" type APIKey struct { - ID int64 - UserID int64 - Key string - Name string - GroupID *int64 - Status string - CreatedAt time.Time - UpdatedAt time.Time - User *User - Group *Group + ID int64 + UserID int64 + Key string + Name string + GroupID *int64 + Status string + IPWhitelist []string + IPBlacklist []string + CreatedAt time.Time + UpdatedAt time.Time + User *User + Group *Group } func (k *APIKey) IsActive() bool { diff --git a/backend/internal/service/api_key_service.go b/backend/internal/service/api_key_service.go index 0ffe8821..578afc1a 100644 --- a/backend/internal/service/api_key_service.go +++ b/backend/internal/service/api_key_service.go @@ -9,6 +9,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/config" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/ip" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" ) @@ -20,6 +21,7 @@ var ( ErrAPIKeyTooShort = infraerrors.BadRequest("API_KEY_TOO_SHORT", "api key must be at least 16 characters") ErrAPIKeyInvalidChars = infraerrors.BadRequest("API_KEY_INVALID_CHARS", "api key can only contain letters, numbers, underscores, and hyphens") ErrAPIKeyRateLimited = infraerrors.TooManyRequests("API_KEY_RATE_LIMITED", "too many failed attempts, please try again later") + ErrInvalidIPPattern = infraerrors.BadRequest("INVALID_IP_PATTERN", "invalid IP or CIDR pattern") ) const ( @@ -57,16 +59,20 @@ type APIKeyCache interface { // CreateAPIKeyRequest 创建API Key请求 type CreateAPIKeyRequest struct { - Name string `json:"name"` - GroupID *int64 `json:"group_id"` - CustomKey *string `json:"custom_key"` // 可选的自定义key + Name string `json:"name"` + GroupID *int64 `json:"group_id"` + CustomKey *string `json:"custom_key"` // 可选的自定义key + IPWhitelist []string `json:"ip_whitelist"` // IP 白名单 + IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单 } // UpdateAPIKeyRequest 更新API Key请求 type UpdateAPIKeyRequest struct { - Name *string `json:"name"` - GroupID *int64 `json:"group_id"` - Status *string `json:"status"` + Name *string `json:"name"` + GroupID *int64 `json:"group_id"` + Status *string `json:"status"` + IPWhitelist []string `json:"ip_whitelist"` // IP 白名单(空数组清空) + IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单(空数组清空) } // APIKeyService API Key服务 @@ -186,6 +192,20 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK return nil, fmt.Errorf("get user: %w", err) } + // 验证 IP 白名单格式 + if len(req.IPWhitelist) > 0 { + if invalid := ip.ValidateIPPatterns(req.IPWhitelist); len(invalid) > 0 { + return nil, fmt.Errorf("%w: %v", ErrInvalidIPPattern, invalid) + } + } + + // 验证 IP 黑名单格式 + if len(req.IPBlacklist) > 0 { + if invalid := ip.ValidateIPPatterns(req.IPBlacklist); len(invalid) > 0 { + return nil, fmt.Errorf("%w: %v", ErrInvalidIPPattern, invalid) + } + } + // 验证分组权限(如果指定了分组) if req.GroupID != nil { group, err := s.groupRepo.GetByID(ctx, *req.GroupID) @@ -236,11 +256,13 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK // 创建API Key记录 apiKey := &APIKey{ - UserID: userID, - Key: key, - Name: req.Name, - GroupID: req.GroupID, - Status: StatusActive, + UserID: userID, + Key: key, + Name: req.Name, + GroupID: req.GroupID, + Status: StatusActive, + IPWhitelist: req.IPWhitelist, + IPBlacklist: req.IPBlacklist, } if err := s.apiKeyRepo.Create(ctx, apiKey); err != nil { @@ -312,6 +334,20 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req return nil, ErrInsufficientPerms } + // 验证 IP 白名单格式 + if len(req.IPWhitelist) > 0 { + if invalid := ip.ValidateIPPatterns(req.IPWhitelist); len(invalid) > 0 { + return nil, fmt.Errorf("%w: %v", ErrInvalidIPPattern, invalid) + } + } + + // 验证 IP 黑名单格式 + if len(req.IPBlacklist) > 0 { + if invalid := ip.ValidateIPPatterns(req.IPBlacklist); len(invalid) > 0 { + return nil, fmt.Errorf("%w: %v", ErrInvalidIPPattern, invalid) + } + } + // 更新字段 if req.Name != nil { apiKey.Name = *req.Name @@ -344,6 +380,10 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req } } + // 更新 IP 限制(空数组会清空设置) + apiKey.IPWhitelist = req.IPWhitelist + apiKey.IPBlacklist = req.IPBlacklist + if err := s.apiKeyRepo.Update(ctx, apiKey); err != nil { return nil, fmt.Errorf("update api key: %w", err) } diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index e73e9406..89f7d798 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -2247,6 +2247,7 @@ type RecordUsageInput struct { Account *Account Subscription *UserSubscription // 可选:订阅信息 UserAgent string // 请求的 User-Agent + IPAddress string // 请求的客户端 IP 地址 } // RecordUsage 记录使用量并扣费(或更新订阅用量) @@ -2337,6 +2338,11 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu usageLog.UserAgent = &input.UserAgent } + // 添加 IPAddress + if input.IPAddress != "" { + usageLog.IPAddress = &input.IPAddress + } + // 添加分组和订阅关联 if apiKey.GroupID != nil { usageLog.GroupID = apiKey.GroupID diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 42e98585..5bb7574a 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -1197,6 +1197,7 @@ type OpenAIRecordUsageInput struct { Account *Account Subscription *UserSubscription UserAgent string // 请求的 User-Agent + IPAddress string // 请求的客户端 IP 地址 } // RecordUsage records usage and deducts balance @@ -1271,6 +1272,11 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec usageLog.UserAgent = &input.UserAgent } + // 添加 IPAddress + if input.IPAddress != "" { + usageLog.IPAddress = &input.IPAddress + } + if apiKey.GroupID != nil { usageLog.GroupID = apiKey.GroupID } diff --git a/backend/internal/service/usage_log.go b/backend/internal/service/usage_log.go index 9ecb7098..62d7fae0 100644 --- a/backend/internal/service/usage_log.go +++ b/backend/internal/service/usage_log.go @@ -39,6 +39,7 @@ type UsageLog struct { DurationMs *int FirstTokenMs *int UserAgent *string + IPAddress *string // 图片生成字段 ImageCount int diff --git a/backend/migrations/031_add_ip_address.sql b/backend/migrations/031_add_ip_address.sql new file mode 100644 index 00000000..7f557830 --- /dev/null +++ b/backend/migrations/031_add_ip_address.sql @@ -0,0 +1,5 @@ +-- Add IP address field to usage_logs table for request tracking (admin-only visibility) +ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS ip_address VARCHAR(45); + +-- Create index for IP address queries +CREATE INDEX IF NOT EXISTS idx_usage_logs_ip_address ON usage_logs(ip_address); diff --git a/backend/migrations/032_add_api_key_ip_restriction.sql b/backend/migrations/032_add_api_key_ip_restriction.sql new file mode 100644 index 00000000..2dfe2c92 --- /dev/null +++ b/backend/migrations/032_add_api_key_ip_restriction.sql @@ -0,0 +1,9 @@ +-- Add IP restriction fields to api_keys table +-- ip_whitelist: JSON array of allowed IPs/CIDRs (if set, only these IPs can use the key) +-- ip_blacklist: JSON array of blocked IPs/CIDRs (these IPs are always blocked) + +ALTER TABLE api_keys ADD COLUMN IF NOT EXISTS ip_whitelist JSONB DEFAULT NULL; +ALTER TABLE api_keys ADD COLUMN IF NOT EXISTS ip_blacklist JSONB DEFAULT NULL; + +COMMENT ON COLUMN api_keys.ip_whitelist IS 'JSON array of allowed IPs/CIDRs, e.g. ["192.168.1.100", "10.0.0.0/8"]'; +COMMENT ON COLUMN api_keys.ip_blacklist IS 'JSON array of blocked IPs/CIDRs, e.g. ["1.2.3.4", "5.6.0.0/16"]'; diff --git a/backend/repository.test b/backend/repository.test new file mode 100755 index 00000000..9ecc014c Binary files /dev/null and b/backend/repository.test differ diff --git a/deploy/docker-compose.standalone.yml b/deploy/docker-compose.standalone.yml new file mode 100644 index 00000000..1bf247c7 --- /dev/null +++ b/deploy/docker-compose.standalone.yml @@ -0,0 +1,93 @@ +# ============================================================================= +# Sub2API Docker Compose - Standalone Configuration +# ============================================================================= +# This configuration runs only the Sub2API application. +# PostgreSQL and Redis must be provided externally. +# +# Usage: +# 1. Copy .env.example to .env and configure database/redis connection +# 2. docker-compose -f docker-compose.standalone.yml up -d +# 3. Access: http://localhost:8080 +# ============================================================================= + +services: + sub2api: + image: weishaw/sub2api:latest + container_name: sub2api + restart: unless-stopped + ulimits: + nofile: + soft: 100000 + hard: 100000 + ports: + - "${BIND_HOST:-0.0.0.0}:${SERVER_PORT:-8080}:8080" + volumes: + - sub2api_data:/app/data + extra_hosts: + - "host.docker.internal:host-gateway" + environment: + # ======================================================================= + # Auto Setup + # ======================================================================= + - AUTO_SETUP=true + + # ======================================================================= + # Server Configuration + # ======================================================================= + - SERVER_HOST=0.0.0.0 + - SERVER_PORT=8080 + - SERVER_MODE=${SERVER_MODE:-release} + - RUN_MODE=${RUN_MODE:-standard} + + # ======================================================================= + # Database Configuration (PostgreSQL) - Required + # ======================================================================= + - DATABASE_HOST=${DATABASE_HOST:?DATABASE_HOST is required} + - DATABASE_PORT=${DATABASE_PORT:-5432} + - DATABASE_USER=${DATABASE_USER:-sub2api} + - DATABASE_PASSWORD=${DATABASE_PASSWORD:?DATABASE_PASSWORD is required} + - DATABASE_DBNAME=${DATABASE_DBNAME:-sub2api} + - DATABASE_SSLMODE=${DATABASE_SSLMODE:-disable} + + # ======================================================================= + # Redis Configuration - Required + # ======================================================================= + - REDIS_HOST=${REDIS_HOST:?REDIS_HOST is required} + - REDIS_PORT=${REDIS_PORT:-6379} + - REDIS_PASSWORD=${REDIS_PASSWORD:-} + - REDIS_DB=${REDIS_DB:-0} + + # ======================================================================= + # Admin Account (auto-created on first run) + # ======================================================================= + - ADMIN_EMAIL=${ADMIN_EMAIL:-admin@sub2api.local} + - ADMIN_PASSWORD=${ADMIN_PASSWORD:-} + + # ======================================================================= + # JWT Configuration + # ======================================================================= + - JWT_SECRET=${JWT_SECRET:-} + - JWT_EXPIRE_HOUR=${JWT_EXPIRE_HOUR:-24} + + # ======================================================================= + # Timezone Configuration + # ======================================================================= + - TZ=${TZ:-Asia/Shanghai} + + # ======================================================================= + # Gemini OAuth Configuration (optional) + # ======================================================================= + - GEMINI_OAUTH_CLIENT_ID=${GEMINI_OAUTH_CLIENT_ID:-} + - GEMINI_OAUTH_CLIENT_SECRET=${GEMINI_OAUTH_CLIENT_SECRET:-} + - GEMINI_OAUTH_SCOPES=${GEMINI_OAUTH_SCOPES:-} + - GEMINI_QUOTA_POLICY=${GEMINI_QUOTA_POLICY:-} + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:8080/health"] + interval: 30s + timeout: 10s + retries: 3 + start_period: 30s + +volumes: + sub2api_data: + driver: local diff --git a/frontend/src/api/admin/usage.ts b/frontend/src/api/admin/usage.ts index 4712dafd..ca76234b 100644 --- a/frontend/src/api/admin/usage.ts +++ b/frontend/src/api/admin/usage.ts @@ -64,7 +64,6 @@ export async function getStats(params: { group_id?: number model?: string stream?: boolean - billing_type?: number period?: string start_date?: string end_date?: string diff --git a/frontend/src/api/keys.ts b/frontend/src/api/keys.ts index caa339e4..cdae1359 100644 --- a/frontend/src/api/keys.ts +++ b/frontend/src/api/keys.ts @@ -42,12 +42,16 @@ export async function getById(id: number): Promise { * @param name - Key name * @param groupId - Optional group ID * @param customKey - Optional custom key value + * @param ipWhitelist - Optional IP whitelist + * @param ipBlacklist - Optional IP blacklist * @returns Created API key */ export async function create( name: string, groupId?: number | null, - customKey?: string + customKey?: string, + ipWhitelist?: string[], + ipBlacklist?: string[] ): Promise { const payload: CreateApiKeyRequest = { name } if (groupId !== undefined) { @@ -56,6 +60,12 @@ export async function create( if (customKey) { payload.custom_key = customKey } + if (ipWhitelist && ipWhitelist.length > 0) { + payload.ip_whitelist = ipWhitelist + } + if (ipBlacklist && ipBlacklist.length > 0) { + payload.ip_blacklist = ipBlacklist + } const { data } = await apiClient.post('/keys', payload) return data diff --git a/frontend/src/components/admin/usage/UsageFilters.vue b/frontend/src/components/admin/usage/UsageFilters.vue index 924f5fb6..0926d83c 100644 --- a/frontend/src/components/admin/usage/UsageFilters.vue +++ b/frontend/src/components/admin/usage/UsageFilters.vue @@ -127,12 +127,6 @@ - -
@@ -227,12 +221,6 @@ const streamTypeOptions = ref([ { value: false, label: t('usage.sync') } ]) -const billingTypeOptions = ref([ - { value: null, label: t('admin.usage.allBillingTypes') }, - { value: 1, label: t('usage.subscription') }, - { value: 0, label: t('usage.balance') } -]) - const emitChange = () => emit('change') const updateStartDate = (value: string) => { diff --git a/frontend/src/components/admin/usage/UsageTable.vue b/frontend/src/components/admin/usage/UsageTable.vue index c53b8c90..a66e4b7b 100644 --- a/frontend/src/components/admin/usage/UsageTable.vue +++ b/frontend/src/components/admin/usage/UsageTable.vue @@ -96,12 +96,6 @@
- - + + @@ -249,11 +248,11 @@ const cols = computed(() => [ { key: 'stream', label: t('usage.type'), sortable: false }, { key: 'tokens', label: t('usage.tokens'), sortable: false }, { key: 'cost', label: t('usage.cost'), sortable: false }, - { key: 'billing_type', label: t('usage.billingType'), sortable: false }, { key: 'first_token', label: t('usage.firstToken'), sortable: false }, { key: 'duration', label: t('usage.duration'), sortable: false }, { key: 'created_at', label: t('usage.time'), sortable: true }, - { key: 'user_agent', label: t('usage.userAgent'), sortable: false } + { key: 'user_agent', label: t('usage.userAgent'), sortable: false }, + { key: 'ip_address', label: t('admin.usage.ipAddress'), sortable: false } ]) const formatCacheTokens = (tokens: number): string => { diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index ca220281..e7d3a28d 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -370,6 +370,14 @@ export default { customKeyTooShort: 'Custom key must be at least 16 characters', customKeyInvalidChars: 'Custom key can only contain letters, numbers, underscores, and hyphens', customKeyRequired: 'Please enter a custom key', + ipRestriction: 'IP Restriction', + ipWhitelist: 'IP Whitelist', + ipWhitelistPlaceholder: '192.168.1.100\n10.0.0.0/8', + ipWhitelistHint: 'One IP or CIDR per line. Only these IPs can use this key when set.', + ipBlacklist: 'IP Blacklist', + ipBlacklistPlaceholder: '1.2.3.4\n5.6.0.0/16', + ipBlacklistHint: 'One IP or CIDR per line. These IPs will be blocked from using this key.', + ipRestrictionEnabled: 'IP restriction enabled', ccSwitchNotInstalled: 'CC-Switch is not installed or the protocol handler is not registered. Please install CC-Switch first or manually copy the API key.', ccsClientSelect: { title: 'Select Client', @@ -430,9 +438,6 @@ export default { exportFailed: 'Failed to export usage data', exportExcelSuccess: 'Usage data exported successfully (Excel format)', exportExcelFailed: 'Failed to export usage data', - billingType: 'Billing', - balance: 'Balance', - subscription: 'Subscription', imageUnit: ' images', userAgent: 'User-Agent' }, @@ -1735,7 +1740,6 @@ export default { allAccounts: 'All Accounts', allGroups: 'All Groups', allTypes: 'All Types', - allBillingTypes: 'All Billing', inputCost: 'Input Cost', outputCost: 'Output Cost', cacheCreationCost: 'Cache Creation Cost', @@ -1744,7 +1748,8 @@ export default { outputTokens: 'Output Tokens', cacheCreationTokens: 'Cache Creation Tokens', cacheReadTokens: 'Cache Read Tokens', - failedToLoad: 'Failed to load usage records' + failedToLoad: 'Failed to load usage records', + ipAddress: 'IP' }, // Settings diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index 6749c02e..fc1e6fff 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -367,6 +367,14 @@ export default { customKeyTooShort: '自定义密钥至少需要16个字符', customKeyInvalidChars: '自定义密钥只能包含字母、数字、下划线和连字符', customKeyRequired: '请输入自定义密钥', + ipRestriction: 'IP 限制', + ipWhitelist: 'IP 白名单', + ipWhitelistPlaceholder: '192.168.1.100\n10.0.0.0/8', + ipWhitelistHint: '每行一个 IP 或 CIDR,设置后仅允许这些 IP 使用此密钥', + ipBlacklist: 'IP 黑名单', + ipBlacklistPlaceholder: '1.2.3.4\n5.6.0.0/16', + ipBlacklistHint: '每行一个 IP 或 CIDR,这些 IP 将被禁止使用此密钥', + ipRestrictionEnabled: '已配置 IP 限制', ccSwitchNotInstalled: 'CC-Switch 未安装或协议处理程序未注册。请先安装 CC-Switch 或手动复制 API 密钥。', ccsClientSelect: { title: '选择客户端', @@ -427,9 +435,6 @@ export default { exportFailed: '使用数据导出失败', exportExcelSuccess: '使用数据导出成功(Excel格式)', exportExcelFailed: '使用数据导出失败', - billingType: '消费类型', - balance: '余额', - subscription: '订阅', imageUnit: '张', userAgent: 'User-Agent' }, @@ -1880,7 +1885,6 @@ export default { allAccounts: '全部账户', allGroups: '全部分组', allTypes: '全部类型', - allBillingTypes: '全部计费', inputCost: '输入成本', outputCost: '输出成本', cacheCreationCost: '缓存创建成本', @@ -1889,7 +1893,8 @@ export default { outputTokens: '输出 Token', cacheCreationTokens: '缓存创建 Token', cacheReadTokens: '缓存读取 Token', - failedToLoad: '加载使用记录失败' + failedToLoad: '加载使用记录失败', + ipAddress: 'IP' }, // Settings diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index 4b8fff09..bc858c6a 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -279,6 +279,8 @@ export interface ApiKey { name: string group_id: number | null status: 'active' | 'inactive' + ip_whitelist: string[] + ip_blacklist: string[] created_at: string updated_at: string group?: Group @@ -288,12 +290,16 @@ export interface CreateApiKeyRequest { name: string group_id?: number | null custom_key?: string // Optional custom API Key + ip_whitelist?: string[] + ip_blacklist?: string[] } export interface UpdateApiKeyRequest { name?: string group_id?: number | null status?: 'active' | 'inactive' + ip_whitelist?: string[] + ip_blacklist?: string[] } export interface CreateGroupRequest { @@ -560,9 +566,6 @@ export interface UpdateProxyRequest { export type RedeemCodeType = 'balance' | 'concurrency' | 'subscription' -// 消费类型: 0=钱包余额, 1=订阅套餐 -export type BillingType = 0 | 1 - export interface UsageLog { id: number user_id: number @@ -589,7 +592,6 @@ export interface UsageLog { actual_cost: number rate_multiplier: number - billing_type: BillingType stream: boolean duration_ms: number first_token_ms: number | null @@ -601,6 +603,9 @@ export interface UsageLog { // User-Agent user_agent: string | null + // IP 地址(仅管理员可见) + ip_address: string | null + created_at: string user?: User @@ -830,7 +835,6 @@ export interface UsageQueryParams { group_id?: number model?: string stream?: boolean - billing_type?: number start_date?: string end_date?: string } diff --git a/frontend/src/views/admin/UsageView.vue b/frontend/src/views/admin/UsageView.vue index 47af8141..fbde13fd 100644 --- a/frontend/src/views/admin/UsageView.vue +++ b/frontend/src/views/admin/UsageView.vue @@ -95,8 +95,8 @@ const exportToExcel = async () => { t('admin.usage.inputCost'), t('admin.usage.outputCost'), t('admin.usage.cacheReadCost'), t('admin.usage.cacheCreationCost'), t('usage.rate'), t('usage.original'), t('usage.billed'), - t('usage.billingType'), t('usage.firstToken'), t('usage.duration'), - t('admin.usage.requestId'), t('usage.userAgent') + t('usage.firstToken'), t('usage.duration'), + t('admin.usage.requestId'), t('usage.userAgent'), t('admin.usage.ipAddress') ] const rows = all.map(log => [ log.created_at, @@ -117,11 +117,11 @@ const exportToExcel = async () => { log.rate_multiplier?.toFixed(2) || '1.00', log.total_cost?.toFixed(6) || '0.000000', log.actual_cost?.toFixed(6) || '0.000000', - log.billing_type === 1 ? t('usage.subscription') : t('usage.balance'), log.first_token_ms ?? '', log.duration_ms, log.request_id || '', - log.user_agent || '' + log.user_agent || '', + log.ip_address || '' ]) const ws = XLSX.utils.aoa_to_sheet([headers, ...rows]) const wb = XLSX.utils.book_new() diff --git a/frontend/src/views/user/KeysView.vue b/frontend/src/views/user/KeysView.vue index 6d4e3c96..0787c467 100644 --- a/frontend/src/views/user/KeysView.vue +++ b/frontend/src/views/user/KeysView.vue @@ -46,8 +46,17 @@ -