feat(api-key): add IP whitelist/blacklist restriction and usage log IP tracking

- Add IP restriction feature for API keys (whitelist/blacklist with CIDR support)
- Add IP address logging to usage logs (admin-only visibility)
- Remove billing_type column from usage logs UI (redundant)
- Use generic "Access denied" error message for security

Backend:
- New ip package with IP/CIDR validation and matching utilities
- Database migrations for ip_whitelist, ip_blacklist (api_keys) and ip_address (usage_logs)
- Middleware IP restriction check after API key validation
- Input validation for IP/CIDR patterns on create/update

Frontend:
- API key form with enable toggle for IP restriction
- Shield icon indicator in table for keys with IP restriction
- Removed billing_type filter and column from usage views
This commit is contained in:
Edric Li
2026-01-09 21:24:59 +08:00
parent 8f24d239af
commit 90798f14b5
42 changed files with 1403 additions and 183 deletions

View File

@@ -3,6 +3,7 @@
package ent
import (
"encoding/json"
"fmt"
"strings"
"time"
@@ -35,6 +36,10 @@ type APIKey struct {
GroupID *int64 `json:"group_id,omitempty"`
// Status holds the value of the "status" field.
Status string `json:"status,omitempty"`
// Allowed IPs/CIDRs, e.g. ["192.168.1.100", "10.0.0.0/8"]
IPWhitelist []string `json:"ip_whitelist,omitempty"`
// Blocked IPs/CIDRs
IPBlacklist []string `json:"ip_blacklist,omitempty"`
// Edges holds the relations/edges for other nodes in the graph.
// The values are being populated by the APIKeyQuery when eager-loading is set.
Edges APIKeyEdges `json:"edges"`
@@ -90,6 +95,8 @@ func (*APIKey) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns))
for i := range columns {
switch columns[i] {
case apikey.FieldIPWhitelist, apikey.FieldIPBlacklist:
values[i] = new([]byte)
case apikey.FieldID, apikey.FieldUserID, apikey.FieldGroupID:
values[i] = new(sql.NullInt64)
case apikey.FieldKey, apikey.FieldName, apikey.FieldStatus:
@@ -167,6 +174,22 @@ func (_m *APIKey) assignValues(columns []string, values []any) error {
} else if value.Valid {
_m.Status = value.String
}
case apikey.FieldIPWhitelist:
if value, ok := values[i].(*[]byte); !ok {
return fmt.Errorf("unexpected type %T for field ip_whitelist", values[i])
} else if value != nil && len(*value) > 0 {
if err := json.Unmarshal(*value, &_m.IPWhitelist); err != nil {
return fmt.Errorf("unmarshal field ip_whitelist: %w", err)
}
}
case apikey.FieldIPBlacklist:
if value, ok := values[i].(*[]byte); !ok {
return fmt.Errorf("unexpected type %T for field ip_blacklist", values[i])
} else if value != nil && len(*value) > 0 {
if err := json.Unmarshal(*value, &_m.IPBlacklist); err != nil {
return fmt.Errorf("unmarshal field ip_blacklist: %w", err)
}
}
default:
_m.selectValues.Set(columns[i], values[i])
}
@@ -245,6 +268,12 @@ func (_m *APIKey) String() string {
builder.WriteString(", ")
builder.WriteString("status=")
builder.WriteString(_m.Status)
builder.WriteString(", ")
builder.WriteString("ip_whitelist=")
builder.WriteString(fmt.Sprintf("%v", _m.IPWhitelist))
builder.WriteString(", ")
builder.WriteString("ip_blacklist=")
builder.WriteString(fmt.Sprintf("%v", _m.IPBlacklist))
builder.WriteByte(')')
return builder.String()
}

View File

@@ -31,6 +31,10 @@ const (
FieldGroupID = "group_id"
// FieldStatus holds the string denoting the status field in the database.
FieldStatus = "status"
// FieldIPWhitelist holds the string denoting the ip_whitelist field in the database.
FieldIPWhitelist = "ip_whitelist"
// FieldIPBlacklist holds the string denoting the ip_blacklist field in the database.
FieldIPBlacklist = "ip_blacklist"
// EdgeUser holds the string denoting the user edge name in mutations.
EdgeUser = "user"
// EdgeGroup holds the string denoting the group edge name in mutations.
@@ -73,6 +77,8 @@ var Columns = []string{
FieldName,
FieldGroupID,
FieldStatus,
FieldIPWhitelist,
FieldIPBlacklist,
}
// ValidColumn reports if the column name is valid (part of the table columns).

View File

@@ -470,6 +470,26 @@ func StatusContainsFold(v string) predicate.APIKey {
return predicate.APIKey(sql.FieldContainsFold(FieldStatus, v))
}
// IPWhitelistIsNil applies the IsNil predicate on the "ip_whitelist" field.
func IPWhitelistIsNil() predicate.APIKey {
return predicate.APIKey(sql.FieldIsNull(FieldIPWhitelist))
}
// IPWhitelistNotNil applies the NotNil predicate on the "ip_whitelist" field.
func IPWhitelistNotNil() predicate.APIKey {
return predicate.APIKey(sql.FieldNotNull(FieldIPWhitelist))
}
// IPBlacklistIsNil applies the IsNil predicate on the "ip_blacklist" field.
func IPBlacklistIsNil() predicate.APIKey {
return predicate.APIKey(sql.FieldIsNull(FieldIPBlacklist))
}
// IPBlacklistNotNil applies the NotNil predicate on the "ip_blacklist" field.
func IPBlacklistNotNil() predicate.APIKey {
return predicate.APIKey(sql.FieldNotNull(FieldIPBlacklist))
}
// HasUser applies the HasEdge predicate on the "user" edge.
func HasUser() predicate.APIKey {
return predicate.APIKey(func(s *sql.Selector) {

View File

@@ -113,6 +113,18 @@ func (_c *APIKeyCreate) SetNillableStatus(v *string) *APIKeyCreate {
return _c
}
// SetIPWhitelist sets the "ip_whitelist" field.
func (_c *APIKeyCreate) SetIPWhitelist(v []string) *APIKeyCreate {
_c.mutation.SetIPWhitelist(v)
return _c
}
// SetIPBlacklist sets the "ip_blacklist" field.
func (_c *APIKeyCreate) SetIPBlacklist(v []string) *APIKeyCreate {
_c.mutation.SetIPBlacklist(v)
return _c
}
// SetUser sets the "user" edge to the User entity.
func (_c *APIKeyCreate) SetUser(v *User) *APIKeyCreate {
return _c.SetUserID(v.ID)
@@ -285,6 +297,14 @@ func (_c *APIKeyCreate) createSpec() (*APIKey, *sqlgraph.CreateSpec) {
_spec.SetField(apikey.FieldStatus, field.TypeString, value)
_node.Status = value
}
if value, ok := _c.mutation.IPWhitelist(); ok {
_spec.SetField(apikey.FieldIPWhitelist, field.TypeJSON, value)
_node.IPWhitelist = value
}
if value, ok := _c.mutation.IPBlacklist(); ok {
_spec.SetField(apikey.FieldIPBlacklist, field.TypeJSON, value)
_node.IPBlacklist = value
}
if nodes := _c.mutation.UserIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
@@ -483,6 +503,42 @@ func (u *APIKeyUpsert) UpdateStatus() *APIKeyUpsert {
return u
}
// SetIPWhitelist sets the "ip_whitelist" field.
func (u *APIKeyUpsert) SetIPWhitelist(v []string) *APIKeyUpsert {
u.Set(apikey.FieldIPWhitelist, v)
return u
}
// UpdateIPWhitelist sets the "ip_whitelist" field to the value that was provided on create.
func (u *APIKeyUpsert) UpdateIPWhitelist() *APIKeyUpsert {
u.SetExcluded(apikey.FieldIPWhitelist)
return u
}
// ClearIPWhitelist clears the value of the "ip_whitelist" field.
func (u *APIKeyUpsert) ClearIPWhitelist() *APIKeyUpsert {
u.SetNull(apikey.FieldIPWhitelist)
return u
}
// SetIPBlacklist sets the "ip_blacklist" field.
func (u *APIKeyUpsert) SetIPBlacklist(v []string) *APIKeyUpsert {
u.Set(apikey.FieldIPBlacklist, v)
return u
}
// UpdateIPBlacklist sets the "ip_blacklist" field to the value that was provided on create.
func (u *APIKeyUpsert) UpdateIPBlacklist() *APIKeyUpsert {
u.SetExcluded(apikey.FieldIPBlacklist)
return u
}
// ClearIPBlacklist clears the value of the "ip_blacklist" field.
func (u *APIKeyUpsert) ClearIPBlacklist() *APIKeyUpsert {
u.SetNull(apikey.FieldIPBlacklist)
return u
}
// UpdateNewValues updates the mutable fields using the new values that were set on create.
// Using this option is equivalent to using:
//
@@ -640,6 +696,48 @@ func (u *APIKeyUpsertOne) UpdateStatus() *APIKeyUpsertOne {
})
}
// SetIPWhitelist sets the "ip_whitelist" field.
func (u *APIKeyUpsertOne) SetIPWhitelist(v []string) *APIKeyUpsertOne {
return u.Update(func(s *APIKeyUpsert) {
s.SetIPWhitelist(v)
})
}
// UpdateIPWhitelist sets the "ip_whitelist" field to the value that was provided on create.
func (u *APIKeyUpsertOne) UpdateIPWhitelist() *APIKeyUpsertOne {
return u.Update(func(s *APIKeyUpsert) {
s.UpdateIPWhitelist()
})
}
// ClearIPWhitelist clears the value of the "ip_whitelist" field.
func (u *APIKeyUpsertOne) ClearIPWhitelist() *APIKeyUpsertOne {
return u.Update(func(s *APIKeyUpsert) {
s.ClearIPWhitelist()
})
}
// SetIPBlacklist sets the "ip_blacklist" field.
func (u *APIKeyUpsertOne) SetIPBlacklist(v []string) *APIKeyUpsertOne {
return u.Update(func(s *APIKeyUpsert) {
s.SetIPBlacklist(v)
})
}
// UpdateIPBlacklist sets the "ip_blacklist" field to the value that was provided on create.
func (u *APIKeyUpsertOne) UpdateIPBlacklist() *APIKeyUpsertOne {
return u.Update(func(s *APIKeyUpsert) {
s.UpdateIPBlacklist()
})
}
// ClearIPBlacklist clears the value of the "ip_blacklist" field.
func (u *APIKeyUpsertOne) ClearIPBlacklist() *APIKeyUpsertOne {
return u.Update(func(s *APIKeyUpsert) {
s.ClearIPBlacklist()
})
}
// Exec executes the query.
func (u *APIKeyUpsertOne) Exec(ctx context.Context) error {
if len(u.create.conflict) == 0 {
@@ -963,6 +1061,48 @@ func (u *APIKeyUpsertBulk) UpdateStatus() *APIKeyUpsertBulk {
})
}
// SetIPWhitelist sets the "ip_whitelist" field.
func (u *APIKeyUpsertBulk) SetIPWhitelist(v []string) *APIKeyUpsertBulk {
return u.Update(func(s *APIKeyUpsert) {
s.SetIPWhitelist(v)
})
}
// UpdateIPWhitelist sets the "ip_whitelist" field to the value that was provided on create.
func (u *APIKeyUpsertBulk) UpdateIPWhitelist() *APIKeyUpsertBulk {
return u.Update(func(s *APIKeyUpsert) {
s.UpdateIPWhitelist()
})
}
// ClearIPWhitelist clears the value of the "ip_whitelist" field.
func (u *APIKeyUpsertBulk) ClearIPWhitelist() *APIKeyUpsertBulk {
return u.Update(func(s *APIKeyUpsert) {
s.ClearIPWhitelist()
})
}
// SetIPBlacklist sets the "ip_blacklist" field.
func (u *APIKeyUpsertBulk) SetIPBlacklist(v []string) *APIKeyUpsertBulk {
return u.Update(func(s *APIKeyUpsert) {
s.SetIPBlacklist(v)
})
}
// UpdateIPBlacklist sets the "ip_blacklist" field to the value that was provided on create.
func (u *APIKeyUpsertBulk) UpdateIPBlacklist() *APIKeyUpsertBulk {
return u.Update(func(s *APIKeyUpsert) {
s.UpdateIPBlacklist()
})
}
// ClearIPBlacklist clears the value of the "ip_blacklist" field.
func (u *APIKeyUpsertBulk) ClearIPBlacklist() *APIKeyUpsertBulk {
return u.Update(func(s *APIKeyUpsert) {
s.ClearIPBlacklist()
})
}
// Exec executes the query.
func (u *APIKeyUpsertBulk) Exec(ctx context.Context) error {
if u.create.err != nil {

View File

@@ -10,6 +10,7 @@ import (
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/dialect/sql/sqljson"
"entgo.io/ent/schema/field"
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/group"
@@ -133,6 +134,42 @@ func (_u *APIKeyUpdate) SetNillableStatus(v *string) *APIKeyUpdate {
return _u
}
// SetIPWhitelist sets the "ip_whitelist" field.
func (_u *APIKeyUpdate) SetIPWhitelist(v []string) *APIKeyUpdate {
_u.mutation.SetIPWhitelist(v)
return _u
}
// AppendIPWhitelist appends value to the "ip_whitelist" field.
func (_u *APIKeyUpdate) AppendIPWhitelist(v []string) *APIKeyUpdate {
_u.mutation.AppendIPWhitelist(v)
return _u
}
// ClearIPWhitelist clears the value of the "ip_whitelist" field.
func (_u *APIKeyUpdate) ClearIPWhitelist() *APIKeyUpdate {
_u.mutation.ClearIPWhitelist()
return _u
}
// SetIPBlacklist sets the "ip_blacklist" field.
func (_u *APIKeyUpdate) SetIPBlacklist(v []string) *APIKeyUpdate {
_u.mutation.SetIPBlacklist(v)
return _u
}
// AppendIPBlacklist appends value to the "ip_blacklist" field.
func (_u *APIKeyUpdate) AppendIPBlacklist(v []string) *APIKeyUpdate {
_u.mutation.AppendIPBlacklist(v)
return _u
}
// ClearIPBlacklist clears the value of the "ip_blacklist" field.
func (_u *APIKeyUpdate) ClearIPBlacklist() *APIKeyUpdate {
_u.mutation.ClearIPBlacklist()
return _u
}
// SetUser sets the "user" edge to the User entity.
func (_u *APIKeyUpdate) SetUser(v *User) *APIKeyUpdate {
return _u.SetUserID(v.ID)
@@ -291,6 +328,28 @@ func (_u *APIKeyUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if value, ok := _u.mutation.Status(); ok {
_spec.SetField(apikey.FieldStatus, field.TypeString, value)
}
if value, ok := _u.mutation.IPWhitelist(); ok {
_spec.SetField(apikey.FieldIPWhitelist, field.TypeJSON, value)
}
if value, ok := _u.mutation.AppendedIPWhitelist(); ok {
_spec.AddModifier(func(u *sql.UpdateBuilder) {
sqljson.Append(u, apikey.FieldIPWhitelist, value)
})
}
if _u.mutation.IPWhitelistCleared() {
_spec.ClearField(apikey.FieldIPWhitelist, field.TypeJSON)
}
if value, ok := _u.mutation.IPBlacklist(); ok {
_spec.SetField(apikey.FieldIPBlacklist, field.TypeJSON, value)
}
if value, ok := _u.mutation.AppendedIPBlacklist(); ok {
_spec.AddModifier(func(u *sql.UpdateBuilder) {
sqljson.Append(u, apikey.FieldIPBlacklist, value)
})
}
if _u.mutation.IPBlacklistCleared() {
_spec.ClearField(apikey.FieldIPBlacklist, field.TypeJSON)
}
if _u.mutation.UserCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,
@@ -516,6 +575,42 @@ func (_u *APIKeyUpdateOne) SetNillableStatus(v *string) *APIKeyUpdateOne {
return _u
}
// SetIPWhitelist sets the "ip_whitelist" field.
func (_u *APIKeyUpdateOne) SetIPWhitelist(v []string) *APIKeyUpdateOne {
_u.mutation.SetIPWhitelist(v)
return _u
}
// AppendIPWhitelist appends value to the "ip_whitelist" field.
func (_u *APIKeyUpdateOne) AppendIPWhitelist(v []string) *APIKeyUpdateOne {
_u.mutation.AppendIPWhitelist(v)
return _u
}
// ClearIPWhitelist clears the value of the "ip_whitelist" field.
func (_u *APIKeyUpdateOne) ClearIPWhitelist() *APIKeyUpdateOne {
_u.mutation.ClearIPWhitelist()
return _u
}
// SetIPBlacklist sets the "ip_blacklist" field.
func (_u *APIKeyUpdateOne) SetIPBlacklist(v []string) *APIKeyUpdateOne {
_u.mutation.SetIPBlacklist(v)
return _u
}
// AppendIPBlacklist appends value to the "ip_blacklist" field.
func (_u *APIKeyUpdateOne) AppendIPBlacklist(v []string) *APIKeyUpdateOne {
_u.mutation.AppendIPBlacklist(v)
return _u
}
// ClearIPBlacklist clears the value of the "ip_blacklist" field.
func (_u *APIKeyUpdateOne) ClearIPBlacklist() *APIKeyUpdateOne {
_u.mutation.ClearIPBlacklist()
return _u
}
// SetUser sets the "user" edge to the User entity.
func (_u *APIKeyUpdateOne) SetUser(v *User) *APIKeyUpdateOne {
return _u.SetUserID(v.ID)
@@ -704,6 +799,28 @@ func (_u *APIKeyUpdateOne) sqlSave(ctx context.Context) (_node *APIKey, err erro
if value, ok := _u.mutation.Status(); ok {
_spec.SetField(apikey.FieldStatus, field.TypeString, value)
}
if value, ok := _u.mutation.IPWhitelist(); ok {
_spec.SetField(apikey.FieldIPWhitelist, field.TypeJSON, value)
}
if value, ok := _u.mutation.AppendedIPWhitelist(); ok {
_spec.AddModifier(func(u *sql.UpdateBuilder) {
sqljson.Append(u, apikey.FieldIPWhitelist, value)
})
}
if _u.mutation.IPWhitelistCleared() {
_spec.ClearField(apikey.FieldIPWhitelist, field.TypeJSON)
}
if value, ok := _u.mutation.IPBlacklist(); ok {
_spec.SetField(apikey.FieldIPBlacklist, field.TypeJSON, value)
}
if value, ok := _u.mutation.AppendedIPBlacklist(); ok {
_spec.AddModifier(func(u *sql.UpdateBuilder) {
sqljson.Append(u, apikey.FieldIPBlacklist, value)
})
}
if _u.mutation.IPBlacklistCleared() {
_spec.ClearField(apikey.FieldIPBlacklist, field.TypeJSON)
}
if _u.mutation.UserCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.M2O,

View File

@@ -18,6 +18,8 @@ var (
{Name: "key", Type: field.TypeString, Unique: true, Size: 128},
{Name: "name", Type: field.TypeString, Size: 100},
{Name: "status", Type: field.TypeString, Size: 20, Default: "active"},
{Name: "ip_whitelist", Type: field.TypeJSON, Nullable: true},
{Name: "ip_blacklist", Type: field.TypeJSON, Nullable: true},
{Name: "group_id", Type: field.TypeInt64, Nullable: true},
{Name: "user_id", Type: field.TypeInt64},
}
@@ -29,13 +31,13 @@ var (
ForeignKeys: []*schema.ForeignKey{
{
Symbol: "api_keys_groups_api_keys",
Columns: []*schema.Column{APIKeysColumns[7]},
Columns: []*schema.Column{APIKeysColumns[9]},
RefColumns: []*schema.Column{GroupsColumns[0]},
OnDelete: schema.SetNull,
},
{
Symbol: "api_keys_users_api_keys",
Columns: []*schema.Column{APIKeysColumns[8]},
Columns: []*schema.Column{APIKeysColumns[10]},
RefColumns: []*schema.Column{UsersColumns[0]},
OnDelete: schema.NoAction,
},
@@ -44,12 +46,12 @@ var (
{
Name: "apikey_user_id",
Unique: false,
Columns: []*schema.Column{APIKeysColumns[8]},
Columns: []*schema.Column{APIKeysColumns[10]},
},
{
Name: "apikey_group_id",
Unique: false,
Columns: []*schema.Column{APIKeysColumns[7]},
Columns: []*schema.Column{APIKeysColumns[9]},
},
{
Name: "apikey_status",
@@ -376,6 +378,7 @@ var (
{Name: "duration_ms", Type: field.TypeInt, Nullable: true},
{Name: "first_token_ms", Type: field.TypeInt, Nullable: true},
{Name: "user_agent", Type: field.TypeString, Nullable: true, Size: 512},
{Name: "ip_address", Type: field.TypeString, Nullable: true, Size: 45},
{Name: "image_count", Type: field.TypeInt, Default: 0},
{Name: "image_size", Type: field.TypeString, Nullable: true, Size: 10},
{Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
@@ -393,31 +396,31 @@ var (
ForeignKeys: []*schema.ForeignKey{
{
Symbol: "usage_logs_api_keys_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[24]},
Columns: []*schema.Column{UsageLogsColumns[25]},
RefColumns: []*schema.Column{APIKeysColumns[0]},
OnDelete: schema.NoAction,
},
{
Symbol: "usage_logs_accounts_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[25]},
Columns: []*schema.Column{UsageLogsColumns[26]},
RefColumns: []*schema.Column{AccountsColumns[0]},
OnDelete: schema.NoAction,
},
{
Symbol: "usage_logs_groups_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[26]},
Columns: []*schema.Column{UsageLogsColumns[27]},
RefColumns: []*schema.Column{GroupsColumns[0]},
OnDelete: schema.SetNull,
},
{
Symbol: "usage_logs_users_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[27]},
Columns: []*schema.Column{UsageLogsColumns[28]},
RefColumns: []*schema.Column{UsersColumns[0]},
OnDelete: schema.NoAction,
},
{
Symbol: "usage_logs_user_subscriptions_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[28]},
Columns: []*schema.Column{UsageLogsColumns[29]},
RefColumns: []*schema.Column{UserSubscriptionsColumns[0]},
OnDelete: schema.SetNull,
},
@@ -426,32 +429,32 @@ var (
{
Name: "usagelog_user_id",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[27]},
Columns: []*schema.Column{UsageLogsColumns[28]},
},
{
Name: "usagelog_api_key_id",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[24]},
Columns: []*schema.Column{UsageLogsColumns[25]},
},
{
Name: "usagelog_account_id",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[25]},
Columns: []*schema.Column{UsageLogsColumns[26]},
},
{
Name: "usagelog_group_id",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[26]},
Columns: []*schema.Column{UsageLogsColumns[27]},
},
{
Name: "usagelog_subscription_id",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[28]},
Columns: []*schema.Column{UsageLogsColumns[29]},
},
{
Name: "usagelog_created_at",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[23]},
Columns: []*schema.Column{UsageLogsColumns[24]},
},
{
Name: "usagelog_model",
@@ -466,12 +469,12 @@ var (
{
Name: "usagelog_user_id_created_at",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[27], UsageLogsColumns[23]},
Columns: []*schema.Column{UsageLogsColumns[28], UsageLogsColumns[24]},
},
{
Name: "usagelog_api_key_id_created_at",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[24], UsageLogsColumns[23]},
Columns: []*schema.Column{UsageLogsColumns[25], UsageLogsColumns[24]},
},
},
}

View File

@@ -54,26 +54,30 @@ const (
// APIKeyMutation represents an operation that mutates the APIKey nodes in the graph.
type APIKeyMutation struct {
config
op Op
typ string
id *int64
created_at *time.Time
updated_at *time.Time
deleted_at *time.Time
key *string
name *string
status *string
clearedFields map[string]struct{}
user *int64
cleareduser bool
group *int64
clearedgroup bool
usage_logs map[int64]struct{}
removedusage_logs map[int64]struct{}
clearedusage_logs bool
done bool
oldValue func(context.Context) (*APIKey, error)
predicates []predicate.APIKey
op Op
typ string
id *int64
created_at *time.Time
updated_at *time.Time
deleted_at *time.Time
key *string
name *string
status *string
ip_whitelist *[]string
appendip_whitelist []string
ip_blacklist *[]string
appendip_blacklist []string
clearedFields map[string]struct{}
user *int64
cleareduser bool
group *int64
clearedgroup bool
usage_logs map[int64]struct{}
removedusage_logs map[int64]struct{}
clearedusage_logs bool
done bool
oldValue func(context.Context) (*APIKey, error)
predicates []predicate.APIKey
}
var _ ent.Mutation = (*APIKeyMutation)(nil)
@@ -488,6 +492,136 @@ func (m *APIKeyMutation) ResetStatus() {
m.status = nil
}
// SetIPWhitelist sets the "ip_whitelist" field.
func (m *APIKeyMutation) SetIPWhitelist(s []string) {
m.ip_whitelist = &s
m.appendip_whitelist = nil
}
// IPWhitelist returns the value of the "ip_whitelist" field in the mutation.
func (m *APIKeyMutation) IPWhitelist() (r []string, exists bool) {
v := m.ip_whitelist
if v == nil {
return
}
return *v, true
}
// OldIPWhitelist returns the old "ip_whitelist" field's value of the APIKey entity.
// If the APIKey object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *APIKeyMutation) OldIPWhitelist(ctx context.Context) (v []string, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldIPWhitelist is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldIPWhitelist requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldIPWhitelist: %w", err)
}
return oldValue.IPWhitelist, nil
}
// AppendIPWhitelist adds s to the "ip_whitelist" field.
func (m *APIKeyMutation) AppendIPWhitelist(s []string) {
m.appendip_whitelist = append(m.appendip_whitelist, s...)
}
// AppendedIPWhitelist returns the list of values that were appended to the "ip_whitelist" field in this mutation.
func (m *APIKeyMutation) AppendedIPWhitelist() ([]string, bool) {
if len(m.appendip_whitelist) == 0 {
return nil, false
}
return m.appendip_whitelist, true
}
// ClearIPWhitelist clears the value of the "ip_whitelist" field.
func (m *APIKeyMutation) ClearIPWhitelist() {
m.ip_whitelist = nil
m.appendip_whitelist = nil
m.clearedFields[apikey.FieldIPWhitelist] = struct{}{}
}
// IPWhitelistCleared returns if the "ip_whitelist" field was cleared in this mutation.
func (m *APIKeyMutation) IPWhitelistCleared() bool {
_, ok := m.clearedFields[apikey.FieldIPWhitelist]
return ok
}
// ResetIPWhitelist resets all changes to the "ip_whitelist" field.
func (m *APIKeyMutation) ResetIPWhitelist() {
m.ip_whitelist = nil
m.appendip_whitelist = nil
delete(m.clearedFields, apikey.FieldIPWhitelist)
}
// SetIPBlacklist sets the "ip_blacklist" field.
func (m *APIKeyMutation) SetIPBlacklist(s []string) {
m.ip_blacklist = &s
m.appendip_blacklist = nil
}
// IPBlacklist returns the value of the "ip_blacklist" field in the mutation.
func (m *APIKeyMutation) IPBlacklist() (r []string, exists bool) {
v := m.ip_blacklist
if v == nil {
return
}
return *v, true
}
// OldIPBlacklist returns the old "ip_blacklist" field's value of the APIKey entity.
// If the APIKey object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *APIKeyMutation) OldIPBlacklist(ctx context.Context) (v []string, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldIPBlacklist is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldIPBlacklist requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldIPBlacklist: %w", err)
}
return oldValue.IPBlacklist, nil
}
// AppendIPBlacklist adds s to the "ip_blacklist" field.
func (m *APIKeyMutation) AppendIPBlacklist(s []string) {
m.appendip_blacklist = append(m.appendip_blacklist, s...)
}
// AppendedIPBlacklist returns the list of values that were appended to the "ip_blacklist" field in this mutation.
func (m *APIKeyMutation) AppendedIPBlacklist() ([]string, bool) {
if len(m.appendip_blacklist) == 0 {
return nil, false
}
return m.appendip_blacklist, true
}
// ClearIPBlacklist clears the value of the "ip_blacklist" field.
func (m *APIKeyMutation) ClearIPBlacklist() {
m.ip_blacklist = nil
m.appendip_blacklist = nil
m.clearedFields[apikey.FieldIPBlacklist] = struct{}{}
}
// IPBlacklistCleared returns if the "ip_blacklist" field was cleared in this mutation.
func (m *APIKeyMutation) IPBlacklistCleared() bool {
_, ok := m.clearedFields[apikey.FieldIPBlacklist]
return ok
}
// ResetIPBlacklist resets all changes to the "ip_blacklist" field.
func (m *APIKeyMutation) ResetIPBlacklist() {
m.ip_blacklist = nil
m.appendip_blacklist = nil
delete(m.clearedFields, apikey.FieldIPBlacklist)
}
// ClearUser clears the "user" edge to the User entity.
func (m *APIKeyMutation) ClearUser() {
m.cleareduser = true
@@ -630,7 +764,7 @@ func (m *APIKeyMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *APIKeyMutation) Fields() []string {
fields := make([]string, 0, 8)
fields := make([]string, 0, 10)
if m.created_at != nil {
fields = append(fields, apikey.FieldCreatedAt)
}
@@ -655,6 +789,12 @@ func (m *APIKeyMutation) Fields() []string {
if m.status != nil {
fields = append(fields, apikey.FieldStatus)
}
if m.ip_whitelist != nil {
fields = append(fields, apikey.FieldIPWhitelist)
}
if m.ip_blacklist != nil {
fields = append(fields, apikey.FieldIPBlacklist)
}
return fields
}
@@ -679,6 +819,10 @@ func (m *APIKeyMutation) Field(name string) (ent.Value, bool) {
return m.GroupID()
case apikey.FieldStatus:
return m.Status()
case apikey.FieldIPWhitelist:
return m.IPWhitelist()
case apikey.FieldIPBlacklist:
return m.IPBlacklist()
}
return nil, false
}
@@ -704,6 +848,10 @@ func (m *APIKeyMutation) OldField(ctx context.Context, name string) (ent.Value,
return m.OldGroupID(ctx)
case apikey.FieldStatus:
return m.OldStatus(ctx)
case apikey.FieldIPWhitelist:
return m.OldIPWhitelist(ctx)
case apikey.FieldIPBlacklist:
return m.OldIPBlacklist(ctx)
}
return nil, fmt.Errorf("unknown APIKey field %s", name)
}
@@ -769,6 +917,20 @@ func (m *APIKeyMutation) SetField(name string, value ent.Value) error {
}
m.SetStatus(v)
return nil
case apikey.FieldIPWhitelist:
v, ok := value.([]string)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetIPWhitelist(v)
return nil
case apikey.FieldIPBlacklist:
v, ok := value.([]string)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetIPBlacklist(v)
return nil
}
return fmt.Errorf("unknown APIKey field %s", name)
}
@@ -808,6 +970,12 @@ func (m *APIKeyMutation) ClearedFields() []string {
if m.FieldCleared(apikey.FieldGroupID) {
fields = append(fields, apikey.FieldGroupID)
}
if m.FieldCleared(apikey.FieldIPWhitelist) {
fields = append(fields, apikey.FieldIPWhitelist)
}
if m.FieldCleared(apikey.FieldIPBlacklist) {
fields = append(fields, apikey.FieldIPBlacklist)
}
return fields
}
@@ -828,6 +996,12 @@ func (m *APIKeyMutation) ClearField(name string) error {
case apikey.FieldGroupID:
m.ClearGroupID()
return nil
case apikey.FieldIPWhitelist:
m.ClearIPWhitelist()
return nil
case apikey.FieldIPBlacklist:
m.ClearIPBlacklist()
return nil
}
return fmt.Errorf("unknown APIKey nullable field %s", name)
}
@@ -860,6 +1034,12 @@ func (m *APIKeyMutation) ResetField(name string) error {
case apikey.FieldStatus:
m.ResetStatus()
return nil
case apikey.FieldIPWhitelist:
m.ResetIPWhitelist()
return nil
case apikey.FieldIPBlacklist:
m.ResetIPBlacklist()
return nil
}
return fmt.Errorf("unknown APIKey field %s", name)
}
@@ -8396,6 +8576,7 @@ type UsageLogMutation struct {
first_token_ms *int
addfirst_token_ms *int
user_agent *string
ip_address *string
image_count *int
addimage_count *int
image_size *string
@@ -9801,6 +9982,55 @@ func (m *UsageLogMutation) ResetUserAgent() {
delete(m.clearedFields, usagelog.FieldUserAgent)
}
// SetIPAddress sets the "ip_address" field.
func (m *UsageLogMutation) SetIPAddress(s string) {
m.ip_address = &s
}
// IPAddress returns the value of the "ip_address" field in the mutation.
func (m *UsageLogMutation) IPAddress() (r string, exists bool) {
v := m.ip_address
if v == nil {
return
}
return *v, true
}
// OldIPAddress returns the old "ip_address" field's value of the UsageLog entity.
// If the UsageLog object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *UsageLogMutation) OldIPAddress(ctx context.Context) (v *string, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldIPAddress is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldIPAddress requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldIPAddress: %w", err)
}
return oldValue.IPAddress, nil
}
// ClearIPAddress clears the value of the "ip_address" field.
func (m *UsageLogMutation) ClearIPAddress() {
m.ip_address = nil
m.clearedFields[usagelog.FieldIPAddress] = struct{}{}
}
// IPAddressCleared returns if the "ip_address" field was cleared in this mutation.
func (m *UsageLogMutation) IPAddressCleared() bool {
_, ok := m.clearedFields[usagelog.FieldIPAddress]
return ok
}
// ResetIPAddress resets all changes to the "ip_address" field.
func (m *UsageLogMutation) ResetIPAddress() {
m.ip_address = nil
delete(m.clearedFields, usagelog.FieldIPAddress)
}
// SetImageCount sets the "image_count" field.
func (m *UsageLogMutation) SetImageCount(i int) {
m.image_count = &i
@@ -10111,7 +10341,7 @@ func (m *UsageLogMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *UsageLogMutation) Fields() []string {
fields := make([]string, 0, 28)
fields := make([]string, 0, 29)
if m.user != nil {
fields = append(fields, usagelog.FieldUserID)
}
@@ -10187,6 +10417,9 @@ func (m *UsageLogMutation) Fields() []string {
if m.user_agent != nil {
fields = append(fields, usagelog.FieldUserAgent)
}
if m.ip_address != nil {
fields = append(fields, usagelog.FieldIPAddress)
}
if m.image_count != nil {
fields = append(fields, usagelog.FieldImageCount)
}
@@ -10254,6 +10487,8 @@ func (m *UsageLogMutation) Field(name string) (ent.Value, bool) {
return m.FirstTokenMs()
case usagelog.FieldUserAgent:
return m.UserAgent()
case usagelog.FieldIPAddress:
return m.IPAddress()
case usagelog.FieldImageCount:
return m.ImageCount()
case usagelog.FieldImageSize:
@@ -10319,6 +10554,8 @@ func (m *UsageLogMutation) OldField(ctx context.Context, name string) (ent.Value
return m.OldFirstTokenMs(ctx)
case usagelog.FieldUserAgent:
return m.OldUserAgent(ctx)
case usagelog.FieldIPAddress:
return m.OldIPAddress(ctx)
case usagelog.FieldImageCount:
return m.OldImageCount(ctx)
case usagelog.FieldImageSize:
@@ -10509,6 +10746,13 @@ func (m *UsageLogMutation) SetField(name string, value ent.Value) error {
}
m.SetUserAgent(v)
return nil
case usagelog.FieldIPAddress:
v, ok := value.(string)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetIPAddress(v)
return nil
case usagelog.FieldImageCount:
v, ok := value.(int)
if !ok {
@@ -10782,6 +11026,9 @@ func (m *UsageLogMutation) ClearedFields() []string {
if m.FieldCleared(usagelog.FieldUserAgent) {
fields = append(fields, usagelog.FieldUserAgent)
}
if m.FieldCleared(usagelog.FieldIPAddress) {
fields = append(fields, usagelog.FieldIPAddress)
}
if m.FieldCleared(usagelog.FieldImageSize) {
fields = append(fields, usagelog.FieldImageSize)
}
@@ -10814,6 +11061,9 @@ func (m *UsageLogMutation) ClearField(name string) error {
case usagelog.FieldUserAgent:
m.ClearUserAgent()
return nil
case usagelog.FieldIPAddress:
m.ClearIPAddress()
return nil
case usagelog.FieldImageSize:
m.ClearImageSize()
return nil
@@ -10900,6 +11150,9 @@ func (m *UsageLogMutation) ResetField(name string) error {
case usagelog.FieldUserAgent:
m.ResetUserAgent()
return nil
case usagelog.FieldIPAddress:
m.ResetIPAddress()
return nil
case usagelog.FieldImageCount:
m.ResetImageCount()
return nil

View File

@@ -533,16 +533,20 @@ func init() {
usagelogDescUserAgent := usagelogFields[24].Descriptor()
// usagelog.UserAgentValidator is a validator for the "user_agent" field. It is called by the builders before save.
usagelog.UserAgentValidator = usagelogDescUserAgent.Validators[0].(func(string) error)
// usagelogDescIPAddress is the schema descriptor for ip_address field.
usagelogDescIPAddress := usagelogFields[25].Descriptor()
// usagelog.IPAddressValidator is a validator for the "ip_address" field. It is called by the builders before save.
usagelog.IPAddressValidator = usagelogDescIPAddress.Validators[0].(func(string) error)
// usagelogDescImageCount is the schema descriptor for image_count field.
usagelogDescImageCount := usagelogFields[25].Descriptor()
usagelogDescImageCount := usagelogFields[26].Descriptor()
// usagelog.DefaultImageCount holds the default value on creation for the image_count field.
usagelog.DefaultImageCount = usagelogDescImageCount.Default.(int)
// usagelogDescImageSize is the schema descriptor for image_size field.
usagelogDescImageSize := usagelogFields[26].Descriptor()
usagelogDescImageSize := usagelogFields[27].Descriptor()
// usagelog.ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save.
usagelog.ImageSizeValidator = usagelogDescImageSize.Validators[0].(func(string) error)
// usagelogDescCreatedAt is the schema descriptor for created_at field.
usagelogDescCreatedAt := usagelogFields[27].Descriptor()
usagelogDescCreatedAt := usagelogFields[28].Descriptor()
// usagelog.DefaultCreatedAt holds the default value on creation for the created_at field.
usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time)
userMixin := schema.User{}.Mixin()

View File

@@ -46,6 +46,12 @@ func (APIKey) Fields() []ent.Field {
field.String("status").
MaxLen(20).
Default(service.StatusActive),
field.JSON("ip_whitelist", []string{}).
Optional().
Comment("Allowed IPs/CIDRs, e.g. [\"192.168.1.100\", \"10.0.0.0/8\"]"),
field.JSON("ip_blacklist", []string{}).
Optional().
Comment("Blocked IPs/CIDRs"),
}
}

View File

@@ -100,6 +100,10 @@ func (UsageLog) Fields() []ent.Field {
MaxLen(512).
Optional().
Nillable(),
field.String("ip_address").
MaxLen(45). // 支持 IPv6
Optional().
Nillable(),
// 图片生成字段(仅 gemini-3-pro-image 等图片模型使用)
field.Int("image_count").

View File

@@ -72,6 +72,8 @@ type UsageLog struct {
FirstTokenMs *int `json:"first_token_ms,omitempty"`
// UserAgent holds the value of the "user_agent" field.
UserAgent *string `json:"user_agent,omitempty"`
// IPAddress holds the value of the "ip_address" field.
IPAddress *string `json:"ip_address,omitempty"`
// ImageCount holds the value of the "image_count" field.
ImageCount int `json:"image_count,omitempty"`
// ImageSize holds the value of the "image_size" field.
@@ -167,7 +169,7 @@ func (*UsageLog) scanValues(columns []string) ([]any, error) {
values[i] = new(sql.NullFloat64)
case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount:
values[i] = new(sql.NullInt64)
case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldUserAgent, usagelog.FieldImageSize:
case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize:
values[i] = new(sql.NullString)
case usagelog.FieldCreatedAt:
values[i] = new(sql.NullTime)
@@ -347,6 +349,13 @@ func (_m *UsageLog) assignValues(columns []string, values []any) error {
_m.UserAgent = new(string)
*_m.UserAgent = value.String
}
case usagelog.FieldIPAddress:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field ip_address", values[i])
} else if value.Valid {
_m.IPAddress = new(string)
*_m.IPAddress = value.String
}
case usagelog.FieldImageCount:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field image_count", values[i])
@@ -512,6 +521,11 @@ func (_m *UsageLog) String() string {
builder.WriteString(*v)
}
builder.WriteString(", ")
if v := _m.IPAddress; v != nil {
builder.WriteString("ip_address=")
builder.WriteString(*v)
}
builder.WriteString(", ")
builder.WriteString("image_count=")
builder.WriteString(fmt.Sprintf("%v", _m.ImageCount))
builder.WriteString(", ")

View File

@@ -64,6 +64,8 @@ const (
FieldFirstTokenMs = "first_token_ms"
// FieldUserAgent holds the string denoting the user_agent field in the database.
FieldUserAgent = "user_agent"
// FieldIPAddress holds the string denoting the ip_address field in the database.
FieldIPAddress = "ip_address"
// FieldImageCount holds the string denoting the image_count field in the database.
FieldImageCount = "image_count"
// FieldImageSize holds the string denoting the image_size field in the database.
@@ -147,6 +149,7 @@ var Columns = []string{
FieldDurationMs,
FieldFirstTokenMs,
FieldUserAgent,
FieldIPAddress,
FieldImageCount,
FieldImageSize,
FieldCreatedAt,
@@ -199,6 +202,8 @@ var (
DefaultStream bool
// UserAgentValidator is a validator for the "user_agent" field. It is called by the builders before save.
UserAgentValidator func(string) error
// IPAddressValidator is a validator for the "ip_address" field. It is called by the builders before save.
IPAddressValidator func(string) error
// DefaultImageCount holds the default value on creation for the "image_count" field.
DefaultImageCount int
// ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save.
@@ -340,6 +345,11 @@ func ByUserAgent(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldUserAgent, opts...).ToFunc()
}
// ByIPAddress orders the results by the ip_address field.
func ByIPAddress(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldIPAddress, opts...).ToFunc()
}
// ByImageCount orders the results by the image_count field.
func ByImageCount(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldImageCount, opts...).ToFunc()

View File

@@ -180,6 +180,11 @@ func UserAgent(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldUserAgent, v))
}
// IPAddress applies equality check predicate on the "ip_address" field. It's identical to IPAddressEQ.
func IPAddress(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldIPAddress, v))
}
// ImageCount applies equality check predicate on the "image_count" field. It's identical to ImageCountEQ.
func ImageCount(v int) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldImageCount, v))
@@ -1190,6 +1195,81 @@ func UserAgentContainsFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContainsFold(FieldUserAgent, v))
}
// IPAddressEQ applies the EQ predicate on the "ip_address" field.
func IPAddressEQ(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldIPAddress, v))
}
// IPAddressNEQ applies the NEQ predicate on the "ip_address" field.
func IPAddressNEQ(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldNEQ(FieldIPAddress, v))
}
// IPAddressIn applies the In predicate on the "ip_address" field.
func IPAddressIn(vs ...string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldIn(FieldIPAddress, vs...))
}
// IPAddressNotIn applies the NotIn predicate on the "ip_address" field.
func IPAddressNotIn(vs ...string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldNotIn(FieldIPAddress, vs...))
}
// IPAddressGT applies the GT predicate on the "ip_address" field.
func IPAddressGT(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldGT(FieldIPAddress, v))
}
// IPAddressGTE applies the GTE predicate on the "ip_address" field.
func IPAddressGTE(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldGTE(FieldIPAddress, v))
}
// IPAddressLT applies the LT predicate on the "ip_address" field.
func IPAddressLT(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldLT(FieldIPAddress, v))
}
// IPAddressLTE applies the LTE predicate on the "ip_address" field.
func IPAddressLTE(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldLTE(FieldIPAddress, v))
}
// IPAddressContains applies the Contains predicate on the "ip_address" field.
func IPAddressContains(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContains(FieldIPAddress, v))
}
// IPAddressHasPrefix applies the HasPrefix predicate on the "ip_address" field.
func IPAddressHasPrefix(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldHasPrefix(FieldIPAddress, v))
}
// IPAddressHasSuffix applies the HasSuffix predicate on the "ip_address" field.
func IPAddressHasSuffix(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldHasSuffix(FieldIPAddress, v))
}
// IPAddressIsNil applies the IsNil predicate on the "ip_address" field.
func IPAddressIsNil() predicate.UsageLog {
return predicate.UsageLog(sql.FieldIsNull(FieldIPAddress))
}
// IPAddressNotNil applies the NotNil predicate on the "ip_address" field.
func IPAddressNotNil() predicate.UsageLog {
return predicate.UsageLog(sql.FieldNotNull(FieldIPAddress))
}
// IPAddressEqualFold applies the EqualFold predicate on the "ip_address" field.
func IPAddressEqualFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEqualFold(FieldIPAddress, v))
}
// IPAddressContainsFold applies the ContainsFold predicate on the "ip_address" field.
func IPAddressContainsFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContainsFold(FieldIPAddress, v))
}
// ImageCountEQ applies the EQ predicate on the "image_count" field.
func ImageCountEQ(v int) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldImageCount, v))

View File

@@ -337,6 +337,20 @@ func (_c *UsageLogCreate) SetNillableUserAgent(v *string) *UsageLogCreate {
return _c
}
// SetIPAddress sets the "ip_address" field.
func (_c *UsageLogCreate) SetIPAddress(v string) *UsageLogCreate {
_c.mutation.SetIPAddress(v)
return _c
}
// SetNillableIPAddress sets the "ip_address" field if the given value is not nil.
func (_c *UsageLogCreate) SetNillableIPAddress(v *string) *UsageLogCreate {
if v != nil {
_c.SetIPAddress(*v)
}
return _c
}
// SetImageCount sets the "image_count" field.
func (_c *UsageLogCreate) SetImageCount(v int) *UsageLogCreate {
_c.mutation.SetImageCount(v)
@@ -586,6 +600,11 @@ func (_c *UsageLogCreate) check() error {
return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)}
}
}
if v, ok := _c.mutation.IPAddress(); ok {
if err := usagelog.IPAddressValidator(v); err != nil {
return &ValidationError{Name: "ip_address", err: fmt.Errorf(`ent: validator failed for field "UsageLog.ip_address": %w`, err)}
}
}
if _, ok := _c.mutation.ImageCount(); !ok {
return &ValidationError{Name: "image_count", err: errors.New(`ent: missing required field "UsageLog.image_count"`)}
}
@@ -713,6 +732,10 @@ func (_c *UsageLogCreate) createSpec() (*UsageLog, *sqlgraph.CreateSpec) {
_spec.SetField(usagelog.FieldUserAgent, field.TypeString, value)
_node.UserAgent = &value
}
if value, ok := _c.mutation.IPAddress(); ok {
_spec.SetField(usagelog.FieldIPAddress, field.TypeString, value)
_node.IPAddress = &value
}
if value, ok := _c.mutation.ImageCount(); ok {
_spec.SetField(usagelog.FieldImageCount, field.TypeInt, value)
_node.ImageCount = value
@@ -1288,6 +1311,24 @@ func (u *UsageLogUpsert) ClearUserAgent() *UsageLogUpsert {
return u
}
// SetIPAddress sets the "ip_address" field.
func (u *UsageLogUpsert) SetIPAddress(v string) *UsageLogUpsert {
u.Set(usagelog.FieldIPAddress, v)
return u
}
// UpdateIPAddress sets the "ip_address" field to the value that was provided on create.
func (u *UsageLogUpsert) UpdateIPAddress() *UsageLogUpsert {
u.SetExcluded(usagelog.FieldIPAddress)
return u
}
// ClearIPAddress clears the value of the "ip_address" field.
func (u *UsageLogUpsert) ClearIPAddress() *UsageLogUpsert {
u.SetNull(usagelog.FieldIPAddress)
return u
}
// SetImageCount sets the "image_count" field.
func (u *UsageLogUpsert) SetImageCount(v int) *UsageLogUpsert {
u.Set(usagelog.FieldImageCount, v)
@@ -1866,6 +1907,27 @@ func (u *UsageLogUpsertOne) ClearUserAgent() *UsageLogUpsertOne {
})
}
// SetIPAddress sets the "ip_address" field.
func (u *UsageLogUpsertOne) SetIPAddress(v string) *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.SetIPAddress(v)
})
}
// UpdateIPAddress sets the "ip_address" field to the value that was provided on create.
func (u *UsageLogUpsertOne) UpdateIPAddress() *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.UpdateIPAddress()
})
}
// ClearIPAddress clears the value of the "ip_address" field.
func (u *UsageLogUpsertOne) ClearIPAddress() *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.ClearIPAddress()
})
}
// SetImageCount sets the "image_count" field.
func (u *UsageLogUpsertOne) SetImageCount(v int) *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
@@ -2616,6 +2678,27 @@ func (u *UsageLogUpsertBulk) ClearUserAgent() *UsageLogUpsertBulk {
})
}
// SetIPAddress sets the "ip_address" field.
func (u *UsageLogUpsertBulk) SetIPAddress(v string) *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.SetIPAddress(v)
})
}
// UpdateIPAddress sets the "ip_address" field to the value that was provided on create.
func (u *UsageLogUpsertBulk) UpdateIPAddress() *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.UpdateIPAddress()
})
}
// ClearIPAddress clears the value of the "ip_address" field.
func (u *UsageLogUpsertBulk) ClearIPAddress() *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.ClearIPAddress()
})
}
// SetImageCount sets the "image_count" field.
func (u *UsageLogUpsertBulk) SetImageCount(v int) *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {

View File

@@ -524,6 +524,26 @@ func (_u *UsageLogUpdate) ClearUserAgent() *UsageLogUpdate {
return _u
}
// SetIPAddress sets the "ip_address" field.
func (_u *UsageLogUpdate) SetIPAddress(v string) *UsageLogUpdate {
_u.mutation.SetIPAddress(v)
return _u
}
// SetNillableIPAddress sets the "ip_address" field if the given value is not nil.
func (_u *UsageLogUpdate) SetNillableIPAddress(v *string) *UsageLogUpdate {
if v != nil {
_u.SetIPAddress(*v)
}
return _u
}
// ClearIPAddress clears the value of the "ip_address" field.
func (_u *UsageLogUpdate) ClearIPAddress() *UsageLogUpdate {
_u.mutation.ClearIPAddress()
return _u
}
// SetImageCount sets the "image_count" field.
func (_u *UsageLogUpdate) SetImageCount(v int) *UsageLogUpdate {
_u.mutation.ResetImageCount()
@@ -669,6 +689,11 @@ func (_u *UsageLogUpdate) check() error {
return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)}
}
}
if v, ok := _u.mutation.IPAddress(); ok {
if err := usagelog.IPAddressValidator(v); err != nil {
return &ValidationError{Name: "ip_address", err: fmt.Errorf(`ent: validator failed for field "UsageLog.ip_address": %w`, err)}
}
}
if v, ok := _u.mutation.ImageSize(); ok {
if err := usagelog.ImageSizeValidator(v); err != nil {
return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)}
@@ -815,6 +840,12 @@ func (_u *UsageLogUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if _u.mutation.UserAgentCleared() {
_spec.ClearField(usagelog.FieldUserAgent, field.TypeString)
}
if value, ok := _u.mutation.IPAddress(); ok {
_spec.SetField(usagelog.FieldIPAddress, field.TypeString, value)
}
if _u.mutation.IPAddressCleared() {
_spec.ClearField(usagelog.FieldIPAddress, field.TypeString)
}
if value, ok := _u.mutation.ImageCount(); ok {
_spec.SetField(usagelog.FieldImageCount, field.TypeInt, value)
}
@@ -1484,6 +1515,26 @@ func (_u *UsageLogUpdateOne) ClearUserAgent() *UsageLogUpdateOne {
return _u
}
// SetIPAddress sets the "ip_address" field.
func (_u *UsageLogUpdateOne) SetIPAddress(v string) *UsageLogUpdateOne {
_u.mutation.SetIPAddress(v)
return _u
}
// SetNillableIPAddress sets the "ip_address" field if the given value is not nil.
func (_u *UsageLogUpdateOne) SetNillableIPAddress(v *string) *UsageLogUpdateOne {
if v != nil {
_u.SetIPAddress(*v)
}
return _u
}
// ClearIPAddress clears the value of the "ip_address" field.
func (_u *UsageLogUpdateOne) ClearIPAddress() *UsageLogUpdateOne {
_u.mutation.ClearIPAddress()
return _u
}
// SetImageCount sets the "image_count" field.
func (_u *UsageLogUpdateOne) SetImageCount(v int) *UsageLogUpdateOne {
_u.mutation.ResetImageCount()
@@ -1642,6 +1693,11 @@ func (_u *UsageLogUpdateOne) check() error {
return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)}
}
}
if v, ok := _u.mutation.IPAddress(); ok {
if err := usagelog.IPAddressValidator(v); err != nil {
return &ValidationError{Name: "ip_address", err: fmt.Errorf(`ent: validator failed for field "UsageLog.ip_address": %w`, err)}
}
}
if v, ok := _u.mutation.ImageSize(); ok {
if err := usagelog.ImageSizeValidator(v); err != nil {
return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)}
@@ -1805,6 +1861,12 @@ func (_u *UsageLogUpdateOne) sqlSave(ctx context.Context) (_node *UsageLog, err
if _u.mutation.UserAgentCleared() {
_spec.ClearField(usagelog.FieldUserAgent, field.TypeString)
}
if value, ok := _u.mutation.IPAddress(); ok {
_spec.SetField(usagelog.FieldIPAddress, field.TypeString, value)
}
if _u.mutation.IPAddressCleared() {
_spec.ClearField(usagelog.FieldIPAddress, field.TypeString)
}
if value, ok := _u.mutation.ImageCount(); ok {
_spec.SetField(usagelog.FieldImageCount, field.TypeInt, value)
}

View File

@@ -27,16 +27,20 @@ func NewAPIKeyHandler(apiKeyService *service.APIKeyService) *APIKeyHandler {
// CreateAPIKeyRequest represents the create API key request payload
type CreateAPIKeyRequest struct {
Name string `json:"name" binding:"required"`
GroupID *int64 `json:"group_id"` // nullable
CustomKey *string `json:"custom_key"` // 可选的自定义key
Name string `json:"name" binding:"required"`
GroupID *int64 `json:"group_id"` // nullable
CustomKey *string `json:"custom_key"` // 可选的自定义key
IPWhitelist []string `json:"ip_whitelist"` // IP 白名单
IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单
}
// UpdateAPIKeyRequest represents the update API key request payload
type UpdateAPIKeyRequest struct {
Name string `json:"name"`
GroupID *int64 `json:"group_id"`
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
Name string `json:"name"`
GroupID *int64 `json:"group_id"`
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
IPWhitelist []string `json:"ip_whitelist"` // IP 白名单
IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单
}
// List handles listing user's API keys with pagination
@@ -110,9 +114,11 @@ func (h *APIKeyHandler) Create(c *gin.Context) {
}
svcReq := service.CreateAPIKeyRequest{
Name: req.Name,
GroupID: req.GroupID,
CustomKey: req.CustomKey,
Name: req.Name,
GroupID: req.GroupID,
CustomKey: req.CustomKey,
IPWhitelist: req.IPWhitelist,
IPBlacklist: req.IPBlacklist,
}
key, err := h.apiKeyService.Create(c.Request.Context(), subject.UserID, svcReq)
if err != nil {
@@ -144,7 +150,10 @@ func (h *APIKeyHandler) Update(c *gin.Context) {
return
}
svcReq := service.UpdateAPIKeyRequest{}
svcReq := service.UpdateAPIKeyRequest{
IPWhitelist: req.IPWhitelist,
IPBlacklist: req.IPBlacklist,
}
if req.Name != "" {
svcReq.Name = &req.Name
}

View File

@@ -53,16 +53,18 @@ func APIKeyFromService(k *service.APIKey) *APIKey {
return nil
}
return &APIKey{
ID: k.ID,
UserID: k.UserID,
Key: k.Key,
Name: k.Name,
GroupID: k.GroupID,
Status: k.Status,
CreatedAt: k.CreatedAt,
UpdatedAt: k.UpdatedAt,
User: UserFromServiceShallow(k.User),
Group: GroupFromServiceShallow(k.Group),
ID: k.ID,
UserID: k.UserID,
Key: k.Key,
Name: k.Name,
GroupID: k.GroupID,
Status: k.Status,
IPWhitelist: k.IPWhitelist,
IPBlacklist: k.IPBlacklist,
CreatedAt: k.CreatedAt,
UpdatedAt: k.UpdatedAt,
User: UserFromServiceShallow(k.User),
Group: GroupFromServiceShallow(k.Group),
}
}
@@ -250,11 +252,12 @@ func AccountSummaryFromService(a *service.Account) *AccountSummary {
// usageLogFromServiceBase is a helper that converts service UsageLog to DTO.
// The account parameter allows caller to control what Account info is included.
func usageLogFromServiceBase(l *service.UsageLog, account *AccountSummary) *UsageLog {
// The includeIPAddress parameter controls whether to include the IP address (admin-only).
func usageLogFromServiceBase(l *service.UsageLog, account *AccountSummary, includeIPAddress bool) *UsageLog {
if l == nil {
return nil
}
return &UsageLog{
result := &UsageLog{
ID: l.ID,
UserID: l.UserID,
APIKeyID: l.APIKeyID,
@@ -290,21 +293,26 @@ func usageLogFromServiceBase(l *service.UsageLog, account *AccountSummary) *Usag
Group: GroupFromServiceShallow(l.Group),
Subscription: UserSubscriptionFromService(l.Subscription),
}
// IP 地址仅对管理员可见
if includeIPAddress {
result.IPAddress = l.IPAddress
}
return result
}
// UsageLogFromService converts a service UsageLog to DTO for regular users.
// It excludes Account details - users should not see account information.
// It excludes Account details and IP address - users should not see these.
func UsageLogFromService(l *service.UsageLog) *UsageLog {
return usageLogFromServiceBase(l, nil)
return usageLogFromServiceBase(l, nil, false)
}
// UsageLogFromServiceAdmin converts a service UsageLog to DTO for admin users.
// It includes minimal Account info (ID, Name only).
// It includes minimal Account info (ID, Name only) and IP address.
func UsageLogFromServiceAdmin(l *service.UsageLog) *UsageLog {
if l == nil {
return nil
}
return usageLogFromServiceBase(l, AccountSummaryFromService(l.Account))
return usageLogFromServiceBase(l, AccountSummaryFromService(l.Account), true)
}
func SettingFromService(s *service.Setting) *Setting {

View File

@@ -20,14 +20,16 @@ type User struct {
}
type APIKey struct {
ID int64 `json:"id"`
UserID int64 `json:"user_id"`
Key string `json:"key"`
Name string `json:"name"`
GroupID *int64 `json:"group_id"`
Status string `json:"status"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
ID int64 `json:"id"`
UserID int64 `json:"user_id"`
Key string `json:"key"`
Name string `json:"name"`
GroupID *int64 `json:"group_id"`
Status string `json:"status"`
IPWhitelist []string `json:"ip_whitelist"`
IPBlacklist []string `json:"ip_blacklist"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
User *User `json:"user,omitempty"`
Group *Group `json:"group,omitempty"`
@@ -187,6 +189,9 @@ type UsageLog struct {
// User-Agent
UserAgent *string `json:"user_agent"`
// IP 地址(仅管理员可见)
IPAddress *string `json:"ip_address,omitempty"`
CreatedAt time.Time `json:"created_at"`
User *User `json:"user,omitempty"`

View File

@@ -15,6 +15,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
pkgerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -114,6 +115,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 获取 User-Agent
userAgent := c.Request.UserAgent()
// 获取客户端 IP
clientIP := ip.GetClientIP(c)
// 0. 检查wait队列是否已满
maxWait := service.CalculateMaxWait(subject.Concurrency)
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
@@ -273,7 +277,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
// 异步记录使用量subscription已在函数开头获取
go func(result *service.ForwardResult, usedAccount *service.Account, ua string) {
go func(result *service.ForwardResult, usedAccount *service.Account, ua string, cip string) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
@@ -283,10 +287,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
Account: usedAccount,
Subscription: subscription,
UserAgent: ua,
IPAddress: cip,
}); err != nil {
log.Printf("Record usage failed: %v", err)
}
}(result, account, userAgent)
}(result, account, userAgent, clientIP)
return
}
}
@@ -401,7 +406,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
}
// 异步记录使用量subscription已在函数开头获取
go func(result *service.ForwardResult, usedAccount *service.Account, ua string) {
go func(result *service.ForwardResult, usedAccount *service.Account, ua string, cip string) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
@@ -411,10 +416,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
Account: usedAccount,
Subscription: subscription,
UserAgent: ua,
IPAddress: cip,
}); err != nil {
log.Printf("Record usage failed: %v", err)
}
}(result, account, userAgent)
}(result, account, userAgent, clientIP)
return
}
}

View File

@@ -12,6 +12,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/gemini"
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -167,6 +168,9 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
// 获取 User-Agent
userAgent := c.Request.UserAgent()
// 获取客户端 IP
clientIP := ip.GetClientIP(c)
// For Gemini native API, do not send Claude-style ping frames.
geminiConcurrency := NewConcurrencyHelper(h.concurrencyHelper.concurrencyService, SSEPingFormatNone, 0)
@@ -307,7 +311,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
}
// 6) record usage async
go func(result *service.ForwardResult, usedAccount *service.Account, ua string) {
go func(result *service.ForwardResult, usedAccount *service.Account, ua string, cip string) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
@@ -317,10 +321,11 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
Account: usedAccount,
Subscription: subscription,
UserAgent: ua,
IPAddress: cip,
}); err != nil {
log.Printf("Record usage failed: %v", err)
}
}(result, account, userAgent)
}(result, account, userAgent, clientIP)
return
}
}

View File

@@ -11,6 +11,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -94,6 +95,10 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
// For non-Codex CLI requests, set default instructions
userAgent := c.GetHeader("User-Agent")
// 获取客户端 IP
clientIP := ip.GetClientIP(c)
if !openai.IsCodexCLIRequest(userAgent) {
reqBody["instructions"] = openai.DefaultInstructions
// Re-serialize body
@@ -242,7 +247,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
}
// Async record usage
go func(result *service.OpenAIForwardResult, usedAccount *service.Account, ua string) {
go func(result *service.OpenAIForwardResult, usedAccount *service.Account, ua string, cip string) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
@@ -252,10 +257,11 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
Account: usedAccount,
Subscription: subscription,
UserAgent: ua,
IPAddress: cip,
}); err != nil {
log.Printf("Record usage failed: %v", err)
}
}(result, account, userAgent)
}(result, account, userAgent, clientIP)
return
}
}

View File

@@ -0,0 +1,168 @@
// Package ip 提供客户端 IP 地址提取工具。
package ip
import (
"net"
"strings"
"github.com/gin-gonic/gin"
)
// GetClientIP 从 Gin Context 中提取客户端真实 IP 地址。
// 按以下优先级检查 Header
// 1. CF-Connecting-IP (Cloudflare)
// 2. X-Real-IP (Nginx)
// 3. X-Forwarded-For (取第一个非私有 IP)
// 4. c.ClientIP() (Gin 内置方法)
func GetClientIP(c *gin.Context) string {
// 1. Cloudflare
if ip := c.GetHeader("CF-Connecting-IP"); ip != "" {
return normalizeIP(ip)
}
// 2. Nginx X-Real-IP
if ip := c.GetHeader("X-Real-IP"); ip != "" {
return normalizeIP(ip)
}
// 3. X-Forwarded-For (多个 IP 时取第一个公网 IP)
if xff := c.GetHeader("X-Forwarded-For"); xff != "" {
ips := strings.Split(xff, ",")
for _, ip := range ips {
ip = strings.TrimSpace(ip)
if ip != "" && !isPrivateIP(ip) {
return normalizeIP(ip)
}
}
// 如果都是私有 IP返回第一个
if len(ips) > 0 {
return normalizeIP(strings.TrimSpace(ips[0]))
}
}
// 4. Gin 内置方法
return normalizeIP(c.ClientIP())
}
// normalizeIP 规范化 IP 地址,去除端口号和空格。
func normalizeIP(ip string) string {
ip = strings.TrimSpace(ip)
// 移除端口号(如 "192.168.1.1:8080" -> "192.168.1.1"
if host, _, err := net.SplitHostPort(ip); err == nil {
return host
}
return ip
}
// isPrivateIP 检查 IP 是否为私有地址。
func isPrivateIP(ipStr string) bool {
ip := net.ParseIP(ipStr)
if ip == nil {
return false
}
// 私有 IP 范围
privateBlocks := []string{
"10.0.0.0/8",
"172.16.0.0/12",
"192.168.0.0/16",
"127.0.0.0/8",
"::1/128",
"fc00::/7",
}
for _, block := range privateBlocks {
_, cidr, err := net.ParseCIDR(block)
if err != nil {
continue
}
if cidr.Contains(ip) {
return true
}
}
return false
}
// MatchesPattern 检查 IP 是否匹配指定的模式(支持单个 IP 或 CIDR
// pattern 可以是:
// - 单个 IP: "192.168.1.100"
// - CIDR 范围: "192.168.1.0/24"
func MatchesPattern(clientIP, pattern string) bool {
ip := net.ParseIP(clientIP)
if ip == nil {
return false
}
// 尝试解析为 CIDR
if strings.Contains(pattern, "/") {
_, cidr, err := net.ParseCIDR(pattern)
if err != nil {
return false
}
return cidr.Contains(ip)
}
// 作为单个 IP 处理
patternIP := net.ParseIP(pattern)
if patternIP == nil {
return false
}
return ip.Equal(patternIP)
}
// MatchesAnyPattern 检查 IP 是否匹配任意一个模式。
func MatchesAnyPattern(clientIP string, patterns []string) bool {
for _, pattern := range patterns {
if MatchesPattern(clientIP, pattern) {
return true
}
}
return false
}
// CheckIPRestriction 检查 IP 是否被 API Key 的 IP 限制允许。
// 返回值:(是否允许, 拒绝原因)
// 逻辑:
// 1. 先检查黑名单,如果在黑名单中则直接拒绝
// 2. 如果白名单不为空IP 必须在白名单中
// 3. 如果白名单为空,允许访问(除非被黑名单拒绝)
func CheckIPRestriction(clientIP string, whitelist, blacklist []string) (bool, string) {
// 规范化 IP
clientIP = normalizeIP(clientIP)
if clientIP == "" {
return false, "access denied"
}
// 1. 检查黑名单
if len(blacklist) > 0 && MatchesAnyPattern(clientIP, blacklist) {
return false, "access denied"
}
// 2. 检查白名单如果设置了白名单IP 必须在其中)
if len(whitelist) > 0 && !MatchesAnyPattern(clientIP, whitelist) {
return false, "access denied"
}
return true, ""
}
// ValidateIPPattern 验证 IP 或 CIDR 格式是否有效。
func ValidateIPPattern(pattern string) bool {
if strings.Contains(pattern, "/") {
_, _, err := net.ParseCIDR(pattern)
return err == nil
}
return net.ParseIP(pattern) != nil
}
// ValidateIPPatterns 验证多个 IP 或 CIDR 格式。
// 返回无效的模式列表。
func ValidateIPPatterns(patterns []string) []string {
var invalid []string
for _, p := range patterns {
if !ValidateIPPattern(p) {
invalid = append(invalid, p)
}
}
return invalid
}

View File

@@ -26,13 +26,21 @@ func (r *apiKeyRepository) activeQuery() *dbent.APIKeyQuery {
}
func (r *apiKeyRepository) Create(ctx context.Context, key *service.APIKey) error {
created, err := r.client.APIKey.Create().
builder := r.client.APIKey.Create().
SetUserID(key.UserID).
SetKey(key.Key).
SetName(key.Name).
SetStatus(key.Status).
SetNillableGroupID(key.GroupID).
Save(ctx)
SetNillableGroupID(key.GroupID)
if len(key.IPWhitelist) > 0 {
builder.SetIPWhitelist(key.IPWhitelist)
}
if len(key.IPBlacklist) > 0 {
builder.SetIPBlacklist(key.IPBlacklist)
}
created, err := builder.Save(ctx)
if err == nil {
key.ID = created.ID
key.CreatedAt = created.CreatedAt
@@ -108,6 +116,18 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) erro
builder.ClearGroupID()
}
// IP 限制字段
if len(key.IPWhitelist) > 0 {
builder.SetIPWhitelist(key.IPWhitelist)
} else {
builder.ClearIPWhitelist()
}
if len(key.IPBlacklist) > 0 {
builder.SetIPBlacklist(key.IPBlacklist)
} else {
builder.ClearIPBlacklist()
}
affected, err := builder.Save(ctx)
if err != nil {
return err
@@ -268,14 +288,16 @@ func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey {
return nil
}
out := &service.APIKey{
ID: m.ID,
UserID: m.UserID,
Key: m.Key,
Name: m.Name,
Status: m.Status,
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
GroupID: m.GroupID,
ID: m.ID,
UserID: m.UserID,
Key: m.Key,
Name: m.Name,
Status: m.Status,
IPWhitelist: m.IPWhitelist,
IPBlacklist: m.IPBlacklist,
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
GroupID: m.GroupID,
}
if m.Edges.User != nil {
out.User = userEntityToService(m.Edges.User)

View File

@@ -22,7 +22,7 @@ import (
"github.com/lib/pq"
)
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, image_count, image_size, created_at"
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, created_at"
type usageLogRepository struct {
client *dbent.Client
@@ -110,6 +110,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
duration_ms,
first_token_ms,
user_agent,
ip_address,
image_count,
image_size,
created_at
@@ -119,7 +120,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
$8, $9, $10, $11,
$12, $13,
$14, $15, $16, $17, $18, $19,
$20, $21, $22, $23, $24, $25, $26, $27, $28
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29
)
ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id, created_at
@@ -130,6 +131,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
duration := nullInt(log.DurationMs)
firstToken := nullInt(log.FirstTokenMs)
userAgent := nullString(log.UserAgent)
ipAddress := nullString(log.IPAddress)
imageSize := nullString(log.ImageSize)
var requestIDArg any
@@ -163,6 +165,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
duration,
firstToken,
userAgent,
ipAddress,
log.ImageCount,
imageSize,
createdAt,
@@ -1873,6 +1876,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
durationMs sql.NullInt64
firstTokenMs sql.NullInt64
userAgent sql.NullString
ipAddress sql.NullString
imageCount int
imageSize sql.NullString
createdAt time.Time
@@ -1905,6 +1909,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
&durationMs,
&firstTokenMs,
&userAgent,
&ipAddress,
&imageCount,
&imageSize,
&createdAt,
@@ -1959,6 +1964,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
if userAgent.Valid {
log.UserAgent = &userAgent.String
}
if ipAddress.Valid {
log.IPAddress = &ipAddress.String
}
if imageSize.Valid {
log.ImageSize = &imageSize.String
}

View File

@@ -6,6 +6,7 @@ import (
"strings"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
@@ -71,6 +72,17 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
return
}
// 检查 IP 限制(白名单/黑名单)
// 注意:错误信息故意模糊,避免暴露具体的 IP 限制机制
if len(apiKey.IPWhitelist) > 0 || len(apiKey.IPBlacklist) > 0 {
clientIP := ip.GetClientIP(c)
allowed, _ := ip.CheckIPRestriction(clientIP, apiKey.IPWhitelist, apiKey.IPBlacklist)
if !allowed {
AbortWithError(c, 403, "ACCESS_DENIED", "Access denied")
return
}
}
// 检查关联的用户
if apiKey.User == nil {
AbortWithError(c, 401, "USER_NOT_FOUND", "User associated with API key not found")

View File

@@ -3,16 +3,18 @@ package service
import "time"
type APIKey struct {
ID int64
UserID int64
Key string
Name string
GroupID *int64
Status string
CreatedAt time.Time
UpdatedAt time.Time
User *User
Group *Group
ID int64
UserID int64
Key string
Name string
GroupID *int64
Status string
IPWhitelist []string
IPBlacklist []string
CreatedAt time.Time
UpdatedAt time.Time
User *User
Group *Group
}
func (k *APIKey) IsActive() bool {

View File

@@ -9,6 +9,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
)
@@ -20,6 +21,7 @@ var (
ErrAPIKeyTooShort = infraerrors.BadRequest("API_KEY_TOO_SHORT", "api key must be at least 16 characters")
ErrAPIKeyInvalidChars = infraerrors.BadRequest("API_KEY_INVALID_CHARS", "api key can only contain letters, numbers, underscores, and hyphens")
ErrAPIKeyRateLimited = infraerrors.TooManyRequests("API_KEY_RATE_LIMITED", "too many failed attempts, please try again later")
ErrInvalidIPPattern = infraerrors.BadRequest("INVALID_IP_PATTERN", "invalid IP or CIDR pattern")
)
const (
@@ -57,16 +59,20 @@ type APIKeyCache interface {
// CreateAPIKeyRequest 创建API Key请求
type CreateAPIKeyRequest struct {
Name string `json:"name"`
GroupID *int64 `json:"group_id"`
CustomKey *string `json:"custom_key"` // 可选的自定义key
Name string `json:"name"`
GroupID *int64 `json:"group_id"`
CustomKey *string `json:"custom_key"` // 可选的自定义key
IPWhitelist []string `json:"ip_whitelist"` // IP 白名单
IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单
}
// UpdateAPIKeyRequest 更新API Key请求
type UpdateAPIKeyRequest struct {
Name *string `json:"name"`
GroupID *int64 `json:"group_id"`
Status *string `json:"status"`
Name *string `json:"name"`
GroupID *int64 `json:"group_id"`
Status *string `json:"status"`
IPWhitelist []string `json:"ip_whitelist"` // IP 白名单(空数组清空)
IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单(空数组清空)
}
// APIKeyService API Key服务
@@ -186,6 +192,20 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK
return nil, fmt.Errorf("get user: %w", err)
}
// 验证 IP 白名单格式
if len(req.IPWhitelist) > 0 {
if invalid := ip.ValidateIPPatterns(req.IPWhitelist); len(invalid) > 0 {
return nil, fmt.Errorf("%w: %v", ErrInvalidIPPattern, invalid)
}
}
// 验证 IP 黑名单格式
if len(req.IPBlacklist) > 0 {
if invalid := ip.ValidateIPPatterns(req.IPBlacklist); len(invalid) > 0 {
return nil, fmt.Errorf("%w: %v", ErrInvalidIPPattern, invalid)
}
}
// 验证分组权限(如果指定了分组)
if req.GroupID != nil {
group, err := s.groupRepo.GetByID(ctx, *req.GroupID)
@@ -236,11 +256,13 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK
// 创建API Key记录
apiKey := &APIKey{
UserID: userID,
Key: key,
Name: req.Name,
GroupID: req.GroupID,
Status: StatusActive,
UserID: userID,
Key: key,
Name: req.Name,
GroupID: req.GroupID,
Status: StatusActive,
IPWhitelist: req.IPWhitelist,
IPBlacklist: req.IPBlacklist,
}
if err := s.apiKeyRepo.Create(ctx, apiKey); err != nil {
@@ -312,6 +334,20 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req
return nil, ErrInsufficientPerms
}
// 验证 IP 白名单格式
if len(req.IPWhitelist) > 0 {
if invalid := ip.ValidateIPPatterns(req.IPWhitelist); len(invalid) > 0 {
return nil, fmt.Errorf("%w: %v", ErrInvalidIPPattern, invalid)
}
}
// 验证 IP 黑名单格式
if len(req.IPBlacklist) > 0 {
if invalid := ip.ValidateIPPatterns(req.IPBlacklist); len(invalid) > 0 {
return nil, fmt.Errorf("%w: %v", ErrInvalidIPPattern, invalid)
}
}
// 更新字段
if req.Name != nil {
apiKey.Name = *req.Name
@@ -344,6 +380,10 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req
}
}
// 更新 IP 限制(空数组会清空设置)
apiKey.IPWhitelist = req.IPWhitelist
apiKey.IPBlacklist = req.IPBlacklist
if err := s.apiKeyRepo.Update(ctx, apiKey); err != nil {
return nil, fmt.Errorf("update api key: %w", err)
}

View File

@@ -2247,6 +2247,7 @@ type RecordUsageInput struct {
Account *Account
Subscription *UserSubscription // 可选:订阅信息
UserAgent string // 请求的 User-Agent
IPAddress string // 请求的客户端 IP 地址
}
// RecordUsage 记录使用量并扣费(或更新订阅用量)
@@ -2337,6 +2338,11 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
usageLog.UserAgent = &input.UserAgent
}
// 添加 IPAddress
if input.IPAddress != "" {
usageLog.IPAddress = &input.IPAddress
}
// 添加分组和订阅关联
if apiKey.GroupID != nil {
usageLog.GroupID = apiKey.GroupID

View File

@@ -1093,6 +1093,7 @@ type OpenAIRecordUsageInput struct {
Account *Account
Subscription *UserSubscription
UserAgent string // 请求的 User-Agent
IPAddress string // 请求的客户端 IP 地址
}
// RecordUsage records usage and deducts balance
@@ -1167,6 +1168,11 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
usageLog.UserAgent = &input.UserAgent
}
// 添加 IPAddress
if input.IPAddress != "" {
usageLog.IPAddress = &input.IPAddress
}
if apiKey.GroupID != nil {
usageLog.GroupID = apiKey.GroupID
}

View File

@@ -39,6 +39,7 @@ type UsageLog struct {
DurationMs *int
FirstTokenMs *int
UserAgent *string
IPAddress *string
// 图片生成字段
ImageCount int

View File

@@ -0,0 +1,5 @@
-- Add IP address field to usage_logs table for request tracking (admin-only visibility)
ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS ip_address VARCHAR(45);
-- Create index for IP address queries
CREATE INDEX IF NOT EXISTS idx_usage_logs_ip_address ON usage_logs(ip_address);

View File

@@ -0,0 +1,9 @@
-- Add IP restriction fields to api_keys table
-- ip_whitelist: JSON array of allowed IPs/CIDRs (if set, only these IPs can use the key)
-- ip_blacklist: JSON array of blocked IPs/CIDRs (these IPs are always blocked)
ALTER TABLE api_keys ADD COLUMN IF NOT EXISTS ip_whitelist JSONB DEFAULT NULL;
ALTER TABLE api_keys ADD COLUMN IF NOT EXISTS ip_blacklist JSONB DEFAULT NULL;
COMMENT ON COLUMN api_keys.ip_whitelist IS 'JSON array of allowed IPs/CIDRs, e.g. ["192.168.1.100", "10.0.0.0/8"]';
COMMENT ON COLUMN api_keys.ip_blacklist IS 'JSON array of blocked IPs/CIDRs, e.g. ["1.2.3.4", "5.6.0.0/16"]';