diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index f632bff3..48f15b5c 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -110,7 +110,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig) concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig) adminUserHandler := admin.NewUserHandler(adminService, concurrencyService) - groupHandler := admin.NewGroupHandler(adminService) claudeOAuthClient := repository.NewClaudeOAuthClient() oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient) openAIOAuthClient := repository.NewOpenAIOAuthClient() @@ -143,6 +142,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig) sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig) rpmCache := repository.NewRPMCache(redisClient) + groupCapacityService := service.NewGroupCapacityService(accountRepository, groupRepository, concurrencyService, sessionLimitCache, rpmCache) + groupHandler := admin.NewGroupHandler(adminService, dashboardService, groupCapacityService) accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, rpmCache, compositeTokenCacheInvalidator) adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService) dataManagementService := service.NewDataManagementService() diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go index ff1c1b88..acdd0d18 100644 --- a/backend/ent/migrate/schema.go +++ b/backend/ent/migrate/schema.go @@ -716,6 +716,7 @@ var ( {Name: "id", Type: field.TypeInt64, Increment: true}, {Name: "request_id", Type: field.TypeString, Size: 64}, {Name: "model", Type: field.TypeString, Size: 100}, + {Name: "upstream_model", Type: field.TypeString, Nullable: true, Size: 100}, {Name: "input_tokens", Type: field.TypeInt, Default: 0}, {Name: "output_tokens", Type: field.TypeInt, Default: 0}, {Name: "cache_creation_tokens", Type: field.TypeInt, Default: 0}, @@ -755,31 +756,31 @@ var ( ForeignKeys: []*schema.ForeignKey{ { Symbol: "usage_logs_api_keys_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[28]}, + Columns: []*schema.Column{UsageLogsColumns[29]}, RefColumns: []*schema.Column{APIKeysColumns[0]}, OnDelete: schema.NoAction, }, { Symbol: "usage_logs_accounts_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[29]}, + Columns: []*schema.Column{UsageLogsColumns[30]}, RefColumns: []*schema.Column{AccountsColumns[0]}, OnDelete: schema.NoAction, }, { Symbol: "usage_logs_groups_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[30]}, + Columns: []*schema.Column{UsageLogsColumns[31]}, RefColumns: []*schema.Column{GroupsColumns[0]}, OnDelete: schema.SetNull, }, { Symbol: "usage_logs_users_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[31]}, + Columns: []*schema.Column{UsageLogsColumns[32]}, RefColumns: []*schema.Column{UsersColumns[0]}, OnDelete: schema.NoAction, }, { Symbol: "usage_logs_user_subscriptions_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[32]}, + Columns: []*schema.Column{UsageLogsColumns[33]}, RefColumns: []*schema.Column{UserSubscriptionsColumns[0]}, OnDelete: schema.SetNull, }, @@ -788,32 +789,32 @@ var ( { Name: "usagelog_user_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[31]}, + Columns: []*schema.Column{UsageLogsColumns[32]}, }, { Name: "usagelog_api_key_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[28]}, + Columns: []*schema.Column{UsageLogsColumns[29]}, }, { Name: "usagelog_account_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[29]}, + Columns: []*schema.Column{UsageLogsColumns[30]}, }, { Name: "usagelog_group_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[30]}, + Columns: []*schema.Column{UsageLogsColumns[31]}, }, { Name: "usagelog_subscription_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[32]}, + Columns: []*schema.Column{UsageLogsColumns[33]}, }, { Name: "usagelog_created_at", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[27]}, + Columns: []*schema.Column{UsageLogsColumns[28]}, }, { Name: "usagelog_model", @@ -828,17 +829,17 @@ var ( { Name: "usagelog_user_id_created_at", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[31], UsageLogsColumns[27]}, + Columns: []*schema.Column{UsageLogsColumns[32], UsageLogsColumns[28]}, }, { Name: "usagelog_api_key_id_created_at", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[28], UsageLogsColumns[27]}, + Columns: []*schema.Column{UsageLogsColumns[29], UsageLogsColumns[28]}, }, { Name: "usagelog_group_id_created_at", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[30], UsageLogsColumns[27]}, + Columns: []*schema.Column{UsageLogsColumns[31], UsageLogsColumns[28]}, }, }, } diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go index 652adcac..ff58fa9e 100644 --- a/backend/ent/mutation.go +++ b/backend/ent/mutation.go @@ -18239,6 +18239,7 @@ type UsageLogMutation struct { id *int64 request_id *string model *string + upstream_model *string input_tokens *int addinput_tokens *int output_tokens *int @@ -18576,6 +18577,55 @@ func (m *UsageLogMutation) ResetModel() { m.model = nil } +// SetUpstreamModel sets the "upstream_model" field. +func (m *UsageLogMutation) SetUpstreamModel(s string) { + m.upstream_model = &s +} + +// UpstreamModel returns the value of the "upstream_model" field in the mutation. +func (m *UsageLogMutation) UpstreamModel() (r string, exists bool) { + v := m.upstream_model + if v == nil { + return + } + return *v, true +} + +// OldUpstreamModel returns the old "upstream_model" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldUpstreamModel(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUpstreamModel is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUpstreamModel requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUpstreamModel: %w", err) + } + return oldValue.UpstreamModel, nil +} + +// ClearUpstreamModel clears the value of the "upstream_model" field. +func (m *UsageLogMutation) ClearUpstreamModel() { + m.upstream_model = nil + m.clearedFields[usagelog.FieldUpstreamModel] = struct{}{} +} + +// UpstreamModelCleared returns if the "upstream_model" field was cleared in this mutation. +func (m *UsageLogMutation) UpstreamModelCleared() bool { + _, ok := m.clearedFields[usagelog.FieldUpstreamModel] + return ok +} + +// ResetUpstreamModel resets all changes to the "upstream_model" field. +func (m *UsageLogMutation) ResetUpstreamModel() { + m.upstream_model = nil + delete(m.clearedFields, usagelog.FieldUpstreamModel) +} + // SetGroupID sets the "group_id" field. func (m *UsageLogMutation) SetGroupID(i int64) { m.group = &i @@ -20197,7 +20247,7 @@ func (m *UsageLogMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *UsageLogMutation) Fields() []string { - fields := make([]string, 0, 32) + fields := make([]string, 0, 33) if m.user != nil { fields = append(fields, usagelog.FieldUserID) } @@ -20213,6 +20263,9 @@ func (m *UsageLogMutation) Fields() []string { if m.model != nil { fields = append(fields, usagelog.FieldModel) } + if m.upstream_model != nil { + fields = append(fields, usagelog.FieldUpstreamModel) + } if m.group != nil { fields = append(fields, usagelog.FieldGroupID) } @@ -20312,6 +20365,8 @@ func (m *UsageLogMutation) Field(name string) (ent.Value, bool) { return m.RequestID() case usagelog.FieldModel: return m.Model() + case usagelog.FieldUpstreamModel: + return m.UpstreamModel() case usagelog.FieldGroupID: return m.GroupID() case usagelog.FieldSubscriptionID: @@ -20385,6 +20440,8 @@ func (m *UsageLogMutation) OldField(ctx context.Context, name string) (ent.Value return m.OldRequestID(ctx) case usagelog.FieldModel: return m.OldModel(ctx) + case usagelog.FieldUpstreamModel: + return m.OldUpstreamModel(ctx) case usagelog.FieldGroupID: return m.OldGroupID(ctx) case usagelog.FieldSubscriptionID: @@ -20483,6 +20540,13 @@ func (m *UsageLogMutation) SetField(name string, value ent.Value) error { } m.SetModel(v) return nil + case usagelog.FieldUpstreamModel: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUpstreamModel(v) + return nil case usagelog.FieldGroupID: v, ok := value.(int64) if !ok { @@ -20921,6 +20985,9 @@ func (m *UsageLogMutation) AddField(name string, value ent.Value) error { // mutation. func (m *UsageLogMutation) ClearedFields() []string { var fields []string + if m.FieldCleared(usagelog.FieldUpstreamModel) { + fields = append(fields, usagelog.FieldUpstreamModel) + } if m.FieldCleared(usagelog.FieldGroupID) { fields = append(fields, usagelog.FieldGroupID) } @@ -20962,6 +21029,9 @@ func (m *UsageLogMutation) FieldCleared(name string) bool { // error if the field is not defined in the schema. func (m *UsageLogMutation) ClearField(name string) error { switch name { + case usagelog.FieldUpstreamModel: + m.ClearUpstreamModel() + return nil case usagelog.FieldGroupID: m.ClearGroupID() return nil @@ -21012,6 +21082,9 @@ func (m *UsageLogMutation) ResetField(name string) error { case usagelog.FieldModel: m.ResetModel() return nil + case usagelog.FieldUpstreamModel: + m.ResetUpstreamModel() + return nil case usagelog.FieldGroupID: m.ResetGroupID() return nil diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go index b8facf36..2401e553 100644 --- a/backend/ent/runtime/runtime.go +++ b/backend/ent/runtime/runtime.go @@ -821,92 +821,96 @@ func init() { return nil } }() + // usagelogDescUpstreamModel is the schema descriptor for upstream_model field. + usagelogDescUpstreamModel := usagelogFields[5].Descriptor() + // usagelog.UpstreamModelValidator is a validator for the "upstream_model" field. It is called by the builders before save. + usagelog.UpstreamModelValidator = usagelogDescUpstreamModel.Validators[0].(func(string) error) // usagelogDescInputTokens is the schema descriptor for input_tokens field. - usagelogDescInputTokens := usagelogFields[7].Descriptor() + usagelogDescInputTokens := usagelogFields[8].Descriptor() // usagelog.DefaultInputTokens holds the default value on creation for the input_tokens field. usagelog.DefaultInputTokens = usagelogDescInputTokens.Default.(int) // usagelogDescOutputTokens is the schema descriptor for output_tokens field. - usagelogDescOutputTokens := usagelogFields[8].Descriptor() + usagelogDescOutputTokens := usagelogFields[9].Descriptor() // usagelog.DefaultOutputTokens holds the default value on creation for the output_tokens field. usagelog.DefaultOutputTokens = usagelogDescOutputTokens.Default.(int) // usagelogDescCacheCreationTokens is the schema descriptor for cache_creation_tokens field. - usagelogDescCacheCreationTokens := usagelogFields[9].Descriptor() + usagelogDescCacheCreationTokens := usagelogFields[10].Descriptor() // usagelog.DefaultCacheCreationTokens holds the default value on creation for the cache_creation_tokens field. usagelog.DefaultCacheCreationTokens = usagelogDescCacheCreationTokens.Default.(int) // usagelogDescCacheReadTokens is the schema descriptor for cache_read_tokens field. - usagelogDescCacheReadTokens := usagelogFields[10].Descriptor() + usagelogDescCacheReadTokens := usagelogFields[11].Descriptor() // usagelog.DefaultCacheReadTokens holds the default value on creation for the cache_read_tokens field. usagelog.DefaultCacheReadTokens = usagelogDescCacheReadTokens.Default.(int) // usagelogDescCacheCreation5mTokens is the schema descriptor for cache_creation_5m_tokens field. - usagelogDescCacheCreation5mTokens := usagelogFields[11].Descriptor() + usagelogDescCacheCreation5mTokens := usagelogFields[12].Descriptor() // usagelog.DefaultCacheCreation5mTokens holds the default value on creation for the cache_creation_5m_tokens field. usagelog.DefaultCacheCreation5mTokens = usagelogDescCacheCreation5mTokens.Default.(int) // usagelogDescCacheCreation1hTokens is the schema descriptor for cache_creation_1h_tokens field. - usagelogDescCacheCreation1hTokens := usagelogFields[12].Descriptor() + usagelogDescCacheCreation1hTokens := usagelogFields[13].Descriptor() // usagelog.DefaultCacheCreation1hTokens holds the default value on creation for the cache_creation_1h_tokens field. usagelog.DefaultCacheCreation1hTokens = usagelogDescCacheCreation1hTokens.Default.(int) // usagelogDescInputCost is the schema descriptor for input_cost field. - usagelogDescInputCost := usagelogFields[13].Descriptor() + usagelogDescInputCost := usagelogFields[14].Descriptor() // usagelog.DefaultInputCost holds the default value on creation for the input_cost field. usagelog.DefaultInputCost = usagelogDescInputCost.Default.(float64) // usagelogDescOutputCost is the schema descriptor for output_cost field. - usagelogDescOutputCost := usagelogFields[14].Descriptor() + usagelogDescOutputCost := usagelogFields[15].Descriptor() // usagelog.DefaultOutputCost holds the default value on creation for the output_cost field. usagelog.DefaultOutputCost = usagelogDescOutputCost.Default.(float64) // usagelogDescCacheCreationCost is the schema descriptor for cache_creation_cost field. - usagelogDescCacheCreationCost := usagelogFields[15].Descriptor() + usagelogDescCacheCreationCost := usagelogFields[16].Descriptor() // usagelog.DefaultCacheCreationCost holds the default value on creation for the cache_creation_cost field. usagelog.DefaultCacheCreationCost = usagelogDescCacheCreationCost.Default.(float64) // usagelogDescCacheReadCost is the schema descriptor for cache_read_cost field. - usagelogDescCacheReadCost := usagelogFields[16].Descriptor() + usagelogDescCacheReadCost := usagelogFields[17].Descriptor() // usagelog.DefaultCacheReadCost holds the default value on creation for the cache_read_cost field. usagelog.DefaultCacheReadCost = usagelogDescCacheReadCost.Default.(float64) // usagelogDescTotalCost is the schema descriptor for total_cost field. - usagelogDescTotalCost := usagelogFields[17].Descriptor() + usagelogDescTotalCost := usagelogFields[18].Descriptor() // usagelog.DefaultTotalCost holds the default value on creation for the total_cost field. usagelog.DefaultTotalCost = usagelogDescTotalCost.Default.(float64) // usagelogDescActualCost is the schema descriptor for actual_cost field. - usagelogDescActualCost := usagelogFields[18].Descriptor() + usagelogDescActualCost := usagelogFields[19].Descriptor() // usagelog.DefaultActualCost holds the default value on creation for the actual_cost field. usagelog.DefaultActualCost = usagelogDescActualCost.Default.(float64) // usagelogDescRateMultiplier is the schema descriptor for rate_multiplier field. - usagelogDescRateMultiplier := usagelogFields[19].Descriptor() + usagelogDescRateMultiplier := usagelogFields[20].Descriptor() // usagelog.DefaultRateMultiplier holds the default value on creation for the rate_multiplier field. usagelog.DefaultRateMultiplier = usagelogDescRateMultiplier.Default.(float64) // usagelogDescBillingType is the schema descriptor for billing_type field. - usagelogDescBillingType := usagelogFields[21].Descriptor() + usagelogDescBillingType := usagelogFields[22].Descriptor() // usagelog.DefaultBillingType holds the default value on creation for the billing_type field. usagelog.DefaultBillingType = usagelogDescBillingType.Default.(int8) // usagelogDescStream is the schema descriptor for stream field. - usagelogDescStream := usagelogFields[22].Descriptor() + usagelogDescStream := usagelogFields[23].Descriptor() // usagelog.DefaultStream holds the default value on creation for the stream field. usagelog.DefaultStream = usagelogDescStream.Default.(bool) // usagelogDescUserAgent is the schema descriptor for user_agent field. - usagelogDescUserAgent := usagelogFields[25].Descriptor() + usagelogDescUserAgent := usagelogFields[26].Descriptor() // usagelog.UserAgentValidator is a validator for the "user_agent" field. It is called by the builders before save. usagelog.UserAgentValidator = usagelogDescUserAgent.Validators[0].(func(string) error) // usagelogDescIPAddress is the schema descriptor for ip_address field. - usagelogDescIPAddress := usagelogFields[26].Descriptor() + usagelogDescIPAddress := usagelogFields[27].Descriptor() // usagelog.IPAddressValidator is a validator for the "ip_address" field. It is called by the builders before save. usagelog.IPAddressValidator = usagelogDescIPAddress.Validators[0].(func(string) error) // usagelogDescImageCount is the schema descriptor for image_count field. - usagelogDescImageCount := usagelogFields[27].Descriptor() + usagelogDescImageCount := usagelogFields[28].Descriptor() // usagelog.DefaultImageCount holds the default value on creation for the image_count field. usagelog.DefaultImageCount = usagelogDescImageCount.Default.(int) // usagelogDescImageSize is the schema descriptor for image_size field. - usagelogDescImageSize := usagelogFields[28].Descriptor() + usagelogDescImageSize := usagelogFields[29].Descriptor() // usagelog.ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save. usagelog.ImageSizeValidator = usagelogDescImageSize.Validators[0].(func(string) error) // usagelogDescMediaType is the schema descriptor for media_type field. - usagelogDescMediaType := usagelogFields[29].Descriptor() + usagelogDescMediaType := usagelogFields[30].Descriptor() // usagelog.MediaTypeValidator is a validator for the "media_type" field. It is called by the builders before save. usagelog.MediaTypeValidator = usagelogDescMediaType.Validators[0].(func(string) error) // usagelogDescCacheTTLOverridden is the schema descriptor for cache_ttl_overridden field. - usagelogDescCacheTTLOverridden := usagelogFields[30].Descriptor() + usagelogDescCacheTTLOverridden := usagelogFields[31].Descriptor() // usagelog.DefaultCacheTTLOverridden holds the default value on creation for the cache_ttl_overridden field. usagelog.DefaultCacheTTLOverridden = usagelogDescCacheTTLOverridden.Default.(bool) // usagelogDescCreatedAt is the schema descriptor for created_at field. - usagelogDescCreatedAt := usagelogFields[31].Descriptor() + usagelogDescCreatedAt := usagelogFields[32].Descriptor() // usagelog.DefaultCreatedAt holds the default value on creation for the created_at field. usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time) userMixin := schema.User{}.Mixin() diff --git a/backend/ent/schema/usage_log.go b/backend/ent/schema/usage_log.go index dcca1a0a..8f8a5255 100644 --- a/backend/ent/schema/usage_log.go +++ b/backend/ent/schema/usage_log.go @@ -41,6 +41,12 @@ func (UsageLog) Fields() []ent.Field { field.String("model"). MaxLen(100). NotEmpty(), + // UpstreamModel stores the actual upstream model name when model mapping + // is applied. NULL means no mapping — the requested model was used as-is. + field.String("upstream_model"). + MaxLen(100). + Optional(). + Nillable(), field.Int64("group_id"). Optional(). Nillable(), diff --git a/backend/ent/usagelog.go b/backend/ent/usagelog.go index f6968d0d..014851c9 100644 --- a/backend/ent/usagelog.go +++ b/backend/ent/usagelog.go @@ -32,6 +32,8 @@ type UsageLog struct { RequestID string `json:"request_id,omitempty"` // Model holds the value of the "model" field. Model string `json:"model,omitempty"` + // UpstreamModel holds the value of the "upstream_model" field. + UpstreamModel *string `json:"upstream_model,omitempty"` // GroupID holds the value of the "group_id" field. GroupID *int64 `json:"group_id,omitempty"` // SubscriptionID holds the value of the "subscription_id" field. @@ -175,7 +177,7 @@ func (*UsageLog) scanValues(columns []string) ([]any, error) { values[i] = new(sql.NullFloat64) case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount: values[i] = new(sql.NullInt64) - case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize, usagelog.FieldMediaType: + case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldUpstreamModel, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize, usagelog.FieldMediaType: values[i] = new(sql.NullString) case usagelog.FieldCreatedAt: values[i] = new(sql.NullTime) @@ -230,6 +232,13 @@ func (_m *UsageLog) assignValues(columns []string, values []any) error { } else if value.Valid { _m.Model = value.String } + case usagelog.FieldUpstreamModel: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field upstream_model", values[i]) + } else if value.Valid { + _m.UpstreamModel = new(string) + *_m.UpstreamModel = value.String + } case usagelog.FieldGroupID: if value, ok := values[i].(*sql.NullInt64); !ok { return fmt.Errorf("unexpected type %T for field group_id", values[i]) @@ -477,6 +486,11 @@ func (_m *UsageLog) String() string { builder.WriteString("model=") builder.WriteString(_m.Model) builder.WriteString(", ") + if v := _m.UpstreamModel; v != nil { + builder.WriteString("upstream_model=") + builder.WriteString(*v) + } + builder.WriteString(", ") if v := _m.GroupID; v != nil { builder.WriteString("group_id=") builder.WriteString(fmt.Sprintf("%v", *v)) diff --git a/backend/ent/usagelog/usagelog.go b/backend/ent/usagelog/usagelog.go index ba97b843..789407e7 100644 --- a/backend/ent/usagelog/usagelog.go +++ b/backend/ent/usagelog/usagelog.go @@ -24,6 +24,8 @@ const ( FieldRequestID = "request_id" // FieldModel holds the string denoting the model field in the database. FieldModel = "model" + // FieldUpstreamModel holds the string denoting the upstream_model field in the database. + FieldUpstreamModel = "upstream_model" // FieldGroupID holds the string denoting the group_id field in the database. FieldGroupID = "group_id" // FieldSubscriptionID holds the string denoting the subscription_id field in the database. @@ -135,6 +137,7 @@ var Columns = []string{ FieldAccountID, FieldRequestID, FieldModel, + FieldUpstreamModel, FieldGroupID, FieldSubscriptionID, FieldInputTokens, @@ -179,6 +182,8 @@ var ( RequestIDValidator func(string) error // ModelValidator is a validator for the "model" field. It is called by the builders before save. ModelValidator func(string) error + // UpstreamModelValidator is a validator for the "upstream_model" field. It is called by the builders before save. + UpstreamModelValidator func(string) error // DefaultInputTokens holds the default value on creation for the "input_tokens" field. DefaultInputTokens int // DefaultOutputTokens holds the default value on creation for the "output_tokens" field. @@ -258,6 +263,11 @@ func ByModel(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldModel, opts...).ToFunc() } +// ByUpstreamModel orders the results by the upstream_model field. +func ByUpstreamModel(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpstreamModel, opts...).ToFunc() +} + // ByGroupID orders the results by the group_id field. func ByGroupID(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldGroupID, opts...).ToFunc() diff --git a/backend/ent/usagelog/where.go b/backend/ent/usagelog/where.go index af960335..5f341976 100644 --- a/backend/ent/usagelog/where.go +++ b/backend/ent/usagelog/where.go @@ -80,6 +80,11 @@ func Model(v string) predicate.UsageLog { return predicate.UsageLog(sql.FieldEQ(FieldModel, v)) } +// UpstreamModel applies equality check predicate on the "upstream_model" field. It's identical to UpstreamModelEQ. +func UpstreamModel(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldUpstreamModel, v)) +} + // GroupID applies equality check predicate on the "group_id" field. It's identical to GroupIDEQ. func GroupID(v int64) predicate.UsageLog { return predicate.UsageLog(sql.FieldEQ(FieldGroupID, v)) @@ -405,6 +410,81 @@ func ModelContainsFold(v string) predicate.UsageLog { return predicate.UsageLog(sql.FieldContainsFold(FieldModel, v)) } +// UpstreamModelEQ applies the EQ predicate on the "upstream_model" field. +func UpstreamModelEQ(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldUpstreamModel, v)) +} + +// UpstreamModelNEQ applies the NEQ predicate on the "upstream_model" field. +func UpstreamModelNEQ(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldUpstreamModel, v)) +} + +// UpstreamModelIn applies the In predicate on the "upstream_model" field. +func UpstreamModelIn(vs ...string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldIn(FieldUpstreamModel, vs...)) +} + +// UpstreamModelNotIn applies the NotIn predicate on the "upstream_model" field. +func UpstreamModelNotIn(vs ...string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotIn(FieldUpstreamModel, vs...)) +} + +// UpstreamModelGT applies the GT predicate on the "upstream_model" field. +func UpstreamModelGT(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGT(FieldUpstreamModel, v)) +} + +// UpstreamModelGTE applies the GTE predicate on the "upstream_model" field. +func UpstreamModelGTE(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGTE(FieldUpstreamModel, v)) +} + +// UpstreamModelLT applies the LT predicate on the "upstream_model" field. +func UpstreamModelLT(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLT(FieldUpstreamModel, v)) +} + +// UpstreamModelLTE applies the LTE predicate on the "upstream_model" field. +func UpstreamModelLTE(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLTE(FieldUpstreamModel, v)) +} + +// UpstreamModelContains applies the Contains predicate on the "upstream_model" field. +func UpstreamModelContains(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldContains(FieldUpstreamModel, v)) +} + +// UpstreamModelHasPrefix applies the HasPrefix predicate on the "upstream_model" field. +func UpstreamModelHasPrefix(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldHasPrefix(FieldUpstreamModel, v)) +} + +// UpstreamModelHasSuffix applies the HasSuffix predicate on the "upstream_model" field. +func UpstreamModelHasSuffix(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldHasSuffix(FieldUpstreamModel, v)) +} + +// UpstreamModelIsNil applies the IsNil predicate on the "upstream_model" field. +func UpstreamModelIsNil() predicate.UsageLog { + return predicate.UsageLog(sql.FieldIsNull(FieldUpstreamModel)) +} + +// UpstreamModelNotNil applies the NotNil predicate on the "upstream_model" field. +func UpstreamModelNotNil() predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotNull(FieldUpstreamModel)) +} + +// UpstreamModelEqualFold applies the EqualFold predicate on the "upstream_model" field. +func UpstreamModelEqualFold(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEqualFold(FieldUpstreamModel, v)) +} + +// UpstreamModelContainsFold applies the ContainsFold predicate on the "upstream_model" field. +func UpstreamModelContainsFold(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldContainsFold(FieldUpstreamModel, v)) +} + // GroupIDEQ applies the EQ predicate on the "group_id" field. func GroupIDEQ(v int64) predicate.UsageLog { return predicate.UsageLog(sql.FieldEQ(FieldGroupID, v)) diff --git a/backend/ent/usagelog_create.go b/backend/ent/usagelog_create.go index e0285a5e..26be5dcb 100644 --- a/backend/ent/usagelog_create.go +++ b/backend/ent/usagelog_create.go @@ -57,6 +57,20 @@ func (_c *UsageLogCreate) SetModel(v string) *UsageLogCreate { return _c } +// SetUpstreamModel sets the "upstream_model" field. +func (_c *UsageLogCreate) SetUpstreamModel(v string) *UsageLogCreate { + _c.mutation.SetUpstreamModel(v) + return _c +} + +// SetNillableUpstreamModel sets the "upstream_model" field if the given value is not nil. +func (_c *UsageLogCreate) SetNillableUpstreamModel(v *string) *UsageLogCreate { + if v != nil { + _c.SetUpstreamModel(*v) + } + return _c +} + // SetGroupID sets the "group_id" field. func (_c *UsageLogCreate) SetGroupID(v int64) *UsageLogCreate { _c.mutation.SetGroupID(v) @@ -596,6 +610,11 @@ func (_c *UsageLogCreate) check() error { return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model": %w`, err)} } } + if v, ok := _c.mutation.UpstreamModel(); ok { + if err := usagelog.UpstreamModelValidator(v); err != nil { + return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)} + } + } if _, ok := _c.mutation.InputTokens(); !ok { return &ValidationError{Name: "input_tokens", err: errors.New(`ent: missing required field "UsageLog.input_tokens"`)} } @@ -714,6 +733,10 @@ func (_c *UsageLogCreate) createSpec() (*UsageLog, *sqlgraph.CreateSpec) { _spec.SetField(usagelog.FieldModel, field.TypeString, value) _node.Model = value } + if value, ok := _c.mutation.UpstreamModel(); ok { + _spec.SetField(usagelog.FieldUpstreamModel, field.TypeString, value) + _node.UpstreamModel = &value + } if value, ok := _c.mutation.InputTokens(); ok { _spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value) _node.InputTokens = value @@ -1011,6 +1034,24 @@ func (u *UsageLogUpsert) UpdateModel() *UsageLogUpsert { return u } +// SetUpstreamModel sets the "upstream_model" field. +func (u *UsageLogUpsert) SetUpstreamModel(v string) *UsageLogUpsert { + u.Set(usagelog.FieldUpstreamModel, v) + return u +} + +// UpdateUpstreamModel sets the "upstream_model" field to the value that was provided on create. +func (u *UsageLogUpsert) UpdateUpstreamModel() *UsageLogUpsert { + u.SetExcluded(usagelog.FieldUpstreamModel) + return u +} + +// ClearUpstreamModel clears the value of the "upstream_model" field. +func (u *UsageLogUpsert) ClearUpstreamModel() *UsageLogUpsert { + u.SetNull(usagelog.FieldUpstreamModel) + return u +} + // SetGroupID sets the "group_id" field. func (u *UsageLogUpsert) SetGroupID(v int64) *UsageLogUpsert { u.Set(usagelog.FieldGroupID, v) @@ -1600,6 +1641,27 @@ func (u *UsageLogUpsertOne) UpdateModel() *UsageLogUpsertOne { }) } +// SetUpstreamModel sets the "upstream_model" field. +func (u *UsageLogUpsertOne) SetUpstreamModel(v string) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.SetUpstreamModel(v) + }) +} + +// UpdateUpstreamModel sets the "upstream_model" field to the value that was provided on create. +func (u *UsageLogUpsertOne) UpdateUpstreamModel() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateUpstreamModel() + }) +} + +// ClearUpstreamModel clears the value of the "upstream_model" field. +func (u *UsageLogUpsertOne) ClearUpstreamModel() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.ClearUpstreamModel() + }) +} + // SetGroupID sets the "group_id" field. func (u *UsageLogUpsertOne) SetGroupID(v int64) *UsageLogUpsertOne { return u.Update(func(s *UsageLogUpsert) { @@ -2434,6 +2496,27 @@ func (u *UsageLogUpsertBulk) UpdateModel() *UsageLogUpsertBulk { }) } +// SetUpstreamModel sets the "upstream_model" field. +func (u *UsageLogUpsertBulk) SetUpstreamModel(v string) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.SetUpstreamModel(v) + }) +} + +// UpdateUpstreamModel sets the "upstream_model" field to the value that was provided on create. +func (u *UsageLogUpsertBulk) UpdateUpstreamModel() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateUpstreamModel() + }) +} + +// ClearUpstreamModel clears the value of the "upstream_model" field. +func (u *UsageLogUpsertBulk) ClearUpstreamModel() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.ClearUpstreamModel() + }) +} + // SetGroupID sets the "group_id" field. func (u *UsageLogUpsertBulk) SetGroupID(v int64) *UsageLogUpsertBulk { return u.Update(func(s *UsageLogUpsert) { diff --git a/backend/ent/usagelog_update.go b/backend/ent/usagelog_update.go index b46e5b56..b7c4632c 100644 --- a/backend/ent/usagelog_update.go +++ b/backend/ent/usagelog_update.go @@ -102,6 +102,26 @@ func (_u *UsageLogUpdate) SetNillableModel(v *string) *UsageLogUpdate { return _u } +// SetUpstreamModel sets the "upstream_model" field. +func (_u *UsageLogUpdate) SetUpstreamModel(v string) *UsageLogUpdate { + _u.mutation.SetUpstreamModel(v) + return _u +} + +// SetNillableUpstreamModel sets the "upstream_model" field if the given value is not nil. +func (_u *UsageLogUpdate) SetNillableUpstreamModel(v *string) *UsageLogUpdate { + if v != nil { + _u.SetUpstreamModel(*v) + } + return _u +} + +// ClearUpstreamModel clears the value of the "upstream_model" field. +func (_u *UsageLogUpdate) ClearUpstreamModel() *UsageLogUpdate { + _u.mutation.ClearUpstreamModel() + return _u +} + // SetGroupID sets the "group_id" field. func (_u *UsageLogUpdate) SetGroupID(v int64) *UsageLogUpdate { _u.mutation.SetGroupID(v) @@ -745,6 +765,11 @@ func (_u *UsageLogUpdate) check() error { return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model": %w`, err)} } } + if v, ok := _u.mutation.UpstreamModel(); ok { + if err := usagelog.UpstreamModelValidator(v); err != nil { + return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)} + } + } if v, ok := _u.mutation.UserAgent(); ok { if err := usagelog.UserAgentValidator(v); err != nil { return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)} @@ -795,6 +820,12 @@ func (_u *UsageLogUpdate) sqlSave(ctx context.Context) (_node int, err error) { if value, ok := _u.mutation.Model(); ok { _spec.SetField(usagelog.FieldModel, field.TypeString, value) } + if value, ok := _u.mutation.UpstreamModel(); ok { + _spec.SetField(usagelog.FieldUpstreamModel, field.TypeString, value) + } + if _u.mutation.UpstreamModelCleared() { + _spec.ClearField(usagelog.FieldUpstreamModel, field.TypeString) + } if value, ok := _u.mutation.InputTokens(); ok { _spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value) } @@ -1177,6 +1208,26 @@ func (_u *UsageLogUpdateOne) SetNillableModel(v *string) *UsageLogUpdateOne { return _u } +// SetUpstreamModel sets the "upstream_model" field. +func (_u *UsageLogUpdateOne) SetUpstreamModel(v string) *UsageLogUpdateOne { + _u.mutation.SetUpstreamModel(v) + return _u +} + +// SetNillableUpstreamModel sets the "upstream_model" field if the given value is not nil. +func (_u *UsageLogUpdateOne) SetNillableUpstreamModel(v *string) *UsageLogUpdateOne { + if v != nil { + _u.SetUpstreamModel(*v) + } + return _u +} + +// ClearUpstreamModel clears the value of the "upstream_model" field. +func (_u *UsageLogUpdateOne) ClearUpstreamModel() *UsageLogUpdateOne { + _u.mutation.ClearUpstreamModel() + return _u +} + // SetGroupID sets the "group_id" field. func (_u *UsageLogUpdateOne) SetGroupID(v int64) *UsageLogUpdateOne { _u.mutation.SetGroupID(v) @@ -1833,6 +1884,11 @@ func (_u *UsageLogUpdateOne) check() error { return &ValidationError{Name: "model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.model": %w`, err)} } } + if v, ok := _u.mutation.UpstreamModel(); ok { + if err := usagelog.UpstreamModelValidator(v); err != nil { + return &ValidationError{Name: "upstream_model", err: fmt.Errorf(`ent: validator failed for field "UsageLog.upstream_model": %w`, err)} + } + } if v, ok := _u.mutation.UserAgent(); ok { if err := usagelog.UserAgentValidator(v); err != nil { return &ValidationError{Name: "user_agent", err: fmt.Errorf(`ent: validator failed for field "UsageLog.user_agent": %w`, err)} @@ -1900,6 +1956,12 @@ func (_u *UsageLogUpdateOne) sqlSave(ctx context.Context) (_node *UsageLog, err if value, ok := _u.mutation.Model(); ok { _spec.SetField(usagelog.FieldModel, field.TypeString, value) } + if value, ok := _u.mutation.UpstreamModel(); ok { + _spec.SetField(usagelog.FieldUpstreamModel, field.TypeString, value) + } + if _u.mutation.UpstreamModelCleared() { + _spec.ClearField(usagelog.FieldUpstreamModel, field.TypeString) + } if value, ok := _u.mutation.InputTokens(); ok { _spec.SetField(usagelog.FieldInputTokens, field.TypeInt, value) } diff --git a/backend/go.sum b/backend/go.sum index 324fe652..270be5f8 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -22,8 +22,6 @@ github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwTo github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= github.com/apparentlymart/go-textseg/v15 v15.0.0 h1:uYvfpb3DyLSCGWnctWKGj857c6ew1u1fNQOlOtuGxQY= github.com/apparentlymart/go-textseg/v15 v15.0.0/go.mod h1:K8XmNZdhEBkdlyDdvbmmsvpAG721bKi0joRfFdHIWJ4= -github.com/aws/aws-sdk-go-v2 v1.41.2 h1:LuT2rzqNQsauaGkPK/7813XxcZ3o3yePY0Iy891T2ls= -github.com/aws/aws-sdk-go-v2 v1.41.2/go.mod h1:IvvlAZQXvTXznUPfRVfryiG1fbzE2NGK6m9u39YQ+S4= github.com/aws/aws-sdk-go-v2 v1.41.3 h1:4kQ/fa22KjDt13QCy1+bYADvdgcxpfH18f0zP542kZA= github.com/aws/aws-sdk-go-v2 v1.41.3/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5 h1:zWFmPmgw4sveAYi1mRqG+E/g0461cJ5M4bJ8/nc6d3Q= @@ -60,8 +58,6 @@ github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15 h1:edCcNp9eGIUDUCrzoCu1jWA github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15/go.mod h1:lyRQKED9xWfgkYC/wmmYfv7iVIM68Z5OQ88ZdcV1QbU= github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 h1:NITQpgo9A5NrDZ57uOWj+abvXSb83BbyggcUBVksN7c= github.com/aws/aws-sdk-go-v2/service/sts v1.41.7/go.mod h1:sks5UWBhEuWYDPdwlnRFn1w7xWdH29Jcpe+/PJQefEs= -github.com/aws/smithy-go v1.24.1 h1:VbyeNfmYkWoxMVpGUAbQumkODcYmfMRfZ8yQiH30SK0= -github.com/aws/smithy-go v1.24.1/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0= github.com/aws/smithy-go v1.24.2 h1:FzA3bu/nt/vDvmnkg+R8Xl46gmzEDam6mZ1hzmwXFng= github.com/aws/smithy-go v1.24.2/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc= github.com/bdandy/go-errors v1.2.2 h1:WdFv/oukjTJCLa79UfkGmwX7ZxONAihKu4V0mLIs11Q= @@ -98,10 +94,6 @@ github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XL github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams= github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk= -github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfatpWHKCs= -github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA= -github.com/clipperhouse/uax29/v2 v2.5.0 h1:x7T0T4eTHDONxFJsL94uKNKPHrclyFI0lm7+w94cO8U= -github.com/clipperhouse/uax29/v2 v2.5.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g= github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g= github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg= github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI= @@ -238,8 +230,6 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byFGLdw= -github.com/mattn/go-runewidth v0.0.19/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs= github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM= github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI= @@ -273,8 +263,6 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A= github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc= github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= -github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec= -github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY= github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040= @@ -326,8 +314,6 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8= github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY= github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0= github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= -github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I= -github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ= diff --git a/backend/internal/domain/constants.go b/backend/internal/domain/constants.go index c51046a2..4e69ca02 100644 --- a/backend/internal/domain/constants.go +++ b/backend/internal/domain/constants.go @@ -82,8 +82,8 @@ var DefaultAntigravityModelMapping = map[string]string{ "claude-opus-4-5-20251101": "claude-opus-4-6-thinking", // 迁移旧模型 "claude-sonnet-4-5-20250929": "claude-sonnet-4-5", // Claude Haiku → Sonnet(无 Haiku 支持) - "claude-haiku-4-5": "claude-sonnet-4-5", - "claude-haiku-4-5-20251001": "claude-sonnet-4-5", + "claude-haiku-4-5": "claude-sonnet-4-6", + "claude-haiku-4-5-20251001": "claude-sonnet-4-6", // Gemini 2.5 白名单 "gemini-2.5-flash": "gemini-2.5-flash", "gemini-2.5-flash-image": "gemini-2.5-flash-image", diff --git a/backend/internal/handler/admin/admin_basic_handlers_test.go b/backend/internal/handler/admin/admin_basic_handlers_test.go index 4de10d3e..cba3ae21 100644 --- a/backend/internal/handler/admin/admin_basic_handlers_test.go +++ b/backend/internal/handler/admin/admin_basic_handlers_test.go @@ -17,7 +17,7 @@ func setupAdminRouter() (*gin.Engine, *stubAdminService) { adminSvc := newStubAdminService() userHandler := NewUserHandler(adminSvc, nil) - groupHandler := NewGroupHandler(adminSvc) + groupHandler := NewGroupHandler(adminSvc, nil, nil) proxyHandler := NewProxyHandler(adminSvc) redeemHandler := NewRedeemHandler(adminSvc, nil) diff --git a/backend/internal/handler/admin/dashboard_handler.go b/backend/internal/handler/admin/dashboard_handler.go index a34bbd39..2a214471 100644 --- a/backend/internal/handler/admin/dashboard_handler.go +++ b/backend/internal/handler/admin/dashboard_handler.go @@ -273,6 +273,7 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) { // Parse optional filter params var userID, apiKeyID, accountID, groupID int64 + modelSource := usagestats.ModelSourceRequested var requestType *int16 var stream *bool var billingType *int8 @@ -297,6 +298,13 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) { groupID = id } } + if rawModelSource := strings.TrimSpace(c.Query("model_source")); rawModelSource != "" { + if !usagestats.IsValidModelSource(rawModelSource) { + response.BadRequest(c, "Invalid model_source, use requested/upstream/mapping") + return + } + modelSource = rawModelSource + } if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" { parsed, err := service.ParseUsageRequestType(requestTypeStr) if err != nil { @@ -323,7 +331,7 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) { } } - stats, hit, err := h.getModelStatsCached(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType) + stats, hit, err := h.getModelStatsCached(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, modelSource, requestType, stream, billingType) if err != nil { response.Error(c, 500, "Failed to get model statistics") return @@ -619,6 +627,12 @@ func (h *DashboardHandler) GetUserBreakdown(c *gin.Context) { } } dim.Model = c.Query("model") + rawModelSource := strings.TrimSpace(c.DefaultQuery("model_source", usagestats.ModelSourceRequested)) + if !usagestats.IsValidModelSource(rawModelSource) { + response.BadRequest(c, "Invalid model_source, use requested/upstream/mapping") + return + } + dim.ModelType = rawModelSource dim.Endpoint = c.Query("endpoint") dim.EndpointType = c.DefaultQuery("endpoint_type", "inbound") diff --git a/backend/internal/handler/admin/dashboard_handler_request_type_test.go b/backend/internal/handler/admin/dashboard_handler_request_type_test.go index 9aec61d4..6056f725 100644 --- a/backend/internal/handler/admin/dashboard_handler_request_type_test.go +++ b/backend/internal/handler/admin/dashboard_handler_request_type_test.go @@ -149,6 +149,28 @@ func TestDashboardModelStatsInvalidStream(t *testing.T) { require.Equal(t, http.StatusBadRequest, rec.Code) } +func TestDashboardModelStatsInvalidModelSource(t *testing.T) { + repo := &dashboardUsageRepoCapture{} + router := newDashboardRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/models?model_source=invalid", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestDashboardModelStatsValidModelSource(t *testing.T) { + repo := &dashboardUsageRepoCapture{} + router := newDashboardRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/models?model_source=upstream", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) +} + func TestDashboardUsersRankingLimitAndCache(t *testing.T) { dashboardUsersRankingCache = newSnapshotCache(5 * time.Minute) repo := &dashboardUsageRepoCapture{ diff --git a/backend/internal/handler/admin/dashboard_handler_user_breakdown_test.go b/backend/internal/handler/admin/dashboard_handler_user_breakdown_test.go index 2c1dbd59..b3a05111 100644 --- a/backend/internal/handler/admin/dashboard_handler_user_breakdown_test.go +++ b/backend/internal/handler/admin/dashboard_handler_user_breakdown_test.go @@ -73,9 +73,35 @@ func TestGetUserBreakdown_ModelFilter(t *testing.T) { require.Equal(t, http.StatusOK, w.Code) require.Equal(t, "claude-opus-4-6", repo.capturedDim.Model) + require.Equal(t, usagestats.ModelSourceRequested, repo.capturedDim.ModelType) require.Equal(t, int64(0), repo.capturedDim.GroupID) } +func TestGetUserBreakdown_ModelSourceFilter(t *testing.T) { + repo := &userBreakdownRepoCapture{} + router := newUserBreakdownRouter(repo) + + req := httptest.NewRequest(http.MethodGet, + "/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&model=claude-opus-4-6&model_source=upstream", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code) + require.Equal(t, usagestats.ModelSourceUpstream, repo.capturedDim.ModelType) +} + +func TestGetUserBreakdown_InvalidModelSource(t *testing.T) { + repo := &userBreakdownRepoCapture{} + router := newUserBreakdownRouter(repo) + + req := httptest.NewRequest(http.MethodGet, + "/admin/dashboard/user-breakdown?start_date=2026-03-01&end_date=2026-03-16&model_source=foobar", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusBadRequest, w.Code) +} + func TestGetUserBreakdown_EndpointFilter(t *testing.T) { repo := &userBreakdownRepoCapture{} router := newUserBreakdownRouter(repo) diff --git a/backend/internal/handler/admin/dashboard_query_cache.go b/backend/internal/handler/admin/dashboard_query_cache.go index 47af5117..815c5161 100644 --- a/backend/internal/handler/admin/dashboard_query_cache.go +++ b/backend/internal/handler/admin/dashboard_query_cache.go @@ -38,6 +38,7 @@ type dashboardModelGroupCacheKey struct { APIKeyID int64 `json:"api_key_id"` AccountID int64 `json:"account_id"` GroupID int64 `json:"group_id"` + ModelSource string `json:"model_source,omitempty"` RequestType *int16 `json:"request_type"` Stream *bool `json:"stream"` BillingType *int8 `json:"billing_type"` @@ -111,6 +112,7 @@ func (h *DashboardHandler) getModelStatsCached( ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, + modelSource string, requestType *int16, stream *bool, billingType *int8, @@ -122,12 +124,13 @@ func (h *DashboardHandler) getModelStatsCached( APIKeyID: apiKeyID, AccountID: accountID, GroupID: groupID, + ModelSource: usagestats.NormalizeModelSource(modelSource), RequestType: requestType, Stream: stream, BillingType: billingType, }) entry, hit, err := dashboardModelStatsCache.GetOrLoad(key, func() (any, error) { - return h.dashboardService.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType) + return h.dashboardService.GetModelStatsWithFiltersBySource(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType, modelSource) }) if err != nil { return nil, hit, err diff --git a/backend/internal/handler/admin/dashboard_snapshot_v2_handler.go b/backend/internal/handler/admin/dashboard_snapshot_v2_handler.go index 16e10339..517ae7bd 100644 --- a/backend/internal/handler/admin/dashboard_snapshot_v2_handler.go +++ b/backend/internal/handler/admin/dashboard_snapshot_v2_handler.go @@ -200,6 +200,7 @@ func (h *DashboardHandler) buildSnapshotV2Response( filters.APIKeyID, filters.AccountID, filters.GroupID, + usagestats.ModelSourceRequested, filters.RequestType, filters.Stream, filters.BillingType, diff --git a/backend/internal/handler/admin/group_handler.go b/backend/internal/handler/admin/group_handler.go index 4ffe64ee..459fd949 100644 --- a/backend/internal/handler/admin/group_handler.go +++ b/backend/internal/handler/admin/group_handler.go @@ -9,6 +9,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" @@ -16,7 +17,9 @@ import ( // GroupHandler handles admin group management type GroupHandler struct { - adminService service.AdminService + adminService service.AdminService + dashboardService *service.DashboardService + groupCapacityService *service.GroupCapacityService } type optionalLimitField struct { @@ -69,9 +72,11 @@ func (f optionalLimitField) ToServiceInput() *float64 { } // NewGroupHandler creates a new admin group handler -func NewGroupHandler(adminService service.AdminService) *GroupHandler { +func NewGroupHandler(adminService service.AdminService, dashboardService *service.DashboardService, groupCapacityService *service.GroupCapacityService) *GroupHandler { return &GroupHandler{ - adminService: adminService, + adminService: adminService, + dashboardService: dashboardService, + groupCapacityService: groupCapacityService, } } @@ -363,6 +368,33 @@ func (h *GroupHandler) GetStats(c *gin.Context) { _ = groupID // TODO: implement actual stats } +// GetUsageSummary returns today's and cumulative cost for all groups. +// GET /api/v1/admin/groups/usage-summary?timezone=Asia/Shanghai +func (h *GroupHandler) GetUsageSummary(c *gin.Context) { + userTZ := c.Query("timezone") + now := timezone.NowInUserLocation(userTZ) + todayStart := timezone.StartOfDayInUserLocation(now, userTZ) + + results, err := h.dashboardService.GetGroupUsageSummary(c.Request.Context(), todayStart) + if err != nil { + response.Error(c, 500, "Failed to get group usage summary") + return + } + + response.Success(c, results) +} + +// GetCapacitySummary returns aggregated capacity (concurrency/sessions/RPM) for all active groups. +// GET /api/v1/admin/groups/capacity-summary +func (h *GroupHandler) GetCapacitySummary(c *gin.Context) { + results, err := h.groupCapacityService.GetAllGroupCapacity(c.Request.Context()) + if err != nil { + response.Error(c, 500, "Failed to get group capacity summary") + return + } + response.Success(c, results) +} + // GetGroupAPIKeys handles getting API keys in a group // GET /api/v1/admin/groups/:id/api-keys func (h *GroupHandler) GetGroupAPIKeys(c *gin.Context) { diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index c966cb7d..25456bb3 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -977,6 +977,58 @@ func (h *SettingHandler) DeleteAdminAPIKey(c *gin.Context) { response.Success(c, gin.H{"message": "Admin API key deleted"}) } +// GetOverloadCooldownSettings 获取529过载冷却配置 +// GET /api/v1/admin/settings/overload-cooldown +func (h *SettingHandler) GetOverloadCooldownSettings(c *gin.Context) { + settings, err := h.settingService.GetOverloadCooldownSettings(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, dto.OverloadCooldownSettings{ + Enabled: settings.Enabled, + CooldownMinutes: settings.CooldownMinutes, + }) +} + +// UpdateOverloadCooldownSettingsRequest 更新529过载冷却配置请求 +type UpdateOverloadCooldownSettingsRequest struct { + Enabled bool `json:"enabled"` + CooldownMinutes int `json:"cooldown_minutes"` +} + +// UpdateOverloadCooldownSettings 更新529过载冷却配置 +// PUT /api/v1/admin/settings/overload-cooldown +func (h *SettingHandler) UpdateOverloadCooldownSettings(c *gin.Context) { + var req UpdateOverloadCooldownSettingsRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + settings := &service.OverloadCooldownSettings{ + Enabled: req.Enabled, + CooldownMinutes: req.CooldownMinutes, + } + + if err := h.settingService.SetOverloadCooldownSettings(c.Request.Context(), settings); err != nil { + response.BadRequest(c, err.Error()) + return + } + + updatedSettings, err := h.settingService.GetOverloadCooldownSettings(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, dto.OverloadCooldownSettings{ + Enabled: updatedSettings.Enabled, + CooldownMinutes: updatedSettings.CooldownMinutes, + }) +} + // GetStreamTimeoutSettings 获取流超时处理配置 // GET /api/v1/admin/settings/stream-timeout func (h *SettingHandler) GetStreamTimeoutSettings(c *gin.Context) { diff --git a/backend/internal/handler/admin/subscription_handler.go b/backend/internal/handler/admin/subscription_handler.go index 342964b6..611666de 100644 --- a/backend/internal/handler/admin/subscription_handler.go +++ b/backend/internal/handler/admin/subscription_handler.go @@ -77,12 +77,13 @@ func (h *SubscriptionHandler) List(c *gin.Context) { } } status := c.Query("status") + platform := c.Query("platform") // Parse sorting parameters sortBy := c.DefaultQuery("sort_by", "created_at") sortOrder := c.DefaultQuery("sort_order", "desc") - subscriptions, pagination, err := h.subscriptionService.List(c.Request.Context(), page, pageSize, userID, groupID, status, sortBy, sortOrder) + subscriptions, pagination, err := h.subscriptionService.List(c.Request.Context(), page, pageSize, userID, groupID, status, platform, sortBy, sortOrder) if err != nil { response.ErrorFrom(c, err) return diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go index 8e5f23e7..d1d867ee 100644 --- a/backend/internal/handler/dto/mappers.go +++ b/backend/internal/handler/dto/mappers.go @@ -135,14 +135,16 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup { return nil } out := &AdminGroup{ - Group: groupFromServiceBase(g), - ModelRouting: g.ModelRouting, - ModelRoutingEnabled: g.ModelRoutingEnabled, - MCPXMLInject: g.MCPXMLInject, - DefaultMappedModel: g.DefaultMappedModel, - SupportedModelScopes: g.SupportedModelScopes, - AccountCount: g.AccountCount, - SortOrder: g.SortOrder, + Group: groupFromServiceBase(g), + ModelRouting: g.ModelRouting, + ModelRoutingEnabled: g.ModelRoutingEnabled, + MCPXMLInject: g.MCPXMLInject, + DefaultMappedModel: g.DefaultMappedModel, + SupportedModelScopes: g.SupportedModelScopes, + AccountCount: g.AccountCount, + ActiveAccountCount: g.ActiveAccountCount, + RateLimitedAccountCount: g.RateLimitedAccountCount, + SortOrder: g.SortOrder, } if len(g.AccountGroups) > 0 { out.AccountGroups = make([]AccountGroup, 0, len(g.AccountGroups)) @@ -521,6 +523,7 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog { AccountID: l.AccountID, RequestID: l.RequestID, Model: l.Model, + UpstreamModel: l.UpstreamModel, ServiceTier: l.ServiceTier, ReasoningEffort: l.ReasoningEffort, InboundEndpoint: l.InboundEndpoint, diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go index 29b00bb8..b953e336 100644 --- a/backend/internal/handler/dto/settings.go +++ b/backend/internal/handler/dto/settings.go @@ -157,6 +157,12 @@ type ListSoraS3ProfilesResponse struct { Items []SoraS3Profile `json:"items"` } +// OverloadCooldownSettings 529过载冷却配置 DTO +type OverloadCooldownSettings struct { + Enabled bool `json:"enabled"` + CooldownMinutes int `json:"cooldown_minutes"` +} + // StreamTimeoutSettings 流超时处理配置 DTO type StreamTimeoutSettings struct { Enabled bool `json:"enabled"` diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go index c52e357e..7b3443be 100644 --- a/backend/internal/handler/dto/types.go +++ b/backend/internal/handler/dto/types.go @@ -122,9 +122,11 @@ type AdminGroup struct { DefaultMappedModel string `json:"default_mapped_model"` // 支持的模型系列(仅 antigravity 平台使用) - SupportedModelScopes []string `json:"supported_model_scopes"` - AccountGroups []AccountGroup `json:"account_groups,omitempty"` - AccountCount int64 `json:"account_count,omitempty"` + SupportedModelScopes []string `json:"supported_model_scopes"` + AccountGroups []AccountGroup `json:"account_groups,omitempty"` + AccountCount int64 `json:"account_count,omitempty"` + ActiveAccountCount int64 `json:"active_account_count,omitempty"` + RateLimitedAccountCount int64 `json:"rate_limited_account_count,omitempty"` // 分组排序 SortOrder int `json:"sort_order"` @@ -332,6 +334,9 @@ type UsageLog struct { AccountID int64 `json:"account_id"` RequestID string `json:"request_id"` Model string `json:"model"` + // UpstreamModel is the actual model sent to the upstream provider after mapping. + // Omitted when no mapping was applied (requested model was used as-is). + UpstreamModel *string `json:"upstream_model,omitempty"` // ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex". ServiceTier *string `json:"service_tier,omitempty"` // ReasoningEffort is the request's reasoning effort level. diff --git a/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go index 6bcc0003..b9dbe0ce 100644 --- a/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go +++ b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go @@ -76,7 +76,7 @@ func (f *fakeGroupRepo) ListActiveByPlatform(context.Context, string) ([]service return nil, nil } func (f *fakeGroupRepo) ExistsByName(context.Context, string) (bool, error) { return false, nil } -func (f *fakeGroupRepo) GetAccountCount(context.Context, int64) (int64, error) { return 0, nil } +func (f *fakeGroupRepo) GetAccountCount(context.Context, int64) (int64, int64, error) { return 0, 0, nil } func (f *fakeGroupRepo) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) { return 0, nil } diff --git a/backend/internal/handler/gateway_helper_hotpath_test.go b/backend/internal/handler/gateway_helper_hotpath_test.go index 9e904107..4a677199 100644 --- a/backend/internal/handler/gateway_helper_hotpath_test.go +++ b/backend/internal/handler/gateway_helper_hotpath_test.go @@ -136,7 +136,7 @@ func validClaudeCodeBodyJSON() []byte { return []byte(`{ "model":"claude-3-5-sonnet-20241022", "system":[{"text":"You are Claude Code, Anthropic's official CLI for Claude."}], - "metadata":{"user_id":"user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123"} + "metadata":{"user_id":"user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"} }`) } @@ -190,7 +190,7 @@ func TestSetClaudeCodeClientContext_ReuseParsedRequestAndContextCache(t *testing System: []any{ map[string]any{"text": "You are Claude Code, Anthropic's official CLI for Claude."}, }, - MetadataUserID: "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123", + MetadataUserID: "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", } // body 非法 JSON,如果函数复用 parsedReq 成功则仍应判定为 Claude Code。 @@ -209,7 +209,7 @@ func TestSetClaudeCodeClientContext_ReuseParsedRequestAndContextCache(t *testing "system": []any{ map[string]any{"text": "You are Claude Code, Anthropic's official CLI for Claude."}, }, - "metadata": map[string]any{"user_id": "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123"}, + "metadata": map[string]any{"user_id": "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"}, }) SetClaudeCodeClientContext(c, []byte(`{invalid`), nil) diff --git a/backend/internal/handler/sora_gateway_handler_test.go b/backend/internal/handler/sora_gateway_handler_test.go index 06b09437..5c631132 100644 --- a/backend/internal/handler/sora_gateway_handler_test.go +++ b/backend/internal/handler/sora_gateway_handler_test.go @@ -273,8 +273,8 @@ func (r *stubGroupRepo) ListActiveByPlatform(ctx context.Context, platform strin func (r *stubGroupRepo) ExistsByName(ctx context.Context, name string) (bool, error) { return false, nil } -func (r *stubGroupRepo) GetAccountCount(ctx context.Context, groupID int64) (int64, error) { - return 0, nil +func (r *stubGroupRepo) GetAccountCount(ctx context.Context, groupID int64) (int64, int64, error) { + return 0, 0, nil } func (r *stubGroupRepo) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) { return 0, nil @@ -348,6 +348,9 @@ func (s *stubUsageLogRepo) GetGroupStatsWithFilters(ctx context.Context, startTi func (s *stubUsageLogRepo) GetUserBreakdownStats(ctx context.Context, startTime, endTime time.Time, dim usagestats.UserBreakdownDimension, limit int) ([]usagestats.UserBreakdownItem, error) { return nil, nil } +func (s *stubUsageLogRepo) GetAllGroupUsageSummary(ctx context.Context, todayStart time.Time) ([]usagestats.GroupUsageSummary, error) { + return nil, nil +} func (s *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) { return nil, nil } diff --git a/backend/internal/pkg/usagestats/usage_log_types.go b/backend/internal/pkg/usagestats/usage_log_types.go index f42a746f..44cddb6a 100644 --- a/backend/internal/pkg/usagestats/usage_log_types.go +++ b/backend/internal/pkg/usagestats/usage_log_types.go @@ -3,6 +3,28 @@ package usagestats import "time" +const ( + ModelSourceRequested = "requested" + ModelSourceUpstream = "upstream" + ModelSourceMapping = "mapping" +) + +func IsValidModelSource(source string) bool { + switch source { + case ModelSourceRequested, ModelSourceUpstream, ModelSourceMapping: + return true + default: + return false + } +} + +func NormalizeModelSource(source string) string { + if IsValidModelSource(source) { + return source + } + return ModelSourceRequested +} + // DashboardStats 仪表盘统计 type DashboardStats struct { // 用户统计 @@ -90,6 +112,13 @@ type EndpointStat struct { ActualCost float64 `json:"actual_cost"` // 实际扣除 } +// GroupUsageSummary represents today's and cumulative cost for a single group. +type GroupUsageSummary struct { + GroupID int64 `json:"group_id"` + TodayCost float64 `json:"today_cost"` + TotalCost float64 `json:"total_cost"` +} + // GroupStat represents usage statistics for a single group type GroupStat struct { GroupID int64 `json:"group_id"` @@ -143,6 +172,7 @@ type UserBreakdownItem struct { type UserBreakdownDimension struct { GroupID int64 // filter by group_id (>0 to enable) Model string // filter by model name (non-empty to enable) + ModelType string // "requested", "upstream", or "mapping" Endpoint string // filter by endpoint value (non-empty to enable) EndpointType string // "inbound", "upstream", or "path" } diff --git a/backend/internal/pkg/usagestats/usage_log_types_test.go b/backend/internal/pkg/usagestats/usage_log_types_test.go new file mode 100644 index 00000000..95cf6069 --- /dev/null +++ b/backend/internal/pkg/usagestats/usage_log_types_test.go @@ -0,0 +1,47 @@ +package usagestats + +import "testing" + +func TestIsValidModelSource(t *testing.T) { + tests := []struct { + name string + source string + want bool + }{ + {name: "requested", source: ModelSourceRequested, want: true}, + {name: "upstream", source: ModelSourceUpstream, want: true}, + {name: "mapping", source: ModelSourceMapping, want: true}, + {name: "invalid", source: "foobar", want: false}, + {name: "empty", source: "", want: false}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if got := IsValidModelSource(tc.source); got != tc.want { + t.Fatalf("IsValidModelSource(%q)=%v want %v", tc.source, got, tc.want) + } + }) + } +} + +func TestNormalizeModelSource(t *testing.T) { + tests := []struct { + name string + source string + want string + }{ + {name: "requested", source: ModelSourceRequested, want: ModelSourceRequested}, + {name: "upstream", source: ModelSourceUpstream, want: ModelSourceUpstream}, + {name: "mapping", source: ModelSourceMapping, want: ModelSourceMapping}, + {name: "invalid falls back", source: "foobar", want: ModelSourceRequested}, + {name: "empty falls back", source: "", want: ModelSourceRequested}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if got := NormalizeModelSource(tc.source); got != tc.want { + t.Fatalf("NormalizeModelSource(%q)=%q want %q", tc.source, got, tc.want) + } + }) + } +} diff --git a/backend/internal/repository/group_repo.go b/backend/internal/repository/group_repo.go index c195f1f1..674c655b 100644 --- a/backend/internal/repository/group_repo.go +++ b/backend/internal/repository/group_repo.go @@ -88,8 +88,9 @@ func (r *groupRepository) GetByID(ctx context.Context, id int64) (*service.Group if err != nil { return nil, err } - count, _ := r.GetAccountCount(ctx, out.ID) - out.AccountCount = count + total, active, _ := r.GetAccountCount(ctx, out.ID) + out.AccountCount = total + out.ActiveAccountCount = active return out, nil } @@ -256,7 +257,10 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination counts, err := r.loadAccountCounts(ctx, groupIDs) if err == nil { for i := range outGroups { - outGroups[i].AccountCount = counts[outGroups[i].ID] + c := counts[outGroups[i].ID] + outGroups[i].AccountCount = c.Total + outGroups[i].ActiveAccountCount = c.Active + outGroups[i].RateLimitedAccountCount = c.RateLimited } } @@ -283,7 +287,10 @@ func (r *groupRepository) ListActive(ctx context.Context) ([]service.Group, erro counts, err := r.loadAccountCounts(ctx, groupIDs) if err == nil { for i := range outGroups { - outGroups[i].AccountCount = counts[outGroups[i].ID] + c := counts[outGroups[i].ID] + outGroups[i].AccountCount = c.Total + outGroups[i].ActiveAccountCount = c.Active + outGroups[i].RateLimitedAccountCount = c.RateLimited } } @@ -310,7 +317,10 @@ func (r *groupRepository) ListActiveByPlatform(ctx context.Context, platform str counts, err := r.loadAccountCounts(ctx, groupIDs) if err == nil { for i := range outGroups { - outGroups[i].AccountCount = counts[outGroups[i].ID] + c := counts[outGroups[i].ID] + outGroups[i].AccountCount = c.Total + outGroups[i].ActiveAccountCount = c.Active + outGroups[i].RateLimitedAccountCount = c.RateLimited } } @@ -369,12 +379,20 @@ func (r *groupRepository) ExistsByIDs(ctx context.Context, ids []int64) (map[int return result, nil } -func (r *groupRepository) GetAccountCount(ctx context.Context, groupID int64) (int64, error) { - var count int64 - if err := scanSingleRow(ctx, r.sql, "SELECT COUNT(*) FROM account_groups WHERE group_id = $1", []any{groupID}, &count); err != nil { - return 0, err - } - return count, nil +func (r *groupRepository) GetAccountCount(ctx context.Context, groupID int64) (total int64, active int64, err error) { + var rateLimited int64 + err = scanSingleRow(ctx, r.sql, + `SELECT COUNT(*), + COUNT(*) FILTER (WHERE a.status = 'active' AND a.schedulable = true), + COUNT(*) FILTER (WHERE a.status = 'active' AND ( + a.rate_limit_reset_at > NOW() OR + a.overload_until > NOW() OR + a.temp_unschedulable_until > NOW() + )) + FROM account_groups ag JOIN accounts a ON a.id = ag.account_id + WHERE ag.group_id = $1`, + []any{groupID}, &total, &active, &rateLimited) + return } func (r *groupRepository) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) { @@ -500,15 +518,32 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64, return affectedUserIDs, nil } -func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int64) (counts map[int64]int64, err error) { - counts = make(map[int64]int64, len(groupIDs)) +type groupAccountCounts struct { + Total int64 + Active int64 + RateLimited int64 +} + +func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int64) (counts map[int64]groupAccountCounts, err error) { + counts = make(map[int64]groupAccountCounts, len(groupIDs)) if len(groupIDs) == 0 { return counts, nil } rows, err := r.sql.QueryContext( ctx, - "SELECT group_id, COUNT(*) FROM account_groups WHERE group_id = ANY($1) GROUP BY group_id", + `SELECT ag.group_id, + COUNT(*) AS total, + COUNT(*) FILTER (WHERE a.status = 'active' AND a.schedulable = true) AS active, + COUNT(*) FILTER (WHERE a.status = 'active' AND ( + a.rate_limit_reset_at > NOW() OR + a.overload_until > NOW() OR + a.temp_unschedulable_until > NOW() + )) AS rate_limited + FROM account_groups ag + JOIN accounts a ON a.id = ag.account_id + WHERE ag.group_id = ANY($1) + GROUP BY ag.group_id`, pq.Array(groupIDs), ) if err != nil { @@ -523,11 +558,11 @@ func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int6 for rows.Next() { var groupID int64 - var count int64 - if err = rows.Scan(&groupID, &count); err != nil { + var c groupAccountCounts + if err = rows.Scan(&groupID, &c.Total, &c.Active, &c.RateLimited); err != nil { return nil, err } - counts[groupID] = count + counts[groupID] = c } if err = rows.Err(); err != nil { return nil, err diff --git a/backend/internal/repository/group_repo_integration_test.go b/backend/internal/repository/group_repo_integration_test.go index 4a849a46..eccf5cea 100644 --- a/backend/internal/repository/group_repo_integration_test.go +++ b/backend/internal/repository/group_repo_integration_test.go @@ -603,7 +603,7 @@ func (s *GroupRepoSuite) TestGetAccountCount() { _, err = s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", a2, group.ID, 2) s.Require().NoError(err) - count, err := s.repo.GetAccountCount(s.ctx, group.ID) + count, _, err := s.repo.GetAccountCount(s.ctx, group.ID) s.Require().NoError(err, "GetAccountCount") s.Require().Equal(int64(2), count) } @@ -619,7 +619,7 @@ func (s *GroupRepoSuite) TestGetAccountCount_Empty() { } s.Require().NoError(s.repo.Create(s.ctx, group)) - count, err := s.repo.GetAccountCount(s.ctx, group.ID) + count, _, err := s.repo.GetAccountCount(s.ctx, group.ID) s.Require().NoError(err) s.Require().Zero(count) } @@ -651,7 +651,7 @@ func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID() { s.Require().NoError(err, "DeleteAccountGroupsByGroupID") s.Require().Equal(int64(1), affected, "expected 1 affected row") - count, err := s.repo.GetAccountCount(s.ctx, g.ID) + count, _, err := s.repo.GetAccountCount(s.ctx, g.ID) s.Require().NoError(err, "GetAccountCount") s.Require().Equal(int64(0), count, "expected 0 account groups") } @@ -692,7 +692,7 @@ func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID_MultipleAccounts() { s.Require().NoError(err) s.Require().Equal(int64(3), affected) - count, _ := s.repo.GetAccountCount(s.ctx, g.ID) + count, _, _ := s.repo.GetAccountCount(s.ctx, g.ID) s.Require().Zero(count) } diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go index dcdaeaee..ca454606 100644 --- a/backend/internal/repository/usage_log_repo.go +++ b/backend/internal/repository/usage_log_repo.go @@ -28,7 +28,7 @@ import ( gocache "github.com/patrickmn/go-cache" ) -const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, created_at" +const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, created_at" var usageLogInsertArgTypes = [...]string{ "bigint", @@ -36,6 +36,7 @@ var usageLogInsertArgTypes = [...]string{ "bigint", "text", "text", + "text", "bigint", "bigint", "integer", @@ -277,6 +278,7 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor, account_id, request_id, model, + upstream_model, group_id, subscription_id, input_tokens, @@ -311,12 +313,12 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor, cache_ttl_overridden, created_at ) VALUES ( - $1, $2, $3, $4, $5, - $6, $7, - $8, $9, $10, $11, - $12, $13, - $14, $15, $16, $17, $18, $19, - $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38 + $1, $2, $3, $4, $5, $6, + $7, $8, + $9, $10, $11, $12, + $13, $14, + $15, $16, $17, $18, $19, $20, + $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39 ) ON CONFLICT (request_id, api_key_id) DO NOTHING RETURNING id, created_at @@ -707,6 +709,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage account_id, request_id, model, + upstream_model, group_id, subscription_id, input_tokens, @@ -742,7 +745,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage created_at ) AS (VALUES `) - args := make([]any, 0, len(keys)*38) + args := make([]any, 0, len(keys)*39) argPos := 1 for idx, key := range keys { if idx > 0 { @@ -776,6 +779,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage account_id, request_id, model, + upstream_model, group_id, subscription_id, input_tokens, @@ -816,6 +820,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage account_id, request_id, model, + upstream_model, group_id, subscription_id, input_tokens, @@ -896,6 +901,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( account_id, request_id, model, + upstream_model, group_id, subscription_id, input_tokens, @@ -931,7 +937,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( created_at ) AS (VALUES `) - args := make([]any, 0, len(preparedList)*38) + args := make([]any, 0, len(preparedList)*39) argPos := 1 for idx, prepared := range preparedList { if idx > 0 { @@ -962,6 +968,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( account_id, request_id, model, + upstream_model, group_id, subscription_id, input_tokens, @@ -1002,6 +1009,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( account_id, request_id, model, + upstream_model, group_id, subscription_id, input_tokens, @@ -1050,6 +1058,7 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared account_id, request_id, model, + upstream_model, group_id, subscription_id, input_tokens, @@ -1084,12 +1093,12 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared cache_ttl_overridden, created_at ) VALUES ( - $1, $2, $3, $4, $5, - $6, $7, - $8, $9, $10, $11, - $12, $13, - $14, $15, $16, $17, $18, $19, - $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38 + $1, $2, $3, $4, $5, $6, + $7, $8, + $9, $10, $11, $12, + $13, $14, + $15, $16, $17, $18, $19, $20, + $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39 ) ON CONFLICT (request_id, api_key_id) DO NOTHING `, prepared.args...) @@ -1121,6 +1130,7 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared { reasoningEffort := nullString(log.ReasoningEffort) inboundEndpoint := nullString(log.InboundEndpoint) upstreamEndpoint := nullString(log.UpstreamEndpoint) + upstreamModel := nullString(log.UpstreamModel) var requestIDArg any if requestID != "" { @@ -1138,6 +1148,7 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared { log.AccountID, requestIDArg, log.Model, + upstreamModel, groupID, subscriptionID, log.InputTokens, @@ -2864,15 +2875,26 @@ func (r *usageLogRepository) getUsageTrendFromAggregates(ctx context.Context, st // GetModelStatsWithFilters returns model statistics with optional filters func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) (results []ModelStat, err error) { + return r.getModelStatsWithFiltersBySource(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType, usagestats.ModelSourceRequested) +} + +// GetModelStatsWithFiltersBySource returns model statistics with optional filters and model source dimension. +// source: requested | upstream | mapping. +func (r *usageLogRepository) GetModelStatsWithFiltersBySource(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8, source string) (results []ModelStat, err error) { + return r.getModelStatsWithFiltersBySource(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType, source) +} + +func (r *usageLogRepository) getModelStatsWithFiltersBySource(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8, source string) (results []ModelStat, err error) { actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost" // 当仅按 account_id 聚合时,实际费用使用账号倍率(total_cost * account_rate_multiplier)。 if accountID > 0 && userID == 0 && apiKeyID == 0 { actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost" } + modelExpr := resolveModelDimensionExpression(source) query := fmt.Sprintf(` SELECT - model, + %s as model, COUNT(*) as requests, COALESCE(SUM(input_tokens), 0) as input_tokens, COALESCE(SUM(output_tokens), 0) as output_tokens, @@ -2883,7 +2905,7 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start %s FROM usage_logs WHERE created_at >= $1 AND created_at < $2 - `, actualCostExpr) + `, modelExpr, actualCostExpr) args := []any{startTime, endTime} if userID > 0 { @@ -2907,7 +2929,7 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1) args = append(args, int16(*billingType)) } - query += " GROUP BY model ORDER BY total_tokens DESC" + query += fmt.Sprintf(" GROUP BY %s ORDER BY total_tokens DESC", modelExpr) rows, err := r.sql.QueryContext(ctx, query, args...) if err != nil { @@ -3021,7 +3043,7 @@ func (r *usageLogRepository) GetUserBreakdownStats(ctx context.Context, startTim args = append(args, dim.GroupID) } if dim.Model != "" { - query += fmt.Sprintf(" AND ul.model = $%d", len(args)+1) + query += fmt.Sprintf(" AND %s = $%d", resolveModelDimensionExpression(dim.ModelType), len(args)+1) args = append(args, dim.Model) } if dim.Endpoint != "" { @@ -3067,6 +3089,53 @@ func (r *usageLogRepository) GetUserBreakdownStats(ctx context.Context, startTim return results, nil } +// GetAllGroupUsageSummary returns today's and cumulative actual_cost for every group. +// todayStart is the start-of-day in the caller's timezone (UTC-based). +// TODO(perf): This query scans ALL usage_logs rows for total_cost aggregation. +// When usage_logs exceeds ~1M rows, consider adding a short-lived cache (30s) +// or a materialized view / pre-aggregation table for cumulative costs. +func (r *usageLogRepository) GetAllGroupUsageSummary(ctx context.Context, todayStart time.Time) ([]usagestats.GroupUsageSummary, error) { + query := ` + SELECT + g.id AS group_id, + COALESCE(SUM(ul.actual_cost), 0) AS total_cost, + COALESCE(SUM(CASE WHEN ul.created_at >= $1 THEN ul.actual_cost ELSE 0 END), 0) AS today_cost + FROM groups g + LEFT JOIN usage_logs ul ON ul.group_id = g.id + GROUP BY g.id + ` + + rows, err := r.sql.QueryContext(ctx, query, todayStart) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + var results []usagestats.GroupUsageSummary + for rows.Next() { + var row usagestats.GroupUsageSummary + if err := rows.Scan(&row.GroupID, &row.TotalCost, &row.TodayCost); err != nil { + return nil, err + } + results = append(results, row) + } + if err := rows.Err(); err != nil { + return nil, err + } + return results, nil +} + +// resolveModelDimensionExpression maps model source type to a safe SQL expression. +func resolveModelDimensionExpression(modelType string) string { + switch usagestats.NormalizeModelSource(modelType) { + case usagestats.ModelSourceUpstream: + return "COALESCE(NULLIF(TRIM(upstream_model), ''), model)" + case usagestats.ModelSourceMapping: + return "(model || ' -> ' || COALESCE(NULLIF(TRIM(upstream_model), ''), model))" + default: + return "model" + } +} + // resolveEndpointColumn maps endpoint type to the corresponding DB column name. func resolveEndpointColumn(endpointType string) string { switch endpointType { @@ -3819,6 +3888,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e accountID int64 requestID sql.NullString model string + upstreamModel sql.NullString groupID sql.NullInt64 subscriptionID sql.NullInt64 inputTokens int @@ -3861,6 +3931,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e &accountID, &requestID, &model, + &upstreamModel, &groupID, &subscriptionID, &inputTokens, @@ -3973,6 +4044,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e if upstreamEndpoint.Valid { log.UpstreamEndpoint = &upstreamEndpoint.String } + if upstreamModel.Valid { + log.UpstreamModel = &upstreamModel.String + } return log, nil } diff --git a/backend/internal/repository/usage_log_repo_breakdown_test.go b/backend/internal/repository/usage_log_repo_breakdown_test.go index ca63e0bc..5d908bfd 100644 --- a/backend/internal/repository/usage_log_repo_breakdown_test.go +++ b/backend/internal/repository/usage_log_repo_breakdown_test.go @@ -5,6 +5,7 @@ package repository import ( "testing" + "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" "github.com/stretchr/testify/require" ) @@ -16,8 +17,8 @@ func TestResolveEndpointColumn(t *testing.T) { {"inbound", "ul.inbound_endpoint"}, {"upstream", "ul.upstream_endpoint"}, {"path", "ul.inbound_endpoint || ' -> ' || ul.upstream_endpoint"}, - {"", "ul.inbound_endpoint"}, // default - {"unknown", "ul.inbound_endpoint"}, // fallback + {"", "ul.inbound_endpoint"}, // default + {"unknown", "ul.inbound_endpoint"}, // fallback } for _, tc := range tests { @@ -27,3 +28,23 @@ func TestResolveEndpointColumn(t *testing.T) { }) } } + +func TestResolveModelDimensionExpression(t *testing.T) { + tests := []struct { + modelType string + want string + }{ + {usagestats.ModelSourceRequested, "model"}, + {usagestats.ModelSourceUpstream, "COALESCE(NULLIF(TRIM(upstream_model), ''), model)"}, + {usagestats.ModelSourceMapping, "(model || ' -> ' || COALESCE(NULLIF(TRIM(upstream_model), ''), model))"}, + {"", "model"}, + {"invalid", "model"}, + } + + for _, tc := range tests { + t.Run(tc.modelType, func(t *testing.T) { + got := resolveModelDimensionExpression(tc.modelType) + require.Equal(t, tc.want, got) + }) + } +} diff --git a/backend/internal/repository/usage_log_repo_request_type_test.go b/backend/internal/repository/usage_log_repo_request_type_test.go index 27ae4571..76827c31 100644 --- a/backend/internal/repository/usage_log_repo_request_type_test.go +++ b/backend/internal/repository/usage_log_repo_request_type_test.go @@ -44,6 +44,7 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) { log.AccountID, log.RequestID, log.Model, + sqlmock.AnyArg(), // upstream_model sqlmock.AnyArg(), // group_id sqlmock.AnyArg(), // subscription_id log.InputTokens, @@ -116,6 +117,7 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) { log.Model, sqlmock.AnyArg(), sqlmock.AnyArg(), + sqlmock.AnyArg(), log.InputTokens, log.OutputTokens, log.CacheCreationTokens, @@ -353,6 +355,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { int64(30), // account_id sql.NullString{Valid: true, String: "req-1"}, "gpt-5", // model + sql.NullString{}, // upstream_model sql.NullInt64{}, // group_id sql.NullInt64{}, // subscription_id 1, // input_tokens @@ -404,6 +407,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { int64(31), sql.NullString{Valid: true, String: "req-2"}, "gpt-5", + sql.NullString{}, sql.NullInt64{}, sql.NullInt64{}, 1, 2, 3, 4, 5, 6, @@ -445,6 +449,7 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { int64(32), sql.NullString{Valid: true, String: "req-3"}, "gpt-5.4", + sql.NullString{}, sql.NullInt64{}, sql.NullInt64{}, 1, 2, 3, 4, 5, 6, diff --git a/backend/internal/repository/user_subscription_repo.go b/backend/internal/repository/user_subscription_repo.go index 5a649846..e3f64a5f 100644 --- a/backend/internal/repository/user_subscription_repo.go +++ b/backend/internal/repository/user_subscription_repo.go @@ -5,6 +5,7 @@ import ( "time" dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/usersubscription" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/service" @@ -190,7 +191,7 @@ func (r *userSubscriptionRepository) ListByGroupID(ctx context.Context, groupID return userSubscriptionEntitiesToService(subs), paginationResultFromTotal(int64(total), params), nil } -func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) { +func (r *userSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, platform, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) { client := clientFromContext(ctx, r.client) q := client.UserSubscription.Query() if userID != nil { @@ -199,6 +200,9 @@ func (r *userSubscriptionRepository) List(ctx context.Context, params pagination if groupID != nil { q = q.Where(usersubscription.GroupIDEQ(*groupID)) } + if platform != "" { + q = q.Where(usersubscription.HasGroupWith(group.PlatformEQ(platform))) + } // Status filtering with real-time expiration check now := time.Now() diff --git a/backend/internal/repository/user_subscription_repo_integration_test.go b/backend/internal/repository/user_subscription_repo_integration_test.go index 60a5a378..a74860e3 100644 --- a/backend/internal/repository/user_subscription_repo_integration_test.go +++ b/backend/internal/repository/user_subscription_repo_integration_test.go @@ -271,7 +271,7 @@ func (s *UserSubscriptionRepoSuite) TestList_NoFilters() { group := s.mustCreateGroup("g-list") s.mustCreateSubscription(user.ID, group.ID, nil) - subs, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, "", "", "") + subs, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, "", "", "", "") s.Require().NoError(err, "List") s.Require().Len(subs, 1) s.Require().Equal(int64(1), page.Total) @@ -285,7 +285,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByUserID() { s.mustCreateSubscription(user1.ID, group.ID, nil) s.mustCreateSubscription(user2.ID, group.ID, nil) - subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, &user1.ID, nil, "", "", "") + subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, &user1.ID, nil, "", "", "", "") s.Require().NoError(err) s.Require().Len(subs, 1) s.Require().Equal(user1.ID, subs[0].UserID) @@ -299,7 +299,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByGroupID() { s.mustCreateSubscription(user.ID, g1.ID, nil) s.mustCreateSubscription(user.ID, g2.ID, nil) - subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, &g1.ID, "", "", "") + subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, &g1.ID, "", "", "", "") s.Require().NoError(err) s.Require().Len(subs, 1) s.Require().Equal(g1.ID, subs[0].GroupID) @@ -320,7 +320,7 @@ func (s *UserSubscriptionRepoSuite) TestList_FilterByStatus() { c.SetExpiresAt(time.Now().Add(-24 * time.Hour)) }) - subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, service.SubscriptionStatusExpired, "", "") + subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, service.SubscriptionStatusExpired, "", "", "") s.Require().NoError(err) s.Require().Len(subs, 1) s.Require().Equal(service.SubscriptionStatusExpired, subs[0].Status) diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index 309dcf4e..4ae5c272 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -924,8 +924,8 @@ func (stubGroupRepo) ExistsByName(ctx context.Context, name string) (bool, error return false, errors.New("not implemented") } -func (stubGroupRepo) GetAccountCount(ctx context.Context, groupID int64) (int64, error) { - return 0, errors.New("not implemented") +func (stubGroupRepo) GetAccountCount(ctx context.Context, groupID int64) (int64, int64, error) { + return 0, 0, errors.New("not implemented") } func (stubGroupRepo) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) { @@ -1289,7 +1289,7 @@ func (r *stubUserSubscriptionRepo) ListActiveByUserID(ctx context.Context, userI func (stubUserSubscriptionRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) { return nil, nil, errors.New("not implemented") } -func (stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) { +func (stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, platform, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) { return nil, nil, errors.New("not implemented") } func (stubUserSubscriptionRepo) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) { @@ -1786,6 +1786,9 @@ func (r *stubUsageLogRepo) GetAccountUsageStats(ctx context.Context, accountID i func (r *stubUsageLogRepo) GetStatsWithFilters(ctx context.Context, filters usagestats.UsageLogFilters) (*usagestats.UsageStats, error) { return nil, errors.New("not implemented") } +func (r *stubUsageLogRepo) GetAllGroupUsageSummary(ctx context.Context, todayStart time.Time) ([]usagestats.GroupUsageSummary, error) { + return nil, errors.New("not implemented") +} type stubSettingRepo struct { all map[string]string diff --git a/backend/internal/server/middleware/api_key_auth_google_test.go b/backend/internal/server/middleware/api_key_auth_google_test.go index 49db5f19..9f9bba13 100644 --- a/backend/internal/server/middleware/api_key_auth_google_test.go +++ b/backend/internal/server/middleware/api_key_auth_google_test.go @@ -135,7 +135,7 @@ func (f fakeGoogleSubscriptionRepo) ListActiveByUserID(ctx context.Context, user func (f fakeGoogleSubscriptionRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) { return nil, nil, errors.New("not implemented") } -func (f fakeGoogleSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) { +func (f fakeGoogleSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, platform, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) { return nil, nil, errors.New("not implemented") } func (f fakeGoogleSubscriptionRepo) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) { diff --git a/backend/internal/server/middleware/api_key_auth_test.go b/backend/internal/server/middleware/api_key_auth_test.go index 22befa2a..a633ffdd 100644 --- a/backend/internal/server/middleware/api_key_auth_test.go +++ b/backend/internal/server/middleware/api_key_auth_test.go @@ -646,7 +646,7 @@ func (r *stubUserSubscriptionRepo) ListByGroupID(ctx context.Context, groupID in return nil, nil, errors.New("not implemented") } -func (r *stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) { +func (r *stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, platform, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) { return nil, nil, errors.New("not implemented") } diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index 67d7cb45..c80cca54 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -227,6 +227,8 @@ func registerGroupRoutes(admin *gin.RouterGroup, h *handler.Handlers) { { groups.GET("", h.Admin.Group.List) groups.GET("/all", h.Admin.Group.GetAll) + groups.GET("/usage-summary", h.Admin.Group.GetUsageSummary) + groups.GET("/capacity-summary", h.Admin.Group.GetCapacitySummary) groups.PUT("/sort-order", h.Admin.Group.UpdateSortOrder) groups.GET("/:id", h.Admin.Group.GetByID) groups.POST("", h.Admin.Group.Create) @@ -400,6 +402,9 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) { adminSettings.GET("/admin-api-key", h.Admin.Setting.GetAdminAPIKey) adminSettings.POST("/admin-api-key/regenerate", h.Admin.Setting.RegenerateAdminAPIKey) adminSettings.DELETE("/admin-api-key", h.Admin.Setting.DeleteAdminAPIKey) + // 529过载冷却配置 + adminSettings.GET("/overload-cooldown", h.Admin.Setting.GetOverloadCooldownSettings) + adminSettings.PUT("/overload-cooldown", h.Admin.Setting.UpdateOverloadCooldownSettings) // 流超时处理配置 adminSettings.GET("/stream-timeout", h.Admin.Setting.GetStreamTimeoutSettings) adminSettings.PUT("/stream-timeout", h.Admin.Setting.UpdateStreamTimeoutSettings) diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go index 482d22b1..d30b670d 100644 --- a/backend/internal/service/account_test_service.go +++ b/backend/internal/service/account_test_service.go @@ -113,15 +113,18 @@ func (s *AccountTestService) validateUpstreamBaseURL(raw string) (string, error) return normalized, nil } -// generateSessionString generates a Claude Code style session string +// generateSessionString generates a Claude Code style session string. +// The output format is determined by the UA version in claude.DefaultHeaders, +// ensuring consistency between the user_id format and the UA sent to upstream. func generateSessionString() (string, error) { - bytes := make([]byte, 32) - if _, err := rand.Read(bytes); err != nil { + b := make([]byte, 32) + if _, err := rand.Read(b); err != nil { return "", err } - hex64 := hex.EncodeToString(bytes) + hex64 := hex.EncodeToString(b) sessionUUID := uuid.New().String() - return fmt.Sprintf("user_%s_account__session_%s", hex64, sessionUUID), nil + uaVersion := ExtractCLIVersion(claude.DefaultHeaders["User-Agent"]) + return FormatMetadataUserID(hex64, "", sessionUUID, uaVersion), nil } // createTestPayload creates a Claude Code style test request payload diff --git a/backend/internal/service/account_usage_service.go b/backend/internal/service/account_usage_service.go index 4a05c64a..74142700 100644 --- a/backend/internal/service/account_usage_service.go +++ b/backend/internal/service/account_usage_service.go @@ -49,6 +49,7 @@ type UsageLogRepository interface { GetUpstreamEndpointStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.EndpointStat, error) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) GetUserBreakdownStats(ctx context.Context, startTime, endTime time.Time, dim usagestats.UserBreakdownDimension, limit int) ([]usagestats.UserBreakdownItem, error) + GetAllGroupUsageSummary(ctx context.Context, todayStart time.Time) ([]usagestats.GroupUsageSummary, error) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error) GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (*usagestats.UserSpendingRankingResponse, error) diff --git a/backend/internal/service/admin_service_apikey_test.go b/backend/internal/service/admin_service_apikey_test.go index 88d2f492..7588c16d 100644 --- a/backend/internal/service/admin_service_apikey_test.go +++ b/backend/internal/service/admin_service_apikey_test.go @@ -194,7 +194,7 @@ func (s *groupRepoStubForGroupUpdate) ListActiveByPlatform(context.Context, stri func (s *groupRepoStubForGroupUpdate) ExistsByName(context.Context, string) (bool, error) { panic("unexpected") } -func (s *groupRepoStubForGroupUpdate) GetAccountCount(context.Context, int64) (int64, error) { +func (s *groupRepoStubForGroupUpdate) GetAccountCount(context.Context, int64) (int64, int64, error) { panic("unexpected") } func (s *groupRepoStubForGroupUpdate) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) { diff --git a/backend/internal/service/admin_service_delete_test.go b/backend/internal/service/admin_service_delete_test.go index 2e0f7d90..662b4771 100644 --- a/backend/internal/service/admin_service_delete_test.go +++ b/backend/internal/service/admin_service_delete_test.go @@ -160,7 +160,7 @@ func (s *groupRepoStub) ExistsByName(ctx context.Context, name string) (bool, er panic("unexpected ExistsByName call") } -func (s *groupRepoStub) GetAccountCount(ctx context.Context, groupID int64) (int64, error) { +func (s *groupRepoStub) GetAccountCount(ctx context.Context, groupID int64) (int64, int64, error) { panic("unexpected GetAccountCount call") } diff --git a/backend/internal/service/admin_service_group_test.go b/backend/internal/service/admin_service_group_test.go index ef77a980..536be0b5 100644 --- a/backend/internal/service/admin_service_group_test.go +++ b/backend/internal/service/admin_service_group_test.go @@ -100,7 +100,7 @@ func (s *groupRepoStubForAdmin) ExistsByName(_ context.Context, _ string) (bool, panic("unexpected ExistsByName call") } -func (s *groupRepoStubForAdmin) GetAccountCount(_ context.Context, _ int64) (int64, error) { +func (s *groupRepoStubForAdmin) GetAccountCount(_ context.Context, _ int64) (int64, int64, error) { panic("unexpected GetAccountCount call") } @@ -383,7 +383,7 @@ func (s *groupRepoStubForFallbackCycle) ExistsByName(_ context.Context, _ string panic("unexpected ExistsByName call") } -func (s *groupRepoStubForFallbackCycle) GetAccountCount(_ context.Context, _ int64) (int64, error) { +func (s *groupRepoStubForFallbackCycle) GetAccountCount(_ context.Context, _ int64) (int64, int64, error) { panic("unexpected GetAccountCount call") } @@ -458,7 +458,7 @@ func (s *groupRepoStubForInvalidRequestFallback) ExistsByName(_ context.Context, panic("unexpected ExistsByName call") } -func (s *groupRepoStubForInvalidRequestFallback) GetAccountCount(_ context.Context, _ int64) (int64, error) { +func (s *groupRepoStubForInvalidRequestFallback) GetAccountCount(_ context.Context, _ int64) (int64, int64, error) { panic("unexpected GetAccountCount call") } diff --git a/backend/internal/service/antigravity_model_mapping_test.go b/backend/internal/service/antigravity_model_mapping_test.go index 71939d26..1dbe9870 100644 --- a/backend/internal/service/antigravity_model_mapping_test.go +++ b/backend/internal/service/antigravity_model_mapping_test.go @@ -57,16 +57,16 @@ func TestAntigravityGatewayService_GetMappedModel(t *testing.T) { expected: "claude-opus-4-6-thinking", }, { - name: "默认映射 - claude-haiku-4-5 → claude-sonnet-4-5", + name: "默认映射 - claude-haiku-4-5 → claude-sonnet-4-6", requestedModel: "claude-haiku-4-5", accountMapping: nil, - expected: "claude-sonnet-4-5", + expected: "claude-sonnet-4-6", }, { - name: "默认映射 - claude-haiku-4-5-20251001 → claude-sonnet-4-5", + name: "默认映射 - claude-haiku-4-5-20251001 → claude-sonnet-4-6", requestedModel: "claude-haiku-4-5-20251001", accountMapping: nil, - expected: "claude-sonnet-4-5", + expected: "claude-sonnet-4-6", }, { name: "默认映射 - claude-sonnet-4-5-20250929 → claude-sonnet-4-5", diff --git a/backend/internal/service/claude_code_validator.go b/backend/internal/service/claude_code_validator.go index f71098b1..4e8ced67 100644 --- a/backend/internal/service/claude_code_validator.go +++ b/backend/internal/service/claude_code_validator.go @@ -21,9 +21,6 @@ var ( // 带捕获组的版本提取正则 claudeCodeUAVersionPattern = regexp.MustCompile(`(?i)^claude-cli/(\d+\.\d+\.\d+)`) - // metadata.user_id 格式: user_{64位hex}_account__session_{uuid} - userIDPattern = regexp.MustCompile(`^user_[a-fA-F0-9]{64}_account__session_[\w-]+$`) - // System prompt 相似度阈值(默认 0.5,和 claude-relay-service 一致) systemPromptThreshold = 0.5 ) @@ -124,7 +121,7 @@ func (v *ClaudeCodeValidator) Validate(r *http.Request, body map[string]any) boo return false } - if !userIDPattern.MatchString(userID) { + if ParseMetadataUserID(userID) == nil { return false } @@ -278,11 +275,7 @@ func SetClaudeCodeClient(ctx context.Context, isClaudeCode bool) context.Context // ExtractVersion 从 User-Agent 中提取 Claude Code 版本号 // 返回 "2.1.22" 形式的版本号,如果不匹配返回空字符串 func (v *ClaudeCodeValidator) ExtractVersion(ua string) string { - matches := claudeCodeUAVersionPattern.FindStringSubmatch(ua) - if len(matches) >= 2 { - return matches[1] - } - return "" + return ExtractCLIVersion(ua) } // SetClaudeCodeVersion 将 Claude Code 版本号设置到 context 中 diff --git a/backend/internal/service/dashboard_service.go b/backend/internal/service/dashboard_service.go index ad29990f..3e059e30 100644 --- a/backend/internal/service/dashboard_service.go +++ b/backend/internal/service/dashboard_service.go @@ -140,6 +140,27 @@ func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTi return stats, nil } +func (s *DashboardService) GetModelStatsWithFiltersBySource(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8, modelSource string) ([]usagestats.ModelStat, error) { + normalizedSource := usagestats.NormalizeModelSource(modelSource) + if normalizedSource == usagestats.ModelSourceRequested { + return s.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType) + } + + type modelStatsBySourceRepo interface { + GetModelStatsWithFiltersBySource(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8, source string) ([]usagestats.ModelStat, error) + } + + if sourceRepo, ok := s.usageRepo.(modelStatsBySourceRepo); ok { + stats, err := sourceRepo.GetModelStatsWithFiltersBySource(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType, normalizedSource) + if err != nil { + return nil, fmt.Errorf("get model stats with filters by source: %w", err) + } + return stats, nil + } + + return s.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType) +} + func (s *DashboardService) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) { stats, err := s.usageRepo.GetGroupStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType) if err != nil { @@ -148,6 +169,15 @@ func (s *DashboardService) GetGroupStatsWithFilters(ctx context.Context, startTi return stats, nil } +// GetGroupUsageSummary returns today's and cumulative cost for all groups. +func (s *DashboardService) GetGroupUsageSummary(ctx context.Context, todayStart time.Time) ([]usagestats.GroupUsageSummary, error) { + results, err := s.usageRepo.GetAllGroupUsageSummary(ctx, todayStart) + if err != nil { + return nil, fmt.Errorf("get group usage summary: %w", err) + } + return results, nil +} + func (s *DashboardService) getCachedDashboardStats(ctx context.Context) (*usagestats.DashboardStats, bool, error) { data, err := s.cache.GetDashboardStats(ctx) if err != nil { diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go index 2d8681d4..7b629e14 100644 --- a/backend/internal/service/domain_constants.go +++ b/backend/internal/service/domain_constants.go @@ -170,6 +170,13 @@ const ( // SettingKeyOpsRuntimeLogConfig stores JSON config for runtime log settings. SettingKeyOpsRuntimeLogConfig = "ops_runtime_log_config" + // ========================= + // Overload Cooldown (529) + // ========================= + + // SettingKeyOverloadCooldownSettings stores JSON config for 529 overload cooldown handling. + SettingKeyOverloadCooldownSettings = "overload_cooldown_settings" + // ========================= // Stream Timeout Handling // ========================= diff --git a/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go b/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go index 789cbab8..c534a9b7 100644 --- a/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go +++ b/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go @@ -788,7 +788,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_NonStreamingSuc rateLimitService: &RateLimitService{}, } - result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, newAnthropicAPIKeyAccountForTest(), body, "claude-3-5-sonnet-latest", false, time.Now()) + result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, newAnthropicAPIKeyAccountForTest(), body, "claude-3-5-sonnet-latest", "claude-3-5-sonnet-latest", false, time.Now()) require.NoError(t, err) require.NotNil(t, result) require.Equal(t, 12, result.Usage.InputTokens) @@ -815,7 +815,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_InvalidTokenTyp } svc := &GatewayService{} - result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, account, []byte(`{}`), "claude-3-5-sonnet-latest", false, time.Now()) + result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, account, []byte(`{}`), "claude-3-5-sonnet-latest", "claude-3-5-sonnet-latest", false, time.Now()) require.Nil(t, result) require.Error(t, err) require.Contains(t, err.Error(), "requires apikey token") @@ -840,7 +840,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_UpstreamRequest } account := newAnthropicAPIKeyAccountForTest() - result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, account, []byte(`{"model":"x"}`), "x", false, time.Now()) + result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, account, []byte(`{"model":"x"}`), "x", "x", false, time.Now()) require.Nil(t, result) require.Error(t, err) require.Contains(t, err.Error(), "upstream request failed") @@ -873,7 +873,7 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_EmptyResponseBo httpUpstream: upstream, } - result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, newAnthropicAPIKeyAccountForTest(), []byte(`{"model":"x"}`), "x", false, time.Now()) + result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, newAnthropicAPIKeyAccountForTest(), []byte(`{"model":"x"}`), "x", "x", false, time.Now()) require.Nil(t, result) require.Error(t, err) require.Contains(t, err.Error(), "empty response") diff --git a/backend/internal/service/gateway_multiplatform_test.go b/backend/internal/service/gateway_multiplatform_test.go index ea8fa784..718cd42a 100644 --- a/backend/internal/service/gateway_multiplatform_test.go +++ b/backend/internal/service/gateway_multiplatform_test.go @@ -278,8 +278,8 @@ func (m *mockGroupRepoForGateway) ListActiveByPlatform(ctx context.Context, plat func (m *mockGroupRepoForGateway) ExistsByName(ctx context.Context, name string) (bool, error) { return false, nil } -func (m *mockGroupRepoForGateway) GetAccountCount(ctx context.Context, groupID int64) (int64, error) { - return 0, nil +func (m *mockGroupRepoForGateway) GetAccountCount(ctx context.Context, groupID int64) (int64, int64, error) { + return 0, 0, nil } func (m *mockGroupRepoForGateway) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) { return 0, nil diff --git a/backend/internal/service/gateway_request.go b/backend/internal/service/gateway_request.go index 3816aea9..29b6cfd6 100644 --- a/backend/internal/service/gateway_request.go +++ b/backend/internal/service/gateway_request.go @@ -28,6 +28,12 @@ var ( patternEmptyContentSpaced = []byte(`"content": []`) patternEmptyContentSp1 = []byte(`"content" : []`) patternEmptyContentSp2 = []byte(`"content" :[]`) + + // Fast-path patterns for empty text blocks: {"type":"text","text":""} + patternEmptyText = []byte(`"text":""`) + patternEmptyTextSpaced = []byte(`"text": ""`) + patternEmptyTextSp1 = []byte(`"text" : ""`) + patternEmptyTextSp2 = []byte(`"text" :""`) ) // SessionContext 粘性会话上下文,用于区分不同来源的请求。 @@ -233,15 +239,22 @@ func FilterThinkingBlocksForRetry(body []byte) []byte { bytes.Contains(body, patternThinkingField) || bytes.Contains(body, patternThinkingFieldSpaced) - // Also check for empty content arrays that need fixing. + // Also check for empty content arrays and empty text blocks that need fixing. // Note: This is a heuristic check; the actual empty content handling is done below. hasEmptyContent := bytes.Contains(body, patternEmptyContent) || bytes.Contains(body, patternEmptyContentSpaced) || bytes.Contains(body, patternEmptyContentSp1) || bytes.Contains(body, patternEmptyContentSp2) + // Check for empty text blocks: {"type":"text","text":""} + // These cause upstream 400: "text content blocks must be non-empty" + hasEmptyTextBlock := bytes.Contains(body, patternEmptyText) || + bytes.Contains(body, patternEmptyTextSpaced) || + bytes.Contains(body, patternEmptyTextSp1) || + bytes.Contains(body, patternEmptyTextSp2) + // Fast path: nothing to process - if !hasThinkingContent && !hasEmptyContent { + if !hasThinkingContent && !hasEmptyContent && !hasEmptyTextBlock { return body } @@ -260,7 +273,7 @@ func FilterThinkingBlocksForRetry(body []byte) []byte { bytes.Contains(body, patternTypeRedactedThinking) || bytes.Contains(body, patternTypeRedactedSpaced) || bytes.Contains(body, patternThinkingFieldSpaced) - if !hasEmptyContent && !containsThinkingBlocks { + if !hasEmptyContent && !hasEmptyTextBlock && !containsThinkingBlocks { if topThinking := gjson.Get(jsonStr, "thinking"); topThinking.Exists() { if out, err := sjson.DeleteBytes(body, "thinking"); err == nil { out = removeThinkingDependentContextStrategies(out) @@ -320,6 +333,16 @@ func FilterThinkingBlocksForRetry(body []byte) []byte { blockType, _ := blockMap["type"].(string) + // Strip empty text blocks: {"type":"text","text":""} + // Upstream rejects these with 400: "text content blocks must be non-empty" + if blockType == "text" { + if txt, _ := blockMap["text"].(string); txt == "" { + modifiedThisMsg = true + ensureNewContent(bi) + continue + } + } + // Convert thinking blocks to text (preserve content) and drop redacted_thinking. switch blockType { case "thinking": diff --git a/backend/internal/service/gateway_request_test.go b/backend/internal/service/gateway_request_test.go index f60ed9fb..b11fee9b 100644 --- a/backend/internal/service/gateway_request_test.go +++ b/backend/internal/service/gateway_request_test.go @@ -404,6 +404,51 @@ func TestFilterThinkingBlocksForRetry_EmptyContentGetsPlaceholder(t *testing.T) require.NotEmpty(t, content0["text"]) } +func TestFilterThinkingBlocksForRetry_StripsEmptyTextBlocks(t *testing.T) { + // Empty text blocks cause upstream 400: "text content blocks must be non-empty" + input := []byte(`{ + "messages":[ + {"role":"user","content":[{"type":"text","text":"hello"},{"type":"text","text":""}]}, + {"role":"assistant","content":[{"type":"text","text":""}]} + ] + }`) + + out := FilterThinkingBlocksForRetry(input) + + var req map[string]any + require.NoError(t, json.Unmarshal(out, &req)) + msgs, ok := req["messages"].([]any) + require.True(t, ok) + + // First message: empty text block stripped, "hello" preserved + msg0 := msgs[0].(map[string]any) + content0 := msg0["content"].([]any) + require.Len(t, content0, 1) + require.Equal(t, "hello", content0[0].(map[string]any)["text"]) + + // Second message: only had empty text block → gets placeholder + msg1 := msgs[1].(map[string]any) + content1 := msg1["content"].([]any) + require.Len(t, content1, 1) + block1 := content1[0].(map[string]any) + require.Equal(t, "text", block1["type"]) + require.NotEmpty(t, block1["text"]) +} + +func TestFilterThinkingBlocksForRetry_PreservesNonEmptyTextBlocks(t *testing.T) { + // Non-empty text blocks should pass through unchanged + input := []byte(`{ + "messages":[ + {"role":"user","content":[{"type":"text","text":"hello"},{"type":"text","text":"world"}]} + ] + }`) + + out := FilterThinkingBlocksForRetry(input) + + // Fast path: no thinking content, no empty content, no empty text blocks → unchanged + require.Equal(t, input, out) +} + func TestFilterSignatureSensitiveBlocksForRetry_DowngradesTools(t *testing.T) { input := []byte(`{ "thinking":{"type":"enabled","budget_tokens":1024}, diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 0b50162a..7e962f7f 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -326,7 +326,6 @@ func isClaudeCodeCredentialScopeError(msg string) bool { // Some upstream APIs return non-standard "data:" without space (should be "data: "). var ( sseDataRe = regexp.MustCompile(`^data:\s*`) - sessionIDRegex = regexp.MustCompile(`session_([a-f0-9-]{36})`) claudeCliUserAgentRe = regexp.MustCompile(`^claude-cli/\d+\.\d+\.\d+`) // claudeCodePromptPrefixes 用于检测 Claude Code 系统提示词的前缀列表 @@ -491,6 +490,7 @@ type ForwardResult struct { RequestID string Usage ClaudeUsage Model string + UpstreamModel string // Actual upstream model after mapping (empty = no mapping) Stream bool Duration time.Duration FirstTokenMs *int // 首字时间(流式请求) @@ -644,8 +644,8 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string { // 1. 最高优先级:从 metadata.user_id 提取 session_xxx if parsed.MetadataUserID != "" { - if match := sessionIDRegex.FindStringSubmatch(parsed.MetadataUserID); len(match) > 1 { - return match[1] + if uid := ParseMetadataUserID(parsed.MetadataUserID); uid != nil && uid.SessionID != "" { + return uid.SessionID } } @@ -1026,13 +1026,13 @@ func (s *GatewayService) buildOAuthMetadataUserID(parsed *ParsedRequest, account sessionID = generateSessionUUID(seed) } - // Prefer the newer format that includes account_uuid (if present), - // otherwise fall back to the legacy Claude Code format. - accountUUID := strings.TrimSpace(account.GetExtraString("account_uuid")) - if accountUUID != "" { - return fmt.Sprintf("user_%s_account_%s_session_%s", userID, accountUUID, sessionID) + // 根据指纹 UA 版本选择输出格式 + var uaVersion string + if fp != nil { + uaVersion = ExtractCLIVersion(fp.UserAgent) } - return fmt.Sprintf("user_%s_account__session_%s", userID, sessionID) + accountUUID := strings.TrimSpace(account.GetExtraString("account_uuid")) + return FormatMetadataUserID(userID, accountUUID, sessionID, uaVersion) } // GenerateSessionUUID creates a deterministic UUID4 from a seed string. @@ -3989,7 +3989,13 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A passthroughModel = mappedModel } } - return s.forwardAnthropicAPIKeyPassthrough(ctx, c, account, passthroughBody, passthroughModel, parsed.Stream, startTime) + return s.forwardAnthropicAPIKeyPassthroughWithInput(ctx, c, account, anthropicPassthroughForwardInput{ + Body: passthroughBody, + RequestModel: passthroughModel, + OriginalModel: parsed.Model, + RequestStream: parsed.Stream, + StartTime: startTime, + }) } if account != nil && account.IsBedrock() { @@ -4513,6 +4519,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A RequestID: resp.Header.Get("x-request-id"), Usage: *usage, Model: originalModel, // 使用原始模型用于计费和日志 + UpstreamModel: mappedModel, Stream: reqStream, Duration: time.Since(startTime), FirstTokenMs: firstTokenMs, @@ -4520,14 +4527,38 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A }, nil } +type anthropicPassthroughForwardInput struct { + Body []byte + RequestModel string + OriginalModel string + RequestStream bool + StartTime time.Time +} + func (s *GatewayService) forwardAnthropicAPIKeyPassthrough( ctx context.Context, c *gin.Context, account *Account, body []byte, reqModel string, + originalModel string, reqStream bool, startTime time.Time, +) (*ForwardResult, error) { + return s.forwardAnthropicAPIKeyPassthroughWithInput(ctx, c, account, anthropicPassthroughForwardInput{ + Body: body, + RequestModel: reqModel, + OriginalModel: originalModel, + RequestStream: reqStream, + StartTime: startTime, + }) +} + +func (s *GatewayService) forwardAnthropicAPIKeyPassthroughWithInput( + ctx context.Context, + c *gin.Context, + account *Account, + input anthropicPassthroughForwardInput, ) (*ForwardResult, error) { token, tokenType, err := s.GetAccessToken(ctx, account) if err != nil { @@ -4543,19 +4574,19 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthrough( } logger.LegacyPrintf("service.gateway", "[Anthropic 自动透传] 命中 API Key 透传分支: account=%d name=%s model=%s stream=%v", - account.ID, account.Name, reqModel, reqStream) + account.ID, account.Name, input.RequestModel, input.RequestStream) if c != nil { c.Set("anthropic_passthrough", true) } // 重试间复用同一请求体,避免每次 string(body) 产生额外分配。 - setOpsUpstreamRequestBody(c, body) + setOpsUpstreamRequestBody(c, input.Body) var resp *http.Response retryStart := time.Now() for attempt := 1; attempt <= maxRetryAttempts; attempt++ { - upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream) - upstreamReq, err := s.buildUpstreamRequestAnthropicAPIKeyPassthrough(upstreamCtx, c, account, body, token) + upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, input.RequestStream) + upstreamReq, err := s.buildUpstreamRequestAnthropicAPIKeyPassthrough(upstreamCtx, c, account, input.Body, token) releaseUpstreamCtx() if err != nil { return nil, err @@ -4713,8 +4744,8 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthrough( var usage *ClaudeUsage var firstTokenMs *int var clientDisconnect bool - if reqStream { - streamResult, err := s.handleStreamingResponseAnthropicAPIKeyPassthrough(ctx, resp, c, account, startTime, reqModel) + if input.RequestStream { + streamResult, err := s.handleStreamingResponseAnthropicAPIKeyPassthrough(ctx, resp, c, account, input.StartTime, input.RequestModel) if err != nil { return nil, err } @@ -4734,9 +4765,10 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthrough( return &ForwardResult{ RequestID: resp.Header.Get("x-request-id"), Usage: *usage, - Model: reqModel, - Stream: reqStream, - Duration: time.Since(startTime), + Model: input.OriginalModel, + UpstreamModel: input.RequestModel, + Stream: input.RequestStream, + Duration: time.Since(input.StartTime), FirstTokenMs: firstTokenMs, ClientDisconnect: clientDisconnect, }, nil @@ -5241,6 +5273,7 @@ func (s *GatewayService) forwardBedrock( RequestID: resp.Header.Get("x-amzn-requestid"), Usage: *usage, Model: reqModel, + UpstreamModel: mappedModel, Stream: reqStream, Duration: time.Since(startTime), FirstTokenMs: firstTokenMs, @@ -5533,7 +5566,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex // 如果启用了会话ID伪装,会在重写后替换 session 部分为固定值 accountUUID := account.GetExtraString("account_uuid") if accountUUID != "" && fp.ClientID != "" { - if newBody, err := s.identityService.RewriteUserIDWithMasking(ctx, body, account, accountUUID, fp.ClientID); err == nil && len(newBody) > 0 { + if newBody, err := s.identityService.RewriteUserIDWithMasking(ctx, body, account, accountUUID, fp.ClientID, fp.UserAgent); err == nil && len(newBody) > 0 { body = newBody } } @@ -6068,9 +6101,11 @@ func (s *GatewayService) isThinkingBlockSignatureError(respBody []byte) bool { return true } - // 检测空消息内容错误(可能是过滤 thinking blocks 后导致的) + // 检测空消息内容错误(可能是过滤 thinking blocks 后导致的,或客户端发送了空 text block) // 例如: "all messages must have non-empty content" - if strings.Contains(msg, "non-empty content") || strings.Contains(msg, "empty content") { + // "messages: text content blocks must be non-empty" + if strings.Contains(msg, "non-empty content") || strings.Contains(msg, "empty content") || + strings.Contains(msg, "content blocks must be non-empty") { logger.LegacyPrintf("service.gateway", "[SignatureCheck] Detected empty content error") return true } @@ -7529,6 +7564,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu AccountID: account.ID, RequestID: requestID, Model: result.Model, + UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model), ReasoningEffort: result.ReasoningEffort, InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint), UpstreamEndpoint: optionalTrimmedStringPtr(input.UpstreamEndpoint), @@ -7710,6 +7746,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * AccountID: account.ID, RequestID: requestID, Model: result.Model, + UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model), ReasoningEffort: result.ReasoningEffort, InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint), UpstreamEndpoint: optionalTrimmedStringPtr(input.UpstreamEndpoint), @@ -8161,7 +8198,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con if err == nil { accountUUID := account.GetExtraString("account_uuid") if accountUUID != "" && fp.ClientID != "" { - if newBody, err := s.identityService.RewriteUserIDWithMasking(ctx, body, account, accountUUID, fp.ClientID); err == nil && len(newBody) > 0 { + if newBody, err := s.identityService.RewriteUserIDWithMasking(ctx, body, account, accountUUID, fp.ClientID, fp.UserAgent); err == nil && len(newBody) > 0 { body = newBody } } diff --git a/backend/internal/service/gemini_multiplatform_test.go b/backend/internal/service/gemini_multiplatform_test.go index b0b804eb..a78c56e7 100644 --- a/backend/internal/service/gemini_multiplatform_test.go +++ b/backend/internal/service/gemini_multiplatform_test.go @@ -230,8 +230,8 @@ func (m *mockGroupRepoForGemini) ListActiveByPlatform(ctx context.Context, platf func (m *mockGroupRepoForGemini) ExistsByName(ctx context.Context, name string) (bool, error) { return false, nil } -func (m *mockGroupRepoForGemini) GetAccountCount(ctx context.Context, groupID int64) (int64, error) { - return 0, nil +func (m *mockGroupRepoForGemini) GetAccountCount(ctx context.Context, groupID int64) (int64, int64, error) { + return 0, 0, nil } func (m *mockGroupRepoForGemini) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) { return 0, nil diff --git a/backend/internal/service/generate_session_hash_test.go b/backend/internal/service/generate_session_hash_test.go index 8aa358a5..f91fb4c9 100644 --- a/backend/internal/service/generate_session_hash_test.go +++ b/backend/internal/service/generate_session_hash_test.go @@ -24,7 +24,7 @@ func TestGenerateSessionHash_MetadataHasHighestPriority(t *testing.T) { svc := &GatewayService{} parsed := &ParsedRequest{ - MetadataUserID: "session_123e4567-e89b-12d3-a456-426614174000", + MetadataUserID: "user_a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2_account__session_123e4567-e89b-12d3-a456-426614174000", System: "You are a helpful assistant.", HasSystem: true, Messages: []any{ @@ -196,7 +196,7 @@ func TestGenerateSessionHash_MetadataOverridesSessionContext(t *testing.T) { svc := &GatewayService{} parsed := &ParsedRequest{ - MetadataUserID: "session_123e4567-e89b-12d3-a456-426614174000", + MetadataUserID: "user_a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2_account__session_123e4567-e89b-12d3-a456-426614174000", Messages: []any{ map[string]any{"role": "user", "content": "hello"}, }, @@ -212,6 +212,22 @@ func TestGenerateSessionHash_MetadataOverridesSessionContext(t *testing.T) { "metadata session_id should take priority over SessionContext") } +func TestGenerateSessionHash_MetadataJSON_HasHighestPriority(t *testing.T) { + svc := &GatewayService{} + + parsed := &ParsedRequest{ + MetadataUserID: `{"device_id":"a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2","account_uuid":"","session_id":"c72554f2-1234-5678-abcd-123456789abc"}`, + System: "You are a helpful assistant.", + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + } + + hash := svc.GenerateSessionHash(parsed) + require.Equal(t, "c72554f2-1234-5678-abcd-123456789abc", hash, "JSON format metadata session_id should have highest priority") +} + func TestGenerateSessionHash_NilSessionContextBackwardCompatible(t *testing.T) { svc := &GatewayService{} diff --git a/backend/internal/service/group.go b/backend/internal/service/group.go index 537b5a3b..e17032e0 100644 --- a/backend/internal/service/group.go +++ b/backend/internal/service/group.go @@ -64,8 +64,10 @@ type Group struct { CreatedAt time.Time UpdatedAt time.Time - AccountGroups []AccountGroup - AccountCount int64 + AccountGroups []AccountGroup + AccountCount int64 + ActiveAccountCount int64 + RateLimitedAccountCount int64 } func (g *Group) IsActive() bool { diff --git a/backend/internal/service/group_capacity_service.go b/backend/internal/service/group_capacity_service.go new file mode 100644 index 00000000..459084dc --- /dev/null +++ b/backend/internal/service/group_capacity_service.go @@ -0,0 +1,131 @@ +package service + +import ( + "context" + "time" +) + +// GroupCapacitySummary holds aggregated capacity for a single group. +type GroupCapacitySummary struct { + GroupID int64 `json:"group_id"` + ConcurrencyUsed int `json:"concurrency_used"` + ConcurrencyMax int `json:"concurrency_max"` + SessionsUsed int `json:"sessions_used"` + SessionsMax int `json:"sessions_max"` + RPMUsed int `json:"rpm_used"` + RPMMax int `json:"rpm_max"` +} + +// GroupCapacityService aggregates per-group capacity from runtime data. +type GroupCapacityService struct { + accountRepo AccountRepository + groupRepo GroupRepository + concurrencyService *ConcurrencyService + sessionLimitCache SessionLimitCache + rpmCache RPMCache +} + +// NewGroupCapacityService creates a new GroupCapacityService. +func NewGroupCapacityService( + accountRepo AccountRepository, + groupRepo GroupRepository, + concurrencyService *ConcurrencyService, + sessionLimitCache SessionLimitCache, + rpmCache RPMCache, +) *GroupCapacityService { + return &GroupCapacityService{ + accountRepo: accountRepo, + groupRepo: groupRepo, + concurrencyService: concurrencyService, + sessionLimitCache: sessionLimitCache, + rpmCache: rpmCache, + } +} + +// GetAllGroupCapacity returns capacity summary for all active groups. +func (s *GroupCapacityService) GetAllGroupCapacity(ctx context.Context) ([]GroupCapacitySummary, error) { + groups, err := s.groupRepo.ListActive(ctx) + if err != nil { + return nil, err + } + + results := make([]GroupCapacitySummary, 0, len(groups)) + for i := range groups { + cap, err := s.getGroupCapacity(ctx, groups[i].ID) + if err != nil { + // Skip groups with errors, return partial results + continue + } + cap.GroupID = groups[i].ID + results = append(results, cap) + } + return results, nil +} + +func (s *GroupCapacityService) getGroupCapacity(ctx context.Context, groupID int64) (GroupCapacitySummary, error) { + accounts, err := s.accountRepo.ListSchedulableByGroupID(ctx, groupID) + if err != nil { + return GroupCapacitySummary{}, err + } + if len(accounts) == 0 { + return GroupCapacitySummary{}, nil + } + + // Collect account IDs and config values + accountIDs := make([]int64, 0, len(accounts)) + sessionTimeouts := make(map[int64]time.Duration) + var concurrencyMax, sessionsMax, rpmMax int + + for i := range accounts { + acc := &accounts[i] + accountIDs = append(accountIDs, acc.ID) + concurrencyMax += acc.Concurrency + + if ms := acc.GetMaxSessions(); ms > 0 { + sessionsMax += ms + timeout := time.Duration(acc.GetSessionIdleTimeoutMinutes()) * time.Minute + if timeout <= 0 { + timeout = 5 * time.Minute + } + sessionTimeouts[acc.ID] = timeout + } + + if rpm := acc.GetBaseRPM(); rpm > 0 { + rpmMax += rpm + } + } + + // Batch query runtime data from Redis + concurrencyMap, _ := s.concurrencyService.GetAccountConcurrencyBatch(ctx, accountIDs) + + var sessionsMap map[int64]int + if sessionsMax > 0 && s.sessionLimitCache != nil { + sessionsMap, _ = s.sessionLimitCache.GetActiveSessionCountBatch(ctx, accountIDs, sessionTimeouts) + } + + var rpmMap map[int64]int + if rpmMax > 0 && s.rpmCache != nil { + rpmMap, _ = s.rpmCache.GetRPMBatch(ctx, accountIDs) + } + + // Aggregate + var concurrencyUsed, sessionsUsed, rpmUsed int + for _, id := range accountIDs { + concurrencyUsed += concurrencyMap[id] + if sessionsMap != nil { + sessionsUsed += sessionsMap[id] + } + if rpmMap != nil { + rpmUsed += rpmMap[id] + } + } + + return GroupCapacitySummary{ + ConcurrencyUsed: concurrencyUsed, + ConcurrencyMax: concurrencyMax, + SessionsUsed: sessionsUsed, + SessionsMax: sessionsMax, + RPMUsed: rpmUsed, + RPMMax: rpmMax, + }, nil +} diff --git a/backend/internal/service/group_service.go b/backend/internal/service/group_service.go index 22a67eda..87174e03 100644 --- a/backend/internal/service/group_service.go +++ b/backend/internal/service/group_service.go @@ -27,7 +27,7 @@ type GroupRepository interface { ListActiveByPlatform(ctx context.Context, platform string) ([]Group, error) ExistsByName(ctx context.Context, name string) (bool, error) - GetAccountCount(ctx context.Context, groupID int64) (int64, error) + GetAccountCount(ctx context.Context, groupID int64) (total int64, active int64, err error) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) // GetAccountIDsByGroupIDs 获取多个分组的所有账号 ID(去重) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error) @@ -202,7 +202,7 @@ func (s *GroupService) GetStats(ctx context.Context, id int64) (map[string]any, } // 获取账号数量 - accountCount, err := s.groupRepo.GetAccountCount(ctx, id) + accountCount, _, err := s.groupRepo.GetAccountCount(ctx, id) if err != nil { return nil, fmt.Errorf("get account count: %w", err) } diff --git a/backend/internal/service/identity_service.go b/backend/internal/service/identity_service.go index f6a94d15..8d464a8b 100644 --- a/backend/internal/service/identity_service.go +++ b/backend/internal/service/identity_service.go @@ -19,10 +19,6 @@ import ( // 预编译正则表达式(避免每次调用重新编译) var ( - // 匹配 user_id 格式: - // 旧格式: user_{64位hex}_account__session_{uuid} (account 后无 UUID) - // 新格式: user_{64位hex}_account_{uuid}_session_{uuid} (account 后有 UUID) - userIDRegex = regexp.MustCompile(`^user_[a-f0-9]{64}_account_([a-f0-9-]*)_session_([a-f0-9-]{36})$`) // 匹配 User-Agent 版本号: xxx/x.y.z userAgentVersionRegex = regexp.MustCompile(`/(\d+)\.(\d+)\.(\d+)`) ) @@ -209,12 +205,12 @@ func (s *IdentityService) ApplyFingerprint(req *http.Request, fp *Fingerprint) { } // RewriteUserID 重写body中的metadata.user_id -// 输入格式:user_{clientId}_account__session_{sessionUUID} -// 输出格式:user_{cachedClientID}_account_{accountUUID}_session_{newHash} +// 支持旧拼接格式和新 JSON 格式的 user_id 解析, +// 根据 fingerprintUA 版本选择输出格式。 // // 重要:此函数使用 json.RawMessage 保留其他字段的原始字节, // 避免重新序列化导致 thinking 块等内容被修改。 -func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUID, cachedClientID string) ([]byte, error) { +func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUID, cachedClientID, fingerprintUA string) ([]byte, error) { if len(body) == 0 || accountUUID == "" || cachedClientID == "" { return body, nil } @@ -241,24 +237,21 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI return body, nil } - // 匹配格式: - // 旧格式: user_{64位hex}_account__session_{uuid} - // 新格式: user_{64位hex}_account_{uuid}_session_{uuid} - matches := userIDRegex.FindStringSubmatch(userID) - if matches == nil { + // 解析 user_id(兼容旧拼接格式和新 JSON 格式) + parsed := ParseMetadataUserID(userID) + if parsed == nil { return body, nil } - // matches[1] = account UUID (可能为空), matches[2] = session UUID - sessionTail := matches[2] // 原始session UUID + sessionTail := parsed.SessionID // 原始session UUID // 生成新的session hash: SHA256(accountID::sessionTail) -> UUID格式 seed := fmt.Sprintf("%d::%s", accountID, sessionTail) newSessionHash := generateUUIDFromSeed(seed) - // 构建新的user_id - // 格式: user_{cachedClientID}_account_{account_uuid}_session_{newSessionHash} - newUserID := fmt.Sprintf("user_%s_account_%s_session_%s", cachedClientID, accountUUID, newSessionHash) + // 根据客户端版本选择输出格式 + version := ExtractCLIVersion(fingerprintUA) + newUserID := FormatMetadataUserID(cachedClientID, accountUUID, newSessionHash, version) metadata["user_id"] = newUserID @@ -278,9 +271,9 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI // // 重要:此函数使用 json.RawMessage 保留其他字段的原始字节, // 避免重新序列化导致 thinking 块等内容被修改。 -func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []byte, account *Account, accountUUID, cachedClientID string) ([]byte, error) { +func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []byte, account *Account, accountUUID, cachedClientID, fingerprintUA string) ([]byte, error) { // 先执行常规的 RewriteUserID 逻辑 - newBody, err := s.RewriteUserID(body, account.ID, accountUUID, cachedClientID) + newBody, err := s.RewriteUserID(body, account.ID, accountUUID, cachedClientID, fingerprintUA) if err != nil { return newBody, err } @@ -312,10 +305,9 @@ func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []b return newBody, nil } - // 查找 _session_ 的位置,替换其后的内容 - const sessionMarker = "_session_" - idx := strings.LastIndex(userID, sessionMarker) - if idx == -1 { + // 解析已重写的 user_id + uidParsed := ParseMetadataUserID(userID) + if uidParsed == nil { return newBody, nil } @@ -337,8 +329,9 @@ func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []b logger.LegacyPrintf("service.identity", "Warning: failed to set masked session ID for account %d: %v", account.ID, err) } - // 替换 session 部分:保留 _session_ 之前的内容,替换之后的内容 - newUserID := userID[:idx+len(sessionMarker)] + maskedSessionID + // 用 FormatMetadataUserID 重建(保持与 RewriteUserID 相同的格式) + version := ExtractCLIVersion(fingerprintUA) + newUserID := FormatMetadataUserID(uidParsed.DeviceID, uidParsed.AccountUUID, maskedSessionID, version) slog.Debug("session_id_masking_applied", "account_id", account.ID, diff --git a/backend/internal/service/metadata_userid.go b/backend/internal/service/metadata_userid.go new file mode 100644 index 00000000..ee1ef64a --- /dev/null +++ b/backend/internal/service/metadata_userid.go @@ -0,0 +1,104 @@ +package service + +import ( + "encoding/json" + "regexp" + "strings" +) + +// NewMetadataFormatMinVersion is the minimum Claude Code version that uses +// JSON-formatted metadata.user_id instead of the legacy concatenated string. +const NewMetadataFormatMinVersion = "2.1.78" + +// ParsedUserID represents the components extracted from a metadata.user_id value. +type ParsedUserID struct { + DeviceID string // 64-char hex (or arbitrary client id) + AccountUUID string // may be empty + SessionID string // UUID + IsNewFormat bool // true if the original was JSON format +} + +// legacyUserIDRegex matches the legacy user_id format: +// +// user_{64hex}_account_{optional_uuid}_session_{uuid} +var legacyUserIDRegex = regexp.MustCompile(`^user_([a-fA-F0-9]{64})_account_([a-fA-F0-9-]*)_session_([a-fA-F0-9-]{36})$`) + +// jsonUserID is the JSON structure for the new metadata.user_id format. +type jsonUserID struct { + DeviceID string `json:"device_id"` + AccountUUID string `json:"account_uuid"` + SessionID string `json:"session_id"` +} + +// ParseMetadataUserID parses a metadata.user_id string in either format. +// Returns nil if the input cannot be parsed. +func ParseMetadataUserID(raw string) *ParsedUserID { + raw = strings.TrimSpace(raw) + if raw == "" { + return nil + } + + // Try JSON format first (starts with '{') + if raw[0] == '{' { + var j jsonUserID + if err := json.Unmarshal([]byte(raw), &j); err != nil { + return nil + } + if j.DeviceID == "" || j.SessionID == "" { + return nil + } + return &ParsedUserID{ + DeviceID: j.DeviceID, + AccountUUID: j.AccountUUID, + SessionID: j.SessionID, + IsNewFormat: true, + } + } + + // Try legacy format + matches := legacyUserIDRegex.FindStringSubmatch(raw) + if matches == nil { + return nil + } + return &ParsedUserID{ + DeviceID: matches[1], + AccountUUID: matches[2], + SessionID: matches[3], + IsNewFormat: false, + } +} + +// FormatMetadataUserID builds a metadata.user_id string in the format +// appropriate for the given CLI version. Components are the rewritten values +// (not necessarily the originals). +func FormatMetadataUserID(deviceID, accountUUID, sessionID, uaVersion string) string { + if IsNewMetadataFormatVersion(uaVersion) { + b, _ := json.Marshal(jsonUserID{ + DeviceID: deviceID, + AccountUUID: accountUUID, + SessionID: sessionID, + }) + return string(b) + } + // Legacy format + return "user_" + deviceID + "_account_" + accountUUID + "_session_" + sessionID +} + +// IsNewMetadataFormatVersion returns true if the given CLI version uses the +// new JSON metadata.user_id format (>= 2.1.78). +func IsNewMetadataFormatVersion(version string) bool { + if version == "" { + return false + } + return CompareVersions(version, NewMetadataFormatMinVersion) >= 0 +} + +// ExtractCLIVersion extracts the Claude Code version from a User-Agent string. +// Returns "" if the UA doesn't match the expected pattern. +func ExtractCLIVersion(ua string) string { + matches := claudeCodeUAVersionPattern.FindStringSubmatch(ua) + if len(matches) >= 2 { + return matches[1] + } + return "" +} diff --git a/backend/internal/service/metadata_userid_test.go b/backend/internal/service/metadata_userid_test.go new file mode 100644 index 00000000..40ad7087 --- /dev/null +++ b/backend/internal/service/metadata_userid_test.go @@ -0,0 +1,183 @@ +//go:build unit + +package service + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +// ============ ParseMetadataUserID Tests ============ + +func TestParseMetadataUserID_LegacyFormat_WithoutAccountUUID(t *testing.T) { + raw := "user_a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2_account__session_123e4567-e89b-12d3-a456-426614174000" + parsed := ParseMetadataUserID(raw) + require.NotNil(t, parsed) + require.Equal(t, "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2", parsed.DeviceID) + require.Equal(t, "", parsed.AccountUUID) + require.Equal(t, "123e4567-e89b-12d3-a456-426614174000", parsed.SessionID) + require.False(t, parsed.IsNewFormat) +} + +func TestParseMetadataUserID_LegacyFormat_WithAccountUUID(t *testing.T) { + raw := "user_a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2_account_550e8400-e29b-41d4-a716-446655440000_session_123e4567-e89b-12d3-a456-426614174000" + parsed := ParseMetadataUserID(raw) + require.NotNil(t, parsed) + require.Equal(t, "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2", parsed.DeviceID) + require.Equal(t, "550e8400-e29b-41d4-a716-446655440000", parsed.AccountUUID) + require.Equal(t, "123e4567-e89b-12d3-a456-426614174000", parsed.SessionID) + require.False(t, parsed.IsNewFormat) +} + +func TestParseMetadataUserID_JSONFormat_WithoutAccountUUID(t *testing.T) { + raw := `{"device_id":"d61f76d0aabbccdd00112233445566778899aabbccddeeff0011223344556677","account_uuid":"","session_id":"c72554f2-1234-5678-abcd-123456789abc"}` + parsed := ParseMetadataUserID(raw) + require.NotNil(t, parsed) + require.Equal(t, "d61f76d0aabbccdd00112233445566778899aabbccddeeff0011223344556677", parsed.DeviceID) + require.Equal(t, "", parsed.AccountUUID) + require.Equal(t, "c72554f2-1234-5678-abcd-123456789abc", parsed.SessionID) + require.True(t, parsed.IsNewFormat) +} + +func TestParseMetadataUserID_JSONFormat_WithAccountUUID(t *testing.T) { + raw := `{"device_id":"d61f76d0aabbccdd00112233445566778899aabbccddeeff0011223344556677","account_uuid":"550e8400-e29b-41d4-a716-446655440000","session_id":"c72554f2-1234-5678-abcd-123456789abc"}` + parsed := ParseMetadataUserID(raw) + require.NotNil(t, parsed) + require.Equal(t, "d61f76d0aabbccdd00112233445566778899aabbccddeeff0011223344556677", parsed.DeviceID) + require.Equal(t, "550e8400-e29b-41d4-a716-446655440000", parsed.AccountUUID) + require.Equal(t, "c72554f2-1234-5678-abcd-123456789abc", parsed.SessionID) + require.True(t, parsed.IsNewFormat) +} + +func TestParseMetadataUserID_InvalidInputs(t *testing.T) { + tests := []struct { + name string + raw string + }{ + {"empty string", ""}, + {"whitespace only", " "}, + {"random text", "not-a-valid-user-id"}, + {"partial legacy format", "session_123e4567-e89b-12d3-a456-426614174000"}, + {"invalid JSON", `{"device_id":}`}, + {"JSON missing device_id", `{"account_uuid":"","session_id":"c72554f2-1234-5678-abcd-123456789abc"}`}, + {"JSON missing session_id", `{"device_id":"d61f76d0aabbccdd00112233445566778899aabbccddeeff0011223344556677","account_uuid":""}`}, + {"JSON empty device_id", `{"device_id":"","account_uuid":"","session_id":"c72554f2-1234-5678-abcd-123456789abc"}`}, + {"JSON empty session_id", `{"device_id":"d61f76d0aabbccdd00112233445566778899aabbccddeeff0011223344556677","account_uuid":"","session_id":""}`}, + {"legacy format short hex", "user_a1b2c3d4_account__session_123e4567-e89b-12d3-a456-426614174000"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Nil(t, ParseMetadataUserID(tt.raw), "should return nil for: %s", tt.raw) + }) + } +} + +func TestParseMetadataUserID_HexCaseInsensitive(t *testing.T) { + // Legacy format should accept both upper and lower case hex + rawUpper := "user_A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2_account__session_123e4567-e89b-12d3-a456-426614174000" + parsed := ParseMetadataUserID(rawUpper) + require.NotNil(t, parsed, "legacy format should accept uppercase hex") + require.Equal(t, "A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2", parsed.DeviceID) +} + +// ============ FormatMetadataUserID Tests ============ + +func TestFormatMetadataUserID_LegacyVersion(t *testing.T) { + result := FormatMetadataUserID("deadbeef"+"00112233445566778899aabbccddeeff0011223344556677", "acc-uuid", "sess-uuid", "2.1.77") + require.Equal(t, "user_deadbeef00112233445566778899aabbccddeeff0011223344556677_account_acc-uuid_session_sess-uuid", result) +} + +func TestFormatMetadataUserID_NewVersion(t *testing.T) { + result := FormatMetadataUserID("deadbeef"+"00112233445566778899aabbccddeeff0011223344556677", "acc-uuid", "sess-uuid", "2.1.78") + require.Equal(t, `{"device_id":"deadbeef00112233445566778899aabbccddeeff0011223344556677","account_uuid":"acc-uuid","session_id":"sess-uuid"}`, result) +} + +func TestFormatMetadataUserID_EmptyVersion_Legacy(t *testing.T) { + result := FormatMetadataUserID("deadbeef"+"00112233445566778899aabbccddeeff0011223344556677", "", "sess-uuid", "") + require.Equal(t, "user_deadbeef00112233445566778899aabbccddeeff0011223344556677_account__session_sess-uuid", result) +} + +func TestFormatMetadataUserID_EmptyAccountUUID(t *testing.T) { + // Legacy format with empty account UUID → double underscore + result := FormatMetadataUserID("deadbeef"+"00112233445566778899aabbccddeeff0011223344556677", "", "sess-uuid", "2.1.22") + require.Contains(t, result, "_account__session_") + + // New format with empty account UUID → empty string in JSON + result = FormatMetadataUserID("deadbeef"+"00112233445566778899aabbccddeeff0011223344556677", "", "sess-uuid", "2.1.78") + require.Contains(t, result, `"account_uuid":""`) +} + +// ============ IsNewMetadataFormatVersion Tests ============ + +func TestIsNewMetadataFormatVersion(t *testing.T) { + tests := []struct { + version string + want bool + }{ + {"", false}, + {"2.1.77", false}, + {"2.1.78", true}, + {"2.1.79", true}, + {"2.2.0", true}, + {"3.0.0", true}, + {"2.0.100", false}, + {"1.9.99", false}, + } + for _, tt := range tests { + t.Run(tt.version, func(t *testing.T) { + require.Equal(t, tt.want, IsNewMetadataFormatVersion(tt.version)) + }) + } +} + +// ============ Round-trip Tests ============ + +func TestParseFormat_RoundTrip_Legacy(t *testing.T) { + deviceID := "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2" + accountUUID := "550e8400-e29b-41d4-a716-446655440000" + sessionID := "123e4567-e89b-12d3-a456-426614174000" + + formatted := FormatMetadataUserID(deviceID, accountUUID, sessionID, "2.1.22") + parsed := ParseMetadataUserID(formatted) + require.NotNil(t, parsed) + require.Equal(t, deviceID, parsed.DeviceID) + require.Equal(t, accountUUID, parsed.AccountUUID) + require.Equal(t, sessionID, parsed.SessionID) + require.False(t, parsed.IsNewFormat) +} + +func TestParseFormat_RoundTrip_JSON(t *testing.T) { + deviceID := "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2" + accountUUID := "550e8400-e29b-41d4-a716-446655440000" + sessionID := "123e4567-e89b-12d3-a456-426614174000" + + formatted := FormatMetadataUserID(deviceID, accountUUID, sessionID, "2.1.78") + parsed := ParseMetadataUserID(formatted) + require.NotNil(t, parsed) + require.Equal(t, deviceID, parsed.DeviceID) + require.Equal(t, accountUUID, parsed.AccountUUID) + require.Equal(t, sessionID, parsed.SessionID) + require.True(t, parsed.IsNewFormat) +} + +func TestParseFormat_RoundTrip_EmptyAccountUUID(t *testing.T) { + deviceID := "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2" + sessionID := "123e4567-e89b-12d3-a456-426614174000" + + // Legacy round-trip with empty account UUID + formatted := FormatMetadataUserID(deviceID, "", sessionID, "2.1.22") + parsed := ParseMetadataUserID(formatted) + require.NotNil(t, parsed) + require.Equal(t, deviceID, parsed.DeviceID) + require.Equal(t, "", parsed.AccountUUID) + require.Equal(t, sessionID, parsed.SessionID) + + // JSON round-trip with empty account UUID + formatted = FormatMetadataUserID(deviceID, "", sessionID, "2.1.78") + parsed = ParseMetadataUserID(formatted) + require.NotNil(t, parsed) + require.Equal(t, deviceID, parsed.DeviceID) + require.Equal(t, "", parsed.AccountUUID) + require.Equal(t, sessionID, parsed.SessionID) +} diff --git a/backend/internal/service/openai_gateway_chat_completions.go b/backend/internal/service/openai_gateway_chat_completions.go index 9529f6be..7202f7cb 100644 --- a/backend/internal/service/openai_gateway_chat_completions.go +++ b/backend/internal/service/openai_gateway_chat_completions.go @@ -277,12 +277,13 @@ func (s *OpenAIGatewayService) handleChatBufferedStreamingResponse( c.JSON(http.StatusOK, chatResp) return &OpenAIForwardResult{ - RequestID: requestID, - Usage: usage, - Model: originalModel, - BillingModel: mappedModel, - Stream: false, - Duration: time.Since(startTime), + RequestID: requestID, + Usage: usage, + Model: originalModel, + BillingModel: mappedModel, + UpstreamModel: mappedModel, + Stream: false, + Duration: time.Since(startTime), }, nil } @@ -324,13 +325,14 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse( resultWithUsage := func() *OpenAIForwardResult { return &OpenAIForwardResult{ - RequestID: requestID, - Usage: usage, - Model: originalModel, - BillingModel: mappedModel, - Stream: true, - Duration: time.Since(startTime), - FirstTokenMs: firstTokenMs, + RequestID: requestID, + Usage: usage, + Model: originalModel, + BillingModel: mappedModel, + UpstreamModel: mappedModel, + Stream: true, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, } } diff --git a/backend/internal/service/openai_gateway_messages.go b/backend/internal/service/openai_gateway_messages.go index 58714571..6a29823a 100644 --- a/backend/internal/service/openai_gateway_messages.go +++ b/backend/internal/service/openai_gateway_messages.go @@ -299,12 +299,13 @@ func (s *OpenAIGatewayService) handleAnthropicBufferedStreamingResponse( c.JSON(http.StatusOK, anthropicResp) return &OpenAIForwardResult{ - RequestID: requestID, - Usage: usage, - Model: originalModel, - BillingModel: mappedModel, - Stream: false, - Duration: time.Since(startTime), + RequestID: requestID, + Usage: usage, + Model: originalModel, + BillingModel: mappedModel, + UpstreamModel: mappedModel, + Stream: false, + Duration: time.Since(startTime), }, nil } @@ -347,13 +348,14 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse( // resultWithUsage builds the final result snapshot. resultWithUsage := func() *OpenAIForwardResult { return &OpenAIForwardResult{ - RequestID: requestID, - Usage: usage, - Model: originalModel, - BillingModel: mappedModel, - Stream: true, - Duration: time.Since(startTime), - FirstTokenMs: firstTokenMs, + RequestID: requestID, + Usage: usage, + Model: originalModel, + BillingModel: mappedModel, + UpstreamModel: mappedModel, + Stream: true, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, } } diff --git a/backend/internal/service/openai_gateway_record_usage_test.go b/backend/internal/service/openai_gateway_record_usage_test.go index ada7d805..a35f9127 100644 --- a/backend/internal/service/openai_gateway_record_usage_test.go +++ b/backend/internal/service/openai_gateway_record_usage_test.go @@ -846,7 +846,7 @@ func TestExtractOpenAIServiceTierFromBody(t *testing.T) { require.Nil(t, extractOpenAIServiceTierFromBody(nil)) } -func TestOpenAIGatewayServiceRecordUsage_UsesBillingModelAndMetadataFields(t *testing.T) { +func TestOpenAIGatewayServiceRecordUsage_UsesRequestedModelAndUpstreamModelMetadataFields(t *testing.T) { usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} userRepo := &openAIRecordUsageUserRepoStub{} subRepo := &openAIRecordUsageSubRepoStub{} @@ -859,6 +859,7 @@ func TestOpenAIGatewayServiceRecordUsage_UsesBillingModelAndMetadataFields(t *te RequestID: "resp_billing_model_override", BillingModel: "gpt-5.1-codex", Model: "gpt-5.1", + UpstreamModel: "gpt-5.1-codex", ServiceTier: &serviceTier, ReasoningEffort: &reasoning, Usage: OpenAIUsage{ @@ -877,7 +878,9 @@ func TestOpenAIGatewayServiceRecordUsage_UsesBillingModelAndMetadataFields(t *te require.NoError(t, err) require.NotNil(t, usageRepo.lastLog) - require.Equal(t, "gpt-5.1-codex", usageRepo.lastLog.Model) + require.Equal(t, "gpt-5.1", usageRepo.lastLog.Model) + require.NotNil(t, usageRepo.lastLog.UpstreamModel) + require.Equal(t, "gpt-5.1-codex", *usageRepo.lastLog.UpstreamModel) require.NotNil(t, usageRepo.lastLog.ServiceTier) require.Equal(t, serviceTier, *usageRepo.lastLog.ServiceTier) require.NotNil(t, usageRepo.lastLog.ReasoningEffort) diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index c8876edb..cf902c20 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -216,6 +216,9 @@ type OpenAIForwardResult struct { // This is set by the Anthropic Messages conversion path where // the mapped upstream model differs from the client-facing model. BillingModel string + // UpstreamModel is the actual model sent to the upstream provider after mapping. + // Empty when no mapping was applied (requested model was used as-is). + UpstreamModel string // ServiceTier records the OpenAI Responses API service tier, e.g. "priority" / "flex". // Nil means the request did not specify a recognized tier. ServiceTier *string @@ -2128,6 +2131,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco firstTokenMs, wsAttempts, ) + wsResult.UpstreamModel = mappedModel return wsResult, nil } s.writeOpenAIWSFallbackErrorResponse(c, account, wsErr) @@ -2263,6 +2267,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco RequestID: resp.Header.Get("x-request-id"), Usage: *usage, Model: originalModel, + UpstreamModel: mappedModel, ServiceTier: serviceTier, ReasoningEffort: reasoningEffort, Stream: reqStream, @@ -4134,7 +4139,8 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec APIKeyID: apiKey.ID, AccountID: account.ID, RequestID: requestID, - Model: billingModel, + Model: result.Model, + UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model), ServiceTier: result.ServiceTier, ReasoningEffort: result.ReasoningEffort, InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint), @@ -4700,11 +4706,3 @@ func normalizeOpenAIReasoningEffort(raw string) string { return "" } } - -func optionalTrimmedStringPtr(raw string) *string { - trimmed := strings.TrimSpace(raw) - if trimmed == "" { - return nil - } - return &trimmed -} diff --git a/backend/internal/service/overload_cooldown_test.go b/backend/internal/service/overload_cooldown_test.go new file mode 100644 index 00000000..ef5e7fd1 --- /dev/null +++ b/backend/internal/service/overload_cooldown_test.go @@ -0,0 +1,298 @@ +//go:build unit + +package service + +import ( + "context" + "encoding/json" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// errSettingRepo: a SettingRepository that always returns errors on read +// --------------------------------------------------------------------------- + +type errSettingRepo struct { + mockSettingRepo // embed the existing mock from backup_service_test.go + readErr error +} + +func (r *errSettingRepo) GetValue(_ context.Context, _ string) (string, error) { + return "", r.readErr +} + +func (r *errSettingRepo) Get(_ context.Context, _ string) (*Setting, error) { + return nil, r.readErr +} + +// --------------------------------------------------------------------------- +// overloadAccountRepoStub: records SetOverloaded calls +// --------------------------------------------------------------------------- + +type overloadAccountRepoStub struct { + mockAccountRepoForGemini + overloadCalls int + lastOverloadID int64 + lastOverloadEnd time.Time +} + +func (r *overloadAccountRepoStub) SetOverloaded(_ context.Context, id int64, until time.Time) error { + r.overloadCalls++ + r.lastOverloadID = id + r.lastOverloadEnd = until + return nil +} + +// =========================================================================== +// SettingService: GetOverloadCooldownSettings +// =========================================================================== + +func TestGetOverloadCooldownSettings_DefaultsWhenNotSet(t *testing.T) { + repo := newMockSettingRepo() + svc := NewSettingService(repo, &config.Config{}) + + settings, err := svc.GetOverloadCooldownSettings(context.Background()) + require.NoError(t, err) + require.True(t, settings.Enabled) + require.Equal(t, 10, settings.CooldownMinutes) +} + +func TestGetOverloadCooldownSettings_ReadsFromDB(t *testing.T) { + repo := newMockSettingRepo() + data, _ := json.Marshal(OverloadCooldownSettings{Enabled: false, CooldownMinutes: 30}) + repo.data[SettingKeyOverloadCooldownSettings] = string(data) + svc := NewSettingService(repo, &config.Config{}) + + settings, err := svc.GetOverloadCooldownSettings(context.Background()) + require.NoError(t, err) + require.False(t, settings.Enabled) + require.Equal(t, 30, settings.CooldownMinutes) +} + +func TestGetOverloadCooldownSettings_ClampsMinValue(t *testing.T) { + repo := newMockSettingRepo() + data, _ := json.Marshal(OverloadCooldownSettings{Enabled: true, CooldownMinutes: 0}) + repo.data[SettingKeyOverloadCooldownSettings] = string(data) + svc := NewSettingService(repo, &config.Config{}) + + settings, err := svc.GetOverloadCooldownSettings(context.Background()) + require.NoError(t, err) + require.Equal(t, 1, settings.CooldownMinutes) +} + +func TestGetOverloadCooldownSettings_ClampsMaxValue(t *testing.T) { + repo := newMockSettingRepo() + data, _ := json.Marshal(OverloadCooldownSettings{Enabled: true, CooldownMinutes: 999}) + repo.data[SettingKeyOverloadCooldownSettings] = string(data) + svc := NewSettingService(repo, &config.Config{}) + + settings, err := svc.GetOverloadCooldownSettings(context.Background()) + require.NoError(t, err) + require.Equal(t, 120, settings.CooldownMinutes) +} + +func TestGetOverloadCooldownSettings_InvalidJSON_ReturnsDefaults(t *testing.T) { + repo := newMockSettingRepo() + repo.data[SettingKeyOverloadCooldownSettings] = "not-json" + svc := NewSettingService(repo, &config.Config{}) + + settings, err := svc.GetOverloadCooldownSettings(context.Background()) + require.NoError(t, err) + require.True(t, settings.Enabled) + require.Equal(t, 10, settings.CooldownMinutes) +} + +func TestGetOverloadCooldownSettings_EmptyValue_ReturnsDefaults(t *testing.T) { + repo := newMockSettingRepo() + repo.data[SettingKeyOverloadCooldownSettings] = "" + svc := NewSettingService(repo, &config.Config{}) + + settings, err := svc.GetOverloadCooldownSettings(context.Background()) + require.NoError(t, err) + require.True(t, settings.Enabled) + require.Equal(t, 10, settings.CooldownMinutes) +} + +// =========================================================================== +// SettingService: SetOverloadCooldownSettings +// =========================================================================== + +func TestSetOverloadCooldownSettings_Success(t *testing.T) { + repo := newMockSettingRepo() + svc := NewSettingService(repo, &config.Config{}) + + err := svc.SetOverloadCooldownSettings(context.Background(), &OverloadCooldownSettings{ + Enabled: false, + CooldownMinutes: 25, + }) + require.NoError(t, err) + + // Verify round-trip + settings, err := svc.GetOverloadCooldownSettings(context.Background()) + require.NoError(t, err) + require.False(t, settings.Enabled) + require.Equal(t, 25, settings.CooldownMinutes) +} + +func TestSetOverloadCooldownSettings_RejectsNil(t *testing.T) { + svc := NewSettingService(newMockSettingRepo(), &config.Config{}) + err := svc.SetOverloadCooldownSettings(context.Background(), nil) + require.Error(t, err) +} + +func TestSetOverloadCooldownSettings_EnabledRejectsOutOfRange(t *testing.T) { + svc := NewSettingService(newMockSettingRepo(), &config.Config{}) + + for _, minutes := range []int{0, -1, 121, 999} { + err := svc.SetOverloadCooldownSettings(context.Background(), &OverloadCooldownSettings{ + Enabled: true, CooldownMinutes: minutes, + }) + require.Error(t, err, "should reject enabled=true + cooldown_minutes=%d", minutes) + require.Contains(t, err.Error(), "cooldown_minutes must be between 1-120") + } +} + +func TestSetOverloadCooldownSettings_DisabledNormalizesOutOfRange(t *testing.T) { + repo := newMockSettingRepo() + svc := NewSettingService(repo, &config.Config{}) + + // enabled=false + cooldown_minutes=0 应该保存成功,值被归一化为10 + err := svc.SetOverloadCooldownSettings(context.Background(), &OverloadCooldownSettings{ + Enabled: false, CooldownMinutes: 0, + }) + require.NoError(t, err, "disabled with invalid minutes should NOT be rejected") + + // 验证持久化后读回来的值 + settings, err := svc.GetOverloadCooldownSettings(context.Background()) + require.NoError(t, err) + require.False(t, settings.Enabled) + require.Equal(t, 10, settings.CooldownMinutes, "should be normalized to default") +} + +func TestSetOverloadCooldownSettings_AcceptsBoundaries(t *testing.T) { + svc := NewSettingService(newMockSettingRepo(), &config.Config{}) + + for _, minutes := range []int{1, 60, 120} { + err := svc.SetOverloadCooldownSettings(context.Background(), &OverloadCooldownSettings{ + Enabled: true, CooldownMinutes: minutes, + }) + require.NoError(t, err, "should accept cooldown_minutes=%d", minutes) + } +} + +// =========================================================================== +// RateLimitService: handle529 behaviour +// =========================================================================== + +func TestHandle529_EnabledFromDB_PausesAccount(t *testing.T) { + accountRepo := &overloadAccountRepoStub{} + settingRepo := newMockSettingRepo() + data, _ := json.Marshal(OverloadCooldownSettings{Enabled: true, CooldownMinutes: 15}) + settingRepo.data[SettingKeyOverloadCooldownSettings] = string(data) + + settingSvc := NewSettingService(settingRepo, &config.Config{}) + svc := NewRateLimitService(accountRepo, nil, &config.Config{}, nil, nil) + svc.SetSettingService(settingSvc) + + account := &Account{ID: 42, Platform: PlatformAnthropic, Type: AccountTypeOAuth} + before := time.Now() + svc.handle529(context.Background(), account) + + require.Equal(t, 1, accountRepo.overloadCalls) + require.Equal(t, int64(42), accountRepo.lastOverloadID) + require.WithinDuration(t, before.Add(15*time.Minute), accountRepo.lastOverloadEnd, 2*time.Second) +} + +func TestHandle529_DisabledFromDB_SkipsAccount(t *testing.T) { + accountRepo := &overloadAccountRepoStub{} + settingRepo := newMockSettingRepo() + data, _ := json.Marshal(OverloadCooldownSettings{Enabled: false, CooldownMinutes: 15}) + settingRepo.data[SettingKeyOverloadCooldownSettings] = string(data) + + settingSvc := NewSettingService(settingRepo, &config.Config{}) + svc := NewRateLimitService(accountRepo, nil, &config.Config{}, nil, nil) + svc.SetSettingService(settingSvc) + + account := &Account{ID: 42, Platform: PlatformAnthropic, Type: AccountTypeOAuth} + svc.handle529(context.Background(), account) + + require.Equal(t, 0, accountRepo.overloadCalls, "should NOT pause when disabled") +} + +func TestHandle529_NilSettingService_FallsBackToConfig(t *testing.T) { + accountRepo := &overloadAccountRepoStub{} + cfg := &config.Config{} + cfg.RateLimit.OverloadCooldownMinutes = 20 + svc := NewRateLimitService(accountRepo, nil, cfg, nil, nil) + // NOT calling SetSettingService — remains nil + + account := &Account{ID: 77, Platform: PlatformAnthropic, Type: AccountTypeOAuth} + before := time.Now() + svc.handle529(context.Background(), account) + + require.Equal(t, 1, accountRepo.overloadCalls) + require.WithinDuration(t, before.Add(20*time.Minute), accountRepo.lastOverloadEnd, 2*time.Second) +} + +func TestHandle529_NilSettingService_ZeroConfig_DefaultsTen(t *testing.T) { + accountRepo := &overloadAccountRepoStub{} + svc := NewRateLimitService(accountRepo, nil, &config.Config{}, nil, nil) + + account := &Account{ID: 88, Platform: PlatformAnthropic, Type: AccountTypeOAuth} + before := time.Now() + svc.handle529(context.Background(), account) + + require.Equal(t, 1, accountRepo.overloadCalls) + require.WithinDuration(t, before.Add(10*time.Minute), accountRepo.lastOverloadEnd, 2*time.Second) +} + +func TestHandle529_DBReadError_FallsBackToConfig(t *testing.T) { + accountRepo := &overloadAccountRepoStub{} + errRepo := &errSettingRepo{readErr: context.DeadlineExceeded} + errRepo.data = make(map[string]string) + + cfg := &config.Config{} + cfg.RateLimit.OverloadCooldownMinutes = 7 + settingSvc := NewSettingService(errRepo, cfg) + svc := NewRateLimitService(accountRepo, nil, cfg, nil, nil) + svc.SetSettingService(settingSvc) + + account := &Account{ID: 99, Platform: PlatformAnthropic, Type: AccountTypeOAuth} + before := time.Now() + svc.handle529(context.Background(), account) + + require.Equal(t, 1, accountRepo.overloadCalls) + require.WithinDuration(t, before.Add(7*time.Minute), accountRepo.lastOverloadEnd, 2*time.Second) +} + +// =========================================================================== +// Model: defaults & JSON round-trip +// =========================================================================== + +func TestDefaultOverloadCooldownSettings(t *testing.T) { + d := DefaultOverloadCooldownSettings() + require.True(t, d.Enabled) + require.Equal(t, 10, d.CooldownMinutes) +} + +func TestOverloadCooldownSettings_JSONRoundTrip(t *testing.T) { + original := OverloadCooldownSettings{Enabled: false, CooldownMinutes: 42} + data, err := json.Marshal(original) + require.NoError(t, err) + + var decoded OverloadCooldownSettings + require.NoError(t, json.Unmarshal(data, &decoded)) + require.Equal(t, original, decoded) + + // Verify JSON uses snake_case field names + var raw map[string]any + require.NoError(t, json.Unmarshal(data, &raw)) + _, hasEnabled := raw["enabled"] + _, hasCooldown := raw["cooldown_minutes"] + require.True(t, hasEnabled, "JSON must use 'enabled'") + require.True(t, hasCooldown, "JSON must use 'cooldown_minutes'") +} diff --git a/backend/internal/service/ratelimit_service.go b/backend/internal/service/ratelimit_service.go index ef8a65c9..c59dd68d 100644 --- a/backend/internal/service/ratelimit_service.go +++ b/backend/internal/service/ratelimit_service.go @@ -1023,11 +1023,34 @@ func parseOpenAIRateLimitResetTime(body []byte) *int64 { } // handle529 处理529过载错误 -// 根据配置设置过载冷却时间 +// 根据配置决定是否暂停账号调度及冷却时长 func (s *RateLimitService) handle529(ctx context.Context, account *Account) { - cooldownMinutes := s.cfg.RateLimit.OverloadCooldownMinutes + var settings *OverloadCooldownSettings + if s.settingService != nil { + var err error + settings, err = s.settingService.GetOverloadCooldownSettings(ctx) + if err != nil { + slog.Warn("overload_settings_read_failed", "account_id", account.ID, "error", err) + settings = nil + } + } + // 回退到配置文件 + if settings == nil { + cooldown := s.cfg.RateLimit.OverloadCooldownMinutes + if cooldown <= 0 { + cooldown = 10 + } + settings = &OverloadCooldownSettings{Enabled: true, CooldownMinutes: cooldown} + } + + if !settings.Enabled { + slog.Info("account_529_ignored", "account_id", account.ID, "reason", "overload_cooldown_disabled") + return + } + + cooldownMinutes := settings.CooldownMinutes if cooldownMinutes <= 0 { - cooldownMinutes = 10 // 默认10分钟 + cooldownMinutes = 10 } until := time.Now().Add(time.Duration(cooldownMinutes) * time.Minute) diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index 141cdb39..ece95c4e 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -1172,6 +1172,57 @@ func (s *SettingService) GetLinuxDoConnectOAuthConfig(ctx context.Context) (conf return effective, nil } +// GetOverloadCooldownSettings 获取529过载冷却配置 +func (s *SettingService) GetOverloadCooldownSettings(ctx context.Context) (*OverloadCooldownSettings, error) { + value, err := s.settingRepo.GetValue(ctx, SettingKeyOverloadCooldownSettings) + if err != nil { + if errors.Is(err, ErrSettingNotFound) { + return DefaultOverloadCooldownSettings(), nil + } + return nil, fmt.Errorf("get overload cooldown settings: %w", err) + } + if value == "" { + return DefaultOverloadCooldownSettings(), nil + } + + var settings OverloadCooldownSettings + if err := json.Unmarshal([]byte(value), &settings); err != nil { + return DefaultOverloadCooldownSettings(), nil + } + + // 修正配置值范围 + if settings.CooldownMinutes < 1 { + settings.CooldownMinutes = 1 + } + if settings.CooldownMinutes > 120 { + settings.CooldownMinutes = 120 + } + + return &settings, nil +} + +// SetOverloadCooldownSettings 设置529过载冷却配置 +func (s *SettingService) SetOverloadCooldownSettings(ctx context.Context, settings *OverloadCooldownSettings) error { + if settings == nil { + return fmt.Errorf("settings cannot be nil") + } + + // 禁用时修正为合法值即可,不拒绝请求 + if settings.CooldownMinutes < 1 || settings.CooldownMinutes > 120 { + if settings.Enabled { + return fmt.Errorf("cooldown_minutes must be between 1-120") + } + settings.CooldownMinutes = 10 // 禁用状态下归一化为默认值 + } + + data, err := json.Marshal(settings) + if err != nil { + return fmt.Errorf("marshal overload cooldown settings: %w", err) + } + + return s.settingRepo.Set(ctx, SettingKeyOverloadCooldownSettings, string(data)) +} + // GetStreamTimeoutSettings 获取流超时处理配置 func (s *SettingService) GetStreamTimeoutSettings(ctx context.Context) (*StreamTimeoutSettings, error) { value, err := s.settingRepo.GetValue(ctx, SettingKeyStreamTimeoutSettings) diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go index 71c2e7aa..23188a09 100644 --- a/backend/internal/service/settings_view.go +++ b/backend/internal/service/settings_view.go @@ -222,6 +222,22 @@ type BetaPolicySettings struct { Rules []BetaPolicyRule `json:"rules"` } +// OverloadCooldownSettings 529过载冷却配置 +type OverloadCooldownSettings struct { + // Enabled 是否在收到529时暂停账号调度 + Enabled bool `json:"enabled"` + // CooldownMinutes 冷却时长(分钟) + CooldownMinutes int `json:"cooldown_minutes"` +} + +// DefaultOverloadCooldownSettings 返回默认的过载冷却配置(启用,10分钟) +func DefaultOverloadCooldownSettings() *OverloadCooldownSettings { + return &OverloadCooldownSettings{ + Enabled: true, + CooldownMinutes: 10, + } +} + // DefaultBetaPolicySettings 返回默认的 Beta 策略配置 func DefaultBetaPolicySettings() *BetaPolicySettings { return &BetaPolicySettings{ diff --git a/backend/internal/service/sora_quota_service_test.go b/backend/internal/service/sora_quota_service_test.go index 040e427d..da8efe77 100644 --- a/backend/internal/service/sora_quota_service_test.go +++ b/backend/internal/service/sora_quota_service_test.go @@ -52,8 +52,8 @@ func (r *stubGroupRepoForQuota) ListActiveByPlatform(context.Context, string) ([ func (r *stubGroupRepoForQuota) ExistsByName(context.Context, string) (bool, error) { return false, nil } -func (r *stubGroupRepoForQuota) GetAccountCount(context.Context, int64) (int64, error) { - return 0, nil +func (r *stubGroupRepoForQuota) GetAccountCount(context.Context, int64) (int64, int64, error) { + return 0, 0, nil } func (r *stubGroupRepoForQuota) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) { return 0, nil diff --git a/backend/internal/service/subscription_assign_idempotency_test.go b/backend/internal/service/subscription_assign_idempotency_test.go index 0defafba..40bab206 100644 --- a/backend/internal/service/subscription_assign_idempotency_test.go +++ b/backend/internal/service/subscription_assign_idempotency_test.go @@ -40,7 +40,7 @@ func (groupRepoNoop) ListActiveByPlatform(context.Context, string) ([]Group, err func (groupRepoNoop) ExistsByName(context.Context, string) (bool, error) { panic("unexpected ExistsByName call") } -func (groupRepoNoop) GetAccountCount(context.Context, int64) (int64, error) { +func (groupRepoNoop) GetAccountCount(context.Context, int64) (int64, int64, error) { panic("unexpected GetAccountCount call") } func (groupRepoNoop) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) { @@ -92,7 +92,7 @@ func (userSubRepoNoop) ListActiveByUserID(context.Context, int64) ([]UserSubscri func (userSubRepoNoop) ListByGroupID(context.Context, int64, pagination.PaginationParams) ([]UserSubscription, *pagination.PaginationResult, error) { panic("unexpected ListByGroupID call") } -func (userSubRepoNoop) List(context.Context, pagination.PaginationParams, *int64, *int64, string, string, string) ([]UserSubscription, *pagination.PaginationResult, error) { +func (userSubRepoNoop) List(context.Context, pagination.PaginationParams, *int64, *int64, string, string, string, string) ([]UserSubscription, *pagination.PaginationResult, error) { panic("unexpected List call") } func (userSubRepoNoop) ExistsByUserIDAndGroupID(context.Context, int64, int64) (bool, error) { diff --git a/backend/internal/service/subscription_service.go b/backend/internal/service/subscription_service.go index af548509..f0a5540e 100644 --- a/backend/internal/service/subscription_service.go +++ b/backend/internal/service/subscription_service.go @@ -634,9 +634,9 @@ func (s *SubscriptionService) ListGroupSubscriptions(ctx context.Context, groupI } // List 获取所有订阅(分页,支持筛选和排序) -func (s *SubscriptionService) List(ctx context.Context, page, pageSize int, userID, groupID *int64, status, sortBy, sortOrder string) ([]UserSubscription, *pagination.PaginationResult, error) { +func (s *SubscriptionService) List(ctx context.Context, page, pageSize int, userID, groupID *int64, status, platform, sortBy, sortOrder string) ([]UserSubscription, *pagination.PaginationResult, error) { params := pagination.PaginationParams{Page: page, PageSize: pageSize} - subs, pag, err := s.userSubRepo.List(ctx, params, userID, groupID, status, sortBy, sortOrder) + subs, pag, err := s.userSubRepo.List(ctx, params, userID, groupID, status, platform, sortBy, sortOrder) if err != nil { return nil, nil, err } diff --git a/backend/internal/service/usage_log.go b/backend/internal/service/usage_log.go index 7f1bef7f..5a498676 100644 --- a/backend/internal/service/usage_log.go +++ b/backend/internal/service/usage_log.go @@ -98,6 +98,9 @@ type UsageLog struct { AccountID int64 RequestID string Model string + // UpstreamModel is the actual model sent to the upstream provider after mapping. + // Nil means no mapping was applied (requested model was used as-is). + UpstreamModel *string // ServiceTier records the OpenAI service tier used for billing, e.g. "priority" / "flex". ServiceTier *string // ReasoningEffort is the request's reasoning effort level. diff --git a/backend/internal/service/usage_log_helpers.go b/backend/internal/service/usage_log_helpers.go new file mode 100644 index 00000000..2ab51849 --- /dev/null +++ b/backend/internal/service/usage_log_helpers.go @@ -0,0 +1,21 @@ +package service + +import "strings" + +func optionalTrimmedStringPtr(raw string) *string { + trimmed := strings.TrimSpace(raw) + if trimmed == "" { + return nil + } + return &trimmed +} + +// optionalNonEqualStringPtr returns a pointer to value if it is non-empty and +// differs from compare; otherwise nil. Used to store upstream_model only when +// it differs from the requested model. +func optionalNonEqualStringPtr(value, compare string) *string { + if value == "" || value == compare { + return nil + } + return &value +} diff --git a/backend/internal/service/user_subscription_port.go b/backend/internal/service/user_subscription_port.go index 2dfc8d02..4484fae8 100644 --- a/backend/internal/service/user_subscription_port.go +++ b/backend/internal/service/user_subscription_port.go @@ -18,7 +18,7 @@ type UserSubscriptionRepository interface { ListByUserID(ctx context.Context, userID int64) ([]UserSubscription, error) ListActiveByUserID(ctx context.Context, userID int64) ([]UserSubscription, error) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]UserSubscription, *pagination.PaginationResult, error) - List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]UserSubscription, *pagination.PaginationResult, error) + List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, platform, sortBy, sortOrder string) ([]UserSubscription, *pagination.PaginationResult, error) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) ExtendExpiry(ctx context.Context, subscriptionID int64, newExpiresAt time.Time) error diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index 7da72630..a4c667be 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -486,4 +486,5 @@ var ProviderSet = wire.NewSet( ProvideIdempotencyCleanupService, ProvideScheduledTestService, ProvideScheduledTestRunnerService, + NewGroupCapacityService, ) diff --git a/backend/internal/setup/handler.go b/backend/internal/setup/handler.go index 1531c97b..c2944ced 100644 --- a/backend/internal/setup/handler.go +++ b/backend/internal/setup/handler.go @@ -247,6 +247,12 @@ func install(c *gin.Context) { return } + req.Admin.Email = strings.TrimSpace(req.Admin.Email) + req.Database.Host = strings.TrimSpace(req.Database.Host) + req.Database.User = strings.TrimSpace(req.Database.User) + req.Database.DBName = strings.TrimSpace(req.Database.DBName) + req.Redis.Host = strings.TrimSpace(req.Redis.Host) + // ========== COMPREHENSIVE INPUT VALIDATION ========== // Database validation if !validateHostname(req.Database.Host) { @@ -319,13 +325,6 @@ func install(c *gin.Context) { return } - // Trim whitespace from string inputs - req.Admin.Email = strings.TrimSpace(req.Admin.Email) - req.Database.Host = strings.TrimSpace(req.Database.Host) - req.Database.User = strings.TrimSpace(req.Database.User) - req.Database.DBName = strings.TrimSpace(req.Database.DBName) - req.Redis.Host = strings.TrimSpace(req.Redis.Host) - cfg := &SetupConfig{ Database: req.Database, Redis: req.Redis, diff --git a/backend/internal/web/embed_on.go b/backend/internal/web/embed_on.go index 41ce4d48..ffca98a5 100644 --- a/backend/internal/web/embed_on.go +++ b/backend/internal/web/embed_on.go @@ -180,7 +180,37 @@ func (s *FrontendServer) injectSettings(settingsJSON []byte) []byte { // Inject before headClose := []byte("") - return bytes.Replace(s.baseHTML, headClose, append(script, headClose...), 1) + result := bytes.Replace(s.baseHTML, headClose, append(script, headClose...), 1) + + // Replace with custom site name so the browser tab shows it immediately + result = injectSiteTitle(result, settingsJSON) + + return result +} + +// injectSiteTitle replaces the static <title> in HTML with the configured site name. +// This ensures the browser tab shows the correct title before JS executes. +func injectSiteTitle(html, settingsJSON []byte) []byte { + var cfg struct { + SiteName string `json:"site_name"` + } + if err := json.Unmarshal(settingsJSON, &cfg); err != nil || cfg.SiteName == "" { + return html + } + + // Find and replace the existing <title>... + titleStart := bytes.Index(html, []byte("")) + titleEnd := bytes.Index(html, []byte("")) + if titleStart == -1 || titleEnd == -1 || titleEnd <= titleStart { + return html + } + + newTitle := []byte("" + cfg.SiteName + " - AI API Gateway") + var buf bytes.Buffer + buf.Write(html[:titleStart]) + buf.Write(newTitle) + buf.Write(html[titleEnd+len(""):]) + return buf.Bytes() } // replaceNoncePlaceholder replaces the nonce placeholder with actual nonce value diff --git a/backend/internal/web/embed_test.go b/backend/internal/web/embed_test.go index f270b624..fd47c4da 100644 --- a/backend/internal/web/embed_test.go +++ b/backend/internal/web/embed_test.go @@ -20,6 +20,78 @@ func init() { gin.SetMode(gin.TestMode) } +func TestInjectSiteTitle(t *testing.T) { + t.Run("replaces_title_with_site_name", func(t *testing.T) { + html := []byte(`Sub2API - AI API Gateway`) + settingsJSON := []byte(`{"site_name":"MyCustomSite"}`) + + result := injectSiteTitle(html, settingsJSON) + + assert.Contains(t, string(result), "MyCustomSite - AI API Gateway") + assert.NotContains(t, string(result), "Sub2API") + }) + + t.Run("returns_unchanged_when_site_name_empty", func(t *testing.T) { + html := []byte(`Sub2API - AI API Gateway`) + settingsJSON := []byte(`{"site_name":""}`) + + result := injectSiteTitle(html, settingsJSON) + + assert.Equal(t, string(html), string(result)) + }) + + t.Run("returns_unchanged_when_site_name_missing", func(t *testing.T) { + html := []byte(`Sub2API - AI API Gateway`) + settingsJSON := []byte(`{"other_field":"value"}`) + + result := injectSiteTitle(html, settingsJSON) + + assert.Equal(t, string(html), string(result)) + }) + + t.Run("returns_unchanged_when_invalid_json", func(t *testing.T) { + html := []byte(`Sub2API - AI API Gateway`) + settingsJSON := []byte(`{invalid json}`) + + result := injectSiteTitle(html, settingsJSON) + + assert.Equal(t, string(html), string(result)) + }) + + t.Run("returns_unchanged_when_no_title_tag", func(t *testing.T) { + html := []byte(``) + settingsJSON := []byte(`{"site_name":"MyCustomSite"}`) + + result := injectSiteTitle(html, settingsJSON) + + assert.Equal(t, string(html), string(result)) + }) + + t.Run("returns_unchanged_when_title_has_attributes", func(t *testing.T) { + // The function looks for "" literally, so attributes are not supported + // This is acceptable since index.html uses plain <title> without attributes + html := []byte(`<html><head><title lang="en">Sub2API`) + settingsJSON := []byte(`{"site_name":"NewSite"}`) + + result := injectSiteTitle(html, settingsJSON) + + // Should return unchanged since with attributes is not matched + assert.Equal(t, string(html), string(result)) + }) + + t.Run("preserves_rest_of_html", func(t *testing.T) { + html := []byte(`<html><head><meta charset="UTF-8"><title>Sub2API
`) + settingsJSON := []byte(`{"site_name":"TestSite"}`) + + result := injectSiteTitle(html, settingsJSON) + + assert.Contains(t, string(result), ``) + assert.Contains(t, string(result), ``) + assert.Contains(t, string(result), `
`) + assert.Contains(t, string(result), "TestSite - AI API Gateway") + }) +} + func TestReplaceNoncePlaceholder(t *testing.T) { t.Run("replaces_single_placeholder", func(t *testing.T) { html := []byte(``) diff --git a/backend/migrations/075_add_usage_log_upstream_model.sql b/backend/migrations/075_add_usage_log_upstream_model.sql new file mode 100644 index 00000000..7f9f8ec6 --- /dev/null +++ b/backend/migrations/075_add_usage_log_upstream_model.sql @@ -0,0 +1,4 @@ +-- Add upstream_model field to usage_logs. +-- Stores the actual upstream model name when it differs from the requested model +-- (i.e., when model mapping is applied). NULL means no mapping was applied. +ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS upstream_model VARCHAR(100); diff --git a/backend/migrations/075_map_haiku45_to_sonnet46.sql b/backend/migrations/075_map_haiku45_to_sonnet46.sql new file mode 100644 index 00000000..bbaa45e6 --- /dev/null +++ b/backend/migrations/075_map_haiku45_to_sonnet46.sql @@ -0,0 +1,17 @@ +-- Map claude-haiku-4-5 variants target from claude-sonnet-4-5 to claude-sonnet-4-6 +-- +-- Only updates when the current target is exactly claude-sonnet-4-5. + +-- 1. claude-haiku-4-5 +UPDATE accounts +SET credentials = jsonb_set(credentials, '{model_mapping,claude-haiku-4-5}', '"claude-sonnet-4-6"') +WHERE platform = 'antigravity' + AND deleted_at IS NULL + AND credentials->'model_mapping'->>'claude-haiku-4-5' = 'claude-sonnet-4-5'; + +-- 2. claude-haiku-4-5-20251001 +UPDATE accounts +SET credentials = jsonb_set(credentials, '{model_mapping,claude-haiku-4-5-20251001}', '"claude-sonnet-4-6"') +WHERE platform = 'antigravity' + AND deleted_at IS NULL + AND credentials->'model_mapping'->>'claude-haiku-4-5-20251001' = 'claude-sonnet-4-5'; diff --git a/backend/migrations/076_add_usage_log_upstream_model_index_notx.sql b/backend/migrations/076_add_usage_log_upstream_model_index_notx.sql new file mode 100644 index 00000000..9eee61be --- /dev/null +++ b/backend/migrations/076_add_usage_log_upstream_model_index_notx.sql @@ -0,0 +1,3 @@ +-- Support upstream_model / mapping model distribution aggregations with time-range filters. +CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_usage_logs_created_model_upstream_model +ON usage_logs (created_at, model, upstream_model); diff --git a/backend/migrations/README.md b/backend/migrations/README.md index 47f6fa35..40455ad9 100644 --- a/backend/migrations/README.md +++ b/backend/migrations/README.md @@ -34,18 +34,18 @@ Example: `017_add_gemini_tier_id.sql` ## Migration File Structure -```sql --- +goose Up --- +goose StatementBegin --- Your forward migration SQL here --- +goose StatementEnd +This project uses a custom migration runner (`internal/repository/migrations_runner.go`) that executes the full SQL file content as-is. --- +goose Down --- +goose StatementBegin --- Your rollback migration SQL here --- +goose StatementEnd +- Regular migrations (`*.sql`): executed in a transaction. +- Non-transactional migrations (`*_notx.sql`): split by statement and executed without transaction (for `CONCURRENTLY`). + +```sql +-- Forward-only migration (recommended) +ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS example_column VARCHAR(100); ``` +> ⚠️ Do **not** place executable "Down" SQL in the same file. The runner does not parse goose Up/Down sections and will execute all SQL statements in the file. + ## Important Rules ### ⚠️ Immutability Principle @@ -66,9 +66,9 @@ Why? touch migrations/018_your_change.sql ``` -2. **Write Up and Down migrations** - - Up: Apply the change - - Down: Revert the change (should be symmetric with Up) +2. **Write forward-only migration SQL** + - Put only the intended schema change in the file + - If rollback is needed, create a new migration file to revert 3. **Test locally** ```bash @@ -144,8 +144,6 @@ touch migrations/018_your_new_change.sql ## Example Migration ```sql --- +goose Up --- +goose StatementBegin -- Add tier_id field to Gemini OAuth accounts for quota tracking UPDATE accounts SET credentials = jsonb_set( @@ -157,17 +155,6 @@ SET credentials = jsonb_set( WHERE platform = 'gemini' AND type = 'oauth' AND credentials->>'tier_id' IS NULL; --- +goose StatementEnd - --- +goose Down --- +goose StatementBegin --- Remove tier_id field -UPDATE accounts -SET credentials = credentials - 'tier_id' -WHERE platform = 'gemini' - AND type = 'oauth' - AND credentials->>'tier_id' = 'LEGACY'; --- +goose StatementEnd ``` ## Troubleshooting @@ -194,5 +181,4 @@ VALUES ('NNN_migration.sql', 'calculated_checksum', NOW()); ## References - Migration runner: `internal/repository/migrations_runner.go` -- Goose syntax: https://github.com/pressly/goose - PostgreSQL docs: https://www.postgresql.org/docs/ diff --git a/deploy/docker-compose.local.yml b/deploy/docker-compose.local.yml index d404ac0b..5aea78fb 100644 --- a/deploy/docker-compose.local.yml +++ b/deploy/docker-compose.local.yml @@ -38,7 +38,7 @@ services: - ./data:/app/data # Optional: Mount custom config.yaml (uncomment and create the file first) # Copy config.example.yaml to config.yaml, modify it, then uncomment: - # - ./config.yaml:/app/data/config.yaml:ro + # - ./config.yaml:/app/data/config.yaml environment: # ======================================================================= # Auto Setup (REQUIRED for Docker deployment) diff --git a/deploy/docker-compose.yml b/deploy/docker-compose.yml index acd21fd9..a0bc1a60 100644 --- a/deploy/docker-compose.yml +++ b/deploy/docker-compose.yml @@ -30,7 +30,7 @@ services: - sub2api_data:/app/data # Optional: Mount custom config.yaml (uncomment and create the file first) # Copy config.example.yaml to config.yaml, modify it, then uncomment: - # - ./config.yaml:/app/data/config.yaml:ro + # - ./config.yaml:/app/data/config.yaml environment: # ======================================================================= # Auto Setup (REQUIRED for Docker deployment) diff --git a/deploy/docker-entrypoint.sh b/deploy/docker-entrypoint.sh index 344429bd..47ab6bf1 100644 --- a/deploy/docker-entrypoint.sh +++ b/deploy/docker-entrypoint.sh @@ -6,7 +6,8 @@ set -e # preventing the non-root sub2api user from writing files. if [ "$(id -u)" = "0" ]; then mkdir -p /app/data - chown -R sub2api:sub2api /app/data + # Use || true to avoid failure on read-only mounted files (e.g. config.yaml:ro) + chown -R sub2api:sub2api /app/data 2>/dev/null || true # Re-invoke this script as sub2api so the flag-detection below # also runs under the correct user. exec su-exec sub2api "$0" "$@" diff --git a/frontend/src/App.vue b/frontend/src/App.vue index 4fc6a7c8..7485aa1a 100644 --- a/frontend/src/App.vue +++ b/frontend/src/App.vue @@ -3,6 +3,7 @@ import { RouterView, useRouter, useRoute } from 'vue-router' import { onMounted, onBeforeUnmount, watch } from 'vue' import Toast from '@/components/common/Toast.vue' import NavigationProgress from '@/components/common/NavigationProgress.vue' +import { resolveDocumentTitle } from '@/router/title' import AnnouncementPopup from '@/components/common/AnnouncementPopup.vue' import { useAppStore, useAuthStore, useSubscriptionStore, useAnnouncementStore } from '@/stores' import { getSetupStatus } from '@/api/setup' @@ -104,6 +105,9 @@ onMounted(async () => { // Load public settings into appStore (will be cached for other components) await appStore.fetchPublicSettings() + + // Re-resolve document title now that siteName is available + document.title = resolveDocumentTitle(route.meta.title, appStore.siteName, route.meta.titleKey as string) }) diff --git a/frontend/src/api/admin/dashboard.ts b/frontend/src/api/admin/dashboard.ts index 0bf0a2c5..15d1540f 100644 --- a/frontend/src/api/admin/dashboard.ts +++ b/frontend/src/api/admin/dashboard.ts @@ -81,6 +81,7 @@ export interface ModelStatsParams { user_id?: number api_key_id?: number model?: string + model_source?: 'requested' | 'upstream' | 'mapping' account_id?: number group_id?: number request_type?: UsageRequestType @@ -162,6 +163,7 @@ export interface UserBreakdownParams { end_date?: string group_id?: number model?: string + model_source?: 'requested' | 'upstream' | 'mapping' endpoint?: string endpoint_type?: 'inbound' | 'upstream' | 'path' limit?: number diff --git a/frontend/src/api/admin/groups.ts b/frontend/src/api/admin/groups.ts index 7c2658fa..5885dc6a 100644 --- a/frontend/src/api/admin/groups.ts +++ b/frontend/src/api/admin/groups.ts @@ -218,6 +218,34 @@ export async function batchSetGroupRateMultipliers( return data } +/** + * Get usage summary (today + cumulative cost) for all groups + * @param timezone - IANA timezone string (e.g. "Asia/Shanghai") + * @returns Array of group usage summaries + */ +export async function getUsageSummary( + timezone?: string +): Promise<{ group_id: number; today_cost: number; total_cost: number }[]> { + const { data } = await apiClient.get< + { group_id: number; today_cost: number; total_cost: number }[] + >('/admin/groups/usage-summary', { + params: timezone ? { timezone } : undefined + }) + return data +} + +/** + * Get capacity summary (concurrency/sessions/RPM) for all active groups + */ +export async function getCapacitySummary(): Promise< + { group_id: number; concurrency_used: number; concurrency_max: number; sessions_used: number; sessions_max: number; rpm_used: number; rpm_max: number }[] +> { + const { data } = await apiClient.get< + { group_id: number; concurrency_used: number; concurrency_max: number; sessions_used: number; sessions_max: number; rpm_used: number; rpm_max: number }[] + >('/admin/groups/capacity-summary') + return data +} + export const groupsAPI = { list, getAll, @@ -232,7 +260,9 @@ export const groupsAPI = { getGroupRateMultipliers, clearGroupRateMultipliers, batchSetGroupRateMultipliers, - updateSortOrder + updateSortOrder, + getUsageSummary, + getCapacitySummary } export default groupsAPI diff --git a/frontend/src/api/admin/settings.ts b/frontend/src/api/admin/settings.ts index 040cf71e..a2cd67f0 100644 --- a/frontend/src/api/admin/settings.ts +++ b/frontend/src/api/admin/settings.ts @@ -242,6 +242,33 @@ export async function deleteAdminApiKey(): Promise<{ message: string }> { return data } +// ==================== Overload Cooldown Settings ==================== + +/** + * Overload cooldown settings interface (529 handling) + */ +export interface OverloadCooldownSettings { + enabled: boolean + cooldown_minutes: number +} + +export async function getOverloadCooldownSettings(): Promise { + const { data } = await apiClient.get('/admin/settings/overload-cooldown') + return data +} + +export async function updateOverloadCooldownSettings( + settings: OverloadCooldownSettings +): Promise { + const { data } = await apiClient.put( + '/admin/settings/overload-cooldown', + settings + ) + return data +} + +// ==================== Stream Timeout Settings ==================== + /** * Stream timeout settings interface */ @@ -499,6 +526,8 @@ export const settingsAPI = { getAdminApiKey, regenerateAdminApiKey, deleteAdminApiKey, + getOverloadCooldownSettings, + updateOverloadCooldownSettings, getStreamTimeoutSettings, updateStreamTimeoutSettings, getRectifierSettings, diff --git a/frontend/src/api/admin/subscriptions.ts b/frontend/src/api/admin/subscriptions.ts index 7557e3ad..611f67c2 100644 --- a/frontend/src/api/admin/subscriptions.ts +++ b/frontend/src/api/admin/subscriptions.ts @@ -27,6 +27,7 @@ export async function list( status?: 'active' | 'expired' | 'revoked' user_id?: number group_id?: number + platform?: string sort_by?: string sort_order?: 'asc' | 'desc' }, diff --git a/frontend/src/components/account/AccountUsageCell.vue b/frontend/src/components/account/AccountUsageCell.vue index e548be8c..131d82b2 100644 --- a/frontend/src/components/account/AccountUsageCell.vue +++ b/frontend/src/components/account/AccountUsageCell.vue @@ -82,6 +82,7 @@ :utilization="usageInfo.five_hour.utilization" :resets-at="usageInfo.five_hour.resets_at" :window-stats="usageInfo.five_hour.window_stats" + :show-now-when-idle="true" color="indigo" /> diff --git a/frontend/src/components/account/UsageProgressBar.vue b/frontend/src/components/account/UsageProgressBar.vue index 506071fa..52f0ecbb 100644 --- a/frontend/src/components/account/UsageProgressBar.vue +++ b/frontend/src/components/account/UsageProgressBar.vue @@ -48,7 +48,7 @@ - + {{ formatResetTime }} @@ -68,6 +68,7 @@ const props = defineProps<{ resetsAt?: string | null color: 'indigo' | 'emerald' | 'purple' | 'amber' windowStats?: WindowStats | null + showNowWhenIdle?: boolean }>() const { t } = useI18n() @@ -139,9 +140,20 @@ const displayPercent = computed(() => { return percent > 999 ? '>999%' : `${percent}%` }) +const shouldShowResetTime = computed(() => { + if (props.resetsAt) return true + return Boolean(props.showNowWhenIdle && props.utilization <= 0) +}) + // Format reset time const formatResetTime = computed(() => { + // For rolling windows, when utilization is 0%, treat as immediately available. + if (props.showNowWhenIdle && props.utilization <= 0) { + return '现在' + } + if (!props.resetsAt) return '-' + const date = new Date(props.resetsAt) const diffMs = date.getTime() - now.value.getTime() diff --git a/frontend/src/components/account/__tests__/UsageProgressBar.spec.ts b/frontend/src/components/account/__tests__/UsageProgressBar.spec.ts new file mode 100644 index 00000000..9def052c --- /dev/null +++ b/frontend/src/components/account/__tests__/UsageProgressBar.spec.ts @@ -0,0 +1,69 @@ +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' +import { mount } from '@vue/test-utils' +import UsageProgressBar from '../UsageProgressBar.vue' + +vi.mock('vue-i18n', async () => { + const actual = await vi.importActual('vue-i18n') + return { + ...actual, + useI18n: () => ({ + t: (key: string) => key + }) + } +}) + +describe('UsageProgressBar', () => { + beforeEach(() => { + vi.useFakeTimers() + vi.setSystemTime(new Date('2026-03-17T00:00:00Z')) + }) + + afterEach(() => { + vi.useRealTimers() + }) + + it('showNowWhenIdle=true 且利用率为 0 时显示“现在”', () => { + const wrapper = mount(UsageProgressBar, { + props: { + label: '5h', + utilization: 0, + resetsAt: '2026-03-17T02:30:00Z', + showNowWhenIdle: true, + color: 'indigo' + } + }) + + expect(wrapper.text()).toContain('现在') + expect(wrapper.text()).not.toContain('2h 30m') + }) + + it('showNowWhenIdle=true 但利用率大于 0 时显示倒计时', () => { + const wrapper = mount(UsageProgressBar, { + props: { + label: '7d', + utilization: 12, + resetsAt: '2026-03-17T02:30:00Z', + showNowWhenIdle: true, + color: 'emerald' + } + }) + + expect(wrapper.text()).toContain('2h 30m') + expect(wrapper.text()).not.toContain('现在') + }) + + it('showNowWhenIdle=false 时保持原有倒计时行为', () => { + const wrapper = mount(UsageProgressBar, { + props: { + label: '1d', + utilization: 0, + resetsAt: '2026-03-17T02:30:00Z', + showNowWhenIdle: false, + color: 'indigo' + } + }) + + expect(wrapper.text()).toContain('2h 30m') + expect(wrapper.text()).not.toContain('现在') + }) +}) diff --git a/frontend/src/components/admin/usage/UsageTable.vue b/frontend/src/components/admin/usage/UsageTable.vue index aa6c2bbd..4a42ab05 100644 --- a/frontend/src/components/admin/usage/UsageTable.vue +++ b/frontend/src/components/admin/usage/UsageTable.vue @@ -25,8 +25,16 @@ {{ row.account?.name || '-' }} -