Files
xinghuoapi/backend/internal/repository/proxy_repo.go
yangjianbo ae191f72a4 fix(仓储): 修复软删除过滤与事务测试
修复软删除拦截器使用错误,确保默认查询过滤已删记录
仓储层改用 ent.Tx 与扫描辅助,避免 sql.Tx 断言问题
同步更新集成测试以覆盖事务与统计变动
2025-12-29 19:23:49 +08:00

264 lines
6.9 KiB
Go

package repository
import (
"context"
"database/sql"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/proxy"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
type sqlQuerier interface {
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
}
type proxyRepository struct {
client *dbent.Client
sql sqlQuerier
}
func NewProxyRepository(client *dbent.Client, sqlDB *sql.DB) service.ProxyRepository {
return newProxyRepositoryWithSQL(client, sqlDB)
}
func newProxyRepositoryWithSQL(client *dbent.Client, sqlq sqlQuerier) *proxyRepository {
return &proxyRepository{client: client, sql: sqlq}
}
func (r *proxyRepository) Create(ctx context.Context, proxyIn *service.Proxy) error {
builder := r.client.Proxy.Create().
SetName(proxyIn.Name).
SetProtocol(proxyIn.Protocol).
SetHost(proxyIn.Host).
SetPort(proxyIn.Port).
SetStatus(proxyIn.Status)
if proxyIn.Username != "" {
builder.SetUsername(proxyIn.Username)
}
if proxyIn.Password != "" {
builder.SetPassword(proxyIn.Password)
}
created, err := builder.Save(ctx)
if err == nil {
applyProxyEntityToService(proxyIn, created)
}
return err
}
func (r *proxyRepository) GetByID(ctx context.Context, id int64) (*service.Proxy, error) {
m, err := r.client.Proxy.Get(ctx, id)
if err != nil {
if dbent.IsNotFound(err) {
return nil, service.ErrProxyNotFound
}
return nil, err
}
return proxyEntityToService(m), nil
}
func (r *proxyRepository) Update(ctx context.Context, proxyIn *service.Proxy) error {
builder := r.client.Proxy.UpdateOneID(proxyIn.ID).
SetName(proxyIn.Name).
SetProtocol(proxyIn.Protocol).
SetHost(proxyIn.Host).
SetPort(proxyIn.Port).
SetStatus(proxyIn.Status)
if proxyIn.Username != "" {
builder.SetUsername(proxyIn.Username)
} else {
builder.ClearUsername()
}
if proxyIn.Password != "" {
builder.SetPassword(proxyIn.Password)
} else {
builder.ClearPassword()
}
updated, err := builder.Save(ctx)
if err == nil {
applyProxyEntityToService(proxyIn, updated)
return nil
}
if dbent.IsNotFound(err) {
return service.ErrProxyNotFound
}
return err
}
func (r *proxyRepository) Delete(ctx context.Context, id int64) error {
_, err := r.client.Proxy.Delete().Where(proxy.IDEQ(id)).Exec(ctx)
return err
}
func (r *proxyRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Proxy, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "")
}
// ListWithFilters lists proxies with optional filtering by protocol, status, and search query
func (r *proxyRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]service.Proxy, *pagination.PaginationResult, error) {
q := r.client.Proxy.Query()
if protocol != "" {
q = q.Where(proxy.ProtocolEQ(protocol))
}
if status != "" {
q = q.Where(proxy.StatusEQ(status))
}
if search != "" {
q = q.Where(proxy.NameContainsFold(search))
}
total, err := q.Count(ctx)
if err != nil {
return nil, nil, err
}
proxies, err := q.
Offset(params.Offset()).
Limit(params.Limit()).
Order(dbent.Desc(proxy.FieldID)).
All(ctx)
if err != nil {
return nil, nil, err
}
outProxies := make([]service.Proxy, 0, len(proxies))
for i := range proxies {
outProxies = append(outProxies, *proxyEntityToService(proxies[i]))
}
return outProxies, paginationResultFromTotal(int64(total), params), nil
}
func (r *proxyRepository) ListActive(ctx context.Context) ([]service.Proxy, error) {
proxies, err := r.client.Proxy.Query().
Where(proxy.StatusEQ(service.StatusActive)).
All(ctx)
if err != nil {
return nil, err
}
outProxies := make([]service.Proxy, 0, len(proxies))
for i := range proxies {
outProxies = append(outProxies, *proxyEntityToService(proxies[i]))
}
return outProxies, nil
}
// ExistsByHostPortAuth checks if a proxy with the same host, port, username, and password exists
func (r *proxyRepository) ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error) {
q := r.client.Proxy.Query().
Where(proxy.HostEQ(host), proxy.PortEQ(port))
if username == "" {
q = q.Where(proxy.Or(proxy.UsernameIsNil(), proxy.UsernameEQ("")))
} else {
q = q.Where(proxy.UsernameEQ(username))
}
if password == "" {
q = q.Where(proxy.Or(proxy.PasswordIsNil(), proxy.PasswordEQ("")))
} else {
q = q.Where(proxy.PasswordEQ(password))
}
count, err := q.Count(ctx)
return count > 0, err
}
// CountAccountsByProxyID returns the number of accounts using a specific proxy
func (r *proxyRepository) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) {
var count int64
if err := scanSingleRow(ctx, r.sql, "SELECT COUNT(*) FROM accounts WHERE proxy_id = $1", []any{proxyID}, &count); err != nil {
return 0, err
}
return count, nil
}
// GetAccountCountsForProxies returns a map of proxy ID to account count for all proxies
func (r *proxyRepository) GetAccountCountsForProxies(ctx context.Context) (map[int64]int64, error) {
rows, err := r.sql.QueryContext(ctx, "SELECT proxy_id, COUNT(*) AS count FROM accounts WHERE proxy_id IS NOT NULL GROUP BY proxy_id")
if err != nil {
return nil, err
}
defer rows.Close()
counts := make(map[int64]int64)
for rows.Next() {
var proxyID, count int64
if err := rows.Scan(&proxyID, &count); err != nil {
return nil, err
}
counts[proxyID] = count
}
if err := rows.Err(); err != nil {
return nil, err
}
return counts, nil
}
// ListActiveWithAccountCount returns all active proxies with account count, sorted by creation time descending
func (r *proxyRepository) ListActiveWithAccountCount(ctx context.Context) ([]service.ProxyWithAccountCount, error) {
proxies, err := r.client.Proxy.Query().
Where(proxy.StatusEQ(service.StatusActive)).
Order(dbent.Desc(proxy.FieldCreatedAt)).
All(ctx)
if err != nil {
return nil, err
}
// Get account counts
counts, err := r.GetAccountCountsForProxies(ctx)
if err != nil {
return nil, err
}
// Build result with account counts
result := make([]service.ProxyWithAccountCount, 0, len(proxies))
for i := range proxies {
proxyOut := proxyEntityToService(proxies[i])
if proxyOut == nil {
continue
}
result = append(result, service.ProxyWithAccountCount{
Proxy: *proxyOut,
AccountCount: counts[proxyOut.ID],
})
}
return result, nil
}
func proxyEntityToService(m *dbent.Proxy) *service.Proxy {
if m == nil {
return nil
}
out := &service.Proxy{
ID: m.ID,
Name: m.Name,
Protocol: m.Protocol,
Host: m.Host,
Port: m.Port,
Status: m.Status,
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
}
if m.Username != nil {
out.Username = *m.Username
}
if m.Password != nil {
out.Password = *m.Password
}
return out
}
func applyProxyEntityToService(dst *service.Proxy, src *dbent.Proxy) {
if dst == nil || src == nil {
return
}
dst.ID = src.ID
dst.CreatedAt = src.CreatedAt
dst.UpdatedAt = src.UpdatedAt
}