chore(合并): 同步主分支变更并解决冲突
- 合并 wire/httpclient/http_upstream/proxy_probe 冲突并保留校验逻辑 - 引入 proxyutil 及测试,完善代理配置 - 更新 goreleaser/workflow 与前端细节调整 测试: go test ./...
This commit is contained in:
63
.github/workflows/release.yml
vendored
63
.github/workflows/release.yml
vendored
@@ -4,6 +4,22 @@ on:
|
||||
push:
|
||||
tags:
|
||||
- 'v*'
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
tag:
|
||||
description: 'Tag to release (e.g., v1.0.0)'
|
||||
required: true
|
||||
type: string
|
||||
simple_release:
|
||||
description: 'Simple release: only x86_64 GHCR image, skip other artifacts'
|
||||
required: false
|
||||
type: boolean
|
||||
default: false
|
||||
|
||||
# 环境变量:合并 workflow_dispatch 输入和 repository variable
|
||||
# tag push 触发时读取 vars.SIMPLE_RELEASE,workflow_dispatch 时使用输入参数
|
||||
env:
|
||||
SIMPLE_RELEASE: ${{ github.event.inputs.simple_release == 'true' || vars.SIMPLE_RELEASE == 'true' }}
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
@@ -19,7 +35,12 @@ jobs:
|
||||
|
||||
- name: Update VERSION file
|
||||
run: |
|
||||
VERSION=${GITHUB_REF#refs/tags/v}
|
||||
if [ "${{ github.event_name }}" = "workflow_dispatch" ]; then
|
||||
VERSION=${{ github.event.inputs.tag }}
|
||||
VERSION=${VERSION#v}
|
||||
else
|
||||
VERSION=${GITHUB_REF#refs/tags/v}
|
||||
fi
|
||||
echo "$VERSION" > backend/cmd/server/VERSION
|
||||
echo "Updated VERSION file to: $VERSION"
|
||||
|
||||
@@ -66,6 +87,7 @@ jobs:
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
ref: ${{ github.event.inputs.tag || github.ref }}
|
||||
|
||||
- name: Download VERSION artifact
|
||||
uses: actions/download-artifact@v4
|
||||
@@ -93,7 +115,10 @@ jobs:
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to DockerHub
|
||||
if: ${{ env.DOCKERHUB_USERNAME != '' }}
|
||||
uses: docker/login-action@v3
|
||||
env:
|
||||
DOCKERHUB_USERNAME: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
@@ -113,7 +138,11 @@ jobs:
|
||||
- name: Get tag message
|
||||
id: tag_message
|
||||
run: |
|
||||
TAG_NAME=${GITHUB_REF#refs/tags/}
|
||||
if [ "${{ github.event_name }}" = "workflow_dispatch" ]; then
|
||||
TAG_NAME=${{ github.event.inputs.tag }}
|
||||
else
|
||||
TAG_NAME=${GITHUB_REF#refs/tags/}
|
||||
fi
|
||||
echo "Processing tag: $TAG_NAME"
|
||||
|
||||
# 获取完整的 tag message(跳过第一行标题)
|
||||
@@ -137,18 +166,21 @@ jobs:
|
||||
uses: goreleaser/goreleaser-action@v6
|
||||
with:
|
||||
version: '~> v2'
|
||||
args: release --clean --skip=validate
|
||||
args: release --clean --skip=validate ${{ env.SIMPLE_RELEASE == 'true' && '--config=.goreleaser.simple.yaml' || '' }}
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
TAG_MESSAGE: ${{ steps.tag_message.outputs.message }}
|
||||
GITHUB_REPO_OWNER: ${{ github.repository_owner }}
|
||||
GITHUB_REPO_OWNER_LOWER: ${{ steps.lowercase.outputs.owner }}
|
||||
GITHUB_REPO_NAME: ${{ github.event.repository.name }}
|
||||
DOCKERHUB_USERNAME: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
DOCKERHUB_USERNAME: ${{ secrets.DOCKERHUB_USERNAME || 'skip' }}
|
||||
|
||||
# Update DockerHub description
|
||||
- name: Update DockerHub description
|
||||
if: ${{ env.SIMPLE_RELEASE != 'true' && env.DOCKERHUB_USERNAME != '' }}
|
||||
uses: peter-evans/dockerhub-description@v4
|
||||
env:
|
||||
DOCKERHUB_USERNAME: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
@@ -158,9 +190,11 @@ jobs:
|
||||
|
||||
# Send Telegram notification
|
||||
- name: Send Telegram Notification
|
||||
if: ${{ env.SIMPLE_RELEASE != 'true' }}
|
||||
env:
|
||||
TELEGRAM_BOT_TOKEN: ${{ secrets.TELEGRAM_BOT_TOKEN }}
|
||||
TELEGRAM_CHAT_ID: ${{ secrets.TELEGRAM_CHAT_ID }}
|
||||
DOCKERHUB_USERNAME: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
continue-on-error: true
|
||||
run: |
|
||||
# 检查必要的环境变量
|
||||
@@ -169,10 +203,13 @@ jobs:
|
||||
exit 0
|
||||
fi
|
||||
|
||||
TAG_NAME=${GITHUB_REF#refs/tags/}
|
||||
if [ "${{ github.event_name }}" = "workflow_dispatch" ]; then
|
||||
TAG_NAME=${{ github.event.inputs.tag }}
|
||||
else
|
||||
TAG_NAME=${GITHUB_REF#refs/tags/}
|
||||
fi
|
||||
VERSION=${TAG_NAME#v}
|
||||
REPO="${{ github.repository }}"
|
||||
DOCKER_IMAGE="${{ secrets.DOCKERHUB_USERNAME }}/sub2api"
|
||||
GHCR_IMAGE="ghcr.io/${REPO,,}" # ${,,} converts to lowercase
|
||||
|
||||
# 获取 tag message 内容
|
||||
@@ -194,14 +231,20 @@ jobs:
|
||||
|
||||
MESSAGE+="🐳 *Docker 部署:*"$'\n'
|
||||
MESSAGE+="\`\`\`bash"$'\n'
|
||||
MESSAGE+="# Docker Hub"$'\n'
|
||||
MESSAGE+="docker pull ${DOCKER_IMAGE}:${TAG_NAME}"$'\n'
|
||||
MESSAGE+="# GitHub Container Registry"$'\n'
|
||||
# 根据是否配置 DockerHub 动态生成
|
||||
if [ -n "$DOCKERHUB_USERNAME" ]; then
|
||||
DOCKER_IMAGE="${DOCKERHUB_USERNAME}/sub2api"
|
||||
MESSAGE+="# Docker Hub"$'\n'
|
||||
MESSAGE+="docker pull ${DOCKER_IMAGE}:${TAG_NAME}"$'\n'
|
||||
MESSAGE+="# GitHub Container Registry"$'\n'
|
||||
fi
|
||||
MESSAGE+="docker pull ${GHCR_IMAGE}:${TAG_NAME}"$'\n'
|
||||
MESSAGE+="\`\`\`"$'\n'$'\n'
|
||||
MESSAGE+="🔗 *相关链接:*"$'\n'
|
||||
MESSAGE+="• [GitHub Release](https://github.com/${REPO}/releases/tag/${TAG_NAME})"$'\n'
|
||||
MESSAGE+="• [Docker Hub](https://hub.docker.com/r/${DOCKER_IMAGE})"$'\n'
|
||||
if [ -n "$DOCKERHUB_USERNAME" ]; then
|
||||
MESSAGE+="• [Docker Hub](https://hub.docker.com/r/${DOCKER_IMAGE})"$'\n'
|
||||
fi
|
||||
MESSAGE+="• [GitHub Packages](https://github.com/${REPO}/pkgs/container/sub2api)"$'\n'$'\n'
|
||||
MESSAGE+="#Sub2API #Release #${TAG_NAME//./_}"
|
||||
|
||||
|
||||
86
.goreleaser.simple.yaml
Normal file
86
.goreleaser.simple.yaml
Normal file
@@ -0,0 +1,86 @@
|
||||
# 简化版 GoReleaser 配置 - 仅发布 x86_64 GHCR 镜像
|
||||
version: 2
|
||||
|
||||
project_name: sub2api
|
||||
|
||||
before:
|
||||
hooks:
|
||||
- go mod tidy -C backend
|
||||
|
||||
builds:
|
||||
- id: sub2api
|
||||
dir: backend
|
||||
main: ./cmd/server
|
||||
binary: sub2api
|
||||
flags:
|
||||
- -tags=embed
|
||||
env:
|
||||
- CGO_ENABLED=0
|
||||
goos:
|
||||
- linux
|
||||
goarch:
|
||||
- amd64
|
||||
ldflags:
|
||||
- -s -w
|
||||
- -X main.Commit={{.Commit}}
|
||||
- -X main.Date={{.Date}}
|
||||
- -X main.BuildType=release
|
||||
|
||||
# 跳过 archives
|
||||
archives: []
|
||||
|
||||
# 跳过 checksum
|
||||
checksum:
|
||||
disable: true
|
||||
|
||||
changelog:
|
||||
disable: true
|
||||
|
||||
# 仅 GHCR x86_64 镜像
|
||||
dockers:
|
||||
- id: ghcr-amd64
|
||||
goos: linux
|
||||
goarch: amd64
|
||||
image_templates:
|
||||
- "ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:{{ .Version }}-amd64"
|
||||
- "ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:{{ .Version }}"
|
||||
- "ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:latest"
|
||||
dockerfile: Dockerfile.goreleaser
|
||||
use: buildx
|
||||
build_flag_templates:
|
||||
- "--platform=linux/amd64"
|
||||
- "--label=org.opencontainers.image.version={{ .Version }}"
|
||||
- "--label=org.opencontainers.image.revision={{ .Commit }}"
|
||||
- "--label=org.opencontainers.image.source=https://github.com/{{ .Env.GITHUB_REPO_OWNER }}/{{ .Env.GITHUB_REPO_NAME }}"
|
||||
|
||||
# 跳过 manifests(单架构不需要)
|
||||
docker_manifests: []
|
||||
|
||||
release:
|
||||
github:
|
||||
owner: "{{ .Env.GITHUB_REPO_OWNER }}"
|
||||
name: "{{ .Env.GITHUB_REPO_NAME }}"
|
||||
draft: false
|
||||
prerelease: auto
|
||||
name_template: "Sub2API {{.Version}} (Simple)"
|
||||
# 跳过上传二进制包
|
||||
skip_upload: true
|
||||
header: |
|
||||
> AI API Gateway Platform - 将 AI 订阅配额分发和管理
|
||||
> ⚡ Simple Release: 仅包含 x86_64 GHCR 镜像
|
||||
|
||||
{{ .Env.TAG_MESSAGE }}
|
||||
|
||||
footer: |
|
||||
---
|
||||
|
||||
## 📥 Installation
|
||||
|
||||
**Docker (x86_64 only):**
|
||||
```bash
|
||||
docker pull ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:{{ .Version }}
|
||||
```
|
||||
|
||||
## 📚 Documentation
|
||||
|
||||
- [GitHub Repository](https://github.com/{{ .Env.GITHUB_REPO_OWNER }}/{{ .Env.GITHUB_REPO_NAME }})
|
||||
@@ -54,9 +54,11 @@ changelog:
|
||||
|
||||
# Docker images
|
||||
dockers:
|
||||
# DockerHub images (skipped if DOCKERHUB_USERNAME is 'skip')
|
||||
- id: amd64
|
||||
goos: linux
|
||||
goarch: amd64
|
||||
skip_push: '{{ if eq .Env.DOCKERHUB_USERNAME "skip" }}true{{ else }}false{{ end }}'
|
||||
image_templates:
|
||||
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-amd64"
|
||||
dockerfile: Dockerfile.goreleaser
|
||||
@@ -69,6 +71,7 @@ dockers:
|
||||
- id: arm64
|
||||
goos: linux
|
||||
goarch: arm64
|
||||
skip_push: '{{ if eq .Env.DOCKERHUB_USERNAME "skip" }}true{{ else }}false{{ end }}'
|
||||
image_templates:
|
||||
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-arm64"
|
||||
dockerfile: Dockerfile.goreleaser
|
||||
@@ -107,22 +110,27 @@ dockers:
|
||||
|
||||
# Docker manifests for multi-arch support
|
||||
docker_manifests:
|
||||
# DockerHub manifests (skipped if DOCKERHUB_USERNAME is 'skip')
|
||||
- name_template: "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}"
|
||||
skip_push: '{{ if eq .Env.DOCKERHUB_USERNAME "skip" }}true{{ else }}false{{ end }}'
|
||||
image_templates:
|
||||
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-amd64"
|
||||
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-arm64"
|
||||
|
||||
- name_template: "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:latest"
|
||||
skip_push: '{{ if eq .Env.DOCKERHUB_USERNAME "skip" }}true{{ else }}false{{ end }}'
|
||||
image_templates:
|
||||
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-amd64"
|
||||
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-arm64"
|
||||
|
||||
- name_template: "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Major }}.{{ .Minor }}"
|
||||
skip_push: '{{ if eq .Env.DOCKERHUB_USERNAME "skip" }}true{{ else }}false{{ end }}'
|
||||
image_templates:
|
||||
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-amd64"
|
||||
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-arm64"
|
||||
|
||||
- name_template: "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Major }}"
|
||||
skip_push: '{{ if eq .Env.DOCKERHUB_USERNAME "skip" }}true{{ else }}false{{ end }}'
|
||||
image_templates:
|
||||
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-amd64"
|
||||
- "{{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}-arm64"
|
||||
@@ -169,9 +177,11 @@ release:
|
||||
|
||||
**Docker:**
|
||||
```bash
|
||||
{{ if ne .Env.DOCKERHUB_USERNAME "skip" -}}
|
||||
# Docker Hub
|
||||
docker pull {{ .Env.DOCKERHUB_USERNAME }}/sub2api:{{ .Version }}
|
||||
|
||||
{{ end -}}
|
||||
# GitHub Container Registry
|
||||
docker pull ghcr.io/{{ .Env.GITHUB_REPO_OWNER_LOWER }}/sub2api:{{ .Version }}
|
||||
```
|
||||
|
||||
@@ -100,7 +100,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService)
|
||||
httpUpstream := repository.NewHTTPUpstream(configConfig)
|
||||
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, antigravityTokenProvider, rateLimitService, httpUpstream, configConfig)
|
||||
accountTestService := service.NewAccountTestService(accountRepository, oAuthService, openAIOAuthService, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig)
|
||||
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig)
|
||||
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
|
||||
concurrencyService := service.NewConcurrencyService(concurrencyCache)
|
||||
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
|
||||
|
||||
@@ -26,7 +26,7 @@ func NewProxyHandler(adminService service.AdminService) *ProxyHandler {
|
||||
// CreateProxyRequest represents create proxy request
|
||||
type CreateProxyRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
Protocol string `json:"protocol" binding:"required,oneof=http https socks5"`
|
||||
Protocol string `json:"protocol" binding:"required,oneof=http https socks5 socks5h"`
|
||||
Host string `json:"host" binding:"required"`
|
||||
Port int `json:"port" binding:"required,min=1,max=65535"`
|
||||
Username string `json:"username"`
|
||||
@@ -36,7 +36,7 @@ type CreateProxyRequest struct {
|
||||
// UpdateProxyRequest represents update proxy request
|
||||
type UpdateProxyRequest struct {
|
||||
Name string `json:"name"`
|
||||
Protocol string `json:"protocol" binding:"omitempty,oneof=http https socks5"`
|
||||
Protocol string `json:"protocol" binding:"omitempty,oneof=http https socks5 socks5h"`
|
||||
Host string `json:"host"`
|
||||
Port int `json:"port" binding:"omitempty,min=1,max=65535"`
|
||||
Username string `json:"username"`
|
||||
@@ -255,7 +255,7 @@ func (h *ProxyHandler) GetProxyAccounts(c *gin.Context) {
|
||||
|
||||
// BatchCreateProxyItem represents a single proxy in batch create request
|
||||
type BatchCreateProxyItem struct {
|
||||
Protocol string `json:"protocol" binding:"required,oneof=http https socks5"`
|
||||
Protocol string `json:"protocol" binding:"required,oneof=http https socks5 socks5h"`
|
||||
Host string `json:"host" binding:"required"`
|
||||
Port int `json:"port" binding:"required,min=1,max=65535"`
|
||||
Username string `json:"username"`
|
||||
|
||||
@@ -389,7 +389,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
continue
|
||||
}
|
||||
// 错误响应已在Forward中处理,这里只记录日志
|
||||
log.Printf("Forward request failed: %v", err)
|
||||
log.Printf("Account %d: Forward request failed: %v", account.ID, err)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -237,7 +237,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
continue
|
||||
}
|
||||
// Error response already handled in Forward, just log
|
||||
log.Printf("Forward request failed: %v", err)
|
||||
log.Printf("Account %d: Forward request failed: %v", account.ID, err)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -11,23 +11,21 @@
|
||||
// 新实现使用统一的客户端池:
|
||||
// 1. 相同配置复用同一 http.Client 实例
|
||||
// 2. 复用 Transport 连接池,减少 TCP/TLS 握手开销
|
||||
// 3. 支持 HTTP/HTTPS/SOCKS5 代理
|
||||
// 4. 支持严格代理模式(代理失败则返回错误)
|
||||
// 3. 支持 HTTP/HTTPS/SOCKS5/SOCKS5H 代理
|
||||
// 4. 代理配置失败时直接返回错误,不会回退到直连(避免 IP 关联风险)
|
||||
package httpclient
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
||||
"golang.org/x/net/proxy"
|
||||
)
|
||||
|
||||
// Transport 连接池默认配置
|
||||
@@ -39,7 +37,7 @@ const (
|
||||
|
||||
// Options 定义共享 HTTP 客户端的构建参数
|
||||
type Options struct {
|
||||
ProxyURL string // 代理 URL(支持 http/https/socks5)
|
||||
ProxyURL string // 代理 URL(支持 http/https/socks5/socks5h)
|
||||
Timeout time.Duration // 请求总超时时间
|
||||
ResponseHeaderTimeout time.Duration // 等待响应头超时时间
|
||||
InsecureSkipVerify bool // 是否跳过 TLS 证书验证
|
||||
@@ -58,6 +56,7 @@ var sharedClients sync.Map
|
||||
|
||||
// GetClient 返回共享的 HTTP 客户端实例
|
||||
// 性能优化:相同配置复用同一客户端,避免重复创建 Transport
|
||||
// 安全说明:代理配置失败时直接返回错误,不会回退到直连,避免 IP 关联风险
|
||||
func GetClient(opts Options) (*http.Client, error) {
|
||||
key := buildClientKey(opts)
|
||||
if cached, ok := sharedClients.Load(key); ok {
|
||||
@@ -68,12 +67,7 @@ func GetClient(opts Options) (*http.Client, error) {
|
||||
|
||||
client, err := buildClient(opts)
|
||||
if err != nil {
|
||||
if opts.ProxyStrict {
|
||||
return nil, err
|
||||
}
|
||||
fallback := opts
|
||||
fallback.ProxyURL = ""
|
||||
client, _ = buildClient(fallback)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
actual, _ := sharedClients.LoadOrStore(key, client)
|
||||
@@ -132,19 +126,8 @@ func buildTransport(opts Options) (*http.Transport, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch strings.ToLower(parsed.Scheme) {
|
||||
case "http", "https":
|
||||
transport.Proxy = http.ProxyURL(parsed)
|
||||
case "socks5", "socks5h":
|
||||
dialer, err := proxy.FromURL(parsed, proxy.Direct)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return dialer.Dial(network, addr)
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported proxy protocol: %s", parsed.Scheme)
|
||||
if err := proxyutil.ConfigureTransportProxy(transport, parsed); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return transport, nil
|
||||
|
||||
62
backend/internal/pkg/proxyutil/dialer.go
Normal file
62
backend/internal/pkg/proxyutil/dialer.go
Normal file
@@ -0,0 +1,62 @@
|
||||
// Package proxyutil 提供统一的代理配置功能
|
||||
//
|
||||
// 支持的代理协议:
|
||||
// - HTTP/HTTPS: 通过 Transport.Proxy 设置
|
||||
// - SOCKS5/SOCKS5H: 通过 Transport.DialContext 设置(服务端解析 DNS)
|
||||
package proxyutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/net/proxy"
|
||||
)
|
||||
|
||||
// ConfigureTransportProxy 根据代理 URL 配置 Transport
|
||||
//
|
||||
// 支持的协议:
|
||||
// - http/https: 设置 transport.Proxy
|
||||
// - socks5/socks5h: 设置 transport.DialContext(由代理服务端解析 DNS)
|
||||
//
|
||||
// 参数:
|
||||
// - transport: 需要配置的 http.Transport
|
||||
// - proxyURL: 代理地址,nil 表示直连
|
||||
//
|
||||
// 返回:
|
||||
// - error: 代理配置错误(协议不支持或 dialer 创建失败)
|
||||
func ConfigureTransportProxy(transport *http.Transport, proxyURL *url.URL) error {
|
||||
if proxyURL == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
scheme := strings.ToLower(proxyURL.Scheme)
|
||||
switch scheme {
|
||||
case "http", "https":
|
||||
transport.Proxy = http.ProxyURL(proxyURL)
|
||||
return nil
|
||||
|
||||
case "socks5", "socks5h":
|
||||
dialer, err := proxy.FromURL(proxyURL, proxy.Direct)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create socks5 dialer: %w", err)
|
||||
}
|
||||
// 优先使用支持 context 的 DialContext,以支持请求取消和超时
|
||||
if contextDialer, ok := dialer.(proxy.ContextDialer); ok {
|
||||
transport.DialContext = contextDialer.DialContext
|
||||
} else {
|
||||
// 回退路径:如果 dialer 不支持 ContextDialer,则包装为简单的 DialContext
|
||||
// 注意:此回退不支持请求取消和超时控制
|
||||
transport.DialContext = func(_ context.Context, network, addr string) (net.Conn, error) {
|
||||
return dialer.Dial(network, addr)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
default:
|
||||
return fmt.Errorf("unsupported proxy scheme: %s", scheme)
|
||||
}
|
||||
}
|
||||
204
backend/internal/pkg/proxyutil/dialer_test.go
Normal file
204
backend/internal/pkg/proxyutil/dialer_test.go
Normal file
@@ -0,0 +1,204 @@
|
||||
package proxyutil
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestConfigureTransportProxy_Nil(t *testing.T) {
|
||||
transport := &http.Transport{}
|
||||
err := ConfigureTransportProxy(transport, nil)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, transport.Proxy, "nil proxy should not set Proxy")
|
||||
assert.Nil(t, transport.DialContext, "nil proxy should not set DialContext")
|
||||
}
|
||||
|
||||
func TestConfigureTransportProxy_HTTP(t *testing.T) {
|
||||
transport := &http.Transport{}
|
||||
proxyURL, _ := url.Parse("http://proxy.example.com:8080")
|
||||
|
||||
err := ConfigureTransportProxy(transport, proxyURL)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, transport.Proxy, "HTTP proxy should set Proxy")
|
||||
assert.Nil(t, transport.DialContext, "HTTP proxy should not set DialContext")
|
||||
}
|
||||
|
||||
func TestConfigureTransportProxy_HTTPS(t *testing.T) {
|
||||
transport := &http.Transport{}
|
||||
proxyURL, _ := url.Parse("https://secure-proxy.example.com:8443")
|
||||
|
||||
err := ConfigureTransportProxy(transport, proxyURL)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, transport.Proxy, "HTTPS proxy should set Proxy")
|
||||
assert.Nil(t, transport.DialContext, "HTTPS proxy should not set DialContext")
|
||||
}
|
||||
|
||||
func TestConfigureTransportProxy_SOCKS5(t *testing.T) {
|
||||
transport := &http.Transport{}
|
||||
proxyURL, _ := url.Parse("socks5://socks.example.com:1080")
|
||||
|
||||
err := ConfigureTransportProxy(transport, proxyURL)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, transport.Proxy, "SOCKS5 proxy should not set Proxy")
|
||||
assert.NotNil(t, transport.DialContext, "SOCKS5 proxy should set DialContext")
|
||||
}
|
||||
|
||||
func TestConfigureTransportProxy_SOCKS5H(t *testing.T) {
|
||||
transport := &http.Transport{}
|
||||
proxyURL, _ := url.Parse("socks5h://socks.example.com:1080")
|
||||
|
||||
err := ConfigureTransportProxy(transport, proxyURL)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, transport.Proxy, "SOCKS5H proxy should not set Proxy")
|
||||
assert.NotNil(t, transport.DialContext, "SOCKS5H proxy should set DialContext")
|
||||
}
|
||||
|
||||
func TestConfigureTransportProxy_CaseInsensitive(t *testing.T) {
|
||||
testCases := []struct {
|
||||
scheme string
|
||||
useProxy bool // true = uses Transport.Proxy, false = uses DialContext
|
||||
}{
|
||||
{"HTTP://proxy.example.com:8080", true},
|
||||
{"Http://proxy.example.com:8080", true},
|
||||
{"HTTPS://proxy.example.com:8443", true},
|
||||
{"Https://proxy.example.com:8443", true},
|
||||
{"SOCKS5://socks.example.com:1080", false},
|
||||
{"Socks5://socks.example.com:1080", false},
|
||||
{"SOCKS5H://socks.example.com:1080", false},
|
||||
{"Socks5h://socks.example.com:1080", false},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.scheme, func(t *testing.T) {
|
||||
transport := &http.Transport{}
|
||||
proxyURL, _ := url.Parse(tc.scheme)
|
||||
|
||||
err := ConfigureTransportProxy(transport, proxyURL)
|
||||
|
||||
require.NoError(t, err)
|
||||
if tc.useProxy {
|
||||
assert.NotNil(t, transport.Proxy)
|
||||
assert.Nil(t, transport.DialContext)
|
||||
} else {
|
||||
assert.Nil(t, transport.Proxy)
|
||||
assert.NotNil(t, transport.DialContext)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigureTransportProxy_Unsupported(t *testing.T) {
|
||||
testCases := []string{
|
||||
"ftp://ftp.example.com",
|
||||
"file:///path/to/file",
|
||||
"unknown://example.com",
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc, func(t *testing.T) {
|
||||
transport := &http.Transport{}
|
||||
proxyURL, _ := url.Parse(tc)
|
||||
|
||||
err := ConfigureTransportProxy(transport, proxyURL)
|
||||
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unsupported proxy scheme")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigureTransportProxy_WithAuth(t *testing.T) {
|
||||
transport := &http.Transport{}
|
||||
proxyURL, _ := url.Parse("socks5://user:password@socks.example.com:1080")
|
||||
|
||||
err := ConfigureTransportProxy(transport, proxyURL)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, transport.DialContext, "SOCKS5 with auth should set DialContext")
|
||||
}
|
||||
|
||||
func TestConfigureTransportProxy_EmptyScheme(t *testing.T) {
|
||||
transport := &http.Transport{}
|
||||
// 空 scheme 的 URL
|
||||
proxyURL := &url.URL{Host: "proxy.example.com:8080"}
|
||||
|
||||
err := ConfigureTransportProxy(transport, proxyURL)
|
||||
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unsupported proxy scheme")
|
||||
}
|
||||
|
||||
func TestConfigureTransportProxy_PreservesExistingConfig(t *testing.T) {
|
||||
// 验证代理配置不会覆盖 Transport 的其他配置
|
||||
transport := &http.Transport{
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 10,
|
||||
}
|
||||
proxyURL, _ := url.Parse("socks5://socks.example.com:1080")
|
||||
|
||||
err := ConfigureTransportProxy(transport, proxyURL)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 100, transport.MaxIdleConns, "MaxIdleConns should be preserved")
|
||||
assert.Equal(t, 10, transport.MaxIdleConnsPerHost, "MaxIdleConnsPerHost should be preserved")
|
||||
assert.NotNil(t, transport.DialContext, "DialContext should be set")
|
||||
}
|
||||
|
||||
func TestConfigureTransportProxy_IPv6(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
proxyURL string
|
||||
}{
|
||||
{"SOCKS5H with IPv6 loopback", "socks5h://[::1]:1080"},
|
||||
{"SOCKS5 with full IPv6", "socks5://[2001:db8::1]:1080"},
|
||||
{"HTTP with IPv6", "http://[::1]:8080"},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
transport := &http.Transport{}
|
||||
proxyURL, err := url.Parse(tc.proxyURL)
|
||||
require.NoError(t, err, "URL should be parseable")
|
||||
|
||||
err = ConfigureTransportProxy(transport, proxyURL)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigureTransportProxy_SpecialCharsInPassword(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
proxyURL string
|
||||
}{
|
||||
// 密码包含 @ 符号(URL 编码为 %40)
|
||||
{"password with @", "socks5://user:p%40ssword@proxy.example.com:1080"},
|
||||
// 密码包含 : 符号(URL 编码为 %3A)
|
||||
{"password with :", "socks5://user:pass%3Aword@proxy.example.com:1080"},
|
||||
// 密码包含 / 符号(URL 编码为 %2F)
|
||||
{"password with /", "socks5://user:pass%2Fword@proxy.example.com:1080"},
|
||||
// 复杂密码
|
||||
{"complex password", "socks5h://admin:P%40ss%3Aw0rd%2F123@proxy.example.com:1080"},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
transport := &http.Transport{}
|
||||
proxyURL, err := url.Parse(tc.proxyURL)
|
||||
require.NoError(t, err, "URL should be parseable")
|
||||
|
||||
err = ConfigureTransportProxy(transport, proxyURL)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, transport.DialContext, "SOCKS5 should set DialContext")
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -234,11 +234,17 @@ func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, pro
|
||||
}
|
||||
|
||||
func createReqClient(proxyURL string) *req.Client {
|
||||
return getSharedReqClient(reqClientOptions{
|
||||
ProxyURL: proxyURL,
|
||||
Timeout: 60 * time.Second,
|
||||
Impersonate: true,
|
||||
})
|
||||
// 禁用 CookieJar,确保每次授权都是干净的会话
|
||||
client := req.C().
|
||||
SetTimeout(60 * time.Second).
|
||||
ImpersonateChrome().
|
||||
SetCookieJar(nil) // 禁用 CookieJar
|
||||
|
||||
if strings.TrimSpace(proxyURL) != "" {
|
||||
client.SetProxyURL(strings.TrimSpace(proxyURL))
|
||||
}
|
||||
|
||||
return client
|
||||
}
|
||||
|
||||
func prefix(s string, n int) string {
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
||||
)
|
||||
@@ -261,7 +262,12 @@ func (s *httpUpstreamService) getClientEntry(proxyURL string, accountID int64, a
|
||||
|
||||
// 缓存未命中或需要重建,创建新客户端
|
||||
settings := s.resolvePoolSettings(isolation, accountConcurrency)
|
||||
client := &http.Client{Transport: buildUpstreamTransport(settings, parsedProxy)}
|
||||
transport, err := buildUpstreamTransport(settings, parsedProxy)
|
||||
if err != nil {
|
||||
s.mu.Unlock()
|
||||
return nil, fmt.Errorf("build transport: %w", err)
|
||||
}
|
||||
client := &http.Client{Transport: transport}
|
||||
if s.shouldValidateResolvedIP() {
|
||||
client.CheckRedirect = s.redirectChecker
|
||||
}
|
||||
@@ -587,6 +593,7 @@ func defaultPoolSettings(cfg *config.Config) poolSettings {
|
||||
//
|
||||
// 返回:
|
||||
// - *http.Transport: 配置好的 Transport 实例
|
||||
// - error: 代理配置错误
|
||||
//
|
||||
// Transport 参数说明:
|
||||
// - MaxIdleConns: 所有主机的最大空闲连接总数
|
||||
@@ -594,7 +601,7 @@ func defaultPoolSettings(cfg *config.Config) poolSettings {
|
||||
// - MaxConnsPerHost: 每主机最大连接数(达到后新请求等待)
|
||||
// - IdleConnTimeout: 空闲连接超时(超时后关闭)
|
||||
// - ResponseHeaderTimeout: 等待响应头超时(不影响流式传输)
|
||||
func buildUpstreamTransport(settings poolSettings, proxyURL *url.URL) *http.Transport {
|
||||
func buildUpstreamTransport(settings poolSettings, proxyURL *url.URL) (*http.Transport, error) {
|
||||
transport := &http.Transport{
|
||||
MaxIdleConns: settings.maxIdleConns,
|
||||
MaxIdleConnsPerHost: settings.maxIdleConnsPerHost,
|
||||
@@ -602,10 +609,10 @@ func buildUpstreamTransport(settings poolSettings, proxyURL *url.URL) *http.Tran
|
||||
IdleConnTimeout: settings.idleConnTimeout,
|
||||
ResponseHeaderTimeout: settings.responseHeaderTimeout,
|
||||
}
|
||||
if proxyURL != nil {
|
||||
transport.Proxy = http.ProxyURL(proxyURL)
|
||||
if err := proxyutil.ConfigureTransportProxy(transport, proxyURL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return transport
|
||||
return transport, nil
|
||||
}
|
||||
|
||||
// trackedBody 带跟踪功能的响应体包装器
|
||||
|
||||
@@ -45,8 +45,12 @@ func BenchmarkHTTPUpstreamProxyClient(b *testing.B) {
|
||||
settings := defaultPoolSettings(cfg)
|
||||
for i := 0; i < b.N; i++ {
|
||||
// 每次迭代都创建新客户端,包含 Transport 分配
|
||||
transport, err := buildUpstreamTransport(settings, parsedProxy)
|
||||
if err != nil {
|
||||
b.Fatalf("创建 Transport 失败: %v", err)
|
||||
}
|
||||
httpClientSink = &http.Client{
|
||||
Transport: buildUpstreamTransport(settings, parsedProxy),
|
||||
Transport: transport,
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
@@ -14,7 +14,6 @@ import (
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
@@ -31,7 +30,6 @@ var sseDataPrefix = regexp.MustCompile(`^data:\s*`)
|
||||
|
||||
const (
|
||||
testClaudeAPIURL = "https://api.anthropic.com/v1/messages"
|
||||
testOpenAIAPIURL = "https://api.openai.com/v1/responses"
|
||||
chatgptCodexAPIURL = "https://chatgpt.com/backend-api/codex/responses"
|
||||
)
|
||||
|
||||
@@ -47,8 +45,6 @@ type TestEvent struct {
|
||||
// AccountTestService handles account testing operations
|
||||
type AccountTestService struct {
|
||||
accountRepo AccountRepository
|
||||
oauthService *OAuthService
|
||||
openaiOAuthService *OpenAIOAuthService
|
||||
geminiTokenProvider *GeminiTokenProvider
|
||||
antigravityGatewayService *AntigravityGatewayService
|
||||
httpUpstream HTTPUpstream
|
||||
@@ -58,8 +54,6 @@ type AccountTestService struct {
|
||||
// NewAccountTestService creates a new AccountTestService
|
||||
func NewAccountTestService(
|
||||
accountRepo AccountRepository,
|
||||
oauthService *OAuthService,
|
||||
openaiOAuthService *OpenAIOAuthService,
|
||||
geminiTokenProvider *GeminiTokenProvider,
|
||||
antigravityGatewayService *AntigravityGatewayService,
|
||||
httpUpstream HTTPUpstream,
|
||||
@@ -67,8 +61,6 @@ func NewAccountTestService(
|
||||
) *AccountTestService {
|
||||
return &AccountTestService{
|
||||
accountRepo: accountRepo,
|
||||
oauthService: oauthService,
|
||||
openaiOAuthService: openaiOAuthService,
|
||||
geminiTokenProvider: geminiTokenProvider,
|
||||
antigravityGatewayService: antigravityGatewayService,
|
||||
httpUpstream: httpUpstream,
|
||||
@@ -204,22 +196,6 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
|
||||
if authToken == "" {
|
||||
return s.sendErrorAndEnd(c, "No access token available")
|
||||
}
|
||||
|
||||
// Check if token needs refresh
|
||||
needRefresh := false
|
||||
if expiresAt := account.GetCredentialAsTime("expires_at"); expiresAt != nil {
|
||||
if time.Now().Add(5 * time.Minute).After(*expiresAt) {
|
||||
needRefresh = true
|
||||
}
|
||||
}
|
||||
|
||||
if needRefresh && s.oauthService != nil {
|
||||
tokenInfo, err := s.oauthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to refresh token: %s", err.Error()))
|
||||
}
|
||||
authToken = tokenInfo.AccessToken
|
||||
}
|
||||
} else if account.Type == "apikey" {
|
||||
// API Key - use x-api-key header
|
||||
useBearer = false
|
||||
@@ -335,15 +311,6 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
|
||||
return s.sendErrorAndEnd(c, "No access token available")
|
||||
}
|
||||
|
||||
// Check if token is expired and refresh if needed
|
||||
if account.IsOpenAITokenExpired() && s.openaiOAuthService != nil {
|
||||
tokenInfo, err := s.openaiOAuthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to refresh token: %s", err.Error()))
|
||||
}
|
||||
authToken = tokenInfo.AccessToken
|
||||
}
|
||||
|
||||
// OAuth uses ChatGPT internal API
|
||||
apiURL = chatgptCodexAPIURL
|
||||
chatgptAccountID = account.GetChatGPTAccountID()
|
||||
|
||||
@@ -76,7 +76,7 @@ type antigravityUsageCache struct {
|
||||
}
|
||||
|
||||
const (
|
||||
apiCacheTTL = 10 * time.Minute
|
||||
apiCacheTTL = 3 * time.Minute
|
||||
windowStatsCacheTTL = 1 * time.Minute
|
||||
)
|
||||
|
||||
|
||||
@@ -630,6 +630,7 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
|
||||
Concurrency: input.Concurrency,
|
||||
Priority: input.Priority,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
}
|
||||
if err := s.accountRepo.Create(ctx, account); err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -22,11 +22,27 @@ import (
|
||||
|
||||
const (
|
||||
antigravityStickySessionTTL = time.Hour
|
||||
antigravityMaxRetries = 5
|
||||
antigravityMaxRetries = 3
|
||||
antigravityRetryBaseDelay = 1 * time.Second
|
||||
antigravityRetryMaxDelay = 16 * time.Second
|
||||
)
|
||||
|
||||
// getSessionID 从 gin.Context 获取 session_id(用于日志追踪)
|
||||
func getSessionID(c *gin.Context) string {
|
||||
if c == nil {
|
||||
return ""
|
||||
}
|
||||
return c.GetHeader("session_id")
|
||||
}
|
||||
|
||||
// logPrefix 生成统一的日志前缀
|
||||
func logPrefix(sessionID, accountName string) string {
|
||||
if sessionID != "" {
|
||||
return fmt.Sprintf("[antigravity-Forward] session=%s account=%s", sessionID, accountName)
|
||||
}
|
||||
return fmt.Sprintf("[antigravity-Forward] account=%s", accountName)
|
||||
}
|
||||
|
||||
// Antigravity 直接支持的模型(精确匹配透传)
|
||||
var antigravitySupportedModels = map[string]bool{
|
||||
"claude-opus-4-5-thinking": true,
|
||||
@@ -48,10 +64,11 @@ var antigravityPrefixMapping = []struct {
|
||||
target string
|
||||
}{
|
||||
// 长前缀优先
|
||||
{"gemini-3-pro-image", "gemini-3-pro-image"}, // gemini-3-pro-image-preview 等
|
||||
{"claude-3-5-sonnet", "claude-sonnet-4-5"}, // 旧版 claude-3-5-sonnet-xxx
|
||||
{"claude-sonnet-4-5", "claude-sonnet-4-5"}, // claude-sonnet-4-5-xxx
|
||||
{"claude-haiku-4-5", "claude-sonnet-4-5"}, // claude-haiku-4-5-xxx → sonnet
|
||||
{"gemini-2.5-flash-image", "gemini-3-pro-image"}, // gemini-2.5-flash-image → 3-pro-image
|
||||
{"gemini-3-pro-image", "gemini-3-pro-image"}, // gemini-3-pro-image-preview 等
|
||||
{"claude-3-5-sonnet", "claude-sonnet-4-5"}, // 旧版 claude-3-5-sonnet-xxx
|
||||
{"claude-sonnet-4-5", "claude-sonnet-4-5"}, // claude-sonnet-4-5-xxx
|
||||
{"claude-haiku-4-5", "claude-sonnet-4-5"}, // claude-haiku-4-5-xxx → sonnet
|
||||
{"claude-opus-4-5", "claude-opus-4-5-thinking"},
|
||||
{"claude-3-haiku", "claude-sonnet-4-5"}, // 旧版 claude-3-haiku-xxx → sonnet
|
||||
{"claude-sonnet-4", "claude-sonnet-4-5"},
|
||||
@@ -315,6 +332,8 @@ func (s *AntigravityGatewayService) unwrapV1InternalResponse(body []byte) ([]byt
|
||||
// Forward 转发 Claude 协议请求(Claude → Gemini 转换)
|
||||
func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) {
|
||||
startTime := time.Now()
|
||||
sessionID := getSessionID(c)
|
||||
prefix := logPrefix(sessionID, account.Name)
|
||||
|
||||
// 解析 Claude 请求
|
||||
var claudeReq antigravity.ClaudeRequest
|
||||
@@ -369,10 +388,11 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||||
if err != nil {
|
||||
if attempt < antigravityMaxRetries {
|
||||
log.Printf("Antigravity account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, antigravityMaxRetries, err)
|
||||
log.Printf("%s status=request_failed retry=%d/%d error=%v", prefix, attempt, antigravityMaxRetries, err)
|
||||
sleepAntigravityBackoff(attempt)
|
||||
continue
|
||||
}
|
||||
log.Printf("%s status=request_failed retries_exhausted error=%v", prefix, err)
|
||||
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries")
|
||||
}
|
||||
|
||||
@@ -381,13 +401,13 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
_ = resp.Body.Close()
|
||||
|
||||
if attempt < antigravityMaxRetries {
|
||||
log.Printf("Antigravity account %d: upstream status %d, retry %d/%d", account.ID, resp.StatusCode, attempt, antigravityMaxRetries)
|
||||
log.Printf("%s status=%d retry=%d/%d", prefix, resp.StatusCode, attempt, antigravityMaxRetries)
|
||||
sleepAntigravityBackoff(attempt)
|
||||
continue
|
||||
}
|
||||
// 所有重试都失败,标记限流状态
|
||||
if resp.StatusCode == 429 {
|
||||
s.handleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody)
|
||||
}
|
||||
// 最后一次尝试也失败
|
||||
resp = &http.Response{
|
||||
@@ -405,7 +425,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
// 处理错误响应
|
||||
if resp.StatusCode >= 400 {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
s.handleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody)
|
||||
|
||||
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
@@ -424,6 +444,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
if claudeReq.Stream {
|
||||
streamRes, err := s.handleClaudeStreamingResponse(c, resp, startTime, originalModel)
|
||||
if err != nil {
|
||||
log.Printf("%s status=stream_error error=%v", prefix, err)
|
||||
return nil, err
|
||||
}
|
||||
usage = streamRes.usage
|
||||
@@ -448,6 +469,8 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
// ForwardGemini 转发 Gemini 协议请求
|
||||
func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Context, account *Account, originalModel string, action string, stream bool, body []byte) (*ForwardResult, error) {
|
||||
startTime := time.Now()
|
||||
sessionID := getSessionID(c)
|
||||
prefix := logPrefix(sessionID, account.Name)
|
||||
|
||||
if strings.TrimSpace(originalModel) == "" {
|
||||
return nil, s.writeGoogleError(c, http.StatusBadRequest, "Missing model in URL")
|
||||
@@ -523,10 +546,11 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
||||
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||||
if err != nil {
|
||||
if attempt < antigravityMaxRetries {
|
||||
log.Printf("Antigravity account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, antigravityMaxRetries, err)
|
||||
log.Printf("%s status=request_failed retry=%d/%d error=%v", prefix, attempt, antigravityMaxRetries, err)
|
||||
sleepAntigravityBackoff(attempt)
|
||||
continue
|
||||
}
|
||||
log.Printf("%s status=request_failed retries_exhausted error=%v", prefix, err)
|
||||
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries")
|
||||
}
|
||||
|
||||
@@ -535,13 +559,13 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
||||
_ = resp.Body.Close()
|
||||
|
||||
if attempt < antigravityMaxRetries {
|
||||
log.Printf("Antigravity account %d: upstream status %d, retry %d/%d", account.ID, resp.StatusCode, attempt, antigravityMaxRetries)
|
||||
log.Printf("%s status=%d retry=%d/%d", prefix, resp.StatusCode, attempt, antigravityMaxRetries)
|
||||
sleepAntigravityBackoff(attempt)
|
||||
continue
|
||||
}
|
||||
// 所有重试都失败,标记限流状态
|
||||
if resp.StatusCode == 429 {
|
||||
s.handleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody)
|
||||
}
|
||||
resp = &http.Response{
|
||||
StatusCode: resp.StatusCode,
|
||||
@@ -563,7 +587,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
||||
// 处理错误响应
|
||||
if resp.StatusCode >= 400 {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
s.handleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody)
|
||||
|
||||
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
@@ -585,6 +609,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
||||
if stream || upstreamAction == "streamGenerateContent" {
|
||||
streamRes, err := s.handleGeminiStreamingResponse(c, resp, startTime)
|
||||
if err != nil {
|
||||
log.Printf("%s status=stream_error error=%v", prefix, err)
|
||||
return nil, err
|
||||
}
|
||||
usage = streamRes.usage
|
||||
@@ -633,7 +658,7 @@ func sleepAntigravityBackoff(attempt int) {
|
||||
sleepGeminiBackoff(attempt) // 复用 Gemini 的退避逻辑
|
||||
}
|
||||
|
||||
func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, body []byte) {
|
||||
func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte) {
|
||||
// 429 使用 Gemini 格式解析(从 body 解析重置时间)
|
||||
if statusCode == 429 {
|
||||
resetAt := ParseGeminiRateLimitResetTime(body)
|
||||
@@ -644,17 +669,23 @@ func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, acc
|
||||
defaultDur = 5 * time.Minute
|
||||
}
|
||||
ra := time.Now().Add(defaultDur)
|
||||
log.Printf("%s status=429 rate_limited reset_in=%v (fallback)", prefix, defaultDur)
|
||||
_ = s.accountRepo.SetRateLimited(ctx, account.ID, ra)
|
||||
return
|
||||
}
|
||||
_ = s.accountRepo.SetRateLimited(ctx, account.ID, time.Unix(*resetAt, 0))
|
||||
resetTime := time.Unix(*resetAt, 0)
|
||||
log.Printf("%s status=429 rate_limited reset_at=%v reset_in=%v", prefix, resetTime.Format("15:04:05"), time.Until(resetTime).Truncate(time.Second))
|
||||
_ = s.accountRepo.SetRateLimited(ctx, account.ID, resetTime)
|
||||
return
|
||||
}
|
||||
// 其他错误码继续使用 rateLimitService
|
||||
if s.rateLimitService == nil {
|
||||
return
|
||||
}
|
||||
s.rateLimitService.HandleUpstreamError(ctx, account, statusCode, headers, body)
|
||||
shouldDisable := s.rateLimitService.HandleUpstreamError(ctx, account, statusCode, headers, body)
|
||||
if shouldDisable {
|
||||
log.Printf("%s status=%d marked_error", prefix, statusCode)
|
||||
}
|
||||
}
|
||||
|
||||
type antigravityStreamResult struct {
|
||||
@@ -851,7 +882,7 @@ func (s *AntigravityGatewayService) writeClaudeError(c *gin.Context, status int,
|
||||
|
||||
func (s *AntigravityGatewayService) writeMappedClaudeError(c *gin.Context, upstreamStatus int, body []byte) error {
|
||||
// 记录上游错误详情便于调试
|
||||
log.Printf("Antigravity upstream error %d: %s", upstreamStatus, string(body))
|
||||
log.Printf("[antigravity-Forward] upstream_error status=%d body=%s", upstreamStatus, string(body))
|
||||
|
||||
var statusCode int
|
||||
var errType, errMsg string
|
||||
@@ -925,7 +956,7 @@ func (s *AntigravityGatewayService) handleClaudeNonStreamingResponse(c *gin.Cont
|
||||
// 转换 Gemini 响应为 Claude 格式
|
||||
claudeResp, agUsage, err := antigravity.TransformGeminiToClaude(body, originalModel)
|
||||
if err != nil {
|
||||
log.Printf("Transform Gemini to Claude failed: %v, body: %s", err, string(body))
|
||||
log.Printf("[antigravity-Forward] transform_error error=%v body=%s", err, string(body))
|
||||
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to parse upstream response")
|
||||
}
|
||||
|
||||
|
||||
233
backend/internal/service/gateway_prompt_test.go
Normal file
233
backend/internal/service/gateway_prompt_test.go
Normal file
@@ -0,0 +1,233 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestIsClaudeCodeClient(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
userAgent string
|
||||
metadataUserID string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "Claude Code client",
|
||||
userAgent: "claude-cli/1.0.62 (darwin; arm64)",
|
||||
metadataUserID: "session_123e4567-e89b-12d3-a456-426614174000",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "Claude Code without version suffix",
|
||||
userAgent: "claude-cli/2.0.0",
|
||||
metadataUserID: "session_abc",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "Missing metadata user_id",
|
||||
userAgent: "claude-cli/1.0.0",
|
||||
metadataUserID: "",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "Different user agent",
|
||||
userAgent: "curl/7.68.0",
|
||||
metadataUserID: "user123",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "Empty user agent",
|
||||
userAgent: "",
|
||||
metadataUserID: "user123",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "Similar but not Claude CLI",
|
||||
userAgent: "claude-api/1.0.0",
|
||||
metadataUserID: "user123",
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := isClaudeCodeClient(tt.userAgent, tt.metadataUserID)
|
||||
require.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSystemIncludesClaudeCodePrompt(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
system any
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "nil system",
|
||||
system: nil,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
system: "",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "string with Claude Code prompt",
|
||||
system: claudeCodeSystemPrompt,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "string with different content",
|
||||
system: "You are a helpful assistant.",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "empty array",
|
||||
system: []any{},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "array with Claude Code prompt",
|
||||
system: []any{
|
||||
map[string]any{
|
||||
"type": "text",
|
||||
"text": claudeCodeSystemPrompt,
|
||||
},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "array with Claude Code prompt in second position",
|
||||
system: []any{
|
||||
map[string]any{"type": "text", "text": "First prompt"},
|
||||
map[string]any{"type": "text", "text": claudeCodeSystemPrompt},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "array without Claude Code prompt",
|
||||
system: []any{
|
||||
map[string]any{"type": "text", "text": "Custom prompt"},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "array with partial match (should not match)",
|
||||
system: []any{
|
||||
map[string]any{"type": "text", "text": "You are Claude"},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := systemIncludesClaudeCodePrompt(tt.system)
|
||||
require.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestInjectClaudeCodePrompt(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
system any
|
||||
wantSystemLen int
|
||||
wantFirstText string
|
||||
wantSecondText string
|
||||
}{
|
||||
{
|
||||
name: "nil system",
|
||||
body: `{"model":"claude-3"}`,
|
||||
system: nil,
|
||||
wantSystemLen: 1,
|
||||
wantFirstText: claudeCodeSystemPrompt,
|
||||
},
|
||||
{
|
||||
name: "empty string system",
|
||||
body: `{"model":"claude-3"}`,
|
||||
system: "",
|
||||
wantSystemLen: 1,
|
||||
wantFirstText: claudeCodeSystemPrompt,
|
||||
},
|
||||
{
|
||||
name: "string system",
|
||||
body: `{"model":"claude-3"}`,
|
||||
system: "Custom prompt",
|
||||
wantSystemLen: 2,
|
||||
wantFirstText: claudeCodeSystemPrompt,
|
||||
wantSecondText: "Custom prompt",
|
||||
},
|
||||
{
|
||||
name: "string system equals Claude Code prompt",
|
||||
body: `{"model":"claude-3"}`,
|
||||
system: claudeCodeSystemPrompt,
|
||||
wantSystemLen: 1,
|
||||
wantFirstText: claudeCodeSystemPrompt,
|
||||
},
|
||||
{
|
||||
name: "array system",
|
||||
body: `{"model":"claude-3"}`,
|
||||
system: []any{map[string]any{"type": "text", "text": "Custom"}},
|
||||
// Claude Code + Custom = 2
|
||||
wantSystemLen: 2,
|
||||
wantFirstText: claudeCodeSystemPrompt,
|
||||
wantSecondText: "Custom",
|
||||
},
|
||||
{
|
||||
name: "array system with existing Claude Code prompt (should dedupe)",
|
||||
body: `{"model":"claude-3"}`,
|
||||
system: []any{
|
||||
map[string]any{"type": "text", "text": claudeCodeSystemPrompt},
|
||||
map[string]any{"type": "text", "text": "Other"},
|
||||
},
|
||||
// Claude Code at start + Other = 2 (deduped)
|
||||
wantSystemLen: 2,
|
||||
wantFirstText: claudeCodeSystemPrompt,
|
||||
wantSecondText: "Other",
|
||||
},
|
||||
{
|
||||
name: "empty array",
|
||||
body: `{"model":"claude-3"}`,
|
||||
system: []any{},
|
||||
wantSystemLen: 1,
|
||||
wantFirstText: claudeCodeSystemPrompt,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := injectClaudeCodePrompt([]byte(tt.body), tt.system)
|
||||
|
||||
var parsed map[string]any
|
||||
err := json.Unmarshal(result, &parsed)
|
||||
require.NoError(t, err)
|
||||
|
||||
system, ok := parsed["system"].([]any)
|
||||
require.True(t, ok, "system should be an array")
|
||||
require.Len(t, system, tt.wantSystemLen)
|
||||
|
||||
first, ok := system[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, tt.wantFirstText, first["text"])
|
||||
require.Equal(t, "text", first["type"])
|
||||
|
||||
// Check cache_control
|
||||
cc, ok := first["cache_control"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "ephemeral", cc["type"])
|
||||
|
||||
if tt.wantSecondText != "" && len(system) > 1 {
|
||||
second, ok := system[1].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, tt.wantSecondText, second["text"])
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -34,13 +34,15 @@ const (
|
||||
claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true"
|
||||
stickySessionTTL = time.Hour // 粘性会话TTL
|
||||
defaultMaxLineSize = 10 * 1024 * 1024
|
||||
claudeCodeSystemPrompt = "You are Claude Code, Anthropic's official CLI for Claude."
|
||||
)
|
||||
|
||||
// sseDataRe matches SSE data lines with optional whitespace after colon.
|
||||
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
|
||||
var (
|
||||
sseDataRe = regexp.MustCompile(`^data:\s*`)
|
||||
sessionIDRegex = regexp.MustCompile(`session_([a-f0-9-]{36})`)
|
||||
sseDataRe = regexp.MustCompile(`^data:\s*`)
|
||||
sessionIDRegex = regexp.MustCompile(`session_([a-f0-9-]{36})`)
|
||||
claudeCliUserAgentRe = regexp.MustCompile(`^claude-cli/\d+\.\d+\.\d+`)
|
||||
)
|
||||
|
||||
// allowedHeaders 白名单headers(参考CRS项目)
|
||||
@@ -955,6 +957,76 @@ func (s *GatewayService) shouldFailoverUpstreamError(statusCode int) bool {
|
||||
}
|
||||
}
|
||||
|
||||
// isClaudeCodeClient 判断请求是否来自 Claude Code 客户端
|
||||
// 简化判断:User-Agent 匹配 + metadata.user_id 存在
|
||||
func isClaudeCodeClient(userAgent string, metadataUserID string) bool {
|
||||
if metadataUserID == "" {
|
||||
return false
|
||||
}
|
||||
return claudeCliUserAgentRe.MatchString(userAgent)
|
||||
}
|
||||
|
||||
// systemIncludesClaudeCodePrompt 检查 system 中是否已包含 Claude Code 提示词
|
||||
// 支持 string 和 []any 两种格式
|
||||
func systemIncludesClaudeCodePrompt(system any) bool {
|
||||
switch v := system.(type) {
|
||||
case string:
|
||||
return v == claudeCodeSystemPrompt
|
||||
case []any:
|
||||
for _, item := range v {
|
||||
if m, ok := item.(map[string]any); ok {
|
||||
if text, ok := m["text"].(string); ok && text == claudeCodeSystemPrompt {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// injectClaudeCodePrompt 在 system 开头注入 Claude Code 提示词
|
||||
// 处理 null、字符串、数组三种格式
|
||||
func injectClaudeCodePrompt(body []byte, system any) []byte {
|
||||
claudeCodeBlock := map[string]any{
|
||||
"type": "text",
|
||||
"text": claudeCodeSystemPrompt,
|
||||
"cache_control": map[string]string{"type": "ephemeral"},
|
||||
}
|
||||
|
||||
var newSystem []any
|
||||
|
||||
switch v := system.(type) {
|
||||
case nil:
|
||||
newSystem = []any{claudeCodeBlock}
|
||||
case string:
|
||||
if v == "" || v == claudeCodeSystemPrompt {
|
||||
newSystem = []any{claudeCodeBlock}
|
||||
} else {
|
||||
newSystem = []any{claudeCodeBlock, map[string]any{"type": "text", "text": v}}
|
||||
}
|
||||
case []any:
|
||||
newSystem = make([]any, 0, len(v)+1)
|
||||
newSystem = append(newSystem, claudeCodeBlock)
|
||||
for _, item := range v {
|
||||
if m, ok := item.(map[string]any); ok {
|
||||
if text, ok := m["text"].(string); ok && text == claudeCodeSystemPrompt {
|
||||
continue
|
||||
}
|
||||
}
|
||||
newSystem = append(newSystem, item)
|
||||
}
|
||||
default:
|
||||
newSystem = []any{claudeCodeBlock}
|
||||
}
|
||||
|
||||
result, err := sjson.SetBytes(body, "system", newSystem)
|
||||
if err != nil {
|
||||
log.Printf("Warning: failed to inject Claude Code prompt: %v", err)
|
||||
return body
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// Forward 转发请求到Claude API
|
||||
func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) (*ForwardResult, error) {
|
||||
startTime := time.Now()
|
||||
@@ -966,16 +1038,13 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
reqModel := parsed.Model
|
||||
reqStream := parsed.Stream
|
||||
|
||||
if !parsed.HasSystem {
|
||||
body, _ = sjson.SetBytes(body, "system", []any{
|
||||
map[string]any{
|
||||
"type": "text",
|
||||
"text": "You are Claude Code, Anthropic's official CLI for Claude.",
|
||||
"cache_control": map[string]string{
|
||||
"type": "ephemeral",
|
||||
},
|
||||
},
|
||||
})
|
||||
// 智能注入 Claude Code 系统提示词(仅 OAuth/SetupToken 账号需要)
|
||||
// 条件:1) OAuth/SetupToken 账号 2) 不是 Claude Code 客户端 3) 不是 Haiku 模型 4) system 中还没有 Claude Code 提示词
|
||||
if account.IsOAuth() &&
|
||||
!isClaudeCodeClient(c.GetHeader("User-Agent"), parsed.MetadataUserID) &&
|
||||
!strings.Contains(strings.ToLower(reqModel), "haiku") &&
|
||||
!systemIncludesClaudeCodePrompt(parsed.System) {
|
||||
body = injectClaudeCodePrompt(body, parsed.System)
|
||||
}
|
||||
|
||||
// 应用模型映射(仅对apikey类型账号)
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -171,6 +172,15 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
|
||||
return nil
|
||||
}
|
||||
|
||||
// Antigravity 账户:不可重试错误直接标记 error 状态并返回
|
||||
if account.Platform == PlatformAntigravity && isNonRetryableRefreshError(err) {
|
||||
errorMsg := fmt.Sprintf("Token refresh failed (non-retryable): %v", err)
|
||||
if setErr := s.accountRepo.SetError(ctx, account.ID, errorMsg); setErr != nil {
|
||||
log.Printf("[TokenRefresh] Failed to set error status for account %d: %v", account.ID, setErr)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
lastErr = err
|
||||
log.Printf("[TokenRefresh] Account %d attempt %d/%d failed: %v",
|
||||
account.ID, attempt, s.cfg.MaxRetries, err)
|
||||
@@ -183,11 +193,37 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
|
||||
}
|
||||
}
|
||||
|
||||
// 所有重试都失败,标记账号为error状态
|
||||
errorMsg := fmt.Sprintf("Token refresh failed after %d retries: %v", s.cfg.MaxRetries, lastErr)
|
||||
if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil {
|
||||
log.Printf("[TokenRefresh] Failed to set error status for account %d: %v", account.ID, err)
|
||||
// Antigravity 账户:其他错误仅记录日志,不标记 error(可能是临时网络问题)
|
||||
// 其他平台账户:重试失败后标记 error
|
||||
if account.Platform == PlatformAntigravity {
|
||||
log.Printf("[TokenRefresh] Account %d: refresh failed after %d retries: %v", account.ID, s.cfg.MaxRetries, lastErr)
|
||||
} else {
|
||||
errorMsg := fmt.Sprintf("Token refresh failed after %d retries: %v", s.cfg.MaxRetries, lastErr)
|
||||
if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil {
|
||||
log.Printf("[TokenRefresh] Failed to set error status for account %d: %v", account.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
return lastErr
|
||||
}
|
||||
|
||||
// isNonRetryableRefreshError 判断是否为不可重试的刷新错误
|
||||
// 这些错误通常表示凭证已失效,需要用户重新授权
|
||||
func isNonRetryableRefreshError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
msg := strings.ToLower(err.Error())
|
||||
nonRetryable := []string{
|
||||
"invalid_grant", // refresh_token 已失效
|
||||
"invalid_client", // 客户端配置错误
|
||||
"unauthorized_client", // 客户端未授权
|
||||
"access_denied", // 访问被拒绝
|
||||
}
|
||||
for _, needle := range nonRetryable {
|
||||
if strings.Contains(msg, needle) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -59,7 +59,7 @@
|
||||
<input
|
||||
type="date"
|
||||
v-model="localStartDate"
|
||||
:max="localEndDate || today"
|
||||
:max="localEndDate || tomorrow"
|
||||
class="date-picker-input"
|
||||
@change="onDateChange"
|
||||
/>
|
||||
@@ -85,7 +85,7 @@
|
||||
type="date"
|
||||
v-model="localEndDate"
|
||||
:min="localStartDate"
|
||||
:max="today"
|
||||
:max="tomorrow"
|
||||
class="date-picker-input"
|
||||
@change="onDateChange"
|
||||
/>
|
||||
@@ -144,6 +144,14 @@ const today = computed(() => {
|
||||
return `${year}-${month}-${day}`
|
||||
})
|
||||
|
||||
// Tomorrow's date - used for max date to handle timezone differences
|
||||
// When user is in a timezone behind the server, "today" on server might be "tomorrow" locally
|
||||
const tomorrow = computed(() => {
|
||||
const d = new Date()
|
||||
d.setDate(d.getDate() + 1)
|
||||
return formatDateToString(d)
|
||||
})
|
||||
|
||||
// Helper function to format date to YYYY-MM-DD using local timezone
|
||||
const formatDateToString = (date: Date): string => {
|
||||
const year = date.getFullYear()
|
||||
|
||||
@@ -290,7 +290,7 @@ export interface UpdateGroupRequest {
|
||||
export type AccountPlatform = 'anthropic' | 'openai' | 'gemini' | 'antigravity'
|
||||
export type AccountType = 'oauth' | 'setup-token' | 'apikey'
|
||||
export type OAuthAddMethod = 'oauth' | 'setup-token'
|
||||
export type ProxyProtocol = 'http' | 'https' | 'socks5'
|
||||
export type ProxyProtocol = 'http' | 'https' | 'socks5' | 'socks5h'
|
||||
|
||||
// Claude Model type (returned by /v1/models and account models API)
|
||||
export interface ClaudeModel {
|
||||
|
||||
@@ -90,7 +90,7 @@
|
||||
<template #cell-protocol="{ value }">
|
||||
<span
|
||||
v-if="value"
|
||||
:class="['badge', value === 'socks5' ? 'badge-primary' : 'badge-gray']"
|
||||
:class="['badge', value.startsWith('socks5') ? 'badge-primary' : 'badge-gray']"
|
||||
>
|
||||
{{ value.toUpperCase() }}
|
||||
</span>
|
||||
@@ -628,7 +628,8 @@ const protocolOptions = computed(() => [
|
||||
{ value: '', label: t('admin.proxies.allProtocols') },
|
||||
{ value: 'http', label: 'HTTP' },
|
||||
{ value: 'https', label: 'HTTPS' },
|
||||
{ value: 'socks5', label: 'SOCKS5' }
|
||||
{ value: 'socks5', label: 'SOCKS5' },
|
||||
{ value: 'socks5h', label: 'SOCKS5H' }
|
||||
])
|
||||
|
||||
const statusOptions = computed(() => [
|
||||
@@ -641,7 +642,8 @@ const statusOptions = computed(() => [
|
||||
const protocolSelectOptions = [
|
||||
{ value: 'http', label: 'HTTP' },
|
||||
{ value: 'https', label: 'HTTPS' },
|
||||
{ value: 'socks5', label: 'SOCKS5' }
|
||||
{ value: 'socks5', label: 'SOCKS5' },
|
||||
{ value: 'socks5h', label: 'SOCKS5H (服务端解析DNS)' }
|
||||
]
|
||||
|
||||
const editStatusOptions = computed(() => [
|
||||
@@ -798,8 +800,8 @@ const parseProxyUrl = (
|
||||
const trimmed = line.trim()
|
||||
if (!trimmed) return null
|
||||
|
||||
// Regex to parse proxy URL
|
||||
const regex = /^(https?|socks5):\/\/(?:([^:@]+):([^@]+)@)?([^:]+):(\d+)$/i
|
||||
// Regex to parse proxy URL (supports http, https, socks5, socks5h)
|
||||
const regex = /^(https?|socks5h?):\/\/(?:([^:@]+):([^@]+)@)?([^:]+):(\d+)$/i
|
||||
const match = trimmed.match(regex)
|
||||
|
||||
if (!match) return null
|
||||
|
||||
@@ -888,13 +888,17 @@ const formatLocalDate = (date: Date): string => {
|
||||
}
|
||||
|
||||
// Initialize date range immediately
|
||||
// Use tomorrow as end date to handle timezone differences between client and server
|
||||
// e.g., when server is in Asia/Shanghai and client is in America/Chicago
|
||||
const now = new Date()
|
||||
const tomorrow = new Date(now)
|
||||
tomorrow.setDate(tomorrow.getDate() + 1)
|
||||
const weekAgo = new Date(now)
|
||||
weekAgo.setDate(weekAgo.getDate() - 6)
|
||||
|
||||
// Date range state
|
||||
const startDate = ref(formatLocalDate(weekAgo))
|
||||
const endDate = ref(formatLocalDate(now))
|
||||
const endDate = ref(formatLocalDate(tomorrow))
|
||||
|
||||
const filters = ref<AdminUsageQueryParams>({
|
||||
user_id: undefined,
|
||||
@@ -1215,12 +1219,14 @@ const resetFilters = () => {
|
||||
end_date: undefined
|
||||
}
|
||||
granularity.value = 'day'
|
||||
// Reset date range to default (last 7 days)
|
||||
// Reset date range to default (last 7 days, with tomorrow as end to handle timezone differences)
|
||||
const now = new Date()
|
||||
const tomorrowDate = new Date(now)
|
||||
tomorrowDate.setDate(tomorrowDate.getDate() + 1)
|
||||
const weekAgo = new Date(now)
|
||||
weekAgo.setDate(weekAgo.getDate() - 6)
|
||||
startDate.value = formatLocalDate(weekAgo)
|
||||
endDate.value = formatLocalDate(now)
|
||||
endDate.value = formatLocalDate(tomorrowDate)
|
||||
filters.value.start_date = startDate.value
|
||||
filters.value.end_date = endDate.value
|
||||
pagination.value.page = 1
|
||||
|
||||
Reference in New Issue
Block a user