Merge pull request #9 from NepetaLemon/refactor/add-http-service-ports

refactor(backend): service http ports
This commit is contained in:
Wesley Liddick
2025-12-19 23:35:13 -05:00
committed by GitHub
22 changed files with 927 additions and 695 deletions

View File

@@ -40,6 +40,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
// 服务器层 ProviderSet // 服务器层 ProviderSet
server.ProviderSet, server.ProviderSet,
// BuildInfo provider
provideServiceBuildInfo,
// 清理函数提供者 // 清理函数提供者
provideCleanup, provideCleanup,
@@ -49,6 +52,13 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
return nil, nil return nil, nil
} }
func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo {
return service.BuildInfo{
Version: buildInfo.Version,
BuildType: buildInfo.BuildType,
}
}
func provideCleanup( func provideCleanup(
db *gorm.DB, db *gorm.DB,
rdb *redis.Client, rdb *redis.Client,

View File

@@ -43,7 +43,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
client := infrastructure.ProvideRedis(configConfig) client := infrastructure.ProvideRedis(configConfig)
emailCache := repository.NewEmailCache(client) emailCache := repository.NewEmailCache(client)
emailService := service.NewEmailService(settingRepository, emailCache) emailService := service.NewEmailService(settingRepository, emailCache)
turnstileService := service.NewTurnstileService(settingService) turnstileVerifier := repository.NewTurnstileVerifier()
turnstileService := service.NewTurnstileService(settingService, turnstileVerifier)
emailQueueService := service.ProvideEmailQueueService(emailService) emailQueueService := service.ProvideEmailQueueService(emailService)
authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService) authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService)
authHandler := handler.NewAuthHandler(authService) authHandler := handler.NewAuthHandler(authService)
@@ -68,32 +69,41 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService) subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService)
accountRepository := repository.NewAccountRepository(db) accountRepository := repository.NewAccountRepository(db)
proxyRepository := repository.NewProxyRepository(db) proxyRepository := repository.NewProxyRepository(db)
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, usageLogRepository, userSubscriptionRepository, billingCacheService) proxyExitInfoProber := repository.NewProxyExitInfoProber()
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, usageLogRepository, userSubscriptionRepository, billingCacheService, proxyExitInfoProber)
dashboardHandler := admin.NewDashboardHandler(adminService, usageLogRepository) dashboardHandler := admin.NewDashboardHandler(adminService, usageLogRepository)
adminUserHandler := admin.NewUserHandler(adminService) adminUserHandler := admin.NewUserHandler(adminService)
groupHandler := admin.NewGroupHandler(adminService) groupHandler := admin.NewGroupHandler(adminService)
oAuthService := service.NewOAuthService(proxyRepository) claudeOAuthClient := repository.NewClaudeOAuthClient()
oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient)
rateLimitService := service.NewRateLimitService(accountRepository, configConfig) rateLimitService := service.NewRateLimitService(accountRepository, configConfig)
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, oAuthService) claudeUsageFetcher := repository.NewClaudeUsageFetcher()
accountTestService := service.NewAccountTestService(accountRepository, oAuthService) accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, oAuthService, claudeUsageFetcher)
claudeUpstream := repository.NewClaudeUpstream(configConfig)
accountTestService := service.NewAccountTestService(accountRepository, oAuthService, claudeUpstream)
accountHandler := admin.NewAccountHandler(adminService, oAuthService, rateLimitService, accountUsageService, accountTestService) accountHandler := admin.NewAccountHandler(adminService, oAuthService, rateLimitService, accountUsageService, accountTestService)
oAuthHandler := admin.NewOAuthHandler(oAuthService, adminService) oAuthHandler := admin.NewOAuthHandler(oAuthService, adminService)
proxyHandler := admin.NewProxyHandler(adminService) proxyHandler := admin.NewProxyHandler(adminService)
adminRedeemHandler := admin.NewRedeemHandler(adminService) adminRedeemHandler := admin.NewRedeemHandler(adminService)
settingHandler := admin.NewSettingHandler(settingService, emailService) settingHandler := admin.NewSettingHandler(settingService, emailService)
systemHandler := handler.ProvideSystemHandler(client, buildInfo) updateCache := repository.NewUpdateCache(client)
gitHubReleaseClient := repository.NewGitHubReleaseClient()
serviceBuildInfo := provideServiceBuildInfo(buildInfo)
updateService := service.ProvideUpdateService(updateCache, gitHubReleaseClient, serviceBuildInfo)
systemHandler := handler.ProvideSystemHandler(updateService)
adminSubscriptionHandler := admin.NewSubscriptionHandler(subscriptionService) adminSubscriptionHandler := admin.NewSubscriptionHandler(subscriptionService)
adminUsageHandler := admin.NewUsageHandler(usageLogRepository, apiKeyRepository, usageService, adminService) adminUsageHandler := admin.NewUsageHandler(usageLogRepository, apiKeyRepository, usageService, adminService)
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, proxyHandler, adminRedeemHandler, settingHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler) adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, proxyHandler, adminRedeemHandler, settingHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler)
gatewayCache := repository.NewGatewayCache(client) gatewayCache := repository.NewGatewayCache(client)
pricingService, err := service.ProvidePricingService(configConfig) pricingRemoteClient := repository.NewPricingRemoteClient()
pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient)
if err != nil { if err != nil {
return nil, err return nil, err
} }
billingService := service.NewBillingService(configConfig, pricingService) billingService := service.NewBillingService(configConfig, pricingService)
identityCache := repository.NewIdentityCache(client) identityCache := repository.NewIdentityCache(client)
identityService := service.NewIdentityService(identityCache) identityService := service.NewIdentityService(identityCache)
gatewayService := service.NewGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, oAuthService, billingService, rateLimitService, billingCacheService, identityService) gatewayService := service.NewGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, oAuthService, billingService, rateLimitService, billingCacheService, identityService, claudeUpstream)
concurrencyCache := repository.NewConcurrencyCache(client) concurrencyCache := repository.NewConcurrencyCache(client)
concurrencyService := service.NewConcurrencyService(concurrencyCache) concurrencyService := service.NewConcurrencyService(concurrencyCache)
gatewayHandler := handler.NewGatewayHandler(gatewayService, userService, concurrencyService, billingCacheService) gatewayHandler := handler.NewGatewayHandler(gatewayService, userService, concurrencyService, billingCacheService)
@@ -127,6 +137,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
Subscription: subscriptionService, Subscription: subscriptionService,
Concurrency: concurrencyService, Concurrency: concurrencyService,
Identity: identityService, Identity: identityService,
Update: updateService,
} }
repositories := &repository.Repositories{ repositories := &repository.Repositories{
User: userRepository, User: userRepository,
@@ -156,6 +167,13 @@ type Application struct {
Cleanup func() Cleanup func()
} }
func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo {
return service.BuildInfo{
Version: buildInfo.Version,
BuildType: buildInfo.BuildType,
}
}
func provideCleanup( func provideCleanup(
db *gorm.DB, db *gorm.DB,
rdb *redis.Client, rdb *redis.Client,

View File

@@ -6,11 +6,9 @@ import (
"sub2api/internal/pkg/response" "sub2api/internal/pkg/response"
"sub2api/internal/pkg/sysutil" "sub2api/internal/pkg/sysutil"
"sub2api/internal/repository"
"sub2api/internal/service" "sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/redis/go-redis/v9"
) )
// SystemHandler handles system-related operations // SystemHandler handles system-related operations
@@ -19,10 +17,9 @@ type SystemHandler struct {
} }
// NewSystemHandler creates a new SystemHandler // NewSystemHandler creates a new SystemHandler
func NewSystemHandler(rdb *redis.Client, version, buildType string) *SystemHandler { func NewSystemHandler(updateSvc *service.UpdateService) *SystemHandler {
updateCache := repository.NewUpdateCache(rdb)
return &SystemHandler{ return &SystemHandler{
updateSvc: service.NewUpdateService(updateCache, version, buildType), updateSvc: updateSvc,
} }
} }

View File

@@ -5,7 +5,6 @@ import (
"sub2api/internal/service" "sub2api/internal/service"
"github.com/google/wire" "github.com/google/wire"
"github.com/redis/go-redis/v9"
) )
// ProvideAdminHandlers creates the AdminHandlers struct // ProvideAdminHandlers creates the AdminHandlers struct
@@ -37,9 +36,9 @@ func ProvideAdminHandlers(
} }
} }
// ProvideSystemHandler creates admin.SystemHandler with BuildInfo parameters // ProvideSystemHandler creates admin.SystemHandler with UpdateService
func ProvideSystemHandler(rdb *redis.Client, buildInfo BuildInfo) *admin.SystemHandler { func ProvideSystemHandler(updateService *service.UpdateService) *admin.SystemHandler {
return admin.NewSystemHandler(rdb, buildInfo.Version, buildInfo.BuildType) return admin.NewSystemHandler(updateService)
} }
// ProvideSettingHandler creates SettingHandler with version from BuildInfo // ProvideSettingHandler creates SettingHandler with version from BuildInfo

View File

@@ -0,0 +1,235 @@
package repository
import (
"context"
"encoding/json"
"fmt"
"log"
"net/http"
"net/url"
"time"
"sub2api/internal/pkg/oauth"
"sub2api/internal/service"
"github.com/imroc/req/v3"
)
type claudeOAuthService struct{}
func NewClaudeOAuthClient() service.ClaudeOAuthClient {
return &claudeOAuthService{}
}
func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey, proxyURL string) (string, error) {
client := createReqClient(proxyURL)
var orgs []struct {
UUID string `json:"uuid"`
}
targetURL := "https://claude.ai/api/organizations"
log.Printf("[OAuth] Step 1: Getting organization UUID from %s", targetURL)
resp, err := client.R().
SetContext(ctx).
SetCookies(&http.Cookie{
Name: "sessionKey",
Value: sessionKey,
}).
SetSuccessResult(&orgs).
Get(targetURL)
if err != nil {
log.Printf("[OAuth] Step 1 FAILED - Request error: %v", err)
return "", fmt.Errorf("request failed: %w", err)
}
log.Printf("[OAuth] Step 1 Response - Status: %d, Body: %s", resp.StatusCode, resp.String())
if !resp.IsSuccessState() {
return "", fmt.Errorf("failed to get organizations: status %d, body: %s", resp.StatusCode, resp.String())
}
if len(orgs) == 0 {
return "", fmt.Errorf("no organizations found")
}
log.Printf("[OAuth] Step 1 SUCCESS - Got org UUID: %s", orgs[0].UUID)
return orgs[0].UUID, nil
}
func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error) {
client := createReqClient(proxyURL)
authURL := fmt.Sprintf("https://claude.ai/v1/oauth/%s/authorize", orgUUID)
reqBody := map[string]interface{}{
"response_type": "code",
"client_id": oauth.ClientID,
"organization_uuid": orgUUID,
"redirect_uri": oauth.RedirectURI,
"scope": scope,
"state": state,
"code_challenge": codeChallenge,
"code_challenge_method": "S256",
}
reqBodyJSON, _ := json.Marshal(reqBody)
log.Printf("[OAuth] Step 2: Getting authorization code from %s", authURL)
log.Printf("[OAuth] Step 2 Request Body: %s", string(reqBodyJSON))
var result struct {
RedirectURI string `json:"redirect_uri"`
}
resp, err := client.R().
SetContext(ctx).
SetCookies(&http.Cookie{
Name: "sessionKey",
Value: sessionKey,
}).
SetHeader("Accept", "application/json").
SetHeader("Accept-Language", "en-US,en;q=0.9").
SetHeader("Cache-Control", "no-cache").
SetHeader("Origin", "https://claude.ai").
SetHeader("Referer", "https://claude.ai/new").
SetHeader("Content-Type", "application/json").
SetBody(reqBody).
SetSuccessResult(&result).
Post(authURL)
if err != nil {
log.Printf("[OAuth] Step 2 FAILED - Request error: %v", err)
return "", fmt.Errorf("request failed: %w", err)
}
log.Printf("[OAuth] Step 2 Response - Status: %d, Body: %s", resp.StatusCode, resp.String())
if !resp.IsSuccessState() {
return "", fmt.Errorf("failed to get authorization code: status %d, body: %s", resp.StatusCode, resp.String())
}
if result.RedirectURI == "" {
return "", fmt.Errorf("no redirect_uri in response")
}
parsedURL, err := url.Parse(result.RedirectURI)
if err != nil {
return "", fmt.Errorf("failed to parse redirect_uri: %w", err)
}
queryParams := parsedURL.Query()
authCode := queryParams.Get("code")
responseState := queryParams.Get("state")
if authCode == "" {
return "", fmt.Errorf("no authorization code in redirect_uri")
}
fullCode := authCode
if responseState != "" {
fullCode = authCode + "#" + responseState
}
log.Printf("[OAuth] Step 2 SUCCESS - Got authorization code: %s...", authCode[:20])
return fullCode, nil
}
func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string) (*oauth.TokenResponse, error) {
client := createReqClient(proxyURL)
authCode := code
codeState := ""
if len(code) > 0 {
parts := make([]string, 0, 2)
for i, part := range []rune(code) {
if part == '#' {
authCode = code[:i]
codeState = code[i+1:]
break
}
}
if len(parts) == 0 {
authCode = code
}
}
reqBody := map[string]interface{}{
"code": authCode,
"grant_type": "authorization_code",
"client_id": oauth.ClientID,
"redirect_uri": oauth.RedirectURI,
"code_verifier": codeVerifier,
}
if codeState != "" {
reqBody["state"] = codeState
}
reqBodyJSON, _ := json.Marshal(reqBody)
log.Printf("[OAuth] Step 3: Exchanging code for token at %s", oauth.TokenURL)
log.Printf("[OAuth] Step 3 Request Body: %s", string(reqBodyJSON))
var tokenResp oauth.TokenResponse
resp, err := client.R().
SetContext(ctx).
SetHeader("Content-Type", "application/json").
SetBody(reqBody).
SetSuccessResult(&tokenResp).
Post(oauth.TokenURL)
if err != nil {
log.Printf("[OAuth] Step 3 FAILED - Request error: %v", err)
return nil, fmt.Errorf("request failed: %w", err)
}
log.Printf("[OAuth] Step 3 Response - Status: %d, Body: %s", resp.StatusCode, resp.String())
if !resp.IsSuccessState() {
return nil, fmt.Errorf("token exchange failed: status %d, body: %s", resp.StatusCode, resp.String())
}
log.Printf("[OAuth] Step 3 SUCCESS - Got access token")
return &tokenResp, nil
}
func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) {
client := createReqClient(proxyURL)
formData := url.Values{}
formData.Set("grant_type", "refresh_token")
formData.Set("refresh_token", refreshToken)
formData.Set("client_id", oauth.ClientID)
var tokenResp oauth.TokenResponse
resp, err := client.R().
SetContext(ctx).
SetFormDataFromValues(formData).
SetSuccessResult(&tokenResp).
Post(oauth.TokenURL)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
if !resp.IsSuccessState() {
return nil, fmt.Errorf("token refresh failed: status %d, body: %s", resp.StatusCode, resp.String())
}
return &tokenResp, nil
}
func createReqClient(proxyURL string) *req.Client {
client := req.C().
ImpersonateChrome().
SetTimeout(60 * time.Second)
if proxyURL != "" {
client.SetProxyURL(proxyURL)
}
return client
}

View File

@@ -0,0 +1,64 @@
package repository
import (
"net/http"
"net/url"
"time"
"sub2api/internal/config"
"sub2api/internal/service"
)
type claudeUpstreamService struct {
defaultClient *http.Client
cfg *config.Config
}
func NewClaudeUpstream(cfg *config.Config) service.ClaudeUpstream {
responseHeaderTimeout := time.Duration(cfg.Gateway.ResponseHeaderTimeout) * time.Second
if responseHeaderTimeout == 0 {
responseHeaderTimeout = 300 * time.Second
}
transport := &http.Transport{
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: 90 * time.Second,
ResponseHeaderTimeout: responseHeaderTimeout,
}
return &claudeUpstreamService{
defaultClient: &http.Client{Transport: transport},
cfg: cfg,
}
}
func (s *claudeUpstreamService) Do(req *http.Request, proxyURL string) (*http.Response, error) {
if proxyURL == "" {
return s.defaultClient.Do(req)
}
client := s.createProxyClient(proxyURL)
return client.Do(req)
}
func (s *claudeUpstreamService) createProxyClient(proxyURL string) *http.Client {
parsedURL, err := url.Parse(proxyURL)
if err != nil {
return s.defaultClient
}
responseHeaderTimeout := time.Duration(s.cfg.Gateway.ResponseHeaderTimeout) * time.Second
if responseHeaderTimeout == 0 {
responseHeaderTimeout = 300 * time.Second
}
transport := &http.Transport{
Proxy: http.ProxyURL(parsedURL),
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: 90 * time.Second,
ResponseHeaderTimeout: responseHeaderTimeout,
}
return &http.Client{Transport: transport}
}

View File

@@ -0,0 +1,59 @@
package repository
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"time"
"sub2api/internal/service"
)
type claudeUsageService struct{}
func NewClaudeUsageFetcher() service.ClaudeUsageFetcher {
return &claudeUsageService{}
}
func (s *claudeUsageService) FetchUsage(ctx context.Context, accessToken, proxyURL string) (*service.ClaudeUsageResponse, error) {
transport := http.DefaultTransport.(*http.Transport).Clone()
if proxyURL != "" {
if parsedURL, err := url.Parse(proxyURL); err == nil {
transport.Proxy = http.ProxyURL(parsedURL)
}
}
client := &http.Client{
Transport: transport,
Timeout: 30 * time.Second,
}
req, err := http.NewRequestWithContext(ctx, "GET", "https://api.anthropic.com/api/oauth/usage", nil)
if err != nil {
return nil, fmt.Errorf("create request failed: %w", err)
}
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("anthropic-beta", "oauth-2025-04-20")
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("API returned status %d: %s", resp.StatusCode, string(body))
}
var usageResp service.ClaudeUsageResponse
if err := json.NewDecoder(resp.Body).Decode(&usageResp); err != nil {
return nil, fmt.Errorf("decode response failed: %w", err)
}
return &usageResp, nil
}

View File

@@ -0,0 +1,116 @@
package repository
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"time"
"sub2api/internal/service"
)
type githubReleaseClient struct {
httpClient *http.Client
}
func NewGitHubReleaseClient() service.GitHubReleaseClient {
return &githubReleaseClient{
httpClient: &http.Client{
Timeout: 30 * time.Second,
},
}
}
func (c *githubReleaseClient) FetchLatestRelease(ctx context.Context, repo string) (*service.GitHubRelease, error) {
url := fmt.Sprintf("https://api.github.com/repos/%s/releases/latest", repo)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, err
}
req.Header.Set("Accept", "application/vnd.github.v3+json")
req.Header.Set("User-Agent", "Sub2API-Updater")
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("GitHub API returned %d", resp.StatusCode)
}
var release service.GitHubRelease
if err := json.NewDecoder(resp.Body).Decode(&release); err != nil {
return nil, err
}
return &release, nil
}
func (c *githubReleaseClient) DownloadFile(ctx context.Context, url, dest string, maxSize int64) error {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return err
}
client := &http.Client{Timeout: 10 * time.Minute}
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("download returned %d", resp.StatusCode)
}
// SECURITY: Check Content-Length if available
if resp.ContentLength > maxSize {
return fmt.Errorf("file too large: %d bytes (max %d)", resp.ContentLength, maxSize)
}
out, err := os.Create(dest)
if err != nil {
return err
}
defer out.Close()
// SECURITY: Use LimitReader to enforce max download size even if Content-Length is missing/wrong
limited := io.LimitReader(resp.Body, maxSize+1)
written, err := io.Copy(out, limited)
if err != nil {
return err
}
// Check if we hit the limit (downloaded more than maxSize)
if written > maxSize {
os.Remove(dest) // Clean up partial file
return fmt.Errorf("download exceeded maximum size of %d bytes", maxSize)
}
return nil
}
func (c *githubReleaseClient) FetchChecksumFile(ctx context.Context, url string) ([]byte, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, err
}
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("HTTP %d", resp.StatusCode)
}
return io.ReadAll(resp.Body)
}

View File

@@ -0,0 +1,73 @@
package repository
import (
"context"
"fmt"
"io"
"net/http"
"strings"
"time"
"sub2api/internal/service"
)
type pricingRemoteClient struct {
httpClient *http.Client
}
func NewPricingRemoteClient() service.PricingRemoteClient {
return &pricingRemoteClient{
httpClient: &http.Client{
Timeout: 30 * time.Second,
},
}
}
func (c *pricingRemoteClient) FetchPricingJSON(ctx context.Context, url string) ([]byte, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, err
}
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("HTTP %d", resp.StatusCode)
}
return io.ReadAll(resp.Body)
}
func (c *pricingRemoteClient) FetchHashText(ctx context.Context, url string) (string, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return "", err
}
resp, err := c.httpClient.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("HTTP %d", resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", err
}
// 哈希文件格式hash filename 或者纯 hash
hash := strings.TrimSpace(string(body))
parts := strings.Fields(hash)
if len(parts) > 0 {
return parts[0], nil
}
return hash, nil
}

View File

@@ -0,0 +1,104 @@
package repository
import (
"context"
"crypto/tls"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/url"
"time"
"sub2api/internal/service"
"golang.org/x/net/proxy"
)
type proxyProbeService struct{}
func NewProxyExitInfoProber() service.ProxyExitInfoProber {
return &proxyProbeService{}
}
func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*service.ProxyExitInfo, int64, error) {
transport, err := createProxyTransport(proxyURL)
if err != nil {
return nil, 0, fmt.Errorf("failed to create proxy transport: %w", err)
}
client := &http.Client{
Transport: transport,
Timeout: 15 * time.Second,
}
startTime := time.Now()
req, err := http.NewRequestWithContext(ctx, "GET", "https://ipinfo.io/json", nil)
if err != nil {
return nil, 0, fmt.Errorf("failed to create request: %w", err)
}
resp, err := client.Do(req)
if err != nil {
return nil, 0, fmt.Errorf("proxy connection failed: %w", err)
}
defer resp.Body.Close()
latencyMs := time.Since(startTime).Milliseconds()
if resp.StatusCode != http.StatusOK {
return nil, latencyMs, fmt.Errorf("request failed with status: %d", resp.StatusCode)
}
var ipInfo struct {
IP string `json:"ip"`
City string `json:"city"`
Region string `json:"region"`
Country string `json:"country"`
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, latencyMs, fmt.Errorf("failed to read response: %w", err)
}
if err := json.Unmarshal(body, &ipInfo); err != nil {
return nil, latencyMs, fmt.Errorf("failed to parse response: %w", err)
}
return &service.ProxyExitInfo{
IP: ipInfo.IP,
City: ipInfo.City,
Region: ipInfo.Region,
Country: ipInfo.Country,
}, latencyMs, nil
}
func createProxyTransport(proxyURL string) (*http.Transport, error) {
parsedURL, err := url.Parse(proxyURL)
if err != nil {
return nil, fmt.Errorf("invalid proxy URL: %w", err)
}
transport := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
switch parsedURL.Scheme {
case "http", "https":
transport.Proxy = http.ProxyURL(parsedURL)
case "socks5":
dialer, err := proxy.FromURL(parsedURL, proxy.Direct)
if err != nil {
return nil, fmt.Errorf("failed to create socks5 dialer: %w", err)
}
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.Dial(network, addr)
}
default:
return nil, fmt.Errorf("unsupported proxy protocol: %s", parsedURL.Scheme)
}
return transport, nil
}

View File

@@ -0,0 +1,55 @@
package repository
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
"strings"
"time"
"sub2api/internal/service"
)
const turnstileVerifyURL = "https://challenges.cloudflare.com/turnstile/v0/siteverify"
type turnstileVerifier struct {
httpClient *http.Client
}
func NewTurnstileVerifier() service.TurnstileVerifier {
return &turnstileVerifier{
httpClient: &http.Client{
Timeout: 10 * time.Second,
},
}
}
func (v *turnstileVerifier) VerifyToken(ctx context.Context, secretKey, token, remoteIP string) (*service.TurnstileVerifyResponse, error) {
formData := url.Values{}
formData.Set("secret", secretKey)
formData.Set("response", token)
if remoteIP != "" {
formData.Set("remoteip", remoteIP)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, turnstileVerifyURL, strings.NewReader(formData.Encode()))
if err != nil {
return nil, fmt.Errorf("create request: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := v.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("send request: %w", err)
}
defer resp.Body.Close()
var result service.TurnstileVerifyResponse
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("decode response: %w", err)
}
return &result, nil
}

View File

@@ -29,6 +29,15 @@ var ProviderSet = wire.NewSet(
NewRedeemCache, NewRedeemCache,
NewUpdateCache, NewUpdateCache,
// HTTP service ports (DI Strategy A: return interface directly)
NewTurnstileVerifier,
NewPricingRemoteClient,
NewGitHubReleaseClient,
NewProxyExitInfoProber,
NewClaudeUsageFetcher,
NewClaudeOAuthClient,
NewClaudeUpstream,
// Bind concrete repositories to service port interfaces // Bind concrete repositories to service port interfaces
wire.Bind(new(ports.UserRepository), new(*UserRepository)), wire.Bind(new(ports.UserRepository), new(*UserRepository)),
wire.Bind(new(ports.ApiKeyRepository), new(*ApiKeyRepository)), wire.Bind(new(ports.ApiKeyRepository), new(*ApiKeyRepository)),

View File

@@ -10,7 +10,6 @@ import (
"io" "io"
"log" "log"
"net/http" "net/http"
"net/url"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@@ -37,19 +36,17 @@ type TestEvent struct {
// AccountTestService handles account testing operations // AccountTestService handles account testing operations
type AccountTestService struct { type AccountTestService struct {
accountRepo ports.AccountRepository accountRepo ports.AccountRepository
oauthService *OAuthService oauthService *OAuthService
httpClient *http.Client claudeUpstream ClaudeUpstream
} }
// NewAccountTestService creates a new AccountTestService // NewAccountTestService creates a new AccountTestService
func NewAccountTestService(accountRepo ports.AccountRepository, oauthService *OAuthService) *AccountTestService { func NewAccountTestService(accountRepo ports.AccountRepository, oauthService *OAuthService, claudeUpstream ClaudeUpstream) *AccountTestService {
return &AccountTestService{ return &AccountTestService{
accountRepo: accountRepo, accountRepo: accountRepo,
oauthService: oauthService, oauthService: oauthService,
httpClient: &http.Client{ claudeUpstream: claudeUpstream,
Timeout: 60 * time.Second,
},
} }
} }
@@ -209,23 +206,13 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
req.Header.Set("x-api-key", authToken) req.Header.Set("x-api-key", authToken)
} }
// Configure proxy if account has one // Get proxy URL
transport := http.DefaultTransport.(*http.Transport).Clone() proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil { if account.ProxyID != nil && account.Proxy != nil {
proxyURL := account.Proxy.URL() proxyURL = account.Proxy.URL()
if proxyURL != "" {
if parsedURL, err := url.Parse(proxyURL); err == nil {
transport.Proxy = http.ProxyURL(parsedURL)
}
}
} }
client := &http.Client{ resp, err := s.claudeUpstream.Do(req, proxyURL)
Transport: transport,
Timeout: 60 * time.Second,
}
resp, err := client.Do(req)
if err != nil { if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error())) return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
} }

View File

@@ -2,12 +2,8 @@ package service
import ( import (
"context" "context"
"encoding/json"
"fmt" "fmt"
"io"
"log" "log"
"net/http"
"net/url"
"sync" "sync"
"time" "time"
@@ -65,23 +61,26 @@ type ClaudeUsageResponse struct {
} `json:"seven_day_sonnet"` } `json:"seven_day_sonnet"`
} }
// ClaudeUsageFetcher fetches usage data from Anthropic OAuth API
type ClaudeUsageFetcher interface {
FetchUsage(ctx context.Context, accessToken, proxyURL string) (*ClaudeUsageResponse, error)
}
// AccountUsageService 账号使用量查询服务 // AccountUsageService 账号使用量查询服务
type AccountUsageService struct { type AccountUsageService struct {
accountRepo ports.AccountRepository accountRepo ports.AccountRepository
usageLogRepo ports.UsageLogRepository usageLogRepo ports.UsageLogRepository
oauthService *OAuthService oauthService *OAuthService
httpClient *http.Client usageFetcher ClaudeUsageFetcher
} }
// NewAccountUsageService 创建AccountUsageService实例 // NewAccountUsageService 创建AccountUsageService实例
func NewAccountUsageService(accountRepo ports.AccountRepository, usageLogRepo ports.UsageLogRepository, oauthService *OAuthService) *AccountUsageService { func NewAccountUsageService(accountRepo ports.AccountRepository, usageLogRepo ports.UsageLogRepository, oauthService *OAuthService, usageFetcher ClaudeUsageFetcher) *AccountUsageService {
return &AccountUsageService{ return &AccountUsageService{
accountRepo: accountRepo, accountRepo: accountRepo,
usageLogRepo: usageLogRepo, usageLogRepo: usageLogRepo,
oauthService: oauthService, oauthService: oauthService,
httpClient: &http.Client{ usageFetcher: usageFetcher,
Timeout: 30 * time.Second,
},
} }
} }
@@ -179,58 +178,23 @@ func (s *AccountUsageService) GetTodayStats(ctx context.Context, accountID int64
// fetchOAuthUsage 从Anthropic API获取OAuth账号的使用量 // fetchOAuthUsage 从Anthropic API获取OAuth账号的使用量
func (s *AccountUsageService) fetchOAuthUsage(ctx context.Context, account *model.Account) (*UsageInfo, error) { func (s *AccountUsageService) fetchOAuthUsage(ctx context.Context, account *model.Account) (*UsageInfo, error) {
// 获取access token从credentials中获取
accessToken := account.GetCredential("access_token") accessToken := account.GetCredential("access_token")
if accessToken == "" { if accessToken == "" {
return nil, fmt.Errorf("no access token available") return nil, fmt.Errorf("no access token available")
} }
// 获取代理配置 var proxyURL string
transport := http.DefaultTransport.(*http.Transport).Clone()
if account.ProxyID != nil && account.Proxy != nil { if account.ProxyID != nil && account.Proxy != nil {
proxyURL := account.Proxy.URL() proxyURL = account.Proxy.URL()
if proxyURL != "" {
if parsedURL, err := url.Parse(proxyURL); err == nil {
transport.Proxy = http.ProxyURL(parsedURL)
}
}
} }
client := &http.Client{ usageResp, err := s.usageFetcher.FetchUsage(ctx, accessToken, proxyURL)
Transport: transport,
Timeout: 30 * time.Second,
}
// 构建请求
req, err := http.NewRequestWithContext(ctx, "GET", "https://api.anthropic.com/api/oauth/usage", nil)
if err != nil { if err != nil {
return nil, fmt.Errorf("create request failed: %w", err) return nil, err
} }
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("anthropic-beta", "oauth-2025-04-20")
// 发送请求
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("API returned status %d: %s", resp.StatusCode, string(body))
}
// 解析响应
var usageResp ClaudeUsageResponse
if err := json.NewDecoder(resp.Body).Decode(&usageResp); err != nil {
return nil, fmt.Errorf("decode response failed: %w", err)
}
// 转换为UsageInfo
now := time.Now() now := time.Now()
return s.buildUsageInfo(&usageResp, &now), nil return s.buildUsageInfo(usageResp, &now), nil
} }
// parseTime 尝试多种格式解析时间 // parseTime 尝试多种格式解析时间

View File

@@ -2,21 +2,14 @@ package service
import ( import (
"context" "context"
"crypto/tls"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"io"
"net"
"net/http"
"net/url"
"time" "time"
"sub2api/internal/model" "sub2api/internal/model"
"sub2api/internal/pkg/pagination" "sub2api/internal/pkg/pagination"
"sub2api/internal/service/ports" "sub2api/internal/service/ports"
"golang.org/x/net/proxy"
"gorm.io/gorm" "gorm.io/gorm"
) )
@@ -178,6 +171,19 @@ type ProxyTestResult struct {
Country string `json:"country,omitempty"` Country string `json:"country,omitempty"`
} }
// ProxyExitInfo represents proxy exit information from ipinfo.io
type ProxyExitInfo struct {
IP string
City string
Region string
Country string
}
// ProxyExitInfoProber tests proxy connectivity and retrieves exit information
type ProxyExitInfoProber interface {
ProbeProxy(ctx context.Context, proxyURL string) (*ProxyExitInfo, int64, error)
}
// adminServiceImpl implements AdminService // adminServiceImpl implements AdminService
type adminServiceImpl struct { type adminServiceImpl struct {
userRepo ports.UserRepository userRepo ports.UserRepository
@@ -189,6 +195,7 @@ type adminServiceImpl struct {
usageLogRepo ports.UsageLogRepository usageLogRepo ports.UsageLogRepository
userSubRepo ports.UserSubscriptionRepository userSubRepo ports.UserSubscriptionRepository
billingCacheService *BillingCacheService billingCacheService *BillingCacheService
proxyProber ProxyExitInfoProber
} }
// NewAdminService creates a new AdminService // NewAdminService creates a new AdminService
@@ -202,6 +209,7 @@ func NewAdminService(
usageLogRepo ports.UsageLogRepository, usageLogRepo ports.UsageLogRepository,
userSubRepo ports.UserSubscriptionRepository, userSubRepo ports.UserSubscriptionRepository,
billingCacheService *BillingCacheService, billingCacheService *BillingCacheService,
proxyProber ProxyExitInfoProber,
) AdminService { ) AdminService {
return &adminServiceImpl{ return &adminServiceImpl{
userRepo: userRepo, userRepo: userRepo,
@@ -213,6 +221,7 @@ func NewAdminService(
usageLogRepo: usageLogRepo, usageLogRepo: usageLogRepo,
userSubRepo: userSubRepo, userSubRepo: userSubRepo,
billingCacheService: billingCacheService, billingCacheService: billingCacheService,
proxyProber: proxyProber,
} }
} }
@@ -876,79 +885,12 @@ func (s *adminServiceImpl) TestProxy(ctx context.Context, id int64) (*ProxyTestR
return nil, err return nil, err
} }
return testProxyConnection(ctx, proxy)
}
// testProxyConnection tests proxy connectivity by requesting ipinfo.io/json
func testProxyConnection(ctx context.Context, proxy *model.Proxy) (*ProxyTestResult, error) {
proxyURL := proxy.URL() proxyURL := proxy.URL()
exitInfo, latencyMs, err := s.proxyProber.ProbeProxy(ctx, proxyURL)
// Create HTTP client with proxy
transport, err := createProxyTransport(proxyURL)
if err != nil { if err != nil {
return &ProxyTestResult{ return &ProxyTestResult{
Success: false, Success: false,
Message: fmt.Sprintf("Failed to create proxy transport: %v", err), Message: err.Error(),
}, nil
}
client := &http.Client{
Transport: transport,
Timeout: 15 * time.Second,
}
// Measure latency
startTime := time.Now()
req, err := http.NewRequestWithContext(ctx, "GET", "https://ipinfo.io/json", nil)
if err != nil {
return &ProxyTestResult{
Success: false,
Message: fmt.Sprintf("Failed to create request: %v", err),
}, nil
}
resp, err := client.Do(req)
if err != nil {
return &ProxyTestResult{
Success: false,
Message: fmt.Sprintf("Proxy connection failed: %v", err),
}, nil
}
defer resp.Body.Close()
latencyMs := time.Since(startTime).Milliseconds()
if resp.StatusCode != http.StatusOK {
return &ProxyTestResult{
Success: false,
Message: fmt.Sprintf("Request failed with status: %d", resp.StatusCode),
LatencyMs: latencyMs,
}, nil
}
// Parse ipinfo.io response
var ipInfo struct {
IP string `json:"ip"`
City string `json:"city"`
Region string `json:"region"`
Country string `json:"country"`
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return &ProxyTestResult{
Success: true,
Message: "Proxy is accessible but failed to read response",
LatencyMs: latencyMs,
}, nil
}
if err := json.Unmarshal(body, &ipInfo); err != nil {
return &ProxyTestResult{
Success: true,
Message: "Proxy is accessible but failed to parse response",
LatencyMs: latencyMs,
}, nil }, nil
} }
@@ -956,38 +898,9 @@ func testProxyConnection(ctx context.Context, proxy *model.Proxy) (*ProxyTestRes
Success: true, Success: true,
Message: "Proxy is accessible", Message: "Proxy is accessible",
LatencyMs: latencyMs, LatencyMs: latencyMs,
IPAddress: ipInfo.IP, IPAddress: exitInfo.IP,
City: ipInfo.City, City: exitInfo.City,
Region: ipInfo.Region, Region: exitInfo.Region,
Country: ipInfo.Country, Country: exitInfo.Country,
}, nil }, nil
} }
// createProxyTransport creates an HTTP transport with the given proxy URL
func createProxyTransport(proxyURL string) (*http.Transport, error) {
parsedURL, err := url.Parse(proxyURL)
if err != nil {
return nil, fmt.Errorf("invalid proxy URL: %w", err)
}
transport := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
switch parsedURL.Scheme {
case "http", "https":
transport.Proxy = http.ProxyURL(parsedURL)
case "socks5":
dialer, err := proxy.FromURL(parsedURL, proxy.Direct)
if err != nil {
return nil, fmt.Errorf("failed to create socks5 dialer: %w", err)
}
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.Dial(network, addr)
}
default:
return nil, fmt.Errorf("unsupported proxy protocol: %s", parsedURL.Scheme)
}
return transport, nil
}

View File

@@ -12,7 +12,6 @@ import (
"io" "io"
"log" "log"
"net/http" "net/http"
"net/url"
"regexp" "regexp"
"strconv" "strconv"
"strings" "strings"
@@ -26,6 +25,11 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
// ClaudeUpstream handles HTTP requests to Claude API
type ClaudeUpstream interface {
Do(req *http.Request, proxyURL string) (*http.Response, error)
}
const ( const (
claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true" claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true"
claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true" claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true"
@@ -87,7 +91,7 @@ type GatewayService struct {
rateLimitService *RateLimitService rateLimitService *RateLimitService
billingCacheService *BillingCacheService billingCacheService *BillingCacheService
identityService *IdentityService identityService *IdentityService
httpClient *http.Client claudeUpstream ClaudeUpstream
} }
// NewGatewayService creates a new GatewayService // NewGatewayService creates a new GatewayService
@@ -103,20 +107,8 @@ func NewGatewayService(
rateLimitService *RateLimitService, rateLimitService *RateLimitService,
billingCacheService *BillingCacheService, billingCacheService *BillingCacheService,
identityService *IdentityService, identityService *IdentityService,
claudeUpstream ClaudeUpstream,
) *GatewayService { ) *GatewayService {
// 计算响应头超时时间
responseHeaderTimeout := time.Duration(cfg.Gateway.ResponseHeaderTimeout) * time.Second
if responseHeaderTimeout == 0 {
responseHeaderTimeout = 300 * time.Second // 默认5分钟LLM高负载时可能排队较久
}
transport := &http.Transport{
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: 90 * time.Second,
ResponseHeaderTimeout: responseHeaderTimeout, // 等待上游响应头的超时
// 注意:不设置整体 Timeout让流式响应可以无限时间传输
}
return &GatewayService{ return &GatewayService{
accountRepo: accountRepo, accountRepo: accountRepo,
usageLogRepo: usageLogRepo, usageLogRepo: usageLogRepo,
@@ -129,11 +121,7 @@ func NewGatewayService(
rateLimitService: rateLimitService, rateLimitService: rateLimitService,
billingCacheService: billingCacheService, billingCacheService: billingCacheService,
identityService: identityService, identityService: identityService,
httpClient: &http.Client{ claudeUpstream: claudeUpstream,
Transport: transport,
// 不设置 Timeout流式请求可能持续十几分钟
// 超时控制由 Transport.ResponseHeaderTimeout 负责(只控制等待响应头)
},
} }
} }
@@ -436,19 +424,19 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *m
} }
// 构建上游请求 // 构建上游请求
upstreamResult, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType) upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// 选择使用的client如果有代理则使用独立的client否则使用共享的httpClient // 获取代理URL
httpClient := s.httpClient proxyURL := ""
if upstreamResult.Client != nil { if account.ProxyID != nil && account.Proxy != nil {
httpClient = upstreamResult.Client proxyURL = account.Proxy.URL()
} }
// 发送请求 // 发送请求
resp, err := httpClient.Do(upstreamResult.Request) resp, err := s.claudeUpstream.Do(upstreamReq, proxyURL)
if err != nil { if err != nil {
return nil, fmt.Errorf("upstream request failed: %w", err) return nil, fmt.Errorf("upstream request failed: %w", err)
} }
@@ -461,16 +449,11 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *m
if err != nil { if err != nil {
return nil, fmt.Errorf("token refresh failed: %w", err) return nil, fmt.Errorf("token refresh failed: %w", err)
} }
upstreamResult, err = s.buildUpstreamRequest(ctx, c, account, body, token, tokenType) upstreamReq, err = s.buildUpstreamRequest(ctx, c, account, body, token, tokenType)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// 重试时也需要使用正确的client resp, err = s.claudeUpstream.Do(upstreamReq, proxyURL)
httpClient = s.httpClient
if upstreamResult.Client != nil {
httpClient = upstreamResult.Client
}
resp, err = httpClient.Do(upstreamResult.Request)
if err != nil { if err != nil {
return nil, fmt.Errorf("retry request failed: %w", err) return nil, fmt.Errorf("retry request failed: %w", err)
} }
@@ -509,13 +492,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *m
}, nil }, nil
} }
// buildUpstreamRequestResult contains the request and optional custom client for proxy func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *model.Account, body []byte, token, tokenType string) (*http.Request, error) {
type buildUpstreamRequestResult struct {
Request *http.Request
Client *http.Client // nil means use default s.httpClient
}
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *model.Account, body []byte, token, tokenType string) (*buildUpstreamRequestResult, error) {
// 确定目标URL // 确定目标URL
targetURL := claudeAPIURL targetURL := claudeAPIURL
if account.Type == model.AccountTypeApiKey { if account.Type == model.AccountTypeApiKey {
@@ -584,36 +561,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
req.Header.Set("anthropic-beta", s.getBetaHeader(body, c.GetHeader("anthropic-beta"))) req.Header.Set("anthropic-beta", s.getBetaHeader(body, c.GetHeader("anthropic-beta")))
} }
// 配置代理 - 创建独立的client避免并发修改共享httpClient return req, nil
var customClient *http.Client
if account.ProxyID != nil && account.Proxy != nil {
proxyURL := account.Proxy.URL()
if proxyURL != "" {
if parsedURL, err := url.Parse(proxyURL); err == nil {
// 计算响应头超时时间(与默认 Transport 保持一致)
responseHeaderTimeout := time.Duration(s.cfg.Gateway.ResponseHeaderTimeout) * time.Second
if responseHeaderTimeout == 0 {
responseHeaderTimeout = 300 * time.Second
}
transport := &http.Transport{
Proxy: http.ProxyURL(parsedURL),
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: 90 * time.Second,
ResponseHeaderTimeout: responseHeaderTimeout,
}
// 创建独立的client避免并发时修改共享的s.httpClient.Transport
customClient = &http.Client{
Transport: transport,
}
}
}
}
return &buildUpstreamRequestResult{
Request: req,
Client: customClient,
}, nil
} }
// getBetaHeader 处理anthropic-beta header // getBetaHeader 处理anthropic-beta header
@@ -1085,20 +1033,20 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
} }
// 构建上游请求 // 构建上游请求
upstreamResult, err := s.buildCountTokensRequest(ctx, c, account, body, token, tokenType) upstreamReq, err := s.buildCountTokensRequest(ctx, c, account, body, token, tokenType)
if err != nil { if err != nil {
s.countTokensError(c, http.StatusInternalServerError, "api_error", "Failed to build request") s.countTokensError(c, http.StatusInternalServerError, "api_error", "Failed to build request")
return err return err
} }
// 选择 HTTP client // 获取代理URL
httpClient := s.httpClient proxyURL := ""
if upstreamResult.Client != nil { if account.ProxyID != nil && account.Proxy != nil {
httpClient = upstreamResult.Client proxyURL = account.Proxy.URL()
} }
// 发送请求 // 发送请求
resp, err := httpClient.Do(upstreamResult.Request) resp, err := s.claudeUpstream.Do(upstreamReq, proxyURL)
if err != nil { if err != nil {
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Request failed") s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Request failed")
return fmt.Errorf("upstream request failed: %w", err) return fmt.Errorf("upstream request failed: %w", err)
@@ -1113,15 +1061,11 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Token refresh failed") s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Token refresh failed")
return fmt.Errorf("token refresh failed: %w", err) return fmt.Errorf("token refresh failed: %w", err)
} }
upstreamResult, err = s.buildCountTokensRequest(ctx, c, account, body, token, tokenType) upstreamReq, err = s.buildCountTokensRequest(ctx, c, account, body, token, tokenType)
if err != nil { if err != nil {
return err return err
} }
httpClient = s.httpClient resp, err = s.claudeUpstream.Do(upstreamReq, proxyURL)
if upstreamResult.Client != nil {
httpClient = upstreamResult.Client
}
resp, err = httpClient.Do(upstreamResult.Request)
if err != nil { if err != nil {
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Retry failed") s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Retry failed")
return fmt.Errorf("retry request failed: %w", err) return fmt.Errorf("retry request failed: %w", err)
@@ -1159,7 +1103,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
} }
// buildCountTokensRequest 构建 count_tokens 上游请求 // buildCountTokensRequest 构建 count_tokens 上游请求
func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *model.Account, body []byte, token, tokenType string) (*buildUpstreamRequestResult, error) { func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *model.Account, body []byte, token, tokenType string) (*http.Request, error) {
// 确定目标 URL // 确定目标 URL
targetURL := claudeAPICountTokensURL targetURL := claudeAPICountTokensURL
if account.Type == model.AccountTypeApiKey { if account.Type == model.AccountTypeApiKey {
@@ -1223,32 +1167,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
req.Header.Set("anthropic-beta", s.getBetaHeader(body, c.GetHeader("anthropic-beta"))) req.Header.Set("anthropic-beta", s.getBetaHeader(body, c.GetHeader("anthropic-beta")))
} }
// 配置代理 return req, nil
var customClient *http.Client
if account.ProxyID != nil && account.Proxy != nil {
proxyURL := account.Proxy.URL()
if proxyURL != "" {
if parsedURL, err := url.Parse(proxyURL); err == nil {
responseHeaderTimeout := time.Duration(s.cfg.Gateway.ResponseHeaderTimeout) * time.Second
if responseHeaderTimeout == 0 {
responseHeaderTimeout = 300 * time.Second
}
transport := &http.Transport{
Proxy: http.ProxyURL(parsedURL),
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: 90 * time.Second,
ResponseHeaderTimeout: responseHeaderTimeout,
}
customClient = &http.Client{Transport: transport}
}
}
}
return &buildUpstreamRequestResult{
Request: req,
Client: customClient,
}, nil
} }
// countTokensError 返回 count_tokens 错误响应 // countTokensError 返回 count_tokens 错误响应

View File

@@ -2,32 +2,36 @@ package service
import ( import (
"context" "context"
"encoding/json"
"fmt" "fmt"
"log" "log"
"net/http"
"net/url"
"strings"
"time" "time"
"sub2api/internal/model" "sub2api/internal/model"
"sub2api/internal/pkg/oauth" "sub2api/internal/pkg/oauth"
"sub2api/internal/service/ports" "sub2api/internal/service/ports"
"github.com/imroc/req/v3"
) )
// ClaudeOAuthClient handles HTTP requests for Claude OAuth flows
type ClaudeOAuthClient interface {
GetOrganizationUUID(ctx context.Context, sessionKey, proxyURL string) (string, error)
GetAuthorizationCode(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error)
ExchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string) (*oauth.TokenResponse, error)
RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error)
}
// OAuthService handles OAuth authentication flows // OAuthService handles OAuth authentication flows
type OAuthService struct { type OAuthService struct {
sessionStore *oauth.SessionStore sessionStore *oauth.SessionStore
proxyRepo ports.ProxyRepository proxyRepo ports.ProxyRepository
oauthClient ClaudeOAuthClient
} }
// NewOAuthService creates a new OAuth service // NewOAuthService creates a new OAuth service
func NewOAuthService(proxyRepo ports.ProxyRepository) *OAuthService { func NewOAuthService(proxyRepo ports.ProxyRepository, oauthClient ClaudeOAuthClient) *OAuthService {
return &OAuthService{ return &OAuthService{
sessionStore: oauth.NewSessionStore(), sessionStore: oauth.NewSessionStore(),
proxyRepo: proxyRepo, proxyRepo: proxyRepo,
oauthClient: oauthClient,
} }
} }
@@ -210,177 +214,21 @@ func (s *OAuthService) CookieAuth(ctx context.Context, input *CookieAuthInput) (
// getOrganizationUUID gets the organization UUID from claude.ai using sessionKey // getOrganizationUUID gets the organization UUID from claude.ai using sessionKey
func (s *OAuthService) getOrganizationUUID(ctx context.Context, sessionKey, proxyURL string) (string, error) { func (s *OAuthService) getOrganizationUUID(ctx context.Context, sessionKey, proxyURL string) (string, error) {
client := s.createReqClient(proxyURL) return s.oauthClient.GetOrganizationUUID(ctx, sessionKey, proxyURL)
var orgs []struct {
UUID string `json:"uuid"`
}
targetURL := "https://claude.ai/api/organizations"
log.Printf("[OAuth] Step 1: Getting organization UUID from %s", targetURL)
resp, err := client.R().
SetContext(ctx).
SetCookies(&http.Cookie{
Name: "sessionKey",
Value: sessionKey,
}).
SetSuccessResult(&orgs).
Get(targetURL)
if err != nil {
log.Printf("[OAuth] Step 1 FAILED - Request error: %v", err)
return "", fmt.Errorf("request failed: %w", err)
}
log.Printf("[OAuth] Step 1 Response - Status: %d, Body: %s", resp.StatusCode, resp.String())
if !resp.IsSuccessState() {
return "", fmt.Errorf("failed to get organizations: status %d, body: %s", resp.StatusCode, resp.String())
}
if len(orgs) == 0 {
return "", fmt.Errorf("no organizations found")
}
log.Printf("[OAuth] Step 1 SUCCESS - Got org UUID: %s", orgs[0].UUID)
return orgs[0].UUID, nil
} }
// getAuthorizationCode gets the authorization code using sessionKey // getAuthorizationCode gets the authorization code using sessionKey
func (s *OAuthService) getAuthorizationCode(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error) { func (s *OAuthService) getAuthorizationCode(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error) {
client := s.createReqClient(proxyURL) return s.oauthClient.GetAuthorizationCode(ctx, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL)
authURL := fmt.Sprintf("https://claude.ai/v1/oauth/%s/authorize", orgUUID)
// Build request body - must include organization_uuid as per CRS
reqBody := map[string]interface{}{
"response_type": "code",
"client_id": oauth.ClientID,
"organization_uuid": orgUUID, // Required field!
"redirect_uri": oauth.RedirectURI,
"scope": scope,
"state": state,
"code_challenge": codeChallenge,
"code_challenge_method": "S256",
}
reqBodyJSON, _ := json.Marshal(reqBody)
log.Printf("[OAuth] Step 2: Getting authorization code from %s", authURL)
log.Printf("[OAuth] Step 2 Request Body: %s", string(reqBodyJSON))
// Response contains redirect_uri with code, not direct code field
var result struct {
RedirectURI string `json:"redirect_uri"`
}
resp, err := client.R().
SetContext(ctx).
SetCookies(&http.Cookie{
Name: "sessionKey",
Value: sessionKey,
}).
SetHeader("Accept", "application/json").
SetHeader("Accept-Language", "en-US,en;q=0.9").
SetHeader("Cache-Control", "no-cache").
SetHeader("Origin", "https://claude.ai").
SetHeader("Referer", "https://claude.ai/new").
SetHeader("Content-Type", "application/json").
SetBody(reqBody).
SetSuccessResult(&result).
Post(authURL)
if err != nil {
log.Printf("[OAuth] Step 2 FAILED - Request error: %v", err)
return "", fmt.Errorf("request failed: %w", err)
}
log.Printf("[OAuth] Step 2 Response - Status: %d, Body: %s", resp.StatusCode, resp.String())
if !resp.IsSuccessState() {
return "", fmt.Errorf("failed to get authorization code: status %d, body: %s", resp.StatusCode, resp.String())
}
if result.RedirectURI == "" {
return "", fmt.Errorf("no redirect_uri in response")
}
// Parse redirect_uri to extract code and state
parsedURL, err := url.Parse(result.RedirectURI)
if err != nil {
return "", fmt.Errorf("failed to parse redirect_uri: %w", err)
}
queryParams := parsedURL.Query()
authCode := queryParams.Get("code")
responseState := queryParams.Get("state")
if authCode == "" {
return "", fmt.Errorf("no authorization code in redirect_uri")
}
// Combine code with state if present (as CRS does)
fullCode := authCode
if responseState != "" {
fullCode = authCode + "#" + responseState
}
log.Printf("[OAuth] Step 2 SUCCESS - Got authorization code: %s...", authCode[:20])
return fullCode, nil
} }
// exchangeCodeForToken exchanges authorization code for tokens // exchangeCodeForToken exchanges authorization code for tokens
func (s *OAuthService) exchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string) (*TokenInfo, error) { func (s *OAuthService) exchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string) (*TokenInfo, error) {
client := s.createReqClient(proxyURL) tokenResp, err := s.oauthClient.ExchangeCodeForToken(ctx, code, codeVerifier, state, proxyURL)
// Parse code#state format if present
authCode := code
codeState := ""
if parts := strings.Split(code, "#"); len(parts) > 1 {
authCode = parts[0]
codeState = parts[1]
}
// Build JSON body as CRS does (not form data!)
reqBody := map[string]interface{}{
"code": authCode,
"grant_type": "authorization_code",
"client_id": oauth.ClientID,
"redirect_uri": oauth.RedirectURI,
"code_verifier": codeVerifier,
}
// Add state if present
if codeState != "" {
reqBody["state"] = codeState
}
reqBodyJSON, _ := json.Marshal(reqBody)
log.Printf("[OAuth] Step 3: Exchanging code for token at %s", oauth.TokenURL)
log.Printf("[OAuth] Step 3 Request Body: %s", string(reqBodyJSON))
var tokenResp oauth.TokenResponse
resp, err := client.R().
SetContext(ctx).
SetHeader("Content-Type", "application/json").
SetBody(reqBody).
SetSuccessResult(&tokenResp).
Post(oauth.TokenURL)
if err != nil { if err != nil {
log.Printf("[OAuth] Step 3 FAILED - Request error: %v", err) return nil, err
return nil, fmt.Errorf("request failed: %w", err)
} }
log.Printf("[OAuth] Step 3 Response - Status: %d, Body: %s", resp.StatusCode, resp.String())
if !resp.IsSuccessState() {
return nil, fmt.Errorf("token exchange failed: status %d, body: %s", resp.StatusCode, resp.String())
}
log.Printf("[OAuth] Step 3 SUCCESS - Got access token")
tokenInfo := &TokenInfo{ tokenInfo := &TokenInfo{
AccessToken: tokenResp.AccessToken, AccessToken: tokenResp.AccessToken,
TokenType: tokenResp.TokenType, TokenType: tokenResp.TokenType,
@@ -390,7 +238,6 @@ func (s *OAuthService) exchangeCodeForToken(ctx context.Context, code, codeVerif
Scope: tokenResp.Scope, Scope: tokenResp.Scope,
} }
// Extract org_uuid and account_uuid from response
if tokenResp.Organization != nil && tokenResp.Organization.UUID != "" { if tokenResp.Organization != nil && tokenResp.Organization.UUID != "" {
tokenInfo.OrgUUID = tokenResp.Organization.UUID tokenInfo.OrgUUID = tokenResp.Organization.UUID
log.Printf("[OAuth] Got org_uuid: %s", tokenInfo.OrgUUID) log.Printf("[OAuth] Got org_uuid: %s", tokenInfo.OrgUUID)
@@ -405,27 +252,9 @@ func (s *OAuthService) exchangeCodeForToken(ctx context.Context, code, codeVerif
// RefreshToken refreshes an OAuth token // RefreshToken refreshes an OAuth token
func (s *OAuthService) RefreshToken(ctx context.Context, refreshToken string, proxyURL string) (*TokenInfo, error) { func (s *OAuthService) RefreshToken(ctx context.Context, refreshToken string, proxyURL string) (*TokenInfo, error) {
client := s.createReqClient(proxyURL) tokenResp, err := s.oauthClient.RefreshToken(ctx, refreshToken, proxyURL)
formData := url.Values{}
formData.Set("grant_type", "refresh_token")
formData.Set("refresh_token", refreshToken)
formData.Set("client_id", oauth.ClientID)
var tokenResp oauth.TokenResponse
resp, err := client.R().
SetContext(ctx).
SetFormDataFromValues(formData).
SetSuccessResult(&tokenResp).
Post(oauth.TokenURL)
if err != nil { if err != nil {
return nil, fmt.Errorf("request failed: %w", err) return nil, err
}
if !resp.IsSuccessState() {
return nil, fmt.Errorf("token refresh failed: status %d, body: %s", resp.StatusCode, resp.String())
} }
return &TokenInfo{ return &TokenInfo{
@@ -455,17 +284,3 @@ func (s *OAuthService) RefreshAccountToken(ctx context.Context, account *model.A
return s.RefreshToken(ctx, refreshToken, proxyURL) return s.RefreshToken(ctx, refreshToken, proxyURL)
} }
// createReqClient creates a req client with Chrome impersonation and optional proxy
func (s *OAuthService) createReqClient(proxyURL string) *req.Client {
client := req.C().
ImpersonateChrome(). // Impersonate Chrome browser to bypass Cloudflare
SetTimeout(60 * time.Second)
// Set proxy if specified
if proxyURL != "" {
client.SetProxyURL(proxyURL)
}
return client
}

View File

@@ -1,13 +1,12 @@
package service package service
import ( import (
"context"
"crypto/sha256" "crypto/sha256"
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io"
"log" "log"
"net/http"
"os" "os"
"path/filepath" "path/filepath"
"strings" "strings"
@@ -20,13 +19,19 @@ import (
// LiteLLMModelPricing LiteLLM价格数据结构 // LiteLLMModelPricing LiteLLM价格数据结构
// 只保留我们需要的字段,使用指针来处理可能缺失的值 // 只保留我们需要的字段,使用指针来处理可能缺失的值
type LiteLLMModelPricing struct { type LiteLLMModelPricing struct {
InputCostPerToken float64 `json:"input_cost_per_token"` InputCostPerToken float64 `json:"input_cost_per_token"`
OutputCostPerToken float64 `json:"output_cost_per_token"` OutputCostPerToken float64 `json:"output_cost_per_token"`
CacheCreationInputTokenCost float64 `json:"cache_creation_input_token_cost"` CacheCreationInputTokenCost float64 `json:"cache_creation_input_token_cost"`
CacheReadInputTokenCost float64 `json:"cache_read_input_token_cost"` CacheReadInputTokenCost float64 `json:"cache_read_input_token_cost"`
LiteLLMProvider string `json:"litellm_provider"` LiteLLMProvider string `json:"litellm_provider"`
Mode string `json:"mode"` Mode string `json:"mode"`
SupportsPromptCaching bool `json:"supports_prompt_caching"` SupportsPromptCaching bool `json:"supports_prompt_caching"`
}
// PricingRemoteClient 远程价格数据获取接口
type PricingRemoteClient interface {
FetchPricingJSON(ctx context.Context, url string) ([]byte, error)
FetchHashText(ctx context.Context, url string) (string, error)
} }
// LiteLLMRawEntry 用于解析原始JSON数据 // LiteLLMRawEntry 用于解析原始JSON数据
@@ -42,11 +47,12 @@ type LiteLLMRawEntry struct {
// PricingService 动态价格服务 // PricingService 动态价格服务
type PricingService struct { type PricingService struct {
cfg *config.Config cfg *config.Config
mu sync.RWMutex remoteClient PricingRemoteClient
pricingData map[string]*LiteLLMModelPricing mu sync.RWMutex
lastUpdated time.Time pricingData map[string]*LiteLLMModelPricing
localHash string lastUpdated time.Time
localHash string
// 停止信号 // 停止信号
stopCh chan struct{} stopCh chan struct{}
@@ -54,11 +60,12 @@ type PricingService struct {
} }
// NewPricingService 创建价格服务 // NewPricingService 创建价格服务
func NewPricingService(cfg *config.Config) *PricingService { func NewPricingService(cfg *config.Config, remoteClient PricingRemoteClient) *PricingService {
s := &PricingService{ s := &PricingService{
cfg: cfg, cfg: cfg,
pricingData: make(map[string]*LiteLLMModelPricing), remoteClient: remoteClient,
stopCh: make(chan struct{}), pricingData: make(map[string]*LiteLLMModelPricing),
stopCh: make(chan struct{}),
} }
return s return s
} }
@@ -199,21 +206,13 @@ func (s *PricingService) syncWithRemote() error {
func (s *PricingService) downloadPricingData() error { func (s *PricingService) downloadPricingData() error {
log.Printf("[Pricing] Downloading from %s", s.cfg.Pricing.RemoteURL) log.Printf("[Pricing] Downloading from %s", s.cfg.Pricing.RemoteURL)
client := &http.Client{Timeout: 30 * time.Second} ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
resp, err := client.Get(s.cfg.Pricing.RemoteURL) defer cancel()
body, err := s.remoteClient.FetchPricingJSON(ctx, s.cfg.Pricing.RemoteURL)
if err != nil { if err != nil {
return fmt.Errorf("download failed: %w", err) return fmt.Errorf("download failed: %w", err)
} }
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("download failed: HTTP %d", resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("read response failed: %w", err)
}
// 解析JSON数据使用灵活的解析方式 // 解析JSON数据使用灵活的解析方式
data, err := s.parsePricingData(body) data, err := s.parsePricingData(body)
@@ -367,29 +366,10 @@ func (s *PricingService) useFallbackPricing() error {
// fetchRemoteHash 从远程获取哈希值 // fetchRemoteHash 从远程获取哈希值
func (s *PricingService) fetchRemoteHash() (string, error) { func (s *PricingService) fetchRemoteHash() (string, error) {
client := &http.Client{Timeout: 10 * time.Second} ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
resp, err := client.Get(s.cfg.Pricing.HashURL) defer cancel()
if err != nil {
return "", err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK { return s.remoteClient.FetchHashText(ctx, s.cfg.Pricing.HashURL)
return "", fmt.Errorf("HTTP %d", resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", err
}
// 哈希文件格式hash filename 或者纯 hash
hash := strings.TrimSpace(string(body))
parts := strings.Fields(hash)
if len(parts) > 0 {
return parts[0], nil
}
return hash, nil
} }
// computeFileHash 计算文件哈希 // computeFileHash 计算文件哈希
@@ -466,14 +446,14 @@ func (s *PricingService) extractBaseName(model string) string {
func (s *PricingService) matchByModelFamily(model string) *LiteLLMModelPricing { func (s *PricingService) matchByModelFamily(model string) *LiteLLMModelPricing {
// Claude模型系列匹配规则 // Claude模型系列匹配规则
familyPatterns := map[string][]string{ familyPatterns := map[string][]string{
"opus-4.5": {"claude-opus-4.5", "claude-opus-4-5"}, "opus-4.5": {"claude-opus-4.5", "claude-opus-4-5"},
"opus-4": {"claude-opus-4", "claude-3-opus"}, "opus-4": {"claude-opus-4", "claude-3-opus"},
"sonnet-4.5": {"claude-sonnet-4.5", "claude-sonnet-4-5"}, "sonnet-4.5": {"claude-sonnet-4.5", "claude-sonnet-4-5"},
"sonnet-4": {"claude-sonnet-4", "claude-3-5-sonnet"}, "sonnet-4": {"claude-sonnet-4", "claude-3-5-sonnet"},
"sonnet-3.5": {"claude-3-5-sonnet", "claude-3.5-sonnet"}, "sonnet-3.5": {"claude-3-5-sonnet", "claude-3.5-sonnet"},
"sonnet-3": {"claude-3-sonnet"}, "sonnet-3": {"claude-3-sonnet"},
"haiku-3.5": {"claude-3-5-haiku", "claude-3.5-haiku"}, "haiku-3.5": {"claude-3-5-haiku", "claude-3.5-haiku"},
"haiku-3": {"claude-3-haiku"}, "haiku-3": {"claude-3-haiku"},
} }
// 确定模型属于哪个系列 // 确定模型属于哪个系列

View File

@@ -26,4 +26,5 @@ type Services struct {
Subscription *SubscriptionService Subscription *SubscriptionService
Concurrency *ConcurrencyService Concurrency *ConcurrencyService
Identity *IdentityService Identity *IdentityService
Update *UpdateService
} }

View File

@@ -2,14 +2,9 @@ package service
import ( import (
"context" "context"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"log" "log"
"net/http"
"net/url"
"strings"
"time"
) )
var ( var (
@@ -19,10 +14,15 @@ var (
const turnstileVerifyURL = "https://challenges.cloudflare.com/turnstile/v0/siteverify" const turnstileVerifyURL = "https://challenges.cloudflare.com/turnstile/v0/siteverify"
// TurnstileVerifier 验证 Turnstile token 的接口
type TurnstileVerifier interface {
VerifyToken(ctx context.Context, secretKey, token, remoteIP string) (*TurnstileVerifyResponse, error)
}
// TurnstileService Turnstile 验证服务 // TurnstileService Turnstile 验证服务
type TurnstileService struct { type TurnstileService struct {
settingService *SettingService settingService *SettingService
httpClient *http.Client verifier TurnstileVerifier
} }
// TurnstileVerifyResponse Cloudflare Turnstile 验证响应 // TurnstileVerifyResponse Cloudflare Turnstile 验证响应
@@ -36,12 +36,10 @@ type TurnstileVerifyResponse struct {
} }
// NewTurnstileService 创建 Turnstile 服务实例 // NewTurnstileService 创建 Turnstile 服务实例
func NewTurnstileService(settingService *SettingService) *TurnstileService { func NewTurnstileService(settingService *SettingService, verifier TurnstileVerifier) *TurnstileService {
return &TurnstileService{ return &TurnstileService{
settingService: settingService, settingService: settingService,
httpClient: &http.Client{ verifier: verifier,
Timeout: 10 * time.Second,
},
} }
} }
@@ -66,35 +64,12 @@ func (s *TurnstileService) VerifyToken(ctx context.Context, token string, remote
return ErrTurnstileVerificationFailed return ErrTurnstileVerificationFailed
} }
// 构建请求
formData := url.Values{}
formData.Set("secret", secretKey)
formData.Set("response", token)
if remoteIP != "" {
formData.Set("remoteip", remoteIP)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, turnstileVerifyURL, strings.NewReader(formData.Encode()))
if err != nil {
return fmt.Errorf("create request: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
// 发送请求
log.Printf("[Turnstile] Verifying token for IP: %s", remoteIP) log.Printf("[Turnstile] Verifying token for IP: %s", remoteIP)
resp, err := s.httpClient.Do(req) result, err := s.verifier.VerifyToken(ctx, secretKey, token, remoteIP)
if err != nil { if err != nil {
log.Printf("[Turnstile] Request failed: %v", err) log.Printf("[Turnstile] Request failed: %v", err)
return fmt.Errorf("send request: %w", err) return fmt.Errorf("send request: %w", err)
} }
defer resp.Body.Close()
// 解析响应
var result TurnstileVerifyResponse
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
log.Printf("[Turnstile] Failed to decode response: %v", err)
return fmt.Errorf("decode response: %w", err)
}
if !result.Success { if !result.Success {
log.Printf("[Turnstile] Verification failed, error codes: %v", result.ErrorCodes) log.Printf("[Turnstile] Verification failed, error codes: %v", result.ErrorCodes)

View File

@@ -10,7 +10,6 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"net/http"
"net/url" "net/url"
"os" "os"
"path/filepath" "path/filepath"
@@ -34,17 +33,26 @@ const (
maxDownloadSize = 500 * 1024 * 1024 maxDownloadSize = 500 * 1024 * 1024
) )
// GitHubReleaseClient 获取 GitHub release 信息的接口
type GitHubReleaseClient interface {
FetchLatestRelease(ctx context.Context, repo string) (*GitHubRelease, error)
DownloadFile(ctx context.Context, url, dest string, maxSize int64) error
FetchChecksumFile(ctx context.Context, url string) ([]byte, error)
}
// UpdateService handles software updates // UpdateService handles software updates
type UpdateService struct { type UpdateService struct {
cache ports.UpdateCache cache ports.UpdateCache
githubClient GitHubReleaseClient
currentVersion string currentVersion string
buildType string // "source" for manual builds, "release" for CI builds buildType string // "source" for manual builds, "release" for CI builds
} }
// NewUpdateService creates a new UpdateService // NewUpdateService creates a new UpdateService
func NewUpdateService(cache ports.UpdateCache, version, buildType string) *UpdateService { func NewUpdateService(cache ports.UpdateCache, githubClient GitHubReleaseClient, version, buildType string) *UpdateService {
return &UpdateService{ return &UpdateService{
cache: cache, cache: cache,
githubClient: githubClient,
currentVersion: version, currentVersion: version,
buildType: buildType, buildType: buildType,
} }
@@ -260,42 +268,11 @@ func (s *UpdateService) Rollback() error {
return nil return nil
} }
func (s *UpdateService) fetchLatestRelease(ctx context.Context) (*UpdateInfo, error) { func (s *UpdateService) fetchLatestRelease(ctx context.Context) (*UpdateInfo, error) {
url := fmt.Sprintf("https://api.github.com/repos/%s/releases/latest", githubRepo) release, err := s.githubClient.FetchLatestRelease(ctx, githubRepo)
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
req.Header.Set("Accept", "application/vnd.github.v3+json")
req.Header.Set("User-Agent", "Sub2API-Updater")
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusNotFound {
return &UpdateInfo{
CurrentVersion: s.currentVersion,
LatestVersion: s.currentVersion,
HasUpdate: false,
Warning: "No releases found",
BuildType: s.buildType,
}, nil
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("GitHub API returned %d", resp.StatusCode)
}
var release GitHubRelease
if err := json.NewDecoder(resp.Body).Decode(&release); err != nil {
return nil, err
}
latestVersion := strings.TrimPrefix(release.TagName, "v") latestVersion := strings.TrimPrefix(release.TagName, "v")
@@ -325,47 +302,7 @@ func (s *UpdateService) fetchLatestRelease(ctx context.Context) (*UpdateInfo, er
} }
func (s *UpdateService) downloadFile(ctx context.Context, downloadURL, dest string) error { func (s *UpdateService) downloadFile(ctx context.Context, downloadURL, dest string) error {
req, err := http.NewRequestWithContext(ctx, "GET", downloadURL, nil) return s.githubClient.DownloadFile(ctx, downloadURL, dest, maxDownloadSize)
if err != nil {
return err
}
client := &http.Client{Timeout: 10 * time.Minute}
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("download returned %d", resp.StatusCode)
}
// SECURITY: Check Content-Length if available
if resp.ContentLength > maxDownloadSize {
return fmt.Errorf("file too large: %d bytes (max %d)", resp.ContentLength, maxDownloadSize)
}
out, err := os.Create(dest)
if err != nil {
return err
}
defer out.Close()
// SECURITY: Use LimitReader to enforce max download size even if Content-Length is missing/wrong
limited := io.LimitReader(resp.Body, maxDownloadSize+1)
written, err := io.Copy(out, limited)
if err != nil {
return err
}
// Check if we hit the limit (downloaded more than maxDownloadSize)
if written > maxDownloadSize {
os.Remove(dest) // Clean up partial file
return fmt.Errorf("download exceeded maximum size of %d bytes", maxDownloadSize)
}
return nil
} }
func (s *UpdateService) getArchiveName() string { func (s *UpdateService) getArchiveName() string {
@@ -402,20 +339,9 @@ func validateDownloadURL(rawURL string) error {
func (s *UpdateService) verifyChecksum(ctx context.Context, filePath, checksumURL string) error { func (s *UpdateService) verifyChecksum(ctx context.Context, filePath, checksumURL string) error {
// Download checksums file // Download checksums file
req, err := http.NewRequestWithContext(ctx, "GET", checksumURL, nil) checksumData, err := s.githubClient.FetchChecksumFile(ctx, checksumURL)
if err != nil { if err != nil {
return err return fmt.Errorf("failed to download checksums: %w", err)
}
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("failed to download checksums: %d", resp.StatusCode)
} }
// Calculate file hash // Calculate file hash
@@ -433,7 +359,7 @@ func (s *UpdateService) verifyChecksum(ctx context.Context, filePath, checksumUR
// Find expected hash in checksums file // Find expected hash in checksums file
fileName := filepath.Base(filePath) fileName := filepath.Base(filePath)
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(strings.NewReader(string(checksumData)))
for scanner.Scan() { for scanner.Scan() {
line := scanner.Text() line := scanner.Text()
parts := strings.Fields(line) parts := strings.Fields(line)

View File

@@ -2,13 +2,20 @@ package service
import ( import (
"sub2api/internal/config" "sub2api/internal/config"
"sub2api/internal/service/ports"
"github.com/google/wire" "github.com/google/wire"
) )
// BuildInfo contains build information
type BuildInfo struct {
Version string
BuildType string
}
// ProvidePricingService creates and initializes PricingService // ProvidePricingService creates and initializes PricingService
func ProvidePricingService(cfg *config.Config) (*PricingService, error) { func ProvidePricingService(cfg *config.Config, remoteClient PricingRemoteClient) (*PricingService, error) {
svc := NewPricingService(cfg) svc := NewPricingService(cfg, remoteClient)
if err := svc.Initialize(); err != nil { if err := svc.Initialize(); err != nil {
// 价格服务初始化失败不应阻止启动,使用回退价格 // 价格服务初始化失败不应阻止启动,使用回退价格
println("[Service] Warning: Pricing service initialization failed:", err.Error()) println("[Service] Warning: Pricing service initialization failed:", err.Error())
@@ -16,6 +23,11 @@ func ProvidePricingService(cfg *config.Config) (*PricingService, error) {
return svc, nil return svc, nil
} }
// ProvideUpdateService creates UpdateService with BuildInfo
func ProvideUpdateService(cache ports.UpdateCache, githubClient GitHubReleaseClient, buildInfo BuildInfo) *UpdateService {
return NewUpdateService(cache, githubClient, buildInfo.Version, buildInfo.BuildType)
}
// ProvideEmailQueueService creates EmailQueueService with default worker count // ProvideEmailQueueService creates EmailQueueService with default worker count
func ProvideEmailQueueService(emailService *EmailService) *EmailQueueService { func ProvideEmailQueueService(emailService *EmailService) *EmailQueueService {
return NewEmailQueueService(emailService, 3) return NewEmailQueueService(emailService, 3)
@@ -48,6 +60,7 @@ var ProviderSet = wire.NewSet(
NewSubscriptionService, NewSubscriptionService,
NewConcurrencyService, NewConcurrencyService,
NewIdentityService, NewIdentityService,
ProvideUpdateService,
// Provide the Services container struct // Provide the Services container struct
wire.Struct(new(Services), "*"), wire.Struct(new(Services), "*"),