refactor(task): extract billing and polling logic from controller to service layer
Restructure the task relay system for better separation of concerns: - Extract task billing into service/task_billing.go with unified settlement flow - Move task polling loop from controller to service/task_polling.go (supports Suno + video platforms) - Split RelayTask into fetch/submit paths with dedicated retry logic (taskSubmitWithRetry) - Add TaskDto, TaskResponse generics, and FetchReq to dto/task.go - Add taskcommon/helpers.go for shared task adaptor utilities - Remove controller/task_video.go (logic consolidated into service layer) - Update all task adaptors (ali, doubao, gemini, hailuo, jimeng, kling, sora, suno, vertex, vidu) - Simplify frontend task logs to use new TaskDto response format
This commit is contained in:
227
service/task_billing.go
Normal file
227
service/task_billing.go
Normal file
@@ -0,0 +1,227 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/setting/ratio_setting"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// LogTaskConsumption 记录任务消费日志和统计信息(仅记录,不涉及实际扣费)。
|
||||
// 实际扣费已由 BillingSession(PreConsumeBilling + SettleBilling)完成。
|
||||
func LogTaskConsumption(c *gin.Context, info *relaycommon.RelayInfo, modelName string) {
|
||||
tokenName := c.GetString("token_name")
|
||||
logContent := fmt.Sprintf("操作 %s", info.Action)
|
||||
// 支持任务仅按次计费
|
||||
if common.StringsContains(constant.TaskPricePatches, modelName) {
|
||||
logContent = fmt.Sprintf("%s,按次计费", logContent)
|
||||
} else {
|
||||
if len(info.PriceData.OtherRatios) > 0 {
|
||||
var contents []string
|
||||
for key, ra := range info.PriceData.OtherRatios {
|
||||
if 1.0 != ra {
|
||||
contents = append(contents, fmt.Sprintf("%s: %.2f", key, ra))
|
||||
}
|
||||
}
|
||||
if len(contents) > 0 {
|
||||
logContent = fmt.Sprintf("%s, 计算参数:%s", logContent, strings.Join(contents, ", "))
|
||||
}
|
||||
}
|
||||
}
|
||||
other := make(map[string]interface{})
|
||||
other["request_path"] = c.Request.URL.Path
|
||||
other["model_price"] = info.PriceData.ModelPrice
|
||||
other["group_ratio"] = info.PriceData.GroupRatioInfo.GroupRatio
|
||||
if info.PriceData.GroupRatioInfo.HasSpecialRatio {
|
||||
other["user_group_ratio"] = info.PriceData.GroupRatioInfo.GroupSpecialRatio
|
||||
}
|
||||
model.RecordConsumeLog(c, info.UserId, model.RecordConsumeLogParams{
|
||||
ChannelId: info.ChannelId,
|
||||
ModelName: modelName,
|
||||
TokenName: tokenName,
|
||||
Quota: info.PriceData.Quota,
|
||||
Content: logContent,
|
||||
TokenId: info.TokenId,
|
||||
Group: info.UsingGroup,
|
||||
Other: other,
|
||||
})
|
||||
model.UpdateUserUsedQuotaAndRequestCount(info.UserId, info.PriceData.Quota)
|
||||
model.UpdateChannelUsedQuota(info.ChannelId, info.PriceData.Quota)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// 异步任务计费辅助函数
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// resolveTokenKey 通过 TokenId 运行时获取令牌 Key(用于 Redis 缓存操作)。
|
||||
// 如果令牌已被删除或查询失败,返回空字符串。
|
||||
func resolveTokenKey(ctx context.Context, tokenId int, taskID string) string {
|
||||
token, err := model.GetTokenById(tokenId)
|
||||
if err != nil {
|
||||
logger.LogWarn(ctx, fmt.Sprintf("获取令牌 key 失败 (tokenId=%d, task=%s): %s", tokenId, taskID, err.Error()))
|
||||
return ""
|
||||
}
|
||||
return token.Key
|
||||
}
|
||||
|
||||
// taskIsSubscription 判断任务是否通过订阅计费。
|
||||
func taskIsSubscription(task *model.Task) bool {
|
||||
return task.PrivateData.BillingSource == BillingSourceSubscription && task.PrivateData.SubscriptionId > 0
|
||||
}
|
||||
|
||||
// taskAdjustFunding 调整任务的资金来源(钱包或订阅),delta > 0 表示扣费,delta < 0 表示退还。
|
||||
func taskAdjustFunding(task *model.Task, delta int) error {
|
||||
if taskIsSubscription(task) {
|
||||
return model.PostConsumeUserSubscriptionDelta(task.PrivateData.SubscriptionId, int64(delta))
|
||||
}
|
||||
if delta > 0 {
|
||||
return model.DecreaseUserQuota(task.UserId, delta)
|
||||
}
|
||||
return model.IncreaseUserQuota(task.UserId, -delta, false)
|
||||
}
|
||||
|
||||
// taskAdjustTokenQuota 调整任务的令牌额度,delta > 0 表示扣费,delta < 0 表示退还。
|
||||
// 需要通过 resolveTokenKey 运行时获取 key(不从 PrivateData 中读取)。
|
||||
func taskAdjustTokenQuota(ctx context.Context, task *model.Task, delta int) {
|
||||
if task.PrivateData.TokenId <= 0 || delta == 0 {
|
||||
return
|
||||
}
|
||||
tokenKey := resolveTokenKey(ctx, task.PrivateData.TokenId, task.TaskID)
|
||||
if tokenKey == "" {
|
||||
return
|
||||
}
|
||||
var err error
|
||||
if delta > 0 {
|
||||
err = model.DecreaseTokenQuota(task.PrivateData.TokenId, tokenKey, delta)
|
||||
} else {
|
||||
err = model.IncreaseTokenQuota(task.PrivateData.TokenId, tokenKey, -delta)
|
||||
}
|
||||
if err != nil {
|
||||
logger.LogWarn(ctx, fmt.Sprintf("调整令牌额度失败 (delta=%d, task=%s): %s", delta, task.TaskID, err.Error()))
|
||||
}
|
||||
}
|
||||
|
||||
// RefundTaskQuota 统一的任务失败退款逻辑。
|
||||
// 当异步任务失败时,将预扣的 quota 退还给用户(支持钱包和订阅),并退还令牌额度。
|
||||
func RefundTaskQuota(ctx context.Context, task *model.Task, reason string) {
|
||||
quota := task.Quota
|
||||
if quota == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// 1. 退还资金来源(钱包或订阅)
|
||||
if err := taskAdjustFunding(task, -quota); err != nil {
|
||||
logger.LogWarn(ctx, fmt.Sprintf("退还资金来源失败 task %s: %s", task.TaskID, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
// 2. 退还令牌额度
|
||||
taskAdjustTokenQuota(ctx, task, -quota)
|
||||
|
||||
// 3. 记录日志
|
||||
logContent := fmt.Sprintf("异步任务执行失败 %s,补偿 %s,原因:%s", task.TaskID, logger.LogQuota(quota), reason)
|
||||
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
||||
}
|
||||
|
||||
// RecalculateTaskQuotaByTokens 根据实际 token 消耗重新计费(异步差额结算)。
|
||||
// 当任务成功且返回了 totalTokens 时,根据模型倍率和分组倍率重新计算实际扣费额度,
|
||||
// 与预扣费的差额进行补扣或退还。支持钱包和订阅计费来源。
|
||||
func RecalculateTaskQuotaByTokens(ctx context.Context, task *model.Task, totalTokens int) {
|
||||
if totalTokens <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// 获取模型名称
|
||||
var taskData map[string]interface{}
|
||||
if err := common.Unmarshal(task.Data, &taskData); err != nil {
|
||||
return
|
||||
}
|
||||
modelName, ok := taskData["model"].(string)
|
||||
if !ok || modelName == "" {
|
||||
return
|
||||
}
|
||||
|
||||
// 获取模型价格和倍率
|
||||
modelRatio, hasRatioSetting, _ := ratio_setting.GetModelRatio(modelName)
|
||||
// 只有配置了倍率(非固定价格)时才按 token 重新计费
|
||||
if !hasRatioSetting || modelRatio <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// 获取用户和组的倍率信息
|
||||
group := task.Group
|
||||
if group == "" {
|
||||
user, err := model.GetUserById(task.UserId, false)
|
||||
if err == nil {
|
||||
group = user.Group
|
||||
}
|
||||
}
|
||||
if group == "" {
|
||||
return
|
||||
}
|
||||
|
||||
groupRatio := ratio_setting.GetGroupRatio(group)
|
||||
userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(group, group)
|
||||
|
||||
var finalGroupRatio float64
|
||||
if hasUserGroupRatio {
|
||||
finalGroupRatio = userGroupRatio
|
||||
} else {
|
||||
finalGroupRatio = groupRatio
|
||||
}
|
||||
|
||||
// 计算实际应扣费额度: totalTokens * modelRatio * groupRatio
|
||||
actualQuota := int(float64(totalTokens) * modelRatio * finalGroupRatio)
|
||||
|
||||
// 计算差额(正数=需要补扣,负数=需要退还)
|
||||
preConsumedQuota := task.Quota
|
||||
quotaDelta := actualQuota - preConsumedQuota
|
||||
|
||||
if quotaDelta == 0 {
|
||||
logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费准确(%s,tokens:%d)",
|
||||
task.TaskID, logger.LogQuota(actualQuota), totalTokens))
|
||||
return
|
||||
}
|
||||
|
||||
logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 差额结算:delta=%s(实际:%s,预扣:%s,tokens:%d)",
|
||||
task.TaskID,
|
||||
logger.LogQuota(quotaDelta),
|
||||
logger.LogQuota(actualQuota),
|
||||
logger.LogQuota(preConsumedQuota),
|
||||
totalTokens,
|
||||
))
|
||||
|
||||
// 调整资金来源
|
||||
if err := taskAdjustFunding(task, quotaDelta); err != nil {
|
||||
logger.LogError(ctx, fmt.Sprintf("差额结算资金调整失败 task %s: %s", task.TaskID, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
// 调整令牌额度
|
||||
taskAdjustTokenQuota(ctx, task, quotaDelta)
|
||||
|
||||
// 更新统计(仅补扣时更新,退还不影响已用统计)
|
||||
if quotaDelta > 0 {
|
||||
model.UpdateUserUsedQuotaAndRequestCount(task.UserId, quotaDelta)
|
||||
model.UpdateChannelUsedQuota(task.ChannelId, quotaDelta)
|
||||
}
|
||||
task.Quota = actualQuota
|
||||
|
||||
var action string
|
||||
if quotaDelta > 0 {
|
||||
action = "补扣费"
|
||||
} else {
|
||||
action = "退还"
|
||||
}
|
||||
logContent := fmt.Sprintf("视频任务成功%s,模型倍率 %.2f,分组倍率 %.2f,tokens %d,预扣费 %s,实际扣费 %s",
|
||||
action, modelRatio, finalGroupRatio, totalTokens,
|
||||
logger.LogQuota(preConsumedQuota), logger.LogQuota(actualQuota))
|
||||
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
||||
}
|
||||
Reference in New Issue
Block a user