feat(websearch): proxy failover, timeout, quota-weighted load balancing

- Use proxyutil.ConfigureTransportProxy for unified proxy protocol support
  (HTTP/HTTPS/SOCKS5/SOCKS5H), replacing ad-hoc HTTP-only proxy code
- Proxy errors return ErrProxyUnavailable → gateway triggers account switch
  via UpstreamFailoverError instead of fallback to direct connection
- Timeout: proxy dial 3s, TLS handshake 3s, data transfer 60s
- Mark proxy unavailable for 5 minutes in Redis on connectivity failure
- Quota-weighted load balancing: providers with quota_limit>0 are selected
  by remaining quota (weighted random); quota_limit=0 providers treated as
  0% weight and placed last
This commit is contained in:
erio
2026-04-12 01:48:06 +08:00
parent 7535e312e0
commit fda61b067c
2 changed files with 227 additions and 52 deletions

View File

@@ -3,8 +3,11 @@ package websearch
import ( import (
"context" "context"
"crypto/tls" "crypto/tls"
"errors"
"fmt" "fmt"
"log/slog" "log/slog"
"math/rand"
"net"
"net/http" "net/http"
"net/url" "net/url"
"sort" "sort"
@@ -12,6 +15,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
) )
@@ -30,6 +34,7 @@ type ProviderConfig struct {
QuotaLimit int64 `json:"quota_limit"` // 0 = unlimited QuotaLimit int64 `json:"quota_limit"` // 0 = unlimited
QuotaRefreshInterval string `json:"quota_refresh_interval"` // QuotaRefreshDaily / Weekly / Monthly QuotaRefreshInterval string `json:"quota_refresh_interval"` // QuotaRefreshDaily / Weekly / Monthly
ProxyURL string `json:"-"` // resolved proxy URL (not persisted) ProxyURL string `json:"-"` // resolved proxy URL (not persisted)
ProxyID int64 `json:"-"` // resolved proxy ID for unavailability tracking
ExpiresAt *int64 `json:"expires_at,omitempty"` // optional expiration (unix seconds) ExpiresAt *int64 `json:"expires_at,omitempty"` // optional expiration (unix seconds)
} }
@@ -42,22 +47,30 @@ type Manager struct {
clientCache map[string]*http.Client clientCache map[string]*http.Client
} }
// Timeout constants for proxy and search operations.
const ( const (
proxyDialTimeout = 3 * time.Second // proxy TCP connection timeout
proxyTLSTimeout = 3 * time.Second // TLS handshake timeout
searchDataTimeout = 60 * time.Second // response data transfer timeout
searchRequestTimeout = searchDataTimeout + proxyDialTimeout
quotaKeyPrefix = "websearch:quota:" quotaKeyPrefix = "websearch:quota:"
searchRequestTimeout = 30 * time.Second proxyUnavailableKey = "websearch:proxy_unavailable:%d"
proxyUnavailableTTL = 5 * time.Minute
quotaTTLBuffer = 24 * time.Hour quotaTTLBuffer = 24 * time.Hour
maxCachedClients = 100 maxCachedClients = 100
) )
// ErrProxyUnavailable indicates the search failed due to a proxy connectivity issue.
// Callers may use this to trigger account switching instead of direct fallback.
var ErrProxyUnavailable = errors.New("websearch: proxy unavailable")
// quotaIncrScript atomically increments the counter and sets TTL on first creation. // quotaIncrScript atomically increments the counter and sets TTL on first creation.
// KEYS[1] = quota key, ARGV[1] = TTL in seconds.
// Returns the new counter value.
var quotaIncrScript = redis.NewScript(` var quotaIncrScript = redis.NewScript(`
local val = redis.call('INCR', KEYS[1]) local val = redis.call('INCR', KEYS[1])
if val == 1 then if val == 1 then
redis.call('EXPIRE', KEYS[1], ARGV[1]) redis.call('EXPIRE', KEYS[1], ARGV[1])
else else
-- Defensive: ensure TTL exists even if a prior EXPIRE failed
local ttl = redis.call('TTL', KEYS[1]) local ttl = redis.call('TTL', KEYS[1])
if ttl == -1 then if ttl == -1 then
redis.call('EXPIRE', KEYS[1], ARGV[1]) redis.call('EXPIRE', KEYS[1], ARGV[1])
@@ -80,16 +93,22 @@ func NewManager(configs []ProviderConfig, redisClient *redis.Client) *Manager {
} }
} }
// SearchWithBestProvider selects the highest-priority available provider, // SearchWithBestProvider selects a provider using quota-weighted load balancing,
// reserves quota, executes the search, and rolls back quota on failure. // reserves quota, executes the search, and rolls back quota on failure.
// If the search fails due to a proxy error, the proxy is marked unavailable for 5 minutes.
func (m *Manager) SearchWithBestProvider(ctx context.Context, req SearchRequest) (*SearchResponse, string, error) { func (m *Manager) SearchWithBestProvider(ctx context.Context, req SearchRequest) (*SearchResponse, string, error) {
if strings.TrimSpace(req.Query) == "" { if strings.TrimSpace(req.Query) == "" {
return nil, "", fmt.Errorf("websearch: empty search query") return nil, "", fmt.Errorf("websearch: empty search query")
} }
for _, cfg := range m.configs {
if !m.isProviderAvailable(cfg) { candidates := m.filterAvailableProviders(ctx, req.ProxyURL)
continue if len(candidates) == 0 {
} return nil, "", fmt.Errorf("websearch: no available provider (all exhausted, expired, or proxy unavailable)")
}
selected := m.selectByQuotaWeight(ctx, candidates)
for _, cfg := range selected {
allowed, incremented := m.tryReserveQuota(ctx, cfg) allowed, incremented := m.tryReserveQuota(ctx, cfg)
if !allowed { if !allowed {
continue continue
@@ -99,6 +118,12 @@ func (m *Manager) SearchWithBestProvider(ctx context.Context, req SearchRequest)
if incremented { if incremented {
m.rollbackQuota(ctx, cfg) m.rollbackQuota(ctx, cfg)
} }
if isProxyError(err) {
m.markProxyUnavailable(ctx, cfg, req.ProxyURL)
slog.Warn("websearch: proxy error, marking unavailable",
"provider", cfg.Type, "error", err)
return nil, "", fmt.Errorf("%w: %s", ErrProxyUnavailable, err.Error())
}
slog.Warn("websearch: provider search failed", slog.Warn("websearch: provider search failed",
"provider", cfg.Type, "error", err) "provider", cfg.Type, "error", err)
continue continue
@@ -108,6 +133,76 @@ func (m *Manager) SearchWithBestProvider(ctx context.Context, req SearchRequest)
return nil, "", fmt.Errorf("websearch: no available provider (all exhausted or failed)") return nil, "", fmt.Errorf("websearch: no available provider (all exhausted or failed)")
} }
// filterAvailableProviders returns providers that have API keys, are not expired,
// and whose proxies are not marked unavailable.
func (m *Manager) filterAvailableProviders(ctx context.Context, accountProxyURL string) []ProviderConfig {
var out []ProviderConfig
for _, cfg := range m.configs {
if !m.isProviderAvailable(cfg) {
continue
}
proxyID := resolveProxyID(cfg, accountProxyURL)
if proxyID > 0 && !m.isProxyAvailable(ctx, proxyID) {
slog.Debug("websearch: proxy marked unavailable, skipping",
"provider", cfg.Type, "proxy_id", proxyID)
continue
}
out = append(out, cfg)
}
return out
}
// selectByQuotaWeight orders candidates by remaining quota weight.
// Providers with quota_limit=0 (no limit set) get weight 0 and are placed last.
// Among providers with quota, higher remaining quota = higher priority.
func (m *Manager) selectByQuotaWeight(ctx context.Context, candidates []ProviderConfig) []ProviderConfig {
type weighted struct {
cfg ProviderConfig
weight int64
}
items := make([]weighted, 0, len(candidates))
for _, cfg := range candidates {
w := int64(0)
if cfg.QuotaLimit > 0 {
used, _ := m.GetUsage(ctx, cfg.Type, cfg.QuotaRefreshInterval)
remaining := cfg.QuotaLimit - used
if remaining > 0 {
w = remaining
}
}
items = append(items, weighted{cfg: cfg, weight: w})
}
// Separate providers with quota (weight > 0) from those without (weight == 0)
var withQuota, withoutQuota []weighted
for _, item := range items {
if item.weight > 0 {
withQuota = append(withQuota, item)
} else {
withoutQuota = append(withoutQuota, item)
}
}
// Within quota group: weighted random sort (higher remaining = more likely first)
if len(withQuota) > 1 {
sort.Slice(withQuota, func(i, j int) bool {
wi := float64(withQuota[i].weight) * (0.5 + rand.Float64())
wj := float64(withQuota[j].weight) * (0.5 + rand.Float64())
return wi > wj
})
}
// Build final order: quota providers first, then no-quota providers (original priority order)
result := make([]ProviderConfig, 0, len(candidates))
for _, item := range withQuota {
result = append(result, item.cfg)
}
for _, item := range withoutQuota {
result = append(result, item.cfg)
}
return result
}
func (m *Manager) isProviderAvailable(cfg ProviderConfig) bool { func (m *Manager) isProviderAvailable(cfg ProviderConfig) bool {
if cfg.APIKey == "" { if cfg.APIKey == "" {
return false return false
@@ -120,26 +215,80 @@ func (m *Manager) isProviderAvailable(cfg ProviderConfig) bool {
return true return true
} }
// tryReserveQuota atomically increments the counter via Lua script and checks limit. // --- Proxy availability tracking ---
// Returns (allowed, incremented): allowed=true means the request may proceed;
// incremented=true means the Redis counter was actually incremented (so rollback is needed on failure). // markProxyUnavailable marks the effective proxy as unavailable for proxyUnavailableTTL.
func (m *Manager) markProxyUnavailable(ctx context.Context, cfg ProviderConfig, accountProxyURL string) {
proxyID := resolveProxyID(cfg, accountProxyURL)
if proxyID <= 0 || m.redis == nil {
return
}
key := fmt.Sprintf(proxyUnavailableKey, proxyID)
if err := m.redis.Set(ctx, key, "1", proxyUnavailableTTL).Err(); err != nil {
slog.Warn("websearch: failed to mark proxy unavailable",
"proxy_id", proxyID, "error", err)
}
}
// isProxyAvailable checks whether a proxy is currently marked as unavailable.
func (m *Manager) isProxyAvailable(ctx context.Context, proxyID int64) bool {
if m.redis == nil || proxyID <= 0 {
return true
}
key := fmt.Sprintf(proxyUnavailableKey, proxyID)
val, err := m.redis.Get(ctx, key).Result()
if err != nil {
return true // Redis error → assume available
}
return val == ""
}
// resolveProxyID determines the effective proxy ID for a provider+account combination.
func resolveProxyID(cfg ProviderConfig, accountProxyURL string) int64 {
if accountProxyURL != "" {
return 0 // account proxy has no ID in provider config
}
return cfg.ProxyID
}
// isProxyError checks whether the error is likely caused by proxy connectivity.
func isProxyError(err error) bool {
if err == nil {
return false
}
var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() {
return true
}
var opErr *net.OpError
if errors.As(err, &opErr) {
return true
}
msg := err.Error()
return strings.Contains(msg, "proxy") ||
strings.Contains(msg, "SOCKS") ||
strings.Contains(msg, "connection refused") ||
strings.Contains(msg, "no such host") ||
strings.Contains(msg, "i/o timeout")
}
// --- Quota management ---
func (m *Manager) tryReserveQuota(ctx context.Context, cfg ProviderConfig) (bool, bool) { func (m *Manager) tryReserveQuota(ctx context.Context, cfg ProviderConfig) (bool, bool) {
if cfg.QuotaLimit <= 0 { if cfg.QuotaLimit <= 0 {
return true, false // unlimited, no INCR return true, false
} }
if m.redis == nil { if m.redis == nil {
slog.Warn("websearch: Redis unavailable, quota check skipped", slog.Warn("websearch: Redis unavailable, quota check skipped", "provider", cfg.Type)
"provider", cfg.Type) return true, false
return true, false // allowed but not incremented
} }
key := quotaRedisKey(cfg.Type, cfg.QuotaRefreshInterval) key := quotaRedisKey(cfg.Type, cfg.QuotaRefreshInterval)
ttlSec := int(quotaTTL(cfg.QuotaRefreshInterval).Seconds()) ttlSec := int(quotaTTL(cfg.QuotaRefreshInterval).Seconds())
newVal, err := quotaIncrScript.Run(ctx, m.redis, []string{key}, ttlSec).Int64() newVal, err := quotaIncrScript.Run(ctx, m.redis, []string{key}, ttlSec).Int64()
if err != nil { if err != nil {
slog.Warn("websearch: quota Lua INCR failed, allowing request", slog.Warn("websearch: quota Lua INCR failed, allowing request",
"provider", cfg.Type, "error", err) "provider", cfg.Type, "error", err)
return true, false // allowed but not incremented return true, false
} }
if newVal > cfg.QuotaLimit { if newVal > cfg.QuotaLimit {
if decrErr := m.redis.Decr(ctx, key).Err(); decrErr != nil { if decrErr := m.redis.Decr(ctx, key).Err(); decrErr != nil {
@@ -148,12 +297,11 @@ func (m *Manager) tryReserveQuota(ctx context.Context, cfg ProviderConfig) (bool
} }
slog.Info("websearch: provider quota exhausted", slog.Info("websearch: provider quota exhausted",
"provider", cfg.Type, "used", newVal, "limit", cfg.QuotaLimit) "provider", cfg.Type, "used", newVal, "limit", cfg.QuotaLimit)
return false, false // rejected, already rolled back return false, false
} }
return true, true // allowed and incremented return true, true
} }
// rollbackQuota decrements the counter after a search failure.
func (m *Manager) rollbackQuota(ctx context.Context, cfg ProviderConfig) { func (m *Manager) rollbackQuota(ctx context.Context, cfg ProviderConfig) {
if cfg.QuotaLimit <= 0 || m.redis == nil { if cfg.QuotaLimit <= 0 || m.redis == nil {
return return
@@ -165,16 +313,64 @@ func (m *Manager) rollbackQuota(ctx context.Context, cfg ProviderConfig) {
} }
} }
// --- Search execution ---
func (m *Manager) executeSearch(ctx context.Context, cfg ProviderConfig, req SearchRequest) (*SearchResponse, error) { func (m *Manager) executeSearch(ctx context.Context, cfg ProviderConfig, req SearchRequest) (*SearchResponse, error) {
proxyURL := cfg.ProxyURL proxyURL := cfg.ProxyURL
if req.ProxyURL != "" { if req.ProxyURL != "" {
proxyURL = req.ProxyURL proxyURL = req.ProxyURL
} }
client := m.getOrCreateHTTPClient(proxyURL) client, err := m.getOrCreateHTTPClient(proxyURL)
if err != nil {
return nil, fmt.Errorf("websearch: %w", err)
}
provider := m.buildProvider(cfg, client) provider := m.buildProvider(cfg, client)
return provider.Search(ctx, req) return provider.Search(ctx, req)
} }
// --- HTTP client cache ---
func (m *Manager) getOrCreateHTTPClient(proxyURL string) (*http.Client, error) {
m.clientMu.Lock()
defer m.clientMu.Unlock()
if c, ok := m.clientCache[proxyURL]; ok {
return c, nil
}
if len(m.clientCache) >= maxCachedClients {
m.clientCache = make(map[string]*http.Client)
}
c, err := newHTTPClient(proxyURL)
if err != nil {
return nil, err
}
m.clientCache[proxyURL] = c
return c, nil
}
// newHTTPClient creates an HTTP client with proper timeout settings.
// Uses proxyutil.ConfigureTransportProxy for unified proxy protocol support
// (HTTP/HTTPS/SOCKS5/SOCKS5H).
// Returns error if proxyURL is invalid — never falls back to direct connection.
func newHTTPClient(proxyURL string) (*http.Client, error) {
transport := &http.Transport{
TLSClientConfig: &tls.Config{MinVersion: tls.VersionTLS12},
DialContext: (&net.Dialer{Timeout: proxyDialTimeout}).DialContext,
TLSHandshakeTimeout: proxyTLSTimeout,
ResponseHeaderTimeout: searchDataTimeout,
}
if proxyURL != "" {
parsed, err := url.Parse(proxyURL)
if err != nil {
return nil, fmt.Errorf("invalid proxy URL %q: %w", proxyURL, err)
}
if err := proxyutil.ConfigureTransportProxy(transport, parsed); err != nil {
return nil, fmt.Errorf("configure proxy: %w", err)
}
}
return &http.Client{Transport: transport, Timeout: searchRequestTimeout}, nil
}
// GetUsage returns the current usage count for the given provider. // GetUsage returns the current usage count for the given provider.
func (m *Manager) GetUsage(ctx context.Context, providerType, refreshInterval string) (int64, error) { func (m *Manager) GetUsage(ctx context.Context, providerType, refreshInterval string) (int64, error) {
if m.redis == nil { if m.redis == nil {
@@ -198,35 +394,6 @@ func (m *Manager) GetAllUsage(ctx context.Context) map[string]int64 {
return result return result
} }
// --- HTTP client cache (bounded) ---
func (m *Manager) getOrCreateHTTPClient(proxyURL string) *http.Client {
m.clientMu.Lock()
defer m.clientMu.Unlock()
if c, ok := m.clientCache[proxyURL]; ok {
return c
}
if len(m.clientCache) >= maxCachedClients {
m.clientCache = make(map[string]*http.Client) // evict all
}
c := newHTTPClient(proxyURL)
m.clientCache[proxyURL] = c
return c
}
func newHTTPClient(proxyURL string) *http.Client {
transport := &http.Transport{
TLSClientConfig: &tls.Config{MinVersion: tls.VersionTLS12},
}
if proxyURL != "" {
if u, err := url.Parse(proxyURL); err == nil {
transport.Proxy = http.ProxyURL(u)
}
}
return &http.Client{Transport: transport, Timeout: searchRequestTimeout}
}
// --- Provider factory --- // --- Provider factory ---
func (m *Manager) buildProvider(cfg ProviderConfig, client *http.Client) Provider { func (m *Manager) buildProvider(cfg ProviderConfig, client *http.Client) Provider {
@@ -256,7 +423,7 @@ func periodKey(refreshInterval string) string {
case QuotaRefreshWeekly: case QuotaRefreshWeekly:
year, week := now.ISOWeek() year, week := now.ISOWeek()
return fmt.Sprintf("%d-W%02d", year, week) return fmt.Sprintf("%d-W%02d", year, week)
default: // QuotaRefreshMonthly default:
return now.Format("2006-01") return now.Format("2006-01")
} }
} }

View File

@@ -3,6 +3,7 @@ package service
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"log/slog" "log/slog"
"net/http" "net/http"
@@ -147,6 +148,13 @@ func (s *GatewayService) handleWebSearchEmulation(
resp, providerName, err := doWebSearch(ctx, account, query) resp, providerName, err := doWebSearch(ctx, account, query)
if err != nil { if err != nil {
// Proxy unavailable → trigger account switch via UpstreamFailoverError
if errors.Is(err, websearch.ErrProxyUnavailable) {
return nil, &UpstreamFailoverError{
StatusCode: http.StatusBadGateway,
ResponseBody: []byte(err.Error()),
}
}
return nil, err return nil, err
} }