diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index d20ed0c8..1dc2278e 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -4,6 +4,22 @@ on: push: tags: - 'v*' + workflow_dispatch: + inputs: + tag: + description: 'Tag to release (e.g., v1.0.0)' + required: true + type: string + simple_release: + description: 'Simple release: only x86_64 GHCR image, skip other artifacts' + required: false + type: boolean + default: false + +# 环境变量:合并 workflow_dispatch 输入和 repository variable +# tag push 触发时读取 vars.SIMPLE_RELEASE,workflow_dispatch 时使用输入参数 +env: + SIMPLE_RELEASE: ${{ github.event.inputs.simple_release == 'true' || vars.SIMPLE_RELEASE == 'true' }} permissions: contents: write @@ -19,7 +35,12 @@ jobs: - name: Update VERSION file run: | - VERSION=${GITHUB_REF#refs/tags/v} + if [ "${{ github.event_name }}" = "workflow_dispatch" ]; then + VERSION=${{ github.event.inputs.tag }} + VERSION=${VERSION#v} + else + VERSION=${GITHUB_REF#refs/tags/v} + fi echo "$VERSION" > backend/cmd/server/VERSION echo "Updated VERSION file to: $VERSION" @@ -66,6 +87,7 @@ jobs: uses: actions/checkout@v4 with: fetch-depth: 0 + ref: ${{ github.event.inputs.tag || github.ref }} - name: Download VERSION artifact uses: actions/download-artifact@v4 @@ -93,7 +115,10 @@ jobs: uses: docker/setup-buildx-action@v3 - name: Login to DockerHub + if: ${{ env.DOCKERHUB_USERNAME != '' }} uses: docker/login-action@v3 + env: + DOCKERHUB_USERNAME: ${{ secrets.DOCKERHUB_USERNAME }} with: username: ${{ secrets.DOCKERHUB_USERNAME }} password: ${{ secrets.DOCKERHUB_TOKEN }} @@ -113,7 +138,11 @@ jobs: - name: Get tag message id: tag_message run: | - TAG_NAME=${GITHUB_REF#refs/tags/} + if [ "${{ github.event_name }}" = "workflow_dispatch" ]; then + TAG_NAME=${{ github.event.inputs.tag }} + else + TAG_NAME=${GITHUB_REF#refs/tags/} + fi echo "Processing tag: $TAG_NAME" # 获取完整的 tag message(跳过第一行标题) @@ -137,18 +166,21 @@ jobs: uses: goreleaser/goreleaser-action@v6 with: version: '~> v2' - args: release --clean --skip=validate + args: release --clean --skip=validate ${{ env.SIMPLE_RELEASE == 'true' && '--config=.goreleaser.simple.yaml' || '' }} env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} TAG_MESSAGE: ${{ steps.tag_message.outputs.message }} GITHUB_REPO_OWNER: ${{ github.repository_owner }} GITHUB_REPO_OWNER_LOWER: ${{ steps.lowercase.outputs.owner }} GITHUB_REPO_NAME: ${{ github.event.repository.name }} - DOCKERHUB_USERNAME: ${{ secrets.DOCKERHUB_USERNAME }} + DOCKERHUB_USERNAME: ${{ secrets.DOCKERHUB_USERNAME || 'skip' }} # Update DockerHub description - name: Update DockerHub description + if: ${{ env.SIMPLE_RELEASE != 'true' && env.DOCKERHUB_USERNAME != '' }} uses: peter-evans/dockerhub-description@v4 + env: + DOCKERHUB_USERNAME: ${{ secrets.DOCKERHUB_USERNAME }} with: username: ${{ secrets.DOCKERHUB_USERNAME }} password: ${{ secrets.DOCKERHUB_TOKEN }} @@ -158,9 +190,11 @@ jobs: # Send Telegram notification - name: Send Telegram Notification + if: ${{ env.SIMPLE_RELEASE != 'true' }} env: TELEGRAM_BOT_TOKEN: ${{ secrets.TELEGRAM_BOT_TOKEN }} TELEGRAM_CHAT_ID: ${{ secrets.TELEGRAM_CHAT_ID }} + DOCKERHUB_USERNAME: ${{ secrets.DOCKERHUB_USERNAME }} continue-on-error: true run: | # 检查必要的环境变量 @@ -169,10 +203,13 @@ jobs: exit 0 fi - TAG_NAME=${GITHUB_REF#refs/tags/} + if [ "${{ github.event_name }}" = "workflow_dispatch" ]; then + TAG_NAME=${{ github.event.inputs.tag }} + else + TAG_NAME=${GITHUB_REF#refs/tags/} + fi VERSION=${TAG_NAME#v} REPO="${{ github.repository }}" - DOCKER_IMAGE="${{ secrets.DOCKERHUB_USERNAME }}/sub2api" GHCR_IMAGE="ghcr.io/${REPO,,}" # ${,,} converts to lowercase # 获取 tag message 内容 @@ -194,14 +231,20 @@ jobs: MESSAGE+="🐳 *Docker 部署:*"$'\n' MESSAGE+="\`\`\`bash"$'\n' - MESSAGE+="# Docker Hub"$'\n' - MESSAGE+="docker pull ${DOCKER_IMAGE}:${TAG_NAME}"$'\n' - MESSAGE+="# GitHub Container Registry"$'\n' + # 根据是否配置 DockerHub 动态生成 + if [ -n "$DOCKERHUB_USERNAME" ]; then + DOCKER_IMAGE="${DOCKERHUB_USERNAME}/sub2api" + MESSAGE+="# Docker Hub"$'\n' + MESSAGE+="docker pull ${DOCKER_IMAGE}:${TAG_NAME}"$'\n' + MESSAGE+="# GitHub Container Registry"$'\n' + fi MESSAGE+="docker pull ${GHCR_IMAGE}:${TAG_NAME}"$'\n' MESSAGE+="\`\`\`"$'\n'$'\n' MESSAGE+="🔗 *相关链接:*"$'\n' MESSAGE+="• [GitHub Release](https://github.com/${REPO}/releases/tag/${TAG_NAME})"$'\n' - MESSAGE+="• [Docker Hub](https://hub.docker.com/r/${DOCKER_IMAGE})"$'\n' + if [ -n "$DOCKERHUB_USERNAME" ]; then + MESSAGE+="• [Docker Hub](https://hub.docker.com/r/${DOCKER_IMAGE})"$'\n' + fi MESSAGE+="• [GitHub Packages](https://github.com/${REPO}/pkgs/container/sub2api)"$'\n'$'\n' MESSAGE+="#Sub2API #Release #${TAG_NAME//./_}" diff --git a/.goreleaser.simple.yaml b/.goreleaser.simple.yaml new file mode 100644 index 00000000..2155ed9d --- /dev/null +++ b/.goreleaser.simple.yaml @@ -0,0 +1,86 @@ +# 简化版 GoReleaser 配置 - 仅发布 x86_64 GHCR 镜像 +version: 2 + +project_name: sub2api + +before: + hooks: + - go mod tidy -C backend + +builds: + - id: sub2api + dir: backend + main: ./cmd/server + binary: sub2api + flags: + - -tags=embed + env: + - CGO_ENABLED=0 + goos: + - linux + goarch: + - amd64 + ldflags: + - -s -w + - -X main.Commit={{.Commit}} + - -X main.Date={{.Date}} + - -X main.BuildType=release + +# 跳过 archives +archives: [] + +# 跳过 checksum +checksum: + disable: true + +changelog: + disable: true + +# 仅 GHCR x86_64 镜像 +dockers: + - id: ghcr-amd64 + goos: linux + goarch: amd64 + image_templates: + - "ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:{{ .Version }}-amd64" + - "ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:{{ .Version }}" + - "ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:latest" + dockerfile: Dockerfile.goreleaser + use: buildx + build_flag_templates: + - "--platform=linux/amd64" + - "--label=org.opencontainers.image.version={{ .Version }}" + - "--label=org.opencontainers.image.revision={{ .Commit }}" + - "--label=org.opencontainers.image.source=https://github.com/{{ .Env.GITHUB_REPO_OWNER }}/{{ .Env.GITHUB_REPO_NAME }}" + +# 跳过 manifests(单架构不需要) +docker_manifests: [] + +release: + github: + owner: "{{ .Env.GITHUB_REPO_OWNER }}" + name: "{{ .Env.GITHUB_REPO_NAME }}" + draft: false + prerelease: auto + name_template: "Sub2API {{.Version}} (Simple)" + # 跳过上传二进制包 + skip_upload: true + header: | + > AI API Gateway Platform - 将 AI 订阅配额分发和管理 + > ⚡ Simple Release: 仅包含 x86_64 GHCR 镜像 + + {{ .Env.TAG_MESSAGE }} + + footer: | + --- + + ## 📥 Installation + + **Docker (x86_64 only):** + ```bash + docker pull ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:{{ .Version }} + ``` + + ## 📚 Documentation + + - [GitHub Repository](https://github.com/{{ .Env.GITHUB_REPO_OWNER }}/{{ .Env.GITHUB_REPO_NAME }}) diff --git a/.goreleaser.yaml b/.goreleaser.yaml index c72f7422..da2f9aa5 100644 --- a/.goreleaser.yaml +++ b/.goreleaser.yaml @@ -54,9 +54,11 @@ changelog: # Docker images dockers: + # DockerHub images (skipped if DOCKERHUB_USERNAME is 'skip') - id: amd64 goos: linux goarch: amd64 + skip_push: '{{ if eq .Env.DOCKERHUB_USERNAME "skip" }}true{{ else }}false{{ end }}' image_templates: - "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-amd64" dockerfile: Dockerfile.goreleaser @@ -69,6 +71,7 @@ dockers: - id: arm64 goos: linux goarch: arm64 + skip_push: '{{ if eq .Env.DOCKERHUB_USERNAME "skip" }}true{{ else }}false{{ end }}' image_templates: - "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-arm64" dockerfile: Dockerfile.goreleaser @@ -107,22 +110,27 @@ dockers: # Docker manifests for multi-arch support docker_manifests: + # DockerHub manifests (skipped if DOCKERHUB_USERNAME is 'skip') - name_template: "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}" + skip_push: '{{ if eq .Env.DOCKERHUB_USERNAME "skip" }}true{{ else }}false{{ end }}' image_templates: - "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-amd64" - "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-arm64" - name_template: "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:latest" + skip_push: '{{ if eq .Env.DOCKERHUB_USERNAME "skip" }}true{{ else }}false{{ end }}' image_templates: - "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-amd64" - "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-arm64" - name_template: "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Major }}.{{ .Minor }}" + skip_push: '{{ if eq .Env.DOCKERHUB_USERNAME "skip" }}true{{ else }}false{{ end }}' image_templates: - "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-amd64" - "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-arm64" - name_template: "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Major }}" + skip_push: '{{ if eq .Env.DOCKERHUB_USERNAME "skip" }}true{{ else }}false{{ end }}' image_templates: - "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-amd64" - "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-arm64" @@ -169,9 +177,11 @@ release: **Docker:** ```bash + {{ if ne .Env.DOCKERHUB_USERNAME "skip" -}} # Docker Hub docker pull {{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }} + {{ end -}} # GitHub Container Registry docker pull ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:{{ .Version }} ``` diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index eb763bbe..2e33003b 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -100,7 +100,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService) httpUpstream := repository.NewHTTPUpstream(configConfig) antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, antigravityTokenProvider, rateLimitService, httpUpstream, configConfig) - accountTestService := service.NewAccountTestService(accountRepository, oAuthService, openAIOAuthService, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig) + accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig) concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig) concurrencyService := service.NewConcurrencyService(concurrencyCache) crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig) diff --git a/backend/internal/handler/admin/proxy_handler.go b/backend/internal/handler/admin/proxy_handler.go index 0480b312..99557f9a 100644 --- a/backend/internal/handler/admin/proxy_handler.go +++ b/backend/internal/handler/admin/proxy_handler.go @@ -26,7 +26,7 @@ func NewProxyHandler(adminService service.AdminService) *ProxyHandler { // CreateProxyRequest represents create proxy request type CreateProxyRequest struct { Name string `json:"name" binding:"required"` - Protocol string `json:"protocol" binding:"required,oneof=http https socks5"` + Protocol string `json:"protocol" binding:"required,oneof=http https socks5 socks5h"` Host string `json:"host" binding:"required"` Port int `json:"port" binding:"required,min=1,max=65535"` Username string `json:"username"` @@ -36,7 +36,7 @@ type CreateProxyRequest struct { // UpdateProxyRequest represents update proxy request type UpdateProxyRequest struct { Name string `json:"name"` - Protocol string `json:"protocol" binding:"omitempty,oneof=http https socks5"` + Protocol string `json:"protocol" binding:"omitempty,oneof=http https socks5 socks5h"` Host string `json:"host"` Port int `json:"port" binding:"omitempty,min=1,max=65535"` Username string `json:"username"` @@ -255,7 +255,7 @@ func (h *ProxyHandler) GetProxyAccounts(c *gin.Context) { // BatchCreateProxyItem represents a single proxy in batch create request type BatchCreateProxyItem struct { - Protocol string `json:"protocol" binding:"required,oneof=http https socks5"` + Protocol string `json:"protocol" binding:"required,oneof=http https socks5 socks5h"` Host string `json:"host" binding:"required"` Port int `json:"port" binding:"required,min=1,max=65535"` Username string `json:"username"` diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 0be81ae2..66183ced 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -389,7 +389,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { continue } // 错误响应已在Forward中处理,这里只记录日志 - log.Printf("Forward request failed: %v", err) + log.Printf("Account %d: Forward request failed: %v", account.ID, err) return } diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index 0a7602c6..c8557901 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -237,7 +237,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { continue } // Error response already handled in Forward, just log - log.Printf("Forward request failed: %v", err) + log.Printf("Account %d: Forward request failed: %v", account.ID, err) return } diff --git a/backend/internal/pkg/httpclient/pool.go b/backend/internal/pkg/httpclient/pool.go index 3cd5a592..7bf5cff4 100644 --- a/backend/internal/pkg/httpclient/pool.go +++ b/backend/internal/pkg/httpclient/pool.go @@ -11,23 +11,21 @@ // 新实现使用统一的客户端池: // 1. 相同配置复用同一 http.Client 实例 // 2. 复用 Transport 连接池,减少 TCP/TLS 握手开销 -// 3. 支持 HTTP/HTTPS/SOCKS5 代理 -// 4. 支持严格代理模式(代理失败则返回错误) +// 3. 支持 HTTP/HTTPS/SOCKS5/SOCKS5H 代理 +// 4. 代理配置失败时直接返回错误,不会回退到直连(避免 IP 关联风险) package httpclient import ( - "context" "crypto/tls" "fmt" - "net" "net/http" "net/url" "strings" "sync" "time" + "github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil" "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator" - "golang.org/x/net/proxy" ) // Transport 连接池默认配置 @@ -39,7 +37,7 @@ const ( // Options 定义共享 HTTP 客户端的构建参数 type Options struct { - ProxyURL string // 代理 URL(支持 http/https/socks5) + ProxyURL string // 代理 URL(支持 http/https/socks5/socks5h) Timeout time.Duration // 请求总超时时间 ResponseHeaderTimeout time.Duration // 等待响应头超时时间 InsecureSkipVerify bool // 是否跳过 TLS 证书验证 @@ -58,6 +56,7 @@ var sharedClients sync.Map // GetClient 返回共享的 HTTP 客户端实例 // 性能优化:相同配置复用同一客户端,避免重复创建 Transport +// 安全说明:代理配置失败时直接返回错误,不会回退到直连,避免 IP 关联风险 func GetClient(opts Options) (*http.Client, error) { key := buildClientKey(opts) if cached, ok := sharedClients.Load(key); ok { @@ -68,12 +67,7 @@ func GetClient(opts Options) (*http.Client, error) { client, err := buildClient(opts) if err != nil { - if opts.ProxyStrict { - return nil, err - } - fallback := opts - fallback.ProxyURL = "" - client, _ = buildClient(fallback) + return nil, err } actual, _ := sharedClients.LoadOrStore(key, client) @@ -132,19 +126,8 @@ func buildTransport(opts Options) (*http.Transport, error) { return nil, err } - switch strings.ToLower(parsed.Scheme) { - case "http", "https": - transport.Proxy = http.ProxyURL(parsed) - case "socks5", "socks5h": - dialer, err := proxy.FromURL(parsed, proxy.Direct) - if err != nil { - return nil, err - } - transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { - return dialer.Dial(network, addr) - } - default: - return nil, fmt.Errorf("unsupported proxy protocol: %s", parsed.Scheme) + if err := proxyutil.ConfigureTransportProxy(transport, parsed); err != nil { + return nil, err } return transport, nil diff --git a/backend/internal/pkg/proxyutil/dialer.go b/backend/internal/pkg/proxyutil/dialer.go new file mode 100644 index 00000000..91b224a2 --- /dev/null +++ b/backend/internal/pkg/proxyutil/dialer.go @@ -0,0 +1,62 @@ +// Package proxyutil 提供统一的代理配置功能 +// +// 支持的代理协议: +// - HTTP/HTTPS: 通过 Transport.Proxy 设置 +// - SOCKS5/SOCKS5H: 通过 Transport.DialContext 设置(服务端解析 DNS) +package proxyutil + +import ( + "context" + "fmt" + "net" + "net/http" + "net/url" + "strings" + + "golang.org/x/net/proxy" +) + +// ConfigureTransportProxy 根据代理 URL 配置 Transport +// +// 支持的协议: +// - http/https: 设置 transport.Proxy +// - socks5/socks5h: 设置 transport.DialContext(由代理服务端解析 DNS) +// +// 参数: +// - transport: 需要配置的 http.Transport +// - proxyURL: 代理地址,nil 表示直连 +// +// 返回: +// - error: 代理配置错误(协议不支持或 dialer 创建失败) +func ConfigureTransportProxy(transport *http.Transport, proxyURL *url.URL) error { + if proxyURL == nil { + return nil + } + + scheme := strings.ToLower(proxyURL.Scheme) + switch scheme { + case "http", "https": + transport.Proxy = http.ProxyURL(proxyURL) + return nil + + case "socks5", "socks5h": + dialer, err := proxy.FromURL(proxyURL, proxy.Direct) + if err != nil { + return fmt.Errorf("create socks5 dialer: %w", err) + } + // 优先使用支持 context 的 DialContext,以支持请求取消和超时 + if contextDialer, ok := dialer.(proxy.ContextDialer); ok { + transport.DialContext = contextDialer.DialContext + } else { + // 回退路径:如果 dialer 不支持 ContextDialer,则包装为简单的 DialContext + // 注意:此回退不支持请求取消和超时控制 + transport.DialContext = func(_ context.Context, network, addr string) (net.Conn, error) { + return dialer.Dial(network, addr) + } + } + return nil + + default: + return fmt.Errorf("unsupported proxy scheme: %s", scheme) + } +} diff --git a/backend/internal/pkg/proxyutil/dialer_test.go b/backend/internal/pkg/proxyutil/dialer_test.go new file mode 100644 index 00000000..f153cc9f --- /dev/null +++ b/backend/internal/pkg/proxyutil/dialer_test.go @@ -0,0 +1,204 @@ +package proxyutil + +import ( + "net/http" + "net/url" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestConfigureTransportProxy_Nil(t *testing.T) { + transport := &http.Transport{} + err := ConfigureTransportProxy(transport, nil) + + require.NoError(t, err) + assert.Nil(t, transport.Proxy, "nil proxy should not set Proxy") + assert.Nil(t, transport.DialContext, "nil proxy should not set DialContext") +} + +func TestConfigureTransportProxy_HTTP(t *testing.T) { + transport := &http.Transport{} + proxyURL, _ := url.Parse("http://proxy.example.com:8080") + + err := ConfigureTransportProxy(transport, proxyURL) + + require.NoError(t, err) + assert.NotNil(t, transport.Proxy, "HTTP proxy should set Proxy") + assert.Nil(t, transport.DialContext, "HTTP proxy should not set DialContext") +} + +func TestConfigureTransportProxy_HTTPS(t *testing.T) { + transport := &http.Transport{} + proxyURL, _ := url.Parse("https://secure-proxy.example.com:8443") + + err := ConfigureTransportProxy(transport, proxyURL) + + require.NoError(t, err) + assert.NotNil(t, transport.Proxy, "HTTPS proxy should set Proxy") + assert.Nil(t, transport.DialContext, "HTTPS proxy should not set DialContext") +} + +func TestConfigureTransportProxy_SOCKS5(t *testing.T) { + transport := &http.Transport{} + proxyURL, _ := url.Parse("socks5://socks.example.com:1080") + + err := ConfigureTransportProxy(transport, proxyURL) + + require.NoError(t, err) + assert.Nil(t, transport.Proxy, "SOCKS5 proxy should not set Proxy") + assert.NotNil(t, transport.DialContext, "SOCKS5 proxy should set DialContext") +} + +func TestConfigureTransportProxy_SOCKS5H(t *testing.T) { + transport := &http.Transport{} + proxyURL, _ := url.Parse("socks5h://socks.example.com:1080") + + err := ConfigureTransportProxy(transport, proxyURL) + + require.NoError(t, err) + assert.Nil(t, transport.Proxy, "SOCKS5H proxy should not set Proxy") + assert.NotNil(t, transport.DialContext, "SOCKS5H proxy should set DialContext") +} + +func TestConfigureTransportProxy_CaseInsensitive(t *testing.T) { + testCases := []struct { + scheme string + useProxy bool // true = uses Transport.Proxy, false = uses DialContext + }{ + {"HTTP://proxy.example.com:8080", true}, + {"Http://proxy.example.com:8080", true}, + {"HTTPS://proxy.example.com:8443", true}, + {"Https://proxy.example.com:8443", true}, + {"SOCKS5://socks.example.com:1080", false}, + {"Socks5://socks.example.com:1080", false}, + {"SOCKS5H://socks.example.com:1080", false}, + {"Socks5h://socks.example.com:1080", false}, + } + + for _, tc := range testCases { + t.Run(tc.scheme, func(t *testing.T) { + transport := &http.Transport{} + proxyURL, _ := url.Parse(tc.scheme) + + err := ConfigureTransportProxy(transport, proxyURL) + + require.NoError(t, err) + if tc.useProxy { + assert.NotNil(t, transport.Proxy) + assert.Nil(t, transport.DialContext) + } else { + assert.Nil(t, transport.Proxy) + assert.NotNil(t, transport.DialContext) + } + }) + } +} + +func TestConfigureTransportProxy_Unsupported(t *testing.T) { + testCases := []string{ + "ftp://ftp.example.com", + "file:///path/to/file", + "unknown://example.com", + } + + for _, tc := range testCases { + t.Run(tc, func(t *testing.T) { + transport := &http.Transport{} + proxyURL, _ := url.Parse(tc) + + err := ConfigureTransportProxy(transport, proxyURL) + + require.Error(t, err) + assert.Contains(t, err.Error(), "unsupported proxy scheme") + }) + } +} + +func TestConfigureTransportProxy_WithAuth(t *testing.T) { + transport := &http.Transport{} + proxyURL, _ := url.Parse("socks5://user:password@socks.example.com:1080") + + err := ConfigureTransportProxy(transport, proxyURL) + + require.NoError(t, err) + assert.NotNil(t, transport.DialContext, "SOCKS5 with auth should set DialContext") +} + +func TestConfigureTransportProxy_EmptyScheme(t *testing.T) { + transport := &http.Transport{} + // 空 scheme 的 URL + proxyURL := &url.URL{Host: "proxy.example.com:8080"} + + err := ConfigureTransportProxy(transport, proxyURL) + + require.Error(t, err) + assert.Contains(t, err.Error(), "unsupported proxy scheme") +} + +func TestConfigureTransportProxy_PreservesExistingConfig(t *testing.T) { + // 验证代理配置不会覆盖 Transport 的其他配置 + transport := &http.Transport{ + MaxIdleConns: 100, + MaxIdleConnsPerHost: 10, + } + proxyURL, _ := url.Parse("socks5://socks.example.com:1080") + + err := ConfigureTransportProxy(transport, proxyURL) + + require.NoError(t, err) + assert.Equal(t, 100, transport.MaxIdleConns, "MaxIdleConns should be preserved") + assert.Equal(t, 10, transport.MaxIdleConnsPerHost, "MaxIdleConnsPerHost should be preserved") + assert.NotNil(t, transport.DialContext, "DialContext should be set") +} + +func TestConfigureTransportProxy_IPv6(t *testing.T) { + testCases := []struct { + name string + proxyURL string + }{ + {"SOCKS5H with IPv6 loopback", "socks5h://[::1]:1080"}, + {"SOCKS5 with full IPv6", "socks5://[2001:db8::1]:1080"}, + {"HTTP with IPv6", "http://[::1]:8080"}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + transport := &http.Transport{} + proxyURL, err := url.Parse(tc.proxyURL) + require.NoError(t, err, "URL should be parseable") + + err = ConfigureTransportProxy(transport, proxyURL) + require.NoError(t, err) + }) + } +} + +func TestConfigureTransportProxy_SpecialCharsInPassword(t *testing.T) { + testCases := []struct { + name string + proxyURL string + }{ + // 密码包含 @ 符号(URL 编码为 %40) + {"password with @", "socks5://user:p%40ssword@proxy.example.com:1080"}, + // 密码包含 : 符号(URL 编码为 %3A) + {"password with :", "socks5://user:pass%3Aword@proxy.example.com:1080"}, + // 密码包含 / 符号(URL 编码为 %2F) + {"password with /", "socks5://user:pass%2Fword@proxy.example.com:1080"}, + // 复杂密码 + {"complex password", "socks5h://admin:P%40ss%3Aw0rd%2F123@proxy.example.com:1080"}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + transport := &http.Transport{} + proxyURL, err := url.Parse(tc.proxyURL) + require.NoError(t, err, "URL should be parseable") + + err = ConfigureTransportProxy(transport, proxyURL) + require.NoError(t, err) + assert.NotNil(t, transport.DialContext, "SOCKS5 should set DialContext") + }) + } +} diff --git a/backend/internal/repository/claude_oauth_service.go b/backend/internal/repository/claude_oauth_service.go index 051741aa..8595e783 100644 --- a/backend/internal/repository/claude_oauth_service.go +++ b/backend/internal/repository/claude_oauth_service.go @@ -234,11 +234,17 @@ func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, pro } func createReqClient(proxyURL string) *req.Client { - return getSharedReqClient(reqClientOptions{ - ProxyURL: proxyURL, - Timeout: 60 * time.Second, - Impersonate: true, - }) + // 禁用 CookieJar,确保每次授权都是干净的会话 + client := req.C(). + SetTimeout(60 * time.Second). + ImpersonateChrome(). + SetCookieJar(nil) // 禁用 CookieJar + + if strings.TrimSpace(proxyURL) != "" { + client.SetProxyURL(strings.TrimSpace(proxyURL)) + } + + return client } func prefix(s string, n int) string { diff --git a/backend/internal/repository/http_upstream.go b/backend/internal/repository/http_upstream.go index 3c84ab1d..21723d4a 100644 --- a/backend/internal/repository/http_upstream.go +++ b/backend/internal/repository/http_upstream.go @@ -13,6 +13,7 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator" ) @@ -261,7 +262,12 @@ func (s *httpUpstreamService) getClientEntry(proxyURL string, accountID int64, a // 缓存未命中或需要重建,创建新客户端 settings := s.resolvePoolSettings(isolation, accountConcurrency) - client := &http.Client{Transport: buildUpstreamTransport(settings, parsedProxy)} + transport, err := buildUpstreamTransport(settings, parsedProxy) + if err != nil { + s.mu.Unlock() + return nil, fmt.Errorf("build transport: %w", err) + } + client := &http.Client{Transport: transport} if s.shouldValidateResolvedIP() { client.CheckRedirect = s.redirectChecker } @@ -587,6 +593,7 @@ func defaultPoolSettings(cfg *config.Config) poolSettings { // // 返回: // - *http.Transport: 配置好的 Transport 实例 +// - error: 代理配置错误 // // Transport 参数说明: // - MaxIdleConns: 所有主机的最大空闲连接总数 @@ -594,7 +601,7 @@ func defaultPoolSettings(cfg *config.Config) poolSettings { // - MaxConnsPerHost: 每主机最大连接数(达到后新请求等待) // - IdleConnTimeout: 空闲连接超时(超时后关闭) // - ResponseHeaderTimeout: 等待响应头超时(不影响流式传输) -func buildUpstreamTransport(settings poolSettings, proxyURL *url.URL) *http.Transport { +func buildUpstreamTransport(settings poolSettings, proxyURL *url.URL) (*http.Transport, error) { transport := &http.Transport{ MaxIdleConns: settings.maxIdleConns, MaxIdleConnsPerHost: settings.maxIdleConnsPerHost, @@ -602,10 +609,10 @@ func buildUpstreamTransport(settings poolSettings, proxyURL *url.URL) *http.Tran IdleConnTimeout: settings.idleConnTimeout, ResponseHeaderTimeout: settings.responseHeaderTimeout, } - if proxyURL != nil { - transport.Proxy = http.ProxyURL(proxyURL) + if err := proxyutil.ConfigureTransportProxy(transport, proxyURL); err != nil { + return nil, err } - return transport + return transport, nil } // trackedBody 带跟踪功能的响应体包装器 diff --git a/backend/internal/repository/http_upstream_benchmark_test.go b/backend/internal/repository/http_upstream_benchmark_test.go index 3219c6da..1e7430a3 100644 --- a/backend/internal/repository/http_upstream_benchmark_test.go +++ b/backend/internal/repository/http_upstream_benchmark_test.go @@ -45,8 +45,12 @@ func BenchmarkHTTPUpstreamProxyClient(b *testing.B) { settings := defaultPoolSettings(cfg) for i := 0; i < b.N; i++ { // 每次迭代都创建新客户端,包含 Transport 分配 + transport, err := buildUpstreamTransport(settings, parsedProxy) + if err != nil { + b.Fatalf("创建 Transport 失败: %v", err) + } httpClientSink = &http.Client{ - Transport: buildUpstreamTransport(settings, parsedProxy), + Transport: transport, } } }) diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go index 5797e497..820b532f 100644 --- a/backend/internal/service/account_test_service.go +++ b/backend/internal/service/account_test_service.go @@ -14,7 +14,6 @@ import ( "net/http" "regexp" "strings" - "time" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/claude" @@ -31,7 +30,6 @@ var sseDataPrefix = regexp.MustCompile(`^data:\s*`) const ( testClaudeAPIURL = "https://api.anthropic.com/v1/messages" - testOpenAIAPIURL = "https://api.openai.com/v1/responses" chatgptCodexAPIURL = "https://chatgpt.com/backend-api/codex/responses" ) @@ -47,8 +45,6 @@ type TestEvent struct { // AccountTestService handles account testing operations type AccountTestService struct { accountRepo AccountRepository - oauthService *OAuthService - openaiOAuthService *OpenAIOAuthService geminiTokenProvider *GeminiTokenProvider antigravityGatewayService *AntigravityGatewayService httpUpstream HTTPUpstream @@ -58,8 +54,6 @@ type AccountTestService struct { // NewAccountTestService creates a new AccountTestService func NewAccountTestService( accountRepo AccountRepository, - oauthService *OAuthService, - openaiOAuthService *OpenAIOAuthService, geminiTokenProvider *GeminiTokenProvider, antigravityGatewayService *AntigravityGatewayService, httpUpstream HTTPUpstream, @@ -67,8 +61,6 @@ func NewAccountTestService( ) *AccountTestService { return &AccountTestService{ accountRepo: accountRepo, - oauthService: oauthService, - openaiOAuthService: openaiOAuthService, geminiTokenProvider: geminiTokenProvider, antigravityGatewayService: antigravityGatewayService, httpUpstream: httpUpstream, @@ -204,22 +196,6 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account if authToken == "" { return s.sendErrorAndEnd(c, "No access token available") } - - // Check if token needs refresh - needRefresh := false - if expiresAt := account.GetCredentialAsTime("expires_at"); expiresAt != nil { - if time.Now().Add(5 * time.Minute).After(*expiresAt) { - needRefresh = true - } - } - - if needRefresh && s.oauthService != nil { - tokenInfo, err := s.oauthService.RefreshAccountToken(ctx, account) - if err != nil { - return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to refresh token: %s", err.Error())) - } - authToken = tokenInfo.AccessToken - } } else if account.Type == "apikey" { // API Key - use x-api-key header useBearer = false @@ -335,15 +311,6 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account return s.sendErrorAndEnd(c, "No access token available") } - // Check if token is expired and refresh if needed - if account.IsOpenAITokenExpired() && s.openaiOAuthService != nil { - tokenInfo, err := s.openaiOAuthService.RefreshAccountToken(ctx, account) - if err != nil { - return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to refresh token: %s", err.Error())) - } - authToken = tokenInfo.AccessToken - } - // OAuth uses ChatGPT internal API apiURL = chatgptCodexAPIURL chatgptAccountID = account.GetChatGPTAccountID() diff --git a/backend/internal/service/account_usage_service.go b/backend/internal/service/account_usage_service.go index c4220c0c..439e9508 100644 --- a/backend/internal/service/account_usage_service.go +++ b/backend/internal/service/account_usage_service.go @@ -76,7 +76,7 @@ type antigravityUsageCache struct { } const ( - apiCacheTTL = 10 * time.Minute + apiCacheTTL = 3 * time.Minute windowStatsCacheTTL = 1 * time.Minute ) diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index 962b3684..707e728b 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -630,6 +630,7 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou Concurrency: input.Concurrency, Priority: input.Priority, Status: StatusActive, + Schedulable: true, } if err := s.accountRepo.Create(ctx, account); err != nil { return nil, err diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index 8f97598f..b0452be6 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -22,11 +22,27 @@ import ( const ( antigravityStickySessionTTL = time.Hour - antigravityMaxRetries = 5 + antigravityMaxRetries = 3 antigravityRetryBaseDelay = 1 * time.Second antigravityRetryMaxDelay = 16 * time.Second ) +// getSessionID 从 gin.Context 获取 session_id(用于日志追踪) +func getSessionID(c *gin.Context) string { + if c == nil { + return "" + } + return c.GetHeader("session_id") +} + +// logPrefix 生成统一的日志前缀 +func logPrefix(sessionID, accountName string) string { + if sessionID != "" { + return fmt.Sprintf("[antigravity-Forward] session=%s account=%s", sessionID, accountName) + } + return fmt.Sprintf("[antigravity-Forward] account=%s", accountName) +} + // Antigravity 直接支持的模型(精确匹配透传) var antigravitySupportedModels = map[string]bool{ "claude-opus-4-5-thinking": true, @@ -48,10 +64,11 @@ var antigravityPrefixMapping = []struct { target string }{ // 长前缀优先 - {"gemini-3-pro-image", "gemini-3-pro-image"}, // gemini-3-pro-image-preview 等 - {"claude-3-5-sonnet", "claude-sonnet-4-5"}, // 旧版 claude-3-5-sonnet-xxx - {"claude-sonnet-4-5", "claude-sonnet-4-5"}, // claude-sonnet-4-5-xxx - {"claude-haiku-4-5", "claude-sonnet-4-5"}, // claude-haiku-4-5-xxx → sonnet + {"gemini-2.5-flash-image", "gemini-3-pro-image"}, // gemini-2.5-flash-image → 3-pro-image + {"gemini-3-pro-image", "gemini-3-pro-image"}, // gemini-3-pro-image-preview 等 + {"claude-3-5-sonnet", "claude-sonnet-4-5"}, // 旧版 claude-3-5-sonnet-xxx + {"claude-sonnet-4-5", "claude-sonnet-4-5"}, // claude-sonnet-4-5-xxx + {"claude-haiku-4-5", "claude-sonnet-4-5"}, // claude-haiku-4-5-xxx → sonnet {"claude-opus-4-5", "claude-opus-4-5-thinking"}, {"claude-3-haiku", "claude-sonnet-4-5"}, // 旧版 claude-3-haiku-xxx → sonnet {"claude-sonnet-4", "claude-sonnet-4-5"}, @@ -315,6 +332,8 @@ func (s *AntigravityGatewayService) unwrapV1InternalResponse(body []byte) ([]byt // Forward 转发 Claude 协议请求(Claude → Gemini 转换) func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) { startTime := time.Now() + sessionID := getSessionID(c) + prefix := logPrefix(sessionID, account.Name) // 解析 Claude 请求 var claudeReq antigravity.ClaudeRequest @@ -369,10 +388,11 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) if err != nil { if attempt < antigravityMaxRetries { - log.Printf("Antigravity account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, antigravityMaxRetries, err) + log.Printf("%s status=request_failed retry=%d/%d error=%v", prefix, attempt, antigravityMaxRetries, err) sleepAntigravityBackoff(attempt) continue } + log.Printf("%s status=request_failed retries_exhausted error=%v", prefix, err) return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries") } @@ -381,13 +401,13 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, _ = resp.Body.Close() if attempt < antigravityMaxRetries { - log.Printf("Antigravity account %d: upstream status %d, retry %d/%d", account.ID, resp.StatusCode, attempt, antigravityMaxRetries) + log.Printf("%s status=%d retry=%d/%d", prefix, resp.StatusCode, attempt, antigravityMaxRetries) sleepAntigravityBackoff(attempt) continue } // 所有重试都失败,标记限流状态 if resp.StatusCode == 429 { - s.handleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody) } // 最后一次尝试也失败 resp = &http.Response{ @@ -405,7 +425,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, // 处理错误响应 if resp.StatusCode >= 400 { respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - s.handleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody) if s.shouldFailoverUpstreamError(resp.StatusCode) { return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} @@ -424,6 +444,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, if claudeReq.Stream { streamRes, err := s.handleClaudeStreamingResponse(c, resp, startTime, originalModel) if err != nil { + log.Printf("%s status=stream_error error=%v", prefix, err) return nil, err } usage = streamRes.usage @@ -448,6 +469,8 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, // ForwardGemini 转发 Gemini 协议请求 func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Context, account *Account, originalModel string, action string, stream bool, body []byte) (*ForwardResult, error) { startTime := time.Now() + sessionID := getSessionID(c) + prefix := logPrefix(sessionID, account.Name) if strings.TrimSpace(originalModel) == "" { return nil, s.writeGoogleError(c, http.StatusBadRequest, "Missing model in URL") @@ -523,10 +546,11 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) if err != nil { if attempt < antigravityMaxRetries { - log.Printf("Antigravity account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, antigravityMaxRetries, err) + log.Printf("%s status=request_failed retry=%d/%d error=%v", prefix, attempt, antigravityMaxRetries, err) sleepAntigravityBackoff(attempt) continue } + log.Printf("%s status=request_failed retries_exhausted error=%v", prefix, err) return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries") } @@ -535,13 +559,13 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co _ = resp.Body.Close() if attempt < antigravityMaxRetries { - log.Printf("Antigravity account %d: upstream status %d, retry %d/%d", account.ID, resp.StatusCode, attempt, antigravityMaxRetries) + log.Printf("%s status=%d retry=%d/%d", prefix, resp.StatusCode, attempt, antigravityMaxRetries) sleepAntigravityBackoff(attempt) continue } // 所有重试都失败,标记限流状态 if resp.StatusCode == 429 { - s.handleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody) } resp = &http.Response{ StatusCode: resp.StatusCode, @@ -563,7 +587,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co // 处理错误响应 if resp.StatusCode >= 400 { respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - s.handleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody) if s.shouldFailoverUpstreamError(resp.StatusCode) { return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} @@ -585,6 +609,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co if stream || upstreamAction == "streamGenerateContent" { streamRes, err := s.handleGeminiStreamingResponse(c, resp, startTime) if err != nil { + log.Printf("%s status=stream_error error=%v", prefix, err) return nil, err } usage = streamRes.usage @@ -633,7 +658,7 @@ func sleepAntigravityBackoff(attempt int) { sleepGeminiBackoff(attempt) // 复用 Gemini 的退避逻辑 } -func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, body []byte) { +func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte) { // 429 使用 Gemini 格式解析(从 body 解析重置时间) if statusCode == 429 { resetAt := ParseGeminiRateLimitResetTime(body) @@ -644,17 +669,23 @@ func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, acc defaultDur = 5 * time.Minute } ra := time.Now().Add(defaultDur) + log.Printf("%s status=429 rate_limited reset_in=%v (fallback)", prefix, defaultDur) _ = s.accountRepo.SetRateLimited(ctx, account.ID, ra) return } - _ = s.accountRepo.SetRateLimited(ctx, account.ID, time.Unix(*resetAt, 0)) + resetTime := time.Unix(*resetAt, 0) + log.Printf("%s status=429 rate_limited reset_at=%v reset_in=%v", prefix, resetTime.Format("15:04:05"), time.Until(resetTime).Truncate(time.Second)) + _ = s.accountRepo.SetRateLimited(ctx, account.ID, resetTime) return } // 其他错误码继续使用 rateLimitService if s.rateLimitService == nil { return } - s.rateLimitService.HandleUpstreamError(ctx, account, statusCode, headers, body) + shouldDisable := s.rateLimitService.HandleUpstreamError(ctx, account, statusCode, headers, body) + if shouldDisable { + log.Printf("%s status=%d marked_error", prefix, statusCode) + } } type antigravityStreamResult struct { @@ -851,7 +882,7 @@ func (s *AntigravityGatewayService) writeClaudeError(c *gin.Context, status int, func (s *AntigravityGatewayService) writeMappedClaudeError(c *gin.Context, upstreamStatus int, body []byte) error { // 记录上游错误详情便于调试 - log.Printf("Antigravity upstream error %d: %s", upstreamStatus, string(body)) + log.Printf("[antigravity-Forward] upstream_error status=%d body=%s", upstreamStatus, string(body)) var statusCode int var errType, errMsg string @@ -925,7 +956,7 @@ func (s *AntigravityGatewayService) handleClaudeNonStreamingResponse(c *gin.Cont // 转换 Gemini 响应为 Claude 格式 claudeResp, agUsage, err := antigravity.TransformGeminiToClaude(body, originalModel) if err != nil { - log.Printf("Transform Gemini to Claude failed: %v, body: %s", err, string(body)) + log.Printf("[antigravity-Forward] transform_error error=%v body=%s", err, string(body)) return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to parse upstream response") } diff --git a/backend/internal/service/gateway_prompt_test.go b/backend/internal/service/gateway_prompt_test.go new file mode 100644 index 00000000..b056f8fa --- /dev/null +++ b/backend/internal/service/gateway_prompt_test.go @@ -0,0 +1,233 @@ +package service + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestIsClaudeCodeClient(t *testing.T) { + tests := []struct { + name string + userAgent string + metadataUserID string + want bool + }{ + { + name: "Claude Code client", + userAgent: "claude-cli/1.0.62 (darwin; arm64)", + metadataUserID: "session_123e4567-e89b-12d3-a456-426614174000", + want: true, + }, + { + name: "Claude Code without version suffix", + userAgent: "claude-cli/2.0.0", + metadataUserID: "session_abc", + want: true, + }, + { + name: "Missing metadata user_id", + userAgent: "claude-cli/1.0.0", + metadataUserID: "", + want: false, + }, + { + name: "Different user agent", + userAgent: "curl/7.68.0", + metadataUserID: "user123", + want: false, + }, + { + name: "Empty user agent", + userAgent: "", + metadataUserID: "user123", + want: false, + }, + { + name: "Similar but not Claude CLI", + userAgent: "claude-api/1.0.0", + metadataUserID: "user123", + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := isClaudeCodeClient(tt.userAgent, tt.metadataUserID) + require.Equal(t, tt.want, got) + }) + } +} + +func TestSystemIncludesClaudeCodePrompt(t *testing.T) { + tests := []struct { + name string + system any + want bool + }{ + { + name: "nil system", + system: nil, + want: false, + }, + { + name: "empty string", + system: "", + want: false, + }, + { + name: "string with Claude Code prompt", + system: claudeCodeSystemPrompt, + want: true, + }, + { + name: "string with different content", + system: "You are a helpful assistant.", + want: false, + }, + { + name: "empty array", + system: []any{}, + want: false, + }, + { + name: "array with Claude Code prompt", + system: []any{ + map[string]any{ + "type": "text", + "text": claudeCodeSystemPrompt, + }, + }, + want: true, + }, + { + name: "array with Claude Code prompt in second position", + system: []any{ + map[string]any{"type": "text", "text": "First prompt"}, + map[string]any{"type": "text", "text": claudeCodeSystemPrompt}, + }, + want: true, + }, + { + name: "array without Claude Code prompt", + system: []any{ + map[string]any{"type": "text", "text": "Custom prompt"}, + }, + want: false, + }, + { + name: "array with partial match (should not match)", + system: []any{ + map[string]any{"type": "text", "text": "You are Claude"}, + }, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := systemIncludesClaudeCodePrompt(tt.system) + require.Equal(t, tt.want, got) + }) + } +} + +func TestInjectClaudeCodePrompt(t *testing.T) { + tests := []struct { + name string + body string + system any + wantSystemLen int + wantFirstText string + wantSecondText string + }{ + { + name: "nil system", + body: `{"model":"claude-3"}`, + system: nil, + wantSystemLen: 1, + wantFirstText: claudeCodeSystemPrompt, + }, + { + name: "empty string system", + body: `{"model":"claude-3"}`, + system: "", + wantSystemLen: 1, + wantFirstText: claudeCodeSystemPrompt, + }, + { + name: "string system", + body: `{"model":"claude-3"}`, + system: "Custom prompt", + wantSystemLen: 2, + wantFirstText: claudeCodeSystemPrompt, + wantSecondText: "Custom prompt", + }, + { + name: "string system equals Claude Code prompt", + body: `{"model":"claude-3"}`, + system: claudeCodeSystemPrompt, + wantSystemLen: 1, + wantFirstText: claudeCodeSystemPrompt, + }, + { + name: "array system", + body: `{"model":"claude-3"}`, + system: []any{map[string]any{"type": "text", "text": "Custom"}}, + // Claude Code + Custom = 2 + wantSystemLen: 2, + wantFirstText: claudeCodeSystemPrompt, + wantSecondText: "Custom", + }, + { + name: "array system with existing Claude Code prompt (should dedupe)", + body: `{"model":"claude-3"}`, + system: []any{ + map[string]any{"type": "text", "text": claudeCodeSystemPrompt}, + map[string]any{"type": "text", "text": "Other"}, + }, + // Claude Code at start + Other = 2 (deduped) + wantSystemLen: 2, + wantFirstText: claudeCodeSystemPrompt, + wantSecondText: "Other", + }, + { + name: "empty array", + body: `{"model":"claude-3"}`, + system: []any{}, + wantSystemLen: 1, + wantFirstText: claudeCodeSystemPrompt, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := injectClaudeCodePrompt([]byte(tt.body), tt.system) + + var parsed map[string]any + err := json.Unmarshal(result, &parsed) + require.NoError(t, err) + + system, ok := parsed["system"].([]any) + require.True(t, ok, "system should be an array") + require.Len(t, system, tt.wantSystemLen) + + first, ok := system[0].(map[string]any) + require.True(t, ok) + require.Equal(t, tt.wantFirstText, first["text"]) + require.Equal(t, "text", first["type"]) + + // Check cache_control + cc, ok := first["cache_control"].(map[string]any) + require.True(t, ok) + require.Equal(t, "ephemeral", cc["type"]) + + if tt.wantSecondText != "" && len(system) > 1 { + second, ok := system[1].(map[string]any) + require.True(t, ok) + require.Equal(t, tt.wantSecondText, second["text"]) + } + }) + } +} diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 47c136df..4946e7bc 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -34,13 +34,15 @@ const ( claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true" stickySessionTTL = time.Hour // 粘性会话TTL defaultMaxLineSize = 10 * 1024 * 1024 + claudeCodeSystemPrompt = "You are Claude Code, Anthropic's official CLI for Claude." ) // sseDataRe matches SSE data lines with optional whitespace after colon. // Some upstream APIs return non-standard "data:" without space (should be "data: "). var ( - sseDataRe = regexp.MustCompile(`^data:\s*`) - sessionIDRegex = regexp.MustCompile(`session_([a-f0-9-]{36})`) + sseDataRe = regexp.MustCompile(`^data:\s*`) + sessionIDRegex = regexp.MustCompile(`session_([a-f0-9-]{36})`) + claudeCliUserAgentRe = regexp.MustCompile(`^claude-cli/\d+\.\d+\.\d+`) ) // allowedHeaders 白名单headers(参考CRS项目) @@ -955,6 +957,76 @@ func (s *GatewayService) shouldFailoverUpstreamError(statusCode int) bool { } } +// isClaudeCodeClient 判断请求是否来自 Claude Code 客户端 +// 简化判断:User-Agent 匹配 + metadata.user_id 存在 +func isClaudeCodeClient(userAgent string, metadataUserID string) bool { + if metadataUserID == "" { + return false + } + return claudeCliUserAgentRe.MatchString(userAgent) +} + +// systemIncludesClaudeCodePrompt 检查 system 中是否已包含 Claude Code 提示词 +// 支持 string 和 []any 两种格式 +func systemIncludesClaudeCodePrompt(system any) bool { + switch v := system.(type) { + case string: + return v == claudeCodeSystemPrompt + case []any: + for _, item := range v { + if m, ok := item.(map[string]any); ok { + if text, ok := m["text"].(string); ok && text == claudeCodeSystemPrompt { + return true + } + } + } + } + return false +} + +// injectClaudeCodePrompt 在 system 开头注入 Claude Code 提示词 +// 处理 null、字符串、数组三种格式 +func injectClaudeCodePrompt(body []byte, system any) []byte { + claudeCodeBlock := map[string]any{ + "type": "text", + "text": claudeCodeSystemPrompt, + "cache_control": map[string]string{"type": "ephemeral"}, + } + + var newSystem []any + + switch v := system.(type) { + case nil: + newSystem = []any{claudeCodeBlock} + case string: + if v == "" || v == claudeCodeSystemPrompt { + newSystem = []any{claudeCodeBlock} + } else { + newSystem = []any{claudeCodeBlock, map[string]any{"type": "text", "text": v}} + } + case []any: + newSystem = make([]any, 0, len(v)+1) + newSystem = append(newSystem, claudeCodeBlock) + for _, item := range v { + if m, ok := item.(map[string]any); ok { + if text, ok := m["text"].(string); ok && text == claudeCodeSystemPrompt { + continue + } + } + newSystem = append(newSystem, item) + } + default: + newSystem = []any{claudeCodeBlock} + } + + result, err := sjson.SetBytes(body, "system", newSystem) + if err != nil { + log.Printf("Warning: failed to inject Claude Code prompt: %v", err) + return body + } + return result +} + // Forward 转发请求到Claude API func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) (*ForwardResult, error) { startTime := time.Now() @@ -966,16 +1038,13 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A reqModel := parsed.Model reqStream := parsed.Stream - if !parsed.HasSystem { - body, _ = sjson.SetBytes(body, "system", []any{ - map[string]any{ - "type": "text", - "text": "You are Claude Code, Anthropic's official CLI for Claude.", - "cache_control": map[string]string{ - "type": "ephemeral", - }, - }, - }) + // 智能注入 Claude Code 系统提示词(仅 OAuth/SetupToken 账号需要) + // 条件:1) OAuth/SetupToken 账号 2) 不是 Claude Code 客户端 3) 不是 Haiku 模型 4) system 中还没有 Claude Code 提示词 + if account.IsOAuth() && + !isClaudeCodeClient(c.GetHeader("User-Agent"), parsed.MetadataUserID) && + !strings.Contains(strings.ToLower(reqModel), "haiku") && + !systemIncludesClaudeCodePrompt(parsed.System) { + body = injectClaudeCodePrompt(body, parsed.System) } // 应用模型映射(仅对apikey类型账号) diff --git a/backend/internal/service/token_refresh_service.go b/backend/internal/service/token_refresh_service.go index 76ca61fd..3ed35f04 100644 --- a/backend/internal/service/token_refresh_service.go +++ b/backend/internal/service/token_refresh_service.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "log" + "strings" "sync" "time" @@ -171,6 +172,15 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc return nil } + // Antigravity 账户:不可重试错误直接标记 error 状态并返回 + if account.Platform == PlatformAntigravity && isNonRetryableRefreshError(err) { + errorMsg := fmt.Sprintf("Token refresh failed (non-retryable): %v", err) + if setErr := s.accountRepo.SetError(ctx, account.ID, errorMsg); setErr != nil { + log.Printf("[TokenRefresh] Failed to set error status for account %d: %v", account.ID, setErr) + } + return err + } + lastErr = err log.Printf("[TokenRefresh] Account %d attempt %d/%d failed: %v", account.ID, attempt, s.cfg.MaxRetries, err) @@ -183,11 +193,37 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc } } - // 所有重试都失败,标记账号为error状态 - errorMsg := fmt.Sprintf("Token refresh failed after %d retries: %v", s.cfg.MaxRetries, lastErr) - if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil { - log.Printf("[TokenRefresh] Failed to set error status for account %d: %v", account.ID, err) + // Antigravity 账户:其他错误仅记录日志,不标记 error(可能是临时网络问题) + // 其他平台账户:重试失败后标记 error + if account.Platform == PlatformAntigravity { + log.Printf("[TokenRefresh] Account %d: refresh failed after %d retries: %v", account.ID, s.cfg.MaxRetries, lastErr) + } else { + errorMsg := fmt.Sprintf("Token refresh failed after %d retries: %v", s.cfg.MaxRetries, lastErr) + if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil { + log.Printf("[TokenRefresh] Failed to set error status for account %d: %v", account.ID, err) + } } return lastErr } + +// isNonRetryableRefreshError 判断是否为不可重试的刷新错误 +// 这些错误通常表示凭证已失效,需要用户重新授权 +func isNonRetryableRefreshError(err error) bool { + if err == nil { + return false + } + msg := strings.ToLower(err.Error()) + nonRetryable := []string{ + "invalid_grant", // refresh_token 已失效 + "invalid_client", // 客户端配置错误 + "unauthorized_client", // 客户端未授权 + "access_denied", // 访问被拒绝 + } + for _, needle := range nonRetryable { + if strings.Contains(msg, needle) { + return true + } + } + return false +} diff --git a/frontend/src/components/common/DateRangePicker.vue b/frontend/src/components/common/DateRangePicker.vue index be641f9b..4fce029f 100644 --- a/frontend/src/components/common/DateRangePicker.vue +++ b/frontend/src/components/common/DateRangePicker.vue @@ -59,7 +59,7 @@ @@ -85,7 +85,7 @@ type="date" v-model="localEndDate" :min="localStartDate" - :max="today" + :max="tomorrow" class="date-picker-input" @change="onDateChange" /> @@ -144,6 +144,14 @@ const today = computed(() => { return `${year}-${month}-${day}` }) +// Tomorrow's date - used for max date to handle timezone differences +// When user is in a timezone behind the server, "today" on server might be "tomorrow" locally +const tomorrow = computed(() => { + const d = new Date() + d.setDate(d.getDate() + 1) + return formatDateToString(d) +}) + // Helper function to format date to YYYY-MM-DD using local timezone const formatDateToString = (date: Date): string => { const year = date.getFullYear() diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index a3e4c25e..2f6aa2c8 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -290,7 +290,7 @@ export interface UpdateGroupRequest { export type AccountPlatform = 'anthropic' | 'openai' | 'gemini' | 'antigravity' export type AccountType = 'oauth' | 'setup-token' | 'apikey' export type OAuthAddMethod = 'oauth' | 'setup-token' -export type ProxyProtocol = 'http' | 'https' | 'socks5' +export type ProxyProtocol = 'http' | 'https' | 'socks5' | 'socks5h' // Claude Model type (returned by /v1/models and account models API) export interface ClaudeModel { diff --git a/frontend/src/views/admin/ProxiesView.vue b/frontend/src/views/admin/ProxiesView.vue index a5df9bd0..613b503c 100644 --- a/frontend/src/views/admin/ProxiesView.vue +++ b/frontend/src/views/admin/ProxiesView.vue @@ -90,7 +90,7 @@