Merge remote-tracking branch 'upstream/main'
This commit is contained in:
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
@@ -12,9 +13,10 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
apiKeyRateLimitKeyPrefix = "apikey:ratelimit:"
|
||||
apiKeyRateLimitDuration = 24 * time.Hour
|
||||
apiKeyAuthCachePrefix = "apikey:auth:"
|
||||
apiKeyRateLimitKeyPrefix = "apikey:ratelimit:"
|
||||
apiKeyRateLimitDuration = 24 * time.Hour
|
||||
apiKeyAuthCachePrefix = "apikey:auth:"
|
||||
authCacheInvalidateChannel = "auth:cache:invalidate"
|
||||
)
|
||||
|
||||
// apiKeyRateLimitKey generates the Redis key for API key creation rate limiting.
|
||||
@@ -91,3 +93,45 @@ func (c *apiKeyCache) SetAuthCache(ctx context.Context, key string, entry *servi
|
||||
func (c *apiKeyCache) DeleteAuthCache(ctx context.Context, key string) error {
|
||||
return c.rdb.Del(ctx, apiKeyAuthCacheKey(key)).Err()
|
||||
}
|
||||
|
||||
// PublishAuthCacheInvalidation publishes a cache invalidation message to all instances
|
||||
func (c *apiKeyCache) PublishAuthCacheInvalidation(ctx context.Context, cacheKey string) error {
|
||||
return c.rdb.Publish(ctx, authCacheInvalidateChannel, cacheKey).Err()
|
||||
}
|
||||
|
||||
// SubscribeAuthCacheInvalidation subscribes to cache invalidation messages
|
||||
func (c *apiKeyCache) SubscribeAuthCacheInvalidation(ctx context.Context, handler func(cacheKey string)) error {
|
||||
pubsub := c.rdb.Subscribe(ctx, authCacheInvalidateChannel)
|
||||
|
||||
// Verify subscription is working
|
||||
_, err := pubsub.Receive(ctx)
|
||||
if err != nil {
|
||||
_ = pubsub.Close()
|
||||
return fmt.Errorf("subscribe to auth cache invalidation: %w", err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
if err := pubsub.Close(); err != nil {
|
||||
log.Printf("Warning: failed to close auth cache invalidation pubsub: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
ch := pubsub.Channel()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case msg, ok := <-ch:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if msg != nil {
|
||||
handler(msg.Payload)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -65,5 +65,18 @@ func InitEnt(cfg *config.Config) (*ent.Client, *sql.DB, error) {
|
||||
|
||||
// 创建 Ent 客户端,绑定到已配置的数据库驱动。
|
||||
client := ent.NewClient(ent.Driver(drv))
|
||||
|
||||
// SIMPLE 模式:启动时补齐各平台默认分组。
|
||||
// - anthropic/openai/gemini: 确保存在 <platform>-default
|
||||
// - antigravity: 仅要求存在 >=2 个未软删除分组(用于 claude/gemini 混合调度场景)
|
||||
if cfg.RunMode == config.RunModeSimple {
|
||||
seedCtx, seedCancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer seedCancel()
|
||||
if err := ensureSimpleModeDefaultGroups(seedCtx, client); err != nil {
|
||||
_ = client.Close()
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return client, drv.DB(), nil
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
@@ -14,10 +15,19 @@ import (
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// debugLog prints log only in non-release mode.
|
||||
func debugLog(format string, v ...any) {
|
||||
if gin.Mode() != gin.ReleaseMode {
|
||||
log.Printf(format, v...)
|
||||
}
|
||||
}
|
||||
|
||||
// 默认配置常量
|
||||
// 这些值在配置文件未指定时作为回退默认值使用
|
||||
const (
|
||||
@@ -150,6 +160,170 @@ func (s *httpUpstreamService) Do(req *http.Request, proxyURL string, accountID i
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// DoWithTLS 执行带 TLS 指纹伪装的 HTTP 请求
|
||||
// 根据 enableTLSFingerprint 参数决定是否使用 TLS 指纹
|
||||
//
|
||||
// 参数:
|
||||
// - req: HTTP 请求对象
|
||||
// - proxyURL: 代理地址,空字符串表示直连
|
||||
// - accountID: 账户 ID,用于账户级隔离和 TLS 指纹模板选择
|
||||
// - accountConcurrency: 账户并发限制,用于动态调整连接池大小
|
||||
// - enableTLSFingerprint: 是否启用 TLS 指纹伪装
|
||||
//
|
||||
// TLS 指纹说明:
|
||||
// - 当 enableTLSFingerprint=true 时,使用 utls 库模拟 Claude CLI 的 TLS 指纹
|
||||
// - 指纹模板根据 accountID % len(profiles) 自动选择
|
||||
// - 支持直连、HTTP/HTTPS 代理、SOCKS5 代理三种场景
|
||||
func (s *httpUpstreamService) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) {
|
||||
// 如果未启用 TLS 指纹,直接使用标准请求路径
|
||||
if !enableTLSFingerprint {
|
||||
return s.Do(req, proxyURL, accountID, accountConcurrency)
|
||||
}
|
||||
|
||||
// TLS 指纹已启用,记录调试日志
|
||||
targetHost := ""
|
||||
if req != nil && req.URL != nil {
|
||||
targetHost = req.URL.Host
|
||||
}
|
||||
proxyInfo := "direct"
|
||||
if proxyURL != "" {
|
||||
proxyInfo = proxyURL
|
||||
}
|
||||
debugLog("[TLS Fingerprint] Account %d: TLS fingerprint ENABLED, target=%s, proxy=%s", accountID, targetHost, proxyInfo)
|
||||
|
||||
if err := s.validateRequestHost(req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 获取 TLS 指纹 Profile
|
||||
registry := tlsfingerprint.GlobalRegistry()
|
||||
profile := registry.GetProfileByAccountID(accountID)
|
||||
if profile == nil {
|
||||
// 如果获取不到 profile,回退到普通请求
|
||||
debugLog("[TLS Fingerprint] Account %d: WARNING - no profile found, falling back to standard request", accountID)
|
||||
return s.Do(req, proxyURL, accountID, accountConcurrency)
|
||||
}
|
||||
|
||||
debugLog("[TLS Fingerprint] Account %d: Using profile '%s' (GREASE=%v)", accountID, profile.Name, profile.EnableGREASE)
|
||||
|
||||
// 获取或创建带 TLS 指纹的客户端
|
||||
entry, err := s.acquireClientWithTLS(proxyURL, accountID, accountConcurrency, profile)
|
||||
if err != nil {
|
||||
debugLog("[TLS Fingerprint] Account %d: Failed to acquire TLS client: %v", accountID, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 执行请求
|
||||
resp, err := entry.client.Do(req)
|
||||
if err != nil {
|
||||
// 请求失败,立即减少计数
|
||||
atomic.AddInt64(&entry.inFlight, -1)
|
||||
atomic.StoreInt64(&entry.lastUsed, time.Now().UnixNano())
|
||||
debugLog("[TLS Fingerprint] Account %d: Request FAILED: %v", accountID, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
debugLog("[TLS Fingerprint] Account %d: Request SUCCESS, status=%d", accountID, resp.StatusCode)
|
||||
|
||||
// 包装响应体,在关闭时自动减少计数并更新时间戳
|
||||
resp.Body = wrapTrackedBody(resp.Body, func() {
|
||||
atomic.AddInt64(&entry.inFlight, -1)
|
||||
atomic.StoreInt64(&entry.lastUsed, time.Now().UnixNano())
|
||||
})
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// acquireClientWithTLS 获取或创建带 TLS 指纹的客户端
|
||||
func (s *httpUpstreamService) acquireClientWithTLS(proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile) (*upstreamClientEntry, error) {
|
||||
return s.getClientEntryWithTLS(proxyURL, accountID, accountConcurrency, profile, true, true)
|
||||
}
|
||||
|
||||
// getClientEntryWithTLS 获取或创建带 TLS 指纹的客户端条目
|
||||
// TLS 指纹客户端使用独立的缓存键,与普通客户端隔离
|
||||
func (s *httpUpstreamService) getClientEntryWithTLS(proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile, markInFlight bool, enforceLimit bool) (*upstreamClientEntry, error) {
|
||||
isolation := s.getIsolationMode()
|
||||
proxyKey, parsedProxy := normalizeProxyURL(proxyURL)
|
||||
// TLS 指纹客户端使用独立的缓存键,加 "tls:" 前缀
|
||||
cacheKey := "tls:" + buildCacheKey(isolation, proxyKey, accountID)
|
||||
poolKey := s.buildPoolKey(isolation, accountConcurrency) + ":tls"
|
||||
|
||||
now := time.Now()
|
||||
nowUnix := now.UnixNano()
|
||||
|
||||
// 读锁快速路径
|
||||
s.mu.RLock()
|
||||
if entry, ok := s.clients[cacheKey]; ok && s.shouldReuseEntry(entry, isolation, proxyKey, poolKey) {
|
||||
atomic.StoreInt64(&entry.lastUsed, nowUnix)
|
||||
if markInFlight {
|
||||
atomic.AddInt64(&entry.inFlight, 1)
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
debugLog("[TLS Fingerprint] Account %d: Reusing existing TLS client (cacheKey=%s)", accountID, cacheKey)
|
||||
return entry, nil
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
|
||||
// 写锁慢路径
|
||||
s.mu.Lock()
|
||||
if entry, ok := s.clients[cacheKey]; ok {
|
||||
if s.shouldReuseEntry(entry, isolation, proxyKey, poolKey) {
|
||||
atomic.StoreInt64(&entry.lastUsed, nowUnix)
|
||||
if markInFlight {
|
||||
atomic.AddInt64(&entry.inFlight, 1)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
debugLog("[TLS Fingerprint] Account %d: Reusing existing TLS client (cacheKey=%s)", accountID, cacheKey)
|
||||
return entry, nil
|
||||
}
|
||||
debugLog("[TLS Fingerprint] Account %d: Evicting stale TLS client (cacheKey=%s, proxyChanged=%v, poolChanged=%v)",
|
||||
accountID, cacheKey, entry.proxyKey != proxyKey, entry.poolKey != poolKey)
|
||||
s.removeClientLocked(cacheKey, entry)
|
||||
}
|
||||
|
||||
// 超出缓存上限时尝试淘汰
|
||||
if enforceLimit && s.maxUpstreamClients() > 0 {
|
||||
s.evictIdleLocked(now)
|
||||
if len(s.clients) >= s.maxUpstreamClients() {
|
||||
if !s.evictOldestIdleLocked() {
|
||||
s.mu.Unlock()
|
||||
return nil, errUpstreamClientLimitReached
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 创建带 TLS 指纹的 Transport
|
||||
debugLog("[TLS Fingerprint] Account %d: Creating NEW TLS fingerprint client (cacheKey=%s, proxy=%s)",
|
||||
accountID, cacheKey, proxyKey)
|
||||
settings := s.resolvePoolSettings(isolation, accountConcurrency)
|
||||
transport, err := buildUpstreamTransportWithTLSFingerprint(settings, parsedProxy, profile)
|
||||
if err != nil {
|
||||
s.mu.Unlock()
|
||||
return nil, fmt.Errorf("build TLS fingerprint transport: %w", err)
|
||||
}
|
||||
|
||||
client := &http.Client{Transport: transport}
|
||||
if s.shouldValidateResolvedIP() {
|
||||
client.CheckRedirect = s.redirectChecker
|
||||
}
|
||||
|
||||
entry := &upstreamClientEntry{
|
||||
client: client,
|
||||
proxyKey: proxyKey,
|
||||
poolKey: poolKey,
|
||||
}
|
||||
atomic.StoreInt64(&entry.lastUsed, nowUnix)
|
||||
if markInFlight {
|
||||
atomic.StoreInt64(&entry.inFlight, 1)
|
||||
}
|
||||
s.clients[cacheKey] = entry
|
||||
|
||||
s.evictIdleLocked(now)
|
||||
s.evictOverLimitLocked()
|
||||
s.mu.Unlock()
|
||||
return entry, nil
|
||||
}
|
||||
|
||||
func (s *httpUpstreamService) shouldValidateResolvedIP() bool {
|
||||
if s.cfg == nil {
|
||||
return false
|
||||
@@ -618,6 +792,64 @@ func buildUpstreamTransport(settings poolSettings, proxyURL *url.URL) (*http.Tra
|
||||
return transport, nil
|
||||
}
|
||||
|
||||
// buildUpstreamTransportWithTLSFingerprint 构建带 TLS 指纹伪装的 Transport
|
||||
// 使用 utls 库模拟 Claude CLI 的 TLS 指纹
|
||||
//
|
||||
// 参数:
|
||||
// - settings: 连接池配置
|
||||
// - proxyURL: 代理 URL(nil 表示直连)
|
||||
// - profile: TLS 指纹配置
|
||||
//
|
||||
// 返回:
|
||||
// - *http.Transport: 配置好的 Transport 实例
|
||||
// - error: 配置错误
|
||||
//
|
||||
// 代理类型处理:
|
||||
// - nil/空: 直连,使用 TLSFingerprintDialer
|
||||
// - http/https: HTTP 代理,使用 HTTPProxyDialer(CONNECT 隧道 + utls 握手)
|
||||
// - socks5: SOCKS5 代理,使用 SOCKS5ProxyDialer(SOCKS5 隧道 + utls 握手)
|
||||
func buildUpstreamTransportWithTLSFingerprint(settings poolSettings, proxyURL *url.URL, profile *tlsfingerprint.Profile) (*http.Transport, error) {
|
||||
transport := &http.Transport{
|
||||
MaxIdleConns: settings.maxIdleConns,
|
||||
MaxIdleConnsPerHost: settings.maxIdleConnsPerHost,
|
||||
MaxConnsPerHost: settings.maxConnsPerHost,
|
||||
IdleConnTimeout: settings.idleConnTimeout,
|
||||
ResponseHeaderTimeout: settings.responseHeaderTimeout,
|
||||
// 禁用默认的 TLS,我们使用自定义的 DialTLSContext
|
||||
ForceAttemptHTTP2: false,
|
||||
}
|
||||
|
||||
// 根据代理类型选择合适的 TLS 指纹 Dialer
|
||||
if proxyURL == nil {
|
||||
// 直连:使用 TLSFingerprintDialer
|
||||
debugLog("[TLS Fingerprint Transport] Using DIRECT TLS dialer (no proxy)")
|
||||
dialer := tlsfingerprint.NewDialer(profile, nil)
|
||||
transport.DialTLSContext = dialer.DialTLSContext
|
||||
} else {
|
||||
scheme := strings.ToLower(proxyURL.Scheme)
|
||||
switch scheme {
|
||||
case "socks5", "socks5h":
|
||||
// SOCKS5 代理:使用 SOCKS5ProxyDialer
|
||||
debugLog("[TLS Fingerprint Transport] Using SOCKS5 TLS dialer (proxy=%s)", proxyURL.Host)
|
||||
socks5Dialer := tlsfingerprint.NewSOCKS5ProxyDialer(profile, proxyURL)
|
||||
transport.DialTLSContext = socks5Dialer.DialTLSContext
|
||||
case "http", "https":
|
||||
// HTTP/HTTPS 代理:使用 HTTPProxyDialer(CONNECT 隧道)
|
||||
debugLog("[TLS Fingerprint Transport] Using HTTP CONNECT TLS dialer (proxy=%s)", proxyURL.Host)
|
||||
httpDialer := tlsfingerprint.NewHTTPProxyDialer(profile, proxyURL)
|
||||
transport.DialTLSContext = httpDialer.DialTLSContext
|
||||
default:
|
||||
// 未知代理类型,回退到普通代理配置(无 TLS 指纹)
|
||||
debugLog("[TLS Fingerprint Transport] WARNING: Unknown proxy scheme '%s', falling back to standard proxy (NO TLS fingerprint)", scheme)
|
||||
if err := proxyutil.ConfigureTransportProxy(transport, proxyURL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return transport, nil
|
||||
}
|
||||
|
||||
// trackedBody 带跟踪功能的响应体包装器
|
||||
// 在 Close 时执行回调,用于更新请求计数
|
||||
type trackedBody struct {
|
||||
|
||||
@@ -992,7 +992,8 @@ func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) {
|
||||
}
|
||||
|
||||
// View filter: errors vs excluded vs all.
|
||||
// Excluded = upstream 429/529 and business-limited (quota/concurrency/billing) errors.
|
||||
// Excluded = business-limited errors (quota/concurrency/billing).
|
||||
// Upstream 429/529 are included in errors view to match SLA calculation.
|
||||
view := ""
|
||||
if filter != nil {
|
||||
view = strings.ToLower(strings.TrimSpace(filter.View))
|
||||
@@ -1000,15 +1001,13 @@ func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) {
|
||||
switch view {
|
||||
case "", "errors":
|
||||
clauses = append(clauses, "COALESCE(is_business_limited,false) = false")
|
||||
clauses = append(clauses, "COALESCE(upstream_status_code, status_code, 0) NOT IN (429, 529)")
|
||||
case "excluded":
|
||||
clauses = append(clauses, "(COALESCE(is_business_limited,false) = true OR COALESCE(upstream_status_code, status_code, 0) IN (429, 529))")
|
||||
clauses = append(clauses, "COALESCE(is_business_limited,false) = true")
|
||||
case "all":
|
||||
// no-op
|
||||
default:
|
||||
// treat unknown as default 'errors'
|
||||
clauses = append(clauses, "COALESCE(is_business_limited,false) = false")
|
||||
clauses = append(clauses, "COALESCE(upstream_status_code, status_code, 0) NOT IN (429, 529)")
|
||||
}
|
||||
if len(filter.StatusCodes) > 0 {
|
||||
args = append(args, pq.Array(filter.StatusCodes))
|
||||
|
||||
82
backend/internal/repository/simple_mode_default_groups.go
Normal file
82
backend/internal/repository/simple_mode_default_groups.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/group"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
func ensureSimpleModeDefaultGroups(ctx context.Context, client *dbent.Client) error {
|
||||
if client == nil {
|
||||
return fmt.Errorf("nil ent client")
|
||||
}
|
||||
|
||||
requiredByPlatform := map[string]int{
|
||||
service.PlatformAnthropic: 1,
|
||||
service.PlatformOpenAI: 1,
|
||||
service.PlatformGemini: 1,
|
||||
service.PlatformAntigravity: 2,
|
||||
}
|
||||
|
||||
for platform, minCount := range requiredByPlatform {
|
||||
count, err := client.Group.Query().
|
||||
Where(group.PlatformEQ(platform), group.DeletedAtIsNil()).
|
||||
Count(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("count groups for platform %s: %w", platform, err)
|
||||
}
|
||||
|
||||
if platform == service.PlatformAntigravity {
|
||||
if count < minCount {
|
||||
for i := count; i < minCount; i++ {
|
||||
name := fmt.Sprintf("%s-default-%d", platform, i+1)
|
||||
if err := createGroupIfNotExists(ctx, client, name, platform); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Non-antigravity platforms: ensure <platform>-default exists.
|
||||
name := platform + "-default"
|
||||
if err := createGroupIfNotExists(ctx, client, name, platform); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func createGroupIfNotExists(ctx context.Context, client *dbent.Client, name, platform string) error {
|
||||
exists, err := client.Group.Query().
|
||||
Where(group.NameEQ(name), group.DeletedAtIsNil()).
|
||||
Exist(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("check group exists %s: %w", name, err)
|
||||
}
|
||||
if exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err = client.Group.Create().
|
||||
SetName(name).
|
||||
SetDescription("Auto-created default group").
|
||||
SetPlatform(platform).
|
||||
SetStatus(service.StatusActive).
|
||||
SetSubscriptionType(service.SubscriptionTypeStandard).
|
||||
SetRateMultiplier(1.0).
|
||||
SetIsExclusive(false).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsConstraintError(err) {
|
||||
// Concurrent server startups may race on creation; treat as success.
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("create default group %s: %w", name, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,84 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/ent/group"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestEnsureSimpleModeDefaultGroups_CreatesMissingDefaults(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
tx := testEntTx(t)
|
||||
client := tx.Client()
|
||||
|
||||
seedCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
require.NoError(t, ensureSimpleModeDefaultGroups(seedCtx, client))
|
||||
|
||||
assertGroupExists := func(name string) {
|
||||
exists, err := client.Group.Query().Where(group.NameEQ(name), group.DeletedAtIsNil()).Exist(seedCtx)
|
||||
require.NoError(t, err)
|
||||
require.True(t, exists, "expected group %s to exist", name)
|
||||
}
|
||||
|
||||
assertGroupExists(service.PlatformAnthropic + "-default")
|
||||
assertGroupExists(service.PlatformOpenAI + "-default")
|
||||
assertGroupExists(service.PlatformGemini + "-default")
|
||||
assertGroupExists(service.PlatformAntigravity + "-default-1")
|
||||
assertGroupExists(service.PlatformAntigravity + "-default-2")
|
||||
}
|
||||
|
||||
func TestEnsureSimpleModeDefaultGroups_IgnoresSoftDeletedGroups(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
tx := testEntTx(t)
|
||||
client := tx.Client()
|
||||
|
||||
seedCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Create and then soft-delete an anthropic default group.
|
||||
g, err := client.Group.Create().
|
||||
SetName(service.PlatformAnthropic + "-default").
|
||||
SetPlatform(service.PlatformAnthropic).
|
||||
SetStatus(service.StatusActive).
|
||||
SetSubscriptionType(service.SubscriptionTypeStandard).
|
||||
SetRateMultiplier(1.0).
|
||||
SetIsExclusive(false).
|
||||
Save(seedCtx)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.Group.Delete().Where(group.IDEQ(g.ID)).Exec(seedCtx)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, ensureSimpleModeDefaultGroups(seedCtx, client))
|
||||
|
||||
// New active one should exist.
|
||||
count, err := client.Group.Query().Where(group.NameEQ(service.PlatformAnthropic+"-default"), group.DeletedAtIsNil()).Count(seedCtx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, count)
|
||||
}
|
||||
|
||||
func TestEnsureSimpleModeDefaultGroups_AntigravityNeedsTwoGroupsOnlyByCount(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
tx := testEntTx(t)
|
||||
client := tx.Client()
|
||||
|
||||
seedCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
mustCreateGroup(t, client, &service.Group{Name: "ag-custom-1-" + time.Now().Format(time.RFC3339Nano), Platform: service.PlatformAntigravity})
|
||||
mustCreateGroup(t, client, &service.Group{Name: "ag-custom-2-" + time.Now().Format(time.RFC3339Nano), Platform: service.PlatformAntigravity})
|
||||
|
||||
require.NoError(t, ensureSimpleModeDefaultGroups(seedCtx, client))
|
||||
|
||||
count, err := client.Group.Query().Where(group.PlatformEQ(service.PlatformAntigravity), group.DeletedAtIsNil()).Count(seedCtx)
|
||||
require.NoError(t, err)
|
||||
require.GreaterOrEqual(t, count, 2)
|
||||
}
|
||||
Reference in New Issue
Block a user