feat: 新增支持codex转发

This commit is contained in:
shaw
2025-12-22 22:58:31 +08:00
parent dacf3a2a6e
commit 6c469b42ed
46 changed files with 3749 additions and 477 deletions

View File

@@ -24,11 +24,6 @@ import (
"github.com/gin-gonic/gin"
)
// ClaudeUpstream handles HTTP requests to Claude API
type ClaudeUpstream interface {
Do(req *http.Request, proxyURL string) (*http.Response, error)
}
const (
claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true"
claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true"
@@ -87,7 +82,7 @@ type GatewayService struct {
rateLimitService *RateLimitService
billingCacheService *BillingCacheService
identityService *IdentityService
claudeUpstream ClaudeUpstream
httpUpstream ports.HTTPUpstream
}
// NewGatewayService creates a new GatewayService
@@ -102,7 +97,7 @@ func NewGatewayService(
rateLimitService *RateLimitService,
billingCacheService *BillingCacheService,
identityService *IdentityService,
claudeUpstream ClaudeUpstream,
httpUpstream ports.HTTPUpstream,
) *GatewayService {
return &GatewayService{
accountRepo: accountRepo,
@@ -115,7 +110,7 @@ func NewGatewayService(
rateLimitService: rateLimitService,
billingCacheService: billingCacheService,
identityService: identityService,
claudeUpstream: claudeUpstream,
httpUpstream: httpUpstream,
}
}
@@ -285,13 +280,13 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int
}
}
// 2. 获取可调度账号列表(排除限流和过载的账号)
// 2. 获取可调度账号列表(排除限流和过载的账号,仅限 Anthropic 平台
var accounts []model.Account
var err error
if groupID != nil {
accounts, err = s.accountRepo.ListSchedulableByGroupID(ctx, *groupID)
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, model.PlatformAnthropic)
} else {
accounts, err = s.accountRepo.ListSchedulable(ctx)
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, model.PlatformAnthropic)
}
if err != nil {
return nil, fmt.Errorf("query accounts failed: %w", err)
@@ -407,7 +402,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *m
}
// 发送请求
resp, err := s.claudeUpstream.Do(upstreamReq, proxyURL)
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL)
if err != nil {
return nil, fmt.Errorf("upstream request failed: %w", err)
}
@@ -481,7 +476,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
// 设置认证头
if tokenType == "oauth" {
req.Header.Set("Authorization", "Bearer "+token)
req.Header.Set("authorization", "Bearer "+token)
} else {
req.Header.Set("x-api-key", token)
}
@@ -502,8 +497,8 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
}
// 确保必要的headers存在
if req.Header.Get("Content-Type") == "" {
req.Header.Set("Content-Type", "application/json")
if req.Header.Get("content-type") == "" {
req.Header.Set("content-type", "application/json")
}
if req.Header.Get("anthropic-version") == "" {
req.Header.Set("anthropic-version", "2023-06-01")
@@ -982,7 +977,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
}
// 发送请求
resp, err := s.claudeUpstream.Do(upstreamReq, proxyURL)
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL)
if err != nil {
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Request failed")
return fmt.Errorf("upstream request failed: %w", err)
@@ -1049,7 +1044,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
// 设置认证头
if tokenType == "oauth" {
req.Header.Set("Authorization", "Bearer "+token)
req.Header.Set("authorization", "Bearer "+token)
} else {
req.Header.Set("x-api-key", token)
}
@@ -1073,8 +1068,8 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
}
// 确保必要的 headers 存在
if req.Header.Get("Content-Type") == "" {
req.Header.Set("Content-Type", "application/json")
if req.Header.Get("content-type") == "" {
req.Header.Set("content-type", "application/json")
}
if req.Header.Get("anthropic-version") == "" {
req.Header.Set("anthropic-version", "2023-06-01")