feat(sync): full code sync from release
This commit is contained in:
@@ -152,6 +152,7 @@ var claudeModels = []modelDef{
|
||||
{ID: "claude-sonnet-4-5", DisplayName: "Claude Sonnet 4.5", CreatedAt: "2025-09-29T00:00:00Z"},
|
||||
{ID: "claude-sonnet-4-5-thinking", DisplayName: "Claude Sonnet 4.5 Thinking", CreatedAt: "2025-09-29T00:00:00Z"},
|
||||
{ID: "claude-opus-4-6", DisplayName: "Claude Opus 4.6", CreatedAt: "2026-02-05T00:00:00Z"},
|
||||
{ID: "claude-opus-4-6-thinking", DisplayName: "Claude Opus 4.6 Thinking", CreatedAt: "2026-02-05T00:00:00Z"},
|
||||
{ID: "claude-sonnet-4-6", DisplayName: "Claude Sonnet 4.6", CreatedAt: "2026-02-17T00:00:00Z"},
|
||||
}
|
||||
|
||||
@@ -165,6 +166,8 @@ var geminiModels = []modelDef{
|
||||
{ID: "gemini-3-pro-high", DisplayName: "Gemini 3 Pro High", CreatedAt: "2025-06-01T00:00:00Z"},
|
||||
{ID: "gemini-3.1-pro-low", DisplayName: "Gemini 3.1 Pro Low", CreatedAt: "2026-02-19T00:00:00Z"},
|
||||
{ID: "gemini-3.1-pro-high", DisplayName: "Gemini 3.1 Pro High", CreatedAt: "2026-02-19T00:00:00Z"},
|
||||
{ID: "gemini-3.1-flash-image", DisplayName: "Gemini 3.1 Flash Image", CreatedAt: "2026-02-19T00:00:00Z"},
|
||||
{ID: "gemini-3.1-flash-image-preview", DisplayName: "Gemini 3.1 Flash Image Preview", CreatedAt: "2026-02-19T00:00:00Z"},
|
||||
{ID: "gemini-3-pro-preview", DisplayName: "Gemini 3 Pro Preview", CreatedAt: "2025-06-01T00:00:00Z"},
|
||||
{ID: "gemini-3-pro-image", DisplayName: "Gemini 3 Pro Image", CreatedAt: "2025-06-01T00:00:00Z"},
|
||||
}
|
||||
|
||||
26
backend/internal/pkg/antigravity/claude_types_test.go
Normal file
26
backend/internal/pkg/antigravity/claude_types_test.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package antigravity
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestDefaultModels_ContainsNewAndLegacyImageModels(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
models := DefaultModels()
|
||||
byID := make(map[string]ClaudeModel, len(models))
|
||||
for _, m := range models {
|
||||
byID[m.ID] = m
|
||||
}
|
||||
|
||||
requiredIDs := []string{
|
||||
"claude-opus-4-6-thinking",
|
||||
"gemini-3.1-flash-image",
|
||||
"gemini-3.1-flash-image-preview",
|
||||
"gemini-3-pro-image", // legacy compatibility
|
||||
}
|
||||
|
||||
for _, id := range requiredIDs {
|
||||
if _, ok := byID[id]; !ok {
|
||||
t.Fatalf("expected model %q to be exposed in DefaultModels", id)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -70,7 +70,7 @@ type GeminiGenerationConfig struct {
|
||||
ImageConfig *GeminiImageConfig `json:"imageConfig,omitempty"`
|
||||
}
|
||||
|
||||
// GeminiImageConfig Gemini 图片生成配置(仅 gemini-3-pro-image 支持)
|
||||
// GeminiImageConfig Gemini 图片生成配置(gemini-3-pro-image / gemini-3.1-flash-image 等图片模型支持)
|
||||
type GeminiImageConfig struct {
|
||||
AspectRatio string `json:"aspectRatio,omitempty"` // "1:1", "16:9", "9:16", "4:3", "3:4"
|
||||
ImageSize string `json:"imageSize,omitempty"` // "1K", "2K", "4K"
|
||||
|
||||
@@ -53,7 +53,8 @@ const (
|
||||
var defaultUserAgentVersion = "1.19.6"
|
||||
|
||||
// defaultClientSecret 可通过环境变量 ANTIGRAVITY_OAUTH_CLIENT_SECRET 配置
|
||||
var defaultClientSecret = "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"
|
||||
// 默认值使用占位符,生产环境请通过环境变量注入真实值。
|
||||
var defaultClientSecret = "GOCSPX-your-client-secret"
|
||||
|
||||
func init() {
|
||||
// 从环境变量读取版本号,未设置则使用默认值
|
||||
|
||||
@@ -612,14 +612,14 @@ func TestBuildAuthorizationURL_参数验证(t *testing.T) {
|
||||
|
||||
expectedParams := map[string]string{
|
||||
"client_id": ClientID,
|
||||
"redirect_uri": RedirectURI,
|
||||
"response_type": "code",
|
||||
"scope": Scopes,
|
||||
"state": state,
|
||||
"code_challenge": codeChallenge,
|
||||
"code_challenge_method": "S256",
|
||||
"access_type": "offline",
|
||||
"prompt": "consent",
|
||||
"redirect_uri": RedirectURI,
|
||||
"response_type": "code",
|
||||
"scope": Scopes,
|
||||
"state": state,
|
||||
"code_challenge": codeChallenge,
|
||||
"code_challenge_method": "S256",
|
||||
"access_type": "offline",
|
||||
"prompt": "consent",
|
||||
"include_granted_scopes": "true",
|
||||
}
|
||||
|
||||
@@ -684,7 +684,7 @@ func TestConstants_值正确(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("getClientSecret 应返回默认值,但报错: %v", err)
|
||||
}
|
||||
if secret != "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" {
|
||||
if secret != "GOCSPX-your-client-secret" {
|
||||
t.Errorf("默认 client_secret 不匹配: got %s", secret)
|
||||
}
|
||||
if RedirectURI != "http://localhost:8085/callback" {
|
||||
|
||||
@@ -166,3 +166,18 @@ func TestToHTTP(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToHTTP_MetadataDeepCopy(t *testing.T) {
|
||||
md := map[string]string{"k": "v"}
|
||||
appErr := BadRequest("BAD_REQUEST", "invalid").WithMetadata(md)
|
||||
|
||||
code, body := ToHTTP(appErr)
|
||||
require.Equal(t, http.StatusBadRequest, code)
|
||||
require.Equal(t, "v", body.Metadata["k"])
|
||||
|
||||
md["k"] = "changed"
|
||||
require.Equal(t, "v", body.Metadata["k"])
|
||||
|
||||
appErr.Metadata["k"] = "changed-again"
|
||||
require.Equal(t, "v", body.Metadata["k"])
|
||||
}
|
||||
|
||||
@@ -16,6 +16,16 @@ func ToHTTP(err error) (statusCode int, body Status) {
|
||||
return http.StatusOK, Status{Code: int32(http.StatusOK)}
|
||||
}
|
||||
|
||||
cloned := Clone(appErr)
|
||||
return int(cloned.Code), cloned.Status
|
||||
body = Status{
|
||||
Code: appErr.Code,
|
||||
Reason: appErr.Reason,
|
||||
Message: appErr.Message,
|
||||
}
|
||||
if appErr.Metadata != nil {
|
||||
body.Metadata = make(map[string]string, len(appErr.Metadata))
|
||||
for k, v := range appErr.Metadata {
|
||||
body.Metadata[k] = v
|
||||
}
|
||||
}
|
||||
return int(appErr.Code), body
|
||||
}
|
||||
|
||||
@@ -39,7 +39,7 @@ const (
|
||||
// They enable the "login without creating your own OAuth client" experience, but Google may
|
||||
// restrict which scopes are allowed for this client.
|
||||
GeminiCLIOAuthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
|
||||
GeminiCLIOAuthClientSecret = "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"
|
||||
GeminiCLIOAuthClientSecret = "GOCSPX-your-client-secret"
|
||||
|
||||
// GeminiCLIOAuthClientSecretEnv is the environment variable name for the built-in client secret.
|
||||
GeminiCLIOAuthClientSecretEnv = "GEMINI_CLI_OAUTH_CLIENT_SECRET"
|
||||
|
||||
@@ -32,6 +32,7 @@ const (
|
||||
defaultMaxIdleConns = 100 // 最大空闲连接数
|
||||
defaultMaxIdleConnsPerHost = 10 // 每个主机最大空闲连接数
|
||||
defaultIdleConnTimeout = 90 * time.Second // 空闲连接超时时间(建议小于上游 LB 超时)
|
||||
validatedHostTTL = 30 * time.Second // DNS Rebinding 校验缓存 TTL
|
||||
)
|
||||
|
||||
// Options 定义共享 HTTP 客户端的构建参数
|
||||
@@ -53,6 +54,9 @@ type Options struct {
|
||||
// sharedClients 存储按配置参数缓存的 http.Client 实例
|
||||
var sharedClients sync.Map
|
||||
|
||||
// 允许测试替换校验函数,生产默认指向真实实现。
|
||||
var validateResolvedIP = urlvalidator.ValidateResolvedIP
|
||||
|
||||
// GetClient 返回共享的 HTTP 客户端实例
|
||||
// 性能优化:相同配置复用同一客户端,避免重复创建 Transport
|
||||
// 安全说明:代理配置失败时直接返回错误,不会回退到直连,避免 IP 关联风险
|
||||
@@ -84,7 +88,7 @@ func buildClient(opts Options) (*http.Client, error) {
|
||||
|
||||
var rt http.RoundTripper = transport
|
||||
if opts.ValidateResolvedIP && !opts.AllowPrivateHosts {
|
||||
rt = &validatedTransport{base: transport}
|
||||
rt = newValidatedTransport(transport)
|
||||
}
|
||||
return &http.Client{
|
||||
Transport: rt,
|
||||
@@ -149,17 +153,56 @@ func buildClientKey(opts Options) string {
|
||||
}
|
||||
|
||||
type validatedTransport struct {
|
||||
base http.RoundTripper
|
||||
base http.RoundTripper
|
||||
validatedHosts sync.Map // map[string]time.Time, value 为过期时间
|
||||
now func() time.Time
|
||||
}
|
||||
|
||||
func newValidatedTransport(base http.RoundTripper) *validatedTransport {
|
||||
return &validatedTransport{
|
||||
base: base,
|
||||
now: time.Now,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *validatedTransport) isValidatedHost(host string, now time.Time) bool {
|
||||
if t == nil {
|
||||
return false
|
||||
}
|
||||
raw, ok := t.validatedHosts.Load(host)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
expireAt, ok := raw.(time.Time)
|
||||
if !ok {
|
||||
t.validatedHosts.Delete(host)
|
||||
return false
|
||||
}
|
||||
if now.Before(expireAt) {
|
||||
return true
|
||||
}
|
||||
t.validatedHosts.Delete(host)
|
||||
return false
|
||||
}
|
||||
|
||||
func (t *validatedTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
if req != nil && req.URL != nil {
|
||||
host := strings.TrimSpace(req.URL.Hostname())
|
||||
host := strings.ToLower(strings.TrimSpace(req.URL.Hostname()))
|
||||
if host != "" {
|
||||
if err := urlvalidator.ValidateResolvedIP(host); err != nil {
|
||||
return nil, err
|
||||
now := time.Now()
|
||||
if t != nil && t.now != nil {
|
||||
now = t.now()
|
||||
}
|
||||
if !t.isValidatedHost(host, now) {
|
||||
if err := validateResolvedIP(host); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t.validatedHosts.Store(host, now.Add(validatedHostTTL))
|
||||
}
|
||||
}
|
||||
}
|
||||
if t == nil || t.base == nil {
|
||||
return nil, fmt.Errorf("validated transport base is nil")
|
||||
}
|
||||
return t.base.RoundTrip(req)
|
||||
}
|
||||
|
||||
115
backend/internal/pkg/httpclient/pool_test.go
Normal file
115
backend/internal/pkg/httpclient/pool_test.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package httpclient
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type roundTripFunc func(*http.Request) (*http.Response, error)
|
||||
|
||||
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return f(req)
|
||||
}
|
||||
|
||||
func TestValidatedTransport_CacheHostValidation(t *testing.T) {
|
||||
originalValidate := validateResolvedIP
|
||||
defer func() { validateResolvedIP = originalValidate }()
|
||||
|
||||
var validateCalls int32
|
||||
validateResolvedIP = func(host string) error {
|
||||
atomic.AddInt32(&validateCalls, 1)
|
||||
require.Equal(t, "api.openai.com", host)
|
||||
return nil
|
||||
}
|
||||
|
||||
var baseCalls int32
|
||||
base := roundTripFunc(func(_ *http.Request) (*http.Response, error) {
|
||||
atomic.AddInt32(&baseCalls, 1)
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(`{}`)),
|
||||
Header: make(http.Header),
|
||||
}, nil
|
||||
})
|
||||
|
||||
now := time.Unix(1730000000, 0)
|
||||
transport := newValidatedTransport(base)
|
||||
transport.now = func() time.Time { return now }
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, "https://api.openai.com/v1/responses", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = transport.RoundTrip(req)
|
||||
require.NoError(t, err)
|
||||
_, err = transport.RoundTrip(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&validateCalls))
|
||||
require.Equal(t, int32(2), atomic.LoadInt32(&baseCalls))
|
||||
}
|
||||
|
||||
func TestValidatedTransport_ExpiredCacheTriggersRevalidation(t *testing.T) {
|
||||
originalValidate := validateResolvedIP
|
||||
defer func() { validateResolvedIP = originalValidate }()
|
||||
|
||||
var validateCalls int32
|
||||
validateResolvedIP = func(_ string) error {
|
||||
atomic.AddInt32(&validateCalls, 1)
|
||||
return nil
|
||||
}
|
||||
|
||||
base := roundTripFunc(func(_ *http.Request) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(`{}`)),
|
||||
Header: make(http.Header),
|
||||
}, nil
|
||||
})
|
||||
|
||||
now := time.Unix(1730001000, 0)
|
||||
transport := newValidatedTransport(base)
|
||||
transport.now = func() time.Time { return now }
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, "https://api.openai.com/v1/responses", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = transport.RoundTrip(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
now = now.Add(validatedHostTTL + time.Second)
|
||||
_, err = transport.RoundTrip(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, int32(2), atomic.LoadInt32(&validateCalls))
|
||||
}
|
||||
|
||||
func TestValidatedTransport_ValidationErrorStopsRoundTrip(t *testing.T) {
|
||||
originalValidate := validateResolvedIP
|
||||
defer func() { validateResolvedIP = originalValidate }()
|
||||
|
||||
expectedErr := errors.New("dns rebinding rejected")
|
||||
validateResolvedIP = func(_ string) error {
|
||||
return expectedErr
|
||||
}
|
||||
|
||||
var baseCalls int32
|
||||
base := roundTripFunc(func(_ *http.Request) (*http.Response, error) {
|
||||
atomic.AddInt32(&baseCalls, 1)
|
||||
return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(strings.NewReader(`{}`))}, nil
|
||||
})
|
||||
|
||||
transport := newValidatedTransport(base)
|
||||
req, err := http.NewRequest(http.MethodGet, "https://api.openai.com/v1/responses", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = transport.RoundTrip(req)
|
||||
require.ErrorIs(t, err, expectedErr)
|
||||
require.Equal(t, int32(0), atomic.LoadInt32(&baseCalls))
|
||||
}
|
||||
37
backend/internal/pkg/httputil/body.go
Normal file
37
backend/internal/pkg/httputil/body.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package httputil
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
const (
|
||||
requestBodyReadInitCap = 512
|
||||
requestBodyReadMaxInitCap = 1 << 20
|
||||
)
|
||||
|
||||
// ReadRequestBodyWithPrealloc reads request body with preallocated buffer based on content length.
|
||||
func ReadRequestBodyWithPrealloc(req *http.Request) ([]byte, error) {
|
||||
if req == nil || req.Body == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
capHint := requestBodyReadInitCap
|
||||
if req.ContentLength > 0 {
|
||||
switch {
|
||||
case req.ContentLength < int64(requestBodyReadInitCap):
|
||||
capHint = requestBodyReadInitCap
|
||||
case req.ContentLength > int64(requestBodyReadMaxInitCap):
|
||||
capHint = requestBodyReadMaxInitCap
|
||||
default:
|
||||
capHint = int(req.ContentLength)
|
||||
}
|
||||
}
|
||||
|
||||
buf := bytes.NewBuffer(make([]byte, 0, capHint))
|
||||
if _, err := io.Copy(buf, req.Body); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
@@ -67,6 +67,14 @@ func normalizeIP(ip string) string {
|
||||
// privateNets 预编译私有 IP CIDR 块,避免每次调用 isPrivateIP 时重复解析
|
||||
var privateNets []*net.IPNet
|
||||
|
||||
// CompiledIPRules 表示预编译的 IP 匹配规则。
|
||||
// PatternCount 记录原始规则数量,用于保留“规则存在但全无效”时的行为语义。
|
||||
type CompiledIPRules struct {
|
||||
CIDRs []*net.IPNet
|
||||
IPs []net.IP
|
||||
PatternCount int
|
||||
}
|
||||
|
||||
func init() {
|
||||
for _, cidr := range []string{
|
||||
"10.0.0.0/8",
|
||||
@@ -84,6 +92,53 @@ func init() {
|
||||
}
|
||||
}
|
||||
|
||||
// CompileIPRules 将 IP/CIDR 字符串规则预编译为可复用结构。
|
||||
// 非法规则会被忽略,但 PatternCount 会保留原始规则条数。
|
||||
func CompileIPRules(patterns []string) *CompiledIPRules {
|
||||
compiled := &CompiledIPRules{
|
||||
CIDRs: make([]*net.IPNet, 0, len(patterns)),
|
||||
IPs: make([]net.IP, 0, len(patterns)),
|
||||
PatternCount: len(patterns),
|
||||
}
|
||||
for _, pattern := range patterns {
|
||||
normalized := strings.TrimSpace(pattern)
|
||||
if normalized == "" {
|
||||
continue
|
||||
}
|
||||
if strings.Contains(normalized, "/") {
|
||||
_, cidr, err := net.ParseCIDR(normalized)
|
||||
if err != nil || cidr == nil {
|
||||
continue
|
||||
}
|
||||
compiled.CIDRs = append(compiled.CIDRs, cidr)
|
||||
continue
|
||||
}
|
||||
parsedIP := net.ParseIP(normalized)
|
||||
if parsedIP == nil {
|
||||
continue
|
||||
}
|
||||
compiled.IPs = append(compiled.IPs, parsedIP)
|
||||
}
|
||||
return compiled
|
||||
}
|
||||
|
||||
func matchesCompiledRules(parsedIP net.IP, rules *CompiledIPRules) bool {
|
||||
if parsedIP == nil || rules == nil {
|
||||
return false
|
||||
}
|
||||
for _, cidr := range rules.CIDRs {
|
||||
if cidr.Contains(parsedIP) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
for _, ruleIP := range rules.IPs {
|
||||
if parsedIP.Equal(ruleIP) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// isPrivateIP 检查 IP 是否为私有地址。
|
||||
func isPrivateIP(ipStr string) bool {
|
||||
ip := net.ParseIP(ipStr)
|
||||
@@ -142,19 +197,32 @@ func MatchesAnyPattern(clientIP string, patterns []string) bool {
|
||||
// 2. 如果白名单不为空,IP 必须在白名单中
|
||||
// 3. 如果白名单为空,允许访问(除非被黑名单拒绝)
|
||||
func CheckIPRestriction(clientIP string, whitelist, blacklist []string) (bool, string) {
|
||||
return CheckIPRestrictionWithCompiledRules(
|
||||
clientIP,
|
||||
CompileIPRules(whitelist),
|
||||
CompileIPRules(blacklist),
|
||||
)
|
||||
}
|
||||
|
||||
// CheckIPRestrictionWithCompiledRules 使用预编译规则检查 IP 是否允许访问。
|
||||
func CheckIPRestrictionWithCompiledRules(clientIP string, whitelist, blacklist *CompiledIPRules) (bool, string) {
|
||||
// 规范化 IP
|
||||
clientIP = normalizeIP(clientIP)
|
||||
if clientIP == "" {
|
||||
return false, "access denied"
|
||||
}
|
||||
parsedIP := net.ParseIP(clientIP)
|
||||
if parsedIP == nil {
|
||||
return false, "access denied"
|
||||
}
|
||||
|
||||
// 1. 检查黑名单
|
||||
if len(blacklist) > 0 && MatchesAnyPattern(clientIP, blacklist) {
|
||||
if blacklist != nil && blacklist.PatternCount > 0 && matchesCompiledRules(parsedIP, blacklist) {
|
||||
return false, "access denied"
|
||||
}
|
||||
|
||||
// 2. 检查白名单(如果设置了白名单,IP 必须在其中)
|
||||
if len(whitelist) > 0 && !MatchesAnyPattern(clientIP, whitelist) {
|
||||
if whitelist != nil && whitelist.PatternCount > 0 && !matchesCompiledRules(parsedIP, whitelist) {
|
||||
return false, "access denied"
|
||||
}
|
||||
|
||||
|
||||
@@ -73,3 +73,24 @@ func TestGetTrustedClientIPUsesGinClientIP(t *testing.T) {
|
||||
require.Equal(t, 200, w.Code)
|
||||
require.Equal(t, "9.9.9.9", w.Body.String())
|
||||
}
|
||||
|
||||
func TestCheckIPRestrictionWithCompiledRules(t *testing.T) {
|
||||
whitelist := CompileIPRules([]string{"10.0.0.0/8", "192.168.1.2"})
|
||||
blacklist := CompileIPRules([]string{"10.1.1.1"})
|
||||
|
||||
allowed, reason := CheckIPRestrictionWithCompiledRules("10.2.3.4", whitelist, blacklist)
|
||||
require.True(t, allowed)
|
||||
require.Equal(t, "", reason)
|
||||
|
||||
allowed, reason = CheckIPRestrictionWithCompiledRules("10.1.1.1", whitelist, blacklist)
|
||||
require.False(t, allowed)
|
||||
require.Equal(t, "access denied", reason)
|
||||
}
|
||||
|
||||
func TestCheckIPRestrictionWithCompiledRules_InvalidWhitelistStillDenies(t *testing.T) {
|
||||
// 与旧实现保持一致:白名单有配置但全无效时,最终应拒绝访问。
|
||||
invalidWhitelist := CompileIPRules([]string{"not-a-valid-pattern"})
|
||||
allowed, reason := CheckIPRestrictionWithCompiledRules("8.8.8.8", invalidWhitelist, nil)
|
||||
require.False(t, allowed)
|
||||
require.Equal(t, "access denied", reason)
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
@@ -42,15 +43,19 @@ type LogEvent struct {
|
||||
|
||||
var (
|
||||
mu sync.RWMutex
|
||||
global *zap.Logger
|
||||
sugar *zap.SugaredLogger
|
||||
global atomic.Pointer[zap.Logger]
|
||||
sugar atomic.Pointer[zap.SugaredLogger]
|
||||
atomicLevel zap.AtomicLevel
|
||||
initOptions InitOptions
|
||||
currentSink Sink
|
||||
currentSink atomic.Value // sinkState
|
||||
stdLogUndo func()
|
||||
bootstrapOnce sync.Once
|
||||
)
|
||||
|
||||
type sinkState struct {
|
||||
sink Sink
|
||||
}
|
||||
|
||||
func InitBootstrap() {
|
||||
bootstrapOnce.Do(func() {
|
||||
if err := Init(bootstrapOptions()); err != nil {
|
||||
@@ -72,9 +77,9 @@ func initLocked(options InitOptions) error {
|
||||
return err
|
||||
}
|
||||
|
||||
prev := global
|
||||
global = zl
|
||||
sugar = zl.Sugar()
|
||||
prev := global.Load()
|
||||
global.Store(zl)
|
||||
sugar.Store(zl.Sugar())
|
||||
atomicLevel = al
|
||||
initOptions = normalized
|
||||
|
||||
@@ -115,24 +120,32 @@ func SetLevel(level string) error {
|
||||
func CurrentLevel() string {
|
||||
mu.RLock()
|
||||
defer mu.RUnlock()
|
||||
if global == nil {
|
||||
if global.Load() == nil {
|
||||
return "info"
|
||||
}
|
||||
return atomicLevel.Level().String()
|
||||
}
|
||||
|
||||
func SetSink(sink Sink) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
currentSink = sink
|
||||
currentSink.Store(sinkState{sink: sink})
|
||||
}
|
||||
|
||||
func loadSink() Sink {
|
||||
v := currentSink.Load()
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
state, ok := v.(sinkState)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return state.sink
|
||||
}
|
||||
|
||||
// WriteSinkEvent 直接写入日志 sink,不经过全局日志级别门控。
|
||||
// 用于需要“可观测性入库”与“业务输出级别”解耦的场景(例如 ops 系统日志索引)。
|
||||
func WriteSinkEvent(level, component, message string, fields map[string]any) {
|
||||
mu.RLock()
|
||||
sink := currentSink
|
||||
mu.RUnlock()
|
||||
sink := loadSink()
|
||||
if sink == nil {
|
||||
return
|
||||
}
|
||||
@@ -168,19 +181,15 @@ func WriteSinkEvent(level, component, message string, fields map[string]any) {
|
||||
}
|
||||
|
||||
func L() *zap.Logger {
|
||||
mu.RLock()
|
||||
defer mu.RUnlock()
|
||||
if global != nil {
|
||||
return global
|
||||
if l := global.Load(); l != nil {
|
||||
return l
|
||||
}
|
||||
return zap.NewNop()
|
||||
}
|
||||
|
||||
func S() *zap.SugaredLogger {
|
||||
mu.RLock()
|
||||
defer mu.RUnlock()
|
||||
if sugar != nil {
|
||||
return sugar
|
||||
if s := sugar.Load(); s != nil {
|
||||
return s
|
||||
}
|
||||
return zap.NewNop().Sugar()
|
||||
}
|
||||
@@ -190,9 +199,7 @@ func With(fields ...zap.Field) *zap.Logger {
|
||||
}
|
||||
|
||||
func Sync() {
|
||||
mu.RLock()
|
||||
l := global
|
||||
mu.RUnlock()
|
||||
l := global.Load()
|
||||
if l != nil {
|
||||
_ = l.Sync()
|
||||
}
|
||||
@@ -210,7 +217,11 @@ func bridgeStdLogLocked() {
|
||||
|
||||
log.SetFlags(0)
|
||||
log.SetPrefix("")
|
||||
log.SetOutput(newStdLogBridge(global.Named("stdlog")))
|
||||
base := global.Load()
|
||||
if base == nil {
|
||||
base = zap.NewNop()
|
||||
}
|
||||
log.SetOutput(newStdLogBridge(base.Named("stdlog")))
|
||||
|
||||
stdLogUndo = func() {
|
||||
log.SetOutput(prevWriter)
|
||||
@@ -220,7 +231,11 @@ func bridgeStdLogLocked() {
|
||||
}
|
||||
|
||||
func bridgeSlogLocked() {
|
||||
slog.SetDefault(slog.New(newSlogZapHandler(global.Named("slog"))))
|
||||
base := global.Load()
|
||||
if base == nil {
|
||||
base = zap.NewNop()
|
||||
}
|
||||
slog.SetDefault(slog.New(newSlogZapHandler(base.Named("slog"))))
|
||||
}
|
||||
|
||||
func buildLogger(options InitOptions) (*zap.Logger, zap.AtomicLevel, error) {
|
||||
@@ -363,9 +378,7 @@ func (s *sinkCore) Check(entry zapcore.Entry, ce *zapcore.CheckedEntry) *zapcore
|
||||
func (s *sinkCore) Write(entry zapcore.Entry, fields []zapcore.Field) error {
|
||||
// Only handle sink forwarding — the inner cores write via their own
|
||||
// Write methods (added to CheckedEntry by s.core.Check above).
|
||||
mu.RLock()
|
||||
sink := currentSink
|
||||
mu.RUnlock()
|
||||
sink := loadSink()
|
||||
if sink == nil {
|
||||
return nil
|
||||
}
|
||||
@@ -454,7 +467,7 @@ func inferStdLogLevel(msg string) Level {
|
||||
if strings.Contains(lower, " failed") || strings.Contains(lower, "error") || strings.Contains(lower, "panic") || strings.Contains(lower, "fatal") {
|
||||
return LevelError
|
||||
}
|
||||
if strings.Contains(lower, "warning") || strings.Contains(lower, "warn") || strings.Contains(lower, " retry") || strings.Contains(lower, " queue full") || strings.Contains(lower, "fallback") {
|
||||
if strings.Contains(lower, "warning") || strings.Contains(lower, "warn") || strings.Contains(lower, " queue full") || strings.Contains(lower, "fallback") {
|
||||
return LevelWarn
|
||||
}
|
||||
return LevelInfo
|
||||
@@ -467,9 +480,7 @@ func LegacyPrintf(component, format string, args ...any) {
|
||||
return
|
||||
}
|
||||
|
||||
mu.RLock()
|
||||
initialized := global != nil
|
||||
mu.RUnlock()
|
||||
initialized := global.Load() != nil
|
||||
if !initialized {
|
||||
// 在日志系统未初始化前,回退到标准库 log,避免测试/工具链丢日志。
|
||||
log.Print(msg)
|
||||
|
||||
@@ -48,16 +48,15 @@ func (h *slogZapHandler) Handle(_ context.Context, record slog.Record) error {
|
||||
return true
|
||||
})
|
||||
|
||||
entry := h.logger.With(fields...)
|
||||
switch {
|
||||
case record.Level >= slog.LevelError:
|
||||
entry.Error(record.Message)
|
||||
h.logger.Error(record.Message, fields...)
|
||||
case record.Level >= slog.LevelWarn:
|
||||
entry.Warn(record.Message)
|
||||
h.logger.Warn(record.Message, fields...)
|
||||
case record.Level <= slog.LevelDebug:
|
||||
entry.Debug(record.Message)
|
||||
h.logger.Debug(record.Message, fields...)
|
||||
default:
|
||||
entry.Info(record.Message)
|
||||
h.logger.Info(record.Message, fields...)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ func TestInferStdLogLevel(t *testing.T) {
|
||||
{msg: "Warning: queue full", want: LevelWarn},
|
||||
{msg: "Forward request failed: timeout", want: LevelError},
|
||||
{msg: "[ERROR] upstream unavailable", want: LevelError},
|
||||
{msg: "[OpenAI WS Mode] reconnect_retry account_id=22 retry=1 max_retries=5", want: LevelInfo},
|
||||
{msg: "service started", want: LevelInfo},
|
||||
{msg: "debug: cache miss", want: LevelDebug},
|
||||
}
|
||||
|
||||
@@ -36,10 +36,18 @@ const (
|
||||
SessionTTL = 30 * time.Minute
|
||||
)
|
||||
|
||||
const (
|
||||
// OAuthPlatformOpenAI uses OpenAI Codex-compatible OAuth client.
|
||||
OAuthPlatformOpenAI = "openai"
|
||||
// OAuthPlatformSora uses Sora OAuth client.
|
||||
OAuthPlatformSora = "sora"
|
||||
)
|
||||
|
||||
// OAuthSession stores OAuth flow state for OpenAI
|
||||
type OAuthSession struct {
|
||||
State string `json:"state"`
|
||||
CodeVerifier string `json:"code_verifier"`
|
||||
ClientID string `json:"client_id,omitempty"`
|
||||
ProxyURL string `json:"proxy_url,omitempty"`
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
@@ -174,13 +182,20 @@ func base64URLEncode(data []byte) string {
|
||||
|
||||
// BuildAuthorizationURL builds the OpenAI OAuth authorization URL
|
||||
func BuildAuthorizationURL(state, codeChallenge, redirectURI string) string {
|
||||
return BuildAuthorizationURLForPlatform(state, codeChallenge, redirectURI, OAuthPlatformOpenAI)
|
||||
}
|
||||
|
||||
// BuildAuthorizationURLForPlatform builds authorization URL by platform.
|
||||
func BuildAuthorizationURLForPlatform(state, codeChallenge, redirectURI, platform string) string {
|
||||
if redirectURI == "" {
|
||||
redirectURI = DefaultRedirectURI
|
||||
}
|
||||
|
||||
clientID, codexFlow := OAuthClientConfigByPlatform(platform)
|
||||
|
||||
params := url.Values{}
|
||||
params.Set("response_type", "code")
|
||||
params.Set("client_id", ClientID)
|
||||
params.Set("client_id", clientID)
|
||||
params.Set("redirect_uri", redirectURI)
|
||||
params.Set("scope", DefaultScopes)
|
||||
params.Set("state", state)
|
||||
@@ -188,11 +203,25 @@ func BuildAuthorizationURL(state, codeChallenge, redirectURI string) string {
|
||||
params.Set("code_challenge_method", "S256")
|
||||
// OpenAI specific parameters
|
||||
params.Set("id_token_add_organizations", "true")
|
||||
params.Set("codex_cli_simplified_flow", "true")
|
||||
if codexFlow {
|
||||
params.Set("codex_cli_simplified_flow", "true")
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s?%s", AuthorizeURL, params.Encode())
|
||||
}
|
||||
|
||||
// OAuthClientConfigByPlatform returns oauth client_id and whether codex simplified flow should be enabled.
|
||||
// Sora 授权流程复用 Codex CLI 的 client_id(支持 localhost redirect_uri),
|
||||
// 但不启用 codex_cli_simplified_flow;拿到的 access_token 绑定同一 OpenAI 账号,对 Sora API 同样可用。
|
||||
func OAuthClientConfigByPlatform(platform string) (clientID string, codexFlow bool) {
|
||||
switch strings.ToLower(strings.TrimSpace(platform)) {
|
||||
case OAuthPlatformSora:
|
||||
return ClientID, false
|
||||
default:
|
||||
return ClientID, true
|
||||
}
|
||||
}
|
||||
|
||||
// TokenRequest represents the token exchange request body
|
||||
type TokenRequest struct {
|
||||
GrantType string `json:"grant_type"`
|
||||
@@ -296,9 +325,11 @@ func (r *RefreshTokenRequest) ToFormData() string {
|
||||
return params.Encode()
|
||||
}
|
||||
|
||||
// ParseIDToken parses the ID Token JWT and extracts claims
|
||||
// Note: This does NOT verify the signature - it only decodes the payload
|
||||
// For production, you should verify the token signature using OpenAI's public keys
|
||||
// ParseIDToken parses the ID Token JWT and extracts claims.
|
||||
// 注意:当前仅解码 payload 并校验 exp,未验证 JWT 签名。
|
||||
// 生产环境如需用 ID Token 做授权决策,应通过 OpenAI 的 JWKS 端点验证签名:
|
||||
//
|
||||
// https://auth.openai.com/.well-known/jwks.json
|
||||
func ParseIDToken(idToken string) (*IDTokenClaims, error) {
|
||||
parts := strings.Split(idToken, ".")
|
||||
if len(parts) != 3 {
|
||||
@@ -329,6 +360,13 @@ func ParseIDToken(idToken string) (*IDTokenClaims, error) {
|
||||
return nil, fmt.Errorf("failed to parse JWT claims: %w", err)
|
||||
}
|
||||
|
||||
// 校验 ID Token 是否已过期(允许 2 分钟时钟偏差,防止因服务器时钟略有差异误判刚颁发的令牌)
|
||||
const clockSkewTolerance = 120 // 秒
|
||||
now := time.Now().Unix()
|
||||
if claims.Exp > 0 && now > claims.Exp+clockSkewTolerance {
|
||||
return nil, fmt.Errorf("id_token has expired (exp: %d, now: %d, skew_tolerance: %ds)", claims.Exp, now, clockSkewTolerance)
|
||||
}
|
||||
|
||||
return &claims, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -41,3 +42,41 @@ func TestSessionStore_Stop_Concurrent(t *testing.T) {
|
||||
t.Fatal("stopCh 未关闭")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAuthorizationURLForPlatform_OpenAI(t *testing.T) {
|
||||
authURL := BuildAuthorizationURLForPlatform("state-1", "challenge-1", DefaultRedirectURI, OAuthPlatformOpenAI)
|
||||
parsed, err := url.Parse(authURL)
|
||||
if err != nil {
|
||||
t.Fatalf("Parse URL failed: %v", err)
|
||||
}
|
||||
q := parsed.Query()
|
||||
if got := q.Get("client_id"); got != ClientID {
|
||||
t.Fatalf("client_id mismatch: got=%q want=%q", got, ClientID)
|
||||
}
|
||||
if got := q.Get("codex_cli_simplified_flow"); got != "true" {
|
||||
t.Fatalf("codex flow mismatch: got=%q want=true", got)
|
||||
}
|
||||
if got := q.Get("id_token_add_organizations"); got != "true" {
|
||||
t.Fatalf("id_token_add_organizations mismatch: got=%q want=true", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildAuthorizationURLForPlatform_Sora 验证 Sora 平台复用 Codex CLI 的 client_id,
|
||||
// 但不启用 codex_cli_simplified_flow。
|
||||
func TestBuildAuthorizationURLForPlatform_Sora(t *testing.T) {
|
||||
authURL := BuildAuthorizationURLForPlatform("state-2", "challenge-2", DefaultRedirectURI, OAuthPlatformSora)
|
||||
parsed, err := url.Parse(authURL)
|
||||
if err != nil {
|
||||
t.Fatalf("Parse URL failed: %v", err)
|
||||
}
|
||||
q := parsed.Query()
|
||||
if got := q.Get("client_id"); got != ClientID {
|
||||
t.Fatalf("client_id mismatch: got=%q want=%q (Sora should reuse Codex CLI client_id)", got, ClientID)
|
||||
}
|
||||
if got := q.Get("codex_cli_simplified_flow"); got != "" {
|
||||
t.Fatalf("codex flow should be empty for sora, got=%q", got)
|
||||
}
|
||||
if got := q.Get("id_token_add_organizations"); got != "true" {
|
||||
t.Fatalf("id_token_add_organizations mismatch: got=%q want=true", got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -29,10 +29,10 @@ func parsePaginatedBody(t *testing.T, w *httptest.ResponseRecorder) (Response, P
|
||||
t.Helper()
|
||||
// 先用 raw json 解析,因为 Data 是 any 类型
|
||||
var raw struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Reason string `json:"reason,omitempty"`
|
||||
Data json.RawMessage `json:"data,omitempty"`
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Reason string `json:"reason,omitempty"`
|
||||
Data json.RawMessage `json:"data,omitempty"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &raw))
|
||||
|
||||
|
||||
@@ -268,8 +268,8 @@ func (d *SOCKS5ProxyDialer) DialTLSContext(ctx context.Context, network, addr st
|
||||
"cipher_suites", len(spec.CipherSuites),
|
||||
"extensions", len(spec.Extensions),
|
||||
"compression_methods", spec.CompressionMethods,
|
||||
"tls_vers_max", fmt.Sprintf("0x%04x", spec.TLSVersMax),
|
||||
"tls_vers_min", fmt.Sprintf("0x%04x", spec.TLSVersMin))
|
||||
"tls_vers_max", spec.TLSVersMax,
|
||||
"tls_vers_min", spec.TLSVersMin)
|
||||
|
||||
if d.profile != nil {
|
||||
slog.Debug("tls_fingerprint_socks5_using_profile", "name", d.profile.Name, "grease", d.profile.EnableGREASE)
|
||||
@@ -294,8 +294,8 @@ func (d *SOCKS5ProxyDialer) DialTLSContext(ctx context.Context, network, addr st
|
||||
|
||||
state := tlsConn.ConnectionState()
|
||||
slog.Debug("tls_fingerprint_socks5_handshake_success",
|
||||
"version", fmt.Sprintf("0x%04x", state.Version),
|
||||
"cipher_suite", fmt.Sprintf("0x%04x", state.CipherSuite),
|
||||
"version", state.Version,
|
||||
"cipher_suite", state.CipherSuite,
|
||||
"alpn", state.NegotiatedProtocol)
|
||||
|
||||
return tlsConn, nil
|
||||
@@ -404,8 +404,8 @@ func (d *HTTPProxyDialer) DialTLSContext(ctx context.Context, network, addr stri
|
||||
|
||||
state := tlsConn.ConnectionState()
|
||||
slog.Debug("tls_fingerprint_http_proxy_handshake_success",
|
||||
"version", fmt.Sprintf("0x%04x", state.Version),
|
||||
"cipher_suite", fmt.Sprintf("0x%04x", state.CipherSuite),
|
||||
"version", state.Version,
|
||||
"cipher_suite", state.CipherSuite,
|
||||
"alpn", state.NegotiatedProtocol)
|
||||
|
||||
return tlsConn, nil
|
||||
@@ -470,8 +470,8 @@ func (d *Dialer) DialTLSContext(ctx context.Context, network, addr string) (net.
|
||||
// Log successful handshake details
|
||||
state := tlsConn.ConnectionState()
|
||||
slog.Debug("tls_fingerprint_handshake_success",
|
||||
"version", fmt.Sprintf("0x%04x", state.Version),
|
||||
"cipher_suite", fmt.Sprintf("0x%04x", state.CipherSuite),
|
||||
"version", state.Version,
|
||||
"cipher_suite", state.CipherSuite,
|
||||
"alpn", state.NegotiatedProtocol)
|
||||
|
||||
return tlsConn, nil
|
||||
|
||||
@@ -139,6 +139,7 @@ type UsageLogFilters struct {
|
||||
AccountID int64
|
||||
GroupID int64
|
||||
Model string
|
||||
RequestType *int16
|
||||
Stream *bool
|
||||
BillingType *int8
|
||||
StartTime *time.Time
|
||||
|
||||
Reference in New Issue
Block a user