@@ -0,0 +1,304 @@
package model
import (
"database/sql/driver"
"encoding/json"
"one-api/constant"
commonRelay "one-api/relay/common"
"time"
)
type TaskStatus string
const (
TaskStatusNotStart TaskStatus = "NOT_START"
TaskStatusSubmitted = "SUBMITTED"
TaskStatusQueued = "QUEUED"
TaskStatusInProgress = "IN_PROGRESS"
TaskStatusFailure = "FAILURE"
TaskStatusSuccess = "SUCCESS"
TaskStatusUnknown = "UNKNOWN"
)
type Task struct {
ID int64 ` json:"id" gorm:"primary_key;AUTO_INCREMENT" `
CreatedAt int64 ` json:"created_at" gorm:"index" `
UpdatedAt int64 ` json:"updated_at" `
TaskID string ` json:"task_id" gorm:"type:varchar(50);index" ` // 第三方id, 不一定有/ song id\ Task id
Platform constant . TaskPlatform ` json:"platform" gorm:"type:varchar(30);index" ` // 平台
UserId int ` json:"user_id" gorm:"index" `
ChannelId int ` json:"channel_id" gorm:"index" `
Quota int ` json:"quota" `
Action string ` json:"action" gorm:"type:varchar(40);index" ` // 任务类型, song, lyrics, description-mode
Status TaskStatus ` json:"status" gorm:"type:varchar(20);index" ` // 任务状态
FailReason string ` json:"fail_reason" `
SubmitTime int64 ` json:"submit_time" gorm:"index" `
StartTime int64 ` json:"start_time" gorm:"index" `
FinishTime int64 ` json:"finish_time" gorm:"index" `
Progress string ` json:"progress" gorm:"type:varchar(20);index" `
Properties Properties ` json:"properties" gorm:"type:json" `
Data json . RawMessage ` json:"data" gorm:"type:json" `
}
func ( t * Task ) SetData ( data any ) {
b , _ := json . Marshal ( data )
t . Data = json . RawMessage ( b )
}
func ( t * Task ) GetData ( v any ) error {
err := json . Unmarshal ( t . Data , & v )
return err
}
type Properties struct {
Input string ` json:"input" `
}
func ( m * Properties ) Scan ( val interface { } ) error {
bytesValue , _ := val . ( [ ] byte )
return json . Unmarshal ( bytesValue , m )
}
func ( m Properties ) Value ( ) ( driver . Value , error ) {
return json . Marshal ( m )
}
// SyncTaskQueryParams 用于包含所有搜索条件的结构体,可以根据需求添加更多字段
type SyncTaskQueryParams struct {
Platform constant . TaskPlatform
ChannelID string
TaskID string
UserID string
Action string
Status string
StartTimestamp int64
EndTimestamp int64
UserIDs [ ] int
}
func InitTask ( platform constant . TaskPlatform , relayInfo * commonRelay . TaskRelayInfo ) * Task {
t := & Task {
UserId : relayInfo . UserId ,
SubmitTime : time . Now ( ) . Unix ( ) ,
Status : TaskStatusNotStart ,
Progress : "0%" ,
ChannelId : relayInfo . ChannelId ,
Platform : platform ,
}
return t
}
func TaskGetAllUserTask ( userId int , startIdx int , num int , queryParams SyncTaskQueryParams ) [ ] * Task {
var tasks [ ] * Task
var err error
// 初始化查询构建器
query := DB . Where ( "user_id = ?" , userId )
if queryParams . TaskID != "" {
query = query . Where ( "task_id = ?" , queryParams . TaskID )
}
if queryParams . Action != "" {
query = query . Where ( "action = ?" , queryParams . Action )
}
if queryParams . Status != "" {
query = query . Where ( "status = ?" , queryParams . Status )
}
if queryParams . Platform != "" {
query = query . Where ( "platform = ?" , queryParams . Platform )
}
if queryParams . StartTimestamp != 0 {
// 假设您已将前端传来的时间戳转换为数据库所需的时间格式,并处理了时间戳的验证和解析
query = query . Where ( "submit_time >= ?" , queryParams . StartTimestamp )
}
if queryParams . EndTimestamp != 0 {
query = query . Where ( "submit_time <= ?" , queryParams . EndTimestamp )
}
// 获取数据
err = query . Omit ( "channel_id" ) . Order ( "id desc" ) . Limit ( num ) . Offset ( startIdx ) . Find ( & tasks ) . Error
if err != nil {
return nil
}
return tasks
}
func TaskGetAllTasks ( startIdx int , num int , queryParams SyncTaskQueryParams ) [ ] * Task {
var tasks [ ] * Task
var err error
// 初始化查询构建器
query := DB
// 添加过滤条件
if queryParams . ChannelID != "" {
query = query . Where ( "channel_id = ?" , queryParams . ChannelID )
}
if queryParams . Platform != "" {
query = query . Where ( "platform = ?" , queryParams . Platform )
}
if queryParams . UserID != "" {
query = query . Where ( "user_id = ?" , queryParams . UserID )
}
if len ( queryParams . UserIDs ) != 0 {
query = query . Where ( "user_id in (?)" , queryParams . UserIDs )
}
if queryParams . TaskID != "" {
query = query . Where ( "task_id = ?" , queryParams . TaskID )
}
if queryParams . Action != "" {
query = query . Where ( "action = ?" , queryParams . Action )
}
if queryParams . Status != "" {
query = query . Where ( "status = ?" , queryParams . Status )
}
if queryParams . StartTimestamp != 0 {
query = query . Where ( "submit_time >= ?" , queryParams . StartTimestamp )
}
if queryParams . EndTimestamp != 0 {
query = query . Where ( "submit_time <= ?" , queryParams . EndTimestamp )
}
// 获取数据
err = query . Order ( "id desc" ) . Limit ( num ) . Offset ( startIdx ) . Find ( & tasks ) . Error
if err != nil {
return nil
}
return tasks
}
func GetAllUnFinishSyncTasks ( limit int ) [ ] * Task {
var tasks [ ] * Task
var err error
// get all tasks progress is not 100%
err = DB . Where ( "progress != ?" , "100%" ) . Limit ( limit ) . Order ( "id" ) . Find ( & tasks ) . Error
if err != nil {
return nil
}
return tasks
}
func GetByOnlyTaskId ( taskId string ) ( * Task , bool , error ) {
if taskId == "" {
return nil , false , nil
}
var task * Task
var err error
err = DB . Where ( "task_id = ?" , taskId ) . First ( & task ) . Error
exist , err := RecordExist ( err )
if err != nil {
return nil , false , err
}
return task , exist , err
}
func GetByTaskId ( userId int , taskId string ) ( * Task , bool , error ) {
if taskId == "" {
return nil , false , nil
}
var task * Task
var err error
err = DB . Where ( "user_id = ? and task_id = ?" , userId , taskId ) .
First ( & task ) . Error
exist , err := RecordExist ( err )
if err != nil {
return nil , false , err
}
return task , exist , err
}
func GetByTaskIds ( userId int , taskIds [ ] any ) ( [ ] * Task , error ) {
if len ( taskIds ) == 0 {
return nil , nil
}
var task [ ] * Task
var err error
err = DB . Where ( "user_id = ? and task_id in (?)" , userId , taskIds ) .
Find ( & task ) . Error
if err != nil {
return nil , err
}
return task , nil
}
func TaskUpdateProgress ( id int64 , progress string ) error {
return DB . Model ( & Task { } ) . Where ( "id = ?" , id ) . Update ( "progress" , progress ) . Error
}
func ( Task * Task ) Insert ( ) error {
var err error
err = DB . Create ( Task ) . Error
return err
}
func ( Task * Task ) Update ( ) error {
var err error
err = DB . Save ( Task ) . Error
return err
}
func TaskBulkUpdate ( TaskIds [ ] string , params map [ string ] any ) error {
if len ( TaskIds ) == 0 {
return nil
}
return DB . Model ( & Task { } ) .
Where ( "task_id in (?)" , TaskIds ) .
Updates ( params ) . Error
}
func TaskBulkUpdateByTaskIds ( taskIDs [ ] int64 , params map [ string ] any ) error {
if len ( taskIDs ) == 0 {
return nil
}
return DB . Model ( & Task { } ) .
Where ( "id in (?)" , taskIDs ) .
Updates ( params ) . Error
}
func TaskBulkUpdateByID ( ids [ ] int64 , params map [ string ] any ) error {
if len ( ids ) == 0 {
return nil
}
return DB . Model ( & Task { } ) .
Where ( "id in (?)" , ids ) .
Updates ( params ) . Error
}
type TaskQuotaUsage struct {
Mode string ` json:"mode" `
Count float64 ` json:"count" `
}
func SumUsedTaskQuota ( queryParams SyncTaskQueryParams ) ( stat [ ] TaskQuotaUsage , err error ) {
query := DB . Model ( Task { } )
// 添加过滤条件
if queryParams . ChannelID != "" {
query = query . Where ( "channel_id = ?" , queryParams . ChannelID )
}
if queryParams . UserID != "" {
query = query . Where ( "user_id = ?" , queryParams . UserID )
}
if len ( queryParams . UserIDs ) != 0 {
query = query . Where ( "user_id in (?)" , queryParams . UserIDs )
}
if queryParams . TaskID != "" {
query = query . Where ( "task_id = ?" , queryParams . TaskID )
}
if queryParams . Action != "" {
query = query . Where ( "action = ?" , queryParams . Action )
}
if queryParams . Status != "" {
query = query . Where ( "status = ?" , queryParams . Status )
}
if queryParams . StartTimestamp != 0 {
query = query . Where ( "submit_time >= ?" , queryParams . StartTimestamp )
}
if queryParams . EndTimestamp != 0 {
query = query . Where ( "submit_time <= ?" , queryParams . EndTimestamp )
}
err = query . Select ( "mode, sum(quota) as count" ) . Group ( "mode" ) . Find ( & stat ) . Error
return stat , err
}