feat(sync): full code sync from release

This commit is contained in:
yangjianbo
2026-02-28 15:01:20 +08:00
parent bfc7b339f7
commit bb664d9bbf
338 changed files with 54513 additions and 2011 deletions

View File

@@ -152,6 +152,7 @@ var claudeModels = []modelDef{
{ID: "claude-sonnet-4-5", DisplayName: "Claude Sonnet 4.5", CreatedAt: "2025-09-29T00:00:00Z"},
{ID: "claude-sonnet-4-5-thinking", DisplayName: "Claude Sonnet 4.5 Thinking", CreatedAt: "2025-09-29T00:00:00Z"},
{ID: "claude-opus-4-6", DisplayName: "Claude Opus 4.6", CreatedAt: "2026-02-05T00:00:00Z"},
{ID: "claude-opus-4-6-thinking", DisplayName: "Claude Opus 4.6 Thinking", CreatedAt: "2026-02-05T00:00:00Z"},
{ID: "claude-sonnet-4-6", DisplayName: "Claude Sonnet 4.6", CreatedAt: "2026-02-17T00:00:00Z"},
}
@@ -165,6 +166,8 @@ var geminiModels = []modelDef{
{ID: "gemini-3-pro-high", DisplayName: "Gemini 3 Pro High", CreatedAt: "2025-06-01T00:00:00Z"},
{ID: "gemini-3.1-pro-low", DisplayName: "Gemini 3.1 Pro Low", CreatedAt: "2026-02-19T00:00:00Z"},
{ID: "gemini-3.1-pro-high", DisplayName: "Gemini 3.1 Pro High", CreatedAt: "2026-02-19T00:00:00Z"},
{ID: "gemini-3.1-flash-image", DisplayName: "Gemini 3.1 Flash Image", CreatedAt: "2026-02-19T00:00:00Z"},
{ID: "gemini-3.1-flash-image-preview", DisplayName: "Gemini 3.1 Flash Image Preview", CreatedAt: "2026-02-19T00:00:00Z"},
{ID: "gemini-3-pro-preview", DisplayName: "Gemini 3 Pro Preview", CreatedAt: "2025-06-01T00:00:00Z"},
{ID: "gemini-3-pro-image", DisplayName: "Gemini 3 Pro Image", CreatedAt: "2025-06-01T00:00:00Z"},
}

View File

@@ -0,0 +1,26 @@
package antigravity
import "testing"
func TestDefaultModels_ContainsNewAndLegacyImageModels(t *testing.T) {
t.Parallel()
models := DefaultModels()
byID := make(map[string]ClaudeModel, len(models))
for _, m := range models {
byID[m.ID] = m
}
requiredIDs := []string{
"claude-opus-4-6-thinking",
"gemini-3.1-flash-image",
"gemini-3.1-flash-image-preview",
"gemini-3-pro-image", // legacy compatibility
}
for _, id := range requiredIDs {
if _, ok := byID[id]; !ok {
t.Fatalf("expected model %q to be exposed in DefaultModels", id)
}
}
}

View File

@@ -70,7 +70,7 @@ type GeminiGenerationConfig struct {
ImageConfig *GeminiImageConfig `json:"imageConfig,omitempty"`
}
// GeminiImageConfig Gemini 图片生成配置(gemini-3-pro-image 支持)
// GeminiImageConfig Gemini 图片生成配置gemini-3-pro-image / gemini-3.1-flash-image 等图片模型支持)
type GeminiImageConfig struct {
AspectRatio string `json:"aspectRatio,omitempty"` // "1:1", "16:9", "9:16", "4:3", "3:4"
ImageSize string `json:"imageSize,omitempty"` // "1K", "2K", "4K"

View File

@@ -53,7 +53,8 @@ const (
var defaultUserAgentVersion = "1.19.6"
// defaultClientSecret 可通过环境变量 ANTIGRAVITY_OAUTH_CLIENT_SECRET 配置
var defaultClientSecret = "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"
// 默认值使用占位符,生产环境请通过环境变量注入真实值。
var defaultClientSecret = "GOCSPX-your-client-secret"
func init() {
// 从环境变量读取版本号,未设置则使用默认值

View File

@@ -612,14 +612,14 @@ func TestBuildAuthorizationURL_参数验证(t *testing.T) {
expectedParams := map[string]string{
"client_id": ClientID,
"redirect_uri": RedirectURI,
"response_type": "code",
"scope": Scopes,
"state": state,
"code_challenge": codeChallenge,
"code_challenge_method": "S256",
"access_type": "offline",
"prompt": "consent",
"redirect_uri": RedirectURI,
"response_type": "code",
"scope": Scopes,
"state": state,
"code_challenge": codeChallenge,
"code_challenge_method": "S256",
"access_type": "offline",
"prompt": "consent",
"include_granted_scopes": "true",
}
@@ -684,7 +684,7 @@ func TestConstants_值正确(t *testing.T) {
if err != nil {
t.Fatalf("getClientSecret 应返回默认值,但报错: %v", err)
}
if secret != "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" {
if secret != "GOCSPX-your-client-secret" {
t.Errorf("默认 client_secret 不匹配: got %s", secret)
}
if RedirectURI != "http://localhost:8085/callback" {

View File

@@ -166,3 +166,18 @@ func TestToHTTP(t *testing.T) {
})
}
}
func TestToHTTP_MetadataDeepCopy(t *testing.T) {
md := map[string]string{"k": "v"}
appErr := BadRequest("BAD_REQUEST", "invalid").WithMetadata(md)
code, body := ToHTTP(appErr)
require.Equal(t, http.StatusBadRequest, code)
require.Equal(t, "v", body.Metadata["k"])
md["k"] = "changed"
require.Equal(t, "v", body.Metadata["k"])
appErr.Metadata["k"] = "changed-again"
require.Equal(t, "v", body.Metadata["k"])
}

View File

@@ -16,6 +16,16 @@ func ToHTTP(err error) (statusCode int, body Status) {
return http.StatusOK, Status{Code: int32(http.StatusOK)}
}
cloned := Clone(appErr)
return int(cloned.Code), cloned.Status
body = Status{
Code: appErr.Code,
Reason: appErr.Reason,
Message: appErr.Message,
}
if appErr.Metadata != nil {
body.Metadata = make(map[string]string, len(appErr.Metadata))
for k, v := range appErr.Metadata {
body.Metadata[k] = v
}
}
return int(appErr.Code), body
}

View File

@@ -39,7 +39,7 @@ const (
// They enable the "login without creating your own OAuth client" experience, but Google may
// restrict which scopes are allowed for this client.
GeminiCLIOAuthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
GeminiCLIOAuthClientSecret = "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"
GeminiCLIOAuthClientSecret = "GOCSPX-your-client-secret"
// GeminiCLIOAuthClientSecretEnv is the environment variable name for the built-in client secret.
GeminiCLIOAuthClientSecretEnv = "GEMINI_CLI_OAUTH_CLIENT_SECRET"

View File

@@ -32,6 +32,7 @@ const (
defaultMaxIdleConns = 100 // 最大空闲连接数
defaultMaxIdleConnsPerHost = 10 // 每个主机最大空闲连接数
defaultIdleConnTimeout = 90 * time.Second // 空闲连接超时时间(建议小于上游 LB 超时)
validatedHostTTL = 30 * time.Second // DNS Rebinding 校验缓存 TTL
)
// Options 定义共享 HTTP 客户端的构建参数
@@ -53,6 +54,9 @@ type Options struct {
// sharedClients 存储按配置参数缓存的 http.Client 实例
var sharedClients sync.Map
// 允许测试替换校验函数,生产默认指向真实实现。
var validateResolvedIP = urlvalidator.ValidateResolvedIP
// GetClient 返回共享的 HTTP 客户端实例
// 性能优化:相同配置复用同一客户端,避免重复创建 Transport
// 安全说明:代理配置失败时直接返回错误,不会回退到直连,避免 IP 关联风险
@@ -84,7 +88,7 @@ func buildClient(opts Options) (*http.Client, error) {
var rt http.RoundTripper = transport
if opts.ValidateResolvedIP && !opts.AllowPrivateHosts {
rt = &validatedTransport{base: transport}
rt = newValidatedTransport(transport)
}
return &http.Client{
Transport: rt,
@@ -149,17 +153,56 @@ func buildClientKey(opts Options) string {
}
type validatedTransport struct {
base http.RoundTripper
base http.RoundTripper
validatedHosts sync.Map // map[string]time.Time, value 为过期时间
now func() time.Time
}
func newValidatedTransport(base http.RoundTripper) *validatedTransport {
return &validatedTransport{
base: base,
now: time.Now,
}
}
func (t *validatedTransport) isValidatedHost(host string, now time.Time) bool {
if t == nil {
return false
}
raw, ok := t.validatedHosts.Load(host)
if !ok {
return false
}
expireAt, ok := raw.(time.Time)
if !ok {
t.validatedHosts.Delete(host)
return false
}
if now.Before(expireAt) {
return true
}
t.validatedHosts.Delete(host)
return false
}
func (t *validatedTransport) RoundTrip(req *http.Request) (*http.Response, error) {
if req != nil && req.URL != nil {
host := strings.TrimSpace(req.URL.Hostname())
host := strings.ToLower(strings.TrimSpace(req.URL.Hostname()))
if host != "" {
if err := urlvalidator.ValidateResolvedIP(host); err != nil {
return nil, err
now := time.Now()
if t != nil && t.now != nil {
now = t.now()
}
if !t.isValidatedHost(host, now) {
if err := validateResolvedIP(host); err != nil {
return nil, err
}
t.validatedHosts.Store(host, now.Add(validatedHostTTL))
}
}
}
if t == nil || t.base == nil {
return nil, fmt.Errorf("validated transport base is nil")
}
return t.base.RoundTrip(req)
}

View File

@@ -0,0 +1,115 @@
package httpclient
import (
"errors"
"io"
"net/http"
"strings"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/require"
)
type roundTripFunc func(*http.Request) (*http.Response, error)
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req)
}
func TestValidatedTransport_CacheHostValidation(t *testing.T) {
originalValidate := validateResolvedIP
defer func() { validateResolvedIP = originalValidate }()
var validateCalls int32
validateResolvedIP = func(host string) error {
atomic.AddInt32(&validateCalls, 1)
require.Equal(t, "api.openai.com", host)
return nil
}
var baseCalls int32
base := roundTripFunc(func(_ *http.Request) (*http.Response, error) {
atomic.AddInt32(&baseCalls, 1)
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(`{}`)),
Header: make(http.Header),
}, nil
})
now := time.Unix(1730000000, 0)
transport := newValidatedTransport(base)
transport.now = func() time.Time { return now }
req, err := http.NewRequest(http.MethodGet, "https://api.openai.com/v1/responses", nil)
require.NoError(t, err)
_, err = transport.RoundTrip(req)
require.NoError(t, err)
_, err = transport.RoundTrip(req)
require.NoError(t, err)
require.Equal(t, int32(1), atomic.LoadInt32(&validateCalls))
require.Equal(t, int32(2), atomic.LoadInt32(&baseCalls))
}
func TestValidatedTransport_ExpiredCacheTriggersRevalidation(t *testing.T) {
originalValidate := validateResolvedIP
defer func() { validateResolvedIP = originalValidate }()
var validateCalls int32
validateResolvedIP = func(_ string) error {
atomic.AddInt32(&validateCalls, 1)
return nil
}
base := roundTripFunc(func(_ *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(`{}`)),
Header: make(http.Header),
}, nil
})
now := time.Unix(1730001000, 0)
transport := newValidatedTransport(base)
transport.now = func() time.Time { return now }
req, err := http.NewRequest(http.MethodGet, "https://api.openai.com/v1/responses", nil)
require.NoError(t, err)
_, err = transport.RoundTrip(req)
require.NoError(t, err)
now = now.Add(validatedHostTTL + time.Second)
_, err = transport.RoundTrip(req)
require.NoError(t, err)
require.Equal(t, int32(2), atomic.LoadInt32(&validateCalls))
}
func TestValidatedTransport_ValidationErrorStopsRoundTrip(t *testing.T) {
originalValidate := validateResolvedIP
defer func() { validateResolvedIP = originalValidate }()
expectedErr := errors.New("dns rebinding rejected")
validateResolvedIP = func(_ string) error {
return expectedErr
}
var baseCalls int32
base := roundTripFunc(func(_ *http.Request) (*http.Response, error) {
atomic.AddInt32(&baseCalls, 1)
return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(strings.NewReader(`{}`))}, nil
})
transport := newValidatedTransport(base)
req, err := http.NewRequest(http.MethodGet, "https://api.openai.com/v1/responses", nil)
require.NoError(t, err)
_, err = transport.RoundTrip(req)
require.ErrorIs(t, err, expectedErr)
require.Equal(t, int32(0), atomic.LoadInt32(&baseCalls))
}

View File

@@ -0,0 +1,37 @@
package httputil
import (
"bytes"
"io"
"net/http"
)
const (
requestBodyReadInitCap = 512
requestBodyReadMaxInitCap = 1 << 20
)
// ReadRequestBodyWithPrealloc reads request body with preallocated buffer based on content length.
func ReadRequestBodyWithPrealloc(req *http.Request) ([]byte, error) {
if req == nil || req.Body == nil {
return nil, nil
}
capHint := requestBodyReadInitCap
if req.ContentLength > 0 {
switch {
case req.ContentLength < int64(requestBodyReadInitCap):
capHint = requestBodyReadInitCap
case req.ContentLength > int64(requestBodyReadMaxInitCap):
capHint = requestBodyReadMaxInitCap
default:
capHint = int(req.ContentLength)
}
}
buf := bytes.NewBuffer(make([]byte, 0, capHint))
if _, err := io.Copy(buf, req.Body); err != nil {
return nil, err
}
return buf.Bytes(), nil
}

View File

@@ -67,6 +67,14 @@ func normalizeIP(ip string) string {
// privateNets 预编译私有 IP CIDR 块,避免每次调用 isPrivateIP 时重复解析
var privateNets []*net.IPNet
// CompiledIPRules 表示预编译的 IP 匹配规则。
// PatternCount 记录原始规则数量,用于保留“规则存在但全无效”时的行为语义。
type CompiledIPRules struct {
CIDRs []*net.IPNet
IPs []net.IP
PatternCount int
}
func init() {
for _, cidr := range []string{
"10.0.0.0/8",
@@ -84,6 +92,53 @@ func init() {
}
}
// CompileIPRules 将 IP/CIDR 字符串规则预编译为可复用结构。
// 非法规则会被忽略,但 PatternCount 会保留原始规则条数。
func CompileIPRules(patterns []string) *CompiledIPRules {
compiled := &CompiledIPRules{
CIDRs: make([]*net.IPNet, 0, len(patterns)),
IPs: make([]net.IP, 0, len(patterns)),
PatternCount: len(patterns),
}
for _, pattern := range patterns {
normalized := strings.TrimSpace(pattern)
if normalized == "" {
continue
}
if strings.Contains(normalized, "/") {
_, cidr, err := net.ParseCIDR(normalized)
if err != nil || cidr == nil {
continue
}
compiled.CIDRs = append(compiled.CIDRs, cidr)
continue
}
parsedIP := net.ParseIP(normalized)
if parsedIP == nil {
continue
}
compiled.IPs = append(compiled.IPs, parsedIP)
}
return compiled
}
func matchesCompiledRules(parsedIP net.IP, rules *CompiledIPRules) bool {
if parsedIP == nil || rules == nil {
return false
}
for _, cidr := range rules.CIDRs {
if cidr.Contains(parsedIP) {
return true
}
}
for _, ruleIP := range rules.IPs {
if parsedIP.Equal(ruleIP) {
return true
}
}
return false
}
// isPrivateIP 检查 IP 是否为私有地址。
func isPrivateIP(ipStr string) bool {
ip := net.ParseIP(ipStr)
@@ -142,19 +197,32 @@ func MatchesAnyPattern(clientIP string, patterns []string) bool {
// 2. 如果白名单不为空IP 必须在白名单中
// 3. 如果白名单为空,允许访问(除非被黑名单拒绝)
func CheckIPRestriction(clientIP string, whitelist, blacklist []string) (bool, string) {
return CheckIPRestrictionWithCompiledRules(
clientIP,
CompileIPRules(whitelist),
CompileIPRules(blacklist),
)
}
// CheckIPRestrictionWithCompiledRules 使用预编译规则检查 IP 是否允许访问。
func CheckIPRestrictionWithCompiledRules(clientIP string, whitelist, blacklist *CompiledIPRules) (bool, string) {
// 规范化 IP
clientIP = normalizeIP(clientIP)
if clientIP == "" {
return false, "access denied"
}
parsedIP := net.ParseIP(clientIP)
if parsedIP == nil {
return false, "access denied"
}
// 1. 检查黑名单
if len(blacklist) > 0 && MatchesAnyPattern(clientIP, blacklist) {
if blacklist != nil && blacklist.PatternCount > 0 && matchesCompiledRules(parsedIP, blacklist) {
return false, "access denied"
}
// 2. 检查白名单如果设置了白名单IP 必须在其中)
if len(whitelist) > 0 && !MatchesAnyPattern(clientIP, whitelist) {
if whitelist != nil && whitelist.PatternCount > 0 && !matchesCompiledRules(parsedIP, whitelist) {
return false, "access denied"
}

View File

@@ -73,3 +73,24 @@ func TestGetTrustedClientIPUsesGinClientIP(t *testing.T) {
require.Equal(t, 200, w.Code)
require.Equal(t, "9.9.9.9", w.Body.String())
}
func TestCheckIPRestrictionWithCompiledRules(t *testing.T) {
whitelist := CompileIPRules([]string{"10.0.0.0/8", "192.168.1.2"})
blacklist := CompileIPRules([]string{"10.1.1.1"})
allowed, reason := CheckIPRestrictionWithCompiledRules("10.2.3.4", whitelist, blacklist)
require.True(t, allowed)
require.Equal(t, "", reason)
allowed, reason = CheckIPRestrictionWithCompiledRules("10.1.1.1", whitelist, blacklist)
require.False(t, allowed)
require.Equal(t, "access denied", reason)
}
func TestCheckIPRestrictionWithCompiledRules_InvalidWhitelistStillDenies(t *testing.T) {
// 与旧实现保持一致:白名单有配置但全无效时,最终应拒绝访问。
invalidWhitelist := CompileIPRules([]string{"not-a-valid-pattern"})
allowed, reason := CheckIPRestrictionWithCompiledRules("8.8.8.8", invalidWhitelist, nil)
require.False(t, allowed)
require.Equal(t, "access denied", reason)
}

View File

@@ -10,6 +10,7 @@ import (
"path/filepath"
"strings"
"sync"
"sync/atomic"
"time"
"go.uber.org/zap"
@@ -42,15 +43,19 @@ type LogEvent struct {
var (
mu sync.RWMutex
global *zap.Logger
sugar *zap.SugaredLogger
global atomic.Pointer[zap.Logger]
sugar atomic.Pointer[zap.SugaredLogger]
atomicLevel zap.AtomicLevel
initOptions InitOptions
currentSink Sink
currentSink atomic.Value // sinkState
stdLogUndo func()
bootstrapOnce sync.Once
)
type sinkState struct {
sink Sink
}
func InitBootstrap() {
bootstrapOnce.Do(func() {
if err := Init(bootstrapOptions()); err != nil {
@@ -72,9 +77,9 @@ func initLocked(options InitOptions) error {
return err
}
prev := global
global = zl
sugar = zl.Sugar()
prev := global.Load()
global.Store(zl)
sugar.Store(zl.Sugar())
atomicLevel = al
initOptions = normalized
@@ -115,24 +120,32 @@ func SetLevel(level string) error {
func CurrentLevel() string {
mu.RLock()
defer mu.RUnlock()
if global == nil {
if global.Load() == nil {
return "info"
}
return atomicLevel.Level().String()
}
func SetSink(sink Sink) {
mu.Lock()
defer mu.Unlock()
currentSink = sink
currentSink.Store(sinkState{sink: sink})
}
func loadSink() Sink {
v := currentSink.Load()
if v == nil {
return nil
}
state, ok := v.(sinkState)
if !ok {
return nil
}
return state.sink
}
// WriteSinkEvent 直接写入日志 sink不经过全局日志级别门控。
// 用于需要“可观测性入库”与“业务输出级别”解耦的场景(例如 ops 系统日志索引)。
func WriteSinkEvent(level, component, message string, fields map[string]any) {
mu.RLock()
sink := currentSink
mu.RUnlock()
sink := loadSink()
if sink == nil {
return
}
@@ -168,19 +181,15 @@ func WriteSinkEvent(level, component, message string, fields map[string]any) {
}
func L() *zap.Logger {
mu.RLock()
defer mu.RUnlock()
if global != nil {
return global
if l := global.Load(); l != nil {
return l
}
return zap.NewNop()
}
func S() *zap.SugaredLogger {
mu.RLock()
defer mu.RUnlock()
if sugar != nil {
return sugar
if s := sugar.Load(); s != nil {
return s
}
return zap.NewNop().Sugar()
}
@@ -190,9 +199,7 @@ func With(fields ...zap.Field) *zap.Logger {
}
func Sync() {
mu.RLock()
l := global
mu.RUnlock()
l := global.Load()
if l != nil {
_ = l.Sync()
}
@@ -210,7 +217,11 @@ func bridgeStdLogLocked() {
log.SetFlags(0)
log.SetPrefix("")
log.SetOutput(newStdLogBridge(global.Named("stdlog")))
base := global.Load()
if base == nil {
base = zap.NewNop()
}
log.SetOutput(newStdLogBridge(base.Named("stdlog")))
stdLogUndo = func() {
log.SetOutput(prevWriter)
@@ -220,7 +231,11 @@ func bridgeStdLogLocked() {
}
func bridgeSlogLocked() {
slog.SetDefault(slog.New(newSlogZapHandler(global.Named("slog"))))
base := global.Load()
if base == nil {
base = zap.NewNop()
}
slog.SetDefault(slog.New(newSlogZapHandler(base.Named("slog"))))
}
func buildLogger(options InitOptions) (*zap.Logger, zap.AtomicLevel, error) {
@@ -363,9 +378,7 @@ func (s *sinkCore) Check(entry zapcore.Entry, ce *zapcore.CheckedEntry) *zapcore
func (s *sinkCore) Write(entry zapcore.Entry, fields []zapcore.Field) error {
// Only handle sink forwarding — the inner cores write via their own
// Write methods (added to CheckedEntry by s.core.Check above).
mu.RLock()
sink := currentSink
mu.RUnlock()
sink := loadSink()
if sink == nil {
return nil
}
@@ -454,7 +467,7 @@ func inferStdLogLevel(msg string) Level {
if strings.Contains(lower, " failed") || strings.Contains(lower, "error") || strings.Contains(lower, "panic") || strings.Contains(lower, "fatal") {
return LevelError
}
if strings.Contains(lower, "warning") || strings.Contains(lower, "warn") || strings.Contains(lower, " retry") || strings.Contains(lower, " queue full") || strings.Contains(lower, "fallback") {
if strings.Contains(lower, "warning") || strings.Contains(lower, "warn") || strings.Contains(lower, " queue full") || strings.Contains(lower, "fallback") {
return LevelWarn
}
return LevelInfo
@@ -467,9 +480,7 @@ func LegacyPrintf(component, format string, args ...any) {
return
}
mu.RLock()
initialized := global != nil
mu.RUnlock()
initialized := global.Load() != nil
if !initialized {
// 在日志系统未初始化前,回退到标准库 log避免测试/工具链丢日志。
log.Print(msg)

View File

@@ -48,16 +48,15 @@ func (h *slogZapHandler) Handle(_ context.Context, record slog.Record) error {
return true
})
entry := h.logger.With(fields...)
switch {
case record.Level >= slog.LevelError:
entry.Error(record.Message)
h.logger.Error(record.Message, fields...)
case record.Level >= slog.LevelWarn:
entry.Warn(record.Message)
h.logger.Warn(record.Message, fields...)
case record.Level <= slog.LevelDebug:
entry.Debug(record.Message)
h.logger.Debug(record.Message, fields...)
default:
entry.Info(record.Message)
h.logger.Info(record.Message, fields...)
}
return nil
}

View File

@@ -16,6 +16,7 @@ func TestInferStdLogLevel(t *testing.T) {
{msg: "Warning: queue full", want: LevelWarn},
{msg: "Forward request failed: timeout", want: LevelError},
{msg: "[ERROR] upstream unavailable", want: LevelError},
{msg: "[OpenAI WS Mode] reconnect_retry account_id=22 retry=1 max_retries=5", want: LevelInfo},
{msg: "service started", want: LevelInfo},
{msg: "debug: cache miss", want: LevelDebug},
}

View File

@@ -36,10 +36,18 @@ const (
SessionTTL = 30 * time.Minute
)
const (
// OAuthPlatformOpenAI uses OpenAI Codex-compatible OAuth client.
OAuthPlatformOpenAI = "openai"
// OAuthPlatformSora uses Sora OAuth client.
OAuthPlatformSora = "sora"
)
// OAuthSession stores OAuth flow state for OpenAI
type OAuthSession struct {
State string `json:"state"`
CodeVerifier string `json:"code_verifier"`
ClientID string `json:"client_id,omitempty"`
ProxyURL string `json:"proxy_url,omitempty"`
RedirectURI string `json:"redirect_uri"`
CreatedAt time.Time `json:"created_at"`
@@ -174,13 +182,20 @@ func base64URLEncode(data []byte) string {
// BuildAuthorizationURL builds the OpenAI OAuth authorization URL
func BuildAuthorizationURL(state, codeChallenge, redirectURI string) string {
return BuildAuthorizationURLForPlatform(state, codeChallenge, redirectURI, OAuthPlatformOpenAI)
}
// BuildAuthorizationURLForPlatform builds authorization URL by platform.
func BuildAuthorizationURLForPlatform(state, codeChallenge, redirectURI, platform string) string {
if redirectURI == "" {
redirectURI = DefaultRedirectURI
}
clientID, codexFlow := OAuthClientConfigByPlatform(platform)
params := url.Values{}
params.Set("response_type", "code")
params.Set("client_id", ClientID)
params.Set("client_id", clientID)
params.Set("redirect_uri", redirectURI)
params.Set("scope", DefaultScopes)
params.Set("state", state)
@@ -188,11 +203,25 @@ func BuildAuthorizationURL(state, codeChallenge, redirectURI string) string {
params.Set("code_challenge_method", "S256")
// OpenAI specific parameters
params.Set("id_token_add_organizations", "true")
params.Set("codex_cli_simplified_flow", "true")
if codexFlow {
params.Set("codex_cli_simplified_flow", "true")
}
return fmt.Sprintf("%s?%s", AuthorizeURL, params.Encode())
}
// OAuthClientConfigByPlatform returns oauth client_id and whether codex simplified flow should be enabled.
// Sora 授权流程复用 Codex CLI 的 client_id支持 localhost redirect_uri
// 但不启用 codex_cli_simplified_flow拿到的 access_token 绑定同一 OpenAI 账号,对 Sora API 同样可用。
func OAuthClientConfigByPlatform(platform string) (clientID string, codexFlow bool) {
switch strings.ToLower(strings.TrimSpace(platform)) {
case OAuthPlatformSora:
return ClientID, false
default:
return ClientID, true
}
}
// TokenRequest represents the token exchange request body
type TokenRequest struct {
GrantType string `json:"grant_type"`
@@ -296,9 +325,11 @@ func (r *RefreshTokenRequest) ToFormData() string {
return params.Encode()
}
// ParseIDToken parses the ID Token JWT and extracts claims
// Note: This does NOT verify the signature - it only decodes the payload
// For production, you should verify the token signature using OpenAI's public keys
// ParseIDToken parses the ID Token JWT and extracts claims.
// 注意:当前仅解码 payload 并校验 exp未验证 JWT 签名。
// 生产环境如需用 ID Token 做授权决策,应通过 OpenAI 的 JWKS 端点验证签名:
//
// https://auth.openai.com/.well-known/jwks.json
func ParseIDToken(idToken string) (*IDTokenClaims, error) {
parts := strings.Split(idToken, ".")
if len(parts) != 3 {
@@ -329,6 +360,13 @@ func ParseIDToken(idToken string) (*IDTokenClaims, error) {
return nil, fmt.Errorf("failed to parse JWT claims: %w", err)
}
// 校验 ID Token 是否已过期(允许 2 分钟时钟偏差,防止因服务器时钟略有差异误判刚颁发的令牌)
const clockSkewTolerance = 120 // 秒
now := time.Now().Unix()
if claims.Exp > 0 && now > claims.Exp+clockSkewTolerance {
return nil, fmt.Errorf("id_token has expired (exp: %d, now: %d, skew_tolerance: %ds)", claims.Exp, now, clockSkewTolerance)
}
return &claims, nil
}

View File

@@ -1,6 +1,7 @@
package openai
import (
"net/url"
"sync"
"testing"
"time"
@@ -41,3 +42,41 @@ func TestSessionStore_Stop_Concurrent(t *testing.T) {
t.Fatal("stopCh 未关闭")
}
}
func TestBuildAuthorizationURLForPlatform_OpenAI(t *testing.T) {
authURL := BuildAuthorizationURLForPlatform("state-1", "challenge-1", DefaultRedirectURI, OAuthPlatformOpenAI)
parsed, err := url.Parse(authURL)
if err != nil {
t.Fatalf("Parse URL failed: %v", err)
}
q := parsed.Query()
if got := q.Get("client_id"); got != ClientID {
t.Fatalf("client_id mismatch: got=%q want=%q", got, ClientID)
}
if got := q.Get("codex_cli_simplified_flow"); got != "true" {
t.Fatalf("codex flow mismatch: got=%q want=true", got)
}
if got := q.Get("id_token_add_organizations"); got != "true" {
t.Fatalf("id_token_add_organizations mismatch: got=%q want=true", got)
}
}
// TestBuildAuthorizationURLForPlatform_Sora 验证 Sora 平台复用 Codex CLI 的 client_id
// 但不启用 codex_cli_simplified_flow。
func TestBuildAuthorizationURLForPlatform_Sora(t *testing.T) {
authURL := BuildAuthorizationURLForPlatform("state-2", "challenge-2", DefaultRedirectURI, OAuthPlatformSora)
parsed, err := url.Parse(authURL)
if err != nil {
t.Fatalf("Parse URL failed: %v", err)
}
q := parsed.Query()
if got := q.Get("client_id"); got != ClientID {
t.Fatalf("client_id mismatch: got=%q want=%q (Sora should reuse Codex CLI client_id)", got, ClientID)
}
if got := q.Get("codex_cli_simplified_flow"); got != "" {
t.Fatalf("codex flow should be empty for sora, got=%q", got)
}
if got := q.Get("id_token_add_organizations"); got != "true" {
t.Fatalf("id_token_add_organizations mismatch: got=%q want=true", got)
}
}

View File

@@ -29,10 +29,10 @@ func parsePaginatedBody(t *testing.T, w *httptest.ResponseRecorder) (Response, P
t.Helper()
// 先用 raw json 解析,因为 Data 是 any 类型
var raw struct {
Code int `json:"code"`
Message string `json:"message"`
Reason string `json:"reason,omitempty"`
Data json.RawMessage `json:"data,omitempty"`
Code int `json:"code"`
Message string `json:"message"`
Reason string `json:"reason,omitempty"`
Data json.RawMessage `json:"data,omitempty"`
}
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &raw))

View File

@@ -268,8 +268,8 @@ func (d *SOCKS5ProxyDialer) DialTLSContext(ctx context.Context, network, addr st
"cipher_suites", len(spec.CipherSuites),
"extensions", len(spec.Extensions),
"compression_methods", spec.CompressionMethods,
"tls_vers_max", fmt.Sprintf("0x%04x", spec.TLSVersMax),
"tls_vers_min", fmt.Sprintf("0x%04x", spec.TLSVersMin))
"tls_vers_max", spec.TLSVersMax,
"tls_vers_min", spec.TLSVersMin)
if d.profile != nil {
slog.Debug("tls_fingerprint_socks5_using_profile", "name", d.profile.Name, "grease", d.profile.EnableGREASE)
@@ -294,8 +294,8 @@ func (d *SOCKS5ProxyDialer) DialTLSContext(ctx context.Context, network, addr st
state := tlsConn.ConnectionState()
slog.Debug("tls_fingerprint_socks5_handshake_success",
"version", fmt.Sprintf("0x%04x", state.Version),
"cipher_suite", fmt.Sprintf("0x%04x", state.CipherSuite),
"version", state.Version,
"cipher_suite", state.CipherSuite,
"alpn", state.NegotiatedProtocol)
return tlsConn, nil
@@ -404,8 +404,8 @@ func (d *HTTPProxyDialer) DialTLSContext(ctx context.Context, network, addr stri
state := tlsConn.ConnectionState()
slog.Debug("tls_fingerprint_http_proxy_handshake_success",
"version", fmt.Sprintf("0x%04x", state.Version),
"cipher_suite", fmt.Sprintf("0x%04x", state.CipherSuite),
"version", state.Version,
"cipher_suite", state.CipherSuite,
"alpn", state.NegotiatedProtocol)
return tlsConn, nil
@@ -470,8 +470,8 @@ func (d *Dialer) DialTLSContext(ctx context.Context, network, addr string) (net.
// Log successful handshake details
state := tlsConn.ConnectionState()
slog.Debug("tls_fingerprint_handshake_success",
"version", fmt.Sprintf("0x%04x", state.Version),
"cipher_suite", fmt.Sprintf("0x%04x", state.CipherSuite),
"version", state.Version,
"cipher_suite", state.CipherSuite,
"alpn", state.NegotiatedProtocol)
return tlsConn, nil

View File

@@ -139,6 +139,7 @@ type UsageLogFilters struct {
AccountID int64
GroupID int64
Model string
RequestType *int16
Stream *bool
BillingType *int8
StartTime *time.Time