diff --git a/backend/Dockerfile b/backend/Dockerfile index 770fdedf..4b5b6286 100644 --- a/backend/Dockerfile +++ b/backend/Dockerfile @@ -1,4 +1,4 @@ -FROM golang:1.25.5-alpine +FROM golang:1.25.6-alpine WORKDIR /app @@ -15,7 +15,7 @@ RUN go mod download COPY . . # 构建应用 -RUN go build -o main cmd/server/main.go +RUN go build -o main ./cmd/server/ # 暴露端口 EXPOSE 8080 diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 694d05a7..ab51540f 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -173,8 +173,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { userAttributeService := service.NewUserAttributeService(userAttributeDefinitionRepository, userAttributeValueRepository) userAttributeHandler := admin.NewUserAttributeHandler(userAttributeService) adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler) - gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, configConfig) - openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, configConfig) + gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, configConfig) + openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, configConfig) handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo) totpHandler := handler.NewTotpHandler(totpService) handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler, totpHandler) diff --git a/backend/ent/apikey.go b/backend/ent/apikey.go index 95586017..91d71964 100644 --- a/backend/ent/apikey.go +++ b/backend/ent/apikey.go @@ -40,6 +40,12 @@ type APIKey struct { IPWhitelist []string `json:"ip_whitelist,omitempty"` // Blocked IPs/CIDRs IPBlacklist []string `json:"ip_blacklist,omitempty"` + // Quota limit in USD for this API key (0 = unlimited) + Quota float64 `json:"quota,omitempty"` + // Used quota amount in USD + QuotaUsed float64 `json:"quota_used,omitempty"` + // Expiration time for this API key (null = never expires) + ExpiresAt *time.Time `json:"expires_at,omitempty"` // Edges holds the relations/edges for other nodes in the graph. // The values are being populated by the APIKeyQuery when eager-loading is set. Edges APIKeyEdges `json:"edges"` @@ -97,11 +103,13 @@ func (*APIKey) scanValues(columns []string) ([]any, error) { switch columns[i] { case apikey.FieldIPWhitelist, apikey.FieldIPBlacklist: values[i] = new([]byte) + case apikey.FieldQuota, apikey.FieldQuotaUsed: + values[i] = new(sql.NullFloat64) case apikey.FieldID, apikey.FieldUserID, apikey.FieldGroupID: values[i] = new(sql.NullInt64) case apikey.FieldKey, apikey.FieldName, apikey.FieldStatus: values[i] = new(sql.NullString) - case apikey.FieldCreatedAt, apikey.FieldUpdatedAt, apikey.FieldDeletedAt: + case apikey.FieldCreatedAt, apikey.FieldUpdatedAt, apikey.FieldDeletedAt, apikey.FieldExpiresAt: values[i] = new(sql.NullTime) default: values[i] = new(sql.UnknownType) @@ -190,6 +198,25 @@ func (_m *APIKey) assignValues(columns []string, values []any) error { return fmt.Errorf("unmarshal field ip_blacklist: %w", err) } } + case apikey.FieldQuota: + if value, ok := values[i].(*sql.NullFloat64); !ok { + return fmt.Errorf("unexpected type %T for field quota", values[i]) + } else if value.Valid { + _m.Quota = value.Float64 + } + case apikey.FieldQuotaUsed: + if value, ok := values[i].(*sql.NullFloat64); !ok { + return fmt.Errorf("unexpected type %T for field quota_used", values[i]) + } else if value.Valid { + _m.QuotaUsed = value.Float64 + } + case apikey.FieldExpiresAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field expires_at", values[i]) + } else if value.Valid { + _m.ExpiresAt = new(time.Time) + *_m.ExpiresAt = value.Time + } default: _m.selectValues.Set(columns[i], values[i]) } @@ -274,6 +301,17 @@ func (_m *APIKey) String() string { builder.WriteString(", ") builder.WriteString("ip_blacklist=") builder.WriteString(fmt.Sprintf("%v", _m.IPBlacklist)) + builder.WriteString(", ") + builder.WriteString("quota=") + builder.WriteString(fmt.Sprintf("%v", _m.Quota)) + builder.WriteString(", ") + builder.WriteString("quota_used=") + builder.WriteString(fmt.Sprintf("%v", _m.QuotaUsed)) + builder.WriteString(", ") + if v := _m.ExpiresAt; v != nil { + builder.WriteString("expires_at=") + builder.WriteString(v.Format(time.ANSIC)) + } builder.WriteByte(')') return builder.String() } diff --git a/backend/ent/apikey/apikey.go b/backend/ent/apikey/apikey.go index 564cddb1..ac2a6008 100644 --- a/backend/ent/apikey/apikey.go +++ b/backend/ent/apikey/apikey.go @@ -35,6 +35,12 @@ const ( FieldIPWhitelist = "ip_whitelist" // FieldIPBlacklist holds the string denoting the ip_blacklist field in the database. FieldIPBlacklist = "ip_blacklist" + // FieldQuota holds the string denoting the quota field in the database. + FieldQuota = "quota" + // FieldQuotaUsed holds the string denoting the quota_used field in the database. + FieldQuotaUsed = "quota_used" + // FieldExpiresAt holds the string denoting the expires_at field in the database. + FieldExpiresAt = "expires_at" // EdgeUser holds the string denoting the user edge name in mutations. EdgeUser = "user" // EdgeGroup holds the string denoting the group edge name in mutations. @@ -79,6 +85,9 @@ var Columns = []string{ FieldStatus, FieldIPWhitelist, FieldIPBlacklist, + FieldQuota, + FieldQuotaUsed, + FieldExpiresAt, } // ValidColumn reports if the column name is valid (part of the table columns). @@ -113,6 +122,10 @@ var ( DefaultStatus string // StatusValidator is a validator for the "status" field. It is called by the builders before save. StatusValidator func(string) error + // DefaultQuota holds the default value on creation for the "quota" field. + DefaultQuota float64 + // DefaultQuotaUsed holds the default value on creation for the "quota_used" field. + DefaultQuotaUsed float64 ) // OrderOption defines the ordering options for the APIKey queries. @@ -163,6 +176,21 @@ func ByStatus(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldStatus, opts...).ToFunc() } +// ByQuota orders the results by the quota field. +func ByQuota(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldQuota, opts...).ToFunc() +} + +// ByQuotaUsed orders the results by the quota_used field. +func ByQuotaUsed(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldQuotaUsed, opts...).ToFunc() +} + +// ByExpiresAt orders the results by the expires_at field. +func ByExpiresAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldExpiresAt, opts...).ToFunc() +} + // ByUserField orders the results by user field. func ByUserField(field string, opts ...sql.OrderTermOption) OrderOption { return func(s *sql.Selector) { diff --git a/backend/ent/apikey/where.go b/backend/ent/apikey/where.go index 5152867f..f54f44b7 100644 --- a/backend/ent/apikey/where.go +++ b/backend/ent/apikey/where.go @@ -95,6 +95,21 @@ func Status(v string) predicate.APIKey { return predicate.APIKey(sql.FieldEQ(FieldStatus, v)) } +// Quota applies equality check predicate on the "quota" field. It's identical to QuotaEQ. +func Quota(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldQuota, v)) +} + +// QuotaUsed applies equality check predicate on the "quota_used" field. It's identical to QuotaUsedEQ. +func QuotaUsed(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldQuotaUsed, v)) +} + +// ExpiresAt applies equality check predicate on the "expires_at" field. It's identical to ExpiresAtEQ. +func ExpiresAt(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldExpiresAt, v)) +} + // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.APIKey { return predicate.APIKey(sql.FieldEQ(FieldCreatedAt, v)) @@ -490,6 +505,136 @@ func IPBlacklistNotNil() predicate.APIKey { return predicate.APIKey(sql.FieldNotNull(FieldIPBlacklist)) } +// QuotaEQ applies the EQ predicate on the "quota" field. +func QuotaEQ(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldQuota, v)) +} + +// QuotaNEQ applies the NEQ predicate on the "quota" field. +func QuotaNEQ(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldNEQ(FieldQuota, v)) +} + +// QuotaIn applies the In predicate on the "quota" field. +func QuotaIn(vs ...float64) predicate.APIKey { + return predicate.APIKey(sql.FieldIn(FieldQuota, vs...)) +} + +// QuotaNotIn applies the NotIn predicate on the "quota" field. +func QuotaNotIn(vs ...float64) predicate.APIKey { + return predicate.APIKey(sql.FieldNotIn(FieldQuota, vs...)) +} + +// QuotaGT applies the GT predicate on the "quota" field. +func QuotaGT(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldGT(FieldQuota, v)) +} + +// QuotaGTE applies the GTE predicate on the "quota" field. +func QuotaGTE(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldGTE(FieldQuota, v)) +} + +// QuotaLT applies the LT predicate on the "quota" field. +func QuotaLT(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldLT(FieldQuota, v)) +} + +// QuotaLTE applies the LTE predicate on the "quota" field. +func QuotaLTE(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldLTE(FieldQuota, v)) +} + +// QuotaUsedEQ applies the EQ predicate on the "quota_used" field. +func QuotaUsedEQ(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldQuotaUsed, v)) +} + +// QuotaUsedNEQ applies the NEQ predicate on the "quota_used" field. +func QuotaUsedNEQ(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldNEQ(FieldQuotaUsed, v)) +} + +// QuotaUsedIn applies the In predicate on the "quota_used" field. +func QuotaUsedIn(vs ...float64) predicate.APIKey { + return predicate.APIKey(sql.FieldIn(FieldQuotaUsed, vs...)) +} + +// QuotaUsedNotIn applies the NotIn predicate on the "quota_used" field. +func QuotaUsedNotIn(vs ...float64) predicate.APIKey { + return predicate.APIKey(sql.FieldNotIn(FieldQuotaUsed, vs...)) +} + +// QuotaUsedGT applies the GT predicate on the "quota_used" field. +func QuotaUsedGT(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldGT(FieldQuotaUsed, v)) +} + +// QuotaUsedGTE applies the GTE predicate on the "quota_used" field. +func QuotaUsedGTE(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldGTE(FieldQuotaUsed, v)) +} + +// QuotaUsedLT applies the LT predicate on the "quota_used" field. +func QuotaUsedLT(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldLT(FieldQuotaUsed, v)) +} + +// QuotaUsedLTE applies the LTE predicate on the "quota_used" field. +func QuotaUsedLTE(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldLTE(FieldQuotaUsed, v)) +} + +// ExpiresAtEQ applies the EQ predicate on the "expires_at" field. +func ExpiresAtEQ(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldExpiresAt, v)) +} + +// ExpiresAtNEQ applies the NEQ predicate on the "expires_at" field. +func ExpiresAtNEQ(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldNEQ(FieldExpiresAt, v)) +} + +// ExpiresAtIn applies the In predicate on the "expires_at" field. +func ExpiresAtIn(vs ...time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldIn(FieldExpiresAt, vs...)) +} + +// ExpiresAtNotIn applies the NotIn predicate on the "expires_at" field. +func ExpiresAtNotIn(vs ...time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldNotIn(FieldExpiresAt, vs...)) +} + +// ExpiresAtGT applies the GT predicate on the "expires_at" field. +func ExpiresAtGT(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldGT(FieldExpiresAt, v)) +} + +// ExpiresAtGTE applies the GTE predicate on the "expires_at" field. +func ExpiresAtGTE(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldGTE(FieldExpiresAt, v)) +} + +// ExpiresAtLT applies the LT predicate on the "expires_at" field. +func ExpiresAtLT(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldLT(FieldExpiresAt, v)) +} + +// ExpiresAtLTE applies the LTE predicate on the "expires_at" field. +func ExpiresAtLTE(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldLTE(FieldExpiresAt, v)) +} + +// ExpiresAtIsNil applies the IsNil predicate on the "expires_at" field. +func ExpiresAtIsNil() predicate.APIKey { + return predicate.APIKey(sql.FieldIsNull(FieldExpiresAt)) +} + +// ExpiresAtNotNil applies the NotNil predicate on the "expires_at" field. +func ExpiresAtNotNil() predicate.APIKey { + return predicate.APIKey(sql.FieldNotNull(FieldExpiresAt)) +} + // HasUser applies the HasEdge predicate on the "user" edge. func HasUser() predicate.APIKey { return predicate.APIKey(func(s *sql.Selector) { diff --git a/backend/ent/apikey_create.go b/backend/ent/apikey_create.go index d5363be5..71540975 100644 --- a/backend/ent/apikey_create.go +++ b/backend/ent/apikey_create.go @@ -125,6 +125,48 @@ func (_c *APIKeyCreate) SetIPBlacklist(v []string) *APIKeyCreate { return _c } +// SetQuota sets the "quota" field. +func (_c *APIKeyCreate) SetQuota(v float64) *APIKeyCreate { + _c.mutation.SetQuota(v) + return _c +} + +// SetNillableQuota sets the "quota" field if the given value is not nil. +func (_c *APIKeyCreate) SetNillableQuota(v *float64) *APIKeyCreate { + if v != nil { + _c.SetQuota(*v) + } + return _c +} + +// SetQuotaUsed sets the "quota_used" field. +func (_c *APIKeyCreate) SetQuotaUsed(v float64) *APIKeyCreate { + _c.mutation.SetQuotaUsed(v) + return _c +} + +// SetNillableQuotaUsed sets the "quota_used" field if the given value is not nil. +func (_c *APIKeyCreate) SetNillableQuotaUsed(v *float64) *APIKeyCreate { + if v != nil { + _c.SetQuotaUsed(*v) + } + return _c +} + +// SetExpiresAt sets the "expires_at" field. +func (_c *APIKeyCreate) SetExpiresAt(v time.Time) *APIKeyCreate { + _c.mutation.SetExpiresAt(v) + return _c +} + +// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil. +func (_c *APIKeyCreate) SetNillableExpiresAt(v *time.Time) *APIKeyCreate { + if v != nil { + _c.SetExpiresAt(*v) + } + return _c +} + // SetUser sets the "user" edge to the User entity. func (_c *APIKeyCreate) SetUser(v *User) *APIKeyCreate { return _c.SetUserID(v.ID) @@ -205,6 +247,14 @@ func (_c *APIKeyCreate) defaults() error { v := apikey.DefaultStatus _c.mutation.SetStatus(v) } + if _, ok := _c.mutation.Quota(); !ok { + v := apikey.DefaultQuota + _c.mutation.SetQuota(v) + } + if _, ok := _c.mutation.QuotaUsed(); !ok { + v := apikey.DefaultQuotaUsed + _c.mutation.SetQuotaUsed(v) + } return nil } @@ -243,6 +293,12 @@ func (_c *APIKeyCreate) check() error { return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "APIKey.status": %w`, err)} } } + if _, ok := _c.mutation.Quota(); !ok { + return &ValidationError{Name: "quota", err: errors.New(`ent: missing required field "APIKey.quota"`)} + } + if _, ok := _c.mutation.QuotaUsed(); !ok { + return &ValidationError{Name: "quota_used", err: errors.New(`ent: missing required field "APIKey.quota_used"`)} + } if len(_c.mutation.UserIDs()) == 0 { return &ValidationError{Name: "user", err: errors.New(`ent: missing required edge "APIKey.user"`)} } @@ -305,6 +361,18 @@ func (_c *APIKeyCreate) createSpec() (*APIKey, *sqlgraph.CreateSpec) { _spec.SetField(apikey.FieldIPBlacklist, field.TypeJSON, value) _node.IPBlacklist = value } + if value, ok := _c.mutation.Quota(); ok { + _spec.SetField(apikey.FieldQuota, field.TypeFloat64, value) + _node.Quota = value + } + if value, ok := _c.mutation.QuotaUsed(); ok { + _spec.SetField(apikey.FieldQuotaUsed, field.TypeFloat64, value) + _node.QuotaUsed = value + } + if value, ok := _c.mutation.ExpiresAt(); ok { + _spec.SetField(apikey.FieldExpiresAt, field.TypeTime, value) + _node.ExpiresAt = &value + } if nodes := _c.mutation.UserIDs(); len(nodes) > 0 { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.M2O, @@ -539,6 +607,60 @@ func (u *APIKeyUpsert) ClearIPBlacklist() *APIKeyUpsert { return u } +// SetQuota sets the "quota" field. +func (u *APIKeyUpsert) SetQuota(v float64) *APIKeyUpsert { + u.Set(apikey.FieldQuota, v) + return u +} + +// UpdateQuota sets the "quota" field to the value that was provided on create. +func (u *APIKeyUpsert) UpdateQuota() *APIKeyUpsert { + u.SetExcluded(apikey.FieldQuota) + return u +} + +// AddQuota adds v to the "quota" field. +func (u *APIKeyUpsert) AddQuota(v float64) *APIKeyUpsert { + u.Add(apikey.FieldQuota, v) + return u +} + +// SetQuotaUsed sets the "quota_used" field. +func (u *APIKeyUpsert) SetQuotaUsed(v float64) *APIKeyUpsert { + u.Set(apikey.FieldQuotaUsed, v) + return u +} + +// UpdateQuotaUsed sets the "quota_used" field to the value that was provided on create. +func (u *APIKeyUpsert) UpdateQuotaUsed() *APIKeyUpsert { + u.SetExcluded(apikey.FieldQuotaUsed) + return u +} + +// AddQuotaUsed adds v to the "quota_used" field. +func (u *APIKeyUpsert) AddQuotaUsed(v float64) *APIKeyUpsert { + u.Add(apikey.FieldQuotaUsed, v) + return u +} + +// SetExpiresAt sets the "expires_at" field. +func (u *APIKeyUpsert) SetExpiresAt(v time.Time) *APIKeyUpsert { + u.Set(apikey.FieldExpiresAt, v) + return u +} + +// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create. +func (u *APIKeyUpsert) UpdateExpiresAt() *APIKeyUpsert { + u.SetExcluded(apikey.FieldExpiresAt) + return u +} + +// ClearExpiresAt clears the value of the "expires_at" field. +func (u *APIKeyUpsert) ClearExpiresAt() *APIKeyUpsert { + u.SetNull(apikey.FieldExpiresAt) + return u +} + // UpdateNewValues updates the mutable fields using the new values that were set on create. // Using this option is equivalent to using: // @@ -738,6 +860,69 @@ func (u *APIKeyUpsertOne) ClearIPBlacklist() *APIKeyUpsertOne { }) } +// SetQuota sets the "quota" field. +func (u *APIKeyUpsertOne) SetQuota(v float64) *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.SetQuota(v) + }) +} + +// AddQuota adds v to the "quota" field. +func (u *APIKeyUpsertOne) AddQuota(v float64) *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.AddQuota(v) + }) +} + +// UpdateQuota sets the "quota" field to the value that was provided on create. +func (u *APIKeyUpsertOne) UpdateQuota() *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateQuota() + }) +} + +// SetQuotaUsed sets the "quota_used" field. +func (u *APIKeyUpsertOne) SetQuotaUsed(v float64) *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.SetQuotaUsed(v) + }) +} + +// AddQuotaUsed adds v to the "quota_used" field. +func (u *APIKeyUpsertOne) AddQuotaUsed(v float64) *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.AddQuotaUsed(v) + }) +} + +// UpdateQuotaUsed sets the "quota_used" field to the value that was provided on create. +func (u *APIKeyUpsertOne) UpdateQuotaUsed() *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateQuotaUsed() + }) +} + +// SetExpiresAt sets the "expires_at" field. +func (u *APIKeyUpsertOne) SetExpiresAt(v time.Time) *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.SetExpiresAt(v) + }) +} + +// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create. +func (u *APIKeyUpsertOne) UpdateExpiresAt() *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateExpiresAt() + }) +} + +// ClearExpiresAt clears the value of the "expires_at" field. +func (u *APIKeyUpsertOne) ClearExpiresAt() *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.ClearExpiresAt() + }) +} + // Exec executes the query. func (u *APIKeyUpsertOne) Exec(ctx context.Context) error { if len(u.create.conflict) == 0 { @@ -1103,6 +1288,69 @@ func (u *APIKeyUpsertBulk) ClearIPBlacklist() *APIKeyUpsertBulk { }) } +// SetQuota sets the "quota" field. +func (u *APIKeyUpsertBulk) SetQuota(v float64) *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.SetQuota(v) + }) +} + +// AddQuota adds v to the "quota" field. +func (u *APIKeyUpsertBulk) AddQuota(v float64) *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.AddQuota(v) + }) +} + +// UpdateQuota sets the "quota" field to the value that was provided on create. +func (u *APIKeyUpsertBulk) UpdateQuota() *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateQuota() + }) +} + +// SetQuotaUsed sets the "quota_used" field. +func (u *APIKeyUpsertBulk) SetQuotaUsed(v float64) *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.SetQuotaUsed(v) + }) +} + +// AddQuotaUsed adds v to the "quota_used" field. +func (u *APIKeyUpsertBulk) AddQuotaUsed(v float64) *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.AddQuotaUsed(v) + }) +} + +// UpdateQuotaUsed sets the "quota_used" field to the value that was provided on create. +func (u *APIKeyUpsertBulk) UpdateQuotaUsed() *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateQuotaUsed() + }) +} + +// SetExpiresAt sets the "expires_at" field. +func (u *APIKeyUpsertBulk) SetExpiresAt(v time.Time) *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.SetExpiresAt(v) + }) +} + +// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create. +func (u *APIKeyUpsertBulk) UpdateExpiresAt() *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateExpiresAt() + }) +} + +// ClearExpiresAt clears the value of the "expires_at" field. +func (u *APIKeyUpsertBulk) ClearExpiresAt() *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.ClearExpiresAt() + }) +} + // Exec executes the query. func (u *APIKeyUpsertBulk) Exec(ctx context.Context) error { if u.create.err != nil { diff --git a/backend/ent/apikey_update.go b/backend/ent/apikey_update.go index 9ae332a8..b4ff230b 100644 --- a/backend/ent/apikey_update.go +++ b/backend/ent/apikey_update.go @@ -170,6 +170,68 @@ func (_u *APIKeyUpdate) ClearIPBlacklist() *APIKeyUpdate { return _u } +// SetQuota sets the "quota" field. +func (_u *APIKeyUpdate) SetQuota(v float64) *APIKeyUpdate { + _u.mutation.ResetQuota() + _u.mutation.SetQuota(v) + return _u +} + +// SetNillableQuota sets the "quota" field if the given value is not nil. +func (_u *APIKeyUpdate) SetNillableQuota(v *float64) *APIKeyUpdate { + if v != nil { + _u.SetQuota(*v) + } + return _u +} + +// AddQuota adds value to the "quota" field. +func (_u *APIKeyUpdate) AddQuota(v float64) *APIKeyUpdate { + _u.mutation.AddQuota(v) + return _u +} + +// SetQuotaUsed sets the "quota_used" field. +func (_u *APIKeyUpdate) SetQuotaUsed(v float64) *APIKeyUpdate { + _u.mutation.ResetQuotaUsed() + _u.mutation.SetQuotaUsed(v) + return _u +} + +// SetNillableQuotaUsed sets the "quota_used" field if the given value is not nil. +func (_u *APIKeyUpdate) SetNillableQuotaUsed(v *float64) *APIKeyUpdate { + if v != nil { + _u.SetQuotaUsed(*v) + } + return _u +} + +// AddQuotaUsed adds value to the "quota_used" field. +func (_u *APIKeyUpdate) AddQuotaUsed(v float64) *APIKeyUpdate { + _u.mutation.AddQuotaUsed(v) + return _u +} + +// SetExpiresAt sets the "expires_at" field. +func (_u *APIKeyUpdate) SetExpiresAt(v time.Time) *APIKeyUpdate { + _u.mutation.SetExpiresAt(v) + return _u +} + +// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil. +func (_u *APIKeyUpdate) SetNillableExpiresAt(v *time.Time) *APIKeyUpdate { + if v != nil { + _u.SetExpiresAt(*v) + } + return _u +} + +// ClearExpiresAt clears the value of the "expires_at" field. +func (_u *APIKeyUpdate) ClearExpiresAt() *APIKeyUpdate { + _u.mutation.ClearExpiresAt() + return _u +} + // SetUser sets the "user" edge to the User entity. func (_u *APIKeyUpdate) SetUser(v *User) *APIKeyUpdate { return _u.SetUserID(v.ID) @@ -350,6 +412,24 @@ func (_u *APIKeyUpdate) sqlSave(ctx context.Context) (_node int, err error) { if _u.mutation.IPBlacklistCleared() { _spec.ClearField(apikey.FieldIPBlacklist, field.TypeJSON) } + if value, ok := _u.mutation.Quota(); ok { + _spec.SetField(apikey.FieldQuota, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedQuota(); ok { + _spec.AddField(apikey.FieldQuota, field.TypeFloat64, value) + } + if value, ok := _u.mutation.QuotaUsed(); ok { + _spec.SetField(apikey.FieldQuotaUsed, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedQuotaUsed(); ok { + _spec.AddField(apikey.FieldQuotaUsed, field.TypeFloat64, value) + } + if value, ok := _u.mutation.ExpiresAt(); ok { + _spec.SetField(apikey.FieldExpiresAt, field.TypeTime, value) + } + if _u.mutation.ExpiresAtCleared() { + _spec.ClearField(apikey.FieldExpiresAt, field.TypeTime) + } if _u.mutation.UserCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.M2O, @@ -611,6 +691,68 @@ func (_u *APIKeyUpdateOne) ClearIPBlacklist() *APIKeyUpdateOne { return _u } +// SetQuota sets the "quota" field. +func (_u *APIKeyUpdateOne) SetQuota(v float64) *APIKeyUpdateOne { + _u.mutation.ResetQuota() + _u.mutation.SetQuota(v) + return _u +} + +// SetNillableQuota sets the "quota" field if the given value is not nil. +func (_u *APIKeyUpdateOne) SetNillableQuota(v *float64) *APIKeyUpdateOne { + if v != nil { + _u.SetQuota(*v) + } + return _u +} + +// AddQuota adds value to the "quota" field. +func (_u *APIKeyUpdateOne) AddQuota(v float64) *APIKeyUpdateOne { + _u.mutation.AddQuota(v) + return _u +} + +// SetQuotaUsed sets the "quota_used" field. +func (_u *APIKeyUpdateOne) SetQuotaUsed(v float64) *APIKeyUpdateOne { + _u.mutation.ResetQuotaUsed() + _u.mutation.SetQuotaUsed(v) + return _u +} + +// SetNillableQuotaUsed sets the "quota_used" field if the given value is not nil. +func (_u *APIKeyUpdateOne) SetNillableQuotaUsed(v *float64) *APIKeyUpdateOne { + if v != nil { + _u.SetQuotaUsed(*v) + } + return _u +} + +// AddQuotaUsed adds value to the "quota_used" field. +func (_u *APIKeyUpdateOne) AddQuotaUsed(v float64) *APIKeyUpdateOne { + _u.mutation.AddQuotaUsed(v) + return _u +} + +// SetExpiresAt sets the "expires_at" field. +func (_u *APIKeyUpdateOne) SetExpiresAt(v time.Time) *APIKeyUpdateOne { + _u.mutation.SetExpiresAt(v) + return _u +} + +// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil. +func (_u *APIKeyUpdateOne) SetNillableExpiresAt(v *time.Time) *APIKeyUpdateOne { + if v != nil { + _u.SetExpiresAt(*v) + } + return _u +} + +// ClearExpiresAt clears the value of the "expires_at" field. +func (_u *APIKeyUpdateOne) ClearExpiresAt() *APIKeyUpdateOne { + _u.mutation.ClearExpiresAt() + return _u +} + // SetUser sets the "user" edge to the User entity. func (_u *APIKeyUpdateOne) SetUser(v *User) *APIKeyUpdateOne { return _u.SetUserID(v.ID) @@ -821,6 +963,24 @@ func (_u *APIKeyUpdateOne) sqlSave(ctx context.Context) (_node *APIKey, err erro if _u.mutation.IPBlacklistCleared() { _spec.ClearField(apikey.FieldIPBlacklist, field.TypeJSON) } + if value, ok := _u.mutation.Quota(); ok { + _spec.SetField(apikey.FieldQuota, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedQuota(); ok { + _spec.AddField(apikey.FieldQuota, field.TypeFloat64, value) + } + if value, ok := _u.mutation.QuotaUsed(); ok { + _spec.SetField(apikey.FieldQuotaUsed, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedQuotaUsed(); ok { + _spec.AddField(apikey.FieldQuotaUsed, field.TypeFloat64, value) + } + if value, ok := _u.mutation.ExpiresAt(); ok { + _spec.SetField(apikey.FieldExpiresAt, field.TypeTime, value) + } + if _u.mutation.ExpiresAtCleared() { + _spec.ClearField(apikey.FieldExpiresAt, field.TypeTime) + } if _u.mutation.UserCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.M2O, diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go index d0238545..dc91f6a5 100644 --- a/backend/ent/migrate/schema.go +++ b/backend/ent/migrate/schema.go @@ -20,6 +20,9 @@ var ( {Name: "status", Type: field.TypeString, Size: 20, Default: "active"}, {Name: "ip_whitelist", Type: field.TypeJSON, Nullable: true}, {Name: "ip_blacklist", Type: field.TypeJSON, Nullable: true}, + {Name: "quota", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, + {Name: "quota_used", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, + {Name: "expires_at", Type: field.TypeTime, Nullable: true}, {Name: "group_id", Type: field.TypeInt64, Nullable: true}, {Name: "user_id", Type: field.TypeInt64}, } @@ -31,13 +34,13 @@ var ( ForeignKeys: []*schema.ForeignKey{ { Symbol: "api_keys_groups_api_keys", - Columns: []*schema.Column{APIKeysColumns[9]}, + Columns: []*schema.Column{APIKeysColumns[12]}, RefColumns: []*schema.Column{GroupsColumns[0]}, OnDelete: schema.SetNull, }, { Symbol: "api_keys_users_api_keys", - Columns: []*schema.Column{APIKeysColumns[10]}, + Columns: []*schema.Column{APIKeysColumns[13]}, RefColumns: []*schema.Column{UsersColumns[0]}, OnDelete: schema.NoAction, }, @@ -46,12 +49,12 @@ var ( { Name: "apikey_user_id", Unique: false, - Columns: []*schema.Column{APIKeysColumns[10]}, + Columns: []*schema.Column{APIKeysColumns[13]}, }, { Name: "apikey_group_id", Unique: false, - Columns: []*schema.Column{APIKeysColumns[9]}, + Columns: []*schema.Column{APIKeysColumns[12]}, }, { Name: "apikey_status", @@ -63,6 +66,16 @@ var ( Unique: false, Columns: []*schema.Column{APIKeysColumns[3]}, }, + { + Name: "apikey_quota_quota_used", + Unique: false, + Columns: []*schema.Column{APIKeysColumns[9], APIKeysColumns[10]}, + }, + { + Name: "apikey_expires_at", + Unique: false, + Columns: []*schema.Column{APIKeysColumns[11]}, + }, }, } // AccountsColumns holds the columns for the "accounts" table. diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go index c7812024..77d208e1 100644 --- a/backend/ent/mutation.go +++ b/backend/ent/mutation.go @@ -79,6 +79,11 @@ type APIKeyMutation struct { appendip_whitelist []string ip_blacklist *[]string appendip_blacklist []string + quota *float64 + addquota *float64 + quota_used *float64 + addquota_used *float64 + expires_at *time.Time clearedFields map[string]struct{} user *int64 cleareduser bool @@ -634,6 +639,167 @@ func (m *APIKeyMutation) ResetIPBlacklist() { delete(m.clearedFields, apikey.FieldIPBlacklist) } +// SetQuota sets the "quota" field. +func (m *APIKeyMutation) SetQuota(f float64) { + m.quota = &f + m.addquota = nil +} + +// Quota returns the value of the "quota" field in the mutation. +func (m *APIKeyMutation) Quota() (r float64, exists bool) { + v := m.quota + if v == nil { + return + } + return *v, true +} + +// OldQuota returns the old "quota" field's value of the APIKey entity. +// If the APIKey object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *APIKeyMutation) OldQuota(ctx context.Context) (v float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldQuota is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldQuota requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldQuota: %w", err) + } + return oldValue.Quota, nil +} + +// AddQuota adds f to the "quota" field. +func (m *APIKeyMutation) AddQuota(f float64) { + if m.addquota != nil { + *m.addquota += f + } else { + m.addquota = &f + } +} + +// AddedQuota returns the value that was added to the "quota" field in this mutation. +func (m *APIKeyMutation) AddedQuota() (r float64, exists bool) { + v := m.addquota + if v == nil { + return + } + return *v, true +} + +// ResetQuota resets all changes to the "quota" field. +func (m *APIKeyMutation) ResetQuota() { + m.quota = nil + m.addquota = nil +} + +// SetQuotaUsed sets the "quota_used" field. +func (m *APIKeyMutation) SetQuotaUsed(f float64) { + m.quota_used = &f + m.addquota_used = nil +} + +// QuotaUsed returns the value of the "quota_used" field in the mutation. +func (m *APIKeyMutation) QuotaUsed() (r float64, exists bool) { + v := m.quota_used + if v == nil { + return + } + return *v, true +} + +// OldQuotaUsed returns the old "quota_used" field's value of the APIKey entity. +// If the APIKey object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *APIKeyMutation) OldQuotaUsed(ctx context.Context) (v float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldQuotaUsed is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldQuotaUsed requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldQuotaUsed: %w", err) + } + return oldValue.QuotaUsed, nil +} + +// AddQuotaUsed adds f to the "quota_used" field. +func (m *APIKeyMutation) AddQuotaUsed(f float64) { + if m.addquota_used != nil { + *m.addquota_used += f + } else { + m.addquota_used = &f + } +} + +// AddedQuotaUsed returns the value that was added to the "quota_used" field in this mutation. +func (m *APIKeyMutation) AddedQuotaUsed() (r float64, exists bool) { + v := m.addquota_used + if v == nil { + return + } + return *v, true +} + +// ResetQuotaUsed resets all changes to the "quota_used" field. +func (m *APIKeyMutation) ResetQuotaUsed() { + m.quota_used = nil + m.addquota_used = nil +} + +// SetExpiresAt sets the "expires_at" field. +func (m *APIKeyMutation) SetExpiresAt(t time.Time) { + m.expires_at = &t +} + +// ExpiresAt returns the value of the "expires_at" field in the mutation. +func (m *APIKeyMutation) ExpiresAt() (r time.Time, exists bool) { + v := m.expires_at + if v == nil { + return + } + return *v, true +} + +// OldExpiresAt returns the old "expires_at" field's value of the APIKey entity. +// If the APIKey object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *APIKeyMutation) OldExpiresAt(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldExpiresAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldExpiresAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldExpiresAt: %w", err) + } + return oldValue.ExpiresAt, nil +} + +// ClearExpiresAt clears the value of the "expires_at" field. +func (m *APIKeyMutation) ClearExpiresAt() { + m.expires_at = nil + m.clearedFields[apikey.FieldExpiresAt] = struct{}{} +} + +// ExpiresAtCleared returns if the "expires_at" field was cleared in this mutation. +func (m *APIKeyMutation) ExpiresAtCleared() bool { + _, ok := m.clearedFields[apikey.FieldExpiresAt] + return ok +} + +// ResetExpiresAt resets all changes to the "expires_at" field. +func (m *APIKeyMutation) ResetExpiresAt() { + m.expires_at = nil + delete(m.clearedFields, apikey.FieldExpiresAt) +} + // ClearUser clears the "user" edge to the User entity. func (m *APIKeyMutation) ClearUser() { m.cleareduser = true @@ -776,7 +942,7 @@ func (m *APIKeyMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *APIKeyMutation) Fields() []string { - fields := make([]string, 0, 10) + fields := make([]string, 0, 13) if m.created_at != nil { fields = append(fields, apikey.FieldCreatedAt) } @@ -807,6 +973,15 @@ func (m *APIKeyMutation) Fields() []string { if m.ip_blacklist != nil { fields = append(fields, apikey.FieldIPBlacklist) } + if m.quota != nil { + fields = append(fields, apikey.FieldQuota) + } + if m.quota_used != nil { + fields = append(fields, apikey.FieldQuotaUsed) + } + if m.expires_at != nil { + fields = append(fields, apikey.FieldExpiresAt) + } return fields } @@ -835,6 +1010,12 @@ func (m *APIKeyMutation) Field(name string) (ent.Value, bool) { return m.IPWhitelist() case apikey.FieldIPBlacklist: return m.IPBlacklist() + case apikey.FieldQuota: + return m.Quota() + case apikey.FieldQuotaUsed: + return m.QuotaUsed() + case apikey.FieldExpiresAt: + return m.ExpiresAt() } return nil, false } @@ -864,6 +1045,12 @@ func (m *APIKeyMutation) OldField(ctx context.Context, name string) (ent.Value, return m.OldIPWhitelist(ctx) case apikey.FieldIPBlacklist: return m.OldIPBlacklist(ctx) + case apikey.FieldQuota: + return m.OldQuota(ctx) + case apikey.FieldQuotaUsed: + return m.OldQuotaUsed(ctx) + case apikey.FieldExpiresAt: + return m.OldExpiresAt(ctx) } return nil, fmt.Errorf("unknown APIKey field %s", name) } @@ -943,6 +1130,27 @@ func (m *APIKeyMutation) SetField(name string, value ent.Value) error { } m.SetIPBlacklist(v) return nil + case apikey.FieldQuota: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetQuota(v) + return nil + case apikey.FieldQuotaUsed: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetQuotaUsed(v) + return nil + case apikey.FieldExpiresAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetExpiresAt(v) + return nil } return fmt.Errorf("unknown APIKey field %s", name) } @@ -951,6 +1159,12 @@ func (m *APIKeyMutation) SetField(name string, value ent.Value) error { // this mutation. func (m *APIKeyMutation) AddedFields() []string { var fields []string + if m.addquota != nil { + fields = append(fields, apikey.FieldQuota) + } + if m.addquota_used != nil { + fields = append(fields, apikey.FieldQuotaUsed) + } return fields } @@ -959,6 +1173,10 @@ func (m *APIKeyMutation) AddedFields() []string { // was not set, or was not defined in the schema. func (m *APIKeyMutation) AddedField(name string) (ent.Value, bool) { switch name { + case apikey.FieldQuota: + return m.AddedQuota() + case apikey.FieldQuotaUsed: + return m.AddedQuotaUsed() } return nil, false } @@ -968,6 +1186,20 @@ func (m *APIKeyMutation) AddedField(name string) (ent.Value, bool) { // type. func (m *APIKeyMutation) AddField(name string, value ent.Value) error { switch name { + case apikey.FieldQuota: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddQuota(v) + return nil + case apikey.FieldQuotaUsed: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddQuotaUsed(v) + return nil } return fmt.Errorf("unknown APIKey numeric field %s", name) } @@ -988,6 +1220,9 @@ func (m *APIKeyMutation) ClearedFields() []string { if m.FieldCleared(apikey.FieldIPBlacklist) { fields = append(fields, apikey.FieldIPBlacklist) } + if m.FieldCleared(apikey.FieldExpiresAt) { + fields = append(fields, apikey.FieldExpiresAt) + } return fields } @@ -1014,6 +1249,9 @@ func (m *APIKeyMutation) ClearField(name string) error { case apikey.FieldIPBlacklist: m.ClearIPBlacklist() return nil + case apikey.FieldExpiresAt: + m.ClearExpiresAt() + return nil } return fmt.Errorf("unknown APIKey nullable field %s", name) } @@ -1052,6 +1290,15 @@ func (m *APIKeyMutation) ResetField(name string) error { case apikey.FieldIPBlacklist: m.ResetIPBlacklist() return nil + case apikey.FieldQuota: + m.ResetQuota() + return nil + case apikey.FieldQuotaUsed: + m.ResetQuotaUsed() + return nil + case apikey.FieldExpiresAt: + m.ResetExpiresAt() + return nil } return fmt.Errorf("unknown APIKey field %s", name) } diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go index 8032dc58..f1fea8cc 100644 --- a/backend/ent/runtime/runtime.go +++ b/backend/ent/runtime/runtime.go @@ -91,6 +91,14 @@ func init() { apikey.DefaultStatus = apikeyDescStatus.Default.(string) // apikey.StatusValidator is a validator for the "status" field. It is called by the builders before save. apikey.StatusValidator = apikeyDescStatus.Validators[0].(func(string) error) + // apikeyDescQuota is the schema descriptor for quota field. + apikeyDescQuota := apikeyFields[7].Descriptor() + // apikey.DefaultQuota holds the default value on creation for the quota field. + apikey.DefaultQuota = apikeyDescQuota.Default.(float64) + // apikeyDescQuotaUsed is the schema descriptor for quota_used field. + apikeyDescQuotaUsed := apikeyFields[8].Descriptor() + // apikey.DefaultQuotaUsed holds the default value on creation for the quota_used field. + apikey.DefaultQuotaUsed = apikeyDescQuotaUsed.Default.(float64) accountMixin := schema.Account{}.Mixin() accountMixinHooks1 := accountMixin[1].Hooks() account.Hooks[0] = accountMixinHooks1[0] diff --git a/backend/ent/schema/api_key.go b/backend/ent/schema/api_key.go index 1c2d4bd4..26d52cb0 100644 --- a/backend/ent/schema/api_key.go +++ b/backend/ent/schema/api_key.go @@ -5,6 +5,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/domain" "entgo.io/ent" + "entgo.io/ent/dialect" "entgo.io/ent/dialect/entsql" "entgo.io/ent/schema" "entgo.io/ent/schema/edge" @@ -52,6 +53,23 @@ func (APIKey) Fields() []ent.Field { field.JSON("ip_blacklist", []string{}). Optional(). Comment("Blocked IPs/CIDRs"), + + // ========== Quota fields ========== + // Quota limit in USD (0 = unlimited) + field.Float("quota"). + SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}). + Default(0). + Comment("Quota limit in USD for this API key (0 = unlimited)"), + // Used quota amount + field.Float("quota_used"). + SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}). + Default(0). + Comment("Used quota amount in USD"), + // Expiration time (nil = never expires) + field.Time("expires_at"). + Optional(). + Nillable(). + Comment("Expiration time for this API key (null = never expires)"), } } @@ -77,5 +95,8 @@ func (APIKey) Indexes() []ent.Index { index.Fields("group_id"), index.Fields("status"), index.Fields("deleted_at"), + // Index for quota queries + index.Fields("quota", "quota_used"), + index.Fields("expires_at"), } } diff --git a/backend/internal/handler/admin/admin_service_stub_test.go b/backend/internal/handler/admin/admin_service_stub_test.go index b820a3fb..ea2ea963 100644 --- a/backend/internal/handler/admin/admin_service_stub_test.go +++ b/backend/internal/handler/admin/admin_service_stub_test.go @@ -290,5 +290,9 @@ func (s *stubAdminService) ExpireRedeemCode(ctx context.Context, id int64) (*ser return &code, nil } +func (s *stubAdminService) GetUserBalanceHistory(ctx context.Context, userID int64, page, pageSize int, codeType string) ([]service.RedeemCode, int64, float64, error) { + return s.redeems, int64(len(s.redeems)), 100.0, nil +} + // Ensure stub implements interface. var _ service.AdminService = (*stubAdminService)(nil) diff --git a/backend/internal/handler/admin/user_handler.go b/backend/internal/handler/admin/user_handler.go index 9a5a691f..ac76689d 100644 --- a/backend/internal/handler/admin/user_handler.go +++ b/backend/internal/handler/admin/user_handler.go @@ -277,3 +277,44 @@ func (h *UserHandler) GetUserUsage(c *gin.Context) { response.Success(c, stats) } + +// GetBalanceHistory handles getting user's balance/concurrency change history +// GET /api/v1/admin/users/:id/balance-history +// Query params: +// - type: filter by record type (balance, admin_balance, concurrency, admin_concurrency, subscription) +func (h *UserHandler) GetBalanceHistory(c *gin.Context) { + userID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "Invalid user ID") + return + } + + page, pageSize := response.ParsePagination(c) + codeType := c.Query("type") + + codes, total, totalRecharged, err := h.adminService.GetUserBalanceHistory(c.Request.Context(), userID, page, pageSize, codeType) + if err != nil { + response.ErrorFrom(c, err) + return + } + + // Convert to admin DTO (includes notes field for admin visibility) + out := make([]dto.AdminRedeemCode, 0, len(codes)) + for i := range codes { + out = append(out, *dto.RedeemCodeFromServiceAdmin(&codes[i])) + } + + // Custom response with total_recharged alongside pagination + pages := int((total + int64(pageSize) - 1) / int64(pageSize)) + if pages < 1 { + pages = 1 + } + response.Success(c, gin.H{ + "items": out, + "total": total, + "page": page, + "page_size": pageSize, + "pages": pages, + "total_recharged": totalRecharged, + }) +} diff --git a/backend/internal/handler/api_key_handler.go b/backend/internal/handler/api_key_handler.go index 52dc6911..9717194b 100644 --- a/backend/internal/handler/api_key_handler.go +++ b/backend/internal/handler/api_key_handler.go @@ -3,6 +3,7 @@ package handler import ( "strconv" + "time" "github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" @@ -27,11 +28,13 @@ func NewAPIKeyHandler(apiKeyService *service.APIKeyService) *APIKeyHandler { // CreateAPIKeyRequest represents the create API key request payload type CreateAPIKeyRequest struct { - Name string `json:"name" binding:"required"` - GroupID *int64 `json:"group_id"` // nullable - CustomKey *string `json:"custom_key"` // 可选的自定义key - IPWhitelist []string `json:"ip_whitelist"` // IP 白名单 - IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单 + Name string `json:"name" binding:"required"` + GroupID *int64 `json:"group_id"` // nullable + CustomKey *string `json:"custom_key"` // 可选的自定义key + IPWhitelist []string `json:"ip_whitelist"` // IP 白名单 + IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单 + Quota *float64 `json:"quota"` // 配额限制 (USD) + ExpiresInDays *int `json:"expires_in_days"` // 过期天数 } // UpdateAPIKeyRequest represents the update API key request payload @@ -41,6 +44,9 @@ type UpdateAPIKeyRequest struct { Status string `json:"status" binding:"omitempty,oneof=active inactive"` IPWhitelist []string `json:"ip_whitelist"` // IP 白名单 IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单 + Quota *float64 `json:"quota"` // 配额限制 (USD), 0=无限制 + ExpiresAt *string `json:"expires_at"` // 过期时间 (ISO 8601) + ResetQuota *bool `json:"reset_quota"` // 重置已用配额 } // List handles listing user's API keys with pagination @@ -114,11 +120,15 @@ func (h *APIKeyHandler) Create(c *gin.Context) { } svcReq := service.CreateAPIKeyRequest{ - Name: req.Name, - GroupID: req.GroupID, - CustomKey: req.CustomKey, - IPWhitelist: req.IPWhitelist, - IPBlacklist: req.IPBlacklist, + Name: req.Name, + GroupID: req.GroupID, + CustomKey: req.CustomKey, + IPWhitelist: req.IPWhitelist, + IPBlacklist: req.IPBlacklist, + ExpiresInDays: req.ExpiresInDays, + } + if req.Quota != nil { + svcReq.Quota = *req.Quota } key, err := h.apiKeyService.Create(c.Request.Context(), subject.UserID, svcReq) if err != nil { @@ -153,6 +163,8 @@ func (h *APIKeyHandler) Update(c *gin.Context) { svcReq := service.UpdateAPIKeyRequest{ IPWhitelist: req.IPWhitelist, IPBlacklist: req.IPBlacklist, + Quota: req.Quota, + ResetQuota: req.ResetQuota, } if req.Name != "" { svcReq.Name = &req.Name @@ -161,6 +173,21 @@ func (h *APIKeyHandler) Update(c *gin.Context) { if req.Status != "" { svcReq.Status = &req.Status } + // Parse expires_at if provided + if req.ExpiresAt != nil { + if *req.ExpiresAt == "" { + // Empty string means clear expiration + svcReq.ExpiresAt = nil + svcReq.ClearExpiration = true + } else { + t, err := time.Parse(time.RFC3339, *req.ExpiresAt) + if err != nil { + response.BadRequest(c, "Invalid expires_at format: "+err.Error()) + return + } + svcReq.ExpiresAt = &t + } + } key, err := h.apiKeyService.Update(c.Request.Context(), keyID, subject.UserID, svcReq) if err != nil { diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go index 4cdb99fe..4f8d1eeb 100644 --- a/backend/internal/handler/dto/mappers.go +++ b/backend/internal/handler/dto/mappers.go @@ -76,6 +76,9 @@ func APIKeyFromService(k *service.APIKey) *APIKey { Status: k.Status, IPWhitelist: k.IPWhitelist, IPBlacklist: k.IPBlacklist, + Quota: k.Quota, + QuotaUsed: k.QuotaUsed, + ExpiresAt: k.ExpiresAt, CreatedAt: k.CreatedAt, UpdatedAt: k.UpdatedAt, User: UserFromServiceShallow(k.User), diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go index 1099e1f6..8e6faf02 100644 --- a/backend/internal/handler/dto/types.go +++ b/backend/internal/handler/dto/types.go @@ -32,16 +32,19 @@ type AdminUser struct { } type APIKey struct { - ID int64 `json:"id"` - UserID int64 `json:"user_id"` - Key string `json:"key"` - Name string `json:"name"` - GroupID *int64 `json:"group_id"` - Status string `json:"status"` - IPWhitelist []string `json:"ip_whitelist"` - IPBlacklist []string `json:"ip_blacklist"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` + ID int64 `json:"id"` + UserID int64 `json:"user_id"` + Key string `json:"key"` + Name string `json:"name"` + GroupID *int64 `json:"group_id"` + Status string `json:"status"` + IPWhitelist []string `json:"ip_whitelist"` + IPBlacklist []string `json:"ip_blacklist"` + Quota float64 `json:"quota"` // Quota limit in USD (0 = unlimited) + QuotaUsed float64 `json:"quota_used"` // Used quota amount in USD + ExpiresAt *time.Time `json:"expires_at"` // Expiration time (nil = never expires) + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` User *User `json:"user,omitempty"` Group *Group `json:"group,omitempty"` diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index ab8082db..ccf06b7f 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -32,6 +32,7 @@ type GatewayHandler struct { userService *service.UserService billingCacheService *service.BillingCacheService usageService *service.UsageService + apiKeyService *service.APIKeyService concurrencyHelper *ConcurrencyHelper maxAccountSwitches int maxAccountSwitchesGemini int @@ -46,6 +47,7 @@ func NewGatewayHandler( concurrencyService *service.ConcurrencyService, billingCacheService *service.BillingCacheService, usageService *service.UsageService, + apiKeyService *service.APIKeyService, cfg *config.Config, ) *GatewayHandler { pingInterval := time.Duration(0) @@ -67,6 +69,7 @@ func NewGatewayHandler( userService: userService, billingCacheService: billingCacheService, usageService: usageService, + apiKeyService: apiKeyService, concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval), maxAccountSwitches: maxAccountSwitches, maxAccountSwitchesGemini: maxAccountSwitchesGemini, @@ -321,13 +324,14 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ - Result: result, - APIKey: apiKey, - User: apiKey.User, - Account: usedAccount, - Subscription: subscription, - UserAgent: ua, - IPAddress: clientIP, + Result: result, + APIKey: apiKey, + User: apiKey.User, + Account: usedAccount, + Subscription: subscription, + UserAgent: ua, + IPAddress: clientIP, + APIKeyService: h.apiKeyService, }); err != nil { log.Printf("Record usage failed: %v", err) } @@ -513,13 +517,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) { Subscription: currentSubscription, UserAgent: ua, IPAddress: clientIP, + APIKeyService: h.apiKeyService, }); err != nil { log.Printf("Record usage failed: %v", err) } }(result, account, userAgent, clientIP) return } - if !retryWithFallback { return } diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index cfb59c04..787e3760 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -386,6 +386,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { IPAddress: ip, LongContextThreshold: 200000, // Gemini 200K 阈值 LongContextMultiplier: 2.0, // 超出部分双倍计费 + APIKeyService: h.apiKeyService, }); err != nil { log.Printf("Record usage failed: %v", err) } diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index 4c9dd8b9..a84679ae 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -24,6 +24,7 @@ import ( type OpenAIGatewayHandler struct { gatewayService *service.OpenAIGatewayService billingCacheService *service.BillingCacheService + apiKeyService *service.APIKeyService concurrencyHelper *ConcurrencyHelper maxAccountSwitches int } @@ -33,6 +34,7 @@ func NewOpenAIGatewayHandler( gatewayService *service.OpenAIGatewayService, concurrencyService *service.ConcurrencyService, billingCacheService *service.BillingCacheService, + apiKeyService *service.APIKeyService, cfg *config.Config, ) *OpenAIGatewayHandler { pingInterval := time.Duration(0) @@ -46,6 +48,7 @@ func NewOpenAIGatewayHandler( return &OpenAIGatewayHandler{ gatewayService: gatewayService, billingCacheService: billingCacheService, + apiKeyService: apiKeyService, concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval), maxAccountSwitches: maxAccountSwitches, } @@ -299,13 +302,14 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{ - Result: result, - APIKey: apiKey, - User: apiKey.User, - Account: usedAccount, - Subscription: subscription, - UserAgent: ua, - IPAddress: ip, + Result: result, + APIKey: apiKey, + User: apiKey.User, + Account: usedAccount, + Subscription: subscription, + UserAgent: ua, + IPAddress: ip, + APIKeyService: h.apiKeyService, }); err != nil { log.Printf("Record usage failed: %v", err) } diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go index 59f13985..c0cfd256 100644 --- a/backend/internal/repository/api_key_repo.go +++ b/backend/internal/repository/api_key_repo.go @@ -33,7 +33,10 @@ func (r *apiKeyRepository) Create(ctx context.Context, key *service.APIKey) erro SetKey(key.Key). SetName(key.Name). SetStatus(key.Status). - SetNillableGroupID(key.GroupID) + SetNillableGroupID(key.GroupID). + SetQuota(key.Quota). + SetQuotaUsed(key.QuotaUsed). + SetNillableExpiresAt(key.ExpiresAt) if len(key.IPWhitelist) > 0 { builder.SetIPWhitelist(key.IPWhitelist) @@ -110,6 +113,9 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se apikey.FieldStatus, apikey.FieldIPWhitelist, apikey.FieldIPBlacklist, + apikey.FieldQuota, + apikey.FieldQuotaUsed, + apikey.FieldExpiresAt, ). WithUser(func(q *dbent.UserQuery) { q.Select( @@ -164,6 +170,8 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) erro Where(apikey.IDEQ(key.ID), apikey.DeletedAtIsNil()). SetName(key.Name). SetStatus(key.Status). + SetQuota(key.Quota). + SetQuotaUsed(key.QuotaUsed). SetUpdatedAt(now) if key.GroupID != nil { builder.SetGroupID(*key.GroupID) @@ -171,6 +179,13 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) erro builder.ClearGroupID() } + // Expiration time + if key.ExpiresAt != nil { + builder.SetExpiresAt(*key.ExpiresAt) + } else { + builder.ClearExpiresAt() + } + // IP 限制字段 if len(key.IPWhitelist) > 0 { builder.SetIPWhitelist(key.IPWhitelist) @@ -360,6 +375,38 @@ func (r *apiKeyRepository) ListKeysByGroupID(ctx context.Context, groupID int64) return keys, nil } +// IncrementQuotaUsed atomically increments the quota_used field and returns the new value +func (r *apiKeyRepository) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error) { + // Use raw SQL for atomic increment to avoid race conditions + // First get current value + m, err := r.activeQuery(). + Where(apikey.IDEQ(id)). + Select(apikey.FieldQuotaUsed). + Only(ctx) + if err != nil { + if dbent.IsNotFound(err) { + return 0, service.ErrAPIKeyNotFound + } + return 0, err + } + + newValue := m.QuotaUsed + amount + + // Update with new value + affected, err := r.client.APIKey.Update(). + Where(apikey.IDEQ(id), apikey.DeletedAtIsNil()). + SetQuotaUsed(newValue). + Save(ctx) + if err != nil { + return 0, err + } + if affected == 0 { + return 0, service.ErrAPIKeyNotFound + } + + return newValue, nil +} + func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey { if m == nil { return nil @@ -375,6 +422,9 @@ func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey { CreatedAt: m.CreatedAt, UpdatedAt: m.UpdatedAt, GroupID: m.GroupID, + Quota: m.Quota, + QuotaUsed: m.QuotaUsed, + ExpiresAt: m.ExpiresAt, } if m.Edges.User != nil { out.User = userEntityToService(m.Edges.User) diff --git a/backend/internal/repository/proxy_probe_service.go b/backend/internal/repository/proxy_probe_service.go index fb6f405e..513e929c 100644 --- a/backend/internal/repository/proxy_probe_service.go +++ b/backend/internal/repository/proxy_probe_service.go @@ -28,7 +28,6 @@ func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber { log.Printf("[ProxyProbe] Warning: insecure_skip_verify is not allowed and will cause probe failure.") } return &proxyProbeService{ - ipInfoURL: defaultIPInfoURL, insecureSkipVerify: insecure, allowPrivateHosts: allowPrivate, validateResolvedIP: validateResolvedIP, @@ -36,12 +35,20 @@ func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber { } const ( - defaultIPInfoURL = "http://ip-api.com/json/?lang=zh-CN" defaultProxyProbeTimeout = 30 * time.Second ) +// probeURLs 按优先级排列的探测 URL 列表 +// 某些 AI API 专用代理只允许访问特定域名,因此需要多个备选 +var probeURLs = []struct { + url string + parser string // "ip-api" or "httpbin" +}{ + {"http://ip-api.com/json/?lang=zh-CN", "ip-api"}, + {"http://httpbin.org/ip", "httpbin"}, +} + type proxyProbeService struct { - ipInfoURL string insecureSkipVerify bool allowPrivateHosts bool validateResolvedIP bool @@ -60,8 +67,21 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s return nil, 0, fmt.Errorf("failed to create proxy client: %w", err) } + var lastErr error + for _, probe := range probeURLs { + exitInfo, latencyMs, err := s.probeWithURL(ctx, client, probe.url, probe.parser) + if err == nil { + return exitInfo, latencyMs, nil + } + lastErr = err + } + + return nil, 0, fmt.Errorf("all probe URLs failed, last error: %w", lastErr) +} + +func (s *proxyProbeService) probeWithURL(ctx context.Context, client *http.Client, url string, parser string) (*service.ProxyExitInfo, int64, error) { startTime := time.Now() - req, err := http.NewRequestWithContext(ctx, "GET", s.ipInfoURL, nil) + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) if err != nil { return nil, 0, fmt.Errorf("failed to create request: %w", err) } @@ -78,6 +98,22 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s return nil, latencyMs, fmt.Errorf("request failed with status: %d", resp.StatusCode) } + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, latencyMs, fmt.Errorf("failed to read response: %w", err) + } + + switch parser { + case "ip-api": + return s.parseIPAPI(body, latencyMs) + case "httpbin": + return s.parseHTTPBin(body, latencyMs) + default: + return nil, latencyMs, fmt.Errorf("unknown parser: %s", parser) + } +} + +func (s *proxyProbeService) parseIPAPI(body []byte, latencyMs int64) (*service.ProxyExitInfo, int64, error) { var ipInfo struct { Status string `json:"status"` Message string `json:"message"` @@ -89,13 +125,12 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s CountryCode string `json:"countryCode"` } - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, latencyMs, fmt.Errorf("failed to read response: %w", err) - } - if err := json.Unmarshal(body, &ipInfo); err != nil { - return nil, latencyMs, fmt.Errorf("failed to parse response: %w", err) + preview := string(body) + if len(preview) > 200 { + preview = preview[:200] + "..." + } + return nil, latencyMs, fmt.Errorf("failed to parse response: %w (body: %s)", err, preview) } if strings.ToLower(ipInfo.Status) != "success" { if ipInfo.Message == "" { @@ -116,3 +151,19 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s CountryCode: ipInfo.CountryCode, }, latencyMs, nil } + +func (s *proxyProbeService) parseHTTPBin(body []byte, latencyMs int64) (*service.ProxyExitInfo, int64, error) { + // httpbin.org/ip 返回格式: {"origin": "1.2.3.4"} + var result struct { + Origin string `json:"origin"` + } + if err := json.Unmarshal(body, &result); err != nil { + return nil, latencyMs, fmt.Errorf("failed to parse httpbin response: %w", err) + } + if result.Origin == "" { + return nil, latencyMs, fmt.Errorf("httpbin: no IP found in response") + } + return &service.ProxyExitInfo{ + IP: result.Origin, + }, latencyMs, nil +} diff --git a/backend/internal/repository/proxy_probe_service_test.go b/backend/internal/repository/proxy_probe_service_test.go index f1cd5721..7450653b 100644 --- a/backend/internal/repository/proxy_probe_service_test.go +++ b/backend/internal/repository/proxy_probe_service_test.go @@ -5,6 +5,7 @@ import ( "io" "net/http" "net/http/httptest" + "strings" "testing" "github.com/stretchr/testify/require" @@ -21,7 +22,6 @@ type ProxyProbeServiceSuite struct { func (s *ProxyProbeServiceSuite) SetupTest() { s.ctx = context.Background() s.prober = &proxyProbeService{ - ipInfoURL: "http://ip-api.test/json/?lang=zh-CN", allowPrivateHosts: true, } } @@ -49,12 +49,16 @@ func (s *ProxyProbeServiceSuite) TestProbeProxy_UnsupportedProxyScheme() { require.ErrorContains(s.T(), err, "failed to create proxy client") } -func (s *ProxyProbeServiceSuite) TestProbeProxy_Success() { - seen := make(chan string, 1) +func (s *ProxyProbeServiceSuite) TestProbeProxy_Success_IPAPI() { s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - seen <- r.RequestURI - w.Header().Set("Content-Type", "application/json") - _, _ = io.WriteString(w, `{"status":"success","query":"1.2.3.4","city":"c","regionName":"r","country":"cc","countryCode":"CC"}`) + // 检查是否是 ip-api 请求 + if strings.Contains(r.RequestURI, "ip-api.com") { + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, `{"status":"success","query":"1.2.3.4","city":"c","regionName":"r","country":"cc","countryCode":"CC"}`) + return + } + // 其他请求返回错误 + w.WriteHeader(http.StatusServiceUnavailable) })) info, latencyMs, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL) @@ -65,45 +69,59 @@ func (s *ProxyProbeServiceSuite) TestProbeProxy_Success() { require.Equal(s.T(), "r", info.Region) require.Equal(s.T(), "cc", info.Country) require.Equal(s.T(), "CC", info.CountryCode) - - // Verify proxy received the request - select { - case uri := <-seen: - require.Contains(s.T(), uri, "ip-api.test", "expected request to go through proxy") - default: - require.Fail(s.T(), "expected proxy to receive request") - } } -func (s *ProxyProbeServiceSuite) TestProbeProxy_NonOKStatus() { +func (s *ProxyProbeServiceSuite) TestProbeProxy_Success_HTTPBinFallback() { + s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // ip-api 失败 + if strings.Contains(r.RequestURI, "ip-api.com") { + w.WriteHeader(http.StatusServiceUnavailable) + return + } + // httpbin 成功 + if strings.Contains(r.RequestURI, "httpbin.org") { + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, `{"origin": "5.6.7.8"}`) + return + } + w.WriteHeader(http.StatusServiceUnavailable) + })) + + info, latencyMs, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL) + require.NoError(s.T(), err, "ProbeProxy should fallback to httpbin") + require.GreaterOrEqual(s.T(), latencyMs, int64(0), "unexpected latency") + require.Equal(s.T(), "5.6.7.8", info.IP) +} + +func (s *ProxyProbeServiceSuite) TestProbeProxy_AllFailed() { s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusServiceUnavailable) })) _, _, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL) require.Error(s.T(), err) - require.ErrorContains(s.T(), err, "status: 503") + require.ErrorContains(s.T(), err, "all probe URLs failed") } func (s *ProxyProbeServiceSuite) TestProbeProxy_InvalidJSON() { s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - _, _ = io.WriteString(w, "not-json") + if strings.Contains(r.RequestURI, "ip-api.com") { + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, "not-json") + return + } + // httpbin 也返回无效响应 + if strings.Contains(r.RequestURI, "httpbin.org") { + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, "not-json") + return + } + w.WriteHeader(http.StatusServiceUnavailable) })) _, _, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL) require.Error(s.T(), err) - require.ErrorContains(s.T(), err, "failed to parse response") -} - -func (s *ProxyProbeServiceSuite) TestProbeProxy_InvalidIPInfoURL() { - s.prober.ipInfoURL = "://invalid-url" - s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - })) - - _, _, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL) - require.Error(s.T(), err, "expected error for invalid ipInfoURL") + require.ErrorContains(s.T(), err, "all probe URLs failed") } func (s *ProxyProbeServiceSuite) TestProbeProxy_ProxyServerClosed() { @@ -114,6 +132,40 @@ func (s *ProxyProbeServiceSuite) TestProbeProxy_ProxyServerClosed() { require.Error(s.T(), err, "expected error when proxy server is closed") } +func (s *ProxyProbeServiceSuite) TestParseIPAPI_Success() { + body := []byte(`{"status":"success","query":"1.2.3.4","city":"Beijing","regionName":"Beijing","country":"China","countryCode":"CN"}`) + info, latencyMs, err := s.prober.parseIPAPI(body, 100) + require.NoError(s.T(), err) + require.Equal(s.T(), int64(100), latencyMs) + require.Equal(s.T(), "1.2.3.4", info.IP) + require.Equal(s.T(), "Beijing", info.City) + require.Equal(s.T(), "Beijing", info.Region) + require.Equal(s.T(), "China", info.Country) + require.Equal(s.T(), "CN", info.CountryCode) +} + +func (s *ProxyProbeServiceSuite) TestParseIPAPI_Failure() { + body := []byte(`{"status":"fail","message":"rate limited"}`) + _, _, err := s.prober.parseIPAPI(body, 100) + require.Error(s.T(), err) + require.ErrorContains(s.T(), err, "rate limited") +} + +func (s *ProxyProbeServiceSuite) TestParseHTTPBin_Success() { + body := []byte(`{"origin": "9.8.7.6"}`) + info, latencyMs, err := s.prober.parseHTTPBin(body, 50) + require.NoError(s.T(), err) + require.Equal(s.T(), int64(50), latencyMs) + require.Equal(s.T(), "9.8.7.6", info.IP) +} + +func (s *ProxyProbeServiceSuite) TestParseHTTPBin_NoIP() { + body := []byte(`{"origin": ""}`) + _, _, err := s.prober.parseHTTPBin(body, 50) + require.Error(s.T(), err) + require.ErrorContains(s.T(), err, "no IP found") +} + func TestProxyProbeServiceSuite(t *testing.T) { suite.Run(t, new(ProxyProbeServiceSuite)) } diff --git a/backend/internal/repository/redeem_code_repo.go b/backend/internal/repository/redeem_code_repo.go index ee8a01b5..a3a048c3 100644 --- a/backend/internal/repository/redeem_code_repo.go +++ b/backend/internal/repository/redeem_code_repo.go @@ -202,6 +202,57 @@ func (r *redeemCodeRepository) ListByUser(ctx context.Context, userID int64, lim return redeemCodeEntitiesToService(codes), nil } +// ListByUserPaginated returns paginated balance/concurrency history for a user. +// Supports optional type filter (e.g. "balance", "admin_balance", "concurrency", "admin_concurrency", "subscription"). +func (r *redeemCodeRepository) ListByUserPaginated(ctx context.Context, userID int64, params pagination.PaginationParams, codeType string) ([]service.RedeemCode, *pagination.PaginationResult, error) { + q := r.client.RedeemCode.Query(). + Where(redeemcode.UsedByEQ(userID)) + + // Optional type filter + if codeType != "" { + q = q.Where(redeemcode.TypeEQ(codeType)) + } + + total, err := q.Count(ctx) + if err != nil { + return nil, nil, err + } + + codes, err := q. + WithGroup(). + Offset(params.Offset()). + Limit(params.Limit()). + Order(dbent.Desc(redeemcode.FieldUsedAt)). + All(ctx) + if err != nil { + return nil, nil, err + } + + return redeemCodeEntitiesToService(codes), paginationResultFromTotal(int64(total), params), nil +} + +// SumPositiveBalanceByUser returns total recharged amount (sum of value > 0 where type is balance/admin_balance). +func (r *redeemCodeRepository) SumPositiveBalanceByUser(ctx context.Context, userID int64) (float64, error) { + var result []struct { + Sum float64 `json:"sum"` + } + err := r.client.RedeemCode.Query(). + Where( + redeemcode.UsedByEQ(userID), + redeemcode.ValueGT(0), + redeemcode.TypeIn("balance", "admin_balance"), + ). + Aggregate(dbent.As(dbent.Sum(redeemcode.FieldValue), "sum")). + Scan(ctx, &result) + if err != nil { + return 0, err + } + if len(result) == 0 { + return 0, nil + } + return result[0].Sum, nil +} + func redeemCodeEntityToService(m *dbent.RedeemCode) *service.RedeemCode { if m == nil { return nil diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index 28a00fa3..44264e72 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -83,6 +83,9 @@ func TestAPIContracts(t *testing.T) { "status": "active", "ip_whitelist": null, "ip_blacklist": null, + "quota": 0, + "quota_used": 0, + "expires_at": null, "created_at": "2025-01-02T03:04:05Z", "updated_at": "2025-01-02T03:04:05Z" } @@ -119,6 +122,9 @@ func TestAPIContracts(t *testing.T) { "status": "active", "ip_whitelist": null, "ip_blacklist": null, + "quota": 0, + "quota_used": 0, + "expires_at": null, "created_at": "2025-01-02T03:04:05Z", "updated_at": "2025-01-02T03:04:05Z" } @@ -1151,6 +1157,14 @@ func (r *stubRedeemCodeRepo) ListByUser(ctx context.Context, userID int64, limit return append([]service.RedeemCode(nil), codes...), nil } +func (stubRedeemCodeRepo) ListByUserPaginated(ctx context.Context, userID int64, params pagination.PaginationParams, codeType string) ([]service.RedeemCode, *pagination.PaginationResult, error) { + return nil, nil, errors.New("not implemented") +} + +func (stubRedeemCodeRepo) SumPositiveBalanceByUser(ctx context.Context, userID int64) (float64, error) { + return 0, errors.New("not implemented") +} + type stubUserSubscriptionRepo struct { byUser map[int64][]service.UserSubscription activeByUser map[int64][]service.UserSubscription @@ -1435,6 +1449,10 @@ func (r *stubApiKeyRepo) ListKeysByGroupID(ctx context.Context, groupID int64) ( return nil, errors.New("not implemented") } +func (r *stubApiKeyRepo) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error) { + return 0, errors.New("not implemented") +} + type stubUsageLogRepo struct { userLogs map[int64][]service.UsageLog } diff --git a/backend/internal/server/middleware/api_key_auth.go b/backend/internal/server/middleware/api_key_auth.go index dff6ba95..2f739357 100644 --- a/backend/internal/server/middleware/api_key_auth.go +++ b/backend/internal/server/middleware/api_key_auth.go @@ -70,7 +70,27 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti // 检查API key是否激活 if !apiKey.IsActive() { - AbortWithError(c, 401, "API_KEY_DISABLED", "API key is disabled") + // Provide more specific error message based on status + switch apiKey.Status { + case service.StatusAPIKeyQuotaExhausted: + AbortWithError(c, 429, "API_KEY_QUOTA_EXHAUSTED", "API key 额度已用完") + case service.StatusAPIKeyExpired: + AbortWithError(c, 403, "API_KEY_EXPIRED", "API key 已过期") + default: + AbortWithError(c, 401, "API_KEY_DISABLED", "API key is disabled") + } + return + } + + // 检查API Key是否过期(即使状态是active,也要检查时间) + if apiKey.IsExpired() { + AbortWithError(c, 403, "API_KEY_EXPIRED", "API key 已过期") + return + } + + // 检查API Key配额是否耗尽 + if apiKey.IsQuotaExhausted() { + AbortWithError(c, 429, "API_KEY_QUOTA_EXHAUSTED", "API key 额度已用完") return } diff --git a/backend/internal/server/middleware/api_key_auth_google_test.go b/backend/internal/server/middleware/api_key_auth_google_test.go index 6f09469b..c14582bd 100644 --- a/backend/internal/server/middleware/api_key_auth_google_test.go +++ b/backend/internal/server/middleware/api_key_auth_google_test.go @@ -75,6 +75,9 @@ func (f fakeAPIKeyRepo) ListKeysByUserID(ctx context.Context, userID int64) ([]s func (f fakeAPIKeyRepo) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) { return nil, errors.New("not implemented") } +func (f fakeAPIKeyRepo) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error) { + return 0, errors.New("not implemented") +} type googleErrorResponse struct { Error struct { diff --git a/backend/internal/server/middleware/api_key_auth_test.go b/backend/internal/server/middleware/api_key_auth_test.go index 920ff93f..a03f6168 100644 --- a/backend/internal/server/middleware/api_key_auth_test.go +++ b/backend/internal/server/middleware/api_key_auth_test.go @@ -319,6 +319,10 @@ func (r *stubApiKeyRepo) ListKeysByGroupID(ctx context.Context, groupID int64) ( return nil, errors.New("not implemented") } +func (r *stubApiKeyRepo) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error) { + return 0, errors.New("not implemented") +} + type stubUserSubscriptionRepo struct { getActive func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) updateStatus func(ctx context.Context, subscriptionID int64, status string) error diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index 3e0033e7..ca9d627e 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -175,6 +175,7 @@ func registerUserManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) { users.POST("/:id/balance", h.Admin.User.UpdateBalance) users.GET("/:id/api-keys", h.Admin.User.GetUserAPIKeys) users.GET("/:id/usage", h.Admin.User.GetUserUsage) + users.GET("/:id/balance-history", h.Admin.User.GetBalanceHistory) // User attribute values users.GET("/:id/attributes", h.Admin.UserAttribute.GetUserAttributes) diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index 52a10476..c512f235 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -22,6 +22,10 @@ type AdminService interface { UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]APIKey, int64, error) GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error) + // GetUserBalanceHistory returns paginated balance/concurrency change records for a user. + // codeType is optional - pass empty string to return all types. + // Also returns totalRecharged (sum of all positive balance top-ups). + GetUserBalanceHistory(ctx context.Context, userID int64, page, pageSize int, codeType string) ([]RedeemCode, int64, float64, error) // Group management ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool) ([]Group, int64, error) @@ -536,6 +540,21 @@ func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64, }, nil } +// GetUserBalanceHistory returns paginated balance/concurrency change records for a user. +func (s *adminServiceImpl) GetUserBalanceHistory(ctx context.Context, userID int64, page, pageSize int, codeType string) ([]RedeemCode, int64, float64, error) { + params := pagination.PaginationParams{Page: page, PageSize: pageSize} + codes, result, err := s.redeemCodeRepo.ListByUserPaginated(ctx, userID, params, codeType) + if err != nil { + return nil, 0, 0, err + } + // Aggregate total recharged amount (only once, regardless of type filter) + totalRecharged, err := s.redeemCodeRepo.SumPositiveBalanceByUser(ctx, userID) + if err != nil { + return nil, 0, 0, err + } + return codes, result.Total, totalRecharged, nil +} + // Group management implementations func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool) ([]Group, int64, error) { params := pagination.PaginationParams{Page: page, PageSize: pageSize} diff --git a/backend/internal/service/admin_service_delete_test.go b/backend/internal/service/admin_service_delete_test.go index 923d33ab..e2aa83d9 100644 --- a/backend/internal/service/admin_service_delete_test.go +++ b/backend/internal/service/admin_service_delete_test.go @@ -282,6 +282,14 @@ func (s *redeemRepoStub) ListByUser(ctx context.Context, userID int64, limit int panic("unexpected ListByUser call") } +func (s *redeemRepoStub) ListByUserPaginated(ctx context.Context, userID int64, params pagination.PaginationParams, codeType string) ([]RedeemCode, *pagination.PaginationResult, error) { + panic("unexpected ListByUserPaginated call") +} + +func (s *redeemRepoStub) SumPositiveBalanceByUser(ctx context.Context, userID int64) (float64, error) { + panic("unexpected SumPositiveBalanceByUser call") +} + type subscriptionInvalidateCall struct { userID int64 groupID int64 diff --git a/backend/internal/service/admin_service_search_test.go b/backend/internal/service/admin_service_search_test.go index 7506c6db..d661b710 100644 --- a/backend/internal/service/admin_service_search_test.go +++ b/backend/internal/service/admin_service_search_test.go @@ -152,6 +152,14 @@ func (s *redeemRepoStubForAdminList) ListWithFilters(_ context.Context, params p return s.listWithFiltersCodes, result, nil } +func (s *redeemRepoStubForAdminList) ListByUserPaginated(_ context.Context, userID int64, params pagination.PaginationParams, codeType string) ([]RedeemCode, *pagination.PaginationResult, error) { + panic("unexpected ListByUserPaginated call") +} + +func (s *redeemRepoStubForAdminList) SumPositiveBalanceByUser(_ context.Context, userID int64) (float64, error) { + panic("unexpected SumPositiveBalanceByUser call") +} + func TestAdminService_ListAccounts_WithSearch(t *testing.T) { t.Run("search 参数正常传递到 repository 层", func(t *testing.T) { repo := &accountRepoStubForAdminList{ diff --git a/backend/internal/service/api_key.go b/backend/internal/service/api_key.go index 8c692d09..d66059dd 100644 --- a/backend/internal/service/api_key.go +++ b/backend/internal/service/api_key.go @@ -2,6 +2,14 @@ package service import "time" +// API Key status constants +const ( + StatusAPIKeyActive = "active" + StatusAPIKeyDisabled = "disabled" + StatusAPIKeyQuotaExhausted = "quota_exhausted" + StatusAPIKeyExpired = "expired" +) + type APIKey struct { ID int64 UserID int64 @@ -15,8 +23,53 @@ type APIKey struct { UpdatedAt time.Time User *User Group *Group + + // Quota fields + Quota float64 // Quota limit in USD (0 = unlimited) + QuotaUsed float64 // Used quota amount + ExpiresAt *time.Time // Expiration time (nil = never expires) } func (k *APIKey) IsActive() bool { return k.Status == StatusActive } + +// IsExpired checks if the API key has expired +func (k *APIKey) IsExpired() bool { + if k.ExpiresAt == nil { + return false + } + return time.Now().After(*k.ExpiresAt) +} + +// IsQuotaExhausted checks if the API key quota is exhausted +func (k *APIKey) IsQuotaExhausted() bool { + if k.Quota <= 0 { + return false // unlimited + } + return k.QuotaUsed >= k.Quota +} + +// GetQuotaRemaining returns remaining quota (-1 for unlimited) +func (k *APIKey) GetQuotaRemaining() float64 { + if k.Quota <= 0 { + return -1 // unlimited + } + remaining := k.Quota - k.QuotaUsed + if remaining < 0 { + return 0 + } + return remaining +} + +// GetDaysUntilExpiry returns days until expiry (-1 for never expires) +func (k *APIKey) GetDaysUntilExpiry() int { + if k.ExpiresAt == nil { + return -1 // never expires + } + duration := time.Until(*k.ExpiresAt) + if duration < 0 { + return 0 + } + return int(duration.Hours() / 24) +} diff --git a/backend/internal/service/api_key_auth_cache.go b/backend/internal/service/api_key_auth_cache.go index b56e7cf3..d15b5817 100644 --- a/backend/internal/service/api_key_auth_cache.go +++ b/backend/internal/service/api_key_auth_cache.go @@ -1,5 +1,7 @@ package service +import "time" + // APIKeyAuthSnapshot API Key 认证缓存快照(仅包含认证所需字段) type APIKeyAuthSnapshot struct { APIKeyID int64 `json:"api_key_id"` @@ -10,6 +12,13 @@ type APIKeyAuthSnapshot struct { IPBlacklist []string `json:"ip_blacklist,omitempty"` User APIKeyAuthUserSnapshot `json:"user"` Group *APIKeyAuthGroupSnapshot `json:"group,omitempty"` + + // Quota fields for API Key independent quota feature + Quota float64 `json:"quota"` // Quota limit in USD (0 = unlimited) + QuotaUsed float64 `json:"quota_used"` // Used quota amount + + // Expiration field for API Key expiration feature + ExpiresAt *time.Time `json:"expires_at,omitempty"` // Expiration time (nil = never expires) } // APIKeyAuthUserSnapshot 用户快照 diff --git a/backend/internal/service/api_key_auth_cache_impl.go b/backend/internal/service/api_key_auth_cache_impl.go index d4b2347e..f5bba7d0 100644 --- a/backend/internal/service/api_key_auth_cache_impl.go +++ b/backend/internal/service/api_key_auth_cache_impl.go @@ -213,6 +213,9 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot { Status: apiKey.Status, IPWhitelist: apiKey.IPWhitelist, IPBlacklist: apiKey.IPBlacklist, + Quota: apiKey.Quota, + QuotaUsed: apiKey.QuotaUsed, + ExpiresAt: apiKey.ExpiresAt, User: APIKeyAuthUserSnapshot{ ID: apiKey.User.ID, Status: apiKey.User.Status, @@ -259,6 +262,9 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho Status: snapshot.Status, IPWhitelist: snapshot.IPWhitelist, IPBlacklist: snapshot.IPBlacklist, + Quota: snapshot.Quota, + QuotaUsed: snapshot.QuotaUsed, + ExpiresAt: snapshot.ExpiresAt, User: &User{ ID: snapshot.User.ID, Status: snapshot.User.Status, diff --git a/backend/internal/service/api_key_service.go b/backend/internal/service/api_key_service.go index ef1ff990..b27682f3 100644 --- a/backend/internal/service/api_key_service.go +++ b/backend/internal/service/api_key_service.go @@ -24,6 +24,10 @@ var ( ErrAPIKeyInvalidChars = infraerrors.BadRequest("API_KEY_INVALID_CHARS", "api key can only contain letters, numbers, underscores, and hyphens") ErrAPIKeyRateLimited = infraerrors.TooManyRequests("API_KEY_RATE_LIMITED", "too many failed attempts, please try again later") ErrInvalidIPPattern = infraerrors.BadRequest("INVALID_IP_PATTERN", "invalid IP or CIDR pattern") + // ErrAPIKeyExpired = infraerrors.Forbidden("API_KEY_EXPIRED", "api key has expired") + ErrAPIKeyExpired = infraerrors.Forbidden("API_KEY_EXPIRED", "api key 已过期") + // ErrAPIKeyQuotaExhausted = infraerrors.TooManyRequests("API_KEY_QUOTA_EXHAUSTED", "api key quota exhausted") + ErrAPIKeyQuotaExhausted = infraerrors.TooManyRequests("API_KEY_QUOTA_EXHAUSTED", "api key 额度已用完") ) const ( @@ -51,6 +55,9 @@ type APIKeyRepository interface { CountByGroupID(ctx context.Context, groupID int64) (int64, error) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) + + // Quota methods + IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error) } // APIKeyCache defines cache operations for API key service @@ -85,6 +92,10 @@ type CreateAPIKeyRequest struct { CustomKey *string `json:"custom_key"` // 可选的自定义key IPWhitelist []string `json:"ip_whitelist"` // IP 白名单 IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单 + + // Quota fields + Quota float64 `json:"quota"` // Quota limit in USD (0 = unlimited) + ExpiresInDays *int `json:"expires_in_days"` // Days until expiry (nil = never expires) } // UpdateAPIKeyRequest 更新API Key请求 @@ -94,6 +105,12 @@ type UpdateAPIKeyRequest struct { Status *string `json:"status"` IPWhitelist []string `json:"ip_whitelist"` // IP 白名单(空数组清空) IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单(空数组清空) + + // Quota fields + Quota *float64 `json:"quota"` // Quota limit in USD (nil = no change, 0 = unlimited) + ExpiresAt *time.Time `json:"expires_at"` // Expiration time (nil = no change) + ClearExpiration bool `json:"-"` // Clear expiration (internal use) + ResetQuota *bool `json:"reset_quota"` // Reset quota_used to 0 } // APIKeyService API Key服务 @@ -289,6 +306,14 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK Status: StatusActive, IPWhitelist: req.IPWhitelist, IPBlacklist: req.IPBlacklist, + Quota: req.Quota, + QuotaUsed: 0, + } + + // Set expiration time if specified + if req.ExpiresInDays != nil && *req.ExpiresInDays > 0 { + expiresAt := time.Now().AddDate(0, 0, *req.ExpiresInDays) + apiKey.ExpiresAt = &expiresAt } if err := s.apiKeyRepo.Create(ctx, apiKey); err != nil { @@ -436,6 +461,35 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req } } + // Update quota fields + if req.Quota != nil { + apiKey.Quota = *req.Quota + // If quota is increased and status was quota_exhausted, reactivate + if apiKey.Status == StatusAPIKeyQuotaExhausted && *req.Quota > apiKey.QuotaUsed { + apiKey.Status = StatusActive + } + } + if req.ResetQuota != nil && *req.ResetQuota { + apiKey.QuotaUsed = 0 + // If resetting quota and status was quota_exhausted, reactivate + if apiKey.Status == StatusAPIKeyQuotaExhausted { + apiKey.Status = StatusActive + } + } + if req.ClearExpiration { + apiKey.ExpiresAt = nil + // If clearing expiry and status was expired, reactivate + if apiKey.Status == StatusAPIKeyExpired { + apiKey.Status = StatusActive + } + } else if req.ExpiresAt != nil { + apiKey.ExpiresAt = req.ExpiresAt + // If extending expiry and status was expired, reactivate + if apiKey.Status == StatusAPIKeyExpired && time.Now().Before(*req.ExpiresAt) { + apiKey.Status = StatusActive + } + } + // 更新 IP 限制(空数组会清空设置) apiKey.IPWhitelist = req.IPWhitelist apiKey.IPBlacklist = req.IPBlacklist @@ -572,3 +626,51 @@ func (s *APIKeyService) SearchAPIKeys(ctx context.Context, userID int64, keyword } return keys, nil } + +// CheckAPIKeyQuotaAndExpiry checks if the API key is valid for use (not expired, quota not exhausted) +// Returns nil if valid, error if invalid +func (s *APIKeyService) CheckAPIKeyQuotaAndExpiry(apiKey *APIKey) error { + // Check expiration + if apiKey.IsExpired() { + return ErrAPIKeyExpired + } + + // Check quota + if apiKey.IsQuotaExhausted() { + return ErrAPIKeyQuotaExhausted + } + + return nil +} + +// UpdateQuotaUsed updates the quota_used field after a request +// Also checks if quota is exhausted and updates status accordingly +func (s *APIKeyService) UpdateQuotaUsed(ctx context.Context, apiKeyID int64, cost float64) error { + if cost <= 0 { + return nil + } + + // Use repository to atomically increment quota_used + newQuotaUsed, err := s.apiKeyRepo.IncrementQuotaUsed(ctx, apiKeyID, cost) + if err != nil { + return fmt.Errorf("increment quota used: %w", err) + } + + // Check if quota is now exhausted and update status if needed + apiKey, err := s.apiKeyRepo.GetByID(ctx, apiKeyID) + if err != nil { + return nil // Don't fail the request, just log + } + + // If quota is set and now exhausted, update status + if apiKey.Quota > 0 && newQuotaUsed >= apiKey.Quota { + apiKey.Status = StatusAPIKeyQuotaExhausted + if err := s.apiKeyRepo.Update(ctx, apiKey); err != nil { + return nil // Don't fail the request + } + // Invalidate cache so next request sees the new status + s.InvalidateAuthCacheByKey(ctx, apiKey.Key) + } + + return nil +} diff --git a/backend/internal/service/api_key_service_cache_test.go b/backend/internal/service/api_key_service_cache_test.go index c5e9cd47..1099b1d2 100644 --- a/backend/internal/service/api_key_service_cache_test.go +++ b/backend/internal/service/api_key_service_cache_test.go @@ -99,6 +99,10 @@ func (s *authRepoStub) ListKeysByGroupID(ctx context.Context, groupID int64) ([] return s.listKeysByGroupID(ctx, groupID) } +func (s *authRepoStub) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error) { + panic("unexpected IncrementQuotaUsed call") +} + type authCacheStub struct { getAuthCache func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) setAuthKeys []string diff --git a/backend/internal/service/api_key_service_delete_test.go b/backend/internal/service/api_key_service_delete_test.go index 092b7fce..d4d12144 100644 --- a/backend/internal/service/api_key_service_delete_test.go +++ b/backend/internal/service/api_key_service_delete_test.go @@ -118,6 +118,10 @@ func (s *apiKeyRepoStub) ListKeysByGroupID(ctx context.Context, groupID int64) ( panic("unexpected ListKeysByGroupID call") } +func (s *apiKeyRepoStub) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error) { + panic("unexpected IncrementQuotaUsed call") +} + // apiKeyCacheStub 是 APIKeyCache 接口的测试桩实现。 // 用于验证删除操作时缓存清理逻辑是否被正确调用。 // diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 065f3cba..9b31a9c6 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -4540,13 +4540,19 @@ func (s *GatewayService) replaceToolNamesInResponseBody(body []byte, toolNameMap // RecordUsageInput 记录使用量的输入参数 type RecordUsageInput struct { - Result *ForwardResult - APIKey *APIKey - User *User - Account *Account - Subscription *UserSubscription // 可选:订阅信息 - UserAgent string // 请求的 User-Agent - IPAddress string // 请求的客户端 IP 地址 + Result *ForwardResult + APIKey *APIKey + User *User + Account *Account + Subscription *UserSubscription // 可选:订阅信息 + UserAgent string // 请求的 User-Agent + IPAddress string // 请求的客户端 IP 地址 + APIKeyService APIKeyQuotaUpdater // 可选:用于更新API Key配额 +} + +// APIKeyQuotaUpdater defines the interface for updating API Key quota +type APIKeyQuotaUpdater interface { + UpdateQuotaUsed(ctx context.Context, apiKeyID int64, cost float64) error } // RecordUsage 记录使用量并扣费(或更新订阅用量) @@ -4686,6 +4692,13 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu } } + // 更新 API Key 配额(如果设置了配额限制) + if shouldBill && cost.ActualCost > 0 && apiKey.Quota > 0 && input.APIKeyService != nil { + if err := input.APIKeyService.UpdateQuotaUsed(ctx, apiKey.ID, cost.ActualCost); err != nil { + log.Printf("Update API key quota failed: %v", err) + } + } + // Schedule batch update for account last_used_at s.deferredService.ScheduleLastUsedUpdate(account.ID) @@ -4703,6 +4716,7 @@ type RecordUsageLongContextInput struct { IPAddress string // 请求的客户端 IP 地址 LongContextThreshold int // 长上下文阈值(如 200000) LongContextMultiplier float64 // 超出阈值部分的倍率(如 2.0) + APIKeyService *APIKeyService // API Key 配额服务(可选) } // RecordUsageWithLongContext 记录使用量并扣费,支持长上下文双倍计费(用于 Gemini) @@ -4839,6 +4853,12 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * } // 异步更新余额缓存 s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost) + // API Key 独立配额扣费 + if input.APIKeyService != nil && apiKey.Quota > 0 { + if err := input.APIKeyService.UpdateQuotaUsed(ctx, apiKey.ID, cost.ActualCost); err != nil { + log.Printf("Add API key quota used failed: %v", err) + } + } } } diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 6d93e92d..aa9c00e0 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -1681,13 +1681,14 @@ func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel // OpenAIRecordUsageInput input for recording usage type OpenAIRecordUsageInput struct { - Result *OpenAIForwardResult - APIKey *APIKey - User *User - Account *Account - Subscription *UserSubscription - UserAgent string // 请求的 User-Agent - IPAddress string // 请求的客户端 IP 地址 + Result *OpenAIForwardResult + APIKey *APIKey + User *User + Account *Account + Subscription *UserSubscription + UserAgent string // 请求的 User-Agent + IPAddress string // 请求的客户端 IP 地址 + APIKeyService APIKeyQuotaUpdater } // RecordUsage records usage and deducts balance @@ -1799,6 +1800,13 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec } } + // Update API key quota if applicable (only for balance mode with quota set) + if shouldBill && cost.ActualCost > 0 && apiKey.Quota > 0 && input.APIKeyService != nil { + if err := input.APIKeyService.UpdateQuotaUsed(ctx, apiKey.ID, cost.ActualCost); err != nil { + log.Printf("Update API key quota failed: %v", err) + } + } + // Schedule batch update for account last_used_at s.deferredService.ScheduleLastUsedUpdate(account.ID) diff --git a/backend/internal/service/redeem_service.go b/backend/internal/service/redeem_service.go index adcafb3f..ad277ca0 100644 --- a/backend/internal/service/redeem_service.go +++ b/backend/internal/service/redeem_service.go @@ -49,6 +49,11 @@ type RedeemCodeRepository interface { List(ctx context.Context, params pagination.PaginationParams) ([]RedeemCode, *pagination.PaginationResult, error) ListWithFilters(ctx context.Context, params pagination.PaginationParams, codeType, status, search string) ([]RedeemCode, *pagination.PaginationResult, error) ListByUser(ctx context.Context, userID int64, limit int) ([]RedeemCode, error) + // ListByUserPaginated returns paginated balance/concurrency history for a specific user. + // codeType filter is optional - pass empty string to return all types. + ListByUserPaginated(ctx context.Context, userID int64, params pagination.PaginationParams, codeType string) ([]RedeemCode, *pagination.PaginationResult, error) + // SumPositiveBalanceByUser returns the total recharged amount (sum of positive balance values) for a user. + SumPositiveBalanceByUser(ctx context.Context, userID int64) (float64, error) } // GenerateCodesRequest 生成兑换码请求 diff --git a/backend/migrations/045_add_api_key_quota.sql b/backend/migrations/045_add_api_key_quota.sql new file mode 100644 index 00000000..b3c42d2c --- /dev/null +++ b/backend/migrations/045_add_api_key_quota.sql @@ -0,0 +1,20 @@ +-- Migration: Add quota fields to api_keys table +-- This migration adds independent quota and expiration support for API keys + +-- Add quota limit field (0 = unlimited) +ALTER TABLE api_keys ADD COLUMN IF NOT EXISTS quota DECIMAL(20, 8) NOT NULL DEFAULT 0; + +-- Add used quota amount field +ALTER TABLE api_keys ADD COLUMN IF NOT EXISTS quota_used DECIMAL(20, 8) NOT NULL DEFAULT 0; + +-- Add expiration time field (NULL = never expires) +ALTER TABLE api_keys ADD COLUMN IF NOT EXISTS expires_at TIMESTAMPTZ; + +-- Add indexes for efficient quota queries +CREATE INDEX IF NOT EXISTS idx_api_keys_quota_quota_used ON api_keys(quota, quota_used) WHERE deleted_at IS NULL; +CREATE INDEX IF NOT EXISTS idx_api_keys_expires_at ON api_keys(expires_at) WHERE deleted_at IS NULL; + +-- Comment on columns for documentation +COMMENT ON COLUMN api_keys.quota IS 'Quota limit in USD for this API key (0 = unlimited)'; +COMMENT ON COLUMN api_keys.quota_used IS 'Used quota amount in USD'; +COMMENT ON COLUMN api_keys.expires_at IS 'Expiration time for this API key (null = never expires)'; diff --git a/backend/tools.go b/backend/tools.go deleted file mode 100644 index f06d2c78..00000000 --- a/backend/tools.go +++ /dev/null @@ -1,9 +0,0 @@ -//go:build tools -// +build tools - -package tools - -import ( - _ "entgo.io/ent/cmd/ent" - _ "github.com/google/wire/cmd/wire" -) diff --git a/frontend/src/api/admin/index.ts b/frontend/src/api/admin/index.ts index a88b02c6..9a8a4195 100644 --- a/frontend/src/api/admin/index.ts +++ b/frontend/src/api/admin/index.ts @@ -62,3 +62,6 @@ export { } export default adminAPI + +// Re-export types used by components +export type { BalanceHistoryItem } from './users' diff --git a/frontend/src/api/admin/users.ts b/frontend/src/api/admin/users.ts index 734e3ac7..287aef96 100644 --- a/frontend/src/api/admin/users.ts +++ b/frontend/src/api/admin/users.ts @@ -174,6 +174,53 @@ export async function getUserUsageStats( return data } +/** + * Balance history item returned from the API + */ +export interface BalanceHistoryItem { + id: number + code: string + type: string + value: number + status: string + used_by: number | null + used_at: string | null + created_at: string + group_id: number | null + validity_days: number + notes: string + user?: { id: number; email: string } | null + group?: { id: number; name: string } | null +} + +// Balance history response extends pagination with total_recharged summary +export interface BalanceHistoryResponse extends PaginatedResponse { + total_recharged: number +} + +/** + * Get user's balance/concurrency change history + * @param id - User ID + * @param page - Page number + * @param pageSize - Items per page + * @param type - Optional type filter (balance, admin_balance, concurrency, admin_concurrency, subscription) + * @returns Paginated balance history with total_recharged + */ +export async function getUserBalanceHistory( + id: number, + page: number = 1, + pageSize: number = 20, + type?: string +): Promise { + const params: Record = { page, page_size: pageSize } + if (type) params.type = type + const { data } = await apiClient.get( + `/admin/users/${id}/balance-history`, + { params } + ) + return data +} + export const usersAPI = { list, getById, @@ -184,7 +231,8 @@ export const usersAPI = { updateConcurrency, toggleStatus, getUserApiKeys, - getUserUsageStats + getUserUsageStats, + getUserBalanceHistory } export default usersAPI diff --git a/frontend/src/api/keys.ts b/frontend/src/api/keys.ts index cdae1359..c5943789 100644 --- a/frontend/src/api/keys.ts +++ b/frontend/src/api/keys.ts @@ -44,6 +44,8 @@ export async function getById(id: number): Promise { * @param customKey - Optional custom key value * @param ipWhitelist - Optional IP whitelist * @param ipBlacklist - Optional IP blacklist + * @param quota - Optional quota limit in USD (0 = unlimited) + * @param expiresInDays - Optional days until expiry (undefined = never expires) * @returns Created API key */ export async function create( @@ -51,7 +53,9 @@ export async function create( groupId?: number | null, customKey?: string, ipWhitelist?: string[], - ipBlacklist?: string[] + ipBlacklist?: string[], + quota?: number, + expiresInDays?: number ): Promise { const payload: CreateApiKeyRequest = { name } if (groupId !== undefined) { @@ -66,6 +70,12 @@ export async function create( if (ipBlacklist && ipBlacklist.length > 0) { payload.ip_blacklist = ipBlacklist } + if (quota !== undefined && quota > 0) { + payload.quota = quota + } + if (expiresInDays !== undefined && expiresInDays > 0) { + payload.expires_in_days = expiresInDays + } const { data } = await apiClient.post('/keys', payload) return data diff --git a/frontend/src/components/admin/user/UserBalanceHistoryModal.vue b/frontend/src/components/admin/user/UserBalanceHistoryModal.vue new file mode 100644 index 00000000..e7dfdb7d --- /dev/null +++ b/frontend/src/components/admin/user/UserBalanceHistoryModal.vue @@ -0,0 +1,320 @@ + + + diff --git a/frontend/src/components/common/BaseDialog.vue b/frontend/src/components/common/BaseDialog.vue index 3d38b568..93e4ba36 100644 --- a/frontend/src/components/common/BaseDialog.vue +++ b/frontend/src/components/common/BaseDialog.vue @@ -4,6 +4,7 @@ + + +
+ + + +
+
+
+ $ + +
+

{{ t('keys.quotaAmountHint') }}

+
+ + +
+ +
+
+ + ${{ selectedKey.quota_used?.toFixed(4) || '0.0000' }} + + / + + ${{ selectedKey.quota?.toFixed(2) || '0.00' }} + +
+ +
+
+
+
+ + +
+
+ + +
+ +
+ +
+ + +
+ + +
+ + +

{{ t('keys.expirationDateHint') }}

+
+ + +
+ {{ t('keys.currentExpiration') }}: + + {{ formatDateTime(selectedKey.expires_at) }} + +
+
+