merge: 合并官方 upstream/main 的 6 个功能更新
Some checks failed
CI / test (push) Has been cancelled
CI / golangci-lint (push) Has been cancelled
CI / test (pull_request) Has been cancelled
CI / golangci-lint (pull_request) Has been cancelled

合并内容:
1. feat(gateway): Claude Code 系统提示词智能注入
2. fix: 修复创建账号 schedulable 默认值为 false 的 bug
3. fix(frontend): 修复跨时区日期范围筛选问题
4. feat(proxy): SOCKS5H 代理支持(统一代理配置)
5. fix(oauth): 修复 Claude Cookie 添加账号时会话混淆
6. fix(test): 修复 OAuth 账号测试刷新 token 的 bug

新增文件:
- backend/internal/pkg/proxyutil/* (SOCKS5H 支持)
- backend/internal/service/gateway_prompt_test.go (测试)

来自 upstream: Wei-Shaw/sub2api commits d9b1587..a527559
This commit is contained in:
huangzhenpc
2026-01-04 18:48:05 +08:00
18 changed files with 10346 additions and 9797 deletions

View File

@@ -1,248 +1,248 @@
// Code generated by Wire. DO NOT EDIT. // Code generated by Wire. DO NOT EDIT.
//go:generate go run -mod=mod github.com/google/wire/cmd/wire //go:generate go run -mod=mod github.com/google/wire/cmd/wire
//go:build !wireinject //go:build !wireinject
// +build !wireinject // +build !wireinject
package main package main
import ( import (
"context" "context"
"github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler" "github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/handler/admin" "github.com/Wei-Shaw/sub2api/internal/handler/admin"
"github.com/Wei-Shaw/sub2api/internal/repository" "github.com/Wei-Shaw/sub2api/internal/repository"
"github.com/Wei-Shaw/sub2api/internal/server" "github.com/Wei-Shaw/sub2api/internal/server"
"github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
"log" "log"
"net/http" "net/http"
"time" "time"
) )
import ( import (
_ "embed" _ "embed"
_ "github.com/Wei-Shaw/sub2api/ent/runtime" _ "github.com/Wei-Shaw/sub2api/ent/runtime"
) )
// Injectors from wire.go: // Injectors from wire.go:
func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
configConfig, err := config.ProvideConfig() configConfig, err := config.ProvideConfig()
if err != nil { if err != nil {
return nil, err return nil, err
} }
client, err := repository.ProvideEnt(configConfig) client, err := repository.ProvideEnt(configConfig)
if err != nil { if err != nil {
return nil, err return nil, err
} }
db, err := repository.ProvideSQLDB(client) db, err := repository.ProvideSQLDB(client)
if err != nil { if err != nil {
return nil, err return nil, err
} }
userRepository := repository.NewUserRepository(client, db) userRepository := repository.NewUserRepository(client, db)
settingRepository := repository.NewSettingRepository(client) settingRepository := repository.NewSettingRepository(client)
settingService := service.NewSettingService(settingRepository, configConfig) settingService := service.NewSettingService(settingRepository, configConfig)
redisClient := repository.ProvideRedis(configConfig) redisClient := repository.ProvideRedis(configConfig)
emailCache := repository.NewEmailCache(redisClient) emailCache := repository.NewEmailCache(redisClient)
emailService := service.NewEmailService(settingRepository, emailCache) emailService := service.NewEmailService(settingRepository, emailCache)
turnstileVerifier := repository.NewTurnstileVerifier() turnstileVerifier := repository.NewTurnstileVerifier()
turnstileService := service.NewTurnstileService(settingService, turnstileVerifier) turnstileService := service.NewTurnstileService(settingService, turnstileVerifier)
emailQueueService := service.ProvideEmailQueueService(emailService) emailQueueService := service.ProvideEmailQueueService(emailService)
authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService) authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService)
userService := service.NewUserService(userRepository) userService := service.NewUserService(userRepository)
authHandler := handler.NewAuthHandler(configConfig, authService, userService) authHandler := handler.NewAuthHandler(configConfig, authService, userService)
userHandler := handler.NewUserHandler(userService) userHandler := handler.NewUserHandler(userService)
apiKeyRepository := repository.NewApiKeyRepository(client) apiKeyRepository := repository.NewApiKeyRepository(client)
groupRepository := repository.NewGroupRepository(client, db) groupRepository := repository.NewGroupRepository(client, db)
userSubscriptionRepository := repository.NewUserSubscriptionRepository(client) userSubscriptionRepository := repository.NewUserSubscriptionRepository(client)
apiKeyCache := repository.NewApiKeyCache(redisClient) apiKeyCache := repository.NewApiKeyCache(redisClient)
apiKeyService := service.NewApiKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, apiKeyCache, configConfig) apiKeyService := service.NewApiKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, apiKeyCache, configConfig)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageLogRepository := repository.NewUsageLogRepository(client, db) usageLogRepository := repository.NewUsageLogRepository(client, db)
usageService := service.NewUsageService(usageLogRepository, userRepository) usageService := service.NewUsageService(usageLogRepository, userRepository)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService) usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
redeemCodeRepository := repository.NewRedeemCodeRepository(client) redeemCodeRepository := repository.NewRedeemCodeRepository(client)
billingCache := repository.NewBillingCache(redisClient) billingCache := repository.NewBillingCache(redisClient)
billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository, configConfig) billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository, configConfig)
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService) subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService)
redeemCache := repository.NewRedeemCache(redisClient) redeemCache := repository.NewRedeemCache(redisClient)
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client) redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client)
redeemHandler := handler.NewRedeemHandler(redeemService) redeemHandler := handler.NewRedeemHandler(redeemService)
subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService) subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService)
dashboardService := service.NewDashboardService(usageLogRepository) dashboardService := service.NewDashboardService(usageLogRepository)
dashboardHandler := admin.NewDashboardHandler(dashboardService) dashboardHandler := admin.NewDashboardHandler(dashboardService)
accountRepository := repository.NewAccountRepository(client, db) accountRepository := repository.NewAccountRepository(client, db)
proxyRepository := repository.NewProxyRepository(client, db) proxyRepository := repository.NewProxyRepository(client, db)
proxyExitInfoProber := repository.NewProxyExitInfoProber() proxyExitInfoProber := repository.NewProxyExitInfoProber()
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber) adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber)
adminUserHandler := admin.NewUserHandler(adminService) adminUserHandler := admin.NewUserHandler(adminService)
groupHandler := admin.NewGroupHandler(adminService) groupHandler := admin.NewGroupHandler(adminService)
claudeOAuthClient := repository.NewClaudeOAuthClient() claudeOAuthClient := repository.NewClaudeOAuthClient()
oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient) oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient)
openAIOAuthClient := repository.NewOpenAIOAuthClient() openAIOAuthClient := repository.NewOpenAIOAuthClient()
openAIOAuthService := service.NewOpenAIOAuthService(proxyRepository, openAIOAuthClient) openAIOAuthService := service.NewOpenAIOAuthService(proxyRepository, openAIOAuthClient)
geminiOAuthClient := repository.NewGeminiOAuthClient(configConfig) geminiOAuthClient := repository.NewGeminiOAuthClient(configConfig)
geminiCliCodeAssistClient := repository.NewGeminiCliCodeAssistClient() geminiCliCodeAssistClient := repository.NewGeminiCliCodeAssistClient()
geminiOAuthService := service.NewGeminiOAuthService(proxyRepository, geminiOAuthClient, geminiCliCodeAssistClient, configConfig) geminiOAuthService := service.NewGeminiOAuthService(proxyRepository, geminiOAuthClient, geminiCliCodeAssistClient, configConfig)
geminiQuotaService := service.NewGeminiQuotaService(configConfig, settingRepository) geminiQuotaService := service.NewGeminiQuotaService(configConfig, settingRepository)
rateLimitService := service.NewRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService) rateLimitService := service.NewRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService)
claudeUsageFetcher := repository.NewClaudeUsageFetcher() claudeUsageFetcher := repository.NewClaudeUsageFetcher()
antigravityQuotaFetcher := service.NewAntigravityQuotaFetcher(proxyRepository) antigravityQuotaFetcher := service.NewAntigravityQuotaFetcher(proxyRepository)
usageCache := service.NewUsageCache() usageCache := service.NewUsageCache()
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache) accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache)
geminiTokenCache := repository.NewGeminiTokenCache(redisClient) geminiTokenCache := repository.NewGeminiTokenCache(redisClient)
geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService) geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService)
gatewayCache := repository.NewGatewayCache(redisClient) gatewayCache := repository.NewGatewayCache(redisClient)
antigravityOAuthService := service.NewAntigravityOAuthService(proxyRepository) antigravityOAuthService := service.NewAntigravityOAuthService(proxyRepository)
antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService) antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService)
httpUpstream := repository.NewHTTPUpstream(configConfig) httpUpstream := repository.NewHTTPUpstream(configConfig)
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, antigravityTokenProvider, rateLimitService, httpUpstream) antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, antigravityTokenProvider, rateLimitService, httpUpstream)
accountTestService := service.NewAccountTestService(accountRepository, oAuthService, openAIOAuthService, geminiTokenProvider, antigravityGatewayService, httpUpstream) accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream)
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig) concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig) concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService) crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService)
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService) accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService)
oAuthHandler := admin.NewOAuthHandler(oAuthService) oAuthHandler := admin.NewOAuthHandler(oAuthService)
openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService) openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService)
geminiOAuthHandler := admin.NewGeminiOAuthHandler(geminiOAuthService) geminiOAuthHandler := admin.NewGeminiOAuthHandler(geminiOAuthService)
antigravityOAuthHandler := admin.NewAntigravityOAuthHandler(antigravityOAuthService) antigravityOAuthHandler := admin.NewAntigravityOAuthHandler(antigravityOAuthService)
proxyHandler := admin.NewProxyHandler(adminService) proxyHandler := admin.NewProxyHandler(adminService)
adminRedeemHandler := admin.NewRedeemHandler(adminService) adminRedeemHandler := admin.NewRedeemHandler(adminService)
settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService) settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService)
updateCache := repository.NewUpdateCache(redisClient) updateCache := repository.NewUpdateCache(redisClient)
gitHubReleaseClient := repository.NewGitHubReleaseClient() gitHubReleaseClient := repository.NewGitHubReleaseClient()
serviceBuildInfo := provideServiceBuildInfo(buildInfo) serviceBuildInfo := provideServiceBuildInfo(buildInfo)
updateService := service.ProvideUpdateService(updateCache, gitHubReleaseClient, serviceBuildInfo) updateService := service.ProvideUpdateService(updateCache, gitHubReleaseClient, serviceBuildInfo)
systemHandler := handler.ProvideSystemHandler(updateService) systemHandler := handler.ProvideSystemHandler(updateService)
adminSubscriptionHandler := admin.NewSubscriptionHandler(subscriptionService) adminSubscriptionHandler := admin.NewSubscriptionHandler(subscriptionService)
adminUsageHandler := admin.NewUsageHandler(usageService, apiKeyService, adminService) adminUsageHandler := admin.NewUsageHandler(usageService, apiKeyService, adminService)
userAttributeDefinitionRepository := repository.NewUserAttributeDefinitionRepository(client) userAttributeDefinitionRepository := repository.NewUserAttributeDefinitionRepository(client)
userAttributeValueRepository := repository.NewUserAttributeValueRepository(client) userAttributeValueRepository := repository.NewUserAttributeValueRepository(client)
userAttributeService := service.NewUserAttributeService(userAttributeDefinitionRepository, userAttributeValueRepository) userAttributeService := service.NewUserAttributeService(userAttributeDefinitionRepository, userAttributeValueRepository)
userAttributeHandler := admin.NewUserAttributeHandler(userAttributeService) userAttributeHandler := admin.NewUserAttributeHandler(userAttributeService)
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, settingHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler) adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, settingHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler)
pricingRemoteClient := repository.NewPricingRemoteClient() pricingRemoteClient := repository.NewPricingRemoteClient()
pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient) pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient)
if err != nil { if err != nil {
return nil, err return nil, err
} }
billingService := service.NewBillingService(configConfig, pricingService) billingService := service.NewBillingService(configConfig, pricingService)
identityCache := repository.NewIdentityCache(redisClient) identityCache := repository.NewIdentityCache(redisClient)
identityService := service.NewIdentityService(identityCache) identityService := service.NewIdentityService(identityCache)
timingWheelService := service.ProvideTimingWheelService() timingWheelService := service.ProvideTimingWheelService()
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService) deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService) gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService) geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService)
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService) gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService) openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService)
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService) openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService)
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo) handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler) handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler)
jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService) jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService)
adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService) adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService)
apiKeyAuthMiddleware := middleware.NewApiKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig) apiKeyAuthMiddleware := middleware.NewApiKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig)
engine := server.ProvideRouter(configConfig, handlers, jwtAuthMiddleware, adminAuthMiddleware, apiKeyAuthMiddleware, apiKeyService, subscriptionService) engine := server.ProvideRouter(configConfig, handlers, jwtAuthMiddleware, adminAuthMiddleware, apiKeyAuthMiddleware, apiKeyService, subscriptionService)
httpServer := server.ProvideHTTPServer(configConfig, engine) httpServer := server.ProvideHTTPServer(configConfig, engine)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, configConfig) tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, configConfig)
v := provideCleanup(client, redisClient, tokenRefreshService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService) v := provideCleanup(client, redisClient, tokenRefreshService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService)
application := &Application{ application := &Application{
Server: httpServer, Server: httpServer,
Cleanup: v, Cleanup: v,
} }
return application, nil return application, nil
} }
// wire.go: // wire.go:
type Application struct { type Application struct {
Server *http.Server Server *http.Server
Cleanup func() Cleanup func()
} }
func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo { func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo {
return service.BuildInfo{ return service.BuildInfo{
Version: buildInfo.Version, Version: buildInfo.Version,
BuildType: buildInfo.BuildType, BuildType: buildInfo.BuildType,
} }
} }
func provideCleanup( func provideCleanup(
entClient *ent.Client, entClient *ent.Client,
rdb *redis.Client, rdb *redis.Client,
tokenRefresh *service.TokenRefreshService, tokenRefresh *service.TokenRefreshService,
pricing *service.PricingService, pricing *service.PricingService,
emailQueue *service.EmailQueueService, emailQueue *service.EmailQueueService,
billingCache *service.BillingCacheService, billingCache *service.BillingCacheService,
oauth *service.OAuthService, oauth *service.OAuthService,
openaiOAuth *service.OpenAIOAuthService, openaiOAuth *service.OpenAIOAuthService,
geminiOAuth *service.GeminiOAuthService, geminiOAuth *service.GeminiOAuthService,
antigravityOAuth *service.AntigravityOAuthService, antigravityOAuth *service.AntigravityOAuthService,
) func() { ) func() {
return func() { return func() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel() defer cancel()
cleanupSteps := []struct { cleanupSteps := []struct {
name string name string
fn func() error fn func() error
}{ }{
{"TokenRefreshService", func() error { {"TokenRefreshService", func() error {
tokenRefresh.Stop() tokenRefresh.Stop()
return nil return nil
}}, }},
{"PricingService", func() error { {"PricingService", func() error {
pricing.Stop() pricing.Stop()
return nil return nil
}}, }},
{"EmailQueueService", func() error { {"EmailQueueService", func() error {
emailQueue.Stop() emailQueue.Stop()
return nil return nil
}}, }},
{"BillingCacheService", func() error { {"BillingCacheService", func() error {
billingCache.Stop() billingCache.Stop()
return nil return nil
}}, }},
{"OAuthService", func() error { {"OAuthService", func() error {
oauth.Stop() oauth.Stop()
return nil return nil
}}, }},
{"OpenAIOAuthService", func() error { {"OpenAIOAuthService", func() error {
openaiOAuth.Stop() openaiOAuth.Stop()
return nil return nil
}}, }},
{"GeminiOAuthService", func() error { {"GeminiOAuthService", func() error {
geminiOAuth.Stop() geminiOAuth.Stop()
return nil return nil
}}, }},
{"AntigravityOAuthService", func() error { {"AntigravityOAuthService", func() error {
antigravityOAuth.Stop() antigravityOAuth.Stop()
return nil return nil
}}, }},
{"Redis", func() error { {"Redis", func() error {
return rdb.Close() return rdb.Close()
}}, }},
{"Ent", func() error { {"Ent", func() error {
return entClient.Close() return entClient.Close()
}}, }},
} }
for _, step := range cleanupSteps { for _, step := range cleanupSteps {
if err := step.fn(); err != nil { if err := step.fn(); err != nil {
log.Printf("[Cleanup] %s failed: %v", step.name, err) log.Printf("[Cleanup] %s failed: %v", step.name, err)
} else { } else {
log.Printf("[Cleanup] %s succeeded", step.name) log.Printf("[Cleanup] %s succeeded", step.name)
} }
} }
select { select {
case <-ctx.Done(): case <-ctx.Done():
log.Printf("[Cleanup] Warning: cleanup timed out after 10 seconds") log.Printf("[Cleanup] Warning: cleanup timed out after 10 seconds")
default: default:
log.Printf("[Cleanup] All cleanup steps completed") log.Printf("[Cleanup] All cleanup steps completed")
} }
} }
} }

View File

@@ -1,323 +1,323 @@
package admin package admin
import ( import (
"strconv" "strconv"
"strings" "strings"
"github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
// ProxyHandler handles admin proxy management // ProxyHandler handles admin proxy management
type ProxyHandler struct { type ProxyHandler struct {
adminService service.AdminService adminService service.AdminService
} }
// NewProxyHandler creates a new admin proxy handler // NewProxyHandler creates a new admin proxy handler
func NewProxyHandler(adminService service.AdminService) *ProxyHandler { func NewProxyHandler(adminService service.AdminService) *ProxyHandler {
return &ProxyHandler{ return &ProxyHandler{
adminService: adminService, adminService: adminService,
} }
} }
// CreateProxyRequest represents create proxy request // CreateProxyRequest represents create proxy request
type CreateProxyRequest struct { type CreateProxyRequest struct {
Name string `json:"name" binding:"required"` Name string `json:"name" binding:"required"`
Protocol string `json:"protocol" binding:"required,oneof=http https socks5"` Protocol string `json:"protocol" binding:"required,oneof=http https socks5 socks5h"`
Host string `json:"host" binding:"required"` Host string `json:"host" binding:"required"`
Port int `json:"port" binding:"required,min=1,max=65535"` Port int `json:"port" binding:"required,min=1,max=65535"`
Username string `json:"username"` Username string `json:"username"`
Password string `json:"password"` Password string `json:"password"`
} }
// UpdateProxyRequest represents update proxy request // UpdateProxyRequest represents update proxy request
type UpdateProxyRequest struct { type UpdateProxyRequest struct {
Name string `json:"name"` Name string `json:"name"`
Protocol string `json:"protocol" binding:"omitempty,oneof=http https socks5"` Protocol string `json:"protocol" binding:"omitempty,oneof=http https socks5 socks5h"`
Host string `json:"host"` Host string `json:"host"`
Port int `json:"port" binding:"omitempty,min=1,max=65535"` Port int `json:"port" binding:"omitempty,min=1,max=65535"`
Username string `json:"username"` Username string `json:"username"`
Password string `json:"password"` Password string `json:"password"`
Status string `json:"status" binding:"omitempty,oneof=active inactive"` Status string `json:"status" binding:"omitempty,oneof=active inactive"`
} }
// List handles listing all proxies with pagination // List handles listing all proxies with pagination
// GET /api/v1/admin/proxies // GET /api/v1/admin/proxies
func (h *ProxyHandler) List(c *gin.Context) { func (h *ProxyHandler) List(c *gin.Context) {
page, pageSize := response.ParsePagination(c) page, pageSize := response.ParsePagination(c)
protocol := c.Query("protocol") protocol := c.Query("protocol")
status := c.Query("status") status := c.Query("status")
search := c.Query("search") search := c.Query("search")
proxies, total, err := h.adminService.ListProxies(c.Request.Context(), page, pageSize, protocol, status, search) proxies, total, err := h.adminService.ListProxies(c.Request.Context(), page, pageSize, protocol, status, search)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
out := make([]dto.Proxy, 0, len(proxies)) out := make([]dto.Proxy, 0, len(proxies))
for i := range proxies { for i := range proxies {
out = append(out, *dto.ProxyFromService(&proxies[i])) out = append(out, *dto.ProxyFromService(&proxies[i]))
} }
response.Paginated(c, out, total, page, pageSize) response.Paginated(c, out, total, page, pageSize)
} }
// GetAll handles getting all active proxies without pagination // GetAll handles getting all active proxies without pagination
// GET /api/v1/admin/proxies/all // GET /api/v1/admin/proxies/all
// Optional query param: with_count=true to include account count per proxy // Optional query param: with_count=true to include account count per proxy
func (h *ProxyHandler) GetAll(c *gin.Context) { func (h *ProxyHandler) GetAll(c *gin.Context) {
withCount := c.Query("with_count") == "true" withCount := c.Query("with_count") == "true"
if withCount { if withCount {
proxies, err := h.adminService.GetAllProxiesWithAccountCount(c.Request.Context()) proxies, err := h.adminService.GetAllProxiesWithAccountCount(c.Request.Context())
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
out := make([]dto.ProxyWithAccountCount, 0, len(proxies)) out := make([]dto.ProxyWithAccountCount, 0, len(proxies))
for i := range proxies { for i := range proxies {
out = append(out, *dto.ProxyWithAccountCountFromService(&proxies[i])) out = append(out, *dto.ProxyWithAccountCountFromService(&proxies[i]))
} }
response.Success(c, out) response.Success(c, out)
return return
} }
proxies, err := h.adminService.GetAllProxies(c.Request.Context()) proxies, err := h.adminService.GetAllProxies(c.Request.Context())
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
out := make([]dto.Proxy, 0, len(proxies)) out := make([]dto.Proxy, 0, len(proxies))
for i := range proxies { for i := range proxies {
out = append(out, *dto.ProxyFromService(&proxies[i])) out = append(out, *dto.ProxyFromService(&proxies[i]))
} }
response.Success(c, out) response.Success(c, out)
} }
// GetByID handles getting a proxy by ID // GetByID handles getting a proxy by ID
// GET /api/v1/admin/proxies/:id // GET /api/v1/admin/proxies/:id
func (h *ProxyHandler) GetByID(c *gin.Context) { func (h *ProxyHandler) GetByID(c *gin.Context) {
proxyID, err := strconv.ParseInt(c.Param("id"), 10, 64) proxyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil { if err != nil {
response.BadRequest(c, "Invalid proxy ID") response.BadRequest(c, "Invalid proxy ID")
return return
} }
proxy, err := h.adminService.GetProxy(c.Request.Context(), proxyID) proxy, err := h.adminService.GetProxy(c.Request.Context(), proxyID)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
response.Success(c, dto.ProxyFromService(proxy)) response.Success(c, dto.ProxyFromService(proxy))
} }
// Create handles creating a new proxy // Create handles creating a new proxy
// POST /api/v1/admin/proxies // POST /api/v1/admin/proxies
func (h *ProxyHandler) Create(c *gin.Context) { func (h *ProxyHandler) Create(c *gin.Context) {
var req CreateProxyRequest var req CreateProxyRequest
if err := c.ShouldBindJSON(&req); err != nil { if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error()) response.BadRequest(c, "Invalid request: "+err.Error())
return return
} }
proxy, err := h.adminService.CreateProxy(c.Request.Context(), &service.CreateProxyInput{ proxy, err := h.adminService.CreateProxy(c.Request.Context(), &service.CreateProxyInput{
Name: strings.TrimSpace(req.Name), Name: strings.TrimSpace(req.Name),
Protocol: strings.TrimSpace(req.Protocol), Protocol: strings.TrimSpace(req.Protocol),
Host: strings.TrimSpace(req.Host), Host: strings.TrimSpace(req.Host),
Port: req.Port, Port: req.Port,
Username: strings.TrimSpace(req.Username), Username: strings.TrimSpace(req.Username),
Password: strings.TrimSpace(req.Password), Password: strings.TrimSpace(req.Password),
}) })
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
response.Success(c, dto.ProxyFromService(proxy)) response.Success(c, dto.ProxyFromService(proxy))
} }
// Update handles updating a proxy // Update handles updating a proxy
// PUT /api/v1/admin/proxies/:id // PUT /api/v1/admin/proxies/:id
func (h *ProxyHandler) Update(c *gin.Context) { func (h *ProxyHandler) Update(c *gin.Context) {
proxyID, err := strconv.ParseInt(c.Param("id"), 10, 64) proxyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil { if err != nil {
response.BadRequest(c, "Invalid proxy ID") response.BadRequest(c, "Invalid proxy ID")
return return
} }
var req UpdateProxyRequest var req UpdateProxyRequest
if err := c.ShouldBindJSON(&req); err != nil { if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error()) response.BadRequest(c, "Invalid request: "+err.Error())
return return
} }
proxy, err := h.adminService.UpdateProxy(c.Request.Context(), proxyID, &service.UpdateProxyInput{ proxy, err := h.adminService.UpdateProxy(c.Request.Context(), proxyID, &service.UpdateProxyInput{
Name: strings.TrimSpace(req.Name), Name: strings.TrimSpace(req.Name),
Protocol: strings.TrimSpace(req.Protocol), Protocol: strings.TrimSpace(req.Protocol),
Host: strings.TrimSpace(req.Host), Host: strings.TrimSpace(req.Host),
Port: req.Port, Port: req.Port,
Username: strings.TrimSpace(req.Username), Username: strings.TrimSpace(req.Username),
Password: strings.TrimSpace(req.Password), Password: strings.TrimSpace(req.Password),
Status: strings.TrimSpace(req.Status), Status: strings.TrimSpace(req.Status),
}) })
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
response.Success(c, dto.ProxyFromService(proxy)) response.Success(c, dto.ProxyFromService(proxy))
} }
// Delete handles deleting a proxy // Delete handles deleting a proxy
// DELETE /api/v1/admin/proxies/:id // DELETE /api/v1/admin/proxies/:id
func (h *ProxyHandler) Delete(c *gin.Context) { func (h *ProxyHandler) Delete(c *gin.Context) {
proxyID, err := strconv.ParseInt(c.Param("id"), 10, 64) proxyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil { if err != nil {
response.BadRequest(c, "Invalid proxy ID") response.BadRequest(c, "Invalid proxy ID")
return return
} }
err = h.adminService.DeleteProxy(c.Request.Context(), proxyID) err = h.adminService.DeleteProxy(c.Request.Context(), proxyID)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
response.Success(c, gin.H{"message": "Proxy deleted successfully"}) response.Success(c, gin.H{"message": "Proxy deleted successfully"})
} }
// Test handles testing proxy connectivity // Test handles testing proxy connectivity
// POST /api/v1/admin/proxies/:id/test // POST /api/v1/admin/proxies/:id/test
func (h *ProxyHandler) Test(c *gin.Context) { func (h *ProxyHandler) Test(c *gin.Context) {
proxyID, err := strconv.ParseInt(c.Param("id"), 10, 64) proxyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil { if err != nil {
response.BadRequest(c, "Invalid proxy ID") response.BadRequest(c, "Invalid proxy ID")
return return
} }
result, err := h.adminService.TestProxy(c.Request.Context(), proxyID) result, err := h.adminService.TestProxy(c.Request.Context(), proxyID)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
response.Success(c, result) response.Success(c, result)
} }
// GetStats handles getting proxy statistics // GetStats handles getting proxy statistics
// GET /api/v1/admin/proxies/:id/stats // GET /api/v1/admin/proxies/:id/stats
func (h *ProxyHandler) GetStats(c *gin.Context) { func (h *ProxyHandler) GetStats(c *gin.Context) {
proxyID, err := strconv.ParseInt(c.Param("id"), 10, 64) proxyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil { if err != nil {
response.BadRequest(c, "Invalid proxy ID") response.BadRequest(c, "Invalid proxy ID")
return return
} }
// Return mock data for now // Return mock data for now
_ = proxyID _ = proxyID
response.Success(c, gin.H{ response.Success(c, gin.H{
"total_accounts": 0, "total_accounts": 0,
"active_accounts": 0, "active_accounts": 0,
"total_requests": 0, "total_requests": 0,
"success_rate": 100.0, "success_rate": 100.0,
"average_latency": 0, "average_latency": 0,
}) })
} }
// GetProxyAccounts handles getting accounts using a proxy // GetProxyAccounts handles getting accounts using a proxy
// GET /api/v1/admin/proxies/:id/accounts // GET /api/v1/admin/proxies/:id/accounts
func (h *ProxyHandler) GetProxyAccounts(c *gin.Context) { func (h *ProxyHandler) GetProxyAccounts(c *gin.Context) {
proxyID, err := strconv.ParseInt(c.Param("id"), 10, 64) proxyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil { if err != nil {
response.BadRequest(c, "Invalid proxy ID") response.BadRequest(c, "Invalid proxy ID")
return return
} }
page, pageSize := response.ParsePagination(c) page, pageSize := response.ParsePagination(c)
accounts, total, err := h.adminService.GetProxyAccounts(c.Request.Context(), proxyID, page, pageSize) accounts, total, err := h.adminService.GetProxyAccounts(c.Request.Context(), proxyID, page, pageSize)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
out := make([]dto.Account, 0, len(accounts)) out := make([]dto.Account, 0, len(accounts))
for i := range accounts { for i := range accounts {
out = append(out, *dto.AccountFromService(&accounts[i])) out = append(out, *dto.AccountFromService(&accounts[i]))
} }
response.Paginated(c, out, total, page, pageSize) response.Paginated(c, out, total, page, pageSize)
} }
// BatchCreateProxyItem represents a single proxy in batch create request // BatchCreateProxyItem represents a single proxy in batch create request
type BatchCreateProxyItem struct { type BatchCreateProxyItem struct {
Protocol string `json:"protocol" binding:"required,oneof=http https socks5"` Protocol string `json:"protocol" binding:"required,oneof=http https socks5 socks5h"`
Host string `json:"host" binding:"required"` Host string `json:"host" binding:"required"`
Port int `json:"port" binding:"required,min=1,max=65535"` Port int `json:"port" binding:"required,min=1,max=65535"`
Username string `json:"username"` Username string `json:"username"`
Password string `json:"password"` Password string `json:"password"`
} }
// BatchCreateRequest represents batch create proxies request // BatchCreateRequest represents batch create proxies request
type BatchCreateRequest struct { type BatchCreateRequest struct {
Proxies []BatchCreateProxyItem `json:"proxies" binding:"required,min=1"` Proxies []BatchCreateProxyItem `json:"proxies" binding:"required,min=1"`
} }
// BatchCreate handles batch creating proxies // BatchCreate handles batch creating proxies
// POST /api/v1/admin/proxies/batch // POST /api/v1/admin/proxies/batch
func (h *ProxyHandler) BatchCreate(c *gin.Context) { func (h *ProxyHandler) BatchCreate(c *gin.Context) {
var req BatchCreateRequest var req BatchCreateRequest
if err := c.ShouldBindJSON(&req); err != nil { if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error()) response.BadRequest(c, "Invalid request: "+err.Error())
return return
} }
created := 0 created := 0
skipped := 0 skipped := 0
for _, item := range req.Proxies { for _, item := range req.Proxies {
// Trim all string fields // Trim all string fields
host := strings.TrimSpace(item.Host) host := strings.TrimSpace(item.Host)
protocol := strings.TrimSpace(item.Protocol) protocol := strings.TrimSpace(item.Protocol)
username := strings.TrimSpace(item.Username) username := strings.TrimSpace(item.Username)
password := strings.TrimSpace(item.Password) password := strings.TrimSpace(item.Password)
// Check for duplicates (same host, port, username, password) // Check for duplicates (same host, port, username, password)
exists, err := h.adminService.CheckProxyExists(c.Request.Context(), host, item.Port, username, password) exists, err := h.adminService.CheckProxyExists(c.Request.Context(), host, item.Port, username, password)
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
return return
} }
if exists { if exists {
skipped++ skipped++
continue continue
} }
// Create proxy with default name // Create proxy with default name
_, err = h.adminService.CreateProxy(c.Request.Context(), &service.CreateProxyInput{ _, err = h.adminService.CreateProxy(c.Request.Context(), &service.CreateProxyInput{
Name: "default", Name: "default",
Protocol: protocol, Protocol: protocol,
Host: host, Host: host,
Port: item.Port, Port: item.Port,
Username: username, Username: username,
Password: password, Password: password,
}) })
if err != nil { if err != nil {
// If creation fails due to duplicate, count as skipped // If creation fails due to duplicate, count as skipped
skipped++ skipped++
continue continue
} }
created++ created++
} }
response.Success(c, gin.H{ response.Success(c, gin.H{
"created": created, "created": created,
"skipped": skipped, "skipped": skipped,
}) })
} }

View File

@@ -1,157 +1,138 @@
// Package httpclient 提供共享 HTTP 客户端池 // Package httpclient 提供共享 HTTP 客户端池
// //
// 性能优化说明: // 性能优化说明:
// 原实现在多个服务中重复创建 http.Client // 原实现在多个服务中重复创建 http.Client
// 1. proxy_probe_service.go: 每次探测创建新客户端 // 1. proxy_probe_service.go: 每次探测创建新客户端
// 2. pricing_service.go: 每次请求创建新客户端 // 2. pricing_service.go: 每次请求创建新客户端
// 3. turnstile_service.go: 每次验证创建新客户端 // 3. turnstile_service.go: 每次验证创建新客户端
// 4. github_release_service.go: 每次请求创建新客户端 // 4. github_release_service.go: 每次请求创建新客户端
// 5. claude_usage_service.go: 每次请求创建新客户端 // 5. claude_usage_service.go: 每次请求创建新客户端
// //
// 新实现使用统一的客户端池: // 新实现使用统一的客户端池:
// 1. 相同配置复用同一 http.Client 实例 // 1. 相同配置复用同一 http.Client 实例
// 2. 复用 Transport 连接池,减少 TCP/TLS 握手开销 // 2. 复用 Transport 连接池,减少 TCP/TLS 握手开销
// 3. 支持 HTTP/HTTPS/SOCKS5 代理 // 3. 支持 HTTP/HTTPS/SOCKS5/SOCKS5H 代理
// 4. 支持严格代理模式(代理失败则返回错误 // 4. 代理配置失败时直接返回错误,不会回退到直连(避免 IP 关联风险
package httpclient package httpclient
import ( import (
"context" "crypto/tls"
"crypto/tls" "fmt"
"fmt" "net/http"
"net" "net/url"
"net/http" "strings"
"net/url" "sync"
"strings" "time"
"sync"
"time" "github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil"
)
"golang.org/x/net/proxy"
) // Transport 连接池默认配置
const (
// Transport 连接池默认配置 defaultMaxIdleConns = 100 // 最大空闲连接数
const ( defaultMaxIdleConnsPerHost = 10 // 每个主机最大空闲连接数
defaultMaxIdleConns = 100 // 最大空闲连接 defaultIdleConnTimeout = 90 * time.Second // 空闲连接超时时间
defaultMaxIdleConnsPerHost = 10 // 每个主机最大空闲连接数 )
defaultIdleConnTimeout = 90 * time.Second // 空闲连接超时时间
) // Options 定义共享 HTTP 客户端的构建参数
type Options struct {
// Options 定义共享 HTTP 客户端的构建参数 ProxyURL string // 代理 URL支持 http/https/socks5/socks5h
type Options struct { Timeout time.Duration // 请求总超时时间
ProxyURL string // 代理 URL支持 http/https/socks5 ResponseHeaderTimeout time.Duration // 等待响应头超时时间
Timeout time.Duration // 请求总超时时间 InsecureSkipVerify bool // 是否跳过 TLS 证书验证
ResponseHeaderTimeout time.Duration // 等待响应头超时时间
InsecureSkipVerify bool // 是否跳过 TLS 证书验证 // 可选的连接池参数(不设置则使用默认值)
ProxyStrict bool // 严格代理模式:代理失败时返回错误而非回退 MaxIdleConns int // 最大空闲连接总数(默认 100
MaxIdleConnsPerHost int // 每主机最大空闲连接(默认 10
// 可选的连接池参数(不设置则使用默认值 MaxConnsPerHost int // 每主机最大连接数(默认 0 无限制
MaxIdleConns int // 最大空闲连接总数(默认 100 }
MaxIdleConnsPerHost int // 每主机最大空闲连接(默认 10
MaxConnsPerHost int // 每主机最大连接数(默认 0 无限制) // sharedClients 存储按配置参数缓存的 http.Client 实例
} var sharedClients sync.Map
// sharedClients 存储按配置参数缓存的 http.Client 实例 // GetClient 返回共享的 HTTP 客户端实例
var sharedClients sync.Map // 性能优化:相同配置复用同一客户端,避免重复创建 Transport
// 安全说明:代理配置失败时直接返回错误,不会回退到直连,避免 IP 关联风险
// GetClient 返回共享的 HTTP 客户端实例 func GetClient(opts Options) (*http.Client, error) {
// 性能优化:相同配置复用同一客户端,避免重复创建 Transport key := buildClientKey(opts)
func GetClient(opts Options) (*http.Client, error) { if cached, ok := sharedClients.Load(key); ok {
key := buildClientKey(opts) if client, ok := cached.(*http.Client); ok {
if cached, ok := sharedClients.Load(key); ok { return client, nil
if client, ok := cached.(*http.Client); ok { }
return client, nil }
}
} client, err := buildClient(opts)
if err != nil {
client, err := buildClient(opts) return nil, err
if err != nil { }
if opts.ProxyStrict {
return nil, err actual, _ := sharedClients.LoadOrStore(key, client)
} if c, ok := actual.(*http.Client); ok {
fallback := opts return c, nil
fallback.ProxyURL = "" }
client, _ = buildClient(fallback) return client, nil
} }
actual, _ := sharedClients.LoadOrStore(key, client) func buildClient(opts Options) (*http.Client, error) {
if c, ok := actual.(*http.Client); ok { transport, err := buildTransport(opts)
return c, nil if err != nil {
} return nil, err
return client, nil }
}
return &http.Client{
func buildClient(opts Options) (*http.Client, error) { Transport: transport,
transport, err := buildTransport(opts) Timeout: opts.Timeout,
if err != nil { }, nil
return nil, err }
}
func buildTransport(opts Options) (*http.Transport, error) {
return &http.Client{ // 使用自定义值或默认值
Transport: transport, maxIdleConns := opts.MaxIdleConns
Timeout: opts.Timeout, if maxIdleConns <= 0 {
}, nil maxIdleConns = defaultMaxIdleConns
} }
maxIdleConnsPerHost := opts.MaxIdleConnsPerHost
func buildTransport(opts Options) (*http.Transport, error) { if maxIdleConnsPerHost <= 0 {
// 使用自定义值或默认值 maxIdleConnsPerHost = defaultMaxIdleConnsPerHost
maxIdleConns := opts.MaxIdleConns }
if maxIdleConns <= 0 {
maxIdleConns = defaultMaxIdleConns transport := &http.Transport{
} MaxIdleConns: maxIdleConns,
maxIdleConnsPerHost := opts.MaxIdleConnsPerHost MaxIdleConnsPerHost: maxIdleConnsPerHost,
if maxIdleConnsPerHost <= 0 { MaxConnsPerHost: opts.MaxConnsPerHost, // 0 表示无限制
maxIdleConnsPerHost = defaultMaxIdleConnsPerHost IdleConnTimeout: defaultIdleConnTimeout,
} ResponseHeaderTimeout: opts.ResponseHeaderTimeout,
}
transport := &http.Transport{
MaxIdleConns: maxIdleConns, if opts.InsecureSkipVerify {
MaxIdleConnsPerHost: maxIdleConnsPerHost, transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
MaxConnsPerHost: opts.MaxConnsPerHost, // 0 表示无限制 }
IdleConnTimeout: defaultIdleConnTimeout,
ResponseHeaderTimeout: opts.ResponseHeaderTimeout, proxyURL := strings.TrimSpace(opts.ProxyURL)
} if proxyURL == "" {
return transport, nil
if opts.InsecureSkipVerify { }
transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
} parsed, err := url.Parse(proxyURL)
if err != nil {
proxyURL := strings.TrimSpace(opts.ProxyURL) return nil, err
if proxyURL == "" { }
return transport, nil
} if err := proxyutil.ConfigureTransportProxy(transport, parsed); err != nil {
return nil, err
parsed, err := url.Parse(proxyURL) }
if err != nil {
return nil, err return transport, nil
} }
switch strings.ToLower(parsed.Scheme) { func buildClientKey(opts Options) string {
case "http", "https": return fmt.Sprintf("%s|%s|%s|%t|%d|%d|%d",
transport.Proxy = http.ProxyURL(parsed) strings.TrimSpace(opts.ProxyURL),
case "socks5", "socks5h": opts.Timeout.String(),
dialer, err := proxy.FromURL(parsed, proxy.Direct) opts.ResponseHeaderTimeout.String(),
if err != nil { opts.InsecureSkipVerify,
return nil, err opts.MaxIdleConns,
} opts.MaxIdleConnsPerHost,
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { opts.MaxConnsPerHost,
return dialer.Dial(network, addr) )
} }
default:
return nil, fmt.Errorf("unsupported proxy protocol: %s", parsed.Scheme)
}
return transport, nil
}
func buildClientKey(opts Options) string {
return fmt.Sprintf("%s|%s|%s|%t|%t|%d|%d|%d",
strings.TrimSpace(opts.ProxyURL),
opts.Timeout.String(),
opts.ResponseHeaderTimeout.String(),
opts.InsecureSkipVerify,
opts.ProxyStrict,
opts.MaxIdleConns,
opts.MaxIdleConnsPerHost,
opts.MaxConnsPerHost,
)
}

View File

@@ -0,0 +1,62 @@
// Package proxyutil 提供统一的代理配置功能
//
// 支持的代理协议:
// - HTTP/HTTPS: 通过 Transport.Proxy 设置
// - SOCKS5/SOCKS5H: 通过 Transport.DialContext 设置(服务端解析 DNS
package proxyutil
import (
"context"
"fmt"
"net"
"net/http"
"net/url"
"strings"
"golang.org/x/net/proxy"
)
// ConfigureTransportProxy 根据代理 URL 配置 Transport
//
// 支持的协议:
// - http/https: 设置 transport.Proxy
// - socks5/socks5h: 设置 transport.DialContext由代理服务端解析 DNS
//
// 参数:
// - transport: 需要配置的 http.Transport
// - proxyURL: 代理地址nil 表示直连
//
// 返回:
// - error: 代理配置错误(协议不支持或 dialer 创建失败)
func ConfigureTransportProxy(transport *http.Transport, proxyURL *url.URL) error {
if proxyURL == nil {
return nil
}
scheme := strings.ToLower(proxyURL.Scheme)
switch scheme {
case "http", "https":
transport.Proxy = http.ProxyURL(proxyURL)
return nil
case "socks5", "socks5h":
dialer, err := proxy.FromURL(proxyURL, proxy.Direct)
if err != nil {
return fmt.Errorf("create socks5 dialer: %w", err)
}
// 优先使用支持 context 的 DialContext以支持请求取消和超时
if contextDialer, ok := dialer.(proxy.ContextDialer); ok {
transport.DialContext = contextDialer.DialContext
} else {
// 回退路径:如果 dialer 不支持 ContextDialer则包装为简单的 DialContext
// 注意:此回退不支持请求取消和超时控制
transport.DialContext = func(_ context.Context, network, addr string) (net.Conn, error) {
return dialer.Dial(network, addr)
}
}
return nil
default:
return fmt.Errorf("unsupported proxy scheme: %s", scheme)
}
}

View File

@@ -0,0 +1,204 @@
package proxyutil
import (
"net/http"
"net/url"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestConfigureTransportProxy_Nil(t *testing.T) {
transport := &http.Transport{}
err := ConfigureTransportProxy(transport, nil)
require.NoError(t, err)
assert.Nil(t, transport.Proxy, "nil proxy should not set Proxy")
assert.Nil(t, transport.DialContext, "nil proxy should not set DialContext")
}
func TestConfigureTransportProxy_HTTP(t *testing.T) {
transport := &http.Transport{}
proxyURL, _ := url.Parse("http://proxy.example.com:8080")
err := ConfigureTransportProxy(transport, proxyURL)
require.NoError(t, err)
assert.NotNil(t, transport.Proxy, "HTTP proxy should set Proxy")
assert.Nil(t, transport.DialContext, "HTTP proxy should not set DialContext")
}
func TestConfigureTransportProxy_HTTPS(t *testing.T) {
transport := &http.Transport{}
proxyURL, _ := url.Parse("https://secure-proxy.example.com:8443")
err := ConfigureTransportProxy(transport, proxyURL)
require.NoError(t, err)
assert.NotNil(t, transport.Proxy, "HTTPS proxy should set Proxy")
assert.Nil(t, transport.DialContext, "HTTPS proxy should not set DialContext")
}
func TestConfigureTransportProxy_SOCKS5(t *testing.T) {
transport := &http.Transport{}
proxyURL, _ := url.Parse("socks5://socks.example.com:1080")
err := ConfigureTransportProxy(transport, proxyURL)
require.NoError(t, err)
assert.Nil(t, transport.Proxy, "SOCKS5 proxy should not set Proxy")
assert.NotNil(t, transport.DialContext, "SOCKS5 proxy should set DialContext")
}
func TestConfigureTransportProxy_SOCKS5H(t *testing.T) {
transport := &http.Transport{}
proxyURL, _ := url.Parse("socks5h://socks.example.com:1080")
err := ConfigureTransportProxy(transport, proxyURL)
require.NoError(t, err)
assert.Nil(t, transport.Proxy, "SOCKS5H proxy should not set Proxy")
assert.NotNil(t, transport.DialContext, "SOCKS5H proxy should set DialContext")
}
func TestConfigureTransportProxy_CaseInsensitive(t *testing.T) {
testCases := []struct {
scheme string
useProxy bool // true = uses Transport.Proxy, false = uses DialContext
}{
{"HTTP://proxy.example.com:8080", true},
{"Http://proxy.example.com:8080", true},
{"HTTPS://proxy.example.com:8443", true},
{"Https://proxy.example.com:8443", true},
{"SOCKS5://socks.example.com:1080", false},
{"Socks5://socks.example.com:1080", false},
{"SOCKS5H://socks.example.com:1080", false},
{"Socks5h://socks.example.com:1080", false},
}
for _, tc := range testCases {
t.Run(tc.scheme, func(t *testing.T) {
transport := &http.Transport{}
proxyURL, _ := url.Parse(tc.scheme)
err := ConfigureTransportProxy(transport, proxyURL)
require.NoError(t, err)
if tc.useProxy {
assert.NotNil(t, transport.Proxy)
assert.Nil(t, transport.DialContext)
} else {
assert.Nil(t, transport.Proxy)
assert.NotNil(t, transport.DialContext)
}
})
}
}
func TestConfigureTransportProxy_Unsupported(t *testing.T) {
testCases := []string{
"ftp://ftp.example.com",
"file:///path/to/file",
"unknown://example.com",
}
for _, tc := range testCases {
t.Run(tc, func(t *testing.T) {
transport := &http.Transport{}
proxyURL, _ := url.Parse(tc)
err := ConfigureTransportProxy(transport, proxyURL)
require.Error(t, err)
assert.Contains(t, err.Error(), "unsupported proxy scheme")
})
}
}
func TestConfigureTransportProxy_WithAuth(t *testing.T) {
transport := &http.Transport{}
proxyURL, _ := url.Parse("socks5://user:password@socks.example.com:1080")
err := ConfigureTransportProxy(transport, proxyURL)
require.NoError(t, err)
assert.NotNil(t, transport.DialContext, "SOCKS5 with auth should set DialContext")
}
func TestConfigureTransportProxy_EmptyScheme(t *testing.T) {
transport := &http.Transport{}
// 空 scheme 的 URL
proxyURL := &url.URL{Host: "proxy.example.com:8080"}
err := ConfigureTransportProxy(transport, proxyURL)
require.Error(t, err)
assert.Contains(t, err.Error(), "unsupported proxy scheme")
}
func TestConfigureTransportProxy_PreservesExistingConfig(t *testing.T) {
// 验证代理配置不会覆盖 Transport 的其他配置
transport := &http.Transport{
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
}
proxyURL, _ := url.Parse("socks5://socks.example.com:1080")
err := ConfigureTransportProxy(transport, proxyURL)
require.NoError(t, err)
assert.Equal(t, 100, transport.MaxIdleConns, "MaxIdleConns should be preserved")
assert.Equal(t, 10, transport.MaxIdleConnsPerHost, "MaxIdleConnsPerHost should be preserved")
assert.NotNil(t, transport.DialContext, "DialContext should be set")
}
func TestConfigureTransportProxy_IPv6(t *testing.T) {
testCases := []struct {
name string
proxyURL string
}{
{"SOCKS5H with IPv6 loopback", "socks5h://[::1]:1080"},
{"SOCKS5 with full IPv6", "socks5://[2001:db8::1]:1080"},
{"HTTP with IPv6", "http://[::1]:8080"},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
transport := &http.Transport{}
proxyURL, err := url.Parse(tc.proxyURL)
require.NoError(t, err, "URL should be parseable")
err = ConfigureTransportProxy(transport, proxyURL)
require.NoError(t, err)
})
}
}
func TestConfigureTransportProxy_SpecialCharsInPassword(t *testing.T) {
testCases := []struct {
name string
proxyURL string
}{
// 密码包含 @ 符号URL 编码为 %40
{"password with @", "socks5://user:p%40ssword@proxy.example.com:1080"},
// 密码包含 : 符号URL 编码为 %3A
{"password with :", "socks5://user:pass%3Aword@proxy.example.com:1080"},
// 密码包含 / 符号URL 编码为 %2F
{"password with /", "socks5://user:pass%2Fword@proxy.example.com:1080"},
// 复杂密码
{"complex password", "socks5h://admin:P%40ss%3Aw0rd%2F123@proxy.example.com:1080"},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
transport := &http.Transport{}
proxyURL, err := url.Parse(tc.proxyURL)
require.NoError(t, err, "URL should be parseable")
err = ConfigureTransportProxy(transport, proxyURL)
require.NoError(t, err)
assert.NotNil(t, transport.DialContext, "SOCKS5 should set DialContext")
})
}
}

View File

@@ -1,251 +1,257 @@
package repository package repository
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"log" "log"
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth" "github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/imroc/req/v3" "github.com/imroc/req/v3"
) )
func NewClaudeOAuthClient() service.ClaudeOAuthClient { func NewClaudeOAuthClient() service.ClaudeOAuthClient {
return &claudeOAuthService{ return &claudeOAuthService{
baseURL: "https://claude.ai", baseURL: "https://claude.ai",
tokenURL: oauth.TokenURL, tokenURL: oauth.TokenURL,
clientFactory: createReqClient, clientFactory: createReqClient,
} }
} }
type claudeOAuthService struct { type claudeOAuthService struct {
baseURL string baseURL string
tokenURL string tokenURL string
clientFactory func(proxyURL string) *req.Client clientFactory func(proxyURL string) *req.Client
} }
func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey, proxyURL string) (string, error) { func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey, proxyURL string) (string, error) {
client := s.clientFactory(proxyURL) client := s.clientFactory(proxyURL)
var orgs []struct { var orgs []struct {
UUID string `json:"uuid"` UUID string `json:"uuid"`
} }
targetURL := s.baseURL + "/api/organizations" targetURL := s.baseURL + "/api/organizations"
log.Printf("[OAuth] Step 1: Getting organization UUID from %s", targetURL) log.Printf("[OAuth] Step 1: Getting organization UUID from %s", targetURL)
resp, err := client.R(). resp, err := client.R().
SetContext(ctx). SetContext(ctx).
SetCookies(&http.Cookie{ SetCookies(&http.Cookie{
Name: "sessionKey", Name: "sessionKey",
Value: sessionKey, Value: sessionKey,
}). }).
SetSuccessResult(&orgs). SetSuccessResult(&orgs).
Get(targetURL) Get(targetURL)
if err != nil { if err != nil {
log.Printf("[OAuth] Step 1 FAILED - Request error: %v", err) log.Printf("[OAuth] Step 1 FAILED - Request error: %v", err)
return "", fmt.Errorf("request failed: %w", err) return "", fmt.Errorf("request failed: %w", err)
} }
log.Printf("[OAuth] Step 1 Response - Status: %d, Body: %s", resp.StatusCode, resp.String()) log.Printf("[OAuth] Step 1 Response - Status: %d, Body: %s", resp.StatusCode, resp.String())
if !resp.IsSuccessState() { if !resp.IsSuccessState() {
return "", fmt.Errorf("failed to get organizations: status %d, body: %s", resp.StatusCode, resp.String()) return "", fmt.Errorf("failed to get organizations: status %d, body: %s", resp.StatusCode, resp.String())
} }
if len(orgs) == 0 { if len(orgs) == 0 {
return "", fmt.Errorf("no organizations found") return "", fmt.Errorf("no organizations found")
} }
log.Printf("[OAuth] Step 1 SUCCESS - Got org UUID: %s", orgs[0].UUID) log.Printf("[OAuth] Step 1 SUCCESS - Got org UUID: %s", orgs[0].UUID)
return orgs[0].UUID, nil return orgs[0].UUID, nil
} }
func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error) { func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error) {
client := s.clientFactory(proxyURL) client := s.clientFactory(proxyURL)
authURL := fmt.Sprintf("%s/v1/oauth/%s/authorize", s.baseURL, orgUUID) authURL := fmt.Sprintf("%s/v1/oauth/%s/authorize", s.baseURL, orgUUID)
reqBody := map[string]any{ reqBody := map[string]any{
"response_type": "code", "response_type": "code",
"client_id": oauth.ClientID, "client_id": oauth.ClientID,
"organization_uuid": orgUUID, "organization_uuid": orgUUID,
"redirect_uri": oauth.RedirectURI, "redirect_uri": oauth.RedirectURI,
"scope": scope, "scope": scope,
"state": state, "state": state,
"code_challenge": codeChallenge, "code_challenge": codeChallenge,
"code_challenge_method": "S256", "code_challenge_method": "S256",
} }
reqBodyJSON, _ := json.Marshal(reqBody) reqBodyJSON, _ := json.Marshal(reqBody)
log.Printf("[OAuth] Step 2: Getting authorization code from %s", authURL) log.Printf("[OAuth] Step 2: Getting authorization code from %s", authURL)
log.Printf("[OAuth] Step 2 Request Body: %s", string(reqBodyJSON)) log.Printf("[OAuth] Step 2 Request Body: %s", string(reqBodyJSON))
var result struct { var result struct {
RedirectURI string `json:"redirect_uri"` RedirectURI string `json:"redirect_uri"`
} }
resp, err := client.R(). resp, err := client.R().
SetContext(ctx). SetContext(ctx).
SetCookies(&http.Cookie{ SetCookies(&http.Cookie{
Name: "sessionKey", Name: "sessionKey",
Value: sessionKey, Value: sessionKey,
}). }).
SetHeader("Accept", "application/json"). SetHeader("Accept", "application/json").
SetHeader("Accept-Language", "en-US,en;q=0.9"). SetHeader("Accept-Language", "en-US,en;q=0.9").
SetHeader("Cache-Control", "no-cache"). SetHeader("Cache-Control", "no-cache").
SetHeader("Origin", "https://claude.ai"). SetHeader("Origin", "https://claude.ai").
SetHeader("Referer", "https://claude.ai/new"). SetHeader("Referer", "https://claude.ai/new").
SetHeader("Content-Type", "application/json"). SetHeader("Content-Type", "application/json").
SetBody(reqBody). SetBody(reqBody).
SetSuccessResult(&result). SetSuccessResult(&result).
Post(authURL) Post(authURL)
if err != nil { if err != nil {
log.Printf("[OAuth] Step 2 FAILED - Request error: %v", err) log.Printf("[OAuth] Step 2 FAILED - Request error: %v", err)
return "", fmt.Errorf("request failed: %w", err) return "", fmt.Errorf("request failed: %w", err)
} }
log.Printf("[OAuth] Step 2 Response - Status: %d, Body: %s", resp.StatusCode, resp.String()) log.Printf("[OAuth] Step 2 Response - Status: %d, Body: %s", resp.StatusCode, resp.String())
if !resp.IsSuccessState() { if !resp.IsSuccessState() {
return "", fmt.Errorf("failed to get authorization code: status %d, body: %s", resp.StatusCode, resp.String()) return "", fmt.Errorf("failed to get authorization code: status %d, body: %s", resp.StatusCode, resp.String())
} }
if result.RedirectURI == "" { if result.RedirectURI == "" {
return "", fmt.Errorf("no redirect_uri in response") return "", fmt.Errorf("no redirect_uri in response")
} }
parsedURL, err := url.Parse(result.RedirectURI) parsedURL, err := url.Parse(result.RedirectURI)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to parse redirect_uri: %w", err) return "", fmt.Errorf("failed to parse redirect_uri: %w", err)
} }
queryParams := parsedURL.Query() queryParams := parsedURL.Query()
authCode := queryParams.Get("code") authCode := queryParams.Get("code")
responseState := queryParams.Get("state") responseState := queryParams.Get("state")
if authCode == "" { if authCode == "" {
return "", fmt.Errorf("no authorization code in redirect_uri") return "", fmt.Errorf("no authorization code in redirect_uri")
} }
fullCode := authCode fullCode := authCode
if responseState != "" { if responseState != "" {
fullCode = authCode + "#" + responseState fullCode = authCode + "#" + responseState
} }
log.Printf("[OAuth] Step 2 SUCCESS - Got authorization code: %s...", prefix(authCode, 20)) log.Printf("[OAuth] Step 2 SUCCESS - Got authorization code: %s...", prefix(authCode, 20))
return fullCode, nil return fullCode, nil
} }
func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error) { func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error) {
client := s.clientFactory(proxyURL) client := s.clientFactory(proxyURL)
// Parse code which may contain state in format "authCode#state" // Parse code which may contain state in format "authCode#state"
authCode := code authCode := code
codeState := "" codeState := ""
if idx := strings.Index(code, "#"); idx != -1 { if idx := strings.Index(code, "#"); idx != -1 {
authCode = code[:idx] authCode = code[:idx]
codeState = code[idx+1:] codeState = code[idx+1:]
} }
reqBody := map[string]any{ reqBody := map[string]any{
"code": authCode, "code": authCode,
"grant_type": "authorization_code", "grant_type": "authorization_code",
"client_id": oauth.ClientID, "client_id": oauth.ClientID,
"redirect_uri": oauth.RedirectURI, "redirect_uri": oauth.RedirectURI,
"code_verifier": codeVerifier, "code_verifier": codeVerifier,
} }
if codeState != "" { if codeState != "" {
reqBody["state"] = codeState reqBody["state"] = codeState
} }
// Setup token requires longer expiration (1 year) // Setup token requires longer expiration (1 year)
if isSetupToken { if isSetupToken {
reqBody["expires_in"] = 31536000 // 365 * 24 * 60 * 60 seconds reqBody["expires_in"] = 31536000 // 365 * 24 * 60 * 60 seconds
} }
reqBodyJSON, _ := json.Marshal(reqBody) reqBodyJSON, _ := json.Marshal(reqBody)
log.Printf("[OAuth] Step 3: Exchanging code for token at %s", s.tokenURL) log.Printf("[OAuth] Step 3: Exchanging code for token at %s", s.tokenURL)
log.Printf("[OAuth] Step 3 Request Body: %s", string(reqBodyJSON)) log.Printf("[OAuth] Step 3 Request Body: %s", string(reqBodyJSON))
var tokenResp oauth.TokenResponse var tokenResp oauth.TokenResponse
resp, err := client.R(). resp, err := client.R().
SetContext(ctx). SetContext(ctx).
SetHeader("Content-Type", "application/json"). SetHeader("Content-Type", "application/json").
SetBody(reqBody). SetBody(reqBody).
SetSuccessResult(&tokenResp). SetSuccessResult(&tokenResp).
Post(s.tokenURL) Post(s.tokenURL)
if err != nil { if err != nil {
log.Printf("[OAuth] Step 3 FAILED - Request error: %v", err) log.Printf("[OAuth] Step 3 FAILED - Request error: %v", err)
return nil, fmt.Errorf("request failed: %w", err) return nil, fmt.Errorf("request failed: %w", err)
} }
log.Printf("[OAuth] Step 3 Response - Status: %d, Body: %s", resp.StatusCode, resp.String()) log.Printf("[OAuth] Step 3 Response - Status: %d, Body: %s", resp.StatusCode, resp.String())
if !resp.IsSuccessState() { if !resp.IsSuccessState() {
return nil, fmt.Errorf("token exchange failed: status %d, body: %s", resp.StatusCode, resp.String()) return nil, fmt.Errorf("token exchange failed: status %d, body: %s", resp.StatusCode, resp.String())
} }
log.Printf("[OAuth] Step 3 SUCCESS - Got access token") log.Printf("[OAuth] Step 3 SUCCESS - Got access token")
return &tokenResp, nil return &tokenResp, nil
} }
func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) { func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) {
client := s.clientFactory(proxyURL) client := s.clientFactory(proxyURL)
// 使用 JSON 格式(与 ExchangeCodeForToken 保持一致) // 使用 JSON 格式(与 ExchangeCodeForToken 保持一致)
// Anthropic OAuth API 期望 JSON 格式的请求体 // Anthropic OAuth API 期望 JSON 格式的请求体
reqBody := map[string]any{ reqBody := map[string]any{
"grant_type": "refresh_token", "grant_type": "refresh_token",
"refresh_token": refreshToken, "refresh_token": refreshToken,
"client_id": oauth.ClientID, "client_id": oauth.ClientID,
} }
var tokenResp oauth.TokenResponse var tokenResp oauth.TokenResponse
resp, err := client.R(). resp, err := client.R().
SetContext(ctx). SetContext(ctx).
SetHeader("Content-Type", "application/json"). SetHeader("Content-Type", "application/json").
SetBody(reqBody). SetBody(reqBody).
SetSuccessResult(&tokenResp). SetSuccessResult(&tokenResp).
Post(s.tokenURL) Post(s.tokenURL)
if err != nil { if err != nil {
return nil, fmt.Errorf("request failed: %w", err) return nil, fmt.Errorf("request failed: %w", err)
} }
if !resp.IsSuccessState() { if !resp.IsSuccessState() {
return nil, fmt.Errorf("token refresh failed: status %d, body: %s", resp.StatusCode, resp.String()) return nil, fmt.Errorf("token refresh failed: status %d, body: %s", resp.StatusCode, resp.String())
} }
return &tokenResp, nil return &tokenResp, nil
} }
func createReqClient(proxyURL string) *req.Client { func createReqClient(proxyURL string) *req.Client {
return getSharedReqClient(reqClientOptions{ // 禁用 CookieJar确保每次授权都是干净的会话
ProxyURL: proxyURL, client := req.C().
Timeout: 60 * time.Second, SetTimeout(60 * time.Second).
Impersonate: true, ImpersonateChrome().
}) SetCookieJar(nil) // 禁用 CookieJar
}
if strings.TrimSpace(proxyURL) != "" {
func prefix(s string, n int) string { client.SetProxyURL(strings.TrimSpace(proxyURL))
if n <= 0 { }
return ""
} return client
if len(s) <= n { }
return s
} func prefix(s string, n int) string {
return s[:n] if n <= 0 {
} return ""
}
if len(s) <= n {
return s
}
return s[:n]
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,66 +1,70 @@
package repository package repository
import ( import (
"net/http" "net/http"
"net/url" "net/url"
"testing" "testing"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
) )
// httpClientSink 用于防止编译器优化掉基准测试中的赋值操作 // httpClientSink 用于防止编译器优化掉基准测试中的赋值操作
// 这是 Go 基准测试的常见模式,确保测试结果准确 // 这是 Go 基准测试的常见模式,确保测试结果准确
var httpClientSink *http.Client var httpClientSink *http.Client
// BenchmarkHTTPUpstreamProxyClient 对比重复创建与复用代理客户端的开销 // BenchmarkHTTPUpstreamProxyClient 对比重复创建与复用代理客户端的开销
// //
// 测试目的: // 测试目的:
// - 验证连接池复用相比每次新建的性能提升 // - 验证连接池复用相比每次新建的性能提升
// - 量化内存分配差异 // - 量化内存分配差异
// //
// 预期结果: // 预期结果:
// - "复用" 子测试应显著快于 "新建" // - "复用" 子测试应显著快于 "新建"
// - "复用" 子测试应零内存分配 // - "复用" 子测试应零内存分配
func BenchmarkHTTPUpstreamProxyClient(b *testing.B) { func BenchmarkHTTPUpstreamProxyClient(b *testing.B) {
// 创建测试配置 // 创建测试配置
cfg := &config.Config{ cfg := &config.Config{
Gateway: config.GatewayConfig{ResponseHeaderTimeout: 300}, Gateway: config.GatewayConfig{ResponseHeaderTimeout: 300},
} }
upstream := NewHTTPUpstream(cfg) upstream := NewHTTPUpstream(cfg)
svc, ok := upstream.(*httpUpstreamService) svc, ok := upstream.(*httpUpstreamService)
if !ok { if !ok {
b.Fatalf("类型断言失败,无法获取 httpUpstreamService") b.Fatalf("类型断言失败,无法获取 httpUpstreamService")
} }
proxyURL := "http://127.0.0.1:8080" proxyURL := "http://127.0.0.1:8080"
b.ReportAllocs() // 报告内存分配统计 b.ReportAllocs() // 报告内存分配统计
// 子测试:每次新建客户端 // 子测试:每次新建客户端
// 模拟未优化前的行为,每次请求都创建新的 http.Client // 模拟未优化前的行为,每次请求都创建新的 http.Client
b.Run("新建", func(b *testing.B) { b.Run("新建", func(b *testing.B) {
parsedProxy, err := url.Parse(proxyURL) parsedProxy, err := url.Parse(proxyURL)
if err != nil { if err != nil {
b.Fatalf("解析代理地址失败: %v", err) b.Fatalf("解析代理地址失败: %v", err)
} }
settings := defaultPoolSettings(cfg) settings := defaultPoolSettings(cfg)
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
// 每次迭代都创建新客户端,包含 Transport 分配 // 每次迭代都创建新客户端,包含 Transport 分配
httpClientSink = &http.Client{ transport, err := buildUpstreamTransport(settings, parsedProxy)
Transport: buildUpstreamTransport(settings, parsedProxy), if err != nil {
} b.Fatalf("创建 Transport 失败: %v", err)
} }
}) httpClientSink = &http.Client{
Transport: transport,
// 子测试:复用已缓存的客户端 }
// 模拟优化后的行为,从缓存获取客户端 }
b.Run("复用", func(b *testing.B) { })
// 预热:确保客户端已缓存
entry := svc.getOrCreateClient(proxyURL, 1, 1) // 子测试:复用已缓存的客户端
client := entry.client // 模拟优化后的行为,从缓存获取客户端
b.ResetTimer() // 重置计时器,排除预热时间 b.Run("复用", func(b *testing.B) {
for i := 0; i < b.N; i++ { // 预热:确保客户端已缓存
// 直接使用缓存的客户端,无内存分配 entry := svc.getOrCreateClient(proxyURL, 1, 1)
httpClientSink = client client := entry.client
} b.ResetTimer() // 重置计时器,排除预热时间
}) for i := 0; i < b.N; i++ {
} // 直接使用缓存的客户端,无内存分配
httpClientSink = client
}
})
}

View File

@@ -1,76 +1,75 @@
package repository package repository
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
) )
func NewProxyExitInfoProber() service.ProxyExitInfoProber { func NewProxyExitInfoProber() service.ProxyExitInfoProber {
return &proxyProbeService{ipInfoURL: defaultIPInfoURL} return &proxyProbeService{ipInfoURL: defaultIPInfoURL}
} }
const defaultIPInfoURL = "https://ipinfo.io/json" const defaultIPInfoURL = "https://ipinfo.io/json"
type proxyProbeService struct { type proxyProbeService struct {
ipInfoURL string ipInfoURL string
} }
func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*service.ProxyExitInfo, int64, error) { func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*service.ProxyExitInfo, int64, error) {
client, err := httpclient.GetClient(httpclient.Options{ client, err := httpclient.GetClient(httpclient.Options{
ProxyURL: proxyURL, ProxyURL: proxyURL,
Timeout: 15 * time.Second, Timeout: 15 * time.Second,
InsecureSkipVerify: true, InsecureSkipVerify: true,
ProxyStrict: true, })
}) if err != nil {
if err != nil { return nil, 0, fmt.Errorf("failed to create proxy client: %w", err)
return nil, 0, fmt.Errorf("failed to create proxy client: %w", err) }
}
startTime := time.Now()
startTime := time.Now() req, err := http.NewRequestWithContext(ctx, "GET", s.ipInfoURL, nil)
req, err := http.NewRequestWithContext(ctx, "GET", s.ipInfoURL, nil) if err != nil {
if err != nil { return nil, 0, fmt.Errorf("failed to create request: %w", err)
return nil, 0, fmt.Errorf("failed to create request: %w", err) }
}
resp, err := client.Do(req)
resp, err := client.Do(req) if err != nil {
if err != nil { return nil, 0, fmt.Errorf("proxy connection failed: %w", err)
return nil, 0, fmt.Errorf("proxy connection failed: %w", err) }
} defer func() { _ = resp.Body.Close() }()
defer func() { _ = resp.Body.Close() }()
latencyMs := time.Since(startTime).Milliseconds()
latencyMs := time.Since(startTime).Milliseconds()
if resp.StatusCode != http.StatusOK {
if resp.StatusCode != http.StatusOK { return nil, latencyMs, fmt.Errorf("request failed with status: %d", resp.StatusCode)
return nil, latencyMs, fmt.Errorf("request failed with status: %d", resp.StatusCode) }
}
var ipInfo struct {
var ipInfo struct { IP string `json:"ip"`
IP string `json:"ip"` City string `json:"city"`
City string `json:"city"` Region string `json:"region"`
Region string `json:"region"` Country string `json:"country"`
Country string `json:"country"` }
}
body, err := io.ReadAll(resp.Body)
body, err := io.ReadAll(resp.Body) if err != nil {
if err != nil { return nil, latencyMs, fmt.Errorf("failed to read response: %w", err)
return nil, latencyMs, fmt.Errorf("failed to read response: %w", err) }
}
if err := json.Unmarshal(body, &ipInfo); err != nil {
if err := json.Unmarshal(body, &ipInfo); err != nil { return nil, latencyMs, fmt.Errorf("failed to parse response: %w", err)
return nil, latencyMs, fmt.Errorf("failed to parse response: %w", err) }
}
return &service.ProxyExitInfo{
return &service.ProxyExitInfo{ IP: ipInfo.IP,
IP: ipInfo.IP, City: ipInfo.City,
City: ipInfo.City, Region: ipInfo.Region,
Region: ipInfo.Region, Country: ipInfo.Country,
Country: ipInfo.Country, }, latencyMs, nil
}, latencyMs, nil }
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,233 @@
package service
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/require"
)
func TestIsClaudeCodeClient(t *testing.T) {
tests := []struct {
name string
userAgent string
metadataUserID string
want bool
}{
{
name: "Claude Code client",
userAgent: "claude-cli/1.0.62 (darwin; arm64)",
metadataUserID: "session_123e4567-e89b-12d3-a456-426614174000",
want: true,
},
{
name: "Claude Code without version suffix",
userAgent: "claude-cli/2.0.0",
metadataUserID: "session_abc",
want: true,
},
{
name: "Missing metadata user_id",
userAgent: "claude-cli/1.0.0",
metadataUserID: "",
want: false,
},
{
name: "Different user agent",
userAgent: "curl/7.68.0",
metadataUserID: "user123",
want: false,
},
{
name: "Empty user agent",
userAgent: "",
metadataUserID: "user123",
want: false,
},
{
name: "Similar but not Claude CLI",
userAgent: "claude-api/1.0.0",
metadataUserID: "user123",
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := isClaudeCodeClient(tt.userAgent, tt.metadataUserID)
require.Equal(t, tt.want, got)
})
}
}
func TestSystemIncludesClaudeCodePrompt(t *testing.T) {
tests := []struct {
name string
system any
want bool
}{
{
name: "nil system",
system: nil,
want: false,
},
{
name: "empty string",
system: "",
want: false,
},
{
name: "string with Claude Code prompt",
system: claudeCodeSystemPrompt,
want: true,
},
{
name: "string with different content",
system: "You are a helpful assistant.",
want: false,
},
{
name: "empty array",
system: []any{},
want: false,
},
{
name: "array with Claude Code prompt",
system: []any{
map[string]any{
"type": "text",
"text": claudeCodeSystemPrompt,
},
},
want: true,
},
{
name: "array with Claude Code prompt in second position",
system: []any{
map[string]any{"type": "text", "text": "First prompt"},
map[string]any{"type": "text", "text": claudeCodeSystemPrompt},
},
want: true,
},
{
name: "array without Claude Code prompt",
system: []any{
map[string]any{"type": "text", "text": "Custom prompt"},
},
want: false,
},
{
name: "array with partial match (should not match)",
system: []any{
map[string]any{"type": "text", "text": "You are Claude"},
},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := systemIncludesClaudeCodePrompt(tt.system)
require.Equal(t, tt.want, got)
})
}
}
func TestInjectClaudeCodePrompt(t *testing.T) {
tests := []struct {
name string
body string
system any
wantSystemLen int
wantFirstText string
wantSecondText string
}{
{
name: "nil system",
body: `{"model":"claude-3"}`,
system: nil,
wantSystemLen: 1,
wantFirstText: claudeCodeSystemPrompt,
},
{
name: "empty string system",
body: `{"model":"claude-3"}`,
system: "",
wantSystemLen: 1,
wantFirstText: claudeCodeSystemPrompt,
},
{
name: "string system",
body: `{"model":"claude-3"}`,
system: "Custom prompt",
wantSystemLen: 2,
wantFirstText: claudeCodeSystemPrompt,
wantSecondText: "Custom prompt",
},
{
name: "string system equals Claude Code prompt",
body: `{"model":"claude-3"}`,
system: claudeCodeSystemPrompt,
wantSystemLen: 1,
wantFirstText: claudeCodeSystemPrompt,
},
{
name: "array system",
body: `{"model":"claude-3"}`,
system: []any{map[string]any{"type": "text", "text": "Custom"}},
// Claude Code + Custom = 2
wantSystemLen: 2,
wantFirstText: claudeCodeSystemPrompt,
wantSecondText: "Custom",
},
{
name: "array system with existing Claude Code prompt (should dedupe)",
body: `{"model":"claude-3"}`,
system: []any{
map[string]any{"type": "text", "text": claudeCodeSystemPrompt},
map[string]any{"type": "text", "text": "Other"},
},
// Claude Code at start + Other = 2 (deduped)
wantSystemLen: 2,
wantFirstText: claudeCodeSystemPrompt,
wantSecondText: "Other",
},
{
name: "empty array",
body: `{"model":"claude-3"}`,
system: []any{},
wantSystemLen: 1,
wantFirstText: claudeCodeSystemPrompt,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := injectClaudeCodePrompt([]byte(tt.body), tt.system)
var parsed map[string]any
err := json.Unmarshal(result, &parsed)
require.NoError(t, err)
system, ok := parsed["system"].([]any)
require.True(t, ok, "system should be an array")
require.Len(t, system, tt.wantSystemLen)
first, ok := system[0].(map[string]any)
require.True(t, ok)
require.Equal(t, tt.wantFirstText, first["text"])
require.Equal(t, "text", first["type"])
// Check cache_control
cc, ok := first["cache_control"].(map[string]any)
require.True(t, ok)
require.Equal(t, "ephemeral", cc["type"])
if tt.wantSecondText != "" && len(system) > 1 {
second, ok := system[1].(map[string]any)
require.True(t, ok)
require.Equal(t, tt.wantSecondText, second["text"])
}
})
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,444 +1,452 @@
<template> <template>
<div class="relative" ref="containerRef"> <div class="relative" ref="containerRef">
<button <button
type="button" type="button"
@click="toggle" @click="toggle"
:class="['date-picker-trigger', isOpen && 'date-picker-trigger-open']" :class="['date-picker-trigger', isOpen && 'date-picker-trigger-open']"
> >
<span class="date-picker-icon"> <span class="date-picker-icon">
<svg <svg
class="h-4 w-4" class="h-4 w-4"
fill="none" fill="none"
stroke="currentColor" stroke="currentColor"
viewBox="0 0 24 24" viewBox="0 0 24 24"
stroke-width="1.5" stroke-width="1.5"
> >
<path <path
stroke-linecap="round" stroke-linecap="round"
stroke-linejoin="round" stroke-linejoin="round"
d="M6.75 3v2.25M17.25 3v2.25M3 18.75V7.5a2.25 2.25 0 012.25-2.25h13.5A2.25 2.25 0 0121 7.5v11.25m-18 0A2.25 2.25 0 005.25 21h13.5A2.25 2.25 0 0021 18.75m-18 0v-7.5A2.25 2.25 0 015.25 9h13.5A2.25 2.25 0 0121 11.25v7.5" d="M6.75 3v2.25M17.25 3v2.25M3 18.75V7.5a2.25 2.25 0 012.25-2.25h13.5A2.25 2.25 0 0121 7.5v11.25m-18 0A2.25 2.25 0 005.25 21h13.5A2.25 2.25 0 0021 18.75m-18 0v-7.5A2.25 2.25 0 015.25 9h13.5A2.25 2.25 0 0121 11.25v7.5"
/> />
</svg> </svg>
</span> </span>
<span class="date-picker-value"> <span class="date-picker-value">
{{ displayValue }} {{ displayValue }}
</span> </span>
<span class="date-picker-chevron"> <span class="date-picker-chevron">
<svg <svg
:class="['h-4 w-4 transition-transform duration-200', isOpen && 'rotate-180']" :class="['h-4 w-4 transition-transform duration-200', isOpen && 'rotate-180']"
fill="none" fill="none"
stroke="currentColor" stroke="currentColor"
viewBox="0 0 24 24" viewBox="0 0 24 24"
stroke-width="1.5" stroke-width="1.5"
> >
<path stroke-linecap="round" stroke-linejoin="round" d="M19.5 8.25l-7.5 7.5-7.5-7.5" /> <path stroke-linecap="round" stroke-linejoin="round" d="M19.5 8.25l-7.5 7.5-7.5-7.5" />
</svg> </svg>
</span> </span>
</button> </button>
<Transition name="date-picker-dropdown"> <Transition name="date-picker-dropdown">
<div v-if="isOpen" class="date-picker-dropdown"> <div v-if="isOpen" class="date-picker-dropdown">
<!-- Quick presets --> <!-- Quick presets -->
<div class="date-picker-presets"> <div class="date-picker-presets">
<button <button
v-for="preset in presets" v-for="preset in presets"
:key="preset.value" :key="preset.value"
@click="selectPreset(preset)" @click="selectPreset(preset)"
:class="['date-picker-preset', isPresetActive(preset) && 'date-picker-preset-active']" :class="['date-picker-preset', isPresetActive(preset) && 'date-picker-preset-active']"
> >
{{ t(preset.labelKey) }} {{ t(preset.labelKey) }}
</button> </button>
</div> </div>
<div class="date-picker-divider"></div> <div class="date-picker-divider"></div>
<!-- Custom date range inputs --> <!-- Custom date range inputs -->
<div class="date-picker-custom"> <div class="date-picker-custom">
<div class="date-picker-field"> <div class="date-picker-field">
<label class="date-picker-label">{{ t('dates.startDate') }}</label> <label class="date-picker-label">{{ t('dates.startDate') }}</label>
<input <input
type="date" type="date"
v-model="localStartDate" v-model="localStartDate"
:max="localEndDate || today" :max="localEndDate || tomorrow"
class="date-picker-input" class="date-picker-input"
@change="onDateChange" @change="onDateChange"
/> />
</div> </div>
<div class="date-picker-separator"> <div class="date-picker-separator">
<svg <svg
class="h-4 w-4 text-gray-400" class="h-4 w-4 text-gray-400"
fill="none" fill="none"
stroke="currentColor" stroke="currentColor"
viewBox="0 0 24 24" viewBox="0 0 24 24"
stroke-width="1.5" stroke-width="1.5"
> >
<path <path
stroke-linecap="round" stroke-linecap="round"
stroke-linejoin="round" stroke-linejoin="round"
d="M17.25 8.25L21 12m0 0l-3.75 3.75M21 12H3" d="M17.25 8.25L21 12m0 0l-3.75 3.75M21 12H3"
/> />
</svg> </svg>
</div> </div>
<div class="date-picker-field"> <div class="date-picker-field">
<label class="date-picker-label">{{ t('dates.endDate') }}</label> <label class="date-picker-label">{{ t('dates.endDate') }}</label>
<input <input
type="date" type="date"
v-model="localEndDate" v-model="localEndDate"
:min="localStartDate" :min="localStartDate"
:max="today" :max="tomorrow"
class="date-picker-input" class="date-picker-input"
@change="onDateChange" @change="onDateChange"
/> />
</div> </div>
</div> </div>
<!-- Apply button --> <!-- Apply button -->
<div class="date-picker-actions"> <div class="date-picker-actions">
<button @click="apply" class="date-picker-apply"> <button @click="apply" class="date-picker-apply">
{{ t('dates.apply') }} {{ t('dates.apply') }}
</button> </button>
</div> </div>
</div> </div>
</Transition> </Transition>
</div> </div>
</template> </template>
<script setup lang="ts"> <script setup lang="ts">
import { ref, computed, watch, onMounted, onUnmounted } from 'vue' import { ref, computed, watch, onMounted, onUnmounted } from 'vue'
import { useI18n } from 'vue-i18n' import { useI18n } from 'vue-i18n'
interface DatePreset { interface DatePreset {
labelKey: string labelKey: string
value: string value: string
getRange: () => { start: string; end: string } getRange: () => { start: string; end: string }
} }
interface Props { interface Props {
startDate: string startDate: string
endDate: string endDate: string
} }
interface Emits { interface Emits {
(e: 'update:startDate', value: string): void (e: 'update:startDate', value: string): void
(e: 'update:endDate', value: string): void (e: 'update:endDate', value: string): void
(e: 'change', range: { startDate: string; endDate: string; preset: string | null }): void (e: 'change', range: { startDate: string; endDate: string; preset: string | null }): void
} }
const props = defineProps<Props>() const props = defineProps<Props>()
const emit = defineEmits<Emits>() const emit = defineEmits<Emits>()
const { t, locale } = useI18n() const { t, locale } = useI18n()
const isOpen = ref(false) const isOpen = ref(false)
const containerRef = ref<HTMLElement | null>(null) const containerRef = ref<HTMLElement | null>(null)
const localStartDate = ref(props.startDate) const localStartDate = ref(props.startDate)
const localEndDate = ref(props.endDate) const localEndDate = ref(props.endDate)
const activePreset = ref<string | null>('7days') const activePreset = ref<string | null>('7days')
const today = computed(() => { const today = computed(() => {
// Use local timezone to avoid UTC timezone issues // Use local timezone to avoid UTC timezone issues
const now = new Date() const now = new Date()
const year = now.getFullYear() const year = now.getFullYear()
const month = String(now.getMonth() + 1).padStart(2, '0') const month = String(now.getMonth() + 1).padStart(2, '0')
const day = String(now.getDate()).padStart(2, '0') const day = String(now.getDate()).padStart(2, '0')
return `${year}-${month}-${day}` return `${year}-${month}-${day}`
}) })
// Helper function to format date to YYYY-MM-DD using local timezone // Tomorrow's date - used for max date to handle timezone differences
const formatDateToString = (date: Date): string => { // When user is in a timezone behind the server, "today" on server might be "tomorrow" locally
const year = date.getFullYear() const tomorrow = computed(() => {
const month = String(date.getMonth() + 1).padStart(2, '0') const d = new Date()
const day = String(date.getDate()).padStart(2, '0') d.setDate(d.getDate() + 1)
return `${year}-${month}-${day}` return formatDateToString(d)
} })
const presets: DatePreset[] = [ // Helper function to format date to YYYY-MM-DD using local timezone
{ const formatDateToString = (date: Date): string => {
labelKey: 'dates.today', const year = date.getFullYear()
value: 'today', const month = String(date.getMonth() + 1).padStart(2, '0')
getRange: () => { const day = String(date.getDate()).padStart(2, '0')
const t = today.value return `${year}-${month}-${day}`
return { start: t, end: t } }
}
}, const presets: DatePreset[] = [
{ {
labelKey: 'dates.yesterday', labelKey: 'dates.today',
value: 'yesterday', value: 'today',
getRange: () => { getRange: () => {
const d = new Date() const t = today.value
d.setDate(d.getDate() - 1) return { start: t, end: t }
const yesterday = formatDateToString(d) }
return { start: yesterday, end: yesterday } },
} {
}, labelKey: 'dates.yesterday',
{ value: 'yesterday',
labelKey: 'dates.last7Days', getRange: () => {
value: '7days', const d = new Date()
getRange: () => { d.setDate(d.getDate() - 1)
const end = today.value const yesterday = formatDateToString(d)
const d = new Date() return { start: yesterday, end: yesterday }
d.setDate(d.getDate() - 6) }
const start = formatDateToString(d) },
return { start, end } {
} labelKey: 'dates.last7Days',
}, value: '7days',
{ getRange: () => {
labelKey: 'dates.last14Days', const end = today.value
value: '14days', const d = new Date()
getRange: () => { d.setDate(d.getDate() - 6)
const end = today.value const start = formatDateToString(d)
const d = new Date() return { start, end }
d.setDate(d.getDate() - 13) }
const start = formatDateToString(d) },
return { start, end } {
} labelKey: 'dates.last14Days',
}, value: '14days',
{ getRange: () => {
labelKey: 'dates.last30Days', const end = today.value
value: '30days', const d = new Date()
getRange: () => { d.setDate(d.getDate() - 13)
const end = today.value const start = formatDateToString(d)
const d = new Date() return { start, end }
d.setDate(d.getDate() - 29) }
const start = formatDateToString(d) },
return { start, end } {
} labelKey: 'dates.last30Days',
}, value: '30days',
{ getRange: () => {
labelKey: 'dates.thisMonth', const end = today.value
value: 'thisMonth', const d = new Date()
getRange: () => { d.setDate(d.getDate() - 29)
const now = new Date() const start = formatDateToString(d)
const start = formatDateToString(new Date(now.getFullYear(), now.getMonth(), 1)) return { start, end }
return { start, end: today.value } }
} },
}, {
{ labelKey: 'dates.thisMonth',
labelKey: 'dates.lastMonth', value: 'thisMonth',
value: 'lastMonth', getRange: () => {
getRange: () => { const now = new Date()
const now = new Date() const start = formatDateToString(new Date(now.getFullYear(), now.getMonth(), 1))
const start = formatDateToString(new Date(now.getFullYear(), now.getMonth() - 1, 1)) return { start, end: today.value }
const end = formatDateToString(new Date(now.getFullYear(), now.getMonth(), 0)) }
return { start, end } },
} {
} labelKey: 'dates.lastMonth',
] value: 'lastMonth',
getRange: () => {
const displayValue = computed(() => { const now = new Date()
if (activePreset.value) { const start = formatDateToString(new Date(now.getFullYear(), now.getMonth() - 1, 1))
const preset = presets.find((p) => p.value === activePreset.value) const end = formatDateToString(new Date(now.getFullYear(), now.getMonth(), 0))
if (preset) return t(preset.labelKey) return { start, end }
} }
}
if (localStartDate.value && localEndDate.value) { ]
if (localStartDate.value === localEndDate.value) {
return formatDate(localStartDate.value) const displayValue = computed(() => {
} if (activePreset.value) {
return `${formatDate(localStartDate.value)} - ${formatDate(localEndDate.value)}` const preset = presets.find((p) => p.value === activePreset.value)
} if (preset) return t(preset.labelKey)
}
return t('dates.selectDateRange')
}) if (localStartDate.value && localEndDate.value) {
if (localStartDate.value === localEndDate.value) {
const formatDate = (dateStr: string): string => { return formatDate(localStartDate.value)
const date = new Date(dateStr + 'T00:00:00') }
const dateLocale = locale.value === 'zh' ? 'zh-CN' : 'en-US' return `${formatDate(localStartDate.value)} - ${formatDate(localEndDate.value)}`
return date.toLocaleDateString(dateLocale, { month: 'short', day: 'numeric' }) }
}
return t('dates.selectDateRange')
const isPresetActive = (preset: DatePreset): boolean => { })
return activePreset.value === preset.value
} const formatDate = (dateStr: string): string => {
const date = new Date(dateStr + 'T00:00:00')
const selectPreset = (preset: DatePreset) => { const dateLocale = locale.value === 'zh' ? 'zh-CN' : 'en-US'
const range = preset.getRange() return date.toLocaleDateString(dateLocale, { month: 'short', day: 'numeric' })
localStartDate.value = range.start }
localEndDate.value = range.end
activePreset.value = preset.value const isPresetActive = (preset: DatePreset): boolean => {
} return activePreset.value === preset.value
}
const onDateChange = () => {
// Check if current dates match any preset const selectPreset = (preset: DatePreset) => {
activePreset.value = null const range = preset.getRange()
for (const preset of presets) { localStartDate.value = range.start
const range = preset.getRange() localEndDate.value = range.end
if (range.start === localStartDate.value && range.end === localEndDate.value) { activePreset.value = preset.value
activePreset.value = preset.value }
break
} const onDateChange = () => {
} // Check if current dates match any preset
} activePreset.value = null
for (const preset of presets) {
const toggle = () => { const range = preset.getRange()
isOpen.value = !isOpen.value if (range.start === localStartDate.value && range.end === localEndDate.value) {
} activePreset.value = preset.value
break
const apply = () => { }
emit('update:startDate', localStartDate.value) }
emit('update:endDate', localEndDate.value) }
emit('change', {
startDate: localStartDate.value, const toggle = () => {
endDate: localEndDate.value, isOpen.value = !isOpen.value
preset: activePreset.value }
})
isOpen.value = false const apply = () => {
} emit('update:startDate', localStartDate.value)
emit('update:endDate', localEndDate.value)
const handleClickOutside = (event: MouseEvent) => { emit('change', {
if (containerRef.value && !containerRef.value.contains(event.target as Node)) { startDate: localStartDate.value,
isOpen.value = false endDate: localEndDate.value,
} preset: activePreset.value
} })
isOpen.value = false
const handleEscape = (event: KeyboardEvent) => { }
if (event.key === 'Escape' && isOpen.value) {
isOpen.value = false const handleClickOutside = (event: MouseEvent) => {
} if (containerRef.value && !containerRef.value.contains(event.target as Node)) {
} isOpen.value = false
}
// Sync local state with props }
watch(
() => props.startDate, const handleEscape = (event: KeyboardEvent) => {
(val) => { if (event.key === 'Escape' && isOpen.value) {
localStartDate.value = val isOpen.value = false
onDateChange() }
} }
)
// Sync local state with props
watch( watch(
() => props.endDate, () => props.startDate,
(val) => { (val) => {
localEndDate.value = val localStartDate.value = val
onDateChange() onDateChange()
} }
) )
onMounted(() => { watch(
document.addEventListener('click', handleClickOutside) () => props.endDate,
document.addEventListener('keydown', handleEscape) (val) => {
// Initialize active preset detection localEndDate.value = val
onDateChange() onDateChange()
}) }
)
onUnmounted(() => {
document.removeEventListener('click', handleClickOutside) onMounted(() => {
document.removeEventListener('keydown', handleEscape) document.addEventListener('click', handleClickOutside)
}) document.addEventListener('keydown', handleEscape)
</script> // Initialize active preset detection
onDateChange()
<style scoped> })
.date-picker-trigger {
@apply flex items-center gap-2; onUnmounted(() => {
@apply rounded-lg px-3 py-2 text-sm; document.removeEventListener('click', handleClickOutside)
@apply bg-white dark:bg-dark-800; document.removeEventListener('keydown', handleEscape)
@apply border border-gray-200 dark:border-dark-600; })
@apply text-gray-700 dark:text-gray-300; </script>
@apply transition-all duration-200;
@apply focus:border-primary-500 focus:outline-none focus:ring-2 focus:ring-primary-500/30; <style scoped>
@apply hover:border-gray-300 dark:hover:border-dark-500; .date-picker-trigger {
@apply cursor-pointer; @apply flex items-center gap-2;
} @apply rounded-lg px-3 py-2 text-sm;
@apply bg-white dark:bg-dark-800;
.date-picker-trigger-open { @apply border border-gray-200 dark:border-dark-600;
@apply border-primary-500 ring-2 ring-primary-500/30; @apply text-gray-700 dark:text-gray-300;
} @apply transition-all duration-200;
@apply focus:border-primary-500 focus:outline-none focus:ring-2 focus:ring-primary-500/30;
.date-picker-icon { @apply hover:border-gray-300 dark:hover:border-dark-500;
@apply text-gray-400 dark:text-dark-400; @apply cursor-pointer;
} }
.date-picker-value { .date-picker-trigger-open {
@apply font-medium; @apply border-primary-500 ring-2 ring-primary-500/30;
} }
.date-picker-chevron { .date-picker-icon {
@apply text-gray-400 dark:text-dark-400; @apply text-gray-400 dark:text-dark-400;
} }
.date-picker-dropdown { .date-picker-value {
@apply absolute left-0 z-[100] mt-2; @apply font-medium;
@apply bg-white dark:bg-dark-800; }
@apply rounded-xl;
@apply border border-gray-200 dark:border-dark-700; .date-picker-chevron {
@apply shadow-lg shadow-black/10 dark:shadow-black/30; @apply text-gray-400 dark:text-dark-400;
@apply overflow-hidden; }
@apply min-w-[320px];
} .date-picker-dropdown {
@apply absolute left-0 z-[100] mt-2;
.date-picker-presets { @apply bg-white dark:bg-dark-800;
@apply grid grid-cols-2 gap-1 p-2; @apply rounded-xl;
} @apply border border-gray-200 dark:border-dark-700;
@apply shadow-lg shadow-black/10 dark:shadow-black/30;
.date-picker-preset { @apply overflow-hidden;
@apply rounded-md px-3 py-1.5 text-xs font-medium; @apply min-w-[320px];
@apply text-gray-600 dark:text-gray-400; }
@apply hover:bg-gray-100 dark:hover:bg-dark-700;
@apply transition-colors duration-150; .date-picker-presets {
} @apply grid grid-cols-2 gap-1 p-2;
}
.date-picker-preset-active {
@apply bg-primary-100 dark:bg-primary-900/30; .date-picker-preset {
@apply text-primary-700 dark:text-primary-300; @apply rounded-md px-3 py-1.5 text-xs font-medium;
} @apply text-gray-600 dark:text-gray-400;
@apply hover:bg-gray-100 dark:hover:bg-dark-700;
.date-picker-divider { @apply transition-colors duration-150;
@apply border-t border-gray-100 dark:border-dark-700; }
}
.date-picker-preset-active {
.date-picker-custom { @apply bg-primary-100 dark:bg-primary-900/30;
@apply flex items-end gap-2 p-3; @apply text-primary-700 dark:text-primary-300;
} }
.date-picker-field { .date-picker-divider {
@apply flex-1; @apply border-t border-gray-100 dark:border-dark-700;
} }
.date-picker-label { .date-picker-custom {
@apply mb-1 block text-xs font-medium text-gray-500 dark:text-gray-400; @apply flex items-end gap-2 p-3;
} }
.date-picker-input { .date-picker-field {
@apply w-full rounded-md px-2 py-1.5 text-sm; @apply flex-1;
@apply bg-gray-50 dark:bg-dark-700; }
@apply border border-gray-200 dark:border-dark-600;
@apply text-gray-900 dark:text-gray-100; .date-picker-label {
@apply focus:border-primary-500 focus:outline-none focus:ring-2 focus:ring-primary-500/30; @apply mb-1 block text-xs font-medium text-gray-500 dark:text-gray-400;
} }
.date-picker-input::-webkit-calendar-picker-indicator { .date-picker-input {
@apply cursor-pointer opacity-60 hover:opacity-100; @apply w-full rounded-md px-2 py-1.5 text-sm;
filter: invert(0.5); @apply bg-gray-50 dark:bg-dark-700;
} @apply border border-gray-200 dark:border-dark-600;
@apply text-gray-900 dark:text-gray-100;
.dark .date-picker-input::-webkit-calendar-picker-indicator { @apply focus:border-primary-500 focus:outline-none focus:ring-2 focus:ring-primary-500/30;
filter: invert(0.7); }
}
.date-picker-input::-webkit-calendar-picker-indicator {
.date-picker-separator { @apply cursor-pointer opacity-60 hover:opacity-100;
@apply flex items-center justify-center pb-1; filter: invert(0.5);
} }
.date-picker-actions { .dark .date-picker-input::-webkit-calendar-picker-indicator {
@apply flex justify-end p-2 pt-0; filter: invert(0.7);
} }
.date-picker-apply { .date-picker-separator {
@apply rounded-lg px-4 py-1.5 text-sm font-medium; @apply flex items-center justify-center pb-1;
@apply bg-primary-600 text-white; }
@apply hover:bg-primary-700;
@apply transition-colors duration-150; .date-picker-actions {
} @apply flex justify-end p-2 pt-0;
}
/* Dropdown animation */
.date-picker-dropdown-enter-active, .date-picker-apply {
.date-picker-dropdown-leave-active { @apply rounded-lg px-4 py-1.5 text-sm font-medium;
transition: all 0.2s ease; @apply bg-primary-600 text-white;
} @apply hover:bg-primary-700;
@apply transition-colors duration-150;
.date-picker-dropdown-enter-from, }
.date-picker-dropdown-leave-to {
opacity: 0; /* Dropdown animation */
transform: translateY(-8px); .date-picker-dropdown-enter-active,
} .date-picker-dropdown-leave-active {
</style> transition: all 0.2s ease;
}
.date-picker-dropdown-enter-from,
.date-picker-dropdown-leave-to {
opacity: 0;
transform: translateY(-8px);
}
</style>

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff