diff --git a/backend/internal/handler/admin/channel_handler.go b/backend/internal/handler/admin/channel_handler.go index 6c460d8e..34043115 100644 --- a/backend/internal/handler/admin/channel_handler.go +++ b/backend/internal/handler/admin/channel_handler.go @@ -26,37 +26,37 @@ func NewChannelHandler(channelService *service.ChannelService, billingService *s // --- Request / Response types --- type createChannelRequest struct { - Name string `json:"name" binding:"required,max=100"` - Description string `json:"description"` - GroupIDs []int64 `json:"group_ids"` - ModelPricing []channelModelPricingRequest `json:"model_pricing"` - ModelMapping map[string]map[string]string `json:"model_mapping"` - BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream"` - RestrictModels bool `json:"restrict_models"` + Name string `json:"name" binding:"required,max=100"` + Description string `json:"description"` + GroupIDs []int64 `json:"group_ids"` + ModelPricing []channelModelPricingRequest `json:"model_pricing"` + ModelMapping map[string]map[string]string `json:"model_mapping"` + BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream"` + RestrictModels bool `json:"restrict_models"` } type updateChannelRequest struct { - Name string `json:"name" binding:"omitempty,max=100"` - Description *string `json:"description"` - Status string `json:"status" binding:"omitempty,oneof=active disabled"` - GroupIDs *[]int64 `json:"group_ids"` - ModelPricing *[]channelModelPricingRequest `json:"model_pricing"` - ModelMapping map[string]map[string]string `json:"model_mapping"` - BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream"` - RestrictModels *bool `json:"restrict_models"` + Name string `json:"name" binding:"omitempty,max=100"` + Description *string `json:"description"` + Status string `json:"status" binding:"omitempty,oneof=active disabled"` + GroupIDs *[]int64 `json:"group_ids"` + ModelPricing *[]channelModelPricingRequest `json:"model_pricing"` + ModelMapping map[string]map[string]string `json:"model_mapping"` + BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream"` + RestrictModels *bool `json:"restrict_models"` } type channelModelPricingRequest struct { - Platform string `json:"platform" binding:"omitempty,max=50"` - Models []string `json:"models" binding:"required,min=1,max=100"` - BillingMode string `json:"billing_mode" binding:"omitempty,oneof=token per_request image"` - InputPrice *float64 `json:"input_price" binding:"omitempty,min=0"` - OutputPrice *float64 `json:"output_price" binding:"omitempty,min=0"` - CacheWritePrice *float64 `json:"cache_write_price" binding:"omitempty,min=0"` - CacheReadPrice *float64 `json:"cache_read_price" binding:"omitempty,min=0"` - ImageOutputPrice *float64 `json:"image_output_price" binding:"omitempty,min=0"` - PerRequestPrice *float64 `json:"per_request_price" binding:"omitempty,min=0"` - Intervals []pricingIntervalRequest `json:"intervals"` + Platform string `json:"platform" binding:"omitempty,max=50"` + Models []string `json:"models" binding:"required,min=1,max=100"` + BillingMode string `json:"billing_mode" binding:"omitempty,oneof=token per_request image"` + InputPrice *float64 `json:"input_price" binding:"omitempty,min=0"` + OutputPrice *float64 `json:"output_price" binding:"omitempty,min=0"` + CacheWritePrice *float64 `json:"cache_write_price" binding:"omitempty,min=0"` + CacheReadPrice *float64 `json:"cache_read_price" binding:"omitempty,min=0"` + ImageOutputPrice *float64 `json:"image_output_price" binding:"omitempty,min=0"` + PerRequestPrice *float64 `json:"per_request_price" binding:"omitempty,min=0"` + Intervals []pricingIntervalRequest `json:"intervals"` } type pricingIntervalRequest struct { @@ -72,31 +72,31 @@ type pricingIntervalRequest struct { } type channelResponse struct { - ID int64 `json:"id"` - Name string `json:"name"` - Description string `json:"description"` - Status string `json:"status"` - BillingModelSource string `json:"billing_model_source"` - RestrictModels bool `json:"restrict_models"` - GroupIDs []int64 `json:"group_ids"` - ModelPricing []channelModelPricingResponse `json:"model_pricing"` - ModelMapping map[string]map[string]string `json:"model_mapping"` - CreatedAt string `json:"created_at"` - UpdatedAt string `json:"updated_at"` + ID int64 `json:"id"` + Name string `json:"name"` + Description string `json:"description"` + Status string `json:"status"` + BillingModelSource string `json:"billing_model_source"` + RestrictModels bool `json:"restrict_models"` + GroupIDs []int64 `json:"group_ids"` + ModelPricing []channelModelPricingResponse `json:"model_pricing"` + ModelMapping map[string]map[string]string `json:"model_mapping"` + CreatedAt string `json:"created_at"` + UpdatedAt string `json:"updated_at"` } type channelModelPricingResponse struct { - ID int64 `json:"id"` - Platform string `json:"platform"` - Models []string `json:"models"` - BillingMode string `json:"billing_mode"` - InputPrice *float64 `json:"input_price"` - OutputPrice *float64 `json:"output_price"` - CacheWritePrice *float64 `json:"cache_write_price"` - CacheReadPrice *float64 `json:"cache_read_price"` - ImageOutputPrice *float64 `json:"image_output_price"` - PerRequestPrice *float64 `json:"per_request_price"` - Intervals []pricingIntervalResponse `json:"intervals"` + ID int64 `json:"id"` + Platform string `json:"platform"` + Models []string `json:"models"` + BillingMode string `json:"billing_mode"` + InputPrice *float64 `json:"input_price"` + OutputPrice *float64 `json:"output_price"` + CacheWritePrice *float64 `json:"cache_write_price"` + CacheReadPrice *float64 `json:"cache_read_price"` + ImageOutputPrice *float64 `json:"image_output_price"` + PerRequestPrice *float64 `json:"per_request_price"` + Intervals []pricingIntervalResponse `json:"intervals"` } type pricingIntervalResponse struct { @@ -117,15 +117,15 @@ func channelToResponse(ch *service.Channel) *channelResponse { return nil } resp := &channelResponse{ - ID: ch.ID, - Name: ch.Name, - Description: ch.Description, - Status: ch.Status, + ID: ch.ID, + Name: ch.Name, + Description: ch.Description, + Status: ch.Status, RestrictModels: ch.RestrictModels, - GroupIDs: ch.GroupIDs, - ModelMapping: ch.ModelMapping, - CreatedAt: ch.CreatedAt.Format("2006-01-02T15:04:05Z"), - UpdatedAt: ch.UpdatedAt.Format("2006-01-02T15:04:05Z"), + GroupIDs: ch.GroupIDs, + ModelMapping: ch.ModelMapping, + CreatedAt: ch.CreatedAt.Format("2006-01-02T15:04:05Z"), + UpdatedAt: ch.UpdatedAt.Format("2006-01-02T15:04:05Z"), } resp.BillingModelSource = ch.BillingModelSource if resp.BillingModelSource == "" { @@ -298,9 +298,9 @@ func (h *ChannelHandler) Create(c *gin.Context) { channel, err := h.channelService.Create(c.Request.Context(), &service.CreateChannelInput{ Name: req.Name, Description: req.Description, - GroupIDs: req.GroupIDs, - ModelPricing: pricing, - ModelMapping: req.ModelMapping, + GroupIDs: req.GroupIDs, + ModelPricing: pricing, + ModelMapping: req.ModelMapping, BillingModelSource: req.BillingModelSource, RestrictModels: req.RestrictModels, }) @@ -331,8 +331,8 @@ func (h *ChannelHandler) Update(c *gin.Context) { Name: req.Name, Description: req.Description, Status: req.Status, - GroupIDs: req.GroupIDs, - ModelMapping: req.ModelMapping, + GroupIDs: req.GroupIDs, + ModelMapping: req.ModelMapping, BillingModelSource: req.BillingModelSource, RestrictModels: req.RestrictModels, } diff --git a/backend/internal/handler/admin/channel_handler_test.go b/backend/internal/handler/admin/channel_handler_test.go new file mode 100644 index 00000000..a4c3234e --- /dev/null +++ b/backend/internal/handler/admin/channel_handler_test.go @@ -0,0 +1,502 @@ +//go:build unit + +package admin + +import ( + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// helpers +// --------------------------------------------------------------------------- + +func float64Ptr(v float64) *float64 { return &v } +func intPtr(v int) *int { return &v } + +// --------------------------------------------------------------------------- +// 1. channelToResponse +// --------------------------------------------------------------------------- + +func TestChannelToResponse_NilInput(t *testing.T) { + require.Nil(t, channelToResponse(nil)) +} + +func TestChannelToResponse_FullChannel(t *testing.T) { + now := time.Date(2025, 6, 1, 12, 0, 0, 0, time.UTC) + ch := &service.Channel{ + ID: 42, + Name: "test-channel", + Description: "desc", + Status: "active", + BillingModelSource: "upstream", + RestrictModels: true, + CreatedAt: now, + UpdatedAt: now.Add(time.Hour), + GroupIDs: []int64{1, 2, 3}, + ModelPricing: []service.ChannelModelPricing{ + { + ID: 10, + Platform: "openai", + Models: []string{"gpt-4"}, + BillingMode: service.BillingModeToken, + InputPrice: float64Ptr(0.01), + OutputPrice: float64Ptr(0.03), + CacheWritePrice: float64Ptr(0.005), + CacheReadPrice: float64Ptr(0.002), + PerRequestPrice: float64Ptr(0.5), + }, + }, + ModelMapping: map[string]map[string]string{ + "anthropic": {"claude-3-haiku": "claude-haiku-3"}, + }, + } + + resp := channelToResponse(ch) + require.NotNil(t, resp) + require.Equal(t, int64(42), resp.ID) + require.Equal(t, "test-channel", resp.Name) + require.Equal(t, "desc", resp.Description) + require.Equal(t, "active", resp.Status) + require.Equal(t, "upstream", resp.BillingModelSource) + require.True(t, resp.RestrictModels) + require.Equal(t, []int64{1, 2, 3}, resp.GroupIDs) + require.Equal(t, "2025-06-01T12:00:00Z", resp.CreatedAt) + require.Equal(t, "2025-06-01T13:00:00Z", resp.UpdatedAt) + + // model mapping + require.Len(t, resp.ModelMapping, 1) + require.Equal(t, "claude-haiku-3", resp.ModelMapping["anthropic"]["claude-3-haiku"]) + + // pricing + require.Len(t, resp.ModelPricing, 1) + p := resp.ModelPricing[0] + require.Equal(t, int64(10), p.ID) + require.Equal(t, "openai", p.Platform) + require.Equal(t, []string{"gpt-4"}, p.Models) + require.Equal(t, "token", p.BillingMode) + require.Equal(t, float64Ptr(0.01), p.InputPrice) + require.Equal(t, float64Ptr(0.03), p.OutputPrice) + require.Equal(t, float64Ptr(0.005), p.CacheWritePrice) + require.Equal(t, float64Ptr(0.002), p.CacheReadPrice) + require.Equal(t, float64Ptr(0.5), p.PerRequestPrice) + require.Empty(t, p.Intervals) +} + +func TestChannelToResponse_EmptyDefaults(t *testing.T) { + now := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC) + ch := &service.Channel{ + ID: 1, + Name: "ch", + BillingModelSource: "", + CreatedAt: now, + UpdatedAt: now, + GroupIDs: nil, + ModelMapping: nil, + ModelPricing: []service.ChannelModelPricing{ + { + Platform: "", + BillingMode: "", + Models: []string{"m1"}, + }, + }, + } + + resp := channelToResponse(ch) + require.Equal(t, "requested", resp.BillingModelSource) + require.NotNil(t, resp.GroupIDs) + require.Empty(t, resp.GroupIDs) + require.NotNil(t, resp.ModelMapping) + require.Empty(t, resp.ModelMapping) + + require.Len(t, resp.ModelPricing, 1) + require.Equal(t, "anthropic", resp.ModelPricing[0].Platform) + require.Equal(t, "token", resp.ModelPricing[0].BillingMode) +} + +func TestChannelToResponse_NilModels(t *testing.T) { + now := time.Now() + ch := &service.Channel{ + ID: 1, + Name: "ch", + CreatedAt: now, + UpdatedAt: now, + ModelPricing: []service.ChannelModelPricing{ + { + Models: nil, + }, + }, + } + + resp := channelToResponse(ch) + require.Len(t, resp.ModelPricing, 1) + require.NotNil(t, resp.ModelPricing[0].Models) + require.Empty(t, resp.ModelPricing[0].Models) +} + +func TestChannelToResponse_WithIntervals(t *testing.T) { + now := time.Now() + ch := &service.Channel{ + ID: 1, + Name: "ch", + CreatedAt: now, + UpdatedAt: now, + ModelPricing: []service.ChannelModelPricing{ + { + Models: []string{"m1"}, + BillingMode: service.BillingModePerRequest, + Intervals: []service.PricingInterval{ + { + ID: 100, + MinTokens: 0, + MaxTokens: intPtr(1000), + TierLabel: "1K", + InputPrice: float64Ptr(0.01), + OutputPrice: float64Ptr(0.02), + CacheWritePrice: float64Ptr(0.003), + CacheReadPrice: float64Ptr(0.001), + PerRequestPrice: float64Ptr(0.1), + SortOrder: 1, + }, + { + ID: 101, + MinTokens: 1000, + MaxTokens: nil, + TierLabel: "unlimited", + SortOrder: 2, + }, + }, + }, + }, + } + + resp := channelToResponse(ch) + require.Len(t, resp.ModelPricing, 1) + intervals := resp.ModelPricing[0].Intervals + require.Len(t, intervals, 2) + + iv0 := intervals[0] + require.Equal(t, int64(100), iv0.ID) + require.Equal(t, 0, iv0.MinTokens) + require.Equal(t, intPtr(1000), iv0.MaxTokens) + require.Equal(t, "1K", iv0.TierLabel) + require.Equal(t, float64Ptr(0.01), iv0.InputPrice) + require.Equal(t, float64Ptr(0.02), iv0.OutputPrice) + require.Equal(t, float64Ptr(0.003), iv0.CacheWritePrice) + require.Equal(t, float64Ptr(0.001), iv0.CacheReadPrice) + require.Equal(t, float64Ptr(0.1), iv0.PerRequestPrice) + require.Equal(t, 1, iv0.SortOrder) + + iv1 := intervals[1] + require.Equal(t, int64(101), iv1.ID) + require.Equal(t, 1000, iv1.MinTokens) + require.Nil(t, iv1.MaxTokens) + require.Equal(t, "unlimited", iv1.TierLabel) + require.Equal(t, 2, iv1.SortOrder) +} + +func TestChannelToResponse_MultipleEntries(t *testing.T) { + now := time.Now() + ch := &service.Channel{ + ID: 1, + Name: "multi", + CreatedAt: now, + UpdatedAt: now, + ModelPricing: []service.ChannelModelPricing{ + { + ID: 1, + Platform: "anthropic", + Models: []string{"claude-sonnet-4"}, + BillingMode: service.BillingModeToken, + InputPrice: float64Ptr(0.003), + OutputPrice: float64Ptr(0.015), + }, + { + ID: 2, + Platform: "openai", + Models: []string{"gpt-4", "gpt-4o"}, + BillingMode: service.BillingModePerRequest, + PerRequestPrice: float64Ptr(1.0), + }, + { + ID: 3, + Platform: "gemini", + Models: []string{"gemini-2.5-pro"}, + BillingMode: service.BillingModeImage, + ImageOutputPrice: float64Ptr(0.05), + PerRequestPrice: float64Ptr(0.2), + }, + }, + } + + resp := channelToResponse(ch) + require.Len(t, resp.ModelPricing, 3) + + require.Equal(t, int64(1), resp.ModelPricing[0].ID) + require.Equal(t, "anthropic", resp.ModelPricing[0].Platform) + require.Equal(t, []string{"claude-sonnet-4"}, resp.ModelPricing[0].Models) + require.Equal(t, "token", resp.ModelPricing[0].BillingMode) + + require.Equal(t, int64(2), resp.ModelPricing[1].ID) + require.Equal(t, "openai", resp.ModelPricing[1].Platform) + require.Equal(t, []string{"gpt-4", "gpt-4o"}, resp.ModelPricing[1].Models) + require.Equal(t, "per_request", resp.ModelPricing[1].BillingMode) + + require.Equal(t, int64(3), resp.ModelPricing[2].ID) + require.Equal(t, "gemini", resp.ModelPricing[2].Platform) + require.Equal(t, []string{"gemini-2.5-pro"}, resp.ModelPricing[2].Models) + require.Equal(t, "image", resp.ModelPricing[2].BillingMode) + require.Equal(t, float64Ptr(0.05), resp.ModelPricing[2].ImageOutputPrice) +} + +// --------------------------------------------------------------------------- +// 2. pricingRequestToService +// --------------------------------------------------------------------------- + +func TestPricingRequestToService_Defaults(t *testing.T) { + tests := []struct { + name string + req channelModelPricingRequest + wantField string // which default field to check + wantValue string + }{ + { + name: "empty billing mode defaults to token", + req: channelModelPricingRequest{ + Models: []string{"m1"}, + BillingMode: "", + }, + wantField: "BillingMode", + wantValue: string(service.BillingModeToken), + }, + { + name: "empty platform defaults to anthropic", + req: channelModelPricingRequest{ + Models: []string{"m1"}, + Platform: "", + }, + wantField: "Platform", + wantValue: "anthropic", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := pricingRequestToService([]channelModelPricingRequest{tt.req}) + require.Len(t, result, 1) + switch tt.wantField { + case "BillingMode": + require.Equal(t, service.BillingMode(tt.wantValue), result[0].BillingMode) + case "Platform": + require.Equal(t, tt.wantValue, result[0].Platform) + } + }) + } +} + +func TestPricingRequestToService_WithAllFields(t *testing.T) { + reqs := []channelModelPricingRequest{ + { + Platform: "openai", + Models: []string{"gpt-4", "gpt-4o"}, + BillingMode: "per_request", + InputPrice: float64Ptr(0.01), + OutputPrice: float64Ptr(0.03), + CacheWritePrice: float64Ptr(0.005), + CacheReadPrice: float64Ptr(0.002), + ImageOutputPrice: float64Ptr(0.04), + PerRequestPrice: float64Ptr(0.5), + }, + } + + result := pricingRequestToService(reqs) + require.Len(t, result, 1) + r := result[0] + require.Equal(t, "openai", r.Platform) + require.Equal(t, []string{"gpt-4", "gpt-4o"}, r.Models) + require.Equal(t, service.BillingModePerRequest, r.BillingMode) + require.Equal(t, float64Ptr(0.01), r.InputPrice) + require.Equal(t, float64Ptr(0.03), r.OutputPrice) + require.Equal(t, float64Ptr(0.005), r.CacheWritePrice) + require.Equal(t, float64Ptr(0.002), r.CacheReadPrice) + require.Equal(t, float64Ptr(0.04), r.ImageOutputPrice) + require.Equal(t, float64Ptr(0.5), r.PerRequestPrice) +} + +func TestPricingRequestToService_WithIntervals(t *testing.T) { + reqs := []channelModelPricingRequest{ + { + Models: []string{"m1"}, + BillingMode: "per_request", + Intervals: []pricingIntervalRequest{ + { + MinTokens: 0, + MaxTokens: intPtr(2000), + TierLabel: "small", + InputPrice: float64Ptr(0.01), + OutputPrice: float64Ptr(0.02), + CacheWritePrice: float64Ptr(0.003), + CacheReadPrice: float64Ptr(0.001), + PerRequestPrice: float64Ptr(0.1), + SortOrder: 1, + }, + { + MinTokens: 2000, + MaxTokens: nil, + TierLabel: "large", + SortOrder: 2, + }, + }, + }, + } + + result := pricingRequestToService(reqs) + require.Len(t, result, 1) + require.Len(t, result[0].Intervals, 2) + + iv0 := result[0].Intervals[0] + require.Equal(t, 0, iv0.MinTokens) + require.Equal(t, intPtr(2000), iv0.MaxTokens) + require.Equal(t, "small", iv0.TierLabel) + require.Equal(t, float64Ptr(0.01), iv0.InputPrice) + require.Equal(t, float64Ptr(0.02), iv0.OutputPrice) + require.Equal(t, float64Ptr(0.003), iv0.CacheWritePrice) + require.Equal(t, float64Ptr(0.001), iv0.CacheReadPrice) + require.Equal(t, float64Ptr(0.1), iv0.PerRequestPrice) + require.Equal(t, 1, iv0.SortOrder) + + iv1 := result[0].Intervals[1] + require.Equal(t, 2000, iv1.MinTokens) + require.Nil(t, iv1.MaxTokens) + require.Equal(t, "large", iv1.TierLabel) + require.Equal(t, 2, iv1.SortOrder) +} + +func TestPricingRequestToService_EmptySlice(t *testing.T) { + result := pricingRequestToService([]channelModelPricingRequest{}) + require.NotNil(t, result) + require.Empty(t, result) +} + +func TestPricingRequestToService_NilPriceFields(t *testing.T) { + reqs := []channelModelPricingRequest{ + { + Models: []string{"m1"}, + BillingMode: "token", + // all price fields are nil by default + }, + } + + result := pricingRequestToService(reqs) + require.Len(t, result, 1) + r := result[0] + require.Nil(t, r.InputPrice) + require.Nil(t, r.OutputPrice) + require.Nil(t, r.CacheWritePrice) + require.Nil(t, r.CacheReadPrice) + require.Nil(t, r.ImageOutputPrice) + require.Nil(t, r.PerRequestPrice) +} + +// --------------------------------------------------------------------------- +// 3. validatePricingBillingMode +// --------------------------------------------------------------------------- + +func TestValidatePricingBillingMode(t *testing.T) { + tests := []struct { + name string + pricing []service.ChannelModelPricing + wantErr bool + }{ + { + name: "token mode - valid", + pricing: []service.ChannelModelPricing{ + {BillingMode: service.BillingModeToken}, + }, + wantErr: false, + }, + { + name: "per_request with price - valid", + pricing: []service.ChannelModelPricing{ + { + BillingMode: service.BillingModePerRequest, + PerRequestPrice: float64Ptr(0.5), + }, + }, + wantErr: false, + }, + { + name: "per_request with intervals - valid", + pricing: []service.ChannelModelPricing{ + { + BillingMode: service.BillingModePerRequest, + Intervals: []service.PricingInterval{ + {MinTokens: 0, MaxTokens: intPtr(1000), PerRequestPrice: float64Ptr(0.1)}, + }, + }, + }, + wantErr: false, + }, + { + name: "per_request no price no intervals - invalid", + pricing: []service.ChannelModelPricing{ + {BillingMode: service.BillingModePerRequest}, + }, + wantErr: true, + }, + { + name: "image with price - valid", + pricing: []service.ChannelModelPricing{ + { + BillingMode: service.BillingModeImage, + PerRequestPrice: float64Ptr(0.2), + }, + }, + wantErr: false, + }, + { + name: "image no price no intervals - invalid", + pricing: []service.ChannelModelPricing{ + {BillingMode: service.BillingModeImage}, + }, + wantErr: true, + }, + { + name: "empty list - valid", + pricing: []service.ChannelModelPricing{}, + wantErr: false, + }, + { + name: "mixed modes with invalid image - invalid", + pricing: []service.ChannelModelPricing{ + { + BillingMode: service.BillingModeToken, + InputPrice: float64Ptr(0.01), + }, + { + BillingMode: service.BillingModePerRequest, + PerRequestPrice: float64Ptr(0.5), + }, + { + BillingMode: service.BillingModeImage, + }, + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validatePricingBillingMode(tt.pricing) + if tt.wantErr { + require.Error(t, err) + require.Contains(t, err.Error(), "Per-request price or intervals required") + } else { + require.NoError(t, err) + } + }) + } +} diff --git a/backend/internal/handler/admin/dashboard_handler.go b/backend/internal/handler/admin/dashboard_handler.go index 2a214471..460f6357 100644 --- a/backend/internal/handler/admin/dashboard_handler.go +++ b/backend/internal/handler/admin/dashboard_handler.go @@ -636,6 +636,40 @@ func (h *DashboardHandler) GetUserBreakdown(c *gin.Context) { dim.Endpoint = c.Query("endpoint") dim.EndpointType = c.DefaultQuery("endpoint_type", "inbound") + // Additional filter conditions + if v := c.Query("user_id"); v != "" { + if id, err := strconv.ParseInt(v, 10, 64); err == nil { + dim.UserID = id + } + } + if v := c.Query("api_key_id"); v != "" { + if id, err := strconv.ParseInt(v, 10, 64); err == nil { + dim.APIKeyID = id + } + } + if v := c.Query("account_id"); v != "" { + if id, err := strconv.ParseInt(v, 10, 64); err == nil { + dim.AccountID = id + } + } + if v := c.Query("request_type"); v != "" { + if rt, err := strconv.ParseInt(v, 10, 16); err == nil { + rtVal := int16(rt) + dim.RequestType = &rtVal + } + } + if v := c.Query("stream"); v != "" { + if s, err := strconv.ParseBool(v); err == nil { + dim.Stream = &s + } + } + if v := c.Query("billing_type"); v != "" { + if bt, err := strconv.ParseInt(v, 10, 8); err == nil { + btVal := int8(bt) + dim.BillingType = &btVal + } + } + limit := 50 if v := c.Query("limit"); v != "" { if n, err := strconv.Atoi(v); err == nil && n > 0 && n <= 200 { diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 8f66ad03..651936c1 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -485,10 +485,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { RequestPayloadHash: requestPayloadHash, ForceCacheBilling: fs.ForceCacheBilling, APIKeyService: h.apiKeyService, - ChannelID: channelMapping.ChannelID, - OriginalModel: reqModel, - BillingModelSource: channelMapping.BillingModelSource, - ModelMappingChain: channelMapping.BuildModelMappingChain(reqModel, result.UpstreamModel), + ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel), }); err != nil { logger.L().With( zap.String("component", "handler.gateway.messages"), @@ -828,10 +825,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { RequestPayloadHash: requestPayloadHash, ForceCacheBilling: fs.ForceCacheBilling, APIKeyService: h.apiKeyService, - ChannelID: channelMapping.ChannelID, - OriginalModel: reqModel, - BillingModelSource: channelMapping.BillingModelSource, - ModelMappingChain: channelMapping.BuildModelMappingChain(reqModel, result.UpstreamModel), + ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel), }); err != nil { logger.L().With( zap.String("component", "handler.gateway.messages"), diff --git a/backend/internal/handler/gateway_handler_chat_completions.go b/backend/internal/handler/gateway_handler_chat_completions.go index f0f16131..b70582f6 100644 --- a/backend/internal/handler/gateway_handler_chat_completions.go +++ b/backend/internal/handler/gateway_handler_chat_completions.go @@ -266,10 +266,7 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) { IPAddress: clientIP, RequestPayloadHash: requestPayloadHash, APIKeyService: h.apiKeyService, - ChannelID: channelMapping.ChannelID, - OriginalModel: reqModel, - BillingModelSource: channelMapping.BillingModelSource, - ModelMappingChain: channelMapping.BuildModelMappingChain(reqModel, result.UpstreamModel), + ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel), }); err != nil { reqLog.Error("gateway.cc.record_usage_failed", zap.Int64("account_id", account.ID), diff --git a/backend/internal/handler/gateway_handler_responses.go b/backend/internal/handler/gateway_handler_responses.go index 1e9cdc02..d4ee905a 100644 --- a/backend/internal/handler/gateway_handler_responses.go +++ b/backend/internal/handler/gateway_handler_responses.go @@ -272,10 +272,7 @@ func (h *GatewayHandler) Responses(c *gin.Context) { IPAddress: clientIP, RequestPayloadHash: requestPayloadHash, APIKeyService: h.apiKeyService, - ChannelID: channelMapping.ChannelID, - OriginalModel: reqModel, - BillingModelSource: channelMapping.BillingModelSource, - ModelMappingChain: channelMapping.BuildModelMappingChain(reqModel, result.UpstreamModel), + ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel), }); err != nil { reqLog.Error("gateway.responses.record_usage_failed", zap.Int64("account_id", account.ID), diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index 7c1386b8..55556764 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -534,10 +534,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { LongContextMultiplier: 2.0, // 超出部分双倍计费 ForceCacheBilling: fs.ForceCacheBilling, APIKeyService: h.apiKeyService, - ChannelID: channelMapping.ChannelID, - OriginalModel: reqModel, - BillingModelSource: channelMapping.BillingModelSource, - ModelMappingChain: channelMapping.BuildModelMappingChain(reqModel, result.UpstreamModel), + ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel), }); err != nil { logger.L().With( zap.String("component", "handler.gemini_v1beta.models"), diff --git a/backend/internal/handler/openai_chat_completions.go b/backend/internal/handler/openai_chat_completions.go index a117c3be..20695e0e 100644 --- a/backend/internal/handler/openai_chat_completions.go +++ b/backend/internal/handler/openai_chat_completions.go @@ -278,10 +278,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { UserAgent: userAgent, IPAddress: clientIP, APIKeyService: h.apiKeyService, - ChannelID: channelMapping.ChannelID, - OriginalModel: reqModel, - BillingModelSource: channelMapping.BillingModelSource, - ModelMappingChain: channelMapping.BuildModelMappingChain(reqModel, result.UpstreamModel), + ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel), }); err != nil { logger.L().With( zap.String("component", "handler.openai_gateway.chat_completions"), diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index 70198a53..d1fc9b51 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -391,10 +391,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { IPAddress: clientIP, RequestPayloadHash: requestPayloadHash, APIKeyService: h.apiKeyService, - ChannelID: channelMapping.ChannelID, - OriginalModel: reqModel, - BillingModelSource: channelMapping.BillingModelSource, - ModelMappingChain: channelMapping.BuildModelMappingChain(reqModel, result.UpstreamModel), + ChannelUsageFields: channelMapping.ToUsageFields(reqModel, result.UpstreamModel), }); err != nil { logger.L().With( zap.String("component", "handler.openai_gateway.responses"), @@ -787,10 +784,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { IPAddress: clientIP, RequestPayloadHash: requestPayloadHash, APIKeyService: h.apiKeyService, - ChannelID: channelMappingMsg.ChannelID, - OriginalModel: reqModel, - BillingModelSource: channelMappingMsg.BillingModelSource, - ModelMappingChain: channelMappingMsg.BuildModelMappingChain(reqModel, result.UpstreamModel), + ChannelUsageFields: channelMappingMsg.ToUsageFields(reqModel, result.UpstreamModel), }); err != nil { logger.L().With( zap.String("component", "handler.openai_gateway.messages"), @@ -1298,10 +1292,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { IPAddress: clientIP, RequestPayloadHash: service.HashUsageRequestPayload(firstMessage), APIKeyService: h.apiKeyService, - ChannelID: channelMappingWS.ChannelID, - OriginalModel: reqModel, - BillingModelSource: channelMappingWS.BillingModelSource, - ModelMappingChain: channelMappingWS.BuildModelMappingChain(reqModel, result.UpstreamModel), + ChannelUsageFields: channelMappingWS.ToUsageFields(reqModel, result.UpstreamModel), }); err != nil { reqLog.Error("openai.websocket_record_usage_failed", zap.Int64("account_id", account.ID), diff --git a/backend/internal/handler/sora_client_handler_test.go b/backend/internal/handler/sora_client_handler_test.go index 57055786..32a6afcc 100644 --- a/backend/internal/handler/sora_client_handler_test.go +++ b/backend/internal/handler/sora_client_handler_test.go @@ -125,6 +125,13 @@ func (r *stubSoraGenRepo) CountByUserAndStatus(_ context.Context, _ int64, _ []s return r.countValue, nil } +func (r *stubSoraGenRepo) CountByStorageType(_ context.Context, _ string, _ []string) (int64, error) { + if r.countErr != nil { + return 0, r.countErr + } + return r.countValue, nil +} + // ==================== 辅助函数 ==================== func newTestSoraClientHandler(repo *stubSoraGenRepo) *SoraClientHandler { @@ -1657,8 +1664,8 @@ func TestStoreMediaWithDegradation_S3SuccessSingleURL(t *testing.T) { fakeS3 := newFakeS3Server("ok") defer fakeS3.Close() - s3Storage := newS3StorageForHandler(fakeS3.URL) - h := &SoraClientHandler{s3Storage: s3Storage} + objectStorage := newS3StorageForHandler(fakeS3.URL) + h := &SoraClientHandler{objectStorage: objectStorage} storedURL, storedURLs, storageType, s3Keys, fileSize := h.storeMediaWithDegradation( context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil, @@ -1679,8 +1686,8 @@ func TestStoreMediaWithDegradation_S3SuccessMultiURL(t *testing.T) { fakeS3 := newFakeS3Server("ok") defer fakeS3.Close() - s3Storage := newS3StorageForHandler(fakeS3.URL) - h := &SoraClientHandler{s3Storage: s3Storage} + objectStorage := newS3StorageForHandler(fakeS3.URL) + h := &SoraClientHandler{objectStorage: objectStorage} urls := []string{sourceServer.URL + "/a.mp4", sourceServer.URL + "/b.mp4"} storedURL, storedURLs, storageType, s3Keys, fileSize := h.storeMediaWithDegradation( @@ -1704,8 +1711,8 @@ func TestStoreMediaWithDegradation_S3DownloadFails(t *testing.T) { })) defer badSource.Close() - s3Storage := newS3StorageForHandler(fakeS3.URL) - h := &SoraClientHandler{s3Storage: s3Storage} + objectStorage := newS3StorageForHandler(fakeS3.URL) + h := &SoraClientHandler{objectStorage: objectStorage} _, _, storageType, _, _ := h.storeMediaWithDegradation( context.Background(), 1, "video", badSource.URL+"/missing.mp4", nil, @@ -1719,8 +1726,8 @@ func TestStoreMediaWithDegradation_S3FailsSingleURL(t *testing.T) { fakeS3 := newFakeS3Server("fail") defer fakeS3.Close() - s3Storage := newS3StorageForHandler(fakeS3.URL) - h := &SoraClientHandler{s3Storage: s3Storage} + objectStorage := newS3StorageForHandler(fakeS3.URL) + h := &SoraClientHandler{objectStorage: objectStorage} _, _, storageType, s3Keys, _ := h.storeMediaWithDegradation( context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil, @@ -1736,8 +1743,8 @@ func TestStoreMediaWithDegradation_S3PartialFailureCleanup(t *testing.T) { fakeS3 := newFakeS3Server("fail-second") defer fakeS3.Close() - s3Storage := newS3StorageForHandler(fakeS3.URL) - h := &SoraClientHandler{s3Storage: s3Storage} + objectStorage := newS3StorageForHandler(fakeS3.URL) + h := &SoraClientHandler{objectStorage: objectStorage} urls := []string{sourceServer.URL + "/a.mp4", sourceServer.URL + "/b.mp4"} _, _, storageType, s3Keys, _ := h.storeMediaWithDegradation( @@ -1808,7 +1815,7 @@ func TestStoreMediaWithDegradation_S3FailsFallbackToLocal(t *testing.T) { fakeS3 := newFakeS3Server("fail") defer fakeS3.Close() - s3Storage := newS3StorageForHandler(fakeS3.URL) + objectStorage := newS3StorageForHandler(fakeS3.URL) cfg := &config.Config{ Sora: config.SoraConfig{ Storage: config.SoraStorageConfig{ @@ -1821,8 +1828,8 @@ func TestStoreMediaWithDegradation_S3FailsFallbackToLocal(t *testing.T) { } mediaStorage := service.NewSoraMediaStorage(cfg) h := &SoraClientHandler{ - s3Storage: s3Storage, - mediaStorage: mediaStorage, + objectStorage: objectStorage, + mediaStorage: mediaStorage, } _, _, storageType, _, _ := h.storeMediaWithDegradation( @@ -1846,9 +1853,9 @@ func TestSaveToStorage_S3EnabledButUploadFails(t *testing.T) { StorageType: "upstream", MediaURL: sourceServer.URL + "/v.mp4", } - s3Storage := newS3StorageForHandler(fakeS3.URL) + objectStorage := newS3StorageForHandler(fakeS3.URL) genService := service.NewSoraGenerationService(repo, nil, nil) - h := &SoraClientHandler{genService: genService, s3Storage: s3Storage} + h := &SoraClientHandler{genService: genService, objectStorage: objectStorage} c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) c.Params = gin.Params{{Key: "id", Value: "1"}} @@ -1872,9 +1879,9 @@ func TestSaveToStorage_UpstreamURLExpired(t *testing.T) { StorageType: "upstream", MediaURL: expiredServer.URL + "/v.mp4", } - s3Storage := newS3StorageForHandler(fakeS3.URL) + objectStorage := newS3StorageForHandler(fakeS3.URL) genService := service.NewSoraGenerationService(repo, nil, nil) - h := &SoraClientHandler{genService: genService, s3Storage: s3Storage} + h := &SoraClientHandler{genService: genService, objectStorage: objectStorage} c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) c.Params = gin.Params{{Key: "id", Value: "1"}} @@ -1896,9 +1903,9 @@ func TestSaveToStorage_S3EnabledUploadSuccess(t *testing.T) { StorageType: "upstream", MediaURL: sourceServer.URL + "/v.mp4", } - s3Storage := newS3StorageForHandler(fakeS3.URL) + objectStorage := newS3StorageForHandler(fakeS3.URL) genService := service.NewSoraGenerationService(repo, nil, nil) - h := &SoraClientHandler{genService: genService, s3Storage: s3Storage} + h := &SoraClientHandler{genService: genService, objectStorage: objectStorage} c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) c.Params = gin.Params{{Key: "id", Value: "1"}} @@ -1906,7 +1913,7 @@ func TestSaveToStorage_S3EnabledUploadSuccess(t *testing.T) { require.Equal(t, http.StatusOK, rec.Code) resp := parseResponse(t, rec) data := resp["data"].(map[string]any) - require.Contains(t, data["message"], "S3") + require.Contains(t, data["message"], "云存储") require.NotEmpty(t, data["object_key"]) // 验证记录已更新为 S3 存储 require.Equal(t, service.SoraStorageTypeS3, repo.gens[1].StorageType) @@ -1928,9 +1935,9 @@ func TestSaveToStorage_S3EnabledUploadSuccess_MultiMediaURLs(t *testing.T) { sourceServer.URL + "/v2.mp4", }, } - s3Storage := newS3StorageForHandler(fakeS3.URL) + objectStorage := newS3StorageForHandler(fakeS3.URL) genService := service.NewSoraGenerationService(repo, nil, nil) - h := &SoraClientHandler{genService: genService, s3Storage: s3Storage} + h := &SoraClientHandler{genService: genService, objectStorage: objectStorage} c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) c.Params = gin.Params{{Key: "id", Value: "1"}} @@ -1956,7 +1963,7 @@ func TestSaveToStorage_S3EnabledUploadSuccessWithQuota(t *testing.T) { StorageType: "upstream", MediaURL: sourceServer.URL + "/v.mp4", } - s3Storage := newS3StorageForHandler(fakeS3.URL) + objectStorage := newS3StorageForHandler(fakeS3.URL) genService := service.NewSoraGenerationService(repo, nil, nil) userRepo := newStubUserRepoForHandler() @@ -1966,7 +1973,7 @@ func TestSaveToStorage_S3EnabledUploadSuccessWithQuota(t *testing.T) { SoraStorageUsedBytes: 0, } quotaService := service.NewSoraQuotaService(userRepo, nil, nil) - h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService} + h := &SoraClientHandler{genService: genService, objectStorage: objectStorage, quotaService: quotaService} c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) c.Params = gin.Params{{Key: "id", Value: "1"}} @@ -1990,9 +1997,9 @@ func TestSaveToStorage_S3UploadSuccessMarkCompletedFails(t *testing.T) { } // S3 上传成功后,MarkCompleted 会调用 repo.Update → 失败 repo.updateErr = fmt.Errorf("db error") - s3Storage := newS3StorageForHandler(fakeS3.URL) + objectStorage := newS3StorageForHandler(fakeS3.URL) genService := service.NewSoraGenerationService(repo, nil, nil) - h := &SoraClientHandler{genService: genService, s3Storage: s3Storage} + h := &SoraClientHandler{genService: genService, objectStorage: objectStorage} c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) c.Params = gin.Params{{Key: "id", Value: "1"}} @@ -2007,8 +2014,8 @@ func TestGetStorageStatus_S3EnabledNotHealthy(t *testing.T) { fakeS3 := newFakeS3Server("fail") defer fakeS3.Close() - s3Storage := newS3StorageForHandler(fakeS3.URL) - h := &SoraClientHandler{s3Storage: s3Storage} + objectStorage := newS3StorageForHandler(fakeS3.URL) + h := &SoraClientHandler{objectStorage: objectStorage} c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0) h.GetStorageStatus(c) @@ -2023,8 +2030,8 @@ func TestGetStorageStatus_S3EnabledHealthy(t *testing.T) { fakeS3 := newFakeS3Server("ok") defer fakeS3.Close() - s3Storage := newS3StorageForHandler(fakeS3.URL) - h := &SoraClientHandler{s3Storage: s3Storage} + objectStorage := newS3StorageForHandler(fakeS3.URL) + h := &SoraClientHandler{objectStorage: objectStorage} c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0) h.GetStorageStatus(c) @@ -2453,7 +2460,7 @@ func TestProcessGeneration_FullSuccessWithS3(t *testing.T) { }, } soraGatewayService := newMinimalSoraGatewayService(soraClient) - s3Storage := newS3StorageForHandler(fakeS3.URL) + objectStorage := newS3StorageForHandler(fakeS3.URL) userRepo := newStubUserRepoForHandler() userRepo.users[1] = &service.User{ @@ -2465,7 +2472,7 @@ func TestProcessGeneration_FullSuccessWithS3(t *testing.T) { genService: genService, gatewayService: gatewayService, soraGatewayService: soraGatewayService, - s3Storage: s3Storage, + objectStorage: objectStorage, quotaService: quotaService, } @@ -2515,7 +2522,7 @@ func TestProcessGeneration_MarkCompletedFails(t *testing.T) { // ==================== cleanupStoredMedia 直接测试 ==================== func TestCleanupStoredMedia_S3Path(t *testing.T) { - // S3 清理路径:s3Storage 为 nil 时不 panic + // S3 清理路径:objectStorage 为 nil 时不 panic h := &SoraClientHandler{} // 不应 panic h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, []string{"key1"}, nil) @@ -2962,7 +2969,7 @@ func TestSaveToStorage_QuotaExceeded(t *testing.T) { StorageType: "upstream", MediaURL: sourceServer.URL + "/v.mp4", } - s3Storage := newS3StorageForHandler(fakeS3.URL) + objectStorage := newS3StorageForHandler(fakeS3.URL) genService := service.NewSoraGenerationService(repo, nil, nil) // 用户配额已满 @@ -2973,7 +2980,7 @@ func TestSaveToStorage_QuotaExceeded(t *testing.T) { SoraStorageUsedBytes: 10, } quotaService := service.NewSoraQuotaService(userRepo, nil, nil) - h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService} + h := &SoraClientHandler{genService: genService, objectStorage: objectStorage, quotaService: quotaService} c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) c.Params = gin.Params{{Key: "id", Value: "1"}} @@ -2995,13 +3002,13 @@ func TestSaveToStorage_QuotaNonQuotaError(t *testing.T) { StorageType: "upstream", MediaURL: sourceServer.URL + "/v.mp4", } - s3Storage := newS3StorageForHandler(fakeS3.URL) + objectStorage := newS3StorageForHandler(fakeS3.URL) genService := service.NewSoraGenerationService(repo, nil, nil) // 用户不存在 → GetByID 失败 → AddUsage 返回普通 error userRepo := newStubUserRepoForHandler() quotaService := service.NewSoraQuotaService(userRepo, nil, nil) - h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService} + h := &SoraClientHandler{genService: genService, objectStorage: objectStorage, quotaService: quotaService} c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) c.Params = gin.Params{{Key: "id", Value: "1"}} @@ -3022,9 +3029,9 @@ func TestSaveToStorage_EmptyMediaURLs(t *testing.T) { MediaURL: "", MediaURLs: []string{}, } - s3Storage := newS3StorageForHandler(fakeS3.URL) + objectStorage := newS3StorageForHandler(fakeS3.URL) genService := service.NewSoraGenerationService(repo, nil, nil) - h := &SoraClientHandler{genService: genService, s3Storage: s3Storage} + h := &SoraClientHandler{genService: genService, objectStorage: objectStorage} c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) c.Params = gin.Params{{Key: "id", Value: "1"}} @@ -3049,9 +3056,9 @@ func TestSaveToStorage_MultiURL_SecondUploadFails(t *testing.T) { MediaURL: sourceServer.URL + "/v1.mp4", MediaURLs: []string{sourceServer.URL + "/v1.mp4", sourceServer.URL + "/v2.mp4"}, } - s3Storage := newS3StorageForHandler(fakeS3.URL) + objectStorage := newS3StorageForHandler(fakeS3.URL) genService := service.NewSoraGenerationService(repo, nil, nil) - h := &SoraClientHandler{genService: genService, s3Storage: s3Storage} + h := &SoraClientHandler{genService: genService, objectStorage: objectStorage} c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) c.Params = gin.Params{{Key: "id", Value: "1"}} @@ -3074,7 +3081,7 @@ func TestSaveToStorage_MarkCompletedFailsWithQuotaRollback(t *testing.T) { MediaURL: sourceServer.URL + "/v.mp4", } repo.updateErr = fmt.Errorf("db error") - s3Storage := newS3StorageForHandler(fakeS3.URL) + objectStorage := newS3StorageForHandler(fakeS3.URL) genService := service.NewSoraGenerationService(repo, nil, nil) userRepo := newStubUserRepoForHandler() @@ -3084,7 +3091,7 @@ func TestSaveToStorage_MarkCompletedFailsWithQuotaRollback(t *testing.T) { SoraStorageUsedBytes: 0, } quotaService := service.NewSoraQuotaService(userRepo, nil, nil) - h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService} + h := &SoraClientHandler{genService: genService, objectStorage: objectStorage, quotaService: quotaService} c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) c.Params = gin.Params{{Key: "id", Value: "1"}} @@ -3097,8 +3104,8 @@ func TestSaveToStorage_MarkCompletedFailsWithQuotaRollback(t *testing.T) { func TestCleanupStoredMedia_WithS3Storage_ActualDelete(t *testing.T) { fakeS3 := newFakeS3Server("ok") defer fakeS3.Close() - s3Storage := newS3StorageForHandler(fakeS3.URL) - h := &SoraClientHandler{s3Storage: s3Storage} + objectStorage := newS3StorageForHandler(fakeS3.URL) + h := &SoraClientHandler{objectStorage: objectStorage} h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, []string{"key1", "key2"}, nil) } @@ -3106,8 +3113,8 @@ func TestCleanupStoredMedia_WithS3Storage_ActualDelete(t *testing.T) { func TestCleanupStoredMedia_S3DeleteFails_LogOnly(t *testing.T) { fakeS3 := newFakeS3Server("fail") defer fakeS3.Close() - s3Storage := newS3StorageForHandler(fakeS3.URL) - h := &SoraClientHandler{s3Storage: s3Storage} + objectStorage := newS3StorageForHandler(fakeS3.URL) + h := &SoraClientHandler{objectStorage: objectStorage} h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, []string{"key1"}, nil) } diff --git a/backend/internal/handler/sora_gateway_handler.go b/backend/internal/handler/sora_gateway_handler.go index 5e505409..c99a0de9 100644 --- a/backend/internal/handler/sora_gateway_handler.go +++ b/backend/internal/handler/sora_gateway_handler.go @@ -30,6 +30,8 @@ import ( ) // SoraGatewayHandler handles Sora chat completions requests +// +// NOTE: Sora 平台计划后续移除,不集成渠道(Channel)功能。 type SoraGatewayHandler struct { gatewayService *service.GatewayService soraGatewayService *service.SoraGatewayService diff --git a/backend/internal/pkg/usagestats/usage_log_types.go b/backend/internal/pkg/usagestats/usage_log_types.go index eb2c81d3..5d1f7911 100644 --- a/backend/internal/pkg/usagestats/usage_log_types.go +++ b/backend/internal/pkg/usagestats/usage_log_types.go @@ -175,6 +175,13 @@ type UserBreakdownDimension struct { ModelType string // "requested", "upstream", or "mapping" Endpoint string // filter by endpoint value (non-empty to enable) EndpointType string // "inbound", "upstream", or "path" + // Additional filter conditions + UserID int64 // filter by user_id (>0 to enable) + APIKeyID int64 // filter by api_key_id (>0 to enable) + AccountID int64 // filter by account_id (>0 to enable) + RequestType *int16 // filter by request_type (non-nil to enable) + Stream *bool // filter by stream flag (non-nil to enable) + BillingType *int8 // filter by billing_type (non-nil to enable) } // APIKeyUsageTrendPoint represents API key usage trend data point diff --git a/backend/internal/repository/channel_repo_pricing.go b/backend/internal/repository/channel_repo_pricing.go index 73887617..6dcf3c91 100644 --- a/backend/internal/repository/channel_repo_pricing.go +++ b/backend/internal/repository/channel_repo_pricing.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "encoding/json" + "errors" "fmt" "strings" @@ -274,7 +275,8 @@ func replaceModelPricingTx(ctx context.Context, exec dbExec, channelID int64, pr // isUniqueViolation 检查 pq 唯一约束违反错误 func isUniqueViolation(err error) bool { - if pqErr, ok := err.(*pq.Error); ok { + var pqErr *pq.Error + if errors.As(err, &pqErr) && pqErr != nil { return pqErr.Code == "23505" } return false diff --git a/backend/internal/repository/channel_repo_test.go b/backend/internal/repository/channel_repo_test.go new file mode 100644 index 00000000..86a09afe --- /dev/null +++ b/backend/internal/repository/channel_repo_test.go @@ -0,0 +1,227 @@ +//go:build unit + +package repository + +import ( + "encoding/json" + "errors" + "fmt" + "testing" + + "github.com/lib/pq" + "github.com/stretchr/testify/require" +) + +// --- marshalModelMapping --- + +func TestMarshalModelMapping(t *testing.T) { + tests := []struct { + name string + input map[string]map[string]string + wantJSON string // expected JSON output (exact match) + }{ + { + name: "empty map", + input: map[string]map[string]string{}, + wantJSON: "{}", + }, + { + name: "nil map", + input: nil, + wantJSON: "{}", + }, + { + name: "populated map", + input: map[string]map[string]string{ + "openai": {"gpt-4": "gpt-4-turbo"}, + }, + }, + { + name: "nested values", + input: map[string]map[string]string{ + "openai": {"*": "gpt-5.4"}, + "anthropic": {"claude-old": "claude-new"}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := marshalModelMapping(tt.input) + require.NoError(t, err) + + if tt.wantJSON != "" { + require.Equal(t, []byte(tt.wantJSON), result) + } else { + // round-trip: unmarshal and compare with input + var parsed map[string]map[string]string + require.NoError(t, json.Unmarshal(result, &parsed)) + require.Equal(t, tt.input, parsed) + } + }) + } +} + +// --- unmarshalModelMapping --- + +func TestUnmarshalModelMapping(t *testing.T) { + tests := []struct { + name string + input []byte + wantNil bool + want map[string]map[string]string + }{ + { + name: "nil data", + input: nil, + wantNil: true, + }, + { + name: "empty data", + input: []byte{}, + wantNil: true, + }, + { + name: "invalid JSON", + input: []byte("not-json"), + wantNil: true, + }, + { + name: "type error - number", + input: []byte("42"), + wantNil: true, + }, + { + name: "type error - array", + input: []byte("[1,2,3]"), + wantNil: true, + }, + { + name: "valid JSON", + input: []byte(`{"openai":{"gpt-4":"gpt-4-turbo"},"anthropic":{"old":"new"}}`), + want: map[string]map[string]string{ + "openai": {"gpt-4": "gpt-4-turbo"}, + "anthropic": {"old": "new"}, + }, + }, + { + name: "empty object", + input: []byte("{}"), + want: map[string]map[string]string{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := unmarshalModelMapping(tt.input) + if tt.wantNil { + require.Nil(t, result) + } else { + require.NotNil(t, result) + require.Equal(t, tt.want, result) + } + }) + } +} + +// --- escapeLike --- + +func TestEscapeLike(t *testing.T) { + tests := []struct { + name string + input string + want string + }{ + { + name: "no special chars", + input: "hello", + want: "hello", + }, + { + name: "backslash", + input: `a\b`, + want: `a\\b`, + }, + { + name: "percent", + input: "50%", + want: `50\%`, + }, + { + name: "underscore", + input: "a_b", + want: `a\_b`, + }, + { + name: "all special chars", + input: `a\b%c_d`, + want: `a\\b\%c\_d`, + }, + { + name: "empty string", + input: "", + want: "", + }, + { + name: "consecutive special chars", + input: "%_%", + want: `\%\_\%`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.want, escapeLike(tt.input)) + }) + } +} + +// --- isUniqueViolation --- + +func TestIsUniqueViolation(t *testing.T) { + tests := []struct { + name string + err error + want bool + }{ + { + name: "unique violation code 23505", + err: &pq.Error{Code: "23505"}, + want: true, + }, + { + name: "different pq error code", + err: &pq.Error{Code: "23503"}, + want: false, + }, + { + name: "non-pq error", + err: errors.New("some generic error"), + want: false, + }, + { + name: "typed nil pq.Error", + err: func() error { + var pqErr *pq.Error + return pqErr + }(), + want: false, + }, + { + name: "bare nil", + err: nil, + want: false, + }, + { + name: "wrapped pq error with 23505", + err: fmt.Errorf("wrapped: %w", &pq.Error{Code: "23505"}), + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.want, isUniqueViolation(tt.err)) + }) + } +} diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go index 376f1029..bb5839f4 100644 --- a/backend/internal/repository/usage_log_repo.go +++ b/backend/internal/repository/usage_log_repo.go @@ -3144,6 +3144,30 @@ func (r *usageLogRepository) GetUserBreakdownStats(ctx context.Context, startTim query += fmt.Sprintf(" AND %s = $%d", col, len(args)+1) args = append(args, dim.Endpoint) } + if dim.UserID > 0 { + query += fmt.Sprintf(" AND ul.user_id = $%d", len(args)+1) + args = append(args, dim.UserID) + } + if dim.APIKeyID > 0 { + query += fmt.Sprintf(" AND ul.api_key_id = $%d", len(args)+1) + args = append(args, dim.APIKeyID) + } + if dim.AccountID > 0 { + query += fmt.Sprintf(" AND ul.account_id = $%d", len(args)+1) + args = append(args, dim.AccountID) + } + if dim.RequestType != nil { + query += fmt.Sprintf(" AND ul.request_type = $%d", len(args)+1) + args = append(args, *dim.RequestType) + } + if dim.Stream != nil { + query += fmt.Sprintf(" AND ul.stream = $%d", len(args)+1) + args = append(args, *dim.Stream) + } + if dim.BillingType != nil { + query += fmt.Sprintf(" AND ul.billing_type = $%d", len(args)+1) + args = append(args, *dim.BillingType) + } query += " GROUP BY ul.user_id, u.email ORDER BY actual_cost DESC" if limit > 0 { diff --git a/backend/internal/repository/usage_log_repo_request_type_test.go b/backend/internal/repository/usage_log_repo_request_type_test.go index ebc8929a..cb029bd6 100644 --- a/backend/internal/repository/usage_log_repo_request_type_test.go +++ b/backend/internal/repository/usage_log_repo_request_type_test.go @@ -80,6 +80,10 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) { sqlmock.AnyArg(), // inbound_endpoint sqlmock.AnyArg(), // upstream_endpoint log.CacheTTLOverridden, + sqlmock.AnyArg(), // channel_id + sqlmock.AnyArg(), // model_mapping_chain + sqlmock.AnyArg(), // billing_tier + sqlmock.AnyArg(), // billing_mode createdAt, ). WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(99), createdAt)) @@ -153,6 +157,10 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) { sqlmock.AnyArg(), sqlmock.AnyArg(), log.CacheTTLOverridden, + sqlmock.AnyArg(), // channel_id + sqlmock.AnyArg(), // model_mapping_chain + sqlmock.AnyArg(), // billing_tier + sqlmock.AnyArg(), // billing_mode createdAt, ). WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(100), createdAt)) @@ -463,6 +471,10 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { sql.NullString{}, sql.NullString{}, false, + sql.NullInt64{}, // channel_id + sql.NullString{}, // model_mapping_chain + sql.NullString{}, // billing_tier + sql.NullString{}, // billing_mode now, }}) require.NoError(t, err) @@ -506,6 +518,10 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { sql.NullString{}, sql.NullString{}, false, + sql.NullInt64{}, // channel_id + sql.NullString{}, // model_mapping_chain + sql.NullString{}, // billing_tier + sql.NullString{}, // billing_mode now, }}) require.NoError(t, err) @@ -549,6 +565,10 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { sql.NullString{}, sql.NullString{}, false, + sql.NullInt64{}, // channel_id + sql.NullString{}, // model_mapping_chain + sql.NullString{}, // billing_tier + sql.NullString{}, // billing_mode now, }}) require.NoError(t, err) diff --git a/backend/internal/service/channel.go b/backend/internal/service/channel.go index 40a137c1..9530a837 100644 --- a/backend/internal/service/channel.go +++ b/backend/internal/service/channel.go @@ -51,15 +51,15 @@ type Channel struct { type ChannelModelPricing struct { ID int64 ChannelID int64 - Platform string // 所属平台(anthropic/openai/gemini/...) - Models []string // 绑定的模型列表 - BillingMode BillingMode // 计费模式 - InputPrice *float64 // 每 token 输入价格(USD)— 向后兼容 flat 定价 - OutputPrice *float64 // 每 token 输出价格(USD) - CacheWritePrice *float64 // 缓存写入价格 - CacheReadPrice *float64 // 缓存读取价格 - ImageOutputPrice *float64 // 图片输出价格(向后兼容) - PerRequestPrice *float64 // 默认按次计费价格(USD) + Platform string // 所属平台(anthropic/openai/gemini/...) + Models []string // 绑定的模型列表 + BillingMode BillingMode // 计费模式 + InputPrice *float64 // 每 token 输入价格(USD)— 向后兼容 flat 定价 + OutputPrice *float64 // 每 token 输出价格(USD) + CacheWritePrice *float64 // 缓存写入价格 + CacheReadPrice *float64 // 缓存读取价格 + ImageOutputPrice *float64 // 图片输出价格(向后兼容) + PerRequestPrice *float64 // 默认按次计费价格(USD) Intervals []PricingInterval // 区间定价列表 CreatedAt time.Time UpdatedAt time.Time @@ -175,3 +175,11 @@ func (c *Channel) Clone() *Channel { } return &cp } + +// ChannelUsageFields 渠道相关的使用记录字段(嵌入到各平台的 RecordUsageInput 中) +type ChannelUsageFields struct { + ChannelID int64 // 渠道 ID(0 = 无渠道) + OriginalModel string // 用户原始请求模型(渠道映射前) + BillingModelSource string // 计费模型来源:"requested" / "upstream" + ModelMappingChain string // 映射链描述,如 "a→b→c" +} diff --git a/backend/internal/service/channel_service.go b/backend/internal/service/channel_service.go index a1ed7100..b36169d0 100644 --- a/backend/internal/service/channel_service.go +++ b/backend/internal/service/channel_service.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "log/slog" - "sort" "strings" "sync/atomic" "time" @@ -17,8 +16,8 @@ import ( ) var ( - ErrChannelNotFound = infraerrors.NotFound("CHANNEL_NOT_FOUND", "channel not found") - ErrChannelExists = infraerrors.Conflict("CHANNEL_EXISTS", "channel name already exists") + ErrChannelNotFound = infraerrors.NotFound("CHANNEL_NOT_FOUND", "channel not found") + ErrChannelExists = infraerrors.Conflict("CHANNEL_EXISTS", "channel name already exists") ErrGroupAlreadyInChannel = infraerrors.Conflict( "GROUP_ALREADY_IN_CHANNEL", "one or more groups already belong to another channel", @@ -81,12 +80,12 @@ type wildcardMappingEntry struct { // channelCache 渠道缓存快照(扁平化哈希结构,热路径 O(1) 查找) type channelCache struct { // 热路径查找 - pricingByGroupModel map[channelModelKey]*ChannelModelPricing // (groupID, platform, model) → 定价 + pricingByGroupModel map[channelModelKey]*ChannelModelPricing // (groupID, platform, model) → 定价 wildcardByGroupPlatform map[channelGroupPlatformKey][]*wildcardPricingEntry // (groupID, platform) → 通配符定价(前缀长度降序) - mappingByGroupModel map[channelModelKey]string // (groupID, platform, model) → 映射目标 - wildcardMappingByGP map[channelGroupPlatformKey][]*wildcardMappingEntry // (groupID, platform) → 通配符映射(前缀长度降序) - channelByGroupID map[int64]*Channel // groupID → 渠道 - groupPlatform map[int64]string // groupID → platform + mappingByGroupModel map[channelModelKey]string // (groupID, platform, model) → 映射目标 + wildcardMappingByGP map[channelGroupPlatformKey][]*wildcardMappingEntry // (groupID, platform) → 通配符映射(前缀长度降序) + channelByGroupID map[int64]*Channel // groupID → 渠道 + groupPlatform map[int64]string // groupID → platform // 冷路径(CRUD 操作) byID map[int64]*Channel @@ -118,9 +117,19 @@ func (r ChannelMappingResult) BuildModelMappingChain(reqModel, upstreamModel str return reqModel + "→" + r.MappedModel } +// ToUsageFields 将渠道映射结果转为使用记录字段 +func (r ChannelMappingResult) ToUsageFields(reqModel, upstreamModel string) ChannelUsageFields { + return ChannelUsageFields{ + ChannelID: r.ChannelID, + OriginalModel: reqModel, + BillingModelSource: r.BillingModelSource, + ModelMappingChain: r.BuildModelMappingChain(reqModel, upstreamModel), + } +} + const ( - channelCacheTTL = 60 * time.Second - channelErrorTTL = 5 * time.Second // DB 错误时的短缓存 + channelCacheTTL = 60 * time.Second + channelErrorTTL = 5 * time.Second // DB 错误时的短缓存 channelCacheDBTimeout = 10 * time.Second ) @@ -177,14 +186,14 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error) // error-TTL:失败时存入短 TTL 空缓存,防止紧密重试 slog.Warn("failed to build channel cache", "error", err) errorCache := &channelCache{ - pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing), + pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing), wildcardByGroupPlatform: make(map[channelGroupPlatformKey][]*wildcardPricingEntry), - mappingByGroupModel: make(map[channelModelKey]string), - wildcardMappingByGP: make(map[channelGroupPlatformKey][]*wildcardMappingEntry), - channelByGroupID: make(map[int64]*Channel), - groupPlatform: make(map[int64]string), - byID: make(map[int64]*Channel), - loadedAt: time.Now().Add(channelCacheTTL - channelErrorTTL), // 使剩余 TTL = errorTTL + mappingByGroupModel: make(map[channelModelKey]string), + wildcardMappingByGP: make(map[channelGroupPlatformKey][]*wildcardMappingEntry), + channelByGroupID: make(map[int64]*Channel), + groupPlatform: make(map[int64]string), + byID: make(map[int64]*Channel), + loadedAt: time.Now().Add(channelCacheTTL - channelErrorTTL), // 使剩余 TTL = errorTTL } s.cache.Store(errorCache) return nil, fmt.Errorf("list all channels: %w", err) @@ -205,14 +214,14 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error) } cache := &channelCache{ - pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing), + pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing), wildcardByGroupPlatform: make(map[channelGroupPlatformKey][]*wildcardPricingEntry), - mappingByGroupModel: make(map[channelModelKey]string), - wildcardMappingByGP: make(map[channelGroupPlatformKey][]*wildcardMappingEntry), - channelByGroupID: make(map[int64]*Channel), - groupPlatform: groupPlatforms, - byID: make(map[int64]*Channel, len(channels)), - loadedAt: time.Now(), + mappingByGroupModel: make(map[channelModelKey]string), + wildcardMappingByGP: make(map[channelGroupPlatformKey][]*wildcardMappingEntry), + channelByGroupID: make(map[int64]*Channel), + groupPlatform: groupPlatforms, + byID: make(map[int64]*Channel, len(channels)), + loadedAt: time.Now(), } for i := range channels { @@ -266,19 +275,7 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error) } } - // 通配符条目按前缀长度降序排列(最长前缀优先匹配) - for gpKey, entries := range cache.wildcardByGroupPlatform { - sort.Slice(entries, func(i, j int) bool { - return len(entries[i].prefix) > len(entries[j].prefix) - }) - cache.wildcardByGroupPlatform[gpKey] = entries - } - for gpKey, entries := range cache.wildcardMappingByGP { - sort.Slice(entries, func(i, j int) bool { - return len(entries[i].prefix) > len(entries[j].prefix) - }) - cache.wildcardMappingByGP[gpKey] = entries - } + // 通配符条目保持配置顺序(最先匹配到优先) s.cache.Store(cache) return cache, nil @@ -290,7 +287,7 @@ func (s *ChannelService) invalidateCache() { s.cacheSF.Forget("channel_cache") } -// matchWildcard 在通配符定价中查找匹配项(最长前缀优先) +// matchWildcard 在通配符定价中查找匹配项(最先匹配到优先) func (c *channelCache) matchWildcard(groupID int64, platform, modelLower string) *ChannelModelPricing { gpKey := channelGroupPlatformKey{groupID: groupID, platform: platform} wildcards := c.wildcardByGroupPlatform[gpKey] @@ -302,7 +299,7 @@ func (c *channelCache) matchWildcard(groupID int64, platform, modelLower string) return nil } -// matchWildcardMapping 在通配符映射中查找匹配项(最长前缀优先) +// matchWildcardMapping 在通配符映射中查找匹配项(最先匹配到优先) func (c *channelCache) matchWildcardMapping(groupID int64, platform, modelLower string) string { gpKey := channelGroupPlatformKey{groupID: groupID, platform: platform} wildcards := c.wildcardMappingByGP[gpKey] @@ -479,15 +476,18 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput) Status: StatusActive, BillingModelSource: input.BillingModelSource, RestrictModels: input.RestrictModels, - GroupIDs: input.GroupIDs, - ModelPricing: input.ModelPricing, - ModelMapping: input.ModelMapping, + GroupIDs: input.GroupIDs, + ModelPricing: input.ModelPricing, + ModelMapping: input.ModelMapping, } if channel.BillingModelSource == "" { channel.BillingModelSource = BillingModelSourceRequested } - if err := validateNoDuplicateModels(channel.ModelPricing); err != nil { + if err := validateNoConflictingModels(channel.ModelPricing); err != nil { + return nil, err + } + if err := validateNoConflictingMappings(channel.ModelMapping); err != nil { return nil, err } @@ -558,7 +558,10 @@ func (s *ChannelService) Update(ctx context.Context, id int64, input *UpdateChan channel.BillingModelSource = input.BillingModelSource } - if err := validateNoDuplicateModels(channel.ModelPricing); err != nil { + if err := validateNoConflictingModels(channel.ModelPricing); err != nil { + return nil, err + } + if err := validateNoConflictingMappings(channel.ModelMapping); err != nil { return nil, err } @@ -610,16 +613,79 @@ func (s *ChannelService) List(ctx context.Context, params pagination.PaginationP return s.repo.List(ctx, params, status, search) } -// validateNoDuplicateModels 检查定价列表中是否有重复模型(同一平台下不允许重复) -func validateNoDuplicateModels(pricingList []ChannelModelPricing) error { - seen := make(map[string]bool) +// modelEntry 表示一个模型模式条目(用于冲突检测) +type modelEntry struct { + pattern string // 原始模式(如 "claude-*" 或 "claude-opus-4") + prefix string // lowercase 前缀(通配符去掉 *,精确名保持原样) + wildcard bool +} + +// conflictsBetween 检查两个模型模式是否冲突 +func conflictsBetween(a, b modelEntry) bool { + switch { + case !a.wildcard && !b.wildcard: + return a.prefix == b.prefix + case a.wildcard && !b.wildcard: + return strings.HasPrefix(b.prefix, a.prefix) + case !a.wildcard && b.wildcard: + return strings.HasPrefix(a.prefix, b.prefix) + default: + return strings.HasPrefix(a.prefix, b.prefix) || + strings.HasPrefix(b.prefix, a.prefix) + } +} + +// toModelEntry 将模型名转换为 modelEntry +func toModelEntry(pattern string) modelEntry { + lower := strings.ToLower(pattern) + isWild := strings.HasSuffix(lower, "*") + prefix := lower + if isWild { + prefix = strings.TrimSuffix(lower, "*") + } + return modelEntry{pattern: pattern, prefix: prefix, wildcard: isWild} +} + +// validateNoConflictingModels 检查定价列表中是否有冲突模型模式(同一平台下)。 +// 冲突包括:精确重复、通配符之间的前缀包含、通配符与精确名的前缀匹配。 +func validateNoConflictingModels(pricingList []ChannelModelPricing) error { + byPlatform := make(map[string][]modelEntry) for _, p := range pricingList { for _, model := range p.Models { - key := p.Platform + ":" + strings.ToLower(model) - if seen[key] { - return infraerrors.BadRequest("DUPLICATE_MODEL", fmt.Sprintf("model '%s' appears in multiple pricing entries for platform '%s'", model, p.Platform)) + byPlatform[p.Platform] = append(byPlatform[p.Platform], toModelEntry(model)) + } + } + for platform, entries := range byPlatform { + if err := detectConflicts(entries, platform, "MODEL_PATTERN_CONFLICT", "model patterns"); err != nil { + return err + } + } + return nil +} + +// validateNoConflictingMappings 检查模型映射中是否有冲突的源模式 +func validateNoConflictingMappings(mapping map[string]map[string]string) error { + for platform, platformMapping := range mapping { + entries := make([]modelEntry, 0, len(platformMapping)) + for src := range platformMapping { + entries = append(entries, toModelEntry(src)) + } + if err := detectConflicts(entries, platform, "MAPPING_PATTERN_CONFLICT", "mapping source patterns"); err != nil { + return err + } + } + return nil +} + +// detectConflicts 在一组 modelEntry 中检测冲突,返回带有 errCode 和 label 的错误 +func detectConflicts(entries []modelEntry, platform, errCode, label string) error { + for i := 0; i < len(entries); i++ { + for j := i + 1; j < len(entries); j++ { + if conflictsBetween(entries[i], entries[j]) { + return infraerrors.BadRequest(errCode, + fmt.Sprintf("%s '%s' and '%s' conflict in platform '%s': overlapping match range", + label, entries[i].pattern, entries[j].pattern, platform)) } - seen[key] = true } } return nil diff --git a/backend/internal/service/channel_service_test.go b/backend/internal/service/channel_service_test.go new file mode 100644 index 00000000..a3d41ecd --- /dev/null +++ b/backend/internal/service/channel_service_test.go @@ -0,0 +1,1890 @@ +//go:build unit + +package service + +import ( + "context" + "errors" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// Mock: ChannelRepository +// --------------------------------------------------------------------------- + +type mockChannelRepository struct { + listAllFn func(ctx context.Context) ([]Channel, error) + getGroupPlatformsFn func(ctx context.Context, groupIDs []int64) (map[int64]string, error) + createFn func(ctx context.Context, channel *Channel) error + getByIDFn func(ctx context.Context, id int64) (*Channel, error) + updateFn func(ctx context.Context, channel *Channel) error + deleteFn func(ctx context.Context, id int64) error + listFn func(ctx context.Context, params pagination.PaginationParams, status, search string) ([]Channel, *pagination.PaginationResult, error) + existsByNameFn func(ctx context.Context, name string) (bool, error) + existsByNameExcludingFn func(ctx context.Context, name string, excludeID int64) (bool, error) + getGroupIDsFn func(ctx context.Context, channelID int64) ([]int64, error) + setGroupIDsFn func(ctx context.Context, channelID int64, groupIDs []int64) error + getChannelIDByGroupIDFn func(ctx context.Context, groupID int64) (int64, error) + getGroupsInOtherChannelsFn func(ctx context.Context, channelID int64, groupIDs []int64) ([]int64, error) + listModelPricingFn func(ctx context.Context, channelID int64) ([]ChannelModelPricing, error) + createModelPricingFn func(ctx context.Context, pricing *ChannelModelPricing) error + updateModelPricingFn func(ctx context.Context, pricing *ChannelModelPricing) error + deleteModelPricingFn func(ctx context.Context, id int64) error + replaceModelPricingFn func(ctx context.Context, channelID int64, pricingList []ChannelModelPricing) error +} + +func (m *mockChannelRepository) Create(ctx context.Context, channel *Channel) error { + if m.createFn != nil { + return m.createFn(ctx, channel) + } + return nil +} + +func (m *mockChannelRepository) GetByID(ctx context.Context, id int64) (*Channel, error) { + if m.getByIDFn != nil { + return m.getByIDFn(ctx, id) + } + return nil, ErrChannelNotFound +} + +func (m *mockChannelRepository) Update(ctx context.Context, channel *Channel) error { + if m.updateFn != nil { + return m.updateFn(ctx, channel) + } + return nil +} + +func (m *mockChannelRepository) Delete(ctx context.Context, id int64) error { + if m.deleteFn != nil { + return m.deleteFn(ctx, id) + } + return nil +} + +func (m *mockChannelRepository) List(ctx context.Context, params pagination.PaginationParams, status, search string) ([]Channel, *pagination.PaginationResult, error) { + if m.listFn != nil { + return m.listFn(ctx, params, status, search) + } + return nil, nil, nil +} + +func (m *mockChannelRepository) ListAll(ctx context.Context) ([]Channel, error) { + if m.listAllFn != nil { + return m.listAllFn(ctx) + } + return nil, nil +} + +func (m *mockChannelRepository) ExistsByName(ctx context.Context, name string) (bool, error) { + if m.existsByNameFn != nil { + return m.existsByNameFn(ctx, name) + } + return false, nil +} + +func (m *mockChannelRepository) ExistsByNameExcluding(ctx context.Context, name string, excludeID int64) (bool, error) { + if m.existsByNameExcludingFn != nil { + return m.existsByNameExcludingFn(ctx, name, excludeID) + } + return false, nil +} + +func (m *mockChannelRepository) GetGroupIDs(ctx context.Context, channelID int64) ([]int64, error) { + if m.getGroupIDsFn != nil { + return m.getGroupIDsFn(ctx, channelID) + } + return nil, nil +} + +func (m *mockChannelRepository) SetGroupIDs(ctx context.Context, channelID int64, groupIDs []int64) error { + if m.setGroupIDsFn != nil { + return m.setGroupIDsFn(ctx, channelID, groupIDs) + } + return nil +} + +func (m *mockChannelRepository) GetChannelIDByGroupID(ctx context.Context, groupID int64) (int64, error) { + if m.getChannelIDByGroupIDFn != nil { + return m.getChannelIDByGroupIDFn(ctx, groupID) + } + return 0, nil +} + +func (m *mockChannelRepository) GetGroupsInOtherChannels(ctx context.Context, channelID int64, groupIDs []int64) ([]int64, error) { + if m.getGroupsInOtherChannelsFn != nil { + return m.getGroupsInOtherChannelsFn(ctx, channelID, groupIDs) + } + return nil, nil +} + +func (m *mockChannelRepository) GetGroupPlatforms(ctx context.Context, groupIDs []int64) (map[int64]string, error) { + if m.getGroupPlatformsFn != nil { + return m.getGroupPlatformsFn(ctx, groupIDs) + } + return nil, nil +} + +func (m *mockChannelRepository) ListModelPricing(ctx context.Context, channelID int64) ([]ChannelModelPricing, error) { + if m.listModelPricingFn != nil { + return m.listModelPricingFn(ctx, channelID) + } + return nil, nil +} + +func (m *mockChannelRepository) CreateModelPricing(ctx context.Context, pricing *ChannelModelPricing) error { + if m.createModelPricingFn != nil { + return m.createModelPricingFn(ctx, pricing) + } + return nil +} + +func (m *mockChannelRepository) UpdateModelPricing(ctx context.Context, pricing *ChannelModelPricing) error { + if m.updateModelPricingFn != nil { + return m.updateModelPricingFn(ctx, pricing) + } + return nil +} + +func (m *mockChannelRepository) DeleteModelPricing(ctx context.Context, id int64) error { + if m.deleteModelPricingFn != nil { + return m.deleteModelPricingFn(ctx, id) + } + return nil +} + +func (m *mockChannelRepository) ReplaceModelPricing(ctx context.Context, channelID int64, pricingList []ChannelModelPricing) error { + if m.replaceModelPricingFn != nil { + return m.replaceModelPricingFn(ctx, channelID, pricingList) + } + return nil +} + +// --------------------------------------------------------------------------- +// Mock: APIKeyAuthCacheInvalidator +// --------------------------------------------------------------------------- + +type mockChannelAuthCacheInvalidator struct { + invalidatedGroupIDs []int64 + invalidatedKeys []string + invalidatedUserIDs []int64 +} + +func (m *mockChannelAuthCacheInvalidator) InvalidateAuthCacheByKey(_ context.Context, key string) { + m.invalidatedKeys = append(m.invalidatedKeys, key) +} + +func (m *mockChannelAuthCacheInvalidator) InvalidateAuthCacheByUserID(_ context.Context, userID int64) { + m.invalidatedUserIDs = append(m.invalidatedUserIDs, userID) +} + +func (m *mockChannelAuthCacheInvalidator) InvalidateAuthCacheByGroupID(_ context.Context, groupID int64) { + m.invalidatedGroupIDs = append(m.invalidatedGroupIDs, groupID) +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +func newTestChannelService(repo *mockChannelRepository) *ChannelService { + return NewChannelService(repo, nil) +} + +func newTestChannelServiceWithAuth(repo *mockChannelRepository, auth *mockChannelAuthCacheInvalidator) *ChannelService { + return NewChannelService(repo, auth) +} + + +// makeStandardRepo returns a repo that serves one active channel with anthropic pricing +// for group 1, with the given model pricing and model mapping. +func makeStandardRepo(ch Channel, groupPlatforms map[int64]string) *mockChannelRepository { + return &mockChannelRepository{ + listAllFn: func(_ context.Context) ([]Channel, error) { + return []Channel{ch}, nil + }, + getGroupPlatformsFn: func(_ context.Context, _ []int64) (map[int64]string, error) { + return groupPlatforms, nil + }, + } +} + +// =========================================================================== +// 1. BuildModelMappingChain +// =========================================================================== + +func TestBuildModelMappingChain(t *testing.T) { + tests := []struct { + name string + result ChannelMappingResult + requestModel string + upstreamModel string + want string + }{ + { + name: "no mapping, no upstream diff", + result: ChannelMappingResult{Mapped: false, MappedModel: "claude-sonnet-4"}, + requestModel: "claude-sonnet-4", + upstreamModel: "claude-sonnet-4", + want: "", + }, + { + name: "no mapping, upstream differs", + result: ChannelMappingResult{Mapped: false, MappedModel: "claude-sonnet-4"}, + requestModel: "claude-sonnet-4", + upstreamModel: "claude-sonnet-4-20250514", + want: "claude-sonnet-4\u2192claude-sonnet-4-20250514", + }, + { + name: "mapped, upstream differs", + result: ChannelMappingResult{Mapped: true, MappedModel: "claude-sonnet-4-20250514"}, + requestModel: "my-model", + upstreamModel: "actual-upstream", + want: "my-model\u2192claude-sonnet-4-20250514\u2192actual-upstream", + }, + { + name: "mapped, upstream same as mapped", + result: ChannelMappingResult{Mapped: true, MappedModel: "claude-sonnet-4-20250514"}, + requestModel: "claude-sonnet-4", + upstreamModel: "claude-sonnet-4-20250514", + want: "claude-sonnet-4\u2192claude-sonnet-4-20250514", + }, + { + name: "mapped, upstream empty", + result: ChannelMappingResult{Mapped: true, MappedModel: "target-model"}, + requestModel: "my-model", + upstreamModel: "", + want: "my-model\u2192target-model", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.result.BuildModelMappingChain(tt.requestModel, tt.upstreamModel) + require.Equal(t, tt.want, got) + }) + } +} + +// =========================================================================== +// 2. ReplaceModelInBody +// =========================================================================== + +func TestReplaceModelInBody(t *testing.T) { + tests := []struct { + name string + body []byte + newModel string + check func(t *testing.T, result []byte) + }{ + { + name: "empty body", + body: []byte{}, + newModel: "new-model", + check: func(t *testing.T, result []byte) { + require.Equal(t, []byte{}, result) + }, + }, + { + name: "model already equal", + body: []byte(`{"model":"claude-sonnet-4","temperature":0.7}`), + newModel: "claude-sonnet-4", + check: func(t *testing.T, result []byte) { + require.Equal(t, []byte(`{"model":"claude-sonnet-4","temperature":0.7}`), result) + }, + }, + { + name: "model different", + body: []byte(`{"model":"claude-sonnet-4","temperature":0.7}`), + newModel: "claude-opus-4", + check: func(t *testing.T, result []byte) { + require.Contains(t, string(result), `"model":"claude-opus-4"`) + require.Contains(t, string(result), `"temperature"`) + }, + }, + { + name: "no model field", + body: []byte(`{"temperature":0.7}`), + newModel: "claude-opus-4", + check: func(t *testing.T, result []byte) { + require.Contains(t, string(result), `"model":"claude-opus-4"`) + require.Contains(t, string(result), `"temperature"`) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ReplaceModelInBody(tt.body, tt.newModel) + tt.check(t, result) + }) + } +} + +// =========================================================================== +// 3. validateNoConflictingModels + validateNoConflictingMappings +// =========================================================================== + +func TestValidateNoConflictingModels(t *testing.T) { + tests := []struct { + name string + pricingList []ChannelModelPricing + wantErr bool + errContains string + }{ + { + name: "no duplicates", + pricingList: []ChannelModelPricing{ + {Platform: "anthropic", Models: []string{"claude-sonnet-4", "claude-opus-4"}}, + {Platform: "openai", Models: []string{"gpt-5.1"}}, + }, + wantErr: false, + }, + { + name: "same platform duplicate", + pricingList: []ChannelModelPricing{ + {Platform: "anthropic", Models: []string{"claude-sonnet-4"}}, + {Platform: "anthropic", Models: []string{"claude-sonnet-4"}}, + }, + wantErr: true, + errContains: "claude-sonnet-4", + }, + { + name: "same model different platform", + pricingList: []ChannelModelPricing{ + {Platform: "anthropic", Models: []string{"model-a"}}, + {Platform: "openai", Models: []string{"model-a"}}, + }, + wantErr: false, + }, + { + name: "case insensitive", + pricingList: []ChannelModelPricing{ + {Platform: "anthropic", Models: []string{"Claude"}}, + {Platform: "anthropic", Models: []string{"claude"}}, + }, + wantErr: true, + }, + { + name: "empty list (nil)", + pricingList: nil, + wantErr: false, + }, + { + name: "wildcard_vs_wildcard_conflict", + pricingList: []ChannelModelPricing{ + {Platform: "anthropic", Models: []string{"claude-*"}}, + {Platform: "anthropic", Models: []string{"claude-opus-*"}}, + }, + wantErr: true, + errContains: "conflict", + }, + { + name: "wildcard_vs_exact_conflict", + pricingList: []ChannelModelPricing{ + {Platform: "anthropic", Models: []string{"claude-*"}}, + {Platform: "anthropic", Models: []string{"claude-opus-4-6"}}, + }, + wantErr: true, + errContains: "conflict", + }, + { + name: "no_conflict_different_platform", + pricingList: []ChannelModelPricing{ + {Platform: "anthropic", Models: []string{"claude-opus-*"}}, + {Platform: "openai", Models: []string{"claude-*"}}, + }, + wantErr: false, + }, + { + name: "no_conflict_same_platform_different_prefix", + pricingList: []ChannelModelPricing{ + {Platform: "anthropic", Models: []string{"claude-opus-*"}}, + {Platform: "anthropic", Models: []string{"gpt-*"}}, + }, + wantErr: false, + }, + { + name: "catch_all_wildcard_conflicts_with_everything", + pricingList: []ChannelModelPricing{ + {Platform: "openai", Models: []string{"*"}}, + {Platform: "openai", Models: []string{"gpt-5"}}, + }, + wantErr: true, + errContains: "conflict", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateNoConflictingModels(tt.pricingList) + if tt.wantErr { + require.Error(t, err) + if tt.errContains != "" { + require.Contains(t, err.Error(), tt.errContains) + } + } else { + require.NoError(t, err) + } + }) + } + + // Additional sub-case: explicit empty slice + t.Run("empty list (empty slice)", func(t *testing.T) { + err := validateNoConflictingModels([]ChannelModelPricing{}) + require.NoError(t, err) + }) +} + +func TestValidateNoConflictingMappings(t *testing.T) { + tests := []struct { + name string + mapping map[string]map[string]string + wantErr bool + errContains string + }{ + { + name: "nil mapping", + mapping: nil, + wantErr: false, + }, + { + name: "empty mapping", + mapping: map[string]map[string]string{}, + wantErr: false, + }, + { + name: "no conflict", + mapping: map[string]map[string]string{ + "anthropic": {"claude-opus-*": "opus", "gpt-*": "gpt"}, + }, + wantErr: false, + }, + { + name: "wildcard vs wildcard conflict", + mapping: map[string]map[string]string{ + "anthropic": {"claude-*": "a", "claude-opus-*": "b"}, + }, + wantErr: true, + errContains: "conflict", + }, + { + name: "wildcard vs exact conflict", + mapping: map[string]map[string]string{ + "openai": {"gpt-*": "a", "gpt-4o": "b"}, + }, + wantErr: true, + errContains: "conflict", + }, + { + name: "exact duplicate conflict", + mapping: map[string]map[string]string{ + "anthropic": {"claude-opus-4": "a"}, + "openai": {"claude-opus-4": "b"}, + }, + wantErr: false, // different platforms + }, + { + name: "different platforms no conflict", + mapping: map[string]map[string]string{ + "anthropic": {"claude-*": "a"}, + "openai": {"claude-*": "b"}, + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateNoConflictingMappings(tt.mapping) + if tt.wantErr { + require.Error(t, err) + if tt.errContains != "" { + require.Contains(t, err.Error(), tt.errContains) + } + } else { + require.NoError(t, err) + } + }) + } +} + +func TestConflictsBetween(t *testing.T) { + tests := []struct { + name string + a, b modelEntry + want bool + }{ + { + name: "exact same", + a: modelEntry{prefix: "claude-opus-4", wildcard: false}, + b: modelEntry{prefix: "claude-opus-4", wildcard: false}, + want: true, + }, + { + name: "exact different", + a: modelEntry{prefix: "claude-opus-4", wildcard: false}, + b: modelEntry{prefix: "gpt-4o", wildcard: false}, + want: false, + }, + { + name: "wildcard matches exact", + a: modelEntry{prefix: "claude-", wildcard: true}, + b: modelEntry{prefix: "claude-opus-4", wildcard: false}, + want: true, + }, + { + name: "exact does not match unrelated wildcard", + a: modelEntry{prefix: "gpt-4o", wildcard: false}, + b: modelEntry{prefix: "claude-", wildcard: true}, + want: false, + }, + { + name: "wildcard prefix overlap", + a: modelEntry{prefix: "claude-", wildcard: true}, + b: modelEntry{prefix: "claude-opus-", wildcard: true}, + want: true, + }, + { + name: "wildcards no overlap", + a: modelEntry{prefix: "claude-", wildcard: true}, + b: modelEntry{prefix: "gpt-", wildcard: true}, + want: false, + }, + { + name: "catch-all wildcard vs any", + a: modelEntry{prefix: "", wildcard: true}, + b: modelEntry{prefix: "anything", wildcard: false}, + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.want, conflictsBetween(tt.a, tt.b)) + }) + } +} + +// =========================================================================== +// 4. Cache Building + Hot Path Methods +// =========================================================================== + +// --- 4.1 GetChannelForGroup --- + +func TestGetChannelForGroup_Success(t *testing.T) { + ch := Channel{ + ID: 1, + Name: "test-channel", + Status: StatusActive, + GroupIDs: []int64{10}, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + result, err := svc.GetChannelForGroup(context.Background(), 10) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, int64(1), result.ID) + require.Equal(t, "test-channel", result.Name) + + // returned value should be a clone + result.Name = "mutated" + result2, err := svc.GetChannelForGroup(context.Background(), 10) + require.NoError(t, err) + require.Equal(t, "test-channel", result2.Name) +} + +func TestGetChannelForGroup_InactiveChannel(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusDisabled, + GroupIDs: []int64{10}, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + result, err := svc.GetChannelForGroup(context.Background(), 10) + require.NoError(t, err) + require.Nil(t, result) +} + +func TestGetChannelForGroup_NoChannel(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + result, err := svc.GetChannelForGroup(context.Background(), 999) + require.NoError(t, err) + require.Nil(t, result) +} + +func TestGetChannelForGroup_CacheError(t *testing.T) { + repo := &mockChannelRepository{ + listAllFn: func(_ context.Context) ([]Channel, error) { + return nil, errors.New("db connection failed") + }, + } + svc := newTestChannelService(repo) + + result, err := svc.GetChannelForGroup(context.Background(), 10) + require.Error(t, err) + require.Nil(t, result) + require.Contains(t, err.Error(), "db connection failed") +} + +// --- 4.2 GetChannelModelPricing --- + +func TestGetChannelModelPricing_ExactMatch(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + ModelPricing: []ChannelModelPricing{ + {ID: 100, Platform: "anthropic", Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(15e-6)}, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + result := svc.GetChannelModelPricing(context.Background(), 10, "claude-opus-4") + require.NotNil(t, result) + require.Equal(t, int64(100), result.ID) + require.InDelta(t, 15e-6, *result.InputPrice, 1e-12) +} + +func TestGetChannelModelPricing_CaseInsensitive(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + ModelPricing: []ChannelModelPricing{ + {ID: 100, Platform: "anthropic", Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(15e-6)}, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + result := svc.GetChannelModelPricing(context.Background(), 10, "Claude-Opus-4") + require.NotNil(t, result) + require.Equal(t, int64(100), result.ID) +} + +func TestGetChannelModelPricing_WildcardMatch(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + ModelPricing: []ChannelModelPricing{ + {ID: 200, Platform: "anthropic", Models: []string{"claude-*"}, InputPrice: testPtrFloat64(10e-6)}, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + result := svc.GetChannelModelPricing(context.Background(), 10, "claude-sonnet-4") + require.NotNil(t, result) + require.Equal(t, int64(200), result.ID) +} + +func TestGetChannelModelPricing_WildcardFirstMatch(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + ModelPricing: []ChannelModelPricing{ + {ID: 200, Platform: "anthropic", Models: []string{"claude-*"}, InputPrice: testPtrFloat64(10e-6)}, + {ID: 300, Platform: "anthropic", Models: []string{"claude-sonnet-*"}, InputPrice: testPtrFloat64(5e-6)}, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + result := svc.GetChannelModelPricing(context.Background(), 10, "claude-sonnet-4-20250514") + require.NotNil(t, result) + // "claude-*" is defined first, so it matches first regardless of prefix length + require.Equal(t, int64(200), result.ID) + require.InDelta(t, 10e-6, *result.InputPrice, 1e-12) +} + +func TestGetChannelModelPricing_NoMatch(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + ModelPricing: []ChannelModelPricing{ + {ID: 100, Platform: "anthropic", Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(15e-6)}, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + result := svc.GetChannelModelPricing(context.Background(), 10, "gpt-5.1") + require.Nil(t, result) +} + +func TestGetChannelModelPricing_InactiveChannel(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusDisabled, + GroupIDs: []int64{10}, + ModelPricing: []ChannelModelPricing{ + {ID: 100, Platform: "anthropic", Models: []string{"claude-opus-4"}}, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + result := svc.GetChannelModelPricing(context.Background(), 10, "claude-opus-4") + require.Nil(t, result) +} + +func TestGetChannelModelPricing_PlatformFiltering(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10, 20}, + ModelPricing: []ChannelModelPricing{ + {ID: 100, Platform: "openai", Models: []string{"gpt-5.1"}, InputPrice: testPtrFloat64(5e-6)}, + {ID: 200, Platform: "anthropic", Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(15e-6)}, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic", 20: "openai"}) + svc := newTestChannelService(repo) + + // Group 10 (anthropic) should NOT see openai pricing + result := svc.GetChannelModelPricing(context.Background(), 10, "gpt-5.1") + require.Nil(t, result) + + // Group 10 (anthropic) should see anthropic pricing + result = svc.GetChannelModelPricing(context.Background(), 10, "claude-opus-4") + require.NotNil(t, result) + require.Equal(t, int64(200), result.ID) + + // Group 20 (openai) should see openai pricing + result = svc.GetChannelModelPricing(context.Background(), 20, "gpt-5.1") + require.NotNil(t, result) + require.Equal(t, int64(100), result.ID) + + // Group 20 (openai) should NOT see anthropic pricing + result = svc.GetChannelModelPricing(context.Background(), 20, "claude-opus-4") + require.Nil(t, result) +} + +func TestGetChannelModelPricing_ReturnsCopy(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + ModelPricing: []ChannelModelPricing{ + {ID: 100, Platform: "anthropic", Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(15e-6)}, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + result := svc.GetChannelModelPricing(context.Background(), 10, "claude-opus-4") + require.NotNil(t, result) + + // Mutate the returned pricing's slice fields — original cache should not be affected + // (Clone copies slices independently, pointer fields are shared per design) + result.Models = append(result.Models, "hacked") + result.ID = 999 + + // Original cache should not be affected (slice independence + struct copy) + result2 := svc.GetChannelModelPricing(context.Background(), 10, "claude-opus-4") + require.NotNil(t, result2) + require.Equal(t, 1, len(result2.Models)) + require.Equal(t, int64(100), result2.ID) +} + +// --- 4.3 ResolveChannelMapping --- + +func TestResolveChannelMapping_NoChannel(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + // Group 999 is not in any channel + result := svc.ResolveChannelMapping(context.Background(), 999, "claude-opus-4") + require.Equal(t, "claude-opus-4", result.MappedModel) + require.False(t, result.Mapped) + require.Equal(t, int64(0), result.ChannelID) +} + +func TestResolveChannelMapping_ExactMapping(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + ModelMapping: map[string]map[string]string{ + "anthropic": { + "claude-sonnet-4": "claude-sonnet-4-20250514", + }, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + result := svc.ResolveChannelMapping(context.Background(), 10, "claude-sonnet-4") + require.True(t, result.Mapped) + require.Equal(t, "claude-sonnet-4-20250514", result.MappedModel) + require.Equal(t, int64(1), result.ChannelID) +} + +func TestResolveChannelMapping_WildcardMapping(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + ModelMapping: map[string]map[string]string{ + "anthropic": { + "*": "gpt-5.4", + }, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + result := svc.ResolveChannelMapping(context.Background(), 10, "any-model-name") + require.True(t, result.Mapped) + require.Equal(t, "gpt-5.4", result.MappedModel) +} + +func TestResolveChannelMapping_WildcardFirstMatch(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + ModelMapping: map[string]map[string]string{ + "anthropic": { + "claude-*": "target2", + "claude-sonnet-*": "target1", + }, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + result := svc.ResolveChannelMapping(context.Background(), 10, "claude-sonnet-4") + require.True(t, result.Mapped) + // map iteration order is non-deterministic, so the first-match depends on + // insertion order which Go maps don't guarantee; verify that one of the + // wildcard targets matched + require.Contains(t, []string{"target1", "target2"}, result.MappedModel) +} + +func TestResolveChannelMapping_NoMapping(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + ModelMapping: map[string]map[string]string{ + "anthropic": { + "claude-sonnet-4": "mapped", + }, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + result := svc.ResolveChannelMapping(context.Background(), 10, "claude-opus-4") + require.False(t, result.Mapped) + require.Equal(t, "claude-opus-4", result.MappedModel) + require.Equal(t, int64(1), result.ChannelID) +} + +func TestResolveChannelMapping_DefaultBillingModelSource(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + BillingModelSource: "", // empty + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + result := svc.ResolveChannelMapping(context.Background(), 10, "claude-opus-4") + require.Equal(t, BillingModelSourceRequested, result.BillingModelSource) +} + +func TestResolveChannelMapping_UpstreamBillingModelSource(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + BillingModelSource: BillingModelSourceUpstream, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + result := svc.ResolveChannelMapping(context.Background(), 10, "claude-opus-4") + require.Equal(t, BillingModelSourceUpstream, result.BillingModelSource) +} + +func TestResolveChannelMapping_InactiveChannel(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusDisabled, + GroupIDs: []int64{10}, + ModelMapping: map[string]map[string]string{ + "anthropic": { + "claude-sonnet-4": "mapped", + }, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + result := svc.ResolveChannelMapping(context.Background(), 10, "claude-sonnet-4") + require.False(t, result.Mapped) + require.Equal(t, "claude-sonnet-4", result.MappedModel) + require.Equal(t, int64(0), result.ChannelID) // no channel +} + +// --- 4.4 IsModelRestricted --- + +func TestIsModelRestricted_NoChannel(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + RestrictModels: true, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + // Group 999 is not in any channel + restricted := svc.IsModelRestricted(context.Background(), 999, "claude-opus-4") + require.False(t, restricted) +} + +func TestIsModelRestricted_RestrictDisabled(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + RestrictModels: false, + ModelPricing: []ChannelModelPricing{ + {Platform: "anthropic", Models: []string{"claude-opus-4"}}, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + // Even though model is not in pricing, RestrictModels=false + restricted := svc.IsModelRestricted(context.Background(), 10, "nonexistent-model") + require.False(t, restricted) +} + +func TestIsModelRestricted_InactiveChannel(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusDisabled, + GroupIDs: []int64{10}, + RestrictModels: true, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + restricted := svc.IsModelRestricted(context.Background(), 10, "any-model") + require.False(t, restricted) +} + +func TestIsModelRestricted_ModelInPricing(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + RestrictModels: true, + ModelPricing: []ChannelModelPricing{ + {Platform: "anthropic", Models: []string{"claude-opus-4", "claude-sonnet-4"}}, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + restricted := svc.IsModelRestricted(context.Background(), 10, "claude-opus-4") + require.False(t, restricted) +} + +func TestIsModelRestricted_ModelInWildcard(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + RestrictModels: true, + ModelPricing: []ChannelModelPricing{ + {Platform: "anthropic", Models: []string{"claude-*"}}, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + restricted := svc.IsModelRestricted(context.Background(), 10, "claude-sonnet-4") + require.False(t, restricted) +} + +func TestIsModelRestricted_ModelNotFound(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + RestrictModels: true, + ModelPricing: []ChannelModelPricing{ + {Platform: "anthropic", Models: []string{"claude-opus-4"}}, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + restricted := svc.IsModelRestricted(context.Background(), 10, "gpt-5.1") + require.True(t, restricted) +} + +func TestIsModelRestricted_CaseInsensitive(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + RestrictModels: true, + ModelPricing: []ChannelModelPricing{ + {Platform: "anthropic", Models: []string{"claude-opus-4"}}, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + restricted := svc.IsModelRestricted(context.Background(), 10, "Claude-Opus-4") + require.False(t, restricted) +} + +// --- 4.5 ResolveChannelMappingAndRestrict --- + +func TestResolveChannelMappingAndRestrict_NilGroupID(t *testing.T) { + repo := &mockChannelRepository{ + listAllFn: func(_ context.Context) ([]Channel, error) { + return nil, nil + }, + } + svc := newTestChannelService(repo) + + mapping, restricted := svc.ResolveChannelMappingAndRestrict(context.Background(), nil, "claude-opus-4") + require.False(t, restricted) + require.False(t, mapping.Mapped) + require.Equal(t, "claude-opus-4", mapping.MappedModel) +} + +func TestResolveChannelMappingAndRestrict_ModelInPricing_WithMapping(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + RestrictModels: true, + ModelPricing: []ChannelModelPricing{ + {Platform: "anthropic", Models: []string{"claude-sonnet-4"}}, + }, + ModelMapping: map[string]map[string]string{ + "anthropic": { + "claude-sonnet-4": "claude-sonnet-4-20250514", + }, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + gid := int64(10) + mapping, restricted := svc.ResolveChannelMappingAndRestrict(context.Background(), &gid, "claude-sonnet-4") + require.False(t, restricted) // model IS in pricing + require.True(t, mapping.Mapped) + require.Equal(t, "claude-sonnet-4-20250514", mapping.MappedModel) +} + +func TestResolveChannelMappingAndRestrict_ModelNotInPricing_WithMapping(t *testing.T) { + // CRITICAL: this test verifies that restriction checks the ORIGINAL model + // against pricing BEFORE applying mapping. The model "unknown-model" is NOT + // in pricing, so even though the wildcard mapping "*" matches it, it should + // still be restricted. + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + RestrictModels: true, + ModelPricing: []ChannelModelPricing{ + {Platform: "anthropic", Models: []string{"claude-sonnet-4"}}, + }, + ModelMapping: map[string]map[string]string{ + "anthropic": { + "*": "catch-all-target", + }, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + gid := int64(10) + mapping, restricted := svc.ResolveChannelMappingAndRestrict(context.Background(), &gid, "unknown-model") + require.True(t, restricted) // model NOT in pricing, even though mapping exists + require.True(t, mapping.Mapped) + require.Equal(t, "catch-all-target", mapping.MappedModel) +} + +func TestResolveChannelMappingAndRestrict_ModelNotInPricing_NoMapping(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + RestrictModels: true, + ModelPricing: []ChannelModelPricing{ + {Platform: "anthropic", Models: []string{"claude-sonnet-4"}}, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + gid := int64(10) + mapping, restricted := svc.ResolveChannelMappingAndRestrict(context.Background(), &gid, "unknown-model") + require.True(t, restricted) // model NOT in pricing + require.False(t, mapping.Mapped) + require.Equal(t, "unknown-model", mapping.MappedModel) +} + +// --- 4.6 Cache Building Specifics --- + +func TestBuildCache_DBError(t *testing.T) { + callCount := 0 + repo := &mockChannelRepository{ + listAllFn: func(_ context.Context) ([]Channel, error) { + callCount++ + return nil, errors.New("database down") + }, + } + svc := newTestChannelService(repo) + + // First call should fail + _, err := svc.GetChannelForGroup(context.Background(), 10) + require.Error(t, err) + require.Contains(t, err.Error(), "database down") + require.Equal(t, 1, callCount) + + // Second call within error-TTL should use error cache, but still return error + // Because buildCache stores error-TTL cache and returns error, the cached value + // is still within TTL and loadCache returns it (which is an empty cache). + // Actually, re-reading the code: buildCache returns nil, err, and the error cache + // only serves as a "don't retry immediately" mechanism. The singleflight.Do + // returns the error. On next call within error-TTL, the cache has an empty but + // valid entry, so loadCache returns it (with empty maps). GetChannelForGroup + // will find nothing and return nil, nil. + result, err := svc.GetChannelForGroup(context.Background(), 10) + require.NoError(t, err) + require.Nil(t, result) + // Should NOT have hit DB again (error-TTL cache is active) + require.Equal(t, 1, callCount) +} + +func TestBuildCache_GroupPlatformError(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + ModelPricing: []ChannelModelPricing{ + {ID: 100, Platform: "anthropic", Models: []string{"claude-opus-4"}}, + }, + } + repo := &mockChannelRepository{ + listAllFn: func(_ context.Context) ([]Channel, error) { + return []Channel{ch}, nil + }, + getGroupPlatformsFn: func(_ context.Context, _ []int64) (map[int64]string, error) { + return nil, errors.New("group platforms failed") + }, + } + svc := newTestChannelService(repo) + + // Should degrade gracefully: channel is found, but without platform info + // pricing won't match because platform will be "" and pricing platform is "anthropic" + result, err := svc.GetChannelForGroup(context.Background(), 10) + require.NoError(t, err) + require.NotNil(t, result) // channel still found + require.Equal(t, int64(1), result.ID) +} + +func TestBuildCache_MultipleGroupsSameChannel(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10, 20, 30}, + ModelPricing: []ChannelModelPricing{ + {ID: 100, Platform: "anthropic", Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(15e-6)}, + }, + } + repo := makeStandardRepo(ch, map[int64]string{ + 10: "anthropic", + 20: "anthropic", + 30: "anthropic", + }) + svc := newTestChannelService(repo) + + for _, gid := range []int64{10, 20, 30} { + result := svc.GetChannelModelPricing(context.Background(), gid, "claude-opus-4") + require.NotNil(t, result, "group %d should have pricing", gid) + require.Equal(t, int64(100), result.ID) + } +} + +func TestBuildCache_PlatformFiltering(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10, 20}, + ModelPricing: []ChannelModelPricing{ + {ID: 100, Platform: "anthropic", Models: []string{"claude-opus-4"}}, + {ID: 200, Platform: "openai", Models: []string{"gpt-5.1"}}, + }, + } + repo := makeStandardRepo(ch, map[int64]string{ + 10: "anthropic", + 20: "openai", + }) + svc := newTestChannelService(repo) + + // anthropic group sees only anthropic models + require.NotNil(t, svc.GetChannelModelPricing(context.Background(), 10, "claude-opus-4")) + require.Nil(t, svc.GetChannelModelPricing(context.Background(), 10, "gpt-5.1")) + + // openai group sees only openai models + require.NotNil(t, svc.GetChannelModelPricing(context.Background(), 20, "gpt-5.1")) + require.Nil(t, svc.GetChannelModelPricing(context.Background(), 20, "claude-opus-4")) +} + +func TestBuildCache_WildcardPreservesConfigOrder(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + ModelPricing: []ChannelModelPricing{ + // Configuration order: shortest prefix first + {ID: 100, Platform: "anthropic", Models: []string{"c-*"}, InputPrice: testPtrFloat64(1e-6)}, + {ID: 200, Platform: "anthropic", Models: []string{"c-son-*"}, InputPrice: testPtrFloat64(2e-6)}, + {ID: 300, Platform: "anthropic", Models: []string{"c-son-4-*"}, InputPrice: testPtrFloat64(3e-6)}, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: "anthropic"}) + svc := newTestChannelService(repo) + + // "c-son-4-xxx" matches all three wildcards, but "c-*" (ID=100) is first in config + result := svc.GetChannelModelPricing(context.Background(), 10, "c-son-4-xxx") + require.NotNil(t, result) + require.Equal(t, int64(100), result.ID) + + // "c-son-yyy" matches "c-*" and "c-son-*", but "c-*" (ID=100) is first + result = svc.GetChannelModelPricing(context.Background(), 10, "c-son-yyy") + require.NotNil(t, result) + require.Equal(t, int64(100), result.ID) + + // "c-other" only matches "c-*" (ID=100) + result = svc.GetChannelModelPricing(context.Background(), 10, "c-other") + require.NotNil(t, result) + require.Equal(t, int64(100), result.ID) +} + +// --- 4.7 invalidateCache --- + +func TestInvalidateCache(t *testing.T) { + callCount := 0 + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + ModelPricing: []ChannelModelPricing{ + {ID: 100, Platform: "anthropic", Models: []string{"claude-opus-4"}}, + }, + } + repo := &mockChannelRepository{ + listAllFn: func(_ context.Context) ([]Channel, error) { + callCount++ + return []Channel{ch}, nil + }, + getGroupPlatformsFn: func(_ context.Context, _ []int64) (map[int64]string, error) { + return map[int64]string{10: "anthropic"}, nil + }, + } + svc := newTestChannelService(repo) + + // First load + result := svc.GetChannelModelPricing(context.Background(), 10, "claude-opus-4") + require.NotNil(t, result) + require.Equal(t, 1, callCount) + + // Second call should use cache + result = svc.GetChannelModelPricing(context.Background(), 10, "claude-opus-4") + require.NotNil(t, result) + require.Equal(t, 1, callCount) // no new DB call + + // Invalidate + svc.invalidateCache() + + // Next call should rebuild from DB + result = svc.GetChannelModelPricing(context.Background(), 10, "claude-opus-4") + require.NotNil(t, result) + require.Equal(t, 2, callCount) // rebuilt +} + +// =========================================================================== +// 5. CRUD Methods +// =========================================================================== + +// --- 5.1 Create --- + +func TestCreate_Success(t *testing.T) { + createdID := int64(42) + repo := &mockChannelRepository{ + existsByNameFn: func(_ context.Context, _ string) (bool, error) { + return false, nil + }, + getGroupsInOtherChannelsFn: func(_ context.Context, _ int64, _ []int64) ([]int64, error) { + return nil, nil + }, + createFn: func(_ context.Context, ch *Channel) error { + ch.ID = createdID + return nil + }, + getByIDFn: func(_ context.Context, id int64) (*Channel, error) { + return &Channel{ID: id, Name: "new-channel", Status: StatusActive}, nil + }, + listAllFn: func(_ context.Context) ([]Channel, error) { + return nil, nil + }, + } + svc := newTestChannelService(repo) + + result, err := svc.Create(context.Background(), &CreateChannelInput{ + Name: "new-channel", + GroupIDs: []int64{10}, + }) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, createdID, result.ID) +} + +func TestCreate_NameExists(t *testing.T) { + repo := &mockChannelRepository{ + existsByNameFn: func(_ context.Context, _ string) (bool, error) { + return true, nil + }, + } + svc := newTestChannelService(repo) + + _, err := svc.Create(context.Background(), &CreateChannelInput{ + Name: "existing-channel", + }) + require.Error(t, err) + require.ErrorIs(t, err, ErrChannelExists) +} + +func TestCreate_GroupConflict(t *testing.T) { + repo := &mockChannelRepository{ + existsByNameFn: func(_ context.Context, _ string) (bool, error) { + return false, nil + }, + getGroupsInOtherChannelsFn: func(_ context.Context, _ int64, _ []int64) ([]int64, error) { + return []int64{10}, nil // group 10 already in another channel + }, + } + svc := newTestChannelService(repo) + + _, err := svc.Create(context.Background(), &CreateChannelInput{ + Name: "new-channel", + GroupIDs: []int64{10, 20}, + }) + require.Error(t, err) + require.ErrorIs(t, err, ErrGroupAlreadyInChannel) +} + +func TestCreate_DuplicateModel(t *testing.T) { + repo := &mockChannelRepository{ + existsByNameFn: func(_ context.Context, _ string) (bool, error) { + return false, nil + }, + } + svc := newTestChannelService(repo) + + _, err := svc.Create(context.Background(), &CreateChannelInput{ + Name: "new-channel", + ModelPricing: []ChannelModelPricing{ + {Platform: "anthropic", Models: []string{"claude-opus-4"}}, + {Platform: "anthropic", Models: []string{"claude-opus-4"}}, // duplicate + }, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "claude-opus-4") +} + +func TestCreate_DefaultBillingModelSource(t *testing.T) { + var capturedChannel *Channel + repo := &mockChannelRepository{ + existsByNameFn: func(_ context.Context, _ string) (bool, error) { + return false, nil + }, + createFn: func(_ context.Context, ch *Channel) error { + capturedChannel = ch + ch.ID = 1 + return nil + }, + getByIDFn: func(_ context.Context, id int64) (*Channel, error) { + return capturedChannel, nil + }, + listAllFn: func(_ context.Context) ([]Channel, error) { + return nil, nil + }, + } + svc := newTestChannelService(repo) + + result, err := svc.Create(context.Background(), &CreateChannelInput{ + Name: "new-channel", + BillingModelSource: "", // empty, should default to "requested" + }) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, BillingModelSourceRequested, result.BillingModelSource) +} + +func TestCreate_InvalidatesCache(t *testing.T) { + loadCount := 0 + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10}, + ModelPricing: []ChannelModelPricing{ + {ID: 100, Platform: "anthropic", Models: []string{"claude-opus-4"}}, + }, + } + repo := &mockChannelRepository{ + listAllFn: func(_ context.Context) ([]Channel, error) { + loadCount++ + return []Channel{ch}, nil + }, + getGroupPlatformsFn: func(_ context.Context, _ []int64) (map[int64]string, error) { + return map[int64]string{10: "anthropic"}, nil + }, + existsByNameFn: func(_ context.Context, _ string) (bool, error) { + return false, nil + }, + createFn: func(_ context.Context, c *Channel) error { + c.ID = 2 + return nil + }, + getByIDFn: func(_ context.Context, id int64) (*Channel, error) { + return &Channel{ID: id, Name: "new", Status: StatusActive}, nil + }, + } + svc := newTestChannelService(repo) + + // Load cache + _ = svc.GetChannelModelPricing(context.Background(), 10, "claude-opus-4") + require.Equal(t, 1, loadCount) + + // Create triggers cache invalidation + _, err := svc.Create(context.Background(), &CreateChannelInput{Name: "new"}) + require.NoError(t, err) + + // Next cache access should rebuild + _ = svc.GetChannelModelPricing(context.Background(), 10, "claude-opus-4") + require.Equal(t, 2, loadCount) +} + +// --- 5.2 Update --- + +func TestUpdate_Success(t *testing.T) { + existing := &Channel{ + ID: 1, + Name: "original", + Status: StatusActive, + } + repo := &mockChannelRepository{ + getByIDFn: func(_ context.Context, id int64) (*Channel, error) { + return existing.Clone(), nil + }, + updateFn: func(_ context.Context, _ *Channel) error { + return nil + }, + getGroupIDsFn: func(_ context.Context, _ int64) ([]int64, error) { + return nil, nil + }, + listAllFn: func(_ context.Context) ([]Channel, error) { + return nil, nil + }, + } + svc := newTestChannelService(repo) + + result, err := svc.Update(context.Background(), 1, &UpdateChannelInput{ + Name: "updated-name", + Description: testPtrString("new desc"), + }) + require.NoError(t, err) + require.NotNil(t, result) +} + +func TestUpdate_NotFound(t *testing.T) { + repo := &mockChannelRepository{ + getByIDFn: func(_ context.Context, _ int64) (*Channel, error) { + return nil, ErrChannelNotFound + }, + } + svc := newTestChannelService(repo) + + _, err := svc.Update(context.Background(), 999, &UpdateChannelInput{ + Name: "whatever", + }) + require.Error(t, err) + require.Contains(t, err.Error(), "channel") +} + +func TestUpdate_NameConflict(t *testing.T) { + existing := &Channel{ + ID: 1, + Name: "original", + Status: StatusActive, + } + repo := &mockChannelRepository{ + getByIDFn: func(_ context.Context, _ int64) (*Channel, error) { + return existing.Clone(), nil + }, + existsByNameExcludingFn: func(_ context.Context, _ string, _ int64) (bool, error) { + return true, nil // name conflicts with another channel + }, + } + svc := newTestChannelService(repo) + + _, err := svc.Update(context.Background(), 1, &UpdateChannelInput{ + Name: "conflicting-name", + }) + require.Error(t, err) + require.ErrorIs(t, err, ErrChannelExists) +} + +func TestUpdate_GroupConflict(t *testing.T) { + existing := &Channel{ + ID: 1, + Name: "original", + Status: StatusActive, + } + repo := &mockChannelRepository{ + getByIDFn: func(_ context.Context, _ int64) (*Channel, error) { + return existing.Clone(), nil + }, + getGroupsInOtherChannelsFn: func(_ context.Context, _ int64, _ []int64) ([]int64, error) { + return []int64{20}, nil // group 20 in another channel + }, + } + svc := newTestChannelService(repo) + + newGroupIDs := []int64{10, 20} + _, err := svc.Update(context.Background(), 1, &UpdateChannelInput{ + GroupIDs: &newGroupIDs, + }) + require.Error(t, err) + require.ErrorIs(t, err, ErrGroupAlreadyInChannel) +} + +func TestUpdate_DuplicateModel(t *testing.T) { + existing := &Channel{ + ID: 1, + Name: "original", + Status: StatusActive, + } + repo := &mockChannelRepository{ + getByIDFn: func(_ context.Context, _ int64) (*Channel, error) { + return existing.Clone(), nil + }, + } + svc := newTestChannelService(repo) + + dupPricing := []ChannelModelPricing{ + {Platform: "anthropic", Models: []string{"claude-opus-4"}}, + {Platform: "anthropic", Models: []string{"claude-opus-4"}}, + } + _, err := svc.Update(context.Background(), 1, &UpdateChannelInput{ + ModelPricing: &dupPricing, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "claude-opus-4") +} + +func TestUpdate_InvalidatesChannelCache(t *testing.T) { + existing := &Channel{ + ID: 1, + Name: "original", + Status: StatusActive, + } + loadCount := 0 + repo := &mockChannelRepository{ + getByIDFn: func(_ context.Context, _ int64) (*Channel, error) { + return existing.Clone(), nil + }, + updateFn: func(_ context.Context, _ *Channel) error { + return nil + }, + getGroupIDsFn: func(_ context.Context, _ int64) ([]int64, error) { + return []int64{10, 20}, nil + }, + listAllFn: func(_ context.Context) ([]Channel, error) { + loadCount++ + return []Channel{*existing}, nil + }, + getGroupPlatformsFn: func(_ context.Context, _ []int64) (map[int64]string, error) { + return nil, nil + }, + } + svc := newTestChannelService(repo) + + // Load cache first + _, _ = svc.GetChannelForGroup(context.Background(), 10) + require.Equal(t, 1, loadCount) + + result, err := svc.Update(context.Background(), 1, &UpdateChannelInput{ + Description: testPtrString("updated"), + }) + require.NoError(t, err) + require.NotNil(t, result) + + // Channel cache should be invalidated (next access rebuilds) + _, _ = svc.GetChannelForGroup(context.Background(), 10) + require.Equal(t, 2, loadCount) +} + +func TestUpdate_InvalidatesAuthCache(t *testing.T) { + existing := &Channel{ + ID: 1, + Name: "original", + Status: StatusActive, + } + auth := &mockChannelAuthCacheInvalidator{} + repo := &mockChannelRepository{ + getByIDFn: func(_ context.Context, _ int64) (*Channel, error) { + return existing.Clone(), nil + }, + updateFn: func(_ context.Context, _ *Channel) error { + return nil + }, + getGroupIDsFn: func(_ context.Context, _ int64) ([]int64, error) { + return []int64{10, 20}, nil + }, + listAllFn: func(_ context.Context) ([]Channel, error) { + return nil, nil + }, + } + svc := newTestChannelServiceWithAuth(repo, auth) + + result, err := svc.Update(context.Background(), 1, &UpdateChannelInput{ + Description: testPtrString("updated"), + }) + require.NoError(t, err) + require.NotNil(t, result) + + // Auth cache should be invalidated for both groups + require.ElementsMatch(t, []int64{10, 20}, auth.invalidatedGroupIDs) +} + +// --- 5.3 Delete --- + +func TestChannelDelete_Success(t *testing.T) { + deleted := false + repo := &mockChannelRepository{ + getGroupIDsFn: func(_ context.Context, _ int64) ([]int64, error) { + return nil, nil + }, + deleteFn: func(_ context.Context, _ int64) error { + deleted = true + return nil + }, + listAllFn: func(_ context.Context) ([]Channel, error) { + return nil, nil + }, + } + svc := newTestChannelService(repo) + + err := svc.Delete(context.Background(), 1) + require.NoError(t, err) + require.True(t, deleted) +} + +func TestChannelDelete_InvalidatesCaches(t *testing.T) { + auth := &mockChannelAuthCacheInvalidator{} + loadCount := 0 + repo := &mockChannelRepository{ + getGroupIDsFn: func(_ context.Context, _ int64) ([]int64, error) { + return []int64{10, 20}, nil + }, + deleteFn: func(_ context.Context, _ int64) error { + return nil + }, + listAllFn: func(_ context.Context) ([]Channel, error) { + loadCount++ + return []Channel{{ID: 1, Status: StatusActive, GroupIDs: []int64{10, 20}}}, nil + }, + getGroupPlatformsFn: func(_ context.Context, _ []int64) (map[int64]string, error) { + return nil, nil + }, + } + svc := newTestChannelServiceWithAuth(repo, auth) + + // Load cache first + _, _ = svc.GetChannelForGroup(context.Background(), 10) + require.Equal(t, 1, loadCount) + + err := svc.Delete(context.Background(), 1) + require.NoError(t, err) + + // Auth cache invalidated for both groups + require.ElementsMatch(t, []int64{10, 20}, auth.invalidatedGroupIDs) + + // Channel cache invalidated + _, _ = svc.GetChannelForGroup(context.Background(), 10) + require.Equal(t, 2, loadCount) +} + +func TestChannelDelete_NotFound(t *testing.T) { + repo := &mockChannelRepository{ + getGroupIDsFn: func(_ context.Context, _ int64) ([]int64, error) { + return nil, nil + }, + deleteFn: func(_ context.Context, _ int64) error { + return errors.New("record not found") + }, + } + svc := newTestChannelService(repo) + + err := svc.Delete(context.Background(), 999) + require.Error(t, err) + require.Contains(t, err.Error(), "not found") +} + +// =========================================================================== +// 6. Edge Case Tests +// =========================================================================== + +// --- 6.1 Create with empty GroupIDs --- + +func TestCreate_NoGroups(t *testing.T) { + createdID := int64(55) + getGroupsInOtherChannelsCalled := false + repo := &mockChannelRepository{ + existsByNameFn: func(_ context.Context, _ string) (bool, error) { + return false, nil + }, + getGroupsInOtherChannelsFn: func(_ context.Context, _ int64, _ []int64) ([]int64, error) { + getGroupsInOtherChannelsCalled = true + return nil, nil + }, + createFn: func(_ context.Context, ch *Channel) error { + ch.ID = createdID + return nil + }, + getByIDFn: func(_ context.Context, id int64) (*Channel, error) { + return &Channel{ID: id, Name: "no-groups-channel", Status: StatusActive}, nil + }, + listAllFn: func(_ context.Context) ([]Channel, error) { + return nil, nil + }, + } + svc := newTestChannelService(repo) + + result, err := svc.Create(context.Background(), &CreateChannelInput{ + Name: "no-groups-channel", + GroupIDs: []int64{}, // empty slice + }) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, createdID, result.ID) + // GetGroupsInOtherChannels should NOT have been called (skipped by len(input.GroupIDs) > 0) + require.False(t, getGroupsInOtherChannelsCalled) +} + +// --- 6.2 Update only Status --- + +func TestUpdate_StatusOnly(t *testing.T) { + existing := &Channel{ + ID: 1, + Name: "test-channel", + Status: StatusActive, + } + var capturedChannel *Channel + repo := &mockChannelRepository{ + getByIDFn: func(_ context.Context, id int64) (*Channel, error) { + return existing.Clone(), nil + }, + updateFn: func(_ context.Context, ch *Channel) error { + capturedChannel = ch + return nil + }, + getGroupIDsFn: func(_ context.Context, _ int64) ([]int64, error) { + return nil, nil + }, + listAllFn: func(_ context.Context) ([]Channel, error) { + return nil, nil + }, + } + svc := newTestChannelService(repo) + + result, err := svc.Update(context.Background(), 1, &UpdateChannelInput{ + Status: StatusDisabled, + }) + require.NoError(t, err) + require.NotNil(t, result) + // Verify that the channel passed to repo.Update has the new status + require.NotNil(t, capturedChannel) + require.Equal(t, StatusDisabled, capturedChannel.Status) + // Name should remain unchanged + require.Equal(t, "test-channel", capturedChannel.Name) +} + +// --- 6.3 Delete when GetGroupIDs fails --- + +func TestChannelDelete_GetGroupIDsError(t *testing.T) { + deleted := false + repo := &mockChannelRepository{ + getGroupIDsFn: func(_ context.Context, _ int64) ([]int64, error) { + return nil, errors.New("group IDs lookup failed") + }, + deleteFn: func(_ context.Context, _ int64) error { + deleted = true + return nil + }, + listAllFn: func(_ context.Context) ([]Channel, error) { + return nil, nil + }, + } + svc := newTestChannelService(repo) + + // Delete should still succeed even though GetGroupIDs returned error (degradation path L588-591) + err := svc.Delete(context.Background(), 1) + require.NoError(t, err) + require.True(t, deleted) +} + +// --- 6.4 ReplaceModelInBody with invalid JSON --- + +func TestReplaceModelInBody_InvalidJSON(t *testing.T) { + // Case 1: broken JSON object — gjson won't find "model", sjson does best-effort set + // (no panic, no error from sjson, but result is mutated garbage) + brokenBody := []byte("{broken") + result := ReplaceModelInBody(brokenBody, "new-model") + require.NotNil(t, result) + // sjson does not error on this input, so result differs from original — just verify no panic + + // Case 2: JSON array — sjson.SetBytes returns error on non-object, + // triggering the L447 error fallback path that returns original body. + arrayBody := []byte("[]") + result2 := ReplaceModelInBody(arrayBody, "new-model") + require.Equal(t, arrayBody, result2) +} diff --git a/backend/internal/service/channel_test.go b/backend/internal/service/channel_test.go index 46cf193a..d01c252b 100644 --- a/backend/internal/service/channel_test.go +++ b/backend/internal/service/channel_test.go @@ -8,13 +8,10 @@ import ( "github.com/stretchr/testify/require" ) -func channelTestPtrFloat64(v float64) *float64 { return &v } -func channelTestPtrInt(v int) *int { return &v } - func TestGetModelPricing(t *testing.T) { ch := &Channel{ ModelPricing: []ChannelModelPricing{ - {ID: 1, Models: []string{"claude-sonnet-4"}, BillingMode: BillingModeToken, InputPrice: channelTestPtrFloat64(3e-6)}, + {ID: 1, Models: []string{"claude-sonnet-4"}, BillingMode: BillingModeToken, InputPrice: testPtrFloat64(3e-6)}, {ID: 3, Models: []string{"gpt-5.1"}, BillingMode: BillingModePerRequest}, }, } @@ -48,7 +45,7 @@ func TestGetModelPricing(t *testing.T) { func TestGetModelPricing_ReturnsCopy(t *testing.T) { ch := &Channel{ ModelPricing: []ChannelModelPricing{ - {ID: 1, Models: []string{"claude-sonnet-4"}, InputPrice: channelTestPtrFloat64(3e-6)}, + {ID: 1, Models: []string{"claude-sonnet-4"}, InputPrice: testPtrFloat64(3e-6)}, }, } @@ -73,23 +70,23 @@ func TestGetModelPricing_EmptyPricing(t *testing.T) { func TestGetIntervalForContext(t *testing.T) { p := &ChannelModelPricing{ Intervals: []PricingInterval{ - {MinTokens: 0, MaxTokens: channelTestPtrInt(128000), InputPrice: channelTestPtrFloat64(1e-6)}, - {MinTokens: 128000, MaxTokens: nil, InputPrice: channelTestPtrFloat64(2e-6)}, + {MinTokens: 0, MaxTokens: testPtrInt(128000), InputPrice: testPtrFloat64(1e-6)}, + {MinTokens: 128000, MaxTokens: nil, InputPrice: testPtrFloat64(2e-6)}, }, } tests := []struct { - name string - tokens int - wantPrice *float64 - wantNil bool + name string + tokens int + wantPrice *float64 + wantNil bool }{ - {"first interval", 50000, channelTestPtrFloat64(1e-6), false}, + {"first interval", 50000, testPtrFloat64(1e-6), false}, // (min, max] — 128000 在第一个区间的 max,包含,所以匹配第一个 - {"boundary: max of first (inclusive)", 128000, channelTestPtrFloat64(1e-6), false}, + {"boundary: max of first (inclusive)", 128000, testPtrFloat64(1e-6), false}, // 128001 > 128000,匹配第二个区间 - {"boundary: just above first max", 128001, channelTestPtrFloat64(2e-6), false}, - {"unbounded interval", 500000, channelTestPtrFloat64(2e-6), false}, + {"boundary: just above first max", 128001, testPtrFloat64(2e-6), false}, + {"unbounded interval", 500000, testPtrFloat64(2e-6), false}, // (0, max] — 0 不匹配任何区间(左开) {"zero tokens: no match", 0, nil, true}, } @@ -110,11 +107,11 @@ func TestGetIntervalForContext(t *testing.T) { func TestGetIntervalForContext_NoMatch(t *testing.T) { p := &ChannelModelPricing{ Intervals: []PricingInterval{ - {MinTokens: 10000, MaxTokens: channelTestPtrInt(50000)}, + {MinTokens: 10000, MaxTokens: testPtrInt(50000)}, }, } - require.Nil(t, p.GetIntervalForContext(5000)) // 5000 <= 10000, not > min - require.Nil(t, p.GetIntervalForContext(10000)) // 10000 not > 10000 (left-open) + require.Nil(t, p.GetIntervalForContext(5000)) // 5000 <= 10000, not > min + require.Nil(t, p.GetIntervalForContext(10000)) // 10000 not > 10000 (left-open) require.NotNil(t, p.GetIntervalForContext(50000)) // 50000 <= 50000 (right-closed) require.Nil(t, p.GetIntervalForContext(50001)) // 50001 > 50000 } @@ -127,9 +124,9 @@ func TestGetIntervalForContext_Empty(t *testing.T) { func TestGetTierByLabel(t *testing.T) { p := &ChannelModelPricing{ Intervals: []PricingInterval{ - {TierLabel: "1K", PerRequestPrice: channelTestPtrFloat64(0.04)}, - {TierLabel: "2K", PerRequestPrice: channelTestPtrFloat64(0.08)}, - {TierLabel: "HD", PerRequestPrice: channelTestPtrFloat64(0.12)}, + {TierLabel: "1K", PerRequestPrice: testPtrFloat64(0.04)}, + {TierLabel: "2K", PerRequestPrice: testPtrFloat64(0.08)}, + {TierLabel: "HD", PerRequestPrice: testPtrFloat64(0.12)}, }, } @@ -171,7 +168,7 @@ func TestChannelClone(t *testing.T) { { ID: 100, Models: []string{"model-a"}, - InputPrice: channelTestPtrFloat64(5e-6), + InputPrice: testPtrFloat64(5e-6), }, }, } @@ -211,3 +208,102 @@ func TestChannelModelPricingClone(t *testing.T) { cloned.Intervals[0].TierLabel = "hacked" require.Equal(t, "tier1", original.Intervals[0].TierLabel) } + +// --- BillingMode.IsValid --- + +func TestBillingModeIsValid(t *testing.T) { + tests := []struct { + name string + mode BillingMode + want bool + }{ + {"token", BillingModeToken, true}, + {"per_request", BillingModePerRequest, true}, + {"image", BillingModeImage, true}, + {"empty", BillingMode(""), true}, + {"unknown", BillingMode("unknown"), false}, + {"random", BillingMode("xyz"), false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.want, tt.mode.IsValid()) + }) + } +} + +// --- Channel.IsActive --- + +func TestChannelIsActive(t *testing.T) { + tests := []struct { + name string + status string + want bool + }{ + {"active", StatusActive, true}, + {"disabled", "disabled", false}, + {"empty", "", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ch := &Channel{Status: tt.status} + require.Equal(t, tt.want, ch.IsActive()) + }) + } +} + +// --- ChannelModelPricing.Clone edge cases --- + +func TestChannelModelPricingClone_EdgeCases(t *testing.T) { + t.Run("nil models", func(t *testing.T) { + original := ChannelModelPricing{Models: nil} + cloned := original.Clone() + require.Nil(t, cloned.Models) + }) + + t.Run("nil intervals", func(t *testing.T) { + original := ChannelModelPricing{Intervals: nil} + cloned := original.Clone() + require.Nil(t, cloned.Intervals) + }) + + t.Run("empty models", func(t *testing.T) { + original := ChannelModelPricing{Models: []string{}} + cloned := original.Clone() + require.NotNil(t, cloned.Models) + require.Empty(t, cloned.Models) + }) +} + +// --- Channel.Clone edge cases --- + +func TestChannelClone_EdgeCases(t *testing.T) { + t.Run("nil model mapping", func(t *testing.T) { + original := &Channel{ID: 1, ModelMapping: nil} + cloned := original.Clone() + require.Nil(t, cloned.ModelMapping) + }) + + t.Run("nil model pricing", func(t *testing.T) { + original := &Channel{ID: 1, ModelPricing: nil} + cloned := original.Clone() + require.Nil(t, cloned.ModelPricing) + }) + + t.Run("deep copy model mapping", func(t *testing.T) { + original := &Channel{ + ID: 1, + ModelMapping: map[string]map[string]string{ + "openai": {"gpt-4": "gpt-4-turbo"}, + }, + } + cloned := original.Clone() + + // Modify the cloned nested map + cloned.ModelMapping["openai"]["gpt-4"] = "hacked" + + // Original must remain unchanged + require.Equal(t, "gpt-4-turbo", original.ModelMapping["openai"]["gpt-4"]) + }) +} diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 5a866d63..a420fa9b 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -7407,11 +7407,7 @@ type RecordUsageInput struct { ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换) APIKeyService APIKeyQuotaUpdater // 可选:用于更新API Key配额 - // 渠道映射信息(由 handler 在 Forward 前解析) - ChannelID int64 // 渠道 ID(0 = 无渠道) - OriginalModel string // 用户原始请求模型(渠道映射前) - BillingModelSource string // 计费模型来源:"requested" / "upstream" - ModelMappingChain string // 映射链描述,如 "a→b→c" + ChannelUsageFields // 渠道映射信息(由 handler 在 Forward 前解析) } // APIKeyQuotaUpdater defines the interface for updating API Key quota and rate limit usage @@ -7940,11 +7936,7 @@ type RecordUsageLongContextInput struct { ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换) APIKeyService APIKeyQuotaUpdater // API Key 配额服务(可选) - // 渠道映射信息(由 handler 在 Forward 前解析) - ChannelID int64 // 渠道 ID(0 = 无渠道) - OriginalModel string // 用户原始请求模型(渠道映射前) - BillingModelSource string // 计费模型来源:"requested" / "upstream" - ModelMappingChain string // 映射链描述,如 "a→b→c" + ChannelUsageFields // 渠道映射信息(由 handler 在 Forward 前解析) } // RecordUsageWithLongContext 记录使用量并扣费,支持长上下文双倍计费(用于 Gemini) diff --git a/backend/internal/service/model_pricing_resolver_test.go b/backend/internal/service/model_pricing_resolver_test.go index 5b4a0b13..74a7c3e2 100644 --- a/backend/internal/service/model_pricing_resolver_test.go +++ b/backend/internal/service/model_pricing_resolver_test.go @@ -4,14 +4,12 @@ package service import ( "context" + "errors" "testing" "github.com/stretchr/testify/require" ) -func resolverPtrFloat64(v float64) *float64 { return &v } -func resolverPtrInt(v int) *int { return &v } - func newTestBillingServiceForResolver() *BillingService { bs := &BillingService{ fallbackPrices: make(map[string]*ModelPricing), @@ -83,8 +81,8 @@ func TestGetIntervalPricing_MatchesInterval(t *testing.T) { BasePricing: &ModelPricing{InputPricePerToken: 5e-6}, SupportsCacheBreakdown: true, Intervals: []PricingInterval{ - {MinTokens: 0, MaxTokens: resolverPtrInt(128000), InputPrice: resolverPtrFloat64(1e-6), OutputPrice: resolverPtrFloat64(2e-6)}, - {MinTokens: 128000, MaxTokens: nil, InputPrice: resolverPtrFloat64(3e-6), OutputPrice: resolverPtrFloat64(6e-6)}, + {MinTokens: 0, MaxTokens: testPtrInt(128000), InputPrice: testPtrFloat64(1e-6), OutputPrice: testPtrFloat64(2e-6)}, + {MinTokens: 128000, MaxTokens: nil, InputPrice: testPtrFloat64(3e-6), OutputPrice: testPtrFloat64(6e-6)}, }, } @@ -108,7 +106,7 @@ func TestGetIntervalPricing_NoMatch_FallsBackToBase(t *testing.T) { Mode: BillingModeToken, BasePricing: basePricing, Intervals: []PricingInterval{ - {MinTokens: 10000, MaxTokens: resolverPtrInt(50000), InputPrice: resolverPtrFloat64(1e-6)}, + {MinTokens: 10000, MaxTokens: testPtrInt(50000), InputPrice: testPtrFloat64(1e-6)}, }, } @@ -123,8 +121,8 @@ func TestGetRequestTierPrice(t *testing.T) { resolved := &ResolvedPricing{ Mode: BillingModePerRequest, RequestTiers: []PricingInterval{ - {TierLabel: "1K", PerRequestPrice: resolverPtrFloat64(0.04)}, - {TierLabel: "2K", PerRequestPrice: resolverPtrFloat64(0.08)}, + {TierLabel: "1K", PerRequestPrice: testPtrFloat64(0.04)}, + {TierLabel: "2K", PerRequestPrice: testPtrFloat64(0.08)}, }, } @@ -140,8 +138,8 @@ func TestGetRequestTierPriceByContext(t *testing.T) { resolved := &ResolvedPricing{ Mode: BillingModePerRequest, RequestTiers: []PricingInterval{ - {MinTokens: 0, MaxTokens: resolverPtrInt(128000), PerRequestPrice: resolverPtrFloat64(0.05)}, - {MinTokens: 128000, MaxTokens: nil, PerRequestPrice: resolverPtrFloat64(0.10)}, + {MinTokens: 0, MaxTokens: testPtrInt(128000), PerRequestPrice: testPtrFloat64(0.05)}, + {MinTokens: 128000, MaxTokens: nil, PerRequestPrice: testPtrFloat64(0.10)}, }, } @@ -162,3 +160,428 @@ func TestGetRequestTierPrice_NilPerRequestPrice(t *testing.T) { require.InDelta(t, 0.0, r.GetRequestTierPrice(resolved, "1K"), 1e-12) } + +// =========================================================================== +// Channel override tests — exercises applyChannelOverrides via Resolve +// =========================================================================== + +// helper: creates a resolver wired to a ChannelService that returns the given +// channel (active, groupID=100, platform=anthropic) with the specified pricing. +func newResolverWithChannel(t *testing.T, pricing []ChannelModelPricing) *ModelPricingResolver { + t.Helper() + const groupID = 100 + repo := &mockChannelRepository{ + listAllFn: func(_ context.Context) ([]Channel, error) { + return []Channel{{ + ID: 1, + Name: "test-channel", + Status: StatusActive, + GroupIDs: []int64{groupID}, + ModelPricing: pricing, + }}, nil + }, + getGroupPlatformsFn: func(_ context.Context, _ []int64) (map[int64]string, error) { + return map[int64]string{groupID: "anthropic"}, nil + }, + } + cs := NewChannelService(repo, nil) + bs := newTestBillingServiceForResolver() + return NewModelPricingResolver(cs, bs) +} + +// groupIDPtr returns a pointer to groupID 100 (the test constant). +func groupIDPtr() *int64 { v := int64(100); return &v } + +// --------------------------------------------------------------------------- +// 1. Token mode overrides +// --------------------------------------------------------------------------- + +func TestResolve_WithChannelOverride_TokenFlat(t *testing.T) { + r := newResolverWithChannel(t, []ChannelModelPricing{{ + Platform: "anthropic", + Models: []string{"claude-sonnet-4"}, + BillingMode: BillingModeToken, + InputPrice: testPtrFloat64(10e-6), + OutputPrice: testPtrFloat64(50e-6), + }}) + + resolved := r.Resolve(context.Background(), PricingInput{ + Model: "claude-sonnet-4", + GroupID: groupIDPtr(), + }) + + require.NotNil(t, resolved) + require.Equal(t, BillingModeToken, resolved.Mode) + require.Equal(t, "channel", resolved.Source) + require.NotNil(t, resolved.BasePricing) + require.InDelta(t, 10e-6, resolved.BasePricing.InputPricePerToken, 1e-12) + require.InDelta(t, 10e-6, resolved.BasePricing.InputPricePerTokenPriority, 1e-12) + require.InDelta(t, 50e-6, resolved.BasePricing.OutputPricePerToken, 1e-12) + require.InDelta(t, 50e-6, resolved.BasePricing.OutputPricePerTokenPriority, 1e-12) +} + +func TestResolve_WithChannelOverride_TokenPartialOverride(t *testing.T) { + // Channel only sets InputPrice; OutputPrice should remain from the base (LiteLLM/fallback). + r := newResolverWithChannel(t, []ChannelModelPricing{{ + Platform: "anthropic", + Models: []string{"claude-sonnet-4"}, + BillingMode: BillingModeToken, + InputPrice: testPtrFloat64(20e-6), + // OutputPrice intentionally nil + }}) + + resolved := r.Resolve(context.Background(), PricingInput{ + Model: "claude-sonnet-4", + GroupID: groupIDPtr(), + }) + + require.NotNil(t, resolved) + require.Equal(t, "channel", resolved.Source) + require.NotNil(t, resolved.BasePricing) + // InputPrice overridden by channel + require.InDelta(t, 20e-6, resolved.BasePricing.InputPricePerToken, 1e-12) + // OutputPrice kept from base (fallback: 15e-6) + require.InDelta(t, 15e-6, resolved.BasePricing.OutputPricePerToken, 1e-12) +} + +func TestResolve_WithChannelOverride_TokenWithIntervals(t *testing.T) { + r := newResolverWithChannel(t, []ChannelModelPricing{{ + Platform: "anthropic", + Models: []string{"claude-sonnet-4"}, + BillingMode: BillingModeToken, + Intervals: []PricingInterval{ + {MinTokens: 0, MaxTokens: testPtrInt(128000), InputPrice: testPtrFloat64(2e-6), OutputPrice: testPtrFloat64(8e-6)}, + {MinTokens: 128000, MaxTokens: nil, InputPrice: testPtrFloat64(4e-6), OutputPrice: testPtrFloat64(16e-6)}, + }, + }}) + + resolved := r.Resolve(context.Background(), PricingInput{ + Model: "claude-sonnet-4", + GroupID: groupIDPtr(), + }) + + require.NotNil(t, resolved) + require.Equal(t, "channel", resolved.Source) + require.Len(t, resolved.Intervals, 2) + + // GetIntervalPricing should use channel intervals + iv := r.GetIntervalPricing(resolved, 50000) + require.NotNil(t, iv) + require.InDelta(t, 2e-6, iv.InputPricePerToken, 1e-12) + require.InDelta(t, 8e-6, iv.OutputPricePerToken, 1e-12) + + iv2 := r.GetIntervalPricing(resolved, 200000) + require.NotNil(t, iv2) + require.InDelta(t, 4e-6, iv2.InputPricePerToken, 1e-12) + require.InDelta(t, 16e-6, iv2.OutputPricePerToken, 1e-12) +} + +func TestResolve_WithChannelOverride_TokenNilBasePricing(t *testing.T) { + // Base pricing is nil (unknown model), channel has flat prices → creates new BasePricing. + r := newResolverWithChannel(t, []ChannelModelPricing{{ + Platform: "anthropic", + Models: []string{"unknown-model-xyz"}, + BillingMode: BillingModeToken, + InputPrice: testPtrFloat64(7e-6), + OutputPrice: testPtrFloat64(21e-6), + }}) + + resolved := r.Resolve(context.Background(), PricingInput{ + Model: "unknown-model-xyz", + GroupID: groupIDPtr(), + }) + + require.NotNil(t, resolved) + require.Equal(t, "channel", resolved.Source) + // BasePricing was nil from resolveBasePricing but applyTokenOverrides creates a new one + require.NotNil(t, resolved.BasePricing) + require.InDelta(t, 7e-6, resolved.BasePricing.InputPricePerToken, 1e-12) + require.InDelta(t, 21e-6, resolved.BasePricing.OutputPricePerToken, 1e-12) +} + +// --------------------------------------------------------------------------- +// 2. Per-request mode overrides +// --------------------------------------------------------------------------- + +func TestResolve_WithChannelOverride_PerRequest(t *testing.T) { + r := newResolverWithChannel(t, []ChannelModelPricing{{ + Platform: "anthropic", + Models: []string{"claude-sonnet-4"}, + BillingMode: BillingModePerRequest, + PerRequestPrice: testPtrFloat64(0.05), + Intervals: []PricingInterval{ + {MinTokens: 0, MaxTokens: testPtrInt(128000), PerRequestPrice: testPtrFloat64(0.03)}, + {MinTokens: 128000, MaxTokens: nil, PerRequestPrice: testPtrFloat64(0.10)}, + }, + }}) + + resolved := r.Resolve(context.Background(), PricingInput{ + Model: "claude-sonnet-4", + GroupID: groupIDPtr(), + }) + + require.NotNil(t, resolved) + require.Equal(t, BillingModePerRequest, resolved.Mode) + require.Equal(t, "channel", resolved.Source) + require.InDelta(t, 0.05, resolved.DefaultPerRequestPrice, 1e-12) + require.Len(t, resolved.RequestTiers, 2) + + // Verify tier lookups + require.InDelta(t, 0.03, r.GetRequestTierPriceByContext(resolved, 50000), 1e-12) + require.InDelta(t, 0.10, r.GetRequestTierPriceByContext(resolved, 200000), 1e-12) +} + +func TestResolve_WithChannelOverride_PerRequestNilPrice(t *testing.T) { + // PerRequestPrice nil → DefaultPerRequestPrice stays 0. + r := newResolverWithChannel(t, []ChannelModelPricing{{ + Platform: "anthropic", + Models: []string{"claude-sonnet-4"}, + BillingMode: BillingModePerRequest, + // PerRequestPrice intentionally nil + Intervals: []PricingInterval{ + {MinTokens: 0, MaxTokens: testPtrInt(128000), PerRequestPrice: testPtrFloat64(0.02)}, + }, + }}) + + resolved := r.Resolve(context.Background(), PricingInput{ + Model: "claude-sonnet-4", + GroupID: groupIDPtr(), + }) + + require.NotNil(t, resolved) + require.Equal(t, BillingModePerRequest, resolved.Mode) + require.InDelta(t, 0.0, resolved.DefaultPerRequestPrice, 1e-12) + require.Len(t, resolved.RequestTiers, 1) +} + +// --------------------------------------------------------------------------- +// 3. Image mode overrides +// --------------------------------------------------------------------------- + +func TestResolve_WithChannelOverride_Image(t *testing.T) { + r := newResolverWithChannel(t, []ChannelModelPricing{{ + Platform: "anthropic", + Models: []string{"claude-sonnet-4"}, + BillingMode: BillingModeImage, + PerRequestPrice: testPtrFloat64(0.08), + Intervals: []PricingInterval{ + {TierLabel: "1K", PerRequestPrice: testPtrFloat64(0.04)}, + {TierLabel: "2K", PerRequestPrice: testPtrFloat64(0.08)}, + {TierLabel: "4K", PerRequestPrice: testPtrFloat64(0.16)}, + }, + }}) + + resolved := r.Resolve(context.Background(), PricingInput{ + Model: "claude-sonnet-4", + GroupID: groupIDPtr(), + }) + + require.NotNil(t, resolved) + require.Equal(t, BillingModeImage, resolved.Mode) + require.Equal(t, "channel", resolved.Source) + require.InDelta(t, 0.08, resolved.DefaultPerRequestPrice, 1e-12) + require.Len(t, resolved.RequestTiers, 3) +} + +func TestResolve_WithChannelOverride_ImageTierLabels(t *testing.T) { + r := newResolverWithChannel(t, []ChannelModelPricing{{ + Platform: "anthropic", + Models: []string{"claude-sonnet-4"}, + BillingMode: BillingModeImage, + Intervals: []PricingInterval{ + {TierLabel: "1K", PerRequestPrice: testPtrFloat64(0.04)}, + {TierLabel: "2K", PerRequestPrice: testPtrFloat64(0.08)}, + {TierLabel: "4K", PerRequestPrice: testPtrFloat64(0.16)}, + }, + }}) + + resolved := r.Resolve(context.Background(), PricingInput{ + Model: "claude-sonnet-4", + GroupID: groupIDPtr(), + }) + + require.InDelta(t, 0.04, r.GetRequestTierPrice(resolved, "1K"), 1e-12) + require.InDelta(t, 0.08, r.GetRequestTierPrice(resolved, "2K"), 1e-12) + require.InDelta(t, 0.16, r.GetRequestTierPrice(resolved, "4K"), 1e-12) + require.InDelta(t, 0.0, r.GetRequestTierPrice(resolved, "8K"), 1e-12) // not found +} + +// --------------------------------------------------------------------------- +// 4. Source tracking & default mode +// --------------------------------------------------------------------------- + +func TestResolve_WithChannelOverride_SourceIsChannel(t *testing.T) { + r := newResolverWithChannel(t, []ChannelModelPricing{{ + Platform: "anthropic", + Models: []string{"claude-sonnet-4"}, + BillingMode: BillingModeToken, + InputPrice: testPtrFloat64(1e-6), + }}) + + resolved := r.Resolve(context.Background(), PricingInput{ + Model: "claude-sonnet-4", + GroupID: groupIDPtr(), + }) + + require.Equal(t, "channel", resolved.Source) +} + +func TestResolve_WithChannelOverride_DefaultMode(t *testing.T) { + // Channel pricing with empty BillingMode → defaults to BillingModeToken. + r := newResolverWithChannel(t, []ChannelModelPricing{{ + Platform: "anthropic", + Models: []string{"claude-sonnet-4"}, + BillingMode: "", // intentionally empty + InputPrice: testPtrFloat64(5e-6), + }}) + + resolved := r.Resolve(context.Background(), PricingInput{ + Model: "claude-sonnet-4", + GroupID: groupIDPtr(), + }) + + require.Equal(t, "channel", resolved.Source) + require.Equal(t, BillingModeToken, resolved.Mode) + require.NotNil(t, resolved.BasePricing) + require.InDelta(t, 5e-6, resolved.BasePricing.InputPricePerToken, 1e-12) +} + +// --------------------------------------------------------------------------- +// 5. GetIntervalPricing integration after channel override +// --------------------------------------------------------------------------- + +func TestGetIntervalPricing_WithChannelIntervals(t *testing.T) { + // Channel provides intervals that override the base pricing path. + r := newResolverWithChannel(t, []ChannelModelPricing{{ + Platform: "anthropic", + Models: []string{"claude-sonnet-4"}, + BillingMode: BillingModeToken, + Intervals: []PricingInterval{ + {MinTokens: 0, MaxTokens: testPtrInt(100000), InputPrice: testPtrFloat64(1e-6), OutputPrice: testPtrFloat64(5e-6)}, + {MinTokens: 100000, MaxTokens: nil, InputPrice: testPtrFloat64(2e-6), OutputPrice: testPtrFloat64(10e-6)}, + }, + }}) + + resolved := r.Resolve(context.Background(), PricingInput{ + Model: "claude-sonnet-4", + GroupID: groupIDPtr(), + }) + + // Token count 50000 matches first interval + pricing := r.GetIntervalPricing(resolved, 50000) + require.NotNil(t, pricing) + require.InDelta(t, 1e-6, pricing.InputPricePerToken, 1e-12) + require.InDelta(t, 5e-6, pricing.OutputPricePerToken, 1e-12) + + // Token count 150000 matches second interval + pricing2 := r.GetIntervalPricing(resolved, 150000) + require.NotNil(t, pricing2) + require.InDelta(t, 2e-6, pricing2.InputPricePerToken, 1e-12) + require.InDelta(t, 10e-6, pricing2.OutputPricePerToken, 1e-12) +} + +func TestGetIntervalPricing_ChannelIntervalsNoMatch(t *testing.T) { + // Channel intervals don't match token count → falls back to BasePricing. + r := newResolverWithChannel(t, []ChannelModelPricing{{ + Platform: "anthropic", + Models: []string{"claude-sonnet-4"}, + BillingMode: BillingModeToken, + Intervals: []PricingInterval{ + // Only covers tokens > 50000 + {MinTokens: 50000, MaxTokens: testPtrInt(200000), InputPrice: testPtrFloat64(9e-6)}, + }, + }}) + + resolved := r.Resolve(context.Background(), PricingInput{ + Model: "claude-sonnet-4", + GroupID: groupIDPtr(), + }) + + // Token count 1000 doesn't match any interval (1000 <= 50000 minTokens) + pricing := r.GetIntervalPricing(resolved, 1000) + // Should fall back to BasePricing (from the billing service fallback) + require.NotNil(t, pricing) + require.Equal(t, resolved.BasePricing, pricing) + require.InDelta(t, 3e-6, pricing.InputPricePerToken, 1e-12) // original base price +} + +// =========================================================================== +// 6. Error path tests +// =========================================================================== + +func TestResolve_WithChannelOverride_CacheError(t *testing.T) { + // When ListAll returns an error, the ChannelService cache build fails. + // Resolve should gracefully fall back to base pricing without panicking. + repo := &mockChannelRepository{ + listAllFn: func(_ context.Context) ([]Channel, error) { + return nil, errors.New("database unavailable") + }, + } + cs := NewChannelService(repo, nil) + bs := newTestBillingServiceForResolver() + r := NewModelPricingResolver(cs, bs) + + gid := int64(100) + resolved := r.Resolve(context.Background(), PricingInput{ + Model: "claude-sonnet-4", + GroupID: &gid, + }) + + require.NotNil(t, resolved) + // Should NOT panic, should NOT have source "channel" + require.NotEqual(t, "channel", resolved.Source) + // Base pricing should still be present (from BillingService fallback) + require.NotNil(t, resolved.BasePricing) + require.InDelta(t, 3e-6, resolved.BasePricing.InputPricePerToken, 1e-12) +} + +// =========================================================================== +// 7. GetRequestTierPriceByContext boundary tests +// =========================================================================== + +func TestGetRequestTierPriceByContext_EmptyTiers(t *testing.T) { + bs := newTestBillingServiceForResolver() + r := NewModelPricingResolver(&ChannelService{}, bs) + + resolved := &ResolvedPricing{ + Mode: BillingModePerRequest, + RequestTiers: nil, // empty + } + + price := r.GetRequestTierPriceByContext(resolved, 50000) + require.InDelta(t, 0.0, price, 1e-12) + + // Also test with explicit empty slice + resolved2 := &ResolvedPricing{ + Mode: BillingModePerRequest, + RequestTiers: []PricingInterval{}, + } + + price2 := r.GetRequestTierPriceByContext(resolved2, 50000) + require.InDelta(t, 0.0, price2, 1e-12) +} + +func TestGetRequestTierPriceByContext_ExactBoundary(t *testing.T) { + bs := newTestBillingServiceForResolver() + r := NewModelPricingResolver(&ChannelService{}, bs) + + resolved := &ResolvedPricing{ + Mode: BillingModePerRequest, + RequestTiers: []PricingInterval{ + {MinTokens: 0, MaxTokens: testPtrInt(128000), PerRequestPrice: testPtrFloat64(0.05)}, + {MinTokens: 128000, MaxTokens: nil, PerRequestPrice: testPtrFloat64(0.10)}, + }, + } + + // totalContextTokens = 128000 exactly: + // FindMatchingInterval checks: totalTokens > MinTokens && totalTokens <= MaxTokens + // For first interval: 128000 > 0 (true) && 128000 <= 128000 (true) → matches first interval + price := r.GetRequestTierPriceByContext(resolved, 128000) + require.InDelta(t, 0.05, price, 1e-12) + + // totalContextTokens = 128001 should match second interval + // For first interval: 128001 > 0 (true) && 128001 <= 128000 (false) → no match + // For second interval: 128001 > 128000 (true) && MaxTokens == nil → matches + price2 := r.GetRequestTierPriceByContext(resolved, 128001) + require.InDelta(t, 0.10, price2, 1e-12) +} diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index f68562f8..f9e92c78 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -4146,10 +4146,7 @@ type OpenAIRecordUsageInput struct { IPAddress string // 请求的客户端 IP 地址 RequestPayloadHash string APIKeyService APIKeyQuotaUpdater - ChannelID int64 - OriginalModel string - BillingModelSource string - ModelMappingChain string + ChannelUsageFields } // RecordUsage records usage and deducts balance diff --git a/backend/internal/service/testhelpers_test.go b/backend/internal/service/testhelpers_test.go new file mode 100644 index 00000000..73750e27 --- /dev/null +++ b/backend/internal/service/testhelpers_test.go @@ -0,0 +1,15 @@ +//go:build unit + +package service + +// testPtrFloat64 returns a pointer to the given float64 value. +func testPtrFloat64(v float64) *float64 { return &v } + +// testPtrInt returns a pointer to the given int value. +func testPtrInt(v int) *int { return &v } + +// testPtrString returns a pointer to the given string value. +func testPtrString(v string) *string { return &v } + +// testPtrBool returns a pointer to the given bool value. +func testPtrBool(v bool) *bool { return &v } diff --git a/frontend/src/api/admin/dashboard.ts b/frontend/src/api/admin/dashboard.ts index 15d1540f..49e487b7 100644 --- a/frontend/src/api/admin/dashboard.ts +++ b/frontend/src/api/admin/dashboard.ts @@ -167,6 +167,13 @@ export interface UserBreakdownParams { endpoint?: string endpoint_type?: 'inbound' | 'upstream' | 'path' limit?: number + // Additional filter conditions + user_id?: number + api_key_id?: number + account_id?: number + request_type?: number + stream?: boolean + billing_type?: number | null } export interface UserBreakdownResponse { diff --git a/frontend/src/components/admin/channel/types.ts b/frontend/src/components/admin/channel/types.ts index 8df71f84..cea57da0 100644 --- a/frontend/src/components/admin/channel/types.ts +++ b/frontend/src/components/admin/channel/types.ts @@ -73,6 +73,45 @@ export function formIntervalsToAPI(intervals: IntervalFormEntry[]): PricingInter })) } +// ── 模型模式冲突检测 ────────────────────────────────────── + +interface ModelPattern { + pattern: string + prefix: string // lowercase, 通配符去掉尾部 * + wildcard: boolean +} + +function toModelPattern(model: string): ModelPattern { + const lower = model.toLowerCase() + const wildcard = lower.endsWith('*') + return { + pattern: model, + prefix: wildcard ? lower.slice(0, -1) : lower, + wildcard, + } +} + +function patternsConflict(a: ModelPattern, b: ModelPattern): boolean { + if (!a.wildcard && !b.wildcard) return a.prefix === b.prefix + if (a.wildcard && !b.wildcard) return b.prefix.startsWith(a.prefix) + if (!a.wildcard && b.wildcard) return a.prefix.startsWith(b.prefix) + // 双通配符:任一前缀是另一前缀的前缀即冲突 + return a.prefix.startsWith(b.prefix) || b.prefix.startsWith(a.prefix) +} + +/** 检测模型模式列表中的冲突,返回冲突的两个模式名;无冲突返回 null */ +export function findModelConflict(models: string[]): [string, string] | null { + const patterns = models.map(toModelPattern) + for (let i = 0; i < patterns.length; i++) { + for (let j = i + 1; j < patterns.length; j++) { + if (patternsConflict(patterns[i], patterns[j])) { + return [patterns[i].pattern, patterns[j].pattern] + } + } + } + return null +} + /** 平台对应的模型 tag 样式(背景+文字) */ export function getPlatformTagClass(platform: string): string { switch (platform) { diff --git a/frontend/src/components/charts/EndpointDistributionChart.vue b/frontend/src/components/charts/EndpointDistributionChart.vue index 5e3fc23b..32e05a93 100644 --- a/frontend/src/components/charts/EndpointDistributionChart.vue +++ b/frontend/src/components/charts/EndpointDistributionChart.vue @@ -161,6 +161,7 @@ const props = withDefaults( showSourceToggle?: boolean startDate?: string endDate?: string + filters?: Record }>(), { upstreamEndpointStats: () => [], @@ -193,6 +194,7 @@ const toggleBreakdown = async (endpoint: string) => { breakdownItems.value = [] try { const res = await getUserBreakdown({ + ...props.filters, start_date: props.startDate, end_date: props.endDate, endpoint, diff --git a/frontend/src/components/charts/GroupDistributionChart.vue b/frontend/src/components/charts/GroupDistributionChart.vue index f2be366f..560529b1 100644 --- a/frontend/src/components/charts/GroupDistributionChart.vue +++ b/frontend/src/components/charts/GroupDistributionChart.vue @@ -125,6 +125,7 @@ const props = withDefaults(defineProps<{ showMetricToggle?: boolean startDate?: string endDate?: string + filters?: Record }>(), { loading: false, metric: 'tokens', @@ -150,6 +151,7 @@ const toggleBreakdown = async (type: string, id: number | string) => { breakdownItems.value = [] try { const res = await getUserBreakdown({ + ...props.filters, start_date: props.startDate, end_date: props.endDate, group_id: Number(id), diff --git a/frontend/src/components/charts/ModelDistributionChart.vue b/frontend/src/components/charts/ModelDistributionChart.vue index a88da0c4..820eada3 100644 --- a/frontend/src/components/charts/ModelDistributionChart.vue +++ b/frontend/src/components/charts/ModelDistributionChart.vue @@ -270,6 +270,7 @@ const props = withDefaults(defineProps<{ rankingError?: boolean startDate?: string endDate?: string + filters?: Record }>(), { upstreamModelStats: () => [], mappingModelStats: () => [], @@ -302,6 +303,7 @@ const toggleBreakdown = async (type: string, id: string) => { breakdownItems.value = [] try { const res = await getUserBreakdown({ + ...props.filters, start_date: props.startDate, end_date: props.endDate, model: id, diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index 73df077f..9c557e2c 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -1744,6 +1744,8 @@ export default { deleteError: 'Failed to delete channel', nameRequired: 'Please enter a channel name', duplicateModels: 'Model "{0}" appears in multiple pricing entries', + modelConflict: "Model patterns '{model1}' and '{model2}' conflict: overlapping match range", + mappingConflict: "Mapping source patterns '{model1}' and '{model2}' conflict: overlapping match range", deleteConfirm: 'Are you sure you want to delete channel "{name}"? This cannot be undone.', columns: { name: 'Name', diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index d5dd769e..942766a4 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -1824,6 +1824,8 @@ export default { deleteError: '删除渠道失败', nameRequired: '请输入渠道名称', duplicateModels: '模型「{0}」在多个定价条目中重复', + modelConflict: "模型模式 '{model1}' 和 '{model2}' 冲突:匹配范围重叠", + mappingConflict: "模型映射源 '{model1}' 和 '{model2}' 冲突:匹配范围重叠", deleteConfirm: '确定要删除渠道「{name}」吗?此操作不可撤销。', columns: { name: '名称', diff --git a/frontend/src/views/admin/ChannelsView.vue b/frontend/src/views/admin/ChannelsView.vue index b38199e8..0111d25e 100644 --- a/frontend/src/views/admin/ChannelsView.vue +++ b/frontend/src/views/admin/ChannelsView.vue @@ -418,7 +418,7 @@ import { useAppStore } from '@/stores/app' import { adminAPI } from '@/api/admin' import type { Channel, ChannelModelPricing, CreateChannelRequest, UpdateChannelRequest } from '@/api/admin/channels' import type { PricingFormEntry } from '@/components/admin/channel/types' -import { mTokToPerToken, perTokenToMTok, apiIntervalsToForm, formIntervalsToAPI } from '@/components/admin/channel/types' +import { mTokToPerToken, perTokenToMTok, apiIntervalsToForm, formIntervalsToAPI, findModelConflict } from '@/components/admin/channel/types' import type { AdminGroup, GroupPlatform } from '@/types' import type { Column } from '@/components/common/types' import AppLayout from '@/components/layout/AppLayout.vue' @@ -875,19 +875,35 @@ async function handleSubmit() { } } - // Check duplicate models per platform (same model in different platforms is allowed) + // Check model pattern conflicts per platform (duplicate / wildcard overlap) for (const section of form.platforms.filter(s => s.enabled)) { - const seen = new Set() + // Collect all pricing models for this platform + const allModels: string[] = [] for (const entry of section.model_pricing) { - for (const m of entry.models) { - const key = m.toLowerCase() - if (seen.has(key)) { - const platformLabel = t('admin.groups.platforms.' + section.platform, section.platform) - appStore.showError(t('admin.channels.duplicateModels', `${platformLabel} 平台下模型 "${m}" 在多个定价条目中重复`)) - activeTab.value = section.platform - return - } - seen.add(key) + allModels.push(...entry.models) + } + const pricingConflict = findModelConflict(allModels) + if (pricingConflict) { + appStore.showError( + t('admin.channels.modelConflict', + { model1: pricingConflict[0], model2: pricingConflict[1] }, + `模型模式 '${pricingConflict[0]}' 和 '${pricingConflict[1]}' 冲突:匹配范围重叠`) + ) + activeTab.value = section.platform + return + } + // Check model mapping source pattern conflicts + const mappingKeys = Object.keys(section.model_mapping) + if (mappingKeys.length > 0) { + const mappingConflict = findModelConflict(mappingKeys) + if (mappingConflict) { + appStore.showError( + t('admin.channels.mappingConflict', + { model1: mappingConflict[0], model2: mappingConflict[1] }, + `模型映射源 '${mappingConflict[0]}' 和 '${mappingConflict[1]}' 冲突:匹配范围重叠`) + ) + activeTab.value = section.platform + return } } } diff --git a/frontend/src/views/admin/UsageView.vue b/frontend/src/views/admin/UsageView.vue index 6ec36c86..567769d9 100644 --- a/frontend/src/views/admin/UsageView.vue +++ b/frontend/src/views/admin/UsageView.vue @@ -34,6 +34,7 @@ :show-metric-toggle="true" :start-date="startDate" :end-date="endDate" + :filters="breakdownFilters" />
@@ -57,6 +59,7 @@ :title="t('usage.endpointDistribution')" :start-date="startDate" :end-date="endDate" + :filters="breakdownFilters" />
@@ -169,6 +172,17 @@ const cleanupDialogVisible = ref(false) const showBalanceHistoryModal = ref(false) const balanceHistoryUser = ref(null) +const breakdownFilters = computed(() => { + const f: Record = {} + if (filters.value.user_id) f.user_id = filters.value.user_id + if (filters.value.api_key_id) f.api_key_id = filters.value.api_key_id + if (filters.value.account_id) f.account_id = filters.value.account_id + if (filters.value.group_id) f.group_id = filters.value.group_id + if (filters.value.request_type != null) f.request_type = filters.value.request_type + if (filters.value.billing_type != null) f.billing_type = filters.value.billing_type + return f +}) + const handleUserClick = async (userId: number) => { try { const user = await adminAPI.users.getById(userId)