diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index cde870de..85590c12 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -217,7 +217,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { scheduledTestResultRepository := repository.NewScheduledTestResultRepository(db) scheduledTestService := service.ProvideScheduledTestService(scheduledTestPlanRepository, scheduledTestResultRepository) scheduledTestHandler := admin.NewScheduledTestHandler(scheduledTestService) - channelHandler := admin.NewChannelHandler(channelService) + channelHandler := admin.NewChannelHandler(channelService, billingService) adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler) usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig) userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient) diff --git a/backend/internal/handler/admin/channel_handler.go b/backend/internal/handler/admin/channel_handler.go index 30cd0645..77540d3d 100644 --- a/backend/internal/handler/admin/channel_handler.go +++ b/backend/internal/handler/admin/channel_handler.go @@ -14,11 +14,12 @@ import ( // ChannelHandler handles admin channel management type ChannelHandler struct { channelService *service.ChannelService + billingService *service.BillingService } // NewChannelHandler creates a new admin channel handler -func NewChannelHandler(channelService *service.ChannelService) *ChannelHandler { - return &ChannelHandler{channelService: channelService} +func NewChannelHandler(channelService *service.ChannelService, billingService *service.BillingService) *ChannelHandler { + return &ChannelHandler{channelService: channelService, billingService: billingService} } // --- Request / Response types --- @@ -346,3 +347,28 @@ func (h *ChannelHandler) Delete(c *gin.Context) { response.Success(c, gin.H{"message": "Channel deleted successfully"}) } + +// GetModelDefaultPricing 获取模型的默认定价(用于前端自动填充) +// GET /api/v1/admin/channels/model-pricing?model=claude-sonnet-4 +func (h *ChannelHandler) GetModelDefaultPricing(c *gin.Context) { + model := strings.TrimSpace(c.Query("model")) + if model == "" { + response.BadRequest(c, "model parameter is required") + return + } + + pricing, err := h.billingService.GetModelPricing(model) + if err != nil { + // 模型不在定价列表中 + response.Success(c, gin.H{"found": false}) + return + } + + response.Success(c, gin.H{ + "found": true, + "input_price": pricing.InputPricePerToken, + "output_price": pricing.OutputPricePerToken, + "cache_write_price": pricing.CacheCreationPricePerToken, + "cache_read_price": pricing.CacheReadPricePerToken, + }) +} diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index abc28295..76f4c4b4 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -575,6 +575,7 @@ func registerChannelRoutes(admin *gin.RouterGroup, h *handler.Handlers) { channels := admin.Group("/channels") { channels.GET("", h.Admin.Channel.List) + channels.GET("/model-pricing", h.Admin.Channel.GetModelDefaultPricing) channels.GET("/:id", h.Admin.Channel.GetByID) channels.POST("", h.Admin.Channel.Create) channels.PUT("/:id", h.Admin.Channel.Update) diff --git a/frontend/src/api/admin/channels.ts b/frontend/src/api/admin/channels.ts index df23db93..79d824c5 100644 --- a/frontend/src/api/admin/channels.ts +++ b/frontend/src/api/admin/channels.ts @@ -128,5 +128,20 @@ export async function remove(id: number): Promise { await apiClient.delete(`/admin/channels/${id}`) } -const channelsAPI = { list, getById, create, update, remove } +export interface ModelDefaultPricing { + found: boolean + input_price?: number // per-token price + output_price?: number + cache_write_price?: number + cache_read_price?: number +} + +export async function getModelDefaultPricing(model: string): Promise { + const { data } = await apiClient.get('/admin/channels/model-pricing', { + params: { model } + }) + return data +} + +const channelsAPI = { list, getById, create, update, remove, getModelDefaultPricing } export default channelsAPI diff --git a/frontend/src/components/admin/channel/PricingEntryCard.vue b/frontend/src/components/admin/channel/PricingEntryCard.vue index 6e97676a..633fbdee 100644 --- a/frontend/src/components/admin/channel/PricingEntryCard.vue +++ b/frontend/src/components/admin/channel/PricingEntryCard.vue @@ -74,7 +74,7 @@ @@ -232,7 +232,9 @@ import Icon from '@/components/icons/Icon.vue' import IntervalRow from './IntervalRow.vue' import ModelTagInput from './ModelTagInput.vue' import type { PricingFormEntry, IntervalFormEntry } from './types' +import { perTokenToMTok } from './types' import type { BillingMode } from '@/api/admin/channels' +import channelsAPI from '@/api/admin/channels' const { t } = useI18n() @@ -297,6 +299,38 @@ function removeInterval(idx: number) { intervals.splice(idx, 1) emit('update', { ...props.entry, intervals }) } + +async function onModelsUpdate(newModels: string[]) { + const oldModels = props.entry.models + emit('update', { ...props.entry, models: newModels }) + + // 只在新增模型且当前无价格时自动填充 + const addedModels = newModels.filter(m => !oldModels.includes(m)) + if (addedModels.length === 0) return + + // 检查是否所有价格字段都为空 + const e = props.entry + const hasPrice = e.input_price != null || e.output_price != null || + e.cache_write_price != null || e.cache_read_price != null + if (hasPrice) return + + // 查询第一个新增模型的默认价格 + try { + const result = await channelsAPI.getModelDefaultPricing(addedModels[0]) + if (result.found) { + emit('update', { + ...props.entry, + models: newModels, + input_price: perTokenToMTok(result.input_price ?? null), + output_price: perTokenToMTok(result.output_price ?? null), + cache_write_price: perTokenToMTok(result.cache_write_price ?? null), + cache_read_price: perTokenToMTok(result.cache_read_price ?? null), + }) + } + } catch { + // 查询失败不影响用户操作 + } +}