Merge pull request #9 from NepetaLemon/refactor/add-http-service-ports
refactor(backend): service http ports
This commit is contained in:
@@ -40,6 +40,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
// 服务器层 ProviderSet
|
// 服务器层 ProviderSet
|
||||||
server.ProviderSet,
|
server.ProviderSet,
|
||||||
|
|
||||||
|
// BuildInfo provider
|
||||||
|
provideServiceBuildInfo,
|
||||||
|
|
||||||
// 清理函数提供者
|
// 清理函数提供者
|
||||||
provideCleanup,
|
provideCleanup,
|
||||||
|
|
||||||
@@ -49,6 +52,13 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo {
|
||||||
|
return service.BuildInfo{
|
||||||
|
Version: buildInfo.Version,
|
||||||
|
BuildType: buildInfo.BuildType,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func provideCleanup(
|
func provideCleanup(
|
||||||
db *gorm.DB,
|
db *gorm.DB,
|
||||||
rdb *redis.Client,
|
rdb *redis.Client,
|
||||||
|
|||||||
@@ -43,7 +43,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
client := infrastructure.ProvideRedis(configConfig)
|
client := infrastructure.ProvideRedis(configConfig)
|
||||||
emailCache := repository.NewEmailCache(client)
|
emailCache := repository.NewEmailCache(client)
|
||||||
emailService := service.NewEmailService(settingRepository, emailCache)
|
emailService := service.NewEmailService(settingRepository, emailCache)
|
||||||
turnstileService := service.NewTurnstileService(settingService)
|
turnstileVerifier := repository.NewTurnstileVerifier()
|
||||||
|
turnstileService := service.NewTurnstileService(settingService, turnstileVerifier)
|
||||||
emailQueueService := service.ProvideEmailQueueService(emailService)
|
emailQueueService := service.ProvideEmailQueueService(emailService)
|
||||||
authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService)
|
authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService)
|
||||||
authHandler := handler.NewAuthHandler(authService)
|
authHandler := handler.NewAuthHandler(authService)
|
||||||
@@ -68,32 +69,41 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService)
|
subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService)
|
||||||
accountRepository := repository.NewAccountRepository(db)
|
accountRepository := repository.NewAccountRepository(db)
|
||||||
proxyRepository := repository.NewProxyRepository(db)
|
proxyRepository := repository.NewProxyRepository(db)
|
||||||
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, usageLogRepository, userSubscriptionRepository, billingCacheService)
|
proxyExitInfoProber := repository.NewProxyExitInfoProber()
|
||||||
|
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, usageLogRepository, userSubscriptionRepository, billingCacheService, proxyExitInfoProber)
|
||||||
dashboardHandler := admin.NewDashboardHandler(adminService, usageLogRepository)
|
dashboardHandler := admin.NewDashboardHandler(adminService, usageLogRepository)
|
||||||
adminUserHandler := admin.NewUserHandler(adminService)
|
adminUserHandler := admin.NewUserHandler(adminService)
|
||||||
groupHandler := admin.NewGroupHandler(adminService)
|
groupHandler := admin.NewGroupHandler(adminService)
|
||||||
oAuthService := service.NewOAuthService(proxyRepository)
|
claudeOAuthClient := repository.NewClaudeOAuthClient()
|
||||||
|
oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient)
|
||||||
rateLimitService := service.NewRateLimitService(accountRepository, configConfig)
|
rateLimitService := service.NewRateLimitService(accountRepository, configConfig)
|
||||||
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, oAuthService)
|
claudeUsageFetcher := repository.NewClaudeUsageFetcher()
|
||||||
accountTestService := service.NewAccountTestService(accountRepository, oAuthService)
|
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, oAuthService, claudeUsageFetcher)
|
||||||
|
claudeUpstream := repository.NewClaudeUpstream(configConfig)
|
||||||
|
accountTestService := service.NewAccountTestService(accountRepository, oAuthService, claudeUpstream)
|
||||||
accountHandler := admin.NewAccountHandler(adminService, oAuthService, rateLimitService, accountUsageService, accountTestService)
|
accountHandler := admin.NewAccountHandler(adminService, oAuthService, rateLimitService, accountUsageService, accountTestService)
|
||||||
oAuthHandler := admin.NewOAuthHandler(oAuthService, adminService)
|
oAuthHandler := admin.NewOAuthHandler(oAuthService, adminService)
|
||||||
proxyHandler := admin.NewProxyHandler(adminService)
|
proxyHandler := admin.NewProxyHandler(adminService)
|
||||||
adminRedeemHandler := admin.NewRedeemHandler(adminService)
|
adminRedeemHandler := admin.NewRedeemHandler(adminService)
|
||||||
settingHandler := admin.NewSettingHandler(settingService, emailService)
|
settingHandler := admin.NewSettingHandler(settingService, emailService)
|
||||||
systemHandler := handler.ProvideSystemHandler(client, buildInfo)
|
updateCache := repository.NewUpdateCache(client)
|
||||||
|
gitHubReleaseClient := repository.NewGitHubReleaseClient()
|
||||||
|
serviceBuildInfo := provideServiceBuildInfo(buildInfo)
|
||||||
|
updateService := service.ProvideUpdateService(updateCache, gitHubReleaseClient, serviceBuildInfo)
|
||||||
|
systemHandler := handler.ProvideSystemHandler(updateService)
|
||||||
adminSubscriptionHandler := admin.NewSubscriptionHandler(subscriptionService)
|
adminSubscriptionHandler := admin.NewSubscriptionHandler(subscriptionService)
|
||||||
adminUsageHandler := admin.NewUsageHandler(usageLogRepository, apiKeyRepository, usageService, adminService)
|
adminUsageHandler := admin.NewUsageHandler(usageLogRepository, apiKeyRepository, usageService, adminService)
|
||||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, proxyHandler, adminRedeemHandler, settingHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler)
|
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, proxyHandler, adminRedeemHandler, settingHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler)
|
||||||
gatewayCache := repository.NewGatewayCache(client)
|
gatewayCache := repository.NewGatewayCache(client)
|
||||||
pricingService, err := service.ProvidePricingService(configConfig)
|
pricingRemoteClient := repository.NewPricingRemoteClient()
|
||||||
|
pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
billingService := service.NewBillingService(configConfig, pricingService)
|
billingService := service.NewBillingService(configConfig, pricingService)
|
||||||
identityCache := repository.NewIdentityCache(client)
|
identityCache := repository.NewIdentityCache(client)
|
||||||
identityService := service.NewIdentityService(identityCache)
|
identityService := service.NewIdentityService(identityCache)
|
||||||
gatewayService := service.NewGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, oAuthService, billingService, rateLimitService, billingCacheService, identityService)
|
gatewayService := service.NewGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, oAuthService, billingService, rateLimitService, billingCacheService, identityService, claudeUpstream)
|
||||||
concurrencyCache := repository.NewConcurrencyCache(client)
|
concurrencyCache := repository.NewConcurrencyCache(client)
|
||||||
concurrencyService := service.NewConcurrencyService(concurrencyCache)
|
concurrencyService := service.NewConcurrencyService(concurrencyCache)
|
||||||
gatewayHandler := handler.NewGatewayHandler(gatewayService, userService, concurrencyService, billingCacheService)
|
gatewayHandler := handler.NewGatewayHandler(gatewayService, userService, concurrencyService, billingCacheService)
|
||||||
@@ -127,6 +137,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
Subscription: subscriptionService,
|
Subscription: subscriptionService,
|
||||||
Concurrency: concurrencyService,
|
Concurrency: concurrencyService,
|
||||||
Identity: identityService,
|
Identity: identityService,
|
||||||
|
Update: updateService,
|
||||||
}
|
}
|
||||||
repositories := &repository.Repositories{
|
repositories := &repository.Repositories{
|
||||||
User: userRepository,
|
User: userRepository,
|
||||||
@@ -156,6 +167,13 @@ type Application struct {
|
|||||||
Cleanup func()
|
Cleanup func()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo {
|
||||||
|
return service.BuildInfo{
|
||||||
|
Version: buildInfo.Version,
|
||||||
|
BuildType: buildInfo.BuildType,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func provideCleanup(
|
func provideCleanup(
|
||||||
db *gorm.DB,
|
db *gorm.DB,
|
||||||
rdb *redis.Client,
|
rdb *redis.Client,
|
||||||
|
|||||||
@@ -6,11 +6,9 @@ import (
|
|||||||
|
|
||||||
"sub2api/internal/pkg/response"
|
"sub2api/internal/pkg/response"
|
||||||
"sub2api/internal/pkg/sysutil"
|
"sub2api/internal/pkg/sysutil"
|
||||||
"sub2api/internal/repository"
|
|
||||||
"sub2api/internal/service"
|
"sub2api/internal/service"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/redis/go-redis/v9"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// SystemHandler handles system-related operations
|
// SystemHandler handles system-related operations
|
||||||
@@ -19,10 +17,9 @@ type SystemHandler struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewSystemHandler creates a new SystemHandler
|
// NewSystemHandler creates a new SystemHandler
|
||||||
func NewSystemHandler(rdb *redis.Client, version, buildType string) *SystemHandler {
|
func NewSystemHandler(updateSvc *service.UpdateService) *SystemHandler {
|
||||||
updateCache := repository.NewUpdateCache(rdb)
|
|
||||||
return &SystemHandler{
|
return &SystemHandler{
|
||||||
updateSvc: service.NewUpdateService(updateCache, version, buildType),
|
updateSvc: updateSvc,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ import (
|
|||||||
"sub2api/internal/service"
|
"sub2api/internal/service"
|
||||||
|
|
||||||
"github.com/google/wire"
|
"github.com/google/wire"
|
||||||
"github.com/redis/go-redis/v9"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// ProvideAdminHandlers creates the AdminHandlers struct
|
// ProvideAdminHandlers creates the AdminHandlers struct
|
||||||
@@ -37,9 +36,9 @@ func ProvideAdminHandlers(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ProvideSystemHandler creates admin.SystemHandler with BuildInfo parameters
|
// ProvideSystemHandler creates admin.SystemHandler with UpdateService
|
||||||
func ProvideSystemHandler(rdb *redis.Client, buildInfo BuildInfo) *admin.SystemHandler {
|
func ProvideSystemHandler(updateService *service.UpdateService) *admin.SystemHandler {
|
||||||
return admin.NewSystemHandler(rdb, buildInfo.Version, buildInfo.BuildType)
|
return admin.NewSystemHandler(updateService)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ProvideSettingHandler creates SettingHandler with version from BuildInfo
|
// ProvideSettingHandler creates SettingHandler with version from BuildInfo
|
||||||
|
|||||||
235
backend/internal/repository/claude_oauth_service.go
Normal file
235
backend/internal/repository/claude_oauth_service.go
Normal file
@@ -0,0 +1,235 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"sub2api/internal/pkg/oauth"
|
||||||
|
"sub2api/internal/service"
|
||||||
|
|
||||||
|
"github.com/imroc/req/v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
type claudeOAuthService struct{}
|
||||||
|
|
||||||
|
func NewClaudeOAuthClient() service.ClaudeOAuthClient {
|
||||||
|
return &claudeOAuthService{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey, proxyURL string) (string, error) {
|
||||||
|
client := createReqClient(proxyURL)
|
||||||
|
|
||||||
|
var orgs []struct {
|
||||||
|
UUID string `json:"uuid"`
|
||||||
|
}
|
||||||
|
|
||||||
|
targetURL := "https://claude.ai/api/organizations"
|
||||||
|
log.Printf("[OAuth] Step 1: Getting organization UUID from %s", targetURL)
|
||||||
|
|
||||||
|
resp, err := client.R().
|
||||||
|
SetContext(ctx).
|
||||||
|
SetCookies(&http.Cookie{
|
||||||
|
Name: "sessionKey",
|
||||||
|
Value: sessionKey,
|
||||||
|
}).
|
||||||
|
SetSuccessResult(&orgs).
|
||||||
|
Get(targetURL)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[OAuth] Step 1 FAILED - Request error: %v", err)
|
||||||
|
return "", fmt.Errorf("request failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("[OAuth] Step 1 Response - Status: %d, Body: %s", resp.StatusCode, resp.String())
|
||||||
|
|
||||||
|
if !resp.IsSuccessState() {
|
||||||
|
return "", fmt.Errorf("failed to get organizations: status %d, body: %s", resp.StatusCode, resp.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(orgs) == 0 {
|
||||||
|
return "", fmt.Errorf("no organizations found")
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("[OAuth] Step 1 SUCCESS - Got org UUID: %s", orgs[0].UUID)
|
||||||
|
return orgs[0].UUID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error) {
|
||||||
|
client := createReqClient(proxyURL)
|
||||||
|
|
||||||
|
authURL := fmt.Sprintf("https://claude.ai/v1/oauth/%s/authorize", orgUUID)
|
||||||
|
|
||||||
|
reqBody := map[string]interface{}{
|
||||||
|
"response_type": "code",
|
||||||
|
"client_id": oauth.ClientID,
|
||||||
|
"organization_uuid": orgUUID,
|
||||||
|
"redirect_uri": oauth.RedirectURI,
|
||||||
|
"scope": scope,
|
||||||
|
"state": state,
|
||||||
|
"code_challenge": codeChallenge,
|
||||||
|
"code_challenge_method": "S256",
|
||||||
|
}
|
||||||
|
|
||||||
|
reqBodyJSON, _ := json.Marshal(reqBody)
|
||||||
|
log.Printf("[OAuth] Step 2: Getting authorization code from %s", authURL)
|
||||||
|
log.Printf("[OAuth] Step 2 Request Body: %s", string(reqBodyJSON))
|
||||||
|
|
||||||
|
var result struct {
|
||||||
|
RedirectURI string `json:"redirect_uri"`
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := client.R().
|
||||||
|
SetContext(ctx).
|
||||||
|
SetCookies(&http.Cookie{
|
||||||
|
Name: "sessionKey",
|
||||||
|
Value: sessionKey,
|
||||||
|
}).
|
||||||
|
SetHeader("Accept", "application/json").
|
||||||
|
SetHeader("Accept-Language", "en-US,en;q=0.9").
|
||||||
|
SetHeader("Cache-Control", "no-cache").
|
||||||
|
SetHeader("Origin", "https://claude.ai").
|
||||||
|
SetHeader("Referer", "https://claude.ai/new").
|
||||||
|
SetHeader("Content-Type", "application/json").
|
||||||
|
SetBody(reqBody).
|
||||||
|
SetSuccessResult(&result).
|
||||||
|
Post(authURL)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[OAuth] Step 2 FAILED - Request error: %v", err)
|
||||||
|
return "", fmt.Errorf("request failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("[OAuth] Step 2 Response - Status: %d, Body: %s", resp.StatusCode, resp.String())
|
||||||
|
|
||||||
|
if !resp.IsSuccessState() {
|
||||||
|
return "", fmt.Errorf("failed to get authorization code: status %d, body: %s", resp.StatusCode, resp.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.RedirectURI == "" {
|
||||||
|
return "", fmt.Errorf("no redirect_uri in response")
|
||||||
|
}
|
||||||
|
|
||||||
|
parsedURL, err := url.Parse(result.RedirectURI)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to parse redirect_uri: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
queryParams := parsedURL.Query()
|
||||||
|
authCode := queryParams.Get("code")
|
||||||
|
responseState := queryParams.Get("state")
|
||||||
|
|
||||||
|
if authCode == "" {
|
||||||
|
return "", fmt.Errorf("no authorization code in redirect_uri")
|
||||||
|
}
|
||||||
|
|
||||||
|
fullCode := authCode
|
||||||
|
if responseState != "" {
|
||||||
|
fullCode = authCode + "#" + responseState
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("[OAuth] Step 2 SUCCESS - Got authorization code: %s...", authCode[:20])
|
||||||
|
return fullCode, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string) (*oauth.TokenResponse, error) {
|
||||||
|
client := createReqClient(proxyURL)
|
||||||
|
|
||||||
|
authCode := code
|
||||||
|
codeState := ""
|
||||||
|
if len(code) > 0 {
|
||||||
|
parts := make([]string, 0, 2)
|
||||||
|
for i, part := range []rune(code) {
|
||||||
|
if part == '#' {
|
||||||
|
authCode = code[:i]
|
||||||
|
codeState = code[i+1:]
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(parts) == 0 {
|
||||||
|
authCode = code
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
reqBody := map[string]interface{}{
|
||||||
|
"code": authCode,
|
||||||
|
"grant_type": "authorization_code",
|
||||||
|
"client_id": oauth.ClientID,
|
||||||
|
"redirect_uri": oauth.RedirectURI,
|
||||||
|
"code_verifier": codeVerifier,
|
||||||
|
}
|
||||||
|
|
||||||
|
if codeState != "" {
|
||||||
|
reqBody["state"] = codeState
|
||||||
|
}
|
||||||
|
|
||||||
|
reqBodyJSON, _ := json.Marshal(reqBody)
|
||||||
|
log.Printf("[OAuth] Step 3: Exchanging code for token at %s", oauth.TokenURL)
|
||||||
|
log.Printf("[OAuth] Step 3 Request Body: %s", string(reqBodyJSON))
|
||||||
|
|
||||||
|
var tokenResp oauth.TokenResponse
|
||||||
|
|
||||||
|
resp, err := client.R().
|
||||||
|
SetContext(ctx).
|
||||||
|
SetHeader("Content-Type", "application/json").
|
||||||
|
SetBody(reqBody).
|
||||||
|
SetSuccessResult(&tokenResp).
|
||||||
|
Post(oauth.TokenURL)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[OAuth] Step 3 FAILED - Request error: %v", err)
|
||||||
|
return nil, fmt.Errorf("request failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("[OAuth] Step 3 Response - Status: %d, Body: %s", resp.StatusCode, resp.String())
|
||||||
|
|
||||||
|
if !resp.IsSuccessState() {
|
||||||
|
return nil, fmt.Errorf("token exchange failed: status %d, body: %s", resp.StatusCode, resp.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("[OAuth] Step 3 SUCCESS - Got access token")
|
||||||
|
return &tokenResp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) {
|
||||||
|
client := createReqClient(proxyURL)
|
||||||
|
|
||||||
|
formData := url.Values{}
|
||||||
|
formData.Set("grant_type", "refresh_token")
|
||||||
|
formData.Set("refresh_token", refreshToken)
|
||||||
|
formData.Set("client_id", oauth.ClientID)
|
||||||
|
|
||||||
|
var tokenResp oauth.TokenResponse
|
||||||
|
|
||||||
|
resp, err := client.R().
|
||||||
|
SetContext(ctx).
|
||||||
|
SetFormDataFromValues(formData).
|
||||||
|
SetSuccessResult(&tokenResp).
|
||||||
|
Post(oauth.TokenURL)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("request failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !resp.IsSuccessState() {
|
||||||
|
return nil, fmt.Errorf("token refresh failed: status %d, body: %s", resp.StatusCode, resp.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
return &tokenResp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func createReqClient(proxyURL string) *req.Client {
|
||||||
|
client := req.C().
|
||||||
|
ImpersonateChrome().
|
||||||
|
SetTimeout(60 * time.Second)
|
||||||
|
|
||||||
|
if proxyURL != "" {
|
||||||
|
client.SetProxyURL(proxyURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
return client
|
||||||
|
}
|
||||||
64
backend/internal/repository/claude_service.go
Normal file
64
backend/internal/repository/claude_service.go
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"sub2api/internal/config"
|
||||||
|
"sub2api/internal/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
type claudeUpstreamService struct {
|
||||||
|
defaultClient *http.Client
|
||||||
|
cfg *config.Config
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewClaudeUpstream(cfg *config.Config) service.ClaudeUpstream {
|
||||||
|
responseHeaderTimeout := time.Duration(cfg.Gateway.ResponseHeaderTimeout) * time.Second
|
||||||
|
if responseHeaderTimeout == 0 {
|
||||||
|
responseHeaderTimeout = 300 * time.Second
|
||||||
|
}
|
||||||
|
|
||||||
|
transport := &http.Transport{
|
||||||
|
MaxIdleConns: 100,
|
||||||
|
MaxIdleConnsPerHost: 10,
|
||||||
|
IdleConnTimeout: 90 * time.Second,
|
||||||
|
ResponseHeaderTimeout: responseHeaderTimeout,
|
||||||
|
}
|
||||||
|
|
||||||
|
return &claudeUpstreamService{
|
||||||
|
defaultClient: &http.Client{Transport: transport},
|
||||||
|
cfg: cfg,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *claudeUpstreamService) Do(req *http.Request, proxyURL string) (*http.Response, error) {
|
||||||
|
if proxyURL == "" {
|
||||||
|
return s.defaultClient.Do(req)
|
||||||
|
}
|
||||||
|
client := s.createProxyClient(proxyURL)
|
||||||
|
return client.Do(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *claudeUpstreamService) createProxyClient(proxyURL string) *http.Client {
|
||||||
|
parsedURL, err := url.Parse(proxyURL)
|
||||||
|
if err != nil {
|
||||||
|
return s.defaultClient
|
||||||
|
}
|
||||||
|
|
||||||
|
responseHeaderTimeout := time.Duration(s.cfg.Gateway.ResponseHeaderTimeout) * time.Second
|
||||||
|
if responseHeaderTimeout == 0 {
|
||||||
|
responseHeaderTimeout = 300 * time.Second
|
||||||
|
}
|
||||||
|
|
||||||
|
transport := &http.Transport{
|
||||||
|
Proxy: http.ProxyURL(parsedURL),
|
||||||
|
MaxIdleConns: 100,
|
||||||
|
MaxIdleConnsPerHost: 10,
|
||||||
|
IdleConnTimeout: 90 * time.Second,
|
||||||
|
ResponseHeaderTimeout: responseHeaderTimeout,
|
||||||
|
}
|
||||||
|
|
||||||
|
return &http.Client{Transport: transport}
|
||||||
|
}
|
||||||
59
backend/internal/repository/claude_usage_service.go
Normal file
59
backend/internal/repository/claude_usage_service.go
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"sub2api/internal/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
type claudeUsageService struct{}
|
||||||
|
|
||||||
|
func NewClaudeUsageFetcher() service.ClaudeUsageFetcher {
|
||||||
|
return &claudeUsageService{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *claudeUsageService) FetchUsage(ctx context.Context, accessToken, proxyURL string) (*service.ClaudeUsageResponse, error) {
|
||||||
|
transport := http.DefaultTransport.(*http.Transport).Clone()
|
||||||
|
if proxyURL != "" {
|
||||||
|
if parsedURL, err := url.Parse(proxyURL); err == nil {
|
||||||
|
transport.Proxy = http.ProxyURL(parsedURL)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
client := &http.Client{
|
||||||
|
Transport: transport,
|
||||||
|
Timeout: 30 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "GET", "https://api.anthropic.com/api/oauth/usage", nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("create request failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||||
|
req.Header.Set("anthropic-beta", "oauth-2025-04-20")
|
||||||
|
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("request failed: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
return nil, fmt.Errorf("API returned status %d: %s", resp.StatusCode, string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
var usageResp service.ClaudeUsageResponse
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&usageResp); err != nil {
|
||||||
|
return nil, fmt.Errorf("decode response failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &usageResp, nil
|
||||||
|
}
|
||||||
116
backend/internal/repository/github_release_service.go
Normal file
116
backend/internal/repository/github_release_service.go
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"sub2api/internal/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
type githubReleaseClient struct {
|
||||||
|
httpClient *http.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewGitHubReleaseClient() service.GitHubReleaseClient {
|
||||||
|
return &githubReleaseClient{
|
||||||
|
httpClient: &http.Client{
|
||||||
|
Timeout: 30 * time.Second,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *githubReleaseClient) FetchLatestRelease(ctx context.Context, repo string) (*service.GitHubRelease, error) {
|
||||||
|
url := fmt.Sprintf("https://api.github.com/repos/%s/releases/latest", repo)
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req.Header.Set("Accept", "application/vnd.github.v3+json")
|
||||||
|
req.Header.Set("User-Agent", "Sub2API-Updater")
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, fmt.Errorf("GitHub API returned %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
var release service.GitHubRelease
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&release); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &release, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *githubReleaseClient) DownloadFile(ctx context.Context, url, dest string, maxSize int64) error {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
client := &http.Client{Timeout: 10 * time.Minute}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return fmt.Errorf("download returned %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SECURITY: Check Content-Length if available
|
||||||
|
if resp.ContentLength > maxSize {
|
||||||
|
return fmt.Errorf("file too large: %d bytes (max %d)", resp.ContentLength, maxSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
out, err := os.Create(dest)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer out.Close()
|
||||||
|
|
||||||
|
// SECURITY: Use LimitReader to enforce max download size even if Content-Length is missing/wrong
|
||||||
|
limited := io.LimitReader(resp.Body, maxSize+1)
|
||||||
|
written, err := io.Copy(out, limited)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if we hit the limit (downloaded more than maxSize)
|
||||||
|
if written > maxSize {
|
||||||
|
os.Remove(dest) // Clean up partial file
|
||||||
|
return fmt.Errorf("download exceeded maximum size of %d bytes", maxSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *githubReleaseClient) FetchChecksumFile(ctx context.Context, url string) ([]byte, error) {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, fmt.Errorf("HTTP %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
return io.ReadAll(resp.Body)
|
||||||
|
}
|
||||||
73
backend/internal/repository/pricing_service.go
Normal file
73
backend/internal/repository/pricing_service.go
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"sub2api/internal/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
type pricingRemoteClient struct {
|
||||||
|
httpClient *http.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPricingRemoteClient() service.PricingRemoteClient {
|
||||||
|
return &pricingRemoteClient{
|
||||||
|
httpClient: &http.Client{
|
||||||
|
Timeout: 30 * time.Second,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *pricingRemoteClient) FetchPricingJSON(ctx context.Context, url string) ([]byte, error) {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, fmt.Errorf("HTTP %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
return io.ReadAll(resp.Body)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *pricingRemoteClient) FetchHashText(ctx context.Context, url string) (string, error) {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return "", fmt.Errorf("HTTP %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 哈希文件格式:hash filename 或者纯 hash
|
||||||
|
hash := strings.TrimSpace(string(body))
|
||||||
|
parts := strings.Fields(hash)
|
||||||
|
if len(parts) > 0 {
|
||||||
|
return parts[0], nil
|
||||||
|
}
|
||||||
|
return hash, nil
|
||||||
|
}
|
||||||
104
backend/internal/repository/proxy_probe_service.go
Normal file
104
backend/internal/repository/proxy_probe_service.go
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"sub2api/internal/service"
|
||||||
|
|
||||||
|
"golang.org/x/net/proxy"
|
||||||
|
)
|
||||||
|
|
||||||
|
type proxyProbeService struct{}
|
||||||
|
|
||||||
|
func NewProxyExitInfoProber() service.ProxyExitInfoProber {
|
||||||
|
return &proxyProbeService{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*service.ProxyExitInfo, int64, error) {
|
||||||
|
transport, err := createProxyTransport(proxyURL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, fmt.Errorf("failed to create proxy transport: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
client := &http.Client{
|
||||||
|
Transport: transport,
|
||||||
|
Timeout: 15 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
startTime := time.Now()
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "GET", "https://ipinfo.io/json", nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, fmt.Errorf("failed to create request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, fmt.Errorf("proxy connection failed: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
latencyMs := time.Since(startTime).Milliseconds()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, latencyMs, fmt.Errorf("request failed with status: %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
var ipInfo struct {
|
||||||
|
IP string `json:"ip"`
|
||||||
|
City string `json:"city"`
|
||||||
|
Region string `json:"region"`
|
||||||
|
Country string `json:"country"`
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, latencyMs, fmt.Errorf("failed to read response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal(body, &ipInfo); err != nil {
|
||||||
|
return nil, latencyMs, fmt.Errorf("failed to parse response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &service.ProxyExitInfo{
|
||||||
|
IP: ipInfo.IP,
|
||||||
|
City: ipInfo.City,
|
||||||
|
Region: ipInfo.Region,
|
||||||
|
Country: ipInfo.Country,
|
||||||
|
}, latencyMs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func createProxyTransport(proxyURL string) (*http.Transport, error) {
|
||||||
|
parsedURL, err := url.Parse(proxyURL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid proxy URL: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
transport := &http.Transport{
|
||||||
|
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||||
|
}
|
||||||
|
|
||||||
|
switch parsedURL.Scheme {
|
||||||
|
case "http", "https":
|
||||||
|
transport.Proxy = http.ProxyURL(parsedURL)
|
||||||
|
case "socks5":
|
||||||
|
dialer, err := proxy.FromURL(parsedURL, proxy.Direct)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create socks5 dialer: %w", err)
|
||||||
|
}
|
||||||
|
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
return dialer.Dial(network, addr)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unsupported proxy protocol: %s", parsedURL.Scheme)
|
||||||
|
}
|
||||||
|
|
||||||
|
return transport, nil
|
||||||
|
}
|
||||||
55
backend/internal/repository/turnstile_service.go
Normal file
55
backend/internal/repository/turnstile_service.go
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"sub2api/internal/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
const turnstileVerifyURL = "https://challenges.cloudflare.com/turnstile/v0/siteverify"
|
||||||
|
|
||||||
|
type turnstileVerifier struct {
|
||||||
|
httpClient *http.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewTurnstileVerifier() service.TurnstileVerifier {
|
||||||
|
return &turnstileVerifier{
|
||||||
|
httpClient: &http.Client{
|
||||||
|
Timeout: 10 * time.Second,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *turnstileVerifier) VerifyToken(ctx context.Context, secretKey, token, remoteIP string) (*service.TurnstileVerifyResponse, error) {
|
||||||
|
formData := url.Values{}
|
||||||
|
formData.Set("secret", secretKey)
|
||||||
|
formData.Set("response", token)
|
||||||
|
if remoteIP != "" {
|
||||||
|
formData.Set("remoteip", remoteIP)
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, turnstileVerifyURL, strings.NewReader(formData.Encode()))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("create request: %w", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
|
||||||
|
resp, err := v.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("send request: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
var result service.TurnstileVerifyResponse
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||||
|
return nil, fmt.Errorf("decode response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &result, nil
|
||||||
|
}
|
||||||
@@ -29,6 +29,15 @@ var ProviderSet = wire.NewSet(
|
|||||||
NewRedeemCache,
|
NewRedeemCache,
|
||||||
NewUpdateCache,
|
NewUpdateCache,
|
||||||
|
|
||||||
|
// HTTP service ports (DI Strategy A: return interface directly)
|
||||||
|
NewTurnstileVerifier,
|
||||||
|
NewPricingRemoteClient,
|
||||||
|
NewGitHubReleaseClient,
|
||||||
|
NewProxyExitInfoProber,
|
||||||
|
NewClaudeUsageFetcher,
|
||||||
|
NewClaudeOAuthClient,
|
||||||
|
NewClaudeUpstream,
|
||||||
|
|
||||||
// Bind concrete repositories to service port interfaces
|
// Bind concrete repositories to service port interfaces
|
||||||
wire.Bind(new(ports.UserRepository), new(*UserRepository)),
|
wire.Bind(new(ports.UserRepository), new(*UserRepository)),
|
||||||
wire.Bind(new(ports.ApiKeyRepository), new(*ApiKeyRepository)),
|
wire.Bind(new(ports.ApiKeyRepository), new(*ApiKeyRepository)),
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@@ -37,19 +36,17 @@ type TestEvent struct {
|
|||||||
|
|
||||||
// AccountTestService handles account testing operations
|
// AccountTestService handles account testing operations
|
||||||
type AccountTestService struct {
|
type AccountTestService struct {
|
||||||
accountRepo ports.AccountRepository
|
accountRepo ports.AccountRepository
|
||||||
oauthService *OAuthService
|
oauthService *OAuthService
|
||||||
httpClient *http.Client
|
claudeUpstream ClaudeUpstream
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAccountTestService creates a new AccountTestService
|
// NewAccountTestService creates a new AccountTestService
|
||||||
func NewAccountTestService(accountRepo ports.AccountRepository, oauthService *OAuthService) *AccountTestService {
|
func NewAccountTestService(accountRepo ports.AccountRepository, oauthService *OAuthService, claudeUpstream ClaudeUpstream) *AccountTestService {
|
||||||
return &AccountTestService{
|
return &AccountTestService{
|
||||||
accountRepo: accountRepo,
|
accountRepo: accountRepo,
|
||||||
oauthService: oauthService,
|
oauthService: oauthService,
|
||||||
httpClient: &http.Client{
|
claudeUpstream: claudeUpstream,
|
||||||
Timeout: 60 * time.Second,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -209,23 +206,13 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
|
|||||||
req.Header.Set("x-api-key", authToken)
|
req.Header.Set("x-api-key", authToken)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Configure proxy if account has one
|
// Get proxy URL
|
||||||
transport := http.DefaultTransport.(*http.Transport).Clone()
|
proxyURL := ""
|
||||||
if account.ProxyID != nil && account.Proxy != nil {
|
if account.ProxyID != nil && account.Proxy != nil {
|
||||||
proxyURL := account.Proxy.URL()
|
proxyURL = account.Proxy.URL()
|
||||||
if proxyURL != "" {
|
|
||||||
if parsedURL, err := url.Parse(proxyURL); err == nil {
|
|
||||||
transport.Proxy = http.ProxyURL(parsedURL)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
client := &http.Client{
|
resp, err := s.claudeUpstream.Do(req, proxyURL)
|
||||||
Transport: transport,
|
|
||||||
Timeout: 60 * time.Second,
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := client.Do(req)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
|
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,12 +2,8 @@ package service
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
|
||||||
"net/url"
|
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -65,23 +61,26 @@ type ClaudeUsageResponse struct {
|
|||||||
} `json:"seven_day_sonnet"`
|
} `json:"seven_day_sonnet"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ClaudeUsageFetcher fetches usage data from Anthropic OAuth API
|
||||||
|
type ClaudeUsageFetcher interface {
|
||||||
|
FetchUsage(ctx context.Context, accessToken, proxyURL string) (*ClaudeUsageResponse, error)
|
||||||
|
}
|
||||||
|
|
||||||
// AccountUsageService 账号使用量查询服务
|
// AccountUsageService 账号使用量查询服务
|
||||||
type AccountUsageService struct {
|
type AccountUsageService struct {
|
||||||
accountRepo ports.AccountRepository
|
accountRepo ports.AccountRepository
|
||||||
usageLogRepo ports.UsageLogRepository
|
usageLogRepo ports.UsageLogRepository
|
||||||
oauthService *OAuthService
|
oauthService *OAuthService
|
||||||
httpClient *http.Client
|
usageFetcher ClaudeUsageFetcher
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAccountUsageService 创建AccountUsageService实例
|
// NewAccountUsageService 创建AccountUsageService实例
|
||||||
func NewAccountUsageService(accountRepo ports.AccountRepository, usageLogRepo ports.UsageLogRepository, oauthService *OAuthService) *AccountUsageService {
|
func NewAccountUsageService(accountRepo ports.AccountRepository, usageLogRepo ports.UsageLogRepository, oauthService *OAuthService, usageFetcher ClaudeUsageFetcher) *AccountUsageService {
|
||||||
return &AccountUsageService{
|
return &AccountUsageService{
|
||||||
accountRepo: accountRepo,
|
accountRepo: accountRepo,
|
||||||
usageLogRepo: usageLogRepo,
|
usageLogRepo: usageLogRepo,
|
||||||
oauthService: oauthService,
|
oauthService: oauthService,
|
||||||
httpClient: &http.Client{
|
usageFetcher: usageFetcher,
|
||||||
Timeout: 30 * time.Second,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -179,58 +178,23 @@ func (s *AccountUsageService) GetTodayStats(ctx context.Context, accountID int64
|
|||||||
|
|
||||||
// fetchOAuthUsage 从Anthropic API获取OAuth账号的使用量
|
// fetchOAuthUsage 从Anthropic API获取OAuth账号的使用量
|
||||||
func (s *AccountUsageService) fetchOAuthUsage(ctx context.Context, account *model.Account) (*UsageInfo, error) {
|
func (s *AccountUsageService) fetchOAuthUsage(ctx context.Context, account *model.Account) (*UsageInfo, error) {
|
||||||
// 获取access token(从credentials中获取)
|
|
||||||
accessToken := account.GetCredential("access_token")
|
accessToken := account.GetCredential("access_token")
|
||||||
if accessToken == "" {
|
if accessToken == "" {
|
||||||
return nil, fmt.Errorf("no access token available")
|
return nil, fmt.Errorf("no access token available")
|
||||||
}
|
}
|
||||||
|
|
||||||
// 获取代理配置
|
var proxyURL string
|
||||||
transport := http.DefaultTransport.(*http.Transport).Clone()
|
|
||||||
if account.ProxyID != nil && account.Proxy != nil {
|
if account.ProxyID != nil && account.Proxy != nil {
|
||||||
proxyURL := account.Proxy.URL()
|
proxyURL = account.Proxy.URL()
|
||||||
if proxyURL != "" {
|
|
||||||
if parsedURL, err := url.Parse(proxyURL); err == nil {
|
|
||||||
transport.Proxy = http.ProxyURL(parsedURL)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
client := &http.Client{
|
usageResp, err := s.usageFetcher.FetchUsage(ctx, accessToken, proxyURL)
|
||||||
Transport: transport,
|
|
||||||
Timeout: 30 * time.Second,
|
|
||||||
}
|
|
||||||
|
|
||||||
// 构建请求
|
|
||||||
req, err := http.NewRequestWithContext(ctx, "GET", "https://api.anthropic.com/api/oauth/usage", nil)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("create request failed: %w", err)
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
|
||||||
req.Header.Set("anthropic-beta", "oauth-2025-04-20")
|
|
||||||
|
|
||||||
// 发送请求
|
|
||||||
resp, err := client.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("request failed: %w", err)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
body, _ := io.ReadAll(resp.Body)
|
|
||||||
return nil, fmt.Errorf("API returned status %d: %s", resp.StatusCode, string(body))
|
|
||||||
}
|
|
||||||
|
|
||||||
// 解析响应
|
|
||||||
var usageResp ClaudeUsageResponse
|
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&usageResp); err != nil {
|
|
||||||
return nil, fmt.Errorf("decode response failed: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 转换为UsageInfo
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
return s.buildUsageInfo(&usageResp, &now), nil
|
return s.buildUsageInfo(usageResp, &now), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// parseTime 尝试多种格式解析时间
|
// parseTime 尝试多种格式解析时间
|
||||||
|
|||||||
@@ -2,21 +2,14 @@ package service
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
"net/url"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"sub2api/internal/model"
|
"sub2api/internal/model"
|
||||||
"sub2api/internal/pkg/pagination"
|
"sub2api/internal/pkg/pagination"
|
||||||
"sub2api/internal/service/ports"
|
"sub2api/internal/service/ports"
|
||||||
|
|
||||||
"golang.org/x/net/proxy"
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -178,6 +171,19 @@ type ProxyTestResult struct {
|
|||||||
Country string `json:"country,omitempty"`
|
Country string `json:"country,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ProxyExitInfo represents proxy exit information from ipinfo.io
|
||||||
|
type ProxyExitInfo struct {
|
||||||
|
IP string
|
||||||
|
City string
|
||||||
|
Region string
|
||||||
|
Country string
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProxyExitInfoProber tests proxy connectivity and retrieves exit information
|
||||||
|
type ProxyExitInfoProber interface {
|
||||||
|
ProbeProxy(ctx context.Context, proxyURL string) (*ProxyExitInfo, int64, error)
|
||||||
|
}
|
||||||
|
|
||||||
// adminServiceImpl implements AdminService
|
// adminServiceImpl implements AdminService
|
||||||
type adminServiceImpl struct {
|
type adminServiceImpl struct {
|
||||||
userRepo ports.UserRepository
|
userRepo ports.UserRepository
|
||||||
@@ -189,6 +195,7 @@ type adminServiceImpl struct {
|
|||||||
usageLogRepo ports.UsageLogRepository
|
usageLogRepo ports.UsageLogRepository
|
||||||
userSubRepo ports.UserSubscriptionRepository
|
userSubRepo ports.UserSubscriptionRepository
|
||||||
billingCacheService *BillingCacheService
|
billingCacheService *BillingCacheService
|
||||||
|
proxyProber ProxyExitInfoProber
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAdminService creates a new AdminService
|
// NewAdminService creates a new AdminService
|
||||||
@@ -202,6 +209,7 @@ func NewAdminService(
|
|||||||
usageLogRepo ports.UsageLogRepository,
|
usageLogRepo ports.UsageLogRepository,
|
||||||
userSubRepo ports.UserSubscriptionRepository,
|
userSubRepo ports.UserSubscriptionRepository,
|
||||||
billingCacheService *BillingCacheService,
|
billingCacheService *BillingCacheService,
|
||||||
|
proxyProber ProxyExitInfoProber,
|
||||||
) AdminService {
|
) AdminService {
|
||||||
return &adminServiceImpl{
|
return &adminServiceImpl{
|
||||||
userRepo: userRepo,
|
userRepo: userRepo,
|
||||||
@@ -213,6 +221,7 @@ func NewAdminService(
|
|||||||
usageLogRepo: usageLogRepo,
|
usageLogRepo: usageLogRepo,
|
||||||
userSubRepo: userSubRepo,
|
userSubRepo: userSubRepo,
|
||||||
billingCacheService: billingCacheService,
|
billingCacheService: billingCacheService,
|
||||||
|
proxyProber: proxyProber,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -876,79 +885,12 @@ func (s *adminServiceImpl) TestProxy(ctx context.Context, id int64) (*ProxyTestR
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return testProxyConnection(ctx, proxy)
|
|
||||||
}
|
|
||||||
|
|
||||||
// testProxyConnection tests proxy connectivity by requesting ipinfo.io/json
|
|
||||||
func testProxyConnection(ctx context.Context, proxy *model.Proxy) (*ProxyTestResult, error) {
|
|
||||||
proxyURL := proxy.URL()
|
proxyURL := proxy.URL()
|
||||||
|
exitInfo, latencyMs, err := s.proxyProber.ProbeProxy(ctx, proxyURL)
|
||||||
// Create HTTP client with proxy
|
|
||||||
transport, err := createProxyTransport(proxyURL)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &ProxyTestResult{
|
return &ProxyTestResult{
|
||||||
Success: false,
|
Success: false,
|
||||||
Message: fmt.Sprintf("Failed to create proxy transport: %v", err),
|
Message: err.Error(),
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
client := &http.Client{
|
|
||||||
Transport: transport,
|
|
||||||
Timeout: 15 * time.Second,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Measure latency
|
|
||||||
startTime := time.Now()
|
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, "GET", "https://ipinfo.io/json", nil)
|
|
||||||
if err != nil {
|
|
||||||
return &ProxyTestResult{
|
|
||||||
Success: false,
|
|
||||||
Message: fmt.Sprintf("Failed to create request: %v", err),
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := client.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return &ProxyTestResult{
|
|
||||||
Success: false,
|
|
||||||
Message: fmt.Sprintf("Proxy connection failed: %v", err),
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
latencyMs := time.Since(startTime).Milliseconds()
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
return &ProxyTestResult{
|
|
||||||
Success: false,
|
|
||||||
Message: fmt.Sprintf("Request failed with status: %d", resp.StatusCode),
|
|
||||||
LatencyMs: latencyMs,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse ipinfo.io response
|
|
||||||
var ipInfo struct {
|
|
||||||
IP string `json:"ip"`
|
|
||||||
City string `json:"city"`
|
|
||||||
Region string `json:"region"`
|
|
||||||
Country string `json:"country"`
|
|
||||||
}
|
|
||||||
|
|
||||||
body, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return &ProxyTestResult{
|
|
||||||
Success: true,
|
|
||||||
Message: "Proxy is accessible but failed to read response",
|
|
||||||
LatencyMs: latencyMs,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := json.Unmarshal(body, &ipInfo); err != nil {
|
|
||||||
return &ProxyTestResult{
|
|
||||||
Success: true,
|
|
||||||
Message: "Proxy is accessible but failed to parse response",
|
|
||||||
LatencyMs: latencyMs,
|
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -956,38 +898,9 @@ func testProxyConnection(ctx context.Context, proxy *model.Proxy) (*ProxyTestRes
|
|||||||
Success: true,
|
Success: true,
|
||||||
Message: "Proxy is accessible",
|
Message: "Proxy is accessible",
|
||||||
LatencyMs: latencyMs,
|
LatencyMs: latencyMs,
|
||||||
IPAddress: ipInfo.IP,
|
IPAddress: exitInfo.IP,
|
||||||
City: ipInfo.City,
|
City: exitInfo.City,
|
||||||
Region: ipInfo.Region,
|
Region: exitInfo.Region,
|
||||||
Country: ipInfo.Country,
|
Country: exitInfo.Country,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// createProxyTransport creates an HTTP transport with the given proxy URL
|
|
||||||
func createProxyTransport(proxyURL string) (*http.Transport, error) {
|
|
||||||
parsedURL, err := url.Parse(proxyURL)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("invalid proxy URL: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
transport := &http.Transport{
|
|
||||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
|
||||||
}
|
|
||||||
|
|
||||||
switch parsedURL.Scheme {
|
|
||||||
case "http", "https":
|
|
||||||
transport.Proxy = http.ProxyURL(parsedURL)
|
|
||||||
case "socks5":
|
|
||||||
dialer, err := proxy.FromURL(parsedURL, proxy.Direct)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to create socks5 dialer: %w", err)
|
|
||||||
}
|
|
||||||
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
|
||||||
return dialer.Dial(network, addr)
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("unsupported proxy protocol: %s", parsedURL.Scheme)
|
|
||||||
}
|
|
||||||
|
|
||||||
return transport, nil
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
|
||||||
"regexp"
|
"regexp"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -26,6 +25,11 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ClaudeUpstream handles HTTP requests to Claude API
|
||||||
|
type ClaudeUpstream interface {
|
||||||
|
Do(req *http.Request, proxyURL string) (*http.Response, error)
|
||||||
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true"
|
claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true"
|
||||||
claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true"
|
claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true"
|
||||||
@@ -87,7 +91,7 @@ type GatewayService struct {
|
|||||||
rateLimitService *RateLimitService
|
rateLimitService *RateLimitService
|
||||||
billingCacheService *BillingCacheService
|
billingCacheService *BillingCacheService
|
||||||
identityService *IdentityService
|
identityService *IdentityService
|
||||||
httpClient *http.Client
|
claudeUpstream ClaudeUpstream
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewGatewayService creates a new GatewayService
|
// NewGatewayService creates a new GatewayService
|
||||||
@@ -103,20 +107,8 @@ func NewGatewayService(
|
|||||||
rateLimitService *RateLimitService,
|
rateLimitService *RateLimitService,
|
||||||
billingCacheService *BillingCacheService,
|
billingCacheService *BillingCacheService,
|
||||||
identityService *IdentityService,
|
identityService *IdentityService,
|
||||||
|
claudeUpstream ClaudeUpstream,
|
||||||
) *GatewayService {
|
) *GatewayService {
|
||||||
// 计算响应头超时时间
|
|
||||||
responseHeaderTimeout := time.Duration(cfg.Gateway.ResponseHeaderTimeout) * time.Second
|
|
||||||
if responseHeaderTimeout == 0 {
|
|
||||||
responseHeaderTimeout = 300 * time.Second // 默认5分钟,LLM高负载时可能排队较久
|
|
||||||
}
|
|
||||||
|
|
||||||
transport := &http.Transport{
|
|
||||||
MaxIdleConns: 100,
|
|
||||||
MaxIdleConnsPerHost: 10,
|
|
||||||
IdleConnTimeout: 90 * time.Second,
|
|
||||||
ResponseHeaderTimeout: responseHeaderTimeout, // 等待上游响应头的超时
|
|
||||||
// 注意:不设置整体 Timeout,让流式响应可以无限时间传输
|
|
||||||
}
|
|
||||||
return &GatewayService{
|
return &GatewayService{
|
||||||
accountRepo: accountRepo,
|
accountRepo: accountRepo,
|
||||||
usageLogRepo: usageLogRepo,
|
usageLogRepo: usageLogRepo,
|
||||||
@@ -129,11 +121,7 @@ func NewGatewayService(
|
|||||||
rateLimitService: rateLimitService,
|
rateLimitService: rateLimitService,
|
||||||
billingCacheService: billingCacheService,
|
billingCacheService: billingCacheService,
|
||||||
identityService: identityService,
|
identityService: identityService,
|
||||||
httpClient: &http.Client{
|
claudeUpstream: claudeUpstream,
|
||||||
Transport: transport,
|
|
||||||
// 不设置 Timeout:流式请求可能持续十几分钟
|
|
||||||
// 超时控制由 Transport.ResponseHeaderTimeout 负责(只控制等待响应头)
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -436,19 +424,19 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *m
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 构建上游请求
|
// 构建上游请求
|
||||||
upstreamResult, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType)
|
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 选择使用的client:如果有代理则使用独立的client,否则使用共享的httpClient
|
// 获取代理URL
|
||||||
httpClient := s.httpClient
|
proxyURL := ""
|
||||||
if upstreamResult.Client != nil {
|
if account.ProxyID != nil && account.Proxy != nil {
|
||||||
httpClient = upstreamResult.Client
|
proxyURL = account.Proxy.URL()
|
||||||
}
|
}
|
||||||
|
|
||||||
// 发送请求
|
// 发送请求
|
||||||
resp, err := httpClient.Do(upstreamResult.Request)
|
resp, err := s.claudeUpstream.Do(upstreamReq, proxyURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("upstream request failed: %w", err)
|
return nil, fmt.Errorf("upstream request failed: %w", err)
|
||||||
}
|
}
|
||||||
@@ -461,16 +449,11 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *m
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("token refresh failed: %w", err)
|
return nil, fmt.Errorf("token refresh failed: %w", err)
|
||||||
}
|
}
|
||||||
upstreamResult, err = s.buildUpstreamRequest(ctx, c, account, body, token, tokenType)
|
upstreamReq, err = s.buildUpstreamRequest(ctx, c, account, body, token, tokenType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
// 重试时也需要使用正确的client
|
resp, err = s.claudeUpstream.Do(upstreamReq, proxyURL)
|
||||||
httpClient = s.httpClient
|
|
||||||
if upstreamResult.Client != nil {
|
|
||||||
httpClient = upstreamResult.Client
|
|
||||||
}
|
|
||||||
resp, err = httpClient.Do(upstreamResult.Request)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("retry request failed: %w", err)
|
return nil, fmt.Errorf("retry request failed: %w", err)
|
||||||
}
|
}
|
||||||
@@ -509,13 +492,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *m
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// buildUpstreamRequestResult contains the request and optional custom client for proxy
|
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *model.Account, body []byte, token, tokenType string) (*http.Request, error) {
|
||||||
type buildUpstreamRequestResult struct {
|
|
||||||
Request *http.Request
|
|
||||||
Client *http.Client // nil means use default s.httpClient
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *model.Account, body []byte, token, tokenType string) (*buildUpstreamRequestResult, error) {
|
|
||||||
// 确定目标URL
|
// 确定目标URL
|
||||||
targetURL := claudeAPIURL
|
targetURL := claudeAPIURL
|
||||||
if account.Type == model.AccountTypeApiKey {
|
if account.Type == model.AccountTypeApiKey {
|
||||||
@@ -584,36 +561,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
|||||||
req.Header.Set("anthropic-beta", s.getBetaHeader(body, c.GetHeader("anthropic-beta")))
|
req.Header.Set("anthropic-beta", s.getBetaHeader(body, c.GetHeader("anthropic-beta")))
|
||||||
}
|
}
|
||||||
|
|
||||||
// 配置代理 - 创建独立的client避免并发修改共享httpClient
|
return req, nil
|
||||||
var customClient *http.Client
|
|
||||||
if account.ProxyID != nil && account.Proxy != nil {
|
|
||||||
proxyURL := account.Proxy.URL()
|
|
||||||
if proxyURL != "" {
|
|
||||||
if parsedURL, err := url.Parse(proxyURL); err == nil {
|
|
||||||
// 计算响应头超时时间(与默认 Transport 保持一致)
|
|
||||||
responseHeaderTimeout := time.Duration(s.cfg.Gateway.ResponseHeaderTimeout) * time.Second
|
|
||||||
if responseHeaderTimeout == 0 {
|
|
||||||
responseHeaderTimeout = 300 * time.Second
|
|
||||||
}
|
|
||||||
transport := &http.Transport{
|
|
||||||
Proxy: http.ProxyURL(parsedURL),
|
|
||||||
MaxIdleConns: 100,
|
|
||||||
MaxIdleConnsPerHost: 10,
|
|
||||||
IdleConnTimeout: 90 * time.Second,
|
|
||||||
ResponseHeaderTimeout: responseHeaderTimeout,
|
|
||||||
}
|
|
||||||
// 创建独立的client,避免并发时修改共享的s.httpClient.Transport
|
|
||||||
customClient = &http.Client{
|
|
||||||
Transport: transport,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return &buildUpstreamRequestResult{
|
|
||||||
Request: req,
|
|
||||||
Client: customClient,
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// getBetaHeader 处理anthropic-beta header
|
// getBetaHeader 处理anthropic-beta header
|
||||||
@@ -1085,20 +1033,20 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 构建上游请求
|
// 构建上游请求
|
||||||
upstreamResult, err := s.buildCountTokensRequest(ctx, c, account, body, token, tokenType)
|
upstreamReq, err := s.buildCountTokensRequest(ctx, c, account, body, token, tokenType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.countTokensError(c, http.StatusInternalServerError, "api_error", "Failed to build request")
|
s.countTokensError(c, http.StatusInternalServerError, "api_error", "Failed to build request")
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 选择 HTTP client
|
// 获取代理URL
|
||||||
httpClient := s.httpClient
|
proxyURL := ""
|
||||||
if upstreamResult.Client != nil {
|
if account.ProxyID != nil && account.Proxy != nil {
|
||||||
httpClient = upstreamResult.Client
|
proxyURL = account.Proxy.URL()
|
||||||
}
|
}
|
||||||
|
|
||||||
// 发送请求
|
// 发送请求
|
||||||
resp, err := httpClient.Do(upstreamResult.Request)
|
resp, err := s.claudeUpstream.Do(upstreamReq, proxyURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Request failed")
|
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Request failed")
|
||||||
return fmt.Errorf("upstream request failed: %w", err)
|
return fmt.Errorf("upstream request failed: %w", err)
|
||||||
@@ -1113,15 +1061,11 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
|||||||
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Token refresh failed")
|
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Token refresh failed")
|
||||||
return fmt.Errorf("token refresh failed: %w", err)
|
return fmt.Errorf("token refresh failed: %w", err)
|
||||||
}
|
}
|
||||||
upstreamResult, err = s.buildCountTokensRequest(ctx, c, account, body, token, tokenType)
|
upstreamReq, err = s.buildCountTokensRequest(ctx, c, account, body, token, tokenType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
httpClient = s.httpClient
|
resp, err = s.claudeUpstream.Do(upstreamReq, proxyURL)
|
||||||
if upstreamResult.Client != nil {
|
|
||||||
httpClient = upstreamResult.Client
|
|
||||||
}
|
|
||||||
resp, err = httpClient.Do(upstreamResult.Request)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Retry failed")
|
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Retry failed")
|
||||||
return fmt.Errorf("retry request failed: %w", err)
|
return fmt.Errorf("retry request failed: %w", err)
|
||||||
@@ -1159,7 +1103,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// buildCountTokensRequest 构建 count_tokens 上游请求
|
// buildCountTokensRequest 构建 count_tokens 上游请求
|
||||||
func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *model.Account, body []byte, token, tokenType string) (*buildUpstreamRequestResult, error) {
|
func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *model.Account, body []byte, token, tokenType string) (*http.Request, error) {
|
||||||
// 确定目标 URL
|
// 确定目标 URL
|
||||||
targetURL := claudeAPICountTokensURL
|
targetURL := claudeAPICountTokensURL
|
||||||
if account.Type == model.AccountTypeApiKey {
|
if account.Type == model.AccountTypeApiKey {
|
||||||
@@ -1223,32 +1167,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
|
|||||||
req.Header.Set("anthropic-beta", s.getBetaHeader(body, c.GetHeader("anthropic-beta")))
|
req.Header.Set("anthropic-beta", s.getBetaHeader(body, c.GetHeader("anthropic-beta")))
|
||||||
}
|
}
|
||||||
|
|
||||||
// 配置代理
|
return req, nil
|
||||||
var customClient *http.Client
|
|
||||||
if account.ProxyID != nil && account.Proxy != nil {
|
|
||||||
proxyURL := account.Proxy.URL()
|
|
||||||
if proxyURL != "" {
|
|
||||||
if parsedURL, err := url.Parse(proxyURL); err == nil {
|
|
||||||
responseHeaderTimeout := time.Duration(s.cfg.Gateway.ResponseHeaderTimeout) * time.Second
|
|
||||||
if responseHeaderTimeout == 0 {
|
|
||||||
responseHeaderTimeout = 300 * time.Second
|
|
||||||
}
|
|
||||||
transport := &http.Transport{
|
|
||||||
Proxy: http.ProxyURL(parsedURL),
|
|
||||||
MaxIdleConns: 100,
|
|
||||||
MaxIdleConnsPerHost: 10,
|
|
||||||
IdleConnTimeout: 90 * time.Second,
|
|
||||||
ResponseHeaderTimeout: responseHeaderTimeout,
|
|
||||||
}
|
|
||||||
customClient = &http.Client{Transport: transport}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return &buildUpstreamRequestResult{
|
|
||||||
Request: req,
|
|
||||||
Client: customClient,
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// countTokensError 返回 count_tokens 错误响应
|
// countTokensError 返回 count_tokens 错误响应
|
||||||
|
|||||||
@@ -2,32 +2,36 @@ package service
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
|
||||||
"net/url"
|
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"sub2api/internal/model"
|
"sub2api/internal/model"
|
||||||
"sub2api/internal/pkg/oauth"
|
"sub2api/internal/pkg/oauth"
|
||||||
"sub2api/internal/service/ports"
|
"sub2api/internal/service/ports"
|
||||||
|
|
||||||
"github.com/imroc/req/v3"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ClaudeOAuthClient handles HTTP requests for Claude OAuth flows
|
||||||
|
type ClaudeOAuthClient interface {
|
||||||
|
GetOrganizationUUID(ctx context.Context, sessionKey, proxyURL string) (string, error)
|
||||||
|
GetAuthorizationCode(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error)
|
||||||
|
ExchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string) (*oauth.TokenResponse, error)
|
||||||
|
RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error)
|
||||||
|
}
|
||||||
|
|
||||||
// OAuthService handles OAuth authentication flows
|
// OAuthService handles OAuth authentication flows
|
||||||
type OAuthService struct {
|
type OAuthService struct {
|
||||||
sessionStore *oauth.SessionStore
|
sessionStore *oauth.SessionStore
|
||||||
proxyRepo ports.ProxyRepository
|
proxyRepo ports.ProxyRepository
|
||||||
|
oauthClient ClaudeOAuthClient
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewOAuthService creates a new OAuth service
|
// NewOAuthService creates a new OAuth service
|
||||||
func NewOAuthService(proxyRepo ports.ProxyRepository) *OAuthService {
|
func NewOAuthService(proxyRepo ports.ProxyRepository, oauthClient ClaudeOAuthClient) *OAuthService {
|
||||||
return &OAuthService{
|
return &OAuthService{
|
||||||
sessionStore: oauth.NewSessionStore(),
|
sessionStore: oauth.NewSessionStore(),
|
||||||
proxyRepo: proxyRepo,
|
proxyRepo: proxyRepo,
|
||||||
|
oauthClient: oauthClient,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -210,177 +214,21 @@ func (s *OAuthService) CookieAuth(ctx context.Context, input *CookieAuthInput) (
|
|||||||
|
|
||||||
// getOrganizationUUID gets the organization UUID from claude.ai using sessionKey
|
// getOrganizationUUID gets the organization UUID from claude.ai using sessionKey
|
||||||
func (s *OAuthService) getOrganizationUUID(ctx context.Context, sessionKey, proxyURL string) (string, error) {
|
func (s *OAuthService) getOrganizationUUID(ctx context.Context, sessionKey, proxyURL string) (string, error) {
|
||||||
client := s.createReqClient(proxyURL)
|
return s.oauthClient.GetOrganizationUUID(ctx, sessionKey, proxyURL)
|
||||||
|
|
||||||
var orgs []struct {
|
|
||||||
UUID string `json:"uuid"`
|
|
||||||
}
|
|
||||||
|
|
||||||
targetURL := "https://claude.ai/api/organizations"
|
|
||||||
log.Printf("[OAuth] Step 1: Getting organization UUID from %s", targetURL)
|
|
||||||
|
|
||||||
resp, err := client.R().
|
|
||||||
SetContext(ctx).
|
|
||||||
SetCookies(&http.Cookie{
|
|
||||||
Name: "sessionKey",
|
|
||||||
Value: sessionKey,
|
|
||||||
}).
|
|
||||||
SetSuccessResult(&orgs).
|
|
||||||
Get(targetURL)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("[OAuth] Step 1 FAILED - Request error: %v", err)
|
|
||||||
return "", fmt.Errorf("request failed: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Printf("[OAuth] Step 1 Response - Status: %d, Body: %s", resp.StatusCode, resp.String())
|
|
||||||
|
|
||||||
if !resp.IsSuccessState() {
|
|
||||||
return "", fmt.Errorf("failed to get organizations: status %d, body: %s", resp.StatusCode, resp.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(orgs) == 0 {
|
|
||||||
return "", fmt.Errorf("no organizations found")
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Printf("[OAuth] Step 1 SUCCESS - Got org UUID: %s", orgs[0].UUID)
|
|
||||||
return orgs[0].UUID, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// getAuthorizationCode gets the authorization code using sessionKey
|
// getAuthorizationCode gets the authorization code using sessionKey
|
||||||
func (s *OAuthService) getAuthorizationCode(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error) {
|
func (s *OAuthService) getAuthorizationCode(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error) {
|
||||||
client := s.createReqClient(proxyURL)
|
return s.oauthClient.GetAuthorizationCode(ctx, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL)
|
||||||
|
|
||||||
authURL := fmt.Sprintf("https://claude.ai/v1/oauth/%s/authorize", orgUUID)
|
|
||||||
|
|
||||||
// Build request body - must include organization_uuid as per CRS
|
|
||||||
reqBody := map[string]interface{}{
|
|
||||||
"response_type": "code",
|
|
||||||
"client_id": oauth.ClientID,
|
|
||||||
"organization_uuid": orgUUID, // Required field!
|
|
||||||
"redirect_uri": oauth.RedirectURI,
|
|
||||||
"scope": scope,
|
|
||||||
"state": state,
|
|
||||||
"code_challenge": codeChallenge,
|
|
||||||
"code_challenge_method": "S256",
|
|
||||||
}
|
|
||||||
|
|
||||||
reqBodyJSON, _ := json.Marshal(reqBody)
|
|
||||||
log.Printf("[OAuth] Step 2: Getting authorization code from %s", authURL)
|
|
||||||
log.Printf("[OAuth] Step 2 Request Body: %s", string(reqBodyJSON))
|
|
||||||
|
|
||||||
// Response contains redirect_uri with code, not direct code field
|
|
||||||
var result struct {
|
|
||||||
RedirectURI string `json:"redirect_uri"`
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := client.R().
|
|
||||||
SetContext(ctx).
|
|
||||||
SetCookies(&http.Cookie{
|
|
||||||
Name: "sessionKey",
|
|
||||||
Value: sessionKey,
|
|
||||||
}).
|
|
||||||
SetHeader("Accept", "application/json").
|
|
||||||
SetHeader("Accept-Language", "en-US,en;q=0.9").
|
|
||||||
SetHeader("Cache-Control", "no-cache").
|
|
||||||
SetHeader("Origin", "https://claude.ai").
|
|
||||||
SetHeader("Referer", "https://claude.ai/new").
|
|
||||||
SetHeader("Content-Type", "application/json").
|
|
||||||
SetBody(reqBody).
|
|
||||||
SetSuccessResult(&result).
|
|
||||||
Post(authURL)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("[OAuth] Step 2 FAILED - Request error: %v", err)
|
|
||||||
return "", fmt.Errorf("request failed: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Printf("[OAuth] Step 2 Response - Status: %d, Body: %s", resp.StatusCode, resp.String())
|
|
||||||
|
|
||||||
if !resp.IsSuccessState() {
|
|
||||||
return "", fmt.Errorf("failed to get authorization code: status %d, body: %s", resp.StatusCode, resp.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
if result.RedirectURI == "" {
|
|
||||||
return "", fmt.Errorf("no redirect_uri in response")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse redirect_uri to extract code and state
|
|
||||||
parsedURL, err := url.Parse(result.RedirectURI)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("failed to parse redirect_uri: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
queryParams := parsedURL.Query()
|
|
||||||
authCode := queryParams.Get("code")
|
|
||||||
responseState := queryParams.Get("state")
|
|
||||||
|
|
||||||
if authCode == "" {
|
|
||||||
return "", fmt.Errorf("no authorization code in redirect_uri")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Combine code with state if present (as CRS does)
|
|
||||||
fullCode := authCode
|
|
||||||
if responseState != "" {
|
|
||||||
fullCode = authCode + "#" + responseState
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Printf("[OAuth] Step 2 SUCCESS - Got authorization code: %s...", authCode[:20])
|
|
||||||
return fullCode, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// exchangeCodeForToken exchanges authorization code for tokens
|
// exchangeCodeForToken exchanges authorization code for tokens
|
||||||
func (s *OAuthService) exchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string) (*TokenInfo, error) {
|
func (s *OAuthService) exchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string) (*TokenInfo, error) {
|
||||||
client := s.createReqClient(proxyURL)
|
tokenResp, err := s.oauthClient.ExchangeCodeForToken(ctx, code, codeVerifier, state, proxyURL)
|
||||||
|
|
||||||
// Parse code#state format if present
|
|
||||||
authCode := code
|
|
||||||
codeState := ""
|
|
||||||
if parts := strings.Split(code, "#"); len(parts) > 1 {
|
|
||||||
authCode = parts[0]
|
|
||||||
codeState = parts[1]
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build JSON body as CRS does (not form data!)
|
|
||||||
reqBody := map[string]interface{}{
|
|
||||||
"code": authCode,
|
|
||||||
"grant_type": "authorization_code",
|
|
||||||
"client_id": oauth.ClientID,
|
|
||||||
"redirect_uri": oauth.RedirectURI,
|
|
||||||
"code_verifier": codeVerifier,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add state if present
|
|
||||||
if codeState != "" {
|
|
||||||
reqBody["state"] = codeState
|
|
||||||
}
|
|
||||||
|
|
||||||
reqBodyJSON, _ := json.Marshal(reqBody)
|
|
||||||
log.Printf("[OAuth] Step 3: Exchanging code for token at %s", oauth.TokenURL)
|
|
||||||
log.Printf("[OAuth] Step 3 Request Body: %s", string(reqBodyJSON))
|
|
||||||
|
|
||||||
var tokenResp oauth.TokenResponse
|
|
||||||
|
|
||||||
resp, err := client.R().
|
|
||||||
SetContext(ctx).
|
|
||||||
SetHeader("Content-Type", "application/json").
|
|
||||||
SetBody(reqBody).
|
|
||||||
SetSuccessResult(&tokenResp).
|
|
||||||
Post(oauth.TokenURL)
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("[OAuth] Step 3 FAILED - Request error: %v", err)
|
return nil, err
|
||||||
return nil, fmt.Errorf("request failed: %w", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("[OAuth] Step 3 Response - Status: %d, Body: %s", resp.StatusCode, resp.String())
|
|
||||||
|
|
||||||
if !resp.IsSuccessState() {
|
|
||||||
return nil, fmt.Errorf("token exchange failed: status %d, body: %s", resp.StatusCode, resp.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Printf("[OAuth] Step 3 SUCCESS - Got access token")
|
|
||||||
|
|
||||||
tokenInfo := &TokenInfo{
|
tokenInfo := &TokenInfo{
|
||||||
AccessToken: tokenResp.AccessToken,
|
AccessToken: tokenResp.AccessToken,
|
||||||
TokenType: tokenResp.TokenType,
|
TokenType: tokenResp.TokenType,
|
||||||
@@ -390,7 +238,6 @@ func (s *OAuthService) exchangeCodeForToken(ctx context.Context, code, codeVerif
|
|||||||
Scope: tokenResp.Scope,
|
Scope: tokenResp.Scope,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract org_uuid and account_uuid from response
|
|
||||||
if tokenResp.Organization != nil && tokenResp.Organization.UUID != "" {
|
if tokenResp.Organization != nil && tokenResp.Organization.UUID != "" {
|
||||||
tokenInfo.OrgUUID = tokenResp.Organization.UUID
|
tokenInfo.OrgUUID = tokenResp.Organization.UUID
|
||||||
log.Printf("[OAuth] Got org_uuid: %s", tokenInfo.OrgUUID)
|
log.Printf("[OAuth] Got org_uuid: %s", tokenInfo.OrgUUID)
|
||||||
@@ -405,27 +252,9 @@ func (s *OAuthService) exchangeCodeForToken(ctx context.Context, code, codeVerif
|
|||||||
|
|
||||||
// RefreshToken refreshes an OAuth token
|
// RefreshToken refreshes an OAuth token
|
||||||
func (s *OAuthService) RefreshToken(ctx context.Context, refreshToken string, proxyURL string) (*TokenInfo, error) {
|
func (s *OAuthService) RefreshToken(ctx context.Context, refreshToken string, proxyURL string) (*TokenInfo, error) {
|
||||||
client := s.createReqClient(proxyURL)
|
tokenResp, err := s.oauthClient.RefreshToken(ctx, refreshToken, proxyURL)
|
||||||
|
|
||||||
formData := url.Values{}
|
|
||||||
formData.Set("grant_type", "refresh_token")
|
|
||||||
formData.Set("refresh_token", refreshToken)
|
|
||||||
formData.Set("client_id", oauth.ClientID)
|
|
||||||
|
|
||||||
var tokenResp oauth.TokenResponse
|
|
||||||
|
|
||||||
resp, err := client.R().
|
|
||||||
SetContext(ctx).
|
|
||||||
SetFormDataFromValues(formData).
|
|
||||||
SetSuccessResult(&tokenResp).
|
|
||||||
Post(oauth.TokenURL)
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("request failed: %w", err)
|
return nil, err
|
||||||
}
|
|
||||||
|
|
||||||
if !resp.IsSuccessState() {
|
|
||||||
return nil, fmt.Errorf("token refresh failed: status %d, body: %s", resp.StatusCode, resp.String())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return &TokenInfo{
|
return &TokenInfo{
|
||||||
@@ -455,17 +284,3 @@ func (s *OAuthService) RefreshAccountToken(ctx context.Context, account *model.A
|
|||||||
|
|
||||||
return s.RefreshToken(ctx, refreshToken, proxyURL)
|
return s.RefreshToken(ctx, refreshToken, proxyURL)
|
||||||
}
|
}
|
||||||
|
|
||||||
// createReqClient creates a req client with Chrome impersonation and optional proxy
|
|
||||||
func (s *OAuthService) createReqClient(proxyURL string) *req.Client {
|
|
||||||
client := req.C().
|
|
||||||
ImpersonateChrome(). // Impersonate Chrome browser to bypass Cloudflare
|
|
||||||
SetTimeout(60 * time.Second)
|
|
||||||
|
|
||||||
// Set proxy if specified
|
|
||||||
if proxyURL != "" {
|
|
||||||
client.SetProxyURL(proxyURL)
|
|
||||||
}
|
|
||||||
|
|
||||||
return client
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,13 +1,12 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -20,13 +19,19 @@ import (
|
|||||||
// LiteLLMModelPricing LiteLLM价格数据结构
|
// LiteLLMModelPricing LiteLLM价格数据结构
|
||||||
// 只保留我们需要的字段,使用指针来处理可能缺失的值
|
// 只保留我们需要的字段,使用指针来处理可能缺失的值
|
||||||
type LiteLLMModelPricing struct {
|
type LiteLLMModelPricing struct {
|
||||||
InputCostPerToken float64 `json:"input_cost_per_token"`
|
InputCostPerToken float64 `json:"input_cost_per_token"`
|
||||||
OutputCostPerToken float64 `json:"output_cost_per_token"`
|
OutputCostPerToken float64 `json:"output_cost_per_token"`
|
||||||
CacheCreationInputTokenCost float64 `json:"cache_creation_input_token_cost"`
|
CacheCreationInputTokenCost float64 `json:"cache_creation_input_token_cost"`
|
||||||
CacheReadInputTokenCost float64 `json:"cache_read_input_token_cost"`
|
CacheReadInputTokenCost float64 `json:"cache_read_input_token_cost"`
|
||||||
LiteLLMProvider string `json:"litellm_provider"`
|
LiteLLMProvider string `json:"litellm_provider"`
|
||||||
Mode string `json:"mode"`
|
Mode string `json:"mode"`
|
||||||
SupportsPromptCaching bool `json:"supports_prompt_caching"`
|
SupportsPromptCaching bool `json:"supports_prompt_caching"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// PricingRemoteClient 远程价格数据获取接口
|
||||||
|
type PricingRemoteClient interface {
|
||||||
|
FetchPricingJSON(ctx context.Context, url string) ([]byte, error)
|
||||||
|
FetchHashText(ctx context.Context, url string) (string, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// LiteLLMRawEntry 用于解析原始JSON数据
|
// LiteLLMRawEntry 用于解析原始JSON数据
|
||||||
@@ -42,11 +47,12 @@ type LiteLLMRawEntry struct {
|
|||||||
|
|
||||||
// PricingService 动态价格服务
|
// PricingService 动态价格服务
|
||||||
type PricingService struct {
|
type PricingService struct {
|
||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
mu sync.RWMutex
|
remoteClient PricingRemoteClient
|
||||||
pricingData map[string]*LiteLLMModelPricing
|
mu sync.RWMutex
|
||||||
lastUpdated time.Time
|
pricingData map[string]*LiteLLMModelPricing
|
||||||
localHash string
|
lastUpdated time.Time
|
||||||
|
localHash string
|
||||||
|
|
||||||
// 停止信号
|
// 停止信号
|
||||||
stopCh chan struct{}
|
stopCh chan struct{}
|
||||||
@@ -54,11 +60,12 @@ type PricingService struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewPricingService 创建价格服务
|
// NewPricingService 创建价格服务
|
||||||
func NewPricingService(cfg *config.Config) *PricingService {
|
func NewPricingService(cfg *config.Config, remoteClient PricingRemoteClient) *PricingService {
|
||||||
s := &PricingService{
|
s := &PricingService{
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
pricingData: make(map[string]*LiteLLMModelPricing),
|
remoteClient: remoteClient,
|
||||||
stopCh: make(chan struct{}),
|
pricingData: make(map[string]*LiteLLMModelPricing),
|
||||||
|
stopCh: make(chan struct{}),
|
||||||
}
|
}
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
@@ -199,21 +206,13 @@ func (s *PricingService) syncWithRemote() error {
|
|||||||
func (s *PricingService) downloadPricingData() error {
|
func (s *PricingService) downloadPricingData() error {
|
||||||
log.Printf("[Pricing] Downloading from %s", s.cfg.Pricing.RemoteURL)
|
log.Printf("[Pricing] Downloading from %s", s.cfg.Pricing.RemoteURL)
|
||||||
|
|
||||||
client := &http.Client{Timeout: 30 * time.Second}
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||||
resp, err := client.Get(s.cfg.Pricing.RemoteURL)
|
defer cancel()
|
||||||
|
|
||||||
|
body, err := s.remoteClient.FetchPricingJSON(ctx, s.cfg.Pricing.RemoteURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("download failed: %w", err)
|
return fmt.Errorf("download failed: %w", err)
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
return fmt.Errorf("download failed: HTTP %d", resp.StatusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
body, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("read response failed: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 解析JSON数据(使用灵活的解析方式)
|
// 解析JSON数据(使用灵活的解析方式)
|
||||||
data, err := s.parsePricingData(body)
|
data, err := s.parsePricingData(body)
|
||||||
@@ -367,29 +366,10 @@ func (s *PricingService) useFallbackPricing() error {
|
|||||||
|
|
||||||
// fetchRemoteHash 从远程获取哈希值
|
// fetchRemoteHash 从远程获取哈希值
|
||||||
func (s *PricingService) fetchRemoteHash() (string, error) {
|
func (s *PricingService) fetchRemoteHash() (string, error) {
|
||||||
client := &http.Client{Timeout: 10 * time.Second}
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
resp, err := client.Get(s.cfg.Pricing.HashURL)
|
defer cancel()
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
return s.remoteClient.FetchHashText(ctx, s.cfg.Pricing.HashURL)
|
||||||
return "", fmt.Errorf("HTTP %d", resp.StatusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
body, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
// 哈希文件格式:hash filename 或者纯 hash
|
|
||||||
hash := strings.TrimSpace(string(body))
|
|
||||||
parts := strings.Fields(hash)
|
|
||||||
if len(parts) > 0 {
|
|
||||||
return parts[0], nil
|
|
||||||
}
|
|
||||||
return hash, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// computeFileHash 计算文件哈希
|
// computeFileHash 计算文件哈希
|
||||||
@@ -466,14 +446,14 @@ func (s *PricingService) extractBaseName(model string) string {
|
|||||||
func (s *PricingService) matchByModelFamily(model string) *LiteLLMModelPricing {
|
func (s *PricingService) matchByModelFamily(model string) *LiteLLMModelPricing {
|
||||||
// Claude模型系列匹配规则
|
// Claude模型系列匹配规则
|
||||||
familyPatterns := map[string][]string{
|
familyPatterns := map[string][]string{
|
||||||
"opus-4.5": {"claude-opus-4.5", "claude-opus-4-5"},
|
"opus-4.5": {"claude-opus-4.5", "claude-opus-4-5"},
|
||||||
"opus-4": {"claude-opus-4", "claude-3-opus"},
|
"opus-4": {"claude-opus-4", "claude-3-opus"},
|
||||||
"sonnet-4.5": {"claude-sonnet-4.5", "claude-sonnet-4-5"},
|
"sonnet-4.5": {"claude-sonnet-4.5", "claude-sonnet-4-5"},
|
||||||
"sonnet-4": {"claude-sonnet-4", "claude-3-5-sonnet"},
|
"sonnet-4": {"claude-sonnet-4", "claude-3-5-sonnet"},
|
||||||
"sonnet-3.5": {"claude-3-5-sonnet", "claude-3.5-sonnet"},
|
"sonnet-3.5": {"claude-3-5-sonnet", "claude-3.5-sonnet"},
|
||||||
"sonnet-3": {"claude-3-sonnet"},
|
"sonnet-3": {"claude-3-sonnet"},
|
||||||
"haiku-3.5": {"claude-3-5-haiku", "claude-3.5-haiku"},
|
"haiku-3.5": {"claude-3-5-haiku", "claude-3.5-haiku"},
|
||||||
"haiku-3": {"claude-3-haiku"},
|
"haiku-3": {"claude-3-haiku"},
|
||||||
}
|
}
|
||||||
|
|
||||||
// 确定模型属于哪个系列
|
// 确定模型属于哪个系列
|
||||||
|
|||||||
@@ -26,4 +26,5 @@ type Services struct {
|
|||||||
Subscription *SubscriptionService
|
Subscription *SubscriptionService
|
||||||
Concurrency *ConcurrencyService
|
Concurrency *ConcurrencyService
|
||||||
Identity *IdentityService
|
Identity *IdentityService
|
||||||
|
Update *UpdateService
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,14 +2,9 @@ package service
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
|
||||||
"net/url"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -19,10 +14,15 @@ var (
|
|||||||
|
|
||||||
const turnstileVerifyURL = "https://challenges.cloudflare.com/turnstile/v0/siteverify"
|
const turnstileVerifyURL = "https://challenges.cloudflare.com/turnstile/v0/siteverify"
|
||||||
|
|
||||||
|
// TurnstileVerifier 验证 Turnstile token 的接口
|
||||||
|
type TurnstileVerifier interface {
|
||||||
|
VerifyToken(ctx context.Context, secretKey, token, remoteIP string) (*TurnstileVerifyResponse, error)
|
||||||
|
}
|
||||||
|
|
||||||
// TurnstileService Turnstile 验证服务
|
// TurnstileService Turnstile 验证服务
|
||||||
type TurnstileService struct {
|
type TurnstileService struct {
|
||||||
settingService *SettingService
|
settingService *SettingService
|
||||||
httpClient *http.Client
|
verifier TurnstileVerifier
|
||||||
}
|
}
|
||||||
|
|
||||||
// TurnstileVerifyResponse Cloudflare Turnstile 验证响应
|
// TurnstileVerifyResponse Cloudflare Turnstile 验证响应
|
||||||
@@ -36,12 +36,10 @@ type TurnstileVerifyResponse struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewTurnstileService 创建 Turnstile 服务实例
|
// NewTurnstileService 创建 Turnstile 服务实例
|
||||||
func NewTurnstileService(settingService *SettingService) *TurnstileService {
|
func NewTurnstileService(settingService *SettingService, verifier TurnstileVerifier) *TurnstileService {
|
||||||
return &TurnstileService{
|
return &TurnstileService{
|
||||||
settingService: settingService,
|
settingService: settingService,
|
||||||
httpClient: &http.Client{
|
verifier: verifier,
|
||||||
Timeout: 10 * time.Second,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -66,35 +64,12 @@ func (s *TurnstileService) VerifyToken(ctx context.Context, token string, remote
|
|||||||
return ErrTurnstileVerificationFailed
|
return ErrTurnstileVerificationFailed
|
||||||
}
|
}
|
||||||
|
|
||||||
// 构建请求
|
|
||||||
formData := url.Values{}
|
|
||||||
formData.Set("secret", secretKey)
|
|
||||||
formData.Set("response", token)
|
|
||||||
if remoteIP != "" {
|
|
||||||
formData.Set("remoteip", remoteIP)
|
|
||||||
}
|
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, turnstileVerifyURL, strings.NewReader(formData.Encode()))
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("create request: %w", err)
|
|
||||||
}
|
|
||||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
||||||
|
|
||||||
// 发送请求
|
|
||||||
log.Printf("[Turnstile] Verifying token for IP: %s", remoteIP)
|
log.Printf("[Turnstile] Verifying token for IP: %s", remoteIP)
|
||||||
resp, err := s.httpClient.Do(req)
|
result, err := s.verifier.VerifyToken(ctx, secretKey, token, remoteIP)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("[Turnstile] Request failed: %v", err)
|
log.Printf("[Turnstile] Request failed: %v", err)
|
||||||
return fmt.Errorf("send request: %w", err)
|
return fmt.Errorf("send request: %w", err)
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
// 解析响应
|
|
||||||
var result TurnstileVerifyResponse
|
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
|
||||||
log.Printf("[Turnstile] Failed to decode response: %v", err)
|
|
||||||
return fmt.Errorf("decode response: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !result.Success {
|
if !result.Success {
|
||||||
log.Printf("[Turnstile] Verification failed, error codes: %v", result.ErrorCodes)
|
log.Printf("[Turnstile] Verification failed, error codes: %v", result.ErrorCodes)
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
@@ -34,17 +33,26 @@ const (
|
|||||||
maxDownloadSize = 500 * 1024 * 1024
|
maxDownloadSize = 500 * 1024 * 1024
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// GitHubReleaseClient 获取 GitHub release 信息的接口
|
||||||
|
type GitHubReleaseClient interface {
|
||||||
|
FetchLatestRelease(ctx context.Context, repo string) (*GitHubRelease, error)
|
||||||
|
DownloadFile(ctx context.Context, url, dest string, maxSize int64) error
|
||||||
|
FetchChecksumFile(ctx context.Context, url string) ([]byte, error)
|
||||||
|
}
|
||||||
|
|
||||||
// UpdateService handles software updates
|
// UpdateService handles software updates
|
||||||
type UpdateService struct {
|
type UpdateService struct {
|
||||||
cache ports.UpdateCache
|
cache ports.UpdateCache
|
||||||
|
githubClient GitHubReleaseClient
|
||||||
currentVersion string
|
currentVersion string
|
||||||
buildType string // "source" for manual builds, "release" for CI builds
|
buildType string // "source" for manual builds, "release" for CI builds
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewUpdateService creates a new UpdateService
|
// NewUpdateService creates a new UpdateService
|
||||||
func NewUpdateService(cache ports.UpdateCache, version, buildType string) *UpdateService {
|
func NewUpdateService(cache ports.UpdateCache, githubClient GitHubReleaseClient, version, buildType string) *UpdateService {
|
||||||
return &UpdateService{
|
return &UpdateService{
|
||||||
cache: cache,
|
cache: cache,
|
||||||
|
githubClient: githubClient,
|
||||||
currentVersion: version,
|
currentVersion: version,
|
||||||
buildType: buildType,
|
buildType: buildType,
|
||||||
}
|
}
|
||||||
@@ -260,42 +268,11 @@ func (s *UpdateService) Rollback() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
func (s *UpdateService) fetchLatestRelease(ctx context.Context) (*UpdateInfo, error) {
|
func (s *UpdateService) fetchLatestRelease(ctx context.Context) (*UpdateInfo, error) {
|
||||||
url := fmt.Sprintf("https://api.github.com/repos/%s/releases/latest", githubRepo)
|
release, err := s.githubClient.FetchLatestRelease(ctx, githubRepo)
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
req.Header.Set("Accept", "application/vnd.github.v3+json")
|
|
||||||
req.Header.Set("User-Agent", "Sub2API-Updater")
|
|
||||||
|
|
||||||
client := &http.Client{Timeout: 30 * time.Second}
|
|
||||||
resp, err := client.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
if resp.StatusCode == http.StatusNotFound {
|
|
||||||
return &UpdateInfo{
|
|
||||||
CurrentVersion: s.currentVersion,
|
|
||||||
LatestVersion: s.currentVersion,
|
|
||||||
HasUpdate: false,
|
|
||||||
Warning: "No releases found",
|
|
||||||
BuildType: s.buildType,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
return nil, fmt.Errorf("GitHub API returned %d", resp.StatusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
var release GitHubRelease
|
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&release); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
latestVersion := strings.TrimPrefix(release.TagName, "v")
|
latestVersion := strings.TrimPrefix(release.TagName, "v")
|
||||||
|
|
||||||
@@ -325,47 +302,7 @@ func (s *UpdateService) fetchLatestRelease(ctx context.Context) (*UpdateInfo, er
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *UpdateService) downloadFile(ctx context.Context, downloadURL, dest string) error {
|
func (s *UpdateService) downloadFile(ctx context.Context, downloadURL, dest string) error {
|
||||||
req, err := http.NewRequestWithContext(ctx, "GET", downloadURL, nil)
|
return s.githubClient.DownloadFile(ctx, downloadURL, dest, maxDownloadSize)
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
client := &http.Client{Timeout: 10 * time.Minute}
|
|
||||||
resp, err := client.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
return fmt.Errorf("download returned %d", resp.StatusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SECURITY: Check Content-Length if available
|
|
||||||
if resp.ContentLength > maxDownloadSize {
|
|
||||||
return fmt.Errorf("file too large: %d bytes (max %d)", resp.ContentLength, maxDownloadSize)
|
|
||||||
}
|
|
||||||
|
|
||||||
out, err := os.Create(dest)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer out.Close()
|
|
||||||
|
|
||||||
// SECURITY: Use LimitReader to enforce max download size even if Content-Length is missing/wrong
|
|
||||||
limited := io.LimitReader(resp.Body, maxDownloadSize+1)
|
|
||||||
written, err := io.Copy(out, limited)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if we hit the limit (downloaded more than maxDownloadSize)
|
|
||||||
if written > maxDownloadSize {
|
|
||||||
os.Remove(dest) // Clean up partial file
|
|
||||||
return fmt.Errorf("download exceeded maximum size of %d bytes", maxDownloadSize)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *UpdateService) getArchiveName() string {
|
func (s *UpdateService) getArchiveName() string {
|
||||||
@@ -402,20 +339,9 @@ func validateDownloadURL(rawURL string) error {
|
|||||||
|
|
||||||
func (s *UpdateService) verifyChecksum(ctx context.Context, filePath, checksumURL string) error {
|
func (s *UpdateService) verifyChecksum(ctx context.Context, filePath, checksumURL string) error {
|
||||||
// Download checksums file
|
// Download checksums file
|
||||||
req, err := http.NewRequestWithContext(ctx, "GET", checksumURL, nil)
|
checksumData, err := s.githubClient.FetchChecksumFile(ctx, checksumURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("failed to download checksums: %w", err)
|
||||||
}
|
|
||||||
|
|
||||||
client := &http.Client{Timeout: 30 * time.Second}
|
|
||||||
resp, err := client.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
return fmt.Errorf("failed to download checksums: %d", resp.StatusCode)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Calculate file hash
|
// Calculate file hash
|
||||||
@@ -433,7 +359,7 @@ func (s *UpdateService) verifyChecksum(ctx context.Context, filePath, checksumUR
|
|||||||
|
|
||||||
// Find expected hash in checksums file
|
// Find expected hash in checksums file
|
||||||
fileName := filepath.Base(filePath)
|
fileName := filepath.Base(filePath)
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
scanner := bufio.NewScanner(strings.NewReader(string(checksumData)))
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
line := scanner.Text()
|
line := scanner.Text()
|
||||||
parts := strings.Fields(line)
|
parts := strings.Fields(line)
|
||||||
|
|||||||
@@ -2,13 +2,20 @@ package service
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"sub2api/internal/config"
|
"sub2api/internal/config"
|
||||||
|
"sub2api/internal/service/ports"
|
||||||
|
|
||||||
"github.com/google/wire"
|
"github.com/google/wire"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// BuildInfo contains build information
|
||||||
|
type BuildInfo struct {
|
||||||
|
Version string
|
||||||
|
BuildType string
|
||||||
|
}
|
||||||
|
|
||||||
// ProvidePricingService creates and initializes PricingService
|
// ProvidePricingService creates and initializes PricingService
|
||||||
func ProvidePricingService(cfg *config.Config) (*PricingService, error) {
|
func ProvidePricingService(cfg *config.Config, remoteClient PricingRemoteClient) (*PricingService, error) {
|
||||||
svc := NewPricingService(cfg)
|
svc := NewPricingService(cfg, remoteClient)
|
||||||
if err := svc.Initialize(); err != nil {
|
if err := svc.Initialize(); err != nil {
|
||||||
// 价格服务初始化失败不应阻止启动,使用回退价格
|
// 价格服务初始化失败不应阻止启动,使用回退价格
|
||||||
println("[Service] Warning: Pricing service initialization failed:", err.Error())
|
println("[Service] Warning: Pricing service initialization failed:", err.Error())
|
||||||
@@ -16,6 +23,11 @@ func ProvidePricingService(cfg *config.Config) (*PricingService, error) {
|
|||||||
return svc, nil
|
return svc, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ProvideUpdateService creates UpdateService with BuildInfo
|
||||||
|
func ProvideUpdateService(cache ports.UpdateCache, githubClient GitHubReleaseClient, buildInfo BuildInfo) *UpdateService {
|
||||||
|
return NewUpdateService(cache, githubClient, buildInfo.Version, buildInfo.BuildType)
|
||||||
|
}
|
||||||
|
|
||||||
// ProvideEmailQueueService creates EmailQueueService with default worker count
|
// ProvideEmailQueueService creates EmailQueueService with default worker count
|
||||||
func ProvideEmailQueueService(emailService *EmailService) *EmailQueueService {
|
func ProvideEmailQueueService(emailService *EmailService) *EmailQueueService {
|
||||||
return NewEmailQueueService(emailService, 3)
|
return NewEmailQueueService(emailService, 3)
|
||||||
@@ -48,6 +60,7 @@ var ProviderSet = wire.NewSet(
|
|||||||
NewSubscriptionService,
|
NewSubscriptionService,
|
||||||
NewConcurrencyService,
|
NewConcurrencyService,
|
||||||
NewIdentityService,
|
NewIdentityService,
|
||||||
|
ProvideUpdateService,
|
||||||
|
|
||||||
// Provide the Services container struct
|
// Provide the Services container struct
|
||||||
wire.Struct(new(Services), "*"),
|
wire.Struct(new(Services), "*"),
|
||||||
|
|||||||
Reference in New Issue
Block a user