fix: 图片计费代码审查问题修复
- isImageGenerationModel 改为精确匹配/前缀匹配,避免误匹配 - 新增 normalizePrice 函数,支持负数清除价格配置 - 更新注释说明 Gemini API 每次请求只生成一张图片 - 添加测试用例验证不会误匹配自定义模型名
This commit is contained in:
@@ -33,7 +33,7 @@ type CreateGroupRequest struct {
|
|||||||
DailyLimitUSD *float64 `json:"daily_limit_usd"`
|
DailyLimitUSD *float64 `json:"daily_limit_usd"`
|
||||||
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
|
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
|
||||||
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
|
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
|
||||||
// 图片生成计费配置(仅 antigravity 平台使用)
|
// 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置)
|
||||||
ImagePrice1K *float64 `json:"image_price_1k"`
|
ImagePrice1K *float64 `json:"image_price_1k"`
|
||||||
ImagePrice2K *float64 `json:"image_price_2k"`
|
ImagePrice2K *float64 `json:"image_price_2k"`
|
||||||
ImagePrice4K *float64 `json:"image_price_4k"`
|
ImagePrice4K *float64 `json:"image_price_4k"`
|
||||||
@@ -51,7 +51,7 @@ type UpdateGroupRequest struct {
|
|||||||
DailyLimitUSD *float64 `json:"daily_limit_usd"`
|
DailyLimitUSD *float64 `json:"daily_limit_usd"`
|
||||||
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
|
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
|
||||||
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
|
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
|
||||||
// 图片生成计费配置(仅 antigravity 平台使用)
|
// 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置)
|
||||||
ImagePrice1K *float64 `json:"image_price_1k"`
|
ImagePrice1K *float64 `json:"image_price_1k"`
|
||||||
ImagePrice2K *float64 `json:"image_price_2k"`
|
ImagePrice2K *float64 `json:"image_price_2k"`
|
||||||
ImagePrice4K *float64 `json:"image_price_4k"`
|
ImagePrice4K *float64 `json:"image_price_4k"`
|
||||||
|
|||||||
@@ -504,6 +504,11 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
|
|||||||
weeklyLimit := normalizeLimit(input.WeeklyLimitUSD)
|
weeklyLimit := normalizeLimit(input.WeeklyLimitUSD)
|
||||||
monthlyLimit := normalizeLimit(input.MonthlyLimitUSD)
|
monthlyLimit := normalizeLimit(input.MonthlyLimitUSD)
|
||||||
|
|
||||||
|
// 图片价格:负数表示清除(使用默认价格),0 保留(表示免费)
|
||||||
|
imagePrice1K := normalizePrice(input.ImagePrice1K)
|
||||||
|
imagePrice2K := normalizePrice(input.ImagePrice2K)
|
||||||
|
imagePrice4K := normalizePrice(input.ImagePrice4K)
|
||||||
|
|
||||||
group := &Group{
|
group := &Group{
|
||||||
Name: input.Name,
|
Name: input.Name,
|
||||||
Description: input.Description,
|
Description: input.Description,
|
||||||
@@ -515,9 +520,9 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
|
|||||||
DailyLimitUSD: dailyLimit,
|
DailyLimitUSD: dailyLimit,
|
||||||
WeeklyLimitUSD: weeklyLimit,
|
WeeklyLimitUSD: weeklyLimit,
|
||||||
MonthlyLimitUSD: monthlyLimit,
|
MonthlyLimitUSD: monthlyLimit,
|
||||||
ImagePrice1K: input.ImagePrice1K,
|
ImagePrice1K: imagePrice1K,
|
||||||
ImagePrice2K: input.ImagePrice2K,
|
ImagePrice2K: imagePrice2K,
|
||||||
ImagePrice4K: input.ImagePrice4K,
|
ImagePrice4K: imagePrice4K,
|
||||||
}
|
}
|
||||||
if err := s.groupRepo.Create(ctx, group); err != nil {
|
if err := s.groupRepo.Create(ctx, group); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -533,6 +538,14 @@ func normalizeLimit(limit *float64) *float64 {
|
|||||||
return limit
|
return limit
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// normalizePrice 将负数转换为 nil(表示使用默认价格),0 保留(表示免费)
|
||||||
|
func normalizePrice(price *float64) *float64 {
|
||||||
|
if price == nil || *price < 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return price
|
||||||
|
}
|
||||||
|
|
||||||
func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) {
|
func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) {
|
||||||
group, err := s.groupRepo.GetByID(ctx, id)
|
group, err := s.groupRepo.GetByID(ctx, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -572,15 +585,15 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
|
|||||||
if input.MonthlyLimitUSD != nil {
|
if input.MonthlyLimitUSD != nil {
|
||||||
group.MonthlyLimitUSD = normalizeLimit(input.MonthlyLimitUSD)
|
group.MonthlyLimitUSD = normalizeLimit(input.MonthlyLimitUSD)
|
||||||
}
|
}
|
||||||
// 图片生成计费配置
|
// 图片生成计费配置:负数表示清除(使用默认价格)
|
||||||
if input.ImagePrice1K != nil {
|
if input.ImagePrice1K != nil {
|
||||||
group.ImagePrice1K = input.ImagePrice1K
|
group.ImagePrice1K = normalizePrice(input.ImagePrice1K)
|
||||||
}
|
}
|
||||||
if input.ImagePrice2K != nil {
|
if input.ImagePrice2K != nil {
|
||||||
group.ImagePrice2K = input.ImagePrice2K
|
group.ImagePrice2K = normalizePrice(input.ImagePrice2K)
|
||||||
}
|
}
|
||||||
if input.ImagePrice4K != nil {
|
if input.ImagePrice4K != nil {
|
||||||
group.ImagePrice4K = input.ImagePrice4K
|
group.ImagePrice4K = normalizePrice(input.ImagePrice4K)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.groupRepo.Update(ctx, group); err != nil {
|
if err := s.groupRepo.Update(ctx, group); err != nil {
|
||||||
|
|||||||
@@ -838,7 +838,7 @@ handleSuccess:
|
|||||||
// 判断是否为图片生成模型
|
// 判断是否为图片生成模型
|
||||||
imageCount := 0
|
imageCount := 0
|
||||||
if isImageGenerationModel(mappedModel) {
|
if isImageGenerationModel(mappedModel) {
|
||||||
// 图片模型按次计费,默认 1 张图片
|
// Gemini 图片生成 API 每次请求只生成一张图片(API 限制)
|
||||||
imageCount = 1
|
imageCount = 1
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1192,8 +1192,17 @@ func (s *AntigravityGatewayService) extractImageSize(body []byte) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// isImageGenerationModel 判断模型是否为图片生成模型
|
// isImageGenerationModel 判断模型是否为图片生成模型
|
||||||
|
// 支持的模型:gemini-3-pro-image, gemini-3-pro-image-preview, gemini-2.5-flash-image 等
|
||||||
func isImageGenerationModel(model string) bool {
|
func isImageGenerationModel(model string) bool {
|
||||||
modelLower := strings.ToLower(model)
|
modelLower := strings.ToLower(model)
|
||||||
return strings.Contains(modelLower, "gemini-3-pro-image") ||
|
// 移除 models/ 前缀
|
||||||
strings.Contains(modelLower, "gemini-2.5-flash-image")
|
modelLower = strings.TrimPrefix(modelLower, "models/")
|
||||||
|
|
||||||
|
// 精确匹配或前缀匹配
|
||||||
|
return modelLower == "gemini-3-pro-image" ||
|
||||||
|
modelLower == "gemini-3-pro-image-preview" ||
|
||||||
|
strings.HasPrefix(modelLower, "gemini-3-pro-image-") ||
|
||||||
|
modelLower == "gemini-2.5-flash-image" ||
|
||||||
|
modelLower == "gemini-2.5-flash-image-preview" ||
|
||||||
|
strings.HasPrefix(modelLower, "gemini-2.5-flash-image-")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -28,6 +28,9 @@ func TestIsImageGenerationModel_RegularModel(t *testing.T) {
|
|||||||
require.False(t, isImageGenerationModel("gpt-4o"))
|
require.False(t, isImageGenerationModel("gpt-4o"))
|
||||||
require.False(t, isImageGenerationModel("gemini-2.5-pro")) // 非图片模型
|
require.False(t, isImageGenerationModel("gemini-2.5-pro")) // 非图片模型
|
||||||
require.False(t, isImageGenerationModel("gemini-2.5-flash"))
|
require.False(t, isImageGenerationModel("gemini-2.5-flash"))
|
||||||
|
// 验证不会误匹配包含关键词的自定义模型名
|
||||||
|
require.False(t, isImageGenerationModel("my-gemini-3-pro-image-test"))
|
||||||
|
require.False(t, isImageGenerationModel("custom-gemini-2.5-flash-image-wrapper"))
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestIsImageGenerationModel_CaseInsensitive 测试大小写不敏感
|
// TestIsImageGenerationModel_CaseInsensitive 测试大小写不敏感
|
||||||
|
|||||||
Reference in New Issue
Block a user