diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 4e95035a..7568fa50 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -174,7 +174,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oAuthRefreshAPI) digestSessionStore := service.NewDigestSessionStore() channelRepository := repository.NewChannelRepository(db) - channelService := service.NewChannelService(channelRepository, apiKeyAuthCacheInvalidator) + channelService := service.NewChannelService(channelRepository, groupRepository, apiKeyAuthCacheInvalidator) + availableChannelHandler := admin.NewAvailableChannelHandler(channelService) modelPricingResolver := service.NewModelPricingResolver(channelService, billingService) balanceNotifyService := service.ProvideBalanceNotifyService(emailService, settingRepository, accountRepository) gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService) @@ -234,7 +235,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService, paymentService) paymentOrderExpiryService := service.ProvidePaymentOrderExpiryService(paymentService) paymentHandler := admin.NewPaymentHandler(paymentService, paymentConfigService) - adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, channelMonitorHandler, channelMonitorRequestTemplateHandler, paymentHandler) + availableChannelUserHandler := handler.NewAvailableChannelHandler(channelService, apiKeyService) + adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, channelMonitorHandler, channelMonitorRequestTemplateHandler, availableChannelHandler, paymentHandler) usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig) userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient) userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig) @@ -246,7 +248,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { paymentWebhookHandler := handler.NewPaymentWebhookHandler(paymentService, registry) idempotencyCoordinator := service.ProvideIdempotencyCoordinator(idempotencyRepository, configConfig) idempotencyCleanupService := service.ProvideIdempotencyCleanupService(idempotencyRepository, configConfig) - handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, channelMonitorUserHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler, totpHandler, handlerPaymentHandler, paymentWebhookHandler, idempotencyCoordinator, idempotencyCleanupService) + handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, channelMonitorUserHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler, totpHandler, handlerPaymentHandler, paymentWebhookHandler, availableChannelUserHandler, idempotencyCoordinator, idempotencyCleanupService) jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService) adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService) apiKeyAuthMiddleware := middleware.NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig) diff --git a/backend/internal/handler/admin/available_channel_handler.go b/backend/internal/handler/admin/available_channel_handler.go new file mode 100644 index 00000000..53776105 --- /dev/null +++ b/backend/internal/handler/admin/available_channel_handler.go @@ -0,0 +1,99 @@ +package admin + +import ( + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/service" + + "github.com/gin-gonic/gin" +) + +// AvailableChannelHandler 处理「可用渠道」聚合视图的管理员接口。 +// +// 该视图以只读方式聚合渠道基础信息、关联分组与推导出的支持模型列表(无通配符)。 +type AvailableChannelHandler struct { + channelService *service.ChannelService +} + +// NewAvailableChannelHandler 创建 AvailableChannelHandler 实例。 +func NewAvailableChannelHandler(channelService *service.ChannelService) *AvailableChannelHandler { + return &AvailableChannelHandler{channelService: channelService} +} + +// availableGroupResponse 响应中的分组概要。 +type availableGroupResponse struct { + ID int64 `json:"id"` + Name string `json:"name"` + Platform string `json:"platform"` +} + +// supportedModelResponse 响应中的支持模型条目。 +type supportedModelResponse struct { + Name string `json:"name"` + Platform string `json:"platform"` + Pricing *channelModelPricingResponse `json:"pricing"` +} + +// availableChannelResponse 管理员视图完整字段集。 +type availableChannelResponse struct { + ID int64 `json:"id"` + Name string `json:"name"` + Description string `json:"description"` + Status string `json:"status"` + BillingModelSource string `json:"billing_model_source"` + RestrictModels bool `json:"restrict_models"` + Groups []availableGroupResponse `json:"groups"` + SupportedModels []supportedModelResponse `json:"supported_models"` +} + +// AvailableChannelToAdminResponse 将 service 层的 AvailableChannel 转为管理员 DTO。 +// 导出供同 package 的复用;也用于构造测试 fixture。 +func AvailableChannelToAdminResponse(ch service.AvailableChannel) availableChannelResponse { + groups := make([]availableGroupResponse, 0, len(ch.Groups)) + for _, g := range ch.Groups { + groups = append(groups, availableGroupResponse{ID: g.ID, Name: g.Name, Platform: g.Platform}) + } + models := make([]supportedModelResponse, 0, len(ch.SupportedModels)) + for i := range ch.SupportedModels { + m := ch.SupportedModels[i] + var pricing *channelModelPricingResponse + if m.Pricing != nil { + p := pricingToResponse(m.Pricing) + pricing = &p + } + models = append(models, supportedModelResponse{ + Name: m.Name, + Platform: m.Platform, + Pricing: pricing, + }) + } + billingSource := ch.BillingModelSource + if billingSource == "" { + billingSource = service.BillingModelSourceChannelMapped + } + return availableChannelResponse{ + ID: ch.ID, + Name: ch.Name, + Description: ch.Description, + Status: ch.Status, + BillingModelSource: billingSource, + RestrictModels: ch.RestrictModels, + Groups: groups, + SupportedModels: models, + } +} + +// List 列出所有可用渠道(管理员视图)。 +// GET /api/v1/admin/channels/available +func (h *AvailableChannelHandler) List(c *gin.Context) { + channels, err := h.channelService.ListAvailable(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + + out := make([]availableChannelResponse, 0, len(channels)) + for _, ch := range channels { + out = append(out, AvailableChannelToAdminResponse(ch)) + } + response.Success(c, gin.H{"items": out}) +} diff --git a/backend/internal/handler/admin/available_channel_handler_test.go b/backend/internal/handler/admin/available_channel_handler_test.go new file mode 100644 index 00000000..687e8dad --- /dev/null +++ b/backend/internal/handler/admin/available_channel_handler_test.go @@ -0,0 +1,57 @@ +//go:build unit + +package admin + +import ( + "encoding/json" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" +) + +func TestAvailableChannelToAdminResponse_IncludesFullDTO(t *testing.T) { + // 管理员视图应包含 id / status / billing_model_source / restrict_models 等 + // 管理字段;BillingModelSource 为空时应默认回填 channel_mapped。 + input := service.AvailableChannel{ + ID: 42, + Name: "ch", + Description: "d", + Status: service.StatusActive, + BillingModelSource: "", // 验证默认值填充 + RestrictModels: true, + Groups: []service.AvailableGroupRef{ + {ID: 1, Name: "g1", Platform: "anthropic"}, + }, + SupportedModels: []service.SupportedModel{ + {Name: "claude-sonnet-4-6", Platform: "anthropic"}, + }, + } + + resp := AvailableChannelToAdminResponse(input) + require.Equal(t, int64(42), resp.ID) + require.Equal(t, "ch", resp.Name) + require.Equal(t, service.StatusActive, resp.Status) + require.Equal(t, service.BillingModelSourceChannelMapped, resp.BillingModelSource) + require.True(t, resp.RestrictModels) + require.Len(t, resp.Groups, 1) + require.Len(t, resp.SupportedModels, 1) + + // JSON 层验证管理字段确实会被序列化。 + raw, err := json.Marshal(resp) + require.NoError(t, err) + var decoded map[string]any + require.NoError(t, json.Unmarshal(raw, &decoded)) + for _, key := range []string{"id", "status", "billing_model_source", "restrict_models", "groups", "supported_models"} { + _, exists := decoded[key] + require.Truef(t, exists, "admin DTO must expose %q", key) + } +} + +func TestAvailableChannelToAdminResponse_PreservesExplicitBillingSource(t *testing.T) { + input := service.AvailableChannel{ + BillingModelSource: service.BillingModelSourceUpstream, + } + resp := AvailableChannelToAdminResponse(input) + require.Equal(t, service.BillingModelSourceUpstream, resp.BillingModelSource) +} diff --git a/backend/internal/handler/available_channel_handler.go b/backend/internal/handler/available_channel_handler.go new file mode 100644 index 00000000..25452fc8 --- /dev/null +++ b/backend/internal/handler/available_channel_handler.go @@ -0,0 +1,216 @@ +package handler + +import ( + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + + "github.com/gin-gonic/gin" +) + +// AvailableChannelHandler 处理用户侧「可用渠道」查询。 +// +// 用户侧接口委托 ChannelService.ListAvailable,并在返回前做三层过滤: +// 1. 行过滤:只保留状态为 Active 且与当前用户可访问分组有交集的渠道; +// 2. 分组过滤:渠道的 Groups 只保留用户可访问的那些; +// 3. 平台过滤:渠道的 SupportedModels 只保留平台在用户可见 Groups 中出现过的模型, +// 防止"渠道同时挂在 antigravity / anthropic 两个平台的分组上,用户只访问 +// antigravity,却看到 anthropic 模型"这类跨平台信息泄漏; +// 4. 字段白名单:仅返回用户需要的字段(省略 BillingModelSource / RestrictModels +// / 内部 ID / Status 等管理字段)。 +type AvailableChannelHandler struct { + channelService *service.ChannelService + apiKeyService *service.APIKeyService +} + +// NewAvailableChannelHandler 创建用户侧可用渠道 handler。 +func NewAvailableChannelHandler( + channelService *service.ChannelService, + apiKeyService *service.APIKeyService, +) *AvailableChannelHandler { + return &AvailableChannelHandler{ + channelService: channelService, + apiKeyService: apiKeyService, + } +} + +// userAvailableGroup 用户可见的分组概要(白名单字段)。 +type userAvailableGroup struct { + ID int64 `json:"id"` + Name string `json:"name"` + Platform string `json:"platform"` +} + +// userSupportedModelPricing 用户可见的定价字段白名单。 +type userSupportedModelPricing struct { + BillingMode string `json:"billing_mode"` + InputPrice *float64 `json:"input_price"` + OutputPrice *float64 `json:"output_price"` + CacheWritePrice *float64 `json:"cache_write_price"` + CacheReadPrice *float64 `json:"cache_read_price"` + ImageOutputPrice *float64 `json:"image_output_price"` + PerRequestPrice *float64 `json:"per_request_price"` + Intervals []userPricingIntervalDTO `json:"intervals"` +} + +// userPricingIntervalDTO 定价区间白名单(去掉内部 ID、SortOrder 等前端不渲染的字段)。 +type userPricingIntervalDTO struct { + MinTokens int `json:"min_tokens"` + MaxTokens *int `json:"max_tokens"` + TierLabel string `json:"tier_label,omitempty"` + InputPrice *float64 `json:"input_price"` + OutputPrice *float64 `json:"output_price"` + CacheWritePrice *float64 `json:"cache_write_price"` + CacheReadPrice *float64 `json:"cache_read_price"` + PerRequestPrice *float64 `json:"per_request_price"` +} + +// userSupportedModel 用户可见的支持模型条目。 +type userSupportedModel struct { + Name string `json:"name"` + Platform string `json:"platform"` + Pricing *userSupportedModelPricing `json:"pricing"` +} + +// userAvailableChannel 用户可见的渠道条目(白名单字段)。 +type userAvailableChannel struct { + Name string `json:"name"` + Description string `json:"description"` + Groups []userAvailableGroup `json:"groups"` + SupportedModels []userSupportedModel `json:"supported_models"` +} + +// List 列出当前用户可见的「可用渠道」。 +// GET /api/v1/channels/available +func (h *AvailableChannelHandler) List(c *gin.Context) { + subject, ok := middleware2.GetAuthSubjectFromContext(c) + if !ok { + response.Unauthorized(c, "User not authenticated") + return + } + + userGroups, err := h.apiKeyService.GetAvailableGroups(c.Request.Context(), subject.UserID) + if err != nil { + response.ErrorFrom(c, err) + return + } + allowedGroupIDs := make(map[int64]struct{}, len(userGroups)) + for i := range userGroups { + allowedGroupIDs[userGroups[i].ID] = struct{}{} + } + + channels, err := h.channelService.ListAvailable(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + + out := make([]userAvailableChannel, 0, len(channels)) + for _, ch := range channels { + if ch.Status != service.StatusActive { + continue + } + visibleGroups := filterUserVisibleGroups(ch.Groups, allowedGroupIDs) + if len(visibleGroups) == 0 { + continue + } + allowedPlatforms := collectGroupPlatforms(visibleGroups) + out = append(out, userAvailableChannel{ + Name: ch.Name, + Description: ch.Description, + Groups: visibleGroups, + SupportedModels: toUserSupportedModels(ch.SupportedModels, allowedPlatforms), + }) + } + + response.Success(c, out) +} + +// collectGroupPlatforms 聚合 visible groups 覆盖的平台集合,用于过滤 SupportedModels。 +func collectGroupPlatforms(groups []userAvailableGroup) map[string]struct{} { + set := make(map[string]struct{}, len(groups)) + for _, g := range groups { + if g.Platform == "" { + continue + } + set[g.Platform] = struct{}{} + } + return set +} + +// filterUserVisibleGroups 仅保留用户可访问的分组。 +func filterUserVisibleGroups( + groups []service.AvailableGroupRef, + allowed map[int64]struct{}, +) []userAvailableGroup { + visible := make([]userAvailableGroup, 0, len(groups)) + for _, g := range groups { + if _, ok := allowed[g.ID]; !ok { + continue + } + visible = append(visible, userAvailableGroup{ + ID: g.ID, + Name: g.Name, + Platform: g.Platform, + }) + } + return visible +} + +// toUserSupportedModels 将 service 层支持模型转换为用户 DTO(字段白名单)。 +// 仅保留平台在 allowedPlatforms 中的条目,防止跨平台模型信息泄漏。 +// allowedPlatforms 为 nil 时不做平台过滤(保留全部,供测试或明确无过滤场景使用)。 +func toUserSupportedModels( + src []service.SupportedModel, + allowedPlatforms map[string]struct{}, +) []userSupportedModel { + out := make([]userSupportedModel, 0, len(src)) + for i := range src { + m := src[i] + if allowedPlatforms != nil { + if _, ok := allowedPlatforms[m.Platform]; !ok { + continue + } + } + out = append(out, userSupportedModel{ + Name: m.Name, + Platform: m.Platform, + Pricing: toUserPricing(m.Pricing), + }) + } + return out +} + +// toUserPricing 将 service 层定价转换为用户 DTO;入参为 nil 时返回 nil。 +func toUserPricing(p *service.ChannelModelPricing) *userSupportedModelPricing { + if p == nil { + return nil + } + intervals := make([]userPricingIntervalDTO, 0, len(p.Intervals)) + for _, iv := range p.Intervals { + intervals = append(intervals, userPricingIntervalDTO{ + MinTokens: iv.MinTokens, + MaxTokens: iv.MaxTokens, + TierLabel: iv.TierLabel, + InputPrice: iv.InputPrice, + OutputPrice: iv.OutputPrice, + CacheWritePrice: iv.CacheWritePrice, + CacheReadPrice: iv.CacheReadPrice, + PerRequestPrice: iv.PerRequestPrice, + }) + } + billingMode := string(p.BillingMode) + if billingMode == "" { + billingMode = string(service.BillingModeToken) + } + return &userSupportedModelPricing{ + BillingMode: billingMode, + InputPrice: p.InputPrice, + OutputPrice: p.OutputPrice, + CacheWritePrice: p.CacheWritePrice, + CacheReadPrice: p.CacheReadPrice, + ImageOutputPrice: p.ImageOutputPrice, + PerRequestPrice: p.PerRequestPrice, + Intervals: intervals, + } +} diff --git a/backend/internal/handler/available_channel_handler_test.go b/backend/internal/handler/available_channel_handler_test.go new file mode 100644 index 00000000..cc2ca33a --- /dev/null +++ b/backend/internal/handler/available_channel_handler_test.go @@ -0,0 +1,121 @@ +//go:build unit + +package handler + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/service" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestUserAvailableChannel_Unauthenticated401(t *testing.T) { + // 没有 AuthSubject 注入时,handler 应返回 401 且不触达 service 依赖。 + gin.SetMode(gin.TestMode) + h := &AvailableChannelHandler{} // nil services — 401 路径不会调用它们 + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/channels/available", nil) + + h.List(c) + + require.Equal(t, http.StatusUnauthorized, w.Code) +} + +func TestFilterUserVisibleGroups_IntersectionOnly(t *testing.T) { + // 渠道挂在 {g1, g2, g3},用户只允许 {g1, g3} —— 响应必须仅含 g1/g3。 + groups := []service.AvailableGroupRef{ + {ID: 1, Name: "g1", Platform: "anthropic"}, + {ID: 2, Name: "g2", Platform: "anthropic"}, + {ID: 3, Name: "g3", Platform: "openai"}, + } + allowed := map[int64]struct{}{1: {}, 3: {}} + + visible := filterUserVisibleGroups(groups, allowed) + require.Len(t, visible, 2) + ids := []int64{visible[0].ID, visible[1].ID} + require.ElementsMatch(t, []int64{1, 3}, ids) +} + +func TestCollectGroupPlatforms_DerivesAllowedSet(t *testing.T) { + groups := []userAvailableGroup{ + {ID: 1, Platform: "anthropic"}, + {ID: 2, Platform: "openai"}, + {ID: 3, Platform: "anthropic"}, // 去重 + {ID: 4, Platform: ""}, // 空平台忽略 + } + got := collectGroupPlatforms(groups) + require.Len(t, got, 2) + _, hasAnt := got["anthropic"] + _, hasOA := got["openai"] + require.True(t, hasAnt) + require.True(t, hasOA) +} + +func TestToUserSupportedModels_FiltersByAllowedPlatforms(t *testing.T) { + // 用户可访问分组只覆盖 anthropic;anthropic 平台的模型保留,openai 模型被剔除。 + src := []service.SupportedModel{ + {Name: "claude-sonnet-4-6", Platform: "anthropic", Pricing: nil}, + {Name: "gpt-4o", Platform: "openai", Pricing: nil}, + } + allowed := map[string]struct{}{"anthropic": {}} + out := toUserSupportedModels(src, allowed) + require.Len(t, out, 1) + require.Equal(t, "claude-sonnet-4-6", out[0].Name) +} + +func TestToUserSupportedModels_NilAllowedPlatformsKeepsAll(t *testing.T) { + // 显式传 nil allowedPlatforms 表示不做过滤。 + src := []service.SupportedModel{ + {Name: "a", Platform: "anthropic"}, + {Name: "b", Platform: "openai"}, + } + require.Len(t, toUserSupportedModels(src, nil), 2) +} + +func TestUserAvailableChannel_FieldWhitelist(t *testing.T) { + // 通过序列化 userAvailableChannel 结构体验证响应形状: + // 只有 name / description / groups / supported_models;不含管理端字段。 + row := userAvailableChannel{ + Name: "ch", + Description: "d", + Groups: []userAvailableGroup{{ID: 1, Name: "g1", Platform: "anthropic"}}, + SupportedModels: []userSupportedModel{}, + } + raw, err := json.Marshal(row) + require.NoError(t, err) + var decoded map[string]any + require.NoError(t, json.Unmarshal(raw, &decoded)) + + for _, key := range []string{"id", "status", "billing_model_source", "restrict_models"} { + _, exists := decoded[key] + require.Falsef(t, exists, "user DTO must not expose %q", key) + } + for _, key := range []string{"name", "description", "groups", "supported_models"} { + _, exists := decoded[key] + require.Truef(t, exists, "user DTO must expose %q", key) + } + + // pricing interval 白名单:不应暴露 id / sort_order。 + pricing := toUserPricing(&service.ChannelModelPricing{ + BillingMode: service.BillingModeToken, + Intervals: []service.PricingInterval{ + {ID: 7, MinTokens: 0, MaxTokens: nil, SortOrder: 3}, + }, + }) + require.NotNil(t, pricing) + require.Len(t, pricing.Intervals, 1) + rawIv, err := json.Marshal(pricing.Intervals[0]) + require.NoError(t, err) + var ivDecoded map[string]any + require.NoError(t, json.Unmarshal(rawIv, &ivDecoded)) + for _, key := range []string{"id", "pricing_id", "sort_order"} { + _, exists := ivDecoded[key] + require.Falsef(t, exists, "user pricing interval must not expose %q", key) + } +} diff --git a/backend/internal/handler/handler.go b/backend/internal/handler/handler.go index bedb81ae..a35d8041 100644 --- a/backend/internal/handler/handler.go +++ b/backend/internal/handler/handler.go @@ -33,26 +33,28 @@ type AdminHandlers struct { Channel *admin.ChannelHandler ChannelMonitor *admin.ChannelMonitorHandler ChannelMonitorTemplate *admin.ChannelMonitorRequestTemplateHandler + AvailableChannel *admin.AvailableChannelHandler Payment *admin.PaymentHandler } // Handlers contains all HTTP handlers type Handlers struct { - Auth *AuthHandler - User *UserHandler - APIKey *APIKeyHandler - Usage *UsageHandler - Redeem *RedeemHandler - Subscription *SubscriptionHandler - Announcement *AnnouncementHandler - ChannelMonitor *ChannelMonitorUserHandler - Admin *AdminHandlers - Gateway *GatewayHandler - OpenAIGateway *OpenAIGatewayHandler - Setting *SettingHandler - Totp *TotpHandler - Payment *PaymentHandler - PaymentWebhook *PaymentWebhookHandler + Auth *AuthHandler + User *UserHandler + APIKey *APIKeyHandler + Usage *UsageHandler + Redeem *RedeemHandler + Subscription *SubscriptionHandler + Announcement *AnnouncementHandler + ChannelMonitor *ChannelMonitorUserHandler + Admin *AdminHandlers + Gateway *GatewayHandler + OpenAIGateway *OpenAIGatewayHandler + Setting *SettingHandler + Totp *TotpHandler + Payment *PaymentHandler + PaymentWebhook *PaymentWebhookHandler + AvailableChannel *AvailableChannelHandler } // BuildInfo contains build-time information diff --git a/backend/internal/handler/wire.go b/backend/internal/handler/wire.go index 6584eb70..c9296b44 100644 --- a/backend/internal/handler/wire.go +++ b/backend/internal/handler/wire.go @@ -36,6 +36,7 @@ func ProvideAdminHandlers( channelHandler *admin.ChannelHandler, channelMonitorHandler *admin.ChannelMonitorHandler, channelMonitorTemplateHandler *admin.ChannelMonitorRequestTemplateHandler, + availableChannelHandler *admin.AvailableChannelHandler, paymentHandler *admin.PaymentHandler, ) *AdminHandlers { return &AdminHandlers{ @@ -66,6 +67,7 @@ func ProvideAdminHandlers( Channel: channelHandler, ChannelMonitor: channelMonitorHandler, ChannelMonitorTemplate: channelMonitorTemplateHandler, + AvailableChannel: availableChannelHandler, Payment: paymentHandler, } } @@ -97,25 +99,27 @@ func ProvideHandlers( totpHandler *TotpHandler, paymentHandler *PaymentHandler, paymentWebhookHandler *PaymentWebhookHandler, + availableChannelHandler *AvailableChannelHandler, _ *service.IdempotencyCoordinator, _ *service.IdempotencyCleanupService, ) *Handlers { return &Handlers{ - Auth: authHandler, - User: userHandler, - APIKey: apiKeyHandler, - Usage: usageHandler, - Redeem: redeemHandler, - Subscription: subscriptionHandler, - Announcement: announcementHandler, - ChannelMonitor: channelMonitorUserHandler, - Admin: adminHandlers, - Gateway: gatewayHandler, - OpenAIGateway: openaiGatewayHandler, - Setting: settingHandler, - Totp: totpHandler, - Payment: paymentHandler, - PaymentWebhook: paymentWebhookHandler, + Auth: authHandler, + User: userHandler, + APIKey: apiKeyHandler, + Usage: usageHandler, + Redeem: redeemHandler, + Subscription: subscriptionHandler, + Announcement: announcementHandler, + ChannelMonitor: channelMonitorUserHandler, + Admin: adminHandlers, + Gateway: gatewayHandler, + OpenAIGateway: openaiGatewayHandler, + Setting: settingHandler, + Totp: totpHandler, + Payment: paymentHandler, + PaymentWebhook: paymentWebhookHandler, + AvailableChannel: availableChannelHandler, } } @@ -136,6 +140,7 @@ var ProviderSet = wire.NewSet( ProvideSettingHandler, NewPaymentHandler, NewPaymentWebhookHandler, + NewAvailableChannelHandler, // Admin handlers admin.NewDashboardHandler, @@ -165,6 +170,7 @@ var ProviderSet = wire.NewSet( admin.NewChannelHandler, admin.NewChannelMonitorHandler, admin.NewChannelMonitorRequestTemplateHandler, + admin.NewAvailableChannelHandler, admin.NewPaymentHandler, // AdminHandlers and Handlers constructors diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index 4b796d55..e4b5c548 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -560,6 +560,7 @@ func registerChannelRoutes(admin *gin.RouterGroup, h *handler.Handlers) { channels := admin.Group("/channels") { channels.GET("", h.Admin.Channel.List) + channels.GET("/available", h.Admin.AvailableChannel.List) channels.GET("/model-pricing", h.Admin.Channel.GetModelDefaultPricing) channels.GET("/:id", h.Admin.Channel.GetByID) channels.POST("", h.Admin.Channel.Create) diff --git a/backend/internal/server/routes/user.go b/backend/internal/server/routes/user.go index 60503a5b..babab125 100644 --- a/backend/internal/server/routes/user.go +++ b/backend/internal/server/routes/user.go @@ -68,6 +68,12 @@ func RegisterUserRoutes( groups.GET("/rates", h.APIKey.GetUserGroupRates) } + // 用户可用渠道(非管理员接口) + channels := authenticated.Group("/channels") + { + channels.GET("/available", h.AvailableChannel.List) + } + // 使用记录 usage := authenticated.Group("/usage") { diff --git a/backend/internal/service/channel.go b/backend/internal/service/channel.go index 93beb972..de31e829 100644 --- a/backend/internal/service/channel.go +++ b/backend/internal/service/channel.go @@ -345,3 +345,175 @@ type ChannelUsageFields struct { BillingModelSource string // 计费模型来源:"requested" / "upstream" / "channel_mapped" ModelMappingChain string // 映射链描述,如 "a→b→c" } + +// SupportedModel 渠道的一个支持模型条目(无通配符、可直接展示给用户) +type SupportedModel struct { + Name string // 用户侧模型名 + Platform string // 所属平台 + Pricing *ChannelModelPricing // 定价详情(nil 表示未配置定价) +} + +// wildcardSuffix 是模型模式中的通配符后缀标记(仅支持尾部匹配)。 +const wildcardSuffix = "*" + +// splitWildcardSuffix 将模型模式拆分为 (prefix, isWildcard)。 +// +// "claude-opus-*" → ("claude-opus-", true) +// "claude-opus-4" → ("claude-opus-4", false) +// "*" → ("", true) +// +// 注意:返回的 prefix 保持原始大小写,由调用方按需 ToLower。 +func splitWildcardSuffix(pattern string) (prefix string, isWildcard bool) { + if strings.HasSuffix(pattern, wildcardSuffix) { + return strings.TrimSuffix(pattern, wildcardSuffix), true + } + return pattern, false +} + +// GetModelPricingByPlatform 在指定平台下查找精确模型的定价,未找到返回 nil。 +// 与 GetModelPricing 的区别:按 Platform 隔离,避免跨平台同名模型误匹配。 +func (c *Channel) GetModelPricingByPlatform(platform, model string) *ChannelModelPricing { + if c == nil { + return nil + } + modelLower := strings.ToLower(model) + for i := range c.ModelPricing { + if c.ModelPricing[i].Platform != platform { + continue + } + for _, m := range c.ModelPricing[i].Models { + if strings.ToLower(m) == modelLower { + cp := c.ModelPricing[i].Clone() + return &cp + } + } + } + return nil +} + +// pricingLookup 是渠道定价在单个计算过程中的索引:platform → (lowerName → *pricing)。 +// 用于将 SupportedModels 的定价解析从 O(N*M) 降到 O(N+M)。 +type pricingLookup map[string]map[string]*ChannelModelPricing + +// buildPricingLookup 对渠道的定价列表做一次扫描,生成 platform+模型名 的索引。 +// 索引值是定价条目的 Clone 指针,调用方可安全按需返回副本而不污染缓存。 +// wildcard 后缀(如 "claude-*")不会被索引(它们不是精确模型名)。 +func buildPricingLookup(pricings []ChannelModelPricing) pricingLookup { + lookup := make(pricingLookup, len(pricings)) + for i := range pricings { + p := pricings[i] + byModel, ok := lookup[p.Platform] + if !ok { + byModel = make(map[string]*ChannelModelPricing, len(p.Models)) + lookup[p.Platform] = byModel + } + for _, m := range p.Models { + if _, wild := splitWildcardSuffix(m); wild { + continue + } + lower := strings.ToLower(m) + if _, exists := byModel[lower]; exists { + continue // 首个命中胜出(保持 case-insensitive 去重后第一个定价) + } + cp := pricings[i].Clone() + byModel[lower] = &cp + } + } + return lookup +} + +// pricedNamesFor 返回指定平台下已索引的精确模型名(保留原始大小写,按添加顺序)。 +// 它是从 pricingLookup 中取 keys 并回查原始 ModelPricing 以得到原样字符串。 +func pricedNamesFor(pricings []ChannelModelPricing, platform string) []string { + seen := make(map[string]struct{}) + out := make([]string, 0) + for i := range pricings { + if pricings[i].Platform != platform { + continue + } + for _, m := range pricings[i].Models { + if _, wild := splitWildcardSuffix(m); wild { + continue + } + lower := strings.ToLower(m) + if _, ok := seen[lower]; ok { + continue + } + seen[lower] = struct{}{} + out = append(out, m) + } + } + return out +} + +// SupportedModels 计算渠道的支持模型列表,结果保证不含通配符。 +// +// 算法(以渠道自身的 ModelMapping 为唯一入口): +// - 遍历 Channel.ModelMapping 的每个 platform 条目; +// - 映射 key 不带尾部 "*":直接作为一个支持模型名(即使没有匹配的定价行,也会产出 Pricing=nil 的条目); +// - 映射 key 带尾部 "*":用同 platform 的 ModelPricing.Models 做前缀匹配展开(定价中带 "*" 的条目被忽略,因为它们本身就是模式,不是具体模型名); +// - 未在 ModelMapping 中出现的 platform 不会产出任何条目——这是**刻意设计**("没配映射就不显示"),即使该平台有定价行。 +// +// 每个结果尝试从 pricingLookup(平台+模型名索引)查找精确定价,未配置则 Pricing=nil。 +// 结果按 (Platform, Name) 稳定排序,并按 (Platform, lowercase(Name)) 去重。 +func (c *Channel) SupportedModels() []SupportedModel { + if c == nil || len(c.ModelMapping) == 0 { + return nil + } + + lookup := buildPricingLookup(c.ModelPricing) + + type dedupKey struct { + platform string + name string + } + seen := make(map[dedupKey]struct{}) + result := make([]SupportedModel, 0) + + add := func(platform, name string) { + key := dedupKey{platform: platform, name: strings.ToLower(name)} + if _, ok := seen[key]; ok { + return + } + seen[key] = struct{}{} + var pricing *ChannelModelPricing + if byModel, ok := lookup[platform]; ok { + if p, ok := byModel[strings.ToLower(name)]; ok { + pricing = p + } + } + result = append(result, SupportedModel{ + Name: name, + Platform: platform, + Pricing: pricing, + }) + } + + for platform, mapping := range c.ModelMapping { + if len(mapping) == 0 { + continue + } + pricedNames := pricedNamesFor(c.ModelPricing, platform) + for src := range mapping { + prefix, isWild := splitWildcardSuffix(src) + if isWild { + prefixLower := strings.ToLower(prefix) + for _, candidate := range pricedNames { + if strings.HasPrefix(strings.ToLower(candidate), prefixLower) { + add(platform, candidate) + } + } + continue + } + add(platform, src) + } + } + + sort.Slice(result, func(i, j int) bool { + if result[i].Platform != result[j].Platform { + return result[i].Platform < result[j].Platform + } + return result[i].Name < result[j].Name + }) + return result +} diff --git a/backend/internal/service/channel_available.go b/backend/internal/service/channel_available.go new file mode 100644 index 00000000..700380c2 --- /dev/null +++ b/backend/internal/service/channel_available.go @@ -0,0 +1,84 @@ +package service + +import ( + "context" + "fmt" + "sort" + "strings" +) + +// AvailableGroupRef 渠道视图中关联分组的简要信息。 +type AvailableGroupRef struct { + ID int64 + Name string + Platform string +} + +// AvailableChannel 可用渠道视图:用于「可用渠道」页面展示渠道基础信息 + +// 关联的分组 + 推导出的支持模型列表(无通配符)。 +type AvailableChannel struct { + ID int64 + Name string + Description string + Status string + BillingModelSource string + RestrictModels bool + Groups []AvailableGroupRef + SupportedModels []SupportedModel +} + +// ListAvailable 返回所有渠道的可用视图:每个渠道附带关联分组信息与支持模型列表。 +// +// 支持模型通过 (*Channel).SupportedModels() 计算得到(见 channel.go)。 +// 关联分组信息通过 groupRepo.ListActive 查询后按 ID 映射;渠道 GroupIDs 中未在活跃列表中 +// 的分组(已停用或删除)会被忽略。 +func (s *ChannelService) ListAvailable(ctx context.Context) ([]AvailableChannel, error) { + channels, err := s.repo.ListAll(ctx) + if err != nil { + return nil, fmt.Errorf("list channels: %w", err) + } + + groupByID := make(map[int64]AvailableGroupRef) + if s.groupRepo != nil { + groups, err := s.groupRepo.ListActive(ctx) + if err != nil { + return nil, fmt.Errorf("list active groups: %w", err) + } + for i := range groups { + g := groups[i] + groupByID[g.ID] = AvailableGroupRef{ + ID: g.ID, + Name: g.Name, + Platform: g.Platform, + } + } + } + + out := make([]AvailableChannel, 0, len(channels)) + for i := range channels { + ch := &channels[i] + groups := make([]AvailableGroupRef, 0, len(ch.GroupIDs)) + for _, gid := range ch.GroupIDs { + if ref, ok := groupByID[gid]; ok { + groups = append(groups, ref) + } + } + sort.Slice(groups, func(i, j int) bool { return groups[i].Name < groups[j].Name }) + + out = append(out, AvailableChannel{ + ID: ch.ID, + Name: ch.Name, + Description: ch.Description, + Status: ch.Status, + BillingModelSource: ch.BillingModelSource, + RestrictModels: ch.RestrictModels, + Groups: groups, + SupportedModels: ch.SupportedModels(), + }) + } + + sort.SliceStable(out, func(i, j int) bool { + return strings.ToLower(out[i].Name) < strings.ToLower(out[j].Name) + }) + return out, nil +} diff --git a/backend/internal/service/channel_available_test.go b/backend/internal/service/channel_available_test.go new file mode 100644 index 00000000..6a11fa4b --- /dev/null +++ b/backend/internal/service/channel_available_test.go @@ -0,0 +1,119 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/stretchr/testify/require" +) + +// stubGroupRepoForAvailable 是 ListAvailable 测试用的 GroupRepository stub, +// 仅实现 ListActive;其他方法对本测试无关,返回零值即可。 +type stubGroupRepoForAvailable struct { + activeGroups []Group +} + +func (s *stubGroupRepoForAvailable) ListActive(ctx context.Context) ([]Group, error) { + return s.activeGroups, nil +} + +func (s *stubGroupRepoForAvailable) Create(ctx context.Context, group *Group) error { return nil } +func (s *stubGroupRepoForAvailable) GetByID(ctx context.Context, id int64) (*Group, error) { + return nil, nil +} +func (s *stubGroupRepoForAvailable) GetByIDLite(ctx context.Context, id int64) (*Group, error) { + return nil, nil +} +func (s *stubGroupRepoForAvailable) Update(ctx context.Context, group *Group) error { return nil } +func (s *stubGroupRepoForAvailable) Delete(ctx context.Context, id int64) error { return nil } +func (s *stubGroupRepoForAvailable) DeleteCascade(ctx context.Context, id int64) ([]int64, error) { + return nil, nil +} +func (s *stubGroupRepoForAvailable) List(ctx context.Context, params pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (s *stubGroupRepoForAvailable) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (s *stubGroupRepoForAvailable) ListActiveByPlatform(ctx context.Context, platform string) ([]Group, error) { + return nil, nil +} +func (s *stubGroupRepoForAvailable) ExistsByName(ctx context.Context, name string) (bool, error) { + return false, nil +} +func (s *stubGroupRepoForAvailable) GetAccountCount(ctx context.Context, groupID int64) (int64, int64, error) { + return 0, 0, nil +} +func (s *stubGroupRepoForAvailable) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) { + return 0, nil +} +func (s *stubGroupRepoForAvailable) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error) { + return nil, nil +} +func (s *stubGroupRepoForAvailable) BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error { + return nil +} +func (s *stubGroupRepoForAvailable) UpdateSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error { + return nil +} + +// newAvailableChannelService 构造一个 ChannelService,channelRepo.ListAll 返回给定 channels, +// groupRepo 由参数决定(可传 nil 测试 nil 分支)。 +func newAvailableChannelService(channels []Channel, groupRepo GroupRepository) *ChannelService { + repo := &mockChannelRepository{ + listAllFn: func(ctx context.Context) ([]Channel, error) { return channels, nil }, + } + return NewChannelService(repo, groupRepo, nil) +} + +func TestListAvailable_NilGroupRepo_NoGroupsAttached(t *testing.T) { + // groupRepo 为 nil 时不应 panic,且每个渠道的 Groups 应为空切片。 + channels := []Channel{{ + ID: 1, + Name: "chA", + Status: StatusActive, + GroupIDs: []int64{10, 20}, + }} + svc := newAvailableChannelService(channels, nil) + out, err := svc.ListAvailable(context.Background()) + require.NoError(t, err) + require.Len(t, out, 1) + require.Empty(t, out[0].Groups) +} + +func TestListAvailable_InactiveGroupIDSilentlyDropped(t *testing.T) { + // 渠道 GroupIDs 中引用的 group 未出现在 ListActive 结果中(已停用或删除),应被静默丢弃。 + channels := []Channel{{ + ID: 1, + Name: "chA", + Status: StatusActive, + GroupIDs: []int64{1, 99}, + }} + groupRepo := &stubGroupRepoForAvailable{ + activeGroups: []Group{{ID: 1, Name: "g1", Platform: "anthropic"}}, + } + svc := newAvailableChannelService(channels, groupRepo) + out, err := svc.ListAvailable(context.Background()) + require.NoError(t, err) + require.Len(t, out, 1) + require.Len(t, out[0].Groups, 1) + require.Equal(t, int64(1), out[0].Groups[0].ID) +} + +func TestListAvailable_SortedByName(t *testing.T) { + channels := []Channel{ + {ID: 1, Name: "beta"}, + {ID: 2, Name: "Alpha"}, + {ID: 3, Name: "charlie"}, + } + svc := newAvailableChannelService(channels, nil) + out, err := svc.ListAvailable(context.Background()) + require.NoError(t, err) + require.Len(t, out, 3) + require.Equal(t, "Alpha", out[0].Name) + require.Equal(t, "beta", out[1].Name) + require.Equal(t, "charlie", out[2].Name) +} diff --git a/backend/internal/service/channel_service.go b/backend/internal/service/channel_service.go index c29550d9..250df07b 100644 --- a/backend/internal/service/channel_service.go +++ b/backend/internal/service/channel_service.go @@ -141,6 +141,7 @@ const ( // ChannelService 渠道管理服务 type ChannelService struct { repo ChannelRepository + groupRepo GroupRepository authCacheInvalidator APIKeyAuthCacheInvalidator cache atomic.Value // *channelCache @@ -148,9 +149,10 @@ type ChannelService struct { } // NewChannelService 创建渠道服务实例 -func NewChannelService(repo ChannelRepository, authCacheInvalidator APIKeyAuthCacheInvalidator) *ChannelService { +func NewChannelService(repo ChannelRepository, groupRepo GroupRepository, authCacheInvalidator APIKeyAuthCacheInvalidator) *ChannelService { s := &ChannelService{ repo: repo, + groupRepo: groupRepo, authCacheInvalidator: authCacheInvalidator, } return s @@ -884,12 +886,7 @@ func conflictsBetween(a, b modelEntry) bool { // toModelEntry 将模型名转换为 modelEntry func toModelEntry(pattern string) modelEntry { - lower := strings.ToLower(pattern) - isWild := strings.HasSuffix(lower, "*") - prefix := lower - if isWild { - prefix = strings.TrimSuffix(lower, "*") - } + prefix, isWild := splitWildcardSuffix(strings.ToLower(pattern)) return modelEntry{pattern: pattern, prefix: prefix, wildcard: isWild} } diff --git a/backend/internal/service/channel_service_test.go b/backend/internal/service/channel_service_test.go index e1345618..e44b882b 100644 --- a/backend/internal/service/channel_service_test.go +++ b/backend/internal/service/channel_service_test.go @@ -189,11 +189,11 @@ func (m *mockChannelAuthCacheInvalidator) InvalidateAuthCacheByGroupID(_ context // --------------------------------------------------------------------------- func newTestChannelService(repo *mockChannelRepository) *ChannelService { - return NewChannelService(repo, nil) + return NewChannelService(repo, nil, nil) } func newTestChannelServiceWithAuth(repo *mockChannelRepository, auth *mockChannelAuthCacheInvalidator) *ChannelService { - return NewChannelService(repo, auth) + return NewChannelService(repo, nil, auth) } // makeStandardRepo returns a repo that serves one active channel with anthropic pricing diff --git a/backend/internal/service/channel_test.go b/backend/internal/service/channel_test.go index deac64d6..812a3a63 100644 --- a/backend/internal/service/channel_test.go +++ b/backend/internal/service/channel_test.go @@ -433,3 +433,207 @@ func TestValidateIntervals_UnboundedNotLast(t *testing.T) { require.Contains(t, err.Error(), "unbounded") require.Contains(t, err.Error(), "last") } + +func TestSupportedModels_ExactKeysAndPricing(t *testing.T) { + ch := &Channel{ + ModelPricing: []ChannelModelPricing{ + {ID: 10, Platform: "anthropic", Models: []string{"claude-sonnet-4-6"}, InputPrice: testPtrFloat64(3e-6)}, + {ID: 11, Platform: "anthropic", Models: []string{"claude-opus-4-6"}, InputPrice: testPtrFloat64(1.5e-5)}, + }, + ModelMapping: map[string]map[string]string{ + "anthropic": { + "claude-sonnet-4-6": "claude-sonnet-4-6", + "claude-opus-4-6": "claude-opus-4-6", + }, + }, + } + + got := ch.SupportedModels() + require.Len(t, got, 2) + require.Equal(t, "anthropic", got[0].Platform) + require.Equal(t, "claude-opus-4-6", got[0].Name) + require.NotNil(t, got[0].Pricing) + require.Equal(t, int64(11), got[0].Pricing.ID) + require.Equal(t, "claude-sonnet-4-6", got[1].Name) + require.Equal(t, int64(10), got[1].Pricing.ID) +} + +func TestSupportedModels_WildcardExpandedFromPricing(t *testing.T) { + ch := &Channel{ + ModelPricing: []ChannelModelPricing{ + {ID: 1, Platform: "anthropic", Models: []string{"claude-sonnet-4-6", "claude-sonnet-4-5"}}, + {ID: 2, Platform: "anthropic", Models: []string{"claude-opus-4-6"}}, + }, + ModelMapping: map[string]map[string]string{ + "anthropic": { + "claude-sonnet-*": "claude-sonnet-4-6", + }, + }, + } + + got := ch.SupportedModels() + names := make([]string, 0, len(got)) + for _, m := range got { + names = append(names, m.Name) + } + require.ElementsMatch(t, []string{"claude-sonnet-4-5", "claude-sonnet-4-6"}, names) + for _, m := range got { + require.NotContains(t, m.Name, "*") + } +} + +func TestSupportedModels_PlatformWithoutMappingSkipped(t *testing.T) { + ch := &Channel{ + ModelPricing: []ChannelModelPricing{ + {ID: 1, Platform: "anthropic", Models: []string{"claude-sonnet-4-6"}}, + {ID: 2, Platform: "openai", Models: []string{"gpt-4o"}}, + }, + ModelMapping: map[string]map[string]string{ + "anthropic": {"claude-sonnet-4-6": "claude-sonnet-4-6"}, + // openai 没有 mapping 条目 + }, + } + + got := ch.SupportedModels() + require.Len(t, got, 1) + require.Equal(t, "anthropic", got[0].Platform) + require.Equal(t, "claude-sonnet-4-6", got[0].Name) +} + +func TestSupportedModels_MissingPricingKeepsNilPricing(t *testing.T) { + ch := &Channel{ + ModelMapping: map[string]map[string]string{ + "anthropic": {"claude-sonnet-4-6": "claude-sonnet-4-6"}, + }, + } + + got := ch.SupportedModels() + require.Len(t, got, 1) + require.Equal(t, "claude-sonnet-4-6", got[0].Name) + require.Nil(t, got[0].Pricing) +} + +func TestSupportedModels_DedupAndSort(t *testing.T) { + ch := &Channel{ + ModelPricing: []ChannelModelPricing{ + {ID: 1, Platform: "anthropic", Models: []string{"claude-sonnet-4-6", "claude-sonnet-4-5"}}, + {ID: 2, Platform: "openai", Models: []string{"gpt-4o"}}, + }, + ModelMapping: map[string]map[string]string{ + "anthropic": { + "claude-sonnet-4-6": "upstream-a", + "claude-sonnet-*": "upstream-a", + }, + "openai": {"gpt-4o": "gpt-4o"}, + }, + } + + got := ch.SupportedModels() + require.Len(t, got, 3) + require.Equal(t, "anthropic", got[0].Platform) + require.Equal(t, "claude-sonnet-4-5", got[0].Name) + require.Equal(t, "anthropic", got[1].Platform) + require.Equal(t, "claude-sonnet-4-6", got[1].Name) + require.Equal(t, "openai", got[2].Platform) + require.Equal(t, "gpt-4o", got[2].Name) +} + +func TestSupportedModels_NilChannelAndEmpty(t *testing.T) { + var nilCh *Channel + require.Nil(t, nilCh.SupportedModels()) + + empty := &Channel{} + require.Nil(t, empty.SupportedModels()) +} + +func TestGetModelPricingByPlatform(t *testing.T) { + ch := &Channel{ + ModelPricing: []ChannelModelPricing{ + {ID: 1, Platform: "anthropic", Models: []string{"claude-sonnet-4-6"}, InputPrice: testPtrFloat64(3e-6)}, + {ID: 2, Platform: "openai", Models: []string{"claude-sonnet-4-6"}, InputPrice: testPtrFloat64(1e-6)}, + }, + } + + ant := ch.GetModelPricingByPlatform("anthropic", "claude-sonnet-4-6") + require.NotNil(t, ant) + require.Equal(t, int64(1), ant.ID) + + oa := ch.GetModelPricingByPlatform("openai", "claude-sonnet-4-6") + require.NotNil(t, oa) + require.Equal(t, int64(2), oa.ID) + + require.Nil(t, ch.GetModelPricingByPlatform("gemini", "claude-sonnet-4-6")) +} + +func TestSupportedModels_WildcardOnlyPricingRowsSkipped(t *testing.T) { + // 定价中含通配符条目(pattern),不应被当作具体模型名展开。 + ch := &Channel{ + ModelPricing: []ChannelModelPricing{ + {ID: 1, Platform: "anthropic", Models: []string{"claude-sonnet-*", "claude-sonnet-4-6"}}, + }, + ModelMapping: map[string]map[string]string{ + "anthropic": {"claude-sonnet-*": "claude-sonnet-4-6"}, + }, + } + got := ch.SupportedModels() + require.Len(t, got, 1) + require.Equal(t, "claude-sonnet-4-6", got[0].Name) + for _, m := range got { + require.NotContains(t, m.Name, "*") + } +} + +func TestSupportedModels_WildcardPrefixMatchesNothing(t *testing.T) { + // 通配符模式无任何对应定价模型时,该平台应产出 0 个模型。 + ch := &Channel{ + ModelPricing: []ChannelModelPricing{ + {ID: 1, Platform: "openai", Models: []string{"gpt-4o"}}, + }, + ModelMapping: map[string]map[string]string{ + "anthropic": {"gpt-foo-*": "gpt-foo-1"}, + }, + } + require.Empty(t, ch.SupportedModels()) +} + +func TestSupportedModels_CrossPlatformPricingDoesNotBleed(t *testing.T) { + // anthropic 的通配符不应拉入 openai 定价行,哪怕名字恰好前缀匹配。 + ch := &Channel{ + ModelPricing: []ChannelModelPricing{ + {ID: 1, Platform: "openai", Models: []string{"claude-sonnet-4-6"}}, + }, + ModelMapping: map[string]map[string]string{ + "anthropic": {"claude-sonnet-*": "x"}, + }, + } + require.Empty(t, ch.SupportedModels()) +} + +func TestSupportedModels_CaseInsensitiveDedup(t *testing.T) { + // 两行定价用不同大小写定义了同一模型,结果应去重为 1 条;首次出现的原始大小写保留。 + ch := &Channel{ + ModelPricing: []ChannelModelPricing{ + {ID: 1, Platform: "openai", Models: []string{"GPT-4o"}}, + {ID: 2, Platform: "openai", Models: []string{"gpt-4o"}}, + }, + ModelMapping: map[string]map[string]string{ + "openai": {"gpt-*": "x"}, + }, + } + got := ch.SupportedModels() + require.Len(t, got, 1) + require.Equal(t, "GPT-4o", got[0].Name) +} + +func TestSupportedModels_EmptyPlatformMapping(t *testing.T) { + // ModelMapping 有一个 platform key 但 value 是空 map —— 该 platform 应被跳过。 + ch := &Channel{ + ModelPricing: []ChannelModelPricing{ + {ID: 1, Platform: "anthropic", Models: []string{"claude-sonnet-4-6"}}, + }, + ModelMapping: map[string]map[string]string{ + "anthropic": {}, + }, + } + require.Empty(t, ch.SupportedModels()) +} diff --git a/backend/internal/service/model_pricing_resolver_test.go b/backend/internal/service/model_pricing_resolver_test.go index 905c4df6..7484eed5 100644 --- a/backend/internal/service/model_pricing_resolver_test.go +++ b/backend/internal/service/model_pricing_resolver_test.go @@ -184,7 +184,7 @@ func newResolverWithChannel(t *testing.T, pricing []ChannelModelPricing) *ModelP return map[int64]string{groupID: "anthropic"}, nil }, } - cs := NewChannelService(repo, nil) + cs := NewChannelService(repo, nil, nil) bs := newTestBillingServiceForResolver() return NewModelPricingResolver(cs, bs) } @@ -517,7 +517,7 @@ func TestResolve_WithChannelOverride_CacheError(t *testing.T) { return nil, errors.New("database unavailable") }, } - cs := NewChannelService(repo, nil) + cs := NewChannelService(repo, nil, nil) bs := newTestBillingServiceForResolver() r := NewModelPricingResolver(cs, bs) diff --git a/frontend/src/api/admin/channels.ts b/frontend/src/api/admin/channels.ts index f129ceaa..eb7e91d8 100644 --- a/frontend/src/api/admin/channels.ts +++ b/frontend/src/api/admin/channels.ts @@ -163,5 +163,42 @@ export async function getModelDefaultPricing(model: string): Promise { + const { data } = await apiClient.get('/admin/channels/available', { + signal: options?.signal + }) + return data.items +} + +const channelsAPI = { list, getById, create, update, remove, getModelDefaultPricing, listAvailable } export default channelsAPI diff --git a/frontend/src/api/channels.ts b/frontend/src/api/channels.ts new file mode 100644 index 00000000..98b890df --- /dev/null +++ b/frontend/src/api/channels.ts @@ -0,0 +1,60 @@ +/** + * User Channels API endpoints (non-admin) + * 用户侧「可用渠道」聚合查询:渠道 + 用户可访问的分组 + 支持模型(含定价)。 + */ + +import { apiClient } from './client' +import type { BillingMode } from '@/constants/channel' + +export interface UserAvailableGroup { + id: number + name: string + platform: string +} + +export interface UserPricingInterval { + min_tokens: number + max_tokens: number | null + tier_label?: string + input_price: number | null + output_price: number | null + cache_write_price: number | null + cache_read_price: number | null + per_request_price: number | null +} + +export interface UserSupportedModelPricing { + billing_mode: BillingMode + input_price: number | null + output_price: number | null + cache_write_price: number | null + cache_read_price: number | null + image_output_price: number | null + per_request_price: number | null + intervals: UserPricingInterval[] +} + +export interface UserSupportedModel { + name: string + platform: string + pricing: UserSupportedModelPricing | null +} + +export interface UserAvailableChannel { + name: string + description: string + groups: UserAvailableGroup[] + supported_models: UserSupportedModel[] +} + +/** 列出当前用户可见的「可用渠道」(与 /groups/available 保持一致,返回平数组)。 */ +export async function getAvailable(options?: { signal?: AbortSignal }): Promise { + const { data } = await apiClient.get('/channels/available', { + signal: options?.signal + }) + return data +} + +export const userChannelsAPI = { getAvailable } + +export default userChannelsAPI diff --git a/frontend/src/api/index.ts b/frontend/src/api/index.ts index dd005a0d..6702468d 100644 --- a/frontend/src/api/index.ts +++ b/frontend/src/api/index.ts @@ -16,6 +16,7 @@ export { userAPI } from './user' export { redeemAPI, type RedeemHistoryItem } from './redeem' export { paymentAPI } from './payment' export { userGroupsAPI } from './groups' +export { userChannelsAPI } from './channels' export { totpAPI } from './totp' export { default as announcementsAPI } from './announcements' export { channelMonitorUserAPI } from './channelMonitor' diff --git a/frontend/src/components/channels/AvailableChannelsTable.vue b/frontend/src/components/channels/AvailableChannelsTable.vue new file mode 100644 index 00000000..403391a3 --- /dev/null +++ b/frontend/src/components/channels/AvailableChannelsTable.vue @@ -0,0 +1,110 @@ + + + diff --git a/frontend/src/components/channels/PricingRow.vue b/frontend/src/components/channels/PricingRow.vue new file mode 100644 index 00000000..8db077c0 --- /dev/null +++ b/frontend/src/components/channels/PricingRow.vue @@ -0,0 +1,29 @@ + + + diff --git a/frontend/src/components/channels/SupportedModelChip.vue b/frontend/src/components/channels/SupportedModelChip.vue new file mode 100644 index 00000000..82f27607 --- /dev/null +++ b/frontend/src/components/channels/SupportedModelChip.vue @@ -0,0 +1,214 @@ + + + diff --git a/frontend/src/components/layout/AppSidebar.vue b/frontend/src/components/layout/AppSidebar.vue index 248e0021..25284276 100644 --- a/frontend/src/components/layout/AppSidebar.vue +++ b/frontend/src/components/layout/AppSidebar.vue @@ -648,6 +648,7 @@ function buildSelfNavItems(withDashboard: boolean): NavItem[] { { path: '/subscriptions', label: t('nav.mySubscriptions'), icon: CreditCardIcon, hideInSimpleMode: true }, { path: '/purchase', label: t('nav.buySubscription'), icon: RechargeSubscriptionIcon, hideInSimpleMode: true, featureFlag: flagPayment }, { path: '/orders', label: t('nav.myOrders'), icon: OrderListIcon, hideInSimpleMode: true, featureFlag: flagPayment }, + { path: '/available-channels', label: t('nav.availableChannels'), icon: ChannelIcon, hideInSimpleMode: true }, { path: '/redeem', label: t('nav.redeem'), icon: GiftIcon, hideInSimpleMode: true }, { path: '/profile', label: t('nav.profile'), icon: UserIcon }, ...customMenuItemsForUser.value.map((item): NavItem => ({ diff --git a/frontend/src/constants/channel.ts b/frontend/src/constants/channel.ts new file mode 100644 index 00000000..c08f4800 --- /dev/null +++ b/frontend/src/constants/channel.ts @@ -0,0 +1,22 @@ +/** Channel status values (must match service.Status* constants in Go). */ +export const CHANNEL_STATUS_ACTIVE = 'active' as const +export const CHANNEL_STATUS_DISABLED = 'disabled' as const +export type ChannelStatus = typeof CHANNEL_STATUS_ACTIVE | typeof CHANNEL_STATUS_DISABLED + +/** Billing mode values (must match service.BillingMode* constants in Go). */ +export const BILLING_MODE_TOKEN = 'token' as const +export const BILLING_MODE_PER_REQUEST = 'per_request' as const +export const BILLING_MODE_IMAGE = 'image' as const +export type BillingMode = + | typeof BILLING_MODE_TOKEN + | typeof BILLING_MODE_PER_REQUEST + | typeof BILLING_MODE_IMAGE + +/** Billing-model-source values (must match service.BillingModelSource* constants in Go). */ +export const BILLING_MODEL_SOURCE_REQUESTED = 'requested' as const +export const BILLING_MODEL_SOURCE_UPSTREAM = 'upstream' as const +export const BILLING_MODEL_SOURCE_CHANNEL_MAPPED = 'channel_mapped' as const +export type BillingModelSource = + | typeof BILLING_MODEL_SOURCE_REQUESTED + | typeof BILLING_MODEL_SOURCE_UPSTREAM + | typeof BILLING_MODEL_SOURCE_CHANNEL_MAPPED diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index eb401ae2..a54639cc 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -344,6 +344,7 @@ export default { users: 'Users', groups: 'Groups', channels: 'Channels', + availableChannels: 'Available Channels', subscriptions: 'Subscriptions', accounts: 'Accounts', proxies: 'Proxies', @@ -929,6 +930,38 @@ export default { } }, + // Available Channels (user-facing) + availableChannels: { + title: 'Available Channels', + description: 'Channels you can access, along with their supported models and pricing', + searchPlaceholder: 'Search channels or models...', + empty: 'No available channels', + noModels: 'No models configured', + noPricing: 'Pricing not configured', + columns: { + name: 'Channel', + groups: 'Your Accessible Groups', + supportedModels: 'Supported Models' + }, + pricing: { + billingMode: 'Billing Mode', + billingModeToken: 'Per Token', + billingModePerRequest: 'Per Request', + billingModeImage: 'Per Image', + inputPrice: 'Input', + outputPrice: 'Output', + cacheWritePrice: 'Cache Write', + cacheReadPrice: 'Cache Read', + imageOutputPrice: 'Image Output', + perRequestPrice: 'Per Request', + intervals: 'Tiered Pricing', + tierLabel: 'Tier', + tokenRange: 'Token Range', + unitPerMillion: '/ 1M tokens', + unitPerRequest: '/ request' + } + }, + // Redeem redeem: { title: 'Redeem Code', @@ -1980,6 +2013,48 @@ export default { } }, + // Available Channels (aggregated read-only view) + availableChannels: { + title: 'Available Channels', + description: 'Aggregated view: each channel with its linked groups and supported models (wildcards expanded)', + searchPlaceholder: 'Search channels or models...', + columns: { + name: 'Channel', + status: 'Status', + billingSource: 'Billing Model Source', + groups: 'Linked Groups', + supportedModels: 'Supported Models' + }, + empty: 'No data', + noGroups: 'No linked groups', + noModels: 'No model mapping configured', + noPricing: 'Pricing not configured', + statusActive: 'Active', + statusDisabled: 'Disabled', + billingSource: { + requested: 'Requested model', + upstream: 'Upstream model', + channel_mapped: 'Channel-mapped model' + }, + pricing: { + billingMode: 'Billing Mode', + billingModeToken: 'Per Token', + billingModePerRequest: 'Per Request', + billingModeImage: 'Per Image', + inputPrice: 'Input', + outputPrice: 'Output', + cacheWritePrice: 'Cache Write', + cacheReadPrice: 'Cache Read', + imageOutputPrice: 'Image Output', + perRequestPrice: 'Per Request', + intervals: 'Tiered Pricing', + tierLabel: 'Tier', + tokenRange: 'Token Range', + unitPerMillion: '/ 1M tokens', + unitPerRequest: '/ request' + } + }, + // Channel Management channels: { title: 'Channel Management', diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index d38b5034..e69b0223 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -344,6 +344,7 @@ export default { users: '用户管理', groups: '分组管理', channels: '渠道管理', + availableChannels: '可用渠道', subscriptions: '订阅管理', accounts: '账号管理', proxies: 'IP管理', @@ -933,6 +934,38 @@ export default { } }, + // Available Channels (user-facing) + availableChannels: { + title: '可用渠道', + description: '查看您可访问的渠道与其支持的模型、定价', + searchPlaceholder: '搜索渠道或模型...', + empty: '暂无可用渠道', + noModels: '未配置模型', + noPricing: '未配置定价', + columns: { + name: '渠道名', + groups: '我可访问的分组', + supportedModels: '支持模型' + }, + pricing: { + billingMode: '计费模式', + billingModeToken: '按 Token', + billingModePerRequest: '按次', + billingModeImage: '按图片', + inputPrice: '输入', + outputPrice: '输出', + cacheWritePrice: '缓存写入', + cacheReadPrice: '缓存读取', + imageOutputPrice: '图片输出', + perRequestPrice: '每次请求', + intervals: '阶梯定价', + tierLabel: '层级', + tokenRange: 'Token 区间', + unitPerMillion: '/ 1M token', + unitPerRequest: '/ 次' + } + }, + // Redeem redeem: { title: '兑换码', @@ -2059,6 +2092,48 @@ export default { } }, + // Available Channels (aggregated read-only view) + availableChannels: { + title: '可用渠道', + description: '按渠道聚合查看关联分组与支持模型(已展开通配符)', + searchPlaceholder: '搜索渠道或模型...', + columns: { + name: '渠道名', + status: '状态', + billingSource: '计费模型来源', + groups: '关联分组', + supportedModels: '支持模型' + }, + empty: '暂无数据', + noGroups: '未关联分组', + noModels: '未配置模型映射', + noPricing: '未配置定价', + statusActive: '启用', + statusDisabled: '停用', + billingSource: { + requested: '请求模型', + upstream: '上游模型', + channel_mapped: '映射后模型' + }, + pricing: { + billingMode: '计费模式', + billingModeToken: '按 Token', + billingModePerRequest: '按次', + billingModeImage: '按图片', + inputPrice: '输入', + outputPrice: '输出', + cacheWritePrice: '缓存写入', + cacheReadPrice: '缓存读取', + imageOutputPrice: '图片输出', + perRequestPrice: '每次请求', + intervals: '阶梯定价', + tierLabel: '层级', + tokenRange: 'Token 区间', + unitPerMillion: '/ 1M token', + unitPerRequest: '/ 次' + } + }, + // Channel Management channels: { title: '渠道管理', diff --git a/frontend/src/router/index.ts b/frontend/src/router/index.ts index 491a984d..567876b6 100644 --- a/frontend/src/router/index.ts +++ b/frontend/src/router/index.ts @@ -197,6 +197,18 @@ const routes: RouteRecordRaw[] = [ descriptionKey: 'redeem.description' } }, + { + path: '/available-channels', + name: 'UserAvailableChannels', + component: () => import('@/views/user/AvailableChannelsView.vue'), + meta: { + requiresAuth: true, + requiresAdmin: false, + title: 'Available Channels', + titleKey: 'availableChannels.title', + descriptionKey: 'availableChannels.description' + } + }, { path: '/profile', name: 'Profile', @@ -358,6 +370,18 @@ const routes: RouteRecordRaw[] = [ descriptionKey: 'admin.groups.description' } }, + { + path: '/admin/available-channels', + name: 'AdminAvailableChannels', + component: () => import('@/views/admin/AvailableChannelsView.vue'), + meta: { + requiresAuth: true, + requiresAdmin: true, + title: 'Available Channels', + titleKey: 'admin.availableChannels.title', + descriptionKey: 'admin.availableChannels.description' + } + }, { path: '/admin/channels', redirect: '/admin/channels/pricing' diff --git a/frontend/src/views/admin/AvailableChannelsView.vue b/frontend/src/views/admin/AvailableChannelsView.vue new file mode 100644 index 00000000..3f0ee436 --- /dev/null +++ b/frontend/src/views/admin/AvailableChannelsView.vue @@ -0,0 +1,135 @@ + + + diff --git a/frontend/src/views/user/AvailableChannelsView.vue b/frontend/src/views/user/AvailableChannelsView.vue new file mode 100644 index 00000000..44ee456e --- /dev/null +++ b/frontend/src/views/user/AvailableChannelsView.vue @@ -0,0 +1,98 @@ + + +