diff --git a/backend/Dockerfile b/backend/Dockerfile index 770fdedf..4b5b6286 100644 --- a/backend/Dockerfile +++ b/backend/Dockerfile @@ -1,4 +1,4 @@ -FROM golang:1.25.5-alpine +FROM golang:1.25.6-alpine WORKDIR /app @@ -15,7 +15,7 @@ RUN go mod download COPY . . # 构建应用 -RUN go build -o main cmd/server/main.go +RUN go build -o main ./cmd/server/ # 暴露端口 EXPOSE 8080 diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 6422ea20..d30fd955 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -174,8 +174,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { userAttributeService := service.NewUserAttributeService(userAttributeDefinitionRepository, userAttributeValueRepository) userAttributeHandler := admin.NewUserAttributeHandler(userAttributeService) adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler) - gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, configConfig) - openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, configConfig) + gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, configConfig) + openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, configConfig) soraDirectClient := service.NewSoraDirectClient(configConfig, httpUpstream, openAITokenProvider) soraMediaStorage := service.ProvideSoraMediaStorage(configConfig) soraGatewayService := service.NewSoraGatewayService(soraDirectClient, soraMediaStorage, rateLimitService, configConfig) diff --git a/backend/ent/apikey.go b/backend/ent/apikey.go index 95586017..91d71964 100644 --- a/backend/ent/apikey.go +++ b/backend/ent/apikey.go @@ -40,6 +40,12 @@ type APIKey struct { IPWhitelist []string `json:"ip_whitelist,omitempty"` // Blocked IPs/CIDRs IPBlacklist []string `json:"ip_blacklist,omitempty"` + // Quota limit in USD for this API key (0 = unlimited) + Quota float64 `json:"quota,omitempty"` + // Used quota amount in USD + QuotaUsed float64 `json:"quota_used,omitempty"` + // Expiration time for this API key (null = never expires) + ExpiresAt *time.Time `json:"expires_at,omitempty"` // Edges holds the relations/edges for other nodes in the graph. // The values are being populated by the APIKeyQuery when eager-loading is set. Edges APIKeyEdges `json:"edges"` @@ -97,11 +103,13 @@ func (*APIKey) scanValues(columns []string) ([]any, error) { switch columns[i] { case apikey.FieldIPWhitelist, apikey.FieldIPBlacklist: values[i] = new([]byte) + case apikey.FieldQuota, apikey.FieldQuotaUsed: + values[i] = new(sql.NullFloat64) case apikey.FieldID, apikey.FieldUserID, apikey.FieldGroupID: values[i] = new(sql.NullInt64) case apikey.FieldKey, apikey.FieldName, apikey.FieldStatus: values[i] = new(sql.NullString) - case apikey.FieldCreatedAt, apikey.FieldUpdatedAt, apikey.FieldDeletedAt: + case apikey.FieldCreatedAt, apikey.FieldUpdatedAt, apikey.FieldDeletedAt, apikey.FieldExpiresAt: values[i] = new(sql.NullTime) default: values[i] = new(sql.UnknownType) @@ -190,6 +198,25 @@ func (_m *APIKey) assignValues(columns []string, values []any) error { return fmt.Errorf("unmarshal field ip_blacklist: %w", err) } } + case apikey.FieldQuota: + if value, ok := values[i].(*sql.NullFloat64); !ok { + return fmt.Errorf("unexpected type %T for field quota", values[i]) + } else if value.Valid { + _m.Quota = value.Float64 + } + case apikey.FieldQuotaUsed: + if value, ok := values[i].(*sql.NullFloat64); !ok { + return fmt.Errorf("unexpected type %T for field quota_used", values[i]) + } else if value.Valid { + _m.QuotaUsed = value.Float64 + } + case apikey.FieldExpiresAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field expires_at", values[i]) + } else if value.Valid { + _m.ExpiresAt = new(time.Time) + *_m.ExpiresAt = value.Time + } default: _m.selectValues.Set(columns[i], values[i]) } @@ -274,6 +301,17 @@ func (_m *APIKey) String() string { builder.WriteString(", ") builder.WriteString("ip_blacklist=") builder.WriteString(fmt.Sprintf("%v", _m.IPBlacklist)) + builder.WriteString(", ") + builder.WriteString("quota=") + builder.WriteString(fmt.Sprintf("%v", _m.Quota)) + builder.WriteString(", ") + builder.WriteString("quota_used=") + builder.WriteString(fmt.Sprintf("%v", _m.QuotaUsed)) + builder.WriteString(", ") + if v := _m.ExpiresAt; v != nil { + builder.WriteString("expires_at=") + builder.WriteString(v.Format(time.ANSIC)) + } builder.WriteByte(')') return builder.String() } diff --git a/backend/ent/apikey/apikey.go b/backend/ent/apikey/apikey.go index 564cddb1..ac2a6008 100644 --- a/backend/ent/apikey/apikey.go +++ b/backend/ent/apikey/apikey.go @@ -35,6 +35,12 @@ const ( FieldIPWhitelist = "ip_whitelist" // FieldIPBlacklist holds the string denoting the ip_blacklist field in the database. FieldIPBlacklist = "ip_blacklist" + // FieldQuota holds the string denoting the quota field in the database. + FieldQuota = "quota" + // FieldQuotaUsed holds the string denoting the quota_used field in the database. + FieldQuotaUsed = "quota_used" + // FieldExpiresAt holds the string denoting the expires_at field in the database. + FieldExpiresAt = "expires_at" // EdgeUser holds the string denoting the user edge name in mutations. EdgeUser = "user" // EdgeGroup holds the string denoting the group edge name in mutations. @@ -79,6 +85,9 @@ var Columns = []string{ FieldStatus, FieldIPWhitelist, FieldIPBlacklist, + FieldQuota, + FieldQuotaUsed, + FieldExpiresAt, } // ValidColumn reports if the column name is valid (part of the table columns). @@ -113,6 +122,10 @@ var ( DefaultStatus string // StatusValidator is a validator for the "status" field. It is called by the builders before save. StatusValidator func(string) error + // DefaultQuota holds the default value on creation for the "quota" field. + DefaultQuota float64 + // DefaultQuotaUsed holds the default value on creation for the "quota_used" field. + DefaultQuotaUsed float64 ) // OrderOption defines the ordering options for the APIKey queries. @@ -163,6 +176,21 @@ func ByStatus(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldStatus, opts...).ToFunc() } +// ByQuota orders the results by the quota field. +func ByQuota(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldQuota, opts...).ToFunc() +} + +// ByQuotaUsed orders the results by the quota_used field. +func ByQuotaUsed(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldQuotaUsed, opts...).ToFunc() +} + +// ByExpiresAt orders the results by the expires_at field. +func ByExpiresAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldExpiresAt, opts...).ToFunc() +} + // ByUserField orders the results by user field. func ByUserField(field string, opts ...sql.OrderTermOption) OrderOption { return func(s *sql.Selector) { diff --git a/backend/ent/apikey/where.go b/backend/ent/apikey/where.go index 5152867f..f54f44b7 100644 --- a/backend/ent/apikey/where.go +++ b/backend/ent/apikey/where.go @@ -95,6 +95,21 @@ func Status(v string) predicate.APIKey { return predicate.APIKey(sql.FieldEQ(FieldStatus, v)) } +// Quota applies equality check predicate on the "quota" field. It's identical to QuotaEQ. +func Quota(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldQuota, v)) +} + +// QuotaUsed applies equality check predicate on the "quota_used" field. It's identical to QuotaUsedEQ. +func QuotaUsed(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldQuotaUsed, v)) +} + +// ExpiresAt applies equality check predicate on the "expires_at" field. It's identical to ExpiresAtEQ. +func ExpiresAt(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldExpiresAt, v)) +} + // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.APIKey { return predicate.APIKey(sql.FieldEQ(FieldCreatedAt, v)) @@ -490,6 +505,136 @@ func IPBlacklistNotNil() predicate.APIKey { return predicate.APIKey(sql.FieldNotNull(FieldIPBlacklist)) } +// QuotaEQ applies the EQ predicate on the "quota" field. +func QuotaEQ(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldQuota, v)) +} + +// QuotaNEQ applies the NEQ predicate on the "quota" field. +func QuotaNEQ(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldNEQ(FieldQuota, v)) +} + +// QuotaIn applies the In predicate on the "quota" field. +func QuotaIn(vs ...float64) predicate.APIKey { + return predicate.APIKey(sql.FieldIn(FieldQuota, vs...)) +} + +// QuotaNotIn applies the NotIn predicate on the "quota" field. +func QuotaNotIn(vs ...float64) predicate.APIKey { + return predicate.APIKey(sql.FieldNotIn(FieldQuota, vs...)) +} + +// QuotaGT applies the GT predicate on the "quota" field. +func QuotaGT(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldGT(FieldQuota, v)) +} + +// QuotaGTE applies the GTE predicate on the "quota" field. +func QuotaGTE(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldGTE(FieldQuota, v)) +} + +// QuotaLT applies the LT predicate on the "quota" field. +func QuotaLT(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldLT(FieldQuota, v)) +} + +// QuotaLTE applies the LTE predicate on the "quota" field. +func QuotaLTE(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldLTE(FieldQuota, v)) +} + +// QuotaUsedEQ applies the EQ predicate on the "quota_used" field. +func QuotaUsedEQ(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldQuotaUsed, v)) +} + +// QuotaUsedNEQ applies the NEQ predicate on the "quota_used" field. +func QuotaUsedNEQ(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldNEQ(FieldQuotaUsed, v)) +} + +// QuotaUsedIn applies the In predicate on the "quota_used" field. +func QuotaUsedIn(vs ...float64) predicate.APIKey { + return predicate.APIKey(sql.FieldIn(FieldQuotaUsed, vs...)) +} + +// QuotaUsedNotIn applies the NotIn predicate on the "quota_used" field. +func QuotaUsedNotIn(vs ...float64) predicate.APIKey { + return predicate.APIKey(sql.FieldNotIn(FieldQuotaUsed, vs...)) +} + +// QuotaUsedGT applies the GT predicate on the "quota_used" field. +func QuotaUsedGT(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldGT(FieldQuotaUsed, v)) +} + +// QuotaUsedGTE applies the GTE predicate on the "quota_used" field. +func QuotaUsedGTE(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldGTE(FieldQuotaUsed, v)) +} + +// QuotaUsedLT applies the LT predicate on the "quota_used" field. +func QuotaUsedLT(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldLT(FieldQuotaUsed, v)) +} + +// QuotaUsedLTE applies the LTE predicate on the "quota_used" field. +func QuotaUsedLTE(v float64) predicate.APIKey { + return predicate.APIKey(sql.FieldLTE(FieldQuotaUsed, v)) +} + +// ExpiresAtEQ applies the EQ predicate on the "expires_at" field. +func ExpiresAtEQ(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldExpiresAt, v)) +} + +// ExpiresAtNEQ applies the NEQ predicate on the "expires_at" field. +func ExpiresAtNEQ(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldNEQ(FieldExpiresAt, v)) +} + +// ExpiresAtIn applies the In predicate on the "expires_at" field. +func ExpiresAtIn(vs ...time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldIn(FieldExpiresAt, vs...)) +} + +// ExpiresAtNotIn applies the NotIn predicate on the "expires_at" field. +func ExpiresAtNotIn(vs ...time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldNotIn(FieldExpiresAt, vs...)) +} + +// ExpiresAtGT applies the GT predicate on the "expires_at" field. +func ExpiresAtGT(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldGT(FieldExpiresAt, v)) +} + +// ExpiresAtGTE applies the GTE predicate on the "expires_at" field. +func ExpiresAtGTE(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldGTE(FieldExpiresAt, v)) +} + +// ExpiresAtLT applies the LT predicate on the "expires_at" field. +func ExpiresAtLT(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldLT(FieldExpiresAt, v)) +} + +// ExpiresAtLTE applies the LTE predicate on the "expires_at" field. +func ExpiresAtLTE(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldLTE(FieldExpiresAt, v)) +} + +// ExpiresAtIsNil applies the IsNil predicate on the "expires_at" field. +func ExpiresAtIsNil() predicate.APIKey { + return predicate.APIKey(sql.FieldIsNull(FieldExpiresAt)) +} + +// ExpiresAtNotNil applies the NotNil predicate on the "expires_at" field. +func ExpiresAtNotNil() predicate.APIKey { + return predicate.APIKey(sql.FieldNotNull(FieldExpiresAt)) +} + // HasUser applies the HasEdge predicate on the "user" edge. func HasUser() predicate.APIKey { return predicate.APIKey(func(s *sql.Selector) { diff --git a/backend/ent/apikey_create.go b/backend/ent/apikey_create.go index d5363be5..71540975 100644 --- a/backend/ent/apikey_create.go +++ b/backend/ent/apikey_create.go @@ -125,6 +125,48 @@ func (_c *APIKeyCreate) SetIPBlacklist(v []string) *APIKeyCreate { return _c } +// SetQuota sets the "quota" field. +func (_c *APIKeyCreate) SetQuota(v float64) *APIKeyCreate { + _c.mutation.SetQuota(v) + return _c +} + +// SetNillableQuota sets the "quota" field if the given value is not nil. +func (_c *APIKeyCreate) SetNillableQuota(v *float64) *APIKeyCreate { + if v != nil { + _c.SetQuota(*v) + } + return _c +} + +// SetQuotaUsed sets the "quota_used" field. +func (_c *APIKeyCreate) SetQuotaUsed(v float64) *APIKeyCreate { + _c.mutation.SetQuotaUsed(v) + return _c +} + +// SetNillableQuotaUsed sets the "quota_used" field if the given value is not nil. +func (_c *APIKeyCreate) SetNillableQuotaUsed(v *float64) *APIKeyCreate { + if v != nil { + _c.SetQuotaUsed(*v) + } + return _c +} + +// SetExpiresAt sets the "expires_at" field. +func (_c *APIKeyCreate) SetExpiresAt(v time.Time) *APIKeyCreate { + _c.mutation.SetExpiresAt(v) + return _c +} + +// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil. +func (_c *APIKeyCreate) SetNillableExpiresAt(v *time.Time) *APIKeyCreate { + if v != nil { + _c.SetExpiresAt(*v) + } + return _c +} + // SetUser sets the "user" edge to the User entity. func (_c *APIKeyCreate) SetUser(v *User) *APIKeyCreate { return _c.SetUserID(v.ID) @@ -205,6 +247,14 @@ func (_c *APIKeyCreate) defaults() error { v := apikey.DefaultStatus _c.mutation.SetStatus(v) } + if _, ok := _c.mutation.Quota(); !ok { + v := apikey.DefaultQuota + _c.mutation.SetQuota(v) + } + if _, ok := _c.mutation.QuotaUsed(); !ok { + v := apikey.DefaultQuotaUsed + _c.mutation.SetQuotaUsed(v) + } return nil } @@ -243,6 +293,12 @@ func (_c *APIKeyCreate) check() error { return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "APIKey.status": %w`, err)} } } + if _, ok := _c.mutation.Quota(); !ok { + return &ValidationError{Name: "quota", err: errors.New(`ent: missing required field "APIKey.quota"`)} + } + if _, ok := _c.mutation.QuotaUsed(); !ok { + return &ValidationError{Name: "quota_used", err: errors.New(`ent: missing required field "APIKey.quota_used"`)} + } if len(_c.mutation.UserIDs()) == 0 { return &ValidationError{Name: "user", err: errors.New(`ent: missing required edge "APIKey.user"`)} } @@ -305,6 +361,18 @@ func (_c *APIKeyCreate) createSpec() (*APIKey, *sqlgraph.CreateSpec) { _spec.SetField(apikey.FieldIPBlacklist, field.TypeJSON, value) _node.IPBlacklist = value } + if value, ok := _c.mutation.Quota(); ok { + _spec.SetField(apikey.FieldQuota, field.TypeFloat64, value) + _node.Quota = value + } + if value, ok := _c.mutation.QuotaUsed(); ok { + _spec.SetField(apikey.FieldQuotaUsed, field.TypeFloat64, value) + _node.QuotaUsed = value + } + if value, ok := _c.mutation.ExpiresAt(); ok { + _spec.SetField(apikey.FieldExpiresAt, field.TypeTime, value) + _node.ExpiresAt = &value + } if nodes := _c.mutation.UserIDs(); len(nodes) > 0 { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.M2O, @@ -539,6 +607,60 @@ func (u *APIKeyUpsert) ClearIPBlacklist() *APIKeyUpsert { return u } +// SetQuota sets the "quota" field. +func (u *APIKeyUpsert) SetQuota(v float64) *APIKeyUpsert { + u.Set(apikey.FieldQuota, v) + return u +} + +// UpdateQuota sets the "quota" field to the value that was provided on create. +func (u *APIKeyUpsert) UpdateQuota() *APIKeyUpsert { + u.SetExcluded(apikey.FieldQuota) + return u +} + +// AddQuota adds v to the "quota" field. +func (u *APIKeyUpsert) AddQuota(v float64) *APIKeyUpsert { + u.Add(apikey.FieldQuota, v) + return u +} + +// SetQuotaUsed sets the "quota_used" field. +func (u *APIKeyUpsert) SetQuotaUsed(v float64) *APIKeyUpsert { + u.Set(apikey.FieldQuotaUsed, v) + return u +} + +// UpdateQuotaUsed sets the "quota_used" field to the value that was provided on create. +func (u *APIKeyUpsert) UpdateQuotaUsed() *APIKeyUpsert { + u.SetExcluded(apikey.FieldQuotaUsed) + return u +} + +// AddQuotaUsed adds v to the "quota_used" field. +func (u *APIKeyUpsert) AddQuotaUsed(v float64) *APIKeyUpsert { + u.Add(apikey.FieldQuotaUsed, v) + return u +} + +// SetExpiresAt sets the "expires_at" field. +func (u *APIKeyUpsert) SetExpiresAt(v time.Time) *APIKeyUpsert { + u.Set(apikey.FieldExpiresAt, v) + return u +} + +// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create. +func (u *APIKeyUpsert) UpdateExpiresAt() *APIKeyUpsert { + u.SetExcluded(apikey.FieldExpiresAt) + return u +} + +// ClearExpiresAt clears the value of the "expires_at" field. +func (u *APIKeyUpsert) ClearExpiresAt() *APIKeyUpsert { + u.SetNull(apikey.FieldExpiresAt) + return u +} + // UpdateNewValues updates the mutable fields using the new values that were set on create. // Using this option is equivalent to using: // @@ -738,6 +860,69 @@ func (u *APIKeyUpsertOne) ClearIPBlacklist() *APIKeyUpsertOne { }) } +// SetQuota sets the "quota" field. +func (u *APIKeyUpsertOne) SetQuota(v float64) *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.SetQuota(v) + }) +} + +// AddQuota adds v to the "quota" field. +func (u *APIKeyUpsertOne) AddQuota(v float64) *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.AddQuota(v) + }) +} + +// UpdateQuota sets the "quota" field to the value that was provided on create. +func (u *APIKeyUpsertOne) UpdateQuota() *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateQuota() + }) +} + +// SetQuotaUsed sets the "quota_used" field. +func (u *APIKeyUpsertOne) SetQuotaUsed(v float64) *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.SetQuotaUsed(v) + }) +} + +// AddQuotaUsed adds v to the "quota_used" field. +func (u *APIKeyUpsertOne) AddQuotaUsed(v float64) *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.AddQuotaUsed(v) + }) +} + +// UpdateQuotaUsed sets the "quota_used" field to the value that was provided on create. +func (u *APIKeyUpsertOne) UpdateQuotaUsed() *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateQuotaUsed() + }) +} + +// SetExpiresAt sets the "expires_at" field. +func (u *APIKeyUpsertOne) SetExpiresAt(v time.Time) *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.SetExpiresAt(v) + }) +} + +// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create. +func (u *APIKeyUpsertOne) UpdateExpiresAt() *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateExpiresAt() + }) +} + +// ClearExpiresAt clears the value of the "expires_at" field. +func (u *APIKeyUpsertOne) ClearExpiresAt() *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.ClearExpiresAt() + }) +} + // Exec executes the query. func (u *APIKeyUpsertOne) Exec(ctx context.Context) error { if len(u.create.conflict) == 0 { @@ -1103,6 +1288,69 @@ func (u *APIKeyUpsertBulk) ClearIPBlacklist() *APIKeyUpsertBulk { }) } +// SetQuota sets the "quota" field. +func (u *APIKeyUpsertBulk) SetQuota(v float64) *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.SetQuota(v) + }) +} + +// AddQuota adds v to the "quota" field. +func (u *APIKeyUpsertBulk) AddQuota(v float64) *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.AddQuota(v) + }) +} + +// UpdateQuota sets the "quota" field to the value that was provided on create. +func (u *APIKeyUpsertBulk) UpdateQuota() *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateQuota() + }) +} + +// SetQuotaUsed sets the "quota_used" field. +func (u *APIKeyUpsertBulk) SetQuotaUsed(v float64) *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.SetQuotaUsed(v) + }) +} + +// AddQuotaUsed adds v to the "quota_used" field. +func (u *APIKeyUpsertBulk) AddQuotaUsed(v float64) *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.AddQuotaUsed(v) + }) +} + +// UpdateQuotaUsed sets the "quota_used" field to the value that was provided on create. +func (u *APIKeyUpsertBulk) UpdateQuotaUsed() *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateQuotaUsed() + }) +} + +// SetExpiresAt sets the "expires_at" field. +func (u *APIKeyUpsertBulk) SetExpiresAt(v time.Time) *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.SetExpiresAt(v) + }) +} + +// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create. +func (u *APIKeyUpsertBulk) UpdateExpiresAt() *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateExpiresAt() + }) +} + +// ClearExpiresAt clears the value of the "expires_at" field. +func (u *APIKeyUpsertBulk) ClearExpiresAt() *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.ClearExpiresAt() + }) +} + // Exec executes the query. func (u *APIKeyUpsertBulk) Exec(ctx context.Context) error { if u.create.err != nil { diff --git a/backend/ent/apikey_update.go b/backend/ent/apikey_update.go index 9ae332a8..b4ff230b 100644 --- a/backend/ent/apikey_update.go +++ b/backend/ent/apikey_update.go @@ -170,6 +170,68 @@ func (_u *APIKeyUpdate) ClearIPBlacklist() *APIKeyUpdate { return _u } +// SetQuota sets the "quota" field. +func (_u *APIKeyUpdate) SetQuota(v float64) *APIKeyUpdate { + _u.mutation.ResetQuota() + _u.mutation.SetQuota(v) + return _u +} + +// SetNillableQuota sets the "quota" field if the given value is not nil. +func (_u *APIKeyUpdate) SetNillableQuota(v *float64) *APIKeyUpdate { + if v != nil { + _u.SetQuota(*v) + } + return _u +} + +// AddQuota adds value to the "quota" field. +func (_u *APIKeyUpdate) AddQuota(v float64) *APIKeyUpdate { + _u.mutation.AddQuota(v) + return _u +} + +// SetQuotaUsed sets the "quota_used" field. +func (_u *APIKeyUpdate) SetQuotaUsed(v float64) *APIKeyUpdate { + _u.mutation.ResetQuotaUsed() + _u.mutation.SetQuotaUsed(v) + return _u +} + +// SetNillableQuotaUsed sets the "quota_used" field if the given value is not nil. +func (_u *APIKeyUpdate) SetNillableQuotaUsed(v *float64) *APIKeyUpdate { + if v != nil { + _u.SetQuotaUsed(*v) + } + return _u +} + +// AddQuotaUsed adds value to the "quota_used" field. +func (_u *APIKeyUpdate) AddQuotaUsed(v float64) *APIKeyUpdate { + _u.mutation.AddQuotaUsed(v) + return _u +} + +// SetExpiresAt sets the "expires_at" field. +func (_u *APIKeyUpdate) SetExpiresAt(v time.Time) *APIKeyUpdate { + _u.mutation.SetExpiresAt(v) + return _u +} + +// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil. +func (_u *APIKeyUpdate) SetNillableExpiresAt(v *time.Time) *APIKeyUpdate { + if v != nil { + _u.SetExpiresAt(*v) + } + return _u +} + +// ClearExpiresAt clears the value of the "expires_at" field. +func (_u *APIKeyUpdate) ClearExpiresAt() *APIKeyUpdate { + _u.mutation.ClearExpiresAt() + return _u +} + // SetUser sets the "user" edge to the User entity. func (_u *APIKeyUpdate) SetUser(v *User) *APIKeyUpdate { return _u.SetUserID(v.ID) @@ -350,6 +412,24 @@ func (_u *APIKeyUpdate) sqlSave(ctx context.Context) (_node int, err error) { if _u.mutation.IPBlacklistCleared() { _spec.ClearField(apikey.FieldIPBlacklist, field.TypeJSON) } + if value, ok := _u.mutation.Quota(); ok { + _spec.SetField(apikey.FieldQuota, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedQuota(); ok { + _spec.AddField(apikey.FieldQuota, field.TypeFloat64, value) + } + if value, ok := _u.mutation.QuotaUsed(); ok { + _spec.SetField(apikey.FieldQuotaUsed, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedQuotaUsed(); ok { + _spec.AddField(apikey.FieldQuotaUsed, field.TypeFloat64, value) + } + if value, ok := _u.mutation.ExpiresAt(); ok { + _spec.SetField(apikey.FieldExpiresAt, field.TypeTime, value) + } + if _u.mutation.ExpiresAtCleared() { + _spec.ClearField(apikey.FieldExpiresAt, field.TypeTime) + } if _u.mutation.UserCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.M2O, @@ -611,6 +691,68 @@ func (_u *APIKeyUpdateOne) ClearIPBlacklist() *APIKeyUpdateOne { return _u } +// SetQuota sets the "quota" field. +func (_u *APIKeyUpdateOne) SetQuota(v float64) *APIKeyUpdateOne { + _u.mutation.ResetQuota() + _u.mutation.SetQuota(v) + return _u +} + +// SetNillableQuota sets the "quota" field if the given value is not nil. +func (_u *APIKeyUpdateOne) SetNillableQuota(v *float64) *APIKeyUpdateOne { + if v != nil { + _u.SetQuota(*v) + } + return _u +} + +// AddQuota adds value to the "quota" field. +func (_u *APIKeyUpdateOne) AddQuota(v float64) *APIKeyUpdateOne { + _u.mutation.AddQuota(v) + return _u +} + +// SetQuotaUsed sets the "quota_used" field. +func (_u *APIKeyUpdateOne) SetQuotaUsed(v float64) *APIKeyUpdateOne { + _u.mutation.ResetQuotaUsed() + _u.mutation.SetQuotaUsed(v) + return _u +} + +// SetNillableQuotaUsed sets the "quota_used" field if the given value is not nil. +func (_u *APIKeyUpdateOne) SetNillableQuotaUsed(v *float64) *APIKeyUpdateOne { + if v != nil { + _u.SetQuotaUsed(*v) + } + return _u +} + +// AddQuotaUsed adds value to the "quota_used" field. +func (_u *APIKeyUpdateOne) AddQuotaUsed(v float64) *APIKeyUpdateOne { + _u.mutation.AddQuotaUsed(v) + return _u +} + +// SetExpiresAt sets the "expires_at" field. +func (_u *APIKeyUpdateOne) SetExpiresAt(v time.Time) *APIKeyUpdateOne { + _u.mutation.SetExpiresAt(v) + return _u +} + +// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil. +func (_u *APIKeyUpdateOne) SetNillableExpiresAt(v *time.Time) *APIKeyUpdateOne { + if v != nil { + _u.SetExpiresAt(*v) + } + return _u +} + +// ClearExpiresAt clears the value of the "expires_at" field. +func (_u *APIKeyUpdateOne) ClearExpiresAt() *APIKeyUpdateOne { + _u.mutation.ClearExpiresAt() + return _u +} + // SetUser sets the "user" edge to the User entity. func (_u *APIKeyUpdateOne) SetUser(v *User) *APIKeyUpdateOne { return _u.SetUserID(v.ID) @@ -821,6 +963,24 @@ func (_u *APIKeyUpdateOne) sqlSave(ctx context.Context) (_node *APIKey, err erro if _u.mutation.IPBlacklistCleared() { _spec.ClearField(apikey.FieldIPBlacklist, field.TypeJSON) } + if value, ok := _u.mutation.Quota(); ok { + _spec.SetField(apikey.FieldQuota, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedQuota(); ok { + _spec.AddField(apikey.FieldQuota, field.TypeFloat64, value) + } + if value, ok := _u.mutation.QuotaUsed(); ok { + _spec.SetField(apikey.FieldQuotaUsed, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedQuotaUsed(); ok { + _spec.AddField(apikey.FieldQuotaUsed, field.TypeFloat64, value) + } + if value, ok := _u.mutation.ExpiresAt(); ok { + _spec.SetField(apikey.FieldExpiresAt, field.TypeTime, value) + } + if _u.mutation.ExpiresAtCleared() { + _spec.ClearField(apikey.FieldExpiresAt, field.TypeTime) + } if _u.mutation.UserCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.M2O, diff --git a/backend/ent/group.go b/backend/ent/group.go index 0a32543b..8bfdca42 100644 --- a/backend/ent/group.go +++ b/backend/ent/group.go @@ -64,10 +64,16 @@ type Group struct { ClaudeCodeOnly bool `json:"claude_code_only,omitempty"` // 非 Claude Code 请求降级使用的分组 ID FallbackGroupID *int64 `json:"fallback_group_id,omitempty"` + // 无效请求兜底使用的分组 ID + FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request,omitempty"` // 模型路由配置:模型模式 -> 优先账号ID列表 ModelRouting map[string][]int64 `json:"model_routing,omitempty"` // 是否启用模型路由配置 ModelRoutingEnabled bool `json:"model_routing_enabled,omitempty"` + // 是否注入 MCP XML 调用协议提示词(仅 antigravity 平台) + McpXMLInject bool `json:"mcp_xml_inject,omitempty"` + // 支持的模型系列:claude, gemini_text, gemini_image + SupportedModelScopes []string `json:"supported_model_scopes,omitempty"` // Edges holds the relations/edges for other nodes in the graph. // The values are being populated by the GroupQuery when eager-loading is set. Edges GroupEdges `json:"edges"` @@ -174,13 +180,13 @@ func (*Group) scanValues(columns []string) ([]any, error) { values := make([]any, len(columns)) for i := range columns { switch columns[i] { - case group.FieldModelRouting: + case group.FieldModelRouting, group.FieldSupportedModelScopes: values[i] = new([]byte) - case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled: + case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject: values[i] = new(sql.NullBool) case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k, group.FieldSoraImagePrice360, group.FieldSoraImagePrice540, group.FieldSoraVideoPricePerRequest, group.FieldSoraVideoPricePerRequestHd: values[i] = new(sql.NullFloat64) - case group.FieldID, group.FieldDefaultValidityDays, group.FieldFallbackGroupID: + case group.FieldID, group.FieldDefaultValidityDays, group.FieldFallbackGroupID, group.FieldFallbackGroupIDOnInvalidRequest: values[i] = new(sql.NullInt64) case group.FieldName, group.FieldDescription, group.FieldStatus, group.FieldPlatform, group.FieldSubscriptionType: values[i] = new(sql.NullString) @@ -358,6 +364,13 @@ func (_m *Group) assignValues(columns []string, values []any) error { _m.FallbackGroupID = new(int64) *_m.FallbackGroupID = value.Int64 } + case group.FieldFallbackGroupIDOnInvalidRequest: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field fallback_group_id_on_invalid_request", values[i]) + } else if value.Valid { + _m.FallbackGroupIDOnInvalidRequest = new(int64) + *_m.FallbackGroupIDOnInvalidRequest = value.Int64 + } case group.FieldModelRouting: if value, ok := values[i].(*[]byte); !ok { return fmt.Errorf("unexpected type %T for field model_routing", values[i]) @@ -372,6 +385,20 @@ func (_m *Group) assignValues(columns []string, values []any) error { } else if value.Valid { _m.ModelRoutingEnabled = value.Bool } + case group.FieldMcpXMLInject: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field mcp_xml_inject", values[i]) + } else if value.Valid { + _m.McpXMLInject = value.Bool + } + case group.FieldSupportedModelScopes: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field supported_model_scopes", values[i]) + } else if value != nil && len(*value) > 0 { + if err := json.Unmarshal(*value, &_m.SupportedModelScopes); err != nil { + return fmt.Errorf("unmarshal field supported_model_scopes: %w", err) + } + } default: _m.selectValues.Set(columns[i], values[i]) } @@ -543,11 +570,22 @@ func (_m *Group) String() string { builder.WriteString(fmt.Sprintf("%v", *v)) } builder.WriteString(", ") + if v := _m.FallbackGroupIDOnInvalidRequest; v != nil { + builder.WriteString("fallback_group_id_on_invalid_request=") + builder.WriteString(fmt.Sprintf("%v", *v)) + } + builder.WriteString(", ") builder.WriteString("model_routing=") builder.WriteString(fmt.Sprintf("%v", _m.ModelRouting)) builder.WriteString(", ") builder.WriteString("model_routing_enabled=") builder.WriteString(fmt.Sprintf("%v", _m.ModelRoutingEnabled)) + builder.WriteString(", ") + builder.WriteString("mcp_xml_inject=") + builder.WriteString(fmt.Sprintf("%v", _m.McpXMLInject)) + builder.WriteString(", ") + builder.WriteString("supported_model_scopes=") + builder.WriteString(fmt.Sprintf("%v", _m.SupportedModelScopes)) builder.WriteByte(')') return builder.String() } diff --git a/backend/ent/group/group.go b/backend/ent/group/group.go index 7470dd82..7bafc615 100644 --- a/backend/ent/group/group.go +++ b/backend/ent/group/group.go @@ -61,10 +61,16 @@ const ( FieldClaudeCodeOnly = "claude_code_only" // FieldFallbackGroupID holds the string denoting the fallback_group_id field in the database. FieldFallbackGroupID = "fallback_group_id" + // FieldFallbackGroupIDOnInvalidRequest holds the string denoting the fallback_group_id_on_invalid_request field in the database. + FieldFallbackGroupIDOnInvalidRequest = "fallback_group_id_on_invalid_request" // FieldModelRouting holds the string denoting the model_routing field in the database. FieldModelRouting = "model_routing" // FieldModelRoutingEnabled holds the string denoting the model_routing_enabled field in the database. FieldModelRoutingEnabled = "model_routing_enabled" + // FieldMcpXMLInject holds the string denoting the mcp_xml_inject field in the database. + FieldMcpXMLInject = "mcp_xml_inject" + // FieldSupportedModelScopes holds the string denoting the supported_model_scopes field in the database. + FieldSupportedModelScopes = "supported_model_scopes" // EdgeAPIKeys holds the string denoting the api_keys edge name in mutations. EdgeAPIKeys = "api_keys" // EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations. @@ -163,8 +169,11 @@ var Columns = []string{ FieldSoraVideoPricePerRequestHd, FieldClaudeCodeOnly, FieldFallbackGroupID, + FieldFallbackGroupIDOnInvalidRequest, FieldModelRouting, FieldModelRoutingEnabled, + FieldMcpXMLInject, + FieldSupportedModelScopes, } var ( @@ -224,6 +233,10 @@ var ( DefaultClaudeCodeOnly bool // DefaultModelRoutingEnabled holds the default value on creation for the "model_routing_enabled" field. DefaultModelRoutingEnabled bool + // DefaultMcpXMLInject holds the default value on creation for the "mcp_xml_inject" field. + DefaultMcpXMLInject bool + // DefaultSupportedModelScopes holds the default value on creation for the "supported_model_scopes" field. + DefaultSupportedModelScopes []string ) // OrderOption defines the ordering options for the Group queries. @@ -349,11 +362,21 @@ func ByFallbackGroupID(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldFallbackGroupID, opts...).ToFunc() } +// ByFallbackGroupIDOnInvalidRequest orders the results by the fallback_group_id_on_invalid_request field. +func ByFallbackGroupIDOnInvalidRequest(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldFallbackGroupIDOnInvalidRequest, opts...).ToFunc() +} + // ByModelRoutingEnabled orders the results by the model_routing_enabled field. func ByModelRoutingEnabled(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldModelRoutingEnabled, opts...).ToFunc() } +// ByMcpXMLInject orders the results by the mcp_xml_inject field. +func ByMcpXMLInject(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldMcpXMLInject, opts...).ToFunc() +} + // ByAPIKeysCount orders the results by api_keys count. func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption { return func(s *sql.Selector) { diff --git a/backend/ent/group/where.go b/backend/ent/group/where.go index 3f8f4c04..fb30fe86 100644 --- a/backend/ent/group/where.go +++ b/backend/ent/group/where.go @@ -170,11 +170,21 @@ func FallbackGroupID(v int64) predicate.Group { return predicate.Group(sql.FieldEQ(FieldFallbackGroupID, v)) } +// FallbackGroupIDOnInvalidRequest applies equality check predicate on the "fallback_group_id_on_invalid_request" field. It's identical to FallbackGroupIDOnInvalidRequestEQ. +func FallbackGroupIDOnInvalidRequest(v int64) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldFallbackGroupIDOnInvalidRequest, v)) +} + // ModelRoutingEnabled applies equality check predicate on the "model_routing_enabled" field. It's identical to ModelRoutingEnabledEQ. func ModelRoutingEnabled(v bool) predicate.Group { return predicate.Group(sql.FieldEQ(FieldModelRoutingEnabled, v)) } +// McpXMLInject applies equality check predicate on the "mcp_xml_inject" field. It's identical to McpXMLInjectEQ. +func McpXMLInject(v bool) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldMcpXMLInject, v)) +} + // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.Group { return predicate.Group(sql.FieldEQ(FieldCreatedAt, v)) @@ -1290,6 +1300,56 @@ func FallbackGroupIDNotNil() predicate.Group { return predicate.Group(sql.FieldNotNull(FieldFallbackGroupID)) } +// FallbackGroupIDOnInvalidRequestEQ applies the EQ predicate on the "fallback_group_id_on_invalid_request" field. +func FallbackGroupIDOnInvalidRequestEQ(v int64) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldFallbackGroupIDOnInvalidRequest, v)) +} + +// FallbackGroupIDOnInvalidRequestNEQ applies the NEQ predicate on the "fallback_group_id_on_invalid_request" field. +func FallbackGroupIDOnInvalidRequestNEQ(v int64) predicate.Group { + return predicate.Group(sql.FieldNEQ(FieldFallbackGroupIDOnInvalidRequest, v)) +} + +// FallbackGroupIDOnInvalidRequestIn applies the In predicate on the "fallback_group_id_on_invalid_request" field. +func FallbackGroupIDOnInvalidRequestIn(vs ...int64) predicate.Group { + return predicate.Group(sql.FieldIn(FieldFallbackGroupIDOnInvalidRequest, vs...)) +} + +// FallbackGroupIDOnInvalidRequestNotIn applies the NotIn predicate on the "fallback_group_id_on_invalid_request" field. +func FallbackGroupIDOnInvalidRequestNotIn(vs ...int64) predicate.Group { + return predicate.Group(sql.FieldNotIn(FieldFallbackGroupIDOnInvalidRequest, vs...)) +} + +// FallbackGroupIDOnInvalidRequestGT applies the GT predicate on the "fallback_group_id_on_invalid_request" field. +func FallbackGroupIDOnInvalidRequestGT(v int64) predicate.Group { + return predicate.Group(sql.FieldGT(FieldFallbackGroupIDOnInvalidRequest, v)) +} + +// FallbackGroupIDOnInvalidRequestGTE applies the GTE predicate on the "fallback_group_id_on_invalid_request" field. +func FallbackGroupIDOnInvalidRequestGTE(v int64) predicate.Group { + return predicate.Group(sql.FieldGTE(FieldFallbackGroupIDOnInvalidRequest, v)) +} + +// FallbackGroupIDOnInvalidRequestLT applies the LT predicate on the "fallback_group_id_on_invalid_request" field. +func FallbackGroupIDOnInvalidRequestLT(v int64) predicate.Group { + return predicate.Group(sql.FieldLT(FieldFallbackGroupIDOnInvalidRequest, v)) +} + +// FallbackGroupIDOnInvalidRequestLTE applies the LTE predicate on the "fallback_group_id_on_invalid_request" field. +func FallbackGroupIDOnInvalidRequestLTE(v int64) predicate.Group { + return predicate.Group(sql.FieldLTE(FieldFallbackGroupIDOnInvalidRequest, v)) +} + +// FallbackGroupIDOnInvalidRequestIsNil applies the IsNil predicate on the "fallback_group_id_on_invalid_request" field. +func FallbackGroupIDOnInvalidRequestIsNil() predicate.Group { + return predicate.Group(sql.FieldIsNull(FieldFallbackGroupIDOnInvalidRequest)) +} + +// FallbackGroupIDOnInvalidRequestNotNil applies the NotNil predicate on the "fallback_group_id_on_invalid_request" field. +func FallbackGroupIDOnInvalidRequestNotNil() predicate.Group { + return predicate.Group(sql.FieldNotNull(FieldFallbackGroupIDOnInvalidRequest)) +} + // ModelRoutingIsNil applies the IsNil predicate on the "model_routing" field. func ModelRoutingIsNil() predicate.Group { return predicate.Group(sql.FieldIsNull(FieldModelRouting)) @@ -1310,6 +1370,16 @@ func ModelRoutingEnabledNEQ(v bool) predicate.Group { return predicate.Group(sql.FieldNEQ(FieldModelRoutingEnabled, v)) } +// McpXMLInjectEQ applies the EQ predicate on the "mcp_xml_inject" field. +func McpXMLInjectEQ(v bool) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldMcpXMLInject, v)) +} + +// McpXMLInjectNEQ applies the NEQ predicate on the "mcp_xml_inject" field. +func McpXMLInjectNEQ(v bool) predicate.Group { + return predicate.Group(sql.FieldNEQ(FieldMcpXMLInject, v)) +} + // HasAPIKeys applies the HasEdge predicate on the "api_keys" edge. func HasAPIKeys() predicate.Group { return predicate.Group(func(s *sql.Selector) { diff --git a/backend/ent/group_create.go b/backend/ent/group_create.go index ac5cb4d5..2ce0f730 100644 --- a/backend/ent/group_create.go +++ b/backend/ent/group_create.go @@ -342,6 +342,20 @@ func (_c *GroupCreate) SetNillableFallbackGroupID(v *int64) *GroupCreate { return _c } +// SetFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field. +func (_c *GroupCreate) SetFallbackGroupIDOnInvalidRequest(v int64) *GroupCreate { + _c.mutation.SetFallbackGroupIDOnInvalidRequest(v) + return _c +} + +// SetNillableFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field if the given value is not nil. +func (_c *GroupCreate) SetNillableFallbackGroupIDOnInvalidRequest(v *int64) *GroupCreate { + if v != nil { + _c.SetFallbackGroupIDOnInvalidRequest(*v) + } + return _c +} + // SetModelRouting sets the "model_routing" field. func (_c *GroupCreate) SetModelRouting(v map[string][]int64) *GroupCreate { _c.mutation.SetModelRouting(v) @@ -362,6 +376,26 @@ func (_c *GroupCreate) SetNillableModelRoutingEnabled(v *bool) *GroupCreate { return _c } +// SetMcpXMLInject sets the "mcp_xml_inject" field. +func (_c *GroupCreate) SetMcpXMLInject(v bool) *GroupCreate { + _c.mutation.SetMcpXMLInject(v) + return _c +} + +// SetNillableMcpXMLInject sets the "mcp_xml_inject" field if the given value is not nil. +func (_c *GroupCreate) SetNillableMcpXMLInject(v *bool) *GroupCreate { + if v != nil { + _c.SetMcpXMLInject(*v) + } + return _c +} + +// SetSupportedModelScopes sets the "supported_model_scopes" field. +func (_c *GroupCreate) SetSupportedModelScopes(v []string) *GroupCreate { + _c.mutation.SetSupportedModelScopes(v) + return _c +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_c *GroupCreate) AddAPIKeyIDs(ids ...int64) *GroupCreate { _c.mutation.AddAPIKeyIDs(ids...) @@ -535,6 +569,14 @@ func (_c *GroupCreate) defaults() error { v := group.DefaultModelRoutingEnabled _c.mutation.SetModelRoutingEnabled(v) } + if _, ok := _c.mutation.McpXMLInject(); !ok { + v := group.DefaultMcpXMLInject + _c.mutation.SetMcpXMLInject(v) + } + if _, ok := _c.mutation.SupportedModelScopes(); !ok { + v := group.DefaultSupportedModelScopes + _c.mutation.SetSupportedModelScopes(v) + } return nil } @@ -593,6 +635,12 @@ func (_c *GroupCreate) check() error { if _, ok := _c.mutation.ModelRoutingEnabled(); !ok { return &ValidationError{Name: "model_routing_enabled", err: errors.New(`ent: missing required field "Group.model_routing_enabled"`)} } + if _, ok := _c.mutation.McpXMLInject(); !ok { + return &ValidationError{Name: "mcp_xml_inject", err: errors.New(`ent: missing required field "Group.mcp_xml_inject"`)} + } + if _, ok := _c.mutation.SupportedModelScopes(); !ok { + return &ValidationError{Name: "supported_model_scopes", err: errors.New(`ent: missing required field "Group.supported_model_scopes"`)} + } return nil } @@ -712,6 +760,10 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) { _spec.SetField(group.FieldFallbackGroupID, field.TypeInt64, value) _node.FallbackGroupID = &value } + if value, ok := _c.mutation.FallbackGroupIDOnInvalidRequest(); ok { + _spec.SetField(group.FieldFallbackGroupIDOnInvalidRequest, field.TypeInt64, value) + _node.FallbackGroupIDOnInvalidRequest = &value + } if value, ok := _c.mutation.ModelRouting(); ok { _spec.SetField(group.FieldModelRouting, field.TypeJSON, value) _node.ModelRouting = value @@ -720,6 +772,14 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) { _spec.SetField(group.FieldModelRoutingEnabled, field.TypeBool, value) _node.ModelRoutingEnabled = value } + if value, ok := _c.mutation.McpXMLInject(); ok { + _spec.SetField(group.FieldMcpXMLInject, field.TypeBool, value) + _node.McpXMLInject = value + } + if value, ok := _c.mutation.SupportedModelScopes(); ok { + _spec.SetField(group.FieldSupportedModelScopes, field.TypeJSON, value) + _node.SupportedModelScopes = value + } if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, @@ -1296,6 +1356,30 @@ func (u *GroupUpsert) ClearFallbackGroupID() *GroupUpsert { return u } +// SetFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field. +func (u *GroupUpsert) SetFallbackGroupIDOnInvalidRequest(v int64) *GroupUpsert { + u.Set(group.FieldFallbackGroupIDOnInvalidRequest, v) + return u +} + +// UpdateFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field to the value that was provided on create. +func (u *GroupUpsert) UpdateFallbackGroupIDOnInvalidRequest() *GroupUpsert { + u.SetExcluded(group.FieldFallbackGroupIDOnInvalidRequest) + return u +} + +// AddFallbackGroupIDOnInvalidRequest adds v to the "fallback_group_id_on_invalid_request" field. +func (u *GroupUpsert) AddFallbackGroupIDOnInvalidRequest(v int64) *GroupUpsert { + u.Add(group.FieldFallbackGroupIDOnInvalidRequest, v) + return u +} + +// ClearFallbackGroupIDOnInvalidRequest clears the value of the "fallback_group_id_on_invalid_request" field. +func (u *GroupUpsert) ClearFallbackGroupIDOnInvalidRequest() *GroupUpsert { + u.SetNull(group.FieldFallbackGroupIDOnInvalidRequest) + return u +} + // SetModelRouting sets the "model_routing" field. func (u *GroupUpsert) SetModelRouting(v map[string][]int64) *GroupUpsert { u.Set(group.FieldModelRouting, v) @@ -1326,6 +1410,30 @@ func (u *GroupUpsert) UpdateModelRoutingEnabled() *GroupUpsert { return u } +// SetMcpXMLInject sets the "mcp_xml_inject" field. +func (u *GroupUpsert) SetMcpXMLInject(v bool) *GroupUpsert { + u.Set(group.FieldMcpXMLInject, v) + return u +} + +// UpdateMcpXMLInject sets the "mcp_xml_inject" field to the value that was provided on create. +func (u *GroupUpsert) UpdateMcpXMLInject() *GroupUpsert { + u.SetExcluded(group.FieldMcpXMLInject) + return u +} + +// SetSupportedModelScopes sets the "supported_model_scopes" field. +func (u *GroupUpsert) SetSupportedModelScopes(v []string) *GroupUpsert { + u.Set(group.FieldSupportedModelScopes, v) + return u +} + +// UpdateSupportedModelScopes sets the "supported_model_scopes" field to the value that was provided on create. +func (u *GroupUpsert) UpdateSupportedModelScopes() *GroupUpsert { + u.SetExcluded(group.FieldSupportedModelScopes) + return u +} + // UpdateNewValues updates the mutable fields using the new values that were set on create. // Using this option is equivalent to using: // @@ -1861,6 +1969,34 @@ func (u *GroupUpsertOne) ClearFallbackGroupID() *GroupUpsertOne { }) } +// SetFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field. +func (u *GroupUpsertOne) SetFallbackGroupIDOnInvalidRequest(v int64) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetFallbackGroupIDOnInvalidRequest(v) + }) +} + +// AddFallbackGroupIDOnInvalidRequest adds v to the "fallback_group_id_on_invalid_request" field. +func (u *GroupUpsertOne) AddFallbackGroupIDOnInvalidRequest(v int64) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.AddFallbackGroupIDOnInvalidRequest(v) + }) +} + +// UpdateFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdateFallbackGroupIDOnInvalidRequest() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdateFallbackGroupIDOnInvalidRequest() + }) +} + +// ClearFallbackGroupIDOnInvalidRequest clears the value of the "fallback_group_id_on_invalid_request" field. +func (u *GroupUpsertOne) ClearFallbackGroupIDOnInvalidRequest() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.ClearFallbackGroupIDOnInvalidRequest() + }) +} + // SetModelRouting sets the "model_routing" field. func (u *GroupUpsertOne) SetModelRouting(v map[string][]int64) *GroupUpsertOne { return u.Update(func(s *GroupUpsert) { @@ -1896,6 +2032,34 @@ func (u *GroupUpsertOne) UpdateModelRoutingEnabled() *GroupUpsertOne { }) } +// SetMcpXMLInject sets the "mcp_xml_inject" field. +func (u *GroupUpsertOne) SetMcpXMLInject(v bool) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetMcpXMLInject(v) + }) +} + +// UpdateMcpXMLInject sets the "mcp_xml_inject" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdateMcpXMLInject() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdateMcpXMLInject() + }) +} + +// SetSupportedModelScopes sets the "supported_model_scopes" field. +func (u *GroupUpsertOne) SetSupportedModelScopes(v []string) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetSupportedModelScopes(v) + }) +} + +// UpdateSupportedModelScopes sets the "supported_model_scopes" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdateSupportedModelScopes() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdateSupportedModelScopes() + }) +} + // Exec executes the query. func (u *GroupUpsertOne) Exec(ctx context.Context) error { if len(u.create.conflict) == 0 { @@ -2597,6 +2761,34 @@ func (u *GroupUpsertBulk) ClearFallbackGroupID() *GroupUpsertBulk { }) } +// SetFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field. +func (u *GroupUpsertBulk) SetFallbackGroupIDOnInvalidRequest(v int64) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetFallbackGroupIDOnInvalidRequest(v) + }) +} + +// AddFallbackGroupIDOnInvalidRequest adds v to the "fallback_group_id_on_invalid_request" field. +func (u *GroupUpsertBulk) AddFallbackGroupIDOnInvalidRequest(v int64) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.AddFallbackGroupIDOnInvalidRequest(v) + }) +} + +// UpdateFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdateFallbackGroupIDOnInvalidRequest() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdateFallbackGroupIDOnInvalidRequest() + }) +} + +// ClearFallbackGroupIDOnInvalidRequest clears the value of the "fallback_group_id_on_invalid_request" field. +func (u *GroupUpsertBulk) ClearFallbackGroupIDOnInvalidRequest() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.ClearFallbackGroupIDOnInvalidRequest() + }) +} + // SetModelRouting sets the "model_routing" field. func (u *GroupUpsertBulk) SetModelRouting(v map[string][]int64) *GroupUpsertBulk { return u.Update(func(s *GroupUpsert) { @@ -2632,6 +2824,34 @@ func (u *GroupUpsertBulk) UpdateModelRoutingEnabled() *GroupUpsertBulk { }) } +// SetMcpXMLInject sets the "mcp_xml_inject" field. +func (u *GroupUpsertBulk) SetMcpXMLInject(v bool) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetMcpXMLInject(v) + }) +} + +// UpdateMcpXMLInject sets the "mcp_xml_inject" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdateMcpXMLInject() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdateMcpXMLInject() + }) +} + +// SetSupportedModelScopes sets the "supported_model_scopes" field. +func (u *GroupUpsertBulk) SetSupportedModelScopes(v []string) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetSupportedModelScopes(v) + }) +} + +// UpdateSupportedModelScopes sets the "supported_model_scopes" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdateSupportedModelScopes() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdateSupportedModelScopes() + }) +} + // Exec executes the query. func (u *GroupUpsertBulk) Exec(ctx context.Context) error { if u.create.err != nil { diff --git a/backend/ent/group_update.go b/backend/ent/group_update.go index 528a7fe9..f2142ce4 100644 --- a/backend/ent/group_update.go +++ b/backend/ent/group_update.go @@ -10,6 +10,7 @@ import ( "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/dialect/sql/sqljson" "entgo.io/ent/schema/field" "github.com/Wei-Shaw/sub2api/ent/account" "github.com/Wei-Shaw/sub2api/ent/apikey" @@ -503,6 +504,33 @@ func (_u *GroupUpdate) ClearFallbackGroupID() *GroupUpdate { return _u } +// SetFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field. +func (_u *GroupUpdate) SetFallbackGroupIDOnInvalidRequest(v int64) *GroupUpdate { + _u.mutation.ResetFallbackGroupIDOnInvalidRequest() + _u.mutation.SetFallbackGroupIDOnInvalidRequest(v) + return _u +} + +// SetNillableFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field if the given value is not nil. +func (_u *GroupUpdate) SetNillableFallbackGroupIDOnInvalidRequest(v *int64) *GroupUpdate { + if v != nil { + _u.SetFallbackGroupIDOnInvalidRequest(*v) + } + return _u +} + +// AddFallbackGroupIDOnInvalidRequest adds value to the "fallback_group_id_on_invalid_request" field. +func (_u *GroupUpdate) AddFallbackGroupIDOnInvalidRequest(v int64) *GroupUpdate { + _u.mutation.AddFallbackGroupIDOnInvalidRequest(v) + return _u +} + +// ClearFallbackGroupIDOnInvalidRequest clears the value of the "fallback_group_id_on_invalid_request" field. +func (_u *GroupUpdate) ClearFallbackGroupIDOnInvalidRequest() *GroupUpdate { + _u.mutation.ClearFallbackGroupIDOnInvalidRequest() + return _u +} + // SetModelRouting sets the "model_routing" field. func (_u *GroupUpdate) SetModelRouting(v map[string][]int64) *GroupUpdate { _u.mutation.SetModelRouting(v) @@ -529,6 +557,32 @@ func (_u *GroupUpdate) SetNillableModelRoutingEnabled(v *bool) *GroupUpdate { return _u } +// SetMcpXMLInject sets the "mcp_xml_inject" field. +func (_u *GroupUpdate) SetMcpXMLInject(v bool) *GroupUpdate { + _u.mutation.SetMcpXMLInject(v) + return _u +} + +// SetNillableMcpXMLInject sets the "mcp_xml_inject" field if the given value is not nil. +func (_u *GroupUpdate) SetNillableMcpXMLInject(v *bool) *GroupUpdate { + if v != nil { + _u.SetMcpXMLInject(*v) + } + return _u +} + +// SetSupportedModelScopes sets the "supported_model_scopes" field. +func (_u *GroupUpdate) SetSupportedModelScopes(v []string) *GroupUpdate { + _u.mutation.SetSupportedModelScopes(v) + return _u +} + +// AppendSupportedModelScopes appends value to the "supported_model_scopes" field. +func (_u *GroupUpdate) AppendSupportedModelScopes(v []string) *GroupUpdate { + _u.mutation.AppendSupportedModelScopes(v) + return _u +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_u *GroupUpdate) AddAPIKeyIDs(ids ...int64) *GroupUpdate { _u.mutation.AddAPIKeyIDs(ids...) @@ -973,6 +1027,15 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) { if _u.mutation.FallbackGroupIDCleared() { _spec.ClearField(group.FieldFallbackGroupID, field.TypeInt64) } + if value, ok := _u.mutation.FallbackGroupIDOnInvalidRequest(); ok { + _spec.SetField(group.FieldFallbackGroupIDOnInvalidRequest, field.TypeInt64, value) + } + if value, ok := _u.mutation.AddedFallbackGroupIDOnInvalidRequest(); ok { + _spec.AddField(group.FieldFallbackGroupIDOnInvalidRequest, field.TypeInt64, value) + } + if _u.mutation.FallbackGroupIDOnInvalidRequestCleared() { + _spec.ClearField(group.FieldFallbackGroupIDOnInvalidRequest, field.TypeInt64) + } if value, ok := _u.mutation.ModelRouting(); ok { _spec.SetField(group.FieldModelRouting, field.TypeJSON, value) } @@ -982,6 +1045,17 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) { if value, ok := _u.mutation.ModelRoutingEnabled(); ok { _spec.SetField(group.FieldModelRoutingEnabled, field.TypeBool, value) } + if value, ok := _u.mutation.McpXMLInject(); ok { + _spec.SetField(group.FieldMcpXMLInject, field.TypeBool, value) + } + if value, ok := _u.mutation.SupportedModelScopes(); ok { + _spec.SetField(group.FieldSupportedModelScopes, field.TypeJSON, value) + } + if value, ok := _u.mutation.AppendedSupportedModelScopes(); ok { + _spec.AddModifier(func(u *sql.UpdateBuilder) { + sqljson.Append(u, group.FieldSupportedModelScopes, value) + }) + } if _u.mutation.APIKeysCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, @@ -1765,6 +1839,33 @@ func (_u *GroupUpdateOne) ClearFallbackGroupID() *GroupUpdateOne { return _u } +// SetFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field. +func (_u *GroupUpdateOne) SetFallbackGroupIDOnInvalidRequest(v int64) *GroupUpdateOne { + _u.mutation.ResetFallbackGroupIDOnInvalidRequest() + _u.mutation.SetFallbackGroupIDOnInvalidRequest(v) + return _u +} + +// SetNillableFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field if the given value is not nil. +func (_u *GroupUpdateOne) SetNillableFallbackGroupIDOnInvalidRequest(v *int64) *GroupUpdateOne { + if v != nil { + _u.SetFallbackGroupIDOnInvalidRequest(*v) + } + return _u +} + +// AddFallbackGroupIDOnInvalidRequest adds value to the "fallback_group_id_on_invalid_request" field. +func (_u *GroupUpdateOne) AddFallbackGroupIDOnInvalidRequest(v int64) *GroupUpdateOne { + _u.mutation.AddFallbackGroupIDOnInvalidRequest(v) + return _u +} + +// ClearFallbackGroupIDOnInvalidRequest clears the value of the "fallback_group_id_on_invalid_request" field. +func (_u *GroupUpdateOne) ClearFallbackGroupIDOnInvalidRequest() *GroupUpdateOne { + _u.mutation.ClearFallbackGroupIDOnInvalidRequest() + return _u +} + // SetModelRouting sets the "model_routing" field. func (_u *GroupUpdateOne) SetModelRouting(v map[string][]int64) *GroupUpdateOne { _u.mutation.SetModelRouting(v) @@ -1791,6 +1892,32 @@ func (_u *GroupUpdateOne) SetNillableModelRoutingEnabled(v *bool) *GroupUpdateOn return _u } +// SetMcpXMLInject sets the "mcp_xml_inject" field. +func (_u *GroupUpdateOne) SetMcpXMLInject(v bool) *GroupUpdateOne { + _u.mutation.SetMcpXMLInject(v) + return _u +} + +// SetNillableMcpXMLInject sets the "mcp_xml_inject" field if the given value is not nil. +func (_u *GroupUpdateOne) SetNillableMcpXMLInject(v *bool) *GroupUpdateOne { + if v != nil { + _u.SetMcpXMLInject(*v) + } + return _u +} + +// SetSupportedModelScopes sets the "supported_model_scopes" field. +func (_u *GroupUpdateOne) SetSupportedModelScopes(v []string) *GroupUpdateOne { + _u.mutation.SetSupportedModelScopes(v) + return _u +} + +// AppendSupportedModelScopes appends value to the "supported_model_scopes" field. +func (_u *GroupUpdateOne) AppendSupportedModelScopes(v []string) *GroupUpdateOne { + _u.mutation.AppendSupportedModelScopes(v) + return _u +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_u *GroupUpdateOne) AddAPIKeyIDs(ids ...int64) *GroupUpdateOne { _u.mutation.AddAPIKeyIDs(ids...) @@ -2265,6 +2392,15 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error) if _u.mutation.FallbackGroupIDCleared() { _spec.ClearField(group.FieldFallbackGroupID, field.TypeInt64) } + if value, ok := _u.mutation.FallbackGroupIDOnInvalidRequest(); ok { + _spec.SetField(group.FieldFallbackGroupIDOnInvalidRequest, field.TypeInt64, value) + } + if value, ok := _u.mutation.AddedFallbackGroupIDOnInvalidRequest(); ok { + _spec.AddField(group.FieldFallbackGroupIDOnInvalidRequest, field.TypeInt64, value) + } + if _u.mutation.FallbackGroupIDOnInvalidRequestCleared() { + _spec.ClearField(group.FieldFallbackGroupIDOnInvalidRequest, field.TypeInt64) + } if value, ok := _u.mutation.ModelRouting(); ok { _spec.SetField(group.FieldModelRouting, field.TypeJSON, value) } @@ -2274,6 +2410,17 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error) if value, ok := _u.mutation.ModelRoutingEnabled(); ok { _spec.SetField(group.FieldModelRoutingEnabled, field.TypeBool, value) } + if value, ok := _u.mutation.McpXMLInject(); ok { + _spec.SetField(group.FieldMcpXMLInject, field.TypeBool, value) + } + if value, ok := _u.mutation.SupportedModelScopes(); ok { + _spec.SetField(group.FieldSupportedModelScopes, field.TypeJSON, value) + } + if value, ok := _u.mutation.AppendedSupportedModelScopes(); ok { + _spec.AddModifier(func(u *sql.UpdateBuilder) { + sqljson.Append(u, group.FieldSupportedModelScopes, value) + }) + } if _u.mutation.APIKeysCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go index 8df0cdb3..1536d40e 100644 --- a/backend/ent/migrate/schema.go +++ b/backend/ent/migrate/schema.go @@ -20,6 +20,9 @@ var ( {Name: "status", Type: field.TypeString, Size: 20, Default: "active"}, {Name: "ip_whitelist", Type: field.TypeJSON, Nullable: true}, {Name: "ip_blacklist", Type: field.TypeJSON, Nullable: true}, + {Name: "quota", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, + {Name: "quota_used", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, + {Name: "expires_at", Type: field.TypeTime, Nullable: true}, {Name: "group_id", Type: field.TypeInt64, Nullable: true}, {Name: "user_id", Type: field.TypeInt64}, } @@ -31,13 +34,13 @@ var ( ForeignKeys: []*schema.ForeignKey{ { Symbol: "api_keys_groups_api_keys", - Columns: []*schema.Column{APIKeysColumns[9]}, + Columns: []*schema.Column{APIKeysColumns[12]}, RefColumns: []*schema.Column{GroupsColumns[0]}, OnDelete: schema.SetNull, }, { Symbol: "api_keys_users_api_keys", - Columns: []*schema.Column{APIKeysColumns[10]}, + Columns: []*schema.Column{APIKeysColumns[13]}, RefColumns: []*schema.Column{UsersColumns[0]}, OnDelete: schema.NoAction, }, @@ -46,12 +49,12 @@ var ( { Name: "apikey_user_id", Unique: false, - Columns: []*schema.Column{APIKeysColumns[10]}, + Columns: []*schema.Column{APIKeysColumns[13]}, }, { Name: "apikey_group_id", Unique: false, - Columns: []*schema.Column{APIKeysColumns[9]}, + Columns: []*schema.Column{APIKeysColumns[12]}, }, { Name: "apikey_status", @@ -63,6 +66,16 @@ var ( Unique: false, Columns: []*schema.Column{APIKeysColumns[3]}, }, + { + Name: "apikey_quota_quota_used", + Unique: false, + Columns: []*schema.Column{APIKeysColumns[9], APIKeysColumns[10]}, + }, + { + Name: "apikey_expires_at", + Unique: false, + Columns: []*schema.Column{APIKeysColumns[11]}, + }, }, } // AccountsColumns holds the columns for the "accounts" table. @@ -322,8 +335,11 @@ var ( {Name: "sora_video_price_per_request_hd", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, {Name: "claude_code_only", Type: field.TypeBool, Default: false}, {Name: "fallback_group_id", Type: field.TypeInt64, Nullable: true}, + {Name: "fallback_group_id_on_invalid_request", Type: field.TypeInt64, Nullable: true}, {Name: "model_routing", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}}, {Name: "model_routing_enabled", Type: field.TypeBool, Default: false}, + {Name: "mcp_xml_inject", Type: field.TypeBool, Default: true}, + {Name: "supported_model_scopes", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}}, } // GroupsTable holds the schema information for the "groups" table. GroupsTable = &schema.Table{ diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go index f12ccb4f..c30e5559 100644 --- a/backend/ent/mutation.go +++ b/backend/ent/mutation.go @@ -79,6 +79,11 @@ type APIKeyMutation struct { appendip_whitelist []string ip_blacklist *[]string appendip_blacklist []string + quota *float64 + addquota *float64 + quota_used *float64 + addquota_used *float64 + expires_at *time.Time clearedFields map[string]struct{} user *int64 cleareduser bool @@ -634,6 +639,167 @@ func (m *APIKeyMutation) ResetIPBlacklist() { delete(m.clearedFields, apikey.FieldIPBlacklist) } +// SetQuota sets the "quota" field. +func (m *APIKeyMutation) SetQuota(f float64) { + m.quota = &f + m.addquota = nil +} + +// Quota returns the value of the "quota" field in the mutation. +func (m *APIKeyMutation) Quota() (r float64, exists bool) { + v := m.quota + if v == nil { + return + } + return *v, true +} + +// OldQuota returns the old "quota" field's value of the APIKey entity. +// If the APIKey object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *APIKeyMutation) OldQuota(ctx context.Context) (v float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldQuota is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldQuota requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldQuota: %w", err) + } + return oldValue.Quota, nil +} + +// AddQuota adds f to the "quota" field. +func (m *APIKeyMutation) AddQuota(f float64) { + if m.addquota != nil { + *m.addquota += f + } else { + m.addquota = &f + } +} + +// AddedQuota returns the value that was added to the "quota" field in this mutation. +func (m *APIKeyMutation) AddedQuota() (r float64, exists bool) { + v := m.addquota + if v == nil { + return + } + return *v, true +} + +// ResetQuota resets all changes to the "quota" field. +func (m *APIKeyMutation) ResetQuota() { + m.quota = nil + m.addquota = nil +} + +// SetQuotaUsed sets the "quota_used" field. +func (m *APIKeyMutation) SetQuotaUsed(f float64) { + m.quota_used = &f + m.addquota_used = nil +} + +// QuotaUsed returns the value of the "quota_used" field in the mutation. +func (m *APIKeyMutation) QuotaUsed() (r float64, exists bool) { + v := m.quota_used + if v == nil { + return + } + return *v, true +} + +// OldQuotaUsed returns the old "quota_used" field's value of the APIKey entity. +// If the APIKey object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *APIKeyMutation) OldQuotaUsed(ctx context.Context) (v float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldQuotaUsed is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldQuotaUsed requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldQuotaUsed: %w", err) + } + return oldValue.QuotaUsed, nil +} + +// AddQuotaUsed adds f to the "quota_used" field. +func (m *APIKeyMutation) AddQuotaUsed(f float64) { + if m.addquota_used != nil { + *m.addquota_used += f + } else { + m.addquota_used = &f + } +} + +// AddedQuotaUsed returns the value that was added to the "quota_used" field in this mutation. +func (m *APIKeyMutation) AddedQuotaUsed() (r float64, exists bool) { + v := m.addquota_used + if v == nil { + return + } + return *v, true +} + +// ResetQuotaUsed resets all changes to the "quota_used" field. +func (m *APIKeyMutation) ResetQuotaUsed() { + m.quota_used = nil + m.addquota_used = nil +} + +// SetExpiresAt sets the "expires_at" field. +func (m *APIKeyMutation) SetExpiresAt(t time.Time) { + m.expires_at = &t +} + +// ExpiresAt returns the value of the "expires_at" field in the mutation. +func (m *APIKeyMutation) ExpiresAt() (r time.Time, exists bool) { + v := m.expires_at + if v == nil { + return + } + return *v, true +} + +// OldExpiresAt returns the old "expires_at" field's value of the APIKey entity. +// If the APIKey object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *APIKeyMutation) OldExpiresAt(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldExpiresAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldExpiresAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldExpiresAt: %w", err) + } + return oldValue.ExpiresAt, nil +} + +// ClearExpiresAt clears the value of the "expires_at" field. +func (m *APIKeyMutation) ClearExpiresAt() { + m.expires_at = nil + m.clearedFields[apikey.FieldExpiresAt] = struct{}{} +} + +// ExpiresAtCleared returns if the "expires_at" field was cleared in this mutation. +func (m *APIKeyMutation) ExpiresAtCleared() bool { + _, ok := m.clearedFields[apikey.FieldExpiresAt] + return ok +} + +// ResetExpiresAt resets all changes to the "expires_at" field. +func (m *APIKeyMutation) ResetExpiresAt() { + m.expires_at = nil + delete(m.clearedFields, apikey.FieldExpiresAt) +} + // ClearUser clears the "user" edge to the User entity. func (m *APIKeyMutation) ClearUser() { m.cleareduser = true @@ -776,7 +942,7 @@ func (m *APIKeyMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *APIKeyMutation) Fields() []string { - fields := make([]string, 0, 10) + fields := make([]string, 0, 13) if m.created_at != nil { fields = append(fields, apikey.FieldCreatedAt) } @@ -807,6 +973,15 @@ func (m *APIKeyMutation) Fields() []string { if m.ip_blacklist != nil { fields = append(fields, apikey.FieldIPBlacklist) } + if m.quota != nil { + fields = append(fields, apikey.FieldQuota) + } + if m.quota_used != nil { + fields = append(fields, apikey.FieldQuotaUsed) + } + if m.expires_at != nil { + fields = append(fields, apikey.FieldExpiresAt) + } return fields } @@ -835,6 +1010,12 @@ func (m *APIKeyMutation) Field(name string) (ent.Value, bool) { return m.IPWhitelist() case apikey.FieldIPBlacklist: return m.IPBlacklist() + case apikey.FieldQuota: + return m.Quota() + case apikey.FieldQuotaUsed: + return m.QuotaUsed() + case apikey.FieldExpiresAt: + return m.ExpiresAt() } return nil, false } @@ -864,6 +1045,12 @@ func (m *APIKeyMutation) OldField(ctx context.Context, name string) (ent.Value, return m.OldIPWhitelist(ctx) case apikey.FieldIPBlacklist: return m.OldIPBlacklist(ctx) + case apikey.FieldQuota: + return m.OldQuota(ctx) + case apikey.FieldQuotaUsed: + return m.OldQuotaUsed(ctx) + case apikey.FieldExpiresAt: + return m.OldExpiresAt(ctx) } return nil, fmt.Errorf("unknown APIKey field %s", name) } @@ -943,6 +1130,27 @@ func (m *APIKeyMutation) SetField(name string, value ent.Value) error { } m.SetIPBlacklist(v) return nil + case apikey.FieldQuota: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetQuota(v) + return nil + case apikey.FieldQuotaUsed: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetQuotaUsed(v) + return nil + case apikey.FieldExpiresAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetExpiresAt(v) + return nil } return fmt.Errorf("unknown APIKey field %s", name) } @@ -951,6 +1159,12 @@ func (m *APIKeyMutation) SetField(name string, value ent.Value) error { // this mutation. func (m *APIKeyMutation) AddedFields() []string { var fields []string + if m.addquota != nil { + fields = append(fields, apikey.FieldQuota) + } + if m.addquota_used != nil { + fields = append(fields, apikey.FieldQuotaUsed) + } return fields } @@ -959,6 +1173,10 @@ func (m *APIKeyMutation) AddedFields() []string { // was not set, or was not defined in the schema. func (m *APIKeyMutation) AddedField(name string) (ent.Value, bool) { switch name { + case apikey.FieldQuota: + return m.AddedQuota() + case apikey.FieldQuotaUsed: + return m.AddedQuotaUsed() } return nil, false } @@ -968,6 +1186,20 @@ func (m *APIKeyMutation) AddedField(name string) (ent.Value, bool) { // type. func (m *APIKeyMutation) AddField(name string, value ent.Value) error { switch name { + case apikey.FieldQuota: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddQuota(v) + return nil + case apikey.FieldQuotaUsed: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddQuotaUsed(v) + return nil } return fmt.Errorf("unknown APIKey numeric field %s", name) } @@ -988,6 +1220,9 @@ func (m *APIKeyMutation) ClearedFields() []string { if m.FieldCleared(apikey.FieldIPBlacklist) { fields = append(fields, apikey.FieldIPBlacklist) } + if m.FieldCleared(apikey.FieldExpiresAt) { + fields = append(fields, apikey.FieldExpiresAt) + } return fields } @@ -1014,6 +1249,9 @@ func (m *APIKeyMutation) ClearField(name string) error { case apikey.FieldIPBlacklist: m.ClearIPBlacklist() return nil + case apikey.FieldExpiresAt: + m.ClearExpiresAt() + return nil } return fmt.Errorf("unknown APIKey nullable field %s", name) } @@ -1052,6 +1290,15 @@ func (m *APIKeyMutation) ResetField(name string) error { case apikey.FieldIPBlacklist: m.ResetIPBlacklist() return nil + case apikey.FieldQuota: + m.ResetQuota() + return nil + case apikey.FieldQuotaUsed: + m.ResetQuotaUsed() + return nil + case apikey.FieldExpiresAt: + m.ResetExpiresAt() + return nil } return fmt.Errorf("unknown APIKey field %s", name) } @@ -5506,69 +5753,74 @@ func (m *AnnouncementReadMutation) ResetEdge(name string) error { // GroupMutation represents an operation that mutates the Group nodes in the graph. type GroupMutation struct { config - op Op - typ string - id *int64 - created_at *time.Time - updated_at *time.Time - deleted_at *time.Time - name *string - description *string - rate_multiplier *float64 - addrate_multiplier *float64 - is_exclusive *bool - status *string - platform *string - subscription_type *string - daily_limit_usd *float64 - adddaily_limit_usd *float64 - weekly_limit_usd *float64 - addweekly_limit_usd *float64 - monthly_limit_usd *float64 - addmonthly_limit_usd *float64 - default_validity_days *int - adddefault_validity_days *int - image_price_1k *float64 - addimage_price_1k *float64 - image_price_2k *float64 - addimage_price_2k *float64 - image_price_4k *float64 - addimage_price_4k *float64 - sora_image_price_360 *float64 - addsora_image_price_360 *float64 - sora_image_price_540 *float64 - addsora_image_price_540 *float64 - sora_video_price_per_request *float64 - addsora_video_price_per_request *float64 - sora_video_price_per_request_hd *float64 - addsora_video_price_per_request_hd *float64 - claude_code_only *bool - fallback_group_id *int64 - addfallback_group_id *int64 - model_routing *map[string][]int64 - model_routing_enabled *bool - clearedFields map[string]struct{} - api_keys map[int64]struct{} - removedapi_keys map[int64]struct{} - clearedapi_keys bool - redeem_codes map[int64]struct{} - removedredeem_codes map[int64]struct{} - clearedredeem_codes bool - subscriptions map[int64]struct{} - removedsubscriptions map[int64]struct{} - clearedsubscriptions bool - usage_logs map[int64]struct{} - removedusage_logs map[int64]struct{} - clearedusage_logs bool - accounts map[int64]struct{} - removedaccounts map[int64]struct{} - clearedaccounts bool - allowed_users map[int64]struct{} - removedallowed_users map[int64]struct{} - clearedallowed_users bool - done bool - oldValue func(context.Context) (*Group, error) - predicates []predicate.Group + op Op + typ string + id *int64 + created_at *time.Time + updated_at *time.Time + deleted_at *time.Time + name *string + description *string + rate_multiplier *float64 + addrate_multiplier *float64 + is_exclusive *bool + status *string + platform *string + subscription_type *string + daily_limit_usd *float64 + adddaily_limit_usd *float64 + weekly_limit_usd *float64 + addweekly_limit_usd *float64 + monthly_limit_usd *float64 + addmonthly_limit_usd *float64 + default_validity_days *int + adddefault_validity_days *int + image_price_1k *float64 + addimage_price_1k *float64 + image_price_2k *float64 + addimage_price_2k *float64 + image_price_4k *float64 + addimage_price_4k *float64 + sora_image_price_360 *float64 + addsora_image_price_360 *float64 + sora_image_price_540 *float64 + addsora_image_price_540 *float64 + sora_video_price_per_request *float64 + addsora_video_price_per_request *float64 + sora_video_price_per_request_hd *float64 + addsora_video_price_per_request_hd *float64 + claude_code_only *bool + fallback_group_id *int64 + addfallback_group_id *int64 + fallback_group_id_on_invalid_request *int64 + addfallback_group_id_on_invalid_request *int64 + model_routing *map[string][]int64 + model_routing_enabled *bool + mcp_xml_inject *bool + supported_model_scopes *[]string + appendsupported_model_scopes []string + clearedFields map[string]struct{} + api_keys map[int64]struct{} + removedapi_keys map[int64]struct{} + clearedapi_keys bool + redeem_codes map[int64]struct{} + removedredeem_codes map[int64]struct{} + clearedredeem_codes bool + subscriptions map[int64]struct{} + removedsubscriptions map[int64]struct{} + clearedsubscriptions bool + usage_logs map[int64]struct{} + removedusage_logs map[int64]struct{} + clearedusage_logs bool + accounts map[int64]struct{} + removedaccounts map[int64]struct{} + clearedaccounts bool + allowed_users map[int64]struct{} + removedallowed_users map[int64]struct{} + clearedallowed_users bool + done bool + oldValue func(context.Context) (*Group, error) + predicates []predicate.Group } var _ ent.Mutation = (*GroupMutation)(nil) @@ -6937,6 +7189,76 @@ func (m *GroupMutation) ResetFallbackGroupID() { delete(m.clearedFields, group.FieldFallbackGroupID) } +// SetFallbackGroupIDOnInvalidRequest sets the "fallback_group_id_on_invalid_request" field. +func (m *GroupMutation) SetFallbackGroupIDOnInvalidRequest(i int64) { + m.fallback_group_id_on_invalid_request = &i + m.addfallback_group_id_on_invalid_request = nil +} + +// FallbackGroupIDOnInvalidRequest returns the value of the "fallback_group_id_on_invalid_request" field in the mutation. +func (m *GroupMutation) FallbackGroupIDOnInvalidRequest() (r int64, exists bool) { + v := m.fallback_group_id_on_invalid_request + if v == nil { + return + } + return *v, true +} + +// OldFallbackGroupIDOnInvalidRequest returns the old "fallback_group_id_on_invalid_request" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldFallbackGroupIDOnInvalidRequest(ctx context.Context) (v *int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldFallbackGroupIDOnInvalidRequest is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldFallbackGroupIDOnInvalidRequest requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldFallbackGroupIDOnInvalidRequest: %w", err) + } + return oldValue.FallbackGroupIDOnInvalidRequest, nil +} + +// AddFallbackGroupIDOnInvalidRequest adds i to the "fallback_group_id_on_invalid_request" field. +func (m *GroupMutation) AddFallbackGroupIDOnInvalidRequest(i int64) { + if m.addfallback_group_id_on_invalid_request != nil { + *m.addfallback_group_id_on_invalid_request += i + } else { + m.addfallback_group_id_on_invalid_request = &i + } +} + +// AddedFallbackGroupIDOnInvalidRequest returns the value that was added to the "fallback_group_id_on_invalid_request" field in this mutation. +func (m *GroupMutation) AddedFallbackGroupIDOnInvalidRequest() (r int64, exists bool) { + v := m.addfallback_group_id_on_invalid_request + if v == nil { + return + } + return *v, true +} + +// ClearFallbackGroupIDOnInvalidRequest clears the value of the "fallback_group_id_on_invalid_request" field. +func (m *GroupMutation) ClearFallbackGroupIDOnInvalidRequest() { + m.fallback_group_id_on_invalid_request = nil + m.addfallback_group_id_on_invalid_request = nil + m.clearedFields[group.FieldFallbackGroupIDOnInvalidRequest] = struct{}{} +} + +// FallbackGroupIDOnInvalidRequestCleared returns if the "fallback_group_id_on_invalid_request" field was cleared in this mutation. +func (m *GroupMutation) FallbackGroupIDOnInvalidRequestCleared() bool { + _, ok := m.clearedFields[group.FieldFallbackGroupIDOnInvalidRequest] + return ok +} + +// ResetFallbackGroupIDOnInvalidRequest resets all changes to the "fallback_group_id_on_invalid_request" field. +func (m *GroupMutation) ResetFallbackGroupIDOnInvalidRequest() { + m.fallback_group_id_on_invalid_request = nil + m.addfallback_group_id_on_invalid_request = nil + delete(m.clearedFields, group.FieldFallbackGroupIDOnInvalidRequest) +} + // SetModelRouting sets the "model_routing" field. func (m *GroupMutation) SetModelRouting(value map[string][]int64) { m.model_routing = &value @@ -7022,6 +7344,93 @@ func (m *GroupMutation) ResetModelRoutingEnabled() { m.model_routing_enabled = nil } +// SetMcpXMLInject sets the "mcp_xml_inject" field. +func (m *GroupMutation) SetMcpXMLInject(b bool) { + m.mcp_xml_inject = &b +} + +// McpXMLInject returns the value of the "mcp_xml_inject" field in the mutation. +func (m *GroupMutation) McpXMLInject() (r bool, exists bool) { + v := m.mcp_xml_inject + if v == nil { + return + } + return *v, true +} + +// OldMcpXMLInject returns the old "mcp_xml_inject" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldMcpXMLInject(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldMcpXMLInject is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldMcpXMLInject requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldMcpXMLInject: %w", err) + } + return oldValue.McpXMLInject, nil +} + +// ResetMcpXMLInject resets all changes to the "mcp_xml_inject" field. +func (m *GroupMutation) ResetMcpXMLInject() { + m.mcp_xml_inject = nil +} + +// SetSupportedModelScopes sets the "supported_model_scopes" field. +func (m *GroupMutation) SetSupportedModelScopes(s []string) { + m.supported_model_scopes = &s + m.appendsupported_model_scopes = nil +} + +// SupportedModelScopes returns the value of the "supported_model_scopes" field in the mutation. +func (m *GroupMutation) SupportedModelScopes() (r []string, exists bool) { + v := m.supported_model_scopes + if v == nil { + return + } + return *v, true +} + +// OldSupportedModelScopes returns the old "supported_model_scopes" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldSupportedModelScopes(ctx context.Context) (v []string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSupportedModelScopes is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSupportedModelScopes requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSupportedModelScopes: %w", err) + } + return oldValue.SupportedModelScopes, nil +} + +// AppendSupportedModelScopes adds s to the "supported_model_scopes" field. +func (m *GroupMutation) AppendSupportedModelScopes(s []string) { + m.appendsupported_model_scopes = append(m.appendsupported_model_scopes, s...) +} + +// AppendedSupportedModelScopes returns the list of values that were appended to the "supported_model_scopes" field in this mutation. +func (m *GroupMutation) AppendedSupportedModelScopes() ([]string, bool) { + if len(m.appendsupported_model_scopes) == 0 { + return nil, false + } + return m.appendsupported_model_scopes, true +} + +// ResetSupportedModelScopes resets all changes to the "supported_model_scopes" field. +func (m *GroupMutation) ResetSupportedModelScopes() { + m.supported_model_scopes = nil + m.appendsupported_model_scopes = nil +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids. func (m *GroupMutation) AddAPIKeyIDs(ids ...int64) { if m.api_keys == nil { @@ -7450,12 +7859,21 @@ func (m *GroupMutation) Fields() []string { if m.fallback_group_id != nil { fields = append(fields, group.FieldFallbackGroupID) } + if m.fallback_group_id_on_invalid_request != nil { + fields = append(fields, group.FieldFallbackGroupIDOnInvalidRequest) + } if m.model_routing != nil { fields = append(fields, group.FieldModelRouting) } if m.model_routing_enabled != nil { fields = append(fields, group.FieldModelRoutingEnabled) } + if m.mcp_xml_inject != nil { + fields = append(fields, group.FieldMcpXMLInject) + } + if m.supported_model_scopes != nil { + fields = append(fields, group.FieldSupportedModelScopes) + } return fields } @@ -7510,10 +7928,16 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) { return m.ClaudeCodeOnly() case group.FieldFallbackGroupID: return m.FallbackGroupID() + case group.FieldFallbackGroupIDOnInvalidRequest: + return m.FallbackGroupIDOnInvalidRequest() case group.FieldModelRouting: return m.ModelRouting() case group.FieldModelRoutingEnabled: return m.ModelRoutingEnabled() + case group.FieldMcpXMLInject: + return m.McpXMLInject() + case group.FieldSupportedModelScopes: + return m.SupportedModelScopes() } return nil, false } @@ -7569,10 +7993,16 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e return m.OldClaudeCodeOnly(ctx) case group.FieldFallbackGroupID: return m.OldFallbackGroupID(ctx) + case group.FieldFallbackGroupIDOnInvalidRequest: + return m.OldFallbackGroupIDOnInvalidRequest(ctx) case group.FieldModelRouting: return m.OldModelRouting(ctx) case group.FieldModelRoutingEnabled: return m.OldModelRoutingEnabled(ctx) + case group.FieldMcpXMLInject: + return m.OldMcpXMLInject(ctx) + case group.FieldSupportedModelScopes: + return m.OldSupportedModelScopes(ctx) } return nil, fmt.Errorf("unknown Group field %s", name) } @@ -7743,6 +8173,13 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error { } m.SetFallbackGroupID(v) return nil + case group.FieldFallbackGroupIDOnInvalidRequest: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetFallbackGroupIDOnInvalidRequest(v) + return nil case group.FieldModelRouting: v, ok := value.(map[string][]int64) if !ok { @@ -7757,6 +8194,20 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error { } m.SetModelRoutingEnabled(v) return nil + case group.FieldMcpXMLInject: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetMcpXMLInject(v) + return nil + case group.FieldSupportedModelScopes: + v, ok := value.([]string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSupportedModelScopes(v) + return nil } return fmt.Errorf("unknown Group field %s", name) } @@ -7804,6 +8255,9 @@ func (m *GroupMutation) AddedFields() []string { if m.addfallback_group_id != nil { fields = append(fields, group.FieldFallbackGroupID) } + if m.addfallback_group_id_on_invalid_request != nil { + fields = append(fields, group.FieldFallbackGroupIDOnInvalidRequest) + } return fields } @@ -7838,6 +8292,8 @@ func (m *GroupMutation) AddedField(name string) (ent.Value, bool) { return m.AddedSoraVideoPricePerRequestHd() case group.FieldFallbackGroupID: return m.AddedFallbackGroupID() + case group.FieldFallbackGroupIDOnInvalidRequest: + return m.AddedFallbackGroupIDOnInvalidRequest() } return nil, false } @@ -7938,6 +8394,13 @@ func (m *GroupMutation) AddField(name string, value ent.Value) error { } m.AddFallbackGroupID(v) return nil + case group.FieldFallbackGroupIDOnInvalidRequest: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddFallbackGroupIDOnInvalidRequest(v) + return nil } return fmt.Errorf("unknown Group numeric field %s", name) } @@ -7985,6 +8448,9 @@ func (m *GroupMutation) ClearedFields() []string { if m.FieldCleared(group.FieldFallbackGroupID) { fields = append(fields, group.FieldFallbackGroupID) } + if m.FieldCleared(group.FieldFallbackGroupIDOnInvalidRequest) { + fields = append(fields, group.FieldFallbackGroupIDOnInvalidRequest) + } if m.FieldCleared(group.FieldModelRouting) { fields = append(fields, group.FieldModelRouting) } @@ -8041,6 +8507,9 @@ func (m *GroupMutation) ClearField(name string) error { case group.FieldFallbackGroupID: m.ClearFallbackGroupID() return nil + case group.FieldFallbackGroupIDOnInvalidRequest: + m.ClearFallbackGroupIDOnInvalidRequest() + return nil case group.FieldModelRouting: m.ClearModelRouting() return nil @@ -8121,12 +8590,21 @@ func (m *GroupMutation) ResetField(name string) error { case group.FieldFallbackGroupID: m.ResetFallbackGroupID() return nil + case group.FieldFallbackGroupIDOnInvalidRequest: + m.ResetFallbackGroupIDOnInvalidRequest() + return nil case group.FieldModelRouting: m.ResetModelRouting() return nil case group.FieldModelRoutingEnabled: m.ResetModelRoutingEnabled() return nil + case group.FieldMcpXMLInject: + m.ResetMcpXMLInject() + return nil + case group.FieldSupportedModelScopes: + m.ResetSupportedModelScopes() + return nil } return fmt.Errorf("unknown Group field %s", name) } diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go index aeced47a..c0f0f123 100644 --- a/backend/ent/runtime/runtime.go +++ b/backend/ent/runtime/runtime.go @@ -91,6 +91,14 @@ func init() { apikey.DefaultStatus = apikeyDescStatus.Default.(string) // apikey.StatusValidator is a validator for the "status" field. It is called by the builders before save. apikey.StatusValidator = apikeyDescStatus.Validators[0].(func(string) error) + // apikeyDescQuota is the schema descriptor for quota field. + apikeyDescQuota := apikeyFields[7].Descriptor() + // apikey.DefaultQuota holds the default value on creation for the quota field. + apikey.DefaultQuota = apikeyDescQuota.Default.(float64) + // apikeyDescQuotaUsed is the schema descriptor for quota_used field. + apikeyDescQuotaUsed := apikeyFields[8].Descriptor() + // apikey.DefaultQuotaUsed holds the default value on creation for the quota_used field. + apikey.DefaultQuotaUsed = apikeyDescQuotaUsed.Default.(float64) accountMixin := schema.Account{}.Mixin() accountMixinHooks1 := accountMixin[1].Hooks() account.Hooks[0] = accountMixinHooks1[0] @@ -334,9 +342,17 @@ func init() { // group.DefaultClaudeCodeOnly holds the default value on creation for the claude_code_only field. group.DefaultClaudeCodeOnly = groupDescClaudeCodeOnly.Default.(bool) // groupDescModelRoutingEnabled is the schema descriptor for model_routing_enabled field. - groupDescModelRoutingEnabled := groupFields[21].Descriptor() + groupDescModelRoutingEnabled := groupFields[22].Descriptor() // group.DefaultModelRoutingEnabled holds the default value on creation for the model_routing_enabled field. group.DefaultModelRoutingEnabled = groupDescModelRoutingEnabled.Default.(bool) + // groupDescMcpXMLInject is the schema descriptor for mcp_xml_inject field. + groupDescMcpXMLInject := groupFields[23].Descriptor() + // group.DefaultMcpXMLInject holds the default value on creation for the mcp_xml_inject field. + group.DefaultMcpXMLInject = groupDescMcpXMLInject.Default.(bool) + // groupDescSupportedModelScopes is the schema descriptor for supported_model_scopes field. + groupDescSupportedModelScopes := groupFields[24].Descriptor() + // group.DefaultSupportedModelScopes holds the default value on creation for the supported_model_scopes field. + group.DefaultSupportedModelScopes = groupDescSupportedModelScopes.Default.([]string) promocodeFields := schema.PromoCode{}.Fields() _ = promocodeFields // promocodeDescCode is the schema descriptor for code field. diff --git a/backend/ent/schema/api_key.go b/backend/ent/schema/api_key.go index 1c2d4bd4..26d52cb0 100644 --- a/backend/ent/schema/api_key.go +++ b/backend/ent/schema/api_key.go @@ -5,6 +5,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/domain" "entgo.io/ent" + "entgo.io/ent/dialect" "entgo.io/ent/dialect/entsql" "entgo.io/ent/schema" "entgo.io/ent/schema/edge" @@ -52,6 +53,23 @@ func (APIKey) Fields() []ent.Field { field.JSON("ip_blacklist", []string{}). Optional(). Comment("Blocked IPs/CIDRs"), + + // ========== Quota fields ========== + // Quota limit in USD (0 = unlimited) + field.Float("quota"). + SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}). + Default(0). + Comment("Quota limit in USD for this API key (0 = unlimited)"), + // Used quota amount + field.Float("quota_used"). + SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}). + Default(0). + Comment("Used quota amount in USD"), + // Expiration time (nil = never expires) + field.Time("expires_at"). + Optional(). + Nillable(). + Comment("Expiration time for this API key (null = never expires)"), } } @@ -77,5 +95,8 @@ func (APIKey) Indexes() []ent.Index { index.Fields("group_id"), index.Fields("status"), index.Fields("deleted_at"), + // Index for quota queries + index.Fields("quota", "quota_used"), + index.Fields("expires_at"), } } diff --git a/backend/ent/schema/group.go b/backend/ent/schema/group.go index 65b57754..cb1e5eec 100644 --- a/backend/ent/schema/group.go +++ b/backend/ent/schema/group.go @@ -113,6 +113,10 @@ func (Group) Fields() []ent.Field { Optional(). Nillable(). Comment("非 Claude Code 请求降级使用的分组 ID"), + field.Int64("fallback_group_id_on_invalid_request"). + Optional(). + Nillable(). + Comment("无效请求兜底使用的分组 ID"), // 模型路由配置 (added by migration 040) field.JSON("model_routing", map[string][]int64{}). @@ -124,6 +128,17 @@ func (Group) Fields() []ent.Field { field.Bool("model_routing_enabled"). Default(false). Comment("是否启用模型路由配置"), + + // MCP XML 协议注入开关 (added by migration 042) + field.Bool("mcp_xml_inject"). + Default(true). + Comment("是否注入 MCP XML 调用协议提示词(仅 antigravity 平台)"), + + // 支持的模型系列 (added by migration 046) + field.JSON("supported_model_scopes", []string{}). + Default([]string{"claude", "gemini_text", "gemini_image"}). + SchemaType(map[string]string{dialect.Postgres: "jsonb"}). + Comment("支持的模型系列:claude, gemini_text, gemini_image"), } } diff --git a/backend/go.mod b/backend/go.mod index 4c3e6246..9a36a0f1 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -4,6 +4,8 @@ go 1.25.6 require ( entgo.io/ent v0.14.5 + github.com/DATA-DOG/go-sqlmock v1.5.2 + github.com/dgraph-io/ristretto v0.2.0 github.com/gin-gonic/gin v1.9.1 github.com/golang-jwt/jwt/v5 v5.2.2 github.com/google/uuid v1.6.0 @@ -11,7 +13,10 @@ require ( github.com/gorilla/websocket v1.5.3 github.com/imroc/req/v3 v3.57.0 github.com/lib/pq v1.10.9 + github.com/pquerna/otp v1.5.0 github.com/redis/go-redis/v9 v9.17.2 + github.com/refraction-networking/utls v1.8.1 + github.com/robfig/cron/v3 v3.0.1 github.com/shirou/gopsutil/v4 v4.25.6 github.com/spf13/viper v1.18.2 github.com/stretchr/testify v1.11.1 @@ -25,13 +30,13 @@ require ( golang.org/x/sync v0.19.0 golang.org/x/term v0.38.0 gopkg.in/yaml.v3 v3.0.1 + modernc.org/sqlite v1.44.3 ) require ( ariga.io/atlas v0.32.1-0.20250325101103-175b25e1c1b9 // indirect dario.cat/mergo v1.0.2 // indirect github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect - github.com/DATA-DOG/go-sqlmock v1.5.2 // indirect github.com/Microsoft/go-winio v0.6.2 // indirect github.com/agext/levenshtein v1.2.3 // indirect github.com/andybalholm/brotli v1.2.0 // indirect @@ -48,7 +53,6 @@ require ( github.com/containerd/platforms v0.2.1 // indirect github.com/cpuguy83/dockercfg v0.3.2 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect - github.com/dgraph-io/ristretto v0.2.0 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/distribution/reference v0.6.0 // indirect github.com/docker/docker v28.5.1+incompatible // indirect @@ -107,13 +111,10 @@ require ( github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect - github.com/pquerna/otp v1.5.0 // indirect github.com/quic-go/qpack v0.6.0 // indirect github.com/quic-go/quic-go v0.57.1 // indirect - github.com/refraction-networking/utls v1.8.1 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/rivo/uniseg v0.2.0 // indirect - github.com/robfig/cron/v3 v3.0.1 // indirect github.com/sagikazarmark/locafero v0.4.0 // indirect github.com/sagikazarmark/slog-shim v0.1.0 // indirect github.com/sirupsen/logrus v1.9.3 // indirect @@ -149,12 +150,10 @@ require ( golang.org/x/sys v0.39.0 // indirect golang.org/x/text v0.32.0 // indirect golang.org/x/tools v0.39.0 // indirect - golang.org/x/tools/go/packages/packagestest v0.1.1-deprecated // indirect google.golang.org/grpc v1.75.1 // indirect google.golang.org/protobuf v1.36.10 // indirect gopkg.in/ini.v1 v1.67.0 // indirect modernc.org/libc v1.67.6 // indirect modernc.org/mathutil v1.7.1 // indirect modernc.org/memory v1.11.0 // indirect - modernc.org/sqlite v1.44.1 // indirect ) diff --git a/backend/go.sum b/backend/go.sum index 0addb5bb..371623ad 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -55,6 +55,8 @@ github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1 github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgraph-io/ristretto v0.2.0 h1:XAfl+7cmoUDWW/2Lx8TGZQjjxIQ2Ley9DSf52dru4WE= github.com/dgraph-io/ristretto v0.2.0/go.mod h1:8uBHCU/PBV4Ag0CJrP47b9Ofby5dqWNh4FicAdoqFNU= +github.com/dgryski/go-farm v0.0.0-20200201041132-a6ae2369ad13 h1:fAjc9m62+UWV/WAFKLNi6ZS0675eEUC9y3AlwSbQu1Y= +github.com/dgryski/go-farm v0.0.0-20200201041132-a6ae2369ad13/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk= @@ -113,6 +115,8 @@ github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8= github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs= +github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA= github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE= github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= @@ -123,6 +127,9 @@ github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aN github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 h1:NmZ1PKzSTQbuGHw9DGPFomqkkLWMC+vZCkfs+FHv1Vg= github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3/go.mod h1:zQrxl1YP88HQlA6i9c63DSVPFklWpGX4OWAc9bFuaH4= +github.com/hashicorp/golang-lru v0.5.4 h1:YDjusn29QI/Das2iO9M0BHnIbxPeyuCHsjMW+lJfyTc= +github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= +github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= github.com/hashicorp/hcl/v2 v2.18.1 h1:6nxnOJFku1EuSawSD81fuviYUV8DxFr3fp2dUi3ZYSo= @@ -345,8 +352,6 @@ golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k= golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU= golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0= -golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g= -golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k= golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY= golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70= golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk= @@ -374,9 +379,8 @@ golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= golang.org/x/tools v0.39.0 h1:ik4ho21kwuQln40uelmciQPp9SipgNDdrafrYA4TmQQ= golang.org/x/tools v0.39.0/go.mod h1:JnefbkDPyD8UU2kI5fuf8ZX4/yUeh9W877ZeBONxUqQ= -golang.org/x/tools/go/expect v0.1.0-deprecated h1:jY2C5HGYR5lqex3gEniOQL0r7Dq5+VGVgY1nudX5lXY= -golang.org/x/tools/go/expect v0.1.0-deprecated/go.mod h1:eihoPOH+FgIqa3FpoTwguz/bVUSGBlGQU67vpBeOrBY= golang.org/x/tools/go/expect v0.1.1-deprecated h1:jpBZDwmgPhXsKZC6WhL20P4b/wmnpsEAGHaNy0n/rJM= +golang.org/x/tools/go/expect v0.1.1-deprecated/go.mod h1:eihoPOH+FgIqa3FpoTwguz/bVUSGBlGQU67vpBeOrBY= golang.org/x/tools/go/packages/packagestest v0.1.1-deprecated h1:1h2MnaIAIXISqTFKdENegdpAgUXz6NrPEsbIeWaBRvM= golang.org/x/tools/go/packages/packagestest v0.1.1-deprecated/go.mod h1:RVAQXBGNv1ib0J382/DPCRS/BPnsGebyM1Gj5VSDpG8= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -399,12 +403,32 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q= gotest.tools/v3 v3.5.2/go.mod h1:LtdLGcnqToBH83WByAAi/wiwSFCArdFIUV/xxN4pcjA= +modernc.org/cc/v4 v4.27.1 h1:9W30zRlYrefrDV2JE2O8VDtJ1yPGownxciz5rrbQZis= +modernc.org/cc/v4 v4.27.1/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0= +modernc.org/ccgo/v4 v4.30.1 h1:4r4U1J6Fhj98NKfSjnPUN7Ze2c6MnAdL0hWw6+LrJpc= +modernc.org/ccgo/v4 v4.30.1/go.mod h1:bIOeI1JL54Utlxn+LwrFyjCx2n2RDiYEaJVSrgdrRfM= +modernc.org/fileutil v1.3.40 h1:ZGMswMNc9JOCrcrakF1HrvmergNLAmxOPjizirpfqBA= +modernc.org/fileutil v1.3.40/go.mod h1:HxmghZSZVAz/LXcMNwZPA/DRrQZEVP9VX0V4LQGQFOc= +modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI= +modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito= +modernc.org/gc/v3 v3.1.1 h1:k8T3gkXWY9sEiytKhcgyiZ2L0DTyCQ/nvX+LoCljoRE= +modernc.org/gc/v3 v3.1.1/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY= +modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks= +modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI= modernc.org/libc v1.67.6 h1:eVOQvpModVLKOdT+LvBPjdQqfrZq+pC39BygcT+E7OI= modernc.org/libc v1.67.6/go.mod h1:JAhxUVlolfYDErnwiqaLvUqc8nfb2r6S6slAgZOnaiE= modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU= modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg= modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI= modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw= -modernc.org/sqlite v1.44.1 h1:qybx/rNpfQipX/t47OxbHmkkJuv2JWifCMH8SVUiDas= -modernc.org/sqlite v1.44.1/go.mod h1:CzbrU2lSB1DKUusvwGz7rqEKIq+NUd8GWuBBZDs9/nA= +modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8= +modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns= +modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w= +modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE= +modernc.org/sqlite v1.44.3 h1:+39JvV/HWMcYslAwRxHb8067w+2zowvFOUrOWIy9PjY= +modernc.org/sqlite v1.44.3/go.mod h1:CzbrU2lSB1DKUusvwGz7rqEKIq+NUd8GWuBBZDs9/nA= +modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0= +modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A= +modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y= +modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= diff --git a/backend/internal/domain/constants.go b/backend/internal/domain/constants.go index 1aee7777..9de0e948 100644 --- a/backend/internal/domain/constants.go +++ b/backend/internal/domain/constants.go @@ -30,6 +30,7 @@ const ( AccountTypeOAuth = "oauth" // OAuth类型账号(full scope: profile + inference) AccountTypeSetupToken = "setup-token" // Setup Token类型账号(inference only scope) AccountTypeAPIKey = "apikey" // API Key类型账号 + AccountTypeUpstream = "upstream" // 上游透传类型账号(通过 Base URL + API Key 连接上游) ) // Redeem type constants diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go index bbf5d026..6d42f726 100644 --- a/backend/internal/handler/admin/account_handler.go +++ b/backend/internal/handler/admin/account_handler.go @@ -84,7 +84,7 @@ type CreateAccountRequest struct { Name string `json:"name" binding:"required"` Notes *string `json:"notes"` Platform string `json:"platform" binding:"required"` - Type string `json:"type" binding:"required,oneof=oauth setup-token apikey"` + Type string `json:"type" binding:"required,oneof=oauth setup-token apikey upstream"` Credentials map[string]any `json:"credentials" binding:"required"` Extra map[string]any `json:"extra"` ProxyID *int64 `json:"proxy_id"` @@ -102,7 +102,7 @@ type CreateAccountRequest struct { type UpdateAccountRequest struct { Name string `json:"name"` Notes *string `json:"notes"` - Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey"` + Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey upstream"` Credentials map[string]any `json:"credentials"` Extra map[string]any `json:"extra"` ProxyID *int64 `json:"proxy_id"` diff --git a/backend/internal/handler/admin/group_handler.go b/backend/internal/handler/admin/group_handler.go index f7f6c893..20a20767 100644 --- a/backend/internal/handler/admin/group_handler.go +++ b/backend/internal/handler/admin/group_handler.go @@ -35,18 +35,22 @@ type CreateGroupRequest struct { WeeklyLimitUSD *float64 `json:"weekly_limit_usd"` MonthlyLimitUSD *float64 `json:"monthly_limit_usd"` // 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置) - ImagePrice1K *float64 `json:"image_price_1k"` - ImagePrice2K *float64 `json:"image_price_2k"` - ImagePrice4K *float64 `json:"image_price_4k"` - SoraImagePrice360 *float64 `json:"sora_image_price_360"` - SoraImagePrice540 *float64 `json:"sora_image_price_540"` - SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"` - SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"` - ClaudeCodeOnly bool `json:"claude_code_only"` - FallbackGroupID *int64 `json:"fallback_group_id"` + ImagePrice1K *float64 `json:"image_price_1k"` + ImagePrice2K *float64 `json:"image_price_2k"` + ImagePrice4K *float64 `json:"image_price_4k"` + SoraImagePrice360 *float64 `json:"sora_image_price_360"` + SoraImagePrice540 *float64 `json:"sora_image_price_540"` + SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"` + SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"` + ClaudeCodeOnly bool `json:"claude_code_only"` + FallbackGroupID *int64 `json:"fallback_group_id"` + FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"` // 模型路由配置(仅 anthropic 平台使用) ModelRouting map[string][]int64 `json:"model_routing"` ModelRoutingEnabled bool `json:"model_routing_enabled"` + MCPXMLInject *bool `json:"mcp_xml_inject"` + // 支持的模型系列(仅 antigravity 平台使用) + SupportedModelScopes []string `json:"supported_model_scopes"` // 从指定分组复制账号(创建后自动绑定) CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"` } @@ -64,18 +68,22 @@ type UpdateGroupRequest struct { WeeklyLimitUSD *float64 `json:"weekly_limit_usd"` MonthlyLimitUSD *float64 `json:"monthly_limit_usd"` // 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置) - ImagePrice1K *float64 `json:"image_price_1k"` - ImagePrice2K *float64 `json:"image_price_2k"` - ImagePrice4K *float64 `json:"image_price_4k"` - SoraImagePrice360 *float64 `json:"sora_image_price_360"` - SoraImagePrice540 *float64 `json:"sora_image_price_540"` - SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"` - SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"` - ClaudeCodeOnly *bool `json:"claude_code_only"` - FallbackGroupID *int64 `json:"fallback_group_id"` + ImagePrice1K *float64 `json:"image_price_1k"` + ImagePrice2K *float64 `json:"image_price_2k"` + ImagePrice4K *float64 `json:"image_price_4k"` + SoraImagePrice360 *float64 `json:"sora_image_price_360"` + SoraImagePrice540 *float64 `json:"sora_image_price_540"` + SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"` + SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"` + ClaudeCodeOnly *bool `json:"claude_code_only"` + FallbackGroupID *int64 `json:"fallback_group_id"` + FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"` // 模型路由配置(仅 anthropic 平台使用) ModelRouting map[string][]int64 `json:"model_routing"` ModelRoutingEnabled *bool `json:"model_routing_enabled"` + MCPXMLInject *bool `json:"mcp_xml_inject"` + // 支持的模型系列(仅 antigravity 平台使用) + SupportedModelScopes *[]string `json:"supported_model_scopes"` // 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号) CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"` } @@ -167,27 +175,30 @@ func (h *GroupHandler) Create(c *gin.Context) { } group, err := h.adminService.CreateGroup(c.Request.Context(), &service.CreateGroupInput{ - Name: req.Name, - Description: req.Description, - Platform: req.Platform, - RateMultiplier: req.RateMultiplier, - IsExclusive: req.IsExclusive, - SubscriptionType: req.SubscriptionType, - DailyLimitUSD: req.DailyLimitUSD, - WeeklyLimitUSD: req.WeeklyLimitUSD, - MonthlyLimitUSD: req.MonthlyLimitUSD, - ImagePrice1K: req.ImagePrice1K, - ImagePrice2K: req.ImagePrice2K, - ImagePrice4K: req.ImagePrice4K, - SoraImagePrice360: req.SoraImagePrice360, - SoraImagePrice540: req.SoraImagePrice540, - SoraVideoPricePerRequest: req.SoraVideoPricePerRequest, - SoraVideoPricePerRequestHD: req.SoraVideoPricePerRequestHD, - ClaudeCodeOnly: req.ClaudeCodeOnly, - FallbackGroupID: req.FallbackGroupID, - ModelRouting: req.ModelRouting, - ModelRoutingEnabled: req.ModelRoutingEnabled, - CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs, + Name: req.Name, + Description: req.Description, + Platform: req.Platform, + RateMultiplier: req.RateMultiplier, + IsExclusive: req.IsExclusive, + SubscriptionType: req.SubscriptionType, + DailyLimitUSD: req.DailyLimitUSD, + WeeklyLimitUSD: req.WeeklyLimitUSD, + MonthlyLimitUSD: req.MonthlyLimitUSD, + ImagePrice1K: req.ImagePrice1K, + ImagePrice2K: req.ImagePrice2K, + ImagePrice4K: req.ImagePrice4K, + SoraImagePrice360: req.SoraImagePrice360, + SoraImagePrice540: req.SoraImagePrice540, + SoraVideoPricePerRequest: req.SoraVideoPricePerRequest, + SoraVideoPricePerRequestHD: req.SoraVideoPricePerRequestHD, + ClaudeCodeOnly: req.ClaudeCodeOnly, + FallbackGroupID: req.FallbackGroupID, + FallbackGroupIDOnInvalidRequest: req.FallbackGroupIDOnInvalidRequest, + ModelRouting: req.ModelRouting, + ModelRoutingEnabled: req.ModelRoutingEnabled, + MCPXMLInject: req.MCPXMLInject, + SupportedModelScopes: req.SupportedModelScopes, + CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs, }) if err != nil { response.ErrorFrom(c, err) @@ -213,28 +224,31 @@ func (h *GroupHandler) Update(c *gin.Context) { } group, err := h.adminService.UpdateGroup(c.Request.Context(), groupID, &service.UpdateGroupInput{ - Name: req.Name, - Description: req.Description, - Platform: req.Platform, - RateMultiplier: req.RateMultiplier, - IsExclusive: req.IsExclusive, - Status: req.Status, - SubscriptionType: req.SubscriptionType, - DailyLimitUSD: req.DailyLimitUSD, - WeeklyLimitUSD: req.WeeklyLimitUSD, - MonthlyLimitUSD: req.MonthlyLimitUSD, - ImagePrice1K: req.ImagePrice1K, - ImagePrice2K: req.ImagePrice2K, - ImagePrice4K: req.ImagePrice4K, - SoraImagePrice360: req.SoraImagePrice360, - SoraImagePrice540: req.SoraImagePrice540, - SoraVideoPricePerRequest: req.SoraVideoPricePerRequest, - SoraVideoPricePerRequestHD: req.SoraVideoPricePerRequestHD, - ClaudeCodeOnly: req.ClaudeCodeOnly, - FallbackGroupID: req.FallbackGroupID, - ModelRouting: req.ModelRouting, - ModelRoutingEnabled: req.ModelRoutingEnabled, - CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs, + Name: req.Name, + Description: req.Description, + Platform: req.Platform, + RateMultiplier: req.RateMultiplier, + IsExclusive: req.IsExclusive, + Status: req.Status, + SubscriptionType: req.SubscriptionType, + DailyLimitUSD: req.DailyLimitUSD, + WeeklyLimitUSD: req.WeeklyLimitUSD, + MonthlyLimitUSD: req.MonthlyLimitUSD, + ImagePrice1K: req.ImagePrice1K, + ImagePrice2K: req.ImagePrice2K, + ImagePrice4K: req.ImagePrice4K, + SoraImagePrice360: req.SoraImagePrice360, + SoraImagePrice540: req.SoraImagePrice540, + SoraVideoPricePerRequest: req.SoraVideoPricePerRequest, + SoraVideoPricePerRequestHD: req.SoraVideoPricePerRequestHD, + ClaudeCodeOnly: req.ClaudeCodeOnly, + FallbackGroupID: req.FallbackGroupID, + FallbackGroupIDOnInvalidRequest: req.FallbackGroupIDOnInvalidRequest, + ModelRouting: req.ModelRouting, + ModelRoutingEnabled: req.ModelRoutingEnabled, + MCPXMLInject: req.MCPXMLInject, + SupportedModelScopes: req.SupportedModelScopes, + CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs, }) if err != nil { response.ErrorFrom(c, err) diff --git a/backend/internal/handler/api_key_handler.go b/backend/internal/handler/api_key_handler.go index 52dc6911..9717194b 100644 --- a/backend/internal/handler/api_key_handler.go +++ b/backend/internal/handler/api_key_handler.go @@ -3,6 +3,7 @@ package handler import ( "strconv" + "time" "github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" @@ -27,11 +28,13 @@ func NewAPIKeyHandler(apiKeyService *service.APIKeyService) *APIKeyHandler { // CreateAPIKeyRequest represents the create API key request payload type CreateAPIKeyRequest struct { - Name string `json:"name" binding:"required"` - GroupID *int64 `json:"group_id"` // nullable - CustomKey *string `json:"custom_key"` // 可选的自定义key - IPWhitelist []string `json:"ip_whitelist"` // IP 白名单 - IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单 + Name string `json:"name" binding:"required"` + GroupID *int64 `json:"group_id"` // nullable + CustomKey *string `json:"custom_key"` // 可选的自定义key + IPWhitelist []string `json:"ip_whitelist"` // IP 白名单 + IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单 + Quota *float64 `json:"quota"` // 配额限制 (USD) + ExpiresInDays *int `json:"expires_in_days"` // 过期天数 } // UpdateAPIKeyRequest represents the update API key request payload @@ -41,6 +44,9 @@ type UpdateAPIKeyRequest struct { Status string `json:"status" binding:"omitempty,oneof=active inactive"` IPWhitelist []string `json:"ip_whitelist"` // IP 白名单 IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单 + Quota *float64 `json:"quota"` // 配额限制 (USD), 0=无限制 + ExpiresAt *string `json:"expires_at"` // 过期时间 (ISO 8601) + ResetQuota *bool `json:"reset_quota"` // 重置已用配额 } // List handles listing user's API keys with pagination @@ -114,11 +120,15 @@ func (h *APIKeyHandler) Create(c *gin.Context) { } svcReq := service.CreateAPIKeyRequest{ - Name: req.Name, - GroupID: req.GroupID, - CustomKey: req.CustomKey, - IPWhitelist: req.IPWhitelist, - IPBlacklist: req.IPBlacklist, + Name: req.Name, + GroupID: req.GroupID, + CustomKey: req.CustomKey, + IPWhitelist: req.IPWhitelist, + IPBlacklist: req.IPBlacklist, + ExpiresInDays: req.ExpiresInDays, + } + if req.Quota != nil { + svcReq.Quota = *req.Quota } key, err := h.apiKeyService.Create(c.Request.Context(), subject.UserID, svcReq) if err != nil { @@ -153,6 +163,8 @@ func (h *APIKeyHandler) Update(c *gin.Context) { svcReq := service.UpdateAPIKeyRequest{ IPWhitelist: req.IPWhitelist, IPBlacklist: req.IPBlacklist, + Quota: req.Quota, + ResetQuota: req.ResetQuota, } if req.Name != "" { svcReq.Name = &req.Name @@ -161,6 +173,21 @@ func (h *APIKeyHandler) Update(c *gin.Context) { if req.Status != "" { svcReq.Status = &req.Status } + // Parse expires_at if provided + if req.ExpiresAt != nil { + if *req.ExpiresAt == "" { + // Empty string means clear expiration + svcReq.ExpiresAt = nil + svcReq.ClearExpiration = true + } else { + t, err := time.Parse(time.RFC3339, *req.ExpiresAt) + if err != nil { + response.BadRequest(c, "Invalid expires_at format: "+err.Error()) + return + } + svcReq.ExpiresAt = &t + } + } key, err := h.apiKeyService.Update(c.Request.Context(), keyID, subject.UserID, svcReq) if err != nil { diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go index 04d1385d..2d183485 100644 --- a/backend/internal/handler/dto/mappers.go +++ b/backend/internal/handler/dto/mappers.go @@ -76,6 +76,9 @@ func APIKeyFromService(k *service.APIKey) *APIKey { Status: k.Status, IPWhitelist: k.IPWhitelist, IPBlacklist: k.IPBlacklist, + Quota: k.Quota, + QuotaUsed: k.QuotaUsed, + ExpiresAt: k.ExpiresAt, CreatedAt: k.CreatedAt, UpdatedAt: k.UpdatedAt, User: UserFromServiceShallow(k.User), @@ -105,10 +108,12 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup { return nil } out := &AdminGroup{ - Group: groupFromServiceBase(g), - ModelRouting: g.ModelRouting, - ModelRoutingEnabled: g.ModelRoutingEnabled, - AccountCount: g.AccountCount, + Group: groupFromServiceBase(g), + ModelRouting: g.ModelRouting, + ModelRoutingEnabled: g.ModelRoutingEnabled, + MCPXMLInject: g.MCPXMLInject, + SupportedModelScopes: g.SupportedModelScopes, + AccountCount: g.AccountCount, } if len(g.AccountGroups) > 0 { out.AccountGroups = make([]AccountGroup, 0, len(g.AccountGroups)) @@ -122,28 +127,29 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup { func groupFromServiceBase(g *service.Group) Group { return Group{ - ID: g.ID, - Name: g.Name, - Description: g.Description, - Platform: g.Platform, - RateMultiplier: g.RateMultiplier, - IsExclusive: g.IsExclusive, - Status: g.Status, - SubscriptionType: g.SubscriptionType, - DailyLimitUSD: g.DailyLimitUSD, - WeeklyLimitUSD: g.WeeklyLimitUSD, - MonthlyLimitUSD: g.MonthlyLimitUSD, - ImagePrice1K: g.ImagePrice1K, - ImagePrice2K: g.ImagePrice2K, - ImagePrice4K: g.ImagePrice4K, - SoraImagePrice360: g.SoraImagePrice360, - SoraImagePrice540: g.SoraImagePrice540, - SoraVideoPricePerRequest: g.SoraVideoPricePerRequest, - SoraVideoPricePerRequestHD: g.SoraVideoPricePerRequestHD, - ClaudeCodeOnly: g.ClaudeCodeOnly, - FallbackGroupID: g.FallbackGroupID, - CreatedAt: g.CreatedAt, - UpdatedAt: g.UpdatedAt, + ID: g.ID, + Name: g.Name, + Description: g.Description, + Platform: g.Platform, + RateMultiplier: g.RateMultiplier, + IsExclusive: g.IsExclusive, + Status: g.Status, + SubscriptionType: g.SubscriptionType, + DailyLimitUSD: g.DailyLimitUSD, + WeeklyLimitUSD: g.WeeklyLimitUSD, + MonthlyLimitUSD: g.MonthlyLimitUSD, + ImagePrice1K: g.ImagePrice1K, + ImagePrice2K: g.ImagePrice2K, + ImagePrice4K: g.ImagePrice4K, + SoraImagePrice360: g.SoraImagePrice360, + SoraImagePrice540: g.SoraImagePrice540, + SoraVideoPricePerRequest: g.SoraVideoPricePerRequest, + SoraVideoPricePerRequestHD: g.SoraVideoPricePerRequestHD, + ClaudeCodeOnly: g.ClaudeCodeOnly, + FallbackGroupID: g.FallbackGroupID, + FallbackGroupIDOnInvalidRequest: g.FallbackGroupIDOnInvalidRequest, + CreatedAt: g.CreatedAt, + UpdatedAt: g.UpdatedAt, } } diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go index f2c7f5f1..602b4a44 100644 --- a/backend/internal/handler/dto/types.go +++ b/backend/internal/handler/dto/types.go @@ -32,16 +32,19 @@ type AdminUser struct { } type APIKey struct { - ID int64 `json:"id"` - UserID int64 `json:"user_id"` - Key string `json:"key"` - Name string `json:"name"` - GroupID *int64 `json:"group_id"` - Status string `json:"status"` - IPWhitelist []string `json:"ip_whitelist"` - IPBlacklist []string `json:"ip_blacklist"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` + ID int64 `json:"id"` + UserID int64 `json:"user_id"` + Key string `json:"key"` + Name string `json:"name"` + GroupID *int64 `json:"group_id"` + Status string `json:"status"` + IPWhitelist []string `json:"ip_whitelist"` + IPBlacklist []string `json:"ip_blacklist"` + Quota float64 `json:"quota"` // Quota limit in USD (0 = unlimited) + QuotaUsed float64 `json:"quota_used"` // Used quota amount in USD + ExpiresAt *time.Time `json:"expires_at"` // Expiration time (nil = never expires) + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` User *User `json:"user,omitempty"` Group *Group `json:"group,omitempty"` @@ -75,6 +78,8 @@ type Group struct { // Claude Code 客户端限制 ClaudeCodeOnly bool `json:"claude_code_only"` FallbackGroupID *int64 `json:"fallback_group_id"` + // 无效请求兜底分组 + FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"` CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` @@ -89,8 +94,13 @@ type AdminGroup struct { ModelRouting map[string][]int64 `json:"model_routing"` ModelRoutingEnabled bool `json:"model_routing_enabled"` - AccountGroups []AccountGroup `json:"account_groups,omitempty"` - AccountCount int64 `json:"account_count,omitempty"` + // MCP XML 协议注入(仅 antigravity 平台使用) + MCPXMLInject bool `json:"mcp_xml_inject"` + + // 支持的模型系列(仅 antigravity 平台使用) + SupportedModelScopes []string `json:"supported_model_scopes"` + AccountGroups []AccountGroup `json:"account_groups,omitempty"` + AccountCount int64 `json:"account_count,omitempty"` } type Account struct { diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 673f6369..21795fb4 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -14,6 +14,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" "github.com/Wei-Shaw/sub2api/internal/pkg/claude" + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" pkgerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/ip" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" @@ -31,6 +32,7 @@ type GatewayHandler struct { userService *service.UserService billingCacheService *service.BillingCacheService usageService *service.UsageService + apiKeyService *service.APIKeyService concurrencyHelper *ConcurrencyHelper maxAccountSwitches int maxAccountSwitchesGemini int @@ -46,6 +48,7 @@ func NewGatewayHandler( concurrencyService *service.ConcurrencyService, billingCacheService *service.BillingCacheService, usageService *service.UsageService, + apiKeyService *service.APIKeyService, cfg *config.Config, ) *GatewayHandler { pingInterval := time.Duration(0) @@ -67,6 +70,7 @@ func NewGatewayHandler( userService: userService, billingCacheService: billingCacheService, usageService: usageService, + apiKeyService: apiKeyService, concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval), maxAccountSwitches: maxAccountSwitches, maxAccountSwitchesGemini: maxAccountSwitchesGemini, @@ -283,10 +287,14 @@ func (h *GatewayHandler) Messages(c *gin.Context) { // 转发请求 - 根据账号平台分流 var result *service.ForwardResult + requestCtx := c.Request.Context() + if switchCount > 0 { + requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount) + } if account.Platform == service.PlatformAntigravity { - result, err = h.antigravityGatewayService.ForwardGemini(c.Request.Context(), c, account, reqModel, "generateContent", reqStream, body) + result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, reqModel, "generateContent", reqStream, body) } else { - result, err = h.geminiCompatService.Forward(c.Request.Context(), c, account, body) + result, err = h.geminiCompatService.Forward(requestCtx, c, account, body) } if accountReleaseFunc != nil { accountReleaseFunc() @@ -318,13 +326,14 @@ func (h *GatewayHandler) Messages(c *gin.Context) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ - Result: result, - APIKey: apiKey, - User: apiKey.User, - Account: usedAccount, - Subscription: subscription, - UserAgent: ua, - IPAddress: clientIP, + Result: result, + APIKey: apiKey, + User: apiKey.User, + Account: usedAccount, + Subscription: subscription, + UserAgent: ua, + IPAddress: clientIP, + APIKeyService: h.apiKeyService, }); err != nil { log.Printf("Record usage failed: %v", err) } @@ -333,139 +342,193 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } } - maxAccountSwitches := h.maxAccountSwitches - switchCount := 0 - failedAccountIDs := make(map[int64]struct{}) - lastFailoverStatus := 0 + currentAPIKey := apiKey + currentSubscription := subscription + var fallbackGroupID *int64 + if apiKey.Group != nil { + fallbackGroupID = apiKey.Group.FallbackGroupIDOnInvalidRequest + } + fallbackUsed := false for { - // 选择支持该模型的账号 - selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, parsedReq.MetadataUserID) - if err != nil { - if len(failedAccountIDs) == 0 { - h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) - return - } - h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) - return - } - account := selection.Account - setOpsSelectedAccount(c, account.ID) + maxAccountSwitches := h.maxAccountSwitches + switchCount := 0 + failedAccountIDs := make(map[int64]struct{}) + lastFailoverStatus := 0 + retryWithFallback := false - // 检查请求拦截(预热请求、SUGGESTION MODE等) - if account.IsInterceptWarmupEnabled() { - interceptType := detectInterceptType(body) - if interceptType != InterceptTypeNone { - if selection.Acquired && selection.ReleaseFunc != nil { - selection.ReleaseFunc() - } - if reqStream { - sendMockInterceptStream(c, reqModel, interceptType) - } else { - sendMockInterceptResponse(c, reqModel, interceptType) - } - return - } - } - - // 3. 获取账号并发槽位 - accountReleaseFunc := selection.ReleaseFunc - if !selection.Acquired { - if selection.WaitPlan == nil { - h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted) - return - } - accountWaitCounted := false - canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting) + for { + // 选择支持该模型的账号 + selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, failedAccountIDs, parsedReq.MetadataUserID) if err != nil { - log.Printf("Increment account wait count failed: %v", err) - } else if !canWait { - log.Printf("Account wait queue full: account=%d", account.ID) - h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted) - return - } - if err == nil && canWait { - accountWaitCounted = true - } - defer func() { - if accountWaitCounted { - h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) - } - }() - - accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout( - c, - account.ID, - selection.WaitPlan.MaxConcurrency, - selection.WaitPlan.Timeout, - reqStream, - &streamStarted, - ) - if err != nil { - log.Printf("Account concurrency acquire failed: %v", err) - h.handleConcurrencyError(c, err, "account", streamStarted) - return - } - if accountWaitCounted { - h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) - accountWaitCounted = false - } - if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil { - log.Printf("Bind sticky session failed: %v", err) - } - } - // 账号槽位/等待计数需要在超时或断开时安全回收 - accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc) - - // 转发请求 - 根据账号平台分流 - var result *service.ForwardResult - if account.Platform == service.PlatformAntigravity { - result, err = h.antigravityGatewayService.Forward(c.Request.Context(), c, account, body) - } else { - result, err = h.gatewayService.Forward(c.Request.Context(), c, account, parsedReq) - } - if accountReleaseFunc != nil { - accountReleaseFunc() - } - if err != nil { - var failoverErr *service.UpstreamFailoverError - if errors.As(err, &failoverErr) { - failedAccountIDs[account.ID] = struct{}{} - lastFailoverStatus = failoverErr.StatusCode - if switchCount >= maxAccountSwitches { - h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) + if len(failedAccountIDs) == 0 { + h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) return } - switchCount++ - log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches) - continue + h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) + return } - // 错误响应已在Forward中处理,这里只记录日志 - log.Printf("Account %d: Forward request failed: %v", account.ID, err) + account := selection.Account + setOpsSelectedAccount(c, account.ID) + + // 检查请求拦截(预热请求、SUGGESTION MODE等) + if account.IsInterceptWarmupEnabled() { + interceptType := detectInterceptType(body) + if interceptType != InterceptTypeNone { + if selection.Acquired && selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } + if reqStream { + sendMockInterceptStream(c, reqModel, interceptType) + } else { + sendMockInterceptResponse(c, reqModel, interceptType) + } + return + } + } + + // 3. 获取账号并发槽位 + accountReleaseFunc := selection.ReleaseFunc + if !selection.Acquired { + if selection.WaitPlan == nil { + h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted) + return + } + accountWaitCounted := false + canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting) + if err != nil { + log.Printf("Increment account wait count failed: %v", err) + } else if !canWait { + log.Printf("Account wait queue full: account=%d", account.ID) + h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted) + return + } + if err == nil && canWait { + accountWaitCounted = true + } + defer func() { + if accountWaitCounted { + h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) + } + }() + + accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout( + c, + account.ID, + selection.WaitPlan.MaxConcurrency, + selection.WaitPlan.Timeout, + reqStream, + &streamStarted, + ) + if err != nil { + log.Printf("Account concurrency acquire failed: %v", err) + h.handleConcurrencyError(c, err, "account", streamStarted) + return + } + if accountWaitCounted { + h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) + accountWaitCounted = false + } + if err := h.gatewayService.BindStickySession(c.Request.Context(), currentAPIKey.GroupID, sessionKey, account.ID); err != nil { + log.Printf("Bind sticky session failed: %v", err) + } + } + // 账号槽位/等待计数需要在超时或断开时安全回收 + accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc) + + // 转发请求 - 根据账号平台分流 + var result *service.ForwardResult + requestCtx := c.Request.Context() + if switchCount > 0 { + requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount) + } + if account.Platform == service.PlatformAntigravity { + result, err = h.antigravityGatewayService.Forward(requestCtx, c, account, body) + } else { + result, err = h.gatewayService.Forward(requestCtx, c, account, parsedReq) + } + if accountReleaseFunc != nil { + accountReleaseFunc() + } + if err != nil { + var promptTooLongErr *service.PromptTooLongError + if errors.As(err, &promptTooLongErr) { + log.Printf("Prompt too long from antigravity: group=%d fallback_group_id=%v fallback_used=%v", currentAPIKey.GroupID, fallbackGroupID, fallbackUsed) + if !fallbackUsed && fallbackGroupID != nil && *fallbackGroupID > 0 { + fallbackGroup, err := h.gatewayService.ResolveGroupByID(c.Request.Context(), *fallbackGroupID) + if err != nil { + log.Printf("Resolve fallback group failed: %v", err) + _ = h.antigravityGatewayService.WriteMappedClaudeError(c, account, promptTooLongErr.StatusCode, promptTooLongErr.RequestID, promptTooLongErr.Body) + return + } + if fallbackGroup.Platform != service.PlatformAnthropic || + fallbackGroup.SubscriptionType == service.SubscriptionTypeSubscription || + fallbackGroup.FallbackGroupIDOnInvalidRequest != nil { + log.Printf("Fallback group invalid: group=%d platform=%s subscription=%s", fallbackGroup.ID, fallbackGroup.Platform, fallbackGroup.SubscriptionType) + _ = h.antigravityGatewayService.WriteMappedClaudeError(c, account, promptTooLongErr.StatusCode, promptTooLongErr.RequestID, promptTooLongErr.Body) + return + } + fallbackAPIKey := cloneAPIKeyWithGroup(apiKey, fallbackGroup) + if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), fallbackAPIKey.User, fallbackAPIKey, fallbackGroup, nil); err != nil { + status, code, message := billingErrorDetails(err) + h.handleStreamingAwareError(c, status, code, message, streamStarted) + return + } + // 兜底重试按“直接请求兜底分组”处理:清除强制平台,允许按分组平台调度 + ctx := context.WithValue(c.Request.Context(), ctxkey.ForcePlatform, "") + c.Request = c.Request.WithContext(ctx) + currentAPIKey = fallbackAPIKey + currentSubscription = nil + fallbackUsed = true + retryWithFallback = true + break + } + _ = h.antigravityGatewayService.WriteMappedClaudeError(c, account, promptTooLongErr.StatusCode, promptTooLongErr.RequestID, promptTooLongErr.Body) + return + } + var failoverErr *service.UpstreamFailoverError + if errors.As(err, &failoverErr) { + failedAccountIDs[account.ID] = struct{}{} + lastFailoverStatus = failoverErr.StatusCode + if switchCount >= maxAccountSwitches { + h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) + return + } + switchCount++ + log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches) + continue + } + // 错误响应已在Forward中处理,这里只记录日志 + log.Printf("Account %d: Forward request failed: %v", account.ID, err) + return + } + + // 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context) + userAgent := c.GetHeader("User-Agent") + clientIP := ip.GetClientIP(c) + + // 异步记录使用量(subscription已在函数开头获取) + go func(result *service.ForwardResult, usedAccount *service.Account, ua, clientIP string) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ + Result: result, + APIKey: currentAPIKey, + User: currentAPIKey.User, + Account: usedAccount, + Subscription: currentSubscription, + UserAgent: ua, + IPAddress: clientIP, + APIKeyService: h.apiKeyService, + }); err != nil { + log.Printf("Record usage failed: %v", err) + } + }(result, account, userAgent, clientIP) + return + } + if !retryWithFallback { return } - - // 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context) - userAgent := c.GetHeader("User-Agent") - clientIP := ip.GetClientIP(c) - - // 异步记录使用量(subscription已在函数开头获取) - go func(result *service.ForwardResult, usedAccount *service.Account, ua, clientIP string) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ - Result: result, - APIKey: apiKey, - User: apiKey.User, - Account: usedAccount, - Subscription: subscription, - UserAgent: ua, - IPAddress: clientIP, - }); err != nil { - log.Printf("Record usage failed: %v", err) - } - }(result, account, userAgent, clientIP) - return } } @@ -540,6 +603,17 @@ func (h *GatewayHandler) AntigravityModels(c *gin.Context) { }) } +func cloneAPIKeyWithGroup(apiKey *service.APIKey, group *service.Group) *service.APIKey { + if apiKey == nil || group == nil { + return apiKey + } + cloned := *apiKey + groupID := group.ID + cloned.GroupID = &groupID + cloned.Group = group + return &cloned +} + // Usage handles getting account balance and usage statistics for CC Switch integration // GET /v1/usage func (h *GatewayHandler) Usage(c *gin.Context) { diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index d1b19ede..787e3760 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -14,6 +14,7 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" "github.com/Wei-Shaw/sub2api/internal/pkg/gemini" "github.com/Wei-Shaw/sub2api/internal/pkg/googleapi" "github.com/Wei-Shaw/sub2api/internal/pkg/ip" @@ -335,10 +336,14 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { // 5) forward (根据平台分流) var result *service.ForwardResult + requestCtx := c.Request.Context() + if switchCount > 0 { + requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount) + } if account.Platform == service.PlatformAntigravity { - result, err = h.antigravityGatewayService.ForwardGemini(c.Request.Context(), c, account, modelName, action, stream, body) + result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, modelName, action, stream, body) } else { - result, err = h.geminiCompatService.ForwardNative(c.Request.Context(), c, account, modelName, action, stream, body) + result, err = h.geminiCompatService.ForwardNative(requestCtx, c, account, modelName, action, stream, body) } if accountReleaseFunc != nil { accountReleaseFunc() @@ -381,6 +386,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { IPAddress: ip, LongContextThreshold: 200000, // Gemini 200K 阈值 LongContextMultiplier: 2.0, // 超出部分双倍计费 + APIKeyService: h.apiKeyService, }); err != nil { log.Printf("Record usage failed: %v", err) } diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index 4c9dd8b9..a84679ae 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -24,6 +24,7 @@ import ( type OpenAIGatewayHandler struct { gatewayService *service.OpenAIGatewayService billingCacheService *service.BillingCacheService + apiKeyService *service.APIKeyService concurrencyHelper *ConcurrencyHelper maxAccountSwitches int } @@ -33,6 +34,7 @@ func NewOpenAIGatewayHandler( gatewayService *service.OpenAIGatewayService, concurrencyService *service.ConcurrencyService, billingCacheService *service.BillingCacheService, + apiKeyService *service.APIKeyService, cfg *config.Config, ) *OpenAIGatewayHandler { pingInterval := time.Duration(0) @@ -46,6 +48,7 @@ func NewOpenAIGatewayHandler( return &OpenAIGatewayHandler{ gatewayService: gatewayService, billingCacheService: billingCacheService, + apiKeyService: apiKeyService, concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval), maxAccountSwitches: maxAccountSwitches, } @@ -299,13 +302,14 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{ - Result: result, - APIKey: apiKey, - User: apiKey.User, - Account: usedAccount, - Subscription: subscription, - UserAgent: ua, - IPAddress: ip, + Result: result, + APIKey: apiKey, + User: apiKey.User, + Account: usedAccount, + Subscription: subscription, + UserAgent: ua, + IPAddress: ip, + APIKeyService: h.apiKeyService, }); err != nil { log.Printf("Record usage failed: %v", err) } diff --git a/backend/internal/pkg/antigravity/oauth.go b/backend/internal/pkg/antigravity/oauth.go index c7d657b9..d1712c98 100644 --- a/backend/internal/pkg/antigravity/oauth.go +++ b/backend/internal/pkg/antigravity/oauth.go @@ -40,17 +40,48 @@ const ( // URL 可用性 TTL(不可用 URL 的恢复时间) URLAvailabilityTTL = 5 * time.Minute + + // Antigravity API 端点 + antigravityProdBaseURL = "https://cloudcode-pa.googleapis.com" + antigravityDailyBaseURL = "https://daily-cloudcode-pa.sandbox.googleapis.com" ) // BaseURLs 定义 Antigravity API 端点(与 Antigravity-Manager 保持一致) var BaseURLs = []string{ - "https://cloudcode-pa.googleapis.com", // prod (优先) - "https://daily-cloudcode-pa.sandbox.googleapis.com", // daily sandbox (备用) + antigravityProdBaseURL, // prod (优先) + antigravityDailyBaseURL, // daily sandbox (备用) } // BaseURL 默认 URL(保持向后兼容) var BaseURL = BaseURLs[0] +// ForwardBaseURLs 返回 API 转发用的 URL 顺序(daily 优先) +func ForwardBaseURLs() []string { + if len(BaseURLs) == 0 { + return nil + } + urls := append([]string(nil), BaseURLs...) + dailyIndex := -1 + for i, url := range urls { + if url == antigravityDailyBaseURL { + dailyIndex = i + break + } + } + if dailyIndex <= 0 { + return urls + } + reordered := make([]string, 0, len(urls)) + reordered = append(reordered, urls[dailyIndex]) + for i, url := range urls { + if i == dailyIndex { + continue + } + reordered = append(reordered, url) + } + return reordered +} + // URLAvailability 管理 URL 可用性状态(带 TTL 自动恢复和动态优先级) type URLAvailability struct { mu sync.RWMutex @@ -100,22 +131,37 @@ func (u *URLAvailability) IsAvailable(url string) bool { // GetAvailableURLs 返回可用的 URL 列表 // 最近成功的 URL 优先,其他按默认顺序 func (u *URLAvailability) GetAvailableURLs() []string { + return u.GetAvailableURLsWithBase(BaseURLs) +} + +// GetAvailableURLsWithBase 返回可用的 URL 列表(使用自定义顺序) +// 最近成功的 URL 优先,其他按传入顺序 +func (u *URLAvailability) GetAvailableURLsWithBase(baseURLs []string) []string { u.mu.RLock() defer u.mu.RUnlock() now := time.Now() - result := make([]string, 0, len(BaseURLs)) + result := make([]string, 0, len(baseURLs)) // 如果有最近成功的 URL 且可用,放在最前面 if u.lastSuccess != "" { - expiry, exists := u.unavailable[u.lastSuccess] - if !exists || now.After(expiry) { - result = append(result, u.lastSuccess) + found := false + for _, url := range baseURLs { + if url == u.lastSuccess { + found = true + break + } + } + if found { + expiry, exists := u.unavailable[u.lastSuccess] + if !exists || now.After(expiry) { + result = append(result, u.lastSuccess) + } } } - // 添加其他可用的 URL(按默认顺序) - for _, url := range BaseURLs { + // 添加其他可用的 URL(按传入顺序) + for _, url := range baseURLs { // 跳过已添加的 lastSuccess if url == u.lastSuccess { continue diff --git a/backend/internal/pkg/antigravity/request_transformer.go b/backend/internal/pkg/antigravity/request_transformer.go index 63f6ee7c..972771a8 100644 --- a/backend/internal/pkg/antigravity/request_transformer.go +++ b/backend/internal/pkg/antigravity/request_transformer.go @@ -44,11 +44,13 @@ type TransformOptions struct { // IdentityPatch 可选:自定义注入到 systemInstruction 开头的身份防护提示词; // 为空时使用默认模板(包含 [IDENTITY_PATCH] 及 SYSTEM_PROMPT_BEGIN 标记)。 IdentityPatch string + EnableMCPXML bool } func DefaultTransformOptions() TransformOptions { return TransformOptions{ EnableIdentityPatch: true, + EnableMCPXML: true, } } @@ -257,8 +259,8 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans // 添加用户的 system prompt parts = append(parts, userSystemParts...) - // 检测是否有 MCP 工具,如有则注入 XML 调用协议 - if hasMCPTools(tools) { + // 检测是否有 MCP 工具,如有且启用了 MCP XML 注入则注入 XML 调用协议 + if opts.EnableMCPXML && hasMCPTools(tools) { parts = append(parts, GeminiPart{Text: mcpXMLProtocol}) } @@ -312,7 +314,7 @@ func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isT parts = append([]GeminiPart{{ Text: "Thinking...", Thought: true, - ThoughtSignature: dummyThoughtSignature, + ThoughtSignature: DummyThoughtSignature, }}, parts...) } } @@ -330,9 +332,10 @@ func buildContents(messages []ClaudeMessage, toolIDToName map[string]string, isT return contents, strippedThinking, nil } -// dummyThoughtSignature 用于跳过 Gemini 3 thought_signature 验证 +// DummyThoughtSignature 用于跳过 Gemini 3 thought_signature 验证 // 参考: https://ai.google.dev/gemini-api/docs/thought-signatures -const dummyThoughtSignature = "skip_thought_signature_validator" +// 导出供跨包使用(如 gemini_native_signature_cleaner 跨账号修复) +const DummyThoughtSignature = "skip_thought_signature_validator" // buildParts 构建消息的 parts // allowDummyThought: 只有 Gemini 模型支持 dummy thought signature @@ -370,7 +373,7 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu // signature 处理: // - Claude 模型(allowDummyThought=false):必须是上游返回的真实 signature(dummy 视为缺失) // - Gemini 模型(allowDummyThought=true):优先透传真实 signature,缺失时使用 dummy signature - if block.Signature != "" && (allowDummyThought || block.Signature != dummyThoughtSignature) { + if block.Signature != "" && (allowDummyThought || block.Signature != DummyThoughtSignature) { part.ThoughtSignature = block.Signature } else if !allowDummyThought { // Claude 模型需要有效 signature;在缺失时降级为普通文本,并在上层禁用 thinking mode。 @@ -381,7 +384,7 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu continue } else { // Gemini 模型使用 dummy signature - part.ThoughtSignature = dummyThoughtSignature + part.ThoughtSignature = DummyThoughtSignature } parts = append(parts, part) @@ -411,10 +414,10 @@ func buildParts(content json.RawMessage, toolIDToName map[string]string, allowDu // tool_use 的 signature 处理: // - Claude 模型(allowDummyThought=false):必须是上游返回的真实 signature(dummy 视为缺失) // - Gemini 模型(allowDummyThought=true):优先透传真实 signature,缺失时使用 dummy signature - if block.Signature != "" && (allowDummyThought || block.Signature != dummyThoughtSignature) { + if block.Signature != "" && (allowDummyThought || block.Signature != DummyThoughtSignature) { part.ThoughtSignature = block.Signature } else if allowDummyThought { - part.ThoughtSignature = dummyThoughtSignature + part.ThoughtSignature = DummyThoughtSignature } parts = append(parts, part) @@ -492,9 +495,23 @@ func parseToolResultContent(content json.RawMessage, isError bool) string { } // buildGenerationConfig 构建 generationConfig +const ( + defaultMaxOutputTokens = 64000 + maxOutputTokensUpperBound = 65000 + maxOutputTokensClaude = 64000 +) + +func maxOutputTokensLimit(model string) int { + if strings.HasPrefix(model, "claude-") { + return maxOutputTokensClaude + } + return maxOutputTokensUpperBound +} + func buildGenerationConfig(req *ClaudeRequest) *GeminiGenerationConfig { + maxLimit := maxOutputTokensLimit(req.Model) config := &GeminiGenerationConfig{ - MaxOutputTokens: 64000, // 默认最大输出 + MaxOutputTokens: defaultMaxOutputTokens, // 默认最大输出 StopSequences: DefaultStopSequences, } @@ -518,6 +535,10 @@ func buildGenerationConfig(req *ClaudeRequest) *GeminiGenerationConfig { } } + if config.MaxOutputTokens > maxLimit { + config.MaxOutputTokens = maxLimit + } + // 其他参数 if req.Temperature != nil { config.Temperature = req.Temperature diff --git a/backend/internal/pkg/antigravity/request_transformer_test.go b/backend/internal/pkg/antigravity/request_transformer_test.go index 9d62a4a1..f938b47f 100644 --- a/backend/internal/pkg/antigravity/request_transformer_test.go +++ b/backend/internal/pkg/antigravity/request_transformer_test.go @@ -86,7 +86,7 @@ func TestBuildParts_ThinkingBlockWithoutSignature(t *testing.T) { if len(parts) != 3 { t.Fatalf("expected 3 parts, got %d", len(parts)) } - if !parts[1].Thought || parts[1].ThoughtSignature != dummyThoughtSignature { + if !parts[1].Thought || parts[1].ThoughtSignature != DummyThoughtSignature { t.Fatalf("expected dummy thought signature, got thought=%v signature=%q", parts[1].Thought, parts[1].ThoughtSignature) } @@ -126,8 +126,8 @@ func TestBuildParts_ToolUseSignatureHandling(t *testing.T) { if len(parts) != 1 || parts[0].FunctionCall == nil { t.Fatalf("expected 1 functionCall part, got %+v", parts) } - if parts[0].ThoughtSignature != dummyThoughtSignature { - t.Fatalf("expected dummy tool signature %q, got %q", dummyThoughtSignature, parts[0].ThoughtSignature) + if parts[0].ThoughtSignature != DummyThoughtSignature { + t.Fatalf("expected dummy tool signature %q, got %q", DummyThoughtSignature, parts[0].ThoughtSignature) } }) diff --git a/backend/internal/pkg/ctxkey/ctxkey.go b/backend/internal/pkg/ctxkey/ctxkey.go index 27bb5ac5..fd7512f7 100644 --- a/backend/internal/pkg/ctxkey/ctxkey.go +++ b/backend/internal/pkg/ctxkey/ctxkey.go @@ -14,6 +14,9 @@ const ( // RetryCount 表示当前请求在网关层的重试次数(用于 Ops 记录与排障)。 RetryCount Key = "ctx_retry_count" + // AccountSwitchCount 表示请求过程中发生的账号切换次数 + AccountSwitchCount Key = "ctx_account_switch_count" + // IsClaudeCodeClient 标识当前请求是否来自 Claude Code 客户端 IsClaudeCodeClient Key = "ctx_is_claude_code_client" // Group 认证后的分组信息,由 API Key 认证中间件设置 diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go index 25fb88b8..78db326c 100644 --- a/backend/internal/repository/api_key_repo.go +++ b/backend/internal/repository/api_key_repo.go @@ -33,7 +33,10 @@ func (r *apiKeyRepository) Create(ctx context.Context, key *service.APIKey) erro SetKey(key.Key). SetName(key.Name). SetStatus(key.Status). - SetNillableGroupID(key.GroupID) + SetNillableGroupID(key.GroupID). + SetQuota(key.Quota). + SetQuotaUsed(key.QuotaUsed). + SetNillableExpiresAt(key.ExpiresAt) if len(key.IPWhitelist) > 0 { builder.SetIPWhitelist(key.IPWhitelist) @@ -110,6 +113,9 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se apikey.FieldStatus, apikey.FieldIPWhitelist, apikey.FieldIPBlacklist, + apikey.FieldQuota, + apikey.FieldQuotaUsed, + apikey.FieldExpiresAt, ). WithUser(func(q *dbent.UserQuery) { q.Select( @@ -140,8 +146,11 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se group.FieldSoraVideoPricePerRequestHd, group.FieldClaudeCodeOnly, group.FieldFallbackGroupID, + group.FieldFallbackGroupIDOnInvalidRequest, group.FieldModelRoutingEnabled, group.FieldModelRouting, + group.FieldMcpXMLInject, + group.FieldSupportedModelScopes, ) }). Only(ctx) @@ -165,6 +174,8 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) erro Where(apikey.IDEQ(key.ID), apikey.DeletedAtIsNil()). SetName(key.Name). SetStatus(key.Status). + SetQuota(key.Quota). + SetQuotaUsed(key.QuotaUsed). SetUpdatedAt(now) if key.GroupID != nil { builder.SetGroupID(*key.GroupID) @@ -172,6 +183,13 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) erro builder.ClearGroupID() } + // Expiration time + if key.ExpiresAt != nil { + builder.SetExpiresAt(*key.ExpiresAt) + } else { + builder.ClearExpiresAt() + } + // IP 限制字段 if len(key.IPWhitelist) > 0 { builder.SetIPWhitelist(key.IPWhitelist) @@ -361,6 +379,38 @@ func (r *apiKeyRepository) ListKeysByGroupID(ctx context.Context, groupID int64) return keys, nil } +// IncrementQuotaUsed atomically increments the quota_used field and returns the new value +func (r *apiKeyRepository) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error) { + // Use raw SQL for atomic increment to avoid race conditions + // First get current value + m, err := r.activeQuery(). + Where(apikey.IDEQ(id)). + Select(apikey.FieldQuotaUsed). + Only(ctx) + if err != nil { + if dbent.IsNotFound(err) { + return 0, service.ErrAPIKeyNotFound + } + return 0, err + } + + newValue := m.QuotaUsed + amount + + // Update with new value + affected, err := r.client.APIKey.Update(). + Where(apikey.IDEQ(id), apikey.DeletedAtIsNil()). + SetQuotaUsed(newValue). + Save(ctx) + if err != nil { + return 0, err + } + if affected == 0 { + return 0, service.ErrAPIKeyNotFound + } + + return newValue, nil +} + func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey { if m == nil { return nil @@ -376,6 +426,9 @@ func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey { CreatedAt: m.CreatedAt, UpdatedAt: m.UpdatedAt, GroupID: m.GroupID, + Quota: m.Quota, + QuotaUsed: m.QuotaUsed, + ExpiresAt: m.ExpiresAt, } if m.Edges.User != nil { out.User = userEntityToService(m.Edges.User) @@ -413,32 +466,35 @@ func groupEntityToService(g *dbent.Group) *service.Group { return nil } return &service.Group{ - ID: g.ID, - Name: g.Name, - Description: derefString(g.Description), - Platform: g.Platform, - RateMultiplier: g.RateMultiplier, - IsExclusive: g.IsExclusive, - Status: g.Status, - Hydrated: true, - SubscriptionType: g.SubscriptionType, - DailyLimitUSD: g.DailyLimitUsd, - WeeklyLimitUSD: g.WeeklyLimitUsd, - MonthlyLimitUSD: g.MonthlyLimitUsd, - ImagePrice1K: g.ImagePrice1k, - ImagePrice2K: g.ImagePrice2k, - ImagePrice4K: g.ImagePrice4k, - SoraImagePrice360: g.SoraImagePrice360, - SoraImagePrice540: g.SoraImagePrice540, - SoraVideoPricePerRequest: g.SoraVideoPricePerRequest, - SoraVideoPricePerRequestHD: g.SoraVideoPricePerRequestHd, - DefaultValidityDays: g.DefaultValidityDays, - ClaudeCodeOnly: g.ClaudeCodeOnly, - FallbackGroupID: g.FallbackGroupID, - ModelRouting: g.ModelRouting, - ModelRoutingEnabled: g.ModelRoutingEnabled, - CreatedAt: g.CreatedAt, - UpdatedAt: g.UpdatedAt, + ID: g.ID, + Name: g.Name, + Description: derefString(g.Description), + Platform: g.Platform, + RateMultiplier: g.RateMultiplier, + IsExclusive: g.IsExclusive, + Status: g.Status, + Hydrated: true, + SubscriptionType: g.SubscriptionType, + DailyLimitUSD: g.DailyLimitUsd, + WeeklyLimitUSD: g.WeeklyLimitUsd, + MonthlyLimitUSD: g.MonthlyLimitUsd, + ImagePrice1K: g.ImagePrice1k, + ImagePrice2K: g.ImagePrice2k, + ImagePrice4K: g.ImagePrice4k, + SoraImagePrice360: g.SoraImagePrice360, + SoraImagePrice540: g.SoraImagePrice540, + SoraVideoPricePerRequest: g.SoraVideoPricePerRequest, + SoraVideoPricePerRequestHD: g.SoraVideoPricePerRequestHd, + DefaultValidityDays: g.DefaultValidityDays, + ClaudeCodeOnly: g.ClaudeCodeOnly, + FallbackGroupID: g.FallbackGroupID, + FallbackGroupIDOnInvalidRequest: g.FallbackGroupIDOnInvalidRequest, + ModelRouting: g.ModelRouting, + ModelRoutingEnabled: g.ModelRoutingEnabled, + MCPXMLInject: g.McpXMLInject, + SupportedModelScopes: g.SupportedModelScopes, + CreatedAt: g.CreatedAt, + UpdatedAt: g.UpdatedAt, } } diff --git a/backend/internal/repository/group_repo.go b/backend/internal/repository/group_repo.go index 14e5cb86..5fb486df 100644 --- a/backend/internal/repository/group_repo.go +++ b/backend/internal/repository/group_repo.go @@ -54,13 +54,18 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er SetDefaultValidityDays(groupIn.DefaultValidityDays). SetClaudeCodeOnly(groupIn.ClaudeCodeOnly). SetNillableFallbackGroupID(groupIn.FallbackGroupID). - SetModelRoutingEnabled(groupIn.ModelRoutingEnabled) + SetNillableFallbackGroupIDOnInvalidRequest(groupIn.FallbackGroupIDOnInvalidRequest). + SetModelRoutingEnabled(groupIn.ModelRoutingEnabled). + SetMcpXMLInject(groupIn.MCPXMLInject) // 设置模型路由配置 if groupIn.ModelRouting != nil { builder = builder.SetModelRouting(groupIn.ModelRouting) } + // 设置支持的模型系列(始终设置,空数组表示不限制) + builder = builder.SetSupportedModelScopes(groupIn.SupportedModelScopes) + created, err := builder.Save(ctx) if err == nil { groupIn.ID = created.ID @@ -91,7 +96,6 @@ func (r *groupRepository) GetByIDLite(ctx context.Context, id int64) (*service.G if err != nil { return nil, translatePersistenceError(err, service.ErrGroupNotFound, nil) } - return groupEntityToService(m), nil } @@ -116,7 +120,8 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er SetNillableSoraVideoPricePerRequestHd(groupIn.SoraVideoPricePerRequestHD). SetDefaultValidityDays(groupIn.DefaultValidityDays). SetClaudeCodeOnly(groupIn.ClaudeCodeOnly). - SetModelRoutingEnabled(groupIn.ModelRoutingEnabled) + SetModelRoutingEnabled(groupIn.ModelRoutingEnabled). + SetMcpXMLInject(groupIn.MCPXMLInject) // 处理 FallbackGroupID:nil 时清除,否则设置 if groupIn.FallbackGroupID != nil { @@ -124,6 +129,12 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er } else { builder = builder.ClearFallbackGroupID() } + // 处理 FallbackGroupIDOnInvalidRequest:nil 时清除,否则设置 + if groupIn.FallbackGroupIDOnInvalidRequest != nil { + builder = builder.SetFallbackGroupIDOnInvalidRequest(*groupIn.FallbackGroupIDOnInvalidRequest) + } else { + builder = builder.ClearFallbackGroupIDOnInvalidRequest() + } // 处理 ModelRouting:nil 时清除,否则设置 if groupIn.ModelRouting != nil { @@ -132,6 +143,9 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er builder = builder.ClearModelRouting() } + // 处理 SupportedModelScopes(始终设置,空数组表示不限制) + builder = builder.SetSupportedModelScopes(groupIn.SupportedModelScopes) + updated, err := builder.Save(ctx) if err != nil { return translatePersistenceError(err, service.ErrGroupNotFound, service.ErrGroupExists) diff --git a/backend/internal/repository/ops_repo_metrics.go b/backend/internal/repository/ops_repo_metrics.go index 713e0eb9..f1e57c38 100644 --- a/backend/internal/repository/ops_repo_metrics.go +++ b/backend/internal/repository/ops_repo_metrics.go @@ -43,6 +43,7 @@ INSERT INTO ops_system_metrics ( upstream_529_count, token_consumed, + account_switch_count, qps, tps, @@ -81,14 +82,14 @@ INSERT INTO ops_system_metrics ( $1,$2,$3,$4, $5,$6,$7,$8, $9,$10,$11, - $12,$13,$14, - $15,$16,$17,$18,$19,$20, - $21,$22,$23,$24,$25,$26, - $27,$28,$29,$30, - $31,$32, - $33,$34, - $35,$36,$37, - $38,$39 + $12,$13,$14,$15, + $16,$17,$18,$19,$20,$21, + $22,$23,$24,$25,$26,$27, + $28,$29,$30,$31, + $32,$33, + $34,$35, + $36,$37,$38, + $39,$40 )` _, err := r.db.ExecContext( @@ -109,6 +110,7 @@ INSERT INTO ops_system_metrics ( input.Upstream529Count, input.TokenConsumed, + input.AccountSwitchCount, opsNullFloat64(input.QPS), opsNullFloat64(input.TPS), @@ -177,7 +179,8 @@ SELECT db_conn_waiting, goroutine_count, - concurrency_queue_depth + concurrency_queue_depth, + account_switch_count FROM ops_system_metrics WHERE window_minutes = $1 AND platform IS NULL @@ -199,6 +202,7 @@ LIMIT 1` var dbWaiting sql.NullInt64 var goroutines sql.NullInt64 var queueDepth sql.NullInt64 + var accountSwitchCount sql.NullInt64 if err := r.db.QueryRowContext(ctx, q, windowMinutes).Scan( &out.ID, @@ -217,6 +221,7 @@ LIMIT 1` &dbWaiting, &goroutines, &queueDepth, + &accountSwitchCount, ); err != nil { return nil, err } @@ -273,6 +278,10 @@ LIMIT 1` v := int(queueDepth.Int64) out.ConcurrencyQueueDepth = &v } + if accountSwitchCount.Valid { + v := accountSwitchCount.Int64 + out.AccountSwitchCount = &v + } return &out, nil } diff --git a/backend/internal/repository/ops_repo_trends.go b/backend/internal/repository/ops_repo_trends.go index 022d1187..14394ed8 100644 --- a/backend/internal/repository/ops_repo_trends.go +++ b/backend/internal/repository/ops_repo_trends.go @@ -56,18 +56,44 @@ error_buckets AS ( AND COALESCE(status_code, 0) >= 400 GROUP BY 1 ), +switch_buckets AS ( + SELECT ` + errorBucketExpr + ` AS bucket, + COALESCE(SUM(CASE + WHEN split_part(ev->>'kind', ':', 1) IN ('failover', 'retry_exhausted_failover', 'failover_on_400') THEN 1 + ELSE 0 + END), 0) AS switch_count + FROM ops_error_logs + CROSS JOIN LATERAL jsonb_array_elements( + COALESCE(NULLIF(upstream_errors, 'null'::jsonb), '[]'::jsonb) + ) AS ev + ` + errorWhere + ` + AND upstream_errors IS NOT NULL + GROUP BY 1 +), combined AS ( - SELECT COALESCE(u.bucket, e.bucket) AS bucket, - COALESCE(u.success_count, 0) AS success_count, - COALESCE(e.error_count, 0) AS error_count, - COALESCE(u.token_consumed, 0) AS token_consumed - FROM usage_buckets u - FULL OUTER JOIN error_buckets e ON u.bucket = e.bucket + SELECT + bucket, + SUM(success_count) AS success_count, + SUM(error_count) AS error_count, + SUM(token_consumed) AS token_consumed, + SUM(switch_count) AS switch_count + FROM ( + SELECT bucket, success_count, 0 AS error_count, token_consumed, 0 AS switch_count + FROM usage_buckets + UNION ALL + SELECT bucket, 0, error_count, 0, 0 + FROM error_buckets + UNION ALL + SELECT bucket, 0, 0, 0, switch_count + FROM switch_buckets + ) t + GROUP BY bucket ) SELECT bucket, (success_count + error_count) AS request_count, - token_consumed + token_consumed, + switch_count FROM combined ORDER BY bucket ASC` @@ -84,13 +110,18 @@ ORDER BY bucket ASC` var bucket time.Time var requests int64 var tokens sql.NullInt64 - if err := rows.Scan(&bucket, &requests, &tokens); err != nil { + var switches sql.NullInt64 + if err := rows.Scan(&bucket, &requests, &tokens, &switches); err != nil { return nil, err } tokenConsumed := int64(0) if tokens.Valid { tokenConsumed = tokens.Int64 } + switchCount := int64(0) + if switches.Valid { + switchCount = switches.Int64 + } denom := float64(bucketSeconds) if denom <= 0 { @@ -103,6 +134,7 @@ ORDER BY bucket ASC` BucketStart: bucket.UTC(), RequestCount: requests, TokenConsumed: tokenConsumed, + SwitchCount: switchCount, QPS: qps, TPS: tps, }) @@ -385,6 +417,7 @@ func fillOpsThroughputBuckets(start, end time.Time, bucketSeconds int, points [] BucketStart: cursor, RequestCount: 0, TokenConsumed: 0, + SwitchCount: 0, QPS: 0, TPS: 0, }) diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index 73809ee1..14e012f2 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -83,6 +83,9 @@ func TestAPIContracts(t *testing.T) { "status": "active", "ip_whitelist": null, "ip_blacklist": null, + "quota": 0, + "quota_used": 0, + "expires_at": null, "created_at": "2025-01-02T03:04:05Z", "updated_at": "2025-01-02T03:04:05Z" } @@ -119,6 +122,9 @@ func TestAPIContracts(t *testing.T) { "status": "active", "ip_whitelist": null, "ip_blacklist": null, + "quota": 0, + "quota_used": 0, + "expires_at": null, "created_at": "2025-01-02T03:04:05Z", "updated_at": "2025-01-02T03:04:05Z" } @@ -184,6 +190,7 @@ func TestAPIContracts(t *testing.T) { "sora_video_price_per_request_hd": null, "claude_code_only": false, "fallback_group_id": null, + "fallback_group_id_on_invalid_request": null, "created_at": "2025-01-02T03:04:05Z", "updated_at": "2025-01-02T03:04:05Z" } @@ -1451,6 +1458,10 @@ func (r *stubApiKeyRepo) ListKeysByGroupID(ctx context.Context, groupID int64) ( return nil, errors.New("not implemented") } +func (r *stubApiKeyRepo) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error) { + return 0, errors.New("not implemented") +} + type stubUsageLogRepo struct { userLogs map[int64][]service.UsageLog } diff --git a/backend/internal/server/middleware/api_key_auth.go b/backend/internal/server/middleware/api_key_auth.go index dff6ba95..2f739357 100644 --- a/backend/internal/server/middleware/api_key_auth.go +++ b/backend/internal/server/middleware/api_key_auth.go @@ -70,7 +70,27 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti // 检查API key是否激活 if !apiKey.IsActive() { - AbortWithError(c, 401, "API_KEY_DISABLED", "API key is disabled") + // Provide more specific error message based on status + switch apiKey.Status { + case service.StatusAPIKeyQuotaExhausted: + AbortWithError(c, 429, "API_KEY_QUOTA_EXHAUSTED", "API key 额度已用完") + case service.StatusAPIKeyExpired: + AbortWithError(c, 403, "API_KEY_EXPIRED", "API key 已过期") + default: + AbortWithError(c, 401, "API_KEY_DISABLED", "API key is disabled") + } + return + } + + // 检查API Key是否过期(即使状态是active,也要检查时间) + if apiKey.IsExpired() { + AbortWithError(c, 403, "API_KEY_EXPIRED", "API key 已过期") + return + } + + // 检查API Key配额是否耗尽 + if apiKey.IsQuotaExhausted() { + AbortWithError(c, 429, "API_KEY_QUOTA_EXHAUSTED", "API key 额度已用完") return } diff --git a/backend/internal/server/middleware/api_key_auth_google.go b/backend/internal/server/middleware/api_key_auth_google.go index 1a0b0dd5..38fbe38b 100644 --- a/backend/internal/server/middleware/api_key_auth_google.go +++ b/backend/internal/server/middleware/api_key_auth_google.go @@ -26,7 +26,7 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs abortWithGoogleError(c, 400, "Query parameter api_key is deprecated. Use Authorization header or key instead.") return } - apiKeyString := extractAPIKeyFromRequest(c) + apiKeyString := extractAPIKeyForGoogle(c) if apiKeyString == "" { abortWithGoogleError(c, 401, "API key is required") return @@ -108,25 +108,38 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs } } -func extractAPIKeyFromRequest(c *gin.Context) string { - authHeader := c.GetHeader("Authorization") - if authHeader != "" { - parts := strings.SplitN(authHeader, " ", 2) - if len(parts) == 2 && parts[0] == "Bearer" && strings.TrimSpace(parts[1]) != "" { - return strings.TrimSpace(parts[1]) +// extractAPIKeyForGoogle extracts API key for Google/Gemini endpoints. +// Priority: x-goog-api-key > Authorization: Bearer > x-api-key > query key +// This allows OpenClaw and other clients using Bearer auth to work with Gemini endpoints. +func extractAPIKeyForGoogle(c *gin.Context) string { + // 1) preferred: Gemini native header + if k := strings.TrimSpace(c.GetHeader("x-goog-api-key")); k != "" { + return k + } + + // 2) fallback: Authorization: Bearer + auth := strings.TrimSpace(c.GetHeader("Authorization")) + if auth != "" { + parts := strings.SplitN(auth, " ", 2) + if len(parts) == 2 && strings.EqualFold(parts[0], "Bearer") { + if k := strings.TrimSpace(parts[1]); k != "" { + return k + } } } - if v := strings.TrimSpace(c.GetHeader("x-api-key")); v != "" { - return v - } - if v := strings.TrimSpace(c.GetHeader("x-goog-api-key")); v != "" { - return v + + // 3) x-api-key header (backward compatibility) + if k := strings.TrimSpace(c.GetHeader("x-api-key")); k != "" { + return k } + + // 4) query parameter key (for specific paths) if allowGoogleQueryKey(c.Request.URL.Path) { if v := strings.TrimSpace(c.Query("key")); v != "" { return v } } + return "" } diff --git a/backend/internal/server/middleware/api_key_auth_google_test.go b/backend/internal/server/middleware/api_key_auth_google_test.go index 6f09469b..c14582bd 100644 --- a/backend/internal/server/middleware/api_key_auth_google_test.go +++ b/backend/internal/server/middleware/api_key_auth_google_test.go @@ -75,6 +75,9 @@ func (f fakeAPIKeyRepo) ListKeysByUserID(ctx context.Context, userID int64) ([]s func (f fakeAPIKeyRepo) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) { return nil, errors.New("not implemented") } +func (f fakeAPIKeyRepo) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error) { + return 0, errors.New("not implemented") +} type googleErrorResponse struct { Error struct { diff --git a/backend/internal/server/middleware/api_key_auth_test.go b/backend/internal/server/middleware/api_key_auth_test.go index 920ff93f..a03f6168 100644 --- a/backend/internal/server/middleware/api_key_auth_test.go +++ b/backend/internal/server/middleware/api_key_auth_test.go @@ -319,6 +319,10 @@ func (r *stubApiKeyRepo) ListKeysByGroupID(ctx context.Context, groupID int64) ( return nil, errors.New("not implemented") } +func (r *stubApiKeyRepo) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error) { + return 0, errors.New("not implemented") +} + type stubUserSubscriptionRepo struct { getActive func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) updateStatus func(ctx context.Context, subscriptionID int64, status string) error diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index b1b37e11..677997c5 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -116,9 +116,14 @@ type CreateGroupInput struct { SoraVideoPricePerRequestHD *float64 ClaudeCodeOnly bool // 仅允许 Claude Code 客户端 FallbackGroupID *int64 // 降级分组 ID + // 无效请求兜底分组 ID(仅 anthropic 平台使用) + FallbackGroupIDOnInvalidRequest *int64 // 模型路由配置(仅 anthropic 平台使用) ModelRouting map[string][]int64 ModelRoutingEnabled bool // 是否启用模型路由 + MCPXMLInject *bool + // 支持的模型系列(仅 antigravity 平台使用) + SupportedModelScopes []string // 从指定分组复制账号(创建分组后在同一事务内绑定) CopyAccountsFromGroupIDs []int64 } @@ -145,9 +150,14 @@ type UpdateGroupInput struct { SoraVideoPricePerRequestHD *float64 ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端 FallbackGroupID *int64 // 降级分组 ID + // 无效请求兜底分组 ID(仅 anthropic 平台使用) + FallbackGroupIDOnInvalidRequest *int64 // 模型路由配置(仅 anthropic 平台使用) ModelRouting map[string][]int64 ModelRoutingEnabled *bool // 是否启用模型路由 + MCPXMLInject *bool + // 支持的模型系列(仅 antigravity 平台使用) + SupportedModelScopes *[]string // 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号) CopyAccountsFromGroupIDs []int64 } @@ -611,6 +621,22 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn return nil, err } } + fallbackOnInvalidRequest := input.FallbackGroupIDOnInvalidRequest + if fallbackOnInvalidRequest != nil && *fallbackOnInvalidRequest <= 0 { + fallbackOnInvalidRequest = nil + } + // 校验无效请求兜底分组 + if fallbackOnInvalidRequest != nil { + if err := s.validateFallbackGroupOnInvalidRequest(ctx, 0, platform, subscriptionType, *fallbackOnInvalidRequest); err != nil { + return nil, err + } + } + + // MCPXMLInject:默认为 true,仅当显式传入 false 时关闭 + mcpXMLInject := true + if input.MCPXMLInject != nil { + mcpXMLInject = *input.MCPXMLInject + } // 如果指定了复制账号的源分组,先获取账号 ID 列表 var accountIDsToCopy []int64 @@ -645,26 +671,29 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn } group := &Group{ - Name: input.Name, - Description: input.Description, - Platform: platform, - RateMultiplier: input.RateMultiplier, - IsExclusive: input.IsExclusive, - Status: StatusActive, - SubscriptionType: subscriptionType, - DailyLimitUSD: dailyLimit, - WeeklyLimitUSD: weeklyLimit, - MonthlyLimitUSD: monthlyLimit, - ImagePrice1K: imagePrice1K, - ImagePrice2K: imagePrice2K, - ImagePrice4K: imagePrice4K, - SoraImagePrice360: soraImagePrice360, - SoraImagePrice540: soraImagePrice540, - SoraVideoPricePerRequest: soraVideoPrice, - SoraVideoPricePerRequestHD: soraVideoPriceHD, - ClaudeCodeOnly: input.ClaudeCodeOnly, - FallbackGroupID: input.FallbackGroupID, - ModelRouting: input.ModelRouting, + Name: input.Name, + Description: input.Description, + Platform: platform, + RateMultiplier: input.RateMultiplier, + IsExclusive: input.IsExclusive, + Status: StatusActive, + SubscriptionType: subscriptionType, + DailyLimitUSD: dailyLimit, + WeeklyLimitUSD: weeklyLimit, + MonthlyLimitUSD: monthlyLimit, + ImagePrice1K: imagePrice1K, + ImagePrice2K: imagePrice2K, + ImagePrice4K: imagePrice4K, + SoraImagePrice360: soraImagePrice360, + SoraImagePrice540: soraImagePrice540, + SoraVideoPricePerRequest: soraVideoPrice, + SoraVideoPricePerRequestHD: soraVideoPriceHD, + ClaudeCodeOnly: input.ClaudeCodeOnly, + FallbackGroupID: input.FallbackGroupID, + FallbackGroupIDOnInvalidRequest: fallbackOnInvalidRequest, + ModelRouting: input.ModelRouting, + MCPXMLInject: mcpXMLInject, + SupportedModelScopes: input.SupportedModelScopes, } if err := s.groupRepo.Create(ctx, group); err != nil { return nil, err @@ -735,6 +764,37 @@ func (s *adminServiceImpl) validateFallbackGroup(ctx context.Context, currentGro } } +// validateFallbackGroupOnInvalidRequest 校验无效请求兜底分组的有效性 +// currentGroupID: 当前分组 ID(新建时为 0) +// platform/subscriptionType: 当前分组的有效平台/订阅类型 +// fallbackGroupID: 兜底分组 ID +func (s *adminServiceImpl) validateFallbackGroupOnInvalidRequest(ctx context.Context, currentGroupID int64, platform, subscriptionType string, fallbackGroupID int64) error { + if platform != PlatformAnthropic && platform != PlatformAntigravity { + return fmt.Errorf("invalid request fallback only supported for anthropic or antigravity groups") + } + if subscriptionType == SubscriptionTypeSubscription { + return fmt.Errorf("subscription groups cannot set invalid request fallback") + } + if currentGroupID > 0 && currentGroupID == fallbackGroupID { + return fmt.Errorf("cannot set self as invalid request fallback group") + } + + fallbackGroup, err := s.groupRepo.GetByIDLite(ctx, fallbackGroupID) + if err != nil { + return fmt.Errorf("fallback group not found: %w", err) + } + if fallbackGroup.Platform != PlatformAnthropic { + return fmt.Errorf("fallback group must be anthropic platform") + } + if fallbackGroup.SubscriptionType == SubscriptionTypeSubscription { + return fmt.Errorf("fallback group cannot be subscription type") + } + if fallbackGroup.FallbackGroupIDOnInvalidRequest != nil { + return fmt.Errorf("fallback group cannot have invalid request fallback configured") + } + return nil +} + func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) { group, err := s.groupRepo.GetByID(ctx, id) if err != nil { @@ -813,6 +873,20 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd group.FallbackGroupID = nil } } + fallbackOnInvalidRequest := group.FallbackGroupIDOnInvalidRequest + if input.FallbackGroupIDOnInvalidRequest != nil { + if *input.FallbackGroupIDOnInvalidRequest > 0 { + fallbackOnInvalidRequest = input.FallbackGroupIDOnInvalidRequest + } else { + fallbackOnInvalidRequest = nil + } + } + if fallbackOnInvalidRequest != nil { + if err := s.validateFallbackGroupOnInvalidRequest(ctx, id, group.Platform, group.SubscriptionType, *fallbackOnInvalidRequest); err != nil { + return nil, err + } + } + group.FallbackGroupIDOnInvalidRequest = fallbackOnInvalidRequest // 模型路由配置 if input.ModelRouting != nil { @@ -821,6 +895,14 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd if input.ModelRoutingEnabled != nil { group.ModelRoutingEnabled = *input.ModelRoutingEnabled } + if input.MCPXMLInject != nil { + group.MCPXMLInject = *input.MCPXMLInject + } + + // 支持的模型系列(仅 antigravity 平台使用) + if input.SupportedModelScopes != nil { + group.SupportedModelScopes = *input.SupportedModelScopes + } if err := s.groupRepo.Update(ctx, group); err != nil { return nil, err diff --git a/backend/internal/service/admin_service_group_test.go b/backend/internal/service/admin_service_group_test.go index 1daee89f..d921a086 100644 --- a/backend/internal/service/admin_service_group_test.go +++ b/backend/internal/service/admin_service_group_test.go @@ -394,3 +394,382 @@ func (s *groupRepoStubForFallbackCycle) BindAccountsToGroup(_ context.Context, _ func (s *groupRepoStubForFallbackCycle) GetAccountIDsByGroupIDs(_ context.Context, _ []int64) ([]int64, error) { panic("unexpected GetAccountIDsByGroupIDs call") } + +type groupRepoStubForInvalidRequestFallback struct { + groups map[int64]*Group + created *Group + updated *Group +} + +func (s *groupRepoStubForInvalidRequestFallback) Create(_ context.Context, g *Group) error { + s.created = g + return nil +} + +func (s *groupRepoStubForInvalidRequestFallback) Update(_ context.Context, g *Group) error { + s.updated = g + return nil +} + +func (s *groupRepoStubForInvalidRequestFallback) GetByID(ctx context.Context, id int64) (*Group, error) { + return s.GetByIDLite(ctx, id) +} + +func (s *groupRepoStubForInvalidRequestFallback) GetByIDLite(_ context.Context, id int64) (*Group, error) { + if g, ok := s.groups[id]; ok { + return g, nil + } + return nil, ErrGroupNotFound +} + +func (s *groupRepoStubForInvalidRequestFallback) Delete(_ context.Context, _ int64) error { + panic("unexpected Delete call") +} + +func (s *groupRepoStubForInvalidRequestFallback) DeleteCascade(_ context.Context, _ int64) ([]int64, error) { + panic("unexpected DeleteCascade call") +} + +func (s *groupRepoStubForInvalidRequestFallback) List(_ context.Context, _ pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) { + panic("unexpected List call") +} + +func (s *groupRepoStubForInvalidRequestFallback) ListWithFilters(_ context.Context, _ pagination.PaginationParams, _, _, _ string, _ *bool) ([]Group, *pagination.PaginationResult, error) { + panic("unexpected ListWithFilters call") +} + +func (s *groupRepoStubForInvalidRequestFallback) ListActive(_ context.Context) ([]Group, error) { + panic("unexpected ListActive call") +} + +func (s *groupRepoStubForInvalidRequestFallback) ListActiveByPlatform(_ context.Context, _ string) ([]Group, error) { + panic("unexpected ListActiveByPlatform call") +} + +func (s *groupRepoStubForInvalidRequestFallback) ExistsByName(_ context.Context, _ string) (bool, error) { + panic("unexpected ExistsByName call") +} + +func (s *groupRepoStubForInvalidRequestFallback) GetAccountCount(_ context.Context, _ int64) (int64, error) { + panic("unexpected GetAccountCount call") +} + +func (s *groupRepoStubForInvalidRequestFallback) DeleteAccountGroupsByGroupID(_ context.Context, _ int64) (int64, error) { + panic("unexpected DeleteAccountGroupsByGroupID call") +} + +func (s *groupRepoStubForInvalidRequestFallback) GetAccountIDsByGroupIDs(_ context.Context, _ []int64) ([]int64, error) { + panic("unexpected GetAccountIDsByGroupIDs call") +} + +func (s *groupRepoStubForInvalidRequestFallback) BindAccountsToGroup(_ context.Context, _ int64, _ []int64) error { + panic("unexpected BindAccountsToGroup call") +} + +func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsUnsupportedPlatform(t *testing.T) { + fallbackID := int64(10) + repo := &groupRepoStubForInvalidRequestFallback{ + groups: map[int64]*Group{ + fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard}, + }, + } + svc := &adminServiceImpl{groupRepo: repo} + + _, err := svc.CreateGroup(context.Background(), &CreateGroupInput{ + Name: "g1", + Platform: PlatformOpenAI, + SubscriptionType: SubscriptionTypeStandard, + FallbackGroupIDOnInvalidRequest: &fallbackID, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid request fallback only supported for anthropic or antigravity groups") + require.Nil(t, repo.created) +} + +func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsSubscription(t *testing.T) { + fallbackID := int64(10) + repo := &groupRepoStubForInvalidRequestFallback{ + groups: map[int64]*Group{ + fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard}, + }, + } + svc := &adminServiceImpl{groupRepo: repo} + + _, err := svc.CreateGroup(context.Background(), &CreateGroupInput{ + Name: "g1", + Platform: PlatformAnthropic, + SubscriptionType: SubscriptionTypeSubscription, + FallbackGroupIDOnInvalidRequest: &fallbackID, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "subscription groups cannot set invalid request fallback") + require.Nil(t, repo.created) +} + +func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsFallbackGroup(t *testing.T) { + tests := []struct { + name string + fallback *Group + wantMessage string + }{ + { + name: "openai_target", + fallback: &Group{ID: 10, Platform: PlatformOpenAI, SubscriptionType: SubscriptionTypeStandard}, + wantMessage: "fallback group must be anthropic platform", + }, + { + name: "antigravity_target", + fallback: &Group{ID: 10, Platform: PlatformAntigravity, SubscriptionType: SubscriptionTypeStandard}, + wantMessage: "fallback group must be anthropic platform", + }, + { + name: "subscription_group", + fallback: &Group{ID: 10, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeSubscription}, + wantMessage: "fallback group cannot be subscription type", + }, + { + name: "nested_fallback", + fallback: &Group{ + ID: 10, + Platform: PlatformAnthropic, + SubscriptionType: SubscriptionTypeStandard, + FallbackGroupIDOnInvalidRequest: func() *int64 { v := int64(99); return &v }(), + }, + wantMessage: "fallback group cannot have invalid request fallback configured", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + fallbackID := tc.fallback.ID + repo := &groupRepoStubForInvalidRequestFallback{ + groups: map[int64]*Group{ + fallbackID: tc.fallback, + }, + } + svc := &adminServiceImpl{groupRepo: repo} + + _, err := svc.CreateGroup(context.Background(), &CreateGroupInput{ + Name: "g1", + Platform: PlatformAnthropic, + SubscriptionType: SubscriptionTypeStandard, + FallbackGroupIDOnInvalidRequest: &fallbackID, + }) + require.Error(t, err) + require.Contains(t, err.Error(), tc.wantMessage) + require.Nil(t, repo.created) + }) + } +} + +func TestAdminService_CreateGroup_InvalidRequestFallbackNotFound(t *testing.T) { + fallbackID := int64(10) + repo := &groupRepoStubForInvalidRequestFallback{} + svc := &adminServiceImpl{groupRepo: repo} + + _, err := svc.CreateGroup(context.Background(), &CreateGroupInput{ + Name: "g1", + Platform: PlatformAnthropic, + SubscriptionType: SubscriptionTypeStandard, + FallbackGroupIDOnInvalidRequest: &fallbackID, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "fallback group not found") + require.Nil(t, repo.created) +} + +func TestAdminService_CreateGroup_InvalidRequestFallbackAllowsAntigravity(t *testing.T) { + fallbackID := int64(10) + repo := &groupRepoStubForInvalidRequestFallback{ + groups: map[int64]*Group{ + fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard}, + }, + } + svc := &adminServiceImpl{groupRepo: repo} + + group, err := svc.CreateGroup(context.Background(), &CreateGroupInput{ + Name: "g1", + Platform: PlatformAntigravity, + SubscriptionType: SubscriptionTypeStandard, + FallbackGroupIDOnInvalidRequest: &fallbackID, + }) + require.NoError(t, err) + require.NotNil(t, group) + require.NotNil(t, repo.created) + require.Equal(t, fallbackID, *repo.created.FallbackGroupIDOnInvalidRequest) +} + +func TestAdminService_CreateGroup_InvalidRequestFallbackClearsOnZero(t *testing.T) { + zero := int64(0) + repo := &groupRepoStubForInvalidRequestFallback{} + svc := &adminServiceImpl{groupRepo: repo} + + group, err := svc.CreateGroup(context.Background(), &CreateGroupInput{ + Name: "g1", + Platform: PlatformAnthropic, + SubscriptionType: SubscriptionTypeStandard, + FallbackGroupIDOnInvalidRequest: &zero, + }) + require.NoError(t, err) + require.NotNil(t, group) + require.NotNil(t, repo.created) + require.Nil(t, repo.created.FallbackGroupIDOnInvalidRequest) +} + +func TestAdminService_UpdateGroup_InvalidRequestFallbackPlatformMismatch(t *testing.T) { + fallbackID := int64(10) + existing := &Group{ + ID: 1, + Name: "g1", + Platform: PlatformAnthropic, + SubscriptionType: SubscriptionTypeStandard, + Status: StatusActive, + FallbackGroupIDOnInvalidRequest: &fallbackID, + } + repo := &groupRepoStubForInvalidRequestFallback{ + groups: map[int64]*Group{ + existing.ID: existing, + fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard}, + }, + } + svc := &adminServiceImpl{groupRepo: repo} + + _, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{ + Platform: PlatformOpenAI, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid request fallback only supported for anthropic or antigravity groups") + require.Nil(t, repo.updated) +} + +func TestAdminService_UpdateGroup_InvalidRequestFallbackSubscriptionMismatch(t *testing.T) { + fallbackID := int64(10) + existing := &Group{ + ID: 1, + Name: "g1", + Platform: PlatformAnthropic, + SubscriptionType: SubscriptionTypeStandard, + Status: StatusActive, + FallbackGroupIDOnInvalidRequest: &fallbackID, + } + repo := &groupRepoStubForInvalidRequestFallback{ + groups: map[int64]*Group{ + existing.ID: existing, + fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard}, + }, + } + svc := &adminServiceImpl{groupRepo: repo} + + _, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{ + SubscriptionType: SubscriptionTypeSubscription, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "subscription groups cannot set invalid request fallback") + require.Nil(t, repo.updated) +} + +func TestAdminService_UpdateGroup_InvalidRequestFallbackClearsOnZero(t *testing.T) { + fallbackID := int64(10) + existing := &Group{ + ID: 1, + Name: "g1", + Platform: PlatformAnthropic, + SubscriptionType: SubscriptionTypeStandard, + Status: StatusActive, + FallbackGroupIDOnInvalidRequest: &fallbackID, + } + repo := &groupRepoStubForInvalidRequestFallback{ + groups: map[int64]*Group{ + existing.ID: existing, + fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard}, + }, + } + svc := &adminServiceImpl{groupRepo: repo} + + clear := int64(0) + group, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{ + Platform: PlatformOpenAI, + FallbackGroupIDOnInvalidRequest: &clear, + }) + require.NoError(t, err) + require.NotNil(t, group) + require.NotNil(t, repo.updated) + require.Nil(t, repo.updated.FallbackGroupIDOnInvalidRequest) +} + +func TestAdminService_UpdateGroup_InvalidRequestFallbackRejectsFallbackGroup(t *testing.T) { + fallbackID := int64(10) + existing := &Group{ + ID: 1, + Name: "g1", + Platform: PlatformAnthropic, + SubscriptionType: SubscriptionTypeStandard, + Status: StatusActive, + } + repo := &groupRepoStubForInvalidRequestFallback{ + groups: map[int64]*Group{ + existing.ID: existing, + fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeSubscription}, + }, + } + svc := &adminServiceImpl{groupRepo: repo} + + _, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{ + FallbackGroupIDOnInvalidRequest: &fallbackID, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "fallback group cannot be subscription type") + require.Nil(t, repo.updated) +} + +func TestAdminService_UpdateGroup_InvalidRequestFallbackSetSuccess(t *testing.T) { + fallbackID := int64(10) + existing := &Group{ + ID: 1, + Name: "g1", + Platform: PlatformAnthropic, + SubscriptionType: SubscriptionTypeStandard, + Status: StatusActive, + } + repo := &groupRepoStubForInvalidRequestFallback{ + groups: map[int64]*Group{ + existing.ID: existing, + fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard}, + }, + } + svc := &adminServiceImpl{groupRepo: repo} + + group, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{ + FallbackGroupIDOnInvalidRequest: &fallbackID, + }) + require.NoError(t, err) + require.NotNil(t, group) + require.NotNil(t, repo.updated) + require.Equal(t, fallbackID, *repo.updated.FallbackGroupIDOnInvalidRequest) +} + +func TestAdminService_UpdateGroup_InvalidRequestFallbackAllowsAntigravity(t *testing.T) { + fallbackID := int64(10) + existing := &Group{ + ID: 1, + Name: "g1", + Platform: PlatformAntigravity, + SubscriptionType: SubscriptionTypeStandard, + Status: StatusActive, + } + repo := &groupRepoStubForInvalidRequestFallback{ + groups: map[int64]*Group{ + existing.ID: existing, + fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard}, + }, + } + svc := &adminServiceImpl{groupRepo: repo} + + group, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{ + FallbackGroupIDOnInvalidRequest: &fallbackID, + }) + require.NoError(t, err) + require.NotNil(t, group) + require.NotNil(t, repo.updated) + require.Equal(t, fallbackID, *repo.updated.FallbackGroupIDOnInvalidRequest) +} diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index 9b8156e6..cf7e35fc 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -13,23 +13,34 @@ import ( "net" "net/http" "os" + "strconv" "strings" "sync/atomic" "time" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" "github.com/gin-gonic/gin" "github.com/google/uuid" ) const ( - antigravityStickySessionTTL = time.Hour - antigravityMaxRetries = 3 - antigravityRetryBaseDelay = 1 * time.Second - antigravityRetryMaxDelay = 16 * time.Second + antigravityStickySessionTTL = time.Hour + antigravityDefaultMaxRetries = 3 + antigravityRetryBaseDelay = 1 * time.Second + antigravityRetryMaxDelay = 16 * time.Second ) -const antigravityScopeRateLimitEnv = "GATEWAY_ANTIGRAVITY_429_SCOPE_LIMIT" +const ( + antigravityMaxRetriesEnv = "GATEWAY_ANTIGRAVITY_MAX_RETRIES" + antigravityMaxRetriesAfterSwitchEnv = "GATEWAY_ANTIGRAVITY_AFTER_SWITCHMAX_RETRIES" + antigravityMaxRetriesClaudeEnv = "GATEWAY_ANTIGRAVITY_MAX_RETRIES_CLAUDE" + antigravityMaxRetriesGeminiTextEnv = "GATEWAY_ANTIGRAVITY_MAX_RETRIES_GEMINI_TEXT" + antigravityMaxRetriesGeminiImageEnv = "GATEWAY_ANTIGRAVITY_MAX_RETRIES_GEMINI_IMAGE" + antigravityScopeRateLimitEnv = "GATEWAY_ANTIGRAVITY_429_SCOPE_LIMIT" + antigravityBillingModelEnv = "GATEWAY_ANTIGRAVITY_BILL_WITH_MAPPED_MODEL" + antigravityFallbackSecondsEnv = "GATEWAY_ANTIGRAVITY_FALLBACK_COOLDOWN_SECONDS" +) // antigravityRetryLoopParams 重试循环的参数 type antigravityRetryLoopParams struct { @@ -41,6 +52,7 @@ type antigravityRetryLoopParams struct { action string body []byte quotaScope AntigravityQuotaScope + maxRetries int c *gin.Context httpUpstream HTTPUpstream settingService *SettingService @@ -52,11 +64,28 @@ type antigravityRetryLoopResult struct { resp *http.Response } +// PromptTooLongError 表示上游明确返回 prompt too long +type PromptTooLongError struct { + StatusCode int + RequestID string + Body []byte +} + +func (e *PromptTooLongError) Error() string { + return fmt.Sprintf("prompt too long: status=%d", e.StatusCode) +} + // antigravityRetryLoop 执行带 URL fallback 的重试循环 func antigravityRetryLoop(p antigravityRetryLoopParams) (*antigravityRetryLoopResult, error) { - availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLs() + baseURLs := antigravity.ForwardBaseURLs() + availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLsWithBase(baseURLs) if len(availableURLs) == 0 { - availableURLs = antigravity.BaseURLs + availableURLs = baseURLs + } + + maxRetries := p.maxRetries + if maxRetries <= 0 { + maxRetries = antigravityDefaultMaxRetries } var resp *http.Response @@ -76,7 +105,7 @@ func antigravityRetryLoop(p antigravityRetryLoopParams) (*antigravityRetryLoopRe urlFallbackLoop: for urlIdx, baseURL := range availableURLs { usedBaseURL = baseURL - for attempt := 1; attempt <= antigravityMaxRetries; attempt++ { + for attempt := 1; attempt <= maxRetries; attempt++ { select { case <-p.ctx.Done(): log.Printf("%s status=context_canceled error=%v", p.prefix, p.ctx.Err()) @@ -109,8 +138,8 @@ urlFallbackLoop: log.Printf("%s URL fallback (connection error): %s -> %s", p.prefix, baseURL, availableURLs[urlIdx+1]) continue urlFallbackLoop } - if attempt < antigravityMaxRetries { - log.Printf("%s status=request_failed retry=%d/%d error=%v", p.prefix, attempt, antigravityMaxRetries, err) + if attempt < maxRetries { + log.Printf("%s status=request_failed retry=%d/%d error=%v", p.prefix, attempt, maxRetries, err) if !sleepAntigravityBackoffWithContext(p.ctx, attempt) { log.Printf("%s status=context_canceled_during_backoff", p.prefix) return nil, p.ctx.Err() @@ -134,7 +163,7 @@ urlFallbackLoop: } // 账户/模型配额限流,重试 3 次(指数退避) - if attempt < antigravityMaxRetries { + if attempt < maxRetries { upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody)) upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) appendOpsUpstreamError(p.c, OpsUpstreamErrorEvent{ @@ -147,7 +176,7 @@ urlFallbackLoop: Message: upstreamMsg, Detail: getUpstreamDetail(respBody), }) - log.Printf("%s status=429 retry=%d/%d body=%s", p.prefix, attempt, antigravityMaxRetries, truncateForLog(respBody, 200)) + log.Printf("%s status=429 retry=%d/%d body=%s", p.prefix, attempt, maxRetries, truncateForLog(respBody, 200)) if !sleepAntigravityBackoffWithContext(p.ctx, attempt) { log.Printf("%s status=context_canceled_during_backoff", p.prefix) return nil, p.ctx.Err() @@ -171,7 +200,7 @@ urlFallbackLoop: respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) _ = resp.Body.Close() - if attempt < antigravityMaxRetries { + if attempt < maxRetries { upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody)) upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) appendOpsUpstreamError(p.c, OpsUpstreamErrorEvent{ @@ -184,7 +213,7 @@ urlFallbackLoop: Message: upstreamMsg, Detail: getUpstreamDetail(respBody), }) - log.Printf("%s status=%d retry=%d/%d body=%s", p.prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 500)) + log.Printf("%s status=%d retry=%d/%d body=%s", p.prefix, resp.StatusCode, attempt, maxRetries, truncateForLog(respBody, 500)) if !sleepAntigravityBackoffWithContext(p.ctx, attempt) { log.Printf("%s status=context_canceled_during_backoff", p.prefix) return nil, p.ctx.Err() @@ -390,6 +419,11 @@ type TestConnectionResult struct { // TestConnection 测试 Antigravity 账号连接(非流式,无重试、无计费) // 支持 Claude 和 Gemini 两种协议,根据 modelID 前缀自动选择 func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account *Account, modelID string) (*TestConnectionResult, error) { + // 上游透传账号使用专用测试方法 + if account.Type == AccountTypeUpstream { + return s.testUpstreamConnection(ctx, account, modelID) + } + // 获取 token if s.tokenProvider == nil { return nil, errors.New("antigravity token provider not configured") @@ -484,6 +518,87 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account return nil, lastErr } +// testUpstreamConnection 测试上游透传账号连接 +func (s *AntigravityGatewayService) testUpstreamConnection(ctx context.Context, account *Account, modelID string) (*TestConnectionResult, error) { + baseURL := strings.TrimSpace(account.GetCredential("base_url")) + apiKey := strings.TrimSpace(account.GetCredential("api_key")) + if baseURL == "" || apiKey == "" { + return nil, errors.New("upstream account missing base_url or api_key") + } + baseURL = strings.TrimSuffix(baseURL, "/") + + // 使用 Claude 模型进行测试 + if modelID == "" { + modelID = "claude-sonnet-4-20250514" + } + + // 构建最小测试请求 + testReq := map[string]any{ + "model": modelID, + "max_tokens": 1, + "messages": []map[string]any{ + {"role": "user", "content": "."}, + }, + } + requestBody, err := json.Marshal(testReq) + if err != nil { + return nil, fmt.Errorf("构建请求失败: %w", err) + } + + // 构建 HTTP 请求 + upstreamURL := baseURL + "/v1/messages" + req, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(requestBody)) + if err != nil { + return nil, fmt.Errorf("创建请求失败: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+apiKey) + req.Header.Set("x-api-key", apiKey) + req.Header.Set("anthropic-version", "2023-06-01") + + // 代理 URL + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + log.Printf("[antigravity-Test-Upstream] account=%s url=%s", account.Name, upstreamURL) + + // 发送请求 + resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) + if err != nil { + return nil, fmt.Errorf("请求失败: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + respBody, err := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + if err != nil { + return nil, fmt.Errorf("读取响应失败: %w", err) + } + + if resp.StatusCode >= 400 { + return nil, fmt.Errorf("API 返回 %d: %s", resp.StatusCode, string(respBody)) + } + + // 提取响应文本 + var respData map[string]any + text := "" + if json.Unmarshal(respBody, &respData) == nil { + if content, ok := respData["content"].([]any); ok && len(content) > 0 { + if block, ok := content[0].(map[string]any); ok { + if t, ok := block["text"].(string); ok { + text = t + } + } + } + } + + return &TestConnectionResult{ + Text: text, + MappedModel: modelID, + }, nil +} + // buildGeminiTestRequest 构建 Gemini 格式测试请求 // 使用最小 token 消耗:输入 "." + maxOutputTokens: 1 func (s *AntigravityGatewayService) buildGeminiTestRequest(projectID, model string) ([]byte, error) { @@ -534,6 +649,10 @@ func (s *AntigravityGatewayService) getClaudeTransformOptions(ctx context.Contex } opts.EnableIdentityPatch = s.settingService.IsIdentityPatchEnabled(ctx) opts.IdentityPatch = s.settingService.GetIdentityPatchPrompt(ctx) + + if group, ok := ctx.Value(ctxkey.Group).(*Group); ok && group != nil { + opts.EnableMCPXML = group.MCPXMLInject + } return opts } @@ -702,6 +821,11 @@ func isModelNotFoundError(statusCode int, body []byte) bool { // Forward 转发 Claude 协议请求(Claude → Gemini 转换) func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) { + // 上游透传账号直接转发,不走 OAuth token 刷新 + if account.Type == AccountTypeUpstream { + return s.ForwardUpstream(ctx, c, account, body) + } + startTime := time.Now() sessionID := getSessionID(c) prefix := logPrefix(sessionID, account.Name) @@ -718,6 +842,12 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, originalModel := claudeReq.Model mappedModel := s.getMappedModel(account, claudeReq.Model) quotaScope, _ := resolveAntigravityQuotaScope(originalModel) + billingModel := originalModel + if antigravityUseMappedModelForBilling() && strings.TrimSpace(mappedModel) != "" { + billingModel = mappedModel + } + afterSwitch := antigravityHasAccountSwitch(ctx) + maxRetries := antigravityMaxRetriesForModel(originalModel, afterSwitch) // 获取 access_token if s.tokenProvider == nil { @@ -766,6 +896,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, httpUpstream: s.httpUpstream, settingService: s.settingService, handleError: s.handleUpstreamError, + maxRetries: maxRetries, }) if err != nil { return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries") @@ -842,6 +973,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, httpUpstream: s.httpUpstream, settingService: s.settingService, handleError: s.handleUpstreamError, + maxRetries: maxRetries, }) if retryErr != nil { appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ @@ -917,6 +1049,39 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, // 处理错误响应(重试后仍失败或不触发重试) if resp.StatusCode >= 400 { + if resp.StatusCode == http.StatusBadRequest { + upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + log.Printf("%s status=400 prompt_too_long=%v upstream_message=%q request_id=%s body=%s", prefix, isPromptTooLongError(respBody), upstreamMsg, resp.Header.Get("x-request-id"), truncateForLog(respBody, 500)) + } + if resp.StatusCode == http.StatusBadRequest && isPromptTooLongError(respBody) { + upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + logBody := s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBody + maxBytes := 2048 + if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes > 0 { + maxBytes = s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + } + upstreamDetail := "" + if logBody { + upstreamDetail = truncateString(string(respBody), maxBytes) + } + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: "prompt_too_long", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + return nil, &PromptTooLongError{ + StatusCode: resp.StatusCode, + RequestID: resp.Header.Get("x-request-id"), + Body: respBody, + } + } s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope) if s.shouldFailoverUpstreamError(resp.StatusCode) { @@ -978,7 +1143,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, return &ForwardResult{ RequestID: requestID, Usage: *usage, - Model: originalModel, // 使用原始模型用于计费和日志 + Model: billingModel, // 计费模型(可按映射模型覆盖) Stream: claudeReq.Stream, Duration: time.Since(startTime), FirstTokenMs: firstTokenMs, @@ -1003,24 +1168,64 @@ func isSignatureRelatedError(respBody []byte) bool { return true } + // Detect thinking block modification errors: + // "thinking or redacted_thinking blocks in the latest assistant message cannot be modified" + if strings.Contains(msg, "cannot be modified") && (strings.Contains(msg, "thinking") || strings.Contains(msg, "redacted_thinking")) { + return true + } + return false } +func isPromptTooLongError(respBody []byte) bool { + msg := strings.ToLower(strings.TrimSpace(extractAntigravityErrorMessage(respBody))) + if msg == "" { + msg = strings.ToLower(string(respBody)) + } + return strings.Contains(msg, "prompt is too long") +} + func extractAntigravityErrorMessage(body []byte) string { var payload map[string]any if err := json.Unmarshal(body, &payload); err != nil { return "" } + parseNestedMessage := func(msg string) string { + trimmed := strings.TrimSpace(msg) + if trimmed == "" || !strings.HasPrefix(trimmed, "{") { + return "" + } + var nested map[string]any + if err := json.Unmarshal([]byte(trimmed), &nested); err != nil { + return "" + } + if errObj, ok := nested["error"].(map[string]any); ok { + if innerMsg, ok := errObj["message"].(string); ok && strings.TrimSpace(innerMsg) != "" { + return innerMsg + } + } + if innerMsg, ok := nested["message"].(string); ok && strings.TrimSpace(innerMsg) != "" { + return innerMsg + } + return "" + } + // Google-style: {"error": {"message": "..."}} if errObj, ok := payload["error"].(map[string]any); ok { if msg, ok := errObj["message"].(string); ok && strings.TrimSpace(msg) != "" { + if innerMsg := parseNestedMessage(msg); innerMsg != "" { + return innerMsg + } return msg } } // Fallback: top-level message if msg, ok := payload["message"].(string); ok && strings.TrimSpace(msg) != "" { + if innerMsg := parseNestedMessage(msg); innerMsg != "" { + return innerMsg + } return msg } @@ -1248,6 +1453,208 @@ func stripSignatureSensitiveBlocksFromClaudeRequest(req *antigravity.ClaudeReque return changed, nil } +// ForwardUpstream 透传请求到上游 Antigravity 服务 +// 用于 upstream 类型账号,直接使用 base_url + api_key 转发,不走 OAuth token +func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) { + startTime := time.Now() + sessionID := getSessionID(c) + prefix := logPrefix(sessionID, account.Name) + + // 获取上游配置 + baseURL := strings.TrimSpace(account.GetCredential("base_url")) + apiKey := strings.TrimSpace(account.GetCredential("api_key")) + if baseURL == "" || apiKey == "" { + return nil, fmt.Errorf("upstream account missing base_url or api_key") + } + baseURL = strings.TrimSuffix(baseURL, "/") + + // 解析请求获取模型信息 + var claudeReq antigravity.ClaudeRequest + if err := json.Unmarshal(body, &claudeReq); err != nil { + return nil, fmt.Errorf("parse claude request: %w", err) + } + if strings.TrimSpace(claudeReq.Model) == "" { + return nil, fmt.Errorf("missing model") + } + originalModel := claudeReq.Model + billingModel := originalModel + + // 构建上游请求 URL + upstreamURL := baseURL + "/v1/messages" + + // 创建请求 + req, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("create upstream request: %w", err) + } + + // 设置请求头 + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+apiKey) + req.Header.Set("x-api-key", apiKey) // Claude API 兼容 + + // 透传 Claude 相关 headers + if v := c.GetHeader("anthropic-version"); v != "" { + req.Header.Set("anthropic-version", v) + } + if v := c.GetHeader("anthropic-beta"); v != "" { + req.Header.Set("anthropic-beta", v) + } + + // 代理 URL + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + // 发送请求 + resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) + if err != nil { + log.Printf("%s upstream request failed: %v", prefix, err) + return nil, fmt.Errorf("upstream request failed: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + // 处理错误响应 + if resp.StatusCode >= 400 { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + + // 429 错误时标记账号限流 + if resp.StatusCode == http.StatusTooManyRequests { + s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, AntigravityQuotaScopeClaude) + } + + // 透传上游错误 + c.Header("Content-Type", resp.Header.Get("Content-Type")) + c.Status(resp.StatusCode) + _, _ = c.Writer.Write(respBody) + + return &ForwardResult{ + Model: billingModel, + }, nil + } + + // 处理成功响应(流式/非流式) + var usage *ClaudeUsage + var firstTokenMs *int + + if claudeReq.Stream { + // 流式响应:透传 + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("X-Accel-Buffering", "no") + c.Status(http.StatusOK) + + usage, firstTokenMs = s.streamUpstreamResponse(c, resp, startTime) + } else { + // 非流式响应:直接透传 + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("read upstream response: %w", err) + } + + // 提取 usage + usage = s.extractClaudeUsage(respBody) + + c.Header("Content-Type", resp.Header.Get("Content-Type")) + c.Status(http.StatusOK) + _, _ = c.Writer.Write(respBody) + } + + // 构建计费结果 + duration := time.Since(startTime) + log.Printf("%s status=success duration_ms=%d", prefix, duration.Milliseconds()) + + return &ForwardResult{ + Model: billingModel, + Stream: claudeReq.Stream, + Duration: duration, + FirstTokenMs: firstTokenMs, + Usage: ClaudeUsage{ + InputTokens: usage.InputTokens, + OutputTokens: usage.OutputTokens, + CacheReadInputTokens: usage.CacheReadInputTokens, + CacheCreationInputTokens: usage.CacheCreationInputTokens, + }, + }, nil +} + +// streamUpstreamResponse 透传上游流式响应并提取 usage +func (s *AntigravityGatewayService) streamUpstreamResponse(c *gin.Context, resp *http.Response, startTime time.Time) (*ClaudeUsage, *int) { + usage := &ClaudeUsage{} + var firstTokenMs *int + var firstTokenRecorded bool + + scanner := bufio.NewScanner(resp.Body) + buf := make([]byte, 0, 64*1024) + scanner.Buffer(buf, 1024*1024) + + for scanner.Scan() { + line := scanner.Bytes() + + // 记录首 token 时间 + if !firstTokenRecorded && len(line) > 0 { + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + firstTokenRecorded = true + } + + // 尝试从 message_delta 或 message_stop 事件提取 usage + if bytes.HasPrefix(line, []byte("data: ")) { + dataStr := bytes.TrimPrefix(line, []byte("data: ")) + var event map[string]any + if json.Unmarshal(dataStr, &event) == nil { + if u, ok := event["usage"].(map[string]any); ok { + if v, ok := u["input_tokens"].(float64); ok && int(v) > 0 { + usage.InputTokens = int(v) + } + if v, ok := u["output_tokens"].(float64); ok && int(v) > 0 { + usage.OutputTokens = int(v) + } + if v, ok := u["cache_read_input_tokens"].(float64); ok && int(v) > 0 { + usage.CacheReadInputTokens = int(v) + } + if v, ok := u["cache_creation_input_tokens"].(float64); ok && int(v) > 0 { + usage.CacheCreationInputTokens = int(v) + } + } + } + } + + // 透传行 + _, _ = c.Writer.Write(line) + _, _ = c.Writer.Write([]byte("\n")) + c.Writer.Flush() + } + + return usage, firstTokenMs +} + +// extractClaudeUsage 从非流式 Claude 响应提取 usage +func (s *AntigravityGatewayService) extractClaudeUsage(body []byte) *ClaudeUsage { + usage := &ClaudeUsage{} + var resp map[string]any + if json.Unmarshal(body, &resp) != nil { + return usage + } + if u, ok := resp["usage"].(map[string]any); ok { + if v, ok := u["input_tokens"].(float64); ok { + usage.InputTokens = int(v) + } + if v, ok := u["output_tokens"].(float64); ok { + usage.OutputTokens = int(v) + } + if v, ok := u["cache_read_input_tokens"].(float64); ok { + usage.CacheReadInputTokens = int(v) + } + if v, ok := u["cache_creation_input_tokens"].(float64); ok { + usage.CacheCreationInputTokens = int(v) + } + } + return usage +} + // ForwardGemini 转发 Gemini 协议请求 func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Context, account *Account, originalModel string, action string, stream bool, body []byte) (*ForwardResult, error) { startTime := time.Now() @@ -1287,6 +1694,12 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co } mappedModel := s.getMappedModel(account, originalModel) + billingModel := originalModel + if antigravityUseMappedModelForBilling() && strings.TrimSpace(mappedModel) != "" { + billingModel = mappedModel + } + afterSwitch := antigravityHasAccountSwitch(ctx) + maxRetries := antigravityMaxRetriesForModel(originalModel, afterSwitch) // 获取 access_token if s.tokenProvider == nil { @@ -1306,8 +1719,15 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co proxyURL = account.Proxy.URL() } + // 过滤掉 parts 为空的消息(Gemini API 不接受空 parts) + filteredBody, err := filterEmptyPartsFromGeminiRequest(body) + if err != nil { + log.Printf("[Antigravity] Failed to filter empty parts: %v", err) + filteredBody = body + } + // Antigravity 上游要求必须包含身份提示词,注入到请求中 - injectedBody, err := injectIdentityPatchToGeminiRequest(body) + injectedBody, err := injectIdentityPatchToGeminiRequest(filteredBody) if err != nil { return nil, err } @@ -1344,6 +1764,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co httpUpstream: s.httpUpstream, settingService: s.settingService, handleError: s.handleUpstreamError, + maxRetries: maxRetries, }) if err != nil { return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries") @@ -1493,7 +1914,7 @@ handleSuccess: return &ForwardResult{ RequestID: requestID, Usage: *usage, - Model: originalModel, + Model: billingModel, Stream: stream, Duration: time.Since(startTime), FirstTokenMs: firstTokenMs, @@ -1544,6 +1965,81 @@ func antigravityUseScopeRateLimit() bool { return true } +func antigravityHasAccountSwitch(ctx context.Context) bool { + if ctx == nil { + return false + } + if v, ok := ctx.Value(ctxkey.AccountSwitchCount).(int); ok { + return v > 0 + } + return false +} + +func antigravityMaxRetries() int { + raw := strings.TrimSpace(os.Getenv(antigravityMaxRetriesEnv)) + if raw == "" { + return antigravityDefaultMaxRetries + } + value, err := strconv.Atoi(raw) + if err != nil || value <= 0 { + return antigravityDefaultMaxRetries + } + return value +} + +func antigravityMaxRetriesAfterSwitch() int { + raw := strings.TrimSpace(os.Getenv(antigravityMaxRetriesAfterSwitchEnv)) + if raw == "" { + return antigravityMaxRetries() + } + value, err := strconv.Atoi(raw) + if err != nil || value <= 0 { + return antigravityMaxRetries() + } + return value +} + +// antigravityMaxRetriesForModel 根据模型类型获取重试次数 +// 优先使用模型细分配置,未设置则回退到平台级配置 +func antigravityMaxRetriesForModel(model string, afterSwitch bool) int { + var envKey string + if strings.HasPrefix(model, "claude-") { + envKey = antigravityMaxRetriesClaudeEnv + } else if isImageGenerationModel(model) { + envKey = antigravityMaxRetriesGeminiImageEnv + } else if strings.HasPrefix(model, "gemini-") { + envKey = antigravityMaxRetriesGeminiTextEnv + } + + if envKey != "" { + if raw := strings.TrimSpace(os.Getenv(envKey)); raw != "" { + if value, err := strconv.Atoi(raw); err == nil && value > 0 { + return value + } + } + } + if afterSwitch { + return antigravityMaxRetriesAfterSwitch() + } + return antigravityMaxRetries() +} + +func antigravityUseMappedModelForBilling() bool { + v := strings.ToLower(strings.TrimSpace(os.Getenv(antigravityBillingModelEnv))) + return v == "1" || v == "true" || v == "yes" || v == "on" +} + +func antigravityFallbackCooldownSeconds() (time.Duration, bool) { + raw := strings.TrimSpace(os.Getenv(antigravityFallbackSecondsEnv)) + if raw == "" { + return 0, false + } + seconds, err := strconv.Atoi(raw) + if err != nil || seconds <= 0 { + return 0, false + } + return time.Duration(seconds) * time.Second, true +} func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope) { // 429 使用 Gemini 格式解析(从 body 解析重置时间) if statusCode == 429 { @@ -1556,6 +2052,9 @@ func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, pre fallbackMinutes = s.settingService.cfg.Gateway.AntigravityFallbackCooldownMinutes } defaultDur := time.Duration(fallbackMinutes) * time.Minute + if fallbackDur, ok := antigravityFallbackCooldownSeconds(); ok { + defaultDur = fallbackDur + } ra := time.Now().Add(defaultDur) if useScopeLimit { log.Printf("%s status=429 rate_limited scope=%s reset_in=%v (fallback)", prefix, quotaScope, defaultDur) @@ -2193,6 +2692,10 @@ func (s *AntigravityGatewayService) writeMappedClaudeError(c *gin.Context, accou return fmt.Errorf("upstream error: %d message=%s", upstreamStatus, upstreamMsg) } +func (s *AntigravityGatewayService) WriteMappedClaudeError(c *gin.Context, account *Account, upstreamStatus int, upstreamRequestID string, body []byte) error { + return s.writeMappedClaudeError(c, account, upstreamStatus, upstreamRequestID, body) +} + func (s *AntigravityGatewayService) writeGoogleError(c *gin.Context, status int, message string) error { statusStr := "UNKNOWN" switch status { @@ -2618,3 +3121,55 @@ func cleanGeminiRequest(body []byte) ([]byte, error) { return json.Marshal(payload) } + +// filterEmptyPartsFromGeminiRequest 过滤 Gemini 请求中 parts 为空的消息 +// Gemini API 不接受 parts 为空数组的消息,会返回 400 错误 +func filterEmptyPartsFromGeminiRequest(body []byte) ([]byte, error) { + var payload map[string]any + if err := json.Unmarshal(body, &payload); err != nil { + return nil, err + } + + contents, ok := payload["contents"].([]any) + if !ok || len(contents) == 0 { + return body, nil + } + + filtered := make([]any, 0, len(contents)) + modified := false + + for _, c := range contents { + contentMap, ok := c.(map[string]any) + if !ok { + filtered = append(filtered, c) + continue + } + + parts, hasParts := contentMap["parts"] + if !hasParts { + filtered = append(filtered, c) + continue + } + + partsSlice, ok := parts.([]any) + if !ok { + filtered = append(filtered, c) + continue + } + + // 跳过 parts 为空数组的消息 + if len(partsSlice) == 0 { + modified = true + continue + } + + filtered = append(filtered, c) + } + + if !modified { + return body, nil + } + + payload["contents"] = filtered + return json.Marshal(payload) +} diff --git a/backend/internal/service/antigravity_gateway_service_test.go b/backend/internal/service/antigravity_gateway_service_test.go index 05ad9bbd..32a591ef 100644 --- a/backend/internal/service/antigravity_gateway_service_test.go +++ b/backend/internal/service/antigravity_gateway_service_test.go @@ -1,10 +1,16 @@ package service import ( + "bytes" + "context" "encoding/json" + "io" + "net/http" + "net/http/httptest" "testing" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" + "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" ) @@ -81,3 +87,106 @@ func TestStripThinkingFromClaudeRequest_DoesNotDowngradeTools(t *testing.T) { require.Equal(t, "secret plan", blocks[0]["text"]) require.Equal(t, "tool_use", blocks[1]["type"]) } + +func TestIsPromptTooLongError(t *testing.T) { + require.True(t, isPromptTooLongError([]byte(`{"error":{"message":"Prompt is too long"}}`))) + require.True(t, isPromptTooLongError([]byte(`{"message":"Prompt is too long"}`))) + require.False(t, isPromptTooLongError([]byte(`{"error":{"message":"other"}}`))) +} + +type httpUpstreamStub struct { + resp *http.Response + err error +} + +func (s *httpUpstreamStub) Do(_ *http.Request, _ string, _ int64, _ int) (*http.Response, error) { + return s.resp, s.err +} + +func (s *httpUpstreamStub) DoWithTLS(_ *http.Request, _ string, _ int64, _ int, _ bool) (*http.Response, error) { + return s.resp, s.err +} + +func TestAntigravityGatewayService_Forward_PromptTooLong(t *testing.T) { + gin.SetMode(gin.TestMode) + writer := httptest.NewRecorder() + c, _ := gin.CreateTestContext(writer) + + body, err := json.Marshal(map[string]any{ + "model": "claude-opus-4-5", + "messages": []map[string]any{ + {"role": "user", "content": "hi"}, + }, + "max_tokens": 1, + "stream": false, + }) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + c.Request = req + + respBody := []byte(`{"error":{"message":"Prompt is too long"}}`) + resp := &http.Response{ + StatusCode: http.StatusBadRequest, + Header: http.Header{"X-Request-Id": []string{"req-1"}}, + Body: io.NopCloser(bytes.NewReader(respBody)), + } + + svc := &AntigravityGatewayService{ + tokenProvider: &AntigravityTokenProvider{}, + httpUpstream: &httpUpstreamStub{resp: resp}, + } + + account := &Account{ + ID: 1, + Name: "acc-1", + Platform: PlatformAntigravity, + Type: AccountTypeOAuth, + Status: StatusActive, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "token", + }, + } + + result, err := svc.Forward(context.Background(), c, account, body) + require.Nil(t, result) + + var promptErr *PromptTooLongError + require.ErrorAs(t, err, &promptErr) + require.Equal(t, http.StatusBadRequest, promptErr.StatusCode) + require.Equal(t, "req-1", promptErr.RequestID) + require.NotEmpty(t, promptErr.Body) + + raw, ok := c.Get(OpsUpstreamErrorsKey) + require.True(t, ok) + events, ok := raw.([]*OpsUpstreamErrorEvent) + require.True(t, ok) + require.Len(t, events, 1) + require.Equal(t, "prompt_too_long", events[0].Kind) +} + +func TestAntigravityMaxRetriesForModel_AfterSwitch(t *testing.T) { + t.Setenv(antigravityMaxRetriesEnv, "4") + t.Setenv(antigravityMaxRetriesAfterSwitchEnv, "7") + t.Setenv(antigravityMaxRetriesClaudeEnv, "") + t.Setenv(antigravityMaxRetriesGeminiTextEnv, "") + t.Setenv(antigravityMaxRetriesGeminiImageEnv, "") + + got := antigravityMaxRetriesForModel("claude-sonnet-4-5", false) + require.Equal(t, 4, got) + + got = antigravityMaxRetriesForModel("claude-sonnet-4-5", true) + require.Equal(t, 7, got) +} + +func TestAntigravityMaxRetriesForModel_AfterSwitchFallback(t *testing.T) { + t.Setenv(antigravityMaxRetriesEnv, "5") + t.Setenv(antigravityMaxRetriesAfterSwitchEnv, "") + t.Setenv(antigravityMaxRetriesClaudeEnv, "") + t.Setenv(antigravityMaxRetriesGeminiTextEnv, "") + t.Setenv(antigravityMaxRetriesGeminiImageEnv, "") + + got := antigravityMaxRetriesForModel("gemini-2.5-flash", true) + require.Equal(t, 5, got) +} diff --git a/backend/internal/service/antigravity_quota_scope.go b/backend/internal/service/antigravity_quota_scope.go index 34cd9a4c..e1a0a1f2 100644 --- a/backend/internal/service/antigravity_quota_scope.go +++ b/backend/internal/service/antigravity_quota_scope.go @@ -1,6 +1,7 @@ package service import ( + "slices" "strings" "time" ) @@ -16,6 +17,21 @@ const ( AntigravityQuotaScopeGeminiImage AntigravityQuotaScope = "gemini_image" ) +// IsScopeSupported 检查给定的 scope 是否在分组支持的 scope 列表中 +func IsScopeSupported(supportedScopes []string, scope AntigravityQuotaScope) bool { + if len(supportedScopes) == 0 { + // 未配置时默认全部支持 + return true + } + supported := slices.Contains(supportedScopes, string(scope)) + return supported +} + +// ResolveAntigravityQuotaScope 根据模型名称解析配额域(导出版本) +func ResolveAntigravityQuotaScope(requestedModel string) (AntigravityQuotaScope, bool) { + return resolveAntigravityQuotaScope(requestedModel) +} + // resolveAntigravityQuotaScope 根据模型名称解析配额域 func resolveAntigravityQuotaScope(requestedModel string) (AntigravityQuotaScope, bool) { model := normalizeAntigravityModelName(requestedModel) diff --git a/backend/internal/service/api_key.go b/backend/internal/service/api_key.go index 8c692d09..d66059dd 100644 --- a/backend/internal/service/api_key.go +++ b/backend/internal/service/api_key.go @@ -2,6 +2,14 @@ package service import "time" +// API Key status constants +const ( + StatusAPIKeyActive = "active" + StatusAPIKeyDisabled = "disabled" + StatusAPIKeyQuotaExhausted = "quota_exhausted" + StatusAPIKeyExpired = "expired" +) + type APIKey struct { ID int64 UserID int64 @@ -15,8 +23,53 @@ type APIKey struct { UpdatedAt time.Time User *User Group *Group + + // Quota fields + Quota float64 // Quota limit in USD (0 = unlimited) + QuotaUsed float64 // Used quota amount + ExpiresAt *time.Time // Expiration time (nil = never expires) } func (k *APIKey) IsActive() bool { return k.Status == StatusActive } + +// IsExpired checks if the API key has expired +func (k *APIKey) IsExpired() bool { + if k.ExpiresAt == nil { + return false + } + return time.Now().After(*k.ExpiresAt) +} + +// IsQuotaExhausted checks if the API key quota is exhausted +func (k *APIKey) IsQuotaExhausted() bool { + if k.Quota <= 0 { + return false // unlimited + } + return k.QuotaUsed >= k.Quota +} + +// GetQuotaRemaining returns remaining quota (-1 for unlimited) +func (k *APIKey) GetQuotaRemaining() float64 { + if k.Quota <= 0 { + return -1 // unlimited + } + remaining := k.Quota - k.QuotaUsed + if remaining < 0 { + return 0 + } + return remaining +} + +// GetDaysUntilExpiry returns days until expiry (-1 for never expires) +func (k *APIKey) GetDaysUntilExpiry() int { + if k.ExpiresAt == nil { + return -1 // never expires + } + duration := time.Until(*k.ExpiresAt) + if duration < 0 { + return 0 + } + return int(duration.Hours() / 24) +} diff --git a/backend/internal/service/api_key_auth_cache.go b/backend/internal/service/api_key_auth_cache.go index 9d8f87f2..4240be23 100644 --- a/backend/internal/service/api_key_auth_cache.go +++ b/backend/internal/service/api_key_auth_cache.go @@ -1,5 +1,7 @@ package service +import "time" + // APIKeyAuthSnapshot API Key 认证缓存快照(仅包含认证所需字段) type APIKeyAuthSnapshot struct { APIKeyID int64 `json:"api_key_id"` @@ -10,6 +12,13 @@ type APIKeyAuthSnapshot struct { IPBlacklist []string `json:"ip_blacklist,omitempty"` User APIKeyAuthUserSnapshot `json:"user"` Group *APIKeyAuthGroupSnapshot `json:"group,omitempty"` + + // Quota fields for API Key independent quota feature + Quota float64 `json:"quota"` // Quota limit in USD (0 = unlimited) + QuotaUsed float64 `json:"quota_used"` // Used quota amount + + // Expiration field for API Key expiration feature + ExpiresAt *time.Time `json:"expires_at,omitempty"` // Expiration time (nil = never expires) } // APIKeyAuthUserSnapshot 用户快照 @@ -23,29 +32,34 @@ type APIKeyAuthUserSnapshot struct { // APIKeyAuthGroupSnapshot 分组快照 type APIKeyAuthGroupSnapshot struct { - ID int64 `json:"id"` - Name string `json:"name"` - Platform string `json:"platform"` - Status string `json:"status"` - SubscriptionType string `json:"subscription_type"` - RateMultiplier float64 `json:"rate_multiplier"` - DailyLimitUSD *float64 `json:"daily_limit_usd,omitempty"` - WeeklyLimitUSD *float64 `json:"weekly_limit_usd,omitempty"` - MonthlyLimitUSD *float64 `json:"monthly_limit_usd,omitempty"` - ImagePrice1K *float64 `json:"image_price_1k,omitempty"` - ImagePrice2K *float64 `json:"image_price_2k,omitempty"` - ImagePrice4K *float64 `json:"image_price_4k,omitempty"` - SoraImagePrice360 *float64 `json:"sora_image_price_360,omitempty"` - SoraImagePrice540 *float64 `json:"sora_image_price_540,omitempty"` - SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request,omitempty"` - SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd,omitempty"` - ClaudeCodeOnly bool `json:"claude_code_only"` - FallbackGroupID *int64 `json:"fallback_group_id,omitempty"` + ID int64 `json:"id"` + Name string `json:"name"` + Platform string `json:"platform"` + Status string `json:"status"` + SubscriptionType string `json:"subscription_type"` + RateMultiplier float64 `json:"rate_multiplier"` + DailyLimitUSD *float64 `json:"daily_limit_usd,omitempty"` + WeeklyLimitUSD *float64 `json:"weekly_limit_usd,omitempty"` + MonthlyLimitUSD *float64 `json:"monthly_limit_usd,omitempty"` + ImagePrice1K *float64 `json:"image_price_1k,omitempty"` + ImagePrice2K *float64 `json:"image_price_2k,omitempty"` + ImagePrice4K *float64 `json:"image_price_4k,omitempty"` + SoraImagePrice360 *float64 `json:"sora_image_price_360,omitempty"` + SoraImagePrice540 *float64 `json:"sora_image_price_540,omitempty"` + SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request,omitempty"` + SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd,omitempty"` + ClaudeCodeOnly bool `json:"claude_code_only"` + FallbackGroupID *int64 `json:"fallback_group_id,omitempty"` + FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request,omitempty"` // Model routing is used by gateway account selection, so it must be part of auth cache snapshot. // Only anthropic groups use these fields; others may leave them empty. ModelRouting map[string][]int64 `json:"model_routing,omitempty"` ModelRoutingEnabled bool `json:"model_routing_enabled"` + MCPXMLInject bool `json:"mcp_xml_inject"` + + // 支持的模型系列(仅 antigravity 平台使用) + SupportedModelScopes []string `json:"supported_model_scopes,omitempty"` } // APIKeyAuthCacheEntry 缓存条目,支持负缓存 diff --git a/backend/internal/service/api_key_auth_cache_impl.go b/backend/internal/service/api_key_auth_cache_impl.go index 19ba4e79..f266a12b 100644 --- a/backend/internal/service/api_key_auth_cache_impl.go +++ b/backend/internal/service/api_key_auth_cache_impl.go @@ -213,6 +213,9 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot { Status: apiKey.Status, IPWhitelist: apiKey.IPWhitelist, IPBlacklist: apiKey.IPBlacklist, + Quota: apiKey.Quota, + QuotaUsed: apiKey.QuotaUsed, + ExpiresAt: apiKey.ExpiresAt, User: APIKeyAuthUserSnapshot{ ID: apiKey.User.ID, Status: apiKey.User.Status, @@ -223,26 +226,29 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot { } if apiKey.Group != nil { snapshot.Group = &APIKeyAuthGroupSnapshot{ - ID: apiKey.Group.ID, - Name: apiKey.Group.Name, - Platform: apiKey.Group.Platform, - Status: apiKey.Group.Status, - SubscriptionType: apiKey.Group.SubscriptionType, - RateMultiplier: apiKey.Group.RateMultiplier, - DailyLimitUSD: apiKey.Group.DailyLimitUSD, - WeeklyLimitUSD: apiKey.Group.WeeklyLimitUSD, - MonthlyLimitUSD: apiKey.Group.MonthlyLimitUSD, - ImagePrice1K: apiKey.Group.ImagePrice1K, - ImagePrice2K: apiKey.Group.ImagePrice2K, - ImagePrice4K: apiKey.Group.ImagePrice4K, - SoraImagePrice360: apiKey.Group.SoraImagePrice360, - SoraImagePrice540: apiKey.Group.SoraImagePrice540, - SoraVideoPricePerRequest: apiKey.Group.SoraVideoPricePerRequest, - SoraVideoPricePerRequestHD: apiKey.Group.SoraVideoPricePerRequestHD, - ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly, - FallbackGroupID: apiKey.Group.FallbackGroupID, - ModelRouting: apiKey.Group.ModelRouting, - ModelRoutingEnabled: apiKey.Group.ModelRoutingEnabled, + ID: apiKey.Group.ID, + Name: apiKey.Group.Name, + Platform: apiKey.Group.Platform, + Status: apiKey.Group.Status, + SubscriptionType: apiKey.Group.SubscriptionType, + RateMultiplier: apiKey.Group.RateMultiplier, + DailyLimitUSD: apiKey.Group.DailyLimitUSD, + WeeklyLimitUSD: apiKey.Group.WeeklyLimitUSD, + MonthlyLimitUSD: apiKey.Group.MonthlyLimitUSD, + ImagePrice1K: apiKey.Group.ImagePrice1K, + ImagePrice2K: apiKey.Group.ImagePrice2K, + ImagePrice4K: apiKey.Group.ImagePrice4K, + SoraImagePrice360: apiKey.Group.SoraImagePrice360, + SoraImagePrice540: apiKey.Group.SoraImagePrice540, + SoraVideoPricePerRequest: apiKey.Group.SoraVideoPricePerRequest, + SoraVideoPricePerRequestHD: apiKey.Group.SoraVideoPricePerRequestHD, + ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly, + FallbackGroupID: apiKey.Group.FallbackGroupID, + FallbackGroupIDOnInvalidRequest: apiKey.Group.FallbackGroupIDOnInvalidRequest, + ModelRouting: apiKey.Group.ModelRouting, + ModelRoutingEnabled: apiKey.Group.ModelRoutingEnabled, + MCPXMLInject: apiKey.Group.MCPXMLInject, + SupportedModelScopes: apiKey.Group.SupportedModelScopes, } } return snapshot @@ -260,6 +266,9 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho Status: snapshot.Status, IPWhitelist: snapshot.IPWhitelist, IPBlacklist: snapshot.IPBlacklist, + Quota: snapshot.Quota, + QuotaUsed: snapshot.QuotaUsed, + ExpiresAt: snapshot.ExpiresAt, User: &User{ ID: snapshot.User.ID, Status: snapshot.User.Status, @@ -270,27 +279,30 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho } if snapshot.Group != nil { apiKey.Group = &Group{ - ID: snapshot.Group.ID, - Name: snapshot.Group.Name, - Platform: snapshot.Group.Platform, - Status: snapshot.Group.Status, - Hydrated: true, - SubscriptionType: snapshot.Group.SubscriptionType, - RateMultiplier: snapshot.Group.RateMultiplier, - DailyLimitUSD: snapshot.Group.DailyLimitUSD, - WeeklyLimitUSD: snapshot.Group.WeeklyLimitUSD, - MonthlyLimitUSD: snapshot.Group.MonthlyLimitUSD, - ImagePrice1K: snapshot.Group.ImagePrice1K, - ImagePrice2K: snapshot.Group.ImagePrice2K, - ImagePrice4K: snapshot.Group.ImagePrice4K, - SoraImagePrice360: snapshot.Group.SoraImagePrice360, - SoraImagePrice540: snapshot.Group.SoraImagePrice540, - SoraVideoPricePerRequest: snapshot.Group.SoraVideoPricePerRequest, - SoraVideoPricePerRequestHD: snapshot.Group.SoraVideoPricePerRequestHD, - ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly, - FallbackGroupID: snapshot.Group.FallbackGroupID, - ModelRouting: snapshot.Group.ModelRouting, - ModelRoutingEnabled: snapshot.Group.ModelRoutingEnabled, + ID: snapshot.Group.ID, + Name: snapshot.Group.Name, + Platform: snapshot.Group.Platform, + Status: snapshot.Group.Status, + Hydrated: true, + SubscriptionType: snapshot.Group.SubscriptionType, + RateMultiplier: snapshot.Group.RateMultiplier, + DailyLimitUSD: snapshot.Group.DailyLimitUSD, + WeeklyLimitUSD: snapshot.Group.WeeklyLimitUSD, + MonthlyLimitUSD: snapshot.Group.MonthlyLimitUSD, + ImagePrice1K: snapshot.Group.ImagePrice1K, + ImagePrice2K: snapshot.Group.ImagePrice2K, + ImagePrice4K: snapshot.Group.ImagePrice4K, + SoraImagePrice360: snapshot.Group.SoraImagePrice360, + SoraImagePrice540: snapshot.Group.SoraImagePrice540, + SoraVideoPricePerRequest: snapshot.Group.SoraVideoPricePerRequest, + SoraVideoPricePerRequestHD: snapshot.Group.SoraVideoPricePerRequestHD, + ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly, + FallbackGroupID: snapshot.Group.FallbackGroupID, + FallbackGroupIDOnInvalidRequest: snapshot.Group.FallbackGroupIDOnInvalidRequest, + ModelRouting: snapshot.Group.ModelRouting, + ModelRoutingEnabled: snapshot.Group.ModelRoutingEnabled, + MCPXMLInject: snapshot.Group.MCPXMLInject, + SupportedModelScopes: snapshot.Group.SupportedModelScopes, } } return apiKey diff --git a/backend/internal/service/api_key_service.go b/backend/internal/service/api_key_service.go index ef1ff990..b27682f3 100644 --- a/backend/internal/service/api_key_service.go +++ b/backend/internal/service/api_key_service.go @@ -24,6 +24,10 @@ var ( ErrAPIKeyInvalidChars = infraerrors.BadRequest("API_KEY_INVALID_CHARS", "api key can only contain letters, numbers, underscores, and hyphens") ErrAPIKeyRateLimited = infraerrors.TooManyRequests("API_KEY_RATE_LIMITED", "too many failed attempts, please try again later") ErrInvalidIPPattern = infraerrors.BadRequest("INVALID_IP_PATTERN", "invalid IP or CIDR pattern") + // ErrAPIKeyExpired = infraerrors.Forbidden("API_KEY_EXPIRED", "api key has expired") + ErrAPIKeyExpired = infraerrors.Forbidden("API_KEY_EXPIRED", "api key 已过期") + // ErrAPIKeyQuotaExhausted = infraerrors.TooManyRequests("API_KEY_QUOTA_EXHAUSTED", "api key quota exhausted") + ErrAPIKeyQuotaExhausted = infraerrors.TooManyRequests("API_KEY_QUOTA_EXHAUSTED", "api key 额度已用完") ) const ( @@ -51,6 +55,9 @@ type APIKeyRepository interface { CountByGroupID(ctx context.Context, groupID int64) (int64, error) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) + + // Quota methods + IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error) } // APIKeyCache defines cache operations for API key service @@ -85,6 +92,10 @@ type CreateAPIKeyRequest struct { CustomKey *string `json:"custom_key"` // 可选的自定义key IPWhitelist []string `json:"ip_whitelist"` // IP 白名单 IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单 + + // Quota fields + Quota float64 `json:"quota"` // Quota limit in USD (0 = unlimited) + ExpiresInDays *int `json:"expires_in_days"` // Days until expiry (nil = never expires) } // UpdateAPIKeyRequest 更新API Key请求 @@ -94,6 +105,12 @@ type UpdateAPIKeyRequest struct { Status *string `json:"status"` IPWhitelist []string `json:"ip_whitelist"` // IP 白名单(空数组清空) IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单(空数组清空) + + // Quota fields + Quota *float64 `json:"quota"` // Quota limit in USD (nil = no change, 0 = unlimited) + ExpiresAt *time.Time `json:"expires_at"` // Expiration time (nil = no change) + ClearExpiration bool `json:"-"` // Clear expiration (internal use) + ResetQuota *bool `json:"reset_quota"` // Reset quota_used to 0 } // APIKeyService API Key服务 @@ -289,6 +306,14 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK Status: StatusActive, IPWhitelist: req.IPWhitelist, IPBlacklist: req.IPBlacklist, + Quota: req.Quota, + QuotaUsed: 0, + } + + // Set expiration time if specified + if req.ExpiresInDays != nil && *req.ExpiresInDays > 0 { + expiresAt := time.Now().AddDate(0, 0, *req.ExpiresInDays) + apiKey.ExpiresAt = &expiresAt } if err := s.apiKeyRepo.Create(ctx, apiKey); err != nil { @@ -436,6 +461,35 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req } } + // Update quota fields + if req.Quota != nil { + apiKey.Quota = *req.Quota + // If quota is increased and status was quota_exhausted, reactivate + if apiKey.Status == StatusAPIKeyQuotaExhausted && *req.Quota > apiKey.QuotaUsed { + apiKey.Status = StatusActive + } + } + if req.ResetQuota != nil && *req.ResetQuota { + apiKey.QuotaUsed = 0 + // If resetting quota and status was quota_exhausted, reactivate + if apiKey.Status == StatusAPIKeyQuotaExhausted { + apiKey.Status = StatusActive + } + } + if req.ClearExpiration { + apiKey.ExpiresAt = nil + // If clearing expiry and status was expired, reactivate + if apiKey.Status == StatusAPIKeyExpired { + apiKey.Status = StatusActive + } + } else if req.ExpiresAt != nil { + apiKey.ExpiresAt = req.ExpiresAt + // If extending expiry and status was expired, reactivate + if apiKey.Status == StatusAPIKeyExpired && time.Now().Before(*req.ExpiresAt) { + apiKey.Status = StatusActive + } + } + // 更新 IP 限制(空数组会清空设置) apiKey.IPWhitelist = req.IPWhitelist apiKey.IPBlacklist = req.IPBlacklist @@ -572,3 +626,51 @@ func (s *APIKeyService) SearchAPIKeys(ctx context.Context, userID int64, keyword } return keys, nil } + +// CheckAPIKeyQuotaAndExpiry checks if the API key is valid for use (not expired, quota not exhausted) +// Returns nil if valid, error if invalid +func (s *APIKeyService) CheckAPIKeyQuotaAndExpiry(apiKey *APIKey) error { + // Check expiration + if apiKey.IsExpired() { + return ErrAPIKeyExpired + } + + // Check quota + if apiKey.IsQuotaExhausted() { + return ErrAPIKeyQuotaExhausted + } + + return nil +} + +// UpdateQuotaUsed updates the quota_used field after a request +// Also checks if quota is exhausted and updates status accordingly +func (s *APIKeyService) UpdateQuotaUsed(ctx context.Context, apiKeyID int64, cost float64) error { + if cost <= 0 { + return nil + } + + // Use repository to atomically increment quota_used + newQuotaUsed, err := s.apiKeyRepo.IncrementQuotaUsed(ctx, apiKeyID, cost) + if err != nil { + return fmt.Errorf("increment quota used: %w", err) + } + + // Check if quota is now exhausted and update status if needed + apiKey, err := s.apiKeyRepo.GetByID(ctx, apiKeyID) + if err != nil { + return nil // Don't fail the request, just log + } + + // If quota is set and now exhausted, update status + if apiKey.Quota > 0 && newQuotaUsed >= apiKey.Quota { + apiKey.Status = StatusAPIKeyQuotaExhausted + if err := s.apiKeyRepo.Update(ctx, apiKey); err != nil { + return nil // Don't fail the request + } + // Invalidate cache so next request sees the new status + s.InvalidateAuthCacheByKey(ctx, apiKey.Key) + } + + return nil +} diff --git a/backend/internal/service/api_key_service_cache_test.go b/backend/internal/service/api_key_service_cache_test.go index c5e9cd47..1099b1d2 100644 --- a/backend/internal/service/api_key_service_cache_test.go +++ b/backend/internal/service/api_key_service_cache_test.go @@ -99,6 +99,10 @@ func (s *authRepoStub) ListKeysByGroupID(ctx context.Context, groupID int64) ([] return s.listKeysByGroupID(ctx, groupID) } +func (s *authRepoStub) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error) { + panic("unexpected IncrementQuotaUsed call") +} + type authCacheStub struct { getAuthCache func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) setAuthKeys []string diff --git a/backend/internal/service/api_key_service_delete_test.go b/backend/internal/service/api_key_service_delete_test.go index 092b7fce..d4d12144 100644 --- a/backend/internal/service/api_key_service_delete_test.go +++ b/backend/internal/service/api_key_service_delete_test.go @@ -118,6 +118,10 @@ func (s *apiKeyRepoStub) ListKeysByGroupID(ctx context.Context, groupID int64) ( panic("unexpected ListKeysByGroupID call") } +func (s *apiKeyRepoStub) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error) { + panic("unexpected IncrementQuotaUsed call") +} + // apiKeyCacheStub 是 APIKeyCache 接口的测试桩实现。 // 用于验证删除操作时缓存清理逻辑是否被正确调用。 // diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go index c824ec1e..25604d2c 100644 --- a/backend/internal/service/auth_service.go +++ b/backend/internal/service/auth_service.go @@ -185,7 +185,6 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw log.Printf("[Auth] Failed to mark invitation code as used for user %d: %v", user.ID, err) } } - // 应用优惠码(如果提供且功能已启用) if promoCode != "" && s.promoService != nil && s.settingService != nil && s.settingService.IsPromoCodeEnabled(ctx) { if err := s.promoService.ApplyPromoCode(ctx, user.ID, promoCode); err != nil { diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go index 0f4f2be0..8a4f69b8 100644 --- a/backend/internal/service/domain_constants.go +++ b/backend/internal/service/domain_constants.go @@ -32,6 +32,7 @@ const ( AccountTypeOAuth = domain.AccountTypeOAuth // OAuth类型账号(full scope: profile + inference) AccountTypeSetupToken = domain.AccountTypeSetupToken // Setup Token类型账号(inference only scope) AccountTypeAPIKey = domain.AccountTypeAPIKey // API Key类型账号 + AccountTypeUpstream = domain.AccountTypeUpstream // 上游透传类型账号(通过 Base URL + API Key 连接上游) ) // Redeem type constants diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 5058f265..e188668d 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -257,6 +257,9 @@ var ( // ErrClaudeCodeOnly 表示分组仅允许 Claude Code 客户端访问 var ErrClaudeCodeOnly = errors.New("this group only allows Claude Code clients") +// ErrModelScopeNotSupported 表示请求的模型系列不在分组支持的范围内 +var ErrModelScopeNotSupported = errors.New("model scope not supported by this group") + // allowedHeaders 白名单headers(参考CRS项目) var allowedHeaders = map[string]bool{ "accept": true, @@ -589,12 +592,18 @@ func (s *GatewayService) hashContent(content string) string { } // replaceModelInBody 替换请求体中的model字段 +// 使用 json.RawMessage 保留其他字段的原始字节,避免 thinking 块等内容被修改 func (s *GatewayService) replaceModelInBody(body []byte, newModel string) []byte { - var req map[string]any + var req map[string]json.RawMessage if err := json.Unmarshal(body, &req); err != nil { return body } - req["model"] = newModel + // 只序列化 model 字段 + modelBytes, err := json.Marshal(newModel) + if err != nil { + return body + } + req["model"] = modelBytes newBody, err := json.Marshal(req) if err != nil { return body @@ -791,12 +800,21 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu if len(body) == 0 { return body, modelID, nil } + + // 使用 json.RawMessage 保留 messages 的原始字节,避免 thinking 块被修改 + var reqRaw map[string]json.RawMessage + if err := json.Unmarshal(body, &reqRaw); err != nil { + return body, modelID, nil + } + + // 同时解析为 map[string]any 用于修改非 messages 字段 var req map[string]any if err := json.Unmarshal(body, &req); err != nil { return body, modelID, nil } toolNameMap := make(map[string]string) + modified := false if system, ok := req["system"]; ok { switch v := system.(type) { @@ -804,6 +822,7 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu sanitized := sanitizeSystemText(v) if sanitized != v { req["system"] = sanitized + modified = true } case []any: for _, item := range v { @@ -821,6 +840,7 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu sanitized := sanitizeSystemText(text) if sanitized != text { block["text"] = sanitized + modified = true } } } @@ -831,6 +851,7 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu if normalized != rawModel { req["model"] = normalized modelID = normalized + modified = true } } @@ -846,16 +867,19 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu normalized := normalizeToolNameForClaude(name, toolNameMap) if normalized != "" && normalized != name { toolMap["name"] = normalized + modified = true } } if desc, ok := toolMap["description"].(string); ok { sanitized := sanitizeToolDescription(desc) if sanitized != desc { toolMap["description"] = sanitized + modified = true } } if schema, ok := toolMap["input_schema"]; ok { normalizeToolInputSchema(schema, toolNameMap) + modified = true } tools[idx] = toolMap } @@ -884,11 +908,15 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu normalizedTools[normalized] = value } req["tools"] = normalizedTools + modified = true } } else { req["tools"] = []any{} + modified = true } + // 处理 messages 中的 tool_use 块,但保留包含 thinking 块的消息的原始字节 + messagesModified := false if messages, ok := req["messages"].([]any); ok { for _, msg := range messages { msgMap, ok := msg.(map[string]any) @@ -899,6 +927,24 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu if !ok { continue } + // 检查此消息是否包含 thinking 块 + hasThinking := false + for _, block := range content { + blockMap, ok := block.(map[string]any) + if !ok { + continue + } + blockType, _ := blockMap["type"].(string) + if blockType == "thinking" || blockType == "redacted_thinking" { + hasThinking = true + break + } + } + // 如果包含 thinking 块,跳过此消息的修改 + if hasThinking { + continue + } + // 只修改不包含 thinking 块的消息中的 tool_use for _, block := range content { blockMap, ok := block.(map[string]any) if !ok { @@ -911,6 +957,7 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu normalized := normalizeToolNameForClaude(name, toolNameMap) if normalized != "" && normalized != name { blockMap["name"] = normalized + messagesModified = true } } } @@ -920,6 +967,7 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu if opts.stripSystemCacheControl { if system, ok := req["system"]; ok { _ = stripCacheControlFromSystemBlocks(system) + modified = true } } @@ -931,12 +979,46 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu } if existing, ok := metadata["user_id"].(string); !ok || existing == "" { metadata["user_id"] = opts.metadataUserID + modified = true } } - delete(req, "temperature") - delete(req, "tool_choice") + if _, hasTemp := req["temperature"]; hasTemp { + delete(req, "temperature") + modified = true + } + if _, hasChoice := req["tool_choice"]; hasChoice { + delete(req, "tool_choice") + modified = true + } + if !modified && !messagesModified { + return body, modelID, toolNameMap + } + + // 如果 messages 没有被修改,保留原始 messages 字节 + if !messagesModified { + // 序列化非 messages 字段 + newBody, err := json.Marshal(req) + if err != nil { + return body, modelID, toolNameMap + } + // 替换回原始的 messages + var newReq map[string]json.RawMessage + if err := json.Unmarshal(newBody, &newReq); err != nil { + return newBody, modelID, toolNameMap + } + if origMessages, ok := reqRaw["messages"]; ok { + newReq["messages"] = origMessages + } + finalBody, err := json.Marshal(newReq) + if err != nil { + return newBody, modelID, toolNameMap + } + return finalBody, modelID, toolNameMap + } + + // messages 被修改了,需要完整序列化 newBody, err := json.Marshal(req) if err != nil { return body, modelID, toolNameMap @@ -1139,6 +1221,13 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro log.Printf("[ModelRoutingDebug] load-aware enabled: group_id=%v model=%s session=%s platform=%s", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), platform) } + // Antigravity 模型系列检查(在账号选择前检查,确保所有代码路径都经过此检查) + if platform == PlatformAntigravity && groupID != nil && requestedModel != "" { + if err := s.checkAntigravityModelScope(ctx, *groupID, requestedModel); err != nil { + return nil, err + } + } + accounts, useMixed, err := s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform) if err != nil { return nil, err @@ -1636,6 +1725,10 @@ func (s *GatewayService) resolveGroupByID(ctx context.Context, groupID int64) (* return group, nil } +func (s *GatewayService) ResolveGroupByID(ctx context.Context, groupID int64) (*Group, error) { + return s.resolveGroupByID(ctx, groupID) +} + func (s *GatewayService) routingAccountIDsForRequest(ctx context.Context, groupID *int64, requestedModel string, platform string) []int64 { if groupID == nil || requestedModel == "" || platform != PlatformAnthropic { return nil @@ -1701,7 +1794,7 @@ func (s *GatewayService) checkClaudeCodeRestriction(ctx context.Context, groupID } // 强制平台模式不检查 Claude Code 限制 - if _, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string); hasForcePlatform { + if forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string); hasForcePlatform && forcePlatform != "" { return nil, groupID, nil } @@ -2030,6 +2123,13 @@ func shuffleWithinPriority(accounts []*Account) { // selectAccountForModelWithPlatform 选择单平台账户(完全隔离) func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) { + // 对 Antigravity 平台,检查请求的模型系列是否在分组支持范围内 + if platform == PlatformAntigravity && groupID != nil && requestedModel != "" { + if err := s.checkAntigravityModelScope(ctx, *groupID, requestedModel); err != nil { + return nil, err + } + } + preferOAuth := platform == PlatformGemini routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, platform) @@ -2465,6 +2565,10 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo // Antigravity 平台使用专门的模型支持检查 return IsAntigravityModelSupported(requestedModel) } + // OAuth/SetupToken 账号使用 Anthropic 标准映射(短ID → 长ID) + if account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey { + requestedModel = claude.NormalizeModelID(requestedModel) + } // Gemini API Key 账户直接透传,由上游判断模型是否支持 if account.Platform == PlatformGemini && account.Type == AccountTypeAPIKey { return true @@ -2914,16 +3018,30 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A // 强制执行 cache_control 块数量限制(最多 4 个) body = enforceCacheControlLimit(body) - // 应用模型映射(仅对apikey类型账号) + // 应用模型映射: + // - APIKey 账号:使用账号级别的显式映射(如果配置),否则透传原始模型名 + // - OAuth/SetupToken 账号:使用 Anthropic 标准映射(短ID → 长ID) + mappedModel := reqModel + mappingSource := "" if account.Type == AccountTypeAPIKey { - mappedModel := account.GetMappedModel(reqModel) + mappedModel = account.GetMappedModel(reqModel) if mappedModel != reqModel { - // 替换请求体中的模型名 - body = s.replaceModelInBody(body, mappedModel) - reqModel = mappedModel - log.Printf("Model mapping applied: %s -> %s (account: %s)", originalModel, mappedModel, account.Name) + mappingSource = "account" } } + if mappingSource == "" && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey { + normalized := claude.NormalizeModelID(reqModel) + if normalized != reqModel { + mappedModel = normalized + mappingSource = "prefix" + } + } + if mappedModel != reqModel { + // 替换请求体中的模型名 + body = s.replaceModelInBody(body, mappedModel) + reqModel = mappedModel + log.Printf("Model mapping applied: %s -> %s (account: %s, source=%s)", originalModel, mappedModel, account.Name, mappingSource) + } // 获取凭证 token, tokenType, err := s.GetAccessToken(ctx, account) @@ -3625,6 +3743,13 @@ func (s *GatewayService) isThinkingBlockSignatureError(respBody []byte) bool { return true } + // 检测 thinking block 被修改的错误 + // 例如: "thinking or redacted_thinking blocks in the latest assistant message cannot be modified" + if strings.Contains(msg, "cannot be modified") && (strings.Contains(msg, "thinking") || strings.Contains(msg, "redacted_thinking")) { + log.Printf("[SignatureCheck] Detected thinking block modification error") + return true + } + // 检测空消息内容错误(可能是过滤 thinking blocks 后导致的) // 例如: "all messages must have non-empty content" if strings.Contains(msg, "non-empty content") || strings.Contains(msg, "empty content") { @@ -4493,13 +4618,19 @@ func (s *GatewayService) replaceToolNamesInResponseBody(body []byte, toolNameMap // RecordUsageInput 记录使用量的输入参数 type RecordUsageInput struct { - Result *ForwardResult - APIKey *APIKey - User *User - Account *Account - Subscription *UserSubscription // 可选:订阅信息 - UserAgent string // 请求的 User-Agent - IPAddress string // 请求的客户端 IP 地址 + Result *ForwardResult + APIKey *APIKey + User *User + Account *Account + Subscription *UserSubscription // 可选:订阅信息 + UserAgent string // 请求的 User-Agent + IPAddress string // 请求的客户端 IP 地址 + APIKeyService APIKeyQuotaUpdater // 可选:用于更新API Key配额 +} + +// APIKeyQuotaUpdater defines the interface for updating API Key quota +type APIKeyQuotaUpdater interface { + UpdateQuotaUsed(ctx context.Context, apiKeyID int64, cost float64) error } // RecordUsage 记录使用量并扣费(或更新订阅用量) @@ -4661,6 +4792,13 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu } } + // 更新 API Key 配额(如果设置了配额限制) + if shouldBill && cost.ActualCost > 0 && apiKey.Quota > 0 && input.APIKeyService != nil { + if err := input.APIKeyService.UpdateQuotaUsed(ctx, apiKey.ID, cost.ActualCost); err != nil { + log.Printf("Update API key quota failed: %v", err) + } + } + // Schedule batch update for account last_used_at s.deferredService.ScheduleLastUsedUpdate(account.ID) @@ -4678,6 +4816,7 @@ type RecordUsageLongContextInput struct { IPAddress string // 请求的客户端 IP 地址 LongContextThreshold int // 长上下文阈值(如 200000) LongContextMultiplier float64 // 超出阈值部分的倍率(如 2.0) + APIKeyService *APIKeyService // API Key 配额服务(可选) } // RecordUsageWithLongContext 记录使用量并扣费,支持长上下文双倍计费(用于 Gemini) @@ -4814,6 +4953,12 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * } // 异步更新余额缓存 s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost) + // API Key 独立配额扣费 + if input.APIKeyService != nil && apiKey.Quota > 0 { + if err := input.APIKeyService.UpdateQuotaUsed(ctx, apiKey.ID, cost.ActualCost); err != nil { + log.Printf("Add API key quota used failed: %v", err) + } + } } } @@ -4848,16 +4993,30 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, return nil } - // 应用模型映射(仅对 apikey 类型账号) - if account.Type == AccountTypeAPIKey { - if reqModel != "" { - mappedModel := account.GetMappedModel(reqModel) + // 应用模型映射: + // - APIKey 账号:使用账号级别的显式映射(如果配置),否则透传原始模型名 + // - OAuth/SetupToken 账号:使用 Anthropic 标准映射(短ID → 长ID) + if reqModel != "" { + mappedModel := reqModel + mappingSource := "" + if account.Type == AccountTypeAPIKey { + mappedModel = account.GetMappedModel(reqModel) if mappedModel != reqModel { - body = s.replaceModelInBody(body, mappedModel) - reqModel = mappedModel - log.Printf("CountTokens model mapping applied: %s -> %s (account: %s)", parsed.Model, mappedModel, account.Name) + mappingSource = "account" } } + if mappingSource == "" && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey { + normalized := claude.NormalizeModelID(reqModel) + if normalized != reqModel { + mappedModel = normalized + mappingSource = "prefix" + } + } + if mappedModel != reqModel { + body = s.replaceModelInBody(body, mappedModel) + reqModel = mappedModel + log.Printf("CountTokens model mapping applied: %s -> %s (account: %s, source=%s)", parsed.Model, mappedModel, account.Name, mappingSource) + } } // 获取凭证 @@ -5109,6 +5268,27 @@ func (s *GatewayService) validateUpstreamBaseURL(raw string) (string, error) { return normalized, nil } +// checkAntigravityModelScope 检查 Antigravity 平台的模型系列是否在分组支持范围内 +func (s *GatewayService) checkAntigravityModelScope(ctx context.Context, groupID int64, requestedModel string) error { + scope, ok := ResolveAntigravityQuotaScope(requestedModel) + if !ok { + return nil // 无法解析 scope,跳过检查 + } + + group, err := s.resolveGroupByID(ctx, groupID) + if err != nil { + return nil // 查询失败时放行 + } + if group == nil { + return nil // 分组不存在时放行 + } + + if !IsScopeSupported(group.SupportedModelScopes, scope) { + return ErrModelScopeNotSupported + } + return nil +} + // GetAvailableModels returns the list of models available for a group // It aggregates model_mapping keys from all schedulable accounts in the group func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64, platform string) []string { diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go index 2d2e86d5..bd322991 100644 --- a/backend/internal/service/gemini_messages_compat_service.go +++ b/backend/internal/service/gemini_messages_compat_service.go @@ -977,6 +977,11 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. return nil, s.writeGoogleError(c, http.StatusBadRequest, "Request body is empty") } + // 过滤掉 parts 为空的消息(Gemini API 不接受空 parts) + if filteredBody, err := filterEmptyPartsFromGeminiRequest(body); err == nil { + body = filteredBody + } + switch action { case "generateContent", "streamGenerateContent", "countTokens": // ok diff --git a/backend/internal/service/gemini_native_signature_cleaner.go b/backend/internal/service/gemini_native_signature_cleaner.go index b3352fb0..d43fb445 100644 --- a/backend/internal/service/gemini_native_signature_cleaner.go +++ b/backend/internal/service/gemini_native_signature_cleaner.go @@ -2,20 +2,22 @@ package service import ( "encoding/json" + + "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" ) -// CleanGeminiNativeThoughtSignatures 从 Gemini 原生 API 请求中移除 thoughtSignature 字段, +// CleanGeminiNativeThoughtSignatures 从 Gemini 原生 API 请求中替换 thoughtSignature 字段为 dummy 签名, // 以避免跨账号签名验证错误。 // // 当粘性会话切换账号时(例如原账号异常、不可调度等),旧账号返回的 thoughtSignature -// 会导致新账号的签名验证失败。通过移除这些签名,让新账号重新生成有效的签名。 +// 会导致新账号的签名验证失败。通过替换为 dummy 签名,跳过签名验证。 // -// CleanGeminiNativeThoughtSignatures removes thoughtSignature fields from Gemini native API requests -// to avoid cross-account signature validation errors. +// CleanGeminiNativeThoughtSignatures replaces thoughtSignature fields with dummy signature +// in Gemini native API requests to avoid cross-account signature validation errors. // // When sticky session switches accounts (e.g., original account becomes unavailable), // thoughtSignatures from the old account will cause validation failures on the new account. -// By removing these signatures, we allow the new account to generate valid signatures. +// By replacing with dummy signature, we skip signature validation. func CleanGeminiNativeThoughtSignatures(body []byte) []byte { if len(body) == 0 { return body @@ -28,11 +30,11 @@ func CleanGeminiNativeThoughtSignatures(body []byte) []byte { return body } - // 递归清理 thoughtSignature - cleaned := cleanThoughtSignaturesRecursive(data) + // 递归替换 thoughtSignature 为 dummy 签名 + replaced := replaceThoughtSignaturesRecursive(data) // 重新序列化 - result, err := json.Marshal(cleaned) + result, err := json.Marshal(replaced) if err != nil { // 如果序列化失败,返回原始 body return body @@ -41,19 +43,20 @@ func CleanGeminiNativeThoughtSignatures(body []byte) []byte { return result } -// cleanThoughtSignaturesRecursive 递归遍历数据结构,移除所有 thoughtSignature 字段 -func cleanThoughtSignaturesRecursive(data any) any { +// replaceThoughtSignaturesRecursive 递归遍历数据结构,将所有 thoughtSignature 字段替换为 dummy 签名 +func replaceThoughtSignaturesRecursive(data any) any { switch v := data.(type) { case map[string]any: - // 创建新的 map,移除 thoughtSignature + // 创建新的 map,替换 thoughtSignature 为 dummy 签名 result := make(map[string]any, len(v)) for key, value := range v { - // 跳过 thoughtSignature 字段 + // 替换 thoughtSignature 字段为 dummy 签名 if key == "thoughtSignature" { + result[key] = antigravity.DummyThoughtSignature continue } // 递归处理嵌套结构 - result[key] = cleanThoughtSignaturesRecursive(value) + result[key] = replaceThoughtSignaturesRecursive(value) } return result @@ -61,7 +64,7 @@ func cleanThoughtSignaturesRecursive(data any) any { // 递归处理数组中的每个元素 result := make([]any, len(v)) for i, item := range v { - result[i] = cleanThoughtSignaturesRecursive(item) + result[i] = replaceThoughtSignaturesRecursive(item) } return result diff --git a/backend/internal/service/group.go b/backend/internal/service/group.go index e8bf03d4..23880b0b 100644 --- a/backend/internal/service/group.go +++ b/backend/internal/service/group.go @@ -35,6 +35,8 @@ type Group struct { // Claude Code 客户端限制 ClaudeCodeOnly bool FallbackGroupID *int64 + // 无效请求兜底分组(仅 anthropic 平台使用) + FallbackGroupIDOnInvalidRequest *int64 // 模型路由配置 // key: 模型匹配模式(支持 * 通配符,如 "claude-opus-*") @@ -42,6 +44,13 @@ type Group struct { ModelRouting map[string][]int64 ModelRoutingEnabled bool + // MCP XML 协议注入开关(仅 antigravity 平台使用) + MCPXMLInject bool + + // 支持的模型系列(仅 antigravity 平台使用) + // 可选值: claude, gemini_text, gemini_image + SupportedModelScopes []string + CreatedAt time.Time UpdatedAt time.Time diff --git a/backend/internal/service/identity_service.go b/backend/internal/service/identity_service.go index a620ac4d..261da0ef 100644 --- a/backend/internal/service/identity_service.go +++ b/backend/internal/service/identity_service.go @@ -169,22 +169,31 @@ func (s *IdentityService) ApplyFingerprint(req *http.Request, fp *Fingerprint) { // RewriteUserID 重写body中的metadata.user_id // 输入格式:user_{clientId}_account__session_{sessionUUID} // 输出格式:user_{cachedClientID}_account_{accountUUID}_session_{newHash} +// +// 重要:此函数使用 json.RawMessage 保留其他字段的原始字节, +// 避免重新序列化导致 thinking 块等内容被修改。 func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUID, cachedClientID string) ([]byte, error) { if len(body) == 0 || accountUUID == "" || cachedClientID == "" { return body, nil } - // 解析JSON - var reqMap map[string]any + // 使用 RawMessage 保留其他字段的原始字节 + var reqMap map[string]json.RawMessage if err := json.Unmarshal(body, &reqMap); err != nil { return body, nil } - metadata, ok := reqMap["metadata"].(map[string]any) + // 解析 metadata 字段 + metadataRaw, ok := reqMap["metadata"] if !ok { return body, nil } + var metadata map[string]any + if err := json.Unmarshal(metadataRaw, &metadata); err != nil { + return body, nil + } + userID, ok := metadata["user_id"].(string) if !ok || userID == "" { return body, nil @@ -207,7 +216,13 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI newUserID := fmt.Sprintf("user_%s_account_%s_session_%s", cachedClientID, accountUUID, newSessionHash) metadata["user_id"] = newUserID - reqMap["metadata"] = metadata + + // 只重新序列化 metadata 字段 + newMetadataRaw, err := json.Marshal(metadata) + if err != nil { + return body, nil + } + reqMap["metadata"] = newMetadataRaw return json.Marshal(reqMap) } @@ -215,6 +230,9 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI // RewriteUserIDWithMasking 重写body中的metadata.user_id,支持会话ID伪装 // 如果账号启用了会话ID伪装(session_id_masking_enabled), // 则在完成常规重写后,将 session 部分替换为固定的伪装ID(15分钟内保持不变) +// +// 重要:此函数使用 json.RawMessage 保留其他字段的原始字节, +// 避免重新序列化导致 thinking 块等内容被修改。 func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []byte, account *Account, accountUUID, cachedClientID string) ([]byte, error) { // 先执行常规的 RewriteUserID 逻辑 newBody, err := s.RewriteUserID(body, account.ID, accountUUID, cachedClientID) @@ -227,17 +245,23 @@ func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []b return newBody, nil } - // 解析重写后的 body,提取 user_id - var reqMap map[string]any + // 使用 RawMessage 保留其他字段的原始字节 + var reqMap map[string]json.RawMessage if err := json.Unmarshal(newBody, &reqMap); err != nil { return newBody, nil } - metadata, ok := reqMap["metadata"].(map[string]any) + // 解析 metadata 字段 + metadataRaw, ok := reqMap["metadata"] if !ok { return newBody, nil } + var metadata map[string]any + if err := json.Unmarshal(metadataRaw, &metadata); err != nil { + return newBody, nil + } + userID, ok := metadata["user_id"].(string) if !ok || userID == "" { return newBody, nil @@ -278,7 +302,13 @@ func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []b ) metadata["user_id"] = newUserID - reqMap["metadata"] = metadata + + // 只重新序列化 metadata 字段 + newMetadataRaw, marshalErr := json.Marshal(metadata) + if marshalErr != nil { + return newBody, nil + } + reqMap["metadata"] = newMetadataRaw return json.Marshal(reqMap) } diff --git a/backend/internal/service/openai_codex_transform.go b/backend/internal/service/openai_codex_transform.go index 48c72593..6460558e 100644 --- a/backend/internal/service/openai_codex_transform.go +++ b/backend/internal/service/openai_codex_transform.go @@ -72,7 +72,7 @@ type opencodeCacheMetadata struct { LastChecked int64 `json:"lastChecked"` } -func applyCodexOAuthTransform(reqBody map[string]any) codexTransformResult { +func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool) codexTransformResult { result := codexTransformResult{} // 工具续链需求会影响存储策略与 input 过滤逻辑。 needsToolContinuation := NeedsToolContinuation(reqBody) @@ -118,22 +118,9 @@ func applyCodexOAuthTransform(reqBody map[string]any) codexTransformResult { result.PromptCacheKey = strings.TrimSpace(v) } - instructions := strings.TrimSpace(getOpenCodeCodexHeader()) - existingInstructions, _ := reqBody["instructions"].(string) - existingInstructions = strings.TrimSpace(existingInstructions) - - if instructions != "" { - if existingInstructions != instructions { - reqBody["instructions"] = instructions - result.Modified = true - } - } else if existingInstructions == "" { - // 未获取到 opencode 指令时,回退使用 Codex CLI 指令。 - codexInstructions := strings.TrimSpace(getCodexCLIInstructions()) - if codexInstructions != "" { - reqBody["instructions"] = codexInstructions - result.Modified = true - } + // instructions 处理逻辑:根据是否是 Codex CLI 分别调用不同方法 + if applyInstructions(reqBody, isCodexCLI) { + result.Modified = true } // 续链场景保留 item_reference 与 id,避免 call_id 上下文丢失。 @@ -276,6 +263,72 @@ func GetCodexCLIInstructions() string { return getCodexCLIInstructions() } +// applyInstructions 处理 instructions 字段 +// isCodexCLI=true: 仅补充缺失的 instructions(使用 opencode 指令) +// isCodexCLI=false: 优先使用 opencode 指令覆盖 +func applyInstructions(reqBody map[string]any, isCodexCLI bool) bool { + if isCodexCLI { + return applyCodexCLIInstructions(reqBody) + } + return applyOpenCodeInstructions(reqBody) +} + +// applyCodexCLIInstructions 为 Codex CLI 请求补充缺失的 instructions +// 仅在 instructions 为空时添加 opencode 指令 +func applyCodexCLIInstructions(reqBody map[string]any) bool { + if !isInstructionsEmpty(reqBody) { + return false // 已有有效 instructions,不修改 + } + + instructions := strings.TrimSpace(getOpenCodeCodexHeader()) + if instructions != "" { + reqBody["instructions"] = instructions + return true + } + + return false +} + +// applyOpenCodeInstructions 为非 Codex CLI 请求应用 opencode 指令 +// 优先使用 opencode 指令覆盖 +func applyOpenCodeInstructions(reqBody map[string]any) bool { + instructions := strings.TrimSpace(getOpenCodeCodexHeader()) + existingInstructions, _ := reqBody["instructions"].(string) + existingInstructions = strings.TrimSpace(existingInstructions) + + if instructions != "" { + if existingInstructions != instructions { + reqBody["instructions"] = instructions + return true + } + } else if existingInstructions == "" { + codexInstructions := strings.TrimSpace(getCodexCLIInstructions()) + if codexInstructions != "" { + reqBody["instructions"] = codexInstructions + return true + } + } + + return false +} + +// isInstructionsEmpty 检查 instructions 字段是否为空 +// 处理以下情况:字段不存在、nil、空字符串、纯空白字符串 +func isInstructionsEmpty(reqBody map[string]any) bool { + val, exists := reqBody["instructions"] + if !exists { + return true + } + if val == nil { + return true + } + str, ok := val.(string) + if !ok { + return true + } + return strings.TrimSpace(str) == "" +} + // ReplaceWithCodexInstructions 将请求 instructions 替换为内置 Codex 指令(必要时)。 func ReplaceWithCodexInstructions(reqBody map[string]any) bool { codexInstructions := strings.TrimSpace(getCodexCLIInstructions()) diff --git a/backend/internal/service/openai_codex_transform_test.go b/backend/internal/service/openai_codex_transform_test.go index 4cd72ab6..ac384553 100644 --- a/backend/internal/service/openai_codex_transform_test.go +++ b/backend/internal/service/openai_codex_transform_test.go @@ -23,7 +23,7 @@ func TestApplyCodexOAuthTransform_ToolContinuationPreservesInput(t *testing.T) { "tool_choice": "auto", } - applyCodexOAuthTransform(reqBody) + applyCodexOAuthTransform(reqBody, false) // 未显式设置 store=true,默认为 false。 store, ok := reqBody["store"].(bool) @@ -59,7 +59,7 @@ func TestApplyCodexOAuthTransform_ExplicitStoreFalsePreserved(t *testing.T) { "tool_choice": "auto", } - applyCodexOAuthTransform(reqBody) + applyCodexOAuthTransform(reqBody, false) store, ok := reqBody["store"].(bool) require.True(t, ok) @@ -79,7 +79,7 @@ func TestApplyCodexOAuthTransform_ExplicitStoreTrueForcedFalse(t *testing.T) { "tool_choice": "auto", } - applyCodexOAuthTransform(reqBody) + applyCodexOAuthTransform(reqBody, false) store, ok := reqBody["store"].(bool) require.True(t, ok) @@ -97,7 +97,7 @@ func TestApplyCodexOAuthTransform_NonContinuationDefaultsStoreFalseAndStripsIDs( }, } - applyCodexOAuthTransform(reqBody) + applyCodexOAuthTransform(reqBody, false) store, ok := reqBody["store"].(bool) require.True(t, ok) @@ -148,7 +148,7 @@ func TestApplyCodexOAuthTransform_NormalizeCodexTools_PreservesResponsesFunction }, } - applyCodexOAuthTransform(reqBody) + applyCodexOAuthTransform(reqBody, false) tools, ok := reqBody["tools"].([]any) require.True(t, ok) @@ -169,7 +169,7 @@ func TestApplyCodexOAuthTransform_EmptyInput(t *testing.T) { "input": []any{}, } - applyCodexOAuthTransform(reqBody) + applyCodexOAuthTransform(reqBody, false) input, ok := reqBody["input"].([]any) require.True(t, ok) @@ -196,3 +196,77 @@ func setupCodexCache(t *testing.T) { require.NoError(t, err) require.NoError(t, os.WriteFile(filepath.Join(cacheDir, "opencode-codex-header-meta.json"), data, 0o644)) } + +func TestApplyCodexOAuthTransform_CodexCLI_PreservesExistingInstructions(t *testing.T) { + // Codex CLI 场景:已有 instructions 时不修改 + setupCodexCache(t) + + reqBody := map[string]any{ + "model": "gpt-5.1", + "instructions": "existing instructions", + } + + result := applyCodexOAuthTransform(reqBody, true) // isCodexCLI=true + + instructions, ok := reqBody["instructions"].(string) + require.True(t, ok) + require.Equal(t, "existing instructions", instructions) + // Modified 仍可能为 true(因为其他字段被修改),但 instructions 应保持不变 + _ = result +} + +func TestApplyCodexOAuthTransform_CodexCLI_SuppliesDefaultWhenEmpty(t *testing.T) { + // Codex CLI 场景:无 instructions 时补充默认值 + setupCodexCache(t) + + reqBody := map[string]any{ + "model": "gpt-5.1", + // 没有 instructions 字段 + } + + result := applyCodexOAuthTransform(reqBody, true) // isCodexCLI=true + + instructions, ok := reqBody["instructions"].(string) + require.True(t, ok) + require.NotEmpty(t, instructions) + require.True(t, result.Modified) +} + +func TestApplyCodexOAuthTransform_NonCodexCLI_OverridesInstructions(t *testing.T) { + // 非 Codex CLI 场景:使用 opencode 指令覆盖 + setupCodexCache(t) + + reqBody := map[string]any{ + "model": "gpt-5.1", + "instructions": "old instructions", + } + + result := applyCodexOAuthTransform(reqBody, false) // isCodexCLI=false + + instructions, ok := reqBody["instructions"].(string) + require.True(t, ok) + require.NotEqual(t, "old instructions", instructions) + require.True(t, result.Modified) +} + +func TestIsInstructionsEmpty(t *testing.T) { + tests := []struct { + name string + reqBody map[string]any + expected bool + }{ + {"missing field", map[string]any{}, true}, + {"nil value", map[string]any{"instructions": nil}, true}, + {"empty string", map[string]any{"instructions": ""}, true}, + {"whitespace only", map[string]any{"instructions": " "}, true}, + {"non-string", map[string]any{"instructions": 123}, true}, + {"valid string", map[string]any{"instructions": "hello"}, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := isInstructionsEmpty(tt.reqBody) + require.Equal(t, tt.expected, result) + }) + } +} diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 6d93e92d..742946d8 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -796,8 +796,8 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco } } - if account.Type == AccountTypeOAuth && !isCodexCLI { - codexResult := applyCodexOAuthTransform(reqBody) + if account.Type == AccountTypeOAuth { + codexResult := applyCodexOAuthTransform(reqBody, isCodexCLI) if codexResult.Modified { bodyModified = true } @@ -1681,13 +1681,14 @@ func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel // OpenAIRecordUsageInput input for recording usage type OpenAIRecordUsageInput struct { - Result *OpenAIForwardResult - APIKey *APIKey - User *User - Account *Account - Subscription *UserSubscription - UserAgent string // 请求的 User-Agent - IPAddress string // 请求的客户端 IP 地址 + Result *OpenAIForwardResult + APIKey *APIKey + User *User + Account *Account + Subscription *UserSubscription + UserAgent string // 请求的 User-Agent + IPAddress string // 请求的客户端 IP 地址 + APIKeyService APIKeyQuotaUpdater } // RecordUsage records usage and deducts balance @@ -1799,6 +1800,13 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec } } + // Update API key quota if applicable (only for balance mode with quota set) + if shouldBill && cost.ActualCost > 0 && apiKey.Quota > 0 && input.APIKeyService != nil { + if err := input.APIKeyService.UpdateQuotaUsed(ctx, apiKey.ID, cost.ActualCost); err != nil { + log.Printf("Update API key quota failed: %v", err) + } + } + // Schedule batch update for account last_used_at s.deferredService.ScheduleLastUsedUpdate(account.ID) diff --git a/backend/internal/service/ops_metrics_collector.go b/backend/internal/service/ops_metrics_collector.go index edf32cf2..30adaae0 100644 --- a/backend/internal/service/ops_metrics_collector.go +++ b/backend/internal/service/ops_metrics_collector.go @@ -285,6 +285,11 @@ func (c *OpsMetricsCollector) collectAndPersist(ctx context.Context) error { return fmt.Errorf("query error counts: %w", err) } + accountSwitchCount, err := c.queryAccountSwitchCount(ctx, windowStart, windowEnd) + if err != nil { + return fmt.Errorf("query account switch counts: %w", err) + } + windowSeconds := windowEnd.Sub(windowStart).Seconds() if windowSeconds <= 0 { windowSeconds = 60 @@ -309,9 +314,10 @@ func (c *OpsMetricsCollector) collectAndPersist(ctx context.Context) error { Upstream429Count: upstream429, Upstream529Count: upstream529, - TokenConsumed: tokenConsumed, - QPS: float64Ptr(roundTo1DP(qps)), - TPS: float64Ptr(roundTo1DP(tps)), + TokenConsumed: tokenConsumed, + AccountSwitchCount: accountSwitchCount, + QPS: float64Ptr(roundTo1DP(qps)), + TPS: float64Ptr(roundTo1DP(tps)), DurationP50Ms: duration.p50, DurationP90Ms: duration.p90, @@ -551,6 +557,27 @@ WHERE created_at >= $1 AND created_at < $2` return errorTotal, businessLimited, errorSLA, upstreamExcl429529, upstream429, upstream529, nil } +func (c *OpsMetricsCollector) queryAccountSwitchCount(ctx context.Context, start, end time.Time) (int64, error) { + q := ` +SELECT + COALESCE(SUM(CASE + WHEN split_part(ev->>'kind', ':', 1) IN ('failover', 'retry_exhausted_failover', 'failover_on_400') THEN 1 + ELSE 0 + END), 0) AS switch_count +FROM ops_error_logs o +CROSS JOIN LATERAL jsonb_array_elements( + COALESCE(NULLIF(o.upstream_errors, 'null'::jsonb), '[]'::jsonb) +) AS ev +WHERE o.created_at >= $1 AND o.created_at < $2 + AND o.is_count_tokens = FALSE` + + var count int64 + if err := c.db.QueryRowContext(ctx, q, start, end).Scan(&count); err != nil { + return 0, err + } + return count, nil +} + type opsCollectedSystemStats struct { cpuUsagePercent *float64 memoryUsedMB *int64 diff --git a/backend/internal/service/ops_port.go b/backend/internal/service/ops_port.go index 515b47bb..347b06b5 100644 --- a/backend/internal/service/ops_port.go +++ b/backend/internal/service/ops_port.go @@ -161,7 +161,8 @@ type OpsInsertSystemMetricsInput struct { Upstream429Count int64 Upstream529Count int64 - TokenConsumed int64 + TokenConsumed int64 + AccountSwitchCount int64 QPS *float64 TPS *float64 @@ -223,8 +224,9 @@ type OpsSystemMetricsSnapshot struct { DBConnIdle *int `json:"db_conn_idle"` DBConnWaiting *int `json:"db_conn_waiting"` - GoroutineCount *int `json:"goroutine_count"` - ConcurrencyQueueDepth *int `json:"concurrency_queue_depth"` + GoroutineCount *int `json:"goroutine_count"` + ConcurrencyQueueDepth *int `json:"concurrency_queue_depth"` + AccountSwitchCount *int64 `json:"account_switch_count"` } type OpsUpsertJobHeartbeatInput struct { diff --git a/backend/internal/service/ops_retry.go b/backend/internal/service/ops_retry.go index 8d98e43f..ffe4c934 100644 --- a/backend/internal/service/ops_retry.go +++ b/backend/internal/service/ops_retry.go @@ -12,6 +12,7 @@ import ( "strings" "time" + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/gin-gonic/gin" "github.com/lib/pq" @@ -476,9 +477,13 @@ func (s *OpsService) executeClientRetry(ctx context.Context, reqType opsRetryReq continue } + attemptCtx := ctx + if switches > 0 { + attemptCtx = context.WithValue(attemptCtx, ctxkey.AccountSwitchCount, switches) + } exec := func() *opsRetryExecution { defer selection.ReleaseFunc() - return s.executeWithAccount(ctx, reqType, errorLog, body, account) + return s.executeWithAccount(attemptCtx, reqType, errorLog, body, account) }() if exec != nil { diff --git a/backend/internal/service/ops_trend_models.go b/backend/internal/service/ops_trend_models.go index f6d07c14..97bbfebe 100644 --- a/backend/internal/service/ops_trend_models.go +++ b/backend/internal/service/ops_trend_models.go @@ -6,6 +6,7 @@ type OpsThroughputTrendPoint struct { BucketStart time.Time `json:"bucket_start"` RequestCount int64 `json:"request_count"` TokenConsumed int64 `json:"token_consumed"` + SwitchCount int64 `json:"switch_count"` QPS float64 `json:"qps"` TPS float64 `json:"tps"` } diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go index 99bf7fd0..1bfb392e 100644 --- a/backend/internal/service/user_service.go +++ b/backend/internal/service/user_service.go @@ -39,7 +39,7 @@ type UserRepository interface { ExistsByEmail(ctx context.Context, email string) (bool, error) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) - // TOTP 相关方法 + // TOTP 双因素认证 UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error EnableTotp(ctx context.Context, userID int64) error DisableTotp(ctx context.Context, userID int64) error diff --git a/backend/migrations/042b_add_ops_system_metrics_switch_count.sql b/backend/migrations/042b_add_ops_system_metrics_switch_count.sql new file mode 100644 index 00000000..6d9f48e5 --- /dev/null +++ b/backend/migrations/042b_add_ops_system_metrics_switch_count.sql @@ -0,0 +1,3 @@ +-- ops_system_metrics 增加账号切换次数统计(按分钟窗口) +ALTER TABLE ops_system_metrics + ADD COLUMN IF NOT EXISTS account_switch_count BIGINT NOT NULL DEFAULT 0; diff --git a/backend/migrations/043b_add_group_invalid_request_fallback.sql b/backend/migrations/043b_add_group_invalid_request_fallback.sql new file mode 100644 index 00000000..1c792704 --- /dev/null +++ b/backend/migrations/043b_add_group_invalid_request_fallback.sql @@ -0,0 +1,13 @@ +-- 043_add_group_invalid_request_fallback.sql +-- 添加无效请求兜底分组配置 + +-- 添加 fallback_group_id_on_invalid_request 字段:无效请求兜底使用的分组 +ALTER TABLE groups +ADD COLUMN IF NOT EXISTS fallback_group_id_on_invalid_request BIGINT REFERENCES groups(id) ON DELETE SET NULL; + +-- 添加索引优化查询 +CREATE INDEX IF NOT EXISTS idx_groups_fallback_group_id_on_invalid_request +ON groups(fallback_group_id_on_invalid_request) WHERE deleted_at IS NULL AND fallback_group_id_on_invalid_request IS NOT NULL; + +-- 添加字段注释 +COMMENT ON COLUMN groups.fallback_group_id_on_invalid_request IS '无效请求兜底使用的分组 ID'; diff --git a/backend/migrations/044b_add_group_mcp_xml_inject.sql b/backend/migrations/044b_add_group_mcp_xml_inject.sql new file mode 100644 index 00000000..7db71dd8 --- /dev/null +++ b/backend/migrations/044b_add_group_mcp_xml_inject.sql @@ -0,0 +1,2 @@ +-- Add mcp_xml_inject field to groups table (for antigravity platform) +ALTER TABLE groups ADD COLUMN mcp_xml_inject BOOLEAN NOT NULL DEFAULT true; diff --git a/backend/migrations/045_add_api_key_quota.sql b/backend/migrations/045_add_api_key_quota.sql new file mode 100644 index 00000000..b3c42d2c --- /dev/null +++ b/backend/migrations/045_add_api_key_quota.sql @@ -0,0 +1,20 @@ +-- Migration: Add quota fields to api_keys table +-- This migration adds independent quota and expiration support for API keys + +-- Add quota limit field (0 = unlimited) +ALTER TABLE api_keys ADD COLUMN IF NOT EXISTS quota DECIMAL(20, 8) NOT NULL DEFAULT 0; + +-- Add used quota amount field +ALTER TABLE api_keys ADD COLUMN IF NOT EXISTS quota_used DECIMAL(20, 8) NOT NULL DEFAULT 0; + +-- Add expiration time field (NULL = never expires) +ALTER TABLE api_keys ADD COLUMN IF NOT EXISTS expires_at TIMESTAMPTZ; + +-- Add indexes for efficient quota queries +CREATE INDEX IF NOT EXISTS idx_api_keys_quota_quota_used ON api_keys(quota, quota_used) WHERE deleted_at IS NULL; +CREATE INDEX IF NOT EXISTS idx_api_keys_expires_at ON api_keys(expires_at) WHERE deleted_at IS NULL; + +-- Comment on columns for documentation +COMMENT ON COLUMN api_keys.quota IS 'Quota limit in USD for this API key (0 = unlimited)'; +COMMENT ON COLUMN api_keys.quota_used IS 'Used quota amount in USD'; +COMMENT ON COLUMN api_keys.expires_at IS 'Expiration time for this API key (null = never expires)'; diff --git a/backend/migrations/046b_add_group_supported_model_scopes.sql b/backend/migrations/046b_add_group_supported_model_scopes.sql new file mode 100644 index 00000000..0b2b3968 --- /dev/null +++ b/backend/migrations/046b_add_group_supported_model_scopes.sql @@ -0,0 +1,6 @@ +-- 添加分组支持的模型系列字段 +ALTER TABLE groups +ADD COLUMN IF NOT EXISTS supported_model_scopes JSONB NOT NULL +DEFAULT '["claude", "gemini_text", "gemini_image"]'::jsonb; + +COMMENT ON COLUMN groups.supported_model_scopes IS '支持的模型系列:claude, gemini_text, gemini_image'; diff --git a/backend/tools.go b/backend/tools.go deleted file mode 100644 index f06d2c78..00000000 --- a/backend/tools.go +++ /dev/null @@ -1,9 +0,0 @@ -//go:build tools -// +build tools - -package tools - -import ( - _ "entgo.io/ent/cmd/ent" - _ "github.com/google/wire/cmd/wire" -) diff --git a/docs/rename_local_migrations_20260202.sql b/docs/rename_local_migrations_20260202.sql new file mode 100644 index 00000000..911ed17d --- /dev/null +++ b/docs/rename_local_migrations_20260202.sql @@ -0,0 +1,34 @@ +-- 修正 schema_migrations 中“本地改名”的迁移文件名 +-- 适用场景:你已执行过旧文件名的迁移,合并后仅改了自己这边的文件名 + +BEGIN; + +UPDATE schema_migrations +SET filename = '042b_add_ops_system_metrics_switch_count.sql' +WHERE filename = '042_add_ops_system_metrics_switch_count.sql' + AND NOT EXISTS ( + SELECT 1 FROM schema_migrations WHERE filename = '042b_add_ops_system_metrics_switch_count.sql' + ); + +UPDATE schema_migrations +SET filename = '043b_add_group_invalid_request_fallback.sql' +WHERE filename = '043_add_group_invalid_request_fallback.sql' + AND NOT EXISTS ( + SELECT 1 FROM schema_migrations WHERE filename = '043b_add_group_invalid_request_fallback.sql' + ); + +UPDATE schema_migrations +SET filename = '044b_add_group_mcp_xml_inject.sql' +WHERE filename = '044_add_group_mcp_xml_inject.sql' + AND NOT EXISTS ( + SELECT 1 FROM schema_migrations WHERE filename = '044b_add_group_mcp_xml_inject.sql' + ); + +UPDATE schema_migrations +SET filename = '046b_add_group_supported_model_scopes.sql' +WHERE filename = '046_add_group_supported_model_scopes.sql' + AND NOT EXISTS ( + SELECT 1 FROM schema_migrations WHERE filename = '046b_add_group_supported_model_scopes.sql' + ); + +COMMIT; diff --git a/frontend/src/api/admin/ops.ts b/frontend/src/api/admin/ops.ts index bf2c246c..a1c41e8c 100644 --- a/frontend/src/api/admin/ops.ts +++ b/frontend/src/api/admin/ops.ts @@ -136,6 +136,7 @@ export interface OpsThroughputTrendPoint { bucket_start: string request_count: number token_consumed: number + switch_count?: number qps: number tps: number } @@ -284,6 +285,7 @@ export interface OpsSystemMetricsSnapshot { goroutine_count?: number | null concurrency_queue_depth?: number | null + account_switch_count?: number | null } export interface OpsJobHeartbeat { diff --git a/frontend/src/api/keys.ts b/frontend/src/api/keys.ts index cdae1359..c5943789 100644 --- a/frontend/src/api/keys.ts +++ b/frontend/src/api/keys.ts @@ -44,6 +44,8 @@ export async function getById(id: number): Promise { * @param customKey - Optional custom key value * @param ipWhitelist - Optional IP whitelist * @param ipBlacklist - Optional IP blacklist + * @param quota - Optional quota limit in USD (0 = unlimited) + * @param expiresInDays - Optional days until expiry (undefined = never expires) * @returns Created API key */ export async function create( @@ -51,7 +53,9 @@ export async function create( groupId?: number | null, customKey?: string, ipWhitelist?: string[], - ipBlacklist?: string[] + ipBlacklist?: string[], + quota?: number, + expiresInDays?: number ): Promise { const payload: CreateApiKeyRequest = { name } if (groupId !== undefined) { @@ -66,6 +70,12 @@ export async function create( if (ipBlacklist && ipBlacklist.length > 0) { payload.ip_blacklist = ipBlacklist } + if (quota !== undefined && quota > 0) { + payload.quota = quota + } + if (expiresInDays !== undefined && expiresInDays > 0) { + payload.expires_in_days = expiresInDays + } const { data } = await apiClient.post('/keys', payload) return data diff --git a/frontend/src/components/account/AccountStatusIndicator.vue b/frontend/src/components/account/AccountStatusIndicator.vue index 8e525fa3..8dcddff7 100644 --- a/frontend/src/components/account/AccountStatusIndicator.vue +++ b/frontend/src/components/account/AccountStatusIndicator.vue @@ -56,7 +56,6 @@ > -
- +
-
-
+ + + +
+
+ + +
+
+ + +

{{ t('admin.accounts.upstream.baseUrlHint') }}

+
+
+ + +

{{ t('admin.accounts.upstream.apiKeyHint') }}

@@ -1980,6 +2046,9 @@ const interceptWarmupRequests = ref(false) const autoPauseOnExpired = ref(true) const enableSoraOnOpenAIOAuth = ref(false) // OpenAI OAuth 时同时启用 Sora const mixedScheduling = ref(false) // For antigravity accounts: enable mixed scheduling +const antigravityAccountType = ref<'oauth' | 'upstream'>('oauth') // For antigravity: oauth or upstream +const upstreamBaseUrl = ref('') // For upstream type: base URL +const upstreamApiKey = ref('') // For upstream type: API key const tempUnschedEnabled = ref(false) const tempUnschedRules = ref([]) const geminiOAuthType = ref<'code_assist' | 'google_one' | 'ai_studio'>('google_one') @@ -2082,7 +2151,13 @@ const form = reactive({ }) // Helper to check if current type needs OAuth flow -const isOAuthFlow = computed(() => accountCategory.value === 'oauth-based') +const isOAuthFlow = computed(() => { + // Antigravity upstream 类型不需要 OAuth 流程 + if (form.platform === 'antigravity' && antigravityAccountType.value === 'upstream') { + return false + } + return accountCategory.value === 'oauth-based' +}) const isManualInputMethod = computed(() => { return oauthFlowRef.value?.inputMethod === 'manual' @@ -2122,10 +2197,15 @@ watch( } ) -// Sync form.type based on accountCategory and addMethod +// Sync form.type based on accountCategory, addMethod, and antigravityAccountType watch( - [accountCategory, addMethod], - ([category, method]) => { + [accountCategory, addMethod, antigravityAccountType], + ([category, method, agType]) => { + // Antigravity upstream 类型 + if (form.platform === 'antigravity' && agType === 'upstream') { + form.type = 'upstream' + return + } if (category === 'oauth-based') { form.type = method as AccountType // 'oauth' or 'setup-token' } else { @@ -2153,9 +2233,10 @@ watch( if (newPlatform !== 'anthropic') { interceptWarmupRequests.value = false } - // Antigravity only supports OAuth + // Antigravity: reset to OAuth by default, but allow upstream selection if (newPlatform === 'antigravity') { accountCategory.value = 'oauth-based' + antigravityAccountType.value = 'oauth' } // Reset OAuth states oauth.resetState() @@ -2389,6 +2470,9 @@ const resetForm = () => { sessionIdleTimeout.value = null tlsFingerprintEnabled.value = false sessionIdMaskingEnabled.value = false + antigravityAccountType.value = 'oauth' + upstreamBaseUrl.value = '' + upstreamApiKey.value = '' tempUnschedEnabled.value = false tempUnschedRules.value = [] geminiOAuthType.value = 'code_assist' @@ -2470,6 +2554,36 @@ const handleSubmit = async () => { return } + // For Antigravity upstream type, create directly + if (form.platform === 'antigravity' && antigravityAccountType.value === 'upstream') { + if (!form.name.trim()) { + appStore.showError(t('admin.accounts.pleaseEnterAccountName')) + return + } + if (!upstreamBaseUrl.value.trim()) { + appStore.showError(t('admin.accounts.upstream.pleaseEnterBaseUrl')) + return + } + if (!upstreamApiKey.value.trim()) { + appStore.showError(t('admin.accounts.upstream.pleaseEnterApiKey')) + return + } + + submitting.value = true + try { + const credentials: Record = { + base_url: upstreamBaseUrl.value.trim(), + api_key: upstreamApiKey.value.trim() + } + await createAccountAndFinish(form.platform, 'upstream', credentials) + } catch (error: any) { + appStore.showError(error.response?.data?.detail || t('admin.accounts.failedToCreate')) + } finally { + submitting.value = false + } + return + } + // For apikey type, create directly if (!apiKeyValue.value.trim()) { appStore.showError(t('admin.accounts.pleaseEnterApiKey')) diff --git a/frontend/src/components/admin/usage/UsageTable.vue b/frontend/src/components/admin/usage/UsageTable.vue index 5edbd3b6..fbb1942a 100644 --- a/frontend/src/components/admin/usage/UsageTable.vue +++ b/frontend/src/components/admin/usage/UsageTable.vue @@ -238,14 +238,14 @@ - + + diff --git a/frontend/src/views/setup/SetupWizardView.vue b/frontend/src/views/setup/SetupWizardView.vue index f3c773ca..fcf5aa72 100644 --- a/frontend/src/views/setup/SetupWizardView.vue +++ b/frontend/src/views/setup/SetupWizardView.vue @@ -91,6 +91,18 @@
+
+
+

+ {{ t("setup.redis.enableTls") }} +

+

+ {{ t("setup.redis.enableTlsHint") }} +

+
+ +
+
diff --git a/frontend/src/views/user/KeysView.vue b/frontend/src/views/user/KeysView.vue index b72ae9ad..51b015fa 100644 --- a/frontend/src/views/user/KeysView.vue +++ b/frontend/src/views/user/KeysView.vue @@ -108,12 +108,53 @@ ${{ (usageStats[row.id]?.total_actual_cost ?? 0).toFixed(4) }}
+ +
+
+ {{ t('keys.quota') }}: + + ${{ row.quota_used?.toFixed(2) || '0.00' }} / ${{ row.quota?.toFixed(2) }} + +
+
+
+
+
+ + @@ -334,6 +375,145 @@
+ + +
+ + + +
+
+
+ $ + +
+

{{ t('keys.quotaAmountHint') }}

+
+ + +
+ +
+
+ + ${{ selectedKey.quota_used?.toFixed(4) || '0.0000' }} + + / + + ${{ selectedKey.quota?.toFixed(2) || '0.00' }} + +
+ +
+
+
+
+ + +
+
+ + +
+ +
+ +
+ + +
+ + +
+ + +

{{ t('keys.expirationDateHint') }}

+
+ + +
+ {{ t('keys.currentExpiration') }}: + + {{ formatDateTime(selectedKey.expires_at) }} + +
+
+