feat(openai): 增加 OAuth 账号 Codex 官方客户端限制开关
新增 codex_cli_only 开关并默认关闭,关闭时完全绕过限制逻辑。 在 OpenAI 网关引入统一检测入口,集中判定账号类型、开关与客户端族。 开启后仅放行 codex_cli_rs、codex_vscode、codex_app 客户端家族。 补充后端判定与网关分支测试,并在前端创建/编辑页增加开关配置与回显。 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -9,6 +9,14 @@ var CodexCLIUserAgentPrefixes = []string{
|
||||
"codex_cli_rs/",
|
||||
}
|
||||
|
||||
// CodexOfficialClientUserAgentPrefixes matches Codex 官方客户端家族 User-Agent 前缀。
|
||||
// 该列表仅用于 OpenAI OAuth `codex_cli_only` 访问限制判定。
|
||||
var CodexOfficialClientUserAgentPrefixes = []string{
|
||||
"codex_cli_rs/",
|
||||
"codex_vscode/",
|
||||
"codex_app/",
|
||||
}
|
||||
|
||||
// IsCodexCLIRequest checks if the User-Agent indicates a Codex CLI request
|
||||
func IsCodexCLIRequest(userAgent string) bool {
|
||||
ua := strings.ToLower(strings.TrimSpace(userAgent))
|
||||
@@ -27,3 +35,23 @@ func IsCodexCLIRequest(userAgent string) bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsCodexOfficialClientRequest checks if the User-Agent indicates a Codex 官方客户端请求。
|
||||
// 与 IsCodexCLIRequest 解耦,避免影响历史兼容逻辑。
|
||||
func IsCodexOfficialClientRequest(userAgent string) bool {
|
||||
ua := strings.ToLower(strings.TrimSpace(userAgent))
|
||||
if ua == "" {
|
||||
return false
|
||||
}
|
||||
for _, prefix := range CodexOfficialClientUserAgentPrefixes {
|
||||
normalizedPrefix := strings.ToLower(strings.TrimSpace(prefix))
|
||||
if normalizedPrefix == "" {
|
||||
continue
|
||||
}
|
||||
// 优先前缀匹配;若 UA 被网关/代理拼接为复合字符串时,退化为包含匹配。
|
||||
if strings.HasPrefix(ua, normalizedPrefix) || strings.Contains(ua, normalizedPrefix) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -26,3 +26,28 @@ func TestIsCodexCLIRequest(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsCodexOfficialClientRequest(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ua string
|
||||
want bool
|
||||
}{
|
||||
{name: "codex_cli_rs 前缀", ua: "codex_cli_rs/0.98.0", want: true},
|
||||
{name: "codex_vscode 前缀", ua: "codex_vscode/1.0.0", want: true},
|
||||
{name: "codex_app 前缀", ua: "codex_app/0.1.0", want: true},
|
||||
{name: "复合 UA 包含 codex_app", ua: "Mozilla/5.0 codex_app/0.1.0", want: true},
|
||||
{name: "大小写混合", ua: "Codex_VSCode/1.2.3", want: true},
|
||||
{name: "非 codex", ua: "curl/8.0.1", want: false},
|
||||
{name: "空字符串", ua: "", want: false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := IsCodexOfficialClientRequest(tt.ua)
|
||||
if got != tt.want {
|
||||
t.Fatalf("IsCodexOfficialClientRequest(%q) = %v, want %v", tt.ua, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -719,6 +719,17 @@ func (a *Account) IsOpenAIOAuthPassthroughEnabled() bool {
|
||||
return a != nil && a.IsOpenAIOAuth() && a.IsOpenAIPassthroughEnabled()
|
||||
}
|
||||
|
||||
// IsCodexCLIOnlyEnabled 返回 OpenAI OAuth 账号是否启用“仅允许 Codex 官方客户端”。
|
||||
// 字段:accounts.extra.codex_cli_only。
|
||||
// 字段缺失或类型不正确时,按 false(关闭)处理。
|
||||
func (a *Account) IsCodexCLIOnlyEnabled() bool {
|
||||
if a == nil || !a.IsOpenAIOAuth() || a.Extra == nil {
|
||||
return false
|
||||
}
|
||||
enabled, ok := a.Extra["codex_cli_only"].(bool)
|
||||
return ok && enabled
|
||||
}
|
||||
|
||||
// WindowCostSchedulability 窗口费用调度状态
|
||||
type WindowCostSchedulability int
|
||||
|
||||
|
||||
@@ -70,3 +70,67 @@ func TestAccount_IsOpenAIOAuthPassthroughEnabled(t *testing.T) {
|
||||
require.False(t, apiKeyAccount.IsOpenAIOAuthPassthroughEnabled())
|
||||
})
|
||||
}
|
||||
|
||||
func TestAccount_IsCodexCLIOnlyEnabled(t *testing.T) {
|
||||
t.Run("OpenAI OAuth 开启", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{
|
||||
"codex_cli_only": true,
|
||||
},
|
||||
}
|
||||
require.True(t, account.IsCodexCLIOnlyEnabled())
|
||||
})
|
||||
|
||||
t.Run("OpenAI OAuth 关闭", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{
|
||||
"codex_cli_only": false,
|
||||
},
|
||||
}
|
||||
require.False(t, account.IsCodexCLIOnlyEnabled())
|
||||
})
|
||||
|
||||
t.Run("字段缺失默认关闭", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{},
|
||||
}
|
||||
require.False(t, account.IsCodexCLIOnlyEnabled())
|
||||
})
|
||||
|
||||
t.Run("类型非法默认关闭", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{
|
||||
"codex_cli_only": "true",
|
||||
},
|
||||
}
|
||||
require.False(t, account.IsCodexCLIOnlyEnabled())
|
||||
})
|
||||
|
||||
t.Run("非 OAuth 账号始终关闭", func(t *testing.T) {
|
||||
apiKeyAccount := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Extra: map[string]any{
|
||||
"codex_cli_only": true,
|
||||
},
|
||||
}
|
||||
require.False(t, apiKeyAccount.IsCodexCLIOnlyEnabled())
|
||||
|
||||
otherPlatform := &Account{
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{
|
||||
"codex_cli_only": true,
|
||||
},
|
||||
}
|
||||
require.False(t, otherPlatform.IsCodexCLIOnlyEnabled())
|
||||
})
|
||||
}
|
||||
|
||||
@@ -0,0 +1,75 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const (
|
||||
// CodexClientRestrictionReasonDisabled 表示账号未开启 codex_cli_only。
|
||||
CodexClientRestrictionReasonDisabled = "codex_cli_only_disabled"
|
||||
// CodexClientRestrictionReasonMatchedUA 表示请求命中官方客户端 UA 白名单。
|
||||
CodexClientRestrictionReasonMatchedUA = "official_client_user_agent_matched"
|
||||
// CodexClientRestrictionReasonNotMatchedUA 表示请求未命中官方客户端 UA 白名单。
|
||||
CodexClientRestrictionReasonNotMatchedUA = "official_client_user_agent_not_matched"
|
||||
// CodexClientRestrictionReasonForceCodexCLI 表示通过 ForceCodexCLI 配置兜底放行。
|
||||
CodexClientRestrictionReasonForceCodexCLI = "force_codex_cli_enabled"
|
||||
)
|
||||
|
||||
// CodexClientRestrictionDetectionResult 是 codex_cli_only 统一检测入口结果。
|
||||
type CodexClientRestrictionDetectionResult struct {
|
||||
Enabled bool
|
||||
Matched bool
|
||||
Reason string
|
||||
}
|
||||
|
||||
// CodexClientRestrictionDetector 定义 codex_cli_only 统一检测入口。
|
||||
type CodexClientRestrictionDetector interface {
|
||||
Detect(c *gin.Context, account *Account) CodexClientRestrictionDetectionResult
|
||||
}
|
||||
|
||||
// OpenAICodexClientRestrictionDetector 为 OpenAI OAuth codex_cli_only 的默认实现。
|
||||
type OpenAICodexClientRestrictionDetector struct {
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
func NewOpenAICodexClientRestrictionDetector(cfg *config.Config) *OpenAICodexClientRestrictionDetector {
|
||||
return &OpenAICodexClientRestrictionDetector{cfg: cfg}
|
||||
}
|
||||
|
||||
func (d *OpenAICodexClientRestrictionDetector) Detect(c *gin.Context, account *Account) CodexClientRestrictionDetectionResult {
|
||||
if account == nil || !account.IsCodexCLIOnlyEnabled() {
|
||||
return CodexClientRestrictionDetectionResult{
|
||||
Enabled: false,
|
||||
Matched: false,
|
||||
Reason: CodexClientRestrictionReasonDisabled,
|
||||
}
|
||||
}
|
||||
|
||||
if d != nil && d.cfg != nil && d.cfg.Gateway.ForceCodexCLI {
|
||||
return CodexClientRestrictionDetectionResult{
|
||||
Enabled: true,
|
||||
Matched: true,
|
||||
Reason: CodexClientRestrictionReasonForceCodexCLI,
|
||||
}
|
||||
}
|
||||
|
||||
userAgent := ""
|
||||
if c != nil {
|
||||
userAgent = c.GetHeader("User-Agent")
|
||||
}
|
||||
if openai.IsCodexOfficialClientRequest(userAgent) {
|
||||
return CodexClientRestrictionDetectionResult{
|
||||
Enabled: true,
|
||||
Matched: true,
|
||||
Reason: CodexClientRestrictionReasonMatchedUA,
|
||||
}
|
||||
}
|
||||
|
||||
return CodexClientRestrictionDetectionResult{
|
||||
Enabled: true,
|
||||
Matched: false,
|
||||
Reason: CodexClientRestrictionReasonNotMatchedUA,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,107 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func newCodexDetectorTestContext(ua string) *gin.Context {
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
if ua != "" {
|
||||
c.Request.Header.Set("User-Agent", ua)
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
func TestOpenAICodexClientRestrictionDetector_Detect(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
t.Run("未开启开关时绕过", func(t *testing.T) {
|
||||
detector := NewOpenAICodexClientRestrictionDetector(nil)
|
||||
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Extra: map[string]any{}}
|
||||
|
||||
result := detector.Detect(newCodexDetectorTestContext("curl/8.0"), account)
|
||||
require.False(t, result.Enabled)
|
||||
require.False(t, result.Matched)
|
||||
require.Equal(t, CodexClientRestrictionReasonDisabled, result.Reason)
|
||||
})
|
||||
|
||||
t.Run("开启后 codex_cli_rs 命中", func(t *testing.T) {
|
||||
detector := NewOpenAICodexClientRestrictionDetector(nil)
|
||||
account := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{"codex_cli_only": true},
|
||||
}
|
||||
|
||||
result := detector.Detect(newCodexDetectorTestContext("codex_cli_rs/0.99.0"), account)
|
||||
require.True(t, result.Enabled)
|
||||
require.True(t, result.Matched)
|
||||
require.Equal(t, CodexClientRestrictionReasonMatchedUA, result.Reason)
|
||||
})
|
||||
|
||||
t.Run("开启后 codex_vscode 命中", func(t *testing.T) {
|
||||
detector := NewOpenAICodexClientRestrictionDetector(nil)
|
||||
account := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{"codex_cli_only": true},
|
||||
}
|
||||
|
||||
result := detector.Detect(newCodexDetectorTestContext("codex_vscode/1.0.0"), account)
|
||||
require.True(t, result.Enabled)
|
||||
require.True(t, result.Matched)
|
||||
require.Equal(t, CodexClientRestrictionReasonMatchedUA, result.Reason)
|
||||
})
|
||||
|
||||
t.Run("开启后 codex_app 命中", func(t *testing.T) {
|
||||
detector := NewOpenAICodexClientRestrictionDetector(nil)
|
||||
account := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{"codex_cli_only": true},
|
||||
}
|
||||
|
||||
result := detector.Detect(newCodexDetectorTestContext("codex_app/2.1.0"), account)
|
||||
require.True(t, result.Enabled)
|
||||
require.True(t, result.Matched)
|
||||
require.Equal(t, CodexClientRestrictionReasonMatchedUA, result.Reason)
|
||||
})
|
||||
|
||||
t.Run("开启后非官方客户端拒绝", func(t *testing.T) {
|
||||
detector := NewOpenAICodexClientRestrictionDetector(nil)
|
||||
account := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{"codex_cli_only": true},
|
||||
}
|
||||
|
||||
result := detector.Detect(newCodexDetectorTestContext("curl/8.0"), account)
|
||||
require.True(t, result.Enabled)
|
||||
require.False(t, result.Matched)
|
||||
require.Equal(t, CodexClientRestrictionReasonNotMatchedUA, result.Reason)
|
||||
})
|
||||
|
||||
t.Run("开启 ForceCodexCLI 时允许通过", func(t *testing.T) {
|
||||
detector := NewOpenAICodexClientRestrictionDetector(&config.Config{
|
||||
Gateway: config.GatewayConfig{ForceCodexCLI: true},
|
||||
})
|
||||
account := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{"codex_cli_only": true},
|
||||
}
|
||||
|
||||
result := detector.Detect(newCodexDetectorTestContext("curl/8.0"), account)
|
||||
require.True(t, result.Enabled)
|
||||
require.True(t, result.Matched)
|
||||
require.Equal(t, CodexClientRestrictionReasonForceCodexCLI, result.Reason)
|
||||
})
|
||||
}
|
||||
@@ -190,6 +190,7 @@ type OpenAIGatewayService struct {
|
||||
userSubRepo UserSubscriptionRepository
|
||||
cache GatewayCache
|
||||
cfg *config.Config
|
||||
codexDetector CodexClientRestrictionDetector
|
||||
schedulerSnapshot *SchedulerSnapshotService
|
||||
concurrencyService *ConcurrencyService
|
||||
billingService *BillingService
|
||||
@@ -225,6 +226,7 @@ func NewOpenAIGatewayService(
|
||||
userSubRepo: userSubRepo,
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
codexDetector: NewOpenAICodexClientRestrictionDetector(cfg),
|
||||
schedulerSnapshot: schedulerSnapshot,
|
||||
concurrencyService: concurrencyService,
|
||||
billingService: billingService,
|
||||
@@ -237,6 +239,65 @@ func NewOpenAIGatewayService(
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) getCodexClientRestrictionDetector() CodexClientRestrictionDetector {
|
||||
if s != nil && s.codexDetector != nil {
|
||||
return s.codexDetector
|
||||
}
|
||||
var cfg *config.Config
|
||||
if s != nil {
|
||||
cfg = s.cfg
|
||||
}
|
||||
return NewOpenAICodexClientRestrictionDetector(cfg)
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) detectCodexClientRestriction(c *gin.Context, account *Account) CodexClientRestrictionDetectionResult {
|
||||
return s.getCodexClientRestrictionDetector().Detect(c, account)
|
||||
}
|
||||
|
||||
func getAPIKeyIDFromContext(c *gin.Context) int64 {
|
||||
if c == nil {
|
||||
return 0
|
||||
}
|
||||
v, exists := c.Get("api_key")
|
||||
if !exists {
|
||||
return 0
|
||||
}
|
||||
apiKey, ok := v.(*APIKey)
|
||||
if !ok || apiKey == nil {
|
||||
return 0
|
||||
}
|
||||
return apiKey.ID
|
||||
}
|
||||
|
||||
func logCodexCLIOnlyDetection(ctx context.Context, account *Account, apiKeyID int64, result CodexClientRestrictionDetectionResult) {
|
||||
if !result.Enabled {
|
||||
return
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
accountID := int64(0)
|
||||
if account != nil {
|
||||
accountID = account.ID
|
||||
}
|
||||
fields := []zap.Field{
|
||||
zap.String("component", "service.openai_gateway"),
|
||||
zap.Int64("account_id", accountID),
|
||||
zap.Bool("codex_cli_only_enabled", result.Enabled),
|
||||
zap.Bool("codex_official_client_match", result.Matched),
|
||||
zap.String("reject_reason", result.Reason),
|
||||
}
|
||||
if apiKeyID > 0 {
|
||||
fields = append(fields, zap.Int64("api_key_id", apiKeyID))
|
||||
}
|
||||
log := logger.FromContext(ctx).With(fields...)
|
||||
if result.Matched {
|
||||
log.Info("OpenAI codex_cli_only 检测通过")
|
||||
return
|
||||
}
|
||||
log.Warn("OpenAI codex_cli_only 拒绝非官方客户端请求")
|
||||
}
|
||||
|
||||
// GenerateSessionHash generates a sticky-session hash for OpenAI requests.
|
||||
//
|
||||
// Priority:
|
||||
@@ -757,6 +818,19 @@ func (s *OpenAIGatewayService) handleFailoverSideEffects(ctx context.Context, re
|
||||
func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*OpenAIForwardResult, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
restrictionResult := s.detectCodexClientRestriction(c, account)
|
||||
apiKeyID := getAPIKeyIDFromContext(c)
|
||||
logCodexCLIOnlyDetection(ctx, account, apiKeyID, restrictionResult)
|
||||
if restrictionResult.Enabled && !restrictionResult.Matched {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"error": gin.H{
|
||||
"type": "forbidden_error",
|
||||
"message": "This account only allows Codex official clients",
|
||||
},
|
||||
})
|
||||
return nil, errors.New("codex_cli_only restriction: only codex official clients are allowed")
|
||||
}
|
||||
|
||||
originalBody := body
|
||||
reqModel, reqStream, promptCacheKey := extractOpenAIRequestMetaFromBody(body)
|
||||
originalModel := reqModel
|
||||
|
||||
@@ -0,0 +1,101 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type stubCodexRestrictionDetector struct {
|
||||
result CodexClientRestrictionDetectionResult
|
||||
}
|
||||
|
||||
func (s *stubCodexRestrictionDetector) Detect(_ *gin.Context, _ *Account) CodexClientRestrictionDetectionResult {
|
||||
return s.result
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_GetCodexClientRestrictionDetector(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
t.Run("使用注入的 detector", func(t *testing.T) {
|
||||
expected := &stubCodexRestrictionDetector{
|
||||
result: CodexClientRestrictionDetectionResult{Enabled: true, Matched: true, Reason: "stub"},
|
||||
}
|
||||
svc := &OpenAIGatewayService{codexDetector: expected}
|
||||
|
||||
got := svc.getCodexClientRestrictionDetector()
|
||||
require.Same(t, expected, got)
|
||||
})
|
||||
|
||||
t.Run("service 为 nil 时返回默认 detector", func(t *testing.T) {
|
||||
var svc *OpenAIGatewayService
|
||||
got := svc.getCodexClientRestrictionDetector()
|
||||
require.NotNil(t, got)
|
||||
})
|
||||
|
||||
t.Run("service 未注入 detector 时返回默认 detector", func(t *testing.T) {
|
||||
svc := &OpenAIGatewayService{cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: true}}}
|
||||
got := svc.getCodexClientRestrictionDetector()
|
||||
require.NotNil(t, got)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
c.Request.Header.Set("User-Agent", "curl/8.0")
|
||||
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Extra: map[string]any{"codex_cli_only": true}}
|
||||
|
||||
result := got.Detect(c, account)
|
||||
require.True(t, result.Enabled)
|
||||
require.True(t, result.Matched)
|
||||
require.Equal(t, CodexClientRestrictionReasonForceCodexCLI, result.Reason)
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetAPIKeyIDFromContext(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
t.Run("context 为 nil", func(t *testing.T) {
|
||||
require.Equal(t, int64(0), getAPIKeyIDFromContext(nil))
|
||||
})
|
||||
|
||||
t.Run("上下文没有 api_key", func(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
require.Equal(t, int64(0), getAPIKeyIDFromContext(c))
|
||||
})
|
||||
|
||||
t.Run("api_key 类型错误", func(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Set("api_key", "not-api-key")
|
||||
require.Equal(t, int64(0), getAPIKeyIDFromContext(c))
|
||||
})
|
||||
|
||||
t.Run("api_key 指针为空", func(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
var k *APIKey
|
||||
c.Set("api_key", k)
|
||||
require.Equal(t, int64(0), getAPIKeyIDFromContext(c))
|
||||
})
|
||||
|
||||
t.Run("正常读取 api_key_id", func(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Set("api_key", &APIKey{ID: 12345})
|
||||
require.Equal(t, int64(12345), getAPIKeyIDFromContext(c))
|
||||
})
|
||||
}
|
||||
|
||||
func TestLogCodexCLIOnlyDetection_NilSafety(t *testing.T) {
|
||||
// 不校验日志内容,仅保证在 nil 入参下不会 panic。
|
||||
require.NotPanics(t, func() {
|
||||
logCodexCLIOnlyDetection(nil, nil, 0, CodexClientRestrictionDetectionResult{Enabled: true, Matched: false, Reason: "test"})
|
||||
logCodexCLIOnlyDetection(context.Background(), nil, 0, CodexClientRestrictionDetectionResult{Enabled: false, Matched: false, Reason: "disabled"})
|
||||
})
|
||||
}
|
||||
@@ -435,6 +435,92 @@ func TestOpenAIGatewayService_OAuthPassthrough_NonCodexUAFallbackToCodexUA(t *te
|
||||
require.Equal(t, "codex_cli_rs/0.98.0", upstream.lastReq.Header.Get("User-Agent"))
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_CodexCLIOnly_RejectsNonCodexClient(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil))
|
||||
c.Request.Header.Set("User-Agent", "curl/8.0")
|
||||
|
||||
inputBody := []byte(`{"model":"gpt-5.2","stream":false,"store":true,"input":[{"type":"text","text":"hi"}]}`)
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}},
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
ID: 123,
|
||||
Name: "acc",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"},
|
||||
Extra: map[string]any{"openai_passthrough": true, "codex_cli_only": true},
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
RateMultiplier: f64p(1),
|
||||
}
|
||||
|
||||
_, err := svc.Forward(context.Background(), c, account, inputBody)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, http.StatusForbidden, rec.Code)
|
||||
require.Contains(t, rec.Body.String(), "Codex official clients")
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_CodexCLIOnly_AllowOfficialClientFamilies(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
ua string
|
||||
}{
|
||||
{name: "codex_cli_rs", ua: "codex_cli_rs/0.99.0"},
|
||||
{name: "codex_vscode", ua: "codex_vscode/1.0.0"},
|
||||
{name: "codex_app", ua: "codex_app/2.1.0"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil))
|
||||
c.Request.Header.Set("User-Agent", tt.ua)
|
||||
|
||||
inputBody := []byte(`{"model":"gpt-5.2","stream":false,"store":true,"input":[{"type":"text","text":"hi"}]}`)
|
||||
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid"}},
|
||||
Body: io.NopCloser(strings.NewReader("data: [DONE]\n\n")),
|
||||
}
|
||||
upstream := &httpUpstreamRecorder{resp: resp}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}},
|
||||
httpUpstream: upstream,
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
ID: 123,
|
||||
Name: "acc",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"},
|
||||
Extra: map[string]any{"openai_passthrough": true, "codex_cli_only": true},
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
RateMultiplier: f64p(1),
|
||||
}
|
||||
|
||||
_, err := svc.Forward(context.Background(), c, account, inputBody)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, upstream.lastReq)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_OAuthPassthrough_StreamingSetsFirstTokenMs(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user