feat(task): add model redirection, per-call billing, and multipart retry fix for async tasks
1. Async task model redirection (aligned with sync tasks):
- Integrate ModelMappedHelper in RelayTaskSubmit after model name
determination, populating OriginModelName / UpstreamModelName on RelayInfo.
- All task adaptors now send UpstreamModelName to upstream providers:
- Gemini & Vertex: BuildRequestURL uses UpstreamModelName.
- Doubao & Ali: BuildRequestBody conditionally overwrites body.Model.
- Vidu, Kling, Hailuo, Jimeng: convertToRequestPayload accepts RelayInfo
and unconditionally uses info.UpstreamModelName.
- Sora: BuildRequestBody parses JSON and multipart bodies to replace
the "model" field with UpstreamModelName.
- Frontend log visibility: LogTaskConsumption and taskBillingOther now
emit is_model_mapped / upstream_model_name in the "other" JSON field.
- Billing safety: RecalculateTaskQuotaByTokens reads model name from
BillingContext.OriginModelName (via taskModelName) instead of
task.Data["model"], preventing billing leaks from upstream model names.
2. Per-call billing (TaskPricePatches lifecycle):
- Rename TaskBillingContext.ModelName → OriginModelName; add PerCallBilling
bool field, populated from TaskPricePatches at submission time.
- settleTaskBillingOnComplete short-circuits when PerCallBilling is true,
skipping both adaptor adjustments and token-based recalculation.
- Remove ModelName from TaskSubmitResult; use relayInfo.OriginModelName
consistently in controller/relay.go for billing context and logging.
3. Multipart retry boundary mismatch fix:
- Root cause: after Sora (or OpenAI audio) rebuilds a multipart body with a
new boundary and overwrites c.Request.Header["Content-Type"], subsequent
calls to ParseMultipartFormReusable on retry would parse the cached
original body with the wrong boundary, causing "NextPart: EOF".
- Fix: ParseMultipartFormReusable now caches the original Content-Type in
gin context key "_original_multipart_ct" on first call and reuses it for
all subsequent parses, making multipart parsing retry-safe globally.
- Sora adaptor reverted to the standard pattern (direct header set/get),
which is now safe thanks to the root fix.
4. Tests:
- task_billing_test.go: update makeTask to use OriginModelName; add
PerCallBilling settlement tests (skip adaptor adjust, skip token recalc);
add non-per-call adaptor adjustment test with refund verification.
This commit is contained in:
@@ -16,11 +16,11 @@ import (
|
||||
|
||||
// LogTaskConsumption 记录任务消费日志和统计信息(仅记录,不涉及实际扣费)。
|
||||
// 实际扣费已由 BillingSession(PreConsumeBilling + SettleBilling)完成。
|
||||
func LogTaskConsumption(c *gin.Context, info *relaycommon.RelayInfo, modelName string) {
|
||||
func LogTaskConsumption(c *gin.Context, info *relaycommon.RelayInfo) {
|
||||
tokenName := c.GetString("token_name")
|
||||
logContent := fmt.Sprintf("操作 %s", info.Action)
|
||||
// 支持任务仅按次计费
|
||||
if common.StringsContains(constant.TaskPricePatches, modelName) {
|
||||
if common.StringsContains(constant.TaskPricePatches, info.OriginModelName) {
|
||||
logContent = fmt.Sprintf("%s,按次计费", logContent)
|
||||
} else {
|
||||
if len(info.PriceData.OtherRatios) > 0 {
|
||||
@@ -42,9 +42,13 @@ func LogTaskConsumption(c *gin.Context, info *relaycommon.RelayInfo, modelName s
|
||||
if info.PriceData.GroupRatioInfo.HasSpecialRatio {
|
||||
other["user_group_ratio"] = info.PriceData.GroupRatioInfo.GroupSpecialRatio
|
||||
}
|
||||
if info.IsModelMapped {
|
||||
other["is_model_mapped"] = true
|
||||
other["upstream_model_name"] = info.UpstreamModelName
|
||||
}
|
||||
model.RecordConsumeLog(c, info.UserId, model.RecordConsumeLogParams{
|
||||
ChannelId: info.ChannelId,
|
||||
ModelName: modelName,
|
||||
ModelName: info.OriginModelName,
|
||||
TokenName: tokenName,
|
||||
Quota: info.PriceData.Quota,
|
||||
Content: logContent,
|
||||
@@ -120,13 +124,18 @@ func taskBillingOther(task *model.Task) map[string]interface{} {
|
||||
}
|
||||
}
|
||||
}
|
||||
props := task.Properties
|
||||
if props.UpstreamModelName != "" && props.UpstreamModelName != props.OriginModelName {
|
||||
other["is_model_mapped"] = true
|
||||
other["upstream_model_name"] = props.UpstreamModelName
|
||||
}
|
||||
return other
|
||||
}
|
||||
|
||||
// taskModelName 从 BillingContext 或 Properties 中获取模型名称。
|
||||
func taskModelName(task *model.Task) string {
|
||||
if bc := task.PrivateData.BillingContext; bc != nil && bc.ModelName != "" {
|
||||
return bc.ModelName
|
||||
if bc := task.PrivateData.BillingContext; bc != nil && bc.OriginModelName != "" {
|
||||
return bc.OriginModelName
|
||||
}
|
||||
return task.Properties.OriginModelName
|
||||
}
|
||||
@@ -237,15 +246,7 @@ func RecalculateTaskQuotaByTokens(ctx context.Context, task *model.Task, totalTo
|
||||
return
|
||||
}
|
||||
|
||||
// 获取模型名称
|
||||
var taskData map[string]interface{}
|
||||
if err := common.Unmarshal(task.Data, &taskData); err != nil {
|
||||
return
|
||||
}
|
||||
modelName, ok := taskData["model"].(string)
|
||||
if !ok || modelName == "" {
|
||||
return
|
||||
}
|
||||
modelName := taskModelName(task)
|
||||
|
||||
// 获取模型价格和倍率
|
||||
modelRatio, hasRatioSetting, _ := ratio_setting.GetModelRatio(modelName)
|
||||
|
||||
@@ -3,12 +3,14 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/glebarez/sqlite"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -125,7 +127,7 @@ func makeTask(userId, channelId, quota, tokenId int, billingSource string, subsc
|
||||
BillingContext: &model.TaskBillingContext{
|
||||
ModelPrice: 0.02,
|
||||
GroupRatio: 1.0,
|
||||
ModelName: "test-model",
|
||||
OriginModelName: "test-model",
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -604,3 +606,107 @@ func TestNonTerminalUpdate_NoBilling(t *testing.T) {
|
||||
require.NoError(t, model.DB.First(&reloaded, task.ID).Error)
|
||||
assert.Equal(t, "50%", reloaded.Progress)
|
||||
}
|
||||
|
||||
// ===========================================================================
|
||||
// Mock adaptor for settleTaskBillingOnComplete tests
|
||||
// ===========================================================================
|
||||
|
||||
type mockAdaptor struct {
|
||||
adjustReturn int
|
||||
}
|
||||
|
||||
func (m *mockAdaptor) Init(_ *relaycommon.RelayInfo) {}
|
||||
func (m *mockAdaptor) FetchTask(string, string, map[string]any, string) (*http.Response, error) { return nil, nil }
|
||||
func (m *mockAdaptor) ParseTaskResult([]byte) (*relaycommon.TaskInfo, error) { return nil, nil }
|
||||
func (m *mockAdaptor) AdjustBillingOnComplete(_ *model.Task, _ *relaycommon.TaskInfo) int {
|
||||
return m.adjustReturn
|
||||
}
|
||||
|
||||
// ===========================================================================
|
||||
// PerCallBilling tests — settleTaskBillingOnComplete
|
||||
// ===========================================================================
|
||||
|
||||
func TestSettle_PerCallBilling_SkipsAdaptorAdjust(t *testing.T) {
|
||||
truncate(t)
|
||||
ctx := context.Background()
|
||||
|
||||
const userID, tokenID, channelID = 30, 30, 30
|
||||
const initQuota, preConsumed = 10000, 5000
|
||||
const tokenRemain = 8000
|
||||
|
||||
seedUser(t, userID, initQuota)
|
||||
seedToken(t, tokenID, userID, "sk-percall-adaptor", tokenRemain)
|
||||
seedChannel(t, channelID)
|
||||
|
||||
task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
|
||||
task.PrivateData.BillingContext.PerCallBilling = true
|
||||
|
||||
adaptor := &mockAdaptor{adjustReturn: 2000}
|
||||
taskResult := &relaycommon.TaskInfo{Status: model.TaskStatusSuccess}
|
||||
|
||||
settleTaskBillingOnComplete(ctx, adaptor, task, taskResult)
|
||||
|
||||
// Per-call: no adjustment despite adaptor returning 2000
|
||||
assert.Equal(t, initQuota, getUserQuota(t, userID))
|
||||
assert.Equal(t, tokenRemain, getTokenRemainQuota(t, tokenID))
|
||||
assert.Equal(t, preConsumed, task.Quota)
|
||||
assert.Equal(t, int64(0), countLogs(t))
|
||||
}
|
||||
|
||||
func TestSettle_PerCallBilling_SkipsTotalTokens(t *testing.T) {
|
||||
truncate(t)
|
||||
ctx := context.Background()
|
||||
|
||||
const userID, tokenID, channelID = 31, 31, 31
|
||||
const initQuota, preConsumed = 10000, 4000
|
||||
const tokenRemain = 7000
|
||||
|
||||
seedUser(t, userID, initQuota)
|
||||
seedToken(t, tokenID, userID, "sk-percall-tokens", tokenRemain)
|
||||
seedChannel(t, channelID)
|
||||
|
||||
task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
|
||||
task.PrivateData.BillingContext.PerCallBilling = true
|
||||
|
||||
adaptor := &mockAdaptor{adjustReturn: 0}
|
||||
taskResult := &relaycommon.TaskInfo{Status: model.TaskStatusSuccess, TotalTokens: 9999}
|
||||
|
||||
settleTaskBillingOnComplete(ctx, adaptor, task, taskResult)
|
||||
|
||||
// Per-call: no recalculation by tokens
|
||||
assert.Equal(t, initQuota, getUserQuota(t, userID))
|
||||
assert.Equal(t, tokenRemain, getTokenRemainQuota(t, tokenID))
|
||||
assert.Equal(t, preConsumed, task.Quota)
|
||||
assert.Equal(t, int64(0), countLogs(t))
|
||||
}
|
||||
|
||||
func TestSettle_NonPerCall_AdaptorAdjustWorks(t *testing.T) {
|
||||
truncate(t)
|
||||
ctx := context.Background()
|
||||
|
||||
const userID, tokenID, channelID = 32, 32, 32
|
||||
const initQuota, preConsumed = 10000, 5000
|
||||
const adaptorQuota = 3000
|
||||
const tokenRemain = 8000
|
||||
|
||||
seedUser(t, userID, initQuota)
|
||||
seedToken(t, tokenID, userID, "sk-nonpercall-adj", tokenRemain)
|
||||
seedChannel(t, channelID)
|
||||
|
||||
task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
|
||||
// PerCallBilling defaults to false
|
||||
|
||||
adaptor := &mockAdaptor{adjustReturn: adaptorQuota}
|
||||
taskResult := &relaycommon.TaskInfo{Status: model.TaskStatusSuccess}
|
||||
|
||||
settleTaskBillingOnComplete(ctx, adaptor, task, taskResult)
|
||||
|
||||
// Non-per-call: adaptor adjustment applies (refund 2000)
|
||||
assert.Equal(t, initQuota+(preConsumed-adaptorQuota), getUserQuota(t, userID))
|
||||
assert.Equal(t, tokenRemain+(preConsumed-adaptorQuota), getTokenRemainQuota(t, tokenID))
|
||||
assert.Equal(t, adaptorQuota, task.Quota)
|
||||
|
||||
log := getLastLog(t)
|
||||
require.NotNil(t, log)
|
||||
assert.Equal(t, model.LogTypeRefund, log.Type)
|
||||
}
|
||||
|
||||
@@ -467,6 +467,11 @@ func truncateBase64(s string) string {
|
||||
// 2. taskResult.TotalTokens > 0 → 按 token 重算
|
||||
// 3. 都不满足 → 保持预扣额度不变
|
||||
func settleTaskBillingOnComplete(ctx context.Context, adaptor TaskPollingAdaptor, task *model.Task, taskResult *relaycommon.TaskInfo) {
|
||||
// 0. 按次计费的任务不做差额结算
|
||||
if bc := task.PrivateData.BillingContext; bc != nil && bc.PerCallBilling {
|
||||
logger.LogInfo(ctx, fmt.Sprintf("任务 %s 按次计费,跳过差额结算", task.TaskID))
|
||||
return
|
||||
}
|
||||
// 1. 优先让 adaptor 决定最终额度
|
||||
if actualQuota := adaptor.AdjustBillingOnComplete(task, taskResult); actualQuota > 0 {
|
||||
RecalculateTaskQuota(ctx, task, actualQuota, "adaptor计费调整")
|
||||
|
||||
Reference in New Issue
Block a user