fix: 图片计费代码审查问题修复

- isImageGenerationModel 改为精确匹配/前缀匹配,避免误匹配
- 新增 normalizePrice 函数,支持负数清除价格配置
- 更新注释说明 Gemini API 每次请求只生成一张图片
- 添加测试用例验证不会误匹配自定义模型名
This commit is contained in:
song
2026-01-05 17:14:06 +08:00
parent d4c2b723a5
commit 5b1907fe61
4 changed files with 37 additions and 12 deletions

View File

@@ -33,7 +33,7 @@ type CreateGroupRequest struct {
DailyLimitUSD *float64 `json:"daily_limit_usd"` DailyLimitUSD *float64 `json:"daily_limit_usd"`
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"` WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"` MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
// 图片生成计费配置(antigravity 平台使用 // 图片生成计费配置antigravity 和 gemini 平台使用,负数表示清除配置
ImagePrice1K *float64 `json:"image_price_1k"` ImagePrice1K *float64 `json:"image_price_1k"`
ImagePrice2K *float64 `json:"image_price_2k"` ImagePrice2K *float64 `json:"image_price_2k"`
ImagePrice4K *float64 `json:"image_price_4k"` ImagePrice4K *float64 `json:"image_price_4k"`
@@ -51,7 +51,7 @@ type UpdateGroupRequest struct {
DailyLimitUSD *float64 `json:"daily_limit_usd"` DailyLimitUSD *float64 `json:"daily_limit_usd"`
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"` WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"` MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
// 图片生成计费配置(antigravity 平台使用 // 图片生成计费配置antigravity 和 gemini 平台使用,负数表示清除配置
ImagePrice1K *float64 `json:"image_price_1k"` ImagePrice1K *float64 `json:"image_price_1k"`
ImagePrice2K *float64 `json:"image_price_2k"` ImagePrice2K *float64 `json:"image_price_2k"`
ImagePrice4K *float64 `json:"image_price_4k"` ImagePrice4K *float64 `json:"image_price_4k"`

View File

@@ -504,6 +504,11 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
weeklyLimit := normalizeLimit(input.WeeklyLimitUSD) weeklyLimit := normalizeLimit(input.WeeklyLimitUSD)
monthlyLimit := normalizeLimit(input.MonthlyLimitUSD) monthlyLimit := normalizeLimit(input.MonthlyLimitUSD)
// 图片价格负数表示清除使用默认价格0 保留(表示免费)
imagePrice1K := normalizePrice(input.ImagePrice1K)
imagePrice2K := normalizePrice(input.ImagePrice2K)
imagePrice4K := normalizePrice(input.ImagePrice4K)
group := &Group{ group := &Group{
Name: input.Name, Name: input.Name,
Description: input.Description, Description: input.Description,
@@ -515,9 +520,9 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
DailyLimitUSD: dailyLimit, DailyLimitUSD: dailyLimit,
WeeklyLimitUSD: weeklyLimit, WeeklyLimitUSD: weeklyLimit,
MonthlyLimitUSD: monthlyLimit, MonthlyLimitUSD: monthlyLimit,
ImagePrice1K: input.ImagePrice1K, ImagePrice1K: imagePrice1K,
ImagePrice2K: input.ImagePrice2K, ImagePrice2K: imagePrice2K,
ImagePrice4K: input.ImagePrice4K, ImagePrice4K: imagePrice4K,
} }
if err := s.groupRepo.Create(ctx, group); err != nil { if err := s.groupRepo.Create(ctx, group); err != nil {
return nil, err return nil, err
@@ -533,6 +538,14 @@ func normalizeLimit(limit *float64) *float64 {
return limit return limit
} }
// normalizePrice 将负数转换为 nil表示使用默认价格0 保留(表示免费)
func normalizePrice(price *float64) *float64 {
if price == nil || *price < 0 {
return nil
}
return price
}
func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) { func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) {
group, err := s.groupRepo.GetByID(ctx, id) group, err := s.groupRepo.GetByID(ctx, id)
if err != nil { if err != nil {
@@ -572,15 +585,15 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
if input.MonthlyLimitUSD != nil { if input.MonthlyLimitUSD != nil {
group.MonthlyLimitUSD = normalizeLimit(input.MonthlyLimitUSD) group.MonthlyLimitUSD = normalizeLimit(input.MonthlyLimitUSD)
} }
// 图片生成计费配置 // 图片生成计费配置:负数表示清除(使用默认价格)
if input.ImagePrice1K != nil { if input.ImagePrice1K != nil {
group.ImagePrice1K = input.ImagePrice1K group.ImagePrice1K = normalizePrice(input.ImagePrice1K)
} }
if input.ImagePrice2K != nil { if input.ImagePrice2K != nil {
group.ImagePrice2K = input.ImagePrice2K group.ImagePrice2K = normalizePrice(input.ImagePrice2K)
} }
if input.ImagePrice4K != nil { if input.ImagePrice4K != nil {
group.ImagePrice4K = input.ImagePrice4K group.ImagePrice4K = normalizePrice(input.ImagePrice4K)
} }
if err := s.groupRepo.Update(ctx, group); err != nil { if err := s.groupRepo.Update(ctx, group); err != nil {

View File

@@ -838,7 +838,7 @@ handleSuccess:
// 判断是否为图片生成模型 // 判断是否为图片生成模型
imageCount := 0 imageCount := 0
if isImageGenerationModel(mappedModel) { if isImageGenerationModel(mappedModel) {
// 图片模型按次计费,默认 1 张图片 // Gemini 图片生成 API 每次请求只生成一张图片API 限制)
imageCount = 1 imageCount = 1
} }
@@ -1192,8 +1192,17 @@ func (s *AntigravityGatewayService) extractImageSize(body []byte) string {
} }
// isImageGenerationModel 判断模型是否为图片生成模型 // isImageGenerationModel 判断模型是否为图片生成模型
// 支持的模型gemini-3-pro-image, gemini-3-pro-image-preview, gemini-2.5-flash-image 等
func isImageGenerationModel(model string) bool { func isImageGenerationModel(model string) bool {
modelLower := strings.ToLower(model) modelLower := strings.ToLower(model)
return strings.Contains(modelLower, "gemini-3-pro-image") || // 移除 models/ 前缀
strings.Contains(modelLower, "gemini-2.5-flash-image") modelLower = strings.TrimPrefix(modelLower, "models/")
// 精确匹配或前缀匹配
return modelLower == "gemini-3-pro-image" ||
modelLower == "gemini-3-pro-image-preview" ||
strings.HasPrefix(modelLower, "gemini-3-pro-image-") ||
modelLower == "gemini-2.5-flash-image" ||
modelLower == "gemini-2.5-flash-image-preview" ||
strings.HasPrefix(modelLower, "gemini-2.5-flash-image-")
} }

View File

@@ -28,6 +28,9 @@ func TestIsImageGenerationModel_RegularModel(t *testing.T) {
require.False(t, isImageGenerationModel("gpt-4o")) require.False(t, isImageGenerationModel("gpt-4o"))
require.False(t, isImageGenerationModel("gemini-2.5-pro")) // 非图片模型 require.False(t, isImageGenerationModel("gemini-2.5-pro")) // 非图片模型
require.False(t, isImageGenerationModel("gemini-2.5-flash")) require.False(t, isImageGenerationModel("gemini-2.5-flash"))
// 验证不会误匹配包含关键词的自定义模型名
require.False(t, isImageGenerationModel("my-gemini-3-pro-image-test"))
require.False(t, isImageGenerationModel("custom-gemini-2.5-flash-image-wrapper"))
} }
// TestIsImageGenerationModel_CaseInsensitive 测试大小写不敏感 // TestIsImageGenerationModel_CaseInsensitive 测试大小写不敏感