From cceada7daeffb12e1b9e45ba87d2618c2c661751 Mon Sep 17 00:00:00 2001 From: Forest Date: Sat, 20 Dec 2025 11:56:11 +0800 Subject: [PATCH] refactor(backend): service http ports --- backend/cmd/server/wire.go | 10 + backend/cmd/server/wire_gen.go | 34 ++- .../internal/handler/admin/system_handler.go | 7 +- backend/internal/handler/wire.go | 7 +- .../repository/claude_oauth_service.go | 235 ++++++++++++++++++ backend/internal/repository/claude_service.go | 64 +++++ .../repository/claude_usage_service.go | 59 +++++ .../repository/github_release_service.go | 116 +++++++++ .../internal/repository/pricing_service.go | 73 ++++++ .../repository/proxy_probe_service.go | 104 ++++++++ .../internal/repository/turnstile_service.go | 55 ++++ backend/internal/repository/wire.go | 9 + .../internal/service/account_test_service.go | 35 +-- .../internal/service/account_usage_service.go | 62 +---- backend/internal/service/admin_service.go | 131 ++-------- backend/internal/service/gateway_service.go | 137 +++------- backend/internal/service/oauth_service.go | 219 ++-------------- backend/internal/service/pricing_service.go | 100 +++----- backend/internal/service/service.go | 1 + backend/internal/service/turnstile_service.go | 43 +--- backend/internal/service/update_service.go | 104 ++------ backend/internal/service/wire.go | 17 +- 22 files changed, 927 insertions(+), 695 deletions(-) create mode 100644 backend/internal/repository/claude_oauth_service.go create mode 100644 backend/internal/repository/claude_service.go create mode 100644 backend/internal/repository/claude_usage_service.go create mode 100644 backend/internal/repository/github_release_service.go create mode 100644 backend/internal/repository/pricing_service.go create mode 100644 backend/internal/repository/proxy_probe_service.go create mode 100644 backend/internal/repository/turnstile_service.go diff --git a/backend/cmd/server/wire.go b/backend/cmd/server/wire.go index 9e9f8677..2fe3f468 100644 --- a/backend/cmd/server/wire.go +++ b/backend/cmd/server/wire.go @@ -40,6 +40,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { // 服务器层 ProviderSet server.ProviderSet, + // BuildInfo provider + provideServiceBuildInfo, + // 清理函数提供者 provideCleanup, @@ -49,6 +52,13 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { return nil, nil } +func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo { + return service.BuildInfo{ + Version: buildInfo.Version, + BuildType: buildInfo.BuildType, + } +} + func provideCleanup( db *gorm.DB, rdb *redis.Client, diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 5f13e3cc..347e831f 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -43,7 +43,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { client := infrastructure.ProvideRedis(configConfig) emailCache := repository.NewEmailCache(client) emailService := service.NewEmailService(settingRepository, emailCache) - turnstileService := service.NewTurnstileService(settingService) + turnstileVerifier := repository.NewTurnstileVerifier() + turnstileService := service.NewTurnstileService(settingService, turnstileVerifier) emailQueueService := service.ProvideEmailQueueService(emailService) authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService) authHandler := handler.NewAuthHandler(authService) @@ -68,32 +69,41 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService) accountRepository := repository.NewAccountRepository(db) proxyRepository := repository.NewProxyRepository(db) - adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, usageLogRepository, userSubscriptionRepository, billingCacheService) + proxyExitInfoProber := repository.NewProxyExitInfoProber() + adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, usageLogRepository, userSubscriptionRepository, billingCacheService, proxyExitInfoProber) dashboardHandler := admin.NewDashboardHandler(adminService, usageLogRepository) adminUserHandler := admin.NewUserHandler(adminService) groupHandler := admin.NewGroupHandler(adminService) - oAuthService := service.NewOAuthService(proxyRepository) + claudeOAuthClient := repository.NewClaudeOAuthClient() + oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient) rateLimitService := service.NewRateLimitService(accountRepository, configConfig) - accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, oAuthService) - accountTestService := service.NewAccountTestService(accountRepository, oAuthService) + claudeUsageFetcher := repository.NewClaudeUsageFetcher() + accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, oAuthService, claudeUsageFetcher) + claudeUpstream := repository.NewClaudeUpstream(configConfig) + accountTestService := service.NewAccountTestService(accountRepository, oAuthService, claudeUpstream) accountHandler := admin.NewAccountHandler(adminService, oAuthService, rateLimitService, accountUsageService, accountTestService) oAuthHandler := admin.NewOAuthHandler(oAuthService, adminService) proxyHandler := admin.NewProxyHandler(adminService) adminRedeemHandler := admin.NewRedeemHandler(adminService) settingHandler := admin.NewSettingHandler(settingService, emailService) - systemHandler := handler.ProvideSystemHandler(client, buildInfo) + updateCache := repository.NewUpdateCache(client) + gitHubReleaseClient := repository.NewGitHubReleaseClient() + serviceBuildInfo := provideServiceBuildInfo(buildInfo) + updateService := service.ProvideUpdateService(updateCache, gitHubReleaseClient, serviceBuildInfo) + systemHandler := handler.ProvideSystemHandler(updateService) adminSubscriptionHandler := admin.NewSubscriptionHandler(subscriptionService) adminUsageHandler := admin.NewUsageHandler(usageLogRepository, apiKeyRepository, usageService, adminService) adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, proxyHandler, adminRedeemHandler, settingHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler) gatewayCache := repository.NewGatewayCache(client) - pricingService, err := service.ProvidePricingService(configConfig) + pricingRemoteClient := repository.NewPricingRemoteClient() + pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient) if err != nil { return nil, err } billingService := service.NewBillingService(configConfig, pricingService) identityCache := repository.NewIdentityCache(client) identityService := service.NewIdentityService(identityCache) - gatewayService := service.NewGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, oAuthService, billingService, rateLimitService, billingCacheService, identityService) + gatewayService := service.NewGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, oAuthService, billingService, rateLimitService, billingCacheService, identityService, claudeUpstream) concurrencyCache := repository.NewConcurrencyCache(client) concurrencyService := service.NewConcurrencyService(concurrencyCache) gatewayHandler := handler.NewGatewayHandler(gatewayService, userService, concurrencyService, billingCacheService) @@ -127,6 +137,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { Subscription: subscriptionService, Concurrency: concurrencyService, Identity: identityService, + Update: updateService, } repositories := &repository.Repositories{ User: userRepository, @@ -156,6 +167,13 @@ type Application struct { Cleanup func() } +func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo { + return service.BuildInfo{ + Version: buildInfo.Version, + BuildType: buildInfo.BuildType, + } +} + func provideCleanup( db *gorm.DB, rdb *redis.Client, diff --git a/backend/internal/handler/admin/system_handler.go b/backend/internal/handler/admin/system_handler.go index 4e5f40ba..cc833f57 100644 --- a/backend/internal/handler/admin/system_handler.go +++ b/backend/internal/handler/admin/system_handler.go @@ -6,11 +6,9 @@ import ( "sub2api/internal/pkg/response" "sub2api/internal/pkg/sysutil" - "sub2api/internal/repository" "sub2api/internal/service" "github.com/gin-gonic/gin" - "github.com/redis/go-redis/v9" ) // SystemHandler handles system-related operations @@ -19,10 +17,9 @@ type SystemHandler struct { } // NewSystemHandler creates a new SystemHandler -func NewSystemHandler(rdb *redis.Client, version, buildType string) *SystemHandler { - updateCache := repository.NewUpdateCache(rdb) +func NewSystemHandler(updateSvc *service.UpdateService) *SystemHandler { return &SystemHandler{ - updateSvc: service.NewUpdateService(updateCache, version, buildType), + updateSvc: updateSvc, } } diff --git a/backend/internal/handler/wire.go b/backend/internal/handler/wire.go index 02e182e8..435fbd0c 100644 --- a/backend/internal/handler/wire.go +++ b/backend/internal/handler/wire.go @@ -5,7 +5,6 @@ import ( "sub2api/internal/service" "github.com/google/wire" - "github.com/redis/go-redis/v9" ) // ProvideAdminHandlers creates the AdminHandlers struct @@ -37,9 +36,9 @@ func ProvideAdminHandlers( } } -// ProvideSystemHandler creates admin.SystemHandler with BuildInfo parameters -func ProvideSystemHandler(rdb *redis.Client, buildInfo BuildInfo) *admin.SystemHandler { - return admin.NewSystemHandler(rdb, buildInfo.Version, buildInfo.BuildType) +// ProvideSystemHandler creates admin.SystemHandler with UpdateService +func ProvideSystemHandler(updateService *service.UpdateService) *admin.SystemHandler { + return admin.NewSystemHandler(updateService) } // ProvideSettingHandler creates SettingHandler with version from BuildInfo diff --git a/backend/internal/repository/claude_oauth_service.go b/backend/internal/repository/claude_oauth_service.go new file mode 100644 index 00000000..ab0dee88 --- /dev/null +++ b/backend/internal/repository/claude_oauth_service.go @@ -0,0 +1,235 @@ +package repository + +import ( + "context" + "encoding/json" + "fmt" + "log" + "net/http" + "net/url" + "time" + + "sub2api/internal/pkg/oauth" + "sub2api/internal/service" + + "github.com/imroc/req/v3" +) + +type claudeOAuthService struct{} + +func NewClaudeOAuthClient() service.ClaudeOAuthClient { + return &claudeOAuthService{} +} + +func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey, proxyURL string) (string, error) { + client := createReqClient(proxyURL) + + var orgs []struct { + UUID string `json:"uuid"` + } + + targetURL := "https://claude.ai/api/organizations" + log.Printf("[OAuth] Step 1: Getting organization UUID from %s", targetURL) + + resp, err := client.R(). + SetContext(ctx). + SetCookies(&http.Cookie{ + Name: "sessionKey", + Value: sessionKey, + }). + SetSuccessResult(&orgs). + Get(targetURL) + + if err != nil { + log.Printf("[OAuth] Step 1 FAILED - Request error: %v", err) + return "", fmt.Errorf("request failed: %w", err) + } + + log.Printf("[OAuth] Step 1 Response - Status: %d, Body: %s", resp.StatusCode, resp.String()) + + if !resp.IsSuccessState() { + return "", fmt.Errorf("failed to get organizations: status %d, body: %s", resp.StatusCode, resp.String()) + } + + if len(orgs) == 0 { + return "", fmt.Errorf("no organizations found") + } + + log.Printf("[OAuth] Step 1 SUCCESS - Got org UUID: %s", orgs[0].UUID) + return orgs[0].UUID, nil +} + +func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error) { + client := createReqClient(proxyURL) + + authURL := fmt.Sprintf("https://claude.ai/v1/oauth/%s/authorize", orgUUID) + + reqBody := map[string]interface{}{ + "response_type": "code", + "client_id": oauth.ClientID, + "organization_uuid": orgUUID, + "redirect_uri": oauth.RedirectURI, + "scope": scope, + "state": state, + "code_challenge": codeChallenge, + "code_challenge_method": "S256", + } + + reqBodyJSON, _ := json.Marshal(reqBody) + log.Printf("[OAuth] Step 2: Getting authorization code from %s", authURL) + log.Printf("[OAuth] Step 2 Request Body: %s", string(reqBodyJSON)) + + var result struct { + RedirectURI string `json:"redirect_uri"` + } + + resp, err := client.R(). + SetContext(ctx). + SetCookies(&http.Cookie{ + Name: "sessionKey", + Value: sessionKey, + }). + SetHeader("Accept", "application/json"). + SetHeader("Accept-Language", "en-US,en;q=0.9"). + SetHeader("Cache-Control", "no-cache"). + SetHeader("Origin", "https://claude.ai"). + SetHeader("Referer", "https://claude.ai/new"). + SetHeader("Content-Type", "application/json"). + SetBody(reqBody). + SetSuccessResult(&result). + Post(authURL) + + if err != nil { + log.Printf("[OAuth] Step 2 FAILED - Request error: %v", err) + return "", fmt.Errorf("request failed: %w", err) + } + + log.Printf("[OAuth] Step 2 Response - Status: %d, Body: %s", resp.StatusCode, resp.String()) + + if !resp.IsSuccessState() { + return "", fmt.Errorf("failed to get authorization code: status %d, body: %s", resp.StatusCode, resp.String()) + } + + if result.RedirectURI == "" { + return "", fmt.Errorf("no redirect_uri in response") + } + + parsedURL, err := url.Parse(result.RedirectURI) + if err != nil { + return "", fmt.Errorf("failed to parse redirect_uri: %w", err) + } + + queryParams := parsedURL.Query() + authCode := queryParams.Get("code") + responseState := queryParams.Get("state") + + if authCode == "" { + return "", fmt.Errorf("no authorization code in redirect_uri") + } + + fullCode := authCode + if responseState != "" { + fullCode = authCode + "#" + responseState + } + + log.Printf("[OAuth] Step 2 SUCCESS - Got authorization code: %s...", authCode[:20]) + return fullCode, nil +} + +func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string) (*oauth.TokenResponse, error) { + client := createReqClient(proxyURL) + + authCode := code + codeState := "" + if len(code) > 0 { + parts := make([]string, 0, 2) + for i, part := range []rune(code) { + if part == '#' { + authCode = code[:i] + codeState = code[i+1:] + break + } + } + if len(parts) == 0 { + authCode = code + } + } + + reqBody := map[string]interface{}{ + "code": authCode, + "grant_type": "authorization_code", + "client_id": oauth.ClientID, + "redirect_uri": oauth.RedirectURI, + "code_verifier": codeVerifier, + } + + if codeState != "" { + reqBody["state"] = codeState + } + + reqBodyJSON, _ := json.Marshal(reqBody) + log.Printf("[OAuth] Step 3: Exchanging code for token at %s", oauth.TokenURL) + log.Printf("[OAuth] Step 3 Request Body: %s", string(reqBodyJSON)) + + var tokenResp oauth.TokenResponse + + resp, err := client.R(). + SetContext(ctx). + SetHeader("Content-Type", "application/json"). + SetBody(reqBody). + SetSuccessResult(&tokenResp). + Post(oauth.TokenURL) + + if err != nil { + log.Printf("[OAuth] Step 3 FAILED - Request error: %v", err) + return nil, fmt.Errorf("request failed: %w", err) + } + + log.Printf("[OAuth] Step 3 Response - Status: %d, Body: %s", resp.StatusCode, resp.String()) + + if !resp.IsSuccessState() { + return nil, fmt.Errorf("token exchange failed: status %d, body: %s", resp.StatusCode, resp.String()) + } + + log.Printf("[OAuth] Step 3 SUCCESS - Got access token") + return &tokenResp, nil +} + +func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) { + client := createReqClient(proxyURL) + + formData := url.Values{} + formData.Set("grant_type", "refresh_token") + formData.Set("refresh_token", refreshToken) + formData.Set("client_id", oauth.ClientID) + + var tokenResp oauth.TokenResponse + + resp, err := client.R(). + SetContext(ctx). + SetFormDataFromValues(formData). + SetSuccessResult(&tokenResp). + Post(oauth.TokenURL) + + if err != nil { + return nil, fmt.Errorf("request failed: %w", err) + } + + if !resp.IsSuccessState() { + return nil, fmt.Errorf("token refresh failed: status %d, body: %s", resp.StatusCode, resp.String()) + } + + return &tokenResp, nil +} + +func createReqClient(proxyURL string) *req.Client { + client := req.C(). + ImpersonateChrome(). + SetTimeout(60 * time.Second) + + if proxyURL != "" { + client.SetProxyURL(proxyURL) + } + + return client +} diff --git a/backend/internal/repository/claude_service.go b/backend/internal/repository/claude_service.go new file mode 100644 index 00000000..dad8730e --- /dev/null +++ b/backend/internal/repository/claude_service.go @@ -0,0 +1,64 @@ +package repository + +import ( + "net/http" + "net/url" + "time" + + "sub2api/internal/config" + "sub2api/internal/service" +) + +type claudeUpstreamService struct { + defaultClient *http.Client + cfg *config.Config +} + +func NewClaudeUpstream(cfg *config.Config) service.ClaudeUpstream { + responseHeaderTimeout := time.Duration(cfg.Gateway.ResponseHeaderTimeout) * time.Second + if responseHeaderTimeout == 0 { + responseHeaderTimeout = 300 * time.Second + } + + transport := &http.Transport{ + MaxIdleConns: 100, + MaxIdleConnsPerHost: 10, + IdleConnTimeout: 90 * time.Second, + ResponseHeaderTimeout: responseHeaderTimeout, + } + + return &claudeUpstreamService{ + defaultClient: &http.Client{Transport: transport}, + cfg: cfg, + } +} + +func (s *claudeUpstreamService) Do(req *http.Request, proxyURL string) (*http.Response, error) { + if proxyURL == "" { + return s.defaultClient.Do(req) + } + client := s.createProxyClient(proxyURL) + return client.Do(req) +} + +func (s *claudeUpstreamService) createProxyClient(proxyURL string) *http.Client { + parsedURL, err := url.Parse(proxyURL) + if err != nil { + return s.defaultClient + } + + responseHeaderTimeout := time.Duration(s.cfg.Gateway.ResponseHeaderTimeout) * time.Second + if responseHeaderTimeout == 0 { + responseHeaderTimeout = 300 * time.Second + } + + transport := &http.Transport{ + Proxy: http.ProxyURL(parsedURL), + MaxIdleConns: 100, + MaxIdleConnsPerHost: 10, + IdleConnTimeout: 90 * time.Second, + ResponseHeaderTimeout: responseHeaderTimeout, + } + + return &http.Client{Transport: transport} +} diff --git a/backend/internal/repository/claude_usage_service.go b/backend/internal/repository/claude_usage_service.go new file mode 100644 index 00000000..45d0aace --- /dev/null +++ b/backend/internal/repository/claude_usage_service.go @@ -0,0 +1,59 @@ +package repository + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "time" + + "sub2api/internal/service" +) + +type claudeUsageService struct{} + +func NewClaudeUsageFetcher() service.ClaudeUsageFetcher { + return &claudeUsageService{} +} + +func (s *claudeUsageService) FetchUsage(ctx context.Context, accessToken, proxyURL string) (*service.ClaudeUsageResponse, error) { + transport := http.DefaultTransport.(*http.Transport).Clone() + if proxyURL != "" { + if parsedURL, err := url.Parse(proxyURL); err == nil { + transport.Proxy = http.ProxyURL(parsedURL) + } + } + + client := &http.Client{ + Transport: transport, + Timeout: 30 * time.Second, + } + + req, err := http.NewRequestWithContext(ctx, "GET", "https://api.anthropic.com/api/oauth/usage", nil) + if err != nil { + return nil, fmt.Errorf("create request failed: %w", err) + } + + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("anthropic-beta", "oauth-2025-04-20") + + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("request failed: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("API returned status %d: %s", resp.StatusCode, string(body)) + } + + var usageResp service.ClaudeUsageResponse + if err := json.NewDecoder(resp.Body).Decode(&usageResp); err != nil { + return nil, fmt.Errorf("decode response failed: %w", err) + } + + return &usageResp, nil +} diff --git a/backend/internal/repository/github_release_service.go b/backend/internal/repository/github_release_service.go new file mode 100644 index 00000000..927c0ee5 --- /dev/null +++ b/backend/internal/repository/github_release_service.go @@ -0,0 +1,116 @@ +package repository + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "time" + + "sub2api/internal/service" +) + +type githubReleaseClient struct { + httpClient *http.Client +} + +func NewGitHubReleaseClient() service.GitHubReleaseClient { + return &githubReleaseClient{ + httpClient: &http.Client{ + Timeout: 30 * time.Second, + }, + } +} + +func (c *githubReleaseClient) FetchLatestRelease(ctx context.Context, repo string) (*service.GitHubRelease, error) { + url := fmt.Sprintf("https://api.github.com/repos/%s/releases/latest", repo) + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return nil, err + } + req.Header.Set("Accept", "application/vnd.github.v3+json") + req.Header.Set("User-Agent", "Sub2API-Updater") + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("GitHub API returned %d", resp.StatusCode) + } + + var release service.GitHubRelease + if err := json.NewDecoder(resp.Body).Decode(&release); err != nil { + return nil, err + } + + return &release, nil +} + +func (c *githubReleaseClient) DownloadFile(ctx context.Context, url, dest string, maxSize int64) error { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return err + } + + client := &http.Client{Timeout: 10 * time.Minute} + resp, err := client.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("download returned %d", resp.StatusCode) + } + + // SECURITY: Check Content-Length if available + if resp.ContentLength > maxSize { + return fmt.Errorf("file too large: %d bytes (max %d)", resp.ContentLength, maxSize) + } + + out, err := os.Create(dest) + if err != nil { + return err + } + defer out.Close() + + // SECURITY: Use LimitReader to enforce max download size even if Content-Length is missing/wrong + limited := io.LimitReader(resp.Body, maxSize+1) + written, err := io.Copy(out, limited) + if err != nil { + return err + } + + // Check if we hit the limit (downloaded more than maxSize) + if written > maxSize { + os.Remove(dest) // Clean up partial file + return fmt.Errorf("download exceeded maximum size of %d bytes", maxSize) + } + + return nil +} + +func (c *githubReleaseClient) FetchChecksumFile(ctx context.Context, url string) ([]byte, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return nil, err + } + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("HTTP %d", resp.StatusCode) + } + + return io.ReadAll(resp.Body) +} diff --git a/backend/internal/repository/pricing_service.go b/backend/internal/repository/pricing_service.go new file mode 100644 index 00000000..e634531c --- /dev/null +++ b/backend/internal/repository/pricing_service.go @@ -0,0 +1,73 @@ +package repository + +import ( + "context" + "fmt" + "io" + "net/http" + "strings" + "time" + + "sub2api/internal/service" +) + +type pricingRemoteClient struct { + httpClient *http.Client +} + +func NewPricingRemoteClient() service.PricingRemoteClient { + return &pricingRemoteClient{ + httpClient: &http.Client{ + Timeout: 30 * time.Second, + }, + } +} + +func (c *pricingRemoteClient) FetchPricingJSON(ctx context.Context, url string) ([]byte, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return nil, err + } + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("HTTP %d", resp.StatusCode) + } + + return io.ReadAll(resp.Body) +} + +func (c *pricingRemoteClient) FetchHashText(ctx context.Context, url string) (string, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return "", err + } + + resp, err := c.httpClient.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("HTTP %d", resp.StatusCode) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", err + } + + // 哈希文件格式:hash filename 或者纯 hash + hash := strings.TrimSpace(string(body)) + parts := strings.Fields(hash) + if len(parts) > 0 { + return parts[0], nil + } + return hash, nil +} diff --git a/backend/internal/repository/proxy_probe_service.go b/backend/internal/repository/proxy_probe_service.go new file mode 100644 index 00000000..c670f3ef --- /dev/null +++ b/backend/internal/repository/proxy_probe_service.go @@ -0,0 +1,104 @@ +package repository + +import ( + "context" + "crypto/tls" + "encoding/json" + "fmt" + "io" + "net" + "net/http" + "net/url" + "time" + + "sub2api/internal/service" + + "golang.org/x/net/proxy" +) + +type proxyProbeService struct{} + +func NewProxyExitInfoProber() service.ProxyExitInfoProber { + return &proxyProbeService{} +} + +func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*service.ProxyExitInfo, int64, error) { + transport, err := createProxyTransport(proxyURL) + if err != nil { + return nil, 0, fmt.Errorf("failed to create proxy transport: %w", err) + } + + client := &http.Client{ + Transport: transport, + Timeout: 15 * time.Second, + } + + startTime := time.Now() + req, err := http.NewRequestWithContext(ctx, "GET", "https://ipinfo.io/json", nil) + if err != nil { + return nil, 0, fmt.Errorf("failed to create request: %w", err) + } + + resp, err := client.Do(req) + if err != nil { + return nil, 0, fmt.Errorf("proxy connection failed: %w", err) + } + defer resp.Body.Close() + + latencyMs := time.Since(startTime).Milliseconds() + + if resp.StatusCode != http.StatusOK { + return nil, latencyMs, fmt.Errorf("request failed with status: %d", resp.StatusCode) + } + + var ipInfo struct { + IP string `json:"ip"` + City string `json:"city"` + Region string `json:"region"` + Country string `json:"country"` + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, latencyMs, fmt.Errorf("failed to read response: %w", err) + } + + if err := json.Unmarshal(body, &ipInfo); err != nil { + return nil, latencyMs, fmt.Errorf("failed to parse response: %w", err) + } + + return &service.ProxyExitInfo{ + IP: ipInfo.IP, + City: ipInfo.City, + Region: ipInfo.Region, + Country: ipInfo.Country, + }, latencyMs, nil +} + +func createProxyTransport(proxyURL string) (*http.Transport, error) { + parsedURL, err := url.Parse(proxyURL) + if err != nil { + return nil, fmt.Errorf("invalid proxy URL: %w", err) + } + + transport := &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + } + + switch parsedURL.Scheme { + case "http", "https": + transport.Proxy = http.ProxyURL(parsedURL) + case "socks5": + dialer, err := proxy.FromURL(parsedURL, proxy.Direct) + if err != nil { + return nil, fmt.Errorf("failed to create socks5 dialer: %w", err) + } + transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + return dialer.Dial(network, addr) + } + default: + return nil, fmt.Errorf("unsupported proxy protocol: %s", parsedURL.Scheme) + } + + return transport, nil +} diff --git a/backend/internal/repository/turnstile_service.go b/backend/internal/repository/turnstile_service.go new file mode 100644 index 00000000..19152dbd --- /dev/null +++ b/backend/internal/repository/turnstile_service.go @@ -0,0 +1,55 @@ +package repository + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" + "strings" + "time" + + "sub2api/internal/service" +) + +const turnstileVerifyURL = "https://challenges.cloudflare.com/turnstile/v0/siteverify" + +type turnstileVerifier struct { + httpClient *http.Client +} + +func NewTurnstileVerifier() service.TurnstileVerifier { + return &turnstileVerifier{ + httpClient: &http.Client{ + Timeout: 10 * time.Second, + }, + } +} + +func (v *turnstileVerifier) VerifyToken(ctx context.Context, secretKey, token, remoteIP string) (*service.TurnstileVerifyResponse, error) { + formData := url.Values{} + formData.Set("secret", secretKey) + formData.Set("response", token) + if remoteIP != "" { + formData.Set("remoteip", remoteIP) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, turnstileVerifyURL, strings.NewReader(formData.Encode())) + if err != nil { + return nil, fmt.Errorf("create request: %w", err) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + resp, err := v.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("send request: %w", err) + } + defer resp.Body.Close() + + var result service.TurnstileVerifyResponse + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("decode response: %w", err) + } + + return &result, nil +} diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go index 109afc35..ac8838cc 100644 --- a/backend/internal/repository/wire.go +++ b/backend/internal/repository/wire.go @@ -29,6 +29,15 @@ var ProviderSet = wire.NewSet( NewRedeemCache, NewUpdateCache, + // HTTP service ports (DI Strategy A: return interface directly) + NewTurnstileVerifier, + NewPricingRemoteClient, + NewGitHubReleaseClient, + NewProxyExitInfoProber, + NewClaudeUsageFetcher, + NewClaudeOAuthClient, + NewClaudeUpstream, + // Bind concrete repositories to service port interfaces wire.Bind(new(ports.UserRepository), new(*UserRepository)), wire.Bind(new(ports.ApiKeyRepository), new(*ApiKeyRepository)), diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go index 8cfc00f8..e7b44e83 100644 --- a/backend/internal/service/account_test_service.go +++ b/backend/internal/service/account_test_service.go @@ -10,7 +10,6 @@ import ( "io" "log" "net/http" - "net/url" "strconv" "strings" "time" @@ -37,19 +36,17 @@ type TestEvent struct { // AccountTestService handles account testing operations type AccountTestService struct { - accountRepo ports.AccountRepository - oauthService *OAuthService - httpClient *http.Client + accountRepo ports.AccountRepository + oauthService *OAuthService + claudeUpstream ClaudeUpstream } // NewAccountTestService creates a new AccountTestService -func NewAccountTestService(accountRepo ports.AccountRepository, oauthService *OAuthService) *AccountTestService { +func NewAccountTestService(accountRepo ports.AccountRepository, oauthService *OAuthService, claudeUpstream ClaudeUpstream) *AccountTestService { return &AccountTestService{ - accountRepo: accountRepo, - oauthService: oauthService, - httpClient: &http.Client{ - Timeout: 60 * time.Second, - }, + accountRepo: accountRepo, + oauthService: oauthService, + claudeUpstream: claudeUpstream, } } @@ -209,23 +206,13 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int req.Header.Set("x-api-key", authToken) } - // Configure proxy if account has one - transport := http.DefaultTransport.(*http.Transport).Clone() + // Get proxy URL + proxyURL := "" if account.ProxyID != nil && account.Proxy != nil { - proxyURL := account.Proxy.URL() - if proxyURL != "" { - if parsedURL, err := url.Parse(proxyURL); err == nil { - transport.Proxy = http.ProxyURL(parsedURL) - } - } + proxyURL = account.Proxy.URL() } - client := &http.Client{ - Transport: transport, - Timeout: 60 * time.Second, - } - - resp, err := client.Do(req) + resp, err := s.claudeUpstream.Do(req, proxyURL) if err != nil { return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error())) } diff --git a/backend/internal/service/account_usage_service.go b/backend/internal/service/account_usage_service.go index b27830c4..0094f48f 100644 --- a/backend/internal/service/account_usage_service.go +++ b/backend/internal/service/account_usage_service.go @@ -2,12 +2,8 @@ package service import ( "context" - "encoding/json" "fmt" - "io" "log" - "net/http" - "net/url" "sync" "time" @@ -65,23 +61,26 @@ type ClaudeUsageResponse struct { } `json:"seven_day_sonnet"` } +// ClaudeUsageFetcher fetches usage data from Anthropic OAuth API +type ClaudeUsageFetcher interface { + FetchUsage(ctx context.Context, accessToken, proxyURL string) (*ClaudeUsageResponse, error) +} + // AccountUsageService 账号使用量查询服务 type AccountUsageService struct { accountRepo ports.AccountRepository usageLogRepo ports.UsageLogRepository oauthService *OAuthService - httpClient *http.Client + usageFetcher ClaudeUsageFetcher } // NewAccountUsageService 创建AccountUsageService实例 -func NewAccountUsageService(accountRepo ports.AccountRepository, usageLogRepo ports.UsageLogRepository, oauthService *OAuthService) *AccountUsageService { +func NewAccountUsageService(accountRepo ports.AccountRepository, usageLogRepo ports.UsageLogRepository, oauthService *OAuthService, usageFetcher ClaudeUsageFetcher) *AccountUsageService { return &AccountUsageService{ accountRepo: accountRepo, usageLogRepo: usageLogRepo, oauthService: oauthService, - httpClient: &http.Client{ - Timeout: 30 * time.Second, - }, + usageFetcher: usageFetcher, } } @@ -179,58 +178,23 @@ func (s *AccountUsageService) GetTodayStats(ctx context.Context, accountID int64 // fetchOAuthUsage 从Anthropic API获取OAuth账号的使用量 func (s *AccountUsageService) fetchOAuthUsage(ctx context.Context, account *model.Account) (*UsageInfo, error) { - // 获取access token(从credentials中获取) accessToken := account.GetCredential("access_token") if accessToken == "" { return nil, fmt.Errorf("no access token available") } - // 获取代理配置 - transport := http.DefaultTransport.(*http.Transport).Clone() + var proxyURL string if account.ProxyID != nil && account.Proxy != nil { - proxyURL := account.Proxy.URL() - if proxyURL != "" { - if parsedURL, err := url.Parse(proxyURL); err == nil { - transport.Proxy = http.ProxyURL(parsedURL) - } - } + proxyURL = account.Proxy.URL() } - client := &http.Client{ - Transport: transport, - Timeout: 30 * time.Second, - } - - // 构建请求 - req, err := http.NewRequestWithContext(ctx, "GET", "https://api.anthropic.com/api/oauth/usage", nil) + usageResp, err := s.usageFetcher.FetchUsage(ctx, accessToken, proxyURL) if err != nil { - return nil, fmt.Errorf("create request failed: %w", err) + return nil, err } - req.Header.Set("Authorization", "Bearer "+accessToken) - req.Header.Set("anthropic-beta", "oauth-2025-04-20") - - // 发送请求 - resp, err := client.Do(req) - if err != nil { - return nil, fmt.Errorf("request failed: %w", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - return nil, fmt.Errorf("API returned status %d: %s", resp.StatusCode, string(body)) - } - - // 解析响应 - var usageResp ClaudeUsageResponse - if err := json.NewDecoder(resp.Body).Decode(&usageResp); err != nil { - return nil, fmt.Errorf("decode response failed: %w", err) - } - - // 转换为UsageInfo now := time.Now() - return s.buildUsageInfo(&usageResp, &now), nil + return s.buildUsageInfo(usageResp, &now), nil } // parseTime 尝试多种格式解析时间 diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index 4962dafd..09e8cd20 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -2,21 +2,14 @@ package service import ( "context" - "crypto/tls" - "encoding/json" "errors" "fmt" - "io" - "net" - "net/http" - "net/url" "time" "sub2api/internal/model" "sub2api/internal/pkg/pagination" "sub2api/internal/service/ports" - "golang.org/x/net/proxy" "gorm.io/gorm" ) @@ -178,6 +171,19 @@ type ProxyTestResult struct { Country string `json:"country,omitempty"` } +// ProxyExitInfo represents proxy exit information from ipinfo.io +type ProxyExitInfo struct { + IP string + City string + Region string + Country string +} + +// ProxyExitInfoProber tests proxy connectivity and retrieves exit information +type ProxyExitInfoProber interface { + ProbeProxy(ctx context.Context, proxyURL string) (*ProxyExitInfo, int64, error) +} + // adminServiceImpl implements AdminService type adminServiceImpl struct { userRepo ports.UserRepository @@ -189,6 +195,7 @@ type adminServiceImpl struct { usageLogRepo ports.UsageLogRepository userSubRepo ports.UserSubscriptionRepository billingCacheService *BillingCacheService + proxyProber ProxyExitInfoProber } // NewAdminService creates a new AdminService @@ -202,6 +209,7 @@ func NewAdminService( usageLogRepo ports.UsageLogRepository, userSubRepo ports.UserSubscriptionRepository, billingCacheService *BillingCacheService, + proxyProber ProxyExitInfoProber, ) AdminService { return &adminServiceImpl{ userRepo: userRepo, @@ -213,6 +221,7 @@ func NewAdminService( usageLogRepo: usageLogRepo, userSubRepo: userSubRepo, billingCacheService: billingCacheService, + proxyProber: proxyProber, } } @@ -876,79 +885,12 @@ func (s *adminServiceImpl) TestProxy(ctx context.Context, id int64) (*ProxyTestR return nil, err } - return testProxyConnection(ctx, proxy) -} - -// testProxyConnection tests proxy connectivity by requesting ipinfo.io/json -func testProxyConnection(ctx context.Context, proxy *model.Proxy) (*ProxyTestResult, error) { proxyURL := proxy.URL() - - // Create HTTP client with proxy - transport, err := createProxyTransport(proxyURL) + exitInfo, latencyMs, err := s.proxyProber.ProbeProxy(ctx, proxyURL) if err != nil { return &ProxyTestResult{ Success: false, - Message: fmt.Sprintf("Failed to create proxy transport: %v", err), - }, nil - } - - client := &http.Client{ - Transport: transport, - Timeout: 15 * time.Second, - } - - // Measure latency - startTime := time.Now() - - req, err := http.NewRequestWithContext(ctx, "GET", "https://ipinfo.io/json", nil) - if err != nil { - return &ProxyTestResult{ - Success: false, - Message: fmt.Sprintf("Failed to create request: %v", err), - }, nil - } - - resp, err := client.Do(req) - if err != nil { - return &ProxyTestResult{ - Success: false, - Message: fmt.Sprintf("Proxy connection failed: %v", err), - }, nil - } - defer resp.Body.Close() - - latencyMs := time.Since(startTime).Milliseconds() - - if resp.StatusCode != http.StatusOK { - return &ProxyTestResult{ - Success: false, - Message: fmt.Sprintf("Request failed with status: %d", resp.StatusCode), - LatencyMs: latencyMs, - }, nil - } - - // Parse ipinfo.io response - var ipInfo struct { - IP string `json:"ip"` - City string `json:"city"` - Region string `json:"region"` - Country string `json:"country"` - } - - body, err := io.ReadAll(resp.Body) - if err != nil { - return &ProxyTestResult{ - Success: true, - Message: "Proxy is accessible but failed to read response", - LatencyMs: latencyMs, - }, nil - } - - if err := json.Unmarshal(body, &ipInfo); err != nil { - return &ProxyTestResult{ - Success: true, - Message: "Proxy is accessible but failed to parse response", - LatencyMs: latencyMs, + Message: err.Error(), }, nil } @@ -956,38 +898,9 @@ func testProxyConnection(ctx context.Context, proxy *model.Proxy) (*ProxyTestRes Success: true, Message: "Proxy is accessible", LatencyMs: latencyMs, - IPAddress: ipInfo.IP, - City: ipInfo.City, - Region: ipInfo.Region, - Country: ipInfo.Country, + IPAddress: exitInfo.IP, + City: exitInfo.City, + Region: exitInfo.Region, + Country: exitInfo.Country, }, nil } - -// createProxyTransport creates an HTTP transport with the given proxy URL -func createProxyTransport(proxyURL string) (*http.Transport, error) { - parsedURL, err := url.Parse(proxyURL) - if err != nil { - return nil, fmt.Errorf("invalid proxy URL: %w", err) - } - - transport := &http.Transport{ - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, - } - - switch parsedURL.Scheme { - case "http", "https": - transport.Proxy = http.ProxyURL(parsedURL) - case "socks5": - dialer, err := proxy.FromURL(parsedURL, proxy.Direct) - if err != nil { - return nil, fmt.Errorf("failed to create socks5 dialer: %w", err) - } - transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { - return dialer.Dial(network, addr) - } - default: - return nil, fmt.Errorf("unsupported proxy protocol: %s", parsedURL.Scheme) - } - - return transport, nil -} diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index aab17cca..9861db60 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -12,7 +12,6 @@ import ( "io" "log" "net/http" - "net/url" "regexp" "strconv" "strings" @@ -26,6 +25,11 @@ import ( "github.com/gin-gonic/gin" ) +// ClaudeUpstream handles HTTP requests to Claude API +type ClaudeUpstream interface { + Do(req *http.Request, proxyURL string) (*http.Response, error) +} + const ( claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true" claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true" @@ -87,7 +91,7 @@ type GatewayService struct { rateLimitService *RateLimitService billingCacheService *BillingCacheService identityService *IdentityService - httpClient *http.Client + claudeUpstream ClaudeUpstream } // NewGatewayService creates a new GatewayService @@ -103,20 +107,8 @@ func NewGatewayService( rateLimitService *RateLimitService, billingCacheService *BillingCacheService, identityService *IdentityService, + claudeUpstream ClaudeUpstream, ) *GatewayService { - // 计算响应头超时时间 - responseHeaderTimeout := time.Duration(cfg.Gateway.ResponseHeaderTimeout) * time.Second - if responseHeaderTimeout == 0 { - responseHeaderTimeout = 300 * time.Second // 默认5分钟,LLM高负载时可能排队较久 - } - - transport := &http.Transport{ - MaxIdleConns: 100, - MaxIdleConnsPerHost: 10, - IdleConnTimeout: 90 * time.Second, - ResponseHeaderTimeout: responseHeaderTimeout, // 等待上游响应头的超时 - // 注意:不设置整体 Timeout,让流式响应可以无限时间传输 - } return &GatewayService{ accountRepo: accountRepo, usageLogRepo: usageLogRepo, @@ -129,11 +121,7 @@ func NewGatewayService( rateLimitService: rateLimitService, billingCacheService: billingCacheService, identityService: identityService, - httpClient: &http.Client{ - Transport: transport, - // 不设置 Timeout:流式请求可能持续十几分钟 - // 超时控制由 Transport.ResponseHeaderTimeout 负责(只控制等待响应头) - }, + claudeUpstream: claudeUpstream, } } @@ -436,19 +424,19 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *m } // 构建上游请求 - upstreamResult, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType) + upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType) if err != nil { return nil, err } - // 选择使用的client:如果有代理则使用独立的client,否则使用共享的httpClient - httpClient := s.httpClient - if upstreamResult.Client != nil { - httpClient = upstreamResult.Client + // 获取代理URL + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() } // 发送请求 - resp, err := httpClient.Do(upstreamResult.Request) + resp, err := s.claudeUpstream.Do(upstreamReq, proxyURL) if err != nil { return nil, fmt.Errorf("upstream request failed: %w", err) } @@ -461,16 +449,11 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *m if err != nil { return nil, fmt.Errorf("token refresh failed: %w", err) } - upstreamResult, err = s.buildUpstreamRequest(ctx, c, account, body, token, tokenType) + upstreamReq, err = s.buildUpstreamRequest(ctx, c, account, body, token, tokenType) if err != nil { return nil, err } - // 重试时也需要使用正确的client - httpClient = s.httpClient - if upstreamResult.Client != nil { - httpClient = upstreamResult.Client - } - resp, err = httpClient.Do(upstreamResult.Request) + resp, err = s.claudeUpstream.Do(upstreamReq, proxyURL) if err != nil { return nil, fmt.Errorf("retry request failed: %w", err) } @@ -509,13 +492,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *m }, nil } -// buildUpstreamRequestResult contains the request and optional custom client for proxy -type buildUpstreamRequestResult struct { - Request *http.Request - Client *http.Client // nil means use default s.httpClient -} - -func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *model.Account, body []byte, token, tokenType string) (*buildUpstreamRequestResult, error) { +func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *model.Account, body []byte, token, tokenType string) (*http.Request, error) { // 确定目标URL targetURL := claudeAPIURL if account.Type == model.AccountTypeApiKey { @@ -584,36 +561,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex req.Header.Set("anthropic-beta", s.getBetaHeader(body, c.GetHeader("anthropic-beta"))) } - // 配置代理 - 创建独立的client避免并发修改共享httpClient - var customClient *http.Client - if account.ProxyID != nil && account.Proxy != nil { - proxyURL := account.Proxy.URL() - if proxyURL != "" { - if parsedURL, err := url.Parse(proxyURL); err == nil { - // 计算响应头超时时间(与默认 Transport 保持一致) - responseHeaderTimeout := time.Duration(s.cfg.Gateway.ResponseHeaderTimeout) * time.Second - if responseHeaderTimeout == 0 { - responseHeaderTimeout = 300 * time.Second - } - transport := &http.Transport{ - Proxy: http.ProxyURL(parsedURL), - MaxIdleConns: 100, - MaxIdleConnsPerHost: 10, - IdleConnTimeout: 90 * time.Second, - ResponseHeaderTimeout: responseHeaderTimeout, - } - // 创建独立的client,避免并发时修改共享的s.httpClient.Transport - customClient = &http.Client{ - Transport: transport, - } - } - } - } - - return &buildUpstreamRequestResult{ - Request: req, - Client: customClient, - }, nil + return req, nil } // getBetaHeader 处理anthropic-beta header @@ -1085,20 +1033,20 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, } // 构建上游请求 - upstreamResult, err := s.buildCountTokensRequest(ctx, c, account, body, token, tokenType) + upstreamReq, err := s.buildCountTokensRequest(ctx, c, account, body, token, tokenType) if err != nil { s.countTokensError(c, http.StatusInternalServerError, "api_error", "Failed to build request") return err } - // 选择 HTTP client - httpClient := s.httpClient - if upstreamResult.Client != nil { - httpClient = upstreamResult.Client + // 获取代理URL + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() } // 发送请求 - resp, err := httpClient.Do(upstreamResult.Request) + resp, err := s.claudeUpstream.Do(upstreamReq, proxyURL) if err != nil { s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Request failed") return fmt.Errorf("upstream request failed: %w", err) @@ -1113,15 +1061,11 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Token refresh failed") return fmt.Errorf("token refresh failed: %w", err) } - upstreamResult, err = s.buildCountTokensRequest(ctx, c, account, body, token, tokenType) + upstreamReq, err = s.buildCountTokensRequest(ctx, c, account, body, token, tokenType) if err != nil { return err } - httpClient = s.httpClient - if upstreamResult.Client != nil { - httpClient = upstreamResult.Client - } - resp, err = httpClient.Do(upstreamResult.Request) + resp, err = s.claudeUpstream.Do(upstreamReq, proxyURL) if err != nil { s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Retry failed") return fmt.Errorf("retry request failed: %w", err) @@ -1159,7 +1103,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, } // buildCountTokensRequest 构建 count_tokens 上游请求 -func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *model.Account, body []byte, token, tokenType string) (*buildUpstreamRequestResult, error) { +func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *model.Account, body []byte, token, tokenType string) (*http.Request, error) { // 确定目标 URL targetURL := claudeAPICountTokensURL if account.Type == model.AccountTypeApiKey { @@ -1223,32 +1167,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con req.Header.Set("anthropic-beta", s.getBetaHeader(body, c.GetHeader("anthropic-beta"))) } - // 配置代理 - var customClient *http.Client - if account.ProxyID != nil && account.Proxy != nil { - proxyURL := account.Proxy.URL() - if proxyURL != "" { - if parsedURL, err := url.Parse(proxyURL); err == nil { - responseHeaderTimeout := time.Duration(s.cfg.Gateway.ResponseHeaderTimeout) * time.Second - if responseHeaderTimeout == 0 { - responseHeaderTimeout = 300 * time.Second - } - transport := &http.Transport{ - Proxy: http.ProxyURL(parsedURL), - MaxIdleConns: 100, - MaxIdleConnsPerHost: 10, - IdleConnTimeout: 90 * time.Second, - ResponseHeaderTimeout: responseHeaderTimeout, - } - customClient = &http.Client{Transport: transport} - } - } - } - - return &buildUpstreamRequestResult{ - Request: req, - Client: customClient, - }, nil + return req, nil } // countTokensError 返回 count_tokens 错误响应 diff --git a/backend/internal/service/oauth_service.go b/backend/internal/service/oauth_service.go index 829e267d..251bf446 100644 --- a/backend/internal/service/oauth_service.go +++ b/backend/internal/service/oauth_service.go @@ -2,32 +2,36 @@ package service import ( "context" - "encoding/json" "fmt" "log" - "net/http" - "net/url" - "strings" "time" "sub2api/internal/model" "sub2api/internal/pkg/oauth" "sub2api/internal/service/ports" - - "github.com/imroc/req/v3" ) +// ClaudeOAuthClient handles HTTP requests for Claude OAuth flows +type ClaudeOAuthClient interface { + GetOrganizationUUID(ctx context.Context, sessionKey, proxyURL string) (string, error) + GetAuthorizationCode(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error) + ExchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string) (*oauth.TokenResponse, error) + RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) +} + // OAuthService handles OAuth authentication flows type OAuthService struct { sessionStore *oauth.SessionStore proxyRepo ports.ProxyRepository + oauthClient ClaudeOAuthClient } // NewOAuthService creates a new OAuth service -func NewOAuthService(proxyRepo ports.ProxyRepository) *OAuthService { +func NewOAuthService(proxyRepo ports.ProxyRepository, oauthClient ClaudeOAuthClient) *OAuthService { return &OAuthService{ sessionStore: oauth.NewSessionStore(), proxyRepo: proxyRepo, + oauthClient: oauthClient, } } @@ -210,177 +214,21 @@ func (s *OAuthService) CookieAuth(ctx context.Context, input *CookieAuthInput) ( // getOrganizationUUID gets the organization UUID from claude.ai using sessionKey func (s *OAuthService) getOrganizationUUID(ctx context.Context, sessionKey, proxyURL string) (string, error) { - client := s.createReqClient(proxyURL) - - var orgs []struct { - UUID string `json:"uuid"` - } - - targetURL := "https://claude.ai/api/organizations" - log.Printf("[OAuth] Step 1: Getting organization UUID from %s", targetURL) - - resp, err := client.R(). - SetContext(ctx). - SetCookies(&http.Cookie{ - Name: "sessionKey", - Value: sessionKey, - }). - SetSuccessResult(&orgs). - Get(targetURL) - - if err != nil { - log.Printf("[OAuth] Step 1 FAILED - Request error: %v", err) - return "", fmt.Errorf("request failed: %w", err) - } - - log.Printf("[OAuth] Step 1 Response - Status: %d, Body: %s", resp.StatusCode, resp.String()) - - if !resp.IsSuccessState() { - return "", fmt.Errorf("failed to get organizations: status %d, body: %s", resp.StatusCode, resp.String()) - } - - if len(orgs) == 0 { - return "", fmt.Errorf("no organizations found") - } - - log.Printf("[OAuth] Step 1 SUCCESS - Got org UUID: %s", orgs[0].UUID) - return orgs[0].UUID, nil + return s.oauthClient.GetOrganizationUUID(ctx, sessionKey, proxyURL) } // getAuthorizationCode gets the authorization code using sessionKey func (s *OAuthService) getAuthorizationCode(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error) { - client := s.createReqClient(proxyURL) - - authURL := fmt.Sprintf("https://claude.ai/v1/oauth/%s/authorize", orgUUID) - - // Build request body - must include organization_uuid as per CRS - reqBody := map[string]interface{}{ - "response_type": "code", - "client_id": oauth.ClientID, - "organization_uuid": orgUUID, // Required field! - "redirect_uri": oauth.RedirectURI, - "scope": scope, - "state": state, - "code_challenge": codeChallenge, - "code_challenge_method": "S256", - } - - reqBodyJSON, _ := json.Marshal(reqBody) - log.Printf("[OAuth] Step 2: Getting authorization code from %s", authURL) - log.Printf("[OAuth] Step 2 Request Body: %s", string(reqBodyJSON)) - - // Response contains redirect_uri with code, not direct code field - var result struct { - RedirectURI string `json:"redirect_uri"` - } - - resp, err := client.R(). - SetContext(ctx). - SetCookies(&http.Cookie{ - Name: "sessionKey", - Value: sessionKey, - }). - SetHeader("Accept", "application/json"). - SetHeader("Accept-Language", "en-US,en;q=0.9"). - SetHeader("Cache-Control", "no-cache"). - SetHeader("Origin", "https://claude.ai"). - SetHeader("Referer", "https://claude.ai/new"). - SetHeader("Content-Type", "application/json"). - SetBody(reqBody). - SetSuccessResult(&result). - Post(authURL) - - if err != nil { - log.Printf("[OAuth] Step 2 FAILED - Request error: %v", err) - return "", fmt.Errorf("request failed: %w", err) - } - - log.Printf("[OAuth] Step 2 Response - Status: %d, Body: %s", resp.StatusCode, resp.String()) - - if !resp.IsSuccessState() { - return "", fmt.Errorf("failed to get authorization code: status %d, body: %s", resp.StatusCode, resp.String()) - } - - if result.RedirectURI == "" { - return "", fmt.Errorf("no redirect_uri in response") - } - - // Parse redirect_uri to extract code and state - parsedURL, err := url.Parse(result.RedirectURI) - if err != nil { - return "", fmt.Errorf("failed to parse redirect_uri: %w", err) - } - - queryParams := parsedURL.Query() - authCode := queryParams.Get("code") - responseState := queryParams.Get("state") - - if authCode == "" { - return "", fmt.Errorf("no authorization code in redirect_uri") - } - - // Combine code with state if present (as CRS does) - fullCode := authCode - if responseState != "" { - fullCode = authCode + "#" + responseState - } - - log.Printf("[OAuth] Step 2 SUCCESS - Got authorization code: %s...", authCode[:20]) - return fullCode, nil + return s.oauthClient.GetAuthorizationCode(ctx, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL) } // exchangeCodeForToken exchanges authorization code for tokens func (s *OAuthService) exchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string) (*TokenInfo, error) { - client := s.createReqClient(proxyURL) - - // Parse code#state format if present - authCode := code - codeState := "" - if parts := strings.Split(code, "#"); len(parts) > 1 { - authCode = parts[0] - codeState = parts[1] - } - - // Build JSON body as CRS does (not form data!) - reqBody := map[string]interface{}{ - "code": authCode, - "grant_type": "authorization_code", - "client_id": oauth.ClientID, - "redirect_uri": oauth.RedirectURI, - "code_verifier": codeVerifier, - } - - // Add state if present - if codeState != "" { - reqBody["state"] = codeState - } - - reqBodyJSON, _ := json.Marshal(reqBody) - log.Printf("[OAuth] Step 3: Exchanging code for token at %s", oauth.TokenURL) - log.Printf("[OAuth] Step 3 Request Body: %s", string(reqBodyJSON)) - - var tokenResp oauth.TokenResponse - - resp, err := client.R(). - SetContext(ctx). - SetHeader("Content-Type", "application/json"). - SetBody(reqBody). - SetSuccessResult(&tokenResp). - Post(oauth.TokenURL) - + tokenResp, err := s.oauthClient.ExchangeCodeForToken(ctx, code, codeVerifier, state, proxyURL) if err != nil { - log.Printf("[OAuth] Step 3 FAILED - Request error: %v", err) - return nil, fmt.Errorf("request failed: %w", err) + return nil, err } - log.Printf("[OAuth] Step 3 Response - Status: %d, Body: %s", resp.StatusCode, resp.String()) - - if !resp.IsSuccessState() { - return nil, fmt.Errorf("token exchange failed: status %d, body: %s", resp.StatusCode, resp.String()) - } - - log.Printf("[OAuth] Step 3 SUCCESS - Got access token") - tokenInfo := &TokenInfo{ AccessToken: tokenResp.AccessToken, TokenType: tokenResp.TokenType, @@ -390,7 +238,6 @@ func (s *OAuthService) exchangeCodeForToken(ctx context.Context, code, codeVerif Scope: tokenResp.Scope, } - // Extract org_uuid and account_uuid from response if tokenResp.Organization != nil && tokenResp.Organization.UUID != "" { tokenInfo.OrgUUID = tokenResp.Organization.UUID log.Printf("[OAuth] Got org_uuid: %s", tokenInfo.OrgUUID) @@ -405,27 +252,9 @@ func (s *OAuthService) exchangeCodeForToken(ctx context.Context, code, codeVerif // RefreshToken refreshes an OAuth token func (s *OAuthService) RefreshToken(ctx context.Context, refreshToken string, proxyURL string) (*TokenInfo, error) { - client := s.createReqClient(proxyURL) - - formData := url.Values{} - formData.Set("grant_type", "refresh_token") - formData.Set("refresh_token", refreshToken) - formData.Set("client_id", oauth.ClientID) - - var tokenResp oauth.TokenResponse - - resp, err := client.R(). - SetContext(ctx). - SetFormDataFromValues(formData). - SetSuccessResult(&tokenResp). - Post(oauth.TokenURL) - + tokenResp, err := s.oauthClient.RefreshToken(ctx, refreshToken, proxyURL) if err != nil { - return nil, fmt.Errorf("request failed: %w", err) - } - - if !resp.IsSuccessState() { - return nil, fmt.Errorf("token refresh failed: status %d, body: %s", resp.StatusCode, resp.String()) + return nil, err } return &TokenInfo{ @@ -455,17 +284,3 @@ func (s *OAuthService) RefreshAccountToken(ctx context.Context, account *model.A return s.RefreshToken(ctx, refreshToken, proxyURL) } - -// createReqClient creates a req client with Chrome impersonation and optional proxy -func (s *OAuthService) createReqClient(proxyURL string) *req.Client { - client := req.C(). - ImpersonateChrome(). // Impersonate Chrome browser to bypass Cloudflare - SetTimeout(60 * time.Second) - - // Set proxy if specified - if proxyURL != "" { - client.SetProxyURL(proxyURL) - } - - return client -} diff --git a/backend/internal/service/pricing_service.go b/backend/internal/service/pricing_service.go index d71908a5..69b6c2ef 100644 --- a/backend/internal/service/pricing_service.go +++ b/backend/internal/service/pricing_service.go @@ -1,13 +1,12 @@ package service import ( + "context" "crypto/sha256" "encoding/hex" "encoding/json" "fmt" - "io" "log" - "net/http" "os" "path/filepath" "strings" @@ -20,13 +19,19 @@ import ( // LiteLLMModelPricing LiteLLM价格数据结构 // 只保留我们需要的字段,使用指针来处理可能缺失的值 type LiteLLMModelPricing struct { - InputCostPerToken float64 `json:"input_cost_per_token"` - OutputCostPerToken float64 `json:"output_cost_per_token"` - CacheCreationInputTokenCost float64 `json:"cache_creation_input_token_cost"` - CacheReadInputTokenCost float64 `json:"cache_read_input_token_cost"` - LiteLLMProvider string `json:"litellm_provider"` - Mode string `json:"mode"` - SupportsPromptCaching bool `json:"supports_prompt_caching"` + InputCostPerToken float64 `json:"input_cost_per_token"` + OutputCostPerToken float64 `json:"output_cost_per_token"` + CacheCreationInputTokenCost float64 `json:"cache_creation_input_token_cost"` + CacheReadInputTokenCost float64 `json:"cache_read_input_token_cost"` + LiteLLMProvider string `json:"litellm_provider"` + Mode string `json:"mode"` + SupportsPromptCaching bool `json:"supports_prompt_caching"` +} + +// PricingRemoteClient 远程价格数据获取接口 +type PricingRemoteClient interface { + FetchPricingJSON(ctx context.Context, url string) ([]byte, error) + FetchHashText(ctx context.Context, url string) (string, error) } // LiteLLMRawEntry 用于解析原始JSON数据 @@ -42,11 +47,12 @@ type LiteLLMRawEntry struct { // PricingService 动态价格服务 type PricingService struct { - cfg *config.Config - mu sync.RWMutex - pricingData map[string]*LiteLLMModelPricing - lastUpdated time.Time - localHash string + cfg *config.Config + remoteClient PricingRemoteClient + mu sync.RWMutex + pricingData map[string]*LiteLLMModelPricing + lastUpdated time.Time + localHash string // 停止信号 stopCh chan struct{} @@ -54,11 +60,12 @@ type PricingService struct { } // NewPricingService 创建价格服务 -func NewPricingService(cfg *config.Config) *PricingService { +func NewPricingService(cfg *config.Config, remoteClient PricingRemoteClient) *PricingService { s := &PricingService{ - cfg: cfg, - pricingData: make(map[string]*LiteLLMModelPricing), - stopCh: make(chan struct{}), + cfg: cfg, + remoteClient: remoteClient, + pricingData: make(map[string]*LiteLLMModelPricing), + stopCh: make(chan struct{}), } return s } @@ -199,21 +206,13 @@ func (s *PricingService) syncWithRemote() error { func (s *PricingService) downloadPricingData() error { log.Printf("[Pricing] Downloading from %s", s.cfg.Pricing.RemoteURL) - client := &http.Client{Timeout: 30 * time.Second} - resp, err := client.Get(s.cfg.Pricing.RemoteURL) + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + body, err := s.remoteClient.FetchPricingJSON(ctx, s.cfg.Pricing.RemoteURL) if err != nil { return fmt.Errorf("download failed: %w", err) } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("download failed: HTTP %d", resp.StatusCode) - } - - body, err := io.ReadAll(resp.Body) - if err != nil { - return fmt.Errorf("read response failed: %w", err) - } // 解析JSON数据(使用灵活的解析方式) data, err := s.parsePricingData(body) @@ -367,29 +366,10 @@ func (s *PricingService) useFallbackPricing() error { // fetchRemoteHash 从远程获取哈希值 func (s *PricingService) fetchRemoteHash() (string, error) { - client := &http.Client{Timeout: 10 * time.Second} - resp, err := client.Get(s.cfg.Pricing.HashURL) - if err != nil { - return "", err - } - defer resp.Body.Close() + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() - if resp.StatusCode != http.StatusOK { - return "", fmt.Errorf("HTTP %d", resp.StatusCode) - } - - body, err := io.ReadAll(resp.Body) - if err != nil { - return "", err - } - - // 哈希文件格式:hash filename 或者纯 hash - hash := strings.TrimSpace(string(body)) - parts := strings.Fields(hash) - if len(parts) > 0 { - return parts[0], nil - } - return hash, nil + return s.remoteClient.FetchHashText(ctx, s.cfg.Pricing.HashURL) } // computeFileHash 计算文件哈希 @@ -466,14 +446,14 @@ func (s *PricingService) extractBaseName(model string) string { func (s *PricingService) matchByModelFamily(model string) *LiteLLMModelPricing { // Claude模型系列匹配规则 familyPatterns := map[string][]string{ - "opus-4.5": {"claude-opus-4.5", "claude-opus-4-5"}, - "opus-4": {"claude-opus-4", "claude-3-opus"}, - "sonnet-4.5": {"claude-sonnet-4.5", "claude-sonnet-4-5"}, - "sonnet-4": {"claude-sonnet-4", "claude-3-5-sonnet"}, - "sonnet-3.5": {"claude-3-5-sonnet", "claude-3.5-sonnet"}, - "sonnet-3": {"claude-3-sonnet"}, - "haiku-3.5": {"claude-3-5-haiku", "claude-3.5-haiku"}, - "haiku-3": {"claude-3-haiku"}, + "opus-4.5": {"claude-opus-4.5", "claude-opus-4-5"}, + "opus-4": {"claude-opus-4", "claude-3-opus"}, + "sonnet-4.5": {"claude-sonnet-4.5", "claude-sonnet-4-5"}, + "sonnet-4": {"claude-sonnet-4", "claude-3-5-sonnet"}, + "sonnet-3.5": {"claude-3-5-sonnet", "claude-3.5-sonnet"}, + "sonnet-3": {"claude-3-sonnet"}, + "haiku-3.5": {"claude-3-5-haiku", "claude-3.5-haiku"}, + "haiku-3": {"claude-3-haiku"}, } // 确定模型属于哪个系列 diff --git a/backend/internal/service/service.go b/backend/internal/service/service.go index 10b98354..b5f37e1d 100644 --- a/backend/internal/service/service.go +++ b/backend/internal/service/service.go @@ -26,4 +26,5 @@ type Services struct { Subscription *SubscriptionService Concurrency *ConcurrencyService Identity *IdentityService + Update *UpdateService } diff --git a/backend/internal/service/turnstile_service.go b/backend/internal/service/turnstile_service.go index 7603c782..81b6e3a0 100644 --- a/backend/internal/service/turnstile_service.go +++ b/backend/internal/service/turnstile_service.go @@ -2,14 +2,9 @@ package service import ( "context" - "encoding/json" "errors" "fmt" "log" - "net/http" - "net/url" - "strings" - "time" ) var ( @@ -19,10 +14,15 @@ var ( const turnstileVerifyURL = "https://challenges.cloudflare.com/turnstile/v0/siteverify" +// TurnstileVerifier 验证 Turnstile token 的接口 +type TurnstileVerifier interface { + VerifyToken(ctx context.Context, secretKey, token, remoteIP string) (*TurnstileVerifyResponse, error) +} + // TurnstileService Turnstile 验证服务 type TurnstileService struct { settingService *SettingService - httpClient *http.Client + verifier TurnstileVerifier } // TurnstileVerifyResponse Cloudflare Turnstile 验证响应 @@ -36,12 +36,10 @@ type TurnstileVerifyResponse struct { } // NewTurnstileService 创建 Turnstile 服务实例 -func NewTurnstileService(settingService *SettingService) *TurnstileService { +func NewTurnstileService(settingService *SettingService, verifier TurnstileVerifier) *TurnstileService { return &TurnstileService{ settingService: settingService, - httpClient: &http.Client{ - Timeout: 10 * time.Second, - }, + verifier: verifier, } } @@ -66,35 +64,12 @@ func (s *TurnstileService) VerifyToken(ctx context.Context, token string, remote return ErrTurnstileVerificationFailed } - // 构建请求 - formData := url.Values{} - formData.Set("secret", secretKey) - formData.Set("response", token) - if remoteIP != "" { - formData.Set("remoteip", remoteIP) - } - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, turnstileVerifyURL, strings.NewReader(formData.Encode())) - if err != nil { - return fmt.Errorf("create request: %w", err) - } - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - - // 发送请求 log.Printf("[Turnstile] Verifying token for IP: %s", remoteIP) - resp, err := s.httpClient.Do(req) + result, err := s.verifier.VerifyToken(ctx, secretKey, token, remoteIP) if err != nil { log.Printf("[Turnstile] Request failed: %v", err) return fmt.Errorf("send request: %w", err) } - defer resp.Body.Close() - - // 解析响应 - var result TurnstileVerifyResponse - if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { - log.Printf("[Turnstile] Failed to decode response: %v", err) - return fmt.Errorf("decode response: %w", err) - } if !result.Success { log.Printf("[Turnstile] Verification failed, error codes: %v", result.ErrorCodes) diff --git a/backend/internal/service/update_service.go b/backend/internal/service/update_service.go index c65799cc..bc88575b 100644 --- a/backend/internal/service/update_service.go +++ b/backend/internal/service/update_service.go @@ -10,7 +10,6 @@ import ( "encoding/json" "fmt" "io" - "net/http" "net/url" "os" "path/filepath" @@ -34,17 +33,26 @@ const ( maxDownloadSize = 500 * 1024 * 1024 ) +// GitHubReleaseClient 获取 GitHub release 信息的接口 +type GitHubReleaseClient interface { + FetchLatestRelease(ctx context.Context, repo string) (*GitHubRelease, error) + DownloadFile(ctx context.Context, url, dest string, maxSize int64) error + FetchChecksumFile(ctx context.Context, url string) ([]byte, error) +} + // UpdateService handles software updates type UpdateService struct { cache ports.UpdateCache + githubClient GitHubReleaseClient currentVersion string buildType string // "source" for manual builds, "release" for CI builds } // NewUpdateService creates a new UpdateService -func NewUpdateService(cache ports.UpdateCache, version, buildType string) *UpdateService { +func NewUpdateService(cache ports.UpdateCache, githubClient GitHubReleaseClient, version, buildType string) *UpdateService { return &UpdateService{ cache: cache, + githubClient: githubClient, currentVersion: version, buildType: buildType, } @@ -260,42 +268,11 @@ func (s *UpdateService) Rollback() error { return nil } - func (s *UpdateService) fetchLatestRelease(ctx context.Context) (*UpdateInfo, error) { - url := fmt.Sprintf("https://api.github.com/repos/%s/releases/latest", githubRepo) - - req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + release, err := s.githubClient.FetchLatestRelease(ctx, githubRepo) if err != nil { return nil, err } - req.Header.Set("Accept", "application/vnd.github.v3+json") - req.Header.Set("User-Agent", "Sub2API-Updater") - - client := &http.Client{Timeout: 30 * time.Second} - resp, err := client.Do(req) - if err != nil { - return nil, err - } - defer resp.Body.Close() - - if resp.StatusCode == http.StatusNotFound { - return &UpdateInfo{ - CurrentVersion: s.currentVersion, - LatestVersion: s.currentVersion, - HasUpdate: false, - Warning: "No releases found", - BuildType: s.buildType, - }, nil - } - - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("GitHub API returned %d", resp.StatusCode) - } - - var release GitHubRelease - if err := json.NewDecoder(resp.Body).Decode(&release); err != nil { - return nil, err - } latestVersion := strings.TrimPrefix(release.TagName, "v") @@ -325,47 +302,7 @@ func (s *UpdateService) fetchLatestRelease(ctx context.Context) (*UpdateInfo, er } func (s *UpdateService) downloadFile(ctx context.Context, downloadURL, dest string) error { - req, err := http.NewRequestWithContext(ctx, "GET", downloadURL, nil) - if err != nil { - return err - } - - client := &http.Client{Timeout: 10 * time.Minute} - resp, err := client.Do(req) - if err != nil { - return err - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("download returned %d", resp.StatusCode) - } - - // SECURITY: Check Content-Length if available - if resp.ContentLength > maxDownloadSize { - return fmt.Errorf("file too large: %d bytes (max %d)", resp.ContentLength, maxDownloadSize) - } - - out, err := os.Create(dest) - if err != nil { - return err - } - defer out.Close() - - // SECURITY: Use LimitReader to enforce max download size even if Content-Length is missing/wrong - limited := io.LimitReader(resp.Body, maxDownloadSize+1) - written, err := io.Copy(out, limited) - if err != nil { - return err - } - - // Check if we hit the limit (downloaded more than maxDownloadSize) - if written > maxDownloadSize { - os.Remove(dest) // Clean up partial file - return fmt.Errorf("download exceeded maximum size of %d bytes", maxDownloadSize) - } - - return nil + return s.githubClient.DownloadFile(ctx, downloadURL, dest, maxDownloadSize) } func (s *UpdateService) getArchiveName() string { @@ -402,20 +339,9 @@ func validateDownloadURL(rawURL string) error { func (s *UpdateService) verifyChecksum(ctx context.Context, filePath, checksumURL string) error { // Download checksums file - req, err := http.NewRequestWithContext(ctx, "GET", checksumURL, nil) + checksumData, err := s.githubClient.FetchChecksumFile(ctx, checksumURL) if err != nil { - return err - } - - client := &http.Client{Timeout: 30 * time.Second} - resp, err := client.Do(req) - if err != nil { - return err - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("failed to download checksums: %d", resp.StatusCode) + return fmt.Errorf("failed to download checksums: %w", err) } // Calculate file hash @@ -433,7 +359,7 @@ func (s *UpdateService) verifyChecksum(ctx context.Context, filePath, checksumUR // Find expected hash in checksums file fileName := filepath.Base(filePath) - scanner := bufio.NewScanner(resp.Body) + scanner := bufio.NewScanner(strings.NewReader(string(checksumData))) for scanner.Scan() { line := scanner.Text() parts := strings.Fields(line) diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index 9fbe670f..eca145a5 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -2,13 +2,20 @@ package service import ( "sub2api/internal/config" + "sub2api/internal/service/ports" "github.com/google/wire" ) +// BuildInfo contains build information +type BuildInfo struct { + Version string + BuildType string +} + // ProvidePricingService creates and initializes PricingService -func ProvidePricingService(cfg *config.Config) (*PricingService, error) { - svc := NewPricingService(cfg) +func ProvidePricingService(cfg *config.Config, remoteClient PricingRemoteClient) (*PricingService, error) { + svc := NewPricingService(cfg, remoteClient) if err := svc.Initialize(); err != nil { // 价格服务初始化失败不应阻止启动,使用回退价格 println("[Service] Warning: Pricing service initialization failed:", err.Error()) @@ -16,6 +23,11 @@ func ProvidePricingService(cfg *config.Config) (*PricingService, error) { return svc, nil } +// ProvideUpdateService creates UpdateService with BuildInfo +func ProvideUpdateService(cache ports.UpdateCache, githubClient GitHubReleaseClient, buildInfo BuildInfo) *UpdateService { + return NewUpdateService(cache, githubClient, buildInfo.Version, buildInfo.BuildType) +} + // ProvideEmailQueueService creates EmailQueueService with default worker count func ProvideEmailQueueService(emailService *EmailService) *EmailQueueService { return NewEmailQueueService(emailService, 3) @@ -48,6 +60,7 @@ var ProviderSet = wire.NewSet( NewSubscriptionService, NewConcurrencyService, NewIdentityService, + ProvideUpdateService, // Provide the Services container struct wire.Struct(new(Services), "*"),