feat(sync): full code sync from release
This commit is contained in:
@@ -4,6 +4,8 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/apikey"
|
||||
@@ -56,7 +58,8 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
|
||||
SetNillableFallbackGroupID(groupIn.FallbackGroupID).
|
||||
SetNillableFallbackGroupIDOnInvalidRequest(groupIn.FallbackGroupIDOnInvalidRequest).
|
||||
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled).
|
||||
SetMcpXMLInject(groupIn.MCPXMLInject)
|
||||
SetMcpXMLInject(groupIn.MCPXMLInject).
|
||||
SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes)
|
||||
|
||||
// 设置模型路由配置
|
||||
if groupIn.ModelRouting != nil {
|
||||
@@ -121,7 +124,8 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
|
||||
SetDefaultValidityDays(groupIn.DefaultValidityDays).
|
||||
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
|
||||
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled).
|
||||
SetMcpXMLInject(groupIn.MCPXMLInject)
|
||||
SetMcpXMLInject(groupIn.MCPXMLInject).
|
||||
SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes)
|
||||
|
||||
// 处理 FallbackGroupID:nil 时清除,否则设置
|
||||
if groupIn.FallbackGroupID != nil {
|
||||
@@ -281,6 +285,54 @@ func (r *groupRepository) ExistsByName(ctx context.Context, name string) (bool,
|
||||
return r.client.Group.Query().Where(group.NameEQ(name)).Exist(ctx)
|
||||
}
|
||||
|
||||
// ExistsByIDs 批量检查分组是否存在(仅检查未软删除记录)。
|
||||
// 返回结构:map[groupID]exists。
|
||||
func (r *groupRepository) ExistsByIDs(ctx context.Context, ids []int64) (map[int64]bool, error) {
|
||||
result := make(map[int64]bool, len(ids))
|
||||
if len(ids) == 0 {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
uniqueIDs := make([]int64, 0, len(ids))
|
||||
seen := make(map[int64]struct{}, len(ids))
|
||||
for _, id := range ids {
|
||||
if id <= 0 {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[id]; ok {
|
||||
continue
|
||||
}
|
||||
seen[id] = struct{}{}
|
||||
uniqueIDs = append(uniqueIDs, id)
|
||||
result[id] = false
|
||||
}
|
||||
if len(uniqueIDs) == 0 {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
rows, err := r.sql.QueryContext(ctx, `
|
||||
SELECT id
|
||||
FROM groups
|
||||
WHERE id = ANY($1) AND deleted_at IS NULL
|
||||
`, pq.Array(uniqueIDs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
for rows.Next() {
|
||||
var id int64
|
||||
if err := rows.Scan(&id); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result[id] = true
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *groupRepository) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
|
||||
var count int64
|
||||
if err := scanSingleRow(ctx, r.sql, "SELECT COUNT(*) FROM account_groups WHERE group_id = $1", []any{groupID}, &count); err != nil {
|
||||
@@ -512,22 +564,72 @@ func (r *groupRepository) UpdateSortOrders(ctx context.Context, updates []servic
|
||||
return nil
|
||||
}
|
||||
|
||||
// 使用事务批量更新
|
||||
tx, err := r.client.Tx(ctx)
|
||||
// 去重后保留最后一次排序值,避免重复 ID 造成 CASE 分支冲突。
|
||||
sortOrderByID := make(map[int64]int, len(updates))
|
||||
groupIDs := make([]int64, 0, len(updates))
|
||||
for _, u := range updates {
|
||||
if u.ID <= 0 {
|
||||
continue
|
||||
}
|
||||
if _, exists := sortOrderByID[u.ID]; !exists {
|
||||
groupIDs = append(groupIDs, u.ID)
|
||||
}
|
||||
sortOrderByID[u.ID] = u.SortOrder
|
||||
}
|
||||
if len(groupIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 与旧实现保持一致:任何不存在/已删除的分组都返回 not found,且不执行更新。
|
||||
var existingCount int
|
||||
if err := scanSingleRow(
|
||||
ctx,
|
||||
r.sql,
|
||||
`SELECT COUNT(*) FROM groups WHERE deleted_at IS NULL AND id = ANY($1)`,
|
||||
[]any{pq.Array(groupIDs)},
|
||||
&existingCount,
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
if existingCount != len(groupIDs) {
|
||||
return service.ErrGroupNotFound
|
||||
}
|
||||
|
||||
args := make([]any, 0, len(groupIDs)*2+1)
|
||||
caseClauses := make([]string, 0, len(groupIDs))
|
||||
placeholder := 1
|
||||
for _, id := range groupIDs {
|
||||
caseClauses = append(caseClauses, fmt.Sprintf("WHEN $%d THEN $%d", placeholder, placeholder+1))
|
||||
args = append(args, id, sortOrderByID[id])
|
||||
placeholder += 2
|
||||
}
|
||||
args = append(args, pq.Array(groupIDs))
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
UPDATE groups
|
||||
SET sort_order = CASE id
|
||||
%s
|
||||
ELSE sort_order
|
||||
END
|
||||
WHERE deleted_at IS NULL AND id = ANY($%d)
|
||||
`, strings.Join(caseClauses, "\n\t\t\t"), placeholder)
|
||||
|
||||
result, err := r.sql.ExecContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
|
||||
for _, u := range updates {
|
||||
if _, err := tx.Group.UpdateOneID(u.ID).SetSortOrder(u.SortOrder).Save(ctx); err != nil {
|
||||
return translatePersistenceError(err, service.ErrGroupNotFound, nil)
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
affected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if affected != int64(len(groupIDs)) {
|
||||
return service.ErrGroupNotFound
|
||||
}
|
||||
|
||||
for _, id := range groupIDs {
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &id, nil); err != nil {
|
||||
logger.LegacyPrintf("repository.group", "[SchedulerOutbox] enqueue group sort update failed: group=%d err=%v", id, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user