merge: 合并远程分支并修复代码冲突
合并了远程分支 cb72262 的功能更新,同时保留了 ESLint 修复:
**冲突解决详情:**
1. AccountTableFilters.vue
- ✅ 保留 emit 模式修复(避免 vue/no-mutating-props 错误)
- ✅ 添加第三个筛选器 type(账户类型)
- ✅ 新增 antigravity 平台和 inactive 状态选项
2. UserBalanceModal.vue
- ✅ 保留 console.error 错误日志
- ✅ 添加输入验证(金额校验、余额不足检查)
- ✅ 使用 appStore.showError 向用户显示友好错误
3. AccountsView.vue
- ✅ 保留所有 console.error 错误日志(避免 no-empty 错误)
- ✅ 使用新 API:clearRateLimit 和 setSchedulable
4. UsageView.vue
- ✅ 添加 console.error 错误日志
- ✅ 添加图表功能(模型分布、使用趋势)
- ✅ 添加粒度选择(按天/按小时)
- ✅ 保留 XLSX 动态导入优化
**测试结果:**
- ✅ Go tests: PASS
- ✅ golangci-lint: 0 issues
- ✅ ESLint: 0 errors
- ✅ TypeScript: PASS
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -98,6 +98,10 @@ type CreateGroupInput struct {
|
||||
DailyLimitUSD *float64 // 日限额 (USD)
|
||||
WeeklyLimitUSD *float64 // 周限额 (USD)
|
||||
MonthlyLimitUSD *float64 // 月限额 (USD)
|
||||
// 图片生成计费配置(仅 antigravity 平台使用)
|
||||
ImagePrice1K *float64
|
||||
ImagePrice2K *float64
|
||||
ImagePrice4K *float64
|
||||
}
|
||||
|
||||
type UpdateGroupInput struct {
|
||||
@@ -111,6 +115,10 @@ type UpdateGroupInput struct {
|
||||
DailyLimitUSD *float64 // 日限额 (USD)
|
||||
WeeklyLimitUSD *float64 // 周限额 (USD)
|
||||
MonthlyLimitUSD *float64 // 月限额 (USD)
|
||||
// 图片生成计费配置(仅 antigravity 平台使用)
|
||||
ImagePrice1K *float64
|
||||
ImagePrice2K *float64
|
||||
ImagePrice4K *float64
|
||||
}
|
||||
|
||||
type CreateAccountInput struct {
|
||||
@@ -498,6 +506,11 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
|
||||
weeklyLimit := normalizeLimit(input.WeeklyLimitUSD)
|
||||
monthlyLimit := normalizeLimit(input.MonthlyLimitUSD)
|
||||
|
||||
// 图片价格:负数表示清除(使用默认价格),0 保留(表示免费)
|
||||
imagePrice1K := normalizePrice(input.ImagePrice1K)
|
||||
imagePrice2K := normalizePrice(input.ImagePrice2K)
|
||||
imagePrice4K := normalizePrice(input.ImagePrice4K)
|
||||
|
||||
group := &Group{
|
||||
Name: input.Name,
|
||||
Description: input.Description,
|
||||
@@ -509,6 +522,9 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
|
||||
DailyLimitUSD: dailyLimit,
|
||||
WeeklyLimitUSD: weeklyLimit,
|
||||
MonthlyLimitUSD: monthlyLimit,
|
||||
ImagePrice1K: imagePrice1K,
|
||||
ImagePrice2K: imagePrice2K,
|
||||
ImagePrice4K: imagePrice4K,
|
||||
}
|
||||
if err := s.groupRepo.Create(ctx, group); err != nil {
|
||||
return nil, err
|
||||
@@ -524,6 +540,14 @@ func normalizeLimit(limit *float64) *float64 {
|
||||
return limit
|
||||
}
|
||||
|
||||
// normalizePrice 将负数转换为 nil(表示使用默认价格),0 保留(表示免费)
|
||||
func normalizePrice(price *float64) *float64 {
|
||||
if price == nil || *price < 0 {
|
||||
return nil
|
||||
}
|
||||
return price
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) {
|
||||
group, err := s.groupRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
@@ -563,6 +587,16 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
|
||||
if input.MonthlyLimitUSD != nil {
|
||||
group.MonthlyLimitUSD = normalizeLimit(input.MonthlyLimitUSD)
|
||||
}
|
||||
// 图片生成计费配置:负数表示清除(使用默认价格)
|
||||
if input.ImagePrice1K != nil {
|
||||
group.ImagePrice1K = normalizePrice(input.ImagePrice1K)
|
||||
}
|
||||
if input.ImagePrice2K != nil {
|
||||
group.ImagePrice2K = normalizePrice(input.ImagePrice2K)
|
||||
}
|
||||
if input.ImagePrice4K != nil {
|
||||
group.ImagePrice4K = normalizePrice(input.ImagePrice4K)
|
||||
}
|
||||
|
||||
if err := s.groupRepo.Update(ctx, group); err != nil {
|
||||
return nil, err
|
||||
@@ -702,7 +736,12 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
|
||||
account.Extra = input.Extra
|
||||
}
|
||||
if input.ProxyID != nil {
|
||||
account.ProxyID = input.ProxyID
|
||||
// 0 表示清除代理(前端发送 0 而不是 null 来表达清除意图)
|
||||
if *input.ProxyID == 0 {
|
||||
account.ProxyID = nil
|
||||
} else {
|
||||
account.ProxyID = input.ProxyID
|
||||
}
|
||||
account.Proxy = nil // 清除关联对象,防止 GORM Save 时根据 Proxy.ID 覆盖 ProxyID
|
||||
}
|
||||
// 只在指针非 nil 时更新 Concurrency(支持设置为 0)
|
||||
|
||||
197
backend/internal/service/admin_service_group_test.go
Normal file
197
backend/internal/service/admin_service_group_test.go
Normal file
@@ -0,0 +1,197 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// groupRepoStubForAdmin 用于测试 AdminService 的 GroupRepository Stub
|
||||
type groupRepoStubForAdmin struct {
|
||||
created *Group // 记录 Create 调用的参数
|
||||
updated *Group // 记录 Update 调用的参数
|
||||
getByID *Group // GetByID 返回值
|
||||
getErr error // GetByID 返回的错误
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) Create(_ context.Context, g *Group) error {
|
||||
s.created = g
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) Update(_ context.Context, g *Group) error {
|
||||
s.updated = g
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) GetByID(_ context.Context, _ int64) (*Group, error) {
|
||||
if s.getErr != nil {
|
||||
return nil, s.getErr
|
||||
}
|
||||
return s.getByID, nil
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) Delete(_ context.Context, _ int64) error {
|
||||
panic("unexpected Delete call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) DeleteCascade(_ context.Context, _ int64) ([]int64, error) {
|
||||
panic("unexpected DeleteCascade call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) List(_ context.Context, _ pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
|
||||
panic("unexpected List call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) ListWithFilters(_ context.Context, _ pagination.PaginationParams, _, _ string, _ *bool) ([]Group, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListWithFilters call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) ListActive(_ context.Context) ([]Group, error) {
|
||||
panic("unexpected ListActive call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) ListActiveByPlatform(_ context.Context, _ string) ([]Group, error) {
|
||||
panic("unexpected ListActiveByPlatform call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) ExistsByName(_ context.Context, _ string) (bool, error) {
|
||||
panic("unexpected ExistsByName call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) GetAccountCount(_ context.Context, _ int64) (int64, error) {
|
||||
panic("unexpected GetAccountCount call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) DeleteAccountGroupsByGroupID(_ context.Context, _ int64) (int64, error) {
|
||||
panic("unexpected DeleteAccountGroupsByGroupID call")
|
||||
}
|
||||
|
||||
// TestAdminService_CreateGroup_WithImagePricing 测试创建分组时 ImagePrice 字段正确传递
|
||||
func TestAdminService_CreateGroup_WithImagePricing(t *testing.T) {
|
||||
repo := &groupRepoStubForAdmin{}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
price1K := 0.10
|
||||
price2K := 0.15
|
||||
price4K := 0.30
|
||||
|
||||
input := &CreateGroupInput{
|
||||
Name: "test-group",
|
||||
Description: "Test group",
|
||||
Platform: PlatformAntigravity,
|
||||
RateMultiplier: 1.0,
|
||||
ImagePrice1K: &price1K,
|
||||
ImagePrice2K: &price2K,
|
||||
ImagePrice4K: &price4K,
|
||||
}
|
||||
|
||||
group, err := svc.CreateGroup(context.Background(), input)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, group)
|
||||
|
||||
// 验证 repo 收到了正确的字段
|
||||
require.NotNil(t, repo.created)
|
||||
require.NotNil(t, repo.created.ImagePrice1K)
|
||||
require.NotNil(t, repo.created.ImagePrice2K)
|
||||
require.NotNil(t, repo.created.ImagePrice4K)
|
||||
require.InDelta(t, 0.10, *repo.created.ImagePrice1K, 0.0001)
|
||||
require.InDelta(t, 0.15, *repo.created.ImagePrice2K, 0.0001)
|
||||
require.InDelta(t, 0.30, *repo.created.ImagePrice4K, 0.0001)
|
||||
}
|
||||
|
||||
// TestAdminService_CreateGroup_NilImagePricing 测试 ImagePrice 为 nil 时正常创建
|
||||
func TestAdminService_CreateGroup_NilImagePricing(t *testing.T) {
|
||||
repo := &groupRepoStubForAdmin{}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
input := &CreateGroupInput{
|
||||
Name: "test-group",
|
||||
Description: "Test group",
|
||||
Platform: PlatformAntigravity,
|
||||
RateMultiplier: 1.0,
|
||||
// ImagePrice 字段全部为 nil
|
||||
}
|
||||
|
||||
group, err := svc.CreateGroup(context.Background(), input)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, group)
|
||||
|
||||
// 验证 ImagePrice 字段为 nil
|
||||
require.NotNil(t, repo.created)
|
||||
require.Nil(t, repo.created.ImagePrice1K)
|
||||
require.Nil(t, repo.created.ImagePrice2K)
|
||||
require.Nil(t, repo.created.ImagePrice4K)
|
||||
}
|
||||
|
||||
// TestAdminService_UpdateGroup_WithImagePricing 测试更新分组时 ImagePrice 字段正确更新
|
||||
func TestAdminService_UpdateGroup_WithImagePricing(t *testing.T) {
|
||||
existingGroup := &Group{
|
||||
ID: 1,
|
||||
Name: "existing-group",
|
||||
Platform: PlatformAntigravity,
|
||||
Status: StatusActive,
|
||||
}
|
||||
repo := &groupRepoStubForAdmin{getByID: existingGroup}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
price1K := 0.12
|
||||
price2K := 0.18
|
||||
price4K := 0.36
|
||||
|
||||
input := &UpdateGroupInput{
|
||||
ImagePrice1K: &price1K,
|
||||
ImagePrice2K: &price2K,
|
||||
ImagePrice4K: &price4K,
|
||||
}
|
||||
|
||||
group, err := svc.UpdateGroup(context.Background(), 1, input)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, group)
|
||||
|
||||
// 验证 repo 收到了更新后的字段
|
||||
require.NotNil(t, repo.updated)
|
||||
require.NotNil(t, repo.updated.ImagePrice1K)
|
||||
require.NotNil(t, repo.updated.ImagePrice2K)
|
||||
require.NotNil(t, repo.updated.ImagePrice4K)
|
||||
require.InDelta(t, 0.12, *repo.updated.ImagePrice1K, 0.0001)
|
||||
require.InDelta(t, 0.18, *repo.updated.ImagePrice2K, 0.0001)
|
||||
require.InDelta(t, 0.36, *repo.updated.ImagePrice4K, 0.0001)
|
||||
}
|
||||
|
||||
// TestAdminService_UpdateGroup_PartialImagePricing 测试仅更新部分 ImagePrice 字段
|
||||
func TestAdminService_UpdateGroup_PartialImagePricing(t *testing.T) {
|
||||
oldPrice2K := 0.15
|
||||
existingGroup := &Group{
|
||||
ID: 1,
|
||||
Name: "existing-group",
|
||||
Platform: PlatformAntigravity,
|
||||
Status: StatusActive,
|
||||
ImagePrice2K: &oldPrice2K, // 已有 2K 价格
|
||||
}
|
||||
repo := &groupRepoStubForAdmin{getByID: existingGroup}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
// 只更新 1K 价格
|
||||
price1K := 0.10
|
||||
input := &UpdateGroupInput{
|
||||
ImagePrice1K: &price1K,
|
||||
// ImagePrice2K 和 ImagePrice4K 为 nil,不更新
|
||||
}
|
||||
|
||||
group, err := svc.UpdateGroup(context.Background(), 1, input)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, group)
|
||||
|
||||
// 验证:1K 被更新,2K 保持原值,4K 仍为 nil
|
||||
require.NotNil(t, repo.updated)
|
||||
require.NotNil(t, repo.updated.ImagePrice1K)
|
||||
require.InDelta(t, 0.10, *repo.updated.ImagePrice1K, 0.0001)
|
||||
require.NotNil(t, repo.updated.ImagePrice2K)
|
||||
require.InDelta(t, 0.15, *repo.updated.ImagePrice2K, 0.0001) // 原值保持
|
||||
require.Nil(t, repo.updated.ImagePrice4K)
|
||||
}
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
mathrand "math/rand"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
@@ -405,6 +406,14 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
// 重试循环
|
||||
var resp *http.Response
|
||||
for attempt := 1; attempt <= antigravityMaxRetries; attempt++ {
|
||||
// 检查 context 是否已取消(客户端断开连接)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Printf("%s status=context_canceled error=%v", prefix, ctx.Err())
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
upstreamReq, err := antigravity.NewAPIRequest(ctx, action, accessToken, geminiBody)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -414,7 +423,10 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
if err != nil {
|
||||
if attempt < antigravityMaxRetries {
|
||||
log.Printf("%s status=request_failed retry=%d/%d error=%v", prefix, attempt, antigravityMaxRetries, err)
|
||||
sleepAntigravityBackoff(attempt)
|
||||
if !sleepAntigravityBackoffWithContext(ctx, attempt) {
|
||||
log.Printf("%s status=context_canceled_during_backoff", prefix)
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
continue
|
||||
}
|
||||
log.Printf("%s status=request_failed retries_exhausted error=%v", prefix, err)
|
||||
@@ -427,7 +439,10 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
|
||||
if attempt < antigravityMaxRetries {
|
||||
log.Printf("%s status=%d retry=%d/%d", prefix, resp.StatusCode, attempt, antigravityMaxRetries)
|
||||
sleepAntigravityBackoff(attempt)
|
||||
if !sleepAntigravityBackoffWithContext(ctx, attempt) {
|
||||
log.Printf("%s status=context_canceled_during_backoff", prefix)
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
continue
|
||||
}
|
||||
// 所有重试都失败,标记限流状态
|
||||
@@ -845,6 +860,9 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
||||
return nil, s.writeGoogleError(c, http.StatusBadRequest, "Request body is empty")
|
||||
}
|
||||
|
||||
// 解析请求以获取 image_size(用于图片计费)
|
||||
imageSize := s.extractImageSize(body)
|
||||
|
||||
switch action {
|
||||
case "generateContent", "streamGenerateContent":
|
||||
// ok
|
||||
@@ -901,6 +919,14 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
||||
// 重试循环
|
||||
var resp *http.Response
|
||||
for attempt := 1; attempt <= antigravityMaxRetries; attempt++ {
|
||||
// 检查 context 是否已取消(客户端断开连接)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Printf("%s status=context_canceled error=%v", prefix, ctx.Err())
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
upstreamReq, err := antigravity.NewAPIRequest(ctx, upstreamAction, accessToken, wrappedBody)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -910,7 +936,10 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
||||
if err != nil {
|
||||
if attempt < antigravityMaxRetries {
|
||||
log.Printf("%s status=request_failed retry=%d/%d error=%v", prefix, attempt, antigravityMaxRetries, err)
|
||||
sleepAntigravityBackoff(attempt)
|
||||
if !sleepAntigravityBackoffWithContext(ctx, attempt) {
|
||||
log.Printf("%s status=context_canceled_during_backoff", prefix)
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
continue
|
||||
}
|
||||
log.Printf("%s status=request_failed retries_exhausted error=%v", prefix, err)
|
||||
@@ -923,7 +952,10 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
||||
|
||||
if attempt < antigravityMaxRetries {
|
||||
log.Printf("%s status=%d retry=%d/%d", prefix, resp.StatusCode, attempt, antigravityMaxRetries)
|
||||
sleepAntigravityBackoff(attempt)
|
||||
if !sleepAntigravityBackoffWithContext(ctx, attempt) {
|
||||
log.Printf("%s status=context_canceled_during_backoff", prefix)
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
continue
|
||||
}
|
||||
// 所有重试都失败,标记限流状态
|
||||
@@ -1030,6 +1062,13 @@ handleSuccess:
|
||||
usage = &ClaudeUsage{}
|
||||
}
|
||||
|
||||
// 判断是否为图片生成模型
|
||||
imageCount := 0
|
||||
if isImageGenerationModel(mappedModel) {
|
||||
// Gemini 图片生成 API 每次请求只生成一张图片(API 限制)
|
||||
imageCount = 1
|
||||
}
|
||||
|
||||
return &ForwardResult{
|
||||
RequestID: requestID,
|
||||
Usage: *usage,
|
||||
@@ -1037,6 +1076,8 @@ handleSuccess:
|
||||
Stream: stream,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
ImageCount: imageCount,
|
||||
ImageSize: imageSize,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -1058,8 +1099,28 @@ func (s *AntigravityGatewayService) shouldFailoverUpstreamError(statusCode int)
|
||||
}
|
||||
}
|
||||
|
||||
func sleepAntigravityBackoff(attempt int) {
|
||||
sleepGeminiBackoff(attempt) // 复用 Gemini 的退避逻辑
|
||||
// sleepAntigravityBackoffWithContext 带 context 取消检查的退避等待
|
||||
// 返回 true 表示正常完成等待,false 表示 context 已取消
|
||||
func sleepAntigravityBackoffWithContext(ctx context.Context, attempt int) bool {
|
||||
delay := geminiRetryBaseDelay * time.Duration(1<<uint(attempt-1))
|
||||
if delay > geminiRetryMaxDelay {
|
||||
delay = geminiRetryMaxDelay
|
||||
}
|
||||
|
||||
// +/- 20% jitter
|
||||
r := mathrand.New(mathrand.NewSource(time.Now().UnixNano()))
|
||||
jitter := time.Duration(float64(delay) * 0.2 * (r.Float64()*2 - 1))
|
||||
sleepFor := delay + jitter
|
||||
if sleepFor < 0 {
|
||||
sleepFor = 0
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
case <-time.After(sleepFor):
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte) {
|
||||
@@ -1523,3 +1584,36 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// extractImageSize 从 Gemini 请求中提取 image_size 参数
|
||||
func (s *AntigravityGatewayService) extractImageSize(body []byte) string {
|
||||
var req antigravity.GeminiRequest
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
return "2K" // 默认 2K
|
||||
}
|
||||
|
||||
if req.GenerationConfig != nil && req.GenerationConfig.ImageConfig != nil {
|
||||
size := strings.ToUpper(strings.TrimSpace(req.GenerationConfig.ImageConfig.ImageSize))
|
||||
if size == "1K" || size == "2K" || size == "4K" {
|
||||
return size
|
||||
}
|
||||
}
|
||||
|
||||
return "2K" // 默认 2K
|
||||
}
|
||||
|
||||
// isImageGenerationModel 判断模型是否为图片生成模型
|
||||
// 支持的模型:gemini-3-pro-image, gemini-3-pro-image-preview, gemini-2.5-flash-image 等
|
||||
func isImageGenerationModel(model string) bool {
|
||||
modelLower := strings.ToLower(model)
|
||||
// 移除 models/ 前缀
|
||||
modelLower = strings.TrimPrefix(modelLower, "models/")
|
||||
|
||||
// 精确匹配或前缀匹配
|
||||
return modelLower == "gemini-3-pro-image" ||
|
||||
modelLower == "gemini-3-pro-image-preview" ||
|
||||
strings.HasPrefix(modelLower, "gemini-3-pro-image-") ||
|
||||
modelLower == "gemini-2.5-flash-image" ||
|
||||
modelLower == "gemini-2.5-flash-image-preview" ||
|
||||
strings.HasPrefix(modelLower, "gemini-2.5-flash-image-")
|
||||
}
|
||||
|
||||
123
backend/internal/service/antigravity_image_test.go
Normal file
123
backend/internal/service/antigravity_image_test.go
Normal file
@@ -0,0 +1,123 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestIsImageGenerationModel_GeminiProImage 测试 gemini-3-pro-image 识别
|
||||
func TestIsImageGenerationModel_GeminiProImage(t *testing.T) {
|
||||
require.True(t, isImageGenerationModel("gemini-3-pro-image"))
|
||||
require.True(t, isImageGenerationModel("gemini-3-pro-image-preview"))
|
||||
require.True(t, isImageGenerationModel("models/gemini-3-pro-image"))
|
||||
}
|
||||
|
||||
// TestIsImageGenerationModel_GeminiFlashImage 测试 gemini-2.5-flash-image 识别
|
||||
func TestIsImageGenerationModel_GeminiFlashImage(t *testing.T) {
|
||||
require.True(t, isImageGenerationModel("gemini-2.5-flash-image"))
|
||||
require.True(t, isImageGenerationModel("gemini-2.5-flash-image-preview"))
|
||||
}
|
||||
|
||||
// TestIsImageGenerationModel_RegularModel 测试普通模型不被识别为图片模型
|
||||
func TestIsImageGenerationModel_RegularModel(t *testing.T) {
|
||||
require.False(t, isImageGenerationModel("claude-3-opus"))
|
||||
require.False(t, isImageGenerationModel("claude-sonnet-4-20250514"))
|
||||
require.False(t, isImageGenerationModel("gpt-4o"))
|
||||
require.False(t, isImageGenerationModel("gemini-2.5-pro")) // 非图片模型
|
||||
require.False(t, isImageGenerationModel("gemini-2.5-flash"))
|
||||
// 验证不会误匹配包含关键词的自定义模型名
|
||||
require.False(t, isImageGenerationModel("my-gemini-3-pro-image-test"))
|
||||
require.False(t, isImageGenerationModel("custom-gemini-2.5-flash-image-wrapper"))
|
||||
}
|
||||
|
||||
// TestIsImageGenerationModel_CaseInsensitive 测试大小写不敏感
|
||||
func TestIsImageGenerationModel_CaseInsensitive(t *testing.T) {
|
||||
require.True(t, isImageGenerationModel("GEMINI-3-PRO-IMAGE"))
|
||||
require.True(t, isImageGenerationModel("Gemini-3-Pro-Image"))
|
||||
require.True(t, isImageGenerationModel("GEMINI-2.5-FLASH-IMAGE"))
|
||||
}
|
||||
|
||||
// TestExtractImageSize_ValidSizes 测试有效尺寸解析
|
||||
func TestExtractImageSize_ValidSizes(t *testing.T) {
|
||||
svc := &AntigravityGatewayService{}
|
||||
|
||||
// 1K
|
||||
body := []byte(`{"generationConfig":{"imageConfig":{"imageSize":"1K"}}}`)
|
||||
require.Equal(t, "1K", svc.extractImageSize(body))
|
||||
|
||||
// 2K
|
||||
body = []byte(`{"generationConfig":{"imageConfig":{"imageSize":"2K"}}}`)
|
||||
require.Equal(t, "2K", svc.extractImageSize(body))
|
||||
|
||||
// 4K
|
||||
body = []byte(`{"generationConfig":{"imageConfig":{"imageSize":"4K"}}}`)
|
||||
require.Equal(t, "4K", svc.extractImageSize(body))
|
||||
}
|
||||
|
||||
// TestExtractImageSize_CaseInsensitive 测试大小写不敏感
|
||||
func TestExtractImageSize_CaseInsensitive(t *testing.T) {
|
||||
svc := &AntigravityGatewayService{}
|
||||
|
||||
body := []byte(`{"generationConfig":{"imageConfig":{"imageSize":"1k"}}}`)
|
||||
require.Equal(t, "1K", svc.extractImageSize(body))
|
||||
|
||||
body = []byte(`{"generationConfig":{"imageConfig":{"imageSize":"4k"}}}`)
|
||||
require.Equal(t, "4K", svc.extractImageSize(body))
|
||||
}
|
||||
|
||||
// TestExtractImageSize_Default 测试无 imageConfig 返回默认 2K
|
||||
func TestExtractImageSize_Default(t *testing.T) {
|
||||
svc := &AntigravityGatewayService{}
|
||||
|
||||
// 无 generationConfig
|
||||
body := []byte(`{"contents":[]}`)
|
||||
require.Equal(t, "2K", svc.extractImageSize(body))
|
||||
|
||||
// 有 generationConfig 但无 imageConfig
|
||||
body = []byte(`{"generationConfig":{"temperature":0.7}}`)
|
||||
require.Equal(t, "2K", svc.extractImageSize(body))
|
||||
|
||||
// 有 imageConfig 但无 imageSize
|
||||
body = []byte(`{"generationConfig":{"imageConfig":{}}}`)
|
||||
require.Equal(t, "2K", svc.extractImageSize(body))
|
||||
}
|
||||
|
||||
// TestExtractImageSize_InvalidJSON 测试非法 JSON 返回默认 2K
|
||||
func TestExtractImageSize_InvalidJSON(t *testing.T) {
|
||||
svc := &AntigravityGatewayService{}
|
||||
|
||||
body := []byte(`not valid json`)
|
||||
require.Equal(t, "2K", svc.extractImageSize(body))
|
||||
|
||||
body = []byte(`{"broken":`)
|
||||
require.Equal(t, "2K", svc.extractImageSize(body))
|
||||
}
|
||||
|
||||
// TestExtractImageSize_EmptySize 测试空 imageSize 返回默认 2K
|
||||
func TestExtractImageSize_EmptySize(t *testing.T) {
|
||||
svc := &AntigravityGatewayService{}
|
||||
|
||||
body := []byte(`{"generationConfig":{"imageConfig":{"imageSize":""}}}`)
|
||||
require.Equal(t, "2K", svc.extractImageSize(body))
|
||||
|
||||
// 空格
|
||||
body = []byte(`{"generationConfig":{"imageConfig":{"imageSize":" "}}}`)
|
||||
require.Equal(t, "2K", svc.extractImageSize(body))
|
||||
}
|
||||
|
||||
// TestExtractImageSize_InvalidSize 测试无效尺寸返回默认 2K
|
||||
func TestExtractImageSize_InvalidSize(t *testing.T) {
|
||||
svc := &AntigravityGatewayService{}
|
||||
|
||||
body := []byte(`{"generationConfig":{"imageConfig":{"imageSize":"3K"}}}`)
|
||||
require.Equal(t, "2K", svc.extractImageSize(body))
|
||||
|
||||
body = []byte(`{"generationConfig":{"imageConfig":{"imageSize":"8K"}}}`)
|
||||
require.Equal(t, "2K", svc.extractImageSize(body))
|
||||
|
||||
body = []byte(`{"generationConfig":{"imageConfig":{"imageSize":"invalid"}}}`)
|
||||
require.Equal(t, "2K", svc.extractImageSize(body))
|
||||
}
|
||||
@@ -295,3 +295,88 @@ func (s *BillingService) ForceUpdatePricing() error {
|
||||
}
|
||||
return fmt.Errorf("pricing service not initialized")
|
||||
}
|
||||
|
||||
// ImagePriceConfig 图片计费配置
|
||||
type ImagePriceConfig struct {
|
||||
Price1K *float64 // 1K 尺寸价格(nil 表示使用默认值)
|
||||
Price2K *float64 // 2K 尺寸价格(nil 表示使用默认值)
|
||||
Price4K *float64 // 4K 尺寸价格(nil 表示使用默认值)
|
||||
}
|
||||
|
||||
// CalculateImageCost 计算图片生成费用
|
||||
// model: 请求的模型名称(用于获取 LiteLLM 默认价格)
|
||||
// imageSize: 图片尺寸 "1K", "2K", "4K"
|
||||
// imageCount: 生成的图片数量
|
||||
// groupConfig: 分组配置的价格(可能为 nil,表示使用默认值)
|
||||
// rateMultiplier: 费率倍数
|
||||
func (s *BillingService) CalculateImageCost(model string, imageSize string, imageCount int, groupConfig *ImagePriceConfig, rateMultiplier float64) *CostBreakdown {
|
||||
if imageCount <= 0 {
|
||||
return &CostBreakdown{}
|
||||
}
|
||||
|
||||
// 获取单价
|
||||
unitPrice := s.getImageUnitPrice(model, imageSize, groupConfig)
|
||||
|
||||
// 计算总费用
|
||||
totalCost := unitPrice * float64(imageCount)
|
||||
|
||||
// 应用倍率
|
||||
if rateMultiplier <= 0 {
|
||||
rateMultiplier = 1.0
|
||||
}
|
||||
actualCost := totalCost * rateMultiplier
|
||||
|
||||
return &CostBreakdown{
|
||||
TotalCost: totalCost,
|
||||
ActualCost: actualCost,
|
||||
}
|
||||
}
|
||||
|
||||
// getImageUnitPrice 获取图片单价
|
||||
func (s *BillingService) getImageUnitPrice(model string, imageSize string, groupConfig *ImagePriceConfig) float64 {
|
||||
// 优先使用分组配置的价格
|
||||
if groupConfig != nil {
|
||||
switch imageSize {
|
||||
case "1K":
|
||||
if groupConfig.Price1K != nil {
|
||||
return *groupConfig.Price1K
|
||||
}
|
||||
case "2K":
|
||||
if groupConfig.Price2K != nil {
|
||||
return *groupConfig.Price2K
|
||||
}
|
||||
case "4K":
|
||||
if groupConfig.Price4K != nil {
|
||||
return *groupConfig.Price4K
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 回退到 LiteLLM 默认价格
|
||||
return s.getDefaultImagePrice(model, imageSize)
|
||||
}
|
||||
|
||||
// getDefaultImagePrice 获取 LiteLLM 默认图片价格
|
||||
func (s *BillingService) getDefaultImagePrice(model string, imageSize string) float64 {
|
||||
basePrice := 0.0
|
||||
|
||||
// 从 PricingService 获取 output_cost_per_image
|
||||
if s.pricingService != nil {
|
||||
pricing := s.pricingService.GetModelPricing(model)
|
||||
if pricing != nil && pricing.OutputCostPerImage > 0 {
|
||||
basePrice = pricing.OutputCostPerImage
|
||||
}
|
||||
}
|
||||
|
||||
// 如果没有找到价格,使用硬编码默认值($0.134,来自 gemini-3-pro-image-preview)
|
||||
if basePrice <= 0 {
|
||||
basePrice = 0.134
|
||||
}
|
||||
|
||||
// 4K 尺寸翻倍
|
||||
if imageSize == "4K" {
|
||||
return basePrice * 2
|
||||
}
|
||||
|
||||
return basePrice
|
||||
}
|
||||
|
||||
149
backend/internal/service/billing_service_image_test.go
Normal file
149
backend/internal/service/billing_service_image_test.go
Normal file
@@ -0,0 +1,149 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestCalculateImageCost_DefaultPricing 测试无分组配置时使用默认价格
|
||||
func TestCalculateImageCost_DefaultPricing(t *testing.T) {
|
||||
svc := &BillingService{} // pricingService 为 nil,使用硬编码默认值
|
||||
|
||||
// 2K 尺寸,默认价格 $0.134
|
||||
cost := svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, nil, 1.0)
|
||||
require.InDelta(t, 0.134, cost.TotalCost, 0.0001)
|
||||
require.InDelta(t, 0.134, cost.ActualCost, 0.0001)
|
||||
|
||||
// 多张图片
|
||||
cost = svc.CalculateImageCost("gemini-3-pro-image", "2K", 3, nil, 1.0)
|
||||
require.InDelta(t, 0.402, cost.TotalCost, 0.0001)
|
||||
}
|
||||
|
||||
// TestCalculateImageCost_GroupCustomPricing 测试分组自定义价格
|
||||
func TestCalculateImageCost_GroupCustomPricing(t *testing.T) {
|
||||
svc := &BillingService{}
|
||||
|
||||
price1K := 0.10
|
||||
price2K := 0.15
|
||||
price4K := 0.30
|
||||
groupConfig := &ImagePriceConfig{
|
||||
Price1K: &price1K,
|
||||
Price2K: &price2K,
|
||||
Price4K: &price4K,
|
||||
}
|
||||
|
||||
// 1K 使用分组价格
|
||||
cost := svc.CalculateImageCost("gemini-3-pro-image", "1K", 2, groupConfig, 1.0)
|
||||
require.InDelta(t, 0.20, cost.TotalCost, 0.0001)
|
||||
|
||||
// 2K 使用分组价格
|
||||
cost = svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, groupConfig, 1.0)
|
||||
require.InDelta(t, 0.15, cost.TotalCost, 0.0001)
|
||||
|
||||
// 4K 使用分组价格
|
||||
cost = svc.CalculateImageCost("gemini-3-pro-image", "4K", 1, groupConfig, 1.0)
|
||||
require.InDelta(t, 0.30, cost.TotalCost, 0.0001)
|
||||
}
|
||||
|
||||
// TestCalculateImageCost_4KDoublePrice 测试 4K 默认价格翻倍
|
||||
func TestCalculateImageCost_4KDoublePrice(t *testing.T) {
|
||||
svc := &BillingService{}
|
||||
|
||||
// 4K 尺寸,默认价格翻倍 $0.134 * 2 = $0.268
|
||||
cost := svc.CalculateImageCost("gemini-3-pro-image", "4K", 1, nil, 1.0)
|
||||
require.InDelta(t, 0.268, cost.TotalCost, 0.0001)
|
||||
}
|
||||
|
||||
// TestCalculateImageCost_RateMultiplier 测试费率倍数
|
||||
func TestCalculateImageCost_RateMultiplier(t *testing.T) {
|
||||
svc := &BillingService{}
|
||||
|
||||
// 费率倍数 1.5x
|
||||
cost := svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, nil, 1.5)
|
||||
require.InDelta(t, 0.134, cost.TotalCost, 0.0001) // TotalCost 不变
|
||||
require.InDelta(t, 0.201, cost.ActualCost, 0.0001) // ActualCost = 0.134 * 1.5
|
||||
|
||||
// 费率倍数 2.0x
|
||||
cost = svc.CalculateImageCost("gemini-3-pro-image", "2K", 2, nil, 2.0)
|
||||
require.InDelta(t, 0.268, cost.TotalCost, 0.0001)
|
||||
require.InDelta(t, 0.536, cost.ActualCost, 0.0001)
|
||||
}
|
||||
|
||||
// TestCalculateImageCost_ZeroCount 测试 imageCount=0
|
||||
func TestCalculateImageCost_ZeroCount(t *testing.T) {
|
||||
svc := &BillingService{}
|
||||
|
||||
cost := svc.CalculateImageCost("gemini-3-pro-image", "2K", 0, nil, 1.0)
|
||||
require.Equal(t, 0.0, cost.TotalCost)
|
||||
require.Equal(t, 0.0, cost.ActualCost)
|
||||
}
|
||||
|
||||
// TestCalculateImageCost_NegativeCount 测试 imageCount=-1
|
||||
func TestCalculateImageCost_NegativeCount(t *testing.T) {
|
||||
svc := &BillingService{}
|
||||
|
||||
cost := svc.CalculateImageCost("gemini-3-pro-image", "2K", -1, nil, 1.0)
|
||||
require.Equal(t, 0.0, cost.TotalCost)
|
||||
require.Equal(t, 0.0, cost.ActualCost)
|
||||
}
|
||||
|
||||
// TestCalculateImageCost_ZeroRateMultiplier 测试费率倍数为 0 时默认使用 1.0
|
||||
func TestCalculateImageCost_ZeroRateMultiplier(t *testing.T) {
|
||||
svc := &BillingService{}
|
||||
|
||||
cost := svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, nil, 0)
|
||||
require.InDelta(t, 0.134, cost.TotalCost, 0.0001)
|
||||
require.InDelta(t, 0.134, cost.ActualCost, 0.0001) // 0 倍率当作 1.0 处理
|
||||
}
|
||||
|
||||
// TestGetImageUnitPrice_GroupPriorityOverDefault 测试分组价格优先于默认价格
|
||||
func TestGetImageUnitPrice_GroupPriorityOverDefault(t *testing.T) {
|
||||
svc := &BillingService{}
|
||||
|
||||
price2K := 0.20
|
||||
groupConfig := &ImagePriceConfig{
|
||||
Price2K: &price2K,
|
||||
}
|
||||
|
||||
// 分组配置了 2K 价格,应该使用分组价格而不是默认的 $0.134
|
||||
cost := svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, groupConfig, 1.0)
|
||||
require.InDelta(t, 0.20, cost.TotalCost, 0.0001)
|
||||
}
|
||||
|
||||
// TestGetImageUnitPrice_PartialGroupConfig 测试分组部分配置时回退默认
|
||||
func TestGetImageUnitPrice_PartialGroupConfig(t *testing.T) {
|
||||
svc := &BillingService{}
|
||||
|
||||
// 只配置 1K 价格
|
||||
price1K := 0.10
|
||||
groupConfig := &ImagePriceConfig{
|
||||
Price1K: &price1K,
|
||||
}
|
||||
|
||||
// 1K 使用分组价格
|
||||
cost := svc.CalculateImageCost("gemini-3-pro-image", "1K", 1, groupConfig, 1.0)
|
||||
require.InDelta(t, 0.10, cost.TotalCost, 0.0001)
|
||||
|
||||
// 2K 回退默认价格 $0.134
|
||||
cost = svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, groupConfig, 1.0)
|
||||
require.InDelta(t, 0.134, cost.TotalCost, 0.0001)
|
||||
|
||||
// 4K 回退默认价格 $0.268 (翻倍)
|
||||
cost = svc.CalculateImageCost("gemini-3-pro-image", "4K", 1, groupConfig, 1.0)
|
||||
require.InDelta(t, 0.268, cost.TotalCost, 0.0001)
|
||||
}
|
||||
|
||||
// TestGetDefaultImagePrice_FallbackHardcoded 测试 PricingService 无数据时使用硬编码默认值
|
||||
func TestGetDefaultImagePrice_FallbackHardcoded(t *testing.T) {
|
||||
svc := &BillingService{} // pricingService 为 nil
|
||||
|
||||
// 1K 和 2K 使用相同的默认价格 $0.134
|
||||
cost := svc.CalculateImageCost("gemini-3-pro-image", "1K", 1, nil, 1.0)
|
||||
require.InDelta(t, 0.134, cost.TotalCost, 0.0001)
|
||||
|
||||
cost = svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, nil, 1.0)
|
||||
require.InDelta(t, 0.134, cost.TotalCost, 0.0001)
|
||||
}
|
||||
@@ -104,6 +104,10 @@ type ForwardResult struct {
|
||||
Stream bool
|
||||
Duration time.Duration
|
||||
FirstTokenMs *int // 首字时间(流式请求)
|
||||
|
||||
// 图片生成计费字段(仅 gemini-3-pro-image 使用)
|
||||
ImageCount int // 生成的图片数量
|
||||
ImageSize string // 图片尺寸 "1K", "2K", "4K"
|
||||
}
|
||||
|
||||
// UpstreamFailoverError indicates an upstream error that should trigger account failover.
|
||||
@@ -2009,25 +2013,40 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
account := input.Account
|
||||
subscription := input.Subscription
|
||||
|
||||
// 计算费用
|
||||
tokens := UsageTokens{
|
||||
InputTokens: result.Usage.InputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
||||
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
||||
}
|
||||
|
||||
// 获取费率倍数
|
||||
multiplier := s.cfg.Default.RateMultiplier
|
||||
if apiKey.GroupID != nil && apiKey.Group != nil {
|
||||
multiplier = apiKey.Group.RateMultiplier
|
||||
}
|
||||
|
||||
cost, err := s.billingService.CalculateCost(result.Model, tokens, multiplier)
|
||||
if err != nil {
|
||||
log.Printf("Calculate cost failed: %v", err)
|
||||
// 使用默认费用继续
|
||||
cost = &CostBreakdown{ActualCost: 0}
|
||||
var cost *CostBreakdown
|
||||
|
||||
// 根据请求类型选择计费方式
|
||||
if result.ImageCount > 0 {
|
||||
// 图片生成计费
|
||||
var groupConfig *ImagePriceConfig
|
||||
if apiKey.Group != nil {
|
||||
groupConfig = &ImagePriceConfig{
|
||||
Price1K: apiKey.Group.ImagePrice1K,
|
||||
Price2K: apiKey.Group.ImagePrice2K,
|
||||
Price4K: apiKey.Group.ImagePrice4K,
|
||||
}
|
||||
}
|
||||
cost = s.billingService.CalculateImageCost(result.Model, result.ImageSize, result.ImageCount, groupConfig, multiplier)
|
||||
} else {
|
||||
// Token 计费
|
||||
tokens := UsageTokens{
|
||||
InputTokens: result.Usage.InputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
|
||||
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
||||
}
|
||||
var err error
|
||||
cost, err = s.billingService.CalculateCost(result.Model, tokens, multiplier)
|
||||
if err != nil {
|
||||
log.Printf("Calculate cost failed: %v", err)
|
||||
cost = &CostBreakdown{ActualCost: 0}
|
||||
}
|
||||
}
|
||||
|
||||
// 判断计费方式:订阅模式 vs 余额模式
|
||||
@@ -2039,6 +2058,10 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
|
||||
// 创建使用日志
|
||||
durationMs := int(result.Duration.Milliseconds())
|
||||
var imageSize *string
|
||||
if result.ImageSize != "" {
|
||||
imageSize = &result.ImageSize
|
||||
}
|
||||
usageLog := &UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
@@ -2060,6 +2083,8 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
Stream: result.Stream,
|
||||
DurationMs: &durationMs,
|
||||
FirstTokenMs: result.FirstTokenMs,
|
||||
ImageCount: result.ImageCount,
|
||||
ImageSize: imageSize,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
|
||||
@@ -17,6 +17,11 @@ type Group struct {
|
||||
MonthlyLimitUSD *float64
|
||||
DefaultValidityDays int
|
||||
|
||||
// 图片生成计费配置(antigravity 和 gemini 平台使用)
|
||||
ImagePrice1K *float64
|
||||
ImagePrice2K *float64
|
||||
ImagePrice4K *float64
|
||||
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
|
||||
@@ -47,3 +52,19 @@ func (g *Group) HasWeeklyLimit() bool {
|
||||
func (g *Group) HasMonthlyLimit() bool {
|
||||
return g.MonthlyLimitUSD != nil && *g.MonthlyLimitUSD > 0
|
||||
}
|
||||
|
||||
// GetImagePrice 根据 image_size 返回对应的图片生成价格
|
||||
// 如果分组未配置价格,返回 nil(调用方应使用默认值)
|
||||
func (g *Group) GetImagePrice(imageSize string) *float64 {
|
||||
switch imageSize {
|
||||
case "1K":
|
||||
return g.ImagePrice1K
|
||||
case "2K":
|
||||
return g.ImagePrice2K
|
||||
case "4K":
|
||||
return g.ImagePrice4K
|
||||
default:
|
||||
// 未知尺寸默认按 2K 计费
|
||||
return g.ImagePrice2K
|
||||
}
|
||||
}
|
||||
|
||||
92
backend/internal/service/group_test.go
Normal file
92
backend/internal/service/group_test.go
Normal file
@@ -0,0 +1,92 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestGroup_GetImagePrice_1K 测试 1K 尺寸返回正确价格
|
||||
func TestGroup_GetImagePrice_1K(t *testing.T) {
|
||||
price := 0.10
|
||||
group := &Group{
|
||||
ImagePrice1K: &price,
|
||||
}
|
||||
|
||||
result := group.GetImagePrice("1K")
|
||||
require.NotNil(t, result)
|
||||
require.InDelta(t, 0.10, *result, 0.0001)
|
||||
}
|
||||
|
||||
// TestGroup_GetImagePrice_2K 测试 2K 尺寸返回正确价格
|
||||
func TestGroup_GetImagePrice_2K(t *testing.T) {
|
||||
price := 0.15
|
||||
group := &Group{
|
||||
ImagePrice2K: &price,
|
||||
}
|
||||
|
||||
result := group.GetImagePrice("2K")
|
||||
require.NotNil(t, result)
|
||||
require.InDelta(t, 0.15, *result, 0.0001)
|
||||
}
|
||||
|
||||
// TestGroup_GetImagePrice_4K 测试 4K 尺寸返回正确价格
|
||||
func TestGroup_GetImagePrice_4K(t *testing.T) {
|
||||
price := 0.30
|
||||
group := &Group{
|
||||
ImagePrice4K: &price,
|
||||
}
|
||||
|
||||
result := group.GetImagePrice("4K")
|
||||
require.NotNil(t, result)
|
||||
require.InDelta(t, 0.30, *result, 0.0001)
|
||||
}
|
||||
|
||||
// TestGroup_GetImagePrice_UnknownSize 测试未知尺寸回退 2K
|
||||
func TestGroup_GetImagePrice_UnknownSize(t *testing.T) {
|
||||
price2K := 0.15
|
||||
group := &Group{
|
||||
ImagePrice2K: &price2K,
|
||||
}
|
||||
|
||||
// 未知尺寸 "3K" 应该回退到 2K
|
||||
result := group.GetImagePrice("3K")
|
||||
require.NotNil(t, result)
|
||||
require.InDelta(t, 0.15, *result, 0.0001)
|
||||
|
||||
// 空字符串也回退到 2K
|
||||
result = group.GetImagePrice("")
|
||||
require.NotNil(t, result)
|
||||
require.InDelta(t, 0.15, *result, 0.0001)
|
||||
}
|
||||
|
||||
// TestGroup_GetImagePrice_NilValues 测试未配置时返回 nil
|
||||
func TestGroup_GetImagePrice_NilValues(t *testing.T) {
|
||||
group := &Group{
|
||||
// 所有 ImagePrice 字段都是 nil
|
||||
}
|
||||
|
||||
require.Nil(t, group.GetImagePrice("1K"))
|
||||
require.Nil(t, group.GetImagePrice("2K"))
|
||||
require.Nil(t, group.GetImagePrice("4K"))
|
||||
require.Nil(t, group.GetImagePrice("unknown"))
|
||||
}
|
||||
|
||||
// TestGroup_GetImagePrice_PartialConfig 测试部分配置
|
||||
func TestGroup_GetImagePrice_PartialConfig(t *testing.T) {
|
||||
price1K := 0.10
|
||||
group := &Group{
|
||||
ImagePrice1K: &price1K,
|
||||
// ImagePrice2K 和 ImagePrice4K 未配置
|
||||
}
|
||||
|
||||
result := group.GetImagePrice("1K")
|
||||
require.NotNil(t, result)
|
||||
require.InDelta(t, 0.10, *result, 0.0001)
|
||||
|
||||
// 2K 和 4K 返回 nil
|
||||
require.Nil(t, group.GetImagePrice("2K"))
|
||||
require.Nil(t, group.GetImagePrice("4K"))
|
||||
}
|
||||
@@ -34,6 +34,7 @@ type LiteLLMModelPricing struct {
|
||||
LiteLLMProvider string `json:"litellm_provider"`
|
||||
Mode string `json:"mode"`
|
||||
SupportsPromptCaching bool `json:"supports_prompt_caching"`
|
||||
OutputCostPerImage float64 `json:"output_cost_per_image"` // 图片生成模型每张图片价格
|
||||
}
|
||||
|
||||
// PricingRemoteClient 远程价格数据获取接口
|
||||
@@ -51,6 +52,7 @@ type LiteLLMRawEntry struct {
|
||||
LiteLLMProvider string `json:"litellm_provider"`
|
||||
Mode string `json:"mode"`
|
||||
SupportsPromptCaching bool `json:"supports_prompt_caching"`
|
||||
OutputCostPerImage *float64 `json:"output_cost_per_image"`
|
||||
}
|
||||
|
||||
// PricingService 动态价格服务
|
||||
@@ -319,6 +321,9 @@ func (s *PricingService) parsePricingData(body []byte) (map[string]*LiteLLMModel
|
||||
if entry.CacheReadInputTokenCost != nil {
|
||||
pricing.CacheReadInputTokenCost = *entry.CacheReadInputTokenCost
|
||||
}
|
||||
if entry.OutputCostPerImage != nil {
|
||||
pricing.OutputCostPerImage = *entry.OutputCostPerImage
|
||||
}
|
||||
|
||||
result[modelName] = pricing
|
||||
}
|
||||
|
||||
@@ -39,6 +39,10 @@ type UsageLog struct {
|
||||
DurationMs *int
|
||||
FirstTokenMs *int
|
||||
|
||||
// 图片生成字段
|
||||
ImageCount int
|
||||
ImageSize *string
|
||||
|
||||
CreatedAt time.Time
|
||||
|
||||
User *User
|
||||
|
||||
Reference in New Issue
Block a user