feat: add doubao video use quota by total token
This commit is contained in:
@@ -13,6 +13,7 @@ import (
|
|||||||
"one-api/relay"
|
"one-api/relay"
|
||||||
"one-api/relay/channel"
|
"one-api/relay/channel"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
|
"one-api/setting/ratio_setting"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -120,6 +121,89 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
|
|||||||
if !(len(taskResult.Url) > 5 && taskResult.Url[:5] == "data:") {
|
if !(len(taskResult.Url) > 5 && taskResult.Url[:5] == "data:") {
|
||||||
task.FailReason = taskResult.Url
|
task.FailReason = taskResult.Url
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 如果返回了 total_tokens 并且配置了模型倍率(非固定价格),则重新计费
|
||||||
|
if taskResult.TotalTokens > 0 {
|
||||||
|
// 获取模型名称
|
||||||
|
var taskData map[string]interface{}
|
||||||
|
if err := json.Unmarshal(task.Data, &taskData); err == nil {
|
||||||
|
if modelName, ok := taskData["model"].(string); ok && modelName != "" {
|
||||||
|
// 获取模型价格和倍率
|
||||||
|
modelRatio, hasRatioSetting, _ := ratio_setting.GetModelRatio(modelName)
|
||||||
|
|
||||||
|
// 只有配置了倍率(非固定价格)时才按 token 重新计费
|
||||||
|
if hasRatioSetting && modelRatio > 0 {
|
||||||
|
// 获取用户和组的倍率信息
|
||||||
|
user, err := model.GetUserById(task.UserId, false)
|
||||||
|
if err == nil {
|
||||||
|
groupRatio := ratio_setting.GetGroupRatio(user.Group)
|
||||||
|
userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(user.Group, user.Group)
|
||||||
|
|
||||||
|
var finalGroupRatio float64
|
||||||
|
if hasUserGroupRatio {
|
||||||
|
finalGroupRatio = userGroupRatio
|
||||||
|
} else {
|
||||||
|
finalGroupRatio = groupRatio
|
||||||
|
}
|
||||||
|
|
||||||
|
// 计算实际应扣费额度: totalTokens * modelRatio * groupRatio
|
||||||
|
actualQuota := int(float64(taskResult.TotalTokens) * modelRatio * finalGroupRatio)
|
||||||
|
|
||||||
|
// 计算差额
|
||||||
|
preConsumedQuota := task.Quota
|
||||||
|
quotaDelta := actualQuota - preConsumedQuota
|
||||||
|
|
||||||
|
if quotaDelta > 0 {
|
||||||
|
// 需要补扣费
|
||||||
|
logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费后补扣费:%s(实际消耗:%s,预扣费:%s,tokens:%d)",
|
||||||
|
task.TaskID,
|
||||||
|
logger.LogQuota(quotaDelta),
|
||||||
|
logger.LogQuota(actualQuota),
|
||||||
|
logger.LogQuota(preConsumedQuota),
|
||||||
|
taskResult.TotalTokens,
|
||||||
|
))
|
||||||
|
if err := model.DecreaseUserQuota(task.UserId, quotaDelta); err != nil {
|
||||||
|
logger.LogError(ctx, fmt.Sprintf("补扣费失败: %s", err.Error()))
|
||||||
|
} else {
|
||||||
|
model.UpdateUserUsedQuotaAndRequestCount(task.UserId, quotaDelta)
|
||||||
|
model.UpdateChannelUsedQuota(task.ChannelId, quotaDelta)
|
||||||
|
task.Quota = actualQuota // 更新任务记录的实际扣费额度
|
||||||
|
|
||||||
|
// 记录消费日志
|
||||||
|
logContent := fmt.Sprintf("视频任务成功补扣费,模型倍率 %.2f,分组倍率 %.2f,tokens %d",
|
||||||
|
modelRatio, finalGroupRatio, taskResult.TotalTokens)
|
||||||
|
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
||||||
|
}
|
||||||
|
} else if quotaDelta < 0 {
|
||||||
|
// 需要退还多扣的费用
|
||||||
|
refundQuota := -quotaDelta
|
||||||
|
logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费后返还:%s(实际消耗:%s,预扣费:%s,tokens:%d)",
|
||||||
|
task.TaskID,
|
||||||
|
logger.LogQuota(refundQuota),
|
||||||
|
logger.LogQuota(actualQuota),
|
||||||
|
logger.LogQuota(preConsumedQuota),
|
||||||
|
taskResult.TotalTokens,
|
||||||
|
))
|
||||||
|
if err := model.IncreaseUserQuota(task.UserId, refundQuota, false); err != nil {
|
||||||
|
logger.LogError(ctx, fmt.Sprintf("退还预扣费失败: %s", err.Error()))
|
||||||
|
} else {
|
||||||
|
task.Quota = actualQuota // 更新任务记录的实际扣费额度
|
||||||
|
|
||||||
|
// 记录退款日志
|
||||||
|
logContent := fmt.Sprintf("视频任务成功退还多扣费用,模型倍率 %.2f,分组倍率 %.2f,tokens %d,退还 %s",
|
||||||
|
modelRatio, finalGroupRatio, taskResult.TotalTokens, logger.LogQuota(refundQuota))
|
||||||
|
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// quotaDelta == 0, 预扣费刚好准确
|
||||||
|
logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费准确(%s,tokens:%d)",
|
||||||
|
task.TaskID, logger.LogQuota(actualQuota), taskResult.TotalTokens))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
case model.TaskStatusFailure:
|
case model.TaskStatusFailure:
|
||||||
task.Status = model.TaskStatusFailure
|
task.Status = model.TaskStatusFailure
|
||||||
task.Progress = "100%"
|
task.Progress = "100%"
|
||||||
|
|||||||
@@ -231,6 +231,9 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
|
|||||||
taskResult.Status = model.TaskStatusSuccess
|
taskResult.Status = model.TaskStatusSuccess
|
||||||
taskResult.Progress = "100%"
|
taskResult.Progress = "100%"
|
||||||
taskResult.Url = resTask.Content.VideoURL
|
taskResult.Url = resTask.Content.VideoURL
|
||||||
|
// 解析 usage 信息用于按倍率计费
|
||||||
|
taskResult.CompletionTokens = resTask.Usage.CompletionTokens
|
||||||
|
taskResult.TotalTokens = resTask.Usage.TotalTokens
|
||||||
case "failed":
|
case "failed":
|
||||||
taskResult.Status = model.TaskStatusFailure
|
taskResult.Status = model.TaskStatusFailure
|
||||||
taskResult.Progress = "100%"
|
taskResult.Progress = "100%"
|
||||||
|
|||||||
@@ -500,10 +500,12 @@ func (t TaskSubmitReq) HasImage() bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type TaskInfo struct {
|
type TaskInfo struct {
|
||||||
Code int `json:"code"`
|
Code int `json:"code"`
|
||||||
TaskID string `json:"task_id"`
|
TaskID string `json:"task_id"`
|
||||||
Status string `json:"status"`
|
Status string `json:"status"`
|
||||||
Reason string `json:"reason,omitempty"`
|
Reason string `json:"reason,omitempty"`
|
||||||
Url string `json:"url,omitempty"`
|
Url string `json:"url,omitempty"`
|
||||||
Progress string `json:"progress,omitempty"`
|
Progress string `json:"progress,omitempty"`
|
||||||
|
CompletionTokens int `json:"completion_tokens,omitempty"` // 用于按倍率计费
|
||||||
|
TotalTokens int `json:"total_tokens,omitempty"` // 用于按倍率计费
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user